Merge remote-tracking branch 'eclipse/master' into ui_multisession_arbiter
commit
20f451e807
|
@ -50,11 +50,11 @@ import java.util.regex.PatternSyntaxException;
|
||||||
* <h4 id="Resources">Resources</h4>
|
* <h4 id="Resources">Resources</h4>
|
||||||
*
|
*
|
||||||
* <p>Configurations are specified by resources. A resource contains a set of
|
* <p>Configurations are specified by resources. A resource contains a set of
|
||||||
* name/value pairs as XML data. Each resource is named by either a
|
* name/value pairs as XML data. Each resource is named by either a
|
||||||
* <code>String</code>. If named by a <code>String</code>,
|
* <code>String</code> or a <code>Path</code>. If named by a
|
||||||
* then the classpath is examined for a file with that name. If named by a
|
* <code>String</code>, then the classpath is examined for a file with that
|
||||||
* <code>Path</code>, then the local filesystem is examined directly, without
|
* name. If named by a <code>Path</code>, then the local filesystem is
|
||||||
* referring to the classpath.
|
* examined directly, without referring to the classpath.
|
||||||
*
|
*
|
||||||
* <p>Unless explicitly turned off, Hadoop by default specifies two
|
* <p>Unless explicitly turned off, Hadoop by default specifies two
|
||||||
* resources, loaded in-order from the classpath: <ol>
|
* resources, loaded in-order from the classpath: <ol>
|
||||||
|
|
|
@ -52,6 +52,7 @@ public class NDArrayRecordBatch extends AbstractWritableRecordBatch {
|
||||||
public NDArrayRecordBatch(@NonNull List<INDArray> arrays){
|
public NDArrayRecordBatch(@NonNull List<INDArray> arrays){
|
||||||
Preconditions.checkArgument(arrays.size() > 0, "Input list must not be empty");
|
Preconditions.checkArgument(arrays.size() > 0, "Input list must not be empty");
|
||||||
this.arrays = arrays;
|
this.arrays = arrays;
|
||||||
|
this.size = arrays.get(0).size(0);
|
||||||
|
|
||||||
//Check that dimension 0 matches:
|
//Check that dimension 0 matches:
|
||||||
if(arrays.size() > 1){
|
if(arrays.size() > 1){
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.datavec.image.loader;
|
package org.datavec.image.loader;
|
||||||
|
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.io.FileUtils;
|
import org.apache.commons.io.FileUtils;
|
||||||
import org.datavec.image.data.Image;
|
import org.datavec.image.data.Image;
|
||||||
import org.datavec.image.transform.ImageTransform;
|
import org.datavec.image.transform.ImageTransform;
|
||||||
|
@ -35,10 +36,9 @@ import java.util.Random;
|
||||||
/**
|
/**
|
||||||
* Created by nyghtowl on 12/17/15.
|
* Created by nyghtowl on 12/17/15.
|
||||||
*/
|
*/
|
||||||
|
@Slf4j
|
||||||
public abstract class BaseImageLoader implements Serializable {
|
public abstract class BaseImageLoader implements Serializable {
|
||||||
|
|
||||||
protected static final Logger log = LoggerFactory.getLogger(BaseImageLoader.class);
|
|
||||||
|
|
||||||
public enum MultiPageMode {
|
public enum MultiPageMode {
|
||||||
MINIBATCH, FIRST //, CHANNELS,
|
MINIBATCH, FIRST //, CHANNELS,
|
||||||
}
|
}
|
||||||
|
@ -62,13 +62,37 @@ public abstract class BaseImageLoader implements Serializable {
|
||||||
|
|
||||||
public abstract INDArray asRowVector(InputStream inputStream) throws IOException;
|
public abstract INDArray asRowVector(InputStream inputStream) throws IOException;
|
||||||
|
|
||||||
|
/** As per {@link #asMatrix(File, boolean)} but NCHW/channels_first format */
|
||||||
public abstract INDArray asMatrix(File f) throws IOException;
|
public abstract INDArray asMatrix(File f) throws IOException;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Load an image from a file to an INDArray
|
||||||
|
* @param f File to load the image from
|
||||||
|
* @param nchw If true: return image in NCHW/channels_first [1, channels, height width] format; if false, return
|
||||||
|
* in NHWC/channels_last [1, height, width, channels] format
|
||||||
|
* @return Image file as as INDArray
|
||||||
|
*/
|
||||||
|
public abstract INDArray asMatrix(File f, boolean nchw) throws IOException;
|
||||||
|
|
||||||
public abstract INDArray asMatrix(InputStream inputStream) throws IOException;
|
public abstract INDArray asMatrix(InputStream inputStream) throws IOException;
|
||||||
|
/**
|
||||||
|
* Load an image file from an input stream to an INDArray
|
||||||
|
* @param inputStream Input stream to load the image from
|
||||||
|
* @param nchw If true: return image in NCHW/channels_first [1, channels, height width] format; if false, return
|
||||||
|
* in NHWC/channels_last [1, height, width, channels] format
|
||||||
|
* @return Image file stream as as INDArray
|
||||||
|
*/
|
||||||
|
public abstract INDArray asMatrix(InputStream inputStream, boolean nchw) throws IOException;
|
||||||
|
|
||||||
|
/** As per {@link #asMatrix(File)} but as an {@link Image}*/
|
||||||
public abstract Image asImageMatrix(File f) throws IOException;
|
public abstract Image asImageMatrix(File f) throws IOException;
|
||||||
|
/** As per {@link #asMatrix(File, boolean)} but as an {@link Image}*/
|
||||||
|
public abstract Image asImageMatrix(File f, boolean nchw) throws IOException;
|
||||||
|
|
||||||
|
/** As per {@link #asMatrix(InputStream)} but as an {@link Image}*/
|
||||||
public abstract Image asImageMatrix(InputStream inputStream) throws IOException;
|
public abstract Image asImageMatrix(InputStream inputStream) throws IOException;
|
||||||
|
/** As per {@link #asMatrix(InputStream, boolean)} but as an {@link Image}*/
|
||||||
|
public abstract Image asImageMatrix(InputStream inputStream, boolean nchw) throws IOException;
|
||||||
|
|
||||||
|
|
||||||
public static void downloadAndUntar(Map urlMap, File fullDir) {
|
public static void downloadAndUntar(Map urlMap, File fullDir) {
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.datavec.image.loader;
|
package org.datavec.image.loader;
|
||||||
|
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.io.FileUtils;
|
import org.apache.commons.io.FileUtils;
|
||||||
import org.apache.commons.io.FilenameUtils;
|
import org.apache.commons.io.FilenameUtils;
|
||||||
import org.bytedeco.javacv.OpenCVFrameConverter;
|
import org.bytedeco.javacv.OpenCVFrameConverter;
|
||||||
|
@ -47,6 +48,7 @@ import static org.bytedeco.opencv.global.opencv_imgproc.*;
|
||||||
* There is a special preProcessor used to normalize the dataset based on Sergey Zagoruyko example
|
* There is a special preProcessor used to normalize the dataset based on Sergey Zagoruyko example
|
||||||
* <a href="https://github.com/szagoruyko/cifar.torch">https://github.com/szagoruyko/cifar.torch</a>
|
* <a href="https://github.com/szagoruyko/cifar.torch">https://github.com/szagoruyko/cifar.torch</a>
|
||||||
*/
|
*/
|
||||||
|
@Slf4j
|
||||||
public class CifarLoader extends NativeImageLoader implements Serializable {
|
public class CifarLoader extends NativeImageLoader implements Serializable {
|
||||||
public static final int NUM_TRAIN_IMAGES = 50000;
|
public static final int NUM_TRAIN_IMAGES = 50000;
|
||||||
public static final int NUM_TEST_IMAGES = 10000;
|
public static final int NUM_TEST_IMAGES = 10000;
|
||||||
|
|
|
@ -249,7 +249,14 @@ public class ImageLoader extends BaseImageLoader {
|
||||||
* @throws IOException
|
* @throws IOException
|
||||||
*/
|
*/
|
||||||
public INDArray asMatrix(File f) throws IOException {
|
public INDArray asMatrix(File f) throws IOException {
|
||||||
return NDArrayUtil.toNDArray(fromFile(f));
|
return asMatrix(f, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray asMatrix(File f, boolean nchw) throws IOException {
|
||||||
|
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
|
||||||
|
return asMatrix(is, nchw);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -259,34 +266,68 @@ public class ImageLoader extends BaseImageLoader {
|
||||||
* @return the input stream to convert
|
* @return the input stream to convert
|
||||||
*/
|
*/
|
||||||
public INDArray asMatrix(InputStream inputStream) throws IOException {
|
public INDArray asMatrix(InputStream inputStream) throws IOException {
|
||||||
if (channels == 3)
|
return asMatrix(inputStream, true);
|
||||||
return toBgr(inputStream);
|
}
|
||||||
try {
|
|
||||||
BufferedImage image = ImageIO.read(inputStream);
|
@Override
|
||||||
return asMatrix(image);
|
public INDArray asMatrix(InputStream inputStream, boolean nchw) throws IOException {
|
||||||
} catch (IOException e) {
|
INDArray ret;
|
||||||
throw new IOException("Unable to load image", e);
|
if (channels == 3) {
|
||||||
|
ret = toBgr(inputStream);
|
||||||
|
} else {
|
||||||
|
try {
|
||||||
|
BufferedImage image = ImageIO.read(inputStream);
|
||||||
|
ret = asMatrix(image);
|
||||||
|
} catch (IOException e) {
|
||||||
|
throw new IOException("Unable to load image", e);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
if(ret.rank() == 3){
|
||||||
|
ret = ret.reshape(1, ret.size(0), ret.size(1), ret.size(2));
|
||||||
|
}
|
||||||
|
if(!nchw)
|
||||||
|
ret = ret.permute(0,2,3,1); //NCHW to NHWC
|
||||||
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public org.datavec.image.data.Image asImageMatrix(File f) throws IOException {
|
public org.datavec.image.data.Image asImageMatrix(File f) throws IOException {
|
||||||
|
return asImageMatrix(f, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public org.datavec.image.data.Image asImageMatrix(File f, boolean nchw) throws IOException {
|
||||||
try (BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f))) {
|
try (BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f))) {
|
||||||
return asImageMatrix(bis);
|
return asImageMatrix(bis, nchw);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public org.datavec.image.data.Image asImageMatrix(InputStream inputStream) throws IOException {
|
public org.datavec.image.data.Image asImageMatrix(InputStream inputStream) throws IOException {
|
||||||
if (channels == 3)
|
return asImageMatrix(inputStream, true);
|
||||||
return toBgrImage(inputStream);
|
}
|
||||||
try {
|
|
||||||
BufferedImage image = ImageIO.read(inputStream);
|
@Override
|
||||||
INDArray asMatrix = asMatrix(image);
|
public org.datavec.image.data.Image asImageMatrix(InputStream inputStream, boolean nchw) throws IOException {
|
||||||
return new org.datavec.image.data.Image(asMatrix, image.getData().getNumBands(), image.getHeight(), image.getWidth());
|
org.datavec.image.data.Image ret;
|
||||||
} catch (IOException e) {
|
if (channels == 3) {
|
||||||
throw new IOException("Unable to load image", e);
|
ret = toBgrImage(inputStream);
|
||||||
|
} else {
|
||||||
|
try {
|
||||||
|
BufferedImage image = ImageIO.read(inputStream);
|
||||||
|
INDArray asMatrix = asMatrix(image);
|
||||||
|
ret = new org.datavec.image.data.Image(asMatrix, image.getData().getNumBands(), image.getHeight(), image.getWidth());
|
||||||
|
} catch (IOException e) {
|
||||||
|
throw new IOException("Unable to load image", e);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
if(ret.getImage().rank() == 3){
|
||||||
|
INDArray a = ret.getImage();
|
||||||
|
ret.setImage(a.reshape(1, a.size(0), a.size(1), a.size(2)));
|
||||||
|
}
|
||||||
|
if(!nchw)
|
||||||
|
ret.setImage(ret.getImage().permute(0,2,3,1)); //NCHW to NHWC
|
||||||
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
package org.datavec.image.loader;
|
package org.datavec.image.loader;
|
||||||
|
|
||||||
|
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.datavec.api.io.filters.BalancedPathFilter;
|
import org.datavec.api.io.filters.BalancedPathFilter;
|
||||||
import org.datavec.api.io.labels.PathLabelGenerator;
|
import org.datavec.api.io.labels.PathLabelGenerator;
|
||||||
import org.datavec.api.io.labels.PatternPathLabelGenerator;
|
import org.datavec.api.io.labels.PatternPathLabelGenerator;
|
||||||
|
@ -48,6 +49,7 @@ import java.util.Random;
|
||||||
* most images are in color, although a few are grayscale
|
* most images are in color, although a few are grayscale
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
|
@Slf4j
|
||||||
public class LFWLoader extends BaseImageLoader implements Serializable {
|
public class LFWLoader extends BaseImageLoader implements Serializable {
|
||||||
|
|
||||||
public final static int NUM_IMAGES = 13233;
|
public final static int NUM_IMAGES = 13233;
|
||||||
|
@ -270,19 +272,39 @@ public class LFWLoader extends BaseImageLoader implements Serializable {
|
||||||
throw new UnsupportedOperationException();
|
throw new UnsupportedOperationException();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray asMatrix(File f, boolean nchw) throws IOException {
|
||||||
|
throw new UnsupportedOperationException();
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray asMatrix(InputStream inputStream) throws IOException {
|
public INDArray asMatrix(InputStream inputStream) throws IOException {
|
||||||
throw new UnsupportedOperationException();
|
throw new UnsupportedOperationException();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray asMatrix(InputStream inputStream, boolean nchw) throws IOException {
|
||||||
|
throw new UnsupportedOperationException();
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Image asImageMatrix(File f) throws IOException {
|
public Image asImageMatrix(File f) throws IOException {
|
||||||
throw new UnsupportedOperationException();
|
throw new UnsupportedOperationException();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Image asImageMatrix(File f, boolean nchw) throws IOException {
|
||||||
|
throw new UnsupportedOperationException();
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Image asImageMatrix(InputStream inputStream) throws IOException {
|
public Image asImageMatrix(InputStream inputStream) throws IOException {
|
||||||
throw new UnsupportedOperationException();
|
throw new UnsupportedOperationException();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Image asImageMatrix(InputStream inputStream, boolean nchw) throws IOException {
|
||||||
|
throw new UnsupportedOperationException();
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -248,17 +248,27 @@ public class NativeImageLoader extends BaseImageLoader {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray asMatrix(File f) throws IOException {
|
public INDArray asMatrix(File f) throws IOException {
|
||||||
|
return asMatrix(f, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray asMatrix(File f, boolean nchw) throws IOException {
|
||||||
try (BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f))) {
|
try (BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f))) {
|
||||||
return asMatrix(bis);
|
return asMatrix(bis, nchw);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray asMatrix(InputStream is) throws IOException {
|
public INDArray asMatrix(InputStream is) throws IOException {
|
||||||
Mat mat = streamToMat(is);
|
return asMatrix(is, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray asMatrix(InputStream inputStream, boolean nchw) throws IOException {
|
||||||
|
Mat mat = streamToMat(inputStream);
|
||||||
INDArray a;
|
INDArray a;
|
||||||
if (this.multiPageMode != null) {
|
if (this.multiPageMode != null) {
|
||||||
a = asMatrix(mat.data(), mat.cols());
|
a = asMatrix(mat.data(), mat.cols());
|
||||||
}else{
|
}else{
|
||||||
Mat image = imdecode(mat, IMREAD_ANYDEPTH | IMREAD_ANYCOLOR);
|
Mat image = imdecode(mat, IMREAD_ANYDEPTH | IMREAD_ANYCOLOR);
|
||||||
if (image == null || image.empty()) {
|
if (image == null || image.empty()) {
|
||||||
|
@ -272,7 +282,11 @@ public class NativeImageLoader extends BaseImageLoader {
|
||||||
a = asMatrix(image);
|
a = asMatrix(image);
|
||||||
image.deallocate();
|
image.deallocate();
|
||||||
}
|
}
|
||||||
return a;
|
if(nchw) {
|
||||||
|
return a;
|
||||||
|
} else {
|
||||||
|
return a.permute(0, 2, 3, 1); //NCHW to NHWC
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -331,19 +345,29 @@ public class NativeImageLoader extends BaseImageLoader {
|
||||||
}
|
}
|
||||||
|
|
||||||
public Image asImageMatrix(String filename) throws IOException {
|
public Image asImageMatrix(String filename) throws IOException {
|
||||||
return asImageMatrix(filename);
|
return asImageMatrix(new File(filename));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Image asImageMatrix(File f) throws IOException {
|
public Image asImageMatrix(File f) throws IOException {
|
||||||
|
return asImageMatrix(f, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Image asImageMatrix(File f, boolean nchw) throws IOException {
|
||||||
try (BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f))) {
|
try (BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f))) {
|
||||||
return asImageMatrix(bis);
|
return asImageMatrix(bis, nchw);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Image asImageMatrix(InputStream is) throws IOException {
|
public Image asImageMatrix(InputStream is) throws IOException {
|
||||||
Mat mat = streamToMat(is);
|
return asImageMatrix(is, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Image asImageMatrix(InputStream inputStream, boolean nchw) throws IOException {
|
||||||
|
Mat mat = streamToMat(inputStream);
|
||||||
Mat image = imdecode(mat, IMREAD_ANYDEPTH | IMREAD_ANYCOLOR);
|
Mat image = imdecode(mat, IMREAD_ANYDEPTH | IMREAD_ANYCOLOR);
|
||||||
if (image == null || image.empty()) {
|
if (image == null || image.empty()) {
|
||||||
PIX pix = pixReadMem(mat.data(), mat.cols());
|
PIX pix = pixReadMem(mat.data(), mat.cols());
|
||||||
|
@ -354,6 +378,8 @@ public class NativeImageLoader extends BaseImageLoader {
|
||||||
pixDestroy(pix);
|
pixDestroy(pix);
|
||||||
}
|
}
|
||||||
INDArray a = asMatrix(image);
|
INDArray a = asMatrix(image);
|
||||||
|
if(!nchw)
|
||||||
|
a = a.permute(0,2,3,1); //NCHW to NHWC
|
||||||
Image i = new Image(a, image.channels(), image.rows(), image.cols());
|
Image i = new Image(a, image.channels(), image.rows(), image.cols());
|
||||||
|
|
||||||
image.deallocate();
|
image.deallocate();
|
||||||
|
|
|
@ -77,6 +77,8 @@ public abstract class BaseImageRecordReader extends BaseRecordReader {
|
||||||
protected int patternPosition = 0;
|
protected int patternPosition = 0;
|
||||||
@Getter @Setter
|
@Getter @Setter
|
||||||
protected boolean logLabelCountOnInit = true;
|
protected boolean logLabelCountOnInit = true;
|
||||||
|
@Getter @Setter
|
||||||
|
protected boolean nchw_channels_first = true;
|
||||||
|
|
||||||
public final static String HEIGHT = NAME_SPACE + ".height";
|
public final static String HEIGHT = NAME_SPACE + ".height";
|
||||||
public final static String WIDTH = NAME_SPACE + ".width";
|
public final static String WIDTH = NAME_SPACE + ".width";
|
||||||
|
@ -101,6 +103,11 @@ public abstract class BaseImageRecordReader extends BaseRecordReader {
|
||||||
|
|
||||||
protected BaseImageRecordReader(long height, long width, long channels, PathLabelGenerator labelGenerator,
|
protected BaseImageRecordReader(long height, long width, long channels, PathLabelGenerator labelGenerator,
|
||||||
PathMultiLabelGenerator labelMultiGenerator, ImageTransform imageTransform) {
|
PathMultiLabelGenerator labelMultiGenerator, ImageTransform imageTransform) {
|
||||||
|
this(height, width, channels, true, labelGenerator, labelMultiGenerator, imageTransform);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected BaseImageRecordReader(long height, long width, long channels, boolean nchw_channels_first, PathLabelGenerator labelGenerator,
|
||||||
|
PathMultiLabelGenerator labelMultiGenerator, ImageTransform imageTransform) {
|
||||||
this.height = height;
|
this.height = height;
|
||||||
this.width = width;
|
this.width = width;
|
||||||
this.channels = channels;
|
this.channels = channels;
|
||||||
|
@ -108,6 +115,7 @@ public abstract class BaseImageRecordReader extends BaseRecordReader {
|
||||||
this.labelMultiGenerator = labelMultiGenerator;
|
this.labelMultiGenerator = labelMultiGenerator;
|
||||||
this.imageTransform = imageTransform;
|
this.imageTransform = imageTransform;
|
||||||
this.appendLabel = (labelGenerator != null || labelMultiGenerator != null);
|
this.appendLabel = (labelGenerator != null || labelMultiGenerator != null);
|
||||||
|
this.nchw_channels_first = nchw_channels_first;
|
||||||
}
|
}
|
||||||
|
|
||||||
protected boolean containsFormat(String format) {
|
protected boolean containsFormat(String format) {
|
||||||
|
@ -237,9 +245,13 @@ public abstract class BaseImageRecordReader extends BaseRecordReader {
|
||||||
return next();
|
return next();
|
||||||
try {
|
try {
|
||||||
invokeListeners(image);
|
invokeListeners(image);
|
||||||
INDArray row = imageLoader.asMatrix(image);
|
INDArray array = imageLoader.asMatrix(image);
|
||||||
Nd4j.getAffinityManager().ensureLocation(row, AffinityManager.Location.DEVICE);
|
if(!nchw_channels_first){
|
||||||
ret = RecordConverter.toRecord(row);
|
array = array.permute(0,2,3,1); //NCHW to NHWC
|
||||||
|
}
|
||||||
|
|
||||||
|
Nd4j.getAffinityManager().ensureLocation(array, AffinityManager.Location.DEVICE);
|
||||||
|
ret = RecordConverter.toRecord(array);
|
||||||
if (appendLabel || writeLabel){
|
if (appendLabel || writeLabel){
|
||||||
if(labelMultiGenerator != null){
|
if(labelMultiGenerator != null){
|
||||||
ret.addAll(labelMultiGenerator.getLabels(image.getPath()));
|
ret.addAll(labelMultiGenerator.getLabels(image.getPath()));
|
||||||
|
@ -286,7 +298,7 @@ public abstract class BaseImageRecordReader extends BaseRecordReader {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<List<Writable>> next(int num) {
|
public List<List<Writable>> next(int num) {
|
||||||
Preconditions.checkArgument(num > 0, "Number of examples must be > 0: got " + num);
|
Preconditions.checkArgument(num > 0, "Number of examples must be > 0: got %s", num);
|
||||||
|
|
||||||
if (imageLoader == null) {
|
if (imageLoader == null) {
|
||||||
imageLoader = new NativeImageLoader(height, width, channels, imageTransform);
|
imageLoader = new NativeImageLoader(height, width, channels, imageTransform);
|
||||||
|
@ -337,6 +349,9 @@ public abstract class BaseImageRecordReader extends BaseRecordReader {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if(!nchw_channels_first){
|
||||||
|
features = features.permute(0,2,3,1); //NCHW to NHWC
|
||||||
|
}
|
||||||
Nd4j.getAffinityManager().ensureLocation(features, AffinityManager.Location.DEVICE);
|
Nd4j.getAffinityManager().ensureLocation(features, AffinityManager.Location.DEVICE);
|
||||||
|
|
||||||
|
|
||||||
|
@ -483,8 +498,10 @@ public abstract class BaseImageRecordReader extends BaseRecordReader {
|
||||||
if (imageLoader == null) {
|
if (imageLoader == null) {
|
||||||
imageLoader = new NativeImageLoader(height, width, channels, imageTransform);
|
imageLoader = new NativeImageLoader(height, width, channels, imageTransform);
|
||||||
}
|
}
|
||||||
INDArray row = imageLoader.asMatrix(dataInputStream);
|
INDArray array = imageLoader.asMatrix(dataInputStream);
|
||||||
List<Writable> ret = RecordConverter.toRecord(row);
|
if(!nchw_channels_first)
|
||||||
|
array = array.permute(0,2,3,1);
|
||||||
|
List<Writable> ret = RecordConverter.toRecord(array);
|
||||||
if (appendLabel)
|
if (appendLabel)
|
||||||
ret.add(new IntWritable(labels.indexOf(getLabel(uri.getPath()))));
|
ret.add(new IntWritable(labels.indexOf(getLabel(uri.getPath()))));
|
||||||
return ret;
|
return ret;
|
||||||
|
|
|
@ -34,47 +34,70 @@ import org.datavec.image.transform.ImageTransform;
|
||||||
public class ImageRecordReader extends BaseImageRecordReader {
|
public class ImageRecordReader extends BaseImageRecordReader {
|
||||||
|
|
||||||
|
|
||||||
/** Loads images with height = 28, width = 28, and channels = 1, appending no labels. */
|
/** Loads images with height = 28, width = 28, and channels = 1, appending no labels.
|
||||||
|
* Output format is NCHW (channels first) - [numExamples, 1, 28, 28]*/
|
||||||
public ImageRecordReader() {
|
public ImageRecordReader() {
|
||||||
super();
|
super();
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Loads images with given height, width, and channels, appending labels returned by the generator. */
|
/** Loads images with given height, width, and channels, appending labels returned by the generator.
|
||||||
|
* Output format is NCHW (channels first) - [numExamples, channels, height, width]
|
||||||
|
*/
|
||||||
public ImageRecordReader(long height, long width, long channels, PathLabelGenerator labelGenerator) {
|
public ImageRecordReader(long height, long width, long channels, PathLabelGenerator labelGenerator) {
|
||||||
super(height, width, channels, labelGenerator);
|
super(height, width, channels, labelGenerator);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Loads images with given height, width, and channels, appending labels returned by the generator. */
|
/** Loads images with given height, width, and channels, appending labels returned by the generator.
|
||||||
|
* Output format is NCHW (channels first) - [numExamples, channels, height, width]
|
||||||
|
*/
|
||||||
public ImageRecordReader(long height, long width, long channels, PathMultiLabelGenerator labelGenerator) {
|
public ImageRecordReader(long height, long width, long channels, PathMultiLabelGenerator labelGenerator) {
|
||||||
super(height, width, channels, labelGenerator);
|
super(height, width, channels, labelGenerator);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Loads images with given height, width, and channels, appending no labels. */
|
/** Loads images with given height, width, and channels, appending no labels - in NCHW (channels first) format */
|
||||||
public ImageRecordReader(long height, long width, long channels) {
|
public ImageRecordReader(long height, long width, long channels) {
|
||||||
super(height, width, channels, (PathLabelGenerator) null);
|
super(height, width, channels, (PathLabelGenerator) null);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Loads images with given height, width, and channels, appending labels returned by the generator. */
|
/** Loads images with given height, width, and channels, appending no labels - in specified format<br>
|
||||||
|
* If {@code nchw_channels_first == true} output format is NCHW (channels first) - [numExamples, channels, height, width]<br>
|
||||||
|
* If {@code nchw_channels_first == false} output format is NHWC (channels last) - [numExamples, height, width, channels]<br>
|
||||||
|
*/
|
||||||
|
public ImageRecordReader(long height, long width, long channels, boolean nchw_channels_first) {
|
||||||
|
super(height, width, channels, nchw_channels_first, null, null, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Loads images with given height, width, and channels, appending labels returned by the generator.
|
||||||
|
* Output format is NCHW (channels first) - [numExamples, channels, height, width] */
|
||||||
public ImageRecordReader(long height, long width, long channels, PathLabelGenerator labelGenerator,
|
public ImageRecordReader(long height, long width, long channels, PathLabelGenerator labelGenerator,
|
||||||
ImageTransform imageTransform) {
|
ImageTransform imageTransform) {
|
||||||
super(height, width, channels, labelGenerator, imageTransform);
|
super(height, width, channels, labelGenerator, imageTransform);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Loads images with given height, width, and channels, appending no labels. */
|
/** Loads images with given height, width, and channels, appending labels returned by the generator.<br>
|
||||||
|
* If {@code nchw_channels_first == true} output format is NCHW (channels first) - [numExamples, channels, height, width]<br>
|
||||||
|
* If {@code nchw_channels_first == false} output format is NHWC (channels last) - [numExamples, height, width, channels]<br>
|
||||||
|
*/
|
||||||
|
public ImageRecordReader(long height, long width, long channels, boolean nchw_channels_first, PathLabelGenerator labelGenerator,
|
||||||
|
ImageTransform imageTransform) {
|
||||||
|
super(height, width, channels, nchw_channels_first, labelGenerator, null, imageTransform);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Loads images with given height, width, and channels, appending no labels.
|
||||||
|
* Output format is NCHW (channels first) - [numExamples, channels, height, width]*/
|
||||||
public ImageRecordReader(long height, long width, long channels, ImageTransform imageTransform) {
|
public ImageRecordReader(long height, long width, long channels, ImageTransform imageTransform) {
|
||||||
super(height, width, channels, null, imageTransform);
|
super(height, width, channels, null, imageTransform);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Loads images with given height, width, and channels, appending labels returned by the generator. */
|
/** Loads images with given height, width, and channels, appending labels returned by the generator
|
||||||
|
* Output format is NCHW (channels first) - [numExamples, channels, height, width]*/
|
||||||
public ImageRecordReader(long height, long width, PathLabelGenerator labelGenerator) {
|
public ImageRecordReader(long height, long width, PathLabelGenerator labelGenerator) {
|
||||||
super(height, width, 1, labelGenerator);
|
super(height, width, 1, labelGenerator);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Loads images with given height, width, and channels = 1, appending no labels. */
|
/** Loads images with given height, width, and channels = 1, appending no labels.
|
||||||
|
* Output format is NCHW (channels first) - [numExamples, channels, height, width]*/
|
||||||
public ImageRecordReader(long height, long width) {
|
public ImageRecordReader(long height, long width) {
|
||||||
super(height, width, 1, null, null);
|
super(height, width, 1, null, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,10 +16,16 @@
|
||||||
|
|
||||||
package org.datavec.image.loader;
|
package org.datavec.image.loader;
|
||||||
|
|
||||||
|
import org.datavec.image.data.Image;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
import org.nd4j.common.resources.Resources;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
import java.awt.image.BufferedImage;
|
import java.awt.image.BufferedImage;
|
||||||
|
import java.io.BufferedInputStream;
|
||||||
|
import java.io.File;
|
||||||
|
import java.io.FileInputStream;
|
||||||
|
import java.io.InputStream;
|
||||||
import java.util.Random;
|
import java.util.Random;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
|
@ -208,4 +214,57 @@ public class TestImageLoader {
|
||||||
private BufferedImage makeRandomBufferedImage(boolean alpha) {
|
private BufferedImage makeRandomBufferedImage(boolean alpha) {
|
||||||
return makeRandomBufferedImage(alpha, rng.nextInt() % 100 + 100, rng.nextInt() % 100 + 100);
|
return makeRandomBufferedImage(alpha, rng.nextInt() % 100 + 100, rng.nextInt() % 100 + 100);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testNCHW_NHWC() throws Exception {
|
||||||
|
File f = Resources.asFile("datavec-data-image/voc/2007/JPEGImages/000005.jpg");
|
||||||
|
|
||||||
|
ImageLoader il = new ImageLoader(32, 32, 3);
|
||||||
|
|
||||||
|
//asMatrix(File, boolean)
|
||||||
|
INDArray a_nchw = il.asMatrix(f);
|
||||||
|
INDArray a_nchw2 = il.asMatrix(f, true);
|
||||||
|
INDArray a_nhwc = il.asMatrix(f, false);
|
||||||
|
|
||||||
|
assertEquals(a_nchw, a_nchw2);
|
||||||
|
assertEquals(a_nchw, a_nhwc.permute(0,3,1,2));
|
||||||
|
|
||||||
|
|
||||||
|
//asMatrix(InputStream, boolean)
|
||||||
|
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
|
||||||
|
a_nchw = il.asMatrix(is);
|
||||||
|
}
|
||||||
|
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
|
||||||
|
a_nchw2 = il.asMatrix(is, true);
|
||||||
|
}
|
||||||
|
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
|
||||||
|
a_nhwc = il.asMatrix(is, false);
|
||||||
|
}
|
||||||
|
assertEquals(a_nchw, a_nchw2);
|
||||||
|
assertEquals(a_nchw, a_nhwc.permute(0,3,1,2));
|
||||||
|
|
||||||
|
|
||||||
|
//asImageMatrix(File, boolean)
|
||||||
|
Image i_nchw = il.asImageMatrix(f);
|
||||||
|
Image i_nchw2 = il.asImageMatrix(f, true);
|
||||||
|
Image i_nhwc = il.asImageMatrix(f, false);
|
||||||
|
|
||||||
|
assertEquals(i_nchw.getImage(), i_nchw2.getImage());
|
||||||
|
assertEquals(i_nchw.getImage(), i_nhwc.getImage().permute(0,3,1,2)); //NHWC to NCHW
|
||||||
|
|
||||||
|
|
||||||
|
//asImageMatrix(InputStream, boolean)
|
||||||
|
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
|
||||||
|
i_nchw = il.asImageMatrix(is);
|
||||||
|
}
|
||||||
|
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
|
||||||
|
i_nchw2 = il.asImageMatrix(is, true);
|
||||||
|
}
|
||||||
|
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
|
||||||
|
i_nhwc = il.asImageMatrix(is, false);
|
||||||
|
}
|
||||||
|
assertEquals(i_nchw.getImage(), i_nchw2.getImage());
|
||||||
|
assertEquals(i_nchw.getImage(), i_nhwc.getImage().permute(0,3,1,2)); //NHWC to NCHW
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,20 +24,19 @@ import org.bytedeco.javacpp.indexer.UByteIndexer;
|
||||||
import org.bytedeco.javacv.Frame;
|
import org.bytedeco.javacv.Frame;
|
||||||
import org.bytedeco.javacv.Java2DFrameConverter;
|
import org.bytedeco.javacv.Java2DFrameConverter;
|
||||||
import org.bytedeco.javacv.OpenCVFrameConverter;
|
import org.bytedeco.javacv.OpenCVFrameConverter;
|
||||||
|
import org.datavec.image.data.Image;
|
||||||
import org.datavec.image.data.ImageWritable;
|
import org.datavec.image.data.ImageWritable;
|
||||||
import org.junit.Rule;
|
import org.junit.Rule;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.junit.rules.TemporaryFolder;
|
import org.junit.rules.TemporaryFolder;
|
||||||
|
import org.nd4j.common.resources.Resources;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.common.io.ClassPathResource;
|
import org.nd4j.common.io.ClassPathResource;
|
||||||
|
|
||||||
import java.awt.image.BufferedImage;
|
import java.awt.image.BufferedImage;
|
||||||
import java.io.File;
|
import java.io.*;
|
||||||
import java.io.FileInputStream;
|
|
||||||
import java.io.InputStream;
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.lang.reflect.Field;
|
import java.lang.reflect.Field;
|
||||||
import java.util.Random;
|
import java.util.Random;
|
||||||
|
|
||||||
|
@ -604,4 +603,56 @@ public class TestNativeImageLoader {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testNCHW_NHWC() throws Exception {
|
||||||
|
File f = Resources.asFile("datavec-data-image/voc/2007/JPEGImages/000005.jpg");
|
||||||
|
|
||||||
|
NativeImageLoader il = new NativeImageLoader(32, 32, 3);
|
||||||
|
|
||||||
|
//asMatrix(File, boolean)
|
||||||
|
INDArray a_nchw = il.asMatrix(f);
|
||||||
|
INDArray a_nchw2 = il.asMatrix(f, true);
|
||||||
|
INDArray a_nhwc = il.asMatrix(f, false);
|
||||||
|
|
||||||
|
assertEquals(a_nchw, a_nchw2);
|
||||||
|
assertEquals(a_nchw, a_nhwc.permute(0,3,1,2));
|
||||||
|
|
||||||
|
|
||||||
|
//asMatrix(InputStream, boolean)
|
||||||
|
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
|
||||||
|
a_nchw = il.asMatrix(is);
|
||||||
|
}
|
||||||
|
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
|
||||||
|
a_nchw2 = il.asMatrix(is, true);
|
||||||
|
}
|
||||||
|
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
|
||||||
|
a_nhwc = il.asMatrix(is, false);
|
||||||
|
}
|
||||||
|
assertEquals(a_nchw, a_nchw2);
|
||||||
|
assertEquals(a_nchw, a_nhwc.permute(0,3,1,2));
|
||||||
|
|
||||||
|
|
||||||
|
//asImageMatrix(File, boolean)
|
||||||
|
Image i_nchw = il.asImageMatrix(f);
|
||||||
|
Image i_nchw2 = il.asImageMatrix(f, true);
|
||||||
|
Image i_nhwc = il.asImageMatrix(f, false);
|
||||||
|
|
||||||
|
assertEquals(i_nchw.getImage(), i_nchw2.getImage());
|
||||||
|
assertEquals(i_nchw.getImage(), i_nhwc.getImage().permute(0,3,1,2)); //NHWC to NCHW
|
||||||
|
|
||||||
|
|
||||||
|
//asImageMatrix(InputStream, boolean)
|
||||||
|
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
|
||||||
|
i_nchw = il.asImageMatrix(is);
|
||||||
|
}
|
||||||
|
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
|
||||||
|
i_nchw2 = il.asImageMatrix(is, true);
|
||||||
|
}
|
||||||
|
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
|
||||||
|
i_nhwc = il.asImageMatrix(is, false);
|
||||||
|
}
|
||||||
|
assertEquals(i_nchw.getImage(), i_nchw2.getImage());
|
||||||
|
assertEquals(i_nchw.getImage(), i_nhwc.getImage().permute(0,3,1,2)); //NHWC to NCHW
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -35,13 +35,13 @@ import org.datavec.api.writable.batch.NDArrayRecordBatch;
|
||||||
import org.junit.Rule;
|
import org.junit.Rule;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.junit.rules.TemporaryFolder;
|
import org.junit.rules.TemporaryFolder;
|
||||||
|
import org.nd4j.common.resources.Resources;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.common.io.ClassPathResource;
|
import org.nd4j.common.io.ClassPathResource;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.*;
|
||||||
import java.io.IOException;
|
|
||||||
import java.net.URI;
|
import java.net.URI;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
@ -467,5 +467,87 @@ public class TestImageRecordReader {
|
||||||
return count;
|
return count;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testNCHW_NCHW() throws Exception {
|
||||||
|
//Idea: labels order should be consistent regardless of input file order
|
||||||
|
File f0 = testDir.newFolder();
|
||||||
|
new ClassPathResource("datavec-data-image/testimages/").copyDirectory(f0);
|
||||||
|
|
||||||
|
FileSplit fs0 = new FileSplit(f0, new Random(12345));
|
||||||
|
FileSplit fs1 = new FileSplit(f0, new Random(12345));
|
||||||
|
assertEquals(6, fs0.locations().length);
|
||||||
|
assertEquals(6, fs1.locations().length);
|
||||||
|
|
||||||
|
ImageRecordReader nchw = new ImageRecordReader(32, 32, 3, true);
|
||||||
|
nchw.initialize(fs0);
|
||||||
|
|
||||||
|
ImageRecordReader nhwc = new ImageRecordReader(32, 32, 3, false);
|
||||||
|
nhwc.initialize(fs1);
|
||||||
|
|
||||||
|
while(nchw.hasNext()){
|
||||||
|
assertTrue(nhwc.hasNext());
|
||||||
|
|
||||||
|
List<Writable> l_nchw = nchw.next();
|
||||||
|
List<Writable> l_nhwc = nhwc.next();
|
||||||
|
|
||||||
|
INDArray a_nchw = ((NDArrayWritable)l_nchw.get(0)).get();
|
||||||
|
INDArray a_nhwc = ((NDArrayWritable)l_nhwc.get(0)).get();
|
||||||
|
|
||||||
|
assertArrayEquals(new long[]{1, 3, 32, 32}, a_nchw.shape());
|
||||||
|
assertArrayEquals(new long[]{1, 32, 32, 3}, a_nhwc.shape());
|
||||||
|
|
||||||
|
INDArray permuted = a_nhwc.permute(0,3,1,2); //NHWC to NCHW
|
||||||
|
assertEquals(a_nchw, permuted);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
//Test batch:
|
||||||
|
nchw.reset();
|
||||||
|
nhwc.reset();
|
||||||
|
|
||||||
|
int batchCount = 0;
|
||||||
|
while(nchw.hasNext()){
|
||||||
|
assertTrue(nhwc.hasNext());
|
||||||
|
batchCount++;
|
||||||
|
|
||||||
|
List<List<Writable>> l_nchw = nchw.next(3);
|
||||||
|
List<List<Writable>> l_nhwc = nhwc.next(3);
|
||||||
|
assertEquals(3, l_nchw.size());
|
||||||
|
assertEquals(3, l_nhwc.size());
|
||||||
|
|
||||||
|
NDArrayRecordBatch b_nchw = (NDArrayRecordBatch)l_nchw;
|
||||||
|
NDArrayRecordBatch b_nhwc = (NDArrayRecordBatch)l_nhwc;
|
||||||
|
|
||||||
|
INDArray a_nchw = b_nchw.getArrays().get(0);
|
||||||
|
INDArray a_nhwc = b_nhwc.getArrays().get(0);
|
||||||
|
|
||||||
|
assertArrayEquals(new long[]{3, 3, 32, 32}, a_nchw.shape());
|
||||||
|
assertArrayEquals(new long[]{3, 32, 32, 3}, a_nhwc.shape());
|
||||||
|
|
||||||
|
INDArray permuted = a_nhwc.permute(0,3,1,2); //NHWC to NCHW
|
||||||
|
assertEquals(a_nchw, permuted);
|
||||||
|
}
|
||||||
|
assertEquals(2, batchCount);
|
||||||
|
|
||||||
|
|
||||||
|
//Test record(URI, DataInputStream)
|
||||||
|
|
||||||
|
URI u = fs0.locations()[0];
|
||||||
|
|
||||||
|
try(DataInputStream dis = new DataInputStream(new BufferedInputStream(new FileInputStream(new File(u))))) {
|
||||||
|
List<Writable> l = nchw.record(u, dis);
|
||||||
|
INDArray arr = ((NDArrayWritable)l.get(0)).get();
|
||||||
|
assertArrayEquals(new long[]{1, 3, 32, 32}, arr.shape());
|
||||||
|
}
|
||||||
|
|
||||||
|
try(DataInputStream dis = new DataInputStream(new BufferedInputStream(new FileInputStream(new File(u))))) {
|
||||||
|
List<Writable> l = nhwc.record(u, dis);
|
||||||
|
INDArray arr = ((NDArrayWritable)l.get(0)).get();
|
||||||
|
assertArrayEquals(new long[]{1, 32, 32, 3}, arr.shape());
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -8,12 +8,14 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
import org.junit.Ignore;
|
import org.junit.Ignore;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
import org.nd4j.common.resources.Resources;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.learning.config.RmsProp;
|
import org.nd4j.linalg.learning.config.RmsProp;
|
||||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
|
|
||||||
|
import java.nio.file.Files;
|
||||||
import java.util.concurrent.CountDownLatch;
|
import java.util.concurrent.CountDownLatch;
|
||||||
|
|
||||||
@Ignore
|
@Ignore
|
||||||
|
|
|
@ -18,11 +18,9 @@ package org.deeplearning4j.nn.layers.convolution;
|
||||||
import lombok.*;
|
import lombok.*;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.TestUtils;
|
import org.deeplearning4j.TestUtils;
|
||||||
|
import org.deeplearning4j.exception.DL4JInvalidInputException;
|
||||||
import org.deeplearning4j.nn.api.MaskState;
|
import org.deeplearning4j.nn.api.MaskState;
|
||||||
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
import org.deeplearning4j.nn.conf.*;
|
||||||
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
|
||||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.conf.layers.*;
|
import org.deeplearning4j.nn.conf.layers.*;
|
||||||
import org.deeplearning4j.nn.conf.layers.CnnLossLayer;
|
import org.deeplearning4j.nn.conf.layers.CnnLossLayer;
|
||||||
|
@ -35,6 +33,7 @@ import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
import org.deeplearning4j.nn.workspace.ArrayType;
|
import org.deeplearning4j.nn.workspace.ArrayType;
|
||||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||||
|
import org.deeplearning4j.util.ConvolutionUtils;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.junit.runner.RunWith;
|
import org.junit.runner.RunWith;
|
||||||
import org.junit.runners.Parameterized;
|
import org.junit.runners.Parameterized;
|
||||||
|
@ -49,6 +48,7 @@ import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
|
import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
@RunWith(Parameterized.class)
|
@RunWith(Parameterized.class)
|
||||||
public class ConvDataFormatTests extends BaseDL4JTest {
|
public class ConvDataFormatTests extends BaseDL4JTest {
|
||||||
|
@ -971,4 +971,58 @@ public class ConvDataFormatTests extends BaseDL4JTest {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testWrongFormatIn(){
|
||||||
|
|
||||||
|
for(CNN2DFormat df : CNN2DFormat.values()){
|
||||||
|
|
||||||
|
|
||||||
|
for(int i=0; i<4; i++ ){
|
||||||
|
|
||||||
|
NeuralNetConfiguration.ListBuilder b = new NeuralNetConfiguration.Builder()
|
||||||
|
.list();
|
||||||
|
switch (i){
|
||||||
|
case 0:
|
||||||
|
b.layer(new ConvolutionLayer.Builder().kernelSize(2,2).nIn(3).nOut(3).dataFormat(df).build());
|
||||||
|
break;
|
||||||
|
case 1:
|
||||||
|
b.layer(new DepthwiseConvolution2D.Builder().kernelSize(2,2).nIn(3).nOut(3).dataFormat(df).build());
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
b.layer(new Deconvolution2D.Builder().dataFormat(df).kernelSize(2,2).nIn(3).nOut(3).build());
|
||||||
|
break;
|
||||||
|
case 3:
|
||||||
|
b.layer(new SeparableConvolution2D.Builder().dataFormat(df).kernelSize(2,2).nIn(3).nOut(3).build());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
MultiLayerNetwork net = new MultiLayerNetwork(b.build());
|
||||||
|
net.init();
|
||||||
|
|
||||||
|
INDArray in;
|
||||||
|
INDArray wrongFormatIn;
|
||||||
|
if(df == CNN2DFormat.NCHW){
|
||||||
|
in = Nd4j.create(DataType.FLOAT, 5, 3, 12, 12);
|
||||||
|
wrongFormatIn = Nd4j.create(DataType.FLOAT, 5, 12, 12, 3);
|
||||||
|
} else {
|
||||||
|
in = Nd4j.create(DataType.FLOAT, 5, 12, 12, 3);
|
||||||
|
wrongFormatIn = Nd4j.create(DataType.FLOAT, 5, 3, 12, 12);
|
||||||
|
}
|
||||||
|
|
||||||
|
net.output(in);
|
||||||
|
|
||||||
|
try {
|
||||||
|
net.output(wrongFormatIn);
|
||||||
|
} catch (DL4JInvalidInputException e){
|
||||||
|
// e.printStackTrace();
|
||||||
|
String msg = e.getMessage();
|
||||||
|
assertTrue(msg, msg.contains(ConvolutionUtils.NCHW_NHWC_ERROR_MSG));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,9 +23,13 @@ import org.deeplearning4j.optimize.solvers.accumulation.EncodedGradientsAccumula
|
||||||
import org.deeplearning4j.optimize.solvers.accumulation.EncodingHandler;
|
import org.deeplearning4j.optimize.solvers.accumulation.EncodingHandler;
|
||||||
import org.deeplearning4j.optimize.solvers.accumulation.encoding.threshold.FixedThresholdAlgorithm;
|
import org.deeplearning4j.optimize.solvers.accumulation.encoding.threshold.FixedThresholdAlgorithm;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
import org.nd4j.linalg.api.concurrency.AffinityManager;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.util.PrintAffinity;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.nativeblas.OpaqueDataBuffer;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertNotNull;
|
||||||
import static org.junit.Assert.assertTrue;
|
import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -93,12 +97,13 @@ public class EncodedGradientsAccumulatorTest extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
EncodingHandler handler = new EncodingHandler(new FixedThresholdAlgorithm(1e-3), null, null, false);
|
EncodingHandler handler = new EncodingHandler(new FixedThresholdAlgorithm(1e-3), null, Integer.MAX_VALUE, false);
|
||||||
for (int e = 10; e < numParams / 5; e++) {
|
for (int e = 10; e < numParams / 5; e++) {
|
||||||
|
|
||||||
INDArray encoded = handler.encodeUpdates(0, 0, getGradients(numParams, e, 2e-3));
|
val gradients = getGradients(numParams, e, 2e-3);
|
||||||
|
val encoded = handler.encodeUpdates(0, 0, gradients);
|
||||||
|
|
||||||
// log.info("enc len: {}", encoded.data().length());
|
assertNotNull("Failed with e == " + e, encoded);
|
||||||
|
|
||||||
int encFormat = encoded.data().getInt(3);
|
int encFormat = encoded.data().getInt(3);
|
||||||
|
|
||||||
|
|
|
@ -21,9 +21,9 @@ import lombok.val;
|
||||||
import org.apache.commons.lang3.RandomUtils;
|
import org.apache.commons.lang3.RandomUtils;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.optimize.solvers.accumulation.SmartFancyBlockingQueue;
|
import org.deeplearning4j.optimize.solvers.accumulation.SmartFancyBlockingQueue;
|
||||||
import org.deeplearning4j.core.util.ThreadUtils;
|
|
||||||
import org.junit.Ignore;
|
import org.junit.Ignore;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
import org.nd4j.common.util.ThreadUtils;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
|
|
@ -31,11 +31,6 @@
|
||||||
|
|
||||||
|
|
||||||
<dependencies>
|
<dependencies>
|
||||||
<dependency>
|
|
||||||
<groupId>org.deeplearning4j</groupId>
|
|
||||||
<artifactId>deeplearning4j-util</artifactId>
|
|
||||||
<version>${project.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.nd4j</groupId>
|
<groupId>org.nd4j</groupId>
|
||||||
<artifactId>nd4j-api</artifactId>
|
<artifactId>nd4j-api</artifactId>
|
||||||
|
|
|
@ -1,75 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.datasets.iterator.impl;
|
|
||||||
|
|
||||||
import org.deeplearning4j.util.MovingWindowMatrix;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.dataset.DataSet;
|
|
||||||
import org.nd4j.linalg.dataset.api.iterator.fetcher.BaseDataFetcher;
|
|
||||||
import org.nd4j.common.util.ArrayUtil;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* Moving window data fetcher. Handles rotation of matrices in all directions
|
|
||||||
* to generate more examples.
|
|
||||||
*
|
|
||||||
*
|
|
||||||
* @author Adam Gibson
|
|
||||||
*/
|
|
||||||
public class MovingWindowDataSetFetcher extends BaseDataFetcher {
|
|
||||||
|
|
||||||
private DataSet data;
|
|
||||||
private int windowRows = 28, windowColumns = 28;
|
|
||||||
private int cursor = 0;
|
|
||||||
|
|
||||||
public MovingWindowDataSetFetcher(DataSet data, int windowRows, int windowColumns) {
|
|
||||||
this.data = data;
|
|
||||||
this.windowRows = windowRows;
|
|
||||||
this.windowColumns = windowColumns;
|
|
||||||
List<DataSet> list = data.asList();
|
|
||||||
List<DataSet> flipped = new ArrayList<>();
|
|
||||||
for (int i = 0; i < list.size(); i++) {
|
|
||||||
INDArray label = list.get(i).getLabels();
|
|
||||||
List<INDArray> windows =
|
|
||||||
new MovingWindowMatrix(list.get(i).getFeatures(), windowRows, windowColumns, true)
|
|
||||||
.windows(true);
|
|
||||||
for (int j = 0; j < windows.size(); j++) {
|
|
||||||
flipped.add(new DataSet(windows.get(j), label));
|
|
||||||
}
|
|
||||||
flipped.add(list.get(i));
|
|
||||||
}
|
|
||||||
|
|
||||||
this.data = DataSet.merge(flipped);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Fetches the next dataset. You need to call this
|
|
||||||
* to get a new dataset, otherwise {@link #next()}
|
|
||||||
* just returns the last data applyTransformToDestination fetch
|
|
||||||
*
|
|
||||||
* @param numExamples the number of examples to fetch
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public void fetch(int numExamples) {
|
|
||||||
initializeCurrFromList(data.get(ArrayUtil.range(cursor, cursor + numExamples)).asList());
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -28,6 +29,7 @@ import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
|
import org.nd4j.linalg.lossfunctions.ILossFunction;
|
||||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
@ -45,7 +47,7 @@ import static org.deeplearning4j.nn.modelimport.keras.utils.KerasLossUtils.mapLo
|
||||||
public class KerasLoss extends KerasLayer {
|
public class KerasLoss extends KerasLayer {
|
||||||
|
|
||||||
private final String KERAS_CLASS_NAME_LOSS = "Loss";
|
private final String KERAS_CLASS_NAME_LOSS = "Loss";
|
||||||
private LossFunctions.LossFunction loss;
|
private ILossFunction loss;
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -86,7 +88,7 @@ public class KerasLoss extends KerasLayer {
|
||||||
if (enforceTrainingConfig)
|
if (enforceTrainingConfig)
|
||||||
throw e;
|
throw e;
|
||||||
log.warn("Unsupported Keras loss function. Replacing with MSE.");
|
log.warn("Unsupported Keras loss function. Replacing with MSE.");
|
||||||
loss = LossFunctions.LossFunction.SQUARED_LOSS;
|
loss = LossFunctions.LossFunction.SQUARED_LOSS.getILossFunction();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -19,8 +20,13 @@ package org.deeplearning4j.nn.modelimport.keras.utils;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
|
import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
||||||
|
import org.nd4j.linalg.lossfunctions.ILossFunction;
|
||||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
|
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Utility functionality for keras loss functions
|
* Utility functionality for keras loss functions
|
||||||
*
|
*
|
||||||
|
@ -28,13 +34,33 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
*/
|
*/
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class KerasLossUtils {
|
public class KerasLossUtils {
|
||||||
|
static final Map<String, ILossFunction> customLoss = new HashMap<>();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Register a custom loss function
|
||||||
|
*
|
||||||
|
* @param lossName name of the lambda layer in the serialized Keras model
|
||||||
|
* @param lossFunction SameDiffLambdaLayer instance to map to Keras Lambda layer
|
||||||
|
*/
|
||||||
|
public static void registerCustomLoss(String lossName, ILossFunction lossFunction) {
|
||||||
|
customLoss.put(lossName, lossFunction);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Clear all lambda layers
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
public static void clearCustomLoss() {
|
||||||
|
customLoss.clear();
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Map Keras to DL4J loss functions.
|
* Map Keras to DL4J loss functions.
|
||||||
*
|
*
|
||||||
* @param kerasLoss String containing Keras loss function name
|
* @param kerasLoss String containing Keras loss function name
|
||||||
* @return String containing DL4J loss function
|
* @return String containing DL4J loss function
|
||||||
*/
|
*/
|
||||||
public static LossFunctions.LossFunction mapLossFunction(String kerasLoss, KerasLayerConfiguration conf)
|
public static ILossFunction mapLossFunction(String kerasLoss, KerasLayerConfiguration conf)
|
||||||
throws UnsupportedKerasConfigurationException {
|
throws UnsupportedKerasConfigurationException {
|
||||||
LossFunctions.LossFunction dl4jLoss;
|
LossFunctions.LossFunction dl4jLoss;
|
||||||
if (kerasLoss.equals(conf.getKERAS_LOSS_MEAN_SQUARED_ERROR()) ||
|
if (kerasLoss.equals(conf.getKERAS_LOSS_MEAN_SQUARED_ERROR()) ||
|
||||||
|
@ -67,8 +93,13 @@ public class KerasLossUtils {
|
||||||
} else if (kerasLoss.equals(conf.getKERAS_LOSS_COSINE_PROXIMITY())) {
|
} else if (kerasLoss.equals(conf.getKERAS_LOSS_COSINE_PROXIMITY())) {
|
||||||
dl4jLoss = LossFunctions.LossFunction.COSINE_PROXIMITY;
|
dl4jLoss = LossFunctions.LossFunction.COSINE_PROXIMITY;
|
||||||
} else {
|
} else {
|
||||||
throw new UnsupportedKerasConfigurationException("Unknown Keras loss function " + kerasLoss);
|
ILossFunction lossClass = customLoss.get(kerasLoss);
|
||||||
|
if(lossClass != null){
|
||||||
|
return lossClass;
|
||||||
|
}else{
|
||||||
|
throw new UnsupportedKerasConfigurationException("Unknown Keras loss function " + kerasLoss);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return dl4jLoss;
|
return dl4jLoss.getILossFunction();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,78 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.deeplearning4j.nn.modelimport.keras.e2e;
|
||||||
|
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
|
import org.deeplearning4j.nn.modelimport.keras.KerasSequentialModel;
|
||||||
|
import org.deeplearning4j.nn.modelimport.keras.utils.KerasLossUtils;
|
||||||
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
|
import org.junit.Rule;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.junit.rules.TemporaryFolder;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.common.resources.Resources;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.lossfunctions.SameDiffLoss;
|
||||||
|
|
||||||
|
import java.io.File;
|
||||||
|
import java.io.InputStream;
|
||||||
|
import java.nio.file.Files;
|
||||||
|
import java.nio.file.StandardCopyOption;
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Test importing Keras models with custom loss.
|
||||||
|
*
|
||||||
|
* @author Paul Dubs
|
||||||
|
*/
|
||||||
|
public class KerasCustomLossTest extends BaseDL4JTest {
|
||||||
|
|
||||||
|
@Rule
|
||||||
|
public TemporaryFolder testDir = new TemporaryFolder();
|
||||||
|
|
||||||
|
public class LogCosh extends SameDiffLoss {
|
||||||
|
@Override
|
||||||
|
public SDVariable defineLoss(SameDiff sd, SDVariable layerInput, SDVariable labels) {
|
||||||
|
return sd.math.log(sd.math.cosh(labels.sub(layerInput)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testSequentialLambdaLayerImport() throws Exception {
|
||||||
|
KerasLossUtils.registerCustomLoss("logcosh", new LogCosh());
|
||||||
|
|
||||||
|
String modelPath = "modelimport/keras/examples/custom_loss.h5";
|
||||||
|
|
||||||
|
try(InputStream is = Resources.asStream(modelPath)) {
|
||||||
|
File modelFile = testDir.newFile("tempModel" + System.currentTimeMillis() + ".h5");
|
||||||
|
Files.copy(is, modelFile.toPath(), StandardCopyOption.REPLACE_EXISTING);
|
||||||
|
MultiLayerNetwork model = new KerasSequentialModel().modelBuilder().modelHdf5Filename(modelFile.getAbsolutePath())
|
||||||
|
.enforceTrainingConfig(true).buildSequential().getMultiLayerNetwork();
|
||||||
|
|
||||||
|
System.out.println(model.summary());
|
||||||
|
INDArray input = Nd4j.create(new int[]{10, 3});
|
||||||
|
|
||||||
|
model.output(input);
|
||||||
|
} finally {
|
||||||
|
KerasLossUtils.clearCustomLoss();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
|
@ -856,15 +856,26 @@ public class WordVectorSerializerTest extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testFastText() {
|
public void testFastText() {
|
||||||
|
File[] files = { fastTextRaw, fastTextZip, fastTextGzip };
|
||||||
File[] files = {fastTextRaw, fastTextZip, fastTextGzip};
|
|
||||||
for (File file : files) {
|
for (File file : files) {
|
||||||
try {
|
try {
|
||||||
Word2Vec word2Vec = WordVectorSerializer.readAsCsv(file);
|
Word2Vec word2Vec = WordVectorSerializer.readAsCsv(file);
|
||||||
assertEquals(99, word2Vec.getVocab().numWords());
|
assertEquals(99, word2Vec.getVocab().numWords());
|
||||||
|
} catch (Exception readCsvException) {
|
||||||
|
fail("Failure for input file " + file.getAbsolutePath() + " " + readCsvException.getMessage());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} catch (Exception e) {
|
@Test
|
||||||
fail("Failure for input file " + file.getAbsolutePath() + " " + e.getMessage());
|
public void testFastText_readWord2VecModel() {
|
||||||
|
File[] files = { fastTextRaw, fastTextZip, fastTextGzip };
|
||||||
|
for (File file : files) {
|
||||||
|
try {
|
||||||
|
Word2Vec word2Vec = WordVectorSerializer.readWord2VecModel(file);
|
||||||
|
assertEquals(99, word2Vec.getVocab().numWords());
|
||||||
|
} catch (Exception readCsvException) {
|
||||||
|
fail("Failure for input file " + file.getAbsolutePath() + " " + readCsvException.getMessage());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -84,6 +84,12 @@
|
||||||
<version>${project.version}</version>
|
<version>${project.version}</version>
|
||||||
<scope>test</scope>
|
<scope>test</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.awaitility</groupId>
|
||||||
|
<artifactId>awaitility</artifactId>
|
||||||
|
<version>4.0.2</version>
|
||||||
|
<scope>test</scope>
|
||||||
|
</dependency>
|
||||||
</dependencies>
|
</dependencies>
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -16,14 +17,45 @@
|
||||||
|
|
||||||
package org.deeplearning4j.models.embeddings.loader;
|
package org.deeplearning4j.models.embeddings.loader;
|
||||||
|
|
||||||
import lombok.*;
|
import java.io.BufferedInputStream;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import java.io.BufferedOutputStream;
|
||||||
|
import java.io.BufferedReader;
|
||||||
|
import java.io.BufferedWriter;
|
||||||
|
import java.io.ByteArrayInputStream;
|
||||||
|
import java.io.DataInputStream;
|
||||||
|
import java.io.DataOutputStream;
|
||||||
|
import java.io.File;
|
||||||
|
import java.io.FileInputStream;
|
||||||
|
import java.io.FileNotFoundException;
|
||||||
|
import java.io.FileOutputStream;
|
||||||
|
import java.io.FileReader;
|
||||||
|
import java.io.FileWriter;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.InputStream;
|
||||||
|
import java.io.InputStreamReader;
|
||||||
|
import java.io.ObjectInputStream;
|
||||||
|
import java.io.ObjectOutputStream;
|
||||||
|
import java.io.OutputStream;
|
||||||
|
import java.io.OutputStreamWriter;
|
||||||
|
import java.io.PrintWriter;
|
||||||
|
import java.io.UnsupportedEncodingException;
|
||||||
|
import java.nio.charset.StandardCharsets;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.concurrent.atomic.AtomicInteger;
|
||||||
|
import java.util.zip.GZIPInputStream;
|
||||||
|
import java.util.zip.ZipEntry;
|
||||||
|
import java.util.zip.ZipFile;
|
||||||
|
import java.util.zip.ZipInputStream;
|
||||||
|
import java.util.zip.ZipOutputStream;
|
||||||
|
|
||||||
import org.apache.commons.codec.binary.Base64;
|
import org.apache.commons.codec.binary.Base64;
|
||||||
import org.apache.commons.compress.compressors.gzip.GzipUtils;
|
import org.apache.commons.compress.compressors.gzip.GzipUtils;
|
||||||
import org.apache.commons.io.FileUtils;
|
import org.apache.commons.io.FileUtils;
|
||||||
import org.apache.commons.io.IOUtils;
|
import org.apache.commons.io.IOUtils;
|
||||||
import org.apache.commons.io.LineIterator;
|
import org.apache.commons.io.LineIterator;
|
||||||
import org.apache.commons.io.output.CloseShieldOutputStream;
|
import org.apache.commons.io.output.CloseShieldOutputStream;
|
||||||
|
import org.deeplearning4j.common.util.DL4JFileUtils;
|
||||||
import org.deeplearning4j.exception.DL4JInvalidInputException;
|
import org.deeplearning4j.exception.DL4JInvalidInputException;
|
||||||
import org.deeplearning4j.models.embeddings.WeightLookupTable;
|
import org.deeplearning4j.models.embeddings.WeightLookupTable;
|
||||||
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
|
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
|
||||||
|
@ -50,26 +82,25 @@ import org.deeplearning4j.text.documentiterator.LabelsSource;
|
||||||
import org.deeplearning4j.text.sentenceiterator.BasicLineIterator;
|
import org.deeplearning4j.text.sentenceiterator.BasicLineIterator;
|
||||||
import org.deeplearning4j.text.tokenization.tokenizer.TokenPreProcess;
|
import org.deeplearning4j.text.tokenization.tokenizer.TokenPreProcess;
|
||||||
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
|
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
|
||||||
import org.deeplearning4j.common.util.DL4JFileUtils;
|
import org.nd4j.common.primitives.Pair;
|
||||||
|
import org.nd4j.common.util.OneTimeLogger;
|
||||||
import org.nd4j.compression.impl.NoOp;
|
import org.nd4j.compression.impl.NoOp;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.ops.transforms.Transforms;
|
import org.nd4j.linalg.ops.transforms.Transforms;
|
||||||
import org.nd4j.common.primitives.Pair;
|
|
||||||
import org.nd4j.shade.jackson.databind.DeserializationFeature;
|
import org.nd4j.shade.jackson.databind.DeserializationFeature;
|
||||||
import org.nd4j.shade.jackson.databind.MapperFeature;
|
import org.nd4j.shade.jackson.databind.MapperFeature;
|
||||||
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
||||||
import org.nd4j.shade.jackson.databind.SerializationFeature;
|
import org.nd4j.shade.jackson.databind.SerializationFeature;
|
||||||
import org.nd4j.storage.CompressedRamStorage;
|
import org.nd4j.storage.CompressedRamStorage;
|
||||||
import org.nd4j.common.util.OneTimeLogger;
|
|
||||||
|
|
||||||
import java.io.*;
|
import lombok.AllArgsConstructor;
|
||||||
import java.nio.charset.StandardCharsets;
|
import lombok.Data;
|
||||||
import java.util.ArrayList;
|
import lombok.NoArgsConstructor;
|
||||||
import java.util.List;
|
import lombok.NonNull;
|
||||||
import java.util.concurrent.atomic.AtomicInteger;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import java.util.zip.*;
|
import lombok.val;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This is utility class, providing various methods for WordVectors serialization
|
* This is utility class, providing various methods for WordVectors serialization
|
||||||
|
@ -85,14 +116,17 @@ import java.util.zip.*;
|
||||||
* {@link #writeWord2VecModel(Word2Vec, OutputStream)}
|
* {@link #writeWord2VecModel(Word2Vec, OutputStream)}
|
||||||
*
|
*
|
||||||
* <li>Deserializers for Word2Vec:</li>
|
* <li>Deserializers for Word2Vec:</li>
|
||||||
* {@link #readWord2VecModel(File)}
|
|
||||||
* {@link #readWord2VecModel(String)}
|
* {@link #readWord2VecModel(String)}
|
||||||
* {@link #readWord2VecModel(File, boolean)}
|
|
||||||
* {@link #readWord2VecModel(String, boolean)}
|
* {@link #readWord2VecModel(String, boolean)}
|
||||||
|
* {@link #readWord2VecModel(File)}
|
||||||
|
* {@link #readWord2VecModel(File, boolean)}
|
||||||
* {@link #readAsBinaryNoLineBreaks(File)}
|
* {@link #readAsBinaryNoLineBreaks(File)}
|
||||||
|
* {@link #readAsBinaryNoLineBreaks(InputStream)}
|
||||||
* {@link #readAsBinary(File)}
|
* {@link #readAsBinary(File)}
|
||||||
|
* {@link #readAsBinary(InputStream)}
|
||||||
* {@link #readAsCsv(File)}
|
* {@link #readAsCsv(File)}
|
||||||
* {@link #readBinaryModel(File, boolean, boolean)}
|
* {@link #readAsCsv(InputStream)}
|
||||||
|
* {@link #readBinaryModel(InputStream, boolean, boolean)}
|
||||||
* {@link #readWord2VecFromText(File, File, File, File, VectorsConfiguration)}
|
* {@link #readWord2VecFromText(File, File, File, File, VectorsConfiguration)}
|
||||||
* {@link #readWord2Vec(String, boolean)}
|
* {@link #readWord2Vec(String, boolean)}
|
||||||
* {@link #readWord2Vec(File, boolean)}
|
* {@link #readWord2Vec(File, boolean)}
|
||||||
|
@ -117,6 +151,7 @@ import java.util.zip.*;
|
||||||
* {@link #fromTableAndVocab(WeightLookupTable, VocabCache)}
|
* {@link #fromTableAndVocab(WeightLookupTable, VocabCache)}
|
||||||
* {@link #fromPair(Pair)}
|
* {@link #fromPair(Pair)}
|
||||||
* {@link #loadTxt(File)}
|
* {@link #loadTxt(File)}
|
||||||
|
* {@link #loadTxt(InputStream)}
|
||||||
*
|
*
|
||||||
* <li>Serializers to tSNE format</li>
|
* <li>Serializers to tSNE format</li>
|
||||||
* {@link #writeTsneFormat(Glove, INDArray, File)}
|
* {@link #writeTsneFormat(Glove, INDArray, File)}
|
||||||
|
@ -151,6 +186,7 @@ import java.util.zip.*;
|
||||||
* @author Adam Gibson
|
* @author Adam Gibson
|
||||||
* @author raver119
|
* @author raver119
|
||||||
* @author alexander@skymind.io
|
* @author alexander@skymind.io
|
||||||
|
* @author Alexei KLENIN
|
||||||
*/
|
*/
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class WordVectorSerializer {
|
public class WordVectorSerializer {
|
||||||
|
@ -215,18 +251,22 @@ public class WordVectorSerializer {
|
||||||
}*/
|
}*/
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Read a binary word2vec file.
|
* Read a binary word2vec from input stream.
|
||||||
|
*
|
||||||
|
* @param inputStream input stream to read
|
||||||
|
* @param linebreaks if true, the reader expects each word/vector to be in a separate line, terminated
|
||||||
|
* by a line break
|
||||||
|
* @param normalize
|
||||||
*
|
*
|
||||||
* @param modelFile the File to read
|
|
||||||
* @param linebreaks if true, the reader expects each word/vector to be in a separate line, terminated
|
|
||||||
* by a line break
|
|
||||||
* @return a {@link Word2Vec model}
|
* @return a {@link Word2Vec model}
|
||||||
* @throws NumberFormatException
|
* @throws NumberFormatException
|
||||||
* @throws IOException
|
* @throws IOException
|
||||||
* @throws FileNotFoundException
|
* @throws FileNotFoundException
|
||||||
*/
|
*/
|
||||||
public static Word2Vec readBinaryModel(File modelFile, boolean linebreaks, boolean normalize)
|
public static Word2Vec readBinaryModel(
|
||||||
throws NumberFormatException, IOException {
|
InputStream inputStream,
|
||||||
|
boolean linebreaks,
|
||||||
|
boolean normalize) throws NumberFormatException, IOException {
|
||||||
InMemoryLookupTable<VocabWord> lookupTable;
|
InMemoryLookupTable<VocabWord> lookupTable;
|
||||||
VocabCache<VocabWord> cache;
|
VocabCache<VocabWord> cache;
|
||||||
INDArray syn0;
|
INDArray syn0;
|
||||||
|
@ -240,9 +280,7 @@ public class WordVectorSerializer {
|
||||||
|
|
||||||
Nd4j.getMemoryManager().setOccasionalGcFrequency(50000);
|
Nd4j.getMemoryManager().setOccasionalGcFrequency(50000);
|
||||||
|
|
||||||
try (BufferedInputStream bis = new BufferedInputStream(GzipUtils.isCompressedFilename(modelFile.getName())
|
try (DataInputStream dis = new DataInputStream(inputStream)) {
|
||||||
? new GZIPInputStream(new FileInputStream(modelFile)) : new FileInputStream(modelFile));
|
|
||||||
DataInputStream dis = new DataInputStream(bis)) {
|
|
||||||
words = Integer.parseInt(ReadHelper.readString(dis));
|
words = Integer.parseInt(ReadHelper.readString(dis));
|
||||||
size = Integer.parseInt(ReadHelper.readString(dis));
|
size = Integer.parseInt(ReadHelper.readString(dis));
|
||||||
syn0 = Nd4j.create(words, size);
|
syn0 = Nd4j.create(words, size);
|
||||||
|
@ -250,23 +288,26 @@ public class WordVectorSerializer {
|
||||||
|
|
||||||
printOutProjectedMemoryUse(words, size, 1);
|
printOutProjectedMemoryUse(words, size, 1);
|
||||||
|
|
||||||
lookupTable = (InMemoryLookupTable<VocabWord>) new InMemoryLookupTable.Builder<VocabWord>().cache(cache)
|
lookupTable = new InMemoryLookupTable.Builder<VocabWord>()
|
||||||
.useHierarchicSoftmax(false).vectorLength(size).build();
|
.cache(cache)
|
||||||
|
.useHierarchicSoftmax(false)
|
||||||
|
.vectorLength(size)
|
||||||
|
.build();
|
||||||
|
|
||||||
int cnt = 0;
|
|
||||||
String word;
|
String word;
|
||||||
float[] vector = new float[size];
|
float[] vector = new float[size];
|
||||||
for (int i = 0; i < words; i++) {
|
for (int i = 0; i < words; i++) {
|
||||||
|
|
||||||
word = ReadHelper.readString(dis);
|
word = ReadHelper.readString(dis);
|
||||||
log.trace("Loading " + word + " with word " + i);
|
log.trace("Loading {} with word {}", word, i);
|
||||||
|
|
||||||
for (int j = 0; j < size; j++) {
|
for (int j = 0; j < size; j++) {
|
||||||
vector[j] = ReadHelper.readFloat(dis);
|
vector[j] = ReadHelper.readFloat(dis);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cache.containsWord(word))
|
if (cache.containsWord(word)) {
|
||||||
throw new ND4JIllegalStateException("Tried to add existing word. Probably time to switch linebreaks mode?");
|
throw new ND4JIllegalStateException(
|
||||||
|
"Tried to add existing word. Probably time to switch linebreaks mode?");
|
||||||
|
}
|
||||||
|
|
||||||
syn0.putRow(i, normalize ? Transforms.unitVec(Nd4j.create(vector)) : Nd4j.create(vector));
|
syn0.putRow(i, normalize ? Transforms.unitVec(Nd4j.create(vector)) : Nd4j.create(vector));
|
||||||
|
|
||||||
|
@ -285,25 +326,31 @@ public class WordVectorSerializer {
|
||||||
Nd4j.getMemoryManager().invokeGcOccasionally();
|
Nd4j.getMemoryManager().invokeGcOccasionally();
|
||||||
}
|
}
|
||||||
} finally {
|
} finally {
|
||||||
if (originalPeriodic)
|
if (originalPeriodic) {
|
||||||
Nd4j.getMemoryManager().togglePeriodicGc(true);
|
Nd4j.getMemoryManager().togglePeriodicGc(true);
|
||||||
|
}
|
||||||
|
|
||||||
Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
|
Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
lookupTable.setSyn0(syn0);
|
lookupTable.setSyn0(syn0);
|
||||||
|
|
||||||
|
Word2Vec ret = new Word2Vec
|
||||||
Word2Vec ret = new Word2Vec.Builder().useHierarchicSoftmax(false).resetModel(false).layerSize(syn0.columns())
|
.Builder()
|
||||||
.allowParallelTokenization(true).elementsLearningAlgorithm(new SkipGram<VocabWord>())
|
.useHierarchicSoftmax(false)
|
||||||
.learningRate(0.025).windowSize(5).workers(1).build();
|
.resetModel(false)
|
||||||
|
.layerSize(syn0.columns())
|
||||||
|
.allowParallelTokenization(true)
|
||||||
|
.elementsLearningAlgorithm(new SkipGram<VocabWord>())
|
||||||
|
.learningRate(0.025)
|
||||||
|
.windowSize(5)
|
||||||
|
.workers(1)
|
||||||
|
.build();
|
||||||
|
|
||||||
ret.setVocab(cache);
|
ret.setVocab(cache);
|
||||||
ret.setLookupTable(lookupTable);
|
ret.setLookupTable(lookupTable);
|
||||||
|
|
||||||
return ret;
|
return ret;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -927,7 +974,7 @@ public class WordVectorSerializer {
|
||||||
public static Word2Vec readWord2VecFromText(@NonNull File vectors, @NonNull File hs, @NonNull File h_codes,
|
public static Word2Vec readWord2VecFromText(@NonNull File vectors, @NonNull File hs, @NonNull File h_codes,
|
||||||
@NonNull File h_points, @NonNull VectorsConfiguration configuration) throws IOException {
|
@NonNull File h_points, @NonNull VectorsConfiguration configuration) throws IOException {
|
||||||
// first we load syn0
|
// first we load syn0
|
||||||
Pair<InMemoryLookupTable, VocabCache> pair = loadTxt(vectors);
|
Pair<InMemoryLookupTable, VocabCache> pair = loadTxt(new FileInputStream(vectors));
|
||||||
InMemoryLookupTable lookupTable = pair.getFirst();
|
InMemoryLookupTable lookupTable = pair.getFirst();
|
||||||
lookupTable.setNegative(configuration.getNegative());
|
lookupTable.setNegative(configuration.getNegative());
|
||||||
if (configuration.getNegative() > 0)
|
if (configuration.getNegative() > 0)
|
||||||
|
@ -1604,160 +1651,172 @@ public class WordVectorSerializer {
|
||||||
* @param vectorsFile the path of the file to load\
|
* @param vectorsFile the path of the file to load\
|
||||||
* @return
|
* @return
|
||||||
* @throws FileNotFoundException if the file does not exist
|
* @throws FileNotFoundException if the file does not exist
|
||||||
* @deprecated Use {@link #loadTxt(File)}
|
* @deprecated Use {@link #loadTxt(InputStream)}
|
||||||
*/
|
*/
|
||||||
@Deprecated
|
@Deprecated
|
||||||
public static WordVectors loadTxtVectors(File vectorsFile)
|
public static WordVectors loadTxtVectors(File vectorsFile) throws IOException {
|
||||||
throws IOException {
|
FileInputStream fileInputStream = new FileInputStream(vectorsFile);
|
||||||
Pair<InMemoryLookupTable, VocabCache> pair = loadTxt(vectorsFile);
|
Pair<InMemoryLookupTable, VocabCache> pair = loadTxt(fileInputStream);
|
||||||
return fromPair(pair);
|
return fromPair(pair);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static InputStream fileStream(@NonNull File file) throws IOException {
|
||||||
|
boolean isZip = file.getName().endsWith(".zip");
|
||||||
|
boolean isGzip = GzipUtils.isCompressedFilename(file.getName());
|
||||||
|
|
||||||
|
InputStream inputStream;
|
||||||
|
|
||||||
|
if (isZip) {
|
||||||
|
inputStream = decompressZip(file);
|
||||||
|
} else if (isGzip) {
|
||||||
|
FileInputStream fis = new FileInputStream(file);
|
||||||
|
inputStream = new GZIPInputStream(fis);
|
||||||
|
} else {
|
||||||
|
inputStream = new FileInputStream(file);
|
||||||
|
}
|
||||||
|
|
||||||
|
return new BufferedInputStream(inputStream);
|
||||||
|
}
|
||||||
|
|
||||||
private static InputStream decompressZip(File modelFile) throws IOException {
|
private static InputStream decompressZip(File modelFile) throws IOException {
|
||||||
ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
|
||||||
ZipFile zipFile = new ZipFile(modelFile);
|
ZipFile zipFile = new ZipFile(modelFile);
|
||||||
InputStream inputStream = null;
|
InputStream inputStream = null;
|
||||||
|
|
||||||
try (ZipInputStream zipStream = new ZipInputStream(new BufferedInputStream(new FileInputStream(modelFile)))) {
|
try (FileInputStream fis = new FileInputStream(modelFile);
|
||||||
|
BufferedInputStream bis = new BufferedInputStream(fis);
|
||||||
ZipEntry entry = null;
|
ZipInputStream zipStream = new ZipInputStream(bis)) {
|
||||||
|
ZipEntry entry;
|
||||||
if ((entry = zipStream.getNextEntry()) != null) {
|
if ((entry = zipStream.getNextEntry()) != null) {
|
||||||
|
|
||||||
inputStream = zipFile.getInputStream(entry);
|
inputStream = zipFile.getInputStream(entry);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (zipStream.getNextEntry() != null) {
|
if (zipStream.getNextEntry() != null) {
|
||||||
throw new RuntimeException("Zip archive " + modelFile + " contains more than 1 file");
|
throw new RuntimeException("Zip archive " + modelFile + " contains more than 1 file");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return inputStream;
|
return inputStream;
|
||||||
}
|
}
|
||||||
|
|
||||||
private static BufferedReader createReader(File vectorsFile) throws IOException {
|
public static Pair<InMemoryLookupTable, VocabCache> loadTxt(@NonNull File file) {
|
||||||
InputStreamReader inputStreamReader;
|
try (InputStream inputStream = fileStream(file)) {
|
||||||
try {
|
return loadTxt(inputStream);
|
||||||
inputStreamReader = new InputStreamReader(decompressZip(vectorsFile));
|
} catch (IOException readTestException) {
|
||||||
} catch (IOException e) {
|
throw new RuntimeException(readTestException);
|
||||||
inputStreamReader = new InputStreamReader(GzipUtils.isCompressedFilename(vectorsFile.getName())
|
|
||||||
? new GZIPInputStream(new FileInputStream(vectorsFile))
|
|
||||||
: new FileInputStream(vectorsFile), "UTF-8");
|
|
||||||
}
|
}
|
||||||
BufferedReader reader = new BufferedReader(inputStreamReader);
|
|
||||||
return reader;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Loads an in memory cache from the given path (sets syn0 and the vocab)
|
* Loads an in memory cache from the given input stream (sets syn0 and the vocab).
|
||||||
*
|
*
|
||||||
* @param vectorsFile the path of the file to load
|
* @param inputStream input stream
|
||||||
* @return a Pair holding the lookup table and the vocab cache.
|
* @return a {@link Pair} holding the lookup table and the vocab cache.
|
||||||
* @throws FileNotFoundException if the input file does not exist
|
|
||||||
*/
|
*/
|
||||||
public static Pair<InMemoryLookupTable, VocabCache> loadTxt(File vectorsFile)
|
public static Pair<InMemoryLookupTable, VocabCache> loadTxt(@NonNull InputStream inputStream) {
|
||||||
throws IOException, UnsupportedEncodingException {
|
AbstractCache<VocabWord> cache = new AbstractCache<>();
|
||||||
|
LineIterator lines = null;
|
||||||
|
|
||||||
AbstractCache cache = new AbstractCache<>();
|
try (InputStreamReader inputStreamReader = new InputStreamReader(inputStream);
|
||||||
BufferedReader reader = createReader(vectorsFile);
|
BufferedReader reader = new BufferedReader(inputStreamReader)) {
|
||||||
LineIterator iter = IOUtils.lineIterator(reader);
|
lines = IOUtils.lineIterator(reader);
|
||||||
String line = null;
|
|
||||||
boolean hasHeader = false;
|
|
||||||
if (iter.hasNext()) {
|
|
||||||
line = iter.nextLine(); // skip header line
|
|
||||||
//look for spaces
|
|
||||||
if (!line.contains(" ")) {
|
|
||||||
log.debug("Skipping first line");
|
|
||||||
hasHeader = true;
|
|
||||||
} else {
|
|
||||||
// we should check for something that looks like proper word vectors here. i.e: 1 word at the 0 position, and bunch of floats further
|
|
||||||
String[] split = line.split(" ");
|
|
||||||
try {
|
|
||||||
long[] header = new long[split.length];
|
|
||||||
for (int x = 0; x < split.length; x++) {
|
|
||||||
header[x] = Long.parseLong(split[x]);
|
|
||||||
}
|
|
||||||
if (split.length < 4)
|
|
||||||
hasHeader = true;
|
|
||||||
// now we know, if that's all ints - it's just a header
|
|
||||||
// [0] - number of words
|
|
||||||
// [1] - vectorSize
|
|
||||||
// [2] - number of documents <-- DL4j-only value
|
|
||||||
if (split.length == 3)
|
|
||||||
cache.incrementTotalDocCount(header[2]);
|
|
||||||
|
|
||||||
printOutProjectedMemoryUse(header[0], (int) header[1], 1);
|
String line = null;
|
||||||
|
boolean hasHeader = false;
|
||||||
|
|
||||||
hasHeader = true;
|
/* Check if first line is a header */
|
||||||
|
if (lines.hasNext()) {
|
||||||
|
line = lines.nextLine();
|
||||||
|
hasHeader = isHeader(line, cache);
|
||||||
|
}
|
||||||
|
|
||||||
try {
|
if (hasHeader) {
|
||||||
reader.close();
|
log.debug("First line is a header");
|
||||||
} catch (Exception ex) {
|
line = lines.nextLine();
|
||||||
}
|
}
|
||||||
} catch (Exception e) {
|
|
||||||
// if any conversion exception hits - that'll be considered header
|
|
||||||
hasHeader = false;
|
|
||||||
|
|
||||||
|
List<INDArray> arrays = new ArrayList<>();
|
||||||
|
long[] vShape = new long[]{ 1, -1 };
|
||||||
|
|
||||||
|
do {
|
||||||
|
String[] tokens = line.split(" ");
|
||||||
|
String word = ReadHelper.decodeB64(tokens[0]);
|
||||||
|
VocabWord vocabWord = new VocabWord(1.0, word);
|
||||||
|
vocabWord.setIndex(cache.numWords());
|
||||||
|
|
||||||
|
cache.addToken(vocabWord);
|
||||||
|
cache.addWordToIndex(vocabWord.getIndex(), word);
|
||||||
|
cache.putVocabWord(word);
|
||||||
|
|
||||||
|
float[] vector = new float[tokens.length - 1];
|
||||||
|
for (int i = 1; i < tokens.length; i++) {
|
||||||
|
vector[i - 1] = Float.parseFloat(tokens[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
vShape[1] = vector.length;
|
||||||
|
INDArray row = Nd4j.create(vector, vShape);
|
||||||
|
|
||||||
|
arrays.add(row);
|
||||||
|
|
||||||
|
line = lines.hasNext() ? lines.next() : null;
|
||||||
|
} while (line != null);
|
||||||
|
|
||||||
|
INDArray syn = Nd4j.vstack(arrays);
|
||||||
|
|
||||||
|
InMemoryLookupTable<VocabWord> lookupTable = new InMemoryLookupTable
|
||||||
|
.Builder<VocabWord>()
|
||||||
|
.vectorLength(arrays.get(0).columns())
|
||||||
|
.useAdaGrad(false)
|
||||||
|
.cache(cache)
|
||||||
|
.useHierarchicSoftmax(false)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
lookupTable.setSyn0(syn);
|
||||||
|
|
||||||
|
return new Pair<>((InMemoryLookupTable) lookupTable, (VocabCache) cache);
|
||||||
|
} catch (IOException readeTextStreamException) {
|
||||||
|
throw new RuntimeException(readeTextStreamException);
|
||||||
|
} finally {
|
||||||
|
if (lines != null) {
|
||||||
|
lines.close();
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
//reposition buffer to be one line ahead
|
static boolean isHeader(String line, AbstractCache cache) {
|
||||||
if (hasHeader) {
|
if (!line.contains(" ")) {
|
||||||
line = "";
|
return true;
|
||||||
iter.close();
|
} else {
|
||||||
//reader = new BufferedReader(new FileReader(vectorsFile));
|
|
||||||
reader = createReader(vectorsFile);
|
|
||||||
iter = IOUtils.lineIterator(reader);
|
|
||||||
iter.nextLine();
|
|
||||||
}
|
|
||||||
|
|
||||||
List<INDArray> arrays = new ArrayList<>();
|
/* We should check for something that looks like proper word vectors here. i.e: 1 word at the 0
|
||||||
long[] vShape = new long[]{1, -1};
|
* position, and bunch of floats further */
|
||||||
while (iter.hasNext()) {
|
String[] headers = line.split(" ");
|
||||||
if (line.isEmpty())
|
|
||||||
line = iter.nextLine();
|
|
||||||
String[] split = line.split(" ");
|
|
||||||
String word = ReadHelper.decodeB64(split[0]); //split[0].replaceAll(whitespaceReplacement, " ");
|
|
||||||
VocabWord word1 = new VocabWord(1.0, word);
|
|
||||||
|
|
||||||
word1.setIndex(cache.numWords());
|
try {
|
||||||
|
long[] header = new long[headers.length];
|
||||||
|
for (int x = 0; x < headers.length; x++) {
|
||||||
|
header[x] = Long.parseLong(headers[x]);
|
||||||
|
}
|
||||||
|
|
||||||
cache.addToken(word1);
|
/* Now we know, if that's all ints - it's just a header
|
||||||
|
* [0] - number of words
|
||||||
|
* [1] - vectorLength
|
||||||
|
* [2] - number of documents <-- DL4j-only value
|
||||||
|
*/
|
||||||
|
if (headers.length == 3) {
|
||||||
|
long numberOfDocuments = header[2];
|
||||||
|
cache.incrementTotalDocCount(numberOfDocuments);
|
||||||
|
}
|
||||||
|
|
||||||
cache.addWordToIndex(word1.getIndex(), word);
|
long numWords = header[0];
|
||||||
|
int vectorLength = (int) header[1];
|
||||||
|
printOutProjectedMemoryUse(numWords, vectorLength, 1);
|
||||||
|
|
||||||
cache.putVocabWord(word);
|
return true;
|
||||||
|
} catch (Exception notHeaderException) {
|
||||||
float[] vector = new float[split.length - 1];
|
// if any conversion exception hits - that'll be considered header
|
||||||
|
return false;
|
||||||
for (int i = 1; i < split.length; i++) {
|
|
||||||
vector[i - 1] = Float.parseFloat(split[i]);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
vShape[1] = vector.length;
|
|
||||||
INDArray row = Nd4j.create(vector, vShape);
|
|
||||||
|
|
||||||
arrays.add(row);
|
|
||||||
|
|
||||||
// workaround for skipped first row
|
|
||||||
line = "";
|
|
||||||
}
|
}
|
||||||
|
|
||||||
INDArray syn = Nd4j.vstack(arrays);
|
|
||||||
|
|
||||||
InMemoryLookupTable lookupTable =
|
|
||||||
(InMemoryLookupTable) new InMemoryLookupTable.Builder().vectorLength(arrays.get(0).columns())
|
|
||||||
.useAdaGrad(false).cache(cache).useHierarchicSoftmax(false).build();
|
|
||||||
|
|
||||||
lookupTable.setSyn0(syn);
|
|
||||||
|
|
||||||
iter.close();
|
|
||||||
|
|
||||||
try {
|
|
||||||
reader.close();
|
|
||||||
} catch (Exception e) {
|
|
||||||
}
|
|
||||||
|
|
||||||
return new Pair<>(lookupTable, (VocabCache) cache);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -2352,22 +2411,6 @@ public class WordVectorSerializer {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* This method
|
|
||||||
* 1) Binary model, either compressed or not. Like well-known Google Model
|
|
||||||
* 2) Popular CSV word2vec text format
|
|
||||||
* 3) DL4j compressed format
|
|
||||||
* <p>
|
|
||||||
* Please note: Only weights will be loaded by this method.
|
|
||||||
*
|
|
||||||
* @param file
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public static Word2Vec readWord2VecModel(@NonNull File file) {
|
|
||||||
return readWord2VecModel(file, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This method
|
* This method
|
||||||
* 1) Binary model, either compressed or not. Like well-known Google Model
|
* 1) Binary model, either compressed or not. Like well-known Google Model
|
||||||
|
@ -2389,106 +2432,196 @@ public class WordVectorSerializer {
|
||||||
* 2) Popular CSV word2vec text format
|
* 2) Popular CSV word2vec text format
|
||||||
* 3) DL4j compressed format
|
* 3) DL4j compressed format
|
||||||
* <p>
|
* <p>
|
||||||
* Please note: if extended data isn't available, only weights will be loaded instead.
|
* Please note: Only weights will be loaded by this method.
|
||||||
*
|
*
|
||||||
* @param path
|
* @param path path to model file
|
||||||
* @param extendedModel if TRUE, we'll try to load HS states & Huffman tree info, if FALSE, only weights will be loaded
|
* @param extendedModel if TRUE, we'll try to load HS states & Huffman tree info, if FALSE, only weights will be loaded
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public static Word2Vec readWord2VecModel(String path, boolean extendedModel) {
|
public static Word2Vec readWord2VecModel(String path, boolean extendedModel) {
|
||||||
return readWord2VecModel(new File(path), extendedModel);
|
return readWord2VecModel(new File(path), extendedModel);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static Word2Vec readAsBinaryNoLineBreaks(@NonNull File file) {
|
/**
|
||||||
|
* This method
|
||||||
|
* 1) Binary model, either compressed or not. Like well-known Google Model
|
||||||
|
* 2) Popular CSV word2vec text format
|
||||||
|
* 3) DL4j compressed format
|
||||||
|
* <p>
|
||||||
|
* Please note: Only weights will be loaded by this method.
|
||||||
|
*
|
||||||
|
* @param file
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public static Word2Vec readWord2VecModel(File file) {
|
||||||
|
return readWord2VecModel(file, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method
|
||||||
|
* 1) Binary model, either compressed or not. Like well-known Google Model
|
||||||
|
* 2) Popular CSV word2vec text format
|
||||||
|
* 3) DL4j compressed format
|
||||||
|
* <p>
|
||||||
|
* Please note: if extended data isn't available, only weights will be loaded instead.
|
||||||
|
*
|
||||||
|
* @param file model file
|
||||||
|
* @param extendedModel if TRUE, we'll try to load HS states & Huffman tree info, if FALSE, only weights will be loaded
|
||||||
|
* @return word2vec model
|
||||||
|
*/
|
||||||
|
public static Word2Vec readWord2VecModel(File file, boolean extendedModel) {
|
||||||
|
if (!file.exists() || !file.isFile()) {
|
||||||
|
throw new ND4JIllegalStateException("File [" + file.getAbsolutePath() + "] doesn't exist");
|
||||||
|
}
|
||||||
|
|
||||||
|
boolean originalPeriodic = Nd4j.getMemoryManager().isPeriodicGcActive();
|
||||||
|
if (originalPeriodic) {
|
||||||
|
Nd4j.getMemoryManager().togglePeriodicGc(false);
|
||||||
|
}
|
||||||
|
Nd4j.getMemoryManager().setOccasionalGcFrequency(50000);
|
||||||
|
|
||||||
|
try {
|
||||||
|
return readWord2Vec(file, extendedModel);
|
||||||
|
} catch (Exception readSequenceVectors) {
|
||||||
|
try {
|
||||||
|
return extendedModel
|
||||||
|
? readAsExtendedModel(file)
|
||||||
|
: readAsSimplifiedModel(file);
|
||||||
|
} catch (Exception loadFromFileException) {
|
||||||
|
try {
|
||||||
|
return readAsCsv(file);
|
||||||
|
} catch (Exception readCsvException) {
|
||||||
|
try {
|
||||||
|
return readAsBinary(file);
|
||||||
|
} catch (Exception readBinaryException) {
|
||||||
|
try {
|
||||||
|
return readAsBinaryNoLineBreaks(file);
|
||||||
|
} catch (Exception readModelException) {
|
||||||
|
log.error("Unable to guess input file format", readModelException);
|
||||||
|
throw new RuntimeException("Unable to guess input file format. Please use corresponding loader directly");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static Word2Vec readAsBinaryNoLineBreaks(@NonNull File file) {
|
||||||
|
try (InputStream inputStream = fileStream(file)) {
|
||||||
|
return readAsBinaryNoLineBreaks(inputStream);
|
||||||
|
} catch (IOException readCsvException) {
|
||||||
|
throw new RuntimeException(readCsvException);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static Word2Vec readAsBinaryNoLineBreaks(@NonNull InputStream inputStream) {
|
||||||
boolean originalPeriodic = Nd4j.getMemoryManager().isPeriodicGcActive();
|
boolean originalPeriodic = Nd4j.getMemoryManager().isPeriodicGcActive();
|
||||||
int originalFreq = Nd4j.getMemoryManager().getOccasionalGcFrequency();
|
int originalFreq = Nd4j.getMemoryManager().getOccasionalGcFrequency();
|
||||||
Word2Vec vec;
|
|
||||||
|
|
||||||
// try to load without linebreaks
|
// try to load without linebreaks
|
||||||
try {
|
try {
|
||||||
if (originalPeriodic)
|
if (originalPeriodic) {
|
||||||
Nd4j.getMemoryManager().togglePeriodicGc(true);
|
Nd4j.getMemoryManager().togglePeriodicGc(true);
|
||||||
|
}
|
||||||
|
|
||||||
Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
|
Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
|
||||||
|
|
||||||
vec = readBinaryModel(file, false, false);
|
return readBinaryModel(inputStream, false, false);
|
||||||
return vec;
|
} catch (Exception readModelException) {
|
||||||
} catch (Exception ez) {
|
log.error("Cannot read binary model", readModelException);
|
||||||
throw new RuntimeException(
|
throw new RuntimeException("Unable to guess input file format. Please use corresponding loader directly");
|
||||||
"Unable to guess input file format. Please use corresponding loader directly");
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static Word2Vec readAsBinary(@NonNull File file) {
|
||||||
|
try (InputStream inputStream = fileStream(file)) {
|
||||||
|
return readAsBinary(inputStream);
|
||||||
|
} catch (IOException readCsvException) {
|
||||||
|
throw new RuntimeException(readCsvException);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This method loads Word2Vec model from binary file
|
* This method loads Word2Vec model from binary input stream.
|
||||||
*
|
*
|
||||||
* @param file File
|
* @param inputStream binary input stream
|
||||||
* @return Word2Vec
|
* @return Word2Vec
|
||||||
*/
|
*/
|
||||||
public static Word2Vec readAsBinary(@NonNull File file) {
|
public static Word2Vec readAsBinary(@NonNull InputStream inputStream) {
|
||||||
boolean originalPeriodic = Nd4j.getMemoryManager().isPeriodicGcActive();
|
boolean originalPeriodic = Nd4j.getMemoryManager().isPeriodicGcActive();
|
||||||
int originalFreq = Nd4j.getMemoryManager().getOccasionalGcFrequency();
|
int originalFreq = Nd4j.getMemoryManager().getOccasionalGcFrequency();
|
||||||
|
|
||||||
Word2Vec vec;
|
|
||||||
|
|
||||||
// we fallback to trying binary model instead
|
// we fallback to trying binary model instead
|
||||||
try {
|
try {
|
||||||
log.debug("Trying binary model restoration...");
|
log.debug("Trying binary model restoration...");
|
||||||
|
|
||||||
if (originalPeriodic)
|
if (originalPeriodic) {
|
||||||
Nd4j.getMemoryManager().togglePeriodicGc(true);
|
Nd4j.getMemoryManager().togglePeriodicGc(true);
|
||||||
|
}
|
||||||
|
|
||||||
Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
|
Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
|
||||||
|
|
||||||
vec = readBinaryModel(file, true, false);
|
return readBinaryModel(inputStream, true, false);
|
||||||
return vec;
|
} catch (Exception readModelException) {
|
||||||
} catch (Exception ey) {
|
throw new RuntimeException(readModelException);
|
||||||
throw new RuntimeException(ey);
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static Word2Vec readAsCsv(@NonNull File file) {
|
||||||
|
try (InputStream inputStream = fileStream(file)) {
|
||||||
|
return readAsCsv(inputStream);
|
||||||
|
} catch (IOException readCsvException) {
|
||||||
|
throw new RuntimeException(readCsvException);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This method loads Word2Vec model from csv file
|
* This method loads Word2Vec model from csv file
|
||||||
*
|
*
|
||||||
* @param file File
|
* @param inputStream input stream
|
||||||
* @return Word2Vec
|
* @return Word2Vec model
|
||||||
*/
|
*/
|
||||||
public static Word2Vec readAsCsv(@NonNull File file) {
|
public static Word2Vec readAsCsv(@NonNull InputStream inputStream) {
|
||||||
|
|
||||||
Word2Vec vec;
|
|
||||||
VectorsConfiguration configuration = new VectorsConfiguration();
|
VectorsConfiguration configuration = new VectorsConfiguration();
|
||||||
|
|
||||||
// let's try to load this file as csv file
|
// let's try to load this file as csv file
|
||||||
try {
|
try {
|
||||||
log.debug("Trying CSV model restoration...");
|
log.debug("Trying CSV model restoration...");
|
||||||
|
|
||||||
Pair<InMemoryLookupTable, VocabCache> pair = loadTxt(file);
|
Pair<InMemoryLookupTable, VocabCache> pair = loadTxt(inputStream);
|
||||||
Word2Vec.Builder builder = new Word2Vec.Builder().lookupTable(pair.getFirst()).useAdaGrad(false)
|
Word2Vec.Builder builder = new Word2Vec
|
||||||
.vocabCache(pair.getSecond()).layerSize(pair.getFirst().layerSize())
|
.Builder()
|
||||||
|
.lookupTable(pair.getFirst())
|
||||||
|
.useAdaGrad(false)
|
||||||
|
.vocabCache(pair.getSecond())
|
||||||
|
.layerSize(pair.getFirst().layerSize())
|
||||||
// we don't use hs here, because model is incomplete
|
// we don't use hs here, because model is incomplete
|
||||||
.useHierarchicSoftmax(false).resetModel(false);
|
.useHierarchicSoftmax(false)
|
||||||
|
.resetModel(false);
|
||||||
|
|
||||||
TokenizerFactory factory = getTokenizerFactory(configuration);
|
TokenizerFactory factory = getTokenizerFactory(configuration);
|
||||||
if (factory != null)
|
if (factory != null) {
|
||||||
builder.tokenizerFactory(factory);
|
builder.tokenizerFactory(factory);
|
||||||
|
}
|
||||||
|
|
||||||
vec = builder.build();
|
return builder.build();
|
||||||
return vec;
|
|
||||||
} catch (Exception ex) {
|
} catch (Exception ex) {
|
||||||
throw new RuntimeException("Unable to load model in CSV format");
|
throw new RuntimeException("Unable to load model in CSV format");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method just loads full compressed model.
|
||||||
|
*/
|
||||||
private static Word2Vec readAsExtendedModel(@NonNull File file) throws IOException {
|
private static Word2Vec readAsExtendedModel(@NonNull File file) throws IOException {
|
||||||
int originalFreq = Nd4j.getMemoryManager().getOccasionalGcFrequency();
|
int originalFreq = Nd4j.getMemoryManager().getOccasionalGcFrequency();
|
||||||
boolean originalPeriodic = Nd4j.getMemoryManager().isPeriodicGcActive();
|
boolean originalPeriodic = Nd4j.getMemoryManager().isPeriodicGcActive();
|
||||||
|
|
||||||
log.debug("Trying full model restoration...");
|
log.debug("Trying full model restoration...");
|
||||||
// this method just loads full compressed model
|
|
||||||
|
|
||||||
if (originalPeriodic)
|
if (originalPeriodic) {
|
||||||
Nd4j.getMemoryManager().togglePeriodicGc(true);
|
Nd4j.getMemoryManager().togglePeriodicGc(true);
|
||||||
|
}
|
||||||
|
|
||||||
Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
|
Nd4j.getMemoryManager().setOccasionalGcFrequency(originalFreq);
|
||||||
|
|
||||||
|
@ -2627,67 +2760,6 @@ public class WordVectorSerializer {
|
||||||
return vec;
|
return vec;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* This method
|
|
||||||
* 1) Binary model, either compressed or not. Like well-known Google Model
|
|
||||||
* 2) Popular CSV word2vec text format
|
|
||||||
* 3) DL4j compressed format
|
|
||||||
* <p>
|
|
||||||
* Please note: if extended data isn't available, only weights will be loaded instead.
|
|
||||||
*
|
|
||||||
* @param file
|
|
||||||
* @param extendedModel if TRUE, we'll try to load HS states & Huffman tree info, if FALSE, only weights will be loaded
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public static Word2Vec readWord2VecModel(@NonNull File file, boolean extendedModel) {
|
|
||||||
|
|
||||||
if (!file.exists() || !file.isFile())
|
|
||||||
throw new ND4JIllegalStateException("File [" + file.getAbsolutePath() + "] doesn't exist");
|
|
||||||
|
|
||||||
Word2Vec vec = null;
|
|
||||||
|
|
||||||
int originalFreq = Nd4j.getMemoryManager().getOccasionalGcFrequency();
|
|
||||||
boolean originalPeriodic = Nd4j.getMemoryManager().isPeriodicGcActive();
|
|
||||||
if (originalPeriodic)
|
|
||||||
Nd4j.getMemoryManager().togglePeriodicGc(false);
|
|
||||||
Nd4j.getMemoryManager().setOccasionalGcFrequency(50000);
|
|
||||||
|
|
||||||
// try to load zip format
|
|
||||||
try {
|
|
||||||
vec = readWord2Vec(file, extendedModel);
|
|
||||||
return vec;
|
|
||||||
} catch (Exception e) {
|
|
||||||
// let's try to load this file as csv file
|
|
||||||
try {
|
|
||||||
if (extendedModel) {
|
|
||||||
vec = readAsExtendedModel(file);
|
|
||||||
return vec;
|
|
||||||
} else {
|
|
||||||
vec = readAsSimplifiedModel(file);
|
|
||||||
return vec;
|
|
||||||
}
|
|
||||||
} catch (Exception ex) {
|
|
||||||
try {
|
|
||||||
vec = readAsCsv(file);
|
|
||||||
return vec;
|
|
||||||
} catch (Exception exc) {
|
|
||||||
try {
|
|
||||||
vec = readAsBinary(file);
|
|
||||||
return vec;
|
|
||||||
} catch (Exception exce) {
|
|
||||||
try {
|
|
||||||
vec = readAsBinaryNoLineBreaks(file);
|
|
||||||
return vec;
|
|
||||||
|
|
||||||
} catch (Exception excep) {
|
|
||||||
throw new RuntimeException("Unable to guess input file format. Please use corresponding loader directly");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
protected static TokenizerFactory getTokenizerFactory(VectorsConfiguration configuration) {
|
protected static TokenizerFactory getTokenizerFactory(VectorsConfiguration configuration) {
|
||||||
if (configuration == null)
|
if (configuration == null)
|
||||||
return null;
|
return null;
|
||||||
|
@ -3019,16 +3091,13 @@ public class WordVectorSerializer {
|
||||||
/**
|
/**
|
||||||
* This method restores Word2Vec model from file
|
* This method restores Word2Vec model from file
|
||||||
*
|
*
|
||||||
* @param path String
|
* @param path
|
||||||
* @param readExtendedTables booleab
|
* @param readExtendedTables
|
||||||
* @return Word2Vec
|
* @return Word2Vec
|
||||||
*/
|
*/
|
||||||
public static Word2Vec readWord2Vec(@NonNull String path, boolean readExtendedTables)
|
public static Word2Vec readWord2Vec(@NonNull String path, boolean readExtendedTables) {
|
||||||
throws IOException {
|
|
||||||
|
|
||||||
File file = new File(path);
|
File file = new File(path);
|
||||||
Word2Vec word2Vec = readWord2Vec(file, readExtendedTables);
|
return readWord2Vec(file, readExtendedTables);
|
||||||
return word2Vec;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -3139,11 +3208,12 @@ public class WordVectorSerializer {
|
||||||
* @param readExtendedTables boolean
|
* @param readExtendedTables boolean
|
||||||
* @return Word2Vec
|
* @return Word2Vec
|
||||||
*/
|
*/
|
||||||
public static Word2Vec readWord2Vec(@NonNull File file, boolean readExtendedTables)
|
public static Word2Vec readWord2Vec(@NonNull File file, boolean readExtendedTables) {
|
||||||
throws IOException {
|
try (InputStream inputStream = fileStream(file)) {
|
||||||
|
return readWord2Vec(inputStream, readExtendedTables);
|
||||||
Word2Vec word2Vec = readWord2Vec(new FileInputStream(file), readExtendedTables);
|
} catch (Exception readSequenceVectors) {
|
||||||
return word2Vec;
|
throw new RuntimeException(readSequenceVectors);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -3153,13 +3223,19 @@ public class WordVectorSerializer {
|
||||||
* @param readExtendedTable boolean
|
* @param readExtendedTable boolean
|
||||||
* @return Word2Vec
|
* @return Word2Vec
|
||||||
*/
|
*/
|
||||||
public static Word2Vec readWord2Vec(@NonNull InputStream stream,
|
public static Word2Vec readWord2Vec(
|
||||||
boolean readExtendedTable) throws IOException {
|
@NonNull InputStream stream,
|
||||||
|
boolean readExtendedTable) throws IOException {
|
||||||
SequenceVectors<VocabWord> vectors = readSequenceVectors(stream, readExtendedTable);
|
SequenceVectors<VocabWord> vectors = readSequenceVectors(stream, readExtendedTable);
|
||||||
Word2Vec word2Vec = new Word2Vec.Builder(vectors.getConfiguration()).layerSize(vectors.getLayerSize()).build();
|
|
||||||
|
Word2Vec word2Vec = new Word2Vec
|
||||||
|
.Builder(vectors.getConfiguration())
|
||||||
|
.layerSize(vectors.getLayerSize())
|
||||||
|
.build();
|
||||||
word2Vec.setVocab(vectors.getVocab());
|
word2Vec.setVocab(vectors.getVocab());
|
||||||
word2Vec.setLookupTable(vectors.lookupTable());
|
word2Vec.setLookupTable(vectors.lookupTable());
|
||||||
word2Vec.setModelUtils(vectors.getModelUtils());
|
word2Vec.setModelUtils(vectors.getModelUtils());
|
||||||
|
|
||||||
return word2Vec;
|
return word2Vec;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -29,7 +29,7 @@ import org.deeplearning4j.text.sentenceiterator.PrefetchingSentenceIterator;
|
||||||
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
|
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
|
||||||
import org.deeplearning4j.text.sentenceiterator.SynchronizedSentenceIterator;
|
import org.deeplearning4j.text.sentenceiterator.SynchronizedSentenceIterator;
|
||||||
import org.deeplearning4j.common.util.DL4JFileUtils;
|
import org.deeplearning4j.common.util.DL4JFileUtils;
|
||||||
import org.deeplearning4j.core.util.ThreadUtils;
|
import org.nd4j.common.util.ThreadUtils;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.common.primitives.Pair;
|
import org.nd4j.common.primitives.Pair;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
|
|
|
@ -47,7 +47,7 @@ import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
|
||||||
import org.deeplearning4j.text.sentenceiterator.interoperability.SentenceIteratorConverter;
|
import org.deeplearning4j.text.sentenceiterator.interoperability.SentenceIteratorConverter;
|
||||||
import org.deeplearning4j.text.sentenceiterator.labelaware.LabelAwareSentenceIterator;
|
import org.deeplearning4j.text.sentenceiterator.labelaware.LabelAwareSentenceIterator;
|
||||||
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
|
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
|
||||||
import org.deeplearning4j.core.util.ThreadUtils;
|
import org.nd4j.common.util.ThreadUtils;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
|
@ -47,7 +47,7 @@ import org.deeplearning4j.models.word2vec.VocabWord;
|
||||||
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
|
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
|
||||||
import org.deeplearning4j.models.word2vec.wordstore.VocabConstructor;
|
import org.deeplearning4j.models.word2vec.wordstore.VocabConstructor;
|
||||||
import org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache;
|
import org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache;
|
||||||
import org.deeplearning4j.core.util.ThreadUtils;
|
import org.nd4j.common.util.ThreadUtils;
|
||||||
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
|
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
|
||||||
import org.nd4j.linalg.api.memory.enums.LearningPolicy;
|
import org.nd4j.linalg.api.memory.enums.LearningPolicy;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
|
@ -27,7 +27,7 @@ import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
|
||||||
import org.deeplearning4j.models.word2vec.Huffman;
|
import org.deeplearning4j.models.word2vec.Huffman;
|
||||||
import org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache;
|
import org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache;
|
||||||
import org.deeplearning4j.text.invertedindex.InvertedIndex;
|
import org.deeplearning4j.text.invertedindex.InvertedIndex;
|
||||||
import org.deeplearning4j.core.util.ThreadUtils;
|
import org.nd4j.common.util.ThreadUtils;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
import org.threadly.concurrent.PriorityScheduler;
|
import org.threadly.concurrent.PriorityScheduler;
|
||||||
|
|
|
@ -18,7 +18,7 @@ package org.deeplearning4j.text.sentenceiterator;
|
||||||
|
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
|
|
||||||
import org.deeplearning4j.core.util.ThreadUtils;
|
import org.nd4j.common.util.ThreadUtils;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
|
|
|
@ -37,8 +37,6 @@ import java.io.File;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class TsneTest extends BaseDL4JTest {
|
public class TsneTest extends BaseDL4JTest {
|
||||||
|
|
||||||
|
|
|
@ -14,17 +14,14 @@
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
package org.deeplearning4j.models.sequencevectors.serialization;
|
package org.deeplearning4j.models.embeddings.loader;
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.apache.commons.lang.StringUtils;
|
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.models.embeddings.WeightLookupTable;
|
import org.deeplearning4j.models.embeddings.WeightLookupTable;
|
||||||
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
|
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
|
||||||
import org.deeplearning4j.models.embeddings.learning.impl.elements.CBOW;
|
import org.deeplearning4j.models.embeddings.learning.impl.elements.CBOW;
|
||||||
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
|
|
||||||
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
|
|
||||||
import org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils;
|
import org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils;
|
||||||
import org.deeplearning4j.models.embeddings.reader.impl.FlatModelUtils;
|
import org.deeplearning4j.models.embeddings.reader.impl.FlatModelUtils;
|
||||||
import org.deeplearning4j.models.fasttext.FastText;
|
import org.deeplearning4j.models.fasttext.FastText;
|
||||||
|
@ -47,7 +44,11 @@ import java.io.File;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
|
|
||||||
import static org.junit.Assert.*;
|
import static org.junit.Assert.assertEquals;
|
||||||
|
import static org.junit.Assert.assertFalse;
|
||||||
|
import static org.junit.Assert.assertNotNull;
|
||||||
|
import static org.junit.Assert.assertTrue;
|
||||||
|
import static org.junit.Assert.fail;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class WordVectorSerializerTest extends BaseDL4JTest {
|
public class WordVectorSerializerTest extends BaseDL4JTest {
|
||||||
|
@ -78,10 +79,11 @@ public class WordVectorSerializerTest extends BaseDL4JTest {
|
||||||
syn1 = Nd4j.rand(DataType.FLOAT, 10, 2),
|
syn1 = Nd4j.rand(DataType.FLOAT, 10, 2),
|
||||||
syn1Neg = Nd4j.rand(DataType.FLOAT, 10, 2);
|
syn1Neg = Nd4j.rand(DataType.FLOAT, 10, 2);
|
||||||
|
|
||||||
InMemoryLookupTable<VocabWord> lookupTable =
|
InMemoryLookupTable<VocabWord> lookupTable = new InMemoryLookupTable
|
||||||
(InMemoryLookupTable<VocabWord>) new InMemoryLookupTable.Builder<VocabWord>()
|
.Builder<VocabWord>()
|
||||||
.useAdaGrad(false).cache(cache)
|
.useAdaGrad(false)
|
||||||
.build();
|
.cache(cache)
|
||||||
|
.build();
|
||||||
|
|
||||||
lookupTable.setSyn0(syn0);
|
lookupTable.setSyn0(syn0);
|
||||||
lookupTable.setSyn1(syn1);
|
lookupTable.setSyn1(syn1);
|
||||||
|
@ -92,7 +94,6 @@ public class WordVectorSerializerTest extends BaseDL4JTest {
|
||||||
lookupTable(lookupTable).
|
lookupTable(lookupTable).
|
||||||
build();
|
build();
|
||||||
SequenceVectors<VocabWord> deser = null;
|
SequenceVectors<VocabWord> deser = null;
|
||||||
String json = StringUtils.EMPTY;
|
|
||||||
try {
|
try {
|
||||||
ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
||||||
WordVectorSerializer.writeSequenceVectors(vectors, baos);
|
WordVectorSerializer.writeSequenceVectors(vectors, baos);
|
||||||
|
@ -126,10 +127,11 @@ public class WordVectorSerializerTest extends BaseDL4JTest {
|
||||||
syn1 = Nd4j.rand(DataType.FLOAT, 10, 2),
|
syn1 = Nd4j.rand(DataType.FLOAT, 10, 2),
|
||||||
syn1Neg = Nd4j.rand(DataType.FLOAT, 10, 2);
|
syn1Neg = Nd4j.rand(DataType.FLOAT, 10, 2);
|
||||||
|
|
||||||
InMemoryLookupTable<VocabWord> lookupTable =
|
InMemoryLookupTable<VocabWord> lookupTable = new InMemoryLookupTable
|
||||||
(InMemoryLookupTable<VocabWord>) new InMemoryLookupTable.Builder<VocabWord>()
|
.Builder<VocabWord>()
|
||||||
.useAdaGrad(false).cache(cache)
|
.useAdaGrad(false)
|
||||||
.build();
|
.cache(cache)
|
||||||
|
.build();
|
||||||
|
|
||||||
lookupTable.setSyn0(syn0);
|
lookupTable.setSyn0(syn0);
|
||||||
lookupTable.setSyn1(syn1);
|
lookupTable.setSyn1(syn1);
|
||||||
|
@ -204,10 +206,11 @@ public class WordVectorSerializerTest extends BaseDL4JTest {
|
||||||
syn1 = Nd4j.rand(DataType.FLOAT, 10, 2),
|
syn1 = Nd4j.rand(DataType.FLOAT, 10, 2),
|
||||||
syn1Neg = Nd4j.rand(DataType.FLOAT, 10, 2);
|
syn1Neg = Nd4j.rand(DataType.FLOAT, 10, 2);
|
||||||
|
|
||||||
InMemoryLookupTable<VocabWord> lookupTable =
|
InMemoryLookupTable<VocabWord> lookupTable = new InMemoryLookupTable
|
||||||
(InMemoryLookupTable<VocabWord>) new InMemoryLookupTable.Builder<VocabWord>()
|
.Builder<VocabWord>()
|
||||||
.useAdaGrad(false).cache(cache)
|
.useAdaGrad(false)
|
||||||
.build();
|
.cache(cache)
|
||||||
|
.build();
|
||||||
|
|
||||||
lookupTable.setSyn0(syn0);
|
lookupTable.setSyn0(syn0);
|
||||||
lookupTable.setSyn1(syn1);
|
lookupTable.setSyn1(syn1);
|
||||||
|
@ -252,10 +255,11 @@ public class WordVectorSerializerTest extends BaseDL4JTest {
|
||||||
syn1 = Nd4j.rand(DataType.FLOAT, 10, 2),
|
syn1 = Nd4j.rand(DataType.FLOAT, 10, 2),
|
||||||
syn1Neg = Nd4j.rand(DataType.FLOAT, 10, 2);
|
syn1Neg = Nd4j.rand(DataType.FLOAT, 10, 2);
|
||||||
|
|
||||||
InMemoryLookupTable<VocabWord> lookupTable =
|
InMemoryLookupTable<VocabWord> lookupTable = new InMemoryLookupTable
|
||||||
(InMemoryLookupTable<VocabWord>) new InMemoryLookupTable.Builder<VocabWord>()
|
.Builder<VocabWord>()
|
||||||
.useAdaGrad(false).cache(cache)
|
.useAdaGrad(false)
|
||||||
.build();
|
.cache(cache)
|
||||||
|
.build();
|
||||||
|
|
||||||
lookupTable.setSyn0(syn0);
|
lookupTable.setSyn0(syn0);
|
||||||
lookupTable.setSyn1(syn1);
|
lookupTable.setSyn1(syn1);
|
||||||
|
@ -267,7 +271,6 @@ public class WordVectorSerializerTest extends BaseDL4JTest {
|
||||||
WeightLookupTable<VocabWord> deser = null;
|
WeightLookupTable<VocabWord> deser = null;
|
||||||
try {
|
try {
|
||||||
WordVectorSerializer.writeLookupTable(lookupTable, file);
|
WordVectorSerializer.writeLookupTable(lookupTable, file);
|
||||||
ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
|
||||||
deser = WordVectorSerializer.readLookupTable(file);
|
deser = WordVectorSerializer.readLookupTable(file);
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
log.error("",e);
|
log.error("",e);
|
||||||
|
@ -305,7 +308,6 @@ public class WordVectorSerializerTest extends BaseDL4JTest {
|
||||||
|
|
||||||
FastText deser = null;
|
FastText deser = null;
|
||||||
try {
|
try {
|
||||||
ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
|
||||||
deser = WordVectorSerializer.readWordVectors(new File(dir, "some.data"));
|
deser = WordVectorSerializer.readWordVectors(new File(dir, "some.data"));
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
log.error("",e);
|
log.error("",e);
|
||||||
|
@ -323,4 +325,32 @@ public class WordVectorSerializerTest extends BaseDL4JTest {
|
||||||
assertEquals(fastText.getInputFile(), deser.getInputFile());
|
assertEquals(fastText.getInputFile(), deser.getInputFile());
|
||||||
assertEquals(fastText.getOutputFile(), deser.getOutputFile());
|
assertEquals(fastText.getOutputFile(), deser.getOutputFile());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testIsHeader_withValidHeader () {
|
||||||
|
|
||||||
|
/* Given */
|
||||||
|
AbstractCache<VocabWord> cache = new AbstractCache<>();
|
||||||
|
String line = "48 100";
|
||||||
|
|
||||||
|
/* When */
|
||||||
|
boolean isHeader = WordVectorSerializer.isHeader(line, cache);
|
||||||
|
|
||||||
|
/* Then */
|
||||||
|
assertTrue(isHeader);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testIsHeader_notHeader () {
|
||||||
|
|
||||||
|
/* Given */
|
||||||
|
AbstractCache<VocabWord> cache = new AbstractCache<>();
|
||||||
|
String line = "your -0.0017603 0.0030831 0.00069072 0.0020581 -0.0050952 -2.2573e-05 -0.001141";
|
||||||
|
|
||||||
|
/* When */
|
||||||
|
boolean isHeader = WordVectorSerializer.isHeader(line, cache);
|
||||||
|
|
||||||
|
/* Then */
|
||||||
|
assertFalse(isHeader);
|
||||||
|
}
|
||||||
}
|
}
|
|
@ -1,9 +1,9 @@
|
||||||
package org.deeplearning4j.models.fasttext;
|
package org.deeplearning4j.models.fasttext;
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
|
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
|
||||||
import org.deeplearning4j.models.word2vec.Word2Vec;
|
import org.deeplearning4j.models.word2vec.Word2Vec;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
|
||||||
import org.deeplearning4j.text.sentenceiterator.BasicLineIterator;
|
import org.deeplearning4j.text.sentenceiterator.BasicLineIterator;
|
||||||
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
|
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
|
||||||
import org.junit.Rule;
|
import org.junit.Rule;
|
||||||
|
@ -14,13 +14,14 @@ import org.nd4j.common.primitives.Pair;
|
||||||
import org.nd4j.common.resources.Resources;
|
import org.nd4j.common.resources.Resources;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
|
import java.io.FileNotFoundException;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
|
||||||
|
import static org.hamcrest.CoreMatchers.hasItems;
|
||||||
|
import static org.hamcrest.MatcherAssert.assertThat;
|
||||||
import static org.junit.Assert.assertArrayEquals;
|
import static org.junit.Assert.assertArrayEquals;
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
|
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class FastTextTest extends BaseDL4JTest {
|
public class FastTextTest extends BaseDL4JTest {
|
||||||
|
|
||||||
|
@ -32,7 +33,6 @@ public class FastTextTest extends BaseDL4JTest {
|
||||||
private File cbowModelFile = Resources.asFile("models/fasttext/cbow.model.bin");
|
private File cbowModelFile = Resources.asFile("models/fasttext/cbow.model.bin");
|
||||||
private File supervisedVectors = Resources.asFile("models/fasttext/supervised.model.vec");
|
private File supervisedVectors = Resources.asFile("models/fasttext/supervised.model.vec");
|
||||||
|
|
||||||
|
|
||||||
@Rule
|
@Rule
|
||||||
public TemporaryFolder testDir = new TemporaryFolder();
|
public TemporaryFolder testDir = new TemporaryFolder();
|
||||||
|
|
||||||
|
@ -90,7 +90,7 @@ public class FastTextTest extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void tesLoadCBOWModel() throws IOException {
|
public void tesLoadCBOWModel() {
|
||||||
|
|
||||||
FastText fastText = new FastText(cbowModelFile);
|
FastText fastText = new FastText(cbowModelFile);
|
||||||
fastText.test(cbowModelFile);
|
fastText.test(cbowModelFile);
|
||||||
|
@ -99,7 +99,7 @@ public class FastTextTest extends BaseDL4JTest {
|
||||||
assertEquals("enjoy", fastText.vocab().wordAtIndex(fastText.vocab().numWords() - 1));
|
assertEquals("enjoy", fastText.vocab().wordAtIndex(fastText.vocab().numWords() - 1));
|
||||||
|
|
||||||
double[] expected = {5.040466203354299E-4, 0.001005030469968915, 2.8882650076411664E-4, -6.413314840756357E-4, -1.78931062691845E-4, -0.0023157168179750443, -0.002215880434960127, 0.00274421414360404, -1.5344757412094623E-4, 4.6274057240225375E-4, -1.4383681991603225E-4, 3.7832374800927937E-4, 2.523412986192852E-4, 0.0018913350068032742, -0.0024741862434893847, -4.976555937901139E-4, 0.0039220210164785385, -0.001781729981303215, -6.010578363202512E-4, -0.00244093406945467, -7.98621098510921E-4, -0.0010007203090935946, -0.001640203408896923, 7.897148607298732E-4, 9.131592814810574E-4, -0.0013367272913455963, -0.0014030139427632093, -7.755287806503475E-4, -4.2878396925516427E-4, 6.912827957421541E-4, -0.0011824817629531026, -0.0036014916840940714, 0.004353308118879795, -7.073904271237552E-5, -9.646290563978255E-4, -0.0031849315855652094, 2.3360115301329643E-4, -2.9103990527801216E-4, -0.0022990566212683916, -0.002393763978034258, -0.001034979010000825, -0.0010725988540798426, 0.0018285386031493545, -0.0013178540393710136, -1.6632364713586867E-4, -1.4665909475297667E-5, 5.445032729767263E-4, 2.999933494720608E-4, -0.0014367225812748075, -0.002345481887459755, 0.001117417006753385, -8.688368834555149E-4, -0.001830018823966384, 0.0013242220738902688, -8.880519890226424E-4, -6.888324278406799E-4, -0.0036394784692674875, 0.002179111586883664, -1.7201311129610986E-4, 0.002365073887631297, 0.002688770182430744, 0.0023955567739903927, 0.001469283364713192, 0.0011803617235273123, 5.871498142369092E-4, -7.099180947989225E-4, 7.518937345594168E-4, -8.599072461947799E-4, -6.600041524507105E-4, -0.002724145073443651, -8.365285466425121E-4, 0.0013173354091122746, 0.001083166105672717, 0.0014539906987920403, -3.1698777456767857E-4, -2.387022686889395E-4, 1.9560157670639455E-4, 0.0020277926232665777, -0.0012741144746541977, -0.0013026101514697075, -1.5212174912448972E-4, 0.0014194383984431624, 0.0012500399025157094, 0.0013362085446715355, 3.692879108712077E-4, 4.319801155361347E-5, 0.0011261265026405454, 0.0017244465416297317, 5.564604725805111E-5, 0.002170475199818611, 0.0014707016525790095, 0.001303741242736578, 0.005553730763494968, -0.0011097051901742816, -0.0013661726843565702, 0.0014100460102781653, 0.0011811562580987811, -6.622733199037611E-4, 7.860265322960913E-4, -9.811905911192298E-4};
|
double[] expected = {5.040466203354299E-4, 0.001005030469968915, 2.8882650076411664E-4, -6.413314840756357E-4, -1.78931062691845E-4, -0.0023157168179750443, -0.002215880434960127, 0.00274421414360404, -1.5344757412094623E-4, 4.6274057240225375E-4, -1.4383681991603225E-4, 3.7832374800927937E-4, 2.523412986192852E-4, 0.0018913350068032742, -0.0024741862434893847, -4.976555937901139E-4, 0.0039220210164785385, -0.001781729981303215, -6.010578363202512E-4, -0.00244093406945467, -7.98621098510921E-4, -0.0010007203090935946, -0.001640203408896923, 7.897148607298732E-4, 9.131592814810574E-4, -0.0013367272913455963, -0.0014030139427632093, -7.755287806503475E-4, -4.2878396925516427E-4, 6.912827957421541E-4, -0.0011824817629531026, -0.0036014916840940714, 0.004353308118879795, -7.073904271237552E-5, -9.646290563978255E-4, -0.0031849315855652094, 2.3360115301329643E-4, -2.9103990527801216E-4, -0.0022990566212683916, -0.002393763978034258, -0.001034979010000825, -0.0010725988540798426, 0.0018285386031493545, -0.0013178540393710136, -1.6632364713586867E-4, -1.4665909475297667E-5, 5.445032729767263E-4, 2.999933494720608E-4, -0.0014367225812748075, -0.002345481887459755, 0.001117417006753385, -8.688368834555149E-4, -0.001830018823966384, 0.0013242220738902688, -8.880519890226424E-4, -6.888324278406799E-4, -0.0036394784692674875, 0.002179111586883664, -1.7201311129610986E-4, 0.002365073887631297, 0.002688770182430744, 0.0023955567739903927, 0.001469283364713192, 0.0011803617235273123, 5.871498142369092E-4, -7.099180947989225E-4, 7.518937345594168E-4, -8.599072461947799E-4, -6.600041524507105E-4, -0.002724145073443651, -8.365285466425121E-4, 0.0013173354091122746, 0.001083166105672717, 0.0014539906987920403, -3.1698777456767857E-4, -2.387022686889395E-4, 1.9560157670639455E-4, 0.0020277926232665777, -0.0012741144746541977, -0.0013026101514697075, -1.5212174912448972E-4, 0.0014194383984431624, 0.0012500399025157094, 0.0013362085446715355, 3.692879108712077E-4, 4.319801155361347E-5, 0.0011261265026405454, 0.0017244465416297317, 5.564604725805111E-5, 0.002170475199818611, 0.0014707016525790095, 0.001303741242736578, 0.005553730763494968, -0.0011097051901742816, -0.0013661726843565702, 0.0014100460102781653, 0.0011811562580987811, -6.622733199037611E-4, 7.860265322960913E-4, -9.811905911192298E-4};
|
||||||
assertArrayEquals(expected, fastText.getWordVector("enjoy"), 1e-4);
|
assertArrayEquals(expected, fastText.getWordVector("enjoy"), 2e-3);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -111,7 +111,7 @@ public class FastTextTest extends BaseDL4JTest {
|
||||||
assertEquals("association", fastText.vocab().wordAtIndex(fastText.vocab().numWords() - 1));
|
assertEquals("association", fastText.vocab().wordAtIndex(fastText.vocab().numWords() - 1));
|
||||||
|
|
||||||
double[] expected = {-0.006423053797334433, 0.007660661358386278, 0.006068876478821039, -0.004772625397890806, -0.007143457420170307, -0.007735592778772116, -0.005607823841273785, -0.00836215727031231, 0.0011235733982175589, 2.599214785732329E-4, 0.004131870809942484, 0.007203693501651287, 0.0016768622444942594, 0.008694255724549294, -0.0012487826170399785, -0.00393667770549655, -0.006292815785855055, 0.0049359360709786415, -3.356488887220621E-4, -0.009407570585608482, -0.0026168026961386204, -0.00978928804397583, 0.0032913016621023417, -0.0029464277904480696, -0.008649969473481178, 8.056449587456882E-4, 0.0043088337406516075, -0.008980576880276203, 0.008716211654245853, 0.0073893265798687935, -0.007388216909021139, 0.003814412746578455, -0.005518500227481127, 0.004668557550758123, 0.006603693123906851, 0.003820829326286912, 0.007174000144004822, -0.006393063813447952, -0.0019381389720365405, -0.0046371882781386375, -0.006193376146256924, -0.0036685809027403593, 7.58899434003979E-4, -0.003185075242072344, -0.008330358192324638, 3.3206873922608793E-4, -0.005389622412621975, 0.009706716984510422, 0.0037855932023376226, -0.008665262721478939, -0.0032511046156287193, 4.4134497875347733E-4, -0.008377416990697384, -0.009110655635595322, 0.0019723298028111458, 0.007486093323677778, 0.006400121841579676, 0.00902814231812954, 0.00975200068205595, 0.0060582347214221954, -0.0075621469877660275, 1.0270809434587136E-4, -0.00673140911385417, -0.007316927425563335, 0.009916870854794979, -0.0011407854035496712, -4.502215306274593E-4, -0.007612560410052538, 0.008726916275918484, -3.0280642022262327E-5, 0.005529289599508047, -0.007944817654788494, 0.005593308713287115, 0.003423960180953145, 4.1348213562741876E-4, 0.009524818509817123, -0.0025129399728029966, -0.0030074280221015215, -0.007503866218030453, -0.0028124507516622543, -0.006841592025011778, -2.9375351732596755E-4, 0.007195258513092995, -0.007775942329317331, 3.951996040996164E-4, -0.006887971889227629, 0.0032655203249305487, -0.007975360378623009, -4.840183464693837E-6, 0.004651934839785099, 0.0031739831902086735, 0.004644941072911024, -0.007461248897016048, 0.003057275665923953, 0.008903342299163342, 0.006857945583760738, 0.007567950990051031, 0.001506582135334611, 0.0063307867385447025, 0.005645462777465582};
|
double[] expected = {-0.006423053797334433, 0.007660661358386278, 0.006068876478821039, -0.004772625397890806, -0.007143457420170307, -0.007735592778772116, -0.005607823841273785, -0.00836215727031231, 0.0011235733982175589, 2.599214785732329E-4, 0.004131870809942484, 0.007203693501651287, 0.0016768622444942594, 0.008694255724549294, -0.0012487826170399785, -0.00393667770549655, -0.006292815785855055, 0.0049359360709786415, -3.356488887220621E-4, -0.009407570585608482, -0.0026168026961386204, -0.00978928804397583, 0.0032913016621023417, -0.0029464277904480696, -0.008649969473481178, 8.056449587456882E-4, 0.0043088337406516075, -0.008980576880276203, 0.008716211654245853, 0.0073893265798687935, -0.007388216909021139, 0.003814412746578455, -0.005518500227481127, 0.004668557550758123, 0.006603693123906851, 0.003820829326286912, 0.007174000144004822, -0.006393063813447952, -0.0019381389720365405, -0.0046371882781386375, -0.006193376146256924, -0.0036685809027403593, 7.58899434003979E-4, -0.003185075242072344, -0.008330358192324638, 3.3206873922608793E-4, -0.005389622412621975, 0.009706716984510422, 0.0037855932023376226, -0.008665262721478939, -0.0032511046156287193, 4.4134497875347733E-4, -0.008377416990697384, -0.009110655635595322, 0.0019723298028111458, 0.007486093323677778, 0.006400121841579676, 0.00902814231812954, 0.00975200068205595, 0.0060582347214221954, -0.0075621469877660275, 1.0270809434587136E-4, -0.00673140911385417, -0.007316927425563335, 0.009916870854794979, -0.0011407854035496712, -4.502215306274593E-4, -0.007612560410052538, 0.008726916275918484, -3.0280642022262327E-5, 0.005529289599508047, -0.007944817654788494, 0.005593308713287115, 0.003423960180953145, 4.1348213562741876E-4, 0.009524818509817123, -0.0025129399728029966, -0.0030074280221015215, -0.007503866218030453, -0.0028124507516622543, -0.006841592025011778, -2.9375351732596755E-4, 0.007195258513092995, -0.007775942329317331, 3.951996040996164E-4, -0.006887971889227629, 0.0032655203249305487, -0.007975360378623009, -4.840183464693837E-6, 0.004651934839785099, 0.0031739831902086735, 0.004644941072911024, -0.007461248897016048, 0.003057275665923953, 0.008903342299163342, 0.006857945583760738, 0.007567950990051031, 0.001506582135334611, 0.0063307867385447025, 0.005645462777465582};
|
||||||
assertArrayEquals(expected, fastText.getWordVector("association"), 1e-4);
|
assertArrayEquals(expected, fastText.getWordVector("association"), 2e-3);
|
||||||
|
|
||||||
String label = fastText.predict(text);
|
String label = fastText.predict(text);
|
||||||
assertEquals("__label__soccer", label);
|
assertEquals("__label__soccer", label);
|
||||||
|
@ -126,7 +126,7 @@ public class FastTextTest extends BaseDL4JTest {
|
||||||
assertEquals("association", fastText.vocab().wordAtIndex(fastText.vocab().numWords() - 1));
|
assertEquals("association", fastText.vocab().wordAtIndex(fastText.vocab().numWords() - 1));
|
||||||
|
|
||||||
double[] expected = {-0.006423053797334433, 0.007660661358386278, 0.006068876478821039, -0.004772625397890806, -0.007143457420170307, -0.007735592778772116, -0.005607823841273785, -0.00836215727031231, 0.0011235733982175589, 2.599214785732329E-4, 0.004131870809942484, 0.007203693501651287, 0.0016768622444942594, 0.008694255724549294, -0.0012487826170399785, -0.00393667770549655, -0.006292815785855055, 0.0049359360709786415, -3.356488887220621E-4, -0.009407570585608482, -0.0026168026961386204, -0.00978928804397583, 0.0032913016621023417, -0.0029464277904480696, -0.008649969473481178, 8.056449587456882E-4, 0.0043088337406516075, -0.008980576880276203, 0.008716211654245853, 0.0073893265798687935, -0.007388216909021139, 0.003814412746578455, -0.005518500227481127, 0.004668557550758123, 0.006603693123906851, 0.003820829326286912, 0.007174000144004822, -0.006393063813447952, -0.0019381389720365405, -0.0046371882781386375, -0.006193376146256924, -0.0036685809027403593, 7.58899434003979E-4, -0.003185075242072344, -0.008330358192324638, 3.3206873922608793E-4, -0.005389622412621975, 0.009706716984510422, 0.0037855932023376226, -0.008665262721478939, -0.0032511046156287193, 4.4134497875347733E-4, -0.008377416990697384, -0.009110655635595322, 0.0019723298028111458, 0.007486093323677778, 0.006400121841579676, 0.00902814231812954, 0.00975200068205595, 0.0060582347214221954, -0.0075621469877660275, 1.0270809434587136E-4, -0.00673140911385417, -0.007316927425563335, 0.009916870854794979, -0.0011407854035496712, -4.502215306274593E-4, -0.007612560410052538, 0.008726916275918484, -3.0280642022262327E-5, 0.005529289599508047, -0.007944817654788494, 0.005593308713287115, 0.003423960180953145, 4.1348213562741876E-4, 0.009524818509817123, -0.0025129399728029966, -0.0030074280221015215, -0.007503866218030453, -0.0028124507516622543, -0.006841592025011778, -2.9375351732596755E-4, 0.007195258513092995, -0.007775942329317331, 3.951996040996164E-4, -0.006887971889227629, 0.0032655203249305487, -0.007975360378623009, -4.840183464693837E-6, 0.004651934839785099, 0.0031739831902086735, 0.004644941072911024, -0.007461248897016048, 0.003057275665923953, 0.008903342299163342, 0.006857945583760738, 0.007567950990051031, 0.001506582135334611, 0.0063307867385447025, 0.005645462777465582};
|
double[] expected = {-0.006423053797334433, 0.007660661358386278, 0.006068876478821039, -0.004772625397890806, -0.007143457420170307, -0.007735592778772116, -0.005607823841273785, -0.00836215727031231, 0.0011235733982175589, 2.599214785732329E-4, 0.004131870809942484, 0.007203693501651287, 0.0016768622444942594, 0.008694255724549294, -0.0012487826170399785, -0.00393667770549655, -0.006292815785855055, 0.0049359360709786415, -3.356488887220621E-4, -0.009407570585608482, -0.0026168026961386204, -0.00978928804397583, 0.0032913016621023417, -0.0029464277904480696, -0.008649969473481178, 8.056449587456882E-4, 0.0043088337406516075, -0.008980576880276203, 0.008716211654245853, 0.0073893265798687935, -0.007388216909021139, 0.003814412746578455, -0.005518500227481127, 0.004668557550758123, 0.006603693123906851, 0.003820829326286912, 0.007174000144004822, -0.006393063813447952, -0.0019381389720365405, -0.0046371882781386375, -0.006193376146256924, -0.0036685809027403593, 7.58899434003979E-4, -0.003185075242072344, -0.008330358192324638, 3.3206873922608793E-4, -0.005389622412621975, 0.009706716984510422, 0.0037855932023376226, -0.008665262721478939, -0.0032511046156287193, 4.4134497875347733E-4, -0.008377416990697384, -0.009110655635595322, 0.0019723298028111458, 0.007486093323677778, 0.006400121841579676, 0.00902814231812954, 0.00975200068205595, 0.0060582347214221954, -0.0075621469877660275, 1.0270809434587136E-4, -0.00673140911385417, -0.007316927425563335, 0.009916870854794979, -0.0011407854035496712, -4.502215306274593E-4, -0.007612560410052538, 0.008726916275918484, -3.0280642022262327E-5, 0.005529289599508047, -0.007944817654788494, 0.005593308713287115, 0.003423960180953145, 4.1348213562741876E-4, 0.009524818509817123, -0.0025129399728029966, -0.0030074280221015215, -0.007503866218030453, -0.0028124507516622543, -0.006841592025011778, -2.9375351732596755E-4, 0.007195258513092995, -0.007775942329317331, 3.951996040996164E-4, -0.006887971889227629, 0.0032655203249305487, -0.007975360378623009, -4.840183464693837E-6, 0.004651934839785099, 0.0031739831902086735, 0.004644941072911024, -0.007461248897016048, 0.003057275665923953, 0.008903342299163342, 0.006857945583760738, 0.007567950990051031, 0.001506582135334611, 0.0063307867385447025, 0.005645462777465582};
|
||||||
assertArrayEquals(expected, fastText.getWordVector("association"), 1e-4);
|
assertArrayEquals(expected, fastText.getWordVector("association"), 2e-3);
|
||||||
|
|
||||||
String label = fastText.predict(text);
|
String label = fastText.predict(text);
|
||||||
fastText.wordsNearest("test",1);
|
fastText.wordsNearest("test",1);
|
||||||
|
@ -140,10 +140,10 @@ public class FastTextTest extends BaseDL4JTest {
|
||||||
|
|
||||||
Pair<String,Float> result = fastText.predictProbability(text);
|
Pair<String,Float> result = fastText.predictProbability(text);
|
||||||
assertEquals("__label__soccer", result.getFirst());
|
assertEquals("__label__soccer", result.getFirst());
|
||||||
assertEquals(-0.6930, result.getSecond(), 1e-4);
|
assertEquals(-0.6930, result.getSecond(), 2e-3);
|
||||||
|
|
||||||
assertEquals(48, fastText.vocabSize());
|
assertEquals(48, fastText.vocabSize());
|
||||||
assertEquals(0.0500, fastText.getLearningRate(), 1e-4);
|
assertEquals(0.0500, fastText.getLearningRate(), 2e-3);
|
||||||
assertEquals(100, fastText.getDimension());
|
assertEquals(100, fastText.getDimension());
|
||||||
assertEquals(5, fastText.getContextWindowSize());
|
assertEquals(5, fastText.getContextWindowSize());
|
||||||
assertEquals(5, fastText.getEpoch());
|
assertEquals(5, fastText.getEpoch());
|
||||||
|
@ -155,7 +155,7 @@ public class FastTextTest extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testVocabulary() throws IOException {
|
public void testVocabulary() {
|
||||||
FastText fastText = new FastText(supModelFile);
|
FastText fastText = new FastText(supModelFile);
|
||||||
assertEquals(48, fastText.vocab().numWords());
|
assertEquals(48, fastText.vocab().numWords());
|
||||||
assertEquals(48, fastText.vocabSize());
|
assertEquals(48, fastText.vocabSize());
|
||||||
|
@ -171,78 +171,73 @@ public class FastTextTest extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testLoadIterator() {
|
public void testLoadIterator() throws FileNotFoundException {
|
||||||
try {
|
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
|
||||||
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
|
FastText
|
||||||
FastText fastText =
|
.builder()
|
||||||
FastText.builder().supervised(true).iterator(iter).build();
|
.supervised(true)
|
||||||
fastText.loadIterator();
|
.iterator(iter)
|
||||||
|
.build()
|
||||||
} catch (IOException e) {
|
.loadIterator();
|
||||||
log.error("",e);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(expected=IllegalStateException.class)
|
@Test(expected=IllegalStateException.class)
|
||||||
public void testState() {
|
public void testState() {
|
||||||
FastText fastText = new FastText();
|
FastText fastText = new FastText();
|
||||||
String label = fastText.predict("something");
|
fastText.predict("something");
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testPretrainedVectors() throws IOException {
|
public void testPretrainedVectors() throws IOException {
|
||||||
File output = testDir.newFile();
|
File output = testDir.newFile();
|
||||||
|
|
||||||
FastText fastText =
|
FastText fastText = FastText
|
||||||
FastText.builder().supervised(true).
|
.builder()
|
||||||
inputFile(inputFile.getAbsolutePath()).
|
.supervised(true)
|
||||||
pretrainedVectorsFile(supervisedVectors.getAbsolutePath()).
|
.inputFile(inputFile.getAbsolutePath())
|
||||||
outputFile(output.getAbsolutePath()).build();
|
.pretrainedVectorsFile(supervisedVectors.getAbsolutePath())
|
||||||
|
.outputFile(output.getAbsolutePath())
|
||||||
|
.build();
|
||||||
|
|
||||||
log.info("\nTraining supervised model ...\n");
|
log.info("\nTraining supervised model ...\n");
|
||||||
fastText.fit();
|
fastText.fit();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testWordsStatistics() throws IOException {
|
public void testWordsStatistics() throws IOException {
|
||||||
|
|
||||||
File output = testDir.newFile();
|
File output = testDir.newFile();
|
||||||
|
|
||||||
FastText fastText =
|
FastText fastText = FastText
|
||||||
FastText.builder().supervised(true).
|
.builder()
|
||||||
inputFile(inputFile.getAbsolutePath()).
|
.supervised(true)
|
||||||
outputFile(output.getAbsolutePath()).build();
|
.inputFile(inputFile.getAbsolutePath())
|
||||||
|
.outputFile(output.getAbsolutePath())
|
||||||
|
.build();
|
||||||
|
|
||||||
log.info("\nTraining supervised model ...\n");
|
log.info("\nTraining supervised model ...\n");
|
||||||
fastText.fit();
|
fastText.fit();
|
||||||
|
|
||||||
Word2Vec word2Vec = WordVectorSerializer.readAsCsv(new File(output.getAbsolutePath() + ".vec"));
|
File file = new File(output.getAbsolutePath() + ".vec");
|
||||||
|
Word2Vec word2Vec = WordVectorSerializer.readAsCsv(file);
|
||||||
|
|
||||||
assertEquals(48, word2Vec.getVocab().numWords());
|
assertEquals(48, word2Vec.getVocab().numWords());
|
||||||
|
assertEquals("", 0.1667751520872116, word2Vec.similarity("Football", "teams"), 2e-3);
|
||||||
System.out.println(word2Vec.wordsNearest("association", 3));
|
assertEquals("", 0.10083991289138794, word2Vec.similarity("professional", "minutes"), 2e-3);
|
||||||
System.out.println(word2Vec.similarity("Football", "teams"));
|
assertEquals("", Double.NaN, word2Vec.similarity("java","cpp"), 0.0);
|
||||||
System.out.println(word2Vec.similarity("professional", "minutes"));
|
assertThat(word2Vec.wordsNearest("association", 3), hasItems("Football", "Soccer", "men's"));
|
||||||
System.out.println(word2Vec.similarity("java","cpp"));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testWordsNativeStatistics() throws IOException {
|
public void testWordsNativeStatistics() {
|
||||||
|
|
||||||
File output = testDir.newFile();
|
|
||||||
|
|
||||||
FastText fastText = new FastText();
|
FastText fastText = new FastText();
|
||||||
fastText.loadPretrainedVectors(supervisedVectors);
|
fastText.loadPretrainedVectors(supervisedVectors);
|
||||||
|
|
||||||
log.info("\nTraining supervised model ...\n");
|
log.info("\nTraining supervised model ...\n");
|
||||||
|
|
||||||
assertEquals(48, fastText.vocab().numWords());
|
assertEquals(48, fastText.vocab().numWords());
|
||||||
|
assertThat(fastText.wordsNearest("association", 3), hasItems("most","eleven","hours"));
|
||||||
String[] result = new String[3];
|
assertEquals(0.1657, fastText.similarity("Football", "teams"), 2e-3);
|
||||||
fastText.wordsNearest("association", 3).toArray(result);
|
assertEquals(0.3661, fastText.similarity("professional", "minutes"), 2e-3);
|
||||||
assertArrayEquals(new String[]{"most","eleven","hours"}, result);
|
assertEquals(Double.NaN, fastText.similarity("java","cpp"), 0.0);
|
||||||
assertEquals(0.1657, fastText.similarity("Football", "teams"), 1e-4);
|
|
||||||
assertEquals(0.3661, fastText.similarity("professional", "minutes"), 1e-4);
|
|
||||||
assertEquals(Double.NaN, fastText.similarity("java","cpp"), 1e-4);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -47,7 +47,9 @@ import java.io.ByteArrayInputStream;
|
||||||
import java.io.ByteArrayOutputStream;
|
import java.io.ByteArrayOutputStream;
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.util.Collection;
|
import java.util.Collection;
|
||||||
|
import java.util.concurrent.Callable;
|
||||||
|
|
||||||
|
import static org.awaitility.Awaitility.await;
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
|
|
||||||
|
|
||||||
|
@ -190,22 +192,26 @@ public class Word2VecTestsSmall extends BaseDL4JTest {
|
||||||
.nOut(4).build())
|
.nOut(4).build())
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
final MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
net.init();
|
net.init();
|
||||||
|
|
||||||
INDArray w0 = net.getParam("0_W");
|
INDArray w0 = net.getParam("0_W");
|
||||||
assertEquals(w, w0);
|
assertEquals(w, w0);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
||||||
ModelSerializer.writeModel(net, baos, true);
|
ModelSerializer.writeModel(net, baos, true);
|
||||||
byte[] bytes = baos.toByteArray();
|
byte[] bytes = baos.toByteArray();
|
||||||
|
|
||||||
ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
|
ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
|
||||||
MultiLayerNetwork restored = ModelSerializer.restoreMultiLayerNetwork(bais, true);
|
final MultiLayerNetwork restored = ModelSerializer.restoreMultiLayerNetwork(bais, true);
|
||||||
|
|
||||||
assertEquals(net.getLayerWiseConfigurations(), restored.getLayerWiseConfigurations());
|
assertEquals(net.getLayerWiseConfigurations(), restored.getLayerWiseConfigurations());
|
||||||
assertEquals(net.params(), restored.params());
|
await()
|
||||||
|
.until(new Callable<Boolean>() {
|
||||||
|
@Override
|
||||||
|
public Boolean call() {
|
||||||
|
return net.params().equalsWithEps(restored.params(), 2e-3);
|
||||||
|
}
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -63,6 +63,9 @@ public class Deconvolution2D extends ConvolutionLayer {
|
||||||
protected Deconvolution2D(BaseConvBuilder<?> builder) {
|
protected Deconvolution2D(BaseConvBuilder<?> builder) {
|
||||||
super(builder);
|
super(builder);
|
||||||
initializeConstraints(builder);
|
initializeConstraints(builder);
|
||||||
|
if(builder instanceof Builder){
|
||||||
|
this.cnn2dDataFormat = ((Builder) builder).format;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public boolean hasBias() {
|
public boolean hasBias() {
|
||||||
|
@ -136,7 +139,7 @@ public class Deconvolution2D extends ConvolutionLayer {
|
||||||
|
|
||||||
private CNN2DFormat format = CNN2DFormat.NCHW;
|
private CNN2DFormat format = CNN2DFormat.NCHW;
|
||||||
|
|
||||||
public Builder format(CNN2DFormat format){
|
public Builder dataFormat(CNN2DFormat format){
|
||||||
this.format = format;
|
this.format = format;
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
|
@ -310,11 +310,21 @@ public class ConvolutionLayer extends BaseLayer<org.deeplearning4j.nn.conf.layer
|
||||||
String layerName = conf.getLayer().getLayerName();
|
String layerName = conf.getLayer().getLayerName();
|
||||||
if (layerName == null)
|
if (layerName == null)
|
||||||
layerName = "(not named)";
|
layerName = "(not named)";
|
||||||
throw new DL4JInvalidInputException("Cannot do forward pass in Convolution layer (layer name = " + layerName
|
|
||||||
|
String s = "Cannot do forward pass in Convolution layer (layer name = " + layerName
|
||||||
+ ", layer index = " + index + "): input array channels does not match CNN layer configuration"
|
+ ", layer index = " + index + "): input array channels does not match CNN layer configuration"
|
||||||
+ " (data input channels = " + input.size(dim) + ", " + layerConf().getCnn2dDataFormat().dimensionNames()
|
+ " (data format = " + format + ", data input channels = " + input.size(dim) + ", " + layerConf().getCnn2dDataFormat().dimensionNames()
|
||||||
+ "=" + Arrays.toString(input.shape()) + "; expected" + " input channels = " + inDepth + ") "
|
+ "=" + Arrays.toString(input.shape()) + "; expected" + " input channels = " + inDepth + ") "
|
||||||
+ layerId());
|
+ layerId();
|
||||||
|
|
||||||
|
int dimIfWrongFormat = format == CNN2DFormat.NHWC ? 1 : 3;
|
||||||
|
if(input.size(dimIfWrongFormat) == inDepth){
|
||||||
|
//User might have passed NCHW data to a NHWC net, or vice versa?
|
||||||
|
s += "\n" + ConvolutionUtils.NCHW_NHWC_ERROR_MSG;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
throw new DL4JInvalidInputException(s);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -190,12 +190,21 @@ public class Deconvolution2DLayer extends ConvolutionLayer {
|
||||||
String layerName = conf.getLayer().getLayerName();
|
String layerName = conf.getLayer().getLayerName();
|
||||||
if (layerName == null)
|
if (layerName == null)
|
||||||
layerName = "(not named)";
|
layerName = "(not named)";
|
||||||
throw new DL4JInvalidInputException("Cannot do forward pass in Deconvolution2D layer (layer name = " + layerName
|
|
||||||
|
String s = "Cannot do forward pass in Deconvolution2D layer (layer name = " + layerName
|
||||||
+ ", layer index = " + index + "): input array channels does not match CNN layer configuration"
|
+ ", layer index = " + index + "): input array channels does not match CNN layer configuration"
|
||||||
+ " (data input channels = " + input.size(cDim) + ", "
|
+ " (data format = " + format + ", data input channels = " + input.size(cDim) + ", "
|
||||||
+ (nchw ? "[minibatch,inputDepth,height,width]" : "[minibatch,height,width,inputDepth]") + "="
|
+ (nchw ? "[minibatch,inputDepth,height,width]" : "[minibatch,height,width,inputDepth]") + "="
|
||||||
+ Arrays.toString(input.shape()) + "; expected" + " input channels = " + inDepth + ") "
|
+ Arrays.toString(input.shape()) + "; expected" + " input channels = " + inDepth + ") "
|
||||||
+ layerId());
|
+ layerId();
|
||||||
|
|
||||||
|
int dimIfWrongFormat = format == CNN2DFormat.NHWC ? 1 : 3;
|
||||||
|
if(input.size(dimIfWrongFormat) == inDepth){
|
||||||
|
//User might have passed NCHW data to a NHWC net, or vice versa?
|
||||||
|
s += "\n" + ConvolutionUtils.NCHW_NHWC_ERROR_MSG;
|
||||||
|
}
|
||||||
|
|
||||||
|
throw new DL4JInvalidInputException(s);
|
||||||
}
|
}
|
||||||
int kH = (int) weights.size(2);
|
int kH = (int) weights.size(2);
|
||||||
int kW = (int) weights.size(3);
|
int kW = (int) weights.size(3);
|
||||||
|
|
|
@ -183,13 +183,21 @@ public class DepthwiseConvolution2DLayer extends ConvolutionLayer {
|
||||||
String layerName = conf.getLayer().getLayerName();
|
String layerName = conf.getLayer().getLayerName();
|
||||||
if (layerName == null)
|
if (layerName == null)
|
||||||
layerName = "(not named)";
|
layerName = "(not named)";
|
||||||
throw new DL4JInvalidInputException("Cannot do forward pass in DepthwiseConvolution2D layer " +
|
|
||||||
|
String s = "Cannot do forward pass in DepthwiseConvolution2D layer " +
|
||||||
"(layer name = " + layerName
|
"(layer name = " + layerName
|
||||||
+ ", layer index = " + index + "): input array channels does not match CNN layer configuration"
|
+ ", layer index = " + index + "): input array channels does not match CNN layer configuration"
|
||||||
+ " (data input channels = " + input.size(1) + ", "
|
+ " (data format = " + format + ", data input channels = " + input.size(1) + ", "
|
||||||
+ (nchw ? "[minibatch,inputDepth,height,width]=" : "[minibatch,height,width,inputDepth]=")
|
+ (nchw ? "[minibatch,inputDepth,height,width]=" : "[minibatch,height,width,inputDepth]=")
|
||||||
+ Arrays.toString(input.shape()) + "; expected" + " input channels = " + inDepth + ") "
|
+ Arrays.toString(input.shape()) + "; expected" + " input channels = " + inDepth + ") "
|
||||||
+ layerId());
|
+ layerId();
|
||||||
|
int dimIfWrongFormat = format == CNN2DFormat.NHWC ? 1 : 3;
|
||||||
|
if(input.size(dimIfWrongFormat) == inDepth){
|
||||||
|
//User might have passed NCHW data to a NHWC net, or vice versa?
|
||||||
|
s += "\n" + ConvolutionUtils.NCHW_NHWC_ERROR_MSG;
|
||||||
|
}
|
||||||
|
|
||||||
|
throw new DL4JInvalidInputException(s);
|
||||||
}
|
}
|
||||||
int kH = (int) depthWiseWeights.size(0);
|
int kH = (int) depthWiseWeights.size(0);
|
||||||
int kW = (int) depthWiseWeights.size(1);
|
int kW = (int) depthWiseWeights.size(1);
|
||||||
|
|
|
@ -211,11 +211,20 @@ public class SeparableConvolution2DLayer extends ConvolutionLayer {
|
||||||
String layerName = conf.getLayer().getLayerName();
|
String layerName = conf.getLayer().getLayerName();
|
||||||
if (layerName == null)
|
if (layerName == null)
|
||||||
layerName = "(not named)";
|
layerName = "(not named)";
|
||||||
throw new DL4JInvalidInputException("Cannot do forward pass in SeparableConvolution2D layer (layer name = " + layerName
|
|
||||||
|
String s = "Cannot do forward pass in SeparableConvolution2D layer (layer name = " + layerName
|
||||||
+ ", layer index = " + index + "): input array channels does not match CNN layer configuration"
|
+ ", layer index = " + index + "): input array channels does not match CNN layer configuration"
|
||||||
+ " (data input channels = " + input.size(1) + ", [minibatch,inputDepth,height,width]="
|
+ " (data format = " + format + ", data input channels = " + input.size(1) + ", [minibatch,inputDepth,height,width]="
|
||||||
+ Arrays.toString(input.shape()) + "; expected" + " input channels = " + inDepth + ") "
|
+ Arrays.toString(input.shape()) + "; expected" + " input channels = " + inDepth + ") "
|
||||||
+ layerId());
|
+ layerId();
|
||||||
|
|
||||||
|
int dimIfWrongFormat = format == CNN2DFormat.NHWC ? 1 : 3;
|
||||||
|
if(input.size(dimIfWrongFormat) == inDepth){
|
||||||
|
//User might have passed NCHW data to a NHWC net, or vice versa?
|
||||||
|
s += "\n" + ConvolutionUtils.NCHW_NHWC_ERROR_MSG;
|
||||||
|
}
|
||||||
|
|
||||||
|
throw new DL4JInvalidInputException(s);
|
||||||
}
|
}
|
||||||
int kH = (int) depthWiseWeights.size(2);
|
int kH = (int) depthWiseWeights.size(2);
|
||||||
int kW = (int) depthWiseWeights.size(3);
|
int kW = (int) depthWiseWeights.size(3);
|
||||||
|
|
|
@ -20,7 +20,7 @@ import lombok.*;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.deeplearning4j.nn.api.Model;
|
import org.deeplearning4j.nn.api.Model;
|
||||||
import org.deeplearning4j.optimize.api.BaseTrainingListener;
|
import org.deeplearning4j.optimize.api.BaseTrainingListener;
|
||||||
import org.deeplearning4j.util.ThreadUtils;
|
import org.nd4j.common.util.ThreadUtils;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
|
|
|
@ -27,8 +27,8 @@ import org.deeplearning4j.optimize.solvers.accumulation.encoding.ResidualPostPro
|
||||||
import org.deeplearning4j.optimize.solvers.accumulation.encoding.ThresholdAlgorithm;
|
import org.deeplearning4j.optimize.solvers.accumulation.encoding.ThresholdAlgorithm;
|
||||||
import org.deeplearning4j.optimize.solvers.accumulation.encoding.residual.ResidualClippingPostProcessor;
|
import org.deeplearning4j.optimize.solvers.accumulation.encoding.residual.ResidualClippingPostProcessor;
|
||||||
import org.deeplearning4j.optimize.solvers.accumulation.encoding.threshold.AdaptiveThresholdAlgorithm;
|
import org.deeplearning4j.optimize.solvers.accumulation.encoding.threshold.AdaptiveThresholdAlgorithm;
|
||||||
import org.deeplearning4j.util.ThreadUtils;
|
|
||||||
import org.nd4j.common.base.Preconditions;
|
import org.nd4j.common.base.Preconditions;
|
||||||
|
import org.nd4j.common.util.ThreadUtils;
|
||||||
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||||
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
|
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
|
||||||
import org.nd4j.linalg.api.memory.enums.*;
|
import org.nd4j.linalg.api.memory.enums.*;
|
||||||
|
@ -69,7 +69,7 @@ public class EncodedGradientsAccumulator implements GradientsAccumulator, Regist
|
||||||
protected ThreadLocal<Integer> index = new ThreadLocal<>();
|
protected ThreadLocal<Integer> index = new ThreadLocal<>();
|
||||||
protected long initialMemory = 100 * 1024 * 1024L;
|
protected long initialMemory = 100 * 1024 * 1024L;
|
||||||
protected int queueSize = 5;
|
protected int queueSize = 5;
|
||||||
protected Double boundary = 1.0;
|
protected Integer boundary = Integer.MAX_VALUE;
|
||||||
protected boolean encodingDebugMode;
|
protected boolean encodingDebugMode;
|
||||||
|
|
||||||
protected IndexedTail externalSource;
|
protected IndexedTail externalSource;
|
||||||
|
@ -101,11 +101,11 @@ public class EncodedGradientsAccumulator implements GradientsAccumulator, Regist
|
||||||
}
|
}
|
||||||
|
|
||||||
public EncodedGradientsAccumulator(int parties, ThresholdAlgorithm thresholdAlgorithm, ResidualPostProcessor residualPostProcessor, boolean encodingDebugMode) {
|
public EncodedGradientsAccumulator(int parties, ThresholdAlgorithm thresholdAlgorithm, ResidualPostProcessor residualPostProcessor, boolean encodingDebugMode) {
|
||||||
this(parties, new EncodingHandler(thresholdAlgorithm, residualPostProcessor, 1.0, encodingDebugMode), DEFAULT_INITIAL_MEMORY, 10, 1.0, encodingDebugMode);
|
this(parties, new EncodingHandler(thresholdAlgorithm, residualPostProcessor, Integer.MAX_VALUE, encodingDebugMode), DEFAULT_INITIAL_MEMORY, 10, Integer.MAX_VALUE, encodingDebugMode);
|
||||||
}
|
}
|
||||||
|
|
||||||
public EncodedGradientsAccumulator(int parties, @NonNull MessageHandler handler, long initialMemory,
|
public EncodedGradientsAccumulator(int parties, @NonNull MessageHandler handler, long initialMemory,
|
||||||
int queueSize, Double boundary, boolean encodingDebugMode) {
|
int queueSize, Integer boundary, boolean encodingDebugMode) {
|
||||||
this.parties = parties;
|
this.parties = parties;
|
||||||
this.handler = handler;
|
this.handler = handler;
|
||||||
this.initialMemory = initialMemory;
|
this.initialMemory = initialMemory;
|
||||||
|
@ -551,7 +551,7 @@ public class EncodedGradientsAccumulator implements GradientsAccumulator, Regist
|
||||||
protected long initialMemory = DEFAULT_INITIAL_MEMORY;
|
protected long initialMemory = DEFAULT_INITIAL_MEMORY;
|
||||||
protected int queueSize = 5;
|
protected int queueSize = 5;
|
||||||
protected MessageHandler handler;
|
protected MessageHandler handler;
|
||||||
protected Double boundary = null;
|
protected int boundary = Integer.MAX_VALUE;
|
||||||
protected boolean encodingDebugMode;
|
protected boolean encodingDebugMode;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -598,15 +598,12 @@ public class EncodedGradientsAccumulator implements GradientsAccumulator, Regist
|
||||||
/**
|
/**
|
||||||
* This method enables optional limit for max number of updates per message
|
* This method enables optional limit for max number of updates per message
|
||||||
*
|
*
|
||||||
* Default value: 1.0 (no limit)
|
* Default value: Integer.MAX_VALUE (no limit)
|
||||||
* @param boundary positive value in range 0..1
|
* @param boundary positive value in range 0..1
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public Builder updatesBoundary(double boundary) {
|
public Builder updatesBoundary(int boundary) {
|
||||||
if (boundary >= 1.0)
|
if (boundary <= 0)
|
||||||
return this;
|
|
||||||
|
|
||||||
if (boundary <= 0.0)
|
|
||||||
throw new DL4JInvalidConfigException("Boundary should have positive value");
|
throw new DL4JInvalidConfigException("Boundary should have positive value");
|
||||||
|
|
||||||
this.boundary = boundary;
|
this.boundary = boundary;
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.optimize.solvers.accumulation;
|
package org.deeplearning4j.optimize.solvers.accumulation;
|
||||||
|
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.shade.guava.util.concurrent.AtomicDouble;
|
import org.nd4j.shade.guava.util.concurrent.AtomicDouble;
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
@ -24,7 +25,6 @@ import org.deeplearning4j.optimize.solvers.accumulation.encoding.ThresholdAlgori
|
||||||
import org.deeplearning4j.optimize.solvers.accumulation.encoding.ThresholdAlgorithmReducer;
|
import org.deeplearning4j.optimize.solvers.accumulation.encoding.ThresholdAlgorithmReducer;
|
||||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
import org.nd4j.linalg.api.buffer.DataBuffer;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.compression.NDArrayCompressor;
|
|
||||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.ops.transforms.Transforms;
|
import org.nd4j.linalg.ops.transforms.Transforms;
|
||||||
|
@ -54,9 +54,8 @@ public class EncodingHandler implements MessageHandler {
|
||||||
protected ThresholdAlgorithm initialThresholdAlgorithm;
|
protected ThresholdAlgorithm initialThresholdAlgorithm;
|
||||||
protected ResidualPostProcessor initialResidualPostProcessor;
|
protected ResidualPostProcessor initialResidualPostProcessor;
|
||||||
|
|
||||||
protected Double boundary;
|
protected Integer boundary;
|
||||||
protected boolean encodingDebugMode;
|
protected boolean encodingDebugMode;
|
||||||
protected NDArrayCompressor compressor;
|
|
||||||
protected AtomicInteger atomicBoundary = new AtomicInteger(-1);
|
protected AtomicInteger atomicBoundary = new AtomicInteger(-1);
|
||||||
|
|
||||||
protected ThreadLocal<ThresholdAlgorithm> thresholdAlgorithm = new ThreadLocal<>();
|
protected ThreadLocal<ThresholdAlgorithm> thresholdAlgorithm = new ThreadLocal<>();
|
||||||
|
@ -73,20 +72,16 @@ public class EncodingHandler implements MessageHandler {
|
||||||
protected final AtomicLong lastThresholdLogTime = new AtomicLong();
|
protected final AtomicLong lastThresholdLogTime = new AtomicLong();
|
||||||
|
|
||||||
public EncodingHandler(final ThresholdAlgorithm thresholdAlgorithm, final ResidualPostProcessor residualPostProcessor,
|
public EncodingHandler(final ThresholdAlgorithm thresholdAlgorithm, final ResidualPostProcessor residualPostProcessor,
|
||||||
Double boundary, boolean encodingDebugMode){
|
Integer boundary, boolean encodingDebugMode){
|
||||||
this.initialThresholdAlgorithm = thresholdAlgorithm;
|
this.initialThresholdAlgorithm = thresholdAlgorithm;
|
||||||
this.initialResidualPostProcessor = residualPostProcessor;
|
this.initialResidualPostProcessor = residualPostProcessor;
|
||||||
this.boundary = boundary;
|
this.boundary = boundary == null ? Integer.MAX_VALUE : boundary;
|
||||||
this.encodingDebugMode = encodingDebugMode;
|
this.encodingDebugMode = encodingDebugMode;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void initialize(@NonNull GradientsAccumulator accumulator) {
|
public void initialize(@NonNull GradientsAccumulator accumulator) {
|
||||||
this.accumulator = accumulator;
|
this.accumulator = accumulator;
|
||||||
|
|
||||||
compressor = Nd4j.getCompressor().getCompressor("THRESHOLD");
|
|
||||||
if (compressor == null)
|
|
||||||
throw new ND4JIllegalStateException("Can't find Threshold compressor implementation!");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public INDArray encodeUpdates(int iteration, int epoch, INDArray updates) {
|
public INDArray encodeUpdates(int iteration, int epoch, INDArray updates) {
|
||||||
|
@ -135,14 +130,13 @@ public class EncodingHandler implements MessageHandler {
|
||||||
iterations.get().incrementAndGet();
|
iterations.get().incrementAndGet();
|
||||||
|
|
||||||
if (boundary != null && atomicBoundary.get() < 0)
|
if (boundary != null && atomicBoundary.get() < 0)
|
||||||
atomicBoundary.compareAndSet(-1, (int) (updates.length() * boundary));
|
atomicBoundary.compareAndSet(-1, (int) (updates.length() / 16) );
|
||||||
|
|
||||||
INDArray encoded;
|
INDArray encoded;
|
||||||
|
|
||||||
if (!bitmapMode.get().get()) {
|
if (!bitmapMode.get().get()) {
|
||||||
//Sparse updates
|
//Sparse updates
|
||||||
encoded = Nd4j.getExecutioner().thresholdEncode(updates, currentThreshold.get().get(),
|
encoded = Nd4j.getExecutioner().thresholdEncode(updates, currentThreshold.get().get(), boundary == null ? null : atomicBoundary.get());
|
||||||
boundary == null ? null : atomicBoundary.get());
|
|
||||||
|
|
||||||
// updates were TOO sparse, nothing to share here
|
// updates were TOO sparse, nothing to share here
|
||||||
if (encoded == null) {
|
if (encoded == null) {
|
||||||
|
@ -157,17 +151,14 @@ public class EncodingHandler implements MessageHandler {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
double encLen = encoded.data().getInt(0);
|
double encLen = encoded.length();
|
||||||
|
|
||||||
// if updates are too dense - we fallback to bitmap encoding
|
// if updates are too dense - we fallback to bitmap encoding
|
||||||
if (encLen >= (updates.length() / 16)) {
|
if (encLen >= (updates.length() / 16)) {
|
||||||
log.debug("Switching back to bitmapEncoding: iteration {}, epoch {}, threshold {}, encoded length {}", iteration, epoch, currThreshold, encLen);
|
log.debug("Switching back to bitmapEncoding: iteration {}, epoch {}, threshold {}, encoded length {}", iteration, epoch, currThreshold, encLen);
|
||||||
bitmapMode.get().set(true);
|
bitmapMode.get().set(true);
|
||||||
|
|
||||||
DataBuffer buffer = Nd4j.getDataBufferFactory().createInt(updates.length() / 16 + 5);
|
encoded = Nd4j.getExecutioner().bitmapEncode(updates, currentThreshold.get().get());
|
||||||
encoded = Nd4j.createArrayFromShapeBuffer(buffer, updates.shapeInfoDataBuffer());
|
|
||||||
|
|
||||||
Nd4j.getExecutioner().bitmapEncode(updates, encoded, currentThreshold.get().get());
|
|
||||||
|
|
||||||
applyPostProcessor(iteration, epoch, currThreshold, updates);
|
applyPostProcessor(iteration, epoch, currThreshold, updates);
|
||||||
lastSparsityRatio.set(null);
|
lastSparsityRatio.set(null);
|
||||||
|
@ -186,8 +177,7 @@ public class EncodingHandler implements MessageHandler {
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
//Dense bitmap updates
|
//Dense bitmap updates
|
||||||
DataBuffer buffer = Nd4j.getDataBufferFactory().createInt(updates.length() / 16 + 5);
|
encoded = Nd4j.create(DataType.INT32, updates.length() / 16 + 5);
|
||||||
encoded = Nd4j.createArrayFromShapeBuffer(buffer, updates.shapeInfoDataBuffer());
|
|
||||||
|
|
||||||
long values = Nd4j.getExecutioner().bitmapEncode(updates, encoded, currentThreshold.get().get());
|
long values = Nd4j.getExecutioner().bitmapEncode(updates, encoded, currentThreshold.get().get());
|
||||||
|
|
||||||
|
|
|
@ -18,6 +18,7 @@ package org.deeplearning4j.optimize.solvers.accumulation;
|
||||||
|
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.nd4j.common.util.ThreadUtils;
|
||||||
|
|
||||||
import java.util.Collection;
|
import java.util.Collection;
|
||||||
import java.util.Iterator;
|
import java.util.Iterator;
|
||||||
|
@ -28,8 +29,6 @@ import java.util.concurrent.atomic.AtomicInteger;
|
||||||
import java.util.concurrent.atomic.AtomicLong;
|
import java.util.concurrent.atomic.AtomicLong;
|
||||||
import java.util.concurrent.locks.ReentrantReadWriteLock;
|
import java.util.concurrent.locks.ReentrantReadWriteLock;
|
||||||
|
|
||||||
import org.deeplearning4j.util.ThreadUtils;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This BlockingQueue implementation is suited only for symmetric gradients updates, and should NOT be used anywhere else.
|
* This BlockingQueue implementation is suited only for symmetric gradients updates, and should NOT be used anywhere else.
|
||||||
*
|
*
|
||||||
|
|
|
@ -48,6 +48,13 @@ import java.util.Arrays;
|
||||||
*/
|
*/
|
||||||
public class ConvolutionUtils {
|
public class ConvolutionUtils {
|
||||||
|
|
||||||
|
public static final String NCHW_NHWC_ERROR_MSG = "Note: Convolution layers can be configured for either NCHW (channels first)" +
|
||||||
|
" or NHWC (channels last) format for input images and activations.\n" +
|
||||||
|
"Layers can be configured using .dataFormat(CNN2DFormat.NCHW/NHWC) when constructing the layer, or for the entire net using" +
|
||||||
|
" .setInputType(InputType.convolutional(height, width, depth, CNN2DForman.NCHW/NHWC)).\n" +
|
||||||
|
"ImageRecordReader and NativeImageLoader can also be configured to load image data in either NCHW or NHWC format which must match the network";
|
||||||
|
|
||||||
|
|
||||||
private static final int[] ONES = new int[]{1, 1};
|
private static final int[] ONES = new int[]{1, 1};
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -910,7 +910,7 @@ public class ParallelWrapper implements AutoCloseable {
|
||||||
Preconditions.checkState(thresholdAlgorithm != null, "Cannot use SHARED_GRADIENTS training mode without setting a threshold algorithm");
|
Preconditions.checkState(thresholdAlgorithm != null, "Cannot use SHARED_GRADIENTS training mode without setting a threshold algorithm");
|
||||||
this.trainerContext = new SymmetricTrainerContext();
|
this.trainerContext = new SymmetricTrainerContext();
|
||||||
if (this.accumulator == null) {
|
if (this.accumulator == null) {
|
||||||
log.info("Creating new GradientsAccumulator instance with threshold of [5e-4");
|
log.info("Creating new GradientsAccumulator instance with default threshold of [5e-4]");
|
||||||
this.accumulator = new EncodedGradientsAccumulator(workers, thresholdAlgorithm, residualPostProcessor, false);
|
this.accumulator = new EncodedGradientsAccumulator(workers, thresholdAlgorithm, residualPostProcessor, false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -45,7 +45,7 @@ public class WiredEncodingHandler extends EncodingHandler {
|
||||||
* @param thresholdAlgorithm threshold algorithm to use
|
* @param thresholdAlgorithm threshold algorithm to use
|
||||||
* @param boundary
|
* @param boundary
|
||||||
*/
|
*/
|
||||||
public WiredEncodingHandler(ThresholdAlgorithm thresholdAlgorithm, ResidualPostProcessor residualPostProcessor, Double boundary, boolean encodingDebugMode) {
|
public WiredEncodingHandler(ThresholdAlgorithm thresholdAlgorithm, ResidualPostProcessor residualPostProcessor, Integer boundary, boolean encodingDebugMode) {
|
||||||
super(thresholdAlgorithm, residualPostProcessor, boundary, encodingDebugMode);
|
super(thresholdAlgorithm, residualPostProcessor, boundary, encodingDebugMode);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -44,7 +44,7 @@ public class WiredEncodingHandler extends EncodingHandler {
|
||||||
*
|
*
|
||||||
* @param thresholdAlgorithm The threshold algorithm to use
|
* @param thresholdAlgorithm The threshold algorithm to use
|
||||||
*/
|
*/
|
||||||
public WiredEncodingHandler(ThresholdAlgorithm thresholdAlgorithm, ResidualPostProcessor residualPostProcessor, Double boundary, boolean encodingDebugMode) {
|
public WiredEncodingHandler(ThresholdAlgorithm thresholdAlgorithm, ResidualPostProcessor residualPostProcessor, Integer boundary, boolean encodingDebugMode) {
|
||||||
super(thresholdAlgorithm, residualPostProcessor, boundary, encodingDebugMode);
|
super(thresholdAlgorithm, residualPostProcessor, boundary, encodingDebugMode);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,62 +0,0 @@
|
||||||
# DL4J auto-generated documentation
|
|
||||||
|
|
||||||
## Building
|
|
||||||
|
|
||||||
Run `./gen_all_docs.sh` to generate documentation from source for all supported projects. For each documentation module, files will be put into a `doc_sources` folder where they are staged for copying to the primary docs repository. Note that the autogen docs require Python 2.
|
|
||||||
|
|
||||||
To deploy a new version of documentation, first make sure to set `$DL4J_DOCS_DIR` to your local copy of
|
|
||||||
https://github.com/eclipse/deeplearning4j-docs and set `$DL4J_VERSION` to a URI-friendly version string such as `v100-RC` (note the lack of decimals). Then run `./copy-to-dl4j-docs.sh`. This puts documentation
|
|
||||||
into the right folders and you can use `git` to create a PR and update the live docs.
|
|
||||||
|
|
||||||
The structure of this project (template files, generating code, mkdocs YAML) is closely aligned
|
|
||||||
with the [Keras documentation](keras.io) and heavily inspired by the [Keras docs repository](https://github.com/keras-team/keras/tree/master/docs).
|
|
||||||
|
|
||||||
## File structure
|
|
||||||
|
|
||||||
Each major module or library in Eclipse Deeplearning4j has its own folder. Inside that folder are three essential files:
|
|
||||||
|
|
||||||
- `templates/`
|
|
||||||
- `pages.json`
|
|
||||||
- `README.md`
|
|
||||||
|
|
||||||
Note that the folder names don't exactly match up with the modules in the `pom.xml` definitions across DL4J. This is because some of the documentation is consolidated (such as DataVec) or omitted due to its experimental status or because it is low-level in the code.
|
|
||||||
|
|
||||||
Templates must maintain a flat file structure. This is to accommodate Jekyll collections when the docs are published. Don't worry about having similarly named files in different doc modules - the module name is prepended when the docs are generated.
|
|
||||||
|
|
||||||
## Creating templates
|
|
||||||
|
|
||||||
Each template has a Jekyll header at the top:
|
|
||||||
|
|
||||||
```markdown
|
|
||||||
---
|
|
||||||
title: Deeplearning4j Autoencoders
|
|
||||||
short_title: Autoencoders
|
|
||||||
description: Supported autoencoder configurations.
|
|
||||||
category: Models
|
|
||||||
weight: 3
|
|
||||||
---
|
|
||||||
```
|
|
||||||
|
|
||||||
All of these definitions are necessary.
|
|
||||||
|
|
||||||
- `title` is the HTML title that appears for a Google result or at the top of the browser window.
|
|
||||||
- `short_title` is a short name for simple navigation in the user guide.
|
|
||||||
- `description` is the text that appears below the title in a search engine result.
|
|
||||||
- `category` is the high-level category in the user guide.
|
|
||||||
- `weight` is the ordering that the doc will appear in navigation, the larger the lower the listing.
|
|
||||||
|
|
||||||
## Creating links
|
|
||||||
|
|
||||||
**All links to other docs need to be relative.** This prolongs the life of the documentation and reduces maintenance. The basic structure of a link to another doc looks like:
|
|
||||||
|
|
||||||
```
|
|
||||||
<module name>-<file name>
|
|
||||||
```
|
|
||||||
|
|
||||||
So if you created a DataVec doc with the name `iterators.md` in the `datavec` module, your relative link will look like:
|
|
||||||
|
|
||||||
```
|
|
||||||
./datavec-iterators
|
|
||||||
```
|
|
||||||
|
|
||||||
Note the omission of the file extension `.md`. Jekyll automatically generates a clean URL for us to use.
|
|
|
@ -1,16 +0,0 @@
|
||||||
################################################################################
|
|
||||||
# Copyright (c) 2015-2019 Skymind, Inc.
|
|
||||||
#
|
|
||||||
# This program and the accompanying materials are made available under the
|
|
||||||
# terms of the Apache License, Version 2.0 which is available at
|
|
||||||
# https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
# License for the specific language governing permissions and limitations
|
|
||||||
# under the License.
|
|
||||||
#
|
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
################################################################################
|
|
||||||
|
|
|
@ -1,10 +0,0 @@
|
||||||
# arbiter documentation
|
|
||||||
|
|
||||||
To generate docs into the`datavec/doc_sources` folder, first `cd docs` then run:
|
|
||||||
|
|
||||||
```shell
|
|
||||||
python generate_docs.py \
|
|
||||||
--project arbiter \
|
|
||||||
--code ../arbiter
|
|
||||||
--out_language en
|
|
||||||
```
|
|
|
@ -1,61 +0,0 @@
|
||||||
{
|
|
||||||
"excludes": [
|
|
||||||
"abstract"
|
|
||||||
],
|
|
||||||
"indices": [
|
|
||||||
],
|
|
||||||
"pages": [
|
|
||||||
{
|
|
||||||
"page": "overview.md",
|
|
||||||
"class": []
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"page": "visualization.md",
|
|
||||||
"class": []
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"page": "parameter-spaces.md",
|
|
||||||
"class": [
|
|
||||||
"arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/parameter/continuous/ContinuousParameterSpace.java",
|
|
||||||
"arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/parameter/discrete/DiscreteParameterSpace.java",
|
|
||||||
"arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/parameter/integer/IntegerParameterSpace.java",
|
|
||||||
"arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/parameter/BooleanSpace.java",
|
|
||||||
"arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/dropout/AlphaDropoutSpace.java",
|
|
||||||
"arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/dropout/GaussianDropoutSpace.java",
|
|
||||||
"arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/dropout/GaussianNoiseSpace.java",
|
|
||||||
"arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/parameter/FixedValue.java",
|
|
||||||
"arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/parameter/math/MathOp.java",
|
|
||||||
"arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/parameter/math/PairMathOp.java"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"page": "layer-spaces.md",
|
|
||||||
"class": [
|
|
||||||
"arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/ActivationLayerSpace.java",
|
|
||||||
"arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/AutoEncoderLayerSpace.java",
|
|
||||||
"arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/BatchNormalizationSpace.java",
|
|
||||||
"arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/Bidirectional.java",
|
|
||||||
"arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/CenterLossOutputLayerSpace.java",
|
|
||||||
"arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/ConvolutionLayerSpace.java",
|
|
||||||
"arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/Deconvolution2DLayerSpace.java",
|
|
||||||
"arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/DenseLayerSpace.java",
|
|
||||||
"arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/DropoutLayerSpace.java",
|
|
||||||
"arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/EmbeddingLayerSpace.java",
|
|
||||||
"arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/FeedForwardLayerSpace.java",
|
|
||||||
"arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/GlobalPoolingLayerSpace.java",
|
|
||||||
"arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/GravesBidirectionalLSTMLayerSpace.java",
|
|
||||||
"arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/GravesLSTMLayerSpace.java",
|
|
||||||
"arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/LSTMLayerSpace.java",
|
|
||||||
"arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/LocalResponseNormalizationLayerSpace.java",
|
|
||||||
"arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/LossLayerSpace.java",
|
|
||||||
"arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/OCNNLayerSpace.java",
|
|
||||||
"arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/OutputLayerSpace.java",
|
|
||||||
"arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/RnnOutputLayerSpace.java",
|
|
||||||
"arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/SeparableConvolution2DLayerSpace.java",
|
|
||||||
"arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/SubsamplingLayerSpace.java",
|
|
||||||
"arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/VariationalAutoencoderLayerSpace.java"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
|
@ -1,11 +0,0 @@
|
||||||
---
|
|
||||||
title: Arbiter Layer Spaces
|
|
||||||
short_title: Layer Spaces
|
|
||||||
description: Set a search spaces for layers.
|
|
||||||
category: Arbiter
|
|
||||||
weight: 1
|
|
||||||
---
|
|
||||||
|
|
||||||
## Layer Spaces
|
|
||||||
|
|
||||||
{{autogenerated}}
|
|
|
@ -1,257 +0,0 @@
|
||||||
---
|
|
||||||
title: Arbiter Overview
|
|
||||||
short_title: Overview
|
|
||||||
description: Introduction to using Arbiter for hyperparameter optimization.
|
|
||||||
category: Arbiter
|
|
||||||
weight: 0
|
|
||||||
---
|
|
||||||
|
|
||||||
## Hyperparameter Optimization
|
|
||||||
|
|
||||||
Machine learning techniques have a set of parameters that have to be chosen before any training can begin. These parameters are referred to as hyperparameters. Some examples of hyperparameters are ‘k’ in k-nearest-neighbors and the regularization parameter in Support Vector Machines. Neural Networks, in particular, have a wide variety of hyperparameters. Some of these define the architecture of the neural network like the number of layers and their size. Other define the learning process like the learning rate and regularization.
|
|
||||||
|
|
||||||
Traditionally these choices are made based on existing rules of thumb or after extensive trial and error, both of which are less than ideal. Undoubtedly the choice of these parameters can have a significant impact on the results obtained after learning. Hyperparameter optimization attempts to automate this process using software that applies search strategies.
|
|
||||||
|
|
||||||
## Arbiter
|
|
||||||
|
|
||||||
Arbiter is part of the DL4J Suite of Machine Learning/Deep Learning tools for the enterprise. It is dedicated to the hyperparameter optimization of neural networks created or imported into dl4j. It allows users to set up search spaces for the hyperparameters and run either grid search or random search to select the best configuration based on a given scoring metric.
|
|
||||||
|
|
||||||
When to use Arbiter?
|
|
||||||
Arbiter can be used to find good performing models, potentially saving you time tuning your model's hyperparameters, at the expense of greater computational time. Note however that Arbiter doesn't completely automate the neural network tuning process, the user still needs to specify a search space. This search space defines the range of valid values for each hyperparameter (example: minimum and maximum allowable learning rate). If this search space is chosen poorly, Arbiter may not be able to find any good models.
|
|
||||||
|
|
||||||
Add the following to your pom.xml to include Arbiter in your project where ${arbiter.version} is the latest release of the dl4j stack.
|
|
||||||
|
|
||||||
```xml
|
|
||||||
<!-- Arbiter - used for hyperparameter optimization (grid/random search) -->
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.deeplearning4j</groupId>
|
|
||||||
<artifactId>arbiter-deeplearning4j</artifactId>
|
|
||||||
<version>{{page.version}}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.deeplearning4j</groupId>
|
|
||||||
<artifactId>arbiter-ui_2.11</artifactId>
|
|
||||||
<version>{{page.version}}</version>
|
|
||||||
</dependency>
|
|
||||||
```
|
|
||||||
|
|
||||||
Arbiter also comes with a handy UI that helps visualize the results from the optimizations runs.
|
|
||||||
|
|
||||||
As a prerequisite to using Arbiter users should be familiar with the NeuralNetworkConfiguration, MultilayerNetworkConfiguration and ComputationGraphconfiguration classes in DL4J.
|
|
||||||
|
|
||||||
## Usage
|
|
||||||
This section will provide an overview of the important constructs necessary to use Arbiter. The sections that follow will dive into the details.
|
|
||||||
|
|
||||||
At the highest level, setting up hyperparameter optimization involves setting up an OptimizationConfiguration and running it via IOptimizationRunner.
|
|
||||||
|
|
||||||
Below is some code that demonstrates the fluent builder pattern in OptimizationConfiguration:
|
|
||||||
|
|
||||||
```java
|
|
||||||
OptimizationConfiguration configuration = new OptimizationConfiguration.Builder()
|
|
||||||
.candidateGenerator(candidateGenerator)
|
|
||||||
.dataSource(dataSourceClass,dataSourceProperties)
|
|
||||||
.modelSaver(modelSaver)
|
|
||||||
.scoreFunction(scoreFunction)
|
|
||||||
.terminationConditions(terminationConditions)
|
|
||||||
.build();
|
|
||||||
```
|
|
||||||
|
|
||||||
As indicated above setting up an optimization configuration requires:
|
|
||||||
CandidateGenerator: Proposes candidates (i.e., hyperparameter configurations) for evaluation. Candidates are generated based on some strategy. Currently random search and grid search are supported. Valid configurations for the candidates are determined by the hyperparameter space associated with the candidate generator.
|
|
||||||
DataSource: DataSource is used under the hood to provide data to the generated candidates for training and test
|
|
||||||
ModelSaver: Specifies how the results of each hyperparameter optimization run should be saved. For example, whether saving should be done to local disk, to a database, to HDFS, or simply stored in memory.
|
|
||||||
ScoreFunction: A metric that is a single number that we are seeking to minimize or maximize to determine the best candidate. Eg. Model loss or classification accuracy
|
|
||||||
TerminationCondition: Determines when hyperparameter optimization should be stopped. Eg. A given number of candidates have been evaluated, a certain amount of computation time has passed.
|
|
||||||
|
|
||||||
The optimization configuration is then passed to an optimization runner along with a task creator.
|
|
||||||
|
|
||||||
If candidates generated are MultiLayerNetworks this is set up as follows:
|
|
||||||
|
|
||||||
```java
|
|
||||||
IOptimizationRunner runner = new LocalOptimizationRunner(configuration, new MultiLayerNetworkTaskCreator());
|
|
||||||
```
|
|
||||||
|
|
||||||
Alternatively if candidates generated are ComputationGraphs this is set up as follows:
|
|
||||||
|
|
||||||
```java
|
|
||||||
IOptimizationRunner runner = new LocalOptimizationRunner(configuration, new ComputationGraphTaskCreator());
|
|
||||||
```
|
|
||||||
|
|
||||||
Currently the only option available for the runner is the LocalOptimizationRunner which is used to execute learning on a single machine (i.e, in the current JVM). In principle, other execution methods (for example, on Spark or cloud computing machines) could be implemented.
|
|
||||||
|
|
||||||
To summarize here are the steps to set up a hyperparameter optimization run:
|
|
||||||
|
|
||||||
1. Specify hyperparameter search space
|
|
||||||
1. Specify a candidate generator for the hyperparameter search space
|
|
||||||
1. The next section of steps can be done in any order:
|
|
||||||
1. Specify a data source
|
|
||||||
1. Specify a model saver
|
|
||||||
1. Specify a score function
|
|
||||||
1. Specify a termination condition
|
|
||||||
1. The next steps have to be done in order:
|
|
||||||
1. Use 2 to 6 above to construct an Optimization Configuration
|
|
||||||
1. Run with the Optimization Runner.
|
|
||||||
|
|
||||||
|
|
||||||
## Hyperparameter search space
|
|
||||||
|
|
||||||
Arbiter’s `ParameterSpace<T>` class defines the acceptable ranges of values a given hyperparameter may take. ParameterSpace can be a simple, like a ParameterSpace that defines a continuous range of double values (say for learning rate) or complicated with multiple nested parameter spaces within like the case of a MultiLayerSpace (which defines a search space for a MultilayerConfiguration).
|
|
||||||
|
|
||||||
|
|
||||||
## MultiLayerSpace and ComputationGraphSpace
|
|
||||||
|
|
||||||
MultiLayerSpace and ComputationGraphSpace are Arbiter’s counterpart to dl4j’s MultiLayerConfiguration and ComputationGraphConfiguration. They are used to set up parameter spaces for valid hyperparameters in MultiLayerConfiguration and ComputationGraphConfiguration.
|
|
||||||
|
|
||||||
In addition to these users can also set up the number of epochs or an early stopping configuration to indicate when training on each candidate neural net should stop. If both an EarlyStoppingConfiguration and the number of epochs are specified, early stopping will be used in preference.
|
|
||||||
|
|
||||||
Setting up MultiLayerSpace or ComputationGraphSpace are fairly straightforward once the user is familiar with Integer, Continuous and Discrete parameter spaces and LayerSpaces and UpdaterSpaces.
|
|
||||||
|
|
||||||
The only caveat to be noted here is that while it is possible to set up weightConstraints, l1Bias and l2Bias as part of the NeuralNetConfiguration these have to be setup on a per layer/layerSpace basis in MultiLayerSpace. In general all properties/hyperparameters available through the builder will take either a fixed value or a parameter space of that type. This means that pretty much every aspect of the MultiLayerConfiguration can be swept to test out a variety of architectures and initial values.
|
|
||||||
|
|
||||||
Here is a simple example of a MultiLayerSpace:
|
|
||||||
|
|
||||||
```java
|
|
||||||
ParameterSpace<Boolean> biasSpace = new DiscreteParameterSpace<>(new Boolean[]{true, false});
|
|
||||||
ParameterSpace<Integer> firstLayerSize = new IntegerParameterSpace(10,30);
|
|
||||||
ParameterSpace<Integer> secondLayerSize = new MathOp<>(firstLayerSize, Op.MUL, 3);
|
|
||||||
ParameterSpace<Double> firstLayerLR = new ContinuousParameterSpace(0.01, 0.1);
|
|
||||||
ParameterSpace<Double> secondLayerLR = new MathOp<>(firstLayerLR, Op.ADD, 0.2);
|
|
||||||
|
|
||||||
MultiLayerSpace mls =
|
|
||||||
new MultiLayerSpace.Builder().seed(12345)
|
|
||||||
.hasBias(biasSpace)
|
|
||||||
.layer(new DenseLayerSpace.Builder().nOut(firstLayerSize)
|
|
||||||
.updater(new AdamSpace(firstLayerLR))
|
|
||||||
.build())
|
|
||||||
.layer(new OutputLayerSpace.Builder().nOut(secondLayerSize)
|
|
||||||
.updater(new AdamSpace(secondLayerLR))
|
|
||||||
.build())
|
|
||||||
.setInputType(InputType.feedForward(10))
|
|
||||||
.numEpochs(20).build(); //Data will be fit for a fixed number of epochs
|
|
||||||
```
|
|
||||||
|
|
||||||
Of particular note is Arbiter’s ability to vary the number of layers in the MultiLayerSpace. Here is a simple example demonstrating the same that also demonstrates setting up a parameter search space for a weighted loss function:
|
|
||||||
|
|
||||||
```java
|
|
||||||
ILossFunction[] weightedLossFns = new ILossFunction[]{
|
|
||||||
new LossMCXENT(Nd4j.create(new double[]{1, 0.1})),
|
|
||||||
new LossMCXENT(Nd4j.create(new double[]{1, 0.05})),
|
|
||||||
new LossMCXENT(Nd4j.create(new double[]{1, 0.01}))};
|
|
||||||
|
|
||||||
DiscreteParameterSpace<ILossFunction> weightLossFn = new DiscreteParameterSpace<>(weightedLossFns);
|
|
||||||
MultiLayerSpace mls =
|
|
||||||
new MultiLayerSpace.Builder().seed(12345)
|
|
||||||
.addLayer(new DenseLayerSpace.Builder().nIn(10).nOut(10).build(),
|
|
||||||
new IntegerParameterSpace(2, 5)) //2 to 5 identical layers
|
|
||||||
.addLayer(new OutputLayerSpace.Builder()
|
|
||||||
.iLossFunction(weightLossFn)
|
|
||||||
.nIn(10).nOut(2).build())
|
|
||||||
.backprop(true).pretrain(false).build();
|
|
||||||
```
|
|
||||||
|
|
||||||
The two to five layers created above will be identical (stacked). Currently Arbiter does not support the ability to create independent layers.
|
|
||||||
|
|
||||||
Finally it is also possible to create a fixed number of identical layers as shown in the following example:
|
|
||||||
|
|
||||||
```java
|
|
||||||
DiscreteParameterSpace<Activation> activationSpace = new DiscreteParameterSpace(new Activation[]{Activation.IDENTITY, Activation.ELU, Activation.RELU});
|
|
||||||
MultiLayerSpace mls = new MultiLayerSpace.Builder().updater(new Sgd(0.005))
|
|
||||||
.addLayer(new DenseLayerSpace.Builder().activation(activationSpace).nIn(10).nOut(10).build(),
|
|
||||||
new FixedValue<Integer>(3))
|
|
||||||
.addLayer(new OutputLayerSpace.Builder().iLossFunction(new LossMCXENT()).nIn(10).nOut(2).build())
|
|
||||||
.backprop(true).build();
|
|
||||||
```
|
|
||||||
|
|
||||||
In this example with a grid search three separate architectures will be created. They will be identical in every way but in the chosen activation function in the non-output layers. Again it is to be noted that the layers created in each architecture are identical(stacked).
|
|
||||||
|
|
||||||
Creating ComputationGraphSpace is very similar to MultiLayerSpace. However there is currently only support for fixed graph structures.
|
|
||||||
|
|
||||||
Here is a simple example demonstrating setting up a ComputationGraphSpace:
|
|
||||||
|
|
||||||
```java
|
|
||||||
ComputationGraphSpace cgs = new ComputationGraphSpace.Builder()
|
|
||||||
.updater(new SgdSpace(new ContinuousParameterSpace(0.0001, 0.1)))
|
|
||||||
.l2(new ContinuousParameterSpace(0.2, 0.5))
|
|
||||||
.addInputs("in")
|
|
||||||
.addLayer("0",new DenseLayerSpace.Builder().nIn(10).nOut(10).activation(
|
|
||||||
new DiscreteParameterSpace<>(Activation.RELU,Activation.TANH).build(),"in")
|
|
||||||
|
|
||||||
.addLayer("1", new OutputLayerSpace.Builder().nIn(10).nOut(10)
|
|
||||||
.activation(Activation.SOFTMAX).build(), "0")
|
|
||||||
.setOutputs("1").setInputTypes(InputType.feedForward(10)).build();
|
|
||||||
```
|
|
||||||
|
|
||||||
### JSON serialization.
|
|
||||||
|
|
||||||
MultiLayerSpace, ComputationGraphSpace and OptimizationConfiguration have `toJso`n methods as well as `fromJson` methods. You can store the JSON representation for further use.
|
|
||||||
|
|
||||||
Specifying a candidate generator
|
|
||||||
As mentioned earlier Arbiter currently supports grid search and random search.
|
|
||||||
|
|
||||||
Setting up a random search is straightforward and is shown below:
|
|
||||||
MultiLayerSpace mls;
|
|
||||||
...
|
|
||||||
CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls);
|
|
||||||
|
|
||||||
Setting up a grid search is also simple. With a grid search the user also gets to specify a discretization count and a mode. The discretization count determines how many values a continuous parameter is binned into. For eg. a continuous parameter in range [0,1] is converted to [0.0, 0.5, 1.0] with a discretizationCount of 3. The mode determines the manner in which the candidates are generated. Candidates can be generated in Sequential (in order) or RandomOrder. With sequential order the first hyperparameter will be changed most rapidly and consequently the last hyperparameter will be changed the least rapidly. Note that both modes will result in the same set of candidates just in varying order.
|
|
||||||
|
|
||||||
Here is a simple example of how a grid search is set up with a discretization count of 4 in sequential order:
|
|
||||||
|
|
||||||
```java
|
|
||||||
CandidateGenerator candidateGenerator = new GridSearchCandidateGenerator(mls, 4,
|
|
||||||
GridSearchCandidateGenerator.Mode.Sequential);
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
## Specifying a data source
|
|
||||||
|
|
||||||
The DataSource interface defines where data for training the different candidates come from. It is very straightforward to implement. Note that a no argument constructor is required to be defined. Depending on the needs of the user the DataSource implementation can be configured with properties, like the size of the minibatch. A simple implementation of the data source that uses the MNIST dataset is available in the example repo which is covered later in this guide.
|
|
||||||
It is important to note here that the number of epochs (as well as early stopping configurations) can be set via the MultiLayerSpace and ComputationGraphSpace builders.
|
|
||||||
|
|
||||||
|
|
||||||
## Specifying a model/result saver
|
|
||||||
|
|
||||||
Arbiter currently supports saving models either saving to disk in local memory (FileModelSaver) or storing results in-memory (InMemoryResultSaver). InMemoryResultSaver is obviously not recommended for large models.
|
|
||||||
|
|
||||||
Setting them up are trivial. FileModelSaver constructor takes a path as String. It saves config, parameters and score to: baseDir/0/, baseDir/1/, etc where index is given by OptimizationResult.getIndex(). InMemoryResultSaver requires no arguments.
|
|
||||||
|
|
||||||
Specifying a score function
|
|
||||||
There are three main classes for score functions: EvaluationScoreFunction, ROCScoreFunction and RegressionScoreFunction.
|
|
||||||
|
|
||||||
EvaluationScoreFunction uses a DL4J evaluation metric. Available metrics are ACCURACY, F1, PRECISION, RECALL, GMEASURE, MCC. Here is a simple example that uses accuracy:
|
|
||||||
ScoreFunction scoreFunction = new EvaluationScoreFunction(Evaluation.Metric.ACCURACY);
|
|
||||||
|
|
||||||
ROCScoreFunction calculates AUC (area under ROC curve) or AUPRC (area under precision/recall curve) on the test set. Different ROC types (ROC, ROCBinary and ROCMultiClass) are supported. Here is a simple example that uses AUC:
|
|
||||||
ScoreFunction sf = new ROCScoreFunction(ROCScoreFunction.ROCType.BINARY, ROCScoreFunction.Metric.AUC));
|
|
||||||
|
|
||||||
RegressionScoreFunction is used for regression and supports all DL4J RegressionEvaluation metrics (MSE, MAE, RMSE, RSE, PC, R2). Here is a simple example:
|
|
||||||
ScoreFunction sf = new RegressionScoreFunction(RegressionEvaluation.Metric.MSE);
|
|
||||||
|
|
||||||
## Specifying a termination condition
|
|
||||||
|
|
||||||
Arbiter currently only supports two kinds of termination conditions - MaxTimeCondition and MaxCandidatesCondition. MaxTimeCondition specifies a time after which hyperparameter optimization will be terminated. MaxCandidatesCondition specifies a maximum number of candidates after which hyperparameter optimization is terminated. Termination conditions can be specified as a list. Hyperparameter optimization stops if any of the conditions are met.
|
|
||||||
|
|
||||||
Here is a simple example where the run is terminated at fifteen minutes or after training ten candidates which ever is met first:
|
|
||||||
|
|
||||||
```java
|
|
||||||
TerminationCondition[] terminationConditions = {
|
|
||||||
new MaxTimeCondition(15, TimeUnit.MINUTES),
|
|
||||||
new MaxCandidatesCondition(10)
|
|
||||||
};
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
## Example Arbiter Run on MNIST data
|
|
||||||
|
|
||||||
The DL4J example repo contains a BasicHyperparameterOptimizationExample on MNIST data. Users can walk through this simple example here. This example also goes through setting up the Arbiter UI. Arbiter uses the same storage and persistence approach as DL4J's UI. More documentation on the UI can be found here. The UI can be accessed at http://localhost:9000/arbiter.
|
|
||||||
|
|
||||||
|
|
||||||
## Tips for hyperparameter tuning
|
|
||||||
|
|
||||||
Please refer to the excellent section on hyperparameter optimization here from the CS231N class at Stanford. A summary of these techniques are below:
|
|
||||||
- Prefer random search over grid search. For a comparison of random and grid search methods, see Random Search for Hyper-parameter Optimization (Bergstra and Bengio, 2012).
|
|
||||||
- Run search from coarse to fine (Start with a coarse parameter search with one or two epochs, pick the best candidate to do a fine search on with more epochs, iterate)
|
|
||||||
- Use LogUniformDistribution for certain hyperparameter like the learning rate, l2 etc
|
|
||||||
- Be mindful of values that fall close to the borders of the parameter search space
|
|
||||||
|
|
||||||
|
|
|
@ -1,11 +0,0 @@
|
||||||
---
|
|
||||||
title: Arbiter Parameter Spaces
|
|
||||||
short_title: Parameter Spaces
|
|
||||||
description: Set a search spaces for parameters.
|
|
||||||
category: Arbiter
|
|
||||||
weight: 1
|
|
||||||
---
|
|
||||||
|
|
||||||
## Parameter Spaces
|
|
||||||
|
|
||||||
{{autogenerated}}
|
|
|
@ -1,34 +0,0 @@
|
||||||
#!/usr/bin/env bash
|
|
||||||
|
|
||||||
################################################################################
|
|
||||||
# Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
#
|
|
||||||
# This program and the accompanying materials are made available under the
|
|
||||||
# terms of the Apache License, Version 2.0 which is available at
|
|
||||||
# https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
# License for the specific language governing permissions and limitations
|
|
||||||
# under the License.
|
|
||||||
#
|
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
################################################################################
|
|
||||||
|
|
||||||
# Make sure to set $DL4J_DOCS_DIR to your local copy of https://github.com/deeplearning4j/deeplearning4j-docs
|
|
||||||
SOURCE_DIR=$(pwd)
|
|
||||||
|
|
||||||
# print the current git status
|
|
||||||
cd $DL4J_DOCS_DIR
|
|
||||||
git status
|
|
||||||
|
|
||||||
cd $SOURCE_DIR
|
|
||||||
|
|
||||||
# each release is its own jekyll collection located in docs/<version>
|
|
||||||
DOCS_DEST=$DL4J_DOCS_DIR/docs/_$DL4J_VERSION
|
|
||||||
mkdir $DOCS_DEST
|
|
||||||
echo Copying to $DOCS_DEST
|
|
||||||
|
|
||||||
# recursively find all files in doc_sources and copy
|
|
||||||
find $SOURCE_DIR/*/doc_sources -maxdepth 1 -type f -exec cp '{}' $DOCS_DEST \;
|
|
|
@ -1,10 +0,0 @@
|
||||||
# datavec documentation
|
|
||||||
|
|
||||||
To generate docs into the`datavec/doc_sources` folder, first `cd docs` then run:
|
|
||||||
|
|
||||||
```shell
|
|
||||||
python generate_docs.py \
|
|
||||||
--project datavec \
|
|
||||||
--code ../datavec
|
|
||||||
--out_language en
|
|
||||||
```
|
|
|
@ -1,203 +0,0 @@
|
||||||
{
|
|
||||||
"excludes": [
|
|
||||||
"abstract"
|
|
||||||
],
|
|
||||||
"indices": [
|
|
||||||
],
|
|
||||||
"pages": [
|
|
||||||
{
|
|
||||||
"page": "overview.md",
|
|
||||||
"class": []
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"page": "normalization.md",
|
|
||||||
"module": [
|
|
||||||
"/../nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"page": "records.md",
|
|
||||||
"class": [
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/records/impl/Record.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/records/impl/SequenceRecord.java"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"page": "readers.md",
|
|
||||||
"class": [
|
|
||||||
"datavec-data/datavec-data-image/src/main/java/org/datavec/image/recordreader/ImageRecordReader.java",
|
|
||||||
"datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/recordreader/NativeAudioRecordReader.java",
|
|
||||||
"datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/recordreader/WavFileRecordReader.java",
|
|
||||||
"datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/reader/TfidfRecordReader.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/records/reader/impl/FileRecordReader.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/records/reader/impl/LineRecordReader.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/records/reader/impl/ComposableRecordReader.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/records/reader/impl/csv/CSVRecordReader.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/records/reader/impl/csv/CSVRegexRecordReader.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/records/reader/impl/csv/CSVSequenceRecordReader.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/records/reader/impl/csv/CSVVariableSlidingWindowRecordReader.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/records/reader/impl/ConcatenatingRecordReader.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/records/reader/impl/transform/TransformProcessRecordReader.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/records/reader/impl/transform/TransformProcessSequenceRecordReader.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/records/reader/impl/collection/CollectionRecordReader.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/records/reader/impl/collection/CollectionSequenceRecordReader.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/records/reader/impl/collection/ListStringRecordReader.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/records/reader/impl/misc/LibSvmRecordReader.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/records/reader/impl/misc/MatlabRecordReader.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/records/reader/impl/misc/SVMLightRecordReader.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/records/reader/impl/regex/RegexLineRecordReader.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/records/reader/impl/regex/RegexSequenceRecordReader.java"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"page": "executors.md",
|
|
||||||
"class": [
|
|
||||||
"datavec-local/src/main/java/org/datavec/local/transforms/LocalTransformExecutor.java",
|
|
||||||
"datavec-spark/src/main/java/org/datavec/spark/transform/SparkTransformExecutor.java"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"page": "schema.md",
|
|
||||||
"class": [
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/schema/Schema.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/schema/SequenceSchema.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/schema/InferredSchema.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/join/Join.java"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"page": "transforms.md",
|
|
||||||
"class": [
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/TransformProcess.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/transform/categorical/CategoricalToIntegerTransform.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/transform/categorical/CategoricalToOneHotTransform.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/transform/categorical/IntegerToCategoricalTransform.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/transform/categorical/PivotTransform.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/transform/categorical/StringToCategoricalTransform.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/transform/column/AddConstantColumnTransform.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/transform/column/DuplicateColumnsTransform.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/transform/column/RemoveAllColumnsExceptForTransform.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/transform/column/RemoveColumnsTransform.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/transform/column/RenameColumnsTransform.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/transform/column/ReorderColumnsTransform.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/transform/doubletransform/DoubleColumnsMathOpTransform.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/transform/doubletransform/DoubleMathFunctionTransform.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/transform/doubletransform/DoubleMathOpTransform.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/transform/integer/IntegerColumnsMathOpTransform.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/transform/integer/IntegerMathOpTransform.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/transform/integer/IntegerToOneHotTransform.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/transform/integer/ReplaceEmptyIntegerWithValueTransform.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/transform/integer/ReplaceInvalidWithIntegerTransform.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/transform/longtransform/LongColumnsMathOpTransform.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/transform/longtransform/LongMathOpTransform.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/transform/nlp/TextToCharacterIndexTransform.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/transform/nlp/TextToTermIndexSequenceTransform.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/transform/sequence/SequenceDifferenceTransform.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/transform/sequence/SequenceMovingWindowReduceTransform.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/transform/sequence/SequenceOffsetTransform.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/transform/string/AppendStringColumnTransform.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/transform/string/ChangeCaseStringTransform.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/transform/string/ConcatenateStringColumns.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/transform/string/MapAllStringsExceptListTransform.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/transform/string/RemoveWhiteSpaceTransform.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/transform/string/ReplaceEmptyStringTransform.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/transform/string/ReplaceStringTransform.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/transform/string/StringListToCategoricalSetTransform.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/transform/string/StringListToCountsNDArrayTransform.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/transform/string/StringListToIndicesNDArrayTransform.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/transform/string/StringMapTransform.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/transform/time/DeriveColumnsFromTimeTransform.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/transform/time/StringToTimeTransform.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/transform/time/TimeMathOpTransform.java",
|
|
||||||
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/transform/condition/ConditionalCopyValueTransform.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/transform/condition/ConditionalReplaceValueTransform.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/transform/condition/ConditionalReplaceValueTransformWithDefault.java",
|
|
||||||
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/transform/doubletransform/ConvertToDouble.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/transform/integer/ConvertToInteger.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/transform/string/ConvertToString.java",
|
|
||||||
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/transform/doubletransform/Log2Normalizer.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/transform/doubletransform/MinMaxNormalizer.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/transform/doubletransform/StandardizeNormalizer.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/transform/doubletransform/SubtractMeanNormalizer.java"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"page": "operations.md",
|
|
||||||
"class": [
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/ops/AggregableCheckingOp.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/ops/AggregableMultiOp.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/ops/ByteWritableOp.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/ops/DispatchOp.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/ops/DispatchWithConditionOp.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/ops/DoubleWritableOp.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/ops/FloatWritableOp.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/ops/IntWritableOp.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/ops/LongWritableOp.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/ops/StringWritableOp.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/rank/CalculateSortedRank.java"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"page": "conditions.md",
|
|
||||||
"class": [
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/condition/column/BooleanColumnCondition.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/condition/column/CategoricalColumnCondition.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/condition/column/DoubleColumnCondition.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/condition/column/InfiniteColumnCondition.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/condition/column/IntegerColumnCondition.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/condition/column/InvalidValueColumnCondition.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/condition/column/LongColumnCondition.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/condition/column/NaNColumnCondition.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/condition/column/NullWritableColumnCondition.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/condition/column/StringColumnCondition.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/condition/column/TimeColumnCondition.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/condition/column/TrivialColumnCondition.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/condition/sequence/SequenceLengthCondition.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/condition/string/StringRegexColumnCondition.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/condition/BooleanCondition.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/condition/SequenceConditionMode.java"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"page": "filters.md",
|
|
||||||
"class": [
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/filter/Filter.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/filter/ConditionFilter.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/filter/FilterInvalidValues.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/filter/InvalidNumColumns.java"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"page": "reductions.md",
|
|
||||||
"class": [
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/reduce/impl/GeographicMidpointReduction.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/stringreduce/StringReducer.java"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"page": "serialization.md",
|
|
||||||
"class": [
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/serde/JsonSerializer.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/serde/YamlSerializer.java"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"page": "visualization.md",
|
|
||||||
"class": [
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/ui/HtmlAnalysis.java",
|
|
||||||
"datavec-api/src/main/java/org/datavec/api/transform/ui/HtmlSequencePlotting.java"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"page": "analysis.md",
|
|
||||||
"class": [
|
|
||||||
"datavec-spark/src/main/java/org/datavec/spark/transform/AnalyzeSpark.java",
|
|
||||||
"datavec-local/src/main/java/org/datavec/local/transforms/AnalyzeLocal.java"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
|
@ -1,58 +0,0 @@
|
||||||
---
|
|
||||||
title: DataVec Analysis
|
|
||||||
short_title: Analysis
|
|
||||||
description: Gather statistics on datasets.
|
|
||||||
category: DataVec
|
|
||||||
weight: 2
|
|
||||||
---
|
|
||||||
|
|
||||||
## Analysis of data
|
|
||||||
|
|
||||||
Sometimes datasets are too large or too abstract in their format to manually analyze and estimate statistics on certain columns or patterns. DataVec comes with some helper utilities for performing a data analysis, and maximums, means, minimums, and other useful metrics.
|
|
||||||
|
|
||||||
## Using Spark for analysis
|
|
||||||
|
|
||||||
If you have loaded your data into Apache Spark, DataVec has a special `AnalyzeSpark` class which can generate histograms, collect statistics, and return information about the quality of the data. Assuming you have already loaded your data into a Spark RDD, pass the `JavaRDD` and `Schema` to the class.
|
|
||||||
|
|
||||||
If you are using DataVec in Scala and your data was loaded into a regular `RDD` class, you can convert it by calling `.toJavaRDD()` which returns a `JavaRDD`. If you need to convert it back, call `rdd()`.
|
|
||||||
|
|
||||||
The code below demonstrates some of many analyses for a 2D dataset in Spark analysis using the RDD `javaRdd` and the schema `mySchema`:
|
|
||||||
|
|
||||||
```java
|
|
||||||
import org.datavec.spark.transform.AnalyzeSpark;
|
|
||||||
import org.datavec.api.writable.Writable;
|
|
||||||
import org.datavec.api.transform.analysis.*;
|
|
||||||
|
|
||||||
int maxHistogramBuckets = 10
|
|
||||||
DataAnalysis analysis = AnalyzeSpark.analyze(mySchema, javaRdd, maxHistogramBuckets)
|
|
||||||
|
|
||||||
DataQualityAnalysis analysis = AnalyzeSpark.analyzeQuality(mySchema, javaRdd)
|
|
||||||
|
|
||||||
Writable max = AnalyzeSpark.max(javaRdd, "myColumn", mySchema)
|
|
||||||
|
|
||||||
int numSamples = 5
|
|
||||||
List<Writable> sample = AnalyzeSpark.sampleFromColumn(numSamples, "myColumn", mySchema, javaRdd)
|
|
||||||
```
|
|
||||||
|
|
||||||
Note that if you have sequence data, there are special methods for that as well:
|
|
||||||
|
|
||||||
```java
|
|
||||||
SequenceDataAnalysis seqAnalysis = AnalyzeSpark.analyzeSequence(mySchema, sequenceRdd)
|
|
||||||
|
|
||||||
List<Writable> uniqueSequence = AnalyzeSpark.getUniqueSequence("myColumn", seqSchema, sequenceRdd)
|
|
||||||
```
|
|
||||||
|
|
||||||
## Analyzing locally
|
|
||||||
|
|
||||||
The `AnalyzeLocal` class works very similarly to its Spark counterpart and has a similar API. Instead of passing an RDD, it accepts a `RecordReader` which allows it to iterate over the dataset.
|
|
||||||
|
|
||||||
```java
|
|
||||||
import org.datavec.local.transforms.AnalyzeLocal;
|
|
||||||
|
|
||||||
int maxHistogramBuckets = 10
|
|
||||||
DataAnalysis analysis = AnalyzeLocal.analyze(mySchema, csvRecordReader, maxHistogramBuckets)
|
|
||||||
```
|
|
||||||
|
|
||||||
## Utilities
|
|
||||||
|
|
||||||
{{autogenerated}}
|
|
|
@ -1,11 +0,0 @@
|
||||||
---
|
|
||||||
title: DataVec Conditions
|
|
||||||
short_title: Conditions
|
|
||||||
description: Rules for triggering operations and transformations.
|
|
||||||
category: DataVec
|
|
||||||
weight: 3
|
|
||||||
---
|
|
||||||
|
|
||||||
## Available conditions
|
|
||||||
|
|
||||||
{{autogenerated}}
|
|
|
@ -1,43 +0,0 @@
|
||||||
---
|
|
||||||
title: DataVec Executors
|
|
||||||
short_title: Executors
|
|
||||||
description: Execute ETL and vectorization in a local instance.
|
|
||||||
category: DataVec
|
|
||||||
weight: 3
|
|
||||||
---
|
|
||||||
|
|
||||||
## Local or remote execution?
|
|
||||||
|
|
||||||
Because datasets are commonly large by nature, you can decide on an execution mechanism that best suits your needs. For example, if you are vectorizing a large training dataset, you can process it in a distributed Spark cluster. However, if you need to do real-time inference, DataVec also provides a local executor that doesn't require any additional setup.
|
|
||||||
|
|
||||||
## Executing a transform process
|
|
||||||
|
|
||||||
Once you've created your `TransformProcess` using your `Schema`, and you've either loaded your dataset into a Apache Spark `JavaRDD` or have a `RecordReader` that load your dataset, you can execute a transform.
|
|
||||||
|
|
||||||
Locally this looks like:
|
|
||||||
|
|
||||||
```java
|
|
||||||
import org.datavec.local.transforms.LocalTransformExecutor;
|
|
||||||
|
|
||||||
List<List<Writable>> transformed = LocalTransformExecutor.execute(recordReader, transformProcess)
|
|
||||||
|
|
||||||
List<List<List<Writable>>> transformedSeq = LocalTransformExecutor.executeToSequence(sequenceReader, transformProcess)
|
|
||||||
|
|
||||||
List<List<Writable>> joined = LocalTransformExecutor.executeJoin(join, leftReader, rightReader)
|
|
||||||
```
|
|
||||||
|
|
||||||
When using Spark this looks like:
|
|
||||||
|
|
||||||
```java
|
|
||||||
import org.datavec.spark.transforms.SparkTransformExecutor;
|
|
||||||
|
|
||||||
JavaRDD<List<Writable>> transformed = SparkTransformExecutor.execute(inputRdd, transformProcess)
|
|
||||||
|
|
||||||
JavaRDD<List<List<Writable>>> transformedSeq = SparkTransformExecutor.executeToSequence(inputSequenceRdd, transformProcess)
|
|
||||||
|
|
||||||
JavaRDD<List<Writable>> joined = SparkTransformExecutor.executeJoin(join, leftRdd, rightRdd)
|
|
||||||
```
|
|
||||||
|
|
||||||
## Available executors
|
|
||||||
|
|
||||||
{{autogenerated}}
|
|
|
@ -1,23 +0,0 @@
|
||||||
---
|
|
||||||
title: DataVec Filters
|
|
||||||
short_title: Filters
|
|
||||||
description: Selection of data using conditions.
|
|
||||||
category: DataVec
|
|
||||||
weight: 3
|
|
||||||
---
|
|
||||||
|
|
||||||
## Using filters
|
|
||||||
|
|
||||||
Filters are a part of transforms and gives a DSL for you to keep parts of your dataset. Filters can be one-liners for single conditions or include complex boolean logic.
|
|
||||||
|
|
||||||
```java
|
|
||||||
TransformProcess tp = new TransformProcess.Builder(inputDataSchema)
|
|
||||||
.filter(new ConditionFilter(new CategoricalColumnCondition("MerchantCountryCode", ConditionOp.NotInSet, new HashSet<>(Arrays.asList("USA","CAN")))))
|
|
||||||
.build();
|
|
||||||
```
|
|
||||||
|
|
||||||
You can also write your own filters by implementing the `Filter` interface, though it is much more often that you may want to create a custom condition instead.
|
|
||||||
|
|
||||||
## Available filters
|
|
||||||
|
|
||||||
{{autogenerated}}
|
|
|
@ -1,15 +0,0 @@
|
||||||
---
|
|
||||||
title: DataVec Normalization
|
|
||||||
short_title: Normalization
|
|
||||||
description: Preparing data in the right shape and range for learning.
|
|
||||||
category: DataVec
|
|
||||||
weight: 5
|
|
||||||
---
|
|
||||||
|
|
||||||
## Why normalize?
|
|
||||||
|
|
||||||
Neural networks work best when the data they’re fed is normalized, constrained to a range between -1 and 1. There are several reasons for that. One is that nets are trained using gradient descent, and their activation functions usually having an active range somewhere between -1 and 1. Even when using an activation function that doesn’t saturate quickly, it is still good practice to constrain your values to this range to improve performance.
|
|
||||||
|
|
||||||
## Available preprocessors
|
|
||||||
|
|
||||||
{{autogenerated}}
|
|
|
@ -1,33 +0,0 @@
|
||||||
---
|
|
||||||
title: DataVec Operations
|
|
||||||
short_title: Operations
|
|
||||||
description: Implementations for advanced transformation.
|
|
||||||
category: DataVec
|
|
||||||
weight: 3
|
|
||||||
---
|
|
||||||
|
|
||||||
## Usage
|
|
||||||
|
|
||||||
Operations, such as a `Function`, help execute transforms and load data into DataVec. The concept of operations is low-level, meaning that most of the time you will not need to worry about them.
|
|
||||||
|
|
||||||
## Loading data into Spark
|
|
||||||
|
|
||||||
If you're using Apache Spark, functions will iterate over the dataset and load it into a Spark `RDD` and convert the raw data format into a `Writable`.
|
|
||||||
|
|
||||||
```java
|
|
||||||
import org.datavec.api.writable.Writable;
|
|
||||||
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
|
|
||||||
import org.datavec.spark.transform.misc.StringToWritablesFunction;
|
|
||||||
|
|
||||||
SparkConf conf = new SparkConf();
|
|
||||||
JavaSparkContext sc = new JavaSparkContext(conf)
|
|
||||||
|
|
||||||
String customerInfoPath = new ClassPathResource("CustomerInfo.csv").getFile().getPath();
|
|
||||||
JavaRDD<List<Writable>> customerInfo = sc.textFile(customerInfoPath).map(new StringToWritablesFunction(rr));
|
|
||||||
```
|
|
||||||
|
|
||||||
The above code loads a CSV file into a 2D java RDD. Once your RDD is loaded, you can transform it, perform joins and use reducers to wrangle the data any way you want.
|
|
||||||
|
|
||||||
## Available ops
|
|
||||||
|
|
||||||
{{autogenerated}}
|
|
|
@ -1,121 +0,0 @@
|
||||||
---
|
|
||||||
title: DataVec Overview
|
|
||||||
short_title: Overview
|
|
||||||
description: Overview of the vectorization and ETL library for DL4J.
|
|
||||||
category: DataVec
|
|
||||||
weight: 0
|
|
||||||
---
|
|
||||||
|
|
||||||
## DataVec: A Vectorization and ETL Library
|
|
||||||
|
|
||||||
DataVec solves one of the most important obstacles to effective machine or deep learning: getting data into a format that neural nets can understand. Nets understand vectors. Vectorization is the first problem many data scientists will have to solve to start training their algorithms on data. Datavec should be used for 99% of your data transformations, if you are not sure if this applies to you, please consult the [gitter](https://gitter.im/deeplearning4j/deeplearning4j). Datavec supports most data formats you could want out of the box, but you may also implement your own custom record reader as well.
|
|
||||||
|
|
||||||
If your data is in CSV (Comma Seperated Values) format stored in flat files that must be converted to numeric and ingested, or your data is a directory structure of labelled images then DataVec is the tool to help you organize that data for use in DeepLearning4J.
|
|
||||||
|
|
||||||
|
|
||||||
Please **read this entire page**, particularly the section [Reading Records](#record) below, before working with DataVec.
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## Introductory Video
|
|
||||||
|
|
||||||
This video describes the conversion of image data to a vector.
|
|
||||||
|
|
||||||
<iframe width="420" height="315" src="https://www.youtube.com/embed/EHHtyRKQIJ0" frameborder="0" allowfullscreen></iframe>
|
|
||||||
|
|
||||||
## Key Aspects
|
|
||||||
- [DataVec](https://github.com/eclipse/deeplearning4j/tree/master/datavec) uses an input/output format system (similar in some ways to how Hadoop MapReduce uses InputFormat to determine InputSplits and RecordReaders, DataVec also provides RecordReaders to Serialize Data)
|
|
||||||
- Designed to support all major types of input data (text, CSV, audio, image and video) with these specific input formats
|
|
||||||
- Uses an output format system to specify an implementation-neutral type of vector format (SVMLight, etc.)
|
|
||||||
- Can be extended for specialized input formats (such as exotic image formats); i.e. You can write your own custom input format and let the rest of the codebase handle the transformation pipeline
|
|
||||||
- Makes vectorization a first-class citizen
|
|
||||||
- Built in Transformation tools to convert and normalize data
|
|
||||||
- Please see the [DataVec Javadoc](/api/{{page.version}}/) here
|
|
||||||
|
|
||||||
There's a <a href="#tutorial">brief tutorial below</a>.
|
|
||||||
|
|
||||||
## A Few Examples
|
|
||||||
|
|
||||||
* Convert the CSV-based UCI Iris dataset into svmLight open vector text format
|
|
||||||
* Convert the MNIST dataset from raw binary files to the svmLight text format.
|
|
||||||
* Convert raw text into the Metronome vector format
|
|
||||||
* Convert raw text into TF-IDF based vectors in a text vector format {svmLight, metronome}
|
|
||||||
* Convert raw text into the word2vec in a text vector format {svmLight, metronome}
|
|
||||||
|
|
||||||
## Targeted Vectorization Engines
|
|
||||||
|
|
||||||
* Any CSV to vectors with a scriptable transform language
|
|
||||||
* MNIST to vectors
|
|
||||||
* Text to vectors
|
|
||||||
* TF-IDF
|
|
||||||
* Bag of Words
|
|
||||||
* word2vec
|
|
||||||
|
|
||||||
## CSV Transformation Engine
|
|
||||||
|
|
||||||
If data is numeric and appropriately formatted then CSVRecordReader may be satisfactory. If however your data has non-numeric fields such as strings representing boolean (T/F) or strings for labels then a Schema Transformation will be required. DataVec uses apache [Spark](http://spark.apache.org/) to perform transform operations. *note you do not need to know the internals of Spark to be succesful with DataVec Transform
|
|
||||||
|
|
||||||
## Schema Transformation Video
|
|
||||||
|
|
||||||
A video tutorial of a simple DataVec transform along with code is available below.
|
|
||||||
<iframe width="560" height="315" src="https://www.youtube.com/embed/MLEMw2NxjxE" frameborder="0" allowfullscreen></iframe>
|
|
||||||
|
|
||||||
## Example Java Code
|
|
||||||
|
|
||||||
Our [examples](https://github.com/eclipse/deeplearning4j-examples) include a collection of DataVec examples.
|
|
||||||
|
|
||||||
<!-- Note to Tom, write DataVec setup content
|
|
||||||
|
|
||||||
## <a name="tutorial">Setting Up DataVec</a>
|
|
||||||
|
|
||||||
Search for [DataVec](https://search.maven.org/#search%7Cga%7C1%7CDataVec) on Maven Central to get a list of JARs you can use.
|
|
||||||
|
|
||||||
Add the dependency information into your pom.xml.
|
|
||||||
|
|
||||||
-->
|
|
||||||
|
|
||||||
|
|
||||||
## <a name="record">Reading Records, Iterating Over Data</a>
|
|
||||||
|
|
||||||
The following code shows how to work with one example, raw images, transforming them into a format that will work well with DL4J and ND4J:
|
|
||||||
|
|
||||||
``` java
|
|
||||||
// Instantiating RecordReader. Specify height, width and channels of images.
|
|
||||||
// Note that for grayscale output, channels = 1, whereas for RGB images, channels = 3
|
|
||||||
RecordReader recordReader = new ImageRecordReader(28, 28, 3);
|
|
||||||
|
|
||||||
// Point to data path.
|
|
||||||
recordReader.initialize(new FileSplit(new File(labeledPath)));
|
|
||||||
```
|
|
||||||
|
|
||||||
The RecordReader is a class in DataVec that helps convert the byte-oriented input into data that's oriented toward a record; i.e. a collection of elements that are fixed in number and indexed with a unique ID. Converting data to records is the process of vectorization. The record itself is a vector, each element of which is a feature.
|
|
||||||
|
|
||||||
The [ImageRecordReader](https://github.com/eclipse/deeplearning4j/tree/master/datavec/blob/a64389c08396bb39626201beeabb7c4d5f9288f9/datavec-data/datavec-data-image/src/main/java/org/datavec/image/recordreader/ImageRecordReader.java) is a subclass of the RecordReader and is built to automatically take in 28 x 28 pixel images. Thus, LFW images are scaled to 28 pixels x 28 pixels. You can change dimensions to match your custom images by changing the parameters fed to the ImageRecordReader, as long as you make sure to adjust the `nIn` hyperparameter, which will be equal to the product of image height x image width.
|
|
||||||
|
|
||||||
Other parameters shown above include `true`, which instructs the reader to append a label to the record, and `labels`, which is the array of supervised values (e.g. targets) used to validate neural net model results. Here are all the RecordReader extensions that come pre-built with DataVec (you can find them by right-clicking on `RecordReader` in IntelliJ, clicking `Go To` in the drop-down menu, and selection `Implementations`):
|
|
||||||
|
|
||||||
![Alt text](/images/guide/recordreader_extensions.png)
|
|
||||||
|
|
||||||
The DataSetIterator is a Deeplearning4J class that traverses the elements of a list. Iterators pass through the data list, accesses each item sequentially, keeps track of how far it has progressed by pointing to its current element, and modifies itself to point to the next element with each new step in the traversal.
|
|
||||||
|
|
||||||
``` java
|
|
||||||
// DataVec to DL4J
|
|
||||||
DataSetIterator iter = new RecordReaderDataSetIterator(recordReader, 784, labels.size());
|
|
||||||
```
|
|
||||||
|
|
||||||
The DataSetIterator iterates through input datasets, fetching one or more new examples with each iteration, and loading those examples into a DataSet object that neural nets can work with. Note that ImageRecordReader produces image data with 4 dimensions that matches DL4J's expected activations layout. Thus, each 28x28 RGB image is represented as a 4d array, with dimensions [minibatch, channels, height, width] = [1, 3, 28, 28]. Note that the constructor line above also specifies the number of labels possible.
|
|
||||||
Note also that ImageRecordReader does not normalize the image data, thus each pixel/channel value will be in the range 0 to 255 (and generally should be normalized separately - for example using ND4J's ImagePreProcessingScaler or another normalizer.
|
|
||||||
|
|
||||||
`RecordReaderDataSetIterator` can take as parameters the specific recordReader you want (for images, sound, etc.) and the batch size. For supervised learning, it will also take a label index and the number of possible labels that can be applied to the input (for LFW, the number of labels is 5,749).
|
|
||||||
|
|
||||||
## Execution
|
|
||||||
|
|
||||||
Runs as both a local serial process and a MapReduce (MR engine on the roadmap) scale-out process with no code changes.
|
|
||||||
|
|
||||||
## Targetted Vector Formats
|
|
||||||
* svmLight
|
|
||||||
* libsvm
|
|
||||||
* Metronome
|
|
||||||
|
|
||||||
## Built-In General Functionality
|
|
||||||
* Understands how to take general text and convert it into vectors with stock techniques such as kernel hashing and TF-IDF
|
|
|
@ -1,32 +0,0 @@
|
||||||
---
|
|
||||||
title: DataVec Readers
|
|
||||||
short_title: Readers
|
|
||||||
description: Read individual records from different formats.
|
|
||||||
category: DataVec
|
|
||||||
weight: 2
|
|
||||||
---
|
|
||||||
|
|
||||||
## Why readers?
|
|
||||||
|
|
||||||
Readers iterate records from a dataset in storage and load the data into DataVec. The usefulness of readers beyond individual entries in a dataset includes: what if you wanted to train a text generator on a corpus? Or programmatically compose two entries together to form a new record? Reader implementations are useful for complex file types or distributed storage mechanisms.
|
|
||||||
|
|
||||||
Readers return `Writable` classes that describe each column in a `Record`. These classes are used to convert each record to a tensor/ND-Array format.
|
|
||||||
|
|
||||||
## Usage
|
|
||||||
|
|
||||||
Each reader implementation extends `BaseRecordReader` and provides a simple API for selecting the next record in a dataset, acting similarly to iterators.
|
|
||||||
|
|
||||||
Useful methods include:
|
|
||||||
|
|
||||||
- `next`: Return a batch of `Writable`.
|
|
||||||
- `nextRecord`: Return a single `Record`, optionally with `RecordMetaData`.
|
|
||||||
- `reset`: Reset the underlying iterator.
|
|
||||||
- `hasNext`: Iterator method to determine if another record is available.
|
|
||||||
|
|
||||||
## Listeners
|
|
||||||
|
|
||||||
You can hook a custom `RecordListener` to a record reader for debugging or visualization purposes. Pass your custom listener to the `addListener` base method immediately after initializing your class.
|
|
||||||
|
|
||||||
## Types of readers
|
|
||||||
|
|
||||||
{{autogenerated}}
|
|
|
@ -1,19 +0,0 @@
|
||||||
---
|
|
||||||
title: DataVec Records
|
|
||||||
short_title: Records
|
|
||||||
description: How to use data records in DataVec.
|
|
||||||
category: DataVec
|
|
||||||
weight: 1
|
|
||||||
---
|
|
||||||
|
|
||||||
## What is a record?
|
|
||||||
|
|
||||||
In the DataVec world a Record represents a single entry in a dataset. DataVec differentiates types of records to make data manipulation easier with built-in APIs. Sequences and 2D records are distinguishable.
|
|
||||||
|
|
||||||
## Using records
|
|
||||||
|
|
||||||
Most of the time you do not need to interact with the record classes directly, unless you are manually iterating records for the purpose of forwarding through a neural network.
|
|
||||||
|
|
||||||
## Types of records
|
|
||||||
|
|
||||||
{{autogenerated}}
|
|
|
@ -1,11 +0,0 @@
|
||||||
---
|
|
||||||
title: DataVec Reductions
|
|
||||||
short_title: Reductions
|
|
||||||
description: Operations for reducing complexity in data.
|
|
||||||
category: DataVec
|
|
||||||
weight: 1
|
|
||||||
---
|
|
||||||
|
|
||||||
## Available reductions
|
|
||||||
|
|
||||||
{{autogenerated}}
|
|
|
@ -1,60 +0,0 @@
|
||||||
---
|
|
||||||
title: DataVec Schema
|
|
||||||
short_title: Schema
|
|
||||||
description: Schemas for datasets and transformation.
|
|
||||||
category: DataVec
|
|
||||||
weight: 1
|
|
||||||
---
|
|
||||||
|
|
||||||
## Why use schemas?
|
|
||||||
|
|
||||||
The unfortunate reality is that data is *dirty*. When trying to vecotrize a dataset for deep learning, it is quite rare to find files that have zero errors. Schema is important for maintaining the meaning of the data before using it for something like training a neural network.
|
|
||||||
|
|
||||||
## Using schemas
|
|
||||||
|
|
||||||
Schemas are primarily used for programming transformations. Before you can properly execute a `TransformProcess` you will need to pass the schema of the data being transformed.
|
|
||||||
|
|
||||||
An example of a schema for merchant records may look like:
|
|
||||||
|
|
||||||
```java
|
|
||||||
Schema inputDataSchema = new Schema.Builder()
|
|
||||||
.addColumnsString("DateTimeString", "CustomerID", "MerchantID")
|
|
||||||
.addColumnInteger("NumItemsInTransaction")
|
|
||||||
.addColumnCategorical("MerchantCountryCode", Arrays.asList("USA","CAN","FR","MX"))
|
|
||||||
.addColumnDouble("TransactionAmountUSD",0.0,null,false,false) //$0.0 or more, no maximum limit, no NaN and no Infinite values
|
|
||||||
.addColumnCategorical("FraudLabel", Arrays.asList("Fraud","Legit"))
|
|
||||||
.build();
|
|
||||||
```
|
|
||||||
|
|
||||||
## Joining schemas
|
|
||||||
|
|
||||||
If you have two different datasets that you want to merge together, DataVec provides a `Join` class with different join strategies such as `Inner` or `RightOuter`.
|
|
||||||
|
|
||||||
```java
|
|
||||||
Schema customerInfoSchema = new Schema.Builder()
|
|
||||||
.addColumnLong("customerID")
|
|
||||||
.addColumnString("customerName")
|
|
||||||
.addColumnCategorical("customerCountry", Arrays.asList("USA","France","Japan","UK"))
|
|
||||||
.build();
|
|
||||||
|
|
||||||
Schema customerPurchasesSchema = new Schema.Builder()
|
|
||||||
.addColumnLong("customerID")
|
|
||||||
.addColumnTime("purchaseTimestamp", DateTimeZone.UTC)
|
|
||||||
.addColumnLong("productID")
|
|
||||||
.addColumnInteger("purchaseQty")
|
|
||||||
.addColumnDouble("unitPriceUSD")
|
|
||||||
.build();
|
|
||||||
|
|
||||||
Join join = new Join.Builder(Join.JoinType.Inner)
|
|
||||||
.setJoinColumns("customerID")
|
|
||||||
.setSchemas(customerInfoSchema, customerPurchasesSchema)
|
|
||||||
.build();
|
|
||||||
```
|
|
||||||
|
|
||||||
Once you've defined your join and you've loaded the data into DataVec, you must use an `Executor` to complete the join.
|
|
||||||
|
|
||||||
## Classes and utilities
|
|
||||||
|
|
||||||
DataVec comes with a few `Schema` classes and helper utilities for 2D and sequence types of data.
|
|
||||||
|
|
||||||
{{autogenerated}}
|
|
|
@ -1,32 +0,0 @@
|
||||||
---
|
|
||||||
title: DataVec Serialization
|
|
||||||
short_title: Serialization
|
|
||||||
description: Data wrangling and mapping from one schema to another.
|
|
||||||
category: DataVec
|
|
||||||
weight: 1
|
|
||||||
---
|
|
||||||
|
|
||||||
## Serializing transforms
|
|
||||||
|
|
||||||
DataVec comes with the ability to serialize transforms, which allows them to be more portable when they're needed for production environments. A `TransformProcess` is serialzied to a human-readable format such as JSON and can be saved as a file.
|
|
||||||
|
|
||||||
## Serialization
|
|
||||||
|
|
||||||
The code below shows how you can serialize the transform process `tp`.
|
|
||||||
|
|
||||||
```java
|
|
||||||
String serializedTransformString = tp.toJson()
|
|
||||||
```
|
|
||||||
|
|
||||||
## Deserialization
|
|
||||||
|
|
||||||
When you want to reinstantiate the transform process, call the static `from<format>` method.
|
|
||||||
|
|
||||||
```java
|
|
||||||
TransformProcess tp = TransformProcess.fromJson(serializedTransformString)
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
## Available serializers
|
|
||||||
|
|
||||||
{{autogenerated}}
|
|
|
@ -1,64 +0,0 @@
|
||||||
---
|
|
||||||
title: DataVec Transforms
|
|
||||||
short_title: Transforms
|
|
||||||
description: Data wrangling and mapping from one schema to another.
|
|
||||||
category: DataVec
|
|
||||||
weight: 1
|
|
||||||
---
|
|
||||||
|
|
||||||
## Data wrangling
|
|
||||||
|
|
||||||
One of the key tools in DataVec is transformations. DataVec helps the user map a dataset from one schema to another, and provides a list of operations to convert types, format data, and convert a 2D dataset to sequence data.
|
|
||||||
|
|
||||||
## Building a transform process
|
|
||||||
|
|
||||||
A transform process requires a `Schema` to successfully transform data. Both schema and transform process classes come with a helper `Builder` class which are useful for organizing code and avoiding complex constructors.
|
|
||||||
|
|
||||||
When both are combined together they look like the sample code below. Note how `inputDataSchema` is passed into the `Builder` constructor. Your transform process will fail to compile without it.
|
|
||||||
|
|
||||||
```java
|
|
||||||
import org.datavec.api.transform.TransformProcess;
|
|
||||||
|
|
||||||
TransformProcess tp = new TransformProcess.Builder(inputDataSchema)
|
|
||||||
.removeColumns("CustomerID","MerchantID")
|
|
||||||
.filter(new ConditionFilter(new CategoricalColumnCondition("MerchantCountryCode", ConditionOp.NotInSet, new HashSet<>(Arrays.asList("USA","CAN")))))
|
|
||||||
.conditionalReplaceValueTransform(
|
|
||||||
"TransactionAmountUSD", //Column to operate on
|
|
||||||
new DoubleWritable(0.0), //New value to use, when the condition is satisfied
|
|
||||||
new DoubleColumnCondition("TransactionAmountUSD",ConditionOp.LessThan, 0.0)) //Condition: amount < 0.0
|
|
||||||
.stringToTimeTransform("DateTimeString","YYYY-MM-DD HH:mm:ss.SSS", DateTimeZone.UTC)
|
|
||||||
.renameColumn("DateTimeString", "DateTime")
|
|
||||||
.transform(new DeriveColumnsFromTimeTransform.Builder("DateTime").addIntegerDerivedColumn("HourOfDay", DateTimeFieldType.hourOfDay()).build())
|
|
||||||
.removeColumns("DateTime")
|
|
||||||
.build();
|
|
||||||
```
|
|
||||||
|
|
||||||
## Executing a transformation
|
|
||||||
|
|
||||||
Different "backends" for executors are available. Using the `tp` transform process above, here's how you can execute it locally using plain DataVec.
|
|
||||||
|
|
||||||
```java
|
|
||||||
import org.datavec.local.transforms.LocalTransformExecutor;
|
|
||||||
|
|
||||||
List<List<Writable>> processedData = LocalTransformExecutor.execute(originalData, tp);
|
|
||||||
```
|
|
||||||
|
|
||||||
## Debugging
|
|
||||||
|
|
||||||
Each operation in a transform process represents a "step" in schema changes. Sometimes, the resulting transformation is not the intended result. You can debug this by printing each step in the transform `tp` with the following:
|
|
||||||
|
|
||||||
```java
|
|
||||||
//Now, print the schema after each time step:
|
|
||||||
int numActions = tp.getActionList().size();
|
|
||||||
|
|
||||||
for(int i=0; i<numActions; i++ ){
|
|
||||||
System.out.println("\n\n==================================================");
|
|
||||||
System.out.println("-- Schema after step " + i + " (" + tp.getActionList().get(i) + ") --");
|
|
||||||
|
|
||||||
System.out.println(tp.getSchemaAfterStep(i));
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
## Available transformations and conversions
|
|
||||||
|
|
||||||
{{autogenerated}}
|
|
|
@ -1,11 +0,0 @@
|
||||||
---
|
|
||||||
title: DataVec Visualization
|
|
||||||
short_title: Visualization
|
|
||||||
description: UI for visualizing data in DataVec.
|
|
||||||
category: DataVec
|
|
||||||
weight: 10
|
|
||||||
---
|
|
||||||
|
|
||||||
## Utilities
|
|
||||||
|
|
||||||
{{autogenerated}}
|
|
|
@ -1,10 +0,0 @@
|
||||||
# deeplearning4j-nlp documentation
|
|
||||||
|
|
||||||
To generate docs into the `deeplearning4j-nlp/doc_sources` folder, first `cd docs` then run:
|
|
||||||
|
|
||||||
```shell
|
|
||||||
python generate_docs.py \
|
|
||||||
--project deeplearning4j-nlp \
|
|
||||||
--code ../deeplearning4j
|
|
||||||
--out_language en
|
|
||||||
```
|
|
|
@ -1,34 +0,0 @@
|
||||||
{
|
|
||||||
"excludes": [
|
|
||||||
"abstract"
|
|
||||||
],
|
|
||||||
"indices": [
|
|
||||||
],
|
|
||||||
"pages": [
|
|
||||||
{
|
|
||||||
"page": "overview.md",
|
|
||||||
"class": []
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"page": "word2vec.md",
|
|
||||||
"class": []
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"page": "doc2vec.md",
|
|
||||||
"class": []
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"page": "sentence-iterator.md",
|
|
||||||
"class": []
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"page": "tokenization.md",
|
|
||||||
"class": []
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"page": "vocabulary-cache.md",
|
|
||||||
"class": []
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
|
@ -1,47 +0,0 @@
|
||||||
---
|
|
||||||
title: Doc2Vec, or Paragraph Vectors, in Deeplearning4j
|
|
||||||
short_title: Doc2Vec
|
|
||||||
description: Doc2Vec and arbitrary documents for language processing in DL4J.
|
|
||||||
category: Language Processing
|
|
||||||
weight: 10
|
|
||||||
---
|
|
||||||
|
|
||||||
## Doc2Vec, or Paragraph Vectors, in Deeplearning4j
|
|
||||||
|
|
||||||
The main purpose of Doc2Vec is associating arbitrary documents with labels, so labels are required. Doc2vec is an extension of word2vec that learns to correlate labels and words, rather than words with other words. Deeplearning4j's implentation is intended to serve the Java, [Scala](./scala.html) and Clojure communities.
|
|
||||||
|
|
||||||
The first step is coming up with a vector that represents the "meaning" of a document, which can then be used as input to a supervised machine learning algorithm to associate documents with labels.
|
|
||||||
|
|
||||||
In the ParagraphVectors builder pattern, the `labels()` method points to the labels to train on. In the example below, you can see labels related to sentiment analysis:
|
|
||||||
|
|
||||||
``` java
|
|
||||||
.labels(Arrays.asList("negative", "neutral","positive"))
|
|
||||||
```
|
|
||||||
|
|
||||||
Here's a full working example of [classification with paragraph vectors](https://github.com/eclipse/deeplearning4j-examples/blob/master/dl4j-examples/src/main/java/org/deeplearning4j/examples/nlp/paragraphvectors/ParagraphVectorsClassifierExample.java):
|
|
||||||
|
|
||||||
``` java
|
|
||||||
public void testDifferentLabels() throws Exception {
|
|
||||||
ClassPathResource resource = new ClassPathResource("/labeled");
|
|
||||||
File file = resource.getFile();
|
|
||||||
LabelAwareSentenceIterator iter = LabelAwareUimaSentenceIterator.createWithPath(file.getAbsolutePath());
|
|
||||||
|
|
||||||
TokenizerFactory t = new UimaTokenizerFactory();
|
|
||||||
|
|
||||||
ParagraphVectors vec = new ParagraphVectors.Builder()
|
|
||||||
.minWordFrequency(1).labels(Arrays.asList("negative", "neutral","positive"))
|
|
||||||
.layerSize(100)
|
|
||||||
.stopWords(new ArrayList<String>())
|
|
||||||
.windowSize(5).iterate(iter).tokenizerFactory(t).build();
|
|
||||||
|
|
||||||
vec.fit();
|
|
||||||
|
|
||||||
assertNotEquals(vec.lookupTable().vector("UNK"), vec.lookupTable().vector("negative"));
|
|
||||||
assertNotEquals(vec.lookupTable().vector("UNK"),vec.lookupTable().vector("positive"));
|
|
||||||
assertNotEquals(vec.lookupTable().vector("UNK"),vec.lookupTable().vector("neutral"));}
|
|
||||||
```
|
|
||||||
|
|
||||||
### Further Reading
|
|
||||||
|
|
||||||
* [Distributed Representations of Sentences and Documents](https://cs.stanford.edu/~quocle/paragraph_vector.pdf)
|
|
||||||
* [Word2vec: A Tutorial](./word2vec)
|
|
|
@ -1,43 +0,0 @@
|
||||||
---
|
|
||||||
title: Deeplearning4j's NLP Functionality
|
|
||||||
short_title: Overview
|
|
||||||
description: Overview of language processing in DL4J
|
|
||||||
category: Language Processing
|
|
||||||
weight: 0
|
|
||||||
---
|
|
||||||
|
|
||||||
## Deeplearning4j's NLP Functionality
|
|
||||||
|
|
||||||
Although not designed to be comparable to tools such as Stanford CoreNLP or NLTK, deepLearning4J does include some core text processing tools that are described here.
|
|
||||||
|
|
||||||
Deeplearning4j's NLP relies on [ClearTK](https://cleartk.github.io/cleartk/), an open-source machine learning and natural language processing framework for the Apache [Unstructured Information Management Architecture](https://uima.apache.org/), or UIMA. UIMA enables us to perform language identification, language-specific segmentation, sentence boundary detection and entity detection (proper nouns: persons, corporations, places and things).
|
|
||||||
|
|
||||||
### SentenceIterator
|
|
||||||
|
|
||||||
There are several steps involved in processing natural language. The first is to iterate over your corpus to create a list of documents, which can be as short as a tweet, or as long as a newspaper article. This is performed by a SentenceIterator, which will appear like this:
|
|
||||||
|
|
||||||
<script src="https://gist-it.appspot.com/https://github.com/eclipse/deeplearning4j-examples/blob/master/dl4j-examples/src/main/java/org/deeplearning4j/examples/nlp/word2vec/Word2VecRawTextExample.java?slice=33:41"></script>
|
|
||||||
|
|
||||||
The SentenceIterator encapsulates a corpus or text, organizing it, say, as one Tweet per line. It is responsible for feeding text piece by piece into your natural language processor. The SentenceIterator is not analogous to a similarly named class, the DatasetIterator, which creates a dataset for training a neural net. Instead it creates a collection of strings by segmenting a corpus.
|
|
||||||
|
|
||||||
### Tokenizer
|
|
||||||
|
|
||||||
A Tokenizer further segments the text at the level of single words, also alternatively as n-grams. ClearTK contains the underlying tokenizers, such as parts of speech (PoS) and parse trees, which allow for both dependency and constituency parsing, like that employed by a recursive neural tensor network (RNTN).
|
|
||||||
|
|
||||||
A Tokenizer is created and wrapped by a [TokenizerFactory](https://github.com/eclipse/deeplearning4j/blob/6f027fd5075e3e76a38123ae5e28c00c17db4361/deeplearning4j-scaleout/deeplearning4j-nlp/src/main/java/org/deeplearning4j/text/tokenization/tokenizerfactory/UimaTokenizerFactory.java). The default tokens are words separated by spaces. The tokenization process also involves some machine learning to differentiate between ambibuous symbols like . which end sentences and also abbreviate words such as Mr. and vs.
|
|
||||||
|
|
||||||
Both Tokenizers and SentenceIterators work with Preprocessors to deal with anomalies in messy text like Unicode, and to render such text, say, as lowercase characters uniformly.
|
|
||||||
|
|
||||||
<script src="https://gist-it.appspot.com/https://github.com/eclipse/deeplearning4j-examples/blob/master/dl4j-examples/src/main/java/org/deeplearning4j/examples/nlp/word2vec/Word2VecRawTextExample.java?slice=43:57"></script>
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
### Vocab
|
|
||||||
|
|
||||||
Each document has to be tokenized to create a vocab, the set of words that matter for that document or corpus. Those words are stored in the vocab cache, which contains statistics about a subset of words counted in the document, the words that "matter". The line separating significant and insignifant words is mobile, but the basic idea of distinguishing between the two groups is that words occurring only once (or less than, say, five times) are hard to learn and their presence represents unhelpful noise.
|
|
||||||
|
|
||||||
The vocab cache stores metadata for methods such as Word2vec and Bag of Words, which treat words in radically different ways. Word2vec creates representations of words, or neural word embeddings, in the form of vectors that are hundreds of coefficients long. Those coefficients help neural nets predict the likelihood of a word appearing in any given context; for example, after another word. Here's Word2vec, configured:
|
|
||||||
|
|
||||||
<script src="https://gist-it.appspot.com/https://github.com/eclipse/deeplearning4j-examples/blob/master/dl4j-examples/src/main/java/org/deeplearning4j/examples/nlp/word2vec/Word2VecRawTextExample.java"></script>
|
|
||||||
|
|
||||||
Once you obtain word vectors, you can feed them into a deep net for classification, prediction, sentiment analysis and the like.
|
|
|
@ -1,56 +0,0 @@
|
||||||
---
|
|
||||||
title: Sentence Iteration
|
|
||||||
short_title: Sentence Iteration
|
|
||||||
description: Iteration of words, documents, and sentences for language processing in DL4J.
|
|
||||||
category: Language Processing
|
|
||||||
weight: 10
|
|
||||||
---
|
|
||||||
|
|
||||||
## Sentence iterator
|
|
||||||
|
|
||||||
A [sentence iterator](./doc/org/deeplearning4j/word2vec/sentenceiterator/SentenceIterator.html) is used in both [Word2vec](./word2vec.html) and [Bag of Words](./bagofwords-tf-idf.html).
|
|
||||||
|
|
||||||
It feeds bits of text into a neural network in the form of vectors, and also covers the concept of documents in text processing.
|
|
||||||
|
|
||||||
In natural-language processing, a document or sentence is typically used to encapsulate a context which an algorithm should learn.
|
|
||||||
|
|
||||||
A few examples include analyzing Tweets and full-blown news articles. The purpose of the [sentence iterator](./doc/org/deeplearning4j/word2vec/sentenceiterator/SentenceIterator.html) is to divide text into processable bits. Note the sentence iterator is input agnostic. So bits of text (a document) can come from a file system, the Twitter API or Hadoop.
|
|
||||||
|
|
||||||
Depending on how input is processed, the output of a sentence iterator will then be passed to a [tokenizer](./org/deeplearning4j/word2vec/tokenizer/Tokenizer.html) for the processing of individual tokens, which are usually words, but could also be ngrams, skipgrams or other units. The tokenizer is created on a per-sentence basis by a [tokenizer factory](./doc/org/deeplearning4j/word2vec/tokenizer/TokenizerFactory.html). The tokenizer factory is what is passed into a text-processing vectorizer.
|
|
||||||
|
|
||||||
Some typical examples are below:
|
|
||||||
|
|
||||||
SentenceIterator iter = new LineSentenceIterator(new File("your file"));
|
|
||||||
|
|
||||||
This assumes that each line in a file is a sentence.
|
|
||||||
|
|
||||||
You can also do list of strings as sentence as follows:
|
|
||||||
|
|
||||||
Collection<String> sentences = ...;
|
|
||||||
SentenceIterator iter = new CollectionSentenceIterator(sentences);
|
|
||||||
|
|
||||||
This will assume that each string is a sentence (document). Remember this could be a list of Tweets or articles -- both are applicable.
|
|
||||||
|
|
||||||
You can iterate over files as follows:
|
|
||||||
|
|
||||||
SentenceIterator iter = new FileSentenceIterator(new File("your dir or file"));
|
|
||||||
|
|
||||||
This will parse the files line by line and return individual sentences on each one.
|
|
||||||
|
|
||||||
For anything complex, we recommend an actual machine-learning level pipeline, represented by the [UimaSentenceIterator](./doc/org/deeplearning4j/text/sentenceiterator/UimaSentenceIterator.html).
|
|
||||||
|
|
||||||
The UimaSentenceIterator is capable of tokenization, part-of-speech tagging and lemmatization, among other things. The UimaSentenceIterator iterates over a set of files and can segment sentences. You can customize its behavior based on the AnalysisEngine passed into it.
|
|
||||||
|
|
||||||
The AnalysisEngine is the [UIMA](http://uima.apache.org/) concept of a text-processing pipeline. DeepLearning4j comes with standard analysis engines for all of these common tasks, allowing you to customize which text is being passed in and how you define sentences. The AnalysisEngines are thread-safe versions of the [opennlp](http://opennlp.apache.org/) pipelines. We also include [cleartk](http://cleartk.googlecode.com/)-based pipelines for handling common tasks.
|
|
||||||
|
|
||||||
For those using UIMA or curious about it, this employs the cleartk type system for tokens, sentences, and other annotations within the type system.
|
|
||||||
|
|
||||||
Here's how to create a UimaSentenceItrator.
|
|
||||||
|
|
||||||
SentenceIterator iter = UimaSentenceIterator.create("path/to/your/text/documents");
|
|
||||||
|
|
||||||
You can also instantiate directly:
|
|
||||||
|
|
||||||
SentenceIterator iter = new UimaSentenceIterator(path,AnalysisEngineFactory.createEngine(AnalysisEngineFactory.createEngineDescription(TokenizerAnnotator.getDescription(), SentenceAnnotator.getDescription())));
|
|
||||||
|
|
||||||
For those familiar with Uima, this uses Uimafit extensively to create analysis engines. You can also create custom sentence iterators by extending SentenceIterator.
|
|
|
@ -1,31 +0,0 @@
|
||||||
---
|
|
||||||
title: Tokenization
|
|
||||||
short_title: Tokenization
|
|
||||||
description: Breaking text into individual words for language processing in DL4J.
|
|
||||||
category: Language Processing
|
|
||||||
weight: 10
|
|
||||||
---
|
|
||||||
|
|
||||||
## What is Tokenization?
|
|
||||||
|
|
||||||
Tokenization is the process of breaking text down into individual words. Word windows are also composed of tokens. [Word2Vec](./word2vec.html) can output text windows that comprise training examples for input into neural nets, as seen here.
|
|
||||||
|
|
||||||
## Example
|
|
||||||
|
|
||||||
Here's an example of tokenization done with DL4J tools:
|
|
||||||
|
|
||||||
//tokenization with lemmatization,part of speech taggin,sentence segmentation
|
|
||||||
TokenizerFactory tokenizerFactory = new UimaTokenizerFactory();
|
|
||||||
Tokenizer tokenizer = tokenizerFactory.tokenize("mystring");
|
|
||||||
|
|
||||||
//iterate over the tokens
|
|
||||||
while(tokenizer.hasMoreTokens()) {
|
|
||||||
String token = tokenizer.nextToken();
|
|
||||||
}
|
|
||||||
|
|
||||||
//get the whole list of tokens
|
|
||||||
List<String> tokens = tokenizer.getTokens();
|
|
||||||
|
|
||||||
The above snippet creates a tokenizer capable of stemming.
|
|
||||||
|
|
||||||
In Word2Vec, that's the recommended a way of creating a vocabulary, because it averts various vocabulary quirks, such as the singular and plural of the same noun being counted as two different words.
|
|
|
@ -1,26 +0,0 @@
|
||||||
---
|
|
||||||
title: Vocabulary Cache
|
|
||||||
short_title: Vocab Cache
|
|
||||||
description: Mechanism for handling general NLP tasks in DL4J.
|
|
||||||
category: Language Processing
|
|
||||||
weight: 10
|
|
||||||
---
|
|
||||||
|
|
||||||
# How the Vocab Cache Works
|
|
||||||
|
|
||||||
The vocabulary cache, or vocab cache, is a mechanism for handling general-purpose natural-language tasks in Deeplearning4j, including normal TF-IDF, word vectors and certain information-retrieval techniques. The goal of the vocab cache is to be a one-stop shop for text vectorization, encapsulating techniques common to bag of words and word vectors, among others.
|
|
||||||
|
|
||||||
Vocab cache handles storage of tokens, word-count frequencies, inverse-document frequencies and document occurrences via an inverted index. The InMemoryLookupCache is the reference implementation.
|
|
||||||
|
|
||||||
In order to use a vocab cache as you iterate over text and index tokens, you need to figure out if the tokens should be included in the vocab. The criterion is usually if tokens occur with more than a certain pre-configured frequency in the corpus. Below that frequency, an individual token isn't a vocab word, and it remains just a token.
|
|
||||||
|
|
||||||
We track tokens as well. In order to track tokens, do the following:
|
|
||||||
|
|
||||||
addToken(new VocabWord(1.0,"myword"));
|
|
||||||
|
|
||||||
When you want to add a vocab word, do the following:
|
|
||||||
|
|
||||||
addWordToIndex(0, Word2Vec.UNK);
|
|
||||||
putVocabWord(Word2Vec.UNK);
|
|
||||||
|
|
||||||
Adding the word to the index sets the index. Then you declare it as a vocab word. (Declaring it as a vocab word will pull the word from the index.)
|
|
|
@ -1,495 +0,0 @@
|
||||||
---
|
|
||||||
title: Word2Vec in Deeplearning4j
|
|
||||||
short_title: Word2Vec
|
|
||||||
description: Neural word embeddings for NLP in DL4J.
|
|
||||||
category: Language Processing
|
|
||||||
weight: 2
|
|
||||||
---
|
|
||||||
|
|
||||||
## Word2Vec, Doc2vec & GloVe: Neural Word Embeddings for Natural Language Processing
|
|
||||||
|
|
||||||
Contents
|
|
||||||
|
|
||||||
* <a href="#intro">Introduction</a>
|
|
||||||
* <a href="#embed">Neural Word Embeddings</a>
|
|
||||||
* <a href="#crazy">Amusing Word2vec Results</a>
|
|
||||||
* <a href="#just">**Just Give Me the Code**</a>
|
|
||||||
* <a href="#anatomy">Anatomy of Word2Vec</a>
|
|
||||||
* <a href="#setup">Setup, Load and Train</a>
|
|
||||||
* <a href="#code">A Code Example</a>
|
|
||||||
* <a href="#trouble">Troubleshooting & Tuning Word2Vec</a>
|
|
||||||
* <a href="#use">Word2vec Use Cases</a>
|
|
||||||
* <a href="#foreign">Foreign Languages</a>
|
|
||||||
* <a href="#glove">GloVe (Global Vectors) & Doc2Vec</a>
|
|
||||||
|
|
||||||
## <a name="intro">Introduction to Word2Vec</a>
|
|
||||||
|
|
||||||
Word2vec is a two-layer neural net that processes text. Its input is a text corpus and its output is a set of vectors: feature vectors for words in that corpus. While Word2vec is not a [deep neural network](https://skymind.ai/wiki/neural-network), it turns text into a numerical form that deep nets can understand. [Deeplearning4j](./deeplearning4j-quickstart) implements a distributed form of Word2vec for Java and Scala, which works on Spark with GPUs.
|
|
||||||
|
|
||||||
Word2vec's applications extend beyond parsing sentences in the wild. It can be applied just as well to <a href="#sequence">genes, code, likes, playlists, social media graphs and other verbal or symbolic series</a> in which patterns may be discerned.
|
|
||||||
|
|
||||||
Why? Because words are simply discrete states like the other data mentioned above, and we are simply looking for the transitional probabilities between those states: the likelihood that they will co-occur. So gene2vec, like2vec and follower2vec are all possible. With that in mind, the tutorial below will help you understand how to create neural embeddings for any group of discrete and co-occurring states.
|
|
||||||
|
|
||||||
The purpose and usefulness of Word2vec is to group the vectors of similar words together in vectorspace. That is, it detects similarities mathematically. Word2vec creates vectors that are distributed numerical representations of word features, features such as the context of individual words. It does so without human intervention.
|
|
||||||
|
|
||||||
Given enough data, usage and contexts, Word2vec can make highly accurate guesses about a word’s meaning based on past appearances. Those guesses can be used to establish a word's association with other words (e.g. "man" is to "boy" what "woman" is to "girl"), or cluster documents and classify them by topic. Those clusters can form the basis of search, [sentiment analysis](https://github.com/eclipse/deeplearning4j-examples/blob/master/dl4j-examples/src/main/java/org/deeplearning4j/examples/recurrent/word2vecsentiment/Word2VecSentimentRNN.java) and recommendations in such diverse fields as scientific research, legal discovery, e-commerce and customer relationship management.
|
|
||||||
|
|
||||||
The output of the Word2vec neural net is a vocabulary in which each item has a vector attached to it, which can be fed into a deep-learning net or simply queried to detect relationships between words.
|
|
||||||
|
|
||||||
Measuring [cosine similarity](https://skymind.ai/wiki/glossary#cosine), no similarity is expressed as a 90 degree angle, while total similarity of 1 is a 0 degree angle, complete overlap; i.e. Sweden equals Sweden, while Norway has a cosine distance of 0.760124 from Sweden, the highest of any other country.
|
|
||||||
|
|
||||||
Here's a list of words associated with "Sweden" using Word2vec, in order of proximity:
|
|
||||||
|
|
||||||
![Cosine Distance](/images/guide/sweden_cosine_distance.png)
|
|
||||||
|
|
||||||
The nations of Scandinavia and several wealthy, northern European, Germanic countries are among the top nine.
|
|
||||||
|
|
||||||
## <a name="embed">Neural Word Embeddings</a>
|
|
||||||
|
|
||||||
The vectors we use to represent words are called *neural word embeddings*, and representations are strange. One thing describes another, even though those two things are radically different. As Elvis Costello said: "Writing about music is like dancing about architecture." Word2vec "vectorizes" about words, and by doing so it makes natural language computer-readable -- we can start to perform powerful mathematical operations on words to detect their similarities.
|
|
||||||
|
|
||||||
So a neural word embedding represents a word with numbers. It's a simple, yet unlikely, translation.
|
|
||||||
|
|
||||||
Word2vec is similar to an autoencoder, encoding each word in a vector, but rather than training against the input words through [reconstruction](.https://skymind.ai/wiki/variational-autoencoder) word2vec trains words against other words that neighbor them in the input corpus.
|
|
||||||
|
|
||||||
It does so in one of two ways, either using context to predict a target word (a method known as continuous bag of words, or CBOW), or using a word to predict a target context, which is called skip-gram. We use the latter method because it produces more accurate results on large datasets.
|
|
||||||
|
|
||||||
![word2vec diagram](/images/guide/word2vec_diagrams.png)
|
|
||||||
|
|
||||||
When the feature vector assigned to a word cannot be used to accurately predict that word's context, the components of the vector are adjusted. Each word's context in the corpus is the *teacher* sending error signals back to adjust the feature vector. The vectors of words judged similar by their context are nudged closer together by adjusting the numbers in the vector.
|
|
||||||
|
|
||||||
Just as Van Gogh's painting of sunflowers is a two-dimensional mixture of oil on canvas that *represents* vegetable matter in a three-dimensional space in Paris in the late 1880s, so 500 numbers arranged in a vector can represent a word or group of words.
|
|
||||||
|
|
||||||
Those numbers locate each word as a point in 500-dimensional vectorspace. Spaces of more than three dimensions are difficult to visualize. (Geoff Hinton, teaching people to imagine 13-dimensional space, suggests that students first picture 3-dimensional space and then say to themselves: "Thirteen, thirteen, thirteen." :)
|
|
||||||
|
|
||||||
A well trained set of word vectors will place similar words close to each other in that space. The words *oak*, *elm* and *birch* might cluster in one corner, while *war*, *conflict* and *strife* huddle together in another.
|
|
||||||
|
|
||||||
Similar things and ideas are shown to be "close". Their relative meanings have been translated to measurable distances. Qualities become quantities, and algorithms can do their work. But similarity is just the basis of many associations that Word2vec can learn. For example, it can gauge relations between words of one language, and map them to another.
|
|
||||||
|
|
||||||
![word2vec translation](/images/guide/word2vec_translation.png)
|
|
||||||
|
|
||||||
These vectors are the basis of a more comprehensive geometry of words. As shown in the graph, capital cities such as Rome, Paris, Berlin and Beijing cluster near each other, and they will each have similar distances in vectorspace to their countries; i.e. Rome - Italy = Beijing - China. If you only knew that Rome was the capital of Italy, and were wondering about the capital of China, then the equation Rome -Italy + China would return Beijing. No kidding.
|
|
||||||
|
|
||||||
![capitals output](/images/guide/countries_capitals.png)
|
|
||||||
|
|
||||||
## <a name="crazy">Amusing Word2Vec Results</a>
|
|
||||||
|
|
||||||
Let's look at some other associations Word2vec can produce.
|
|
||||||
|
|
||||||
Instead of the pluses, minus and equals signs, we'll give you the results in the notation of logical analogies, where `:` means "is to" and `::` means "as"; e.g. "Rome is to Italy as Beijing is to China" = `Rome:Italy::Beijing:China`. In the last spot, rather than supplying the "answer", we'll give you the list of words that a Word2vec model proposes, when given the first three elements:
|
|
||||||
|
|
||||||
king:queen::man:[woman, Attempted abduction, teenager, girl]
|
|
||||||
//Weird, but you can kind of see it
|
|
||||||
|
|
||||||
China:Taiwan::Russia:[Ukraine, Moscow, Moldova, Armenia]
|
|
||||||
//Two large countries and their small, estranged neighbors
|
|
||||||
|
|
||||||
house:roof::castle:[dome, bell_tower, spire, crenellations, turrets]
|
|
||||||
|
|
||||||
knee:leg::elbow:[forearm, arm, ulna_bone]
|
|
||||||
|
|
||||||
New York Times:Sulzberger::Fox:[Murdoch, Chernin, Bancroft, Ailes]
|
|
||||||
//The Sulzberger-Ochs family owns and runs the NYT.
|
|
||||||
//The Murdoch family owns News Corp., which owns Fox News.
|
|
||||||
//Peter Chernin was News Corp.'s COO for 13 yrs.
|
|
||||||
//Roger Ailes is president of Fox News.
|
|
||||||
//The Bancroft family sold the Wall St. Journal to News Corp.
|
|
||||||
|
|
||||||
love:indifference::fear:[apathy, callousness, timidity, helplessness, inaction]
|
|
||||||
//the poetry of this single array is simply amazing...
|
|
||||||
|
|
||||||
Donald Trump:Republican::Barack Obama:[Democratic, GOP, Democrats, McCain]
|
|
||||||
//It's interesting to note that, just as Obama and McCain were rivals,
|
|
||||||
//so too, Word2vec thinks Trump has a rivalry with the idea Republican.
|
|
||||||
|
|
||||||
monkey:human::dinosaur:[fossil, fossilized, Ice_Age_mammals, fossilization]
|
|
||||||
//Humans are fossilized monkeys? Humans are what's left
|
|
||||||
//over from monkeys? Humans are the species that beat monkeys
|
|
||||||
//just as Ice Age mammals beat dinosaurs? Plausible.
|
|
||||||
|
|
||||||
building:architect::software:[programmer, SecurityCenter, WinPcap]
|
|
||||||
|
|
||||||
This model was trained on the Google News vocab, which you can [import](#import) and play with. Contemplate, for a moment, that the Word2vec algorithm has never been taught a single rule of English syntax. It knows nothing about the world, and is unassociated with any rules-based symbolic logic or knowledge graph. And yet it learns more, in a flexible and automated fashion, than most knowledge graphs will learn after a years of human labor. It comes to the Google News documents as a blank slate, and by the end of training, it can compute complex analogies that mean something to humans.
|
|
||||||
|
|
||||||
You can also query a Word2vec model for other assocations. Not everything has to be two analogies that mirror each other. ([We explain how below....](#eval))
|
|
||||||
|
|
||||||
* Geopolitics: *Iraq - Violence = Jordan*
|
|
||||||
* Distinction: *Human - Animal = Ethics*
|
|
||||||
* *President - Power = Prime Minister*
|
|
||||||
* *Library - Books = Hall*
|
|
||||||
* Analogy: *Stock Market ≈ Thermometer*
|
|
||||||
|
|
||||||
By building a sense of one word's proximity to other similar words, which do not necessarily contain the same letters, we have moved beyond hard tokens to a smoother and more general sense of meaning.
|
|
||||||
|
|
||||||
# <a name="just">Just Give Me the Code</a>
|
|
||||||
|
|
||||||
## <a name="anatomy">Anatomy of Word2vec in DL4J</a>
|
|
||||||
|
|
||||||
Here are Deeplearning4j's natural-language processing components:
|
|
||||||
|
|
||||||
* **SentenceIterator/DocumentIterator**: Used to iterate over a dataset. A SentenceIterator returns strings and a DocumentIterator works with inputstreams.
|
|
||||||
* **Tokenizer/TokenizerFactory**: Used in tokenizing the text. In NLP terms, a sentence is represented as a series of tokens. A TokenizerFactory creates an instance of a tokenizer for a "sentence."
|
|
||||||
* **VocabCache**: Used for tracking metadata including word counts, document occurrences, the set of tokens (not vocab in this case, but rather tokens that have occurred), vocab (the features included in both [bag of words](./bagofwords-tf-idf.html) as well as the word vector lookup table)
|
|
||||||
* **Inverted Index**: Stores metadata about where words occurred. Can be used for understanding the dataset. A Lucene index with the Lucene implementation[1] is automatically created.
|
|
||||||
|
|
||||||
While Word2vec refers to a family of related algorithms, this implementation uses [Negative Sampling](https://skymind.ai/wiki/glossary#skipgram).
|
|
||||||
|
|
||||||
## <a name="setup">Word2Vec Setup</a>
|
|
||||||
|
|
||||||
Create a new project in IntelliJ using Maven. If you don't know how to do that, see our [Quickstart page](./deeplearning4j-quickstart). Then specify these properties and dependencies in the POM.xml file in your project's root directory (You can [check Maven](https://search.maven.org/#search%7Cga%7C1%7Cnd4j) for the most recent versions -- please use those...).
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
### Loading Data
|
|
||||||
|
|
||||||
Now create and name a new class in Java. After that, you'll take the raw sentences in your .txt file, traverse them with your iterator, and subject them to some sort of preprocessing, such as converting all words to lowercase.
|
|
||||||
|
|
||||||
``` java
|
|
||||||
String filePath = new ClassPathResource("raw_sentences.txt").getFile().getAbsolutePath();
|
|
||||||
|
|
||||||
log.info("Load & Vectorize Sentences....");
|
|
||||||
// Strip white space before and after for each line
|
|
||||||
SentenceIterator iter = new BasicLineIterator(filePath);
|
|
||||||
```
|
|
||||||
|
|
||||||
If you want to load a text file besides the sentences provided in our example, you'd do this:
|
|
||||||
|
|
||||||
``` java
|
|
||||||
log.info("Load data....");
|
|
||||||
SentenceIterator iter = new LineSentenceIterator(new File("/Users/cvn/Desktop/file.txt"));
|
|
||||||
iter.setPreProcessor(new SentencePreProcessor() {
|
|
||||||
@Override
|
|
||||||
public String preProcess(String sentence) {
|
|
||||||
return sentence.toLowerCase();
|
|
||||||
}
|
|
||||||
});
|
|
||||||
```
|
|
||||||
|
|
||||||
That is, get rid of the `ClassPathResource` and feed the absolute path of your `.txt` file into the `LineSentenceIterator`.
|
|
||||||
|
|
||||||
``` java
|
|
||||||
SentenceIterator iter = new LineSentenceIterator(new File("/your/absolute/file/path/here.txt"));
|
|
||||||
```
|
|
||||||
|
|
||||||
In bash, you can find the absolute file path of any directory by typing `pwd` in your command line from within that same directory. To that path, you'll add the file name and *voila*.
|
|
||||||
|
|
||||||
### Tokenizing the Data
|
|
||||||
|
|
||||||
Word2vec needs to be fed words rather than whole sentences, so the next step is to tokenize the data. To tokenize a text is to break it up into its atomic units, creating a new token each time you hit a white space, for example.
|
|
||||||
|
|
||||||
``` java
|
|
||||||
// Split on white spaces in the line to get words
|
|
||||||
TokenizerFactory t = new DefaultTokenizerFactory();
|
|
||||||
t.setTokenPreProcessor(new CommonPreprocessor());
|
|
||||||
```
|
|
||||||
|
|
||||||
That should give you one word per line.
|
|
||||||
|
|
||||||
### Training the Model
|
|
||||||
|
|
||||||
Now that the data is ready, you can configure the Word2vec neural net and feed in the tokens.
|
|
||||||
|
|
||||||
``` java
|
|
||||||
log.info("Building model....");
|
|
||||||
Word2Vec vec = new Word2Vec.Builder()
|
|
||||||
.minWordFrequency(5)
|
|
||||||
.layerSize(100)
|
|
||||||
.seed(42)
|
|
||||||
.windowSize(5)
|
|
||||||
.iterate(iter)
|
|
||||||
.tokenizerFactory(t)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
log.info("Fitting Word2Vec model....");
|
|
||||||
vec.fit();
|
|
||||||
```
|
|
||||||
|
|
||||||
This configuration accepts a number of hyperparameters. A few require some explanation:
|
|
||||||
|
|
||||||
* *batchSize* is the amount of words you process at a time.
|
|
||||||
* *minWordFrequency* is the minimum number of times a word must appear in the corpus. Here, if it appears less than 5 times, it is not learned. Words must appear in multiple contexts to learn useful features about them. In very large corpora, it's reasonable to raise the minimum.
|
|
||||||
* *useAdaGrad* - Adagrad creates a different gradient for each feature. Here we are not concerned with that.
|
|
||||||
* *layerSize* specifies the number of features in the word vector. This is equal to the number of dimensions in the featurespace. Words represented by 500 features become points in a 500-dimensional space.
|
|
||||||
* *learningRate* is the step size for each update of the coefficients, as words are repositioned in the feature space.
|
|
||||||
* *minLearningRate* is the floor on the learning rate. Learning rate decays as the number of words you train on decreases. If learning rate shrinks too much, the net's learning is no longer efficient. This keeps the coefficients moving.
|
|
||||||
* *iterate* tells the net what batch of the dataset it's training on.
|
|
||||||
* *tokenizer* feeds it the words from the current batch.
|
|
||||||
* *vec.fit()* tells the configured net to begin training.
|
|
||||||
|
|
||||||
An example for [uptraining your previously trained word vectors is here](https://github.com/eclipse/deeplearning4j-examples/blob/master/dl4j-examples/src/main/java/org/deeplearning4j/examples/nlp/word2vec/Word2VecUptrainingExample.java).
|
|
||||||
|
|
||||||
### <a name="eval">Evaluating the Model, Using Word2vec</a>
|
|
||||||
|
|
||||||
The next step is to evaluate the quality of your feature vectors.
|
|
||||||
|
|
||||||
``` java
|
|
||||||
// Write word vectors
|
|
||||||
WordVectorSerializer.writeWordVectors(vec, "pathToWriteto.txt");
|
|
||||||
|
|
||||||
log.info("Closest Words:");
|
|
||||||
Collection<String> lst = vec.wordsNearest("day", 10);
|
|
||||||
System.out.println(lst);
|
|
||||||
UiServer server = UiServer.getInstance();
|
|
||||||
System.out.println("Started on port " + server.getPort());
|
|
||||||
|
|
||||||
//output: [night, week, year, game, season, during, office, until, -]
|
|
||||||
```
|
|
||||||
|
|
||||||
The line `vec.similarity("word1","word2")` will return the cosine similarity of the two words you enter. The closer it is to 1, the more similar the net perceives those words to be (see the Sweden-Norway example above). For example:
|
|
||||||
|
|
||||||
``` java
|
|
||||||
double cosSim = vec.similarity("day", "night");
|
|
||||||
System.out.println(cosSim);
|
|
||||||
//output: 0.7704452276229858
|
|
||||||
```
|
|
||||||
|
|
||||||
With `vec.wordsNearest("word1", numWordsNearest)`, the words printed to the screen allow you to eyeball whether the net has clustered semantically similar words. You can set the number of nearest words you want with the second parameter of wordsNearest. For example:
|
|
||||||
|
|
||||||
``` java
|
|
||||||
Collection<String> lst3 = vec.wordsNearest("man", 10);
|
|
||||||
System.out.println(lst3);
|
|
||||||
//output: [director, company, program, former, university, family, group, such, general]
|
|
||||||
```
|
|
||||||
|
|
||||||
### Visualizing the Model
|
|
||||||
|
|
||||||
We rely on [TSNE](https://lvdmaaten.github.io/tsne/) to reduce the dimensionality of word feature vectors and project words into a two or three-dimensional space. The full [DL4J/ND4J example for TSNE is here](https://github.com/eclipse/deeplearning4j-examples/blob/master/dl4j-examples/src/main/java/org/deeplearning4j/examples/nlp/tsne/TSNEStandardExample.java).
|
|
||||||
|
|
||||||
``` java
|
|
||||||
Nd4j.setDataType(DataBuffer.Type.DOUBLE);
|
|
||||||
List<String> cacheList = new ArrayList<>(); //cacheList is a dynamic array of strings used to hold all words
|
|
||||||
|
|
||||||
//STEP 2: Turn text input into a list of words
|
|
||||||
log.info("Load & Vectorize data....");
|
|
||||||
File wordFile = new ClassPathResource("words.txt").getFile(); //Open the file
|
|
||||||
//Get the data of all unique word vectors
|
|
||||||
Pair<InMemoryLookupTable,VocabCache> vectors = WordVectorSerializer.loadTxt(wordFile);
|
|
||||||
VocabCache cache = vectors.getSecond();
|
|
||||||
INDArray weights = vectors.getFirst().getSyn0(); //seperate weights of unique words into their own list
|
|
||||||
|
|
||||||
for(int i = 0; i < cache.numWords(); i++) //seperate strings of words into their own list
|
|
||||||
cacheList.add(cache.wordAtIndex(i));
|
|
||||||
|
|
||||||
//STEP 3: build a dual-tree tsne to use later
|
|
||||||
log.info("Build model....");
|
|
||||||
BarnesHutTsne tsne = new BarnesHutTsne.Builder()
|
|
||||||
.setMaxIter(iterations).theta(0.5)
|
|
||||||
.normalize(false)
|
|
||||||
.learningRate(500)
|
|
||||||
.useAdaGrad(false)
|
|
||||||
// .usePca(false)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
//STEP 4: establish the tsne values and save them to a file
|
|
||||||
log.info("Store TSNE Coordinates for Plotting....");
|
|
||||||
String outputFile = "target/archive-tmp/tsne-standard-coords.csv";
|
|
||||||
(new File(outputFile)).getParentFile().mkdirs();
|
|
||||||
|
|
||||||
tsne.fit(weights);
|
|
||||||
tsne.saveAsFile(cacheList, outputFile);
|
|
||||||
```
|
|
||||||
|
|
||||||
### Saving, Reloading & Using the Model
|
|
||||||
|
|
||||||
You'll want to save the model. The normal way to save models in Deeplearning4j is via the serialization utils (Java serialization is akin to Python pickling, converting an object into a *series* of bytes).
|
|
||||||
|
|
||||||
``` java
|
|
||||||
log.info("Save vectors....");
|
|
||||||
WordVectorSerializer.writeWord2VecModel(vec, "pathToSaveModel.txt");
|
|
||||||
```
|
|
||||||
|
|
||||||
This will save the vectors to a file called `pathToSaveModel.txt` that will appear in the root of the directory where Word2vec is trained. The output in the file should have one word per line, followed by a series of numbers that together are its vector representation.
|
|
||||||
|
|
||||||
To keep working with the vectors, simply call methods on `vec` like this:
|
|
||||||
|
|
||||||
``` java
|
|
||||||
Collection<String> kingList = vec.wordsNearest(Arrays.asList("king", "woman"), Arrays.asList("queen"), 10);
|
|
||||||
```
|
|
||||||
|
|
||||||
The classic example of Word2vec's arithmetic of words is "king - queen = man - woman" and its logical extension "king - queen + woman = man".
|
|
||||||
|
|
||||||
The example above will output the 10 nearest words to the vector `king - queen + woman`, which should include `man`. The first parameter for wordsNearest has to include the "positive" words `king` and `woman`, which have a + sign associated with them; the second parameter includes the "negative" word `queen`, which is associated with the minus sign (positive and negative here have no emotional connotation); the third is the length of the list of nearest words you would like to see. Remember to add this to the top of the file: `import java.util.Arrays;`.
|
|
||||||
|
|
||||||
Any number of combinations is possible, but they will only return sensible results if the words you query occurred with enough frequency in the corpus. Obviously, the ability to return similar words (or documents) is at the foundation of both search and recommendation engines.
|
|
||||||
|
|
||||||
You can reload the vectors into memory like this:
|
|
||||||
|
|
||||||
``` java
|
|
||||||
Word2Vec word2Vec = WordVectorSerializer.readWord2VecModel("pathToSaveModel.txt");
|
|
||||||
```
|
|
||||||
|
|
||||||
You can then use Word2vec as a lookup table:
|
|
||||||
|
|
||||||
``` java
|
|
||||||
WeightLookupTable weightLookupTable = word2Vec.lookupTable();
|
|
||||||
Iterator<INDArray> vectors = weightLookupTable.vectors();
|
|
||||||
INDArray wordVectorMatrix = word2Vec.getWordVectorMatrix("myword");
|
|
||||||
double[] wordVector = word2Vec.getWordVector("myword");
|
|
||||||
```
|
|
||||||
|
|
||||||
If the word isn't in the vocabulary, Word2vec returns zeros.
|
|
||||||
|
|
||||||
### <a name="import">Importing Word2vec Models</a>
|
|
||||||
|
|
||||||
The [Google News Corpus model](https://dl4jdata.blob.core.windows.net/resources/wordvectors/GoogleNews-vectors-negative300.bin.gz) we use to test the accuracy of our trained nets is hosted on S3. Users whose current hardware takes a long time to train on large corpora can simply download it to explore a Word2vec model without the prelude.
|
|
||||||
|
|
||||||
If you trained with the [C vectors](https://docs.google.com/file/d/0B7XkCwpI5KDYaDBDQm1tZGNDRHc/edit) or Gensimm, this line will import the model.
|
|
||||||
|
|
||||||
``` java
|
|
||||||
File gModel = new File("/Developer/Vector Models/GoogleNews-vectors-negative300.bin.gz");
|
|
||||||
Word2Vec vec = WordVectorSerializer.readWord2VecModel(gModel);
|
|
||||||
```
|
|
||||||
|
|
||||||
Remember to add `import java.io.File;` to your imported packages.
|
|
||||||
|
|
||||||
With large models, you may run into trouble with your heap space. The Google model may take as much as 10G of RAM, and the JVM only launches with 256 MB of RAM, so you have to adjust your heap space. You can do that either with a `bash_profile` file (see our [Troubleshooting section](./deeplearning4j-troubleshooting-training)), or through IntelliJ itself:
|
|
||||||
|
|
||||||
``` java
|
|
||||||
//Click:
|
|
||||||
IntelliJ Preferences > Compiler > Command Line Options
|
|
||||||
//Then paste:
|
|
||||||
-Xms1024m
|
|
||||||
-Xmx10g
|
|
||||||
-XX:MaxPermSize=2g
|
|
||||||
```
|
|
||||||
|
|
||||||
### <a name="grams">N-grams & Skip-grams</a>
|
|
||||||
|
|
||||||
Words are read into the vector one at a time, *and scanned back and forth within a certain range*. Those ranges are n-grams, and an n-gram is a contiguous sequence of *n* items from a given linguistic sequence; it is the nth version of unigram, bigram, trigram, four-gram or five-gram. A skip-gram simply drops items from the n-gram.
|
|
||||||
|
|
||||||
The skip-gram representation popularized by Mikolov and used in the DL4J implementation has proven to be more accurate than other models, such as continuous bag of words, due to the more generalizable contexts generated.
|
|
||||||
|
|
||||||
This n-gram is then fed into a neural network to learn the significance of a given word vector; i.e. significance is defined as its usefulness as an indicator of certain larger meanings, or labels.
|
|
||||||
|
|
||||||
### <a name="code">A Working Example</a>
|
|
||||||
|
|
||||||
**Please note** : The code below may be outdated. For updated examples, please see our [dl4j-examples repository on Github](https://github.com/eclipse/deeplearning4j-examples/tree/master/dl4j-examples/src/main/java/org/deeplearning4j/examples/nlp).
|
|
||||||
|
|
||||||
Now that you have a basic idea of how to set up Word2Vec, here's [one example](https://github.com/eclipse/deeplearning4j-examples/blob/master/dl4j-examples/src/main/java/org/deeplearning4j/examples/nlp/word2vec/Word2VecRawTextExample.java) of how it can be used with DL4J's API:
|
|
||||||
|
|
||||||
<script src="https://gist-it.appspot.com/https://github.com/eclipse/deeplearning4j-examples/blob/master/src/main/java/org/deeplearning4j/examples/nlp/word2vec/Word2VecRawTextExample.java?slice=22:64"></script>
|
|
||||||
|
|
||||||
After following the instructions in the [Quickstart](./deeplearning4j-quickstart), you can open this example in IntelliJ and hit run to see it work. If you query the Word2vec model with a word isn't contained in the training corpus, it will return null.
|
|
||||||
|
|
||||||
### <a name="trouble">Troubleshooting & Tuning Word2Vec</a>
|
|
||||||
|
|
||||||
*Q: I get a lot of stack traces like this*
|
|
||||||
|
|
||||||
``` java
|
|
||||||
java.lang.StackOverflowError: null
|
|
||||||
at java.lang.ref.Reference.<init>(Reference.java:254) ~[na:1.8.0_11]
|
|
||||||
at java.lang.ref.WeakReference.<init>(WeakReference.java:69) ~[na:1.8.0_11]
|
|
||||||
at java.io.ObjectStreamClass$WeakClassKey.<init>(ObjectStreamClass.java:2306) [na:1.8.0_11]
|
|
||||||
at java.io.ObjectStreamClass.lookup(ObjectStreamClass.java:322) ~[na:1.8.0_11]
|
|
||||||
at java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1134) ~[na:1.8.0_11]
|
|
||||||
at java.io.ObjectOutputStream.defaultWriteFields(ObjectOutputStream.java:1548) ~[na:1.8.0_11]
|
|
||||||
```
|
|
||||||
|
|
||||||
*A:* Look inside the directory where you started your Word2vec application. This can, for example, be an IntelliJ project home directory or the directory where you typed Java at the command line. It should have some directories that look like:
|
|
||||||
|
|
||||||
```
|
|
||||||
ehcache_auto_created2810726831714447871diskstore
|
|
||||||
ehcache_auto_created4727787669919058795diskstore
|
|
||||||
ehcache_auto_created3883187579728988119diskstore
|
|
||||||
ehcache_auto_created9101229611634051478diskstore
|
|
||||||
```
|
|
||||||
|
|
||||||
You can shut down your Word2vec application and try to delete them.
|
|
||||||
|
|
||||||
*Q: Not all of the words from my raw text data are appearing in my Word2vec object…*
|
|
||||||
|
|
||||||
*A:* Try to raise the layer size via **.layerSize()** on your Word2Vec object like so
|
|
||||||
|
|
||||||
``` java
|
|
||||||
Word2Vec vec = new Word2Vec.Builder().layerSize(300).windowSize(5)
|
|
||||||
.layerSize(300).iterate(iter).tokenizerFactory(t).build();
|
|
||||||
```
|
|
||||||
|
|
||||||
*Q: How do I load my data? Why does training take forever?*
|
|
||||||
|
|
||||||
*A:* If all of your sentences have been loaded as *one* sentence, Word2vec training could take a very long time. That's because Word2vec is a sentence-level algorithm, so sentence boundaries are very important, because co-occurrence statistics are gathered sentence by sentence. (For GloVe, sentence boundaries don't matter, because it's looking at corpus-wide co-occurrence. For many corpora, average sentence length is six words. That means that with a window size of 5 you have, say, 30 (random number here) rounds of skip-gram calculations. If you forget to specify your sentence boundaries, you may load a "sentence" that's 10,000 words long. In that case, Word2vec would attempt a full skip-gram cycle for the whole 10,000-word "sentence". In DL4J's implementation, a line is assumed to be a sentence. You need plug in your own SentenceIterator and Tokenizer. By asking you to specify how your sentences end, DL4J remains language-agnostic. UimaSentenceIterator is one way to do that. It uses OpenNLP for sentence boundary detection.
|
|
||||||
|
|
||||||
|
|
||||||
*Q: Why is there such a difference in performance when feeding whole documents as one "sentence" vs splitting into Sentences?*
|
|
||||||
|
|
||||||
*A:*If average sentence contains 6 words, and window size is 5, maximum theoretical number of 10 skipgram rounds will be achieved on 0 words. Sentence isn't long enough to have full window set with words. Rough maximum number of 5 sg rounds is available there for all words in such sentence.
|
|
||||||
|
|
||||||
But if your "sentence" is 1000k words length, you'll have 10 skipgram rounds for every word in this sentence, excluding the first 5 and last five. So, you'll have to spend WAY more time building model + cooccurrence statistics will be shifted due to the absense of sentence boundaries.
|
|
||||||
|
|
||||||
*Q: How does Word2Vec Use Memory?*
|
|
||||||
|
|
||||||
*A:* The major memory consumer in w2v is weights matrix. Math is simple there: NumberOfWords x NumberOfDimensions x 2 x DataType memory footprint.
|
|
||||||
|
|
||||||
So, if you build w2v model for 100k words using floats, and 100 dimensions, your memory footprint will be 100k x 100 x 2 x 4 (float size) = 80MB RAM just for matri + some space for strings, variables, threads etc.
|
|
||||||
|
|
||||||
If you load pre-built model, it uses roughly 2 times less RAM then during build time, so it's 40MB RAM.
|
|
||||||
|
|
||||||
And the most popular model used so far is Google News model. There's 3M words, and vector size 300. That gives us 3.6GB only to load model. And you have to add 3M of strings, that do not have constant size in java. So, usually that's something around 4-6GB for loaded model depending on jvm version/supplier, gc state and phase of the moon.
|
|
||||||
|
|
||||||
|
|
||||||
*Q: I did everything you said and the results still don't look right.*
|
|
||||||
|
|
||||||
*A:* Make sure you're not hitting into normalization issues. Some tasks, like wordsNearest(), use normalized weights by default, and others require non-normalized weights. Pay attention to this difference.
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
### <a name="use">Use Cases</a>
|
|
||||||
|
|
||||||
Google Scholar keeps a running tally of the papers citing [Deeplearning4j's implementation of Word2vec here](https://scholar.google.com/scholar?hl=en&q=deeplearning4j+word2vec&btnG=&as_sdt=1%2C5&as_sdtp=).
|
|
||||||
|
|
||||||
Kenny Helsens, a data scientist based in Belgium, [applied Deeplearning4j's implementation of Word2vec](http://thinkdata.be/2015/06/10/word2vec-on-raw-omim-database/) to the NCBI's Online Mendelian Inheritance In Man (OMIM) database. He then looked for the words most similar to alk, a known oncogene of non-small cell lung carcinoma, and Word2vec returned: "nonsmall, carcinomas, carcinoma, mapdkd." From there, he established analogies between other cancer phenotypes and their genotypes. This is just one example of the associations Word2vec can learn on a large corpus. The potential for discovering new aspects of important diseases has only just begun, and outside of medicine, the opportunities are equally diverse.
|
|
||||||
|
|
||||||
Andreas Klintberg trained Deeplearning4j's implementation of Word2vec on Swedish, and wrote a [thorough walkthrough on Medium](https://medium.com/@klintcho/training-a-word2vec-model-for-swedish-e14b15be6cb).
|
|
||||||
|
|
||||||
Word2Vec is especially useful in preparing text-based data for information retrieval and QA systems, which DL4J implements with [deep autoencoders](./deeplearning4j-nn-autoencoders).
|
|
||||||
|
|
||||||
Marketers might seek to establish relationships among products to build a recommendation engine. Investigators might analyze a social graph to surface members of a single group, or other relations they might have to location or financial sponsorship.
|
|
||||||
|
|
||||||
### <a name="patent">Google's Word2vec Patent</a>
|
|
||||||
|
|
||||||
Word2vec is [a method of computing vector representations of words](https://arxiv.org/pdf/1301.3781.pdf) introduced by a team of researchers at Google led by Tomas Mikolov. Google [hosts an open-source version of Word2vec](https://code.google.com/p/word2vec/) released under an Apache 2.0 license. In 2014, Mikolov left Google for Facebook, and in May 2015, [Google was granted a patent for the method](http://patft.uspto.gov/netacgi/nph-Parser?Sect1=PTO2&Sect2=HITOFF&p=1&u=%2Fnetahtml%2FPTO%2Fsearch-bool.html&r=1&f=G&l=50&co1=AND&d=PTXT&s1=9037464&OS=9037464&RS=9037464), which does not abrogate the Apache license under which it has been released.
|
|
||||||
|
|
||||||
### <a name="foreign">Foreign Languages</a>
|
|
||||||
|
|
||||||
While words in all languages may be converted into vectors with Word2vec, and those vectors learned with Deeplearning4j, NLP preprocessing can be very language specific, and requires tools beyond our libraries. The [Stanford Natural Language Processing Group](http://nlp.stanford.edu/software/) has a number of Java-based tools for tokenization, part-of-speech tagging and named-entity recognition for languages such as [Mandarin Chinese](http://nlp.stanford.edu/projects/chinese-nlp.shtml), Arabic, French, German and Spanish. For Japanese, NLP tools like [Kuromoji](http://www.atilika.org/) are useful. Other foreign-language resources, including [text corpora, are available here](http://www-nlp.stanford.edu/links/statnlp.html).
|
|
||||||
|
|
||||||
### <a name="glove">GloVe: Global Vectors</a>
|
|
||||||
|
|
||||||
Loading and saving GloVe models to word2vec can be done like so:
|
|
||||||
|
|
||||||
``` java
|
|
||||||
WordVectors wordVectors = WordVectorSerializer.loadTxtVectors(new File("glove.6B.50d.txt"));
|
|
||||||
```
|
|
||||||
|
|
||||||
### <a name="sequence">Sequence Vectors</a>
|
|
||||||
|
|
||||||
Deeplearning4j has a class called [SequenceVectors](https://github.com/eclipse/deeplearning4j/blob/master/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/SequenceVectors.java), which is one level of abstraction above word vectors, and which allows you to extract features from any sequence, including social media profiles, transactions, proteins, etc. If data can be described as sequence, it can be learned via skip-gram and hierarchic softmax with the AbstractVectors class. This is compatible with the [DeepWalk algorithm](https://github.com/eclipse/deeplearning4j/blob/master/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/models/deepwalk/DeepWalk.java), also implemented in Deeplearning4j.
|
|
||||||
|
|
||||||
### <a name="features">Word2Vec Features on Deeplearning4j</a>
|
|
||||||
|
|
||||||
* Weights update after model serialization/deserialization was added. That is, you can update model state with, say, 200GB of new text by calling `loadFullModel`, adding `TokenizerFactory` and `SentenceIterator` to it, and calling `fit()` on the restored model.
|
|
||||||
* Option for multiple datasources for vocab construction was added.
|
|
||||||
* Epochs and Iterations can be specified separately, although they are both typically "1".
|
|
||||||
* Word2Vec.Builder has this option: `hugeModelExpected`. If set to `true`, the vocab will be periodically truncated during the build.
|
|
||||||
* While `minWordFrequency` is useful for ignoring rare words in the corpus, any number of words can be excluded to customize.
|
|
||||||
* Two new WordVectorsSerialiaztion methods have been introduced: `writeFullModel` and `loadFullModel`. These save and load a full model state.
|
|
||||||
* A decent workstation should be able to handle a vocab with a few million words. Deeplearning4j's Word2vec imlementation can model a few terabytes of data on a single machine. Roughly, the math is: `vectorSize * 4 * 3 * vocab.size()`.
|
|
||||||
|
|
||||||
### Doc2vec & Other NLP Resources
|
|
||||||
|
|
||||||
* [DL4J Example of Text Classification With Word2vec & RNNs](https://github.com/eclipse/deeplearning4j-examples/blob/master/dl4j-examples/src/main/java/org/deeplearning4j/examples/recurrent/word2vecsentiment/Word2VecSentimentRNN.java)
|
|
||||||
* [DL4J Example of Text Classification With Paragraph Vectors](https://github.com/eclipse/deeplearning4j-examples/blob/master/dl4j-examples/src/main/java/org/deeplearning4j/examples/nlp/paragraphvectors/ParagraphVectorsClassifierExample.java)
|
|
||||||
* [Doc2vec, or Paragraph Vectors, With Deeplearning4j](./deeplearning4j-nlp-doc2vec)
|
|
||||||
* [Thought Vectors, Natural Language Processing & the Future of AI](https://skymind.ai/wiki/thought-vectors)
|
|
||||||
* [Quora: How Does Word2vec Work?](http://www.quora.com/How-does-word2vec-work)
|
|
||||||
* [Quora: What Are Some Interesting Word2Vec Results?](http://www.quora.com/Word2vec/What-are-some-interesting-Word2Vec-results/answer/Omer-Levy)
|
|
||||||
* [Word2Vec: an introduction](http://www.folgertkarsdorp.nl/word2vec-an-introduction/); Folgert Karsdorp
|
|
||||||
* [Mikolov's Original Word2vec Code @Google](https://code.google.com/p/word2vec/)
|
|
||||||
* [word2vec Explained: Deriving Mikolov et al.’s Negative-Sampling Word-Embedding Method](https://arxiv.org/pdf/1402.3722v1.pdf); Yoav Goldberg and Omer Levy
|
|
||||||
* [Advances in Pre-Training Distributed Word Representations - by Mikolov et al](https://arxiv.org/abs/1712.09405)
|
|
||||||
|
|
||||||
|
|
||||||
### <a name="doctorow">Word2Vec in Literature</a>
|
|
||||||
|
|
||||||
It's like numbers are language, like all the letters in the language are turned into numbers, and so it's something that everyone understands the same way. You lose the sounds of the letters and whether they click or pop or touch the palate, or go ooh or aah, and anything that can be misread or con you with its music or the pictures it puts in your mind, all of that is gone, along with the accent, and you have a new understanding entirely, a language of numbers, and everything becomes as clear to everyone as the writing on the wall. So as I say there comes a certain time for the reading of the numbers.
|
|
||||||
-- E.L. Doctorow, Billy Bathgate
|
|
|
@ -1,10 +0,0 @@
|
||||||
# deeplearning4j-nn documentation
|
|
||||||
|
|
||||||
To generate docs into the `deeplearning4j-nn/doc_sources` folder, first `cd docs` then run:
|
|
||||||
|
|
||||||
```shell
|
|
||||||
python generate_docs.py \
|
|
||||||
--project deeplearning4j-nn \
|
|
||||||
--code ../deeplearning4j
|
|
||||||
--out_language en
|
|
||||||
```
|
|
|
@ -1,187 +0,0 @@
|
||||||
{
|
|
||||||
"excludes": [
|
|
||||||
"abstract"
|
|
||||||
],
|
|
||||||
"indices": [
|
|
||||||
],
|
|
||||||
"pages": [
|
|
||||||
{
|
|
||||||
"page": "evaluation.md",
|
|
||||||
"module": [
|
|
||||||
"/deeplearning4j-nn/src/main/java/org/deeplearning4j/eval/"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"page": "model-persistence.md",
|
|
||||||
"class": [
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/util/ModelSerializer.java"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"page": "visualization.md",
|
|
||||||
"class": []
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"page": "tsne-visualization.md",
|
|
||||||
"class": []
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"page": "transfer-learning.md",
|
|
||||||
"module": [
|
|
||||||
"/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/transferlearning/"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"page": "listeners.md",
|
|
||||||
"module": [
|
|
||||||
"/deeplearning4j-nn/src/main/java/org/deeplearning4j/optimize/listeners/"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"page": "iterators.md",
|
|
||||||
"module": [
|
|
||||||
"/deeplearning4j-data/deeplearning4j-datasets/src/main/java/org/deeplearning4j/datasets/iterator/impl/",
|
|
||||||
"/deeplearning4j-data/deeplearning4j-datavec-iterators/src/main/java/org/deeplearning4j/datasets/datavec/",
|
|
||||||
"/deeplearning4j-data/deeplearning4j-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/"
|
|
||||||
],
|
|
||||||
"class": [
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/datasets/iterator/impl/MultiDataSetIteratorAdapter.java"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"page": "layers.md",
|
|
||||||
"class": [
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/OutputLayer.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DropoutLayer.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ActivationLayer.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LossLayer.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DenseLayer.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/EmbeddingLayer.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/EmbeddingSequenceLayer.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/GlobalPoolingLayer.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocalResponseNormalization.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected1D.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected2D.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/NoParamLayer.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Pooling1D.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Pooling2D.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling1DLayer.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SubsamplingLayer.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling3DLayer.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling1D.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling2D.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling3D.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPadding1DLayer.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPaddingLayer.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPadding3DLayer.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/ElementWiseMultiplicationLayer.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/RepeatVector.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/objdetect/Yolo2OutputLayer.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/util/MaskLayer.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/util/MaskZeroLayer.java"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"page": "autoencoders.md",
|
|
||||||
"class": [
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/variational/BernoulliReconstructionDistribution.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/variational/CompositeReconstructionDistribution.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/variational/ExponentialReconstructionDistribution.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/variational/GaussianReconstructionDistribution.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/variational/LossFunctionWrapper.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/variational/ReconstructionDistribution.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/variational/VariationalAutoencoder.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/AutoEncoder.java"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"page": "convolutional.md",
|
|
||||||
"class": [
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution1D.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution2D.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution3D.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Deconvolution2D.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping1D.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping2D.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping3D.java"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"page": "recurrent.md",
|
|
||||||
"class": [
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/GravesBidirectionalLSTM.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/GravesLSTM.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LSTM.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RnnOutputLayer.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RnnLossLayer.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/Bidirectional.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/LastTimeStep.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/SimpleRnn.java"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"page": "custom-layer.md",
|
|
||||||
"class": [
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/BaseLayer.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/FeedForwardLayer.java"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"page": "vertices.md",
|
|
||||||
"class": [
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ElementWiseVertex.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/InputVertex.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/L2NormalizeVertex.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/L2Vertex.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/MergeVertex.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/PoolHelperVertex.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ReshapeVertex.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ScaleVertex.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ShiftVertex.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/StackVertex.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/SubsetVertex.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/UnstackVertex.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/rnn/DuplicateToTimeSeriesVertex.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/rnn/LastTimeStepVertex.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/rnn/ReverseTimeSeriesVertex.java"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"page": "early-stopping.md",
|
|
||||||
"class": [
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/earlystopping/EarlyStoppingConfiguration.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/earlystopping/EarlyStoppingModelSaver.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/earlystopping/EarlyStoppingResult.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/earlystopping/scorecalc/AutoencoderScoreCalculator.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/earlystopping/scorecalc/ClassificationScoreCalculator.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/earlystopping/scorecalc/DataSetLossCalculator.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/earlystopping/scorecalc/DataSetLossCalculatorCG.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/earlystopping/scorecalc/ROCScoreCalculator.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/earlystopping/scorecalc/RegressionScoreCalculator.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/earlystopping/scorecalc/VAEReconErrorScoreCalculator.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/earlystopping/scorecalc/VAEReconProbScoreCalculator.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/earlystopping/termination/ScoreImprovementEpochTerminationCondition.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/earlystopping/termination/BestScoreEpochTerminationCondition.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/earlystopping/termination/EpochTerminationCondition.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/earlystopping/termination/InvalidScoreIterationTerminationCondition.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/earlystopping/termination/IterationTerminationCondition.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/earlystopping/termination/MaxEpochsTerminationCondition.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/earlystopping/termination/MaxScoreIterationTerminationCondition.java",
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/earlystopping/termination/MaxTimeIterationTerminationCondition.java"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"page": "computationgraph.md",
|
|
||||||
"class": [
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"page": "multilayernetwork.md",
|
|
||||||
"class": [
|
|
||||||
"deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
|
@ -1,19 +0,0 @@
|
||||||
---
|
|
||||||
title: Deeplearning4j Autoencoders
|
|
||||||
short_title: Autoencoders
|
|
||||||
description: Supported autoencoder configurations.
|
|
||||||
category: Models
|
|
||||||
weight: 3
|
|
||||||
---
|
|
||||||
|
|
||||||
## What are autoencoders?
|
|
||||||
|
|
||||||
Autoencoders are neural networks for unsupervised learning. Eclipse Deeplearning4j supports certain autoencoder layers such as variational autoencoders.
|
|
||||||
|
|
||||||
## Where's Restricted Boltzmann Machine?
|
|
||||||
|
|
||||||
RBMs are no longer supported as of version 0.9.x. They are no longer best-in-class for most machine learning problems.
|
|
||||||
|
|
||||||
## Supported layers
|
|
||||||
|
|
||||||
{{autogenerated}}
|
|
|
@ -1,258 +0,0 @@
|
||||||
---
|
|
||||||
title: Complex Architectures with Computation Graph
|
|
||||||
short_title: Computation Graph
|
|
||||||
description: How to build complex networks with DL4J computation graph.
|
|
||||||
category: Models
|
|
||||||
weight: 3
|
|
||||||
---
|
|
||||||
|
|
||||||
## Building Complex Network Architectures with Computation Graph
|
|
||||||
|
|
||||||
This page describes how to build more complicated networks, using DL4J's Computation Graph functionality.
|
|
||||||
|
|
||||||
**Contents**
|
|
||||||
|
|
||||||
* [Overview of the Computation Graph](#overview)
|
|
||||||
* [Computation Graph: Some Example Use Cases](#usecases)
|
|
||||||
* [Configuring a ComputationGraph network](#config)
|
|
||||||
* [Types of Graph Vertices](#vertextypes)
|
|
||||||
* [Example 1: Recurrent Network with Skip Connections](#rnnskip)
|
|
||||||
* [Example 2: Multiple Inputs and Merge Vertex](#multiin)
|
|
||||||
* [Example 3: Multi-Task Learning](#multitask)
|
|
||||||
* [Automatically Adding PreProcessors and Calculating nIns](#preprocessors)
|
|
||||||
* [Training Data for ComputationGraph](#data)
|
|
||||||
* [RecordReaderMultiDataSetIterator Example 1: Regression Data](#rrmdsi1)
|
|
||||||
* [RecordReaderMultiDataSetIterator Example 2: Classification and Multi-Task Learning](#rrmdsi2)
|
|
||||||
|
|
||||||
|
|
||||||
## <a name="overview">Overview of Computation Graph</a>
|
|
||||||
|
|
||||||
DL4J has two types of networks comprised of multiple layers:
|
|
||||||
|
|
||||||
- The [MultiLayerNetwork](https://github.com/eclipse/deeplearning4j/blob/master/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java), which is essentially a stack of neural network layers (with a single input layer and single output layer), and
|
|
||||||
- The [ComputationGraph](https://github.com/eclipse/deeplearning4j/blob/master/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java), which allows for greater freedom in network architectures
|
|
||||||
|
|
||||||
|
|
||||||
Specifically, the ComputationGraph allows for networks to be built with the following features:
|
|
||||||
|
|
||||||
- Multiple network input arrays
|
|
||||||
- Multiple network outputs (including mixed classification/regression architectures)
|
|
||||||
- Layers connected to other layers using a directed acyclic graph connection structure (instead of just a stack of layers)
|
|
||||||
|
|
||||||
As a general rule, when building networks with a single input layer, a single output layer, and an input->a->b->c->output type connection structure: MultiLayerNetwork is usually the preferred network. However, everything that MultiLayerNetwork can do, ComputationGraph can do as well - though the configuration may be a little more complicated.
|
|
||||||
|
|
||||||
<p align="center">
|
|
||||||
<a href="https://docs.skymind.ai/docs/welcome" type="button" class="btn btn-lg btn-success" onClick="ga('send', 'event', ‘quickstart', 'click');">GET STARTED WITH DEEP LEARNING</a>
|
|
||||||
</p>
|
|
||||||
|
|
||||||
## <a name="usecases">Computation Graph: Some Example Use Cases</a>
|
|
||||||
|
|
||||||
Examples of some architectures that can be built using ComputationGraph include:
|
|
||||||
|
|
||||||
- Multi-task learning architectures
|
|
||||||
- Recurrent neural networks with skip connections
|
|
||||||
- [GoogLeNet](https://arxiv.org/abs/1409.4842), a complex type of convolutional netural network for image classification
|
|
||||||
- [Image caption generation](https://arxiv.org/abs/1411.4555)
|
|
||||||
- [Convolutional networks for sentence classification](https://github.com/eclipse/deeplearning4j-examples/blob/master/dl4j-examples/src/main/java/org/deeplearning4j/examples/convolution/sentenceclassification/CnnSentenceClassificationExample.java)
|
|
||||||
- [Residual learning convolutional neural networks](https://arxiv.org/abs/1512.03385)
|
|
||||||
|
|
||||||
|
|
||||||
## <a name="config">Configuring a Computation Graph</a>
|
|
||||||
|
|
||||||
### <a name="vertextypes">Types of Graph Vertices</a>
|
|
||||||
|
|
||||||
The basic idea is that in the ComputationGraph, the core building block is the [GraphVertex](https://github.com/eclipse/deeplearning4j/blob/master/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/GraphVertex.java), instead of layers. Layers (or, more accurately the [LayerVertex](https://github.com/eclipse/deeplearning4j/blob/master/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/LayerVertex.java) objects), are but one type of vertex in the graph. Other types of vertices include:
|
|
||||||
|
|
||||||
- Input Vertices
|
|
||||||
- Element-wise operation vertices
|
|
||||||
- Merge vertices
|
|
||||||
- Subset vertices
|
|
||||||
- Preprocessor vertices
|
|
||||||
|
|
||||||
These types of graph vertices are described briefly below.
|
|
||||||
|
|
||||||
**LayerVertex**: Layer vertices (graph vertices with neural network layers) are added using the ```.addLayer(String,Layer,String...)``` method. The first argument is the label for the layer, and the last arguments are the inputs to that layer.
|
|
||||||
If you need to manually add an [InputPreProcessor](https://github.com/eclipse/deeplearning4j/tree/master/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor) (usually this is unnecessary - see next section) you can use the ```.addLayer(String,Layer,InputPreProcessor,String...)``` method.
|
|
||||||
|
|
||||||
**InputVertex**: Input vertices are specified by the ```addInputs(String...)``` method in your configuration. The strings used as inputs can be arbitrary - they are user-defined labels, and can be referenced later in the configuration. The number of strings provided define the number of inputs; the order of the input also defines the order of the corresponding INDArrays in the fit methods (or the DataSet/MultiDataSet objects).
|
|
||||||
|
|
||||||
**ElementWiseVertex**: Element-wise operation vertices do for example an element-wise addition or subtraction of the activations out of one or more other vertices. Thus, the activations used as input for the ElementWiseVertex must all be the same size, and the output size of the elementwise vertex is the same as the inputs.
|
|
||||||
|
|
||||||
**MergeVertex**: The MergeVertex concatenates/merges the input activations. For example, if a MergeVertex has 2 inputs of size 5 and 10 respectively, then output size will be 5+10=15 activations. For convolutional network activations, examples are merged along the depth: so suppose the activations from one layer have 4 features and the other has 5 features (both with (4 or 5) x width x height activations), then the output will have (4+5) x width x height activations.
|
|
||||||
|
|
||||||
**SubsetVertex**: The subset vertex allows you to get only part of the activations out of another vertex. For example, to get the first 5 activations out of another vertex with label "layer1", you can use ```.addVertex("subset1", new SubsetVertex(0,4), "layer1")```: this means that the 0th through 4th (inclusive) activations out of the "layer1" vertex will be used as output from the subset vertex.
|
|
||||||
|
|
||||||
**PreProcessorVertex**: Occasionally, you might want to the functionality of an [InputPreProcessor](https://github.com/eclipse/deeplearning4j/tree/master/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor) without that preprocessor being associated with a layer. The PreProcessorVertex allows you to do this.
|
|
||||||
|
|
||||||
Finally, it is also possible to define custom graph vertices by implementing both a [configuration](https://github.com/eclipse/deeplearning4j/blob/master/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/GraphVertex.java) and [implementation](https://github.com/eclipse/deeplearning4j/blob/master/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/GraphVertex.java) class for your custom GraphVertex.
|
|
||||||
|
|
||||||
|
|
||||||
### <a name="rnnskip">Example 1: Recurrent Network with Skip Connections</a>
|
|
||||||
|
|
||||||
Suppose we wish to build the following recurrent neural network architecture:
|
|
||||||
![RNN with Skip connections](/images/guide/lstm_skip_connection.png)
|
|
||||||
|
|
||||||
For the sake of this example, lets assume our input data is of size 5. Our configuration would be as follows:
|
|
||||||
|
|
||||||
```java
|
|
||||||
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
|
|
||||||
.updater(new Sgd(0.01))
|
|
||||||
.graphBuilder()
|
|
||||||
.addInputs("input") //can use any label for this
|
|
||||||
.addLayer("L1", new GravesLSTM.Builder().nIn(5).nOut(5).build(), "input")
|
|
||||||
.addLayer("L2",new RnnOutputLayer.Builder().nIn(5+5).nOut(5).build(), "input", "L1")
|
|
||||||
.setOutputs("L2") //We need to specify the network outputs and their order
|
|
||||||
.build();
|
|
||||||
|
|
||||||
ComputationGraph net = new ComputationGraph(conf);
|
|
||||||
net.init();
|
|
||||||
```
|
|
||||||
|
|
||||||
Note that in the .addLayer(...) methods, the first string ("L1", "L2") is the name of that layer, and the strings at the end (["input"], ["input","L1"]) are the inputs to that layer.
|
|
||||||
|
|
||||||
|
|
||||||
### <a name="multiin">Example 2: Multiple Inputs and Merge Vertex</a>
|
|
||||||
|
|
||||||
Consider the following architecture:
|
|
||||||
|
|
||||||
![Computation Graph with Merge Vertex](/images/guide/compgraph_merge.png)
|
|
||||||
|
|
||||||
Here, the merge vertex takes the activations out of layers L1 and L2, and merges (concatenates) them: thus if layers L1 and L2 both have has 4 output activations (.nOut(4)) then the output size of the merge vertex is 4+4=8 activations.
|
|
||||||
|
|
||||||
To build the above network, we use the following configuration:
|
|
||||||
|
|
||||||
```java
|
|
||||||
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
|
|
||||||
.updater(new Sgd(0.01))
|
|
||||||
.graphBuilder()
|
|
||||||
.addInputs("input1", "input2")
|
|
||||||
.addLayer("L1", new DenseLayer.Builder().nIn(3).nOut(4).build(), "input1")
|
|
||||||
.addLayer("L2", new DenseLayer.Builder().nIn(3).nOut(4).build(), "input2")
|
|
||||||
.addVertex("merge", new MergeVertex(), "L1", "L2")
|
|
||||||
.addLayer("out", new OutputLayer.Builder().nIn(4+4).nOut(3).build(), "merge")
|
|
||||||
.setOutputs("out")
|
|
||||||
.build();
|
|
||||||
```
|
|
||||||
|
|
||||||
### <a name="multitask">Example 3: Multi-Task Learning</a>
|
|
||||||
|
|
||||||
In multi-task learning, a neural network is used to make multiple independent predictions.
|
|
||||||
Consider for example a simple network used for both classification and regression simultaneously. In this case, we have two output layers, "out1" for classification, and "out2" for regression.
|
|
||||||
|
|
||||||
![Computation Graph for MultiTask Learning](/images/guide/compgraph_multitask.png)
|
|
||||||
|
|
||||||
In this case, the network configuration is:
|
|
||||||
|
|
||||||
```java
|
|
||||||
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
|
|
||||||
.updater(new Sgd(0.01))
|
|
||||||
.graphBuilder()
|
|
||||||
.addInputs("input")
|
|
||||||
.addLayer("L1", new DenseLayer.Builder().nIn(3).nOut(4).build(), "input")
|
|
||||||
.addLayer("out1", new OutputLayer.Builder()
|
|
||||||
.lossFunction(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
|
|
||||||
.nIn(4).nOut(3).build(), "L1")
|
|
||||||
.addLayer("out2", new OutputLayer.Builder()
|
|
||||||
.lossFunction(LossFunctions.LossFunction.MSE)
|
|
||||||
.nIn(4).nOut(2).build(), "L1")
|
|
||||||
.setOutputs("out1","out2")
|
|
||||||
.build();
|
|
||||||
```
|
|
||||||
|
|
||||||
### <a name="preprocessors">Automatically Adding PreProcessors and Calculating nIns</a>
|
|
||||||
|
|
||||||
One feature of the ComputationGraphConfiguration is that you can specify the types of input to the network, using the ```.setInputTypes(InputType...)``` method in the configuration.
|
|
||||||
|
|
||||||
The setInputType method has two effects:
|
|
||||||
|
|
||||||
1. It will automatically add any [InputPreProcessor](https://github.com/eclipse/deeplearning4j/tree/master/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor)s as required. InputPreProcessors are necessary to handle the interaction between for example fully connected (dense) and convolutional layers, or recurrent and fully connected layers.
|
|
||||||
2. It will automatically calculate the number of inputs (.nIn(x) config) to a layer. Thus, if you are using the ```setInputTypes(InputType...)``` functionality, it is not necessary to manually specify the .nIn(x) options in your configuration. This can simplify building some architectures (such as convolutional networks with fully connected layers). If the .nIn(x) is specified for a layer, the network will not override this when using the InputType functionality.
|
|
||||||
|
|
||||||
|
|
||||||
For example, if your network has 2 inputs, one being a convolutional input and the other being a feed-forward input, you would use ```.setInputTypes(InputType.convolutional(depth,width,height), InputType.feedForward(feedForwardInputSize))```
|
|
||||||
|
|
||||||
|
|
||||||
## <a name="data">Training Data for ComputationGraph</a>
|
|
||||||
|
|
||||||
There are two types of data that can be used with the ComputationGraph.
|
|
||||||
|
|
||||||
### DataSet and the DataSetIterator
|
|
||||||
|
|
||||||
The DataSet class was originally designed for use with the MultiLayerNetwork, however can also be used with ComputationGraph - but only if that computation graph has a single input and output array. For computation graph architectures with more than one input array, or more than one output array, DataSet and DataSetIterator cannot be used (instead, use MultiDataSet/MultiDataSetIterator).
|
|
||||||
|
|
||||||
A DataSet object is basically a pair of INDArrays that hold your training data. In the case of RNNs, it may also include masking arrays (see [this](http://deeplearning4j.org/usingrnns) for more details). A DataSetIterator is essentially an iterator over DataSet objects.
|
|
||||||
|
|
||||||
### MultiDataSet and the MultiDataSetIterator
|
|
||||||
|
|
||||||
MultiDataSet is multiple input and/or multiple output version of DataSet. It may also include multiple mask arrays (for each input/output array) in the case of recurrent neural networks. As a general rule, you should use DataSet/DataSetIterator, unless you are dealing with multiple inputs and/or multiple outputs.
|
|
||||||
|
|
||||||
There are currently two ways to use a MultiDataSetIterator:
|
|
||||||
|
|
||||||
- By implementing the [MultiDataSetIterator](https://github.com/eclipse/deeplearning4j/blob/master/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/api/iterator/MultiDataSetIterator.java) interface directly
|
|
||||||
- By using the [RecordReaderMultiDataSetIterator](https://github.com/eclipse/deeplearning4j/blob/master/deeplearning4j/deeplearning4j-data/deeplearning4j-datavec-iterators/src/main/java/org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIterator.java) in conjuction with DataVec record readers
|
|
||||||
|
|
||||||
|
|
||||||
The RecordReaderMultiDataSetIterator provides a number of options for loading data. In particular, the RecordReaderMultiDataSetIterator provides the following functionality:
|
|
||||||
|
|
||||||
- Multiple DataVec RecordReaders may be used simultaneously
|
|
||||||
- The record readers need not be the same modality: for example, you can use an image record reader with a CSV record reader
|
|
||||||
- It is possible to use a subset of the columns in a RecordReader for different purposes - for example, the first 10 columns in a CSV could be your input, and the last 5 could be your output
|
|
||||||
- It is possible to convert single columns from a class index to a one-hot representation
|
|
||||||
|
|
||||||
|
|
||||||
Some basic examples on how to use the RecordReaderMultiDataSetIterator follow. You might also find [these unit tests](https://github.com/eclipse/deeplearning4j/blob/master/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIteratorTest.java) to be useful.
|
|
||||||
|
|
||||||
### <a name="rrmdsi1">RecordReaderMultiDataSetIterator Example 1: Regression Data</a>
|
|
||||||
|
|
||||||
Suppose we have a CSV file with 5 columns, and we want to use the first 3 as our input, and the last 2 columns as our output (for regression). We can build a MultiDataSetIterator to do this as follows:
|
|
||||||
|
|
||||||
```java
|
|
||||||
int numLinesToSkip = 0;
|
|
||||||
String fileDelimiter = ",";
|
|
||||||
RecordReader rr = new CSVRecordReader(numLinesToSkip,fileDelimiter);
|
|
||||||
String csvPath = "/path/to/my/file.csv";
|
|
||||||
rr.initialize(new FileSplit(new File(csvPath)));
|
|
||||||
|
|
||||||
int batchSize = 4;
|
|
||||||
MultiDataSetIterator iterator = new RecordReaderMultiDataSetIterator.Builder(batchSize)
|
|
||||||
.addReader("myReader",rr)
|
|
||||||
.addInput("myReader",0,2) //Input: columns 0 to 2 inclusive
|
|
||||||
.addOutput("myReader",3,4) //Output: columns 3 to 4 inclusive
|
|
||||||
.build();
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
### <a name="rrmdsi2">RecordReaderMultiDataSetIterator Example 2: Classification and Multi-Task Learning</a>
|
|
||||||
|
|
||||||
Suppose we have two separate CSV files, one for our inputs, and one for our outputs. Further suppose we are building a multi-task learning architecture, whereby have two outputs - one for classification.
|
|
||||||
For this example, let's assume the data is as follows:
|
|
||||||
|
|
||||||
- Input file: myInput.csv, and we want to use all columns as input (without modification)
|
|
||||||
- Output file: myOutput.csv.
|
|
||||||
- Network output 1 - regression: columns 0 to 3
|
|
||||||
- Network output 2 - classification: column 4 is the class index for classification, with 3 classes. Thus column 4 contains integer values [0,1,2] only, and we want to convert these indexes to a one-hot representation for classification.
|
|
||||||
|
|
||||||
In this case, we can build our iterator as follows:
|
|
||||||
|
|
||||||
```java
|
|
||||||
int numLinesToSkip = 0;
|
|
||||||
String fileDelimiter = ",";
|
|
||||||
|
|
||||||
RecordReader featuresReader = new CSVRecordReader(numLinesToSkip,fileDelimiter);
|
|
||||||
String featuresCsvPath = "/path/to/my/myInput.csv";
|
|
||||||
featuresReader.initialize(new FileSplit(new File(featuresCsvPath)));
|
|
||||||
|
|
||||||
RecordReader labelsReader = new CSVRecordReader(numLinesToSkip,fileDelimiter);
|
|
||||||
String labelsCsvPath = "/path/to/my/myOutput.csv";
|
|
||||||
labelsReader.initialize(new FileSplit(new File(labelsCsvPath)));
|
|
||||||
|
|
||||||
int batchSize = 4;
|
|
||||||
int numClasses = 3;
|
|
||||||
MultiDataSetIterator iterator = new RecordReaderMultiDataSetIterator.Builder(batchSize)
|
|
||||||
.addReader("csvInput", featuresReader)
|
|
||||||
.addReader("csvLabels", labelsReader)
|
|
||||||
.addInput("csvInput") //Input: all columns from input reader
|
|
||||||
.addOutput("csvLabels", 0, 3) //Output 1: columns 0 to 3 inclusive
|
|
||||||
.addOutputOneHot("csvLabels", 4, numClasses) //Output 2: column 4 -> convert to one-hot for classification
|
|
||||||
.build();
|
|
||||||
```
|
|
|
@ -1,15 +0,0 @@
|
||||||
---
|
|
||||||
title: Supported Convolutional Layers
|
|
||||||
short_title: Convolutional
|
|
||||||
description: Supported convolutional layers.
|
|
||||||
category: Models
|
|
||||||
weight: 3
|
|
||||||
---
|
|
||||||
|
|
||||||
## What is a convolutional neural network?
|
|
||||||
|
|
||||||
Each layer in a neural network configuration represents a unit of hidden units. When layers are stacked together, they represent a *deep neural network*.
|
|
||||||
|
|
||||||
## Available layers
|
|
||||||
|
|
||||||
{{autogenerated}}
|
|
|
@ -1,52 +0,0 @@
|
||||||
---
|
|
||||||
title: Custom Layers
|
|
||||||
short_title: Custom Layers
|
|
||||||
description: Extend DL4J functionality for custom layers.
|
|
||||||
category: Models
|
|
||||||
weight: 10
|
|
||||||
---
|
|
||||||
|
|
||||||
## Writing Your Custom Layer
|
|
||||||
|
|
||||||
There are two components to adding a custom layer:
|
|
||||||
|
|
||||||
1. Adding the layer configuration class: extends org.deeplearning4j.nn.conf.layers.Layer
|
|
||||||
2. Adding the layer implementation class: implements org.deeplearning4j.nn.api.Layer
|
|
||||||
|
|
||||||
The configuration layer ((1) above) class handles the settings. It's the one you would
|
|
||||||
use when constructing a MultiLayerNetwork or ComputationGraph. You can add custom
|
|
||||||
settings here, and use them in your layer.
|
|
||||||
|
|
||||||
The implementation layer ((2) above) class has parameters, and handles network forward
|
|
||||||
pass, backpropagation, etc. It is created from the org.deeplearning4j.nn.conf.layers.Layer.instantiate(...)
|
|
||||||
method. In other words: the instantiate method is how we go from the configuration
|
|
||||||
to the implementation; MultiLayerNetwork or ComputationGraph will call this method
|
|
||||||
when initializing the
|
|
||||||
|
|
||||||
An example of these are CustomLayer (the configuration class) and CustomLayerImpl (the
|
|
||||||
implementation class). Both of these classes have extensive comments regarding
|
|
||||||
their methods.
|
|
||||||
|
|
||||||
You'll note that in Deeplearning4j there are two DenseLayer clases, two GravesLSTM classes,
|
|
||||||
etc: the reason is because one is for the configuration, one is for the implementation.
|
|
||||||
We have not followed this "same name" pattern here to hopefully avoid confusion.
|
|
||||||
|
|
||||||
## Testing Your Custom Layer
|
|
||||||
|
|
||||||
Once you have added a custom layer, it is necessary to run some tests to ensure
|
|
||||||
it is correct.
|
|
||||||
|
|
||||||
These tests should at a minimum include the following:
|
|
||||||
|
|
||||||
1. Tests to ensure that the JSON configuration (to/from JSON) works correctly
|
|
||||||
This is necessary for networks with your custom layer to function with both
|
|
||||||
model serialization (saving) and Spark training.
|
|
||||||
2. Gradient checks to ensure that the implementation is correct.
|
|
||||||
|
|
||||||
## Example
|
|
||||||
|
|
||||||
A full custom layer example is available in our [examples repository](https://github.com/eclipse/deeplearning4j-examples/tree/master/dl4j-examples/src/main/java/org/deeplearning4j/examples/misc/customlayers).
|
|
||||||
|
|
||||||
## API
|
|
||||||
|
|
||||||
{{autogenerated}}
|
|
|
@ -1,83 +0,0 @@
|
||||||
---
|
|
||||||
title: Early Stopping
|
|
||||||
short_title: Early Stopping
|
|
||||||
description: Terminate a training session given certain conditions.
|
|
||||||
category: Tuning & Training
|
|
||||||
weight: 10
|
|
||||||
---
|
|
||||||
|
|
||||||
## What is early stopping?
|
|
||||||
|
|
||||||
When training neural networks, numerous decisions need to be made regarding the settings (hyperparameters) used, in order to obtain good performance. Once such hyperparameter is the number of training epochs: that is, how many full passes of the data set (epochs) should be used? If we use too few epochs, we might underfit (i.e., not learn everything we can from the training data); if we use too many epochs, we might overfit (i.e., fit the 'noise' in the training data, and not the signal).
|
|
||||||
|
|
||||||
Early stopping attempts to remove the need to manually set this value. It can also be considered a type of regularization method (like L1/L2 weight decay and dropout) in that it can stop the network from overfitting.
|
|
||||||
|
|
||||||
The idea behind early stopping is relatively simple:
|
|
||||||
|
|
||||||
* Split data into training and test sets
|
|
||||||
* At the end of each epoch (or, every N epochs):
|
|
||||||
* evaluate the network performance on the test set
|
|
||||||
* if the network outperforms the previous best model: save a copy of the network at the current epoch
|
|
||||||
* Take as our final model the model that has the best test set performance
|
|
||||||
|
|
||||||
|
|
||||||
This is shown graphically below:
|
|
||||||
|
|
||||||
![Early Stopping](/images/guide/earlystopping.png)
|
|
||||||
|
|
||||||
The best model is the one saved at the time of the vertical dotted line - i.e., the model with the best accuracy on the test set.
|
|
||||||
|
|
||||||
|
|
||||||
Using DL4J's early stopping functionality requires you to provide a number of configuration options:
|
|
||||||
|
|
||||||
* A score calculator, such as the *DataSetLossCalculator*([JavaDoc](https://deeplearning4j.org/api/{{page.version}}/org/deeplearning4j/earlystopping/scorecalc/DataSetLossCalculator.html), [Source Code](https://github.com/eclipse/deeplearning4j/blob/c152293ef8d1094c281f5375ded61ff5f8eb6587/deeplearning4j-core/src/main/java/org/deeplearning4j/earlystopping/scorecalc/DataSetLossCalculator.java)) for a Multi Layer Network, or *DataSetLossCalculatorCG* ([JavaDoc](https://deeplearning4j.org/api/{{page.version}}/org/deeplearning4j/earlystopping/scorecalc/DataSetLossCalculatorCG.html), [Source Code](https://github.com/eclipse/deeplearning4j/blob/c152293ef8d1094c281f5375ded61ff5f8eb6587/deeplearning4j-core/src/main/java/org/deeplearning4j/earlystopping/scorecalc/DataSetLossCalculatorCG.java)) for a Computation Graph. Is used to calculate at every epoch (for example: the loss function value on a test set, or the accuracy on the test set)
|
|
||||||
* How frequently we want to calculate the score function (default: every epoch)
|
|
||||||
* One or more termination conditions, which tell the training process when to stop. There are two classes of termination conditions:
|
|
||||||
* Epoch termination conditions: evaluated every N epochs
|
|
||||||
* Iteration termination conditions: evaluated once per minibatch
|
|
||||||
* A model saver, that defines how models are saved
|
|
||||||
|
|
||||||
An example, with an epoch termination condition of maximum of 30 epochs, a maximum of 20 minutes training time, calculating the score every epoch, and saving the intermediate results to disk:
|
|
||||||
|
|
||||||
```java
|
|
||||||
|
|
||||||
MultiLayerConfiguration myNetworkConfiguration = ...;
|
|
||||||
DataSetIterator myTrainData = ...;
|
|
||||||
DataSetIterator myTestData = ...;
|
|
||||||
|
|
||||||
EarlyStoppingConfiguration esConf = new EarlyStoppingConfiguration.Builder()
|
|
||||||
.epochTerminationConditions(new MaxEpochsTerminationCondition(30))
|
|
||||||
.iterationTerminationConditions(new MaxTimeIterationTerminationCondition(20, TimeUnit.MINUTES))
|
|
||||||
.scoreCalculator(new DataSetLossCalculator(myTestData, true))
|
|
||||||
.evaluateEveryNEpochs(1)
|
|
||||||
.modelSaver(new LocalFileModelSaver(directory))
|
|
||||||
.build();
|
|
||||||
|
|
||||||
EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf,myNetworkConfiguration,myTrainData);
|
|
||||||
|
|
||||||
//Conduct early stopping training:
|
|
||||||
EarlyStoppingResult result = trainer.fit();
|
|
||||||
|
|
||||||
//Print out the results:
|
|
||||||
System.out.println("Termination reason: " + result.getTerminationReason());
|
|
||||||
System.out.println("Termination details: " + result.getTerminationDetails());
|
|
||||||
System.out.println("Total epochs: " + result.getTotalEpochs());
|
|
||||||
System.out.println("Best epoch number: " + result.getBestModelEpoch());
|
|
||||||
System.out.println("Score at best epoch: " + result.getBestModelScore());
|
|
||||||
|
|
||||||
//Get the best model:
|
|
||||||
MultiLayerNetwork bestModel = result.getBestModel();
|
|
||||||
|
|
||||||
```
|
|
||||||
|
|
||||||
You can also implement your own iteration and epoch termination conditions.
|
|
||||||
|
|
||||||
## Early Stopping w/ Parallel Wrapper
|
|
||||||
|
|
||||||
The early stopping implementation described above will only work with a single device. However, `EarlyStoppingParallelTrainer` provides similar functionality as early stopping and allows you to optimize for either multiple CPUs or GPUs. `EarlyStoppingParallelTrainer` wraps your model in a `ParallelWrapper` class and performs localized distributed training.
|
|
||||||
|
|
||||||
Note that `EarlyStoppingParallelTrainer` doesn't support all of the functionality as its single device counterpart. It is not UI-compatible and may not work with complex iteration listeners. This is due to how the model is distributed and copied in the background.
|
|
||||||
|
|
||||||
## API
|
|
||||||
|
|
||||||
{{autogenerated}}
|
|
|
@ -1,212 +0,0 @@
|
||||||
---
|
|
||||||
title: Evaluation Classes for Neural Networks
|
|
||||||
short_title: Evaluation
|
|
||||||
description: Tools and classes for evaluating neural network performance
|
|
||||||
category: Tuning & Training
|
|
||||||
weight: 3
|
|
||||||
---
|
|
||||||
|
|
||||||
|
|
||||||
## Why evaluate?
|
|
||||||
|
|
||||||
When training or deploying a Neural Network it is useful to know the accuracy of your model. In DL4J the Evaluation Class and variants of the Evaluation Class are available to evaluate your model's performance.
|
|
||||||
|
|
||||||
|
|
||||||
### <a name="classification">Evaluation for Classification</a>
|
|
||||||
|
|
||||||
The Evaluation class is used to evaluate the performance for binary and multi-class classifiers (including time series classifiers). This section covers basic usage of the Evaluation Class.
|
|
||||||
|
|
||||||
Given a dataset in the form of a DataSetIterator, the easiest way to perform evaluation is to use the built-in evaluate methods on MultiLayerNetwork and ComputationGraph:
|
|
||||||
```
|
|
||||||
DataSetIterator myTestData = ...
|
|
||||||
Evaluation eval = model.evaluate(myTestData);
|
|
||||||
```
|
|
||||||
|
|
||||||
However, evaluation can be performed on individual minibatches also. Here is an example taken from our dataexamples/CSVExample in the [Examples](https://github.com/eclipse/deeplearning4j-examples) project.
|
|
||||||
|
|
||||||
The CSV example has CSV data for 3 classes of flowers and builds a simple feed forward neural network to classify the flowers based on 4 measurements.
|
|
||||||
|
|
||||||
```
|
|
||||||
Evaluation eval = new Evaluation(3);
|
|
||||||
INDArray output = model.output(testData.getFeatures());
|
|
||||||
eval.eval(testData.getLabels(), output);
|
|
||||||
log.info(eval.stats());
|
|
||||||
```
|
|
||||||
|
|
||||||
The first line creates an Evaluation object with 3 classes.
|
|
||||||
The second line gets the labels from the model for our test dataset.
|
|
||||||
The third line uses the eval method to compare the labels array from the testdata with the labels generated from the model.
|
|
||||||
The fourth line logs the evaluation data to the console.
|
|
||||||
|
|
||||||
The output.
|
|
||||||
|
|
||||||
```
|
|
||||||
Examples labeled as 0 classified by model as 0: 24 times
|
|
||||||
Examples labeled as 1 classified by model as 1: 11 times
|
|
||||||
Examples labeled as 1 classified by model as 2: 1 times
|
|
||||||
Examples labeled as 2 classified by model as 2: 17 times
|
|
||||||
|
|
||||||
|
|
||||||
==========================Scores========================================
|
|
||||||
# of classes: 3
|
|
||||||
Accuracy: 0.9811
|
|
||||||
Precision: 0.9815
|
|
||||||
Recall: 0.9722
|
|
||||||
F1 Score: 0.9760
|
|
||||||
Precision, recall & F1: macro-averaged (equally weighted avg. of 3 classes)
|
|
||||||
========================================================================
|
|
||||||
```
|
|
||||||
|
|
||||||
By default the .stats() method displays the confusion matrix entries (one per line), Accuracy, Precision, Recall and F1 Score. Additionally the Evaluation Class can also calculate and return the following values:
|
|
||||||
|
|
||||||
* Confusion Matrix
|
|
||||||
* False Positive/Negative Rate
|
|
||||||
* True Positive/Negative
|
|
||||||
* Class Counts
|
|
||||||
* F-beta, G-measure, Matthews Correlation Coefficient and more, see [Evaluation JavaDoc](https://deeplearning4j.org/api/latest/org/deeplearning4j/eval/Evaluation.html)
|
|
||||||
|
|
||||||
Display the Confusion Matrix.
|
|
||||||
|
|
||||||
```
|
|
||||||
System.out.println(eval.confusionToString());
|
|
||||||
```
|
|
||||||
|
|
||||||
Displays
|
|
||||||
|
|
||||||
```
|
|
||||||
Predicted: 0 1 2
|
|
||||||
Actual:
|
|
||||||
0 0 | 16 0 0
|
|
||||||
1 1 | 0 19 0
|
|
||||||
2 2 | 0 0 18
|
|
||||||
```
|
|
||||||
|
|
||||||
Additionaly the confusion matrix can be accessed directly, converted to csv or html using.
|
|
||||||
|
|
||||||
```
|
|
||||||
eval.getConfusionMatrix() ;
|
|
||||||
eval.getConfusionMatrix().toHTML();
|
|
||||||
eval.getConfusionMatrix().toCSV();
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
### <a name="regression">Evaluation for Regression</a>
|
|
||||||
|
|
||||||
To Evaluate a network performing regression use the RegressionEvaluation Class.
|
|
||||||
|
|
||||||
As with the Evaluation class, RegressionEvaluation on a DataSetIterator can be performed as follows:
|
|
||||||
```
|
|
||||||
DataSetIterator myTestData = ...
|
|
||||||
RegressionEvaluation eval = model.evaluateRegression(myTestData);
|
|
||||||
```
|
|
||||||
|
|
||||||
Here is a code snippet with single column, in this case the neural network was predicting the age of shelfish based on measurements.
|
|
||||||
|
|
||||||
```
|
|
||||||
RegressionEvaluation eval = new RegressionEvaluation(1);
|
|
||||||
```
|
|
||||||
|
|
||||||
Print the statistics for the Evaluation.
|
|
||||||
|
|
||||||
```
|
|
||||||
System.out.println(eval.stats());
|
|
||||||
```
|
|
||||||
|
|
||||||
Returns
|
|
||||||
|
|
||||||
```
|
|
||||||
Column MSE MAE RMSE RSE R^2
|
|
||||||
col_0 7.98925e+00 2.00648e+00 2.82653e+00 5.01481e-01 7.25783e-01
|
|
||||||
```
|
|
||||||
|
|
||||||
Columns are Mean Squared Error, Mean Absolute Error, Root Mean Squared Error, Relative Squared Error, and R^2 Coefficient of Determination
|
|
||||||
|
|
||||||
See [RegressionEvaluation JavaDoc](https://deeplearning4j.org/api/{{page.version}}/org/deeplearning4j/eval/RegressionEvaluation.html)
|
|
||||||
|
|
||||||
### <a name="multiple">Performing Multiple Evaluations Simultaneously</a>
|
|
||||||
|
|
||||||
When performing multiple types of evaluations (for example, Evaluation and ROC on the same network and dataset) it is more efficient to do this in one pass of the dataset, as follows:
|
|
||||||
|
|
||||||
```
|
|
||||||
DataSetIterator testData = ...
|
|
||||||
Evaluation eval = new Evaluation();
|
|
||||||
ROC roc = new ROC();
|
|
||||||
model.doEvaluation(testdata, eval, roc);
|
|
||||||
```
|
|
||||||
|
|
||||||
### <a name="timeseries">Evaluation of Time Series</a>
|
|
||||||
|
|
||||||
Time series evaluation is very similar to the above evaluation approaches. Evaluation in DL4J is performed on all (non-masked) time steps separately - for example, a time series of length 10 will contribute 10 predictions/labels to an Evaluation object.
|
|
||||||
One difference with time seires is the (optional) presence of mask arrays, which are used to mark some time steps as missing or not present. See [Using RNNs - Masking](./deeplearning4j-nn-recurrent) for more details on masking.
|
|
||||||
|
|
||||||
For most users, it is simply sufficient to use the ```MultiLayerNetwork.evaluate(DataSetIterator)``` or ```MultiLayerNetwork.evaluateRegression(DataSetIterator)``` and similar methods. These methods will properly handle masking, if mask arrays are present.
|
|
||||||
|
|
||||||
|
|
||||||
### <a name="binary">Evaluation for Binary Classifiers</a>
|
|
||||||
|
|
||||||
The EvaluationBinary is used for evaluating networks with binary classification outputs - these networks usually have Sigmoid activation functions and XENT loss functions. The typical classification metrics, such as accuracy, precision, recall, F1 score, etc. are calculated for each output.
|
|
||||||
|
|
||||||
```
|
|
||||||
EvaluationBinary eval = new EvaluationBinary(int size)
|
|
||||||
```
|
|
||||||
|
|
||||||
See [EvaluationBinary JavaDoc](https://deeplearning4j.org/api/{{page.version}}/org/deeplearning4j/eval/EvaluationBinary.html)
|
|
||||||
|
|
||||||
|
|
||||||
### <a name="roc">ROC</a>
|
|
||||||
|
|
||||||
ROC (Receiver Operating Characteristic) is another commonly used evaluation metric for the evaluation of classifiers. Three ROC variants exist in DL4J:
|
|
||||||
|
|
||||||
- ROC - for single binary label (as a single column probability, or 2 column 'softmax' probability distribution).
|
|
||||||
- ROCBinary - for multiple binary labels
|
|
||||||
- ROCMultiClass - for evaluation of non-binary classifiers, using a "one vs. all" approach
|
|
||||||
|
|
||||||
These classes have the ability to calculate the area under ROC curve (AUROC) and area under Precision-Recall curve (AUPRC), via the ```calculateAUC()``` and ```calculateAUPRC()``` methods. Furthermore, the ROC and Precision-Recall curves can be obtained using ```getRocCurve()``` and ```getPrecisionRecallCurve()```.
|
|
||||||
|
|
||||||
The ROC and Precision-Recall curves can be exported to HTML for viewing using: ```EvaluationTools.exportRocChartsToHtmlFile(ROC, File)```, which will export a HTML file with both ROC and P-R curves, that can be viewed in a browser.
|
|
||||||
|
|
||||||
|
|
||||||
Note that all three support two modes of operation/calculation
|
|
||||||
- Thresholded (approximate AUROC/AUPRC calculation, no memory issues)
|
|
||||||
- Exact (exact AUROC/AUPRC calculation, but can require large amount of memory with very large datasets - i.e., datasets with many millions of examples)
|
|
||||||
|
|
||||||
The number of bins can be set using the constructors. Exact can be set using the default constructor ```new ROC()``` or explicitly using ```new ROC(0)```
|
|
||||||
|
|
||||||
See [ROCBinary JavaDoc](https://deeplearning4j.org/api/{{page.version}}/org/deeplearning4j/eval/ROC.html) is used to evaluate Binary Classifiers.
|
|
||||||
|
|
||||||
### <a name="calibration">Evaluating Classifier Calibration</a>
|
|
||||||
|
|
||||||
Deeplearning4j also has the EvaluationCalibration class, which is designed to analyze the calibration of a classifier. It provides a number of tools for this purpose:
|
|
||||||
|
|
||||||
- Counts of the number of labels and predictions for each class
|
|
||||||
- Reliability diagram (or reliability curve)
|
|
||||||
- Residual plot (histogram)
|
|
||||||
- Histograms of probabilities, including probabilities for each class separately
|
|
||||||
|
|
||||||
Evaluation of a classifier using EvaluationCalibration is performed in a similar manner to the other evaluation classes.
|
|
||||||
The various plots/histograms can be exported to HTML for viewing using ```EvaluationTools.exportevaluationCalibrationToHtmlFile(EvaluationCalibration, File)```.
|
|
||||||
|
|
||||||
### <a name="spark">Distributed Evaluation for Spark Networks</a>
|
|
||||||
|
|
||||||
SparkDl4jMultiLayer and SparkComputationGraph both have similar methods for evaluation:
|
|
||||||
```
|
|
||||||
Evaluation eval = SparkDl4jMultiLayer.evaluate(JavaRDD<DataSet>);
|
|
||||||
|
|
||||||
//Multiple evaluations in one pass:
|
|
||||||
SparkDl4jMultiLayer.doEvaluation(JavaRDD<DataSet>, IEvaluation...);
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
### <a name="multitask">Evaluation for Multi-task Networks</a>
|
|
||||||
|
|
||||||
A multi-task network is a network that is trained to produce multiple outputs. For example a network given audio samples can be trained to both predict the language spoken and the gender of the speaker. Multi-task configuration is briefly described [here](./deeplearning4j-nn-computationgraph).
|
|
||||||
|
|
||||||
Evaluation Classes useful for Multi-Task Network
|
|
||||||
|
|
||||||
See [ROCMultiClass JavaDoc](https://deeplearning4j.org/api/{{page.version}}/org/deeplearning4j/eval/ROCMultiClass.html)
|
|
||||||
|
|
||||||
See [ROCBinary JavaDoc](https://deeplearning4j.org/api/{{page.version}}/org/deeplearning4j/eval/ROCBinary.html)
|
|
||||||
|
|
||||||
## Available evaluations
|
|
||||||
|
|
||||||
{{autogenerated}}
|
|
|
@ -1,44 +0,0 @@
|
||||||
---
|
|
||||||
title: Deeplearning4j Iterators
|
|
||||||
short_title: Iterators
|
|
||||||
description: Data iteration tools for loading into neural networks.
|
|
||||||
category: Models
|
|
||||||
weight: 5
|
|
||||||
---
|
|
||||||
|
|
||||||
## What is an iterator?
|
|
||||||
|
|
||||||
A dataset iterator allows for easy loading of data into neural networks and help organize batching, conversion, and masking. The iterators included in Eclipse Deeplearning4j help with either user-provided data, or automatic loading of common benchmarking datasets such as MNIST and IRIS.
|
|
||||||
|
|
||||||
## Usage
|
|
||||||
|
|
||||||
For most use cases, initializing an iterator and passing a reference to a `MultiLayerNetwork` or `ComputationGraph` `fit()` method is all you need to begin a task for training:
|
|
||||||
|
|
||||||
```java
|
|
||||||
MultiLayerNetwork model = new MultiLayerNetwork(conf);
|
|
||||||
model.init();
|
|
||||||
|
|
||||||
// pass an MNIST data iterator that automatically fetches data
|
|
||||||
DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize, true, rngSeed);
|
|
||||||
net.fit(mnistTrain);
|
|
||||||
```
|
|
||||||
|
|
||||||
Many other methods also accept iterators for tasks such as evaluation:
|
|
||||||
|
|
||||||
```java
|
|
||||||
// passing directly to the neural network
|
|
||||||
DataSetIterator mnistTest = new MnistDataSetIterator(batchSize, false, rngSeed);
|
|
||||||
net.eval(mnistTest);
|
|
||||||
|
|
||||||
// using an evaluation class
|
|
||||||
Evaluation eval = new Evaluation(10); //create an evaluation object with 10 possible classes
|
|
||||||
while(mnistTest.hasNext()){
|
|
||||||
DataSet next = mnistTest.next();
|
|
||||||
INDArray output = model.output(next.getFeatureMatrix()); //get the networks prediction
|
|
||||||
eval.eval(next.getLabels(), output); //check the prediction against the true class
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
## Available iterators
|
|
||||||
|
|
||||||
{{autogenerated}}
|
|
|
@ -1,23 +0,0 @@
|
||||||
---
|
|
||||||
title: Supported Layers
|
|
||||||
short_title: Layers
|
|
||||||
description: Supported neural network layers.
|
|
||||||
category: Models
|
|
||||||
weight: 3
|
|
||||||
---
|
|
||||||
|
|
||||||
## What are layers?
|
|
||||||
|
|
||||||
Each layer in a neural network configuration represents a unit of hidden units. When layers are stacked together, they represent a *deep neural network*.
|
|
||||||
|
|
||||||
## Using layers
|
|
||||||
|
|
||||||
All layers available in Eclipse Deeplearning4j can be used either in a `MultiLayerNetwork` or `ComputationGraph`. When configuring a neural network, you pass the layer configuration and the network will instantiate the layer for you.
|
|
||||||
|
|
||||||
## Layers vs. vertices
|
|
||||||
|
|
||||||
If you are configuring complex networks such as InceptionV4, you will need to use the `ComputationGraph` API and join different branches together using vertices. Check the vertices for more information.
|
|
||||||
|
|
||||||
## General layers
|
|
||||||
|
|
||||||
{{autogenerated}}
|
|
|
@ -1,26 +0,0 @@
|
||||||
---
|
|
||||||
title: Deeplearning4j Listeners
|
|
||||||
short_title: Listeners
|
|
||||||
description: Adding hooks and listeners on DL4J models.
|
|
||||||
category: Models
|
|
||||||
weight: 5
|
|
||||||
---
|
|
||||||
|
|
||||||
## What are listeners?
|
|
||||||
|
|
||||||
Listeners allow users to "hook" into certain events in Eclipse Deeplearning4j. This allows you to collect or print information useful for tasks like training. For example, a `ScoreIterationListener` allows you to print training scores from the output layer of a neural network.
|
|
||||||
|
|
||||||
## Usage
|
|
||||||
|
|
||||||
To add one or more listeners to a `MultiLayerNetwork` or `ComputationGraph`, use the `addListener` method:
|
|
||||||
|
|
||||||
```java
|
|
||||||
MultiLayerNetwork model = new MultiLayerNetwork(conf);
|
|
||||||
model.init();
|
|
||||||
//print the score with every 1 iteration
|
|
||||||
model.setListeners(new ScoreIterationListener(1));
|
|
||||||
```
|
|
||||||
|
|
||||||
## Available listeners
|
|
||||||
|
|
||||||
{{autogenerated}}
|
|
|
@ -1,28 +0,0 @@
|
||||||
---
|
|
||||||
title: Deeplearning4j Model Persistence
|
|
||||||
short_title: Model Persistence
|
|
||||||
description: Saving and loading of neural networks.
|
|
||||||
category: Models
|
|
||||||
weight: 10
|
|
||||||
---
|
|
||||||
|
|
||||||
## Saving and Loading a Neural Network
|
|
||||||
|
|
||||||
The `ModelSerializer` is a class which handles loading and saving models. There are two methods for saving models shown in the examples through the link. The first example saves a normal multilayer network, the second one saves a [computation graph](https://deeplearning4j.org/docs/latest/deeplearning4j-nn-computationgraph).
|
|
||||||
|
|
||||||
Here is a [basic example](https://github.com/eclipse/deeplearning4j-examples/tree/master/dl4j-examples/src/main/java/org/deeplearning4j/examples/misc/modelsaving) with code to save a computation graph using the `ModelSerializer` class, as well as an example of using ModelSerializer to save a neural net built using MultiLayer configuration.
|
|
||||||
|
|
||||||
### RNG Seed
|
|
||||||
|
|
||||||
If your model uses probabilities (i.e. DropOut/DropConnect), it may make sense to save it separately, and apply it after model is restored; i.e:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
Nd4j.getRandom().setSeed(12345);
|
|
||||||
ModelSerializer.restoreMultiLayerNetwork(modelFile);
|
|
||||||
```
|
|
||||||
|
|
||||||
This will guarantee equal results between sessions/JVMs.
|
|
||||||
|
|
||||||
## Model serializer
|
|
||||||
|
|
||||||
{{autogenerated}}
|
|
|
@ -1,81 +0,0 @@
|
||||||
---
|
|
||||||
title: Multilayer Network
|
|
||||||
short_title: Multilayer Network
|
|
||||||
description: Simple and sequential network configuration.
|
|
||||||
category: Models
|
|
||||||
weight: 3
|
|
||||||
---
|
|
||||||
|
|
||||||
## Why use MultiLayerNetwork?
|
|
||||||
|
|
||||||
The `MultiLayerNetwork` class is the simplest network configuration API available in Eclipse Deeplearning4j. This class is useful for beginners or users who do not need a complex and branched network graph.
|
|
||||||
|
|
||||||
You will not want to use `MultiLayerNetwork` configuration if you are creating complex loss functions, using graph vertices, or doing advanced training such as a triplet network. This includes popular complex networks such as InceptionV4.
|
|
||||||
|
|
||||||
## Usage
|
|
||||||
|
|
||||||
The example below shows how to build a simple linear classifier using `DenseLayer` (a basic multiperceptron layer).
|
|
||||||
|
|
||||||
```java
|
|
||||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
|
||||||
.seed(seed)
|
|
||||||
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
|
||||||
.learningRate(learningRate)
|
|
||||||
.updater(Updater.NESTEROVS).momentum(0.9)
|
|
||||||
.list()
|
|
||||||
.layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
|
|
||||||
.weightInit(WeightInit.XAVIER)
|
|
||||||
.activation("relu")
|
|
||||||
.build())
|
|
||||||
.layer(1, new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD)
|
|
||||||
.weightInit(WeightInit.XAVIER)
|
|
||||||
.activation("softmax").weightInit(WeightInit.XAVIER)
|
|
||||||
.nIn(numHiddenNodes).nOut(numOutputs).build())
|
|
||||||
.pretrain(false).backprop(true).build();
|
|
||||||
```
|
|
||||||
|
|
||||||
You can also create convolutional configurations:
|
|
||||||
|
|
||||||
```java
|
|
||||||
MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder()
|
|
||||||
.seed(seed)
|
|
||||||
.regularization(true).l2(0.0005)
|
|
||||||
.learningRate(0.01)//.biasLearningRate(0.02)
|
|
||||||
//.learningRateDecayPolicy(LearningRatePolicy.Inverse).lrPolicyDecayRate(0.001).lrPolicyPower(0.75)
|
|
||||||
.weightInit(WeightInit.XAVIER)
|
|
||||||
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
|
||||||
.updater(Updater.NESTEROVS).momentum(0.9)
|
|
||||||
.list()
|
|
||||||
.layer(0, new ConvolutionLayer.Builder(5, 5)
|
|
||||||
//nIn and nOut specify depth. nIn here is the nChannels and nOut is the number of filters to be applied
|
|
||||||
.nIn(nChannels)
|
|
||||||
.stride(1, 1)
|
|
||||||
.nOut(20)
|
|
||||||
.activation("identity")
|
|
||||||
.build())
|
|
||||||
.layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
|
|
||||||
.kernelSize(2,2)
|
|
||||||
.stride(2,2)
|
|
||||||
.build())
|
|
||||||
.layer(2, new ConvolutionLayer.Builder(5, 5)
|
|
||||||
//Note that nIn need not be specified in later layers
|
|
||||||
.stride(1, 1)
|
|
||||||
.nOut(50)
|
|
||||||
.activation("identity")
|
|
||||||
.build())
|
|
||||||
.layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
|
|
||||||
.kernelSize(2,2)
|
|
||||||
.stride(2,2)
|
|
||||||
.build())
|
|
||||||
.layer(4, new DenseLayer.Builder().activation("relu")
|
|
||||||
.nOut(500).build())
|
|
||||||
.layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
|
|
||||||
.nOut(outputNum)
|
|
||||||
.activation("softmax")
|
|
||||||
.build())
|
|
||||||
.backprop(true).pretrain(false);
|
|
||||||
```
|
|
||||||
|
|
||||||
## API
|
|
||||||
|
|
||||||
{{autogenerated}}
|
|
|
@ -1,355 +0,0 @@
|
||||||
---
|
|
||||||
title: Recurrent Neural Networks in DL4J
|
|
||||||
short_title: RNN
|
|
||||||
description: Recurrent Neural Network implementations in DL4J.
|
|
||||||
category: Models
|
|
||||||
weight: 10
|
|
||||||
---
|
|
||||||
|
|
||||||
## Recurrent Neural Networks in DL4J
|
|
||||||
|
|
||||||
This document outlines the specifics training features and the practicalities of how to use them in DeepLearning4J. This document assumes some familiarity with recurrent neural networks and their use - it is not an introduction to recurrent neural networks, and assumes some familiarity with their both their use and terminology.
|
|
||||||
|
|
||||||
**Contents**
|
|
||||||
|
|
||||||
* [The Basics: Data and Network Configuration](#basics)
|
|
||||||
* [RNN Training Features](#trainingfeatures)
|
|
||||||
* [Truncated Back Propagation Through Time](#tbptt)
|
|
||||||
* [Masking: One-to-Many, Many-to-One, and Sequence Classification](#masking)
|
|
||||||
* [Masking and Sequence Classification After Training](#testtimemasking)
|
|
||||||
* [Combining RNN Layers with Other Layer Types](#otherlayertypes)
|
|
||||||
* [Test Time: Prediction One Step at a Time](#rnntimestep)
|
|
||||||
* [Importing Time Series Data](#data)
|
|
||||||
* [Examples](#examples)
|
|
||||||
|
|
||||||
## <a name="basics">The Basics: Data and Network Configuration</a>
|
|
||||||
DL4J currently supports the following types of recurrent neural network
|
|
||||||
* GravesLSTM (Long Short-Term Memory)
|
|
||||||
* BidirectionalGravesLSTM
|
|
||||||
* BaseRecurrent
|
|
||||||
|
|
||||||
Java documentation for each is available, [GravesLSTM](https://deeplearning4j.org/api/{{page.version}}/org/deeplearning4j/nn/conf/layers/GravesLSTM.html),
|
|
||||||
[BidirectionalGravesLSTM](https://deeplearning4j.org/api/{{page.version}}/org/deeplearning4j/nn/conf/layers/GravesBidirectionalLSTM.html), [BaseRecurrent](https://deeplearning4j.org/api/latest/org/deeplearning4j/nn/conf/layers/BaseRecurrentLayer.html)
|
|
||||||
|
|
||||||
#### Data for RNNs
|
|
||||||
Consider for the moment a standard feed-forward network (a multi-layer perceptron or 'DenseLayer' in DL4J). These networks expect input and output data that is two-dimensional: that is, data with "shape" [numExamples,inputSize]. This means that the data into a feed-forward network has ‘numExamples’ rows/examples, where each row consists of ‘inputSize’ columns. A single example would have shape [1,inputSize], though in practice we generally use multiple examples for computational and optimization efficiency. Similarly, output data for a standard feed-forward network is also two dimensional, with shape [numExamples,outputSize].
|
|
||||||
|
|
||||||
Conversely, data for RNNs are time series. Thus, they have 3 dimensions: one additional dimension for time. Input data thus has shape [numExamples,inputSize,timeSeriesLength], and output data has shape [numExamples,outputSize,timeSeriesLength]. This means that the data in our INDArray is laid out such that the value at position (i,j,k) is the jth value at the kth time step of the ith example in the minibatch. This data layout is shown below.
|
|
||||||
|
|
||||||
When importing time series data using the class CSVSequenceRecordReader each line in the data files represents one time step with the earliest time series observation in the first row (or first row after header if present) and the most recent observation in the last row of the csv. Each feature time series is a separate column of the of the csv file. For example if you have five features in time series, each with 120 observations, and a training & test set of size 53 then there will be 106 input csv files(53 input, 53 labels). The 53 input csv files will each have five columns and 120 rows. The label csv files will have one column (the label) and one row.
|
|
||||||
|
|
||||||
![Data: Feed Forward vs. RNN](/images/guide/rnn_data.png)
|
|
||||||
|
|
||||||
#### RnnOutputLayer
|
|
||||||
|
|
||||||
RnnOutputLayer is a type of layer used as the final layer with many recurrent neural network systems (for both regression and classification tasks). RnnOutputLayer handles things like score calculation, and error calculation (of prediction vs. actual) given a loss function etc. Functionally, it is very similar to the 'standard' OutputLayer class (which is used with feed-forward networks); however it both outputs (and expects as labels/targets) 3d time series data sets.
|
|
||||||
|
|
||||||
Configuration for the RnnOutputLayer follows the same design other layers: for example, to set the third layer in a MultiLayerNetwork to a RnnOutputLayer for classification:
|
|
||||||
|
|
||||||
.layer(2, new RnnOutputLayer.Builder(LossFunction.MCXENT).activation(Activation.SOFTMAX)
|
|
||||||
.weightInit(WeightInit.XAVIER).nIn(prevLayerSize).nOut(nOut).build())
|
|
||||||
|
|
||||||
Use of RnnOutputLayer in practice can be seen in the examples, linked at the end of this document.
|
|
||||||
|
|
||||||
## <a name="trainingfeatures">RNN Training Features</a>
|
|
||||||
|
|
||||||
### <a name="tbptt">Truncated Back Propagation Through Time</a>
|
|
||||||
Training neural networks (including RNNs) can be quite computationally demanding. For recurrent neural networks, this is especially the case when we are dealing with long sequences - i.e., training data with many time steps.
|
|
||||||
|
|
||||||
Truncated backpropagation through time (BPTT) was developed in order to reduce the computational complexity of each parameter update in a recurrent neural network. In summary, it allows us to train networks faster (by performing more frequent parameter updates), for a given amount of computational power. It is recommended to use truncated BPTT when your input sequences are long (typically, more than a few hundred time steps).
|
|
||||||
|
|
||||||
Consider what happens when training a recurrent neural network with a time series of length 12 time steps. Here, we need to do a forward pass of 12 steps, calculate the error (based on predicted vs. actual), and do a backward pass of 12 time steps:
|
|
||||||
|
|
||||||
![Standard Backprop Training](/images/guide/rnn_tbptt_1.png)
|
|
||||||
|
|
||||||
For 12 time steps, in the image above, this is not a problem. Consider, however, that instead the input time series was 10,000 or more time steps. In this case, standard backpropagation through time would require 10,000 time steps for each of the forward and backward passes for each and every parameter update. This is of course very computationally demanding.
|
|
||||||
|
|
||||||
In practice, truncated BPTT splits the forward and backward passes into a set of smaller forward/backward pass operations. The specific length of these forward/backward pass segments is a parameter set by the user. For example, if we use truncated BPTT of length 4 time steps, learning looks like the following:
|
|
||||||
|
|
||||||
![Truncated BPTT](/images/guide/rnn_tbptt_2.png)
|
|
||||||
|
|
||||||
Note that the overall complexity for truncated BPTT and standard BPTT are approximately the same - both do the same number of time step during forward/backward pass. Using this method however, we get 3 parameter updates instead of one for approximately the same amount of effort. However, the cost is not exactly the same there is a small amount of overhead per parameter update.
|
|
||||||
|
|
||||||
The downside of truncated BPTT is that the length of the dependencies learned in truncated BPTT can be shorter than in full BPTT. This is easy to see: consider the images above, with a TBPTT length of 4. Suppose that at time step 10, the network needs to store some information from time step 0 in order to make an accurate prediction. In standard BPTT, this is ok: the gradients can flow backwards all the way along the unrolled network, from time 10 to time 0. In truncated BPTT, this is problematic: the gradients from time step 10 simply don't flow back far enough to cause the required parameter updates that would store the required information. This tradeoff is usually worth it, and (as long as the truncated BPTT lengths are set appropriately), truncated BPTT works well in practice.
|
|
||||||
|
|
||||||
Using truncated BPTT in DL4J is quite simple: just add the following code to your network configuration (at the end, before the final .build() in your network configuration)
|
|
||||||
|
|
||||||
.backpropType(BackpropType.TruncatedBPTT)
|
|
||||||
.tBPTTLength(100)
|
|
||||||
|
|
||||||
The above code snippet will cause any network training (i.e., calls to MultiLayerNetwork.fit() methods) to use truncated BPTT with segments of length 100 steps.
|
|
||||||
|
|
||||||
Some things of note:
|
|
||||||
|
|
||||||
* By default (if a backprop type is not manually specified), DL4J will use BackpropType.Standard (i.e., full BPTT).
|
|
||||||
* The tBPTTLength configuration parameter set the length of the truncated BPTT passes. Typically, this is somewhere on the order of 50 to 200 time steps, though depends on the application and data.
|
|
||||||
* The truncated BPTT lengths is typically a fraction of the total time series length (i.e., 200 vs. sequence length 1000), but variable length time series in the same minibatch is OK when using TBPTT (for example, a minibatch with two sequences - one of length 100 and another of length 1000 - with a TBPTT length of 200 - will work correctly)
|
|
||||||
|
|
||||||
### <a name="masking">Masking: One-to-Many, Many-to-One, and Sequence Classification</a>
|
|
||||||
|
|
||||||
DL4J supports a number of related training features for RNNs, based on the idea of padding and masking. Padding and masking allows us to support training situations including one-to-many, many-to-one, as also support variable length time series (in the same mini-batch).
|
|
||||||
|
|
||||||
Suppose we want to train a recurrent neural network with inputs or outputs that don't occur at every time step. Examples of this (for a single example) are shown in the image below. DL4J supports training networks for all of these situations:
|
|
||||||
|
|
||||||
![RNN Training Types](/images/guide/rnn_masking_1.png)
|
|
||||||
|
|
||||||
Without masking and padding, we are restricted to the many-to-many case (above, left): that is, (a) All examples are of the same length, and (b) Examples have both inputs and outputs at all time steps.
|
|
||||||
|
|
||||||
The idea behind padding is simple. Consider two time series of lengths 50 and 100 time steps, in the same mini-batch. The training data is a rectangular array; thus, we pad (i.e., add zeros to) the shorter time series (for both input and output), such that the input and output are both the same length (in this example: 100 time steps).
|
|
||||||
|
|
||||||
Of course, if this was all we did, it would cause problems during training. Thus, in addition to padding, we use a masking mechanism. The idea behind masking is simple: we have two additional arrays that record whether an input or output is actually present for a given time step and example, or whether the input/output is just padding.
|
|
||||||
|
|
||||||
Recall that with RNNs, our minibatch data has 3 dimensions, with shape [miniBatchSize,inputSize,timeSeriesLength] and [miniBatchSize,outputSize,timeSeriesLength] for the input and output respectively. The padding arrays are then 2 dimensional, with shape [miniBatchSize,timeSeriesLength] for both the input and output, with values of 0 ('absent') or 1 ('present') for each time series and example. The masking arrays for the input and output are stored in separate arrays.
|
|
||||||
|
|
||||||
For a single example, the input and output masking arrays are shown below:
|
|
||||||
|
|
||||||
![RNN Training Types](/images/guide/rnn_masking_2.png)
|
|
||||||
|
|
||||||
For the “Masking not required” cases, we could equivalently use a masking array of all 1s, which will give the same result as not having a mask array at all. Also note that it is possible to use zero, one or two masking arrays when learning RNNs - for example, the many-to-one case could have a masking array for the output only.
|
|
||||||
|
|
||||||
In practice: these padding arrays are generally created during the data import stage (for example, by the SequenceRecordReaderDatasetIterator – discussed later), and are contained within the DataSet object. If a DataSet contains masking arrays, the MultiLayerNetwork fit will automatically use them during training. If they are absent, no masking functionality is used.
|
|
||||||
|
|
||||||
#### Evaluation and Scoring with Masking
|
|
||||||
|
|
||||||
Mask arrays are also important when doing scoring and evaluation (i.e., when evaluating the accuracy of a RNN classifier). Consider for example the many-to-one case: there is only a single output for each example, and any evaluation should take this into account.
|
|
||||||
|
|
||||||
Evaluation using the (output) mask arrays can be used during evaluation by passing it to the following method:
|
|
||||||
|
|
||||||
Evaluation.evalTimeSeries(INDArray labels, INDArray predicted, INDArray outputMask)
|
|
||||||
|
|
||||||
where labels are the actual output (3d time series), predicted is the network predictions (3d time series, same shape as labels), and outputMask is the 2d mask array for the output. Note that the input mask array is not required for evaluation.
|
|
||||||
|
|
||||||
Score calculation will also make use of the mask arrays, via the MultiLayerNetwork.score(DataSet) method. Again, if the DataSet contains an output masking array, it will automatically be used when calculating the score (loss function - mean squared error, negative log likelihood etc) for the network.
|
|
||||||
|
|
||||||
#### <a name="testtimemasking">Masking and Sequence Classification After Training</a>
|
|
||||||
|
|
||||||
Sequence classification is one common use of masking. The idea is that although we have a sequence (time series) as input, we only want to provide a single label for the entire sequence (rather than one label at each time step in the sequence).
|
|
||||||
|
|
||||||
However, RNNs by design output sequences, of the same length of the input sequence. For sequence classification, masking allows us to train the network with this single label at the final time step - we essentially tell the network that there isn't *actually* label data anywhere except for the last time step.
|
|
||||||
|
|
||||||
Now, suppose we've trained our network, and want to get the last time step for predictions, from the time series output array. How do we do that?
|
|
||||||
|
|
||||||
|
|
||||||
To get the last time step, there are two cases to be aware of. First, when we have a single example, we don't actually need to use the mask arrays: we can just get the last time step in the output array:
|
|
||||||
|
|
||||||
```
|
|
||||||
INDArray timeSeriesFeatures = ...;
|
|
||||||
INDArray timeSeriesOutput = myNetwork.output(timeSeriesFeatures);
|
|
||||||
int timeSeriesLength = timeSeriesOutput.size(2); //Size of time dimension
|
|
||||||
INDArray lastTimeStepProbabilities = timeSeriesOutput.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.point(timeSeriesLength-1));
|
|
||||||
```
|
|
||||||
|
|
||||||
Assuming classification (same process for regression, however) the last line above gives us probabilities at the last time step - i.e., the class probabilities for our sequence classification.
|
|
||||||
|
|
||||||
|
|
||||||
The slightly more complex case is when we have multiple examples in the one minibatch (features array), where the lengths of each example differ. (If all are the same length: we can use the same process as above).
|
|
||||||
|
|
||||||
In this 'variable length' case, we need to get the last time step *for each example separately*. If we have the time series lengths for each example from our data pipeline, it becomes straightforward: we just iterate over examples, replacing the ```timeSeriesLength``` in the above code with the length of that example.
|
|
||||||
|
|
||||||
If we don't have the lengths of the time series directly, we need to extract them from the mask array.
|
|
||||||
|
|
||||||
If we have a labels mask array (which is a one-hot vector, like [0,0,0,1,0] for each time series):
|
|
||||||
|
|
||||||
```
|
|
||||||
INDArray labelsMaskArray = ...;
|
|
||||||
INDArray lastTimeStepIndices = Nd4j.argMax(labelMaskArray,1);
|
|
||||||
```
|
|
||||||
|
|
||||||
Alternatively, if we have only the features mask: One quick and dirty approach is to use this:
|
|
||||||
|
|
||||||
```
|
|
||||||
INDArray featuresMaskArray = ...;
|
|
||||||
int longestTimeSeries = featuresMaskArray.size(1);
|
|
||||||
INDArray linspace = Nd4j.linspace(1,longestTimeSeries,longestTimeSeries);
|
|
||||||
INDArray temp = featuresMaskArray.mulColumnVector(linspace);
|
|
||||||
INDArray lastTimeStepIndices = Nd4j.argMax(temp,1);
|
|
||||||
```
|
|
||||||
To understand what is happening here, note that originally we have a features mask like [1,1,1,1,0], from which we want to get the last non-zero element. So we map [1,1,1,1,0] -> [1,2,3,4,0], and then get the largest element (which is the last time step).
|
|
||||||
|
|
||||||
|
|
||||||
In either case, we can then do the following:
|
|
||||||
|
|
||||||
```
|
|
||||||
int numExamples = timeSeriesFeatures.size(0);
|
|
||||||
for( int i=0; i<numExamples; i++ ){
|
|
||||||
int thisTimeSeriesLastIndex = lastTimeStepIndices.getInt(i);
|
|
||||||
INDArray thisExampleProbabilities = timeSeriesOutput.get(NDArrayIndex.point(i), NDArrayIndex.all(), NDArrayIndex.point(thisTimeSeriesLastIndex));
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
### <a name="otherlayertypes">Combining RNN Layers with Other Layer Types</a>
|
|
||||||
|
|
||||||
RNN layers in DL4J can be combined with other layer types. For example, it is possible to combine DenseLayer and LSTM layers in the same network; or combine Convolutional (CNN) layers and LSTM layers for video.
|
|
||||||
|
|
||||||
Of course, the DenseLayer and Convolutional layers do not handle time series data - they expect a different type of input. To deal with this, we need to use the layer preprocessor functionality: for example, the CnnToRnnPreProcessor and FeedForwardToRnnPreprocessor classes. See [here](https://github.com/eclipse/deeplearning4j/tree/master/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor) for all preprocessors. Fortunately, in most situations, the DL4J configuration system will automatically add these preprocessors as required. However, the preprocessors can be added manually (overriding the automatic addition of preprocessors, for each layer).
|
|
||||||
|
|
||||||
For example, to manually add a preprocessor between layers 1 and 2, add the following to your network configuration: `.inputPreProcessor(2, new RnnToFeedForwardPreProcessor())`.
|
|
||||||
|
|
||||||
## <a name="rnntimestep">Test Time: Predictions One Step at a Time</a>
|
|
||||||
As with other types of neural networks, predictions can be generated for RNNs using the `MultiLayerNetwork.output()` and `MultiLayerNetwork.feedForward()` methods. These methods can be useful in many circumstances; however, they have the limitation that we can only generate predictions for time series, starting from scratch each and every time.
|
|
||||||
|
|
||||||
Consider for example the case where we want to generate predictions in a real-time system, where these predictions are based on a very large amount of history. It this case, it is impractical to use the output/feedForward methods, as they conduct the full forward pass over the entire data history, each time they are called. If we wish to make a prediction for a single time step, at every time step, these methods can be both (a) very costly, and (b) wasteful, as they do the same calculations over and over.
|
|
||||||
|
|
||||||
For these situations, MultiLayerNetwork provides four methods of note:
|
|
||||||
|
|
||||||
* `rnnTimeStep(INDArray)`
|
|
||||||
* `rnnClearPreviousState()`
|
|
||||||
* `rnnGetPreviousState(int layer)`
|
|
||||||
* `rnnSetPreviousState(int layer, Map<String,INDArray> state)`
|
|
||||||
|
|
||||||
The rnnTimeStep() method is designed to allow forward pass (predictions) to be conducted efficiently, one or more steps at a time. Unlike the output/feedForward methods, the rnnTimeStep method keeps track of the internal state of the RNN layers when it is called. It is important to note that output for the rnnTimeStep and the output/feedForward methods should be identical (for each time step), whether we make these predictions all at once (output/feedForward) or whether these predictions are generated one or more steps at a time (rnnTimeStep). Thus, the only difference should be the computational cost.
|
|
||||||
|
|
||||||
In summary, the MultiLayerNetwork.rnnTimeStep() method does two things:
|
|
||||||
|
|
||||||
1. Generate output/predictions (forward pass), using the previous stored state (if any)
|
|
||||||
2. Update the stored state, storing the activations for the last time step (ready to be used next time rnnTimeStep is called)
|
|
||||||
|
|
||||||
For example, suppose we want to use a RNN to predict the weather, one hour in advance (based on the weather at say the previous 100 hours as input).
|
|
||||||
If we were to use the output method, at each hour we would need to feed in the full 100 hours of data to predict the weather for hour 101. Then to predict the weather for hour 102, we would need to feed in the full 100 (or 101) hours of data; and so on for hours 103+.
|
|
||||||
|
|
||||||
Alternatively, we could use the rnnTimeStep method. Of course, if we want to use the full 100 hours of history before we make our first prediction, we still need to do the full forward pass:
|
|
||||||
|
|
||||||
![RNN Time Step](/images/guide/rnn_timestep_1.png)
|
|
||||||
|
|
||||||
For the first time we call rnnTimeStep, the only practical difference between the two approaches is that the activations/state of the last time step are stored - this is shown in orange. However, the next time we use the rnnTimeStep method, this stored state will be used to make the next predictions:
|
|
||||||
|
|
||||||
![RNN Time Step](/images/guide/rnn_timestep_2.png)
|
|
||||||
|
|
||||||
There are a number of important differences here:
|
|
||||||
|
|
||||||
1. In the second image (second call of rnnTimeStep) the input data consists of a single time step, instead of the full history of data
|
|
||||||
2. The forward pass is thus a single time step (as compared to the hundreds – or more)
|
|
||||||
3. After the rnnTimeStep method returns, the internal state will automatically be updated. Thus, predictions for time 103 could be made in the same way as for time 102. And so on.
|
|
||||||
|
|
||||||
However, if you want to start making predictions for a new (entirely separate) time series: it is necessary (and important) to manually clear the stored state, using the `MultiLayerNetwork.rnnClearPreviousState()` method. This will reset the internal state of all recurrent layers in the network.
|
|
||||||
|
|
||||||
If you need to store or set the internal state of the RNN for use in predictions, you can use the rnnGetPreviousState and rnnSetPreviousState methods, for each layer individually. This can be useful for example during serialization (network saving/loading), as the internal network state from the rnnTimeStep method is *not* saved by default, and must be saved and loaded separately. Note that these get/set state methods return and accept a map, keyed by the type of activation. For example, in the LSTM model, it is necessary to store both the output activations, and the memory cell state.
|
|
||||||
|
|
||||||
Some other points of note:
|
|
||||||
|
|
||||||
- We can use the rnnTimeStep method for multiple independent examples/predictions simultaneously. In the weather example above, we might for example want to make predicts for multiple locations using the same neural network. This works in the same way as training and the forward pass / output methods: multiple rows (dimension 0 in the input data) are used for multiple examples.
|
|
||||||
- If no history/stored state is set (i.e., initially, or after a call to rnnClearPreviousState), a default initialization (zeros) is used. This is the same approach as during training.
|
|
||||||
- The rnnTimeStep can be used for an arbitrary number of time steps simultaneously – not just one time step. However, it is important to note:
|
|
||||||
- For a single time step prediction: the data is 2 dimensional, with shape [numExamples,nIn]; in this case, the output is also 2 dimensional, with shape [numExamples,nOut]
|
|
||||||
- For multiple time step predictions: the data is 3 dimensional, with shape [numExamples,nIn,numTimeSteps]; the output will have shape [numExamples,nOut,numTimeSteps]. Again, the final time step activations are stored as before.
|
|
||||||
- It is not possible to change the number of examples between calls of rnnTimeStep (in other words, if the first use of rnnTimeStep is for say 3 examples, all subsequent calls must be with 3 examples). After resetting the internal state (using rnnClearPreviousState()), any number of examples can be used for the next call of rnnTimeStep.
|
|
||||||
- The rnnTimeStep method makes no changes to the parameters; it is used after training the network has been completed only.
|
|
||||||
- The rnnTimeStep method works with networks containing single and stacked/multiple RNN layers, as well as with networks that combine other layer types (such as Convolutional or Dense layers).
|
|
||||||
- The RnnOutputLayer layer type does not have any internal state, as it does not have any recurrent connections.
|
|
||||||
|
|
||||||
## <a name="data">Importing Time Series Data</a>
|
|
||||||
|
|
||||||
Data import for RNNs is complicated by the fact that we have multiple different types of data we could want to use for RNNs: one-to-many, many-to-one, variable length time series, etc. This section will describe the currently implemented data import mechanisms for DL4J.
|
|
||||||
|
|
||||||
The methods described here utilize the SequenceRecordReaderDataSetIterator class, in conjunction with the CSVSequenceRecordReader class from DataVec. This approach currently allows you to load delimited (tab, comma, etc) data from files, where each time series is in a separate file.
|
|
||||||
This method also supports:
|
|
||||||
|
|
||||||
* Variable length time series input
|
|
||||||
* One-to-many and many-to-one data loading (where input and labels are in different files)
|
|
||||||
* Label conversion from an index to a one-hot representation for classification (i.e., '2' to [0,0,1,0])
|
|
||||||
* Skipping a fixed/specified number of rows at the start of the data files (i.e., comment or header rows)
|
|
||||||
|
|
||||||
Note that in all cases, each line in the data files represents one time step.
|
|
||||||
|
|
||||||
(In addition to the examples below, you might find [these unit tests](https://github.com/eclipse/deeplearning4j/blob/master/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderDataSetiteratorTest.java) to be of some use.)
|
|
||||||
|
|
||||||
#### Example 1: Time Series of Same Length, Input and Labels in Separate Files
|
|
||||||
|
|
||||||
Suppose we have 10 time series in our training data, represented by 20 files: 10 files for the input of each time series, and 10 files for the output/labels. For now, assume these 20 files all contain the same number of time steps (i.e., same number of rows).
|
|
||||||
|
|
||||||
To use the [SequenceRecordReaderDataSetIterator](https://github.com/eclipse/deeplearning4j/blob/master/deeplearning4j/deeplearning4j-data/deeplearning4j-datavec-iterators/src/main/java/org/deeplearning4j/datasets/datavec/SequenceRecordReaderDataSetIterator.java) and [CSVSequenceRecordReader](https://github.com/eclipse/deeplearning4j/blob/master/datavec/datavec-api/src/main/java/org/datavec/api/records/reader/impl/csv/CSVSequenceRecordReader.java) approaches, we first create two CSVSequenceRecordReader objects, one for input and one for labels:
|
|
||||||
|
|
||||||
SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
|
|
||||||
SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ",");
|
|
||||||
|
|
||||||
This particular constructor takes the number of lines to skip (1 row skipped here), and the delimiter (comma character used here).
|
|
||||||
|
|
||||||
Second, we need to initialize these two readers, by telling them where to get the data from. We do this with an InputSplit object.
|
|
||||||
Suppose that our time series are numbered, with file names "myInput_0.csv", "myInput_1.csv", ..., "myLabels_0.csv", etc. One approach is to use the [NumberedFileInputSplit](https://github.com/eclipse/deeplearning4j/blob/master/datavec/datavec-api/src/main/java/org/datavec/api/split/NumberedFileInputSplit.java):
|
|
||||||
|
|
||||||
featureReader.initialize(new NumberedFileInputSplit("/path/to/data/myInput_%d.csv", 0, 9));
|
|
||||||
labelReader.initialize(new NumberedFileInputSplit(/path/to/data/myLabels_%d.csv", 0, 9));
|
|
||||||
|
|
||||||
In this particular approach, the "%d" is replaced by the corresponding number, and the numbers 0 to 9 (both inclusive) are used.
|
|
||||||
|
|
||||||
Finally, we can create our SequenceRecordReaderdataSetIterator:
|
|
||||||
|
|
||||||
DataSetIterator iter = new SequenceRecordReaderDataSetIterator(featureReader, labelReader, miniBatchSize, numPossibleLabels, regression);
|
|
||||||
|
|
||||||
This DataSetIterator can then be passed to MultiLayerNetwork.fit() to train the network.
|
|
||||||
|
|
||||||
The miniBatchSize argument specifies the number of examples (time series) in each minibatch. For example, with 10 files total, miniBatchSize of 5 would give us two data sets with 2 minibatches (DataSet objects) with 5 time series in each.
|
|
||||||
|
|
||||||
Note that:
|
|
||||||
|
|
||||||
* For classification problems: numPossibleLabels is the number of classes in your data set. Use regression = false.
|
|
||||||
* Labels data: one value per line, as a class index
|
|
||||||
* Label data will be converted to a one-hot representation automatically
|
|
||||||
* For regression problems: numPossibleLabels is not used (set it to anything) and use regression = true.
|
|
||||||
* The number of values in the input and labels can be anything (unlike classification: can have an arbitrary number of outputs)
|
|
||||||
* No processing of the labels is done when regression = true
|
|
||||||
|
|
||||||
#### Example 2: Time Series of Same Length, Input and Labels in Same File
|
|
||||||
|
|
||||||
Following on from the last example, suppose that instead of a separate files for our input data and labels, we have both in the same file. However, each time series is still in a separate file.
|
|
||||||
|
|
||||||
As of DL4J 0.4-rc3.8, this approach has the restriction of a single column for the output (either a class index, or a single real-valued regression output)
|
|
||||||
|
|
||||||
In this case, we create and initialize a single reader. Again, we are skipping one header row, and specifying the format as comma delimited, and assuming our data files are named "myData_0.csv", ..., "myData_9.csv":
|
|
||||||
|
|
||||||
SequenceRecordReader reader = new CSVSequenceRecordReader(1, ",");
|
|
||||||
reader.initialize(new NumberedFileInputSplit("/path/to/data/myData_%d.csv", 0, 9));
|
|
||||||
DataSetIterator iterClassification = new SequenceRecordReaderDataSetIterator(reader, miniBatchSize, numPossibleLabels, labelIndex, false);
|
|
||||||
|
|
||||||
`miniBatchSize` and `numPossibleLabels` are the same as the previous example. Here, `labelIndex` specifies which column the labels are in. For example, if the labels are in the fifth column, use labelIndex = 4 (i.e., columns are indexed 0 to numColumns-1).
|
|
||||||
|
|
||||||
For regression on a single output value, we use:
|
|
||||||
|
|
||||||
DataSetIterator iterRegression = new SequenceRecordReaderDataSetIterator(reader, miniBatchSize, -1, labelIndex, true);
|
|
||||||
|
|
||||||
Again, the numPossibleLabels argument is not used for regression.
|
|
||||||
|
|
||||||
#### Example 3: Time Series of Different Lengths (Many-to-Many)
|
|
||||||
|
|
||||||
Following on from the previous two examples, suppose that for each example individually, the input and labels are of the same length, but these lengths differ between time series.
|
|
||||||
|
|
||||||
We can use the same approach (CSVSequenceRecordReader and SequenceRecordReaderDataSetIterator), though with a different constructor:
|
|
||||||
|
|
||||||
DataSetIterator variableLengthIter = new SequenceRecordReaderDataSetIterator(featureReader, labelReader, miniBatchSize, numPossibleLabels, regression, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);
|
|
||||||
|
|
||||||
The argument here are the same as in the previous example, with the exception of the AlignmentMode.ALIGN_END addition. This alignment mode input tells the SequenceRecordReaderDataSetIterator to expect two things:
|
|
||||||
|
|
||||||
1. That the time series may be of different lengths
|
|
||||||
2. To align the input and labels - for each example individually - such that their last values occur at the same time step.
|
|
||||||
|
|
||||||
Note that if the features and labels are always of the same length (as is the assumption in example 3), then the two alignment modes (AlignmentMode.ALIGN_END and AlignmentMode.ALIGN_START) will give identical outputs. The alignment mode option is explained in the next section.
|
|
||||||
|
|
||||||
Also note: that variable length time series always start at time zero in the data arrays: padding, if required, will be added after the time series has ended.
|
|
||||||
|
|
||||||
Unlike examples 1 and 2 above, the DataSet objects produced by the above variableLengthIter instance will also include input and masking arrays, as described earlier in this document.
|
|
||||||
|
|
||||||
#### Example 4: Many-to-One and One-to-Many Data
|
|
||||||
We can also use the AlignmentMode functionality in example 3 to implement a many-to-one RNN sequence classifier. Here, let us assume:
|
|
||||||
|
|
||||||
* Input and labels are in separate delimited files
|
|
||||||
* The labels files contain a single row (time step) (either a class index for classification, or one or more numbers for regression)
|
|
||||||
* The input lengths may (optionally) differ between examples
|
|
||||||
|
|
||||||
In fact, the same approach as in example 3 can do this:
|
|
||||||
|
|
||||||
DataSetIterator variableLengthIter = new SequenceRecordReaderDataSetIterator(featureReader, labelReader, miniBatchSize, numPossibleLabels, regression, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);
|
|
||||||
|
|
||||||
Alignment modes are relatively straightforward. They specify whether to pad the start or the end of the shorter time series. The diagram below shows how this works, along with the masking arrays (as discussed earlier in this document):
|
|
||||||
|
|
||||||
![Sequence Alignment](/images/guide/rnn_seq_alignment.png)
|
|
||||||
|
|
||||||
The one-to-many case (similar to the last case above, but with only one input) is done by using AlignmentMode.ALIGN_START.
|
|
||||||
|
|
||||||
Note that in the case of training data that contains time series of different lengths, the labels and inputs will be aligned for each example individually, and then the shorter time series will be padded as required:
|
|
||||||
|
|
||||||
![Sequence Alignment](/images/guide/rnn_seq_alignment_2.png)
|
|
||||||
|
|
||||||
## Available layers
|
|
||||||
|
|
||||||
{{autogenerated}}
|
|
|
@ -1,160 +0,0 @@
|
||||||
---
|
|
||||||
title: Neural Network Transfer Learning
|
|
||||||
short_title: Transfer Learning
|
|
||||||
description:
|
|
||||||
category: Tuning & Training
|
|
||||||
weight: 5
|
|
||||||
---
|
|
||||||
|
|
||||||
## DL4J’s Transfer Learning API
|
|
||||||
|
|
||||||
The DL4J transfer learning API enables users to:
|
|
||||||
|
|
||||||
* Modify the architecture of an existing model
|
|
||||||
* Fine tune learning configurations of an existing model.
|
|
||||||
* Hold parameters of a specified layer constant during training, also referred to as “frozen"
|
|
||||||
|
|
||||||
Holding certain layers frozen on a network and training is effectively the same as training on a transformed version of the input, the transformed version being the intermediate outputs at the boundary of the frozen layers. This is the process of “feature extraction” from the input data and will be referred to as “featurizing” in this document.
|
|
||||||
|
|
||||||
|
|
||||||
## The transfer learning helper
|
|
||||||
|
|
||||||
The forward pass to “featurize” the input data on large, pertained networks can be time consuming. DL4J also provides a TransferLearningHelper class with the following capabilities.
|
|
||||||
|
|
||||||
* Featurize an input dataset to save for future use
|
|
||||||
* Fit the model with frozen layers with a featurized dataset
|
|
||||||
* Output from the model with frozen layers given a featurized input.
|
|
||||||
|
|
||||||
When running multiple epochs users will save on computation time since the expensive forward pass on the frozen layers/vertices will only have to be conducted once.
|
|
||||||
|
|
||||||
|
|
||||||
## Show me the code
|
|
||||||
|
|
||||||
This example will use VGG16 to classify images belonging to five categories of flowers. The dataset will automatically download from http://download.tensorflow.org/example_images/flower_photos.tgz
|
|
||||||
|
|
||||||
#### I. Import a zoo model
|
|
||||||
|
|
||||||
As of 0.9.0 (0.8.1-SNAPSHOT) Deeplearning4j has a new native model zoo. Read about the [deeplearning4j-zoo](/model-zoo) module for more information on using pretrained models. Here, we load a pretrained VGG-16 model initialized with weights trained on ImageNet:
|
|
||||||
|
|
||||||
```
|
|
||||||
ZooModel zooModel = new VGG16();
|
|
||||||
ComputationGraph pretrainedNet = (ComputationGraph) zooModel.initPretrained(PretrainedType.IMAGENET);
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
#### II. Set up a fine-tune configuration
|
|
||||||
|
|
||||||
```
|
|
||||||
FineTuneConfiguration fineTuneConf = new FineTuneConfiguration.Builder()
|
|
||||||
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
|
||||||
.updater(new Nesterovs(5e-5))
|
|
||||||
.seed(seed)
|
|
||||||
.build();
|
|
||||||
```
|
|
||||||
|
|
||||||
#### III. Build new models based on VGG16
|
|
||||||
|
|
||||||
##### A.Modifying only the last layer, keeping other frozen
|
|
||||||
|
|
||||||
The final layer of VGG16 does a softmax regression on the 1000 classes in ImageNet. We modify the very last layer to give predictions for five classes keeping the other layers frozen.
|
|
||||||
|
|
||||||
```
|
|
||||||
ComputationGraph vgg16Transfer = new TransferLearning.GraphBuilder(pretrainedNet)
|
|
||||||
.fineTuneConfiguration(fineTuneConf)
|
|
||||||
.setFeatureExtractor("fc2")
|
|
||||||
.removeVertexKeepConnections("predictions")
|
|
||||||
.addLayer("predictions",
|
|
||||||
new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
|
|
||||||
.nIn(4096).nOut(numClasses)
|
|
||||||
.weightInit(WeightInit.XAVIER)
|
|
||||||
.activation(Activation.SOFTMAX).build(), "fc2")
|
|
||||||
.build();
|
|
||||||
```
|
|
||||||
After a mere thirty iterations, which in this case is exposure to 450 images, the model attains an accuracy > 75% on the test dataset. This is rather remarkable considering the complexity of training an image classifier from scratch.
|
|
||||||
|
|
||||||
##### B. Attach new layers to the bottleneck (block5_pool)
|
|
||||||
|
|
||||||
Here we hold all but the last three dense layers frozen and attach new dense layers onto it. Note that the primary intent here is to demonstrate the use of the API, secondary to what might give better results.
|
|
||||||
|
|
||||||
```
|
|
||||||
ComputationGraph vgg16Transfer = new TransferLearning.GraphBuilder(pretrainedNet)
|
|
||||||
.fineTuneConfiguration(fineTuneConf)
|
|
||||||
.setFeatureExtractor("block5_pool")
|
|
||||||
.nOutReplace("fc2",1024, WeightInit.XAVIER)
|
|
||||||
.removeVertexAndConnections("predictions")
|
|
||||||
.addLayer("fc3",new DenseLayer.Builder()
|
|
||||||
.activation(Activation.RELU)
|
|
||||||
.nIn(1024).nOut(256).build(),"fc2")
|
|
||||||
.addLayer("newpredictions",new OutputLayer
|
|
||||||
.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
|
|
||||||
.activation(Activation.SOFTMAX)
|
|
||||||
.nIn(256).nOut(numClasses).build(),"fc3")
|
|
||||||
.setOutputs("newpredictions")
|
|
||||||
.build();
|
|
||||||
```
|
|
||||||
|
|
||||||
##### C. Fine tune layers from a previously saved model
|
|
||||||
|
|
||||||
Say we have saved off our model from (B) and now want to allow “block_5” layers to train.
|
|
||||||
|
|
||||||
```
|
|
||||||
ComputationGraph vgg16FineTune = new TransferLearning.GraphBuilder(vgg16Transfer)
|
|
||||||
.fineTuneConfiguration(fineTuneConf)
|
|
||||||
.setFeatureExtractor(“block4_pool”)
|
|
||||||
.build();
|
|
||||||
```
|
|
||||||
|
|
||||||
#### IV. Saving “featurized” datasets and training with them.
|
|
||||||
|
|
||||||
We use the transfer learning helper API. Note this freezes the layers of the model passed in.
|
|
||||||
|
|
||||||
Here is how you obtain the featured version of the dataset at the specified layer “fc2”.
|
|
||||||
|
|
||||||
```
|
|
||||||
TransferLearningHelper transferLearningHelper =
|
|
||||||
new TransferLearningHelper(pretrainedNet, "fc2");
|
|
||||||
while(trainIter.hasNext()) {
|
|
||||||
DataSet currentFeaturized = transferLearningHelper.featurize(trainIter.next());
|
|
||||||
saveToDisk(currentFeaturized,trainDataSaved,true);
|
|
||||||
trainDataSaved++;
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
Here is how you can fit with a featured dataset. vgg16Transfer is a model setup in (A) of section III.
|
|
||||||
|
|
||||||
```
|
|
||||||
TransferLearningHelper transferLearningHelper =
|
|
||||||
new TransferLearningHelper(vgg16Transfer);
|
|
||||||
while (trainIter.hasNext()) {
|
|
||||||
transferLearningHelper.fitFeaturized(trainIter.next());
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
## Notes
|
|
||||||
|
|
||||||
* The TransferLearning builder returns a new instance of a dl4j model.
|
|
||||||
|
|
||||||
Keep in mind this is a second model that leaves the original one untouched. For large pertained network take into consideration memory requirements and adjust your JVM heap space accordingly.
|
|
||||||
|
|
||||||
* The trained model helper imports models from Keras without enforcing a training configuration.
|
|
||||||
|
|
||||||
Therefore the last layer (as seen when printing the summary) is a dense layer and not an output layer with a loss function. Therefore to modify nOut of an output layer we delete the layer vertex, keeping it’s connections and add back in a new output layer with the same name, a different nOut, the suitable loss function etc etc.
|
|
||||||
|
|
||||||
* Changing nOuts at a layer/vertex will modify nIn of the layers/vertices it fans into.
|
|
||||||
|
|
||||||
When changing nOut users can specify a weight initialization scheme or a distribution for the layer as well as a separate weight initialization scheme or distribution for the layers it fans out to.
|
|
||||||
|
|
||||||
* Frozen layer configurations are not saved when writing the model to disk.
|
|
||||||
|
|
||||||
In other words, a model with frozen layers when serialized and read back in will not have any frozen layers. To continue training holding specific layers constant the user is expected to go through the transfer learning helper or the transfer learning API. There are two ways to “freeze” layers in a dl4j model.
|
|
||||||
|
|
||||||
- On a copy: With the transfer learning API which will return a new model with the relevant frozen layers
|
|
||||||
- In place: With the transfer learning helper API which will apply the frozen layers to the given model.
|
|
||||||
|
|
||||||
* FineTune configurations will selectively update learning parameters.
|
|
||||||
|
|
||||||
For eg, if a learning rate is specified this learning rate will apply to all unfrozen/trainable layers in the model. However, newly added layers can override this learning rate by specifying their own learning rates in the layer builder.
|
|
||||||
|
|
||||||
## Utilities
|
|
||||||
|
|
||||||
{{autogenerated}}
|
|
|
@ -1,72 +0,0 @@
|
||||||
---
|
|
||||||
title: t-SNE's Data Visualization
|
|
||||||
short_title: t-SNE Visualization
|
|
||||||
description: Data visualizaiton with t-SNE with higher dimensional data.
|
|
||||||
category: Tuning & Training
|
|
||||||
weight: 10
|
|
||||||
---
|
|
||||||
|
|
||||||
## t-SNE's Data Visualization
|
|
||||||
|
|
||||||
[t-Distributed Stochastic Neighbor Embedding](https://en.wikipedia.org/wiki/T-distributed_stochastic_neighbor_embedding) (t-SNE) is a data-visualization tool created by Laurens van der Maaten at Delft University of Technology.
|
|
||||||
|
|
||||||
While it can be used for any data, t-SNE (pronounced Tee-Snee) is only really meaningful with labeled data, which clarify how the input is clustering. Below, you can see the kind of graphic you can generate in DL4J with t-SNE working on MNIST data.
|
|
||||||
|
|
||||||
![Alt text](/images/guide/tsne.png)
|
|
||||||
|
|
||||||
Look closely and you can see the numerals clustered near their likes, alongside the dots.
|
|
||||||
|
|
||||||
Here's how t-SNE appears in Deeplearning4j code.
|
|
||||||
|
|
||||||
```java
|
|
||||||
public class TSNEStandardExample {
|
|
||||||
|
|
||||||
private static Logger log = LoggerFactory.getLogger(TSNEStandardExample.class);
|
|
||||||
|
|
||||||
public static void main(String[] args) throws Exception {
|
|
||||||
//STEP 1: Initialization
|
|
||||||
int iterations = 100;
|
|
||||||
//create an n-dimensional array of doubles
|
|
||||||
DataTypeUtil.setDTypeForContext(DataBuffer.Type.DOUBLE);
|
|
||||||
List<String> cacheList = new ArrayList<>(); //cacheList is a dynamic array of strings used to hold all words
|
|
||||||
|
|
||||||
//STEP 2: Turn text input into a list of words
|
|
||||||
log.info("Load & Vectorize data....");
|
|
||||||
File wordFile = new ClassPathResource("words.txt").getFile(); //Open the file
|
|
||||||
//Get the data of all unique word vectors
|
|
||||||
Pair<InMemoryLookupTable,VocabCache> vectors = WordVectorSerializer.loadTxt(wordFile);
|
|
||||||
VocabCache cache = vectors.getSecond();
|
|
||||||
INDArray weights = vectors.getFirst().getSyn0(); //seperate weights of unique words into their own list
|
|
||||||
|
|
||||||
for(int i = 0; i < cache.numWords(); i++) //seperate strings of words into their own list
|
|
||||||
cacheList.add(cache.wordAtIndex(i));
|
|
||||||
|
|
||||||
//STEP 3: build a dual-tree tsne to use later
|
|
||||||
log.info("Build model....");
|
|
||||||
BarnesHutTsne tsne = new BarnesHutTsne.Builder()
|
|
||||||
.setMaxIter(iterations).theta(0.5)
|
|
||||||
.normalize(false)
|
|
||||||
.learningRate(500)
|
|
||||||
.useAdaGrad(false)
|
|
||||||
// .usePca(false)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
//STEP 4: establish the tsne values and save them to a file
|
|
||||||
log.info("Store TSNE Coordinates for Plotting....");
|
|
||||||
String outputFile = "target/archive-tmp/tsne-standard-coords.csv";
|
|
||||||
(new File(outputFile)).getParentFile().mkdirs();
|
|
||||||
tsne.plot(weights,2,cacheList,outputFile);
|
|
||||||
//This tsne will use the weights of the vectors as its matrix, have two dimensions, use the words strings as
|
|
||||||
//labels, and be written to the outputFile created on the previous line
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
Here is an image of the tsne-standard-coords.csv file plotted using gnuplot.
|
|
||||||
|
|
||||||
|
|
||||||
![Tsne data plot](/images/guide/tsne_output.png)
|
|
|
@ -1,15 +0,0 @@
|
||||||
---
|
|
||||||
title: Supported Vertices
|
|
||||||
short_title: Vertices
|
|
||||||
description: Computation graph nodes for advanced configuration.
|
|
||||||
category: Models
|
|
||||||
weight: 4
|
|
||||||
---
|
|
||||||
|
|
||||||
## What is a vertex?
|
|
||||||
|
|
||||||
In Eclipse Deeplearning4j a vertex is a type of layer that acts as a node in a `ComputationGraph`. It can accept multiple inputs, provide multiple outputs, and can help construct popular networks such as InceptionV4.
|
|
||||||
|
|
||||||
## Available classes
|
|
||||||
|
|
||||||
{{autogenerated}}
|
|
|
@ -1,325 +0,0 @@
|
||||||
---
|
|
||||||
title: Visualize, Monitor and Debug Neural Network Learning
|
|
||||||
short_title: Visualization
|
|
||||||
description: How to visualize, monitor and debug neural network learning.
|
|
||||||
category: Tuning & Training
|
|
||||||
weight: 2
|
|
||||||
---
|
|
||||||
|
|
||||||
## Contents
|
|
||||||
|
|
||||||
* [Visualizing Network Training with the Deeplearning4j Training UI](#ui)
|
|
||||||
* [Deeplearning4j UI: The Overview Page](#overviewpage)
|
|
||||||
* [Deeplearning4j UI: The Model Page](#modelpage)
|
|
||||||
* [Deeplearning4J UI and Spark Training](#sparkui)
|
|
||||||
* [Using the UI to Tune Your Network](#usingui)
|
|
||||||
* [TSNE and Word2Vec](#tsne)
|
|
||||||
* [Fixing UI Issue: "No configuration setting" exception](#issues)
|
|
||||||
|
|
||||||
## <a name="ui">Visualizing Network Training with the Deeplearning4j Training UI</a>
|
|
||||||
|
|
||||||
**Note**: This information here pertains to DL4J versions 0.7.0 and later.
|
|
||||||
|
|
||||||
DL4J Provides a user interface to visualize in your browser (in real time) the current network status and progress of training. The UI is typically used to help with tuning neural networks - i.e., the selection of hyperparameters (such as learning rate) to obtain good performance for a network.
|
|
||||||
|
|
||||||
**Step 1: Add the Deeplearning4j UI dependency to your project.**
|
|
||||||
|
|
||||||
```
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.deeplearning4j</groupId>
|
|
||||||
<artifactId>deeplearning4j-ui_2.10</artifactId>
|
|
||||||
<version>{{ page.version }}</version>
|
|
||||||
</dependency>
|
|
||||||
```
|
|
||||||
|
|
||||||
Note the ```_2.10``` suffix: this is the Scala version (due to using the Play framework, a Scala library, for the backend). If you are not using other Scala libraries, either ```_2.10``` or ```_2.11``` is OK.
|
|
||||||
|
|
||||||
**Step 2: Enable the UI in your project**
|
|
||||||
|
|
||||||
This is relatively straightforward:
|
|
||||||
|
|
||||||
```
|
|
||||||
//Initialize the user interface backend
|
|
||||||
UIServer uiServer = UIServer.getInstance();
|
|
||||||
|
|
||||||
//Configure where the network information (gradients, score vs. time etc) is to be stored. Here: store in memory.
|
|
||||||
StatsStorage statsStorage = new InMemoryStatsStorage(); //Alternative: new FileStatsStorage(File), for saving and loading later
|
|
||||||
|
|
||||||
//Attach the StatsStorage instance to the UI: this allows the contents of the StatsStorage to be visualized
|
|
||||||
uiServer.attach(statsStorage);
|
|
||||||
|
|
||||||
//Then add the StatsListener to collect this information from the network, as it trains
|
|
||||||
net.setListeners(new StatsListener(statsStorage));
|
|
||||||
```
|
|
||||||
|
|
||||||
To access the UI, open your browser and go to ```http://localhost:9000/train```.
|
|
||||||
You can set the port by using the ```org.deeplearning4j.ui.port``` system property: i.e., to use port 9001, pass the following to the JVM on launch: ```-Dorg.deeplearning4j.ui.port=9001```
|
|
||||||
|
|
||||||
Information will then be collected and routed to the UI when you call the ```fit``` method on your network.
|
|
||||||
|
|
||||||
|
|
||||||
**Example:** [See a UI example here](https://github.com/eclipse/deeplearning4j-examples/blob/master/dl4j-examples/src/main/java/org/deeplearning4j/examples/userInterface/UIExample.java)
|
|
||||||
|
|
||||||
The full set of UI examples are available [here](https://github.com/eclipse/deeplearning4j-examples/tree/master/dl4j-examples/src/main/java/org/deeplearning4j/examples/userInterface).
|
|
||||||
|
|
||||||
|
|
||||||
### <a name="overviewpage">Deeplearning4j UI: The Overview Page</a>
|
|
||||||
|
|
||||||
![Overview Page](/images/guide/DL4J_UI_01.png)
|
|
||||||
|
|
||||||
The overview page (one of 3 available pages) contains the following information:
|
|
||||||
|
|
||||||
- Top left: score vs iteration chart - this is the value of the loss function on the current minibatch
|
|
||||||
- Top right: model and training information
|
|
||||||
- Bottom left: Ratio of parameters to updates (by layer) for all network weights vs. iteration
|
|
||||||
- Bottom right: Standard deviations (vs. time) of: activations, gradients and updates
|
|
||||||
|
|
||||||
Note that for the bottom two charts, these are displayed as the logarithm (base 10) of the values. Thus a value of -3 on the update: parameter ratio chart corresponds to a ratio of 10<sup>-3</sup> = 0.001.
|
|
||||||
|
|
||||||
The ratio of updates to parameters is specifically the ratio of mean magnitudes of these values (i.e., log10(mean(abs(updates))/mean(abs(parameters))).
|
|
||||||
|
|
||||||
See the later section of this page on how to use these values in practice.
|
|
||||||
|
|
||||||
### <a name="modelpage">Deeplearning4j UI: The Model Page</a>
|
|
||||||
|
|
||||||
![Model Page](/images/guide/DL4J_UI_02.png)
|
|
||||||
|
|
||||||
The model page contains a graph of the neural network layers, which operates as a selection mechanism. Click on a layer to display information for it.
|
|
||||||
|
|
||||||
On the right, the following charts are available, after selecting a layer:
|
|
||||||
|
|
||||||
- Table of layer information
|
|
||||||
- Update to parameter ratio for this layer, as per the overview page. The components of this ratio (the parameter and update mean magnitudes) are also available via tabs.
|
|
||||||
- Layer activations (mean and mean +/- 2 standard deviations) over time
|
|
||||||
- Histograms of parameters and updates, for each parameter type
|
|
||||||
- Learning rate vs. time (note this will be flat, unless learning rate schedules are used)
|
|
||||||
|
|
||||||
|
|
||||||
*Note: parameters are labeled as follows: weights (W) and biases (b). For recurrent neural networks, W refers to the weights connecting the layer to the layer below, and RW refers to the recurrent weights (i.e., those between time steps).*
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## <a name="sparkui">Deeplearning4J UI and Spark Training</a>
|
|
||||||
|
|
||||||
The DL4J UI can be used with Spark. However, as of 0.7.0, conflicting dependencies mean that running the UI and Spark is the same JVM can be difficult.
|
|
||||||
|
|
||||||
Two alternatives are available:
|
|
||||||
|
|
||||||
1. Collect and save the relevant stats, to be visualized (offline) at a later point
|
|
||||||
2. Run the UI in a separate server, and Use the remote UI functionality to upload the data from the Spark master to your UI instance
|
|
||||||
|
|
||||||
**Collecting Stats for Later Offline Use**
|
|
||||||
|
|
||||||
```
|
|
||||||
SparkDl4jMultiLayer sparkNet = new SparkDl4jMultiLayer(sc, conf, tm);
|
|
||||||
|
|
||||||
StatsStorage ss = new FileStatsStorage(new File("myNetworkTrainingStats.dl4j"));
|
|
||||||
sparkNet.setListeners(ss, Collections.singletonList(new StatsListener(null)));
|
|
||||||
```
|
|
||||||
|
|
||||||
Then, later you can load and display the saved information using:
|
|
||||||
|
|
||||||
```
|
|
||||||
StatsStorage statsStorage = new FileStatsStorage(statsFile); //If file already exists: load the data from it
|
|
||||||
UIServer uiServer = UIServer.getInstance();
|
|
||||||
uiServer.attach(statsStorage);
|
|
||||||
```
|
|
||||||
|
|
||||||
**Using the Remote UI Functionality**
|
|
||||||
|
|
||||||
First, in the JVM running the UI (note this is the server):
|
|
||||||
|
|
||||||
```
|
|
||||||
UIServer uiServer = UIServer.getInstance();
|
|
||||||
uiServer.enableRemoteListener(); //Necessary: remote support is not enabled by default
|
|
||||||
```
|
|
||||||
This will require the ```deeplearning4j-ui_2.10``` or ```deeplearning4j-ui_2.11``` dependency. (NOTE THIS IS NOT THE CLIENT THIS IS YOUR SERVER - SEE BELOW FOR THE CLIENT WHICH USES: deeplearning4j-ui-model)
|
|
||||||
|
|
||||||
Client (both spark and standalone neural networks using simple deeplearning4j-nn)
|
|
||||||
Second, for your neural net (Note this example is for spark, but computation graph and multi layer network both have the equivalemtn setListeners method with the same usage, [example found here](https://github.com/eclipse/deeplearning4j-examples/blob/master/dl4j-examples/src/main/java/org/deeplearning4j/examples/userInterface/RemoteUIExample.java)):
|
|
||||||
|
|
||||||
```
|
|
||||||
SparkDl4jMultiLayer sparkNet = new SparkDl4jMultiLayer(sc, conf, tm);
|
|
||||||
|
|
||||||
StatsStorageRouter remoteUIRouter = new RemoteUIStatsStorageRouter("http://UI_MACHINE_IP:9000");
|
|
||||||
sparkNet.setListeners(remoteUIRouter, Collections.singletonList(new StatsListener(null)));
|
|
||||||
```
|
|
||||||
To avoid dependency conflicts with Spark, you should use the ```deeplearning4j-ui-model``` dependency to get the StatsListener, *not* the full ```deeplearning4j-ui_2.10``` UI dependency.
|
|
||||||
|
|
||||||
**Note to scala users**:
|
|
||||||
|
|
||||||
You need to use the above method if you are on a newer scala version. See the linked example above for the client.
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Note: you should replace ```UI_MACHINE_IP``` with the IP address of the machine running the user interface instance.
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## <a name="usingui">Using the UI to Tune Your Network</a>
|
|
||||||
|
|
||||||
Here's an excellent [web page by Andrej Karpathy](http://cs231n.github.io/neural-networks-3/#baby) about visualizing neural net training. It is worth reading and understanding that page first.
|
|
||||||
|
|
||||||
Tuning neural networks is often more an art than a science. However, here's some ideas that may be useful:
|
|
||||||
|
|
||||||
**Overview Page - Model Score vs. Iteration Chart**
|
|
||||||
|
|
||||||
The score vs. iteration should (overall) go down over time.
|
|
||||||
|
|
||||||
- If the score increases consistently, your learning rate is likely set too high. Try reducing it until scores become more stable.
|
|
||||||
- Increasing scores can also be indicative of other network issues, such as incorrect data normalization
|
|
||||||
- If the score is flat or decreases very slowly (over a few hundred iterations) (a) your learning rate may be too low, or (b) you might be having difficulties with optimization. In the latter case, if you are using the SGD updater, try a different updater such as Nesterovs (momentum), RMSProp or Adagrad.
|
|
||||||
- Note that data that isn't shuffled (i.e., each minibatch contains only one class, for classification) can result in very rough or abnormal-looking score vs. iteration graphs
|
|
||||||
- Some noise in this line chart is expected (i.e., the line will go up and down within a small range). However, if the scores vary quite significantly between runs variation is very large, this can be a problem
|
|
||||||
- The issues mentioned above (learning rate, normalization, data shuffling) may contribute to this.
|
|
||||||
- Setting the minibatch size to a very small number of examples can also contribute to noisy score vs. iteration graphs, and *might* lead to optimization difficulties
|
|
||||||
|
|
||||||
**Overview Page and Model Page - Using the Update: Parameter Ratio Chart**
|
|
||||||
|
|
||||||
- The ratio of mean magnitude of updates to parameters is provided on both the overview and model pages
|
|
||||||
- "Mean magnitude" = the average of the absolute value of the parameters or updates at the current time step
|
|
||||||
- The most important use of this ratio is in selecting a learning rate. As a rule of thumb: this ratio should be around 1:1000 = 0.001. On the (log<sub>10</sub>) chart, this corresponds to a value of -3 (i.e., 10<sup>-3</sup> = 0.001)
|
|
||||||
- Note that is a rough guide only, and may not be appropriate for all networks. It's often a good starting point, however.
|
|
||||||
- If the ratio diverges significantly from this (for example, > -2 (i.e., 10<sup>-2</sup>=0.01) or < -4 (i.e., 10<sup>-4</sup>=0.0001), your parameters may be too unstable to learn useful features, or may change too slowly to learn useful features
|
|
||||||
- To change this ratio, adjust your learning rate (or sometimes, parameter initialization). In some networks, you may need to set the learning rate differently for different layers.
|
|
||||||
- Keep an eye out for unusually large spikes in the ratio: this may indicate exploding gradients
|
|
||||||
|
|
||||||
|
|
||||||
**Model Page: Layer Activations (vs. Time) Chart**
|
|
||||||
|
|
||||||
This chart can be used to detect vanishing or exploding activations (due to poor weight initialization, too much regularization, lack of data normalization, or too high a learning rate).
|
|
||||||
|
|
||||||
- This chart should ideally stabilize over time (usually a few hundred iterations)
|
|
||||||
- A good standard deviation for the activations is on the order of 0.5 to 2.0. Significantly outside of this range may indicate one of the problems mentioned above.
|
|
||||||
|
|
||||||
**Model Page: Layer Parameters Histogram**
|
|
||||||
|
|
||||||
The layer parameters histogram is displayed for the most recent iteration only.
|
|
||||||
|
|
||||||
- For weights, these histograms should have an approximately Gaussian (normal) distribution, after some time
|
|
||||||
- For biases, these histograms will generally start at 0, and will usually end up being approximately Gaussian
|
|
||||||
- One exception to this is for LSTM recurrent neural network layers: by default, the biases for one gate (the forget gate) are set to 1.0 (by default, though this is configurable), to help in learning dependencies across long time periods. This results in the bias graphs initially having many biases around 0.0, with another set of biases around 1.0
|
|
||||||
- Keep an eye out for parameters that are diverging to +/- infinity: this may be due to too high a learning rate, or insufficient regularization (try adding some L2 regularization to your network).
|
|
||||||
- Keep an eye out for biases that become very large. This can sometimes occur in the output layer for classification, if the distribution of classes is very imbalanced
|
|
||||||
|
|
||||||
**Model Page: Layer Updates Histogram**
|
|
||||||
|
|
||||||
The layer update histogram is displayed for the most recent iteration only.
|
|
||||||
|
|
||||||
- Note that these are the updates - i.e., the gradients *after* applying learning rate, momentum, regularization etc
|
|
||||||
- As with the parameter graphs, these should have an approximately Gaussian (normal) distribution
|
|
||||||
- Keep an eye out for very large values: this can indicate exploding gradients in your network
|
|
||||||
- Exploding gradients are problematic as they can 'mess up' the parameters of your network
|
|
||||||
- In this case, it may indicate a weight initialization, learning rate or input/labels data normalization issue
|
|
||||||
- In the case of recurrent neural networks, adding some [gradient normalization or gradient clipping](https://github.com/eclipse/deeplearning4j/blob/master/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/GradientNormalization.java) may help
|
|
||||||
|
|
||||||
**Model Page: Parameter Learning Rates Chart**
|
|
||||||
|
|
||||||
This chart simply shows the learning rates of the parameters of selected layer, over time.
|
|
||||||
|
|
||||||
If you are not using learning rate schedules, the chart will be flat. If you *are* using learning rate schedules, you can use this chart to track the current value of the learning rate (for each parameter), over time.
|
|
||||||
|
|
||||||
|
|
||||||
## <a name="tsne">TSNE and Word2vec</a>
|
|
||||||
|
|
||||||
We rely on [TSNE](https://lvdmaaten.github.io/tsne/) to reduce the dimensionality of [word feature vectors](./deeplearning4j-nlp-word2vec) and project words into a two or three-dimensional space. Here's some code for using TSNE with Word2Vec:
|
|
||||||
|
|
||||||
```java
|
|
||||||
log.info("Plot TSNE....");
|
|
||||||
BarnesHutTsne tsne = new BarnesHutTsne.Builder()
|
|
||||||
.setMaxIter(1000)
|
|
||||||
.stopLyingIteration(250)
|
|
||||||
.learningRate(500)
|
|
||||||
.useAdaGrad(false)
|
|
||||||
.theta(0.5)
|
|
||||||
.setMomentum(0.5)
|
|
||||||
.normalize(true)
|
|
||||||
.usePca(false)
|
|
||||||
.build();
|
|
||||||
vec.lookupTable().plotVocab(tsne);
|
|
||||||
```
|
|
||||||
|
|
||||||
## <a name="issues">Fixing UI Issue: "No configuration setting" exception</a>
|
|
||||||
|
|
||||||
A possible exception that can occur with the DL4J UI is the following:
|
|
||||||
```
|
|
||||||
com.typesafe.config.ConfigException$Missing: No configuration setting found for key 'play.crypto.provider'
|
|
||||||
at com.typesafe.config.impl.SimpleConfig.findKeyOrNull(SimpleConfig.java:152)
|
|
||||||
at com.typesafe.config.impl.SimpleConfig.findOrNull(SimpleConfig.java:170)
|
|
||||||
...
|
|
||||||
at play.server.Server.forRouter(Server.java:96)
|
|
||||||
at org.deeplearning4j.ui.play.PlayUIServer.runMain(PlayUIServer.java:206)
|
|
||||||
at org.deeplearning4j.ui.api.UIServer.getInstance(UIServer.java:27)
|
|
||||||
```
|
|
||||||
|
|
||||||
This exception is not due to DL4J directly, but is due to a missing application.conf file, required by the Play framework (the library that DL4J's UI is based on). This is originally present in the deeplearning4j-play dependency: however, if an uber-jar (i.e., a JAR file with dependencies) is built (say, via ```mvn package```), it may not be copied over correctly. For example, using the ```maven-assembly-plugin``` has caused this exception for some users.
|
|
||||||
|
|
||||||
The recommended solution (for Maven) is to use the Maven Shade plugin to produce an uber-jar, configured as follows:
|
|
||||||
|
|
||||||
```xml
|
|
||||||
<build>
|
|
||||||
<plugins>
|
|
||||||
<plugin>
|
|
||||||
<groupId>org.codehaus.mojo</groupId>
|
|
||||||
<artifactId>exec-maven-plugin</artifactId>
|
|
||||||
<version>${exec-maven-plugin.version}</version>
|
|
||||||
<executions>
|
|
||||||
<execution>
|
|
||||||
<goals>
|
|
||||||
<goal>exec</goal>
|
|
||||||
</goals>
|
|
||||||
</execution>
|
|
||||||
</executions>
|
|
||||||
<configuration>
|
|
||||||
<executable>java</executable>
|
|
||||||
</configuration>
|
|
||||||
</plugin>
|
|
||||||
<plugin>
|
|
||||||
<groupId>org.apache.maven.plugins</groupId>
|
|
||||||
<artifactId>maven-shade-plugin</artifactId>
|
|
||||||
<version>${maven-shade-plugin.version}</version>
|
|
||||||
<configuration>
|
|
||||||
<shadedArtifactAttached>true</shadedArtifactAttached>
|
|
||||||
<shadedClassifierName>${shadedClassifier}</shadedClassifierName>
|
|
||||||
<createDependencyReducedPom>true</createDependencyReducedPom>
|
|
||||||
<filters>
|
|
||||||
<filter>
|
|
||||||
<artifact>*:*</artifact>
|
|
||||||
<excludes>
|
|
||||||
<!--<exclude>org/datanucleus/**</exclude>-->
|
|
||||||
<exclude>META-INF/*.SF</exclude>
|
|
||||||
<exclude>META-INF/*.DSA</exclude>
|
|
||||||
<exclude>META-INF/*.RSA</exclude>
|
|
||||||
</excludes>
|
|
||||||
</filter>
|
|
||||||
</filters>
|
|
||||||
|
|
||||||
</configuration>
|
|
||||||
<executions>
|
|
||||||
<execution>
|
|
||||||
<phase>package</phase>
|
|
||||||
<goals>
|
|
||||||
<goal>shade</goal>
|
|
||||||
</goals>
|
|
||||||
<configuration>
|
|
||||||
<transformers>
|
|
||||||
<transformer implementation="org.apache.maven.plugins.shade.resource.AppendingTransformer">
|
|
||||||
<resource>reference.conf</resource>
|
|
||||||
</transformer>
|
|
||||||
<transformer implementation="org.apache.maven.plugins.shade.resource.ServicesResourceTransformer"/>
|
|
||||||
<transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer" />
|
|
||||||
</transformers>
|
|
||||||
</configuration>
|
|
||||||
</execution>
|
|
||||||
</executions>
|
|
||||||
</plugin>
|
|
||||||
<plugins>
|
|
||||||
<build>
|
|
||||||
```
|
|
||||||
|
|
||||||
Then, create your uber-jar with ```mvn package``` and run via ```cd target && java -cp dl4j-examples-0.9.1-bin.jar org.deeplearning4j.examples.userInterface.UIExample```. Note the "-bin" suffix for the generated JAR file: this includes all dependencies.
|
|
||||||
|
|
||||||
Note also that this Maven Shade approach is configured for DL4J's examples repository.
|
|
|
@ -1,16 +0,0 @@
|
||||||
# deeplearning4j-scaleout documentation
|
|
||||||
|
|
||||||
Build and serve documentation for deeplearning4j-scaleout with MkDocs (install with `pip install mkdocs`)
|
|
||||||
The source for Keras documentation is in this directory under `doc_sources/`.
|
|
||||||
|
|
||||||
The structure of this project (template files, generating code, mkdocs YAML) is closely aligned
|
|
||||||
with the [Keras documentation](keras.io) and heavily inspired by the [Keras docs repository](https://github.com/keras-team/keras/tree/master/docs).
|
|
||||||
|
|
||||||
To generate docs into the `deeplearning4j-scaleout/doc_sources` folder, first `cd docs` then run:
|
|
||||||
|
|
||||||
```shell
|
|
||||||
python generate_docs.py \
|
|
||||||
--project deeplearning4j-scaleout \
|
|
||||||
--code ../deeplearning4j
|
|
||||||
--out_language en
|
|
||||||
```
|
|
|
@ -1,34 +0,0 @@
|
||||||
{
|
|
||||||
"excludes": [
|
|
||||||
],
|
|
||||||
"indices": [
|
|
||||||
],
|
|
||||||
"pages": [
|
|
||||||
{
|
|
||||||
"page": "intro.md",
|
|
||||||
"class": []
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"page": "technicalref.md",
|
|
||||||
"class": []
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"page": "howto.md",
|
|
||||||
"class": []
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"page": "data-howto.md",
|
|
||||||
"class": []
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"page": "apiref.md",
|
|
||||||
"class": [
|
|
||||||
"deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/SparkDl4jMultiLayer.java",
|
|
||||||
"deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/SparkComputationGraph.java",
|
|
||||||
"deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/training/SharedTrainingMaster.java",
|
|
||||||
"deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingMaster.java"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
|
@ -1,13 +0,0 @@
|
||||||
---
|
|
||||||
title: "Deeplearning4j on Spark: API Reference"
|
|
||||||
short_title: API Reference
|
|
||||||
description: "Deeplearning4j on Spark: API Reference"
|
|
||||||
category: Distributed Deep Learning
|
|
||||||
weight: 4
|
|
||||||
---
|
|
||||||
|
|
||||||
# API Reference
|
|
||||||
|
|
||||||
This page provides the API reference for key classes required to do distributed training with DL4J on Spark. Before going through these, make sure you have read the introduction guide for deeplearning4j Spark training [here](deeplearning4j-scaleout-intro).
|
|
||||||
|
|
||||||
{{autogenerated}}
|
|
|
@ -1,490 +0,0 @@
|
||||||
---
|
|
||||||
title: "Deeplearning4j on Spark: How To Build Data Pipelines"
|
|
||||||
short_title: Spark Data Pipelines Guide
|
|
||||||
description: "Deeplearning4j on Spark: How To Build Data Pipelines"
|
|
||||||
category: Distributed Deep Learning
|
|
||||||
weight: 3
|
|
||||||
---
|
|
||||||
|
|
||||||
# Deeplearning4j on Spark: How To Build Data Pipelines
|
|
||||||
|
|
||||||
This page provides some guides on how to create data pipelines for both training and evaluation when using Deeplearning4j on Spark.
|
|
||||||
|
|
||||||
This page assumes some familiarity with Spark (RDDs, master vs. workers, etc) and Deeplearning4j (networks, DataSet etc).
|
|
||||||
|
|
||||||
As with training on a single machine, the final step of a data pipeline should be to produce a DataSet (single features arrays, single label array) or MultiDataSet (one or more feature arrays, one or more label arrays). In the case of DL4J on Spark, the final step of a data pipeline is data in one of the following formats:
|
|
||||||
(a) an ```RDD<DataSet>```/```JavaRDD<DataSet>```
|
|
||||||
(b) an ```RDD<MultiDataSet>```/```JavaRDD<MultiDataSet>```
|
|
||||||
(c) a directory of serialized DataSet/MultiDataSet (minibatch) objects on network storage such as HDFS, S3 or Azure blob storage
|
|
||||||
(d) a directory of minibatches in some other format
|
|
||||||
|
|
||||||
Once data is in one of those four formats, it can be used for training or evaluation.
|
|
||||||
|
|
||||||
**Note:** When training multiple models on a single dataset, it is best practice to preprocess your data once, and save it to network storage such as HDFS.
|
|
||||||
Then, when training the network you can call ```SparkDl4jMultiLayer.fit(String path)``` or ```SparkComputationGraph.fit(String path)``` where ```path``` is the directory where you saved the files.
|
|
||||||
|
|
||||||
|
|
||||||
Spark Data Prepration: How-To Guides
|
|
||||||
* [How to prepare a RDD[DataSet] from CSV data for classification or regression](#csv)
|
|
||||||
* [How to create a Spark data pipeline for training on images](#images)
|
|
||||||
* [How to create a RDD[MultiDataSet] from one or more RDD[List[Writable]]](#multidataset)
|
|
||||||
* [How to save a RDD[DataSet] or RDD[MultiDataSet] to network storage and use it for training](#saveloadrdd)
|
|
||||||
* [How to prepare data on a single machine for use on a cluster: saving DataSets](#singletocluster)
|
|
||||||
* [How to prepare data on a single machine for use on a cluster: map/sequence files](#singletocluster2)
|
|
||||||
* [How to load multiple CSVs (one sequence per file) for RNN data pipelines](#csvseq)
|
|
||||||
* [How to load prepared minibatches in custom format](#customformat)
|
|
||||||
|
|
||||||
<br><br>
|
|
||||||
|
|
||||||
## <a name="csv">How to prepare a RDD[DataSet] from CSV data for classification or regression</a>
|
|
||||||
|
|
||||||
This guide shows how to load data contained in one or more CSV files and produce a ```JavaRDD<DataSet>``` for export, training or evaluation on Spark.
|
|
||||||
|
|
||||||
The process is fairly straightforward. Note that the ```DataVecDataSetFunction``` is very similar to the ```RecordReaderDataSetIterator``` that is often used for single machine training.
|
|
||||||
|
|
||||||
For example, suppose the CSV had the following format - 6 total columns: 5 features followed by an integer class index for classification, and 10 possible classes
|
|
||||||
|
|
||||||
```
|
|
||||||
1.0,3.2,4.5,1.1,6.3,0
|
|
||||||
1.6,2.4,5.9,0.2,2.2,1
|
|
||||||
...
|
|
||||||
```
|
|
||||||
|
|
||||||
we could load this data for classification using the following code:
|
|
||||||
```
|
|
||||||
String filePath = "hdfs:///your/path/some_csv_file.csv";
|
|
||||||
JavaSparkContext sc = new JavaSparkContext();
|
|
||||||
JavaRDD<String> rddString = sc.textFile(filePath);
|
|
||||||
RecordReader recordReader = new CSVRecordReader(',');
|
|
||||||
JavaRDD<List<Writable>> rddWritables = rddString.map(new StringToWritablesFunction(recordReader));
|
|
||||||
|
|
||||||
int labelIndex = 5; //Labels: a single integer representing the class index in column number 5
|
|
||||||
int numLabelClasses = 10; //10 classes for the label
|
|
||||||
JavaRDD<DataSet> rddDataSetClassification = rddWritables.map(new DataVecDataSetFunction(labelIndex, numLabelClasses, false));
|
|
||||||
```
|
|
||||||
|
|
||||||
However, if this dataset was for regression instead, with again 6 total columns, 3 feature columns (positions 0, 1 and 2 in the file rows) and 3 label columns (positions 3, 4 and 5) we could load it using the same process as above, but changing the last 3 lines to:
|
|
||||||
|
|
||||||
```
|
|
||||||
int firstLabelColumn = 3; //First column index for label
|
|
||||||
int lastLabelColumn = 5; //Last column index for label
|
|
||||||
JavaRDD<DataSet> rddDataSetRegression = rddWritables.map(new DataVecDataSetFunction(firstColumnLabel, lastColumnLabel, true, null, null));
|
|
||||||
```
|
|
||||||
|
|
||||||
<br><br>
|
|
||||||
|
|
||||||
## <a name="multidataset">How to create a RDD[MultiDataSet] from one or more RDD[List[Writable]]</a>
|
|
||||||
|
|
||||||
RecordReaderMultiDataSetIterator (RRMDSI) is the most common way to create MultiDataSet instances for single-machine training data pipelines.
|
|
||||||
It is possible to use RRMDSI for Spark data pipelines, where data is coming from one or more of ```RDD<List<Writable>>``` (for 'standard' data) or ```RDD<List<List<Writable>>``` (for sequence data).
|
|
||||||
|
|
||||||
**Case 1: Single ```RDD<List<Writable>>``` to ```RDD<MultiDataSet>```**
|
|
||||||
|
|
||||||
Consider the following *single node* (non-Spark) data pipeline for a CSV classification task.
|
|
||||||
```
|
|
||||||
RecordReader recordReader = new CSVRecordReader(numLinesToSkip,delimiter);
|
|
||||||
recordReader.initialize(new FileSplit(new ClassPathResource("iris.txt").getFile()));
|
|
||||||
|
|
||||||
int batchSize = 32;
|
|
||||||
int labelColumn = 4;
|
|
||||||
int numClasses = 3;
|
|
||||||
MultiDataSetIterator iter = new RecordReaderMultiDataSetIterator.Builder(batchSize)
|
|
||||||
.addReader("data", recordReader)
|
|
||||||
.addInput("data", 0, labelColumn-1)
|
|
||||||
.addOutputOneHot("data", labelColumn, numClasses)
|
|
||||||
.build();
|
|
||||||
```
|
|
||||||
|
|
||||||
The equivalent to the following Spark data pipeline:
|
|
||||||
|
|
||||||
```
|
|
||||||
JavaRDD<List<Writable>> rdd = sc.textFile(f.getPath()).map(new StringToWritablesFunction(new CSVRecordReader()));
|
|
||||||
|
|
||||||
MultiDataSetIterator iter = new RecordReaderMultiDataSetIterator.Builder(batchSize)
|
|
||||||
.addReader("data", new SparkSourceDummyReader(0)) //Note the use of the "SparkSourceDummyReader"
|
|
||||||
.addInput("data", 0, labelColumn-1)
|
|
||||||
.addOutputOneHot("data", labelColumn, numClasses)
|
|
||||||
.build();
|
|
||||||
JavaRDD<MultiDataSet> mdsRdd = IteratorUtils.mapRRMDSI(rdd, rrmdsi2);
|
|
||||||
```
|
|
||||||
|
|
||||||
For Sequence data (```List<List<Writable>>```) you can use SparkSourceDummySeqReader instead.
|
|
||||||
|
|
||||||
**Case 2: Multiple ```RDD<List<Writable>>``` or ```RDD<List<List<Writable>>``` to ```RDD<MultiDataSet>```**
|
|
||||||
|
|
||||||
For this case, the process is much the same. However, internaly, a join is used.
|
|
||||||
|
|
||||||
```
|
|
||||||
JavaRDD<List<Writable>> rdd1 = ...
|
|
||||||
JavaRDD<List<Writable>> rdd2 = ...
|
|
||||||
|
|
||||||
RecordReaderMultiDataSetIterator rrmdsi = new RecordReaderMultiDataSetIterator.Builder(batchSize)
|
|
||||||
.addReader("rdd1", new SparkSourceDummyReader(0)) //0 = use first rdd in list
|
|
||||||
.addReader("rdd2", new SparkSourceDummyReader(1)) //1 = use second rdd in list
|
|
||||||
.addInput("rdd1", 1, 2) //
|
|
||||||
.addOutput("rdd2", 1, 2)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
List<JavaRDD<List<Writable>>> list = Arrays.asList(rdd1, rdd2);
|
|
||||||
int[] keyIdxs = new int[]{0,0}; //Column 0 in rdd1 and rdd2 is the 'key' used for joining
|
|
||||||
boolean filterMissing = false; //If true: filter out any records that don't have matching keys in all RDDs
|
|
||||||
JavaRDD<MultiDataSet> mdsRdd = IteratorUtils.mapRRMDSI(list, null, keyIdxs, null, filterMissing, rrmdsi);
|
|
||||||
```
|
|
||||||
|
|
||||||
<br><br>
|
|
||||||
|
|
||||||
## <a name="saveloadrdd">How to save a RDD[DataSet] or RDD[MultiDataSet] to network storage and use it for training</a>
|
|
||||||
|
|
||||||
As noted at the start of this page, it is considered a best practice to preprocess and export your data once (i.e., save to network storage such as HDFS and reuse), rather than fitting from an ```RDD<DataSet>``` or ```RDD<MultiDataSet>``` directly in each training job.
|
|
||||||
|
|
||||||
There are a number of reasons for this:
|
|
||||||
* Better performance (avoid redundant loading/calculation): When fitting multiple models from the same dataset, it is faster to preprocess this data once and save to disk rather than preprocessing it again for every single training run.
|
|
||||||
* Minimizing memory and other resources: By exporting and fitting from disk, we only need to keep the DataSets we are currently using (plus a small async prefetch buffer) in memory, rather than also keeping many unused DataSet objects in memory. Exporting results in lower total memory use and hence we can use larger networks, larger minibatch sizes, or allocate fewer resources to our job.
|
|
||||||
* Avoiding recomputation: When an RDD is too large to fit into memory, some parts of it may need to be recomputed before it can be used (depending on the cache settings). When this occurs, Spark will recompute parts of the data pipeline multiple times, costing us both time and memory. A pre-export step avoids this recomputation entirely.
|
|
||||||
|
|
||||||
**Step 1: Saving**
|
|
||||||
|
|
||||||
Saving the DataSet objects once you have an ```RDD<DataSet>``` is quite straightforward:
|
|
||||||
```
|
|
||||||
JavaRDD<DataSet> rddDataSet = ...
|
|
||||||
int minibatchSize = 32; //Minibatch size of the saved DataSet objects
|
|
||||||
String exportPath = "hdfs:///path/to/export/data";
|
|
||||||
JavaRDD<String> paths = rddDataSet.mapPartitionsWithIndex(new BatchAndExportDataSetsFunction(minibatchSize, exportPath), true);
|
|
||||||
```
|
|
||||||
Keep in mind that this is a map function, so no data will be saved until the paths RDD is executed - i.e., you should follow this with an operation such as:
|
|
||||||
```
|
|
||||||
paths.saveAsTextFile("hdfs:///path/to/text/file.txt"); //Specified file will contain paths/URIs of all saved DataSet objects
|
|
||||||
```
|
|
||||||
or
|
|
||||||
```
|
|
||||||
List<String> paths = paths.collect(); //Collection of paths/URIs of all saved DataSet objects
|
|
||||||
```
|
|
||||||
or
|
|
||||||
```
|
|
||||||
paths.foreach(new VoidFunction<String>() {
|
|
||||||
@Override
|
|
||||||
public void call(String path) {
|
|
||||||
//Some operation on each path
|
|
||||||
}
|
|
||||||
});
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
Saving an ```RDD<MultiDataSet>``` can be done in the same way using ```BatchAndExportMultiDataSetsFunction``` instead, which takes the same arguments.
|
|
||||||
|
|
||||||
**Step 2: Loading and Fitting**
|
|
||||||
|
|
||||||
The exported data can be used in a few ways.
|
|
||||||
First, it can be used to fit a network directly:
|
|
||||||
```
|
|
||||||
String exportPath = "hdfs:///path/to/export/data";
|
|
||||||
SparkDl4jMultiLayer net = ...
|
|
||||||
net.fit(exportPath); //Loads the serialized DataSet objects found in the 'exportPath' directory
|
|
||||||
```
|
|
||||||
Similarly, we can use ```SparkComputationGraph.fitMultiDataSet(String path)``` if we saved an ```RDD<MultiDataSet>``` instead.
|
|
||||||
|
|
||||||
|
|
||||||
Alternatively, we can load up the paths in a few different ways, depending on if or how we saved them:
|
|
||||||
|
|
||||||
```
|
|
||||||
JavaSparkContext sc = new JavaSparkContext();
|
|
||||||
|
|
||||||
//If we used saveAsTextFile:
|
|
||||||
String saveTo = "hdfs:///path/to/text/file.txt";
|
|
||||||
paths.saveAsTextFile(saveTo); //Save
|
|
||||||
JavaRDD<String> loadedPaths = sc.textFile(saveTo); //Load
|
|
||||||
|
|
||||||
//If we used collecting:
|
|
||||||
List<String> paths = paths.collect(); //Collect
|
|
||||||
JavaRDD<String> loadedPaths = sc.parallelize(paths); //Parallelize
|
|
||||||
|
|
||||||
//If we want to list the directory contents:
|
|
||||||
String exportPath = "hdfs:///path/to/export/data";
|
|
||||||
JavaRDD<String> loadedPaths = SparkUtils.listPaths(sc, exportPath); //List paths using org.deeplearning4j.spark.util.SparkUtils
|
|
||||||
```
|
|
||||||
|
|
||||||
Then we can execute training on these paths by using methods such as ```SparkDl4jMultiLayer.fitPaths(JavaRDD<String>)```
|
|
||||||
|
|
||||||
|
|
||||||
<br><br>
|
|
||||||
|
|
||||||
## <a name="singletocluster">How to prepare data on a single machine for use on a cluster: saving DataSets</a>
|
|
||||||
|
|
||||||
Another possible workflow is to start with the data pipeline on a single machine, and export the DataSet or MultiDataSet objects for use on the cluster.
|
|
||||||
This workflow clearly isn't as scalable as preparing data on a cluster (you are using just one machine to prepare data) but it can be an easy option in some cases, especially when you have an existing data pipeline.
|
|
||||||
|
|
||||||
This section assumes you have an existing ```DataSetIterator``` or ```MultiDataSetIterator``` used for single-machine training. There are many different ways to create one, which is outside of the scope of this guide.
|
|
||||||
|
|
||||||
**Step 1: Save the DataSets or MultiDataSets**
|
|
||||||
|
|
||||||
Saving the contents of a DataSet to a local directory can be done using the following code:
|
|
||||||
```
|
|
||||||
DataSetIterator iter = ...
|
|
||||||
File rootDir = new File("/saving/directory/");
|
|
||||||
int count = 0;
|
|
||||||
while(iter.hasNext()){
|
|
||||||
DataSet ds = iter.next();
|
|
||||||
File outFile = new File(rootDir, "dataset_" + (count++) + ".bin");
|
|
||||||
ds.save(outFile);
|
|
||||||
}
|
|
||||||
```
|
|
||||||
Note that for the purposes of Spark, the exact file names don't matter.
|
|
||||||
The process for saving MultiDataSets is almost identical.
|
|
||||||
|
|
||||||
As an aside: you can read these saved DataSet objects on a single machine (for non-Spark training) using [FileDataSetIterator](https://github.com/eclipse/deeplearning4j/blob/master/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/file/FileDataSetIterator.java)).
|
|
||||||
|
|
||||||
An alternative approach is to save directly to the cluster using output streams, to (for example) HDFS. This can only be done if the machine running the code is properly configured with the required libraries and access rights. For example, to save the DataSets directly to HDFS you could use:
|
|
||||||
|
|
||||||
```
|
|
||||||
JavaSparkContext sc = new JavaSparkContext();
|
|
||||||
FileSystem fileSystem = FileSystem.get(sc.hadoopConfiguration());
|
|
||||||
String outputDir = "hdfs:///my/output/location/";
|
|
||||||
|
|
||||||
DataSetIterator iter = ...
|
|
||||||
int count = 0;
|
|
||||||
while(iter.hasNext()){
|
|
||||||
DataSet ds = iter.next();
|
|
||||||
String filePath = outputDir + "dataset_" + (count++) + ".bin";
|
|
||||||
try (OutputStream os = new BufferedOutputStream(fileSystem.create(new Path(outputPath)))) {
|
|
||||||
ds.save(os);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
**Step 2: Load and Train on a Cluster**
|
|
||||||
The saved DataSet objects can then be copied to the cluster or network file storage (for example, using Hadoop FS utilities on a Hadoop cluster), and used as follows:
|
|
||||||
```
|
|
||||||
String dir = "hdfs:///data/copied/here";
|
|
||||||
SparkDl4jMultiLayer net = ...
|
|
||||||
net.fit(dir); //Loads the serialized DataSet objects found in the 'dir' directory
|
|
||||||
```
|
|
||||||
or alternatively/equivalently, we can list the paths as an RDD using:
|
|
||||||
```
|
|
||||||
String dir = "hdfs:///data/copied/here";
|
|
||||||
JavaRDD<String> paths = SparkUtils.listPaths(sc, dir); //List paths using org.deeplearning4j.spark.util.SparkUtils
|
|
||||||
```
|
|
||||||
|
|
||||||
<br><br>
|
|
||||||
|
|
||||||
## <a name="singletocluster2">How to prepare data on a single machine for use on a cluster: map/sequence files</a>
|
|
||||||
|
|
||||||
An alternative approach is to use Hadoop MapFile and SequenceFiles, which are efficient binary storage formats.
|
|
||||||
This can be used to convert the output of any DataVec ```RecordReader``` or ```SequenceRecordReader``` (including a custom record reader) to a format usable for use on Spark.
|
|
||||||
MapFileRecordWriter and MapFileSequenceRecordWriter require the following dependencies:
|
|
||||||
```
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.datavec</groupId>
|
|
||||||
<artifactId>datavec-hadoop</artifactId>
|
|
||||||
<version>${datavec.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.apache.hadoop</groupId>
|
|
||||||
<artifactId>hadoop-common</artifactId>
|
|
||||||
<version>${hadoop.version}</version>
|
|
||||||
<!-- Optional exclusion for log4j in case you are using other logging frameworks -->
|
|
||||||
<!--
|
|
||||||
<exclusions>
|
|
||||||
<exclusion>
|
|
||||||
<groupId>log4j</groupId>
|
|
||||||
<artifactId>log4j</artifactId>
|
|
||||||
</exclusion>
|
|
||||||
<exclusion>
|
|
||||||
<groupId>org.slf4j</groupId>
|
|
||||||
<artifactId>slf4j-log4j12</artifactId>
|
|
||||||
</exclusion>
|
|
||||||
</exclusions>
|
|
||||||
-->
|
|
||||||
</dependency>
|
|
||||||
```
|
|
||||||
|
|
||||||
**Step 1: Create a MapFile Locally**
|
|
||||||
In the following example, a CSVRecordReader will be used, but any other RecordReader could be used in its place:
|
|
||||||
```
|
|
||||||
File csvFile = new File("/path/to/file.csv")
|
|
||||||
RecordReader recordReader = new CSVRecordReader();
|
|
||||||
recordReader.initialize(new FileSplit(csvFile));
|
|
||||||
|
|
||||||
//Create map file writer
|
|
||||||
String outPath = "/map/file/root/dir"
|
|
||||||
MapFileRecordWriter writer = new MapFileRecordWriter(new File(outPath));
|
|
||||||
|
|
||||||
//Convert to MapFile binary format:
|
|
||||||
RecordReaderConverter.convert(recordReader, writer);
|
|
||||||
```
|
|
||||||
|
|
||||||
The process for using a ```SequenceRecordReader``` combined with a ```MapFileSequenceRecordWriter``` is virtually the same.
|
|
||||||
|
|
||||||
Note also that ```MapFileRecordWriter``` and ```MapFileSequenceRecordWriter``` both support splitting - i.e., creating multiple smaller map files instead of creating one single (potentially multi-GB) map file. Using splitting is recommended when saving data in this manner for use with Spark.
|
|
||||||
|
|
||||||
**Step 2: Copy to HDFS or other network file storage**
|
|
||||||
|
|
||||||
The exact process is beyond the scope of this guide. However, it should be sufficient to simply copy the directory ("/map/file/root/dir" in the example above) to a location on HDFS.
|
|
||||||
|
|
||||||
**Step 3: Read and Convert to ```RDD<DataSet>``` for Training**
|
|
||||||
|
|
||||||
We can load the data for training using the following:
|
|
||||||
```
|
|
||||||
JavaSparkContext sc = new JavaSparkContext();
|
|
||||||
String pathOnHDFS = "hdfs:///map/file/directory";
|
|
||||||
JavaRDD<List<Writable>> rdd = SparkStorageUtils.restoreMapFile(pathOnHDFS, sc); //import: org.datavec.spark.storage.SparkStorageUtils
|
|
||||||
|
|
||||||
//Note at this point: it's the same as the latter part of the CSV how-to guide
|
|
||||||
int labelIndex = 5; //Labels: a single integer representing the class index in column number 5
|
|
||||||
int numLabelClasses = 10; //10 classes for the label
|
|
||||||
JavaRDD<DataSet> rddDataSetClassification = rdd.map(new DataVecDataSetFunction(labelIndex, numLabelClasses, false));
|
|
||||||
```
|
|
||||||
|
|
||||||
<br><br>
|
|
||||||
|
|
||||||
## <a name="csvseq">How to load multiple CSVs (one sequence per file) for RNN data pipelines</a>
|
|
||||||
|
|
||||||
This guide shows how load CSV files for training an RNN.
|
|
||||||
The assumption is that the dataset is comprised of multiple CSV files, where:
|
|
||||||
|
|
||||||
* each CSV file represents one sequence
|
|
||||||
* each row/line of the CSV contains the values for one time step (one or more columns/values, same number of values in all rows for all files)
|
|
||||||
* each CSV may contain a different number of lines to other CSVs (i.e., variable length sequences are OK here)
|
|
||||||
* header lines either aren't present in any files, or are present in all files
|
|
||||||
|
|
||||||
A data pipeline can be created using the following process:
|
|
||||||
```
|
|
||||||
String directoryWithCsvFiles = "hdfs:///path/to/directory";
|
|
||||||
JavaPairRDD<String, PortableDataStream> origData = sc.binaryFiles(directoryWithCsvFiles);
|
|
||||||
|
|
||||||
int numHeaderLinesEachFile = 0; //No header lines
|
|
||||||
int delimiter = ","; //Comma delimited files
|
|
||||||
SequenceRecordReader seqRR = new CSVSequenceRecordReader(numHeaderLinesEachFile, delimiter);
|
|
||||||
|
|
||||||
JavaRDD<List<List<Writable>>> sequencesRdd = origData.map(new SequenceRecordReaderFunction(seqRR));
|
|
||||||
|
|
||||||
//Similar to the non-sequence CSV guide using DataVecDataSetFunction. Assuming classification here:
|
|
||||||
int labelIndex = 5; //Index of the label column. Occurs at position/column 5
|
|
||||||
int numClasses = 10; //Number of classes for classification
|
|
||||||
JavaRDD<DataSet> dataSetRdd = sequencesRdd.map(new DataVecSequenceDataSetFunction(labelIndex, numClasses, false));
|
|
||||||
```
|
|
||||||
|
|
||||||
<br><br>
|
|
||||||
|
|
||||||
## <a name="images">How to create a Spark data pipeline for training on images</a>
|
|
||||||
|
|
||||||
This guide shows how to create an ```RDD<DataSet>``` for image classification, starting from images stored either locally, or on a network file system such as HDFS.
|
|
||||||
|
|
||||||
The approach here used (added in 1.0.0-beta3) is to first preprocess the images into batches of files - [FileBatch](https://github.com/eclipse/deeplearning4j/blob/master/nd4j/nd4j-common/src/main/java/org/nd4j/api/loader/FileBatch.java) objects.
|
|
||||||
The motivation for this approach is simple: the original image files typically use efficient compresion (JPEG for example) which is much more space (and network) efficient than a bitmap (int8 or 32-bit floating point) representation. However, on a cluster we want to minimize disk reads due to latency issues with remote storage - one file read/transfer is going to be faster than ```minibatchSize``` remote file reads.
|
|
||||||
|
|
||||||
The [TinyImageNet example](https://github.com/eclipse/deeplearning4j-examples/tree/master/dl4j-spark-examples/dl4j-spark/src/main/java/org/deeplearning4j/tinyimagenet) also shows how this can be done.
|
|
||||||
|
|
||||||
Note that one limitation of the implementation is that the set of classes (i.e., the class/category labels when doing classification) needs to be known, provided or collected manually. This differs from using ImageRecordReader for classification on a single machine, which can automatically infer the set of class labels.
|
|
||||||
|
|
||||||
First, assume the images are in subdirectories based on their class labels. For example, suppose there are two classes, "cat" and "dog", the directory structure would look like:
|
|
||||||
```
|
|
||||||
rootDir/cat/img0.jpg
|
|
||||||
rootDir/cat/img1.jpg
|
|
||||||
...
|
|
||||||
rootDir/dog/img0.jpg
|
|
||||||
rootDir/dog/img1.jpg
|
|
||||||
...
|
|
||||||
```
|
|
||||||
(Note the file names don't matter in this example - however, the parent directory names are the class labels)
|
|
||||||
|
|
||||||
**Step 1 (option 1 of 2): Preprocess Locally**
|
|
||||||
|
|
||||||
Local preprocessing can be done as follows:
|
|
||||||
```
|
|
||||||
String sourceDirectory = "/home/user/my_images"; //Where your data is located
|
|
||||||
String destinationDirectory = "/home/user/preprocessed"; //Where the preprocessed data should be written
|
|
||||||
int batchSize = 32; //Number of examples (images) in each FileBatch object
|
|
||||||
SparkDataUtils.createFileBatchesLocal(sourceDirectory, NativeImageLoader.ALLOWED_FORMATS, true, saveDirTrain, batchSize);
|
|
||||||
```
|
|
||||||
|
|
||||||
The full import for SparkDataUtils is ```org.deeplearning4j.spark.util.SparkDataUtils```.
|
|
||||||
|
|
||||||
After preprocessing is has been completed, the directory can be copied to the cluster for use in training (Step 2).
|
|
||||||
|
|
||||||
**Step 1 (option 2 of 2): Preprocess using Spark**
|
|
||||||
|
|
||||||
Alternatively, if the original images are on remote file storage (such as HDFS), we can use the following:
|
|
||||||
```
|
|
||||||
```
|
|
||||||
String sourceDirectory = "hdfs:///data/my_images"; //Where your data is located
|
|
||||||
String destinationDirectory = "hdfs:///data/preprocessed"; //Where the preprocessed data should be written
|
|
||||||
int batchSize = 32; //Number of examples (images) in each FileBatch object
|
|
||||||
SparkDataUtils.createFileBatchesSpark(sourceDirectory, destinationDirectory, batchSize, sparkContext);
|
|
||||||
```
|
|
||||||
```
|
|
||||||
|
|
||||||
**Step 2: Training**
|
|
||||||
The data pipeline for image classification can be constructed as follows. This code is taken from the [TinyImageNet example](https://github.com/eclipse/deeplearning4j-examples/blob/master/dl4j-spark-examples/dl4j-spark/src/main/java/org/deeplearning4j/tinyimagenet/TrainSpark.java):
|
|
||||||
```
|
|
||||||
//Create data loader
|
|
||||||
int imageHeightWidth = 64; //64x64 pixel input to network
|
|
||||||
int imageChannels = 3; //RGB
|
|
||||||
PathLabelGenerator labelMaker = new ParentPathLabelGenerator();
|
|
||||||
ImageRecordReader rr = new ImageRecordReader(imageHeightWidth, imageHeightWidth, imageChannels, labelMaker);
|
|
||||||
rr.setLabels(Arrays.asList("cat", "dog"));
|
|
||||||
int numClasses = 2;
|
|
||||||
RecordReaderFileBatchLoader loader = new RecordReaderFileBatchLoader(rr, minibatch, 1, numClasses);
|
|
||||||
loader.setPreProcessor(new ImagePreProcessingScaler()); //Scale 0-255 valued pixels to 0-1 range
|
|
||||||
|
|
||||||
|
|
||||||
//Fit the network
|
|
||||||
String trainDataPath = "hdfs:///data/preprocessed"; //Where the preprocessed data is located
|
|
||||||
JavaRDD<String> pathsTrain = SparkUtils.listPaths(sc, trainDataPath);
|
|
||||||
for (int i = 0; i < numEpochs; i++) {
|
|
||||||
sparkNet.fitPaths(pathsTrain, loader);
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
And that's it.
|
|
||||||
|
|
||||||
Note: for other label generation cases (such as labels provided from the filename instead of parent directory), or for tasks such as semantic segmentation, you can substitute a different PathLabelGenerator instead of the default. For example, if the label should come from the file name, you can use ```PatternPathLabelGenerator``` instead.
|
|
||||||
Let's say images are in the format "cat_img1234.jpg", "dog_2309.png" etc. We can use the following process:
|
|
||||||
```
|
|
||||||
PathLabelGenerator labelGenerator = new PatternPathLabelGenerator("_", 0); //Split on the "_" character, and take the first value
|
|
||||||
ImageRecordReader imageRecordReader = new ImageRecordReader(imageHW, imageHW, imageChannels, labelGenerator);
|
|
||||||
```
|
|
||||||
|
|
||||||
Note that PathLabelGenerator returns a Writable object, so for tasks like image segmentation, you can return an INDArray using the NDArrayWritable class in a custom PathLabelGenerator.
|
|
||||||
|
|
||||||
<br><br>
|
|
||||||
|
|
||||||
## <a name="customformat">How to load prepared minibatches in custom format</a>
|
|
||||||
|
|
||||||
DL4J Spark training supports the ability to load data serialized in a custom format. The assumption is that each file on the remote/network storage represents a single minibatch of data in some readable format.
|
|
||||||
|
|
||||||
Note that this approach is typically not required or recommended for most users, but is provided as an additional option for advanced users or those with pre-prepared data in a custom format or a format that is not natively supported by DL4J.
|
|
||||||
When files represent a single record/example (instead of a minibatch) in a custom format, a custom RecordReader could be used instead.
|
|
||||||
|
|
||||||
The interfaces of note are:
|
|
||||||
|
|
||||||
* [DataSetLoader](https://github.com/eclipse/deeplearning4j/blob/master/deeplearning4j/deeplearning4j-core/src/main/java/org/deeplearning4j/api/loader/DataSetLoader.java)
|
|
||||||
* [MultiDataSetLoader](https://github.com/eclipse/deeplearning4j/blob/master/deeplearning4j/deeplearning4j-core/src/main/java/org/deeplearning4j/api/loader/MultiDataSetLoader.java)
|
|
||||||
|
|
||||||
Both of which extend the single-method [Loader](https://github.com/eclipse/deeplearning4j/blob/master/nd4j/nd4j-common/src/main/java/org/nd4j/api/loader/Loader.java) interface.
|
|
||||||
|
|
||||||
Suppose a HDFS directory contains a number of files, each being a minibatch in some custom format.
|
|
||||||
These can be loaded using the following process:
|
|
||||||
```
|
|
||||||
JavaSparkContext sc = new JavaSparkContext();
|
|
||||||
String dataDirectory = "hdfs:///path/with/data";
|
|
||||||
JavaRDD<String> loadedPaths = SparkUtils.listPaths(sc, dataDirectory); //List paths using org.deeplearning4j.spark.util.SparkUtils
|
|
||||||
|
|
||||||
SparkDl4jMultiLayer net = ...
|
|
||||||
Loader<DataSet> myCustomLoader = new MyCustomLoader();
|
|
||||||
net.fitPaths(loadedPaths, myCustomLoader);
|
|
||||||
```
|
|
||||||
|
|
||||||
Where the custom loader class looks something like:
|
|
||||||
```
|
|
||||||
public class MyCustomLoader implements DataSetLoader {
|
|
||||||
@Override
|
|
||||||
public DataSet load(Source source) throws IOException {
|
|
||||||
InputStream inputStream = source.getInputStream();
|
|
||||||
<load custom data format here>
|
|
||||||
INDArray features = ...;
|
|
||||||
INDArray labels = ...;
|
|
||||||
return new DataSet(features, labels);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue