parent
6856b154b1
commit
acdd9c0a8a
|
@ -66,6 +66,8 @@ dependencies {
|
||||||
implementation projects.cavisDnn.cavisDnnParallelwrapper
|
implementation projects.cavisDnn.cavisDnnParallelwrapper
|
||||||
|
|
||||||
implementation projects.cavisZoo.cavisZooModels
|
implementation projects.cavisZoo.cavisZooModels
|
||||||
|
|
||||||
|
testRuntimeOnly "net.brutex.ai:dl4j-test-resources:1.0.1-SNAPSHOT"
|
||||||
}
|
}
|
||||||
|
|
||||||
test {
|
test {
|
||||||
|
|
|
@ -73,6 +73,7 @@ import org.nd4j.linalg.ops.transforms.Transforms;
|
||||||
import java.io.*;
|
import java.io.*;
|
||||||
import java.lang.reflect.Modifier;
|
import java.lang.reflect.Modifier;
|
||||||
import java.nio.charset.StandardCharsets;
|
import java.nio.charset.StandardCharsets;
|
||||||
|
import java.nio.file.Path;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
import java.util.concurrent.atomic.AtomicInteger;
|
import java.util.concurrent.atomic.AtomicInteger;
|
||||||
|
|
||||||
|
@ -154,6 +155,9 @@ public class IntegrationTestRunner {
|
||||||
evaluationClassesSeen = new HashMap<>();
|
evaluationClassesSeen = new HashMap<>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static void runTest(TestCase tc, Path testDir) throws Exception {
|
||||||
|
runTest(tc, testDir.toFile());
|
||||||
|
}
|
||||||
public static void runTest(TestCase tc, File testDir) throws Exception {
|
public static void runTest(TestCase tc, File testDir) throws Exception {
|
||||||
BaseDL4JTest.skipUnlessIntegrationTests(); //Tests will ONLY be run if integration test profile is enabled.
|
BaseDL4JTest.skipUnlessIntegrationTests(); //Tests will ONLY be run if integration test profile is enabled.
|
||||||
//This could alternatively be done via maven surefire configuration
|
//This could alternatively be done via maven surefire configuration
|
||||||
|
|
|
@ -28,18 +28,14 @@ import org.junit.jupiter.api.Test;
|
||||||
import org.junit.jupiter.api.io.TempDir;
|
import org.junit.jupiter.api.io.TempDir;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
|
import java.nio.file.Path;
|
||||||
|
|
||||||
|
|
||||||
////@Ignore("AB - 2019/05/27 - Integration tests need to be updated")
|
////@Ignore("AB - 2019/05/27 - Integration tests need to be updated")
|
||||||
public class IntegrationTestsDL4J extends BaseDL4JTest {
|
public class IntegrationTestsDL4J extends BaseDL4JTest {
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 300_000L;
|
|
||||||
}
|
|
||||||
|
|
||||||
@TempDir
|
@TempDir
|
||||||
public File testDir;
|
public Path testDir;
|
||||||
|
|
||||||
@AfterAll
|
@AfterAll
|
||||||
public static void afterClass(){
|
public static void afterClass(){
|
||||||
|
|
|
@ -30,12 +30,6 @@ import java.io.File;
|
||||||
|
|
||||||
|
|
||||||
public class IntegrationTestsSameDiff extends BaseDL4JTest {
|
public class IntegrationTestsSameDiff extends BaseDL4JTest {
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 300_000L;
|
|
||||||
}
|
|
||||||
|
|
||||||
@TempDir
|
@TempDir
|
||||||
public File testDir;
|
public File testDir;
|
||||||
|
|
||||||
|
|
|
@ -65,6 +65,7 @@ dependencies {
|
||||||
|
|
||||||
/*Logging*/
|
/*Logging*/
|
||||||
api 'org.slf4j:slf4j-api:1.7.30'
|
api 'org.slf4j:slf4j-api:1.7.30'
|
||||||
|
api 'org.slf4j:slf4j-simple:1.7.25'
|
||||||
|
|
||||||
api "org.apache.logging.log4j:log4j-core:2.17.0"
|
api "org.apache.logging.log4j:log4j-core:2.17.0"
|
||||||
api "ch.qos.logback:logback-classic:1.2.3"
|
api "ch.qos.logback:logback-classic:1.2.3"
|
||||||
|
|
|
@ -48,6 +48,11 @@ import static org.bytedeco.opencv.global.opencv_core.*;
|
||||||
import static org.bytedeco.opencv.global.opencv_imgcodecs.*;
|
import static org.bytedeco.opencv.global.opencv_imgcodecs.*;
|
||||||
import static org.bytedeco.opencv.global.opencv_imgproc.*;
|
import static org.bytedeco.opencv.global.opencv_imgproc.*;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Uses JavaCV to load images. Allowed formats: bmp, gif, jpg, jpeg, jp2, pbm, pgm, ppm, pnm, png, tif, tiff, exr, webp
|
||||||
|
*
|
||||||
|
* @author saudet
|
||||||
|
*/
|
||||||
public class NativeImageLoader extends BaseImageLoader {
|
public class NativeImageLoader extends BaseImageLoader {
|
||||||
private static final int MIN_BUFFER_STEP_SIZE = 64 * 1024;
|
private static final int MIN_BUFFER_STEP_SIZE = 64 * 1024;
|
||||||
private byte[] buffer = null;
|
private byte[] buffer = null;
|
||||||
|
@ -57,14 +62,16 @@ public class NativeImageLoader extends BaseImageLoader {
|
||||||
"png", "tif", "tiff", "exr", "webp", "BMP", "GIF", "JPG", "JPEG", "JP2", "PBM", "PGM", "PPM", "PNM",
|
"png", "tif", "tiff", "exr", "webp", "BMP", "GIF", "JPG", "JPEG", "JP2", "PBM", "PGM", "PPM", "PNM",
|
||||||
"PNG", "TIF", "TIFF", "EXR", "WEBP"};
|
"PNG", "TIF", "TIFF", "EXR", "WEBP"};
|
||||||
|
|
||||||
protected OpenCVFrameConverter.ToMat converter = new OpenCVFrameConverter.ToMat();
|
protected OpenCVFrameConverter.ToMat converter;
|
||||||
|
|
||||||
boolean direct = !Loader.getPlatform().startsWith("android");
|
boolean direct = !Loader.getPlatform().startsWith("android");
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Loads images with no scaling or conversion.
|
* Loads images with no scaling or conversion.
|
||||||
*/
|
*/
|
||||||
public NativeImageLoader() {}
|
public NativeImageLoader() {
|
||||||
|
this.converter = new OpenCVFrameConverter.ToMat();
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Instantiate an image with the given
|
* Instantiate an image with the given
|
||||||
|
@ -74,6 +81,7 @@ public class NativeImageLoader extends BaseImageLoader {
|
||||||
|
|
||||||
*/
|
*/
|
||||||
public NativeImageLoader(long height, long width) {
|
public NativeImageLoader(long height, long width) {
|
||||||
|
this();
|
||||||
this.height = height;
|
this.height = height;
|
||||||
this.width = width;
|
this.width = width;
|
||||||
}
|
}
|
||||||
|
@ -87,8 +95,7 @@ public class NativeImageLoader extends BaseImageLoader {
|
||||||
* @param channels the number of channels for the image*
|
* @param channels the number of channels for the image*
|
||||||
*/
|
*/
|
||||||
public NativeImageLoader(long height, long width, long channels) {
|
public NativeImageLoader(long height, long width, long channels) {
|
||||||
this.height = height;
|
this(height, width);
|
||||||
this.width = width;
|
|
||||||
this.channels = channels;
|
this.channels = channels;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -132,12 +139,9 @@ public class NativeImageLoader extends BaseImageLoader {
|
||||||
}
|
}
|
||||||
|
|
||||||
protected NativeImageLoader(NativeImageLoader other) {
|
protected NativeImageLoader(NativeImageLoader other) {
|
||||||
this.height = other.height;
|
this(other.height, other.width, other.channels, other.multiPageMode);
|
||||||
this.width = other.width;
|
|
||||||
this.channels = other.channels;
|
|
||||||
this.centerCropIfNeeded = other.centerCropIfNeeded;
|
this.centerCropIfNeeded = other.centerCropIfNeeded;
|
||||||
this.imageTransform = other.imageTransform;
|
this.imageTransform = other.imageTransform;
|
||||||
this.multiPageMode = other.multiPageMode;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -297,7 +301,7 @@ public class NativeImageLoader extends BaseImageLoader {
|
||||||
private Mat streamToMat(InputStream is) throws IOException {
|
private Mat streamToMat(InputStream is) throws IOException {
|
||||||
if(buffer == null){
|
if(buffer == null){
|
||||||
buffer = IOUtils.toByteArray(is);
|
buffer = IOUtils.toByteArray(is);
|
||||||
if(buffer.length <= 0){
|
if(buffer.length == 0){
|
||||||
throw new IOException("Could not decode image from input stream: input stream was empty (no data)");
|
throw new IOException("Could not decode image from input stream: input stream was empty (no data)");
|
||||||
}
|
}
|
||||||
bufferMat = new Mat(buffer);
|
bufferMat = new Mat(buffer);
|
||||||
|
@ -545,10 +549,15 @@ public class NativeImageLoader extends BaseImageLoader {
|
||||||
}
|
}
|
||||||
|
|
||||||
public void asMatrixView(InputStream is, INDArray view) throws IOException {
|
public void asMatrixView(InputStream is, INDArray view) throws IOException {
|
||||||
Mat mat = streamToMat(is);
|
throw new RuntimeException("Not implemented");
|
||||||
Mat image = imdecode(mat, IMREAD_ANYDEPTH | IMREAD_ANYCOLOR);
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public void asMatrixView(String filename, INDArray view) throws IOException {
|
||||||
|
Mat image = imread(filename,IMREAD_ANYDEPTH | IMREAD_ANYCOLOR );
|
||||||
|
//Mat image = imdecode(mat, IMREAD_ANYDEPTH | IMREAD_ANYCOLOR);
|
||||||
if (image == null || image.empty()) {
|
if (image == null || image.empty()) {
|
||||||
PIX pix = pixReadMem(mat.data(), mat.cols());
|
PIX pix = pixReadMem(image.data(), image.cols());
|
||||||
if (pix == null) {
|
if (pix == null) {
|
||||||
throw new IOException("Could not decode image from input stream");
|
throw new IOException("Could not decode image from input stream");
|
||||||
}
|
}
|
||||||
|
@ -561,14 +570,8 @@ public class NativeImageLoader extends BaseImageLoader {
|
||||||
image.deallocate();
|
image.deallocate();
|
||||||
}
|
}
|
||||||
|
|
||||||
public void asMatrixView(String filename, INDArray view) throws IOException {
|
|
||||||
asMatrixView(new File(filename), view);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void asMatrixView(File f, INDArray view) throws IOException {
|
public void asMatrixView(File f, INDArray view) throws IOException {
|
||||||
try (BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f))) {
|
asMatrixView(f.getAbsolutePath(), view);
|
||||||
asMatrixView(bis, view);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public void asMatrixView(Mat image, INDArray view) throws IOException {
|
public void asMatrixView(Mat image, INDArray view) throws IOException {
|
||||||
|
|
|
@ -53,6 +53,10 @@ import java.io.*;
|
||||||
import java.net.URI;
|
import java.net.URI;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Base class for the image record reader
|
||||||
|
*
|
||||||
|
*/
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public abstract class BaseImageRecordReader extends BaseRecordReader {
|
public abstract class BaseImageRecordReader extends BaseRecordReader {
|
||||||
protected boolean finishedInputStreamSplit;
|
protected boolean finishedInputStreamSplit;
|
||||||
|
@ -344,7 +348,8 @@ public abstract class BaseImageRecordReader extends BaseRecordReader {
|
||||||
((NativeImageLoader) imageLoader).asMatrixView(currBatch.get(i),
|
((NativeImageLoader) imageLoader).asMatrixView(currBatch.get(i),
|
||||||
features.tensorAlongDimension(i, 1, 2, 3));
|
features.tensorAlongDimension(i, 1, 2, 3));
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
System.out.println("Image file failed during load: " + currBatch.get(i).getAbsolutePath());
|
System.out.println("Image file failed during load: " + currBatch.get(i).getAbsolutePath() + "\n" + e.getMessage());
|
||||||
|
e.printStackTrace();
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -58,13 +58,6 @@ public abstract class BaseDL4JTest {
|
||||||
return DEFAULT_THREADS;
|
return DEFAULT_THREADS;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Override this method to set the default timeout for methods in the test class
|
|
||||||
*/
|
|
||||||
public long getTimeoutMilliseconds(){
|
|
||||||
return 90_000;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Override this to set the profiling mode for the tests defined in the child class
|
* Override this to set the profiling mode for the tests defined in the child class
|
||||||
*/
|
*/
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
package org.deeplearning4j.earlystopping;
|
package org.deeplearning4j.earlystopping;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.TestUtils;
|
import org.deeplearning4j.TestUtils;
|
||||||
|
@ -817,6 +818,7 @@ public class TestEarlyStopping extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public static class TestListener extends BaseTrainingListener {
|
public static class TestListener extends BaseTrainingListener {
|
||||||
private int countEpochStart = 0;
|
private int countEpochStart = 0;
|
||||||
private int countEpochEnd = 0;
|
private int countEpochEnd = 0;
|
||||||
|
|
|
@ -26,7 +26,7 @@ import org.nd4j.linalg.activations.IActivation;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.common.primitives.Pair;
|
import org.nd4j.common.primitives.Pair;
|
||||||
|
|
||||||
@EqualsAndHashCode
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class CustomActivation extends BaseActivationFunction implements IActivation {
|
public class CustomActivation extends BaseActivationFunction implements IActivation {
|
||||||
@Override
|
@Override
|
||||||
public INDArray getActivation(INDArray in, boolean training) {
|
public INDArray getActivation(INDArray in, boolean training) {
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
package org.deeplearning4j.nn.layers.samediff.testlayers;
|
package org.deeplearning4j.nn.layers.samediff.testlayers;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
import org.deeplearning4j.nn.conf.graph.GraphVertex;
|
import org.deeplearning4j.nn.conf.graph.GraphVertex;
|
||||||
import org.deeplearning4j.nn.conf.layers.samediff.SDVertexParams;
|
import org.deeplearning4j.nn.conf.layers.samediff.SDVertexParams;
|
||||||
|
@ -37,6 +38,7 @@ import java.util.Map;
|
||||||
|
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
@Data
|
@Data
|
||||||
|
@EqualsAndHashCode(callSuper = false)
|
||||||
public class SameDiffDenseVertex extends SameDiffVertex {
|
public class SameDiffDenseVertex extends SameDiffVertex {
|
||||||
|
|
||||||
private int nIn;
|
private int nIn;
|
||||||
|
|
|
@ -27,6 +27,7 @@ dependencies {
|
||||||
implementation "com.fasterxml.jackson.core:jackson-core"
|
implementation "com.fasterxml.jackson.core:jackson-core"
|
||||||
implementation "com.fasterxml.jackson.core:jackson-databind"
|
implementation "com.fasterxml.jackson.core:jackson-databind"
|
||||||
implementation "org.slf4j:slf4j-api"
|
implementation "org.slf4j:slf4j-api"
|
||||||
|
implementation "org.slf4j:slf4j-simple"
|
||||||
implementation projects.cavisNd4j.cavisNd4jParameterServer.cavisNd4jParameterServerModel
|
implementation projects.cavisNd4j.cavisNd4jParameterServer.cavisNd4jParameterServerModel
|
||||||
implementation projects.cavisNd4j.cavisNd4jAeron
|
implementation projects.cavisNd4j.cavisNd4jAeron
|
||||||
implementation projects.cavisDnn.cavisDnnApi
|
implementation projects.cavisDnn.cavisDnnApi
|
||||||
|
|
|
@ -56,7 +56,11 @@ public class RemoteParameterServerClientTests extends BaseND4JTest {
|
||||||
new MediaDriver.Context().threadingMode(ThreadingMode.DEDICATED).dirDeleteOnStart(true)
|
new MediaDriver.Context().threadingMode(ThreadingMode.DEDICATED).dirDeleteOnStart(true)
|
||||||
.termBufferSparseFile(false).conductorIdleStrategy(new BusySpinIdleStrategy())
|
.termBufferSparseFile(false).conductorIdleStrategy(new BusySpinIdleStrategy())
|
||||||
.receiverIdleStrategy(new BusySpinIdleStrategy())
|
.receiverIdleStrategy(new BusySpinIdleStrategy())
|
||||||
.senderIdleStrategy(new BusySpinIdleStrategy());
|
.senderIdleStrategy(new BusySpinIdleStrategy())
|
||||||
|
.driverTimeoutMs(1000*1000 *1000)
|
||||||
|
.clientLivenessTimeoutNs(1000*1000*1000)
|
||||||
|
.timerIntervalNs( 1000 * 1000);
|
||||||
|
|
||||||
|
|
||||||
mediaDriver = MediaDriver.launchEmbedded(ctx);
|
mediaDriver = MediaDriver.launchEmbedded(ctx);
|
||||||
aeron = Aeron.connect(getContext());
|
aeron = Aeron.connect(getContext());
|
||||||
|
|
|
@ -325,7 +325,7 @@ public class ParameterServerSubscriber implements AutoCloseable {
|
||||||
int tries=0;
|
int tries=0;
|
||||||
while (!subscriber.launched() && tries<12) {
|
while (!subscriber.launched() && tries<12) {
|
||||||
tries++;
|
tries++;
|
||||||
Thread.sleep(1000);
|
Thread.sleep(2000);
|
||||||
}
|
}
|
||||||
if(!subscriber.launched()) {
|
if(!subscriber.launched()) {
|
||||||
throw new Exception("Subscriber did not start in time.");
|
throw new Exception("Subscriber did not start in time.");
|
||||||
|
|
Loading…
Reference in New Issue