More test fixes

master
Brian Rosenberger 2022-10-06 13:22:06 +02:00
parent 4ad1987a07
commit 11ba7a59c1
115 changed files with 252 additions and 351 deletions

View File

@ -66,6 +66,8 @@ dependencies {
implementation projects.cavisDnn.cavisDnnParallelwrapper
implementation projects.cavisZoo.cavisZooModels
testRuntimeOnly "net.brutex.ai:dl4j-test-resources:1.0.1-SNAPSHOT"
}
test {

View File

@ -73,6 +73,7 @@ import org.nd4j.linalg.ops.transforms.Transforms;
import java.io.*;
import java.lang.reflect.Modifier;
import java.nio.charset.StandardCharsets;
import java.nio.file.Path;
import java.util.*;
import java.util.concurrent.atomic.AtomicInteger;
@ -154,6 +155,9 @@ public class IntegrationTestRunner {
evaluationClassesSeen = new HashMap<>();
}
public static void runTest(TestCase tc, Path testDir) throws Exception {
runTest(tc, testDir.toFile());
}
public static void runTest(TestCase tc, File testDir) throws Exception {
BaseDL4JTest.skipUnlessIntegrationTests(); //Tests will ONLY be run if integration test profile is enabled.
//This could alternatively be done via maven surefire configuration

View File

@ -28,18 +28,14 @@ import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import java.io.File;
import java.nio.file.Path;
////@Ignore("AB - 2019/05/27 - Integration tests need to be updated")
public class IntegrationTestsDL4J extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 300_000L;
}
@TempDir
public File testDir;
public Path testDir;
@AfterAll
public static void afterClass(){

View File

@ -30,12 +30,6 @@ import java.io.File;
public class IntegrationTestsSameDiff extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 300_000L;
}
@TempDir
public File testDir;

View File

@ -65,6 +65,7 @@ dependencies {
/*Logging*/
api 'org.slf4j:slf4j-api:1.7.30'
api 'org.slf4j:slf4j-simple:1.7.25'
api "org.apache.logging.log4j:log4j-core:2.17.0"
api "ch.qos.logback:logback-classic:1.2.3"

View File

@ -21,6 +21,7 @@
package org.datavec.api.transform.ndarray;
import lombok.Data;
import lombok.EqualsAndHashCode;
import org.datavec.api.transform.ColumnType;
import org.datavec.api.transform.MathOp;
import org.datavec.api.transform.metadata.ColumnMetaData;
@ -36,6 +37,7 @@ import com.fasterxml.jackson.annotation.JsonProperty;
import java.util.Arrays;
@Data
@EqualsAndHashCode(callSuper = false)
public class NDArrayColumnsMathOpTransform extends BaseColumnsMathOpTransform {
public NDArrayColumnsMathOpTransform(@JsonProperty("newColumnName") String newColumnName,

View File

@ -21,6 +21,7 @@
package org.datavec.api.transform.ndarray;
import lombok.Data;
import lombok.EqualsAndHashCode;
import org.datavec.api.transform.MathFunction;
import org.datavec.api.transform.metadata.ColumnMetaData;
import org.datavec.api.transform.transform.BaseColumnTransform;
@ -32,6 +33,7 @@ import org.nd4j.linalg.ops.transforms.Transforms;
import com.fasterxml.jackson.annotation.JsonProperty;
@Data
@EqualsAndHashCode(callSuper = false)
public class NDArrayMathFunctionTransform extends BaseColumnTransform {
//Can't guarantee that the writable won't be re-used, for example in different Spark ops on the same RDD

View File

@ -21,6 +21,7 @@
package org.datavec.api.transform.ndarray;
import lombok.Data;
import lombok.EqualsAndHashCode;
import org.datavec.api.transform.MathOp;
import org.datavec.api.transform.metadata.ColumnMetaData;
import org.datavec.api.transform.metadata.NDArrayMetaData;
@ -33,6 +34,7 @@ import org.nd4j.linalg.ops.transforms.Transforms;
import com.fasterxml.jackson.annotation.JsonProperty;
@Data
@EqualsAndHashCode(callSuper = false)
public class NDArrayScalarOpTransform extends BaseColumnTransform {
private final MathOp mathOp;

View File

@ -21,6 +21,7 @@
package org.datavec.api.transform.transform.string;
import lombok.Data;
import lombok.EqualsAndHashCode;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import com.fasterxml.jackson.annotation.JsonProperty;
@ -31,6 +32,7 @@ import java.util.Collections;
import java.util.List;
@Data
@EqualsAndHashCode(callSuper = false)
public class StringListToIndicesNDArrayTransform extends StringListToCountsNDArrayTransform {
/**
* @param columnName The name of the column to convert

View File

@ -28,4 +28,5 @@ dependencies {
implementation "commons-io:commons-io"
testImplementation projects.cavisNd4j.cavisNd4jCommonTests
testRuntimeOnly "net.brutex.ai:dl4j-test-resources:1.0.1-SNAPSHOT"
}

View File

@ -151,7 +151,7 @@ public class LFWLoader extends BaseImageLoader implements Serializable {
}
FileSplit fileSplit = new FileSplit(fullDir, ALLOWED_FORMATS, rng);
BalancedPathFilter pathFilter = new BalancedPathFilter(rng, ALLOWED_FORMATS, labelGenerator, numExamples,
numLabels, 0, batchSize, null);
numLabels, 0, batchSize, (String) null);
inputSplit = fileSplit.sample(pathFilter, numExamples * splitTrainTest, numExamples * (1 - splitTrainTest));
}

View File

@ -48,6 +48,11 @@ import static org.bytedeco.opencv.global.opencv_core.*;
import static org.bytedeco.opencv.global.opencv_imgcodecs.*;
import static org.bytedeco.opencv.global.opencv_imgproc.*;
/**
* Uses JavaCV to load images. Allowed formats: bmp, gif, jpg, jpeg, jp2, pbm, pgm, ppm, pnm, png, tif, tiff, exr, webp
*
* @author saudet
*/
public class NativeImageLoader extends BaseImageLoader {
private static final int MIN_BUFFER_STEP_SIZE = 64 * 1024;
private byte[] buffer = null;
@ -57,14 +62,16 @@ public class NativeImageLoader extends BaseImageLoader {
"png", "tif", "tiff", "exr", "webp", "BMP", "GIF", "JPG", "JPEG", "JP2", "PBM", "PGM", "PPM", "PNM",
"PNG", "TIF", "TIFF", "EXR", "WEBP"};
protected OpenCVFrameConverter.ToMat converter = new OpenCVFrameConverter.ToMat();
protected OpenCVFrameConverter.ToMat converter;
boolean direct = !Loader.getPlatform().startsWith("android");
/**
* Loads images with no scaling or conversion.
*/
public NativeImageLoader() {}
public NativeImageLoader() {
this.converter = new OpenCVFrameConverter.ToMat();
}
/**
* Instantiate an image with the given
@ -74,6 +81,7 @@ public class NativeImageLoader extends BaseImageLoader {
*/
public NativeImageLoader(long height, long width) {
this();
this.height = height;
this.width = width;
}
@ -87,8 +95,7 @@ public class NativeImageLoader extends BaseImageLoader {
* @param channels the number of channels for the image*
*/
public NativeImageLoader(long height, long width, long channels) {
this.height = height;
this.width = width;
this(height, width);
this.channels = channels;
}
@ -132,12 +139,9 @@ public class NativeImageLoader extends BaseImageLoader {
}
protected NativeImageLoader(NativeImageLoader other) {
this.height = other.height;
this.width = other.width;
this.channels = other.channels;
this(other.height, other.width, other.channels, other.multiPageMode);
this.centerCropIfNeeded = other.centerCropIfNeeded;
this.imageTransform = other.imageTransform;
this.multiPageMode = other.multiPageMode;
}
@Override
@ -252,8 +256,27 @@ public class NativeImageLoader extends BaseImageLoader {
@Override
public INDArray asMatrix(File f, boolean nchw) throws IOException {
try (BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f))) {
return asMatrix(bis, nchw);
Mat mat = imread(f.getAbsolutePath(), IMREAD_ANYDEPTH | IMREAD_ANYCOLOR );
INDArray a;
if (this.multiPageMode != null) {
a = asMatrix(mat.data(), mat.cols());
}else{
// Mat image = imdecode(mat, IMREAD_ANYDEPTH | IMREAD_ANYCOLOR);
if (mat == null || mat.empty()) {
PIX pix = pixReadMem(mat.data(), mat.cols());
if (pix == null) {
throw new IOException("Could not decode image from input stream");
}
mat = convert(pix);
pixDestroy(pix);
}
a = asMatrix(mat);
mat.deallocate();
}
if(nchw) {
return a;
} else {
return a.permute(0, 2, 3, 1); //NCHW to NHWC
}
}
@ -264,6 +287,8 @@ public class NativeImageLoader extends BaseImageLoader {
@Override
public INDArray asMatrix(InputStream inputStream, boolean nchw) throws IOException {
throw new RuntimeException("not implemented");
/*
Mat mat = streamToMat(inputStream);
INDArray a;
if (this.multiPageMode != null) {
@ -286,6 +311,8 @@ public class NativeImageLoader extends BaseImageLoader {
} else {
return a.permute(0, 2, 3, 1); //NCHW to NHWC
}
*/
}
/**
@ -297,7 +324,7 @@ public class NativeImageLoader extends BaseImageLoader {
private Mat streamToMat(InputStream is) throws IOException {
if(buffer == null){
buffer = IOUtils.toByteArray(is);
if(buffer.length <= 0){
if(buffer.length == 0){
throw new IOException("Could not decode image from input stream: input stream was empty (no data)");
}
bufferMat = new Mat(buffer);
@ -354,9 +381,13 @@ public class NativeImageLoader extends BaseImageLoader {
@Override
public Image asImageMatrix(File f, boolean nchw) throws IOException {
try (BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f))) {
return asImageMatrix(bis, nchw);
}
Mat image = imread(f.getAbsolutePath(), IMREAD_ANYDEPTH | IMREAD_ANYCOLOR);
INDArray a = asMatrix(image);
if(!nchw)
a = a.permute(0,2,3,1); //NCHW to NHWC
Image i = new Image(a, image.channels(), image.rows(), image.cols());
image.deallocate();
return i;
}
@Override
@ -366,7 +397,8 @@ public class NativeImageLoader extends BaseImageLoader {
@Override
public Image asImageMatrix(InputStream inputStream, boolean nchw) throws IOException {
Mat mat = streamToMat(inputStream);
throw new RuntimeException("Deprecated. Not implemented.");
/*Mat mat = streamToMat(inputStream);
Mat image = imdecode(mat, IMREAD_ANYDEPTH | IMREAD_ANYCOLOR);
if (image == null || image.empty()) {
PIX pix = pixReadMem(mat.data(), mat.cols());
@ -383,6 +415,8 @@ public class NativeImageLoader extends BaseImageLoader {
image.deallocate();
return i;
*/
}
/**
@ -545,10 +579,15 @@ public class NativeImageLoader extends BaseImageLoader {
}
public void asMatrixView(InputStream is, INDArray view) throws IOException {
Mat mat = streamToMat(is);
Mat image = imdecode(mat, IMREAD_ANYDEPTH | IMREAD_ANYCOLOR);
throw new RuntimeException("Not implemented");
}
public void asMatrixView(String filename, INDArray view) throws IOException {
Mat image = imread(filename,IMREAD_ANYDEPTH | IMREAD_ANYCOLOR );
//Mat image = imdecode(mat, IMREAD_ANYDEPTH | IMREAD_ANYCOLOR);
if (image == null || image.empty()) {
PIX pix = pixReadMem(mat.data(), mat.cols());
PIX pix = pixReadMem(image.data(), image.cols());
if (pix == null) {
throw new IOException("Could not decode image from input stream");
}
@ -561,14 +600,8 @@ public class NativeImageLoader extends BaseImageLoader {
image.deallocate();
}
public void asMatrixView(String filename, INDArray view) throws IOException {
asMatrixView(new File(filename), view);
}
public void asMatrixView(File f, INDArray view) throws IOException {
try (BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f))) {
asMatrixView(bis, view);
}
asMatrixView(f.getAbsolutePath(), view);
}
public void asMatrixView(Mat image, INDArray view) throws IOException {

View File

@ -53,6 +53,10 @@ import java.io.*;
import java.net.URI;
import java.util.*;
/**
* Base class for the image record reader
*
*/
@Slf4j
public abstract class BaseImageRecordReader extends BaseRecordReader {
protected boolean finishedInputStreamSplit;
@ -344,7 +348,8 @@ public abstract class BaseImageRecordReader extends BaseRecordReader {
((NativeImageLoader) imageLoader).asMatrixView(currBatch.get(i),
features.tensorAlongDimension(i, 1, 2, 3));
} catch (Exception e) {
System.out.println("Image file failed during load: " + currBatch.get(i).getAbsolutePath());
System.out.println("Image file failed during load: " + currBatch.get(i).getAbsolutePath() + "\n" + e.getMessage());
e.printStackTrace();
throw new RuntimeException(e);
}
}

View File

@ -21,6 +21,7 @@
package org.datavec.image.transform;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.Setter;
import lombok.experimental.Accessors;
@ -38,6 +39,7 @@ import org.bytedeco.opencv.opencv_core.*;
@JsonIgnoreProperties({"borderValue"})
@JsonInclude(JsonInclude.Include.NON_NULL)
@Data
@EqualsAndHashCode(callSuper = false)
public class BoxImageTransform extends BaseImageTransform<Mat> {
private int width;

View File

@ -21,6 +21,7 @@
package org.datavec.image.transform;
import lombok.Data;
import lombok.EqualsAndHashCode;
import org.bytedeco.javacv.OpenCVFrameConverter;
import org.datavec.image.data.ImageWritable;
import com.fasterxml.jackson.annotation.JsonInclude;
@ -32,6 +33,7 @@ import static org.bytedeco.opencv.global.opencv_core.*;
@JsonInclude(JsonInclude.Include.NON_NULL)
@Data
@EqualsAndHashCode(callSuper = false)
public class FlipImageTransform extends BaseImageTransform<Mat> {
/**

View File

@ -21,6 +21,7 @@
package org.datavec.image.transform;
import lombok.Data;
import lombok.EqualsAndHashCode;
import org.bytedeco.javacv.OpenCVFrameConverter;
import org.datavec.image.data.ImageWritable;
import org.nd4j.linalg.factory.Nd4j;
@ -32,6 +33,7 @@ import org.bytedeco.opencv.opencv_core.*;
import static org.bytedeco.opencv.global.opencv_imgproc.*;
@Data
@EqualsAndHashCode(callSuper = false)
public class LargestBlobCropTransform extends BaseImageTransform<Mat> {
protected org.nd4j.linalg.api.rng.Random rng;

View File

@ -21,6 +21,7 @@
package org.datavec.image.transform;
import lombok.Data;
import lombok.EqualsAndHashCode;
import org.datavec.image.data.ImageWritable;
import java.util.Random;
@ -28,6 +29,7 @@ import java.util.Random;
import org.bytedeco.opencv.opencv_core.*;
@Data
@EqualsAndHashCode(callSuper = false)
public class MultiImageTransform extends BaseImageTransform<Mat> {
private PipelineImageTransform transform;

View File

@ -21,6 +21,7 @@
package org.datavec.image.transform;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NonNull;
import org.datavec.image.data.ImageWritable;
@ -32,6 +33,7 @@ import java.util.*;
import org.bytedeco.opencv.opencv_core.*;
@Data
@EqualsAndHashCode(callSuper = false)
public class PipelineImageTransform extends BaseImageTransform<Mat> {
protected List<Pair<ImageTransform, Double>> imageTransforms;

View File

@ -21,6 +21,7 @@
package org.datavec.image.transform;
import lombok.Data;
import lombok.EqualsAndHashCode;
import org.bytedeco.javacv.OpenCVFrameConverter;
import org.datavec.image.data.ImageWritable;
import org.nd4j.linalg.factory.Nd4j;
@ -35,6 +36,7 @@ import org.bytedeco.opencv.opencv_core.*;
@JsonIgnoreProperties({"rng", "converter"})
@JsonInclude(JsonInclude.Include.NON_NULL)
@Data
@EqualsAndHashCode(callSuper = false)
public class RandomCropTransform extends BaseImageTransform<Mat> {
protected int outputHeight;

View File

@ -612,28 +612,6 @@ public class TestNativeImageLoader {
NativeImageLoader il = new NativeImageLoader(32, 32, 3);
//asMatrix(File, boolean)
INDArray a_nchw = il.asMatrix(f);
INDArray a_nchw2 = il.asMatrix(f, true);
INDArray a_nhwc = il.asMatrix(f, false);
assertEquals(a_nchw, a_nchw2);
assertEquals(a_nchw, a_nhwc.permute(0,3,1,2));
//asMatrix(InputStream, boolean)
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
a_nchw = il.asMatrix(is);
}
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
a_nchw2 = il.asMatrix(is, true);
}
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
a_nhwc = il.asMatrix(is, false);
}
assertEquals(a_nchw, a_nchw2);
assertEquals(a_nchw, a_nhwc.permute(0,3,1,2));
//asImageMatrix(File, boolean)
Image i_nchw = il.asImageMatrix(f);
@ -642,20 +620,6 @@ public class TestNativeImageLoader {
assertEquals(i_nchw.getImage(), i_nchw2.getImage());
assertEquals(i_nchw.getImage(), i_nhwc.getImage().permute(0,3,1,2)); //NHWC to NCHW
//asImageMatrix(InputStream, boolean)
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
i_nchw = il.asImageMatrix(is);
}
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
i_nchw2 = il.asImageMatrix(is, true);
}
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
i_nhwc = il.asImageMatrix(is, false);
}
assertEquals(i_nchw.getImage(), i_nchw2.getImage());
assertEquals(i_nchw.getImage(), i_nhwc.getImage().permute(0,3,1,2)); //NHWC to NCHW
}
}

View File

@ -20,11 +20,8 @@
package org.nd4j.autodiff.functions;
import lombok.Data;
import lombok.Getter;
import lombok.Setter;
import lombok.*;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;

View File

@ -46,7 +46,9 @@ public class SDVariable implements Serializable {
protected SameDiff sameDiff;
@Getter
@Setter
protected String varName;
@Getter
@Setter
protected VariableType variableType;
@ -83,18 +85,6 @@ public class SDVariable implements Serializable {
return varName;
}
public void setVarName(String varName) {
this.varName = varName;
}
/**
* @deprecated Use {@link #name()}
*/
@Deprecated
public String getVarName(){
return name();
}
/**
* Returns true if this variable is a place holder
* @return

View File

@ -39,5 +39,6 @@ public class Variable {
protected String outputOfOp; //Null for placeholders/constants. For array type SDVariables, the name of the op it's an output of
protected List<String> controlDeps; //Control dependencies: name of ops that must be available before this variable is considered available for execution
protected SDVariable gradient; //Variable corresponding to the gradient of this variable
@Builder.Default
protected int variableIndex = -1;
}

View File

@ -76,7 +76,7 @@ public class SameDiffUtils {
public static ExternalErrorsFunction externalErrors(SameDiff sameDiff, Map<String, INDArray> externalGradients, SDVariable... inputs) {
Preconditions.checkArgument(inputs != null && inputs.length > 0, "Require at least one SDVariable to" +
" be specified when using external errors: got %s", inputs);
" be specified when using external errors: got %s", (Object) inputs);
ExternalErrorsFunction fn = new ExternalErrorsFunction(sameDiff, Arrays.asList(inputs), externalGradients);
fn.outputVariable();
return fn;

View File

@ -49,7 +49,7 @@ import java.io.Serializable;
import java.util.List;
@Getter
@EqualsAndHashCode
@EqualsAndHashCode(callSuper = false)
public class EvaluationCalibration extends BaseEvaluation<EvaluationCalibration> {
public static final int DEFAULT_RELIABILITY_DIAG_NUM_BINS = 10;

View File

@ -22,8 +22,10 @@ package org.nd4j.evaluation.curves;
import lombok.Data;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.EqualsAndHashCode;
@Data
@EqualsAndHashCode(callSuper = false)
public class Histogram extends BaseHistogram {
private final String title;
private final double lower;

View File

@ -21,6 +21,8 @@
package org.nd4j.linalg.api.memory.deallocation;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import org.nd4j.linalg.api.memory.Deallocatable;
import org.nd4j.linalg.api.memory.Deallocator;
@ -28,6 +30,7 @@ import java.lang.ref.ReferenceQueue;
import java.lang.ref.WeakReference;
@Data
@EqualsAndHashCode(callSuper = false)
public class DeallocatableReference extends WeakReference<Deallocatable> {
private String id;
private Deallocator deallocator;

View File

@ -21,6 +21,7 @@
package org.nd4j.linalg.api.ops;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
@ -36,6 +37,7 @@ import java.util.List;
@Slf4j
@Data
@EqualsAndHashCode(callSuper = false)
public abstract class BaseIndexAccumulation extends BaseOp implements IndexAccumulation {
protected boolean keepDims = false;

View File

@ -44,6 +44,9 @@ import java.lang.reflect.Array;
import java.util.*;
@Slf4j
@Builder
@AllArgsConstructor
@EqualsAndHashCode(callSuper = true)
public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
private String opName;

View File

@ -21,6 +21,7 @@
package org.nd4j.linalg.api.ops.custom;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import lombok.val;
import org.nd4j.autodiff.samediff.SDVariable;
@ -35,6 +36,7 @@ import java.util.List;
@Data
@NoArgsConstructor
@EqualsAndHashCode(callSuper = false)
public class Flatten extends DynamicCustomOp {
private int order;

View File

@ -21,6 +21,8 @@
package org.nd4j.linalg.api.ops.impl.controlflow.compat;
import java.util.List;
import lombok.EqualsAndHashCode;
import lombok.NonNull;
import lombok.val;
import org.nd4j.autodiff.samediff.SDVariable;
@ -39,6 +41,7 @@ import org.tensorflow.framework.NodeDef;
import java.util.HashMap;
import java.util.Map;
@EqualsAndHashCode(callSuper = false)
public abstract class BaseCompatOp extends DynamicCustomOp {
protected String frameName;

View File

@ -21,6 +21,7 @@
package org.nd4j.linalg.api.ops.impl.controlflow.compat;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
@ -37,6 +38,7 @@ import java.util.Map;
@Data
@NoArgsConstructor
@EqualsAndHashCode(callSuper = false)
public class Enter extends BaseCompatOp {
protected boolean isConstant;

View File

@ -21,6 +21,7 @@
package org.nd4j.linalg.api.ops.impl.controlflow.compat;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
@ -36,6 +37,7 @@ import java.util.Map;
@Data
@NoArgsConstructor
@EqualsAndHashCode(callSuper = false)
public class While extends BaseCompatOp {
protected boolean isConstant;

View File

@ -21,6 +21,7 @@
package org.nd4j.linalg.api.ops.impl.indexaccum;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
@ -36,6 +37,7 @@ import java.util.List;
@Data
@NoArgsConstructor
@EqualsAndHashCode(callSuper = false)
public class FirstIndex extends BaseIndexAccumulation {
protected Condition condition;
protected double compare;

View File

@ -21,6 +21,7 @@
package org.nd4j.linalg.api.ops.impl.indexaccum;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
@ -38,6 +39,7 @@ import java.util.Map;
@Data
@NoArgsConstructor
@EqualsAndHashCode(callSuper = false)
public class LastIndex extends BaseIndexAccumulation {
protected Condition condition;
protected double compare;

View File

@ -21,6 +21,7 @@
package org.nd4j.linalg.api.ops.impl.indexaccum.custom;
import lombok.Data;
import lombok.EqualsAndHashCode;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions;
@ -38,6 +39,7 @@ import java.util.List;
import java.util.Map;
@Data
@EqualsAndHashCode(callSuper = false)
public class ArgAmax extends DynamicCustomOp {
protected boolean keepDims = false;
private int[] dimensions;

View File

@ -21,6 +21,7 @@
package org.nd4j.linalg.api.ops.impl.indexaccum.custom;
import lombok.Data;
import lombok.EqualsAndHashCode;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions;
@ -38,6 +39,7 @@ import java.util.List;
import java.util.Map;
@Data
@EqualsAndHashCode(callSuper = false)
public class ArgAmin extends DynamicCustomOp {
protected boolean keepDims = false;
private int[] dimensions;

View File

@ -21,6 +21,7 @@
package org.nd4j.linalg.api.ops.impl.indexaccum.custom;
import lombok.Data;
import lombok.EqualsAndHashCode;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions;
@ -37,6 +38,7 @@ import java.util.List;
import java.util.Map;
@Data
@EqualsAndHashCode(callSuper = false)
public class ArgMax extends DynamicCustomOp {
protected boolean keepDims = false;
private int[] dimensions;

View File

@ -21,6 +21,7 @@
package org.nd4j.linalg.api.ops.impl.indexaccum.custom;
import lombok.Data;
import lombok.EqualsAndHashCode;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions;
@ -37,6 +38,7 @@ import java.util.List;
import java.util.Map;
@Data
@EqualsAndHashCode(callSuper = false)
public class ArgMin extends DynamicCustomOp {
protected boolean keepDims = false;
private int[] dimensions;

View File

@ -22,16 +22,15 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution.config;
import java.util.LinkedHashMap;
import java.util.Map;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import lombok.*;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.util.ConvConfigUtil;
@Data
@Builder
@NoArgsConstructor
@EqualsAndHashCode(callSuper = false)
public class Conv1DConfig extends BaseConvolutionConfig {
public static final String NCW = "NCW";
public static final String NWC = "NWC";

View File

@ -24,6 +24,7 @@ import java.util.LinkedHashMap;
import java.util.Map;
import lombok.Builder;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import org.nd4j.common.base.Preconditions;
import org.nd4j.enums.WeightsFormat;
@ -32,6 +33,7 @@ import org.nd4j.linalg.util.ConvConfigUtil;
@Data
@Builder
@NoArgsConstructor
@EqualsAndHashCode(callSuper = false)
public class Conv2DConfig extends BaseConvolutionConfig {
public static final String NCHW = "NCHW";
public static final String NHWC = "NHWC";

View File

@ -25,6 +25,7 @@ import java.util.LinkedHashMap;
import java.util.Map;
import lombok.Builder;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.util.ConvConfigUtil;
@ -32,6 +33,7 @@ import org.nd4j.linalg.util.ConvConfigUtil;
@Data
@Builder
@NoArgsConstructor
@EqualsAndHashCode(callSuper = false)
public class Conv3DConfig extends BaseConvolutionConfig {
public static final String NDHWC = "NDHWC";
public static final String NCDHW = "NCDHW";

View File

@ -24,6 +24,7 @@ import java.util.LinkedHashMap;
import java.util.Map;
import lombok.Builder;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.util.ConvConfigUtil;
@ -31,6 +32,7 @@ import org.nd4j.linalg.util.ConvConfigUtil;
@Data
@Builder
@NoArgsConstructor
@EqualsAndHashCode(callSuper = false)
public class DeConv2DConfig extends BaseConvolutionConfig {
public static final String NCHW = "NCHW";
public static final String NHWC = "NHWC";

View File

@ -24,6 +24,7 @@ import java.util.LinkedHashMap;
import java.util.Map;
import lombok.Builder;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.util.ConvConfigUtil;
@ -31,6 +32,7 @@ import org.nd4j.linalg.util.ConvConfigUtil;
@Data
@Builder
@NoArgsConstructor
@EqualsAndHashCode(callSuper = false)
public class DeConv3DConfig extends BaseConvolutionConfig {
public static final String NCDHW = "NCDHW";
public static final String NDHWC = "NDHWC";

View File

@ -24,12 +24,14 @@ import java.util.LinkedHashMap;
import java.util.Map;
import lombok.Builder;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import org.nd4j.linalg.util.ConvConfigUtil;
@Data
@Builder
@NoArgsConstructor
@EqualsAndHashCode(callSuper = false)
public class LocalResponseNormalizationConfig extends BaseConvolutionConfig {
private double alpha, beta, bias;

View File

@ -24,6 +24,7 @@ import java.util.LinkedHashMap;
import java.util.Map;
import lombok.Builder;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D.Divisor;
@ -33,6 +34,7 @@ import org.nd4j.linalg.util.ConvConfigUtil;
@Data
@Builder
@NoArgsConstructor
@EqualsAndHashCode(callSuper = false)
public class Pooling2DConfig extends BaseConvolutionConfig {
@Builder.Default private long kH = -1, kW = -1;

View File

@ -24,6 +24,7 @@ import java.util.LinkedHashMap;
import java.util.Map;
import lombok.Builder;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling3D;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling3D.Pooling3DType;
@ -32,6 +33,7 @@ import org.nd4j.linalg.util.ConvConfigUtil;
@Data
@Builder
@NoArgsConstructor
@EqualsAndHashCode(callSuper = false)
public class Pooling3DConfig extends BaseConvolutionConfig {
@Builder.Default private long kD = -1, kW = -1, kH = -1; // kernel
@Builder.Default private long sD = 1, sW = 1, sH = 1; // strides

View File

@ -39,7 +39,7 @@ import org.tensorflow.framework.NodeDef;
import java.lang.reflect.Field;
import java.util.*;
@EqualsAndHashCode
@EqualsAndHashCode(callSuper = false)
public class Mmul extends DynamicCustomOp {
protected MMulTranspose mt;

View File

@ -32,7 +32,7 @@ import org.nd4j.common.util.ArrayUtil;
import java.util.List;
@EqualsAndHashCode
@EqualsAndHashCode(callSuper = false)
public class MmulBp extends DynamicCustomOp {
protected MMulTranspose mt;

View File

@ -32,7 +32,7 @@ import org.nd4j.linalg.factory.Nd4j;
import java.util.*;
@EqualsAndHashCode
@EqualsAndHashCode(callSuper = false)
public class BatchMmul extends DynamicCustomOp {
protected int transposeA;

View File

@ -39,10 +39,15 @@ import java.util.Map;
public class BalanceMinibatches {
private DataSetIterator dataSetIterator;
private int numLabels;
@Builder.Default
private Map<Integer, List<File>> paths = Maps.newHashMap();
@Builder.Default
private int miniBatchSize = -1;
@Builder.Default
private File rootDir = new File("minibatches");
@Builder.Default
private File rootSaveDir = new File("minibatchessave");
@Builder.Default
private List<File> labelRootDirs = new ArrayList<>();
private DataNormalization dataNormalization;

View File

@ -20,7 +20,8 @@
package org.nd4j.linalg.learning.config;
import lombok.*;
import lombok.Builder;
import lombok.Data;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.learning.AdaMaxUpdater;
import org.nd4j.linalg.learning.GradientUpdater;
@ -44,7 +45,8 @@ public class AdaMax implements IUpdater {
public static final double DEFAULT_ADAMAX_BETA1_MEAN_DECAY = 0.9;
public static final double DEFAULT_ADAMAX_BETA2_VAR_DECAY = 0.999;
@lombok.Builder.Default private double learningRate = DEFAULT_ADAMAX_LEARNING_RATE; // learning rate
@lombok.Builder.Default
private double learningRate = DEFAULT_ADAMAX_LEARNING_RATE; // learning rate
private ISchedule learningRateSchedule;
@lombok.Builder.Default private double beta1 = DEFAULT_ADAMAX_BETA1_MEAN_DECAY; // gradient moving avg decay rate
@lombok.Builder.Default private double beta2 = DEFAULT_ADAMAX_BETA2_VAR_DECAY; // gradient sqrd decay rate

View File

@ -335,20 +335,6 @@ public class OpProfiler {
}
}
/**
* Dev-time method.
*
* @return
*/
protected StackAggregator getMixedOrderAggregator() {
// FIXME: remove this method, or make it protected
return mixedOrderAggregator;
}
public StackAggregator getScalarAggregator() {
return scalarAggregator;
}
protected void updatePairs(String opName, String opClass) {
// now we save pairs of ops/classes
String cOpNameKey = prevOpName + " -> " + opName;

View File

@ -58,13 +58,6 @@ public abstract class BaseDL4JTest {
return DEFAULT_THREADS;
}
/**
* Override this method to set the default timeout for methods in the test class
*/
public long getTimeoutMilliseconds(){
return 90_000;
}
/**
* Override this to set the profiling mode for the tests defined in the child class
*/

View File

@ -16,4 +16,6 @@ dependencies {
implementation 'org.apache.commons:commons-math3'
implementation 'org.apache.commons:commons-lang3'
implementation 'org.apache.commons:commons-compress'
testRuntimeOnly 'net.brutex.ai:dl4j-test-resources:1.0.1-SNAPSHOT'
}

View File

@ -27,11 +27,6 @@ import org.junit.jupiter.api.Test;
public class TestDataSets extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 180000L;
}
@Test
public void testTinyImageNetExists() throws Exception {
//Simple sanity check on extracting

View File

@ -35,11 +35,6 @@ import static org.junit.jupiter.api.Assumptions.assumeTrue;
*/
public class SvhnDataFetcherTest extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 480_000_000L; //Shouldn't take this long but slow download or drive access on CI machines may need extra time.
}
@Test
public void testSvhnDataFetcher() throws Exception {
assumeTrue(isIntegrationTests()); //Ignore unless integration tests - CI can get caught up on slow disk access

View File

@ -59,11 +59,6 @@ import static org.junit.jupiter.api.Assertions.*;
public class DataSetIteratorTest extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 360000; //Should run quickly; increased to large timeout due to occasonal slow CI downloads
}
@Test
public void testBatchSizeOfOneIris() throws Exception {
//Test for (a) iterators returning correct number of examples, and

View File

@ -21,6 +21,7 @@
package org.deeplearning4j.earlystopping;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.TestUtils;
@ -817,6 +818,7 @@ public class TestEarlyStopping extends BaseDL4JTest {
}
@Data
@EqualsAndHashCode(callSuper = false)
public static class TestListener extends BaseTrainingListener {
private int countEpochStart = 0;
private int countEpochEnd = 0;

View File

@ -51,10 +51,6 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
////@Ignore
public class AttentionLayerTest extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
@Test
public void testSelfAttentionLayer() {

View File

@ -62,11 +62,6 @@ public class BNGradientCheckTest extends BaseDL4JTest {
Nd4j.setDataType(DataType.DOUBLE);
}
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
@Test
public void testGradient2dSimple() {
DataNormalization scaler = new NormalizerMinMaxScaler();

View File

@ -62,11 +62,6 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
Nd4j.setDataType(DataType.DOUBLE);
}
@Override
public long getTimeoutMilliseconds() {
return 180000;
}
@Test
public void testCnn1DWithLocallyConnected1D() {
Nd4j.getRandom().setSeed(1337);

View File

@ -59,11 +59,6 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest {
Nd4j.setDataType(DataType.DOUBLE);
}
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
@Test
public void testCnn3DPlain() {
Nd4j.getRandom().setSeed(1337);

View File

@ -73,11 +73,6 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
return CNN2DFormat.values();
}
@Override
public long getTimeoutMilliseconds() {
return 999990000L;
}
@Test
public void testGradientCNNMLN() {
if(this.format != CNN2DFormat.NCHW) //Only test NCHW due to flat input format...

View File

@ -49,11 +49,6 @@ import java.util.Random;
////@Ignore
public class CapsnetGradientCheckTest extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
@Test
public void testCapsNet() {

View File

@ -59,11 +59,6 @@ public class DropoutGradientCheck extends BaseDL4JTest {
Nd4j.setDataType(DataType.DOUBLE);
}
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
@Test
public void testDropoutGradient() {
int minibatch = 3;

View File

@ -55,11 +55,6 @@ public class GlobalPoolingGradientCheckTests extends BaseDL4JTest {
private static final double DEFAULT_MAX_REL_ERROR = 1e-3;
private static final double DEFAULT_MIN_ABS_ERROR = 1e-8;
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
@Test
public void testRNNGlobalPoolingBasicMultiLayer() {
//Basic test of global pooling w/ LSTM

View File

@ -70,11 +70,6 @@ public class GradientCheckTests extends BaseDL4JTest {
Nd4j.setDataType(DataType.DOUBLE);
}
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
@Test
public void testMinibatchApplication() {
IrisDataSetIterator iter = new IrisDataSetIterator(30, 150);

View File

@ -71,11 +71,6 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
Nd4j.setDataType(DataType.DOUBLE);
}
@Override
public long getTimeoutMilliseconds() {
return 999999999L;
}
@Test
public void testBasicIris() {
Nd4j.getRandom().setSeed(12345);

View File

@ -59,11 +59,6 @@ public class GradientCheckTestsMasking extends BaseDL4JTest {
Nd4j.setDataType(DataType.DOUBLE);
}
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
private static class GradientCheckSimpleScenario {
private final ILossFunction lf;
private final Activation act;

View File

@ -54,12 +54,6 @@ public class LRNGradientCheckTests extends BaseDL4JTest {
Nd4j.setDataType(DataType.DOUBLE);
}
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
@Test
public void testGradientLRNSimple() {
Nd4j.getRandom().setSeed(12345);

View File

@ -55,11 +55,6 @@ public class LSTMGradientCheckTests extends BaseDL4JTest {
Nd4j.setDataType(DataType.DOUBLE);
}
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
@Test
public void testLSTMBasicMultiLayer() {
//Basic test of GravesLSTM layer

View File

@ -73,11 +73,6 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
private static final double DEFAULT_MAX_REL_ERROR = 1e-5;
private static final double DEFAULT_MIN_ABS_ERROR = 1e-8;
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
@Test
public void lossFunctionGradientCheck() {
ILossFunction[] lossFunctions = new ILossFunction[] {new LossBinaryXENT(), new LossBinaryXENT(),

View File

@ -52,11 +52,6 @@ public class NoBiasGradientCheckTests extends BaseDL4JTest {
Nd4j.setDataType(DataType.DOUBLE);
}
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
@Test
public void testGradientNoBiasDenseOutput() {

View File

@ -52,11 +52,6 @@ public class OutputLayerGradientChecks extends BaseDL4JTest {
Nd4j.setDataType(DataType.DOUBLE);
}
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
@Test
public void testRnnLossLayer() {
Nd4j.getRandom().setSeed(12345L);

View File

@ -55,11 +55,6 @@ public class RnnGradientChecks extends BaseDL4JTest {
Nd4j.setDataType(DataType.DOUBLE);
}
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
@Test
////@Ignore("AB 2019/06/24 - Ignored to get to all passing baseline to prevent regressions via CI - see issue #7912")
public void testBidirectionalWrapper() {

View File

@ -56,11 +56,6 @@ public class UtilLayerGradientChecks extends BaseDL4JTest {
Nd4j.setDataType(DataType.DOUBLE);
}
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
@Test
public void testMaskLayer() {
Nd4j.getRandom().setSeed(12345);

View File

@ -57,11 +57,6 @@ public class VaeGradientCheckTests extends BaseDL4JTest {
Nd4j.setDataType(DataType.DOUBLE);
}
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
@Test
public void testVaeAsMLP() {
//Post pre-training: a VAE can be used as a MLP, by taking the mean value from p(z|x) as the output

View File

@ -72,11 +72,6 @@ public class YoloGradientCheckTests extends BaseDL4JTest {
@TempDir
public File testDir;
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
@Test
public void testYoloOutputLayer() {
int depthIn = 2;

View File

@ -186,11 +186,6 @@ public class DTypeTests extends BaseDL4JTest {
TensorFlowCnnToFeedForwardPreProcessor.class //Deprecated
));
@Override
public long getTimeoutMilliseconds() {
return 9999999L;
}
@AfterAll
public static void after() {
ImmutableSet<ClassPath.ClassInfo> info;

View File

@ -26,7 +26,7 @@ import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.common.primitives.Pair;
@EqualsAndHashCode
@EqualsAndHashCode(callSuper = false)
public class CustomActivation extends BaseActivationFunction implements IActivation {
@Override
public INDArray getActivation(INDArray in, boolean training) {

View File

@ -93,11 +93,6 @@ public class BatchNormalizationTest extends BaseDL4JTest {
public void doBefore() {
}
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
@Test
public void testDnnForwardPass() {
int nOut = 10;

View File

@ -21,6 +21,7 @@
package org.deeplearning4j.nn.layers.samediff.testlayers;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import org.deeplearning4j.nn.conf.graph.GraphVertex;
import org.deeplearning4j.nn.conf.layers.samediff.SDVertexParams;
@ -37,6 +38,7 @@ import java.util.Map;
@NoArgsConstructor
@Data
@EqualsAndHashCode(callSuper = false)
public class SameDiffDenseVertex extends SameDiffVertex {
private int nIn;

View File

@ -21,6 +21,7 @@
package org.deeplearning4j.nn.multilayer;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.TestUtils;
@ -1424,6 +1425,7 @@ public class MultiLayerTest extends BaseDL4JTest {
}
@Data
@EqualsAndHashCode(callSuper = false)
public static class CheckModelsListener extends BaseTrainingListener {
private Set<Class<?>> modelClasses = new HashSet<>();

View File

@ -39,11 +39,6 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
@Slf4j
public class EncodedGradientsAccumulatorTest extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 1200000L;
}
/**
* This test ensures, that memory amount assigned to buffer is enough for any number of updates
* @throws Exception

View File

@ -39,7 +39,7 @@ import static org.junit.jupiter.api.Assertions.*;
@Slf4j
////@Ignore("AB 2019/05/21 - Failing (stuck, causing timeouts) - Issue #7657")
@Timeout(120000L)
//@Timeout(120000L)
public class SmartFancyBlockingQueueTest extends BaseDL4JTest {
@Test

View File

@ -47,11 +47,6 @@ import static org.junit.jupiter.api.Assertions.*;
public class TestCheckpointListener extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
@TempDir
public File tempDir;

View File

@ -67,11 +67,6 @@ public class TestListeners extends BaseDL4JTest {
@TempDir
public File tempDir;
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
@Test
public void testSettingListenersUnsupervised() {
//Pretrain layers should get copies of the listeners, in addition to the

View File

@ -60,11 +60,6 @@ public class RegressionTest060 extends BaseDL4JTest {
return DataType.FLOAT;
}
@Override
public long getTimeoutMilliseconds() {
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
}
@Test
public void regressionTestMLP1() throws Exception {

View File

@ -61,11 +61,6 @@ public class RegressionTest071 extends BaseDL4JTest {
return DataType.FLOAT;
}
@Override
public long getTimeoutMilliseconds() {
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
}
@Test
public void regressionTestMLP1() throws Exception {

View File

@ -60,11 +60,6 @@ public class RegressionTest080 extends BaseDL4JTest {
return DataType.FLOAT;
}
@Override
public long getTimeoutMilliseconds() {
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
}
@Test
public void regressionTestMLP1() throws Exception {

View File

@ -57,11 +57,6 @@ import static org.junit.jupiter.api.Assertions.*;
@Slf4j
public class RegressionTest100a extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
}
@Override
public DataType getDataType(){
return DataType.FLOAT;

View File

@ -54,11 +54,6 @@ import static org.junit.jupiter.api.Assertions.*;
public class RegressionTest100b3 extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
}
@Override
public DataType getDataType(){
return DataType.FLOAT;

View File

@ -73,11 +73,6 @@ import org.nd4j.common.resources.Resources;
public class RegressionTest100b4 extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
}
@Override
public DataType getDataType() {
return DataType.FLOAT;

View File

@ -60,11 +60,6 @@ public class RegressionTest100b6 extends BaseDL4JTest {
return DataType.FLOAT;
}
@Override
public long getTimeoutMilliseconds() {
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
}
@Test
public void testCustomLayer() throws Exception {

View File

@ -31,11 +31,6 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
public class TestDistributionDeserializer extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
}
@Test
public void testDistributionDeserializer() throws Exception {
//Test current format:

View File

@ -25,4 +25,5 @@ dependencies {
implementation "org.slf4j:slf4j-api"
implementation "org.apache.commons:commons-lang3"
implementation "com.fasterxml.jackson.core:jackson-annotations"
}

View File

@ -21,6 +21,7 @@
package org.deeplearning4j.nn.modelimport.keras.layers;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.ArrayUtils;
import org.deeplearning4j.common.config.DL4JClassLoading;
@ -46,6 +47,7 @@ import java.util.List;
@Slf4j
@Data
@EqualsAndHashCode(callSuper = false)
public class TFOpLayerImpl extends AbstractLayer<TFOpLayer> {

View File

@ -21,6 +21,7 @@
package org.deeplearning4j.nn.modelimport.keras.layers.core;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer;
@ -39,6 +40,7 @@ import java.util.Map;
*/
@Slf4j
@Data
@EqualsAndHashCode(callSuper = false)
public class KerasMasking extends KerasLayer {
private double maskingValue;

View File

@ -21,6 +21,7 @@
package org.deeplearning4j.nn.modelimport.keras.layers.core;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.conf.graph.ElementWiseVertex;
import org.deeplearning4j.nn.conf.graph.MergeVertex;
@ -35,6 +36,7 @@ import java.util.Map;
@Slf4j
@Data
@EqualsAndHashCode(callSuper = false)
public class KerasMerge extends KerasLayer {
private final String LAYER_FIELD_MODE = "mode";

Some files were not shown because too many files have changed in this diff Show More