More test fixes

Signed-off-by: brian <brian@brutex.de>
master
Brian Rosenberger 2022-10-07 10:49:08 +02:00
parent 6856b154b1
commit acdd9c0a8a
14 changed files with 49 additions and 42 deletions

View File

@ -66,6 +66,8 @@ dependencies {
implementation projects.cavisDnn.cavisDnnParallelwrapper implementation projects.cavisDnn.cavisDnnParallelwrapper
implementation projects.cavisZoo.cavisZooModels implementation projects.cavisZoo.cavisZooModels
testRuntimeOnly "net.brutex.ai:dl4j-test-resources:1.0.1-SNAPSHOT"
} }
test { test {

View File

@ -73,6 +73,7 @@ import org.nd4j.linalg.ops.transforms.Transforms;
import java.io.*; import java.io.*;
import java.lang.reflect.Modifier; import java.lang.reflect.Modifier;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.nio.file.Path;
import java.util.*; import java.util.*;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
@ -154,6 +155,9 @@ public class IntegrationTestRunner {
evaluationClassesSeen = new HashMap<>(); evaluationClassesSeen = new HashMap<>();
} }
public static void runTest(TestCase tc, Path testDir) throws Exception {
runTest(tc, testDir.toFile());
}
public static void runTest(TestCase tc, File testDir) throws Exception { public static void runTest(TestCase tc, File testDir) throws Exception {
BaseDL4JTest.skipUnlessIntegrationTests(); //Tests will ONLY be run if integration test profile is enabled. BaseDL4JTest.skipUnlessIntegrationTests(); //Tests will ONLY be run if integration test profile is enabled.
//This could alternatively be done via maven surefire configuration //This could alternatively be done via maven surefire configuration

View File

@ -28,18 +28,14 @@ import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import java.io.File; import java.io.File;
import java.nio.file.Path;
////@Ignore("AB - 2019/05/27 - Integration tests need to be updated") ////@Ignore("AB - 2019/05/27 - Integration tests need to be updated")
public class IntegrationTestsDL4J extends BaseDL4JTest { public class IntegrationTestsDL4J extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 300_000L;
}
@TempDir @TempDir
public File testDir; public Path testDir;
@AfterAll @AfterAll
public static void afterClass(){ public static void afterClass(){

View File

@ -30,12 +30,6 @@ import java.io.File;
public class IntegrationTestsSameDiff extends BaseDL4JTest { public class IntegrationTestsSameDiff extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 300_000L;
}
@TempDir @TempDir
public File testDir; public File testDir;

View File

@ -65,6 +65,7 @@ dependencies {
/*Logging*/ /*Logging*/
api 'org.slf4j:slf4j-api:1.7.30' api 'org.slf4j:slf4j-api:1.7.30'
api 'org.slf4j:slf4j-simple:1.7.25'
api "org.apache.logging.log4j:log4j-core:2.17.0" api "org.apache.logging.log4j:log4j-core:2.17.0"
api "ch.qos.logback:logback-classic:1.2.3" api "ch.qos.logback:logback-classic:1.2.3"

View File

@ -48,6 +48,11 @@ import static org.bytedeco.opencv.global.opencv_core.*;
import static org.bytedeco.opencv.global.opencv_imgcodecs.*; import static org.bytedeco.opencv.global.opencv_imgcodecs.*;
import static org.bytedeco.opencv.global.opencv_imgproc.*; import static org.bytedeco.opencv.global.opencv_imgproc.*;
/**
* Uses JavaCV to load images. Allowed formats: bmp, gif, jpg, jpeg, jp2, pbm, pgm, ppm, pnm, png, tif, tiff, exr, webp
*
* @author saudet
*/
public class NativeImageLoader extends BaseImageLoader { public class NativeImageLoader extends BaseImageLoader {
private static final int MIN_BUFFER_STEP_SIZE = 64 * 1024; private static final int MIN_BUFFER_STEP_SIZE = 64 * 1024;
private byte[] buffer = null; private byte[] buffer = null;
@ -57,14 +62,16 @@ public class NativeImageLoader extends BaseImageLoader {
"png", "tif", "tiff", "exr", "webp", "BMP", "GIF", "JPG", "JPEG", "JP2", "PBM", "PGM", "PPM", "PNM", "png", "tif", "tiff", "exr", "webp", "BMP", "GIF", "JPG", "JPEG", "JP2", "PBM", "PGM", "PPM", "PNM",
"PNG", "TIF", "TIFF", "EXR", "WEBP"}; "PNG", "TIF", "TIFF", "EXR", "WEBP"};
protected OpenCVFrameConverter.ToMat converter = new OpenCVFrameConverter.ToMat(); protected OpenCVFrameConverter.ToMat converter;
boolean direct = !Loader.getPlatform().startsWith("android"); boolean direct = !Loader.getPlatform().startsWith("android");
/** /**
* Loads images with no scaling or conversion. * Loads images with no scaling or conversion.
*/ */
public NativeImageLoader() {} public NativeImageLoader() {
this.converter = new OpenCVFrameConverter.ToMat();
}
/** /**
* Instantiate an image with the given * Instantiate an image with the given
@ -74,6 +81,7 @@ public class NativeImageLoader extends BaseImageLoader {
*/ */
public NativeImageLoader(long height, long width) { public NativeImageLoader(long height, long width) {
this();
this.height = height; this.height = height;
this.width = width; this.width = width;
} }
@ -87,8 +95,7 @@ public class NativeImageLoader extends BaseImageLoader {
* @param channels the number of channels for the image* * @param channels the number of channels for the image*
*/ */
public NativeImageLoader(long height, long width, long channels) { public NativeImageLoader(long height, long width, long channels) {
this.height = height; this(height, width);
this.width = width;
this.channels = channels; this.channels = channels;
} }
@ -132,12 +139,9 @@ public class NativeImageLoader extends BaseImageLoader {
} }
protected NativeImageLoader(NativeImageLoader other) { protected NativeImageLoader(NativeImageLoader other) {
this.height = other.height; this(other.height, other.width, other.channels, other.multiPageMode);
this.width = other.width;
this.channels = other.channels;
this.centerCropIfNeeded = other.centerCropIfNeeded; this.centerCropIfNeeded = other.centerCropIfNeeded;
this.imageTransform = other.imageTransform; this.imageTransform = other.imageTransform;
this.multiPageMode = other.multiPageMode;
} }
@Override @Override
@ -297,7 +301,7 @@ public class NativeImageLoader extends BaseImageLoader {
private Mat streamToMat(InputStream is) throws IOException { private Mat streamToMat(InputStream is) throws IOException {
if(buffer == null){ if(buffer == null){
buffer = IOUtils.toByteArray(is); buffer = IOUtils.toByteArray(is);
if(buffer.length <= 0){ if(buffer.length == 0){
throw new IOException("Could not decode image from input stream: input stream was empty (no data)"); throw new IOException("Could not decode image from input stream: input stream was empty (no data)");
} }
bufferMat = new Mat(buffer); bufferMat = new Mat(buffer);
@ -545,10 +549,15 @@ public class NativeImageLoader extends BaseImageLoader {
} }
public void asMatrixView(InputStream is, INDArray view) throws IOException { public void asMatrixView(InputStream is, INDArray view) throws IOException {
Mat mat = streamToMat(is); throw new RuntimeException("Not implemented");
Mat image = imdecode(mat, IMREAD_ANYDEPTH | IMREAD_ANYCOLOR);
}
public void asMatrixView(String filename, INDArray view) throws IOException {
Mat image = imread(filename,IMREAD_ANYDEPTH | IMREAD_ANYCOLOR );
//Mat image = imdecode(mat, IMREAD_ANYDEPTH | IMREAD_ANYCOLOR);
if (image == null || image.empty()) { if (image == null || image.empty()) {
PIX pix = pixReadMem(mat.data(), mat.cols()); PIX pix = pixReadMem(image.data(), image.cols());
if (pix == null) { if (pix == null) {
throw new IOException("Could not decode image from input stream"); throw new IOException("Could not decode image from input stream");
} }
@ -561,14 +570,8 @@ public class NativeImageLoader extends BaseImageLoader {
image.deallocate(); image.deallocate();
} }
public void asMatrixView(String filename, INDArray view) throws IOException {
asMatrixView(new File(filename), view);
}
public void asMatrixView(File f, INDArray view) throws IOException { public void asMatrixView(File f, INDArray view) throws IOException {
try (BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f))) { asMatrixView(f.getAbsolutePath(), view);
asMatrixView(bis, view);
}
} }
public void asMatrixView(Mat image, INDArray view) throws IOException { public void asMatrixView(Mat image, INDArray view) throws IOException {

View File

@ -53,6 +53,10 @@ import java.io.*;
import java.net.URI; import java.net.URI;
import java.util.*; import java.util.*;
/**
* Base class for the image record reader
*
*/
@Slf4j @Slf4j
public abstract class BaseImageRecordReader extends BaseRecordReader { public abstract class BaseImageRecordReader extends BaseRecordReader {
protected boolean finishedInputStreamSplit; protected boolean finishedInputStreamSplit;
@ -344,7 +348,8 @@ public abstract class BaseImageRecordReader extends BaseRecordReader {
((NativeImageLoader) imageLoader).asMatrixView(currBatch.get(i), ((NativeImageLoader) imageLoader).asMatrixView(currBatch.get(i),
features.tensorAlongDimension(i, 1, 2, 3)); features.tensorAlongDimension(i, 1, 2, 3));
} catch (Exception e) { } catch (Exception e) {
System.out.println("Image file failed during load: " + currBatch.get(i).getAbsolutePath()); System.out.println("Image file failed during load: " + currBatch.get(i).getAbsolutePath() + "\n" + e.getMessage());
e.printStackTrace();
throw new RuntimeException(e); throw new RuntimeException(e);
} }
} }

View File

@ -58,13 +58,6 @@ public abstract class BaseDL4JTest {
return DEFAULT_THREADS; return DEFAULT_THREADS;
} }
/**
* Override this method to set the default timeout for methods in the test class
*/
public long getTimeoutMilliseconds(){
return 90_000;
}
/** /**
* Override this to set the profiling mode for the tests defined in the child class * Override this to set the profiling mode for the tests defined in the child class
*/ */

View File

@ -21,6 +21,7 @@
package org.deeplearning4j.earlystopping; package org.deeplearning4j.earlystopping;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.TestUtils; import org.deeplearning4j.TestUtils;
@ -817,6 +818,7 @@ public class TestEarlyStopping extends BaseDL4JTest {
} }
@Data @Data
@EqualsAndHashCode(callSuper = false)
public static class TestListener extends BaseTrainingListener { public static class TestListener extends BaseTrainingListener {
private int countEpochStart = 0; private int countEpochStart = 0;
private int countEpochEnd = 0; private int countEpochEnd = 0;

View File

@ -26,7 +26,7 @@ import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.common.primitives.Pair; import org.nd4j.common.primitives.Pair;
@EqualsAndHashCode @EqualsAndHashCode(callSuper = false)
public class CustomActivation extends BaseActivationFunction implements IActivation { public class CustomActivation extends BaseActivationFunction implements IActivation {
@Override @Override
public INDArray getActivation(INDArray in, boolean training) { public INDArray getActivation(INDArray in, boolean training) {

View File

@ -21,6 +21,7 @@
package org.deeplearning4j.nn.layers.samediff.testlayers; package org.deeplearning4j.nn.layers.samediff.testlayers;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import org.deeplearning4j.nn.conf.graph.GraphVertex; import org.deeplearning4j.nn.conf.graph.GraphVertex;
import org.deeplearning4j.nn.conf.layers.samediff.SDVertexParams; import org.deeplearning4j.nn.conf.layers.samediff.SDVertexParams;
@ -37,6 +38,7 @@ import java.util.Map;
@NoArgsConstructor @NoArgsConstructor
@Data @Data
@EqualsAndHashCode(callSuper = false)
public class SameDiffDenseVertex extends SameDiffVertex { public class SameDiffDenseVertex extends SameDiffVertex {
private int nIn; private int nIn;

View File

@ -27,6 +27,7 @@ dependencies {
implementation "com.fasterxml.jackson.core:jackson-core" implementation "com.fasterxml.jackson.core:jackson-core"
implementation "com.fasterxml.jackson.core:jackson-databind" implementation "com.fasterxml.jackson.core:jackson-databind"
implementation "org.slf4j:slf4j-api" implementation "org.slf4j:slf4j-api"
implementation "org.slf4j:slf4j-simple"
implementation projects.cavisNd4j.cavisNd4jParameterServer.cavisNd4jParameterServerModel implementation projects.cavisNd4j.cavisNd4jParameterServer.cavisNd4jParameterServerModel
implementation projects.cavisNd4j.cavisNd4jAeron implementation projects.cavisNd4j.cavisNd4jAeron
implementation projects.cavisDnn.cavisDnnApi implementation projects.cavisDnn.cavisDnnApi

View File

@ -56,7 +56,11 @@ public class RemoteParameterServerClientTests extends BaseND4JTest {
new MediaDriver.Context().threadingMode(ThreadingMode.DEDICATED).dirDeleteOnStart(true) new MediaDriver.Context().threadingMode(ThreadingMode.DEDICATED).dirDeleteOnStart(true)
.termBufferSparseFile(false).conductorIdleStrategy(new BusySpinIdleStrategy()) .termBufferSparseFile(false).conductorIdleStrategy(new BusySpinIdleStrategy())
.receiverIdleStrategy(new BusySpinIdleStrategy()) .receiverIdleStrategy(new BusySpinIdleStrategy())
.senderIdleStrategy(new BusySpinIdleStrategy()); .senderIdleStrategy(new BusySpinIdleStrategy())
.driverTimeoutMs(1000*1000 *1000)
.clientLivenessTimeoutNs(1000*1000*1000)
.timerIntervalNs( 1000 * 1000);
mediaDriver = MediaDriver.launchEmbedded(ctx); mediaDriver = MediaDriver.launchEmbedded(ctx);
aeron = Aeron.connect(getContext()); aeron = Aeron.connect(getContext());

View File

@ -325,7 +325,7 @@ public class ParameterServerSubscriber implements AutoCloseable {
int tries=0; int tries=0;
while (!subscriber.launched() && tries<12) { while (!subscriber.launched() && tries<12) {
tries++; tries++;
Thread.sleep(1000); Thread.sleep(2000);
} }
if(!subscriber.launched()) { if(!subscriber.launched()) {
throw new Exception("Subscriber did not start in time."); throw new Exception("Subscriber did not start in time.");