More test fixes

master
Brian Rosenberger 2022-10-06 13:22:06 +02:00
parent 4ad1987a07
commit 11ba7a59c1
115 changed files with 252 additions and 351 deletions

View File

@ -66,6 +66,8 @@ dependencies {
implementation projects.cavisDnn.cavisDnnParallelwrapper implementation projects.cavisDnn.cavisDnnParallelwrapper
implementation projects.cavisZoo.cavisZooModels implementation projects.cavisZoo.cavisZooModels
testRuntimeOnly "net.brutex.ai:dl4j-test-resources:1.0.1-SNAPSHOT"
} }
test { test {

View File

@ -73,6 +73,7 @@ import org.nd4j.linalg.ops.transforms.Transforms;
import java.io.*; import java.io.*;
import java.lang.reflect.Modifier; import java.lang.reflect.Modifier;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.nio.file.Path;
import java.util.*; import java.util.*;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
@ -154,6 +155,9 @@ public class IntegrationTestRunner {
evaluationClassesSeen = new HashMap<>(); evaluationClassesSeen = new HashMap<>();
} }
public static void runTest(TestCase tc, Path testDir) throws Exception {
runTest(tc, testDir.toFile());
}
public static void runTest(TestCase tc, File testDir) throws Exception { public static void runTest(TestCase tc, File testDir) throws Exception {
BaseDL4JTest.skipUnlessIntegrationTests(); //Tests will ONLY be run if integration test profile is enabled. BaseDL4JTest.skipUnlessIntegrationTests(); //Tests will ONLY be run if integration test profile is enabled.
//This could alternatively be done via maven surefire configuration //This could alternatively be done via maven surefire configuration

View File

@ -28,18 +28,14 @@ import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import java.io.File; import java.io.File;
import java.nio.file.Path;
////@Ignore("AB - 2019/05/27 - Integration tests need to be updated") ////@Ignore("AB - 2019/05/27 - Integration tests need to be updated")
public class IntegrationTestsDL4J extends BaseDL4JTest { public class IntegrationTestsDL4J extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 300_000L;
}
@TempDir @TempDir
public File testDir; public Path testDir;
@AfterAll @AfterAll
public static void afterClass(){ public static void afterClass(){

View File

@ -30,12 +30,6 @@ import java.io.File;
public class IntegrationTestsSameDiff extends BaseDL4JTest { public class IntegrationTestsSameDiff extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 300_000L;
}
@TempDir @TempDir
public File testDir; public File testDir;

View File

@ -65,6 +65,7 @@ dependencies {
/*Logging*/ /*Logging*/
api 'org.slf4j:slf4j-api:1.7.30' api 'org.slf4j:slf4j-api:1.7.30'
api 'org.slf4j:slf4j-simple:1.7.25'
api "org.apache.logging.log4j:log4j-core:2.17.0" api "org.apache.logging.log4j:log4j-core:2.17.0"
api "ch.qos.logback:logback-classic:1.2.3" api "ch.qos.logback:logback-classic:1.2.3"

View File

@ -21,6 +21,7 @@
package org.datavec.api.transform.ndarray; package org.datavec.api.transform.ndarray;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import org.datavec.api.transform.ColumnType; import org.datavec.api.transform.ColumnType;
import org.datavec.api.transform.MathOp; import org.datavec.api.transform.MathOp;
import org.datavec.api.transform.metadata.ColumnMetaData; import org.datavec.api.transform.metadata.ColumnMetaData;
@ -36,6 +37,7 @@ import com.fasterxml.jackson.annotation.JsonProperty;
import java.util.Arrays; import java.util.Arrays;
@Data @Data
@EqualsAndHashCode(callSuper = false)
public class NDArrayColumnsMathOpTransform extends BaseColumnsMathOpTransform { public class NDArrayColumnsMathOpTransform extends BaseColumnsMathOpTransform {
public NDArrayColumnsMathOpTransform(@JsonProperty("newColumnName") String newColumnName, public NDArrayColumnsMathOpTransform(@JsonProperty("newColumnName") String newColumnName,

View File

@ -21,6 +21,7 @@
package org.datavec.api.transform.ndarray; package org.datavec.api.transform.ndarray;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import org.datavec.api.transform.MathFunction; import org.datavec.api.transform.MathFunction;
import org.datavec.api.transform.metadata.ColumnMetaData; import org.datavec.api.transform.metadata.ColumnMetaData;
import org.datavec.api.transform.transform.BaseColumnTransform; import org.datavec.api.transform.transform.BaseColumnTransform;
@ -32,6 +33,7 @@ import org.nd4j.linalg.ops.transforms.Transforms;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
@Data @Data
@EqualsAndHashCode(callSuper = false)
public class NDArrayMathFunctionTransform extends BaseColumnTransform { public class NDArrayMathFunctionTransform extends BaseColumnTransform {
//Can't guarantee that the writable won't be re-used, for example in different Spark ops on the same RDD //Can't guarantee that the writable won't be re-used, for example in different Spark ops on the same RDD

View File

@ -21,6 +21,7 @@
package org.datavec.api.transform.ndarray; package org.datavec.api.transform.ndarray;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import org.datavec.api.transform.MathOp; import org.datavec.api.transform.MathOp;
import org.datavec.api.transform.metadata.ColumnMetaData; import org.datavec.api.transform.metadata.ColumnMetaData;
import org.datavec.api.transform.metadata.NDArrayMetaData; import org.datavec.api.transform.metadata.NDArrayMetaData;
@ -33,6 +34,7 @@ import org.nd4j.linalg.ops.transforms.Transforms;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
@Data @Data
@EqualsAndHashCode(callSuper = false)
public class NDArrayScalarOpTransform extends BaseColumnTransform { public class NDArrayScalarOpTransform extends BaseColumnTransform {
private final MathOp mathOp; private final MathOp mathOp;

View File

@ -21,6 +21,7 @@
package org.datavec.api.transform.transform.string; package org.datavec.api.transform.transform.string;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
@ -31,6 +32,7 @@ import java.util.Collections;
import java.util.List; import java.util.List;
@Data @Data
@EqualsAndHashCode(callSuper = false)
public class StringListToIndicesNDArrayTransform extends StringListToCountsNDArrayTransform { public class StringListToIndicesNDArrayTransform extends StringListToCountsNDArrayTransform {
/** /**
* @param columnName The name of the column to convert * @param columnName The name of the column to convert

View File

@ -28,4 +28,5 @@ dependencies {
implementation "commons-io:commons-io" implementation "commons-io:commons-io"
testImplementation projects.cavisNd4j.cavisNd4jCommonTests testImplementation projects.cavisNd4j.cavisNd4jCommonTests
testRuntimeOnly "net.brutex.ai:dl4j-test-resources:1.0.1-SNAPSHOT"
} }

View File

@ -151,7 +151,7 @@ public class LFWLoader extends BaseImageLoader implements Serializable {
} }
FileSplit fileSplit = new FileSplit(fullDir, ALLOWED_FORMATS, rng); FileSplit fileSplit = new FileSplit(fullDir, ALLOWED_FORMATS, rng);
BalancedPathFilter pathFilter = new BalancedPathFilter(rng, ALLOWED_FORMATS, labelGenerator, numExamples, BalancedPathFilter pathFilter = new BalancedPathFilter(rng, ALLOWED_FORMATS, labelGenerator, numExamples,
numLabels, 0, batchSize, null); numLabels, 0, batchSize, (String) null);
inputSplit = fileSplit.sample(pathFilter, numExamples * splitTrainTest, numExamples * (1 - splitTrainTest)); inputSplit = fileSplit.sample(pathFilter, numExamples * splitTrainTest, numExamples * (1 - splitTrainTest));
} }

View File

@ -48,6 +48,11 @@ import static org.bytedeco.opencv.global.opencv_core.*;
import static org.bytedeco.opencv.global.opencv_imgcodecs.*; import static org.bytedeco.opencv.global.opencv_imgcodecs.*;
import static org.bytedeco.opencv.global.opencv_imgproc.*; import static org.bytedeco.opencv.global.opencv_imgproc.*;
/**
* Uses JavaCV to load images. Allowed formats: bmp, gif, jpg, jpeg, jp2, pbm, pgm, ppm, pnm, png, tif, tiff, exr, webp
*
* @author saudet
*/
public class NativeImageLoader extends BaseImageLoader { public class NativeImageLoader extends BaseImageLoader {
private static final int MIN_BUFFER_STEP_SIZE = 64 * 1024; private static final int MIN_BUFFER_STEP_SIZE = 64 * 1024;
private byte[] buffer = null; private byte[] buffer = null;
@ -57,14 +62,16 @@ public class NativeImageLoader extends BaseImageLoader {
"png", "tif", "tiff", "exr", "webp", "BMP", "GIF", "JPG", "JPEG", "JP2", "PBM", "PGM", "PPM", "PNM", "png", "tif", "tiff", "exr", "webp", "BMP", "GIF", "JPG", "JPEG", "JP2", "PBM", "PGM", "PPM", "PNM",
"PNG", "TIF", "TIFF", "EXR", "WEBP"}; "PNG", "TIF", "TIFF", "EXR", "WEBP"};
protected OpenCVFrameConverter.ToMat converter = new OpenCVFrameConverter.ToMat(); protected OpenCVFrameConverter.ToMat converter;
boolean direct = !Loader.getPlatform().startsWith("android"); boolean direct = !Loader.getPlatform().startsWith("android");
/** /**
* Loads images with no scaling or conversion. * Loads images with no scaling or conversion.
*/ */
public NativeImageLoader() {} public NativeImageLoader() {
this.converter = new OpenCVFrameConverter.ToMat();
}
/** /**
* Instantiate an image with the given * Instantiate an image with the given
@ -74,6 +81,7 @@ public class NativeImageLoader extends BaseImageLoader {
*/ */
public NativeImageLoader(long height, long width) { public NativeImageLoader(long height, long width) {
this();
this.height = height; this.height = height;
this.width = width; this.width = width;
} }
@ -87,8 +95,7 @@ public class NativeImageLoader extends BaseImageLoader {
* @param channels the number of channels for the image* * @param channels the number of channels for the image*
*/ */
public NativeImageLoader(long height, long width, long channels) { public NativeImageLoader(long height, long width, long channels) {
this.height = height; this(height, width);
this.width = width;
this.channels = channels; this.channels = channels;
} }
@ -132,12 +139,9 @@ public class NativeImageLoader extends BaseImageLoader {
} }
protected NativeImageLoader(NativeImageLoader other) { protected NativeImageLoader(NativeImageLoader other) {
this.height = other.height; this(other.height, other.width, other.channels, other.multiPageMode);
this.width = other.width;
this.channels = other.channels;
this.centerCropIfNeeded = other.centerCropIfNeeded; this.centerCropIfNeeded = other.centerCropIfNeeded;
this.imageTransform = other.imageTransform; this.imageTransform = other.imageTransform;
this.multiPageMode = other.multiPageMode;
} }
@Override @Override
@ -252,8 +256,27 @@ public class NativeImageLoader extends BaseImageLoader {
@Override @Override
public INDArray asMatrix(File f, boolean nchw) throws IOException { public INDArray asMatrix(File f, boolean nchw) throws IOException {
try (BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f))) { Mat mat = imread(f.getAbsolutePath(), IMREAD_ANYDEPTH | IMREAD_ANYCOLOR );
return asMatrix(bis, nchw); INDArray a;
if (this.multiPageMode != null) {
a = asMatrix(mat.data(), mat.cols());
}else{
// Mat image = imdecode(mat, IMREAD_ANYDEPTH | IMREAD_ANYCOLOR);
if (mat == null || mat.empty()) {
PIX pix = pixReadMem(mat.data(), mat.cols());
if (pix == null) {
throw new IOException("Could not decode image from input stream");
}
mat = convert(pix);
pixDestroy(pix);
}
a = asMatrix(mat);
mat.deallocate();
}
if(nchw) {
return a;
} else {
return a.permute(0, 2, 3, 1); //NCHW to NHWC
} }
} }
@ -264,6 +287,8 @@ public class NativeImageLoader extends BaseImageLoader {
@Override @Override
public INDArray asMatrix(InputStream inputStream, boolean nchw) throws IOException { public INDArray asMatrix(InputStream inputStream, boolean nchw) throws IOException {
throw new RuntimeException("not implemented");
/*
Mat mat = streamToMat(inputStream); Mat mat = streamToMat(inputStream);
INDArray a; INDArray a;
if (this.multiPageMode != null) { if (this.multiPageMode != null) {
@ -286,6 +311,8 @@ public class NativeImageLoader extends BaseImageLoader {
} else { } else {
return a.permute(0, 2, 3, 1); //NCHW to NHWC return a.permute(0, 2, 3, 1); //NCHW to NHWC
} }
*/
} }
/** /**
@ -297,7 +324,7 @@ public class NativeImageLoader extends BaseImageLoader {
private Mat streamToMat(InputStream is) throws IOException { private Mat streamToMat(InputStream is) throws IOException {
if(buffer == null){ if(buffer == null){
buffer = IOUtils.toByteArray(is); buffer = IOUtils.toByteArray(is);
if(buffer.length <= 0){ if(buffer.length == 0){
throw new IOException("Could not decode image from input stream: input stream was empty (no data)"); throw new IOException("Could not decode image from input stream: input stream was empty (no data)");
} }
bufferMat = new Mat(buffer); bufferMat = new Mat(buffer);
@ -354,9 +381,13 @@ public class NativeImageLoader extends BaseImageLoader {
@Override @Override
public Image asImageMatrix(File f, boolean nchw) throws IOException { public Image asImageMatrix(File f, boolean nchw) throws IOException {
try (BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f))) { Mat image = imread(f.getAbsolutePath(), IMREAD_ANYDEPTH | IMREAD_ANYCOLOR);
return asImageMatrix(bis, nchw); INDArray a = asMatrix(image);
} if(!nchw)
a = a.permute(0,2,3,1); //NCHW to NHWC
Image i = new Image(a, image.channels(), image.rows(), image.cols());
image.deallocate();
return i;
} }
@Override @Override
@ -366,7 +397,8 @@ public class NativeImageLoader extends BaseImageLoader {
@Override @Override
public Image asImageMatrix(InputStream inputStream, boolean nchw) throws IOException { public Image asImageMatrix(InputStream inputStream, boolean nchw) throws IOException {
Mat mat = streamToMat(inputStream); throw new RuntimeException("Deprecated. Not implemented.");
/*Mat mat = streamToMat(inputStream);
Mat image = imdecode(mat, IMREAD_ANYDEPTH | IMREAD_ANYCOLOR); Mat image = imdecode(mat, IMREAD_ANYDEPTH | IMREAD_ANYCOLOR);
if (image == null || image.empty()) { if (image == null || image.empty()) {
PIX pix = pixReadMem(mat.data(), mat.cols()); PIX pix = pixReadMem(mat.data(), mat.cols());
@ -383,6 +415,8 @@ public class NativeImageLoader extends BaseImageLoader {
image.deallocate(); image.deallocate();
return i; return i;
*/
} }
/** /**
@ -545,10 +579,15 @@ public class NativeImageLoader extends BaseImageLoader {
} }
public void asMatrixView(InputStream is, INDArray view) throws IOException { public void asMatrixView(InputStream is, INDArray view) throws IOException {
Mat mat = streamToMat(is); throw new RuntimeException("Not implemented");
Mat image = imdecode(mat, IMREAD_ANYDEPTH | IMREAD_ANYCOLOR);
}
public void asMatrixView(String filename, INDArray view) throws IOException {
Mat image = imread(filename,IMREAD_ANYDEPTH | IMREAD_ANYCOLOR );
//Mat image = imdecode(mat, IMREAD_ANYDEPTH | IMREAD_ANYCOLOR);
if (image == null || image.empty()) { if (image == null || image.empty()) {
PIX pix = pixReadMem(mat.data(), mat.cols()); PIX pix = pixReadMem(image.data(), image.cols());
if (pix == null) { if (pix == null) {
throw new IOException("Could not decode image from input stream"); throw new IOException("Could not decode image from input stream");
} }
@ -561,14 +600,8 @@ public class NativeImageLoader extends BaseImageLoader {
image.deallocate(); image.deallocate();
} }
public void asMatrixView(String filename, INDArray view) throws IOException {
asMatrixView(new File(filename), view);
}
public void asMatrixView(File f, INDArray view) throws IOException { public void asMatrixView(File f, INDArray view) throws IOException {
try (BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f))) { asMatrixView(f.getAbsolutePath(), view);
asMatrixView(bis, view);
}
} }
public void asMatrixView(Mat image, INDArray view) throws IOException { public void asMatrixView(Mat image, INDArray view) throws IOException {

View File

@ -53,6 +53,10 @@ import java.io.*;
import java.net.URI; import java.net.URI;
import java.util.*; import java.util.*;
/**
* Base class for the image record reader
*
*/
@Slf4j @Slf4j
public abstract class BaseImageRecordReader extends BaseRecordReader { public abstract class BaseImageRecordReader extends BaseRecordReader {
protected boolean finishedInputStreamSplit; protected boolean finishedInputStreamSplit;
@ -344,7 +348,8 @@ public abstract class BaseImageRecordReader extends BaseRecordReader {
((NativeImageLoader) imageLoader).asMatrixView(currBatch.get(i), ((NativeImageLoader) imageLoader).asMatrixView(currBatch.get(i),
features.tensorAlongDimension(i, 1, 2, 3)); features.tensorAlongDimension(i, 1, 2, 3));
} catch (Exception e) { } catch (Exception e) {
System.out.println("Image file failed during load: " + currBatch.get(i).getAbsolutePath()); System.out.println("Image file failed during load: " + currBatch.get(i).getAbsolutePath() + "\n" + e.getMessage());
e.printStackTrace();
throw new RuntimeException(e); throw new RuntimeException(e);
} }
} }

View File

@ -21,6 +21,7 @@
package org.datavec.image.transform; package org.datavec.image.transform;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.Getter; import lombok.Getter;
import lombok.Setter; import lombok.Setter;
import lombok.experimental.Accessors; import lombok.experimental.Accessors;
@ -38,6 +39,7 @@ import org.bytedeco.opencv.opencv_core.*;
@JsonIgnoreProperties({"borderValue"}) @JsonIgnoreProperties({"borderValue"})
@JsonInclude(JsonInclude.Include.NON_NULL) @JsonInclude(JsonInclude.Include.NON_NULL)
@Data @Data
@EqualsAndHashCode(callSuper = false)
public class BoxImageTransform extends BaseImageTransform<Mat> { public class BoxImageTransform extends BaseImageTransform<Mat> {
private int width; private int width;

View File

@ -21,6 +21,7 @@
package org.datavec.image.transform; package org.datavec.image.transform;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import org.bytedeco.javacv.OpenCVFrameConverter; import org.bytedeco.javacv.OpenCVFrameConverter;
import org.datavec.image.data.ImageWritable; import org.datavec.image.data.ImageWritable;
import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude;
@ -32,6 +33,7 @@ import static org.bytedeco.opencv.global.opencv_core.*;
@JsonInclude(JsonInclude.Include.NON_NULL) @JsonInclude(JsonInclude.Include.NON_NULL)
@Data @Data
@EqualsAndHashCode(callSuper = false)
public class FlipImageTransform extends BaseImageTransform<Mat> { public class FlipImageTransform extends BaseImageTransform<Mat> {
/** /**

View File

@ -21,6 +21,7 @@
package org.datavec.image.transform; package org.datavec.image.transform;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import org.bytedeco.javacv.OpenCVFrameConverter; import org.bytedeco.javacv.OpenCVFrameConverter;
import org.datavec.image.data.ImageWritable; import org.datavec.image.data.ImageWritable;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -32,6 +33,7 @@ import org.bytedeco.opencv.opencv_core.*;
import static org.bytedeco.opencv.global.opencv_imgproc.*; import static org.bytedeco.opencv.global.opencv_imgproc.*;
@Data @Data
@EqualsAndHashCode(callSuper = false)
public class LargestBlobCropTransform extends BaseImageTransform<Mat> { public class LargestBlobCropTransform extends BaseImageTransform<Mat> {
protected org.nd4j.linalg.api.rng.Random rng; protected org.nd4j.linalg.api.rng.Random rng;

View File

@ -21,6 +21,7 @@
package org.datavec.image.transform; package org.datavec.image.transform;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import org.datavec.image.data.ImageWritable; import org.datavec.image.data.ImageWritable;
import java.util.Random; import java.util.Random;
@ -28,6 +29,7 @@ import java.util.Random;
import org.bytedeco.opencv.opencv_core.*; import org.bytedeco.opencv.opencv_core.*;
@Data @Data
@EqualsAndHashCode(callSuper = false)
public class MultiImageTransform extends BaseImageTransform<Mat> { public class MultiImageTransform extends BaseImageTransform<Mat> {
private PipelineImageTransform transform; private PipelineImageTransform transform;

View File

@ -21,6 +21,7 @@
package org.datavec.image.transform; package org.datavec.image.transform;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NonNull; import lombok.NonNull;
import org.datavec.image.data.ImageWritable; import org.datavec.image.data.ImageWritable;
@ -32,6 +33,7 @@ import java.util.*;
import org.bytedeco.opencv.opencv_core.*; import org.bytedeco.opencv.opencv_core.*;
@Data @Data
@EqualsAndHashCode(callSuper = false)
public class PipelineImageTransform extends BaseImageTransform<Mat> { public class PipelineImageTransform extends BaseImageTransform<Mat> {
protected List<Pair<ImageTransform, Double>> imageTransforms; protected List<Pair<ImageTransform, Double>> imageTransforms;

View File

@ -21,6 +21,7 @@
package org.datavec.image.transform; package org.datavec.image.transform;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import org.bytedeco.javacv.OpenCVFrameConverter; import org.bytedeco.javacv.OpenCVFrameConverter;
import org.datavec.image.data.ImageWritable; import org.datavec.image.data.ImageWritable;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -35,6 +36,7 @@ import org.bytedeco.opencv.opencv_core.*;
@JsonIgnoreProperties({"rng", "converter"}) @JsonIgnoreProperties({"rng", "converter"})
@JsonInclude(JsonInclude.Include.NON_NULL) @JsonInclude(JsonInclude.Include.NON_NULL)
@Data @Data
@EqualsAndHashCode(callSuper = false)
public class RandomCropTransform extends BaseImageTransform<Mat> { public class RandomCropTransform extends BaseImageTransform<Mat> {
protected int outputHeight; protected int outputHeight;

View File

@ -612,28 +612,6 @@ public class TestNativeImageLoader {
NativeImageLoader il = new NativeImageLoader(32, 32, 3); NativeImageLoader il = new NativeImageLoader(32, 32, 3);
//asMatrix(File, boolean)
INDArray a_nchw = il.asMatrix(f);
INDArray a_nchw2 = il.asMatrix(f, true);
INDArray a_nhwc = il.asMatrix(f, false);
assertEquals(a_nchw, a_nchw2);
assertEquals(a_nchw, a_nhwc.permute(0,3,1,2));
//asMatrix(InputStream, boolean)
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
a_nchw = il.asMatrix(is);
}
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
a_nchw2 = il.asMatrix(is, true);
}
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
a_nhwc = il.asMatrix(is, false);
}
assertEquals(a_nchw, a_nchw2);
assertEquals(a_nchw, a_nhwc.permute(0,3,1,2));
//asImageMatrix(File, boolean) //asImageMatrix(File, boolean)
Image i_nchw = il.asImageMatrix(f); Image i_nchw = il.asImageMatrix(f);
@ -642,20 +620,6 @@ public class TestNativeImageLoader {
assertEquals(i_nchw.getImage(), i_nchw2.getImage()); assertEquals(i_nchw.getImage(), i_nchw2.getImage());
assertEquals(i_nchw.getImage(), i_nhwc.getImage().permute(0,3,1,2)); //NHWC to NCHW assertEquals(i_nchw.getImage(), i_nhwc.getImage().permute(0,3,1,2)); //NHWC to NCHW
//asImageMatrix(InputStream, boolean)
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
i_nchw = il.asImageMatrix(is);
}
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
i_nchw2 = il.asImageMatrix(is, true);
}
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
i_nhwc = il.asImageMatrix(is, false);
}
assertEquals(i_nchw.getImage(), i_nchw2.getImage());
assertEquals(i_nchw.getImage(), i_nhwc.getImage().permute(0,3,1,2)); //NHWC to NCHW
} }
} }

View File

@ -20,11 +20,8 @@
package org.nd4j.autodiff.functions; package org.nd4j.autodiff.functions;
import lombok.Data; import lombok.*;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val;
import onnx.Onnx; import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;

View File

@ -46,7 +46,9 @@ public class SDVariable implements Serializable {
protected SameDiff sameDiff; protected SameDiff sameDiff;
@Getter @Getter
@Setter
protected String varName; protected String varName;
@Getter @Getter
@Setter @Setter
protected VariableType variableType; protected VariableType variableType;
@ -83,18 +85,6 @@ public class SDVariable implements Serializable {
return varName; return varName;
} }
public void setVarName(String varName) {
this.varName = varName;
}
/**
* @deprecated Use {@link #name()}
*/
@Deprecated
public String getVarName(){
return name();
}
/** /**
* Returns true if this variable is a place holder * Returns true if this variable is a place holder
* @return * @return

View File

@ -39,5 +39,6 @@ public class Variable {
protected String outputOfOp; //Null for placeholders/constants. For array type SDVariables, the name of the op it's an output of protected String outputOfOp; //Null for placeholders/constants. For array type SDVariables, the name of the op it's an output of
protected List<String> controlDeps; //Control dependencies: name of ops that must be available before this variable is considered available for execution protected List<String> controlDeps; //Control dependencies: name of ops that must be available before this variable is considered available for execution
protected SDVariable gradient; //Variable corresponding to the gradient of this variable protected SDVariable gradient; //Variable corresponding to the gradient of this variable
@Builder.Default
protected int variableIndex = -1; protected int variableIndex = -1;
} }

View File

@ -76,7 +76,7 @@ public class SameDiffUtils {
public static ExternalErrorsFunction externalErrors(SameDiff sameDiff, Map<String, INDArray> externalGradients, SDVariable... inputs) { public static ExternalErrorsFunction externalErrors(SameDiff sameDiff, Map<String, INDArray> externalGradients, SDVariable... inputs) {
Preconditions.checkArgument(inputs != null && inputs.length > 0, "Require at least one SDVariable to" + Preconditions.checkArgument(inputs != null && inputs.length > 0, "Require at least one SDVariable to" +
" be specified when using external errors: got %s", inputs); " be specified when using external errors: got %s", (Object) inputs);
ExternalErrorsFunction fn = new ExternalErrorsFunction(sameDiff, Arrays.asList(inputs), externalGradients); ExternalErrorsFunction fn = new ExternalErrorsFunction(sameDiff, Arrays.asList(inputs), externalGradients);
fn.outputVariable(); fn.outputVariable();
return fn; return fn;

View File

@ -49,7 +49,7 @@ import java.io.Serializable;
import java.util.List; import java.util.List;
@Getter @Getter
@EqualsAndHashCode @EqualsAndHashCode(callSuper = false)
public class EvaluationCalibration extends BaseEvaluation<EvaluationCalibration> { public class EvaluationCalibration extends BaseEvaluation<EvaluationCalibration> {
public static final int DEFAULT_RELIABILITY_DIAG_NUM_BINS = 10; public static final int DEFAULT_RELIABILITY_DIAG_NUM_BINS = 10;

View File

@ -22,8 +22,10 @@ package org.nd4j.evaluation.curves;
import lombok.Data; import lombok.Data;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.EqualsAndHashCode;
@Data @Data
@EqualsAndHashCode(callSuper = false)
public class Histogram extends BaseHistogram { public class Histogram extends BaseHistogram {
private final String title; private final String title;
private final double lower; private final double lower;

View File

@ -21,6 +21,8 @@
package org.nd4j.linalg.api.memory.deallocation; package org.nd4j.linalg.api.memory.deallocation;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import org.nd4j.linalg.api.memory.Deallocatable; import org.nd4j.linalg.api.memory.Deallocatable;
import org.nd4j.linalg.api.memory.Deallocator; import org.nd4j.linalg.api.memory.Deallocator;
@ -28,6 +30,7 @@ import java.lang.ref.ReferenceQueue;
import java.lang.ref.WeakReference; import java.lang.ref.WeakReference;
@Data @Data
@EqualsAndHashCode(callSuper = false)
public class DeallocatableReference extends WeakReference<Deallocatable> { public class DeallocatableReference extends WeakReference<Deallocatable> {
private String id; private String id;
private Deallocator deallocator; private Deallocator deallocator;

View File

@ -21,6 +21,7 @@
package org.nd4j.linalg.api.ops; package org.nd4j.linalg.api.ops;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
@ -36,6 +37,7 @@ import java.util.List;
@Slf4j @Slf4j
@Data @Data
@EqualsAndHashCode(callSuper = false)
public abstract class BaseIndexAccumulation extends BaseOp implements IndexAccumulation { public abstract class BaseIndexAccumulation extends BaseOp implements IndexAccumulation {
protected boolean keepDims = false; protected boolean keepDims = false;

View File

@ -44,6 +44,9 @@ import java.lang.reflect.Array;
import java.util.*; import java.util.*;
@Slf4j @Slf4j
@Builder
@AllArgsConstructor
@EqualsAndHashCode(callSuper = true)
public class DynamicCustomOp extends DifferentialFunction implements CustomOp { public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
private String opName; private String opName;

View File

@ -21,6 +21,7 @@
package org.nd4j.linalg.api.ops.custom; package org.nd4j.linalg.api.ops.custom;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import lombok.val; import lombok.val;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
@ -35,6 +36,7 @@ import java.util.List;
@Data @Data
@NoArgsConstructor @NoArgsConstructor
@EqualsAndHashCode(callSuper = false)
public class Flatten extends DynamicCustomOp { public class Flatten extends DynamicCustomOp {
private int order; private int order;

View File

@ -21,6 +21,8 @@
package org.nd4j.linalg.api.ops.impl.controlflow.compat; package org.nd4j.linalg.api.ops.impl.controlflow.compat;
import java.util.List; import java.util.List;
import lombok.EqualsAndHashCode;
import lombok.NonNull; import lombok.NonNull;
import lombok.val; import lombok.val;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
@ -39,6 +41,7 @@ import org.tensorflow.framework.NodeDef;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
@EqualsAndHashCode(callSuper = false)
public abstract class BaseCompatOp extends DynamicCustomOp { public abstract class BaseCompatOp extends DynamicCustomOp {
protected String frameName; protected String frameName;

View File

@ -21,6 +21,7 @@
package org.nd4j.linalg.api.ops.impl.controlflow.compat; package org.nd4j.linalg.api.ops.impl.controlflow.compat;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
@ -37,6 +38,7 @@ import java.util.Map;
@Data @Data
@NoArgsConstructor @NoArgsConstructor
@EqualsAndHashCode(callSuper = false)
public class Enter extends BaseCompatOp { public class Enter extends BaseCompatOp {
protected boolean isConstant; protected boolean isConstant;

View File

@ -21,6 +21,7 @@
package org.nd4j.linalg.api.ops.impl.controlflow.compat; package org.nd4j.linalg.api.ops.impl.controlflow.compat;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
@ -36,6 +37,7 @@ import java.util.Map;
@Data @Data
@NoArgsConstructor @NoArgsConstructor
@EqualsAndHashCode(callSuper = false)
public class While extends BaseCompatOp { public class While extends BaseCompatOp {
protected boolean isConstant; protected boolean isConstant;

View File

@ -21,6 +21,7 @@
package org.nd4j.linalg.api.ops.impl.indexaccum; package org.nd4j.linalg.api.ops.impl.indexaccum;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import lombok.NonNull; import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
@ -36,6 +37,7 @@ import java.util.List;
@Data @Data
@NoArgsConstructor @NoArgsConstructor
@EqualsAndHashCode(callSuper = false)
public class FirstIndex extends BaseIndexAccumulation { public class FirstIndex extends BaseIndexAccumulation {
protected Condition condition; protected Condition condition;
protected double compare; protected double compare;

View File

@ -21,6 +21,7 @@
package org.nd4j.linalg.api.ops.impl.indexaccum; package org.nd4j.linalg.api.ops.impl.indexaccum;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import lombok.NonNull; import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
@ -38,6 +39,7 @@ import java.util.Map;
@Data @Data
@NoArgsConstructor @NoArgsConstructor
@EqualsAndHashCode(callSuper = false)
public class LastIndex extends BaseIndexAccumulation { public class LastIndex extends BaseIndexAccumulation {
protected Condition condition; protected Condition condition;
protected double compare; protected double compare;

View File

@ -21,6 +21,7 @@
package org.nd4j.linalg.api.ops.impl.indexaccum.custom; package org.nd4j.linalg.api.ops.impl.indexaccum.custom;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions; import org.nd4j.common.base.Preconditions;
@ -38,6 +39,7 @@ import java.util.List;
import java.util.Map; import java.util.Map;
@Data @Data
@EqualsAndHashCode(callSuper = false)
public class ArgAmax extends DynamicCustomOp { public class ArgAmax extends DynamicCustomOp {
protected boolean keepDims = false; protected boolean keepDims = false;
private int[] dimensions; private int[] dimensions;

View File

@ -21,6 +21,7 @@
package org.nd4j.linalg.api.ops.impl.indexaccum.custom; package org.nd4j.linalg.api.ops.impl.indexaccum.custom;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions; import org.nd4j.common.base.Preconditions;
@ -38,6 +39,7 @@ import java.util.List;
import java.util.Map; import java.util.Map;
@Data @Data
@EqualsAndHashCode(callSuper = false)
public class ArgAmin extends DynamicCustomOp { public class ArgAmin extends DynamicCustomOp {
protected boolean keepDims = false; protected boolean keepDims = false;
private int[] dimensions; private int[] dimensions;

View File

@ -21,6 +21,7 @@
package org.nd4j.linalg.api.ops.impl.indexaccum.custom; package org.nd4j.linalg.api.ops.impl.indexaccum.custom;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions; import org.nd4j.common.base.Preconditions;
@ -37,6 +38,7 @@ import java.util.List;
import java.util.Map; import java.util.Map;
@Data @Data
@EqualsAndHashCode(callSuper = false)
public class ArgMax extends DynamicCustomOp { public class ArgMax extends DynamicCustomOp {
protected boolean keepDims = false; protected boolean keepDims = false;
private int[] dimensions; private int[] dimensions;

View File

@ -21,6 +21,7 @@
package org.nd4j.linalg.api.ops.impl.indexaccum.custom; package org.nd4j.linalg.api.ops.impl.indexaccum.custom;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions; import org.nd4j.common.base.Preconditions;
@ -37,6 +38,7 @@ import java.util.List;
import java.util.Map; import java.util.Map;
@Data @Data
@EqualsAndHashCode(callSuper = false)
public class ArgMin extends DynamicCustomOp { public class ArgMin extends DynamicCustomOp {
protected boolean keepDims = false; protected boolean keepDims = false;
private int[] dimensions; private int[] dimensions;

View File

@ -22,16 +22,15 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution.config;
import java.util.LinkedHashMap; import java.util.LinkedHashMap;
import java.util.Map; import java.util.Map;
import lombok.Builder;
import lombok.Data; import lombok.*;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import org.nd4j.common.base.Preconditions; import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.util.ConvConfigUtil; import org.nd4j.linalg.util.ConvConfigUtil;
@Data @Data
@Builder @Builder
@NoArgsConstructor @NoArgsConstructor
@EqualsAndHashCode(callSuper = false)
public class Conv1DConfig extends BaseConvolutionConfig { public class Conv1DConfig extends BaseConvolutionConfig {
public static final String NCW = "NCW"; public static final String NCW = "NCW";
public static final String NWC = "NWC"; public static final String NWC = "NWC";

View File

@ -24,6 +24,7 @@ import java.util.LinkedHashMap;
import java.util.Map; import java.util.Map;
import lombok.Builder; import lombok.Builder;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import org.nd4j.common.base.Preconditions; import org.nd4j.common.base.Preconditions;
import org.nd4j.enums.WeightsFormat; import org.nd4j.enums.WeightsFormat;
@ -32,6 +33,7 @@ import org.nd4j.linalg.util.ConvConfigUtil;
@Data @Data
@Builder @Builder
@NoArgsConstructor @NoArgsConstructor
@EqualsAndHashCode(callSuper = false)
public class Conv2DConfig extends BaseConvolutionConfig { public class Conv2DConfig extends BaseConvolutionConfig {
public static final String NCHW = "NCHW"; public static final String NCHW = "NCHW";
public static final String NHWC = "NHWC"; public static final String NHWC = "NHWC";

View File

@ -25,6 +25,7 @@ import java.util.LinkedHashMap;
import java.util.Map; import java.util.Map;
import lombok.Builder; import lombok.Builder;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import org.nd4j.common.base.Preconditions; import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.util.ConvConfigUtil; import org.nd4j.linalg.util.ConvConfigUtil;
@ -32,6 +33,7 @@ import org.nd4j.linalg.util.ConvConfigUtil;
@Data @Data
@Builder @Builder
@NoArgsConstructor @NoArgsConstructor
@EqualsAndHashCode(callSuper = false)
public class Conv3DConfig extends BaseConvolutionConfig { public class Conv3DConfig extends BaseConvolutionConfig {
public static final String NDHWC = "NDHWC"; public static final String NDHWC = "NDHWC";
public static final String NCDHW = "NCDHW"; public static final String NCDHW = "NCDHW";

View File

@ -24,6 +24,7 @@ import java.util.LinkedHashMap;
import java.util.Map; import java.util.Map;
import lombok.Builder; import lombok.Builder;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import org.nd4j.common.base.Preconditions; import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.util.ConvConfigUtil; import org.nd4j.linalg.util.ConvConfigUtil;
@ -31,6 +32,7 @@ import org.nd4j.linalg.util.ConvConfigUtil;
@Data @Data
@Builder @Builder
@NoArgsConstructor @NoArgsConstructor
@EqualsAndHashCode(callSuper = false)
public class DeConv2DConfig extends BaseConvolutionConfig { public class DeConv2DConfig extends BaseConvolutionConfig {
public static final String NCHW = "NCHW"; public static final String NCHW = "NCHW";
public static final String NHWC = "NHWC"; public static final String NHWC = "NHWC";

View File

@ -24,6 +24,7 @@ import java.util.LinkedHashMap;
import java.util.Map; import java.util.Map;
import lombok.Builder; import lombok.Builder;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import org.nd4j.common.base.Preconditions; import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.util.ConvConfigUtil; import org.nd4j.linalg.util.ConvConfigUtil;
@ -31,6 +32,7 @@ import org.nd4j.linalg.util.ConvConfigUtil;
@Data @Data
@Builder @Builder
@NoArgsConstructor @NoArgsConstructor
@EqualsAndHashCode(callSuper = false)
public class DeConv3DConfig extends BaseConvolutionConfig { public class DeConv3DConfig extends BaseConvolutionConfig {
public static final String NCDHW = "NCDHW"; public static final String NCDHW = "NCDHW";
public static final String NDHWC = "NDHWC"; public static final String NDHWC = "NDHWC";

View File

@ -24,12 +24,14 @@ import java.util.LinkedHashMap;
import java.util.Map; import java.util.Map;
import lombok.Builder; import lombok.Builder;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import org.nd4j.linalg.util.ConvConfigUtil; import org.nd4j.linalg.util.ConvConfigUtil;
@Data @Data
@Builder @Builder
@NoArgsConstructor @NoArgsConstructor
@EqualsAndHashCode(callSuper = false)
public class LocalResponseNormalizationConfig extends BaseConvolutionConfig { public class LocalResponseNormalizationConfig extends BaseConvolutionConfig {
private double alpha, beta, bias; private double alpha, beta, bias;

View File

@ -24,6 +24,7 @@ import java.util.LinkedHashMap;
import java.util.Map; import java.util.Map;
import lombok.Builder; import lombok.Builder;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D; import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D.Divisor; import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D.Divisor;
@ -33,6 +34,7 @@ import org.nd4j.linalg.util.ConvConfigUtil;
@Data @Data
@Builder @Builder
@NoArgsConstructor @NoArgsConstructor
@EqualsAndHashCode(callSuper = false)
public class Pooling2DConfig extends BaseConvolutionConfig { public class Pooling2DConfig extends BaseConvolutionConfig {
@Builder.Default private long kH = -1, kW = -1; @Builder.Default private long kH = -1, kW = -1;

View File

@ -24,6 +24,7 @@ import java.util.LinkedHashMap;
import java.util.Map; import java.util.Map;
import lombok.Builder; import lombok.Builder;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling3D; import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling3D;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling3D.Pooling3DType; import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling3D.Pooling3DType;
@ -32,6 +33,7 @@ import org.nd4j.linalg.util.ConvConfigUtil;
@Data @Data
@Builder @Builder
@NoArgsConstructor @NoArgsConstructor
@EqualsAndHashCode(callSuper = false)
public class Pooling3DConfig extends BaseConvolutionConfig { public class Pooling3DConfig extends BaseConvolutionConfig {
@Builder.Default private long kD = -1, kW = -1, kH = -1; // kernel @Builder.Default private long kD = -1, kW = -1, kH = -1; // kernel
@Builder.Default private long sD = 1, sW = 1, sH = 1; // strides @Builder.Default private long sD = 1, sW = 1, sH = 1; // strides

View File

@ -39,7 +39,7 @@ import org.tensorflow.framework.NodeDef;
import java.lang.reflect.Field; import java.lang.reflect.Field;
import java.util.*; import java.util.*;
@EqualsAndHashCode @EqualsAndHashCode(callSuper = false)
public class Mmul extends DynamicCustomOp { public class Mmul extends DynamicCustomOp {
protected MMulTranspose mt; protected MMulTranspose mt;

View File

@ -32,7 +32,7 @@ import org.nd4j.common.util.ArrayUtil;
import java.util.List; import java.util.List;
@EqualsAndHashCode @EqualsAndHashCode(callSuper = false)
public class MmulBp extends DynamicCustomOp { public class MmulBp extends DynamicCustomOp {
protected MMulTranspose mt; protected MMulTranspose mt;

View File

@ -32,7 +32,7 @@ import org.nd4j.linalg.factory.Nd4j;
import java.util.*; import java.util.*;
@EqualsAndHashCode @EqualsAndHashCode(callSuper = false)
public class BatchMmul extends DynamicCustomOp { public class BatchMmul extends DynamicCustomOp {
protected int transposeA; protected int transposeA;

View File

@ -39,10 +39,15 @@ import java.util.Map;
public class BalanceMinibatches { public class BalanceMinibatches {
private DataSetIterator dataSetIterator; private DataSetIterator dataSetIterator;
private int numLabels; private int numLabels;
@Builder.Default
private Map<Integer, List<File>> paths = Maps.newHashMap(); private Map<Integer, List<File>> paths = Maps.newHashMap();
@Builder.Default
private int miniBatchSize = -1; private int miniBatchSize = -1;
@Builder.Default
private File rootDir = new File("minibatches"); private File rootDir = new File("minibatches");
@Builder.Default
private File rootSaveDir = new File("minibatchessave"); private File rootSaveDir = new File("minibatchessave");
@Builder.Default
private List<File> labelRootDirs = new ArrayList<>(); private List<File> labelRootDirs = new ArrayList<>();
private DataNormalization dataNormalization; private DataNormalization dataNormalization;

View File

@ -20,7 +20,8 @@
package org.nd4j.linalg.learning.config; package org.nd4j.linalg.learning.config;
import lombok.*; import lombok.Builder;
import lombok.Data;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.learning.AdaMaxUpdater; import org.nd4j.linalg.learning.AdaMaxUpdater;
import org.nd4j.linalg.learning.GradientUpdater; import org.nd4j.linalg.learning.GradientUpdater;
@ -44,7 +45,8 @@ public class AdaMax implements IUpdater {
public static final double DEFAULT_ADAMAX_BETA1_MEAN_DECAY = 0.9; public static final double DEFAULT_ADAMAX_BETA1_MEAN_DECAY = 0.9;
public static final double DEFAULT_ADAMAX_BETA2_VAR_DECAY = 0.999; public static final double DEFAULT_ADAMAX_BETA2_VAR_DECAY = 0.999;
@lombok.Builder.Default private double learningRate = DEFAULT_ADAMAX_LEARNING_RATE; // learning rate @lombok.Builder.Default
private double learningRate = DEFAULT_ADAMAX_LEARNING_RATE; // learning rate
private ISchedule learningRateSchedule; private ISchedule learningRateSchedule;
@lombok.Builder.Default private double beta1 = DEFAULT_ADAMAX_BETA1_MEAN_DECAY; // gradient moving avg decay rate @lombok.Builder.Default private double beta1 = DEFAULT_ADAMAX_BETA1_MEAN_DECAY; // gradient moving avg decay rate
@lombok.Builder.Default private double beta2 = DEFAULT_ADAMAX_BETA2_VAR_DECAY; // gradient sqrd decay rate @lombok.Builder.Default private double beta2 = DEFAULT_ADAMAX_BETA2_VAR_DECAY; // gradient sqrd decay rate

View File

@ -335,20 +335,6 @@ public class OpProfiler {
} }
} }
/**
* Dev-time method.
*
* @return
*/
protected StackAggregator getMixedOrderAggregator() {
// FIXME: remove this method, or make it protected
return mixedOrderAggregator;
}
public StackAggregator getScalarAggregator() {
return scalarAggregator;
}
protected void updatePairs(String opName, String opClass) { protected void updatePairs(String opName, String opClass) {
// now we save pairs of ops/classes // now we save pairs of ops/classes
String cOpNameKey = prevOpName + " -> " + opName; String cOpNameKey = prevOpName + " -> " + opName;

View File

@ -58,13 +58,6 @@ public abstract class BaseDL4JTest {
return DEFAULT_THREADS; return DEFAULT_THREADS;
} }
/**
* Override this method to set the default timeout for methods in the test class
*/
public long getTimeoutMilliseconds(){
return 90_000;
}
/** /**
* Override this to set the profiling mode for the tests defined in the child class * Override this to set the profiling mode for the tests defined in the child class
*/ */

View File

@ -16,4 +16,6 @@ dependencies {
implementation 'org.apache.commons:commons-math3' implementation 'org.apache.commons:commons-math3'
implementation 'org.apache.commons:commons-lang3' implementation 'org.apache.commons:commons-lang3'
implementation 'org.apache.commons:commons-compress' implementation 'org.apache.commons:commons-compress'
testRuntimeOnly 'net.brutex.ai:dl4j-test-resources:1.0.1-SNAPSHOT'
} }

View File

@ -27,11 +27,6 @@ import org.junit.jupiter.api.Test;
public class TestDataSets extends BaseDL4JTest { public class TestDataSets extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 180000L;
}
@Test @Test
public void testTinyImageNetExists() throws Exception { public void testTinyImageNetExists() throws Exception {
//Simple sanity check on extracting //Simple sanity check on extracting

View File

@ -35,11 +35,6 @@ import static org.junit.jupiter.api.Assumptions.assumeTrue;
*/ */
public class SvhnDataFetcherTest extends BaseDL4JTest { public class SvhnDataFetcherTest extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 480_000_000L; //Shouldn't take this long but slow download or drive access on CI machines may need extra time.
}
@Test @Test
public void testSvhnDataFetcher() throws Exception { public void testSvhnDataFetcher() throws Exception {
assumeTrue(isIntegrationTests()); //Ignore unless integration tests - CI can get caught up on slow disk access assumeTrue(isIntegrationTests()); //Ignore unless integration tests - CI can get caught up on slow disk access

View File

@ -59,11 +59,6 @@ import static org.junit.jupiter.api.Assertions.*;
public class DataSetIteratorTest extends BaseDL4JTest { public class DataSetIteratorTest extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 360000; //Should run quickly; increased to large timeout due to occasonal slow CI downloads
}
@Test @Test
public void testBatchSizeOfOneIris() throws Exception { public void testBatchSizeOfOneIris() throws Exception {
//Test for (a) iterators returning correct number of examples, and //Test for (a) iterators returning correct number of examples, and

View File

@ -21,6 +21,7 @@
package org.deeplearning4j.earlystopping; package org.deeplearning4j.earlystopping;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.TestUtils; import org.deeplearning4j.TestUtils;
@ -817,6 +818,7 @@ public class TestEarlyStopping extends BaseDL4JTest {
} }
@Data @Data
@EqualsAndHashCode(callSuper = false)
public static class TestListener extends BaseTrainingListener { public static class TestListener extends BaseTrainingListener {
private int countEpochStart = 0; private int countEpochStart = 0;
private int countEpochEnd = 0; private int countEpochEnd = 0;

View File

@ -51,10 +51,6 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
////@Ignore ////@Ignore
public class AttentionLayerTest extends BaseDL4JTest { public class AttentionLayerTest extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
@Test @Test
public void testSelfAttentionLayer() { public void testSelfAttentionLayer() {

View File

@ -62,11 +62,6 @@ public class BNGradientCheckTest extends BaseDL4JTest {
Nd4j.setDataType(DataType.DOUBLE); Nd4j.setDataType(DataType.DOUBLE);
} }
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
@Test @Test
public void testGradient2dSimple() { public void testGradient2dSimple() {
DataNormalization scaler = new NormalizerMinMaxScaler(); DataNormalization scaler = new NormalizerMinMaxScaler();

View File

@ -62,11 +62,6 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
Nd4j.setDataType(DataType.DOUBLE); Nd4j.setDataType(DataType.DOUBLE);
} }
@Override
public long getTimeoutMilliseconds() {
return 180000;
}
@Test @Test
public void testCnn1DWithLocallyConnected1D() { public void testCnn1DWithLocallyConnected1D() {
Nd4j.getRandom().setSeed(1337); Nd4j.getRandom().setSeed(1337);

View File

@ -59,11 +59,6 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest {
Nd4j.setDataType(DataType.DOUBLE); Nd4j.setDataType(DataType.DOUBLE);
} }
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
@Test @Test
public void testCnn3DPlain() { public void testCnn3DPlain() {
Nd4j.getRandom().setSeed(1337); Nd4j.getRandom().setSeed(1337);

View File

@ -73,11 +73,6 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
return CNN2DFormat.values(); return CNN2DFormat.values();
} }
@Override
public long getTimeoutMilliseconds() {
return 999990000L;
}
@Test @Test
public void testGradientCNNMLN() { public void testGradientCNNMLN() {
if(this.format != CNN2DFormat.NCHW) //Only test NCHW due to flat input format... if(this.format != CNN2DFormat.NCHW) //Only test NCHW due to flat input format...

View File

@ -49,11 +49,6 @@ import java.util.Random;
////@Ignore ////@Ignore
public class CapsnetGradientCheckTest extends BaseDL4JTest { public class CapsnetGradientCheckTest extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
@Test @Test
public void testCapsNet() { public void testCapsNet() {

View File

@ -59,11 +59,6 @@ public class DropoutGradientCheck extends BaseDL4JTest {
Nd4j.setDataType(DataType.DOUBLE); Nd4j.setDataType(DataType.DOUBLE);
} }
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
@Test @Test
public void testDropoutGradient() { public void testDropoutGradient() {
int minibatch = 3; int minibatch = 3;

View File

@ -55,11 +55,6 @@ public class GlobalPoolingGradientCheckTests extends BaseDL4JTest {
private static final double DEFAULT_MAX_REL_ERROR = 1e-3; private static final double DEFAULT_MAX_REL_ERROR = 1e-3;
private static final double DEFAULT_MIN_ABS_ERROR = 1e-8; private static final double DEFAULT_MIN_ABS_ERROR = 1e-8;
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
@Test @Test
public void testRNNGlobalPoolingBasicMultiLayer() { public void testRNNGlobalPoolingBasicMultiLayer() {
//Basic test of global pooling w/ LSTM //Basic test of global pooling w/ LSTM

View File

@ -70,11 +70,6 @@ public class GradientCheckTests extends BaseDL4JTest {
Nd4j.setDataType(DataType.DOUBLE); Nd4j.setDataType(DataType.DOUBLE);
} }
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
@Test @Test
public void testMinibatchApplication() { public void testMinibatchApplication() {
IrisDataSetIterator iter = new IrisDataSetIterator(30, 150); IrisDataSetIterator iter = new IrisDataSetIterator(30, 150);

View File

@ -71,11 +71,6 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
Nd4j.setDataType(DataType.DOUBLE); Nd4j.setDataType(DataType.DOUBLE);
} }
@Override
public long getTimeoutMilliseconds() {
return 999999999L;
}
@Test @Test
public void testBasicIris() { public void testBasicIris() {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);

View File

@ -59,11 +59,6 @@ public class GradientCheckTestsMasking extends BaseDL4JTest {
Nd4j.setDataType(DataType.DOUBLE); Nd4j.setDataType(DataType.DOUBLE);
} }
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
private static class GradientCheckSimpleScenario { private static class GradientCheckSimpleScenario {
private final ILossFunction lf; private final ILossFunction lf;
private final Activation act; private final Activation act;

View File

@ -54,12 +54,6 @@ public class LRNGradientCheckTests extends BaseDL4JTest {
Nd4j.setDataType(DataType.DOUBLE); Nd4j.setDataType(DataType.DOUBLE);
} }
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
@Test @Test
public void testGradientLRNSimple() { public void testGradientLRNSimple() {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);

View File

@ -55,11 +55,6 @@ public class LSTMGradientCheckTests extends BaseDL4JTest {
Nd4j.setDataType(DataType.DOUBLE); Nd4j.setDataType(DataType.DOUBLE);
} }
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
@Test @Test
public void testLSTMBasicMultiLayer() { public void testLSTMBasicMultiLayer() {
//Basic test of GravesLSTM layer //Basic test of GravesLSTM layer

View File

@ -73,11 +73,6 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
private static final double DEFAULT_MAX_REL_ERROR = 1e-5; private static final double DEFAULT_MAX_REL_ERROR = 1e-5;
private static final double DEFAULT_MIN_ABS_ERROR = 1e-8; private static final double DEFAULT_MIN_ABS_ERROR = 1e-8;
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
@Test @Test
public void lossFunctionGradientCheck() { public void lossFunctionGradientCheck() {
ILossFunction[] lossFunctions = new ILossFunction[] {new LossBinaryXENT(), new LossBinaryXENT(), ILossFunction[] lossFunctions = new ILossFunction[] {new LossBinaryXENT(), new LossBinaryXENT(),

View File

@ -52,11 +52,6 @@ public class NoBiasGradientCheckTests extends BaseDL4JTest {
Nd4j.setDataType(DataType.DOUBLE); Nd4j.setDataType(DataType.DOUBLE);
} }
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
@Test @Test
public void testGradientNoBiasDenseOutput() { public void testGradientNoBiasDenseOutput() {

View File

@ -52,11 +52,6 @@ public class OutputLayerGradientChecks extends BaseDL4JTest {
Nd4j.setDataType(DataType.DOUBLE); Nd4j.setDataType(DataType.DOUBLE);
} }
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
@Test @Test
public void testRnnLossLayer() { public void testRnnLossLayer() {
Nd4j.getRandom().setSeed(12345L); Nd4j.getRandom().setSeed(12345L);

View File

@ -55,11 +55,6 @@ public class RnnGradientChecks extends BaseDL4JTest {
Nd4j.setDataType(DataType.DOUBLE); Nd4j.setDataType(DataType.DOUBLE);
} }
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
@Test @Test
////@Ignore("AB 2019/06/24 - Ignored to get to all passing baseline to prevent regressions via CI - see issue #7912") ////@Ignore("AB 2019/06/24 - Ignored to get to all passing baseline to prevent regressions via CI - see issue #7912")
public void testBidirectionalWrapper() { public void testBidirectionalWrapper() {

View File

@ -56,11 +56,6 @@ public class UtilLayerGradientChecks extends BaseDL4JTest {
Nd4j.setDataType(DataType.DOUBLE); Nd4j.setDataType(DataType.DOUBLE);
} }
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
@Test @Test
public void testMaskLayer() { public void testMaskLayer() {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);

View File

@ -57,11 +57,6 @@ public class VaeGradientCheckTests extends BaseDL4JTest {
Nd4j.setDataType(DataType.DOUBLE); Nd4j.setDataType(DataType.DOUBLE);
} }
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
@Test @Test
public void testVaeAsMLP() { public void testVaeAsMLP() {
//Post pre-training: a VAE can be used as a MLP, by taking the mean value from p(z|x) as the output //Post pre-training: a VAE can be used as a MLP, by taking the mean value from p(z|x) as the output

View File

@ -72,11 +72,6 @@ public class YoloGradientCheckTests extends BaseDL4JTest {
@TempDir @TempDir
public File testDir; public File testDir;
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
@Test @Test
public void testYoloOutputLayer() { public void testYoloOutputLayer() {
int depthIn = 2; int depthIn = 2;

View File

@ -186,11 +186,6 @@ public class DTypeTests extends BaseDL4JTest {
TensorFlowCnnToFeedForwardPreProcessor.class //Deprecated TensorFlowCnnToFeedForwardPreProcessor.class //Deprecated
)); ));
@Override
public long getTimeoutMilliseconds() {
return 9999999L;
}
@AfterAll @AfterAll
public static void after() { public static void after() {
ImmutableSet<ClassPath.ClassInfo> info; ImmutableSet<ClassPath.ClassInfo> info;

View File

@ -26,7 +26,7 @@ import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.common.primitives.Pair; import org.nd4j.common.primitives.Pair;
@EqualsAndHashCode @EqualsAndHashCode(callSuper = false)
public class CustomActivation extends BaseActivationFunction implements IActivation { public class CustomActivation extends BaseActivationFunction implements IActivation {
@Override @Override
public INDArray getActivation(INDArray in, boolean training) { public INDArray getActivation(INDArray in, boolean training) {

View File

@ -93,11 +93,6 @@ public class BatchNormalizationTest extends BaseDL4JTest {
public void doBefore() { public void doBefore() {
} }
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
@Test @Test
public void testDnnForwardPass() { public void testDnnForwardPass() {
int nOut = 10; int nOut = 10;

View File

@ -21,6 +21,7 @@
package org.deeplearning4j.nn.layers.samediff.testlayers; package org.deeplearning4j.nn.layers.samediff.testlayers;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import org.deeplearning4j.nn.conf.graph.GraphVertex; import org.deeplearning4j.nn.conf.graph.GraphVertex;
import org.deeplearning4j.nn.conf.layers.samediff.SDVertexParams; import org.deeplearning4j.nn.conf.layers.samediff.SDVertexParams;
@ -37,6 +38,7 @@ import java.util.Map;
@NoArgsConstructor @NoArgsConstructor
@Data @Data
@EqualsAndHashCode(callSuper = false)
public class SameDiffDenseVertex extends SameDiffVertex { public class SameDiffDenseVertex extends SameDiffVertex {
private int nIn; private int nIn;

View File

@ -21,6 +21,7 @@
package org.deeplearning4j.nn.multilayer; package org.deeplearning4j.nn.multilayer;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.TestUtils; import org.deeplearning4j.TestUtils;
@ -1424,6 +1425,7 @@ public class MultiLayerTest extends BaseDL4JTest {
} }
@Data @Data
@EqualsAndHashCode(callSuper = false)
public static class CheckModelsListener extends BaseTrainingListener { public static class CheckModelsListener extends BaseTrainingListener {
private Set<Class<?>> modelClasses = new HashSet<>(); private Set<Class<?>> modelClasses = new HashSet<>();

View File

@ -39,11 +39,6 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
@Slf4j @Slf4j
public class EncodedGradientsAccumulatorTest extends BaseDL4JTest { public class EncodedGradientsAccumulatorTest extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 1200000L;
}
/** /**
* This test ensures, that memory amount assigned to buffer is enough for any number of updates * This test ensures, that memory amount assigned to buffer is enough for any number of updates
* @throws Exception * @throws Exception

View File

@ -39,7 +39,7 @@ import static org.junit.jupiter.api.Assertions.*;
@Slf4j @Slf4j
////@Ignore("AB 2019/05/21 - Failing (stuck, causing timeouts) - Issue #7657") ////@Ignore("AB 2019/05/21 - Failing (stuck, causing timeouts) - Issue #7657")
@Timeout(120000L) //@Timeout(120000L)
public class SmartFancyBlockingQueueTest extends BaseDL4JTest { public class SmartFancyBlockingQueueTest extends BaseDL4JTest {
@Test @Test

View File

@ -47,11 +47,6 @@ import static org.junit.jupiter.api.Assertions.*;
public class TestCheckpointListener extends BaseDL4JTest { public class TestCheckpointListener extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
@TempDir @TempDir
public File tempDir; public File tempDir;

View File

@ -67,11 +67,6 @@ public class TestListeners extends BaseDL4JTest {
@TempDir @TempDir
public File tempDir; public File tempDir;
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
@Test @Test
public void testSettingListenersUnsupervised() { public void testSettingListenersUnsupervised() {
//Pretrain layers should get copies of the listeners, in addition to the //Pretrain layers should get copies of the listeners, in addition to the

View File

@ -60,11 +60,6 @@ public class RegressionTest060 extends BaseDL4JTest {
return DataType.FLOAT; return DataType.FLOAT;
} }
@Override
public long getTimeoutMilliseconds() {
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
}
@Test @Test
public void regressionTestMLP1() throws Exception { public void regressionTestMLP1() throws Exception {

View File

@ -61,11 +61,6 @@ public class RegressionTest071 extends BaseDL4JTest {
return DataType.FLOAT; return DataType.FLOAT;
} }
@Override
public long getTimeoutMilliseconds() {
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
}
@Test @Test
public void regressionTestMLP1() throws Exception { public void regressionTestMLP1() throws Exception {

View File

@ -60,11 +60,6 @@ public class RegressionTest080 extends BaseDL4JTest {
return DataType.FLOAT; return DataType.FLOAT;
} }
@Override
public long getTimeoutMilliseconds() {
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
}
@Test @Test
public void regressionTestMLP1() throws Exception { public void regressionTestMLP1() throws Exception {

View File

@ -57,11 +57,6 @@ import static org.junit.jupiter.api.Assertions.*;
@Slf4j @Slf4j
public class RegressionTest100a extends BaseDL4JTest { public class RegressionTest100a extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
}
@Override @Override
public DataType getDataType(){ public DataType getDataType(){
return DataType.FLOAT; return DataType.FLOAT;

View File

@ -54,11 +54,6 @@ import static org.junit.jupiter.api.Assertions.*;
public class RegressionTest100b3 extends BaseDL4JTest { public class RegressionTest100b3 extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
}
@Override @Override
public DataType getDataType(){ public DataType getDataType(){
return DataType.FLOAT; return DataType.FLOAT;

View File

@ -73,11 +73,6 @@ import org.nd4j.common.resources.Resources;
public class RegressionTest100b4 extends BaseDL4JTest { public class RegressionTest100b4 extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
}
@Override @Override
public DataType getDataType() { public DataType getDataType() {
return DataType.FLOAT; return DataType.FLOAT;

View File

@ -60,11 +60,6 @@ public class RegressionTest100b6 extends BaseDL4JTest {
return DataType.FLOAT; return DataType.FLOAT;
} }
@Override
public long getTimeoutMilliseconds() {
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
}
@Test @Test
public void testCustomLayer() throws Exception { public void testCustomLayer() throws Exception {

View File

@ -31,11 +31,6 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
public class TestDistributionDeserializer extends BaseDL4JTest { public class TestDistributionDeserializer extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
}
@Test @Test
public void testDistributionDeserializer() throws Exception { public void testDistributionDeserializer() throws Exception {
//Test current format: //Test current format:

View File

@ -25,4 +25,5 @@ dependencies {
implementation "org.slf4j:slf4j-api" implementation "org.slf4j:slf4j-api"
implementation "org.apache.commons:commons-lang3" implementation "org.apache.commons:commons-lang3"
implementation "com.fasterxml.jackson.core:jackson-annotations"
} }

View File

@ -21,6 +21,7 @@
package org.deeplearning4j.nn.modelimport.keras.layers; package org.deeplearning4j.nn.modelimport.keras.layers;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.ArrayUtils; import org.apache.commons.lang3.ArrayUtils;
import org.deeplearning4j.common.config.DL4JClassLoading; import org.deeplearning4j.common.config.DL4JClassLoading;
@ -46,6 +47,7 @@ import java.util.List;
@Slf4j @Slf4j
@Data @Data
@EqualsAndHashCode(callSuper = false)
public class TFOpLayerImpl extends AbstractLayer<TFOpLayer> { public class TFOpLayerImpl extends AbstractLayer<TFOpLayer> {

View File

@ -21,6 +21,7 @@
package org.deeplearning4j.nn.modelimport.keras.layers.core; package org.deeplearning4j.nn.modelimport.keras.layers.core;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer; import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer;
@ -39,6 +40,7 @@ import java.util.Map;
*/ */
@Slf4j @Slf4j
@Data @Data
@EqualsAndHashCode(callSuper = false)
public class KerasMasking extends KerasLayer { public class KerasMasking extends KerasLayer {
private double maskingValue; private double maskingValue;

View File

@ -21,6 +21,7 @@
package org.deeplearning4j.nn.modelimport.keras.layers.core; package org.deeplearning4j.nn.modelimport.keras.layers.core;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.conf.graph.ElementWiseVertex; import org.deeplearning4j.nn.conf.graph.ElementWiseVertex;
import org.deeplearning4j.nn.conf.graph.MergeVertex; import org.deeplearning4j.nn.conf.graph.MergeVertex;
@ -35,6 +36,7 @@ import java.util.Map;
@Slf4j @Slf4j
@Data @Data
@EqualsAndHashCode(callSuper = false)
public class KerasMerge extends KerasLayer { public class KerasMerge extends KerasLayer {
private final String LAYER_FIELD_MODE = "mode"; private final String LAYER_FIELD_MODE = "mode";

Some files were not shown because too many files have changed in this diff Show More