Merge pull request #8923 from KonduitAI/master

Development updates
master
Alex Black 2020-05-06 16:53:46 +10:00 committed by GitHub
commit 37d880a23c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
36 changed files with 678 additions and 166 deletions

View File

@ -52,6 +52,7 @@ public class NDArrayRecordBatch extends AbstractWritableRecordBatch {
public NDArrayRecordBatch(@NonNull List<INDArray> arrays){
Preconditions.checkArgument(arrays.size() > 0, "Input list must not be empty");
this.arrays = arrays;
this.size = arrays.get(0).size(0);
//Check that dimension 0 matches:
if(arrays.size() > 1){

View File

@ -16,6 +16,7 @@
package org.datavec.image.loader;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils;
import org.datavec.image.data.Image;
import org.datavec.image.transform.ImageTransform;
@ -35,10 +36,9 @@ import java.util.Random;
/**
* Created by nyghtowl on 12/17/15.
*/
@Slf4j
public abstract class BaseImageLoader implements Serializable {
protected static final Logger log = LoggerFactory.getLogger(BaseImageLoader.class);
public enum MultiPageMode {
MINIBATCH, FIRST //, CHANNELS,
}
@ -62,13 +62,37 @@ public abstract class BaseImageLoader implements Serializable {
public abstract INDArray asRowVector(InputStream inputStream) throws IOException;
/** As per {@link #asMatrix(File, boolean)} but NCHW/channels_first format */
public abstract INDArray asMatrix(File f) throws IOException;
/**
* Load an image from a file to an INDArray
* @param f File to load the image from
* @param nchw If true: return image in NCHW/channels_first [1, channels, height width] format; if false, return
* in NHWC/channels_last [1, height, width, channels] format
* @return Image file as as INDArray
*/
public abstract INDArray asMatrix(File f, boolean nchw) throws IOException;
public abstract INDArray asMatrix(InputStream inputStream) throws IOException;
/**
* Load an image file from an input stream to an INDArray
* @param inputStream Input stream to load the image from
* @param nchw If true: return image in NCHW/channels_first [1, channels, height width] format; if false, return
* in NHWC/channels_last [1, height, width, channels] format
* @return Image file stream as as INDArray
*/
public abstract INDArray asMatrix(InputStream inputStream, boolean nchw) throws IOException;
/** As per {@link #asMatrix(File)} but as an {@link Image}*/
public abstract Image asImageMatrix(File f) throws IOException;
/** As per {@link #asMatrix(File, boolean)} but as an {@link Image}*/
public abstract Image asImageMatrix(File f, boolean nchw) throws IOException;
/** As per {@link #asMatrix(InputStream)} but as an {@link Image}*/
public abstract Image asImageMatrix(InputStream inputStream) throws IOException;
/** As per {@link #asMatrix(InputStream, boolean)} but as an {@link Image}*/
public abstract Image asImageMatrix(InputStream inputStream, boolean nchw) throws IOException;
public static void downloadAndUntar(Map urlMap, File fullDir) {

View File

@ -16,6 +16,7 @@
package org.datavec.image.loader;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils;
import org.bytedeco.javacv.OpenCVFrameConverter;
@ -47,6 +48,7 @@ import static org.bytedeco.opencv.global.opencv_imgproc.*;
* There is a special preProcessor used to normalize the dataset based on Sergey Zagoruyko example
* <a href="https://github.com/szagoruyko/cifar.torch">https://github.com/szagoruyko/cifar.torch</a>
*/
@Slf4j
public class CifarLoader extends NativeImageLoader implements Serializable {
public static final int NUM_TRAIN_IMAGES = 50000;
public static final int NUM_TEST_IMAGES = 10000;

View File

@ -249,7 +249,14 @@ public class ImageLoader extends BaseImageLoader {
* @throws IOException
*/
public INDArray asMatrix(File f) throws IOException {
return NDArrayUtil.toNDArray(fromFile(f));
return asMatrix(f, true);
}
@Override
public INDArray asMatrix(File f, boolean nchw) throws IOException {
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
return asMatrix(is, nchw);
}
}
/**
@ -259,34 +266,68 @@ public class ImageLoader extends BaseImageLoader {
* @return the input stream to convert
*/
public INDArray asMatrix(InputStream inputStream) throws IOException {
if (channels == 3)
return toBgr(inputStream);
try {
BufferedImage image = ImageIO.read(inputStream);
return asMatrix(image);
} catch (IOException e) {
throw new IOException("Unable to load image", e);
return asMatrix(inputStream, true);
}
@Override
public INDArray asMatrix(InputStream inputStream, boolean nchw) throws IOException {
INDArray ret;
if (channels == 3) {
ret = toBgr(inputStream);
} else {
try {
BufferedImage image = ImageIO.read(inputStream);
ret = asMatrix(image);
} catch (IOException e) {
throw new IOException("Unable to load image", e);
}
}
if(ret.rank() == 3){
ret = ret.reshape(1, ret.size(0), ret.size(1), ret.size(2));
}
if(!nchw)
ret = ret.permute(0,2,3,1); //NCHW to NHWC
return ret;
}
@Override
public org.datavec.image.data.Image asImageMatrix(File f) throws IOException {
return asImageMatrix(f, true);
}
@Override
public org.datavec.image.data.Image asImageMatrix(File f, boolean nchw) throws IOException {
try (BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f))) {
return asImageMatrix(bis);
return asImageMatrix(bis, nchw);
}
}
@Override
public org.datavec.image.data.Image asImageMatrix(InputStream inputStream) throws IOException {
if (channels == 3)
return toBgrImage(inputStream);
try {
BufferedImage image = ImageIO.read(inputStream);
INDArray asMatrix = asMatrix(image);
return new org.datavec.image.data.Image(asMatrix, image.getData().getNumBands(), image.getHeight(), image.getWidth());
} catch (IOException e) {
throw new IOException("Unable to load image", e);
return asImageMatrix(inputStream, true);
}
@Override
public org.datavec.image.data.Image asImageMatrix(InputStream inputStream, boolean nchw) throws IOException {
org.datavec.image.data.Image ret;
if (channels == 3) {
ret = toBgrImage(inputStream);
} else {
try {
BufferedImage image = ImageIO.read(inputStream);
INDArray asMatrix = asMatrix(image);
ret = new org.datavec.image.data.Image(asMatrix, image.getData().getNumBands(), image.getHeight(), image.getWidth());
} catch (IOException e) {
throw new IOException("Unable to load image", e);
}
}
if(ret.getImage().rank() == 3){
INDArray a = ret.getImage();
ret.setImage(a.reshape(1, a.size(0), a.size(1), a.size(2)));
}
if(!nchw)
ret.setImage(ret.getImage().permute(0,2,3,1)); //NCHW to NHWC
return ret;
}
/**

View File

@ -17,6 +17,7 @@
package org.datavec.image.loader;
import lombok.extern.slf4j.Slf4j;
import org.datavec.api.io.filters.BalancedPathFilter;
import org.datavec.api.io.labels.PathLabelGenerator;
import org.datavec.api.io.labels.PatternPathLabelGenerator;
@ -48,6 +49,7 @@ import java.util.Random;
* most images are in color, although a few are grayscale
*
*/
@Slf4j
public class LFWLoader extends BaseImageLoader implements Serializable {
public final static int NUM_IMAGES = 13233;
@ -270,19 +272,39 @@ public class LFWLoader extends BaseImageLoader implements Serializable {
throw new UnsupportedOperationException();
}
@Override
public INDArray asMatrix(File f, boolean nchw) throws IOException {
throw new UnsupportedOperationException();
}
@Override
public INDArray asMatrix(InputStream inputStream) throws IOException {
throw new UnsupportedOperationException();
}
@Override
public INDArray asMatrix(InputStream inputStream, boolean nchw) throws IOException {
throw new UnsupportedOperationException();
}
@Override
public Image asImageMatrix(File f) throws IOException {
throw new UnsupportedOperationException();
}
@Override
public Image asImageMatrix(File f, boolean nchw) throws IOException {
throw new UnsupportedOperationException();
}
@Override
public Image asImageMatrix(InputStream inputStream) throws IOException {
throw new UnsupportedOperationException();
}
@Override
public Image asImageMatrix(InputStream inputStream, boolean nchw) throws IOException {
throw new UnsupportedOperationException();
}
}

View File

@ -248,17 +248,27 @@ public class NativeImageLoader extends BaseImageLoader {
@Override
public INDArray asMatrix(File f) throws IOException {
return asMatrix(f, true);
}
@Override
public INDArray asMatrix(File f, boolean nchw) throws IOException {
try (BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f))) {
return asMatrix(bis);
return asMatrix(bis, nchw);
}
}
@Override
public INDArray asMatrix(InputStream is) throws IOException {
Mat mat = streamToMat(is);
return asMatrix(is, true);
}
@Override
public INDArray asMatrix(InputStream inputStream, boolean nchw) throws IOException {
Mat mat = streamToMat(inputStream);
INDArray a;
if (this.multiPageMode != null) {
a = asMatrix(mat.data(), mat.cols());
a = asMatrix(mat.data(), mat.cols());
}else{
Mat image = imdecode(mat, IMREAD_ANYDEPTH | IMREAD_ANYCOLOR);
if (image == null || image.empty()) {
@ -272,7 +282,11 @@ public class NativeImageLoader extends BaseImageLoader {
a = asMatrix(image);
image.deallocate();
}
return a;
if(nchw) {
return a;
} else {
return a.permute(0, 2, 3, 1); //NCHW to NHWC
}
}
/**
@ -331,19 +345,29 @@ public class NativeImageLoader extends BaseImageLoader {
}
public Image asImageMatrix(String filename) throws IOException {
return asImageMatrix(filename);
return asImageMatrix(new File(filename));
}
@Override
public Image asImageMatrix(File f) throws IOException {
return asImageMatrix(f, true);
}
@Override
public Image asImageMatrix(File f, boolean nchw) throws IOException {
try (BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f))) {
return asImageMatrix(bis);
return asImageMatrix(bis, nchw);
}
}
@Override
public Image asImageMatrix(InputStream is) throws IOException {
Mat mat = streamToMat(is);
return asImageMatrix(is, true);
}
@Override
public Image asImageMatrix(InputStream inputStream, boolean nchw) throws IOException {
Mat mat = streamToMat(inputStream);
Mat image = imdecode(mat, IMREAD_ANYDEPTH | IMREAD_ANYCOLOR);
if (image == null || image.empty()) {
PIX pix = pixReadMem(mat.data(), mat.cols());
@ -354,6 +378,8 @@ public class NativeImageLoader extends BaseImageLoader {
pixDestroy(pix);
}
INDArray a = asMatrix(image);
if(!nchw)
a = a.permute(0,2,3,1); //NCHW to NHWC
Image i = new Image(a, image.channels(), image.rows(), image.cols());
image.deallocate();

View File

@ -77,6 +77,8 @@ public abstract class BaseImageRecordReader extends BaseRecordReader {
protected int patternPosition = 0;
@Getter @Setter
protected boolean logLabelCountOnInit = true;
@Getter @Setter
protected boolean nchw_channels_first = true;
public final static String HEIGHT = NAME_SPACE + ".height";
public final static String WIDTH = NAME_SPACE + ".width";
@ -101,6 +103,11 @@ public abstract class BaseImageRecordReader extends BaseRecordReader {
protected BaseImageRecordReader(long height, long width, long channels, PathLabelGenerator labelGenerator,
PathMultiLabelGenerator labelMultiGenerator, ImageTransform imageTransform) {
this(height, width, channels, true, labelGenerator, labelMultiGenerator, imageTransform);
}
protected BaseImageRecordReader(long height, long width, long channels, boolean nchw_channels_first, PathLabelGenerator labelGenerator,
PathMultiLabelGenerator labelMultiGenerator, ImageTransform imageTransform) {
this.height = height;
this.width = width;
this.channels = channels;
@ -108,6 +115,7 @@ public abstract class BaseImageRecordReader extends BaseRecordReader {
this.labelMultiGenerator = labelMultiGenerator;
this.imageTransform = imageTransform;
this.appendLabel = (labelGenerator != null || labelMultiGenerator != null);
this.nchw_channels_first = nchw_channels_first;
}
protected boolean containsFormat(String format) {
@ -237,9 +245,13 @@ public abstract class BaseImageRecordReader extends BaseRecordReader {
return next();
try {
invokeListeners(image);
INDArray row = imageLoader.asMatrix(image);
Nd4j.getAffinityManager().ensureLocation(row, AffinityManager.Location.DEVICE);
ret = RecordConverter.toRecord(row);
INDArray array = imageLoader.asMatrix(image);
if(!nchw_channels_first){
array = array.permute(0,2,3,1); //NCHW to NHWC
}
Nd4j.getAffinityManager().ensureLocation(array, AffinityManager.Location.DEVICE);
ret = RecordConverter.toRecord(array);
if (appendLabel || writeLabel){
if(labelMultiGenerator != null){
ret.addAll(labelMultiGenerator.getLabels(image.getPath()));
@ -286,7 +298,7 @@ public abstract class BaseImageRecordReader extends BaseRecordReader {
@Override
public List<List<Writable>> next(int num) {
Preconditions.checkArgument(num > 0, "Number of examples must be > 0: got " + num);
Preconditions.checkArgument(num > 0, "Number of examples must be > 0: got %s", num);
if (imageLoader == null) {
imageLoader = new NativeImageLoader(height, width, channels, imageTransform);
@ -337,6 +349,9 @@ public abstract class BaseImageRecordReader extends BaseRecordReader {
throw new RuntimeException(e);
}
}
if(!nchw_channels_first){
features = features.permute(0,2,3,1); //NCHW to NHWC
}
Nd4j.getAffinityManager().ensureLocation(features, AffinityManager.Location.DEVICE);
@ -483,8 +498,10 @@ public abstract class BaseImageRecordReader extends BaseRecordReader {
if (imageLoader == null) {
imageLoader = new NativeImageLoader(height, width, channels, imageTransform);
}
INDArray row = imageLoader.asMatrix(dataInputStream);
List<Writable> ret = RecordConverter.toRecord(row);
INDArray array = imageLoader.asMatrix(dataInputStream);
if(!nchw_channels_first)
array = array.permute(0,2,3,1);
List<Writable> ret = RecordConverter.toRecord(array);
if (appendLabel)
ret.add(new IntWritable(labels.indexOf(getLabel(uri.getPath()))));
return ret;

View File

@ -34,47 +34,70 @@ import org.datavec.image.transform.ImageTransform;
public class ImageRecordReader extends BaseImageRecordReader {
/** Loads images with height = 28, width = 28, and channels = 1, appending no labels. */
/** Loads images with height = 28, width = 28, and channels = 1, appending no labels.
* Output format is NCHW (channels first) - [numExamples, 1, 28, 28]*/
public ImageRecordReader() {
super();
}
/** Loads images with given height, width, and channels, appending labels returned by the generator. */
/** Loads images with given height, width, and channels, appending labels returned by the generator.
* Output format is NCHW (channels first) - [numExamples, channels, height, width]
*/
public ImageRecordReader(long height, long width, long channels, PathLabelGenerator labelGenerator) {
super(height, width, channels, labelGenerator);
}
/** Loads images with given height, width, and channels, appending labels returned by the generator. */
/** Loads images with given height, width, and channels, appending labels returned by the generator.
* Output format is NCHW (channels first) - [numExamples, channels, height, width]
*/
public ImageRecordReader(long height, long width, long channels, PathMultiLabelGenerator labelGenerator) {
super(height, width, channels, labelGenerator);
}
/** Loads images with given height, width, and channels, appending no labels. */
/** Loads images with given height, width, and channels, appending no labels - in NCHW (channels first) format */
public ImageRecordReader(long height, long width, long channels) {
super(height, width, channels, (PathLabelGenerator) null);
}
/** Loads images with given height, width, and channels, appending labels returned by the generator. */
/** Loads images with given height, width, and channels, appending no labels - in specified format<br>
* If {@code nchw_channels_first == true} output format is NCHW (channels first) - [numExamples, channels, height, width]<br>
* If {@code nchw_channels_first == false} output format is NHWC (channels last) - [numExamples, height, width, channels]<br>
*/
public ImageRecordReader(long height, long width, long channels, boolean nchw_channels_first) {
super(height, width, channels, nchw_channels_first, null, null, null);
}
/** Loads images with given height, width, and channels, appending labels returned by the generator.
* Output format is NCHW (channels first) - [numExamples, channels, height, width] */
public ImageRecordReader(long height, long width, long channels, PathLabelGenerator labelGenerator,
ImageTransform imageTransform) {
super(height, width, channels, labelGenerator, imageTransform);
}
/** Loads images with given height, width, and channels, appending no labels. */
/** Loads images with given height, width, and channels, appending labels returned by the generator.<br>
* If {@code nchw_channels_first == true} output format is NCHW (channels first) - [numExamples, channels, height, width]<br>
* If {@code nchw_channels_first == false} output format is NHWC (channels last) - [numExamples, height, width, channels]<br>
*/
public ImageRecordReader(long height, long width, long channels, boolean nchw_channels_first, PathLabelGenerator labelGenerator,
ImageTransform imageTransform) {
super(height, width, channels, nchw_channels_first, labelGenerator, null, imageTransform);
}
/** Loads images with given height, width, and channels, appending no labels.
* Output format is NCHW (channels first) - [numExamples, channels, height, width]*/
public ImageRecordReader(long height, long width, long channels, ImageTransform imageTransform) {
super(height, width, channels, null, imageTransform);
}
/** Loads images with given height, width, and channels, appending labels returned by the generator. */
/** Loads images with given height, width, and channels, appending labels returned by the generator
* Output format is NCHW (channels first) - [numExamples, channels, height, width]*/
public ImageRecordReader(long height, long width, PathLabelGenerator labelGenerator) {
super(height, width, 1, labelGenerator);
}
/** Loads images with given height, width, and channels = 1, appending no labels. */
/** Loads images with given height, width, and channels = 1, appending no labels.
* Output format is NCHW (channels first) - [numExamples, channels, height, width]*/
public ImageRecordReader(long height, long width) {
super(height, width, 1, null, null);
}
}

View File

@ -16,10 +16,16 @@
package org.datavec.image.loader;
import org.datavec.image.data.Image;
import org.junit.Test;
import org.nd4j.common.resources.Resources;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.awt.image.BufferedImage;
import java.io.BufferedInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.InputStream;
import java.util.Random;
import static org.junit.Assert.assertEquals;
@ -208,4 +214,57 @@ public class TestImageLoader {
private BufferedImage makeRandomBufferedImage(boolean alpha) {
return makeRandomBufferedImage(alpha, rng.nextInt() % 100 + 100, rng.nextInt() % 100 + 100);
}
@Test
public void testNCHW_NHWC() throws Exception {
File f = Resources.asFile("datavec-data-image/voc/2007/JPEGImages/000005.jpg");
ImageLoader il = new ImageLoader(32, 32, 3);
//asMatrix(File, boolean)
INDArray a_nchw = il.asMatrix(f);
INDArray a_nchw2 = il.asMatrix(f, true);
INDArray a_nhwc = il.asMatrix(f, false);
assertEquals(a_nchw, a_nchw2);
assertEquals(a_nchw, a_nhwc.permute(0,3,1,2));
//asMatrix(InputStream, boolean)
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
a_nchw = il.asMatrix(is);
}
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
a_nchw2 = il.asMatrix(is, true);
}
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
a_nhwc = il.asMatrix(is, false);
}
assertEquals(a_nchw, a_nchw2);
assertEquals(a_nchw, a_nhwc.permute(0,3,1,2));
//asImageMatrix(File, boolean)
Image i_nchw = il.asImageMatrix(f);
Image i_nchw2 = il.asImageMatrix(f, true);
Image i_nhwc = il.asImageMatrix(f, false);
assertEquals(i_nchw.getImage(), i_nchw2.getImage());
assertEquals(i_nchw.getImage(), i_nhwc.getImage().permute(0,3,1,2)); //NHWC to NCHW
//asImageMatrix(InputStream, boolean)
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
i_nchw = il.asImageMatrix(is);
}
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
i_nchw2 = il.asImageMatrix(is, true);
}
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
i_nhwc = il.asImageMatrix(is, false);
}
assertEquals(i_nchw.getImage(), i_nchw2.getImage());
assertEquals(i_nchw.getImage(), i_nhwc.getImage().permute(0,3,1,2)); //NHWC to NCHW
}
}

View File

@ -24,20 +24,19 @@ import org.bytedeco.javacpp.indexer.UByteIndexer;
import org.bytedeco.javacv.Frame;
import org.bytedeco.javacv.Java2DFrameConverter;
import org.bytedeco.javacv.OpenCVFrameConverter;
import org.datavec.image.data.Image;
import org.datavec.image.data.ImageWritable;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.nd4j.common.resources.Resources;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.io.ClassPathResource;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.FileInputStream;
import java.io.InputStream;
import java.io.IOException;
import java.io.*;
import java.lang.reflect.Field;
import java.util.Random;
@ -604,4 +603,56 @@ public class TestNativeImageLoader {
}
}
@Test
public void testNCHW_NHWC() throws Exception {
File f = Resources.asFile("datavec-data-image/voc/2007/JPEGImages/000005.jpg");
NativeImageLoader il = new NativeImageLoader(32, 32, 3);
//asMatrix(File, boolean)
INDArray a_nchw = il.asMatrix(f);
INDArray a_nchw2 = il.asMatrix(f, true);
INDArray a_nhwc = il.asMatrix(f, false);
assertEquals(a_nchw, a_nchw2);
assertEquals(a_nchw, a_nhwc.permute(0,3,1,2));
//asMatrix(InputStream, boolean)
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
a_nchw = il.asMatrix(is);
}
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
a_nchw2 = il.asMatrix(is, true);
}
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
a_nhwc = il.asMatrix(is, false);
}
assertEquals(a_nchw, a_nchw2);
assertEquals(a_nchw, a_nhwc.permute(0,3,1,2));
//asImageMatrix(File, boolean)
Image i_nchw = il.asImageMatrix(f);
Image i_nchw2 = il.asImageMatrix(f, true);
Image i_nhwc = il.asImageMatrix(f, false);
assertEquals(i_nchw.getImage(), i_nchw2.getImage());
assertEquals(i_nchw.getImage(), i_nhwc.getImage().permute(0,3,1,2)); //NHWC to NCHW
//asImageMatrix(InputStream, boolean)
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
i_nchw = il.asImageMatrix(is);
}
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
i_nchw2 = il.asImageMatrix(is, true);
}
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
i_nhwc = il.asImageMatrix(is, false);
}
assertEquals(i_nchw.getImage(), i_nchw2.getImage());
assertEquals(i_nchw.getImage(), i_nhwc.getImage().permute(0,3,1,2)); //NHWC to NCHW
}
}

View File

@ -35,13 +35,13 @@ import org.datavec.api.writable.batch.NDArrayRecordBatch;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.nd4j.common.resources.Resources;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.io.ClassPathResource;
import java.io.File;
import java.io.IOException;
import java.io.*;
import java.net.URI;
import java.util.ArrayList;
import java.util.Arrays;
@ -467,5 +467,87 @@ public class TestImageRecordReader {
return count;
}
}
@Test
public void testNCHW_NCHW() throws Exception {
//Idea: labels order should be consistent regardless of input file order
File f0 = testDir.newFolder();
new ClassPathResource("datavec-data-image/testimages/").copyDirectory(f0);
FileSplit fs0 = new FileSplit(f0, new Random(12345));
FileSplit fs1 = new FileSplit(f0, new Random(12345));
assertEquals(6, fs0.locations().length);
assertEquals(6, fs1.locations().length);
ImageRecordReader nchw = new ImageRecordReader(32, 32, 3, true);
nchw.initialize(fs0);
ImageRecordReader nhwc = new ImageRecordReader(32, 32, 3, false);
nhwc.initialize(fs1);
while(nchw.hasNext()){
assertTrue(nhwc.hasNext());
List<Writable> l_nchw = nchw.next();
List<Writable> l_nhwc = nhwc.next();
INDArray a_nchw = ((NDArrayWritable)l_nchw.get(0)).get();
INDArray a_nhwc = ((NDArrayWritable)l_nhwc.get(0)).get();
assertArrayEquals(new long[]{1, 3, 32, 32}, a_nchw.shape());
assertArrayEquals(new long[]{1, 32, 32, 3}, a_nhwc.shape());
INDArray permuted = a_nhwc.permute(0,3,1,2); //NHWC to NCHW
assertEquals(a_nchw, permuted);
}
//Test batch:
nchw.reset();
nhwc.reset();
int batchCount = 0;
while(nchw.hasNext()){
assertTrue(nhwc.hasNext());
batchCount++;
List<List<Writable>> l_nchw = nchw.next(3);
List<List<Writable>> l_nhwc = nhwc.next(3);
assertEquals(3, l_nchw.size());
assertEquals(3, l_nhwc.size());
NDArrayRecordBatch b_nchw = (NDArrayRecordBatch)l_nchw;
NDArrayRecordBatch b_nhwc = (NDArrayRecordBatch)l_nhwc;
INDArray a_nchw = b_nchw.getArrays().get(0);
INDArray a_nhwc = b_nhwc.getArrays().get(0);
assertArrayEquals(new long[]{3, 3, 32, 32}, a_nchw.shape());
assertArrayEquals(new long[]{3, 32, 32, 3}, a_nhwc.shape());
INDArray permuted = a_nhwc.permute(0,3,1,2); //NHWC to NCHW
assertEquals(a_nchw, permuted);
}
assertEquals(2, batchCount);
//Test record(URI, DataInputStream)
URI u = fs0.locations()[0];
try(DataInputStream dis = new DataInputStream(new BufferedInputStream(new FileInputStream(new File(u))))) {
List<Writable> l = nchw.record(u, dis);
INDArray arr = ((NDArrayWritable)l.get(0)).get();
assertArrayEquals(new long[]{1, 3, 32, 32}, arr.shape());
}
try(DataInputStream dis = new DataInputStream(new BufferedInputStream(new FileInputStream(new File(u))))) {
List<Writable> l = nhwc.record(u, dis);
INDArray arr = ((NDArrayWritable)l.get(0)).get();
assertArrayEquals(new long[]{1, 32, 32, 3}, arr.shape());
}
}
}

View File

@ -8,12 +8,14 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.Ignore;
import org.junit.Test;
import org.nd4j.common.resources.Resources;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.RmsProp;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.nio.file.Files;
import java.util.concurrent.CountDownLatch;
@Ignore

View File

@ -18,11 +18,9 @@ package org.deeplearning4j.nn.layers.convolution;
import lombok.*;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.TestUtils;
import org.deeplearning4j.exception.DL4JInvalidInputException;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.*;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.layers.CnnLossLayer;
@ -35,6 +33,7 @@ import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.util.ConvolutionUtils;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
@ -49,6 +48,7 @@ import java.util.List;
import java.util.Map;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
@RunWith(Parameterized.class)
public class ConvDataFormatTests extends BaseDL4JTest {
@ -971,4 +971,58 @@ public class ConvDataFormatTests extends BaseDL4JTest {
return null;
}
}
@Test
public void testWrongFormatIn(){
for(CNN2DFormat df : CNN2DFormat.values()){
for(int i=0; i<4; i++ ){
NeuralNetConfiguration.ListBuilder b = new NeuralNetConfiguration.Builder()
.list();
switch (i){
case 0:
b.layer(new ConvolutionLayer.Builder().kernelSize(2,2).nIn(3).nOut(3).dataFormat(df).build());
break;
case 1:
b.layer(new DepthwiseConvolution2D.Builder().kernelSize(2,2).nIn(3).nOut(3).dataFormat(df).build());
break;
case 2:
b.layer(new Deconvolution2D.Builder().dataFormat(df).kernelSize(2,2).nIn(3).nOut(3).build());
break;
case 3:
b.layer(new SeparableConvolution2D.Builder().dataFormat(df).kernelSize(2,2).nIn(3).nOut(3).build());
break;
}
MultiLayerNetwork net = new MultiLayerNetwork(b.build());
net.init();
INDArray in;
INDArray wrongFormatIn;
if(df == CNN2DFormat.NCHW){
in = Nd4j.create(DataType.FLOAT, 5, 3, 12, 12);
wrongFormatIn = Nd4j.create(DataType.FLOAT, 5, 12, 12, 3);
} else {
in = Nd4j.create(DataType.FLOAT, 5, 12, 12, 3);
wrongFormatIn = Nd4j.create(DataType.FLOAT, 5, 3, 12, 12);
}
net.output(in);
try {
net.output(wrongFormatIn);
} catch (DL4JInvalidInputException e){
// e.printStackTrace();
String msg = e.getMessage();
assertTrue(msg, msg.contains(ConvolutionUtils.NCHW_NHWC_ERROR_MSG));
}
}
}
}
}

View File

@ -21,9 +21,9 @@ import lombok.val;
import org.apache.commons.lang3.RandomUtils;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.optimize.solvers.accumulation.SmartFancyBlockingQueue;
import org.deeplearning4j.core.util.ThreadUtils;
import org.junit.Ignore;
import org.junit.Test;
import org.nd4j.common.util.ThreadUtils;
import org.nd4j.linalg.factory.Nd4j;
import java.util.ArrayList;

View File

@ -31,11 +31,6 @@
<dependencies>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-util</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-api</artifactId>

View File

@ -1,75 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.datasets.iterator.impl;
import org.deeplearning4j.util.MovingWindowMatrix;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.fetcher.BaseDataFetcher;
import org.nd4j.common.util.ArrayUtil;
import java.util.ArrayList;
import java.util.List;
/**
*
* Moving window data fetcher. Handles rotation of matrices in all directions
* to generate more examples.
*
*
* @author Adam Gibson
*/
public class MovingWindowDataSetFetcher extends BaseDataFetcher {
private DataSet data;
private int windowRows = 28, windowColumns = 28;
private int cursor = 0;
public MovingWindowDataSetFetcher(DataSet data, int windowRows, int windowColumns) {
this.data = data;
this.windowRows = windowRows;
this.windowColumns = windowColumns;
List<DataSet> list = data.asList();
List<DataSet> flipped = new ArrayList<>();
for (int i = 0; i < list.size(); i++) {
INDArray label = list.get(i).getLabels();
List<INDArray> windows =
new MovingWindowMatrix(list.get(i).getFeatures(), windowRows, windowColumns, true)
.windows(true);
for (int j = 0; j < windows.size(); j++) {
flipped.add(new DataSet(windows.get(j), label));
}
flipped.add(list.get(i));
}
this.data = DataSet.merge(flipped);
}
/**
* Fetches the next dataset. You need to call this
* to get a new dataset, otherwise {@link #next()}
* just returns the last data applyTransformToDestination fetch
*
* @param numExamples the number of examples to fetch
*/
@Override
public void fetch(int numExamples) {
initializeCurrFromList(data.get(ArrayUtil.range(cursor, cursor + numExamples)).asList());
}
}

View File

@ -1,5 +1,6 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -28,6 +29,7 @@ import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.util.ArrayList;
@ -45,7 +47,7 @@ import static org.deeplearning4j.nn.modelimport.keras.utils.KerasLossUtils.mapLo
public class KerasLoss extends KerasLayer {
private final String KERAS_CLASS_NAME_LOSS = "Loss";
private LossFunctions.LossFunction loss;
private ILossFunction loss;
/**
@ -86,7 +88,7 @@ public class KerasLoss extends KerasLayer {
if (enforceTrainingConfig)
throw e;
log.warn("Unsupported Keras loss function. Replacing with MSE.");
loss = LossFunctions.LossFunction.SQUARED_LOSS;
loss = LossFunctions.LossFunction.SQUARED_LOSS.getILossFunction();
}
}

View File

@ -1,5 +1,6 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -19,8 +20,13 @@ package org.deeplearning4j.nn.modelimport.keras.utils;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.util.HashMap;
import java.util.Map;
/**
* Utility functionality for keras loss functions
*
@ -28,13 +34,33 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
*/
@Slf4j
public class KerasLossUtils {
static final Map<String, ILossFunction> customLoss = new HashMap<>();
/**
* Register a custom loss function
*
* @param lossName name of the lambda layer in the serialized Keras model
* @param lossFunction SameDiffLambdaLayer instance to map to Keras Lambda layer
*/
public static void registerCustomLoss(String lossName, ILossFunction lossFunction) {
customLoss.put(lossName, lossFunction);
}
/**
* Clear all lambda layers
*
*/
public static void clearCustomLoss() {
customLoss.clear();
}
/**
* Map Keras to DL4J loss functions.
*
* @param kerasLoss String containing Keras loss function name
* @return String containing DL4J loss function
*/
public static LossFunctions.LossFunction mapLossFunction(String kerasLoss, KerasLayerConfiguration conf)
public static ILossFunction mapLossFunction(String kerasLoss, KerasLayerConfiguration conf)
throws UnsupportedKerasConfigurationException {
LossFunctions.LossFunction dl4jLoss;
if (kerasLoss.equals(conf.getKERAS_LOSS_MEAN_SQUARED_ERROR()) ||
@ -67,8 +93,13 @@ public class KerasLossUtils {
} else if (kerasLoss.equals(conf.getKERAS_LOSS_COSINE_PROXIMITY())) {
dl4jLoss = LossFunctions.LossFunction.COSINE_PROXIMITY;
} else {
throw new UnsupportedKerasConfigurationException("Unknown Keras loss function " + kerasLoss);
ILossFunction lossClass = customLoss.get(kerasLoss);
if(lossClass != null){
return lossClass;
}else{
throw new UnsupportedKerasConfigurationException("Unknown Keras loss function " + kerasLoss);
}
}
return dl4jLoss;
return dl4jLoss.getILossFunction();
}
}

View File

@ -0,0 +1,78 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.nn.modelimport.keras.e2e;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.nn.modelimport.keras.KerasSequentialModel;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasLossUtils;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.resources.Resources;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.SameDiffLoss;
import java.io.File;
import java.io.InputStream;
import java.nio.file.Files;
import java.nio.file.StandardCopyOption;
/**
* Test importing Keras models with custom loss.
*
* @author Paul Dubs
*/
public class KerasCustomLossTest extends BaseDL4JTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
public class LogCosh extends SameDiffLoss {
@Override
public SDVariable defineLoss(SameDiff sd, SDVariable layerInput, SDVariable labels) {
return sd.math.log(sd.math.cosh(labels.sub(layerInput)));
}
}
@Test
public void testSequentialLambdaLayerImport() throws Exception {
KerasLossUtils.registerCustomLoss("logcosh", new LogCosh());
String modelPath = "modelimport/keras/examples/custom_loss.h5";
try(InputStream is = Resources.asStream(modelPath)) {
File modelFile = testDir.newFile("tempModel" + System.currentTimeMillis() + ".h5");
Files.copy(is, modelFile.toPath(), StandardCopyOption.REPLACE_EXISTING);
MultiLayerNetwork model = new KerasSequentialModel().modelBuilder().modelHdf5Filename(modelFile.getAbsolutePath())
.enforceTrainingConfig(true).buildSequential().getMultiLayerNetwork();
System.out.println(model.summary());
INDArray input = Nd4j.create(new int[]{10, 3});
model.output(input);
} finally {
KerasLossUtils.clearCustomLoss();
}
}
}

View File

@ -29,7 +29,7 @@ import org.deeplearning4j.text.sentenceiterator.PrefetchingSentenceIterator;
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
import org.deeplearning4j.text.sentenceiterator.SynchronizedSentenceIterator;
import org.deeplearning4j.common.util.DL4JFileUtils;
import org.deeplearning4j.core.util.ThreadUtils;
import org.nd4j.common.util.ThreadUtils;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.primitives.Pair;
import org.slf4j.Logger;

View File

@ -47,7 +47,7 @@ import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
import org.deeplearning4j.text.sentenceiterator.interoperability.SentenceIteratorConverter;
import org.deeplearning4j.text.sentenceiterator.labelaware.LabelAwareSentenceIterator;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.deeplearning4j.core.util.ThreadUtils;
import org.nd4j.common.util.ThreadUtils;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;

View File

@ -47,7 +47,7 @@ import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.models.word2vec.wordstore.VocabConstructor;
import org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache;
import org.deeplearning4j.core.util.ThreadUtils;
import org.nd4j.common.util.ThreadUtils;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
import org.nd4j.linalg.api.memory.enums.LearningPolicy;
import org.nd4j.linalg.api.ndarray.INDArray;

View File

@ -27,7 +27,7 @@ import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.deeplearning4j.models.word2vec.Huffman;
import org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache;
import org.deeplearning4j.text.invertedindex.InvertedIndex;
import org.deeplearning4j.core.util.ThreadUtils;
import org.nd4j.common.util.ThreadUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.threadly.concurrent.PriorityScheduler;

View File

@ -18,7 +18,7 @@ package org.deeplearning4j.text.sentenceiterator;
import lombok.NonNull;
import org.deeplearning4j.core.util.ThreadUtils;
import org.nd4j.common.util.ThreadUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

View File

@ -63,6 +63,9 @@ public class Deconvolution2D extends ConvolutionLayer {
protected Deconvolution2D(BaseConvBuilder<?> builder) {
super(builder);
initializeConstraints(builder);
if(builder instanceof Builder){
this.cnn2dDataFormat = ((Builder) builder).format;
}
}
public boolean hasBias() {
@ -136,7 +139,7 @@ public class Deconvolution2D extends ConvolutionLayer {
private CNN2DFormat format = CNN2DFormat.NCHW;
public Builder format(CNN2DFormat format){
public Builder dataFormat(CNN2DFormat format){
this.format = format;
return this;
}

View File

@ -310,11 +310,21 @@ public class ConvolutionLayer extends BaseLayer<org.deeplearning4j.nn.conf.layer
String layerName = conf.getLayer().getLayerName();
if (layerName == null)
layerName = "(not named)";
throw new DL4JInvalidInputException("Cannot do forward pass in Convolution layer (layer name = " + layerName
String s = "Cannot do forward pass in Convolution layer (layer name = " + layerName
+ ", layer index = " + index + "): input array channels does not match CNN layer configuration"
+ " (data input channels = " + input.size(dim) + ", " + layerConf().getCnn2dDataFormat().dimensionNames()
+ " (data format = " + format + ", data input channels = " + input.size(dim) + ", " + layerConf().getCnn2dDataFormat().dimensionNames()
+ "=" + Arrays.toString(input.shape()) + "; expected" + " input channels = " + inDepth + ") "
+ layerId());
+ layerId();
int dimIfWrongFormat = format == CNN2DFormat.NHWC ? 1 : 3;
if(input.size(dimIfWrongFormat) == inDepth){
//User might have passed NCHW data to a NHWC net, or vice versa?
s += "\n" + ConvolutionUtils.NCHW_NHWC_ERROR_MSG;
}
throw new DL4JInvalidInputException(s);
}
}

View File

@ -190,12 +190,21 @@ public class Deconvolution2DLayer extends ConvolutionLayer {
String layerName = conf.getLayer().getLayerName();
if (layerName == null)
layerName = "(not named)";
throw new DL4JInvalidInputException("Cannot do forward pass in Deconvolution2D layer (layer name = " + layerName
String s = "Cannot do forward pass in Deconvolution2D layer (layer name = " + layerName
+ ", layer index = " + index + "): input array channels does not match CNN layer configuration"
+ " (data input channels = " + input.size(cDim) + ", "
+ " (data format = " + format + ", data input channels = " + input.size(cDim) + ", "
+ (nchw ? "[minibatch,inputDepth,height,width]" : "[minibatch,height,width,inputDepth]") + "="
+ Arrays.toString(input.shape()) + "; expected" + " input channels = " + inDepth + ") "
+ layerId());
+ layerId();
int dimIfWrongFormat = format == CNN2DFormat.NHWC ? 1 : 3;
if(input.size(dimIfWrongFormat) == inDepth){
//User might have passed NCHW data to a NHWC net, or vice versa?
s += "\n" + ConvolutionUtils.NCHW_NHWC_ERROR_MSG;
}
throw new DL4JInvalidInputException(s);
}
int kH = (int) weights.size(2);
int kW = (int) weights.size(3);

View File

@ -183,13 +183,21 @@ public class DepthwiseConvolution2DLayer extends ConvolutionLayer {
String layerName = conf.getLayer().getLayerName();
if (layerName == null)
layerName = "(not named)";
throw new DL4JInvalidInputException("Cannot do forward pass in DepthwiseConvolution2D layer " +
String s = "Cannot do forward pass in DepthwiseConvolution2D layer " +
"(layer name = " + layerName
+ ", layer index = " + index + "): input array channels does not match CNN layer configuration"
+ " (data input channels = " + input.size(1) + ", "
+ " (data format = " + format + ", data input channels = " + input.size(1) + ", "
+ (nchw ? "[minibatch,inputDepth,height,width]=" : "[minibatch,height,width,inputDepth]=")
+ Arrays.toString(input.shape()) + "; expected" + " input channels = " + inDepth + ") "
+ layerId());
+ layerId();
int dimIfWrongFormat = format == CNN2DFormat.NHWC ? 1 : 3;
if(input.size(dimIfWrongFormat) == inDepth){
//User might have passed NCHW data to a NHWC net, or vice versa?
s += "\n" + ConvolutionUtils.NCHW_NHWC_ERROR_MSG;
}
throw new DL4JInvalidInputException(s);
}
int kH = (int) depthWiseWeights.size(0);
int kW = (int) depthWiseWeights.size(1);

View File

@ -211,11 +211,20 @@ public class SeparableConvolution2DLayer extends ConvolutionLayer {
String layerName = conf.getLayer().getLayerName();
if (layerName == null)
layerName = "(not named)";
throw new DL4JInvalidInputException("Cannot do forward pass in SeparableConvolution2D layer (layer name = " + layerName
String s = "Cannot do forward pass in SeparableConvolution2D layer (layer name = " + layerName
+ ", layer index = " + index + "): input array channels does not match CNN layer configuration"
+ " (data input channels = " + input.size(1) + ", [minibatch,inputDepth,height,width]="
+ " (data format = " + format + ", data input channels = " + input.size(1) + ", [minibatch,inputDepth,height,width]="
+ Arrays.toString(input.shape()) + "; expected" + " input channels = " + inDepth + ") "
+ layerId());
+ layerId();
int dimIfWrongFormat = format == CNN2DFormat.NHWC ? 1 : 3;
if(input.size(dimIfWrongFormat) == inDepth){
//User might have passed NCHW data to a NHWC net, or vice versa?
s += "\n" + ConvolutionUtils.NCHW_NHWC_ERROR_MSG;
}
throw new DL4JInvalidInputException(s);
}
int kH = (int) depthWiseWeights.size(2);
int kW = (int) depthWiseWeights.size(3);

View File

@ -20,7 +20,7 @@ import lombok.*;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.optimize.api.BaseTrainingListener;
import org.deeplearning4j.util.ThreadUtils;
import org.nd4j.common.util.ThreadUtils;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.io.Serializable;

View File

@ -27,8 +27,8 @@ import org.deeplearning4j.optimize.solvers.accumulation.encoding.ResidualPostPro
import org.deeplearning4j.optimize.solvers.accumulation.encoding.ThresholdAlgorithm;
import org.deeplearning4j.optimize.solvers.accumulation.encoding.residual.ResidualClippingPostProcessor;
import org.deeplearning4j.optimize.solvers.accumulation.encoding.threshold.AdaptiveThresholdAlgorithm;
import org.deeplearning4j.util.ThreadUtils;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.util.ThreadUtils;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
import org.nd4j.linalg.api.memory.enums.*;

View File

@ -18,6 +18,7 @@ package org.deeplearning4j.optimize.solvers.accumulation;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.common.util.ThreadUtils;
import java.util.Collection;
import java.util.Iterator;
@ -28,8 +29,6 @@ import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import org.deeplearning4j.util.ThreadUtils;
/**
* This BlockingQueue implementation is suited only for symmetric gradients updates, and should NOT be used anywhere else.
*

View File

@ -48,6 +48,13 @@ import java.util.Arrays;
*/
public class ConvolutionUtils {
public static final String NCHW_NHWC_ERROR_MSG = "Note: Convolution layers can be configured for either NCHW (channels first)" +
" or NHWC (channels last) format for input images and activations.\n" +
"Layers can be configured using .dataFormat(CNN2DFormat.NCHW/NHWC) when constructing the layer, or for the entire net using" +
" .setInputType(InputType.convolutional(height, width, depth, CNN2DForman.NCHW/NHWC)).\n" +
"ImageRecordReader and NativeImageLoader can also be configured to load image data in either NCHW or NHWC format which must match the network";
private static final int[] ONES = new int[]{1, 1};

View File

@ -1,5 +1,6 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@ -16,6 +17,7 @@
package org.nd4j.linalg.dataset;
import lombok.NonNull;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
@ -43,7 +45,7 @@ public class ExistingMiniBatchDataSetIterator implements DataSetIterator {
* Create with the given root directory, using the default filename pattern {@link #DEFAULT_PATTERN}
* @param rootDir the root directory to use
*/
public ExistingMiniBatchDataSetIterator(File rootDir) {
public ExistingMiniBatchDataSetIterator(@NonNull File rootDir) {
this(rootDir, DEFAULT_PATTERN);
}
@ -53,7 +55,7 @@ public class ExistingMiniBatchDataSetIterator implements DataSetIterator {
* @param pattern The filename pattern to use. Used with {@code String.format(pattern,idx)}, where idx is an
* integer, starting at 0.
*/
public ExistingMiniBatchDataSetIterator(File rootDir, String pattern) {
public ExistingMiniBatchDataSetIterator(@NonNull File rootDir, String pattern) {
this.rootDir = rootDir;
totalBatches = rootDir.list().length;
this.pattern = pattern;

View File

@ -38,7 +38,7 @@ import java.util.Map;
*/
public abstract class SameDiffLoss implements ILossFunction {
protected transient SameDiff sd;
protected transient SDVariable scoreVariable;
protected transient SDVariable scorePerExampleVariable;
protected SameDiffLoss() {
@ -60,7 +60,8 @@ public abstract class SameDiffLoss implements ILossFunction {
sd = SameDiff.create();
SDVariable layerInput = sd.placeHolder("layerInput", dataType, -1);
SDVariable labels = sd.placeHolder("labels", dataType, -1);
scoreVariable = this.defineLoss(sd, layerInput, labels);
scorePerExampleVariable = this.defineLoss(sd, layerInput, labels);
scorePerExampleVariable.markAsLoss();
sd.createGradFunction("layerInput");
}
@ -112,7 +113,7 @@ public abstract class SameDiffLoss implements ILossFunction {
m.put("labels", labels);
m.put("layerInput", output);
INDArray scoreArr = sd.outputSingle(m,scoreVariable.name());
INDArray scoreArr = sd.outputSingle(m, scorePerExampleVariable.name());
if (mask != null) {
LossUtil.applyMask(scoreArr, mask);

View File

@ -0,0 +1,29 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.common.util;
public class ThreadUtils {
private ThreadUtils(){ }
public static void uncheckedSleep(long sleepTimeMs){
try{
Thread.sleep(sleepTimeMs);
} catch (InterruptedException e){ }
}
}