diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/writable/batch/NDArrayRecordBatch.java b/datavec/datavec-api/src/main/java/org/datavec/api/writable/batch/NDArrayRecordBatch.java index 0a5ddddb8..e9b78e390 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/writable/batch/NDArrayRecordBatch.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/writable/batch/NDArrayRecordBatch.java @@ -52,6 +52,7 @@ public class NDArrayRecordBatch extends AbstractWritableRecordBatch { public NDArrayRecordBatch(@NonNull List arrays){ Preconditions.checkArgument(arrays.size() > 0, "Input list must not be empty"); this.arrays = arrays; + this.size = arrays.get(0).size(0); //Check that dimension 0 matches: if(arrays.size() > 1){ diff --git a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/BaseImageLoader.java b/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/BaseImageLoader.java index d2518eeb6..0bfddeb32 100644 --- a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/BaseImageLoader.java +++ b/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/BaseImageLoader.java @@ -16,6 +16,7 @@ package org.datavec.image.loader; +import lombok.extern.slf4j.Slf4j; import org.apache.commons.io.FileUtils; import org.datavec.image.data.Image; import org.datavec.image.transform.ImageTransform; @@ -35,10 +36,9 @@ import java.util.Random; /** * Created by nyghtowl on 12/17/15. */ +@Slf4j public abstract class BaseImageLoader implements Serializable { - protected static final Logger log = LoggerFactory.getLogger(BaseImageLoader.class); - public enum MultiPageMode { MINIBATCH, FIRST //, CHANNELS, } @@ -62,13 +62,37 @@ public abstract class BaseImageLoader implements Serializable { public abstract INDArray asRowVector(InputStream inputStream) throws IOException; + /** As per {@link #asMatrix(File, boolean)} but NCHW/channels_first format */ public abstract INDArray asMatrix(File f) throws IOException; + /** + * Load an image from a file to an INDArray + * @param f File to load the image from + * @param nchw If true: return image in NCHW/channels_first [1, channels, height width] format; if false, return + * in NHWC/channels_last [1, height, width, channels] format + * @return Image file as as INDArray + */ + public abstract INDArray asMatrix(File f, boolean nchw) throws IOException; + public abstract INDArray asMatrix(InputStream inputStream) throws IOException; + /** + * Load an image file from an input stream to an INDArray + * @param inputStream Input stream to load the image from + * @param nchw If true: return image in NCHW/channels_first [1, channels, height width] format; if false, return + * in NHWC/channels_last [1, height, width, channels] format + * @return Image file stream as as INDArray + */ + public abstract INDArray asMatrix(InputStream inputStream, boolean nchw) throws IOException; + /** As per {@link #asMatrix(File)} but as an {@link Image}*/ public abstract Image asImageMatrix(File f) throws IOException; + /** As per {@link #asMatrix(File, boolean)} but as an {@link Image}*/ + public abstract Image asImageMatrix(File f, boolean nchw) throws IOException; + /** As per {@link #asMatrix(InputStream)} but as an {@link Image}*/ public abstract Image asImageMatrix(InputStream inputStream) throws IOException; + /** As per {@link #asMatrix(InputStream, boolean)} but as an {@link Image}*/ + public abstract Image asImageMatrix(InputStream inputStream, boolean nchw) throws IOException; public static void downloadAndUntar(Map urlMap, File fullDir) { diff --git a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/CifarLoader.java b/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/CifarLoader.java index 3d390c698..e513ebed3 100644 --- a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/CifarLoader.java +++ b/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/CifarLoader.java @@ -16,6 +16,7 @@ package org.datavec.image.loader; +import lombok.extern.slf4j.Slf4j; import org.apache.commons.io.FileUtils; import org.apache.commons.io.FilenameUtils; import org.bytedeco.javacv.OpenCVFrameConverter; @@ -47,6 +48,7 @@ import static org.bytedeco.opencv.global.opencv_imgproc.*; * There is a special preProcessor used to normalize the dataset based on Sergey Zagoruyko example * https://github.com/szagoruyko/cifar.torch */ +@Slf4j public class CifarLoader extends NativeImageLoader implements Serializable { public static final int NUM_TRAIN_IMAGES = 50000; public static final int NUM_TEST_IMAGES = 10000; diff --git a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/ImageLoader.java b/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/ImageLoader.java index d246c65ad..9c2c61d57 100644 --- a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/ImageLoader.java +++ b/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/ImageLoader.java @@ -249,7 +249,14 @@ public class ImageLoader extends BaseImageLoader { * @throws IOException */ public INDArray asMatrix(File f) throws IOException { - return NDArrayUtil.toNDArray(fromFile(f)); + return asMatrix(f, true); + } + + @Override + public INDArray asMatrix(File f, boolean nchw) throws IOException { + try(InputStream is = new BufferedInputStream(new FileInputStream(f))){ + return asMatrix(is, nchw); + } } /** @@ -259,34 +266,68 @@ public class ImageLoader extends BaseImageLoader { * @return the input stream to convert */ public INDArray asMatrix(InputStream inputStream) throws IOException { - if (channels == 3) - return toBgr(inputStream); - try { - BufferedImage image = ImageIO.read(inputStream); - return asMatrix(image); - } catch (IOException e) { - throw new IOException("Unable to load image", e); + return asMatrix(inputStream, true); + } + + @Override + public INDArray asMatrix(InputStream inputStream, boolean nchw) throws IOException { + INDArray ret; + if (channels == 3) { + ret = toBgr(inputStream); + } else { + try { + BufferedImage image = ImageIO.read(inputStream); + ret = asMatrix(image); + } catch (IOException e) { + throw new IOException("Unable to load image", e); + } } + if(ret.rank() == 3){ + ret = ret.reshape(1, ret.size(0), ret.size(1), ret.size(2)); + } + if(!nchw) + ret = ret.permute(0,2,3,1); //NCHW to NHWC + return ret; } @Override public org.datavec.image.data.Image asImageMatrix(File f) throws IOException { + return asImageMatrix(f, true); + } + + @Override + public org.datavec.image.data.Image asImageMatrix(File f, boolean nchw) throws IOException { try (BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f))) { - return asImageMatrix(bis); + return asImageMatrix(bis, nchw); } } @Override public org.datavec.image.data.Image asImageMatrix(InputStream inputStream) throws IOException { - if (channels == 3) - return toBgrImage(inputStream); - try { - BufferedImage image = ImageIO.read(inputStream); - INDArray asMatrix = asMatrix(image); - return new org.datavec.image.data.Image(asMatrix, image.getData().getNumBands(), image.getHeight(), image.getWidth()); - } catch (IOException e) { - throw new IOException("Unable to load image", e); + return asImageMatrix(inputStream, true); + } + + @Override + public org.datavec.image.data.Image asImageMatrix(InputStream inputStream, boolean nchw) throws IOException { + org.datavec.image.data.Image ret; + if (channels == 3) { + ret = toBgrImage(inputStream); + } else { + try { + BufferedImage image = ImageIO.read(inputStream); + INDArray asMatrix = asMatrix(image); + ret = new org.datavec.image.data.Image(asMatrix, image.getData().getNumBands(), image.getHeight(), image.getWidth()); + } catch (IOException e) { + throw new IOException("Unable to load image", e); + } } + if(ret.getImage().rank() == 3){ + INDArray a = ret.getImage(); + ret.setImage(a.reshape(1, a.size(0), a.size(1), a.size(2))); + } + if(!nchw) + ret.setImage(ret.getImage().permute(0,2,3,1)); //NCHW to NHWC + return ret; } /** diff --git a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/LFWLoader.java b/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/LFWLoader.java index d28c73318..b71c53e42 100644 --- a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/LFWLoader.java +++ b/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/LFWLoader.java @@ -17,6 +17,7 @@ package org.datavec.image.loader; +import lombok.extern.slf4j.Slf4j; import org.datavec.api.io.filters.BalancedPathFilter; import org.datavec.api.io.labels.PathLabelGenerator; import org.datavec.api.io.labels.PatternPathLabelGenerator; @@ -48,6 +49,7 @@ import java.util.Random; * most images are in color, although a few are grayscale * */ +@Slf4j public class LFWLoader extends BaseImageLoader implements Serializable { public final static int NUM_IMAGES = 13233; @@ -270,19 +272,39 @@ public class LFWLoader extends BaseImageLoader implements Serializable { throw new UnsupportedOperationException(); } + @Override + public INDArray asMatrix(File f, boolean nchw) throws IOException { + throw new UnsupportedOperationException(); + } + @Override public INDArray asMatrix(InputStream inputStream) throws IOException { throw new UnsupportedOperationException(); } + @Override + public INDArray asMatrix(InputStream inputStream, boolean nchw) throws IOException { + throw new UnsupportedOperationException(); + } + @Override public Image asImageMatrix(File f) throws IOException { throw new UnsupportedOperationException(); } + @Override + public Image asImageMatrix(File f, boolean nchw) throws IOException { + throw new UnsupportedOperationException(); + } + @Override public Image asImageMatrix(InputStream inputStream) throws IOException { throw new UnsupportedOperationException(); } + @Override + public Image asImageMatrix(InputStream inputStream, boolean nchw) throws IOException { + throw new UnsupportedOperationException(); + } + } diff --git a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/NativeImageLoader.java b/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/NativeImageLoader.java index 88bc161f2..ae9e2a322 100644 --- a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/NativeImageLoader.java +++ b/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/NativeImageLoader.java @@ -248,17 +248,27 @@ public class NativeImageLoader extends BaseImageLoader { @Override public INDArray asMatrix(File f) throws IOException { + return asMatrix(f, true); + } + + @Override + public INDArray asMatrix(File f, boolean nchw) throws IOException { try (BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f))) { - return asMatrix(bis); + return asMatrix(bis, nchw); } } @Override public INDArray asMatrix(InputStream is) throws IOException { - Mat mat = streamToMat(is); + return asMatrix(is, true); + } + + @Override + public INDArray asMatrix(InputStream inputStream, boolean nchw) throws IOException { + Mat mat = streamToMat(inputStream); INDArray a; if (this.multiPageMode != null) { - a = asMatrix(mat.data(), mat.cols()); + a = asMatrix(mat.data(), mat.cols()); }else{ Mat image = imdecode(mat, IMREAD_ANYDEPTH | IMREAD_ANYCOLOR); if (image == null || image.empty()) { @@ -272,7 +282,11 @@ public class NativeImageLoader extends BaseImageLoader { a = asMatrix(image); image.deallocate(); } - return a; + if(nchw) { + return a; + } else { + return a.permute(0, 2, 3, 1); //NCHW to NHWC + } } /** @@ -331,19 +345,29 @@ public class NativeImageLoader extends BaseImageLoader { } public Image asImageMatrix(String filename) throws IOException { - return asImageMatrix(filename); + return asImageMatrix(new File(filename)); } @Override public Image asImageMatrix(File f) throws IOException { + return asImageMatrix(f, true); + } + + @Override + public Image asImageMatrix(File f, boolean nchw) throws IOException { try (BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f))) { - return asImageMatrix(bis); + return asImageMatrix(bis, nchw); } } @Override public Image asImageMatrix(InputStream is) throws IOException { - Mat mat = streamToMat(is); + return asImageMatrix(is, true); + } + + @Override + public Image asImageMatrix(InputStream inputStream, boolean nchw) throws IOException { + Mat mat = streamToMat(inputStream); Mat image = imdecode(mat, IMREAD_ANYDEPTH | IMREAD_ANYCOLOR); if (image == null || image.empty()) { PIX pix = pixReadMem(mat.data(), mat.cols()); @@ -354,6 +378,8 @@ public class NativeImageLoader extends BaseImageLoader { pixDestroy(pix); } INDArray a = asMatrix(image); + if(!nchw) + a = a.permute(0,2,3,1); //NCHW to NHWC Image i = new Image(a, image.channels(), image.rows(), image.cols()); image.deallocate(); diff --git a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/recordreader/BaseImageRecordReader.java b/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/recordreader/BaseImageRecordReader.java index fb780ea74..d5400ee8e 100644 --- a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/recordreader/BaseImageRecordReader.java +++ b/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/recordreader/BaseImageRecordReader.java @@ -77,6 +77,8 @@ public abstract class BaseImageRecordReader extends BaseRecordReader { protected int patternPosition = 0; @Getter @Setter protected boolean logLabelCountOnInit = true; + @Getter @Setter + protected boolean nchw_channels_first = true; public final static String HEIGHT = NAME_SPACE + ".height"; public final static String WIDTH = NAME_SPACE + ".width"; @@ -101,6 +103,11 @@ public abstract class BaseImageRecordReader extends BaseRecordReader { protected BaseImageRecordReader(long height, long width, long channels, PathLabelGenerator labelGenerator, PathMultiLabelGenerator labelMultiGenerator, ImageTransform imageTransform) { + this(height, width, channels, true, labelGenerator, labelMultiGenerator, imageTransform); + } + + protected BaseImageRecordReader(long height, long width, long channels, boolean nchw_channels_first, PathLabelGenerator labelGenerator, + PathMultiLabelGenerator labelMultiGenerator, ImageTransform imageTransform) { this.height = height; this.width = width; this.channels = channels; @@ -108,6 +115,7 @@ public abstract class BaseImageRecordReader extends BaseRecordReader { this.labelMultiGenerator = labelMultiGenerator; this.imageTransform = imageTransform; this.appendLabel = (labelGenerator != null || labelMultiGenerator != null); + this.nchw_channels_first = nchw_channels_first; } protected boolean containsFormat(String format) { @@ -237,9 +245,13 @@ public abstract class BaseImageRecordReader extends BaseRecordReader { return next(); try { invokeListeners(image); - INDArray row = imageLoader.asMatrix(image); - Nd4j.getAffinityManager().ensureLocation(row, AffinityManager.Location.DEVICE); - ret = RecordConverter.toRecord(row); + INDArray array = imageLoader.asMatrix(image); + if(!nchw_channels_first){ + array = array.permute(0,2,3,1); //NCHW to NHWC + } + + Nd4j.getAffinityManager().ensureLocation(array, AffinityManager.Location.DEVICE); + ret = RecordConverter.toRecord(array); if (appendLabel || writeLabel){ if(labelMultiGenerator != null){ ret.addAll(labelMultiGenerator.getLabels(image.getPath())); @@ -286,7 +298,7 @@ public abstract class BaseImageRecordReader extends BaseRecordReader { @Override public List> next(int num) { - Preconditions.checkArgument(num > 0, "Number of examples must be > 0: got " + num); + Preconditions.checkArgument(num > 0, "Number of examples must be > 0: got %s", num); if (imageLoader == null) { imageLoader = new NativeImageLoader(height, width, channels, imageTransform); @@ -337,6 +349,9 @@ public abstract class BaseImageRecordReader extends BaseRecordReader { throw new RuntimeException(e); } } + if(!nchw_channels_first){ + features = features.permute(0,2,3,1); //NCHW to NHWC + } Nd4j.getAffinityManager().ensureLocation(features, AffinityManager.Location.DEVICE); @@ -483,8 +498,10 @@ public abstract class BaseImageRecordReader extends BaseRecordReader { if (imageLoader == null) { imageLoader = new NativeImageLoader(height, width, channels, imageTransform); } - INDArray row = imageLoader.asMatrix(dataInputStream); - List ret = RecordConverter.toRecord(row); + INDArray array = imageLoader.asMatrix(dataInputStream); + if(!nchw_channels_first) + array = array.permute(0,2,3,1); + List ret = RecordConverter.toRecord(array); if (appendLabel) ret.add(new IntWritable(labels.indexOf(getLabel(uri.getPath())))); return ret; diff --git a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/recordreader/ImageRecordReader.java b/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/recordreader/ImageRecordReader.java index be7a6d8d9..f8e292c26 100644 --- a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/recordreader/ImageRecordReader.java +++ b/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/recordreader/ImageRecordReader.java @@ -34,47 +34,70 @@ import org.datavec.image.transform.ImageTransform; public class ImageRecordReader extends BaseImageRecordReader { - /** Loads images with height = 28, width = 28, and channels = 1, appending no labels. */ + /** Loads images with height = 28, width = 28, and channels = 1, appending no labels. + * Output format is NCHW (channels first) - [numExamples, 1, 28, 28]*/ public ImageRecordReader() { super(); } - /** Loads images with given height, width, and channels, appending labels returned by the generator. */ + /** Loads images with given height, width, and channels, appending labels returned by the generator. + * Output format is NCHW (channels first) - [numExamples, channels, height, width] + */ public ImageRecordReader(long height, long width, long channels, PathLabelGenerator labelGenerator) { super(height, width, channels, labelGenerator); } - /** Loads images with given height, width, and channels, appending labels returned by the generator. */ + /** Loads images with given height, width, and channels, appending labels returned by the generator. + * Output format is NCHW (channels first) - [numExamples, channels, height, width] + */ public ImageRecordReader(long height, long width, long channels, PathMultiLabelGenerator labelGenerator) { super(height, width, channels, labelGenerator); } - /** Loads images with given height, width, and channels, appending no labels. */ + /** Loads images with given height, width, and channels, appending no labels - in NCHW (channels first) format */ public ImageRecordReader(long height, long width, long channels) { super(height, width, channels, (PathLabelGenerator) null); } - /** Loads images with given height, width, and channels, appending labels returned by the generator. */ + /** Loads images with given height, width, and channels, appending no labels - in specified format
+ * If {@code nchw_channels_first == true} output format is NCHW (channels first) - [numExamples, channels, height, width]
+ * If {@code nchw_channels_first == false} output format is NHWC (channels last) - [numExamples, height, width, channels]
+ */ + public ImageRecordReader(long height, long width, long channels, boolean nchw_channels_first) { + super(height, width, channels, nchw_channels_first, null, null, null); + } + + /** Loads images with given height, width, and channels, appending labels returned by the generator. + * Output format is NCHW (channels first) - [numExamples, channels, height, width] */ public ImageRecordReader(long height, long width, long channels, PathLabelGenerator labelGenerator, ImageTransform imageTransform) { super(height, width, channels, labelGenerator, imageTransform); } - /** Loads images with given height, width, and channels, appending no labels. */ + /** Loads images with given height, width, and channels, appending labels returned by the generator.
+ * If {@code nchw_channels_first == true} output format is NCHW (channels first) - [numExamples, channels, height, width]
+ * If {@code nchw_channels_first == false} output format is NHWC (channels last) - [numExamples, height, width, channels]
+ */ + public ImageRecordReader(long height, long width, long channels, boolean nchw_channels_first, PathLabelGenerator labelGenerator, + ImageTransform imageTransform) { + super(height, width, channels, nchw_channels_first, labelGenerator, null, imageTransform); + } + + /** Loads images with given height, width, and channels, appending no labels. + * Output format is NCHW (channels first) - [numExamples, channels, height, width]*/ public ImageRecordReader(long height, long width, long channels, ImageTransform imageTransform) { super(height, width, channels, null, imageTransform); } - /** Loads images with given height, width, and channels, appending labels returned by the generator. */ + /** Loads images with given height, width, and channels, appending labels returned by the generator + * Output format is NCHW (channels first) - [numExamples, channels, height, width]*/ public ImageRecordReader(long height, long width, PathLabelGenerator labelGenerator) { super(height, width, 1, labelGenerator); } - /** Loads images with given height, width, and channels = 1, appending no labels. */ + /** Loads images with given height, width, and channels = 1, appending no labels. + * Output format is NCHW (channels first) - [numExamples, channels, height, width]*/ public ImageRecordReader(long height, long width) { super(height, width, 1, null, null); } - - - } diff --git a/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/loader/TestImageLoader.java b/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/loader/TestImageLoader.java index 1683980f0..a82f12409 100644 --- a/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/loader/TestImageLoader.java +++ b/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/loader/TestImageLoader.java @@ -16,10 +16,16 @@ package org.datavec.image.loader; +import org.datavec.image.data.Image; import org.junit.Test; +import org.nd4j.common.resources.Resources; import org.nd4j.linalg.api.ndarray.INDArray; import java.awt.image.BufferedImage; +import java.io.BufferedInputStream; +import java.io.File; +import java.io.FileInputStream; +import java.io.InputStream; import java.util.Random; import static org.junit.Assert.assertEquals; @@ -208,4 +214,57 @@ public class TestImageLoader { private BufferedImage makeRandomBufferedImage(boolean alpha) { return makeRandomBufferedImage(alpha, rng.nextInt() % 100 + 100, rng.nextInt() % 100 + 100); } + + + @Test + public void testNCHW_NHWC() throws Exception { + File f = Resources.asFile("datavec-data-image/voc/2007/JPEGImages/000005.jpg"); + + ImageLoader il = new ImageLoader(32, 32, 3); + + //asMatrix(File, boolean) + INDArray a_nchw = il.asMatrix(f); + INDArray a_nchw2 = il.asMatrix(f, true); + INDArray a_nhwc = il.asMatrix(f, false); + + assertEquals(a_nchw, a_nchw2); + assertEquals(a_nchw, a_nhwc.permute(0,3,1,2)); + + + //asMatrix(InputStream, boolean) + try(InputStream is = new BufferedInputStream(new FileInputStream(f))){ + a_nchw = il.asMatrix(is); + } + try(InputStream is = new BufferedInputStream(new FileInputStream(f))){ + a_nchw2 = il.asMatrix(is, true); + } + try(InputStream is = new BufferedInputStream(new FileInputStream(f))){ + a_nhwc = il.asMatrix(is, false); + } + assertEquals(a_nchw, a_nchw2); + assertEquals(a_nchw, a_nhwc.permute(0,3,1,2)); + + + //asImageMatrix(File, boolean) + Image i_nchw = il.asImageMatrix(f); + Image i_nchw2 = il.asImageMatrix(f, true); + Image i_nhwc = il.asImageMatrix(f, false); + + assertEquals(i_nchw.getImage(), i_nchw2.getImage()); + assertEquals(i_nchw.getImage(), i_nhwc.getImage().permute(0,3,1,2)); //NHWC to NCHW + + + //asImageMatrix(InputStream, boolean) + try(InputStream is = new BufferedInputStream(new FileInputStream(f))){ + i_nchw = il.asImageMatrix(is); + } + try(InputStream is = new BufferedInputStream(new FileInputStream(f))){ + i_nchw2 = il.asImageMatrix(is, true); + } + try(InputStream is = new BufferedInputStream(new FileInputStream(f))){ + i_nhwc = il.asImageMatrix(is, false); + } + assertEquals(i_nchw.getImage(), i_nchw2.getImage()); + assertEquals(i_nchw.getImage(), i_nhwc.getImage().permute(0,3,1,2)); //NHWC to NCHW + } } diff --git a/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/loader/TestNativeImageLoader.java b/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/loader/TestNativeImageLoader.java index 6e7705569..68e93107c 100644 --- a/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/loader/TestNativeImageLoader.java +++ b/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/loader/TestNativeImageLoader.java @@ -24,20 +24,19 @@ import org.bytedeco.javacpp.indexer.UByteIndexer; import org.bytedeco.javacv.Frame; import org.bytedeco.javacv.Java2DFrameConverter; import org.bytedeco.javacv.OpenCVFrameConverter; +import org.datavec.image.data.Image; import org.datavec.image.data.ImageWritable; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; +import org.nd4j.common.resources.Resources; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.common.io.ClassPathResource; import java.awt.image.BufferedImage; -import java.io.File; -import java.io.FileInputStream; -import java.io.InputStream; -import java.io.IOException; +import java.io.*; import java.lang.reflect.Field; import java.util.Random; @@ -604,4 +603,56 @@ public class TestNativeImageLoader { } } + @Test + public void testNCHW_NHWC() throws Exception { + File f = Resources.asFile("datavec-data-image/voc/2007/JPEGImages/000005.jpg"); + + NativeImageLoader il = new NativeImageLoader(32, 32, 3); + + //asMatrix(File, boolean) + INDArray a_nchw = il.asMatrix(f); + INDArray a_nchw2 = il.asMatrix(f, true); + INDArray a_nhwc = il.asMatrix(f, false); + + assertEquals(a_nchw, a_nchw2); + assertEquals(a_nchw, a_nhwc.permute(0,3,1,2)); + + + //asMatrix(InputStream, boolean) + try(InputStream is = new BufferedInputStream(new FileInputStream(f))){ + a_nchw = il.asMatrix(is); + } + try(InputStream is = new BufferedInputStream(new FileInputStream(f))){ + a_nchw2 = il.asMatrix(is, true); + } + try(InputStream is = new BufferedInputStream(new FileInputStream(f))){ + a_nhwc = il.asMatrix(is, false); + } + assertEquals(a_nchw, a_nchw2); + assertEquals(a_nchw, a_nhwc.permute(0,3,1,2)); + + + //asImageMatrix(File, boolean) + Image i_nchw = il.asImageMatrix(f); + Image i_nchw2 = il.asImageMatrix(f, true); + Image i_nhwc = il.asImageMatrix(f, false); + + assertEquals(i_nchw.getImage(), i_nchw2.getImage()); + assertEquals(i_nchw.getImage(), i_nhwc.getImage().permute(0,3,1,2)); //NHWC to NCHW + + + //asImageMatrix(InputStream, boolean) + try(InputStream is = new BufferedInputStream(new FileInputStream(f))){ + i_nchw = il.asImageMatrix(is); + } + try(InputStream is = new BufferedInputStream(new FileInputStream(f))){ + i_nchw2 = il.asImageMatrix(is, true); + } + try(InputStream is = new BufferedInputStream(new FileInputStream(f))){ + i_nhwc = il.asImageMatrix(is, false); + } + assertEquals(i_nchw.getImage(), i_nchw2.getImage()); + assertEquals(i_nchw.getImage(), i_nhwc.getImage().permute(0,3,1,2)); //NHWC to NCHW + } + } diff --git a/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/recordreader/TestImageRecordReader.java b/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/recordreader/TestImageRecordReader.java index 80cb9b0af..26cd83f06 100644 --- a/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/recordreader/TestImageRecordReader.java +++ b/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/recordreader/TestImageRecordReader.java @@ -35,13 +35,13 @@ import org.datavec.api.writable.batch.NDArrayRecordBatch; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; +import org.nd4j.common.resources.Resources; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.common.io.ClassPathResource; -import java.io.File; -import java.io.IOException; +import java.io.*; import java.net.URI; import java.util.ArrayList; import java.util.Arrays; @@ -467,5 +467,87 @@ public class TestImageRecordReader { return count; } } + + + + @Test + public void testNCHW_NCHW() throws Exception { + //Idea: labels order should be consistent regardless of input file order + File f0 = testDir.newFolder(); + new ClassPathResource("datavec-data-image/testimages/").copyDirectory(f0); + + FileSplit fs0 = new FileSplit(f0, new Random(12345)); + FileSplit fs1 = new FileSplit(f0, new Random(12345)); + assertEquals(6, fs0.locations().length); + assertEquals(6, fs1.locations().length); + + ImageRecordReader nchw = new ImageRecordReader(32, 32, 3, true); + nchw.initialize(fs0); + + ImageRecordReader nhwc = new ImageRecordReader(32, 32, 3, false); + nhwc.initialize(fs1); + + while(nchw.hasNext()){ + assertTrue(nhwc.hasNext()); + + List l_nchw = nchw.next(); + List l_nhwc = nhwc.next(); + + INDArray a_nchw = ((NDArrayWritable)l_nchw.get(0)).get(); + INDArray a_nhwc = ((NDArrayWritable)l_nhwc.get(0)).get(); + + assertArrayEquals(new long[]{1, 3, 32, 32}, a_nchw.shape()); + assertArrayEquals(new long[]{1, 32, 32, 3}, a_nhwc.shape()); + + INDArray permuted = a_nhwc.permute(0,3,1,2); //NHWC to NCHW + assertEquals(a_nchw, permuted); + } + + + //Test batch: + nchw.reset(); + nhwc.reset(); + + int batchCount = 0; + while(nchw.hasNext()){ + assertTrue(nhwc.hasNext()); + batchCount++; + + List> l_nchw = nchw.next(3); + List> l_nhwc = nhwc.next(3); + assertEquals(3, l_nchw.size()); + assertEquals(3, l_nhwc.size()); + + NDArrayRecordBatch b_nchw = (NDArrayRecordBatch)l_nchw; + NDArrayRecordBatch b_nhwc = (NDArrayRecordBatch)l_nhwc; + + INDArray a_nchw = b_nchw.getArrays().get(0); + INDArray a_nhwc = b_nhwc.getArrays().get(0); + + assertArrayEquals(new long[]{3, 3, 32, 32}, a_nchw.shape()); + assertArrayEquals(new long[]{3, 32, 32, 3}, a_nhwc.shape()); + + INDArray permuted = a_nhwc.permute(0,3,1,2); //NHWC to NCHW + assertEquals(a_nchw, permuted); + } + assertEquals(2, batchCount); + + + //Test record(URI, DataInputStream) + + URI u = fs0.locations()[0]; + + try(DataInputStream dis = new DataInputStream(new BufferedInputStream(new FileInputStream(new File(u))))) { + List l = nchw.record(u, dis); + INDArray arr = ((NDArrayWritable)l.get(0)).get(); + assertArrayEquals(new long[]{1, 3, 32, 32}, arr.shape()); + } + + try(DataInputStream dis = new DataInputStream(new BufferedInputStream(new FileInputStream(new File(u))))) { + List l = nhwc.record(u, dis); + INDArray arr = ((NDArrayWritable)l.get(0)).get(); + assertArrayEquals(new long[]{1, 32, 32, 3}, arr.shape()); + } + } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/RandomTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/RandomTests.java index b52b7cb49..3ea9e07f3 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/RandomTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/RandomTests.java @@ -8,12 +8,14 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.junit.Ignore; import org.junit.Test; +import org.nd4j.common.resources.Resources; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.RmsProp; import org.nd4j.linalg.lossfunctions.LossFunctions; +import java.nio.file.Files; import java.util.concurrent.CountDownLatch; @Ignore diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvDataFormatTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvDataFormatTests.java index 5d540baa7..76d14d47d 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvDataFormatTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvDataFormatTests.java @@ -18,11 +18,9 @@ package org.deeplearning4j.nn.layers.convolution; import lombok.*; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.TestUtils; +import org.deeplearning4j.exception.DL4JInvalidInputException; import org.deeplearning4j.nn.api.MaskState; -import org.deeplearning4j.nn.conf.CNN2DFormat; -import org.deeplearning4j.nn.conf.ConvolutionMode; -import org.deeplearning4j.nn.conf.InputPreProcessor; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.*; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.conf.layers.CnnLossLayer; @@ -35,6 +33,7 @@ import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.workspace.ArrayType; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; +import org.deeplearning4j.util.ConvolutionUtils; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; @@ -49,6 +48,7 @@ import java.util.List; import java.util.Map; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; @RunWith(Parameterized.class) public class ConvDataFormatTests extends BaseDL4JTest { @@ -971,4 +971,58 @@ public class ConvDataFormatTests extends BaseDL4JTest { return null; } } + + + @Test + public void testWrongFormatIn(){ + + for(CNN2DFormat df : CNN2DFormat.values()){ + + + for(int i=0; i<4; i++ ){ + + NeuralNetConfiguration.ListBuilder b = new NeuralNetConfiguration.Builder() + .list(); + switch (i){ + case 0: + b.layer(new ConvolutionLayer.Builder().kernelSize(2,2).nIn(3).nOut(3).dataFormat(df).build()); + break; + case 1: + b.layer(new DepthwiseConvolution2D.Builder().kernelSize(2,2).nIn(3).nOut(3).dataFormat(df).build()); + break; + case 2: + b.layer(new Deconvolution2D.Builder().dataFormat(df).kernelSize(2,2).nIn(3).nOut(3).build()); + break; + case 3: + b.layer(new SeparableConvolution2D.Builder().dataFormat(df).kernelSize(2,2).nIn(3).nOut(3).build()); + break; + } + + MultiLayerNetwork net = new MultiLayerNetwork(b.build()); + net.init(); + + INDArray in; + INDArray wrongFormatIn; + if(df == CNN2DFormat.NCHW){ + in = Nd4j.create(DataType.FLOAT, 5, 3, 12, 12); + wrongFormatIn = Nd4j.create(DataType.FLOAT, 5, 12, 12, 3); + } else { + in = Nd4j.create(DataType.FLOAT, 5, 12, 12, 3); + wrongFormatIn = Nd4j.create(DataType.FLOAT, 5, 3, 12, 12); + } + + net.output(in); + + try { + net.output(wrongFormatIn); + } catch (DL4JInvalidInputException e){ +// e.printStackTrace(); + String msg = e.getMessage(); + assertTrue(msg, msg.contains(ConvolutionUtils.NCHW_NHWC_ERROR_MSG)); + } + } + } + + + } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimize/solver/accumulation/SmartFancyBlockingQueueTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimize/solver/accumulation/SmartFancyBlockingQueueTest.java index 63a69b82c..5abd5a253 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimize/solver/accumulation/SmartFancyBlockingQueueTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimize/solver/accumulation/SmartFancyBlockingQueueTest.java @@ -21,9 +21,9 @@ import lombok.val; import org.apache.commons.lang3.RandomUtils; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.optimize.solvers.accumulation.SmartFancyBlockingQueue; -import org.deeplearning4j.core.util.ThreadUtils; import org.junit.Ignore; import org.junit.Test; +import org.nd4j.common.util.ThreadUtils; import org.nd4j.linalg.factory.Nd4j; import java.util.ArrayList; diff --git a/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/pom.xml b/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/pom.xml index 462bebc95..7806bab88 100644 --- a/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/pom.xml +++ b/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/pom.xml @@ -31,11 +31,6 @@ - - org.deeplearning4j - deeplearning4j-util - ${project.version} - org.nd4j nd4j-api diff --git a/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/impl/MovingWindowDataSetFetcher.java b/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/impl/MovingWindowDataSetFetcher.java deleted file mode 100644 index e8bee9092..000000000 --- a/deeplearning4j/deeplearning4j-data/deeplearning4j-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/impl/MovingWindowDataSetFetcher.java +++ /dev/null @@ -1,75 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.datasets.iterator.impl; - -import org.deeplearning4j.util.MovingWindowMatrix; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.dataset.DataSet; -import org.nd4j.linalg.dataset.api.iterator.fetcher.BaseDataFetcher; -import org.nd4j.common.util.ArrayUtil; - -import java.util.ArrayList; -import java.util.List; - -/** - * - * Moving window data fetcher. Handles rotation of matrices in all directions - * to generate more examples. - * - * - * @author Adam Gibson - */ -public class MovingWindowDataSetFetcher extends BaseDataFetcher { - - private DataSet data; - private int windowRows = 28, windowColumns = 28; - private int cursor = 0; - - public MovingWindowDataSetFetcher(DataSet data, int windowRows, int windowColumns) { - this.data = data; - this.windowRows = windowRows; - this.windowColumns = windowColumns; - List list = data.asList(); - List flipped = new ArrayList<>(); - for (int i = 0; i < list.size(); i++) { - INDArray label = list.get(i).getLabels(); - List windows = - new MovingWindowMatrix(list.get(i).getFeatures(), windowRows, windowColumns, true) - .windows(true); - for (int j = 0; j < windows.size(); j++) { - flipped.add(new DataSet(windows.get(j), label)); - } - flipped.add(list.get(i)); - } - - this.data = DataSet.merge(flipped); - - } - - /** - * Fetches the next dataset. You need to call this - * to get a new dataset, otherwise {@link #next()} - * just returns the last data applyTransformToDestination fetch - * - * @param numExamples the number of examples to fetch - */ - @Override - public void fetch(int numExamples) { - initializeCurrFromList(data.get(ArrayUtil.range(cursor, cursor + numExamples)).asList()); - - } -} diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/KerasLoss.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/KerasLoss.java index e3c603287..d47309d1d 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/KerasLoss.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/KerasLoss.java @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -28,6 +29,7 @@ import org.deeplearning4j.nn.modelimport.keras.KerasLayer; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; import org.nd4j.linalg.activations.Activation; +import org.nd4j.linalg.lossfunctions.ILossFunction; import org.nd4j.linalg.lossfunctions.LossFunctions; import java.util.ArrayList; @@ -45,7 +47,7 @@ import static org.deeplearning4j.nn.modelimport.keras.utils.KerasLossUtils.mapLo public class KerasLoss extends KerasLayer { private final String KERAS_CLASS_NAME_LOSS = "Loss"; - private LossFunctions.LossFunction loss; + private ILossFunction loss; /** @@ -86,7 +88,7 @@ public class KerasLoss extends KerasLayer { if (enforceTrainingConfig) throw e; log.warn("Unsupported Keras loss function. Replacing with MSE."); - loss = LossFunctions.LossFunction.SQUARED_LOSS; + loss = LossFunctions.LossFunction.SQUARED_LOSS.getILossFunction(); } } diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasLossUtils.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasLossUtils.java index 35cf34170..b9e0ddfce 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasLossUtils.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasLossUtils.java @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -19,8 +20,13 @@ package org.deeplearning4j.nn.modelimport.keras.utils; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; +import org.nd4j.linalg.lossfunctions.ILossFunction; import org.nd4j.linalg.lossfunctions.LossFunctions; +import java.util.HashMap; +import java.util.Map; + + /** * Utility functionality for keras loss functions * @@ -28,13 +34,33 @@ import org.nd4j.linalg.lossfunctions.LossFunctions; */ @Slf4j public class KerasLossUtils { + static final Map customLoss = new HashMap<>(); + + /** + * Register a custom loss function + * + * @param lossName name of the lambda layer in the serialized Keras model + * @param lossFunction SameDiffLambdaLayer instance to map to Keras Lambda layer + */ + public static void registerCustomLoss(String lossName, ILossFunction lossFunction) { + customLoss.put(lossName, lossFunction); + } + + /** + * Clear all lambda layers + * + */ + public static void clearCustomLoss() { + customLoss.clear(); + } + /** * Map Keras to DL4J loss functions. * * @param kerasLoss String containing Keras loss function name * @return String containing DL4J loss function */ - public static LossFunctions.LossFunction mapLossFunction(String kerasLoss, KerasLayerConfiguration conf) + public static ILossFunction mapLossFunction(String kerasLoss, KerasLayerConfiguration conf) throws UnsupportedKerasConfigurationException { LossFunctions.LossFunction dl4jLoss; if (kerasLoss.equals(conf.getKERAS_LOSS_MEAN_SQUARED_ERROR()) || @@ -67,8 +93,13 @@ public class KerasLossUtils { } else if (kerasLoss.equals(conf.getKERAS_LOSS_COSINE_PROXIMITY())) { dl4jLoss = LossFunctions.LossFunction.COSINE_PROXIMITY; } else { - throw new UnsupportedKerasConfigurationException("Unknown Keras loss function " + kerasLoss); + ILossFunction lossClass = customLoss.get(kerasLoss); + if(lossClass != null){ + return lossClass; + }else{ + throw new UnsupportedKerasConfigurationException("Unknown Keras loss function " + kerasLoss); + } } - return dl4jLoss; + return dl4jLoss.getILossFunction(); } } diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasCustomLossTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasCustomLossTest.java new file mode 100644 index 000000000..23c46835e --- /dev/null +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasCustomLossTest.java @@ -0,0 +1,78 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.deeplearning4j.nn.modelimport.keras.e2e; + +import org.deeplearning4j.BaseDL4JTest; +import org.deeplearning4j.nn.modelimport.keras.KerasSequentialModel; +import org.deeplearning4j.nn.modelimport.keras.utils.KerasLossUtils; +import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.common.resources.Resources; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.lossfunctions.SameDiffLoss; + +import java.io.File; +import java.io.InputStream; +import java.nio.file.Files; +import java.nio.file.StandardCopyOption; + + +/** + * Test importing Keras models with custom loss. + * + * @author Paul Dubs + */ +public class KerasCustomLossTest extends BaseDL4JTest { + + @Rule + public TemporaryFolder testDir = new TemporaryFolder(); + + public class LogCosh extends SameDiffLoss { + @Override + public SDVariable defineLoss(SameDiff sd, SDVariable layerInput, SDVariable labels) { + return sd.math.log(sd.math.cosh(labels.sub(layerInput))); + } + } + + @Test + public void testSequentialLambdaLayerImport() throws Exception { + KerasLossUtils.registerCustomLoss("logcosh", new LogCosh()); + + String modelPath = "modelimport/keras/examples/custom_loss.h5"; + + try(InputStream is = Resources.asStream(modelPath)) { + File modelFile = testDir.newFile("tempModel" + System.currentTimeMillis() + ".h5"); + Files.copy(is, modelFile.toPath(), StandardCopyOption.REPLACE_EXISTING); + MultiLayerNetwork model = new KerasSequentialModel().modelBuilder().modelHdf5Filename(modelFile.getAbsolutePath()) + .enforceTrainingConfig(true).buildSequential().getMultiLayerNetwork(); + + System.out.println(model.summary()); + INDArray input = Nd4j.create(new int[]{10, 3}); + + model.output(input); + } finally { + KerasLossUtils.clearCustomLoss(); + } + } + + +} diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/AbstractCoOccurrences.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/AbstractCoOccurrences.java index 7c9134e2c..969dbaeb9 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/AbstractCoOccurrences.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/glove/AbstractCoOccurrences.java @@ -29,7 +29,7 @@ import org.deeplearning4j.text.sentenceiterator.PrefetchingSentenceIterator; import org.deeplearning4j.text.sentenceiterator.SentenceIterator; import org.deeplearning4j.text.sentenceiterator.SynchronizedSentenceIterator; import org.deeplearning4j.common.util.DL4JFileUtils; -import org.deeplearning4j.core.util.ThreadUtils; +import org.nd4j.common.util.ThreadUtils; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.common.primitives.Pair; import org.slf4j.Logger; diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/paragraphvectors/ParagraphVectors.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/paragraphvectors/ParagraphVectors.java index 1f1dce5f7..c007d4b96 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/paragraphvectors/ParagraphVectors.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/paragraphvectors/ParagraphVectors.java @@ -47,7 +47,7 @@ import org.deeplearning4j.text.sentenceiterator.SentenceIterator; import org.deeplearning4j.text.sentenceiterator.interoperability.SentenceIteratorConverter; import org.deeplearning4j.text.sentenceiterator.labelaware.LabelAwareSentenceIterator; import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory; -import org.deeplearning4j.core.util.ThreadUtils; +import org.nd4j.common.util.ThreadUtils; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.factory.Nd4j; diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/SequenceVectors.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/SequenceVectors.java index 3f2d5f216..d31cc51b0 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/SequenceVectors.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/SequenceVectors.java @@ -47,7 +47,7 @@ import org.deeplearning4j.models.word2vec.VocabWord; import org.deeplearning4j.models.word2vec.wordstore.VocabCache; import org.deeplearning4j.models.word2vec.wordstore.VocabConstructor; import org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache; -import org.deeplearning4j.core.util.ThreadUtils; +import org.nd4j.common.util.ThreadUtils; import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration; import org.nd4j.linalg.api.memory.enums.LearningPolicy; import org.nd4j.linalg.api.ndarray.INDArray; diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/word2vec/wordstore/VocabConstructor.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/word2vec/wordstore/VocabConstructor.java index 10f2a4811..fca1288d0 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/word2vec/wordstore/VocabConstructor.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/word2vec/wordstore/VocabConstructor.java @@ -27,7 +27,7 @@ import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement; import org.deeplearning4j.models.word2vec.Huffman; import org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache; import org.deeplearning4j.text.invertedindex.InvertedIndex; -import org.deeplearning4j.core.util.ThreadUtils; +import org.nd4j.common.util.ThreadUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.threadly.concurrent.PriorityScheduler; diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/text/sentenceiterator/PrefetchingSentenceIterator.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/text/sentenceiterator/PrefetchingSentenceIterator.java index 8490a1f99..cb1e860c1 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/text/sentenceiterator/PrefetchingSentenceIterator.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/text/sentenceiterator/PrefetchingSentenceIterator.java @@ -18,7 +18,7 @@ package org.deeplearning4j.text.sentenceiterator; import lombok.NonNull; -import org.deeplearning4j.core.util.ThreadUtils; +import org.nd4j.common.util.ThreadUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Deconvolution2D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Deconvolution2D.java index 8daa947df..b2f64c894 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Deconvolution2D.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Deconvolution2D.java @@ -63,6 +63,9 @@ public class Deconvolution2D extends ConvolutionLayer { protected Deconvolution2D(BaseConvBuilder builder) { super(builder); initializeConstraints(builder); + if(builder instanceof Builder){ + this.cnn2dDataFormat = ((Builder) builder).format; + } } public boolean hasBias() { @@ -136,7 +139,7 @@ public class Deconvolution2D extends ConvolutionLayer { private CNN2DFormat format = CNN2DFormat.NCHW; - public Builder format(CNN2DFormat format){ + public Builder dataFormat(CNN2DFormat format){ this.format = format; return this; } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayer.java index b8b0c13a9..81804a31f 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayer.java @@ -310,11 +310,21 @@ public class ConvolutionLayer extends BaseLayer