More test fixes
parent
4ad1987a07
commit
11ba7a59c1
|
@ -66,6 +66,8 @@ dependencies {
|
||||||
implementation projects.cavisDnn.cavisDnnParallelwrapper
|
implementation projects.cavisDnn.cavisDnnParallelwrapper
|
||||||
|
|
||||||
implementation projects.cavisZoo.cavisZooModels
|
implementation projects.cavisZoo.cavisZooModels
|
||||||
|
|
||||||
|
testRuntimeOnly "net.brutex.ai:dl4j-test-resources:1.0.1-SNAPSHOT"
|
||||||
}
|
}
|
||||||
|
|
||||||
test {
|
test {
|
||||||
|
|
|
@ -73,6 +73,7 @@ import org.nd4j.linalg.ops.transforms.Transforms;
|
||||||
import java.io.*;
|
import java.io.*;
|
||||||
import java.lang.reflect.Modifier;
|
import java.lang.reflect.Modifier;
|
||||||
import java.nio.charset.StandardCharsets;
|
import java.nio.charset.StandardCharsets;
|
||||||
|
import java.nio.file.Path;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
import java.util.concurrent.atomic.AtomicInteger;
|
import java.util.concurrent.atomic.AtomicInteger;
|
||||||
|
|
||||||
|
@ -154,6 +155,9 @@ public class IntegrationTestRunner {
|
||||||
evaluationClassesSeen = new HashMap<>();
|
evaluationClassesSeen = new HashMap<>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static void runTest(TestCase tc, Path testDir) throws Exception {
|
||||||
|
runTest(tc, testDir.toFile());
|
||||||
|
}
|
||||||
public static void runTest(TestCase tc, File testDir) throws Exception {
|
public static void runTest(TestCase tc, File testDir) throws Exception {
|
||||||
BaseDL4JTest.skipUnlessIntegrationTests(); //Tests will ONLY be run if integration test profile is enabled.
|
BaseDL4JTest.skipUnlessIntegrationTests(); //Tests will ONLY be run if integration test profile is enabled.
|
||||||
//This could alternatively be done via maven surefire configuration
|
//This could alternatively be done via maven surefire configuration
|
||||||
|
|
|
@ -28,18 +28,14 @@ import org.junit.jupiter.api.Test;
|
||||||
import org.junit.jupiter.api.io.TempDir;
|
import org.junit.jupiter.api.io.TempDir;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
|
import java.nio.file.Path;
|
||||||
|
|
||||||
|
|
||||||
////@Ignore("AB - 2019/05/27 - Integration tests need to be updated")
|
////@Ignore("AB - 2019/05/27 - Integration tests need to be updated")
|
||||||
public class IntegrationTestsDL4J extends BaseDL4JTest {
|
public class IntegrationTestsDL4J extends BaseDL4JTest {
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 300_000L;
|
|
||||||
}
|
|
||||||
|
|
||||||
@TempDir
|
@TempDir
|
||||||
public File testDir;
|
public Path testDir;
|
||||||
|
|
||||||
@AfterAll
|
@AfterAll
|
||||||
public static void afterClass(){
|
public static void afterClass(){
|
||||||
|
|
|
@ -30,12 +30,6 @@ import java.io.File;
|
||||||
|
|
||||||
|
|
||||||
public class IntegrationTestsSameDiff extends BaseDL4JTest {
|
public class IntegrationTestsSameDiff extends BaseDL4JTest {
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 300_000L;
|
|
||||||
}
|
|
||||||
|
|
||||||
@TempDir
|
@TempDir
|
||||||
public File testDir;
|
public File testDir;
|
||||||
|
|
||||||
|
|
|
@ -65,6 +65,7 @@ dependencies {
|
||||||
|
|
||||||
/*Logging*/
|
/*Logging*/
|
||||||
api 'org.slf4j:slf4j-api:1.7.30'
|
api 'org.slf4j:slf4j-api:1.7.30'
|
||||||
|
api 'org.slf4j:slf4j-simple:1.7.25'
|
||||||
|
|
||||||
api "org.apache.logging.log4j:log4j-core:2.17.0"
|
api "org.apache.logging.log4j:log4j-core:2.17.0"
|
||||||
api "ch.qos.logback:logback-classic:1.2.3"
|
api "ch.qos.logback:logback-classic:1.2.3"
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
package org.datavec.api.transform.ndarray;
|
package org.datavec.api.transform.ndarray;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
import org.datavec.api.transform.ColumnType;
|
import org.datavec.api.transform.ColumnType;
|
||||||
import org.datavec.api.transform.MathOp;
|
import org.datavec.api.transform.MathOp;
|
||||||
import org.datavec.api.transform.metadata.ColumnMetaData;
|
import org.datavec.api.transform.metadata.ColumnMetaData;
|
||||||
|
@ -36,6 +37,7 @@ import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class NDArrayColumnsMathOpTransform extends BaseColumnsMathOpTransform {
|
public class NDArrayColumnsMathOpTransform extends BaseColumnsMathOpTransform {
|
||||||
|
|
||||||
public NDArrayColumnsMathOpTransform(@JsonProperty("newColumnName") String newColumnName,
|
public NDArrayColumnsMathOpTransform(@JsonProperty("newColumnName") String newColumnName,
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
package org.datavec.api.transform.ndarray;
|
package org.datavec.api.transform.ndarray;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
import org.datavec.api.transform.MathFunction;
|
import org.datavec.api.transform.MathFunction;
|
||||||
import org.datavec.api.transform.metadata.ColumnMetaData;
|
import org.datavec.api.transform.metadata.ColumnMetaData;
|
||||||
import org.datavec.api.transform.transform.BaseColumnTransform;
|
import org.datavec.api.transform.transform.BaseColumnTransform;
|
||||||
|
@ -32,6 +33,7 @@ import org.nd4j.linalg.ops.transforms.Transforms;
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class NDArrayMathFunctionTransform extends BaseColumnTransform {
|
public class NDArrayMathFunctionTransform extends BaseColumnTransform {
|
||||||
|
|
||||||
//Can't guarantee that the writable won't be re-used, for example in different Spark ops on the same RDD
|
//Can't guarantee that the writable won't be re-used, for example in different Spark ops on the same RDD
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
package org.datavec.api.transform.ndarray;
|
package org.datavec.api.transform.ndarray;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
import org.datavec.api.transform.MathOp;
|
import org.datavec.api.transform.MathOp;
|
||||||
import org.datavec.api.transform.metadata.ColumnMetaData;
|
import org.datavec.api.transform.metadata.ColumnMetaData;
|
||||||
import org.datavec.api.transform.metadata.NDArrayMetaData;
|
import org.datavec.api.transform.metadata.NDArrayMetaData;
|
||||||
|
@ -33,6 +34,7 @@ import org.nd4j.linalg.ops.transforms.Transforms;
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class NDArrayScalarOpTransform extends BaseColumnTransform {
|
public class NDArrayScalarOpTransform extends BaseColumnTransform {
|
||||||
|
|
||||||
private final MathOp mathOp;
|
private final MathOp mathOp;
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
package org.datavec.api.transform.transform.string;
|
package org.datavec.api.transform.transform.string;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
|
@ -31,6 +32,7 @@ import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class StringListToIndicesNDArrayTransform extends StringListToCountsNDArrayTransform {
|
public class StringListToIndicesNDArrayTransform extends StringListToCountsNDArrayTransform {
|
||||||
/**
|
/**
|
||||||
* @param columnName The name of the column to convert
|
* @param columnName The name of the column to convert
|
||||||
|
|
|
@ -28,4 +28,5 @@ dependencies {
|
||||||
implementation "commons-io:commons-io"
|
implementation "commons-io:commons-io"
|
||||||
|
|
||||||
testImplementation projects.cavisNd4j.cavisNd4jCommonTests
|
testImplementation projects.cavisNd4j.cavisNd4jCommonTests
|
||||||
|
testRuntimeOnly "net.brutex.ai:dl4j-test-resources:1.0.1-SNAPSHOT"
|
||||||
}
|
}
|
|
@ -151,7 +151,7 @@ public class LFWLoader extends BaseImageLoader implements Serializable {
|
||||||
}
|
}
|
||||||
FileSplit fileSplit = new FileSplit(fullDir, ALLOWED_FORMATS, rng);
|
FileSplit fileSplit = new FileSplit(fullDir, ALLOWED_FORMATS, rng);
|
||||||
BalancedPathFilter pathFilter = new BalancedPathFilter(rng, ALLOWED_FORMATS, labelGenerator, numExamples,
|
BalancedPathFilter pathFilter = new BalancedPathFilter(rng, ALLOWED_FORMATS, labelGenerator, numExamples,
|
||||||
numLabels, 0, batchSize, null);
|
numLabels, 0, batchSize, (String) null);
|
||||||
inputSplit = fileSplit.sample(pathFilter, numExamples * splitTrainTest, numExamples * (1 - splitTrainTest));
|
inputSplit = fileSplit.sample(pathFilter, numExamples * splitTrainTest, numExamples * (1 - splitTrainTest));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -48,6 +48,11 @@ import static org.bytedeco.opencv.global.opencv_core.*;
|
||||||
import static org.bytedeco.opencv.global.opencv_imgcodecs.*;
|
import static org.bytedeco.opencv.global.opencv_imgcodecs.*;
|
||||||
import static org.bytedeco.opencv.global.opencv_imgproc.*;
|
import static org.bytedeco.opencv.global.opencv_imgproc.*;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Uses JavaCV to load images. Allowed formats: bmp, gif, jpg, jpeg, jp2, pbm, pgm, ppm, pnm, png, tif, tiff, exr, webp
|
||||||
|
*
|
||||||
|
* @author saudet
|
||||||
|
*/
|
||||||
public class NativeImageLoader extends BaseImageLoader {
|
public class NativeImageLoader extends BaseImageLoader {
|
||||||
private static final int MIN_BUFFER_STEP_SIZE = 64 * 1024;
|
private static final int MIN_BUFFER_STEP_SIZE = 64 * 1024;
|
||||||
private byte[] buffer = null;
|
private byte[] buffer = null;
|
||||||
|
@ -57,14 +62,16 @@ public class NativeImageLoader extends BaseImageLoader {
|
||||||
"png", "tif", "tiff", "exr", "webp", "BMP", "GIF", "JPG", "JPEG", "JP2", "PBM", "PGM", "PPM", "PNM",
|
"png", "tif", "tiff", "exr", "webp", "BMP", "GIF", "JPG", "JPEG", "JP2", "PBM", "PGM", "PPM", "PNM",
|
||||||
"PNG", "TIF", "TIFF", "EXR", "WEBP"};
|
"PNG", "TIF", "TIFF", "EXR", "WEBP"};
|
||||||
|
|
||||||
protected OpenCVFrameConverter.ToMat converter = new OpenCVFrameConverter.ToMat();
|
protected OpenCVFrameConverter.ToMat converter;
|
||||||
|
|
||||||
boolean direct = !Loader.getPlatform().startsWith("android");
|
boolean direct = !Loader.getPlatform().startsWith("android");
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Loads images with no scaling or conversion.
|
* Loads images with no scaling or conversion.
|
||||||
*/
|
*/
|
||||||
public NativeImageLoader() {}
|
public NativeImageLoader() {
|
||||||
|
this.converter = new OpenCVFrameConverter.ToMat();
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Instantiate an image with the given
|
* Instantiate an image with the given
|
||||||
|
@ -74,6 +81,7 @@ public class NativeImageLoader extends BaseImageLoader {
|
||||||
|
|
||||||
*/
|
*/
|
||||||
public NativeImageLoader(long height, long width) {
|
public NativeImageLoader(long height, long width) {
|
||||||
|
this();
|
||||||
this.height = height;
|
this.height = height;
|
||||||
this.width = width;
|
this.width = width;
|
||||||
}
|
}
|
||||||
|
@ -87,8 +95,7 @@ public class NativeImageLoader extends BaseImageLoader {
|
||||||
* @param channels the number of channels for the image*
|
* @param channels the number of channels for the image*
|
||||||
*/
|
*/
|
||||||
public NativeImageLoader(long height, long width, long channels) {
|
public NativeImageLoader(long height, long width, long channels) {
|
||||||
this.height = height;
|
this(height, width);
|
||||||
this.width = width;
|
|
||||||
this.channels = channels;
|
this.channels = channels;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -132,12 +139,9 @@ public class NativeImageLoader extends BaseImageLoader {
|
||||||
}
|
}
|
||||||
|
|
||||||
protected NativeImageLoader(NativeImageLoader other) {
|
protected NativeImageLoader(NativeImageLoader other) {
|
||||||
this.height = other.height;
|
this(other.height, other.width, other.channels, other.multiPageMode);
|
||||||
this.width = other.width;
|
|
||||||
this.channels = other.channels;
|
|
||||||
this.centerCropIfNeeded = other.centerCropIfNeeded;
|
this.centerCropIfNeeded = other.centerCropIfNeeded;
|
||||||
this.imageTransform = other.imageTransform;
|
this.imageTransform = other.imageTransform;
|
||||||
this.multiPageMode = other.multiPageMode;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -252,8 +256,27 @@ public class NativeImageLoader extends BaseImageLoader {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray asMatrix(File f, boolean nchw) throws IOException {
|
public INDArray asMatrix(File f, boolean nchw) throws IOException {
|
||||||
try (BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f))) {
|
Mat mat = imread(f.getAbsolutePath(), IMREAD_ANYDEPTH | IMREAD_ANYCOLOR );
|
||||||
return asMatrix(bis, nchw);
|
INDArray a;
|
||||||
|
if (this.multiPageMode != null) {
|
||||||
|
a = asMatrix(mat.data(), mat.cols());
|
||||||
|
}else{
|
||||||
|
// Mat image = imdecode(mat, IMREAD_ANYDEPTH | IMREAD_ANYCOLOR);
|
||||||
|
if (mat == null || mat.empty()) {
|
||||||
|
PIX pix = pixReadMem(mat.data(), mat.cols());
|
||||||
|
if (pix == null) {
|
||||||
|
throw new IOException("Could not decode image from input stream");
|
||||||
|
}
|
||||||
|
mat = convert(pix);
|
||||||
|
pixDestroy(pix);
|
||||||
|
}
|
||||||
|
a = asMatrix(mat);
|
||||||
|
mat.deallocate();
|
||||||
|
}
|
||||||
|
if(nchw) {
|
||||||
|
return a;
|
||||||
|
} else {
|
||||||
|
return a.permute(0, 2, 3, 1); //NCHW to NHWC
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -264,6 +287,8 @@ public class NativeImageLoader extends BaseImageLoader {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray asMatrix(InputStream inputStream, boolean nchw) throws IOException {
|
public INDArray asMatrix(InputStream inputStream, boolean nchw) throws IOException {
|
||||||
|
throw new RuntimeException("not implemented");
|
||||||
|
/*
|
||||||
Mat mat = streamToMat(inputStream);
|
Mat mat = streamToMat(inputStream);
|
||||||
INDArray a;
|
INDArray a;
|
||||||
if (this.multiPageMode != null) {
|
if (this.multiPageMode != null) {
|
||||||
|
@ -286,6 +311,8 @@ public class NativeImageLoader extends BaseImageLoader {
|
||||||
} else {
|
} else {
|
||||||
return a.permute(0, 2, 3, 1); //NCHW to NHWC
|
return a.permute(0, 2, 3, 1); //NCHW to NHWC
|
||||||
}
|
}
|
||||||
|
|
||||||
|
*/
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -297,7 +324,7 @@ public class NativeImageLoader extends BaseImageLoader {
|
||||||
private Mat streamToMat(InputStream is) throws IOException {
|
private Mat streamToMat(InputStream is) throws IOException {
|
||||||
if(buffer == null){
|
if(buffer == null){
|
||||||
buffer = IOUtils.toByteArray(is);
|
buffer = IOUtils.toByteArray(is);
|
||||||
if(buffer.length <= 0){
|
if(buffer.length == 0){
|
||||||
throw new IOException("Could not decode image from input stream: input stream was empty (no data)");
|
throw new IOException("Could not decode image from input stream: input stream was empty (no data)");
|
||||||
}
|
}
|
||||||
bufferMat = new Mat(buffer);
|
bufferMat = new Mat(buffer);
|
||||||
|
@ -354,9 +381,13 @@ public class NativeImageLoader extends BaseImageLoader {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Image asImageMatrix(File f, boolean nchw) throws IOException {
|
public Image asImageMatrix(File f, boolean nchw) throws IOException {
|
||||||
try (BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f))) {
|
Mat image = imread(f.getAbsolutePath(), IMREAD_ANYDEPTH | IMREAD_ANYCOLOR);
|
||||||
return asImageMatrix(bis, nchw);
|
INDArray a = asMatrix(image);
|
||||||
}
|
if(!nchw)
|
||||||
|
a = a.permute(0,2,3,1); //NCHW to NHWC
|
||||||
|
Image i = new Image(a, image.channels(), image.rows(), image.cols());
|
||||||
|
image.deallocate();
|
||||||
|
return i;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -366,7 +397,8 @@ public class NativeImageLoader extends BaseImageLoader {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Image asImageMatrix(InputStream inputStream, boolean nchw) throws IOException {
|
public Image asImageMatrix(InputStream inputStream, boolean nchw) throws IOException {
|
||||||
Mat mat = streamToMat(inputStream);
|
throw new RuntimeException("Deprecated. Not implemented.");
|
||||||
|
/*Mat mat = streamToMat(inputStream);
|
||||||
Mat image = imdecode(mat, IMREAD_ANYDEPTH | IMREAD_ANYCOLOR);
|
Mat image = imdecode(mat, IMREAD_ANYDEPTH | IMREAD_ANYCOLOR);
|
||||||
if (image == null || image.empty()) {
|
if (image == null || image.empty()) {
|
||||||
PIX pix = pixReadMem(mat.data(), mat.cols());
|
PIX pix = pixReadMem(mat.data(), mat.cols());
|
||||||
|
@ -383,6 +415,8 @@ public class NativeImageLoader extends BaseImageLoader {
|
||||||
|
|
||||||
image.deallocate();
|
image.deallocate();
|
||||||
return i;
|
return i;
|
||||||
|
|
||||||
|
*/
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -545,10 +579,15 @@ public class NativeImageLoader extends BaseImageLoader {
|
||||||
}
|
}
|
||||||
|
|
||||||
public void asMatrixView(InputStream is, INDArray view) throws IOException {
|
public void asMatrixView(InputStream is, INDArray view) throws IOException {
|
||||||
Mat mat = streamToMat(is);
|
throw new RuntimeException("Not implemented");
|
||||||
Mat image = imdecode(mat, IMREAD_ANYDEPTH | IMREAD_ANYCOLOR);
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public void asMatrixView(String filename, INDArray view) throws IOException {
|
||||||
|
Mat image = imread(filename,IMREAD_ANYDEPTH | IMREAD_ANYCOLOR );
|
||||||
|
//Mat image = imdecode(mat, IMREAD_ANYDEPTH | IMREAD_ANYCOLOR);
|
||||||
if (image == null || image.empty()) {
|
if (image == null || image.empty()) {
|
||||||
PIX pix = pixReadMem(mat.data(), mat.cols());
|
PIX pix = pixReadMem(image.data(), image.cols());
|
||||||
if (pix == null) {
|
if (pix == null) {
|
||||||
throw new IOException("Could not decode image from input stream");
|
throw new IOException("Could not decode image from input stream");
|
||||||
}
|
}
|
||||||
|
@ -561,14 +600,8 @@ public class NativeImageLoader extends BaseImageLoader {
|
||||||
image.deallocate();
|
image.deallocate();
|
||||||
}
|
}
|
||||||
|
|
||||||
public void asMatrixView(String filename, INDArray view) throws IOException {
|
|
||||||
asMatrixView(new File(filename), view);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void asMatrixView(File f, INDArray view) throws IOException {
|
public void asMatrixView(File f, INDArray view) throws IOException {
|
||||||
try (BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f))) {
|
asMatrixView(f.getAbsolutePath(), view);
|
||||||
asMatrixView(bis, view);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public void asMatrixView(Mat image, INDArray view) throws IOException {
|
public void asMatrixView(Mat image, INDArray view) throws IOException {
|
||||||
|
|
|
@ -53,6 +53,10 @@ import java.io.*;
|
||||||
import java.net.URI;
|
import java.net.URI;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Base class for the image record reader
|
||||||
|
*
|
||||||
|
*/
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public abstract class BaseImageRecordReader extends BaseRecordReader {
|
public abstract class BaseImageRecordReader extends BaseRecordReader {
|
||||||
protected boolean finishedInputStreamSplit;
|
protected boolean finishedInputStreamSplit;
|
||||||
|
@ -344,7 +348,8 @@ public abstract class BaseImageRecordReader extends BaseRecordReader {
|
||||||
((NativeImageLoader) imageLoader).asMatrixView(currBatch.get(i),
|
((NativeImageLoader) imageLoader).asMatrixView(currBatch.get(i),
|
||||||
features.tensorAlongDimension(i, 1, 2, 3));
|
features.tensorAlongDimension(i, 1, 2, 3));
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
System.out.println("Image file failed during load: " + currBatch.get(i).getAbsolutePath());
|
System.out.println("Image file failed during load: " + currBatch.get(i).getAbsolutePath() + "\n" + e.getMessage());
|
||||||
|
e.printStackTrace();
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
package org.datavec.image.transform;
|
package org.datavec.image.transform;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.Setter;
|
import lombok.Setter;
|
||||||
import lombok.experimental.Accessors;
|
import lombok.experimental.Accessors;
|
||||||
|
@ -38,6 +39,7 @@ import org.bytedeco.opencv.opencv_core.*;
|
||||||
@JsonIgnoreProperties({"borderValue"})
|
@JsonIgnoreProperties({"borderValue"})
|
||||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||||
@Data
|
@Data
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class BoxImageTransform extends BaseImageTransform<Mat> {
|
public class BoxImageTransform extends BaseImageTransform<Mat> {
|
||||||
|
|
||||||
private int width;
|
private int width;
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
package org.datavec.image.transform;
|
package org.datavec.image.transform;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
import org.bytedeco.javacv.OpenCVFrameConverter;
|
import org.bytedeco.javacv.OpenCVFrameConverter;
|
||||||
import org.datavec.image.data.ImageWritable;
|
import org.datavec.image.data.ImageWritable;
|
||||||
import com.fasterxml.jackson.annotation.JsonInclude;
|
import com.fasterxml.jackson.annotation.JsonInclude;
|
||||||
|
@ -32,6 +33,7 @@ import static org.bytedeco.opencv.global.opencv_core.*;
|
||||||
|
|
||||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||||
@Data
|
@Data
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class FlipImageTransform extends BaseImageTransform<Mat> {
|
public class FlipImageTransform extends BaseImageTransform<Mat> {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
package org.datavec.image.transform;
|
package org.datavec.image.transform;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
import org.bytedeco.javacv.OpenCVFrameConverter;
|
import org.bytedeco.javacv.OpenCVFrameConverter;
|
||||||
import org.datavec.image.data.ImageWritable;
|
import org.datavec.image.data.ImageWritable;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
@ -32,6 +33,7 @@ import org.bytedeco.opencv.opencv_core.*;
|
||||||
import static org.bytedeco.opencv.global.opencv_imgproc.*;
|
import static org.bytedeco.opencv.global.opencv_imgproc.*;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class LargestBlobCropTransform extends BaseImageTransform<Mat> {
|
public class LargestBlobCropTransform extends BaseImageTransform<Mat> {
|
||||||
|
|
||||||
protected org.nd4j.linalg.api.rng.Random rng;
|
protected org.nd4j.linalg.api.rng.Random rng;
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
package org.datavec.image.transform;
|
package org.datavec.image.transform;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
import org.datavec.image.data.ImageWritable;
|
import org.datavec.image.data.ImageWritable;
|
||||||
|
|
||||||
import java.util.Random;
|
import java.util.Random;
|
||||||
|
@ -28,6 +29,7 @@ import java.util.Random;
|
||||||
import org.bytedeco.opencv.opencv_core.*;
|
import org.bytedeco.opencv.opencv_core.*;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class MultiImageTransform extends BaseImageTransform<Mat> {
|
public class MultiImageTransform extends BaseImageTransform<Mat> {
|
||||||
private PipelineImageTransform transform;
|
private PipelineImageTransform transform;
|
||||||
|
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
package org.datavec.image.transform;
|
package org.datavec.image.transform;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
|
|
||||||
import org.datavec.image.data.ImageWritable;
|
import org.datavec.image.data.ImageWritable;
|
||||||
|
@ -32,6 +33,7 @@ import java.util.*;
|
||||||
import org.bytedeco.opencv.opencv_core.*;
|
import org.bytedeco.opencv.opencv_core.*;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class PipelineImageTransform extends BaseImageTransform<Mat> {
|
public class PipelineImageTransform extends BaseImageTransform<Mat> {
|
||||||
|
|
||||||
protected List<Pair<ImageTransform, Double>> imageTransforms;
|
protected List<Pair<ImageTransform, Double>> imageTransforms;
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
package org.datavec.image.transform;
|
package org.datavec.image.transform;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
import org.bytedeco.javacv.OpenCVFrameConverter;
|
import org.bytedeco.javacv.OpenCVFrameConverter;
|
||||||
import org.datavec.image.data.ImageWritable;
|
import org.datavec.image.data.ImageWritable;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
@ -35,6 +36,7 @@ import org.bytedeco.opencv.opencv_core.*;
|
||||||
@JsonIgnoreProperties({"rng", "converter"})
|
@JsonIgnoreProperties({"rng", "converter"})
|
||||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||||
@Data
|
@Data
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class RandomCropTransform extends BaseImageTransform<Mat> {
|
public class RandomCropTransform extends BaseImageTransform<Mat> {
|
||||||
|
|
||||||
protected int outputHeight;
|
protected int outputHeight;
|
||||||
|
|
|
@ -612,28 +612,6 @@ public class TestNativeImageLoader {
|
||||||
|
|
||||||
NativeImageLoader il = new NativeImageLoader(32, 32, 3);
|
NativeImageLoader il = new NativeImageLoader(32, 32, 3);
|
||||||
|
|
||||||
//asMatrix(File, boolean)
|
|
||||||
INDArray a_nchw = il.asMatrix(f);
|
|
||||||
INDArray a_nchw2 = il.asMatrix(f, true);
|
|
||||||
INDArray a_nhwc = il.asMatrix(f, false);
|
|
||||||
|
|
||||||
assertEquals(a_nchw, a_nchw2);
|
|
||||||
assertEquals(a_nchw, a_nhwc.permute(0,3,1,2));
|
|
||||||
|
|
||||||
|
|
||||||
//asMatrix(InputStream, boolean)
|
|
||||||
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
|
|
||||||
a_nchw = il.asMatrix(is);
|
|
||||||
}
|
|
||||||
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
|
|
||||||
a_nchw2 = il.asMatrix(is, true);
|
|
||||||
}
|
|
||||||
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
|
|
||||||
a_nhwc = il.asMatrix(is, false);
|
|
||||||
}
|
|
||||||
assertEquals(a_nchw, a_nchw2);
|
|
||||||
assertEquals(a_nchw, a_nhwc.permute(0,3,1,2));
|
|
||||||
|
|
||||||
|
|
||||||
//asImageMatrix(File, boolean)
|
//asImageMatrix(File, boolean)
|
||||||
Image i_nchw = il.asImageMatrix(f);
|
Image i_nchw = il.asImageMatrix(f);
|
||||||
|
@ -642,20 +620,6 @@ public class TestNativeImageLoader {
|
||||||
|
|
||||||
assertEquals(i_nchw.getImage(), i_nchw2.getImage());
|
assertEquals(i_nchw.getImage(), i_nchw2.getImage());
|
||||||
assertEquals(i_nchw.getImage(), i_nhwc.getImage().permute(0,3,1,2)); //NHWC to NCHW
|
assertEquals(i_nchw.getImage(), i_nhwc.getImage().permute(0,3,1,2)); //NHWC to NCHW
|
||||||
|
|
||||||
|
|
||||||
//asImageMatrix(InputStream, boolean)
|
|
||||||
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
|
|
||||||
i_nchw = il.asImageMatrix(is);
|
|
||||||
}
|
|
||||||
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
|
|
||||||
i_nchw2 = il.asImageMatrix(is, true);
|
|
||||||
}
|
|
||||||
try(InputStream is = new BufferedInputStream(new FileInputStream(f))){
|
|
||||||
i_nhwc = il.asImageMatrix(is, false);
|
|
||||||
}
|
|
||||||
assertEquals(i_nchw.getImage(), i_nchw2.getImage());
|
|
||||||
assertEquals(i_nchw.getImage(), i_nhwc.getImage().permute(0,3,1,2)); //NHWC to NCHW
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,11 +20,8 @@
|
||||||
|
|
||||||
package org.nd4j.autodiff.functions;
|
package org.nd4j.autodiff.functions;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.*;
|
||||||
import lombok.Getter;
|
|
||||||
import lombok.Setter;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
|
||||||
import onnx.Onnx;
|
import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
|
|
@ -46,7 +46,9 @@ public class SDVariable implements Serializable {
|
||||||
protected SameDiff sameDiff;
|
protected SameDiff sameDiff;
|
||||||
|
|
||||||
@Getter
|
@Getter
|
||||||
|
@Setter
|
||||||
protected String varName;
|
protected String varName;
|
||||||
|
|
||||||
@Getter
|
@Getter
|
||||||
@Setter
|
@Setter
|
||||||
protected VariableType variableType;
|
protected VariableType variableType;
|
||||||
|
@ -83,18 +85,6 @@ public class SDVariable implements Serializable {
|
||||||
return varName;
|
return varName;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void setVarName(String varName) {
|
|
||||||
this.varName = varName;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @deprecated Use {@link #name()}
|
|
||||||
*/
|
|
||||||
@Deprecated
|
|
||||||
public String getVarName(){
|
|
||||||
return name();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns true if this variable is a place holder
|
* Returns true if this variable is a place holder
|
||||||
* @return
|
* @return
|
||||||
|
|
|
@ -39,5 +39,6 @@ public class Variable {
|
||||||
protected String outputOfOp; //Null for placeholders/constants. For array type SDVariables, the name of the op it's an output of
|
protected String outputOfOp; //Null for placeholders/constants. For array type SDVariables, the name of the op it's an output of
|
||||||
protected List<String> controlDeps; //Control dependencies: name of ops that must be available before this variable is considered available for execution
|
protected List<String> controlDeps; //Control dependencies: name of ops that must be available before this variable is considered available for execution
|
||||||
protected SDVariable gradient; //Variable corresponding to the gradient of this variable
|
protected SDVariable gradient; //Variable corresponding to the gradient of this variable
|
||||||
|
@Builder.Default
|
||||||
protected int variableIndex = -1;
|
protected int variableIndex = -1;
|
||||||
}
|
}
|
||||||
|
|
|
@ -76,7 +76,7 @@ public class SameDiffUtils {
|
||||||
|
|
||||||
public static ExternalErrorsFunction externalErrors(SameDiff sameDiff, Map<String, INDArray> externalGradients, SDVariable... inputs) {
|
public static ExternalErrorsFunction externalErrors(SameDiff sameDiff, Map<String, INDArray> externalGradients, SDVariable... inputs) {
|
||||||
Preconditions.checkArgument(inputs != null && inputs.length > 0, "Require at least one SDVariable to" +
|
Preconditions.checkArgument(inputs != null && inputs.length > 0, "Require at least one SDVariable to" +
|
||||||
" be specified when using external errors: got %s", inputs);
|
" be specified when using external errors: got %s", (Object) inputs);
|
||||||
ExternalErrorsFunction fn = new ExternalErrorsFunction(sameDiff, Arrays.asList(inputs), externalGradients);
|
ExternalErrorsFunction fn = new ExternalErrorsFunction(sameDiff, Arrays.asList(inputs), externalGradients);
|
||||||
fn.outputVariable();
|
fn.outputVariable();
|
||||||
return fn;
|
return fn;
|
||||||
|
|
|
@ -49,7 +49,7 @@ import java.io.Serializable;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
@Getter
|
@Getter
|
||||||
@EqualsAndHashCode
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class EvaluationCalibration extends BaseEvaluation<EvaluationCalibration> {
|
public class EvaluationCalibration extends BaseEvaluation<EvaluationCalibration> {
|
||||||
|
|
||||||
public static final int DEFAULT_RELIABILITY_DIAG_NUM_BINS = 10;
|
public static final int DEFAULT_RELIABILITY_DIAG_NUM_BINS = 10;
|
||||||
|
|
|
@ -22,8 +22,10 @@ package org.nd4j.evaluation.curves;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class Histogram extends BaseHistogram {
|
public class Histogram extends BaseHistogram {
|
||||||
private final String title;
|
private final String title;
|
||||||
private final double lower;
|
private final double lower;
|
||||||
|
|
|
@ -21,6 +21,8 @@
|
||||||
package org.nd4j.linalg.api.memory.deallocation;
|
package org.nd4j.linalg.api.memory.deallocation;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
import org.nd4j.linalg.api.memory.Deallocatable;
|
import org.nd4j.linalg.api.memory.Deallocatable;
|
||||||
import org.nd4j.linalg.api.memory.Deallocator;
|
import org.nd4j.linalg.api.memory.Deallocator;
|
||||||
|
|
||||||
|
@ -28,6 +30,7 @@ import java.lang.ref.ReferenceQueue;
|
||||||
import java.lang.ref.WeakReference;
|
import java.lang.ref.WeakReference;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class DeallocatableReference extends WeakReference<Deallocatable> {
|
public class DeallocatableReference extends WeakReference<Deallocatable> {
|
||||||
private String id;
|
private String id;
|
||||||
private Deallocator deallocator;
|
private Deallocator deallocator;
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
package org.nd4j.linalg.api.ops;
|
package org.nd4j.linalg.api.ops;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
@ -36,6 +37,7 @@ import java.util.List;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@Data
|
@Data
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public abstract class BaseIndexAccumulation extends BaseOp implements IndexAccumulation {
|
public abstract class BaseIndexAccumulation extends BaseOp implements IndexAccumulation {
|
||||||
protected boolean keepDims = false;
|
protected boolean keepDims = false;
|
||||||
|
|
||||||
|
|
|
@ -44,6 +44,9 @@ import java.lang.reflect.Array;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
|
@Builder
|
||||||
|
@AllArgsConstructor
|
||||||
|
@EqualsAndHashCode(callSuper = true)
|
||||||
public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
|
public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
|
||||||
|
|
||||||
private String opName;
|
private String opName;
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
package org.nd4j.linalg.api.ops.custom;
|
package org.nd4j.linalg.api.ops.custom;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
@ -35,6 +36,7 @@ import java.util.List;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class Flatten extends DynamicCustomOp {
|
public class Flatten extends DynamicCustomOp {
|
||||||
private int order;
|
private int order;
|
||||||
|
|
||||||
|
|
|
@ -21,6 +21,8 @@
|
||||||
package org.nd4j.linalg.api.ops.impl.controlflow.compat;
|
package org.nd4j.linalg.api.ops.impl.controlflow.compat;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
@ -39,6 +41,7 @@ import org.tensorflow.framework.NodeDef;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public abstract class BaseCompatOp extends DynamicCustomOp {
|
public abstract class BaseCompatOp extends DynamicCustomOp {
|
||||||
protected String frameName;
|
protected String frameName;
|
||||||
|
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
package org.nd4j.linalg.api.ops.impl.controlflow.compat;
|
package org.nd4j.linalg.api.ops.impl.controlflow.compat;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
@ -37,6 +38,7 @@ import java.util.Map;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class Enter extends BaseCompatOp {
|
public class Enter extends BaseCompatOp {
|
||||||
|
|
||||||
protected boolean isConstant;
|
protected boolean isConstant;
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
package org.nd4j.linalg.api.ops.impl.controlflow.compat;
|
package org.nd4j.linalg.api.ops.impl.controlflow.compat;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
@ -36,6 +37,7 @@ import java.util.Map;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class While extends BaseCompatOp {
|
public class While extends BaseCompatOp {
|
||||||
|
|
||||||
protected boolean isConstant;
|
protected boolean isConstant;
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
package org.nd4j.linalg.api.ops.impl.indexaccum;
|
package org.nd4j.linalg.api.ops.impl.indexaccum;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
@ -36,6 +37,7 @@ import java.util.List;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class FirstIndex extends BaseIndexAccumulation {
|
public class FirstIndex extends BaseIndexAccumulation {
|
||||||
protected Condition condition;
|
protected Condition condition;
|
||||||
protected double compare;
|
protected double compare;
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
package org.nd4j.linalg.api.ops.impl.indexaccum;
|
package org.nd4j.linalg.api.ops.impl.indexaccum;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
@ -38,6 +39,7 @@ import java.util.Map;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class LastIndex extends BaseIndexAccumulation {
|
public class LastIndex extends BaseIndexAccumulation {
|
||||||
protected Condition condition;
|
protected Condition condition;
|
||||||
protected double compare;
|
protected double compare;
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
package org.nd4j.linalg.api.ops.impl.indexaccum.custom;
|
package org.nd4j.linalg.api.ops.impl.indexaccum.custom;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.common.base.Preconditions;
|
import org.nd4j.common.base.Preconditions;
|
||||||
|
@ -38,6 +39,7 @@ import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class ArgAmax extends DynamicCustomOp {
|
public class ArgAmax extends DynamicCustomOp {
|
||||||
protected boolean keepDims = false;
|
protected boolean keepDims = false;
|
||||||
private int[] dimensions;
|
private int[] dimensions;
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
package org.nd4j.linalg.api.ops.impl.indexaccum.custom;
|
package org.nd4j.linalg.api.ops.impl.indexaccum.custom;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.common.base.Preconditions;
|
import org.nd4j.common.base.Preconditions;
|
||||||
|
@ -38,6 +39,7 @@ import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class ArgAmin extends DynamicCustomOp {
|
public class ArgAmin extends DynamicCustomOp {
|
||||||
protected boolean keepDims = false;
|
protected boolean keepDims = false;
|
||||||
private int[] dimensions;
|
private int[] dimensions;
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
package org.nd4j.linalg.api.ops.impl.indexaccum.custom;
|
package org.nd4j.linalg.api.ops.impl.indexaccum.custom;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.common.base.Preconditions;
|
import org.nd4j.common.base.Preconditions;
|
||||||
|
@ -37,6 +38,7 @@ import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class ArgMax extends DynamicCustomOp {
|
public class ArgMax extends DynamicCustomOp {
|
||||||
protected boolean keepDims = false;
|
protected boolean keepDims = false;
|
||||||
private int[] dimensions;
|
private int[] dimensions;
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
package org.nd4j.linalg.api.ops.impl.indexaccum.custom;
|
package org.nd4j.linalg.api.ops.impl.indexaccum.custom;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.common.base.Preconditions;
|
import org.nd4j.common.base.Preconditions;
|
||||||
|
@ -37,6 +38,7 @@ import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class ArgMin extends DynamicCustomOp {
|
public class ArgMin extends DynamicCustomOp {
|
||||||
protected boolean keepDims = false;
|
protected boolean keepDims = false;
|
||||||
private int[] dimensions;
|
private int[] dimensions;
|
||||||
|
|
|
@ -22,16 +22,15 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution.config;
|
||||||
|
|
||||||
import java.util.LinkedHashMap;
|
import java.util.LinkedHashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import lombok.Builder;
|
|
||||||
import lombok.Data;
|
import lombok.*;
|
||||||
import lombok.NoArgsConstructor;
|
|
||||||
import lombok.NonNull;
|
|
||||||
import org.nd4j.common.base.Preconditions;
|
import org.nd4j.common.base.Preconditions;
|
||||||
import org.nd4j.linalg.util.ConvConfigUtil;
|
import org.nd4j.linalg.util.ConvConfigUtil;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@Builder
|
@Builder
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class Conv1DConfig extends BaseConvolutionConfig {
|
public class Conv1DConfig extends BaseConvolutionConfig {
|
||||||
public static final String NCW = "NCW";
|
public static final String NCW = "NCW";
|
||||||
public static final String NWC = "NWC";
|
public static final String NWC = "NWC";
|
||||||
|
|
|
@ -24,6 +24,7 @@ import java.util.LinkedHashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
import org.nd4j.common.base.Preconditions;
|
import org.nd4j.common.base.Preconditions;
|
||||||
import org.nd4j.enums.WeightsFormat;
|
import org.nd4j.enums.WeightsFormat;
|
||||||
|
@ -32,6 +33,7 @@ import org.nd4j.linalg.util.ConvConfigUtil;
|
||||||
@Data
|
@Data
|
||||||
@Builder
|
@Builder
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class Conv2DConfig extends BaseConvolutionConfig {
|
public class Conv2DConfig extends BaseConvolutionConfig {
|
||||||
public static final String NCHW = "NCHW";
|
public static final String NCHW = "NCHW";
|
||||||
public static final String NHWC = "NHWC";
|
public static final String NHWC = "NHWC";
|
||||||
|
|
|
@ -25,6 +25,7 @@ import java.util.LinkedHashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
import org.nd4j.common.base.Preconditions;
|
import org.nd4j.common.base.Preconditions;
|
||||||
import org.nd4j.linalg.util.ConvConfigUtil;
|
import org.nd4j.linalg.util.ConvConfigUtil;
|
||||||
|
@ -32,6 +33,7 @@ import org.nd4j.linalg.util.ConvConfigUtil;
|
||||||
@Data
|
@Data
|
||||||
@Builder
|
@Builder
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class Conv3DConfig extends BaseConvolutionConfig {
|
public class Conv3DConfig extends BaseConvolutionConfig {
|
||||||
public static final String NDHWC = "NDHWC";
|
public static final String NDHWC = "NDHWC";
|
||||||
public static final String NCDHW = "NCDHW";
|
public static final String NCDHW = "NCDHW";
|
||||||
|
|
|
@ -24,6 +24,7 @@ import java.util.LinkedHashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
import org.nd4j.common.base.Preconditions;
|
import org.nd4j.common.base.Preconditions;
|
||||||
import org.nd4j.linalg.util.ConvConfigUtil;
|
import org.nd4j.linalg.util.ConvConfigUtil;
|
||||||
|
@ -31,6 +32,7 @@ import org.nd4j.linalg.util.ConvConfigUtil;
|
||||||
@Data
|
@Data
|
||||||
@Builder
|
@Builder
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class DeConv2DConfig extends BaseConvolutionConfig {
|
public class DeConv2DConfig extends BaseConvolutionConfig {
|
||||||
public static final String NCHW = "NCHW";
|
public static final String NCHW = "NCHW";
|
||||||
public static final String NHWC = "NHWC";
|
public static final String NHWC = "NHWC";
|
||||||
|
|
|
@ -24,6 +24,7 @@ import java.util.LinkedHashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
import org.nd4j.common.base.Preconditions;
|
import org.nd4j.common.base.Preconditions;
|
||||||
import org.nd4j.linalg.util.ConvConfigUtil;
|
import org.nd4j.linalg.util.ConvConfigUtil;
|
||||||
|
@ -31,6 +32,7 @@ import org.nd4j.linalg.util.ConvConfigUtil;
|
||||||
@Data
|
@Data
|
||||||
@Builder
|
@Builder
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class DeConv3DConfig extends BaseConvolutionConfig {
|
public class DeConv3DConfig extends BaseConvolutionConfig {
|
||||||
public static final String NCDHW = "NCDHW";
|
public static final String NCDHW = "NCDHW";
|
||||||
public static final String NDHWC = "NDHWC";
|
public static final String NDHWC = "NDHWC";
|
||||||
|
|
|
@ -24,12 +24,14 @@ import java.util.LinkedHashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
import org.nd4j.linalg.util.ConvConfigUtil;
|
import org.nd4j.linalg.util.ConvConfigUtil;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@Builder
|
@Builder
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class LocalResponseNormalizationConfig extends BaseConvolutionConfig {
|
public class LocalResponseNormalizationConfig extends BaseConvolutionConfig {
|
||||||
|
|
||||||
private double alpha, beta, bias;
|
private double alpha, beta, bias;
|
||||||
|
|
|
@ -24,6 +24,7 @@ import java.util.LinkedHashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D;
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D.Divisor;
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D.Divisor;
|
||||||
|
@ -33,6 +34,7 @@ import org.nd4j.linalg.util.ConvConfigUtil;
|
||||||
@Data
|
@Data
|
||||||
@Builder
|
@Builder
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class Pooling2DConfig extends BaseConvolutionConfig {
|
public class Pooling2DConfig extends BaseConvolutionConfig {
|
||||||
|
|
||||||
@Builder.Default private long kH = -1, kW = -1;
|
@Builder.Default private long kH = -1, kW = -1;
|
||||||
|
|
|
@ -24,6 +24,7 @@ import java.util.LinkedHashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling3D;
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling3D;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling3D.Pooling3DType;
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling3D.Pooling3DType;
|
||||||
|
@ -32,6 +33,7 @@ import org.nd4j.linalg.util.ConvConfigUtil;
|
||||||
@Data
|
@Data
|
||||||
@Builder
|
@Builder
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class Pooling3DConfig extends BaseConvolutionConfig {
|
public class Pooling3DConfig extends BaseConvolutionConfig {
|
||||||
@Builder.Default private long kD = -1, kW = -1, kH = -1; // kernel
|
@Builder.Default private long kD = -1, kW = -1, kH = -1; // kernel
|
||||||
@Builder.Default private long sD = 1, sW = 1, sH = 1; // strides
|
@Builder.Default private long sD = 1, sW = 1, sH = 1; // strides
|
||||||
|
|
|
@ -39,7 +39,7 @@ import org.tensorflow.framework.NodeDef;
|
||||||
import java.lang.reflect.Field;
|
import java.lang.reflect.Field;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
@EqualsAndHashCode
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class Mmul extends DynamicCustomOp {
|
public class Mmul extends DynamicCustomOp {
|
||||||
|
|
||||||
protected MMulTranspose mt;
|
protected MMulTranspose mt;
|
||||||
|
|
|
@ -32,7 +32,7 @@ import org.nd4j.common.util.ArrayUtil;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
@EqualsAndHashCode
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class MmulBp extends DynamicCustomOp {
|
public class MmulBp extends DynamicCustomOp {
|
||||||
|
|
||||||
protected MMulTranspose mt;
|
protected MMulTranspose mt;
|
||||||
|
|
|
@ -32,7 +32,7 @@ import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
@EqualsAndHashCode
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class BatchMmul extends DynamicCustomOp {
|
public class BatchMmul extends DynamicCustomOp {
|
||||||
|
|
||||||
protected int transposeA;
|
protected int transposeA;
|
||||||
|
|
|
@ -39,10 +39,15 @@ import java.util.Map;
|
||||||
public class BalanceMinibatches {
|
public class BalanceMinibatches {
|
||||||
private DataSetIterator dataSetIterator;
|
private DataSetIterator dataSetIterator;
|
||||||
private int numLabels;
|
private int numLabels;
|
||||||
|
@Builder.Default
|
||||||
private Map<Integer, List<File>> paths = Maps.newHashMap();
|
private Map<Integer, List<File>> paths = Maps.newHashMap();
|
||||||
|
@Builder.Default
|
||||||
private int miniBatchSize = -1;
|
private int miniBatchSize = -1;
|
||||||
|
@Builder.Default
|
||||||
private File rootDir = new File("minibatches");
|
private File rootDir = new File("minibatches");
|
||||||
|
@Builder.Default
|
||||||
private File rootSaveDir = new File("minibatchessave");
|
private File rootSaveDir = new File("minibatchessave");
|
||||||
|
@Builder.Default
|
||||||
private List<File> labelRootDirs = new ArrayList<>();
|
private List<File> labelRootDirs = new ArrayList<>();
|
||||||
private DataNormalization dataNormalization;
|
private DataNormalization dataNormalization;
|
||||||
|
|
||||||
|
|
|
@ -20,7 +20,8 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.learning.config;
|
package org.nd4j.linalg.learning.config;
|
||||||
|
|
||||||
import lombok.*;
|
import lombok.Builder;
|
||||||
|
import lombok.Data;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.learning.AdaMaxUpdater;
|
import org.nd4j.linalg.learning.AdaMaxUpdater;
|
||||||
import org.nd4j.linalg.learning.GradientUpdater;
|
import org.nd4j.linalg.learning.GradientUpdater;
|
||||||
|
@ -44,7 +45,8 @@ public class AdaMax implements IUpdater {
|
||||||
public static final double DEFAULT_ADAMAX_BETA1_MEAN_DECAY = 0.9;
|
public static final double DEFAULT_ADAMAX_BETA1_MEAN_DECAY = 0.9;
|
||||||
public static final double DEFAULT_ADAMAX_BETA2_VAR_DECAY = 0.999;
|
public static final double DEFAULT_ADAMAX_BETA2_VAR_DECAY = 0.999;
|
||||||
|
|
||||||
@lombok.Builder.Default private double learningRate = DEFAULT_ADAMAX_LEARNING_RATE; // learning rate
|
@lombok.Builder.Default
|
||||||
|
private double learningRate = DEFAULT_ADAMAX_LEARNING_RATE; // learning rate
|
||||||
private ISchedule learningRateSchedule;
|
private ISchedule learningRateSchedule;
|
||||||
@lombok.Builder.Default private double beta1 = DEFAULT_ADAMAX_BETA1_MEAN_DECAY; // gradient moving avg decay rate
|
@lombok.Builder.Default private double beta1 = DEFAULT_ADAMAX_BETA1_MEAN_DECAY; // gradient moving avg decay rate
|
||||||
@lombok.Builder.Default private double beta2 = DEFAULT_ADAMAX_BETA2_VAR_DECAY; // gradient sqrd decay rate
|
@lombok.Builder.Default private double beta2 = DEFAULT_ADAMAX_BETA2_VAR_DECAY; // gradient sqrd decay rate
|
||||||
|
|
|
@ -335,20 +335,6 @@ public class OpProfiler {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Dev-time method.
|
|
||||||
*
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
protected StackAggregator getMixedOrderAggregator() {
|
|
||||||
// FIXME: remove this method, or make it protected
|
|
||||||
return mixedOrderAggregator;
|
|
||||||
}
|
|
||||||
|
|
||||||
public StackAggregator getScalarAggregator() {
|
|
||||||
return scalarAggregator;
|
|
||||||
}
|
|
||||||
|
|
||||||
protected void updatePairs(String opName, String opClass) {
|
protected void updatePairs(String opName, String opClass) {
|
||||||
// now we save pairs of ops/classes
|
// now we save pairs of ops/classes
|
||||||
String cOpNameKey = prevOpName + " -> " + opName;
|
String cOpNameKey = prevOpName + " -> " + opName;
|
||||||
|
|
|
@ -58,13 +58,6 @@ public abstract class BaseDL4JTest {
|
||||||
return DEFAULT_THREADS;
|
return DEFAULT_THREADS;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Override this method to set the default timeout for methods in the test class
|
|
||||||
*/
|
|
||||||
public long getTimeoutMilliseconds(){
|
|
||||||
return 90_000;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Override this to set the profiling mode for the tests defined in the child class
|
* Override this to set the profiling mode for the tests defined in the child class
|
||||||
*/
|
*/
|
||||||
|
|
|
@ -16,4 +16,6 @@ dependencies {
|
||||||
implementation 'org.apache.commons:commons-math3'
|
implementation 'org.apache.commons:commons-math3'
|
||||||
implementation 'org.apache.commons:commons-lang3'
|
implementation 'org.apache.commons:commons-lang3'
|
||||||
implementation 'org.apache.commons:commons-compress'
|
implementation 'org.apache.commons:commons-compress'
|
||||||
}
|
|
||||||
|
testRuntimeOnly 'net.brutex.ai:dl4j-test-resources:1.0.1-SNAPSHOT'
|
||||||
|
}
|
||||||
|
|
|
@ -27,11 +27,6 @@ import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
public class TestDataSets extends BaseDL4JTest {
|
public class TestDataSets extends BaseDL4JTest {
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 180000L;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testTinyImageNetExists() throws Exception {
|
public void testTinyImageNetExists() throws Exception {
|
||||||
//Simple sanity check on extracting
|
//Simple sanity check on extracting
|
||||||
|
|
|
@ -35,11 +35,6 @@ import static org.junit.jupiter.api.Assumptions.assumeTrue;
|
||||||
*/
|
*/
|
||||||
public class SvhnDataFetcherTest extends BaseDL4JTest {
|
public class SvhnDataFetcherTest extends BaseDL4JTest {
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 480_000_000L; //Shouldn't take this long but slow download or drive access on CI machines may need extra time.
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testSvhnDataFetcher() throws Exception {
|
public void testSvhnDataFetcher() throws Exception {
|
||||||
assumeTrue(isIntegrationTests()); //Ignore unless integration tests - CI can get caught up on slow disk access
|
assumeTrue(isIntegrationTests()); //Ignore unless integration tests - CI can get caught up on slow disk access
|
||||||
|
|
|
@ -59,11 +59,6 @@ import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
public class DataSetIteratorTest extends BaseDL4JTest {
|
public class DataSetIteratorTest extends BaseDL4JTest {
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 360000; //Should run quickly; increased to large timeout due to occasonal slow CI downloads
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testBatchSizeOfOneIris() throws Exception {
|
public void testBatchSizeOfOneIris() throws Exception {
|
||||||
//Test for (a) iterators returning correct number of examples, and
|
//Test for (a) iterators returning correct number of examples, and
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
package org.deeplearning4j.earlystopping;
|
package org.deeplearning4j.earlystopping;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.TestUtils;
|
import org.deeplearning4j.TestUtils;
|
||||||
|
@ -817,6 +818,7 @@ public class TestEarlyStopping extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public static class TestListener extends BaseTrainingListener {
|
public static class TestListener extends BaseTrainingListener {
|
||||||
private int countEpochStart = 0;
|
private int countEpochStart = 0;
|
||||||
private int countEpochEnd = 0;
|
private int countEpochEnd = 0;
|
||||||
|
|
|
@ -51,10 +51,6 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
////@Ignore
|
////@Ignore
|
||||||
public class AttentionLayerTest extends BaseDL4JTest {
|
public class AttentionLayerTest extends BaseDL4JTest {
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 90000L;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testSelfAttentionLayer() {
|
public void testSelfAttentionLayer() {
|
||||||
|
|
|
@ -62,11 +62,6 @@ public class BNGradientCheckTest extends BaseDL4JTest {
|
||||||
Nd4j.setDataType(DataType.DOUBLE);
|
Nd4j.setDataType(DataType.DOUBLE);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 90000L;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testGradient2dSimple() {
|
public void testGradient2dSimple() {
|
||||||
DataNormalization scaler = new NormalizerMinMaxScaler();
|
DataNormalization scaler = new NormalizerMinMaxScaler();
|
||||||
|
|
|
@ -62,11 +62,6 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest {
|
||||||
Nd4j.setDataType(DataType.DOUBLE);
|
Nd4j.setDataType(DataType.DOUBLE);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 180000;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testCnn1DWithLocallyConnected1D() {
|
public void testCnn1DWithLocallyConnected1D() {
|
||||||
Nd4j.getRandom().setSeed(1337);
|
Nd4j.getRandom().setSeed(1337);
|
||||||
|
|
|
@ -59,11 +59,6 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest {
|
||||||
Nd4j.setDataType(DataType.DOUBLE);
|
Nd4j.setDataType(DataType.DOUBLE);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 90000L;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testCnn3DPlain() {
|
public void testCnn3DPlain() {
|
||||||
Nd4j.getRandom().setSeed(1337);
|
Nd4j.getRandom().setSeed(1337);
|
||||||
|
|
|
@ -73,11 +73,6 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
return CNN2DFormat.values();
|
return CNN2DFormat.values();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 999990000L;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testGradientCNNMLN() {
|
public void testGradientCNNMLN() {
|
||||||
if(this.format != CNN2DFormat.NCHW) //Only test NCHW due to flat input format...
|
if(this.format != CNN2DFormat.NCHW) //Only test NCHW due to flat input format...
|
||||||
|
|
|
@ -49,11 +49,6 @@ import java.util.Random;
|
||||||
////@Ignore
|
////@Ignore
|
||||||
public class CapsnetGradientCheckTest extends BaseDL4JTest {
|
public class CapsnetGradientCheckTest extends BaseDL4JTest {
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 90000L;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testCapsNet() {
|
public void testCapsNet() {
|
||||||
|
|
||||||
|
|
|
@ -59,11 +59,6 @@ public class DropoutGradientCheck extends BaseDL4JTest {
|
||||||
Nd4j.setDataType(DataType.DOUBLE);
|
Nd4j.setDataType(DataType.DOUBLE);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 90000L;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testDropoutGradient() {
|
public void testDropoutGradient() {
|
||||||
int minibatch = 3;
|
int minibatch = 3;
|
||||||
|
|
|
@ -55,11 +55,6 @@ public class GlobalPoolingGradientCheckTests extends BaseDL4JTest {
|
||||||
private static final double DEFAULT_MAX_REL_ERROR = 1e-3;
|
private static final double DEFAULT_MAX_REL_ERROR = 1e-3;
|
||||||
private static final double DEFAULT_MIN_ABS_ERROR = 1e-8;
|
private static final double DEFAULT_MIN_ABS_ERROR = 1e-8;
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 90000L;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testRNNGlobalPoolingBasicMultiLayer() {
|
public void testRNNGlobalPoolingBasicMultiLayer() {
|
||||||
//Basic test of global pooling w/ LSTM
|
//Basic test of global pooling w/ LSTM
|
||||||
|
|
|
@ -70,11 +70,6 @@ public class GradientCheckTests extends BaseDL4JTest {
|
||||||
Nd4j.setDataType(DataType.DOUBLE);
|
Nd4j.setDataType(DataType.DOUBLE);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 90000L;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testMinibatchApplication() {
|
public void testMinibatchApplication() {
|
||||||
IrisDataSetIterator iter = new IrisDataSetIterator(30, 150);
|
IrisDataSetIterator iter = new IrisDataSetIterator(30, 150);
|
||||||
|
|
|
@ -71,11 +71,6 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
|
||||||
Nd4j.setDataType(DataType.DOUBLE);
|
Nd4j.setDataType(DataType.DOUBLE);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 999999999L;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testBasicIris() {
|
public void testBasicIris() {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
|
@ -59,11 +59,6 @@ public class GradientCheckTestsMasking extends BaseDL4JTest {
|
||||||
Nd4j.setDataType(DataType.DOUBLE);
|
Nd4j.setDataType(DataType.DOUBLE);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 90000L;
|
|
||||||
}
|
|
||||||
|
|
||||||
private static class GradientCheckSimpleScenario {
|
private static class GradientCheckSimpleScenario {
|
||||||
private final ILossFunction lf;
|
private final ILossFunction lf;
|
||||||
private final Activation act;
|
private final Activation act;
|
||||||
|
|
|
@ -54,12 +54,6 @@ public class LRNGradientCheckTests extends BaseDL4JTest {
|
||||||
Nd4j.setDataType(DataType.DOUBLE);
|
Nd4j.setDataType(DataType.DOUBLE);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 90000L;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testGradientLRNSimple() {
|
public void testGradientLRNSimple() {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
|
@ -55,11 +55,6 @@ public class LSTMGradientCheckTests extends BaseDL4JTest {
|
||||||
Nd4j.setDataType(DataType.DOUBLE);
|
Nd4j.setDataType(DataType.DOUBLE);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 90000L;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testLSTMBasicMultiLayer() {
|
public void testLSTMBasicMultiLayer() {
|
||||||
//Basic test of GravesLSTM layer
|
//Basic test of GravesLSTM layer
|
||||||
|
|
|
@ -73,11 +73,6 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
|
||||||
private static final double DEFAULT_MAX_REL_ERROR = 1e-5;
|
private static final double DEFAULT_MAX_REL_ERROR = 1e-5;
|
||||||
private static final double DEFAULT_MIN_ABS_ERROR = 1e-8;
|
private static final double DEFAULT_MIN_ABS_ERROR = 1e-8;
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 90000L;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void lossFunctionGradientCheck() {
|
public void lossFunctionGradientCheck() {
|
||||||
ILossFunction[] lossFunctions = new ILossFunction[] {new LossBinaryXENT(), new LossBinaryXENT(),
|
ILossFunction[] lossFunctions = new ILossFunction[] {new LossBinaryXENT(), new LossBinaryXENT(),
|
||||||
|
|
|
@ -52,11 +52,6 @@ public class NoBiasGradientCheckTests extends BaseDL4JTest {
|
||||||
Nd4j.setDataType(DataType.DOUBLE);
|
Nd4j.setDataType(DataType.DOUBLE);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 90000L;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testGradientNoBiasDenseOutput() {
|
public void testGradientNoBiasDenseOutput() {
|
||||||
|
|
||||||
|
|
|
@ -52,11 +52,6 @@ public class OutputLayerGradientChecks extends BaseDL4JTest {
|
||||||
Nd4j.setDataType(DataType.DOUBLE);
|
Nd4j.setDataType(DataType.DOUBLE);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 90000L;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testRnnLossLayer() {
|
public void testRnnLossLayer() {
|
||||||
Nd4j.getRandom().setSeed(12345L);
|
Nd4j.getRandom().setSeed(12345L);
|
||||||
|
|
|
@ -55,11 +55,6 @@ public class RnnGradientChecks extends BaseDL4JTest {
|
||||||
Nd4j.setDataType(DataType.DOUBLE);
|
Nd4j.setDataType(DataType.DOUBLE);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 90000L;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
////@Ignore("AB 2019/06/24 - Ignored to get to all passing baseline to prevent regressions via CI - see issue #7912")
|
////@Ignore("AB 2019/06/24 - Ignored to get to all passing baseline to prevent regressions via CI - see issue #7912")
|
||||||
public void testBidirectionalWrapper() {
|
public void testBidirectionalWrapper() {
|
||||||
|
|
|
@ -56,11 +56,6 @@ public class UtilLayerGradientChecks extends BaseDL4JTest {
|
||||||
Nd4j.setDataType(DataType.DOUBLE);
|
Nd4j.setDataType(DataType.DOUBLE);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 90000L;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testMaskLayer() {
|
public void testMaskLayer() {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
|
@ -57,11 +57,6 @@ public class VaeGradientCheckTests extends BaseDL4JTest {
|
||||||
Nd4j.setDataType(DataType.DOUBLE);
|
Nd4j.setDataType(DataType.DOUBLE);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 90000L;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testVaeAsMLP() {
|
public void testVaeAsMLP() {
|
||||||
//Post pre-training: a VAE can be used as a MLP, by taking the mean value from p(z|x) as the output
|
//Post pre-training: a VAE can be used as a MLP, by taking the mean value from p(z|x) as the output
|
||||||
|
|
|
@ -72,11 +72,6 @@ public class YoloGradientCheckTests extends BaseDL4JTest {
|
||||||
@TempDir
|
@TempDir
|
||||||
public File testDir;
|
public File testDir;
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 90000L;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testYoloOutputLayer() {
|
public void testYoloOutputLayer() {
|
||||||
int depthIn = 2;
|
int depthIn = 2;
|
||||||
|
|
|
@ -186,11 +186,6 @@ public class DTypeTests extends BaseDL4JTest {
|
||||||
TensorFlowCnnToFeedForwardPreProcessor.class //Deprecated
|
TensorFlowCnnToFeedForwardPreProcessor.class //Deprecated
|
||||||
));
|
));
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 9999999L;
|
|
||||||
}
|
|
||||||
|
|
||||||
@AfterAll
|
@AfterAll
|
||||||
public static void after() {
|
public static void after() {
|
||||||
ImmutableSet<ClassPath.ClassInfo> info;
|
ImmutableSet<ClassPath.ClassInfo> info;
|
||||||
|
|
|
@ -26,7 +26,7 @@ import org.nd4j.linalg.activations.IActivation;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.common.primitives.Pair;
|
import org.nd4j.common.primitives.Pair;
|
||||||
|
|
||||||
@EqualsAndHashCode
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class CustomActivation extends BaseActivationFunction implements IActivation {
|
public class CustomActivation extends BaseActivationFunction implements IActivation {
|
||||||
@Override
|
@Override
|
||||||
public INDArray getActivation(INDArray in, boolean training) {
|
public INDArray getActivation(INDArray in, boolean training) {
|
||||||
|
|
|
@ -93,11 +93,6 @@ public class BatchNormalizationTest extends BaseDL4JTest {
|
||||||
public void doBefore() {
|
public void doBefore() {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 90000L;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testDnnForwardPass() {
|
public void testDnnForwardPass() {
|
||||||
int nOut = 10;
|
int nOut = 10;
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
package org.deeplearning4j.nn.layers.samediff.testlayers;
|
package org.deeplearning4j.nn.layers.samediff.testlayers;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
import org.deeplearning4j.nn.conf.graph.GraphVertex;
|
import org.deeplearning4j.nn.conf.graph.GraphVertex;
|
||||||
import org.deeplearning4j.nn.conf.layers.samediff.SDVertexParams;
|
import org.deeplearning4j.nn.conf.layers.samediff.SDVertexParams;
|
||||||
|
@ -37,6 +38,7 @@ import java.util.Map;
|
||||||
|
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
@Data
|
@Data
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class SameDiffDenseVertex extends SameDiffVertex {
|
public class SameDiffDenseVertex extends SameDiffVertex {
|
||||||
|
|
||||||
private int nIn;
|
private int nIn;
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
package org.deeplearning4j.nn.multilayer;
|
package org.deeplearning4j.nn.multilayer;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.TestUtils;
|
import org.deeplearning4j.TestUtils;
|
||||||
|
@ -1424,6 +1425,7 @@ public class MultiLayerTest extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public static class CheckModelsListener extends BaseTrainingListener {
|
public static class CheckModelsListener extends BaseTrainingListener {
|
||||||
|
|
||||||
private Set<Class<?>> modelClasses = new HashSet<>();
|
private Set<Class<?>> modelClasses = new HashSet<>();
|
||||||
|
|
|
@ -39,11 +39,6 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class EncodedGradientsAccumulatorTest extends BaseDL4JTest {
|
public class EncodedGradientsAccumulatorTest extends BaseDL4JTest {
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 1200000L;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This test ensures, that memory amount assigned to buffer is enough for any number of updates
|
* This test ensures, that memory amount assigned to buffer is enough for any number of updates
|
||||||
* @throws Exception
|
* @throws Exception
|
||||||
|
|
|
@ -39,7 +39,7 @@ import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
////@Ignore("AB 2019/05/21 - Failing (stuck, causing timeouts) - Issue #7657")
|
////@Ignore("AB 2019/05/21 - Failing (stuck, causing timeouts) - Issue #7657")
|
||||||
@Timeout(120000L)
|
//@Timeout(120000L)
|
||||||
public class SmartFancyBlockingQueueTest extends BaseDL4JTest {
|
public class SmartFancyBlockingQueueTest extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
|
|
@ -47,11 +47,6 @@ import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
public class TestCheckpointListener extends BaseDL4JTest {
|
public class TestCheckpointListener extends BaseDL4JTest {
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 90000L;
|
|
||||||
}
|
|
||||||
|
|
||||||
@TempDir
|
@TempDir
|
||||||
public File tempDir;
|
public File tempDir;
|
||||||
|
|
||||||
|
|
|
@ -67,11 +67,6 @@ public class TestListeners extends BaseDL4JTest {
|
||||||
@TempDir
|
@TempDir
|
||||||
public File tempDir;
|
public File tempDir;
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 90000L;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testSettingListenersUnsupervised() {
|
public void testSettingListenersUnsupervised() {
|
||||||
//Pretrain layers should get copies of the listeners, in addition to the
|
//Pretrain layers should get copies of the listeners, in addition to the
|
||||||
|
|
|
@ -60,11 +60,6 @@ public class RegressionTest060 extends BaseDL4JTest {
|
||||||
return DataType.FLOAT;
|
return DataType.FLOAT;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void regressionTestMLP1() throws Exception {
|
public void regressionTestMLP1() throws Exception {
|
||||||
|
|
||||||
|
|
|
@ -61,11 +61,6 @@ public class RegressionTest071 extends BaseDL4JTest {
|
||||||
return DataType.FLOAT;
|
return DataType.FLOAT;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void regressionTestMLP1() throws Exception {
|
public void regressionTestMLP1() throws Exception {
|
||||||
|
|
||||||
|
|
|
@ -60,11 +60,6 @@ public class RegressionTest080 extends BaseDL4JTest {
|
||||||
return DataType.FLOAT;
|
return DataType.FLOAT;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void regressionTestMLP1() throws Exception {
|
public void regressionTestMLP1() throws Exception {
|
||||||
|
|
||||||
|
|
|
@ -57,11 +57,6 @@ import static org.junit.jupiter.api.Assertions.*;
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class RegressionTest100a extends BaseDL4JTest {
|
public class RegressionTest100a extends BaseDL4JTest {
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public DataType getDataType(){
|
public DataType getDataType(){
|
||||||
return DataType.FLOAT;
|
return DataType.FLOAT;
|
||||||
|
|
|
@ -54,11 +54,6 @@ import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
|
||||||
public class RegressionTest100b3 extends BaseDL4JTest {
|
public class RegressionTest100b3 extends BaseDL4JTest {
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public DataType getDataType(){
|
public DataType getDataType(){
|
||||||
return DataType.FLOAT;
|
return DataType.FLOAT;
|
||||||
|
|
|
@ -73,11 +73,6 @@ import org.nd4j.common.resources.Resources;
|
||||||
|
|
||||||
public class RegressionTest100b4 extends BaseDL4JTest {
|
public class RegressionTest100b4 extends BaseDL4JTest {
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public DataType getDataType() {
|
public DataType getDataType() {
|
||||||
return DataType.FLOAT;
|
return DataType.FLOAT;
|
||||||
|
|
|
@ -60,11 +60,6 @@ public class RegressionTest100b6 extends BaseDL4JTest {
|
||||||
return DataType.FLOAT;
|
return DataType.FLOAT;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testCustomLayer() throws Exception {
|
public void testCustomLayer() throws Exception {
|
||||||
|
|
||||||
|
|
|
@ -31,11 +31,6 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
|
|
||||||
public class TestDistributionDeserializer extends BaseDL4JTest {
|
public class TestDistributionDeserializer extends BaseDL4JTest {
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testDistributionDeserializer() throws Exception {
|
public void testDistributionDeserializer() throws Exception {
|
||||||
//Test current format:
|
//Test current format:
|
||||||
|
|
|
@ -25,4 +25,5 @@ dependencies {
|
||||||
|
|
||||||
implementation "org.slf4j:slf4j-api"
|
implementation "org.slf4j:slf4j-api"
|
||||||
implementation "org.apache.commons:commons-lang3"
|
implementation "org.apache.commons:commons-lang3"
|
||||||
|
implementation "com.fasterxml.jackson.core:jackson-annotations"
|
||||||
}
|
}
|
|
@ -21,6 +21,7 @@
|
||||||
package org.deeplearning4j.nn.modelimport.keras.layers;
|
package org.deeplearning4j.nn.modelimport.keras.layers;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.lang3.ArrayUtils;
|
import org.apache.commons.lang3.ArrayUtils;
|
||||||
import org.deeplearning4j.common.config.DL4JClassLoading;
|
import org.deeplearning4j.common.config.DL4JClassLoading;
|
||||||
|
@ -46,6 +47,7 @@ import java.util.List;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@Data
|
@Data
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class TFOpLayerImpl extends AbstractLayer<TFOpLayer> {
|
public class TFOpLayerImpl extends AbstractLayer<TFOpLayer> {
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
package org.deeplearning4j.nn.modelimport.keras.layers.core;
|
package org.deeplearning4j.nn.modelimport.keras.layers.core;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer;
|
import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer;
|
||||||
|
@ -39,6 +40,7 @@ import java.util.Map;
|
||||||
*/
|
*/
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@Data
|
@Data
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class KerasMasking extends KerasLayer {
|
public class KerasMasking extends KerasLayer {
|
||||||
|
|
||||||
private double maskingValue;
|
private double maskingValue;
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
package org.deeplearning4j.nn.modelimport.keras.layers.core;
|
package org.deeplearning4j.nn.modelimport.keras.layers.core;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.deeplearning4j.nn.conf.graph.ElementWiseVertex;
|
import org.deeplearning4j.nn.conf.graph.ElementWiseVertex;
|
||||||
import org.deeplearning4j.nn.conf.graph.MergeVertex;
|
import org.deeplearning4j.nn.conf.graph.MergeVertex;
|
||||||
|
@ -35,6 +36,7 @@ import java.util.Map;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@Data
|
@Data
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class KerasMerge extends KerasLayer {
|
public class KerasMerge extends KerasLayer {
|
||||||
|
|
||||||
private final String LAYER_FIELD_MODE = "mode";
|
private final String LAYER_FIELD_MODE = "mode";
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue