commit
37d880a23c
|
@ -52,6 +52,7 @@ public class NDArrayRecordBatch extends AbstractWritableRecordBatch {
|
|||
public NDArrayRecordBatch(@NonNull List<INDArray> arrays){
|
||||
Preconditions.checkArgument(arrays.size() > 0, "Input list must not be empty");
|
||||
this.arrays = arrays;
|
||||
this.size = arrays.get(0).size(0);
|
||||
|
||||
//Check that dimension 0 matches:
|
||||
if(arrays.size() > 1){
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.datavec.image.loader;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.io.FileUtils;
|
||||
import org.datavec.image.data.Image;
|
||||
import org.datavec.image.transform.ImageTransform;
|
||||
|
@ -35,10 +36,9 @@ import java.util.Random;
|
|||
/**
|
||||
* Created by nyghtowl on 12/17/15.
|
||||
*/
|
||||
@Slf4j
|
||||
public abstract class BaseImageLoader implements Serializable {
|
||||
|
||||
protected static final Logger log = LoggerFactory.getLogger(BaseImageLoader.class);
|
||||
|
||||
public enum MultiPageMode {
|
||||
MINIBATCH, FIRST //, CHANNELS,
|
||||
}
|
||||
|
@ -62,13 +62,37 @@ public abstract class BaseImageLoader implements Serializable {
|
|||
|
||||
public abstract INDArray asRowVector(InputStream inputStream) throws IOException;
|
||||
|
||||
/** As per {@link #asMatrix(File, boolean)} but NCHW/channels_first format */
|
||||
public abstract INDArray asMatrix(File f) throws IOException;
|
||||
|
||||
/**
|
||||
* Load an image from a file to an INDArray
|
||||
* @param f File to load the image from
|
||||
* @param nchw If true: return image in NCHW/channels_first [1, channels, height width] format; if false, return
|
||||
* in NHWC/channels_last [1, height, width, channels] format
|
||||
* @return Image file as as INDArray
|
||||
*/
|
||||
public abstract INDArray asMatrix(File f, boolean nchw) throws IOException;
|
||||
|
||||
public abstract INDArray asMatrix(InputStream inputStream) throws IOException;
|
||||
/**
|
||||
* Load an image file from an input stream to an INDArray
|
||||
* @param inputStream Input stream to load the image from
|
||||
* @param nchw If true: return image in NCHW/channels_first [1, channels, height width] format; if false, return
|
||||
* in NHWC/channels_last [1, height, width, channels] format
|
||||
* @return Image file stream as as INDArray
|
||||
*/
|
||||
public abstract INDArray asMatrix(InputStream inputStream, boolean nchw) throws IOException;
|
||||
|
||||
/** As per {@link #asMatrix(File)} but as an {@link Image}*/
|
||||
public abstract Image asImageMatrix(File f) throws IOException;
|
||||
/** As per {@link #asMatrix(File, boolean)} but as an {@link Image}*/
|
||||
public abstract Image asImageMatrix(File f, boolean nchw) throws IOException;
|
||||
|
||||
/** As per {@link #asMatrix(InputStream)} but as an {@link Image}*/
|
||||
public abstract Image asImageMatrix(InputStream inputStream) throws IOException;
|
||||
/** As per {@link #asMatrix(InputStream, boolean)} but as an {@link Image}*/
|
||||
public abstract Image asImageMatrix(InputStream inputStream, boolean nchw) throws IOException;
|
||||
|
||||
|
||||
public static void downloadAndUntar(Map urlMap, File fullDir) {
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.datavec.image.loader;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.io.FileUtils;
|
||||
import org.apache.commons.io.FilenameUtils;
|
||||
import org.bytedeco.javacv.OpenCVFrameConverter;
|
||||
|
@ -47,6 +48,7 @@ import static org.bytedeco.opencv.global.opencv_imgproc.*;
|
|||
* There is a special preProcessor used to normalize the dataset based on Sergey Zagoruyko example
|
||||
* <a href="https://github.com/szagoruyko/cifar.torch">https://github.com/szagoruyko/cifar.torch</a>
|
||||
*/
|
||||
@Slf4j
|
||||
public class CifarLoader extends NativeImageLoader implements Serializable {
|
||||
public static final int NUM_TRAIN_IMAGES = 50000;
|
||||
public static final int NUM_TEST_IMAGES = 10000;
|
||||
|
|
|
@ -249,7 +249,14 @@ public class ImageLoader extends BaseImageLoader {
|
|||
* @throws IOException
|
||||
*/
|
||||
public INDArray asMatrix(File f) throws IOException {
|
||||
return NDArrayUtil.toNDArray(fromFile(f));
|
||||
return asMatrix(f, true);
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray asMatrix(File f, boolean nchw) throws IOException {
|
||||
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
|
||||
return asMatrix(is, nchw);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -259,34 +266,68 @@ public class ImageLoader extends BaseImageLoader {
|
|||
* @return the input stream to convert
|
||||
*/
|
||||
public INDArray asMatrix(InputStream inputStream) throws IOException {
|
||||
if (channels == 3)
|
||||
return toBgr(inputStream);
|
||||
try {
|
||||
BufferedImage image = ImageIO.read(inputStream);
|
||||
return asMatrix(image);
|
||||
} catch (IOException e) {
|
||||
throw new IOException("Unable to load image", e);
|
||||
return asMatrix(inputStream, true);
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray asMatrix(InputStream inputStream, boolean nchw) throws IOException {
|
||||
INDArray ret;
|
||||
if (channels == 3) {
|
||||
ret = toBgr(inputStream);
|
||||
} else {
|
||||
try {
|
||||
BufferedImage image = ImageIO.read(inputStream);
|
||||
ret = asMatrix(image);
|
||||
} catch (IOException e) {
|
||||
throw new IOException("Unable to load image", e);
|
||||
}
|
||||
}
|
||||
if(ret.rank() == 3){
|
||||
ret = ret.reshape(1, ret.size(0), ret.size(1), ret.size(2));
|
||||
}
|
||||
if(!nchw)
|
||||
ret = ret.permute(0,2,3,1); //NCHW to NHWC
|
||||
return ret;
|
||||
}
|
||||
|
||||
@Override
|
||||
public org.datavec.image.data.Image asImageMatrix(File f) throws IOException {
|
||||
return asImageMatrix(f, true);
|
||||
}
|
||||
|
||||
@Override
|
||||
public org.datavec.image.data.Image asImageMatrix(File f, boolean nchw) throws IOException {
|
||||
try (BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f))) {
|
||||
return asImageMatrix(bis);
|
||||
return asImageMatrix(bis, nchw);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public org.datavec.image.data.Image asImageMatrix(InputStream inputStream) throws IOException {
|
||||
if (channels == 3)
|
||||
return toBgrImage(inputStream);
|
||||
try {
|
||||
BufferedImage image = ImageIO.read(inputStream);
|
||||
INDArray asMatrix = asMatrix(image);
|
||||
return new org.datavec.image.data.Image(asMatrix, image.getData().getNumBands(), image.getHeight(), image.getWidth());
|
||||
} catch (IOException e) {
|
||||
throw new IOException("Unable to load image", e);
|
||||
return asImageMatrix(inputStream, true);
|
||||
}
|
||||
|
||||
@Override
|
||||
public org.datavec.image.data.Image asImageMatrix(InputStream inputStream, boolean nchw) throws IOException {
|
||||
org.datavec.image.data.Image ret;
|
||||
if (channels == 3) {
|
||||
ret = toBgrImage(inputStream);
|
||||
} else {
|
||||
try {
|
||||
BufferedImage image = ImageIO.read(inputStream);
|
||||
INDArray asMatrix = asMatrix(image);
|
||||
ret = new org.datavec.image.data.Image(asMatrix, image.getData().getNumBands(), image.getHeight(), image.getWidth());
|
||||
} catch (IOException e) {
|
||||
throw new IOException("Unable to load image", e);
|
||||
}
|
||||
}
|
||||
if(ret.getImage().rank() == 3){
|
||||
INDArray a = ret.getImage();
|
||||
ret.setImage(a.reshape(1, a.size(0), a.size(1), a.size(2)));
|
||||
}
|
||||
if(!nchw)
|
||||
ret.setImage(ret.getImage().permute(0,2,3,1)); //NCHW to NHWC
|
||||
return ret;
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
package org.datavec.image.loader;
|
||||
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.datavec.api.io.filters.BalancedPathFilter;
|
||||
import org.datavec.api.io.labels.PathLabelGenerator;
|
||||
import org.datavec.api.io.labels.PatternPathLabelGenerator;
|
||||
|
@ -48,6 +49,7 @@ import java.util.Random;
|
|||
* most images are in color, although a few are grayscale
|
||||
*
|
||||
*/
|
||||
@Slf4j
|
||||
public class LFWLoader extends BaseImageLoader implements Serializable {
|
||||
|
||||
public final static int NUM_IMAGES = 13233;
|
||||
|
@ -270,19 +272,39 @@ public class LFWLoader extends BaseImageLoader implements Serializable {
|
|||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray asMatrix(File f, boolean nchw) throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray asMatrix(InputStream inputStream) throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray asMatrix(InputStream inputStream, boolean nchw) throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Image asImageMatrix(File f) throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Image asImageMatrix(File f, boolean nchw) throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Image asImageMatrix(InputStream inputStream) throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Image asImageMatrix(InputStream inputStream, boolean nchw) throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -248,17 +248,27 @@ public class NativeImageLoader extends BaseImageLoader {
|
|||
|
||||
@Override
|
||||
public INDArray asMatrix(File f) throws IOException {
|
||||
return asMatrix(f, true);
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray asMatrix(File f, boolean nchw) throws IOException {
|
||||
try (BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f))) {
|
||||
return asMatrix(bis);
|
||||
return asMatrix(bis, nchw);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray asMatrix(InputStream is) throws IOException {
|
||||
Mat mat = streamToMat(is);
|
||||
return asMatrix(is, true);
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray asMatrix(InputStream inputStream, boolean nchw) throws IOException {
|
||||
Mat mat = streamToMat(inputStream);
|
||||
INDArray a;
|
||||
if (this.multiPageMode != null) {
|
||||
a = asMatrix(mat.data(), mat.cols());
|
||||
a = asMatrix(mat.data(), mat.cols());
|
||||
}else{
|
||||
Mat image = imdecode(mat, IMREAD_ANYDEPTH | IMREAD_ANYCOLOR);
|
||||
if (image == null || image.empty()) {
|
||||
|
@ -272,7 +282,11 @@ public class NativeImageLoader extends BaseImageLoader {
|
|||
a = asMatrix(image);
|
||||
image.deallocate();
|
||||
}
|
||||
return a;
|
||||
if(nchw) {
|
||||
return a;
|
||||
} else {
|
||||
return a.permute(0, 2, 3, 1); //NCHW to NHWC
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -331,19 +345,29 @@ public class NativeImageLoader extends BaseImageLoader {
|
|||
}
|
||||
|
||||
public Image asImageMatrix(String filename) throws IOException {
|
||||
return asImageMatrix(filename);
|
||||
return asImageMatrix(new File(filename));
|
||||
}
|
||||
|
||||
@Override
|
||||
public Image asImageMatrix(File f) throws IOException {
|
||||
return asImageMatrix(f, true);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Image asImageMatrix(File f, boolean nchw) throws IOException {
|
||||
try (BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f))) {
|
||||
return asImageMatrix(bis);
|
||||
return asImageMatrix(bis, nchw);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Image asImageMatrix(InputStream is) throws IOException {
|
||||
Mat mat = streamToMat(is);
|
||||
return asImageMatrix(is, true);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Image asImageMatrix(InputStream inputStream, boolean nchw) throws IOException {
|
||||
Mat mat = streamToMat(inputStream);
|
||||
Mat image = imdecode(mat, IMREAD_ANYDEPTH | IMREAD_ANYCOLOR);
|
||||
if (image == null || image.empty()) {
|
||||
PIX pix = pixReadMem(mat.data(), mat.cols());
|
||||
|
@ -354,6 +378,8 @@ public class NativeImageLoader extends BaseImageLoader {
|
|||
pixDestroy(pix);
|
||||
}
|
||||
INDArray a = asMatrix(image);
|
||||
if(!nchw)
|
||||
a = a.permute(0,2,3,1); //NCHW to NHWC
|
||||
Image i = new Image(a, image.channels(), image.rows(), image.cols());
|
||||
|
||||
image.deallocate();
|
||||
|
|
|
@ -77,6 +77,8 @@ public abstract class BaseImageRecordReader extends BaseRecordReader {
|
|||
protected int patternPosition = 0;
|
||||
@Getter @Setter
|
||||
protected boolean logLabelCountOnInit = true;
|
||||
@Getter @Setter
|
||||
protected boolean nchw_channels_first = true;
|
||||
|
||||
public final static String HEIGHT = NAME_SPACE + ".height";
|
||||
public final static String WIDTH = NAME_SPACE + ".width";
|
||||
|
@ -101,6 +103,11 @@ public abstract class BaseImageRecordReader extends BaseRecordReader {
|
|||
|
||||
protected BaseImageRecordReader(long height, long width, long channels, PathLabelGenerator labelGenerator,
|
||||
PathMultiLabelGenerator labelMultiGenerator, ImageTransform imageTransform) {
|
||||
this(height, width, channels, true, labelGenerator, labelMultiGenerator, imageTransform);
|
||||
}
|
||||
|
||||
protected BaseImageRecordReader(long height, long width, long channels, boolean nchw_channels_first, PathLabelGenerator labelGenerator,
|
||||
PathMultiLabelGenerator labelMultiGenerator, ImageTransform imageTransform) {
|
||||
this.height = height;
|
||||
this.width = width;
|
||||
this.channels = channels;
|
||||
|
@ -108,6 +115,7 @@ public abstract class BaseImageRecordReader extends BaseRecordReader {
|
|||
this.labelMultiGenerator = labelMultiGenerator;
|
||||
this.imageTransform = imageTransform;
|
||||
this.appendLabel = (labelGenerator != null || labelMultiGenerator != null);
|
||||
this.nchw_channels_first = nchw_channels_first;
|
||||
}
|
||||
|
||||
protected boolean containsFormat(String format) {
|
||||
|
@ -237,9 +245,13 @@ public abstract class BaseImageRecordReader extends BaseRecordReader {
|
|||
return next();
|
||||
try {
|
||||
invokeListeners(image);
|
||||
INDArray row = imageLoader.asMatrix(image);
|
||||
Nd4j.getAffinityManager().ensureLocation(row, AffinityManager.Location.DEVICE);
|
||||
ret = RecordConverter.toRecord(row);
|
||||
INDArray array = imageLoader.asMatrix(image);
|
||||
if(!nchw_channels_first){
|
||||
array = array.permute(0,2,3,1); //NCHW to NHWC
|
||||
}
|
||||
|
||||
Nd4j.getAffinityManager().ensureLocation(array, AffinityManager.Location.DEVICE);
|
||||
ret = RecordConverter.toRecord(array);
|
||||
if (appendLabel || writeLabel){
|
||||
if(labelMultiGenerator != null){
|
||||
ret.addAll(labelMultiGenerator.getLabels(image.getPath()));
|
||||
|
@ -286,7 +298,7 @@ public abstract class BaseImageRecordReader extends BaseRecordReader {
|
|||
|
||||
@Override
|
||||
public List<List<Writable>> next(int num) {
|
||||
Preconditions.checkArgument(num > 0, "Number of examples must be > 0: got " + num);
|
||||
Preconditions.checkArgument(num > 0, "Number of examples must be > 0: got %s", num);
|
||||
|
||||
if (imageLoader == null) {
|
||||
imageLoader = new NativeImageLoader(height, width, channels, imageTransform);
|
||||
|
@ -337,6 +349,9 @@ public abstract class BaseImageRecordReader extends BaseRecordReader {
|
|||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
if(!nchw_channels_first){
|
||||
features = features.permute(0,2,3,1); //NCHW to NHWC
|
||||
}
|
||||
Nd4j.getAffinityManager().ensureLocation(features, AffinityManager.Location.DEVICE);
|
||||
|
||||
|
||||
|
@ -483,8 +498,10 @@ public abstract class BaseImageRecordReader extends BaseRecordReader {
|
|||
if (imageLoader == null) {
|
||||
imageLoader = new NativeImageLoader(height, width, channels, imageTransform);
|
||||
}
|
||||
INDArray row = imageLoader.asMatrix(dataInputStream);
|
||||
List<Writable> ret = RecordConverter.toRecord(row);
|
||||
INDArray array = imageLoader.asMatrix(dataInputStream);
|
||||
if(!nchw_channels_first)
|
||||
array = array.permute(0,2,3,1);
|
||||
List<Writable> ret = RecordConverter.toRecord(array);
|
||||
if (appendLabel)
|
||||
ret.add(new IntWritable(labels.indexOf(getLabel(uri.getPath()))));
|
||||
return ret;
|
||||
|
|
|
@ -34,47 +34,70 @@ import org.datavec.image.transform.ImageTransform;
|
|||
public class ImageRecordReader extends BaseImageRecordReader {
|
||||
|
||||
|
||||
/** Loads images with height = 28, width = 28, and channels = 1, appending no labels. */
|
||||
/** Loads images with height = 28, width = 28, and channels = 1, appending no labels.
|
||||
* Output format is NCHW (channels first) - [numExamples, 1, 28, 28]*/
|
||||
public ImageRecordReader() {
|
||||
super();
|
||||
}
|
||||
|
||||
/** Loads images with given height, width, and channels, appending labels returned by the generator. */
|
||||
/** Loads images with given height, width, and channels, appending labels returned by the generator.
|
||||
* Output format is NCHW (channels first) - [numExamples, channels, height, width]
|
||||
*/
|
||||
public ImageRecordReader(long height, long width, long channels, PathLabelGenerator labelGenerator) {
|
||||
super(height, width, channels, labelGenerator);
|
||||
}
|
||||
|
||||
/** Loads images with given height, width, and channels, appending labels returned by the generator. */
|
||||
/** Loads images with given height, width, and channels, appending labels returned by the generator.
|
||||
* Output format is NCHW (channels first) - [numExamples, channels, height, width]
|
||||
*/
|
||||
public ImageRecordReader(long height, long width, long channels, PathMultiLabelGenerator labelGenerator) {
|
||||
super(height, width, channels, labelGenerator);
|
||||
}
|
||||
|
||||
/** Loads images with given height, width, and channels, appending no labels. */
|
||||
/** Loads images with given height, width, and channels, appending no labels - in NCHW (channels first) format */
|
||||
public ImageRecordReader(long height, long width, long channels) {
|
||||
super(height, width, channels, (PathLabelGenerator) null);
|
||||
}
|
||||
|
||||
/** Loads images with given height, width, and channels, appending labels returned by the generator. */
|
||||
/** Loads images with given height, width, and channels, appending no labels - in specified format<br>
|
||||
* If {@code nchw_channels_first == true} output format is NCHW (channels first) - [numExamples, channels, height, width]<br>
|
||||
* If {@code nchw_channels_first == false} output format is NHWC (channels last) - [numExamples, height, width, channels]<br>
|
||||
*/
|
||||
public ImageRecordReader(long height, long width, long channels, boolean nchw_channels_first) {
|
||||
super(height, width, channels, nchw_channels_first, null, null, null);
|
||||
}
|
||||
|
||||
/** Loads images with given height, width, and channels, appending labels returned by the generator.
|
||||
* Output format is NCHW (channels first) - [numExamples, channels, height, width] */
|
||||
public ImageRecordReader(long height, long width, long channels, PathLabelGenerator labelGenerator,
|
||||
ImageTransform imageTransform) {
|
||||
super(height, width, channels, labelGenerator, imageTransform);
|
||||
}
|
||||
|
||||
/** Loads images with given height, width, and channels, appending no labels. */
|
||||
/** Loads images with given height, width, and channels, appending labels returned by the generator.<br>
|
||||
* If {@code nchw_channels_first == true} output format is NCHW (channels first) - [numExamples, channels, height, width]<br>
|
||||
* If {@code nchw_channels_first == false} output format is NHWC (channels last) - [numExamples, height, width, channels]<br>
|
||||
*/
|
||||
public ImageRecordReader(long height, long width, long channels, boolean nchw_channels_first, PathLabelGenerator labelGenerator,
|
||||
ImageTransform imageTransform) {
|
||||
super(height, width, channels, nchw_channels_first, labelGenerator, null, imageTransform);
|
||||
}
|
||||
|
||||
/** Loads images with given height, width, and channels, appending no labels.
|
||||
* Output format is NCHW (channels first) - [numExamples, channels, height, width]*/
|
||||
public ImageRecordReader(long height, long width, long channels, ImageTransform imageTransform) {
|
||||
super(height, width, channels, null, imageTransform);
|
||||
}
|
||||
|
||||
/** Loads images with given height, width, and channels, appending labels returned by the generator. */
|
||||
/** Loads images with given height, width, and channels, appending labels returned by the generator
|
||||
* Output format is NCHW (channels first) - [numExamples, channels, height, width]*/
|
||||
public ImageRecordReader(long height, long width, PathLabelGenerator labelGenerator) {
|
||||
super(height, width, 1, labelGenerator);
|
||||
}
|
||||
|
||||
/** Loads images with given height, width, and channels = 1, appending no labels. */
|
||||
/** Loads images with given height, width, and channels = 1, appending no labels.
|
||||
* Output format is NCHW (channels first) - [numExamples, channels, height, width]*/
|
||||
public ImageRecordReader(long height, long width) {
|
||||
super(height, width, 1, null, null);
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
|
|
@ -16,10 +16,16 @@
|
|||
|
||||
package org.datavec.image.loader;
|
||||
|
||||
import org.datavec.image.data.Image;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.common.resources.Resources;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
import java.awt.image.BufferedImage;
|
||||
import java.io.BufferedInputStream;
|
||||
import java.io.File;
|
||||
import java.io.FileInputStream;
|
||||
import java.io.InputStream;
|
||||
import java.util.Random;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
@ -208,4 +214,57 @@ public class TestImageLoader {
|
|||
private BufferedImage makeRandomBufferedImage(boolean alpha) {
|
||||
return makeRandomBufferedImage(alpha, rng.nextInt() % 100 + 100, rng.nextInt() % 100 + 100);
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testNCHW_NHWC() throws Exception {
|
||||
File f = Resources.asFile("datavec-data-image/voc/2007/JPEGImages/000005.jpg");
|
||||
|
||||
ImageLoader il = new ImageLoader(32, 32, 3);
|
||||
|
||||
//asMatrix(File, boolean)
|
||||
INDArray a_nchw = il.asMatrix(f);
|
||||
INDArray a_nchw2 = il.asMatrix(f, true);
|
||||
INDArray a_nhwc = il.asMatrix(f, false);
|
||||
|
||||
assertEquals(a_nchw, a_nchw2);
|
||||
assertEquals(a_nchw, a_nhwc.permute(0,3,1,2));
|
||||
|
||||
|
||||
//asMatrix(InputStream, boolean)
|
||||
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
|
||||
a_nchw = il.asMatrix(is);
|
||||
}
|
||||
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
|
||||
a_nchw2 = il.asMatrix(is, true);
|
||||
}
|
||||
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
|
||||
a_nhwc = il.asMatrix(is, false);
|
||||
}
|
||||
assertEquals(a_nchw, a_nchw2);
|
||||
assertEquals(a_nchw, a_nhwc.permute(0,3,1,2));
|
||||
|
||||
|
||||
//asImageMatrix(File, boolean)
|
||||
Image i_nchw = il.asImageMatrix(f);
|
||||
Image i_nchw2 = il.asImageMatrix(f, true);
|
||||
Image i_nhwc = il.asImageMatrix(f, false);
|
||||
|
||||
assertEquals(i_nchw.getImage(), i_nchw2.getImage());
|
||||
assertEquals(i_nchw.getImage(), i_nhwc.getImage().permute(0,3,1,2)); //NHWC to NCHW
|
||||
|
||||
|
||||
//asImageMatrix(InputStream, boolean)
|
||||
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
|
||||
i_nchw = il.asImageMatrix(is);
|
||||
}
|
||||
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
|
||||
i_nchw2 = il.asImageMatrix(is, true);
|
||||
}
|
||||
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
|
||||
i_nhwc = il.asImageMatrix(is, false);
|
||||
}
|
||||
assertEquals(i_nchw.getImage(), i_nchw2.getImage());
|
||||
assertEquals(i_nchw.getImage(), i_nhwc.getImage().permute(0,3,1,2)); //NHWC to NCHW
|
||||
}
|
||||
}
|
||||
|
|
|
@ -24,20 +24,19 @@ import org.bytedeco.javacpp.indexer.UByteIndexer;
|
|||
import org.bytedeco.javacv.Frame;
|
||||
import org.bytedeco.javacv.Java2DFrameConverter;
|
||||
import org.bytedeco.javacv.OpenCVFrameConverter;
|
||||
import org.datavec.image.data.Image;
|
||||
import org.datavec.image.data.ImageWritable;
|
||||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
import org.junit.rules.TemporaryFolder;
|
||||
import org.nd4j.common.resources.Resources;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.common.io.ClassPathResource;
|
||||
|
||||
import java.awt.image.BufferedImage;
|
||||
import java.io.File;
|
||||
import java.io.FileInputStream;
|
||||
import java.io.InputStream;
|
||||
import java.io.IOException;
|
||||
import java.io.*;
|
||||
import java.lang.reflect.Field;
|
||||
import java.util.Random;
|
||||
|
||||
|
@ -604,4 +603,56 @@ public class TestNativeImageLoader {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNCHW_NHWC() throws Exception {
|
||||
File f = Resources.asFile("datavec-data-image/voc/2007/JPEGImages/000005.jpg");
|
||||
|
||||
NativeImageLoader il = new NativeImageLoader(32, 32, 3);
|
||||
|
||||
//asMatrix(File, boolean)
|
||||
INDArray a_nchw = il.asMatrix(f);
|
||||
INDArray a_nchw2 = il.asMatrix(f, true);
|
||||
INDArray a_nhwc = il.asMatrix(f, false);
|
||||
|
||||
assertEquals(a_nchw, a_nchw2);
|
||||
assertEquals(a_nchw, a_nhwc.permute(0,3,1,2));
|
||||
|
||||
|
||||
//asMatrix(InputStream, boolean)
|
||||
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
|
||||
a_nchw = il.asMatrix(is);
|
||||
}
|
||||
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
|
||||
a_nchw2 = il.asMatrix(is, true);
|
||||
}
|
||||
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
|
||||
a_nhwc = il.asMatrix(is, false);
|
||||
}
|
||||
assertEquals(a_nchw, a_nchw2);
|
||||
assertEquals(a_nchw, a_nhwc.permute(0,3,1,2));
|
||||
|
||||
|
||||
//asImageMatrix(File, boolean)
|
||||
Image i_nchw = il.asImageMatrix(f);
|
||||
Image i_nchw2 = il.asImageMatrix(f, true);
|
||||
Image i_nhwc = il.asImageMatrix(f, false);
|
||||
|
||||
assertEquals(i_nchw.getImage(), i_nchw2.getImage());
|
||||
assertEquals(i_nchw.getImage(), i_nhwc.getImage().permute(0,3,1,2)); //NHWC to NCHW
|
||||
|
||||
|
||||
//asImageMatrix(InputStream, boolean)
|
||||
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
|
||||
i_nchw = il.asImageMatrix(is);
|
||||
}
|
||||
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
|
||||
i_nchw2 = il.asImageMatrix(is, true);
|
||||
}
|
||||
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
|
||||
i_nhwc = il.asImageMatrix(is, false);
|
||||
}
|
||||
assertEquals(i_nchw.getImage(), i_nchw2.getImage());
|
||||
assertEquals(i_nchw.getImage(), i_nhwc.getImage().permute(0,3,1,2)); //NHWC to NCHW
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -35,13 +35,13 @@ import org.datavec.api.writable.batch.NDArrayRecordBatch;
|
|||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
import org.junit.rules.TemporaryFolder;
|
||||
import org.nd4j.common.resources.Resources;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.common.io.ClassPathResource;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.io.*;
|
||||
import java.net.URI;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
|
@ -467,5 +467,87 @@ public class TestImageRecordReader {
|
|||
return count;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
@Test
|
||||
public void testNCHW_NCHW() throws Exception {
|
||||
//Idea: labels order should be consistent regardless of input file order
|
||||
File f0 = testDir.newFolder();
|
||||
new ClassPathResource("datavec-data-image/testimages/").copyDirectory(f0);
|
||||
|
||||
FileSplit fs0 = new FileSplit(f0, new Random(12345));
|
||||
FileSplit fs1 = new FileSplit(f0, new Random(12345));
|
||||
assertEquals(6, fs0.locations().length);
|
||||
assertEquals(6, fs1.locations().length);
|
||||
|
||||
ImageRecordReader nchw = new ImageRecordReader(32, 32, 3, true);
|
||||
nchw.initialize(fs0);
|
||||
|
||||
ImageRecordReader nhwc = new ImageRecordReader(32, 32, 3, false);
|
||||
nhwc.initialize(fs1);
|
||||
|
||||
while(nchw.hasNext()){
|
||||
assertTrue(nhwc.hasNext());
|
||||
|
||||
List<Writable> l_nchw = nchw.next();
|
||||
List<Writable> l_nhwc = nhwc.next();
|
||||
|
||||
INDArray a_nchw = ((NDArrayWritable)l_nchw.get(0)).get();
|
||||
INDArray a_nhwc = ((NDArrayWritable)l_nhwc.get(0)).get();
|
||||
|
||||
assertArrayEquals(new long[]{1, 3, 32, 32}, a_nchw.shape());
|
||||
assertArrayEquals(new long[]{1, 32, 32, 3}, a_nhwc.shape());
|
||||
|
||||
INDArray permuted = a_nhwc.permute(0,3,1,2); //NHWC to NCHW
|
||||
assertEquals(a_nchw, permuted);
|
||||
}
|
||||
|
||||
|
||||
//Test batch:
|
||||
nchw.reset();
|
||||
nhwc.reset();
|
||||
|
||||
int batchCount = 0;
|
||||
while(nchw.hasNext()){
|
||||
assertTrue(nhwc.hasNext());
|
||||
batchCount++;
|
||||
|
||||
List<List<Writable>> l_nchw = nchw.next(3);
|
||||
List<List<Writable>> l_nhwc = nhwc.next(3);
|
||||
assertEquals(3, l_nchw.size());
|
||||
assertEquals(3, l_nhwc.size());
|
||||
|
||||
NDArrayRecordBatch b_nchw = (NDArrayRecordBatch)l_nchw;
|
||||
NDArrayRecordBatch b_nhwc = (NDArrayRecordBatch)l_nhwc;
|
||||
|
||||
INDArray a_nchw = b_nchw.getArrays().get(0);
|
||||
INDArray a_nhwc = b_nhwc.getArrays().get(0);
|
||||
|
||||
assertArrayEquals(new long[]{3, 3, 32, 32}, a_nchw.shape());
|
||||
assertArrayEquals(new long[]{3, 32, 32, 3}, a_nhwc.shape());
|
||||
|
||||
INDArray permuted = a_nhwc.permute(0,3,1,2); //NHWC to NCHW
|
||||
assertEquals(a_nchw, permuted);
|
||||
}
|
||||
assertEquals(2, batchCount);
|
||||
|
||||
|
||||
//Test record(URI, DataInputStream)
|
||||
|
||||
URI u = fs0.locations()[0];
|
||||
|
||||
try(DataInputStream dis = new DataInputStream(new BufferedInputStream(new FileInputStream(new File(u))))) {
|
||||
List<Writable> l = nchw.record(u, dis);
|
||||
INDArray arr = ((NDArrayWritable)l.get(0)).get();
|
||||
assertArrayEquals(new long[]{1, 3, 32, 32}, arr.shape());
|
||||
}
|
||||
|
||||
try(DataInputStream dis = new DataInputStream(new BufferedInputStream(new FileInputStream(new File(u))))) {
|
||||
List<Writable> l = nhwc.record(u, dis);
|
||||
INDArray arr = ((NDArrayWritable)l.get(0)).get();
|
||||
assertArrayEquals(new long[]{1, 32, 32, 3}, arr.shape());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -8,12 +8,14 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
|||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.junit.Ignore;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.common.resources.Resources;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.learning.config.RmsProp;
|
||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||
|
||||
import java.nio.file.Files;
|
||||
import java.util.concurrent.CountDownLatch;
|
||||
|
||||
@Ignore
|
||||
|
|
|
@ -18,11 +18,9 @@ package org.deeplearning4j.nn.layers.convolution;
|
|||
import lombok.*;
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.TestUtils;
|
||||
import org.deeplearning4j.exception.DL4JInvalidInputException;
|
||||
import org.deeplearning4j.nn.api.MaskState;
|
||||
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
||||
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.*;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.*;
|
||||
import org.deeplearning4j.nn.conf.layers.CnnLossLayer;
|
||||
|
@ -35,6 +33,7 @@ import org.deeplearning4j.nn.gradient.Gradient;
|
|||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.deeplearning4j.nn.workspace.ArrayType;
|
||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||
import org.deeplearning4j.util.ConvolutionUtils;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
|
@ -49,6 +48,7 @@ import java.util.List;
|
|||
import java.util.Map;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
@RunWith(Parameterized.class)
|
||||
public class ConvDataFormatTests extends BaseDL4JTest {
|
||||
|
@ -971,4 +971,58 @@ public class ConvDataFormatTests extends BaseDL4JTest {
|
|||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testWrongFormatIn(){
|
||||
|
||||
for(CNN2DFormat df : CNN2DFormat.values()){
|
||||
|
||||
|
||||
for(int i=0; i<4; i++ ){
|
||||
|
||||
NeuralNetConfiguration.ListBuilder b = new NeuralNetConfiguration.Builder()
|
||||
.list();
|
||||
switch (i){
|
||||
case 0:
|
||||
b.layer(new ConvolutionLayer.Builder().kernelSize(2,2).nIn(3).nOut(3).dataFormat(df).build());
|
||||
break;
|
||||
case 1:
|
||||
b.layer(new DepthwiseConvolution2D.Builder().kernelSize(2,2).nIn(3).nOut(3).dataFormat(df).build());
|
||||
break;
|
||||
case 2:
|
||||
b.layer(new Deconvolution2D.Builder().dataFormat(df).kernelSize(2,2).nIn(3).nOut(3).build());
|
||||
break;
|
||||
case 3:
|
||||
b.layer(new SeparableConvolution2D.Builder().dataFormat(df).kernelSize(2,2).nIn(3).nOut(3).build());
|
||||
break;
|
||||
}
|
||||
|
||||
MultiLayerNetwork net = new MultiLayerNetwork(b.build());
|
||||
net.init();
|
||||
|
||||
INDArray in;
|
||||
INDArray wrongFormatIn;
|
||||
if(df == CNN2DFormat.NCHW){
|
||||
in = Nd4j.create(DataType.FLOAT, 5, 3, 12, 12);
|
||||
wrongFormatIn = Nd4j.create(DataType.FLOAT, 5, 12, 12, 3);
|
||||
} else {
|
||||
in = Nd4j.create(DataType.FLOAT, 5, 12, 12, 3);
|
||||
wrongFormatIn = Nd4j.create(DataType.FLOAT, 5, 3, 12, 12);
|
||||
}
|
||||
|
||||
net.output(in);
|
||||
|
||||
try {
|
||||
net.output(wrongFormatIn);
|
||||
} catch (DL4JInvalidInputException e){
|
||||
// e.printStackTrace();
|
||||
String msg = e.getMessage();
|
||||
assertTrue(msg, msg.contains(ConvolutionUtils.NCHW_NHWC_ERROR_MSG));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,9 +21,9 @@ import lombok.val;
|
|||
import org.apache.commons.lang3.RandomUtils;
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.optimize.solvers.accumulation.SmartFancyBlockingQueue;
|
||||
import org.deeplearning4j.core.util.ThreadUtils;
|
||||
import org.junit.Ignore;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.common.util.ThreadUtils;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import java.util.ArrayList;
|
||||
|
|
|
@ -31,11 +31,6 @@
|
|||
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.deeplearning4j</groupId>
|
||||
<artifactId>deeplearning4j-util</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-api</artifactId>
|
||||
|
|
|
@ -1,75 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.datasets.iterator.impl;
|
||||
|
||||
import org.deeplearning4j.util.MovingWindowMatrix;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.DataSet;
|
||||
import org.nd4j.linalg.dataset.api.iterator.fetcher.BaseDataFetcher;
|
||||
import org.nd4j.common.util.ArrayUtil;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
*
|
||||
* Moving window data fetcher. Handles rotation of matrices in all directions
|
||||
* to generate more examples.
|
||||
*
|
||||
*
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
public class MovingWindowDataSetFetcher extends BaseDataFetcher {
|
||||
|
||||
private DataSet data;
|
||||
private int windowRows = 28, windowColumns = 28;
|
||||
private int cursor = 0;
|
||||
|
||||
public MovingWindowDataSetFetcher(DataSet data, int windowRows, int windowColumns) {
|
||||
this.data = data;
|
||||
this.windowRows = windowRows;
|
||||
this.windowColumns = windowColumns;
|
||||
List<DataSet> list = data.asList();
|
||||
List<DataSet> flipped = new ArrayList<>();
|
||||
for (int i = 0; i < list.size(); i++) {
|
||||
INDArray label = list.get(i).getLabels();
|
||||
List<INDArray> windows =
|
||||
new MovingWindowMatrix(list.get(i).getFeatures(), windowRows, windowColumns, true)
|
||||
.windows(true);
|
||||
for (int j = 0; j < windows.size(); j++) {
|
||||
flipped.add(new DataSet(windows.get(j), label));
|
||||
}
|
||||
flipped.add(list.get(i));
|
||||
}
|
||||
|
||||
this.data = DataSet.merge(flipped);
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Fetches the next dataset. You need to call this
|
||||
* to get a new dataset, otherwise {@link #next()}
|
||||
* just returns the last data applyTransformToDestination fetch
|
||||
*
|
||||
* @param numExamples the number of examples to fetch
|
||||
*/
|
||||
@Override
|
||||
public void fetch(int numExamples) {
|
||||
initializeCurrFromList(data.get(ArrayUtil.range(cursor, cursor + numExamples)).asList());
|
||||
|
||||
}
|
||||
}
|
|
@ -1,5 +1,6 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
@ -28,6 +29,7 @@ import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
|
|||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.lossfunctions.ILossFunction;
|
||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||
|
||||
import java.util.ArrayList;
|
||||
|
@ -45,7 +47,7 @@ import static org.deeplearning4j.nn.modelimport.keras.utils.KerasLossUtils.mapLo
|
|||
public class KerasLoss extends KerasLayer {
|
||||
|
||||
private final String KERAS_CLASS_NAME_LOSS = "Loss";
|
||||
private LossFunctions.LossFunction loss;
|
||||
private ILossFunction loss;
|
||||
|
||||
|
||||
/**
|
||||
|
@ -86,7 +88,7 @@ public class KerasLoss extends KerasLayer {
|
|||
if (enforceTrainingConfig)
|
||||
throw e;
|
||||
log.warn("Unsupported Keras loss function. Replacing with MSE.");
|
||||
loss = LossFunctions.LossFunction.SQUARED_LOSS;
|
||||
loss = LossFunctions.LossFunction.SQUARED_LOSS.getILossFunction();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
@ -19,8 +20,13 @@ package org.deeplearning4j.nn.modelimport.keras.utils;
|
|||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
||||
import org.nd4j.linalg.lossfunctions.ILossFunction;
|
||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
|
||||
/**
|
||||
* Utility functionality for keras loss functions
|
||||
*
|
||||
|
@ -28,13 +34,33 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
|
|||
*/
|
||||
@Slf4j
|
||||
public class KerasLossUtils {
|
||||
static final Map<String, ILossFunction> customLoss = new HashMap<>();
|
||||
|
||||
/**
|
||||
* Register a custom loss function
|
||||
*
|
||||
* @param lossName name of the lambda layer in the serialized Keras model
|
||||
* @param lossFunction SameDiffLambdaLayer instance to map to Keras Lambda layer
|
||||
*/
|
||||
public static void registerCustomLoss(String lossName, ILossFunction lossFunction) {
|
||||
customLoss.put(lossName, lossFunction);
|
||||
}
|
||||
|
||||
/**
|
||||
* Clear all lambda layers
|
||||
*
|
||||
*/
|
||||
public static void clearCustomLoss() {
|
||||
customLoss.clear();
|
||||
}
|
||||
|
||||
/**
|
||||
* Map Keras to DL4J loss functions.
|
||||
*
|
||||
* @param kerasLoss String containing Keras loss function name
|
||||
* @return String containing DL4J loss function
|
||||
*/
|
||||
public static LossFunctions.LossFunction mapLossFunction(String kerasLoss, KerasLayerConfiguration conf)
|
||||
public static ILossFunction mapLossFunction(String kerasLoss, KerasLayerConfiguration conf)
|
||||
throws UnsupportedKerasConfigurationException {
|
||||
LossFunctions.LossFunction dl4jLoss;
|
||||
if (kerasLoss.equals(conf.getKERAS_LOSS_MEAN_SQUARED_ERROR()) ||
|
||||
|
@ -67,8 +93,13 @@ public class KerasLossUtils {
|
|||
} else if (kerasLoss.equals(conf.getKERAS_LOSS_COSINE_PROXIMITY())) {
|
||||
dl4jLoss = LossFunctions.LossFunction.COSINE_PROXIMITY;
|
||||
} else {
|
||||
throw new UnsupportedKerasConfigurationException("Unknown Keras loss function " + kerasLoss);
|
||||
ILossFunction lossClass = customLoss.get(kerasLoss);
|
||||
if(lossClass != null){
|
||||
return lossClass;
|
||||
}else{
|
||||
throw new UnsupportedKerasConfigurationException("Unknown Keras loss function " + kerasLoss);
|
||||
}
|
||||
}
|
||||
return dl4jLoss;
|
||||
return dl4jLoss.getILossFunction();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,78 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.nn.modelimport.keras.e2e;
|
||||
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.nn.modelimport.keras.KerasSequentialModel;
|
||||
import org.deeplearning4j.nn.modelimport.keras.utils.KerasLossUtils;
|
||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
import org.junit.rules.TemporaryFolder;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.common.resources.Resources;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.lossfunctions.SameDiffLoss;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.InputStream;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.StandardCopyOption;
|
||||
|
||||
|
||||
/**
|
||||
* Test importing Keras models with custom loss.
|
||||
*
|
||||
* @author Paul Dubs
|
||||
*/
|
||||
public class KerasCustomLossTest extends BaseDL4JTest {
|
||||
|
||||
@Rule
|
||||
public TemporaryFolder testDir = new TemporaryFolder();
|
||||
|
||||
public class LogCosh extends SameDiffLoss {
|
||||
@Override
|
||||
public SDVariable defineLoss(SameDiff sd, SDVariable layerInput, SDVariable labels) {
|
||||
return sd.math.log(sd.math.cosh(labels.sub(layerInput)));
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSequentialLambdaLayerImport() throws Exception {
|
||||
KerasLossUtils.registerCustomLoss("logcosh", new LogCosh());
|
||||
|
||||
String modelPath = "modelimport/keras/examples/custom_loss.h5";
|
||||
|
||||
try(InputStream is = Resources.asStream(modelPath)) {
|
||||
File modelFile = testDir.newFile("tempModel" + System.currentTimeMillis() + ".h5");
|
||||
Files.copy(is, modelFile.toPath(), StandardCopyOption.REPLACE_EXISTING);
|
||||
MultiLayerNetwork model = new KerasSequentialModel().modelBuilder().modelHdf5Filename(modelFile.getAbsolutePath())
|
||||
.enforceTrainingConfig(true).buildSequential().getMultiLayerNetwork();
|
||||
|
||||
System.out.println(model.summary());
|
||||
INDArray input = Nd4j.create(new int[]{10, 3});
|
||||
|
||||
model.output(input);
|
||||
} finally {
|
||||
KerasLossUtils.clearCustomLoss();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
|
@ -29,7 +29,7 @@ import org.deeplearning4j.text.sentenceiterator.PrefetchingSentenceIterator;
|
|||
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
|
||||
import org.deeplearning4j.text.sentenceiterator.SynchronizedSentenceIterator;
|
||||
import org.deeplearning4j.common.util.DL4JFileUtils;
|
||||
import org.deeplearning4j.core.util.ThreadUtils;
|
||||
import org.nd4j.common.util.ThreadUtils;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.common.primitives.Pair;
|
||||
import org.slf4j.Logger;
|
||||
|
|
|
@ -47,7 +47,7 @@ import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
|
|||
import org.deeplearning4j.text.sentenceiterator.interoperability.SentenceIteratorConverter;
|
||||
import org.deeplearning4j.text.sentenceiterator.labelaware.LabelAwareSentenceIterator;
|
||||
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
|
||||
import org.deeplearning4j.core.util.ThreadUtils;
|
||||
import org.nd4j.common.util.ThreadUtils;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
|
|
@ -47,7 +47,7 @@ import org.deeplearning4j.models.word2vec.VocabWord;
|
|||
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
|
||||
import org.deeplearning4j.models.word2vec.wordstore.VocabConstructor;
|
||||
import org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache;
|
||||
import org.deeplearning4j.core.util.ThreadUtils;
|
||||
import org.nd4j.common.util.ThreadUtils;
|
||||
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
|
||||
import org.nd4j.linalg.api.memory.enums.LearningPolicy;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
|
|
@ -27,7 +27,7 @@ import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
|
|||
import org.deeplearning4j.models.word2vec.Huffman;
|
||||
import org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache;
|
||||
import org.deeplearning4j.text.invertedindex.InvertedIndex;
|
||||
import org.deeplearning4j.core.util.ThreadUtils;
|
||||
import org.nd4j.common.util.ThreadUtils;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.threadly.concurrent.PriorityScheduler;
|
||||
|
|
|
@ -18,7 +18,7 @@ package org.deeplearning4j.text.sentenceiterator;
|
|||
|
||||
import lombok.NonNull;
|
||||
|
||||
import org.deeplearning4j.core.util.ThreadUtils;
|
||||
import org.nd4j.common.util.ThreadUtils;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
|
|
|
@ -63,6 +63,9 @@ public class Deconvolution2D extends ConvolutionLayer {
|
|||
protected Deconvolution2D(BaseConvBuilder<?> builder) {
|
||||
super(builder);
|
||||
initializeConstraints(builder);
|
||||
if(builder instanceof Builder){
|
||||
this.cnn2dDataFormat = ((Builder) builder).format;
|
||||
}
|
||||
}
|
||||
|
||||
public boolean hasBias() {
|
||||
|
@ -136,7 +139,7 @@ public class Deconvolution2D extends ConvolutionLayer {
|
|||
|
||||
private CNN2DFormat format = CNN2DFormat.NCHW;
|
||||
|
||||
public Builder format(CNN2DFormat format){
|
||||
public Builder dataFormat(CNN2DFormat format){
|
||||
this.format = format;
|
||||
return this;
|
||||
}
|
||||
|
|
|
@ -310,11 +310,21 @@ public class ConvolutionLayer extends BaseLayer<org.deeplearning4j.nn.conf.layer
|
|||
String layerName = conf.getLayer().getLayerName();
|
||||
if (layerName == null)
|
||||
layerName = "(not named)";
|
||||
throw new DL4JInvalidInputException("Cannot do forward pass in Convolution layer (layer name = " + layerName
|
||||
|
||||
String s = "Cannot do forward pass in Convolution layer (layer name = " + layerName
|
||||
+ ", layer index = " + index + "): input array channels does not match CNN layer configuration"
|
||||
+ " (data input channels = " + input.size(dim) + ", " + layerConf().getCnn2dDataFormat().dimensionNames()
|
||||
+ " (data format = " + format + ", data input channels = " + input.size(dim) + ", " + layerConf().getCnn2dDataFormat().dimensionNames()
|
||||
+ "=" + Arrays.toString(input.shape()) + "; expected" + " input channels = " + inDepth + ") "
|
||||
+ layerId());
|
||||
+ layerId();
|
||||
|
||||
int dimIfWrongFormat = format == CNN2DFormat.NHWC ? 1 : 3;
|
||||
if(input.size(dimIfWrongFormat) == inDepth){
|
||||
//User might have passed NCHW data to a NHWC net, or vice versa?
|
||||
s += "\n" + ConvolutionUtils.NCHW_NHWC_ERROR_MSG;
|
||||
}
|
||||
|
||||
|
||||
throw new DL4JInvalidInputException(s);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -190,12 +190,21 @@ public class Deconvolution2DLayer extends ConvolutionLayer {
|
|||
String layerName = conf.getLayer().getLayerName();
|
||||
if (layerName == null)
|
||||
layerName = "(not named)";
|
||||
throw new DL4JInvalidInputException("Cannot do forward pass in Deconvolution2D layer (layer name = " + layerName
|
||||
|
||||
String s = "Cannot do forward pass in Deconvolution2D layer (layer name = " + layerName
|
||||
+ ", layer index = " + index + "): input array channels does not match CNN layer configuration"
|
||||
+ " (data input channels = " + input.size(cDim) + ", "
|
||||
+ " (data format = " + format + ", data input channels = " + input.size(cDim) + ", "
|
||||
+ (nchw ? "[minibatch,inputDepth,height,width]" : "[minibatch,height,width,inputDepth]") + "="
|
||||
+ Arrays.toString(input.shape()) + "; expected" + " input channels = " + inDepth + ") "
|
||||
+ layerId());
|
||||
+ layerId();
|
||||
|
||||
int dimIfWrongFormat = format == CNN2DFormat.NHWC ? 1 : 3;
|
||||
if(input.size(dimIfWrongFormat) == inDepth){
|
||||
//User might have passed NCHW data to a NHWC net, or vice versa?
|
||||
s += "\n" + ConvolutionUtils.NCHW_NHWC_ERROR_MSG;
|
||||
}
|
||||
|
||||
throw new DL4JInvalidInputException(s);
|
||||
}
|
||||
int kH = (int) weights.size(2);
|
||||
int kW = (int) weights.size(3);
|
||||
|
|
|
@ -183,13 +183,21 @@ public class DepthwiseConvolution2DLayer extends ConvolutionLayer {
|
|||
String layerName = conf.getLayer().getLayerName();
|
||||
if (layerName == null)
|
||||
layerName = "(not named)";
|
||||
throw new DL4JInvalidInputException("Cannot do forward pass in DepthwiseConvolution2D layer " +
|
||||
|
||||
String s = "Cannot do forward pass in DepthwiseConvolution2D layer " +
|
||||
"(layer name = " + layerName
|
||||
+ ", layer index = " + index + "): input array channels does not match CNN layer configuration"
|
||||
+ " (data input channels = " + input.size(1) + ", "
|
||||
+ " (data format = " + format + ", data input channels = " + input.size(1) + ", "
|
||||
+ (nchw ? "[minibatch,inputDepth,height,width]=" : "[minibatch,height,width,inputDepth]=")
|
||||
+ Arrays.toString(input.shape()) + "; expected" + " input channels = " + inDepth + ") "
|
||||
+ layerId());
|
||||
+ layerId();
|
||||
int dimIfWrongFormat = format == CNN2DFormat.NHWC ? 1 : 3;
|
||||
if(input.size(dimIfWrongFormat) == inDepth){
|
||||
//User might have passed NCHW data to a NHWC net, or vice versa?
|
||||
s += "\n" + ConvolutionUtils.NCHW_NHWC_ERROR_MSG;
|
||||
}
|
||||
|
||||
throw new DL4JInvalidInputException(s);
|
||||
}
|
||||
int kH = (int) depthWiseWeights.size(0);
|
||||
int kW = (int) depthWiseWeights.size(1);
|
||||
|
|
|
@ -211,11 +211,20 @@ public class SeparableConvolution2DLayer extends ConvolutionLayer {
|
|||
String layerName = conf.getLayer().getLayerName();
|
||||
if (layerName == null)
|
||||
layerName = "(not named)";
|
||||
throw new DL4JInvalidInputException("Cannot do forward pass in SeparableConvolution2D layer (layer name = " + layerName
|
||||
|
||||
String s = "Cannot do forward pass in SeparableConvolution2D layer (layer name = " + layerName
|
||||
+ ", layer index = " + index + "): input array channels does not match CNN layer configuration"
|
||||
+ " (data input channels = " + input.size(1) + ", [minibatch,inputDepth,height,width]="
|
||||
+ " (data format = " + format + ", data input channels = " + input.size(1) + ", [minibatch,inputDepth,height,width]="
|
||||
+ Arrays.toString(input.shape()) + "; expected" + " input channels = " + inDepth + ") "
|
||||
+ layerId());
|
||||
+ layerId();
|
||||
|
||||
int dimIfWrongFormat = format == CNN2DFormat.NHWC ? 1 : 3;
|
||||
if(input.size(dimIfWrongFormat) == inDepth){
|
||||
//User might have passed NCHW data to a NHWC net, or vice versa?
|
||||
s += "\n" + ConvolutionUtils.NCHW_NHWC_ERROR_MSG;
|
||||
}
|
||||
|
||||
throw new DL4JInvalidInputException(s);
|
||||
}
|
||||
int kH = (int) depthWiseWeights.size(2);
|
||||
int kW = (int) depthWiseWeights.size(3);
|
||||
|
|
|
@ -20,7 +20,7 @@ import lombok.*;
|
|||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.nn.api.Model;
|
||||
import org.deeplearning4j.optimize.api.BaseTrainingListener;
|
||||
import org.deeplearning4j.util.ThreadUtils;
|
||||
import org.nd4j.common.util.ThreadUtils;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
import java.io.Serializable;
|
||||
|
|
|
@ -27,8 +27,8 @@ import org.deeplearning4j.optimize.solvers.accumulation.encoding.ResidualPostPro
|
|||
import org.deeplearning4j.optimize.solvers.accumulation.encoding.ThresholdAlgorithm;
|
||||
import org.deeplearning4j.optimize.solvers.accumulation.encoding.residual.ResidualClippingPostProcessor;
|
||||
import org.deeplearning4j.optimize.solvers.accumulation.encoding.threshold.AdaptiveThresholdAlgorithm;
|
||||
import org.deeplearning4j.util.ThreadUtils;
|
||||
import org.nd4j.common.base.Preconditions;
|
||||
import org.nd4j.common.util.ThreadUtils;
|
||||
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
|
||||
import org.nd4j.linalg.api.memory.enums.*;
|
||||
|
|
|
@ -18,6 +18,7 @@ package org.deeplearning4j.optimize.solvers.accumulation;
|
|||
|
||||
import lombok.NonNull;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.nd4j.common.util.ThreadUtils;
|
||||
|
||||
import java.util.Collection;
|
||||
import java.util.Iterator;
|
||||
|
@ -28,8 +29,6 @@ import java.util.concurrent.atomic.AtomicInteger;
|
|||
import java.util.concurrent.atomic.AtomicLong;
|
||||
import java.util.concurrent.locks.ReentrantReadWriteLock;
|
||||
|
||||
import org.deeplearning4j.util.ThreadUtils;
|
||||
|
||||
/**
|
||||
* This BlockingQueue implementation is suited only for symmetric gradients updates, and should NOT be used anywhere else.
|
||||
*
|
||||
|
|
|
@ -48,6 +48,13 @@ import java.util.Arrays;
|
|||
*/
|
||||
public class ConvolutionUtils {
|
||||
|
||||
public static final String NCHW_NHWC_ERROR_MSG = "Note: Convolution layers can be configured for either NCHW (channels first)" +
|
||||
" or NHWC (channels last) format for input images and activations.\n" +
|
||||
"Layers can be configured using .dataFormat(CNN2DFormat.NCHW/NHWC) when constructing the layer, or for the entire net using" +
|
||||
" .setInputType(InputType.convolutional(height, width, depth, CNN2DForman.NCHW/NHWC)).\n" +
|
||||
"ImageRecordReader and NativeImageLoader can also be configured to load image data in either NCHW or NHWC format which must match the network";
|
||||
|
||||
|
||||
private static final int[] ONES = new int[]{1, 1};
|
||||
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
@ -16,6 +17,7 @@
|
|||
|
||||
package org.nd4j.linalg.dataset;
|
||||
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
|
||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||
|
||||
|
@ -43,7 +45,7 @@ public class ExistingMiniBatchDataSetIterator implements DataSetIterator {
|
|||
* Create with the given root directory, using the default filename pattern {@link #DEFAULT_PATTERN}
|
||||
* @param rootDir the root directory to use
|
||||
*/
|
||||
public ExistingMiniBatchDataSetIterator(File rootDir) {
|
||||
public ExistingMiniBatchDataSetIterator(@NonNull File rootDir) {
|
||||
this(rootDir, DEFAULT_PATTERN);
|
||||
}
|
||||
|
||||
|
@ -53,7 +55,7 @@ public class ExistingMiniBatchDataSetIterator implements DataSetIterator {
|
|||
* @param pattern The filename pattern to use. Used with {@code String.format(pattern,idx)}, where idx is an
|
||||
* integer, starting at 0.
|
||||
*/
|
||||
public ExistingMiniBatchDataSetIterator(File rootDir, String pattern) {
|
||||
public ExistingMiniBatchDataSetIterator(@NonNull File rootDir, String pattern) {
|
||||
this.rootDir = rootDir;
|
||||
totalBatches = rootDir.list().length;
|
||||
this.pattern = pattern;
|
||||
|
|
|
@ -38,7 +38,7 @@ import java.util.Map;
|
|||
*/
|
||||
public abstract class SameDiffLoss implements ILossFunction {
|
||||
protected transient SameDiff sd;
|
||||
protected transient SDVariable scoreVariable;
|
||||
protected transient SDVariable scorePerExampleVariable;
|
||||
|
||||
protected SameDiffLoss() {
|
||||
|
||||
|
@ -60,7 +60,8 @@ public abstract class SameDiffLoss implements ILossFunction {
|
|||
sd = SameDiff.create();
|
||||
SDVariable layerInput = sd.placeHolder("layerInput", dataType, -1);
|
||||
SDVariable labels = sd.placeHolder("labels", dataType, -1);
|
||||
scoreVariable = this.defineLoss(sd, layerInput, labels);
|
||||
scorePerExampleVariable = this.defineLoss(sd, layerInput, labels);
|
||||
scorePerExampleVariable.markAsLoss();
|
||||
sd.createGradFunction("layerInput");
|
||||
}
|
||||
|
||||
|
@ -112,7 +113,7 @@ public abstract class SameDiffLoss implements ILossFunction {
|
|||
m.put("labels", labels);
|
||||
m.put("layerInput", output);
|
||||
|
||||
INDArray scoreArr = sd.outputSingle(m,scoreVariable.name());
|
||||
INDArray scoreArr = sd.outputSingle(m, scorePerExampleVariable.name());
|
||||
|
||||
if (mask != null) {
|
||||
LossUtil.applyMask(scoreArr, mask);
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.common.util;
|
||||
|
||||
public class ThreadUtils {
|
||||
|
||||
private ThreadUtils(){ }
|
||||
|
||||
public static void uncheckedSleep(long sleepTimeMs){
|
||||
try{
|
||||
Thread.sleep(sleepTimeMs);
|
||||
} catch (InterruptedException e){ }
|
||||
}
|
||||
|
||||
}
|
Loading…
Reference in New Issue