diff --git a/nd4j/README.md b/nd4j/README.md
index f5ed19b7c..c4dbf445a 100644
--- a/nd4j/README.md
+++ b/nd4j/README.md
@@ -14,7 +14,6 @@ Please search for the latest version on search.maven.org.
Or use the versions displayed in:
https://github.com/eclipse/deeplearning4j-examples/blob/master/pom.xml
-
---
## Main Features
@@ -47,6 +46,29 @@ To install ND4J, there are a couple of approaches, and more information can be f
#### Clone from the GitHub Repo
https://deeplearning4j.org/docs/latest/deeplearning4j-build-from-source
+
+#### Build from sources
+
+To build `ND4J` from sources launch from the present directory:
+
+```shell script
+$ mvn clean install -DskipTests=true
+```
+
+To run tests using CPU or CUDA backend run the following.
+
+For CPU:
+
+```shell script
+$ mvn clean test -P testresources -P nd4j-testresources -P nd4j-tests-cpu -P nd4j-tf-cpu
+```
+
+For CUDA:
+
+```shell script
+$ mvn clean test -P testresources -P nd4j-testresources -P nd4j-tests-cuda -P nd4j-tf-gpu
+```
+
## Contribute
1. Check for open issues, or open a new issue to start a discussion around a feature idea or a bug.
diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java
index 09eb32ea2..916ba0a08 100644
--- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java
+++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java
@@ -1,4 +1,4 @@
-// Targeted by JavaCPP version 1.5.4-SNAPSHOT: DO NOT EDIT THIS FILE
+// Targeted by JavaCPP version 1.5.4: DO NOT EDIT THIS FILE
package org.nd4j.nativeblas;
@@ -6,6 +6,10 @@ import java.nio.*;
import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.annotation.*;
+import static org.bytedeco.javacpp.presets.javacpp.*;
+import static org.bytedeco.openblas.global.openblas_nolapack.*;
+import static org.bytedeco.openblas.global.openblas.*;
+
public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper {
static { Loader.load(); }
diff --git a/nd4j/nd4j-backends/nd4j-tests/pom.xml b/nd4j/nd4j-backends/nd4j-tests/pom.xml
index a62f64ed5..4900902ef 100644
--- a/nd4j/nd4j-backends/nd4j-tests/pom.xml
+++ b/nd4j/nd4j-backends/nd4j-tests/pom.xml
@@ -128,7 +128,6 @@
${project.version}
test
-
com.github.stephenc.jcip
jcip-annotations
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/TestOpMapping.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/TestOpMapping.java
index dcc82e0aa..48d433214 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/TestOpMapping.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/TestOpMapping.java
@@ -17,12 +17,14 @@
package org.nd4j.autodiff;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils;
import org.junit.Ignore;
import org.junit.Test;
import org.nd4j.autodiff.functions.DifferentialFunction;
-import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.imports.converters.ImportClassMapping;
import org.nd4j.linalg.BaseNd4jTest;
@@ -30,6 +32,36 @@ import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.NoOp;
import org.nd4j.linalg.api.ops.compat.CompatSparseToDense;
import org.nd4j.linalg.api.ops.compat.CompatStringSplit;
+import org.nd4j.linalg.api.ops.custom.BarnesEdgeForces;
+import org.nd4j.linalg.api.ops.custom.BarnesHutGains;
+import org.nd4j.linalg.api.ops.custom.BarnesHutSymmetrize;
+import org.nd4j.linalg.api.ops.custom.KnnMinDistance;
+import org.nd4j.linalg.api.ops.custom.SpTreeCell;
+import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastAMax;
+import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastAMin;
+import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastAddOp;
+import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastCopyOp;
+import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastDivOp;
+import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastGradientArgs;
+import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMax;
+import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMin;
+import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMulOp;
+import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastRDivOp;
+import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastRSubOp;
+import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastSubOp;
+import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastTo;
+import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastEqualTo;
+import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastGreaterThan;
+import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastGreaterThanOrEqual;
+import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastLessThan;
+import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastLessThanOrEqual;
+import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastNotEqual;
+import org.nd4j.linalg.api.ops.impl.controlflow.compat.Enter;
+import org.nd4j.linalg.api.ops.impl.controlflow.compat.Exit;
+import org.nd4j.linalg.api.ops.impl.controlflow.compat.LoopCond;
+import org.nd4j.linalg.api.ops.impl.controlflow.compat.Merge;
+import org.nd4j.linalg.api.ops.impl.controlflow.compat.NextIteration;
+import org.nd4j.linalg.api.ops.impl.controlflow.compat.Switch;
import org.nd4j.linalg.api.ops.impl.grid.FreeGridOp;
import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax;
import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin;
@@ -55,6 +87,15 @@ import org.nd4j.linalg.api.ops.impl.transforms.pairwise.BinaryRelativeError;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.CopyOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.PowPairwise;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RealDivOp;
+import org.nd4j.linalg.api.ops.impl.updaters.AdaDeltaUpdater;
+import org.nd4j.linalg.api.ops.impl.updaters.AdaGradUpdater;
+import org.nd4j.linalg.api.ops.impl.updaters.AdaMaxUpdater;
+import org.nd4j.linalg.api.ops.impl.updaters.AdamUpdater;
+import org.nd4j.linalg.api.ops.impl.updaters.AmsGradUpdater;
+import org.nd4j.linalg.api.ops.impl.updaters.NadamUpdater;
+import org.nd4j.linalg.api.ops.impl.updaters.NesterovsUpdater;
+import org.nd4j.linalg.api.ops.impl.updaters.RmsPropUpdater;
+import org.nd4j.linalg.api.ops.impl.updaters.SgdUpdater;
import org.nd4j.linalg.api.ops.persistence.RestoreV2;
import org.nd4j.linalg.api.ops.persistence.SaveV2;
import org.nd4j.linalg.api.ops.util.PrintAffinity;
@@ -66,13 +107,17 @@ import org.reflections.Reflections;
import java.io.File;
import java.lang.reflect.Modifier;
import java.nio.charset.StandardCharsets;
-import java.util.*;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertTrue;
-
public class TestOpMapping extends BaseNd4jTest {
Set> subTypes;
@@ -303,9 +348,6 @@ public class TestOpMapping extends BaseNd4jTest {
s.add(PrintVariable.class);
s.add(PrintAffinity.class);
s.add(Assign.class);
-
-
-
}
@Test @Ignore
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java
index aad5f4a61..d4499dbf5 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java
@@ -35,6 +35,15 @@ import org.nd4j.linalg.api.ops.impl.layers.convolution.AvgPooling2D;
import org.nd4j.linalg.api.ops.impl.layers.convolution.DepthwiseConv2D;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2DDerivative;
+import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig;
+import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
+import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig;
+import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv2DConfig;
+import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv3DConfig;
+import org.nd4j.linalg.api.ops.impl.layers.convolution.config.LocalResponseNormalizationConfig;
+import org.nd4j.linalg.api.ops.impl.layers.convolution.config.PaddingMode;
+import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig;
+import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling3DConfig;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.GRU;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMActivations;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMDataFormat;
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java
index e24aa65fc..7e5af6864 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java
@@ -32,6 +32,18 @@ import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
+import org.nd4j.linalg.api.ops.custom.Digamma;
+import org.nd4j.linalg.api.ops.custom.DivideNoNan;
+import org.nd4j.linalg.api.ops.custom.Flatten;
+import org.nd4j.linalg.api.ops.custom.FusedBatchNorm;
+import org.nd4j.linalg.api.ops.custom.Igamma;
+import org.nd4j.linalg.api.ops.custom.Igammac;
+import org.nd4j.linalg.api.ops.custom.Lgamma;
+import org.nd4j.linalg.api.ops.custom.Lu;
+import org.nd4j.linalg.api.ops.custom.MatrixBandPart;
+import org.nd4j.linalg.api.ops.custom.Polygamma;
+import org.nd4j.linalg.api.ops.custom.Roll;
+import org.nd4j.linalg.api.ops.custom.TriangularSolve;
import org.nd4j.linalg.api.ops.impl.broadcast.BiasAdd;
import org.nd4j.linalg.api.ops.impl.broadcast.BiasAddGrad;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.StopGradient;
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionBpOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionBpOpValidation.java
index 7ca72f7c7..ca904731f 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionBpOpValidation.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionBpOpValidation.java
@@ -24,6 +24,17 @@ import org.nd4j.autodiff.validation.OpTestCase;
import org.nd4j.autodiff.validation.OpValidation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.api.ops.impl.reduce.bp.CumSumBp;
+import org.nd4j.linalg.api.ops.impl.reduce.bp.MaxBp;
+import org.nd4j.linalg.api.ops.impl.reduce.bp.MeanBp;
+import org.nd4j.linalg.api.ops.impl.reduce.bp.MinBp;
+import org.nd4j.linalg.api.ops.impl.reduce.bp.Norm1Bp;
+import org.nd4j.linalg.api.ops.impl.reduce.bp.Norm2Bp;
+import org.nd4j.linalg.api.ops.impl.reduce.bp.NormMaxBp;
+import org.nd4j.linalg.api.ops.impl.reduce.bp.ProdBp;
+import org.nd4j.linalg.api.ops.impl.reduce.bp.StandardDeviationBp;
+import org.nd4j.linalg.api.ops.impl.reduce.bp.SumBp;
+import org.nd4j.linalg.api.ops.impl.reduce.bp.VarianceBp;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.ops.transforms.Transforms;
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionOpValidation.java
index f3db04faf..080c4b22f 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionOpValidation.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionOpValidation.java
@@ -38,7 +38,22 @@ import org.nd4j.linalg.api.ops.impl.loss.SoftmaxCrossEntropyWithLogitsLoss;
import org.nd4j.linalg.api.ops.impl.reduce.Moments;
import org.nd4j.linalg.api.ops.impl.reduce.NormalizeMoments;
import org.nd4j.linalg.api.ops.impl.reduce.SufficientStatistics;
+import org.nd4j.linalg.api.ops.impl.reduce.floating.AMean;
+import org.nd4j.linalg.api.ops.impl.reduce.floating.Entropy;
+import org.nd4j.linalg.api.ops.impl.reduce.floating.Mean;
+import org.nd4j.linalg.api.ops.impl.reduce.floating.Norm1;
+import org.nd4j.linalg.api.ops.impl.reduce.floating.Norm2;
+import org.nd4j.linalg.api.ops.impl.reduce.floating.NormMax;
+import org.nd4j.linalg.api.ops.impl.reduce.floating.ShannonEntropy;
+import org.nd4j.linalg.api.ops.impl.reduce.floating.SquaredNorm;
import org.nd4j.linalg.api.ops.impl.reduce.same.ASum;
+import org.nd4j.linalg.api.ops.impl.reduce3.CosineDistance;
+import org.nd4j.linalg.api.ops.impl.reduce3.CosineSimilarity;
+import org.nd4j.linalg.api.ops.impl.reduce3.Dot;
+import org.nd4j.linalg.api.ops.impl.reduce3.EuclideanDistance;
+import org.nd4j.linalg.api.ops.impl.reduce3.HammingDistance;
+import org.nd4j.linalg.api.ops.impl.reduce3.JaccardDistance;
+import org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance;
import org.nd4j.linalg.api.ops.impl.summarystats.StandardDeviation;
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
import org.nd4j.linalg.checkutil.NDArrayCreationUtil;
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java
index c20952c5b..eb41e7abc 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java
@@ -35,6 +35,13 @@ import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.custom.Tri;
import org.nd4j.linalg.api.ops.custom.Triu;
+import org.nd4j.linalg.api.ops.impl.shape.DiagPart;
+import org.nd4j.linalg.api.ops.impl.shape.MergeMaxIndex;
+import org.nd4j.linalg.api.ops.impl.shape.Permute;
+import org.nd4j.linalg.api.ops.impl.shape.SequenceMask;
+import org.nd4j.linalg.api.ops.impl.shape.SizeAt;
+import org.nd4j.linalg.api.ops.impl.shape.Transpose;
+import org.nd4j.linalg.api.ops.impl.shape.Unstack;
import org.nd4j.linalg.api.ops.impl.transforms.custom.Fill;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.api.shape.options.ArrayOptionsHelper;
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java
index b502a23be..76dbc4e1a 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java
@@ -48,8 +48,26 @@ import org.nd4j.linalg.api.ops.impl.shape.MergeMax;
import org.nd4j.linalg.api.ops.impl.shape.MergeMaxIndex;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.EmbeddingLookup;
import org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByAvgNorm;
+import org.nd4j.linalg.api.ops.impl.transforms.custom.CReLU;
+import org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThanOrEqual;
+import org.nd4j.linalg.api.ops.impl.transforms.custom.LessThanOrEqual;
+import org.nd4j.linalg.api.ops.impl.transforms.custom.Max;
+import org.nd4j.linalg.api.ops.impl.transforms.custom.Min;
+import org.nd4j.linalg.api.ops.impl.transforms.custom.Reverse;
+import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
+import org.nd4j.linalg.api.ops.impl.transforms.custom.Standardize;
import org.nd4j.linalg.api.ops.impl.transforms.floating.RSqrt;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MergeAddOp;
+import org.nd4j.linalg.api.ops.impl.transforms.strict.ACosh;
+import org.nd4j.linalg.api.ops.impl.transforms.strict.ASinh;
+import org.nd4j.linalg.api.ops.impl.transforms.strict.Erf;
+import org.nd4j.linalg.api.ops.impl.transforms.strict.Erfc;
+import org.nd4j.linalg.api.ops.impl.transforms.strict.HardSigmoid;
+import org.nd4j.linalg.api.ops.impl.transforms.strict.LogSigmoid;
+import org.nd4j.linalg.api.ops.impl.transforms.strict.RationalTanh;
+import org.nd4j.linalg.api.ops.impl.transforms.strict.RectifiedTanh;
+import org.nd4j.linalg.api.ops.impl.transforms.strict.SELU;
+import org.nd4j.linalg.api.ops.impl.transforms.strict.Swish;
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.factory.Nd4j;
@@ -2059,9 +2077,6 @@ public class TransformOpValidation extends BaseOpValidation {
);
assertNull(err);
-
-
-
}
@Test
@@ -2085,7 +2100,6 @@ public class TransformOpValidation extends BaseOpValidation {
}
-
@Test
public void testEmbeddingLookup() {
@@ -2243,11 +2257,5 @@ public class TransformOpValidation extends BaseOpValidation {
.gradientCheck(true));
assertNull(err);
}
-
-
}
-
-
-
- }
-
+}
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/ConvConfigTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/ConvConfigTests.java
index 82982e260..a47123d66 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/ConvConfigTests.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/ConvConfigTests.java
@@ -22,6 +22,14 @@ import static org.junit.Assert.fail;
import org.junit.Test;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv2D;
+import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig;
+import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
+import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig;
+import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv2DConfig;
+import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv3DConfig;
+import org.nd4j.linalg.api.ops.impl.layers.convolution.config.PaddingMode;
+import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig;
+import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling3DConfig;
import org.nd4j.linalg.factory.Nd4jBackend;
public class ConvConfigTests extends BaseNd4jTest {
@@ -487,8 +495,6 @@ public class ConvConfigTests extends BaseNd4jTest {
}
}
-
-
@Test
public void testConv1D(){
Conv1DConfig.builder().k(2).paddingMode(PaddingMode.SAME).build();
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/FlatBufferSerdeTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/FlatBufferSerdeTest.java
index ea88b45ed..2d9dcacac 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/FlatBufferSerdeTest.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/FlatBufferSerdeTest.java
@@ -22,6 +22,11 @@ import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.nd4j.autodiff.functions.DifferentialFunction;
+import org.nd4j.graph.FlatConfiguration;
+import org.nd4j.graph.FlatGraph;
+import org.nd4j.graph.FlatNode;
+import org.nd4j.graph.FlatVariable;
+import org.nd4j.graph.IntPair;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@@ -30,6 +35,17 @@ import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.learning.GradientUpdater;
+import org.nd4j.linalg.learning.config.AMSGrad;
+import org.nd4j.linalg.learning.config.AdaDelta;
+import org.nd4j.linalg.learning.config.AdaGrad;
+import org.nd4j.linalg.learning.config.AdaMax;
+import org.nd4j.linalg.learning.config.Adam;
+import org.nd4j.linalg.learning.config.IUpdater;
+import org.nd4j.linalg.learning.config.Nadam;
+import org.nd4j.linalg.learning.config.Nesterovs;
+import org.nd4j.linalg.learning.config.NoOp;
+import org.nd4j.linalg.learning.config.RmsProp;
+import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.learning.regularization.L1Regularization;
import org.nd4j.linalg.learning.regularization.L2Regularization;
import org.nd4j.linalg.learning.regularization.WeightDecay;
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/GraphTransformUtilTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/GraphTransformUtilTests.java
index f516c9813..6664d3d6c 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/GraphTransformUtilTests.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/GraphTransformUtilTests.java
@@ -18,6 +18,11 @@ package org.nd4j.autodiff.samediff;
import lombok.extern.slf4j.Slf4j;
import org.junit.Test;
+import org.nd4j.autodiff.samediff.transform.GraphTransformUtil;
+import org.nd4j.autodiff.samediff.transform.OpPredicate;
+import org.nd4j.autodiff.samediff.transform.SubGraph;
+import org.nd4j.autodiff.samediff.transform.SubGraphPredicate;
+import org.nd4j.autodiff.samediff.transform.SubGraphProcessor;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java
index 5a376cde6..3cb274ebd 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java
@@ -45,6 +45,12 @@ import org.nd4j.autodiff.validation.OpValidation;
import org.nd4j.autodiff.validation.TestCase;
import org.nd4j.enums.WeightsFormat;
import org.nd4j.evaluation.IEvaluation;
+import org.nd4j.evaluation.classification.Evaluation;
+import org.nd4j.evaluation.classification.EvaluationBinary;
+import org.nd4j.evaluation.classification.EvaluationCalibration;
+import org.nd4j.evaluation.classification.ROC;
+import org.nd4j.evaluation.classification.ROCBinary;
+import org.nd4j.evaluation.classification.ROCMultiClass;
import org.nd4j.evaluation.regression.RegressionEvaluation;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.activations.Activation;
@@ -3485,42 +3491,42 @@ public class SameDiffTests extends BaseNd4jTest {
}
@Test
- public void testConcatVariableGrad() {
- SameDiff sd = SameDiff.create();
- SDVariable label = sd.var("label", DataType.FLOAT, 3, 4);
- SDVariable a = sd.var("a", DataType.FLOAT, 3, 2);
- SDVariable b = sd.var("b", DataType.FLOAT, 3, 2);
- INDArray inputArr = Nd4j.rand(3,4);
- INDArray labelArr = Nd4j.rand(3,4);
- SDVariable c = sd.concat("concat", 1, a, b);
- SDVariable loss = sd.math().pow(c.sub(label), 2);
- sd.setLossVariables(loss);
- sd.associateArrayWithVariable(labelArr, label);
- sd.associateArrayWithVariable(inputArr.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 2)), a);
- sd.associateArrayWithVariable(inputArr.get(NDArrayIndex.all(), NDArrayIndex.interval(2, 4)), b);
- Map map = sd.calculateGradients(null, "a", "b", "concat");
- INDArray concatArray = Nd4j.hstack(map.get("a"), map.get("b"));
- assertEquals(concatArray, map.get("concat"));
+ public void testConcatVariableGrad() {
+ SameDiff sd = SameDiff.create();
+ SDVariable label = sd.var("label", DataType.FLOAT, 3, 4);
+ SDVariable a = sd.var("a", DataType.FLOAT, 3, 2);
+ SDVariable b = sd.var("b", DataType.FLOAT, 3, 2);
+ INDArray inputArr = Nd4j.rand(3,4);
+ INDArray labelArr = Nd4j.rand(3,4);
+ SDVariable c = sd.concat("concat", 1, a, b);
+ SDVariable loss = sd.math().pow(c.sub(label), 2);
+ sd.setLossVariables(loss);
+ sd.associateArrayWithVariable(labelArr, label);
+ sd.associateArrayWithVariable(inputArr.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 2)), a);
+ sd.associateArrayWithVariable(inputArr.get(NDArrayIndex.all(), NDArrayIndex.interval(2, 4)), b);
+ Map map = sd.calculateGradients(null, "a", "b", "concat");
+ INDArray concatArray = Nd4j.hstack(map.get("a"), map.get("b"));
+ assertEquals(concatArray, map.get("concat"));
- }
+ }
- @Test
- public void testSliceVariableGrad() {
- SameDiff sd = SameDiff.create();
- SDVariable label = sd.var("label", DataType.FLOAT, 3, 4);
- SDVariable input = sd.var("input", DataType.FLOAT, 3, 4);
- INDArray inputArr = Nd4j.rand(3,4);
- INDArray labelArr = Nd4j.rand(3,4);
- SDVariable a = input.get(SDIndex.all(), SDIndex.interval(0, 2));
- SDVariable b = input.get(SDIndex.all(), SDIndex.interval(2, 4));
- SDVariable c = sd.concat("concat", 1, a, b);
- SDVariable loss = sd.math().pow(c.sub(label), 2);
- sd.setLossVariables(loss);
- sd.associateArrayWithVariable(labelArr, label);
- sd.associateArrayWithVariable(inputArr, input);
- Map map = sd.calculateGradients(null,"input", "concat");
- assertEquals(map.get("input"), map.get("concat"));
- }
+ @Test
+ public void testSliceVariableGrad() {
+ SameDiff sd = SameDiff.create();
+ SDVariable label = sd.var("label", DataType.FLOAT, 3, 4);
+ SDVariable input = sd.var("input", DataType.FLOAT, 3, 4);
+ INDArray inputArr = Nd4j.rand(3,4);
+ INDArray labelArr = Nd4j.rand(3,4);
+ SDVariable a = input.get(SDIndex.all(), SDIndex.interval(0, 2));
+ SDVariable b = input.get(SDIndex.all(), SDIndex.interval(2, 4));
+ SDVariable c = sd.concat("concat", 1, a, b);
+ SDVariable loss = sd.math().pow(c.sub(label), 2);
+ sd.setLossVariables(loss);
+ sd.associateArrayWithVariable(labelArr, label);
+ sd.associateArrayWithVariable(inputArr, input);
+ Map map = sd.calculateGradients(null,"input", "concat");
+ assertEquals(map.get("input"), map.get("concat"));
+ }
@Test
public void testTrainingConfigJson(){
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/ListenerTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/ListenerTest.java
index ab546ebae..6dea34e01 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/ListenerTest.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/listeners/ListenerTest.java
@@ -17,6 +17,13 @@
package org.nd4j.autodiff.samediff.listeners;
import org.junit.Test;
+import org.nd4j.autodiff.listeners.At;
+import org.nd4j.autodiff.listeners.BaseListener;
+import org.nd4j.autodiff.listeners.Listener;
+import org.nd4j.autodiff.listeners.ListenerResponse;
+import org.nd4j.autodiff.listeners.ListenerVariables;
+import org.nd4j.autodiff.listeners.Loss;
+import org.nd4j.autodiff.listeners.Operation;
import org.nd4j.autodiff.listeners.impl.ScoreListener;
import org.nd4j.autodiff.listeners.records.History;
import org.nd4j.autodiff.listeners.records.LossCurve;
@@ -351,7 +358,7 @@ public class ListenerTest extends BaseNd4jTest {
}
@Override
- public void iterationDone(SameDiff sd, At at, MultiDataSet dataSet, Loss loss) {
+ public void iterationDone(final SameDiff sd, final At at, final MultiDataSet dataSet, final Loss loss) {
iterationDoneCount++;
}
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/ui/FileReadWriteTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/ui/FileReadWriteTests.java
index 749a19848..843cf3a31 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/ui/FileReadWriteTests.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/ui/FileReadWriteTests.java
@@ -28,6 +28,13 @@ import org.nd4j.autodiff.samediff.VariableType;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.autodiff.samediff.internal.Variable;
import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper;
+import org.nd4j.graph.FlatArray;
+import org.nd4j.graph.UIAddName;
+import org.nd4j.graph.UIEvent;
+import org.nd4j.graph.UIGraphStructure;
+import org.nd4j.graph.UIInfoType;
+import org.nd4j.graph.UIOp;
+import org.nd4j.graph.UIVariable;
import org.nd4j.graph.ui.LogFileWriter;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.buffer.DataType;
@@ -150,7 +157,6 @@ public class FileReadWriteTests extends BaseNd4jTest {
assertEquals(UIInfoType.START_EVENTS, read.getData().get(1).getFirst().infoType());
-
//Append a number of events
w.registerEventName("accuracy");
for( int iter=0; iter<3; iter++) {
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EmptyEvaluationTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EmptyEvaluationTests.java
index 834111ae2..4346ec8a3 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EmptyEvaluationTests.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EmptyEvaluationTests.java
@@ -1,6 +1,12 @@
package org.nd4j.evaluation;
import org.junit.Test;
+import org.nd4j.evaluation.classification.Evaluation;
+import org.nd4j.evaluation.classification.EvaluationBinary;
+import org.nd4j.evaluation.classification.EvaluationCalibration;
+import org.nd4j.evaluation.classification.ROC;
+import org.nd4j.evaluation.classification.ROCBinary;
+import org.nd4j.evaluation.classification.ROCMultiClass;
import org.nd4j.evaluation.regression.RegressionEvaluation;
import org.nd4j.evaluation.regression.RegressionEvaluation.Metric;
import org.nd4j.linalg.BaseNd4jTest;
@@ -107,7 +113,6 @@ public class EmptyEvaluationTests extends BaseNd4jTest {
assertTrue(t.getMessage(), t.getMessage().contains("no data"));
}
}
-
}
@Test
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvalJsonTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvalJsonTest.java
index 945a583ca..54e16713b 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvalJsonTest.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvalJsonTest.java
@@ -17,6 +17,12 @@
package org.nd4j.evaluation;
import org.junit.Test;
+import org.nd4j.evaluation.classification.Evaluation;
+import org.nd4j.evaluation.classification.EvaluationBinary;
+import org.nd4j.evaluation.classification.EvaluationCalibration;
+import org.nd4j.evaluation.classification.ROC;
+import org.nd4j.evaluation.classification.ROCBinary;
+import org.nd4j.evaluation.classification.ROCMultiClass;
import org.nd4j.evaluation.curves.Histogram;
import org.nd4j.evaluation.curves.PrecisionRecallCurve;
import org.nd4j.evaluation.curves.RocCurve;
@@ -100,8 +106,6 @@ public class EvalJsonTest extends BaseNd4jTest {
regressionEvaluation.eval(Nd4j.rand(10, 3), Nd4j.rand(10, 3));
-
-
for (IEvaluation e : arr) {
String json = e.toJson();
if (print) {
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/BERTGraphTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/BERTGraphTest.java
index 3be481b0b..bdda34338 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/BERTGraphTest.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/BERTGraphTest.java
@@ -22,6 +22,11 @@ import org.junit.Test;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.TrainingConfig;
+import org.nd4j.autodiff.samediff.transform.GraphTransformUtil;
+import org.nd4j.autodiff.samediff.transform.OpPredicate;
+import org.nd4j.autodiff.samediff.transform.SubGraph;
+import org.nd4j.autodiff.samediff.transform.SubGraphPredicate;
+import org.nd4j.autodiff.samediff.transform.SubGraphProcessor;
import org.nd4j.graph.ui.LogFileWriter;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.imports.tensorflow.TFImportOverride;
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestSuite.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestSuite.java
index 415ef64a8..37fa40689 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestSuite.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestSuite.java
@@ -35,7 +35,7 @@ import java.util.ServiceLoader;
public class Nd4jTestSuite extends BlockJUnit4ClassRunner {
//the system property for what backends should run
public final static String BACKENDS_TO_LOAD = "backends";
- private static List BACKENDS;
+ private static List BACKENDS = new ArrayList<>();
static {
ServiceLoader loadedBackends = ND4JClassLoading.loadService(Nd4jBackend.class);
for (Nd4jBackend backend : loadedBackends) {
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java
index 47aa4a679..67fc088c6 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java
@@ -16,6 +16,13 @@
package org.nd4j.linalg;
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
@@ -23,10 +30,18 @@ import lombok.var;
import org.apache.commons.io.FilenameUtils;
import org.apache.commons.math3.stat.descriptive.rank.Percentile;
import org.apache.commons.math3.util.FastMath;
-import org.junit.*;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Ignore;
+import org.junit.Rule;
+import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
+import org.nd4j.common.io.ClassPathResource;
+import org.nd4j.common.primitives.Pair;
+import org.nd4j.common.util.ArrayUtil;
+import org.nd4j.common.util.MathUtils;
import org.nd4j.enums.WeightsFormat;
import org.nd4j.imports.TFGraphs.NodeReader;
import org.nd4j.linalg.api.blas.Level1;
@@ -47,6 +62,14 @@ import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.executioner.GridExecutioner;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
+import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastAMax;
+import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastAMin;
+import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastAddOp;
+import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastDivOp;
+import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMax;
+import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMin;
+import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMulOp;
+import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastSubOp;
import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastEqualTo;
import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastGreaterThan;
import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastGreaterThanOrEqual;
@@ -64,6 +87,11 @@ import org.nd4j.linalg.api.ops.impl.reduce.custom.LogSumExp;
import org.nd4j.linalg.api.ops.impl.reduce.floating.Norm1;
import org.nd4j.linalg.api.ops.impl.reduce.floating.Norm2;
import org.nd4j.linalg.api.ops.impl.reduce.same.Sum;
+import org.nd4j.linalg.api.ops.impl.reduce3.CosineDistance;
+import org.nd4j.linalg.api.ops.impl.reduce3.CosineSimilarity;
+import org.nd4j.linalg.api.ops.impl.reduce3.EuclideanDistance;
+import org.nd4j.linalg.api.ops.impl.reduce3.HammingDistance;
+import org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance;
import org.nd4j.linalg.api.ops.impl.scalar.LeakyReLU;
import org.nd4j.linalg.api.ops.impl.scalar.ReplaceNans;
import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarEquals;
@@ -73,8 +101,8 @@ import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax;
import org.nd4j.linalg.api.ops.impl.transforms.bool.MatchConditionTransform;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.Eps;
-import org.nd4j.linalg.api.ops.impl.transforms.custom.Reverse;
import org.nd4j.linalg.api.ops.impl.transforms.custom.BatchToSpaceND;
+import org.nd4j.linalg.api.ops.impl.transforms.custom.Reverse;
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.BinaryRelativeError;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.Set;
@@ -94,20 +122,28 @@ import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.indexing.SpecifiedIndex;
import org.nd4j.linalg.indexing.conditions.Conditions;
-import org.nd4j.common.io.ClassPathResource;
import org.nd4j.linalg.ops.transforms.Transforms;
-import org.nd4j.common.primitives.Pair;
-import org.nd4j.common.util.ArrayUtil;
-import org.nd4j.common.util.MathUtils;
-import java.io.*;
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.ObjectInputStream;
+import java.io.ObjectOutputStream;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.file.Files;
import java.nio.file.Paths;
-import java.util.*;
-
-import static org.junit.Assert.*;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.List;
/**
* NDArrayTests
@@ -148,8 +184,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
Nd4j.setDataType(initialType);
}
-
-
@Test
public void testArangeNegative() {
INDArray arr = Nd4j.arange(-2,2).castTo(DataType.DOUBLE);
@@ -241,9 +275,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
INDArray inDup = in.dup();
-// System.out.println(in);
-// System.out.println(inDup);
-
assertEquals(arr, in); //Passes: Original array "in" is OK, but array "inDup" is not!?
assertEquals(in, inDup); //Fails
}
@@ -310,7 +341,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
assertEquals(assertion,test);
}
-
@Test
public void testAudoBroadcastAddMatrix() {
INDArray arr = Nd4j.linspace(1,4,4, DataType.DOUBLE).reshape(2,2);
@@ -336,7 +366,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
}
-
@Test
public void testTensorAlongDimension() {
val shape = new long[] {4, 5, 7};
@@ -538,7 +567,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
}
-
@Test
public void testGetColumns() {
INDArray matrix = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3);
@@ -2719,7 +2747,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
assertEquals(exp, zOutCF); //fails
}
-
@Test
public void testBroadcastDiv() {
INDArray num = Nd4j.create(new double[] {1.00, 1.00, 1.00, 1.00, 2.00, 2.00, 2.00, 2.00, 1.00, 1.00, 1.00, 1.00,
@@ -2753,7 +2780,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
}
}
-
@Test
public void testBroadcastMult() {
INDArray num = Nd4j.create(new double[] {1.00, 2.00, 3.00, 4.00, 5.00, 6.00, 7.00, 8.00, -1.00, -2.00, -3.00,
@@ -2796,7 +2822,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
assertEquals(expected, actual);
}
-
@Test
public void testDimension() {
INDArray test = Nd4j.create(Nd4j.linspace(1, 4, 4, DataType.DOUBLE).data(), new long[] {2, 2});
@@ -4595,8 +4620,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
-1.25485503673});
INDArray reduced = Nd4j.getExecutioner().exec(new CosineDistance(haystack, needle, 1));
-// log.info("Reduced: {}", reduced);
-
INDArray exp = Nd4j.create(new double[] {0.577452, 0.0, 1.80182});
assertEquals(exp, reduced);
@@ -4606,9 +4629,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
double res = Nd4j.getExecutioner().execAndReturn(new CosineDistance(row, needle)).z().getDouble(0);
assertEquals("Failed at " + i, reduced.getDouble(i), res, 1e-5);
}
- //cosinedistance([-0.84443557262, -0.06822254508, 0.74266910552, 0.61765557527, -0.77555125951], [-0.99536740779, -0.0257304441183, -0.6512106060, -0.345789492130, -1.25485503673)
- //cosinedistance([.62955373525, -0.31357592344, 1.03362500667, -0.59279078245, 1.1914824247], [-0.99536740779, -0.0257304441183, -0.6512106060, -0.345789492130, -1.25485503673)
-
}
@Test
@@ -4677,8 +4697,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
double res = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(x, needle)).getFinalResult()
.doubleValue();
assertEquals("Failed at " + i, reduced.getDouble(i), res, 0.001);
-
-// log.info("Euclidean: {} vs {} is {}", x, needle, res);
}
}
@@ -4698,8 +4716,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
double res = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(x, needle)).getFinalResult()
.doubleValue();
assertEquals("Failed at " + i, reduced.getDouble(i), res, 0.001);
-
-// log.info("Euclidean: {} vs {} is {}", x, needle, res);
}
}
@@ -4720,8 +4736,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
double res = Nd4j.getExecutioner().execAndReturn(new CosineSimilarity(x, needle)).getFinalResult()
.doubleValue();
assertEquals("Failed at " + i, reduced.getDouble(i), res, 0.001);
-
-// log.info("Cosine: {} vs {} is {}", x, needle, res);
}
}
@@ -4755,7 +4769,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
}
}
-
@Test
public void testAtan2_1() {
INDArray x = Nd4j.create(10).assign(-1.0);
@@ -4767,7 +4780,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
assertEquals(exp, z);
}
-
@Test
public void testAtan2_2() {
INDArray x = Nd4j.create(10).assign(1.0);
@@ -4779,7 +4791,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
assertEquals(exp, z);
}
-
@Test
public void testJaccardDistance1() {
INDArray x = Nd4j.create(new double[] {0, 1, 0, 0, 1, 0});
@@ -4790,7 +4801,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
assertEquals(0.75, val, 1e-5);
}
-
@Test
public void testJaccardDistance2() {
INDArray x = Nd4j.create(new double[] {0, 1, 0, 0, 1, 1});
@@ -4811,7 +4821,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
assertEquals(2.0 / 6, val, 1e-5);
}
-
@Test
public void testHammingDistance2() {
INDArray x = Nd4j.create(new double[] {0, 0, 0, 1, 0, 0});
@@ -4822,7 +4831,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
assertEquals(3.0 / 6, val, 1e-5);
}
-
@Test
public void testHammingDistance3() {
INDArray x = Nd4j.create(DataType.DOUBLE, 10, 6);
@@ -4831,7 +4839,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
x.getRow(r).putScalar(p, 1);
}
-// log.info("X: {}", x);
INDArray y = Nd4j.create(new double[] {0, 0, 0, 0, 1, 0});
INDArray res = Nd4j.getExecutioner().exec(new HammingDistance(x, y, 1));
@@ -4846,7 +4853,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
}
}
-
@Test
public void testAllDistances1() {
INDArray initialX = Nd4j.create(5, 10);
@@ -4879,7 +4885,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
}
}
-
@Test
public void testAllDistances2() {
INDArray initialX = Nd4j.create(5, 10);
@@ -4940,7 +4945,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
}
}
-
@Test
public void testAllDistances3_Large() {
INDArray initialX = Nd4j.create(5, 2000);
@@ -4968,13 +4972,11 @@ public class Nd4jTestsC extends BaseNd4jTest {
double res = result.getDouble(x, y);
double exp = Transforms.euclideanDistance(rowX, initialY.getRow(y).dup());
- //log.info("Expected [{}, {}]: {}",x, y, exp);
assertEquals("Failed for [" + x + ", " + y + "]", exp, res, 0.001);
}
}
}
-
@Test
public void testAllDistances3_Large_Columns() {
INDArray initialX = Nd4j.create(2000, 5);
@@ -5005,7 +5007,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
}
}
-
@Test
public void testAllDistances4_Large_Columns() {
INDArray initialX = Nd4j.create(2000, 5);
@@ -5095,8 +5096,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
}
}
-
-
@Test
public void testAllDistances3() {
Nd4j.getRandom().setSeed(123);
@@ -5122,7 +5121,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
}
}
-
@Test
public void testStridedTransforms1() {
//output: Rank: 2,Offset: 0
@@ -5176,7 +5174,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
}
}
-
@Test
public void testEntropy3() {
INDArray x = Nd4j.rand(1, 100);
@@ -5197,7 +5194,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
assertEquals(exp, res, 1e-5);
}
-
protected double getShannonEntropy(double[] array) {
double ret = 0;
for (double x : array) {
@@ -5207,12 +5203,10 @@ public class Nd4jTestsC extends BaseNd4jTest {
return -ret;
}
-
protected double getLogEntropy(double[] array) {
return Math.log(MathUtils.entropy(array));
}
-
@Test
public void testReverse1() {
INDArray array = Nd4j.create(new double[] {9, 8, 7, 6, 5, 4, 3, 2, 1, 0});
@@ -5228,8 +5222,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
INDArray array = Nd4j.create(new double[] {10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0});
INDArray exp = Nd4j.create(new double[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10});
-// log.info("Array shapeInfo: {}", array.shapeInfoJava());
-
INDArray rev = Nd4j.reverse(array);
assertEquals(exp, rev);
@@ -5278,7 +5270,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
assertTrue(rev == array);
}
-
@Test
public void testNativeSortView1() {
INDArray matrix = Nd4j.create(10, 10);
@@ -5291,9 +5282,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
Nd4j.sort(matrix.getColumn(0), true);
-
-// log.info("Matrix: {}", matrix);
-
assertEquals(exp, matrix.getColumn(0));
}
@@ -5384,9 +5372,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
Transforms.reverse(array, false);
-// log.info("Reversed shapeInfo: {}", array.shapeInfoJava());
-// log.info("Reversed: {}", array);
-
Transforms.reverse(array, false);
val jexp = exp.data().asInt();
@@ -5401,9 +5386,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
val exp = array.dup(array.ordering());
val reversed = Transforms.reverse(array, true);
-
-// log.info("Reversed: {}", reversed);
-
val rereversed = Transforms.reverse(reversed, true);
val jexp = exp.data().asInt();
@@ -5445,8 +5427,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
INDArray array = Nd4j.linspace(1, 2017152, 2017152, DataType.DOUBLE).reshape(1, -1);
INDArray exp = array.dup();
Transforms.reverse(array, false);
-// log.info("Reverse: {}", array);
-
long time1 = System.currentTimeMillis();
INDArray res = Nd4j.sort(array, true);
@@ -5465,7 +5445,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
assertNotEquals(exp1, dps);
-
for (int r = 0; r < array.rows(); r++) {
array.getRow(r).assign(dps);
}
@@ -5485,7 +5464,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
}
}
-
protected boolean checkIfUnique(INDArray array, int iteration) {
var jarray = array.data().asInt();
var set = new HashSet();
@@ -5698,7 +5676,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
}
}
-
@Test
public void testBroadcastAMax() {
INDArray matrix = Nd4j.create(5, 5);
@@ -5715,7 +5692,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
}
}
-
@Test
public void testBroadcastAMin() {
INDArray matrix = Nd4j.create(5, 5);
@@ -5767,7 +5743,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
assertEquals(exp, res);
}
-
@Test
public void testRDiv1() {
val argX = Nd4j.create(3).assign(2.0);
@@ -5789,7 +5764,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
assertEquals(arrayC, arrayF);
}
-
@Test
public void testMatchTransform() {
val array = Nd4j.create(new double[] {1, 1, 1, 0, 1, 1},'c');
@@ -5838,10 +5812,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
val a = Nd4j.linspace(1, x * A1 * A2, x * A1 * A2, DataType.DOUBLE).reshape(x, A1, A2);
val b = Nd4j.linspace(1, x * B1 * B2, x * B1 * B2, DataType.DOUBLE).reshape(x, B1, B2);
-
- //
-
- //log.info("C shape: {}", Arrays.toString(c.shapeInfoDataBuffer().asInt()));
}
@Test
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/activations/TestActivation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/activations/TestActivation.java
index 9e3e4dfc5..504e3bc34 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/activations/TestActivation.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/activations/TestActivation.java
@@ -21,6 +21,21 @@ import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
+import org.nd4j.linalg.activations.impl.ActivationCube;
+import org.nd4j.linalg.activations.impl.ActivationELU;
+import org.nd4j.linalg.activations.impl.ActivationGELU;
+import org.nd4j.linalg.activations.impl.ActivationHardSigmoid;
+import org.nd4j.linalg.activations.impl.ActivationHardTanH;
+import org.nd4j.linalg.activations.impl.ActivationIdentity;
+import org.nd4j.linalg.activations.impl.ActivationLReLU;
+import org.nd4j.linalg.activations.impl.ActivationRReLU;
+import org.nd4j.linalg.activations.impl.ActivationRationalTanh;
+import org.nd4j.linalg.activations.impl.ActivationReLU;
+import org.nd4j.linalg.activations.impl.ActivationSigmoid;
+import org.nd4j.linalg.activations.impl.ActivationSoftPlus;
+import org.nd4j.linalg.activations.impl.ActivationSoftSign;
+import org.nd4j.linalg.activations.impl.ActivationSoftmax;
+import org.nd4j.linalg.activations.impl.ActivationTanH;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
@@ -73,7 +88,6 @@ public class TestActivation extends BaseNd4jTest {
double[] dIn = in.data().asDouble();
for( int i=0; i 5000);
}
@@ -556,14 +554,8 @@ public class IndexingTestsC extends BaseNd4jTest {
char order = 'c';
INDArray arr = Nd4j.linspace(DataType.FLOAT, 1, prod, prod).reshape('c', inShape).dup(order);
INDArray sub = arr.get(indexes);
-
-// System.out.println(Arrays.toString(indexes));
-// System.out.println(arr);
-// System.out.println();
-// System.out.println(sub);
}
-
@Override
public char ordering() {
return 'c';
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/convolution/DeconvTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/convolution/DeconvTests.java
index cdaaf1be5..2b02b9a16 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/convolution/DeconvTests.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/convolution/DeconvTests.java
@@ -1,8 +1,13 @@
package org.nd4j.linalg.convolution;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
+import org.nd4j.common.io.ClassPathResource;
+import org.nd4j.common.resources.Resources;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@@ -10,12 +15,13 @@ import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
-import org.nd4j.common.resources.Resources;
import java.io.File;
-import java.util.*;
-
-import static org.junit.Assert.*;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
public class DeconvTests extends BaseNd4jTest {
@@ -33,10 +39,10 @@ public class DeconvTests extends BaseNd4jTest {
@Test
public void compareKeras() throws Exception {
- File f = testDir.newFolder();
- Resources.copyDirectory("keras/deconv", f);
+ File newFolder = testDir.newFolder();
+ new ClassPathResource("keras/deconv/").copyDirectory(newFolder);
- File[] files = f.listFiles();
+ File[] files = newFolder.listFiles();
Set tests = new HashSet<>();
for(File file : files){
@@ -64,10 +70,10 @@ public class DeconvTests extends BaseNd4jTest {
int d = Integer.parseInt(nums[5]);
boolean nchw = s.contains("nchw");
- INDArray w = Nd4j.readNpy(new File(f, s + "_W.npy"));
- INDArray b = Nd4j.readNpy(new File(f, s + "_b.npy"));
- INDArray in = Nd4j.readNpy(new File(f, s + "_in.npy")).castTo(DataType.FLOAT);
- INDArray expOut = Nd4j.readNpy(new File(f, s + "_out.npy"));
+ INDArray w = Nd4j.readNpy(new File(newFolder, s + "_W.npy"));
+ INDArray b = Nd4j.readNpy(new File(newFolder, s + "_b.npy"));
+ INDArray in = Nd4j.readNpy(new File(newFolder, s + "_in.npy")).castTo(DataType.FLOAT);
+ INDArray expOut = Nd4j.readNpy(new File(newFolder, s + "_out.npy"));
CustomOp op = DynamicCustomOp.builder("deconv2d")
.addInputs(in, w, b)
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java
index e8b5a214e..45322ef71 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java
@@ -26,6 +26,37 @@ import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
+import org.nd4j.linalg.api.ops.custom.AdjustContrast;
+import org.nd4j.linalg.api.ops.custom.AdjustHue;
+import org.nd4j.linalg.api.ops.custom.AdjustSaturation;
+import org.nd4j.linalg.api.ops.custom.BetaInc;
+import org.nd4j.linalg.api.ops.custom.BitCast;
+import org.nd4j.linalg.api.ops.custom.CompareAndBitpack;
+import org.nd4j.linalg.api.ops.custom.DivideNoNan;
+import org.nd4j.linalg.api.ops.custom.DrawBoundingBoxes;
+import org.nd4j.linalg.api.ops.custom.FakeQuantWithMinMaxVarsPerChannel;
+import org.nd4j.linalg.api.ops.custom.Flatten;
+import org.nd4j.linalg.api.ops.custom.FusedBatchNorm;
+import org.nd4j.linalg.api.ops.custom.HsvToRgb;
+import org.nd4j.linalg.api.ops.custom.KnnMinDistance;
+import org.nd4j.linalg.api.ops.custom.Lgamma;
+import org.nd4j.linalg.api.ops.custom.LinearSolve;
+import org.nd4j.linalg.api.ops.custom.Logdet;
+import org.nd4j.linalg.api.ops.custom.Lstsq;
+import org.nd4j.linalg.api.ops.custom.Lu;
+import org.nd4j.linalg.api.ops.custom.MatrixBandPart;
+import org.nd4j.linalg.api.ops.custom.Polygamma;
+import org.nd4j.linalg.api.ops.custom.RandomCrop;
+import org.nd4j.linalg.api.ops.custom.RgbToGrayscale;
+import org.nd4j.linalg.api.ops.custom.RgbToHsv;
+import org.nd4j.linalg.api.ops.custom.RgbToYiq;
+import org.nd4j.linalg.api.ops.custom.RgbToYuv;
+import org.nd4j.linalg.api.ops.custom.Roll;
+import org.nd4j.linalg.api.ops.custom.ScatterUpdate;
+import org.nd4j.linalg.api.ops.custom.ToggleBits;
+import org.nd4j.linalg.api.ops.custom.TriangularSolve;
+import org.nd4j.linalg.api.ops.custom.YiqToRgb;
+import org.nd4j.linalg.api.ops.custom.YuvToRgb;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.api.ops.executioner.OpStatus;
import org.nd4j.linalg.api.ops.impl.controlflow.Where;
@@ -373,7 +404,6 @@ public class CustomOpsTests extends BaseNd4jTest {
ScatterUpdate op = new ScatterUpdate(matrix, updates, indices, dims, ScatterUpdate.UpdateOp.ADD);
Nd4j.getExecutioner().exec(op);
-// log.info("Matrix: {}", matrix);
assertEquals(exp0, matrix.getRow(0));
assertEquals(exp1, matrix.getRow(1));
assertEquals(exp0, matrix.getRow(2));
@@ -1384,8 +1414,6 @@ public class CustomOpsTests extends BaseNd4jTest {
INDArray y = Nd4j.linspace(DataType.FLOAT, -5, 9, 1).reshape(3, 3);
val c = Conditions.equals(0.0);
-// System.out.println("Y:\n" + y);
-
INDArray z = x.match(y, c);
INDArray exp = Nd4j.createFromArray(new boolean[][]{
{false, false, false},
@@ -1396,7 +1424,6 @@ public class CustomOpsTests extends BaseNd4jTest {
assertEquals(exp, z);
}
-
@Test
public void testCreateOp_1() {
val shape = Nd4j.createFromArray(new int[] {3, 4, 5});
@@ -1862,11 +1889,9 @@ public class CustomOpsTests extends BaseNd4jTest {
System.out.println("in: " + in.shapeInfoToString());
System.out.println("inBadStrides: " + inBadStrides.shapeInfoToString());
-
INDArray out = Nd4j.create(DataType.FLOAT, 2, 12, 3, 3);
INDArray out2 = out.like();
-
CustomOp op1 = DynamicCustomOp.builder("space_to_depth")
.addInputs(in)
.addIntegerArguments(2, 0) //nchw = 0, nhwc = 1
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/NormalizerSerializerTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/NormalizerSerializerTest.java
index 08415415a..c6449e7dc 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/NormalizerSerializerTest.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/NormalizerSerializerTest.java
@@ -22,6 +22,14 @@ import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
+import org.nd4j.linalg.dataset.api.preprocessor.AbstractDataSetNormalizer;
+import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
+import org.nd4j.linalg.dataset.api.preprocessor.MinMaxStrategy;
+import org.nd4j.linalg.dataset.api.preprocessor.MultiNormalizerHybrid;
+import org.nd4j.linalg.dataset.api.preprocessor.MultiNormalizerMinMaxScaler;
+import org.nd4j.linalg.dataset.api.preprocessor.MultiNormalizerStandardize;
+import org.nd4j.linalg.dataset.api.preprocessor.NormalizerMinMaxScaler;
+import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
import org.nd4j.linalg.dataset.api.preprocessor.serializer.CustomSerializerStrategy;
import org.nd4j.linalg.dataset.api.preprocessor.serializer.NormalizerSerializer;
import org.nd4j.linalg.dataset.api.preprocessor.serializer.NormalizerType;
@@ -65,7 +73,6 @@ public class NormalizerSerializerTest extends BaseNd4jTest {
ImagePreProcessingScaler restored = SUT.restore(tmpFile);
assertEquals(imagePreProcessingScaler,restored);
-
}
@Test
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/NormalizerTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/NormalizerTests.java
index 2ed42e85b..a50b2ac5f 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/NormalizerTests.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/NormalizerTests.java
@@ -25,6 +25,15 @@ import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.TestDataSetIterator;
+import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
+import org.nd4j.linalg.dataset.api.preprocessor.ImageMultiPreProcessingScaler;
+import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
+import org.nd4j.linalg.dataset.api.preprocessor.MultiDataNormalization;
+import org.nd4j.linalg.dataset.api.preprocessor.MultiNormalizerMinMaxScaler;
+import org.nd4j.linalg.dataset.api.preprocessor.MultiNormalizerStandardize;
+import org.nd4j.linalg.dataset.api.preprocessor.NormalizerMinMaxScaler;
+import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
+import org.nd4j.linalg.dataset.api.preprocessor.VGG16ImagePreProcessor;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.indexing.NDArrayIndex;
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/learning/UpdaterTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/learning/UpdaterTest.java
index 81b4e4e75..763209756 100755
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/learning/UpdaterTest.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/learning/UpdaterTest.java
@@ -24,6 +24,12 @@ import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.distribution.Distribution;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
+import org.nd4j.linalg.learning.config.AdaDelta;
+import org.nd4j.linalg.learning.config.AdaGrad;
+import org.nd4j.linalg.learning.config.AdaMax;
+import org.nd4j.linalg.learning.config.Adam;
+import org.nd4j.linalg.learning.config.Nadam;
+import org.nd4j.linalg.learning.config.Nesterovs;
import static org.junit.Assert.assertEquals;
@@ -53,7 +59,6 @@ public class UpdaterTest extends BaseNd4jTest {
int rows = 10;
int cols = 2;
-
NesterovsUpdater grad = new NesterovsUpdater(new Nesterovs(0.5, 0.9));
grad.setStateViewArray(Nd4j.zeros(1, rows * cols), new long[] {rows, cols}, 'c', true);
INDArray W = Nd4j.zeros(rows, cols);
@@ -68,13 +73,11 @@ public class UpdaterTest extends BaseNd4jTest {
}
}
-
@Test
public void testAdaGrad() {
int rows = 10;
int cols = 2;
-
AdaGradUpdater grad = new AdaGradUpdater(new AdaGrad(0.1, AdaGrad.DEFAULT_ADAGRAD_EPSILON));
grad.setStateViewArray(Nd4j.zeros(1, rows * cols), new long[] {rows, cols}, 'c', true);
INDArray W = Nd4j.zeros(rows, cols);
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/learning/UpdaterValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/learning/UpdaterValidation.java
index 43ac02fc9..e1314a650 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/learning/UpdaterValidation.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/learning/UpdaterValidation.java
@@ -23,6 +23,15 @@ import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.updaters.AmsGradUpdater;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
+import org.nd4j.linalg.learning.config.AMSGrad;
+import org.nd4j.linalg.learning.config.AdaDelta;
+import org.nd4j.linalg.learning.config.AdaGrad;
+import org.nd4j.linalg.learning.config.AdaMax;
+import org.nd4j.linalg.learning.config.Adam;
+import org.nd4j.linalg.learning.config.Nadam;
+import org.nd4j.linalg.learning.config.Nesterovs;
+import org.nd4j.linalg.learning.config.RmsProp;
+import org.nd4j.linalg.learning.config.Sgd;
import java.util.HashMap;
import java.util.Map;
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/lossfunctions/LossFunctionJson.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/lossfunctions/LossFunctionJson.java
index 5a2fc38b1..fca7dcf5f 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/lossfunctions/LossFunctionJson.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/lossfunctions/LossFunctionJson.java
@@ -21,6 +21,21 @@ import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
+import org.nd4j.linalg.lossfunctions.impl.LossBinaryXENT;
+import org.nd4j.linalg.lossfunctions.impl.LossCosineProximity;
+import org.nd4j.linalg.lossfunctions.impl.LossHinge;
+import org.nd4j.linalg.lossfunctions.impl.LossKLD;
+import org.nd4j.linalg.lossfunctions.impl.LossL1;
+import org.nd4j.linalg.lossfunctions.impl.LossL2;
+import org.nd4j.linalg.lossfunctions.impl.LossMAE;
+import org.nd4j.linalg.lossfunctions.impl.LossMAPE;
+import org.nd4j.linalg.lossfunctions.impl.LossMCXENT;
+import org.nd4j.linalg.lossfunctions.impl.LossMSE;
+import org.nd4j.linalg.lossfunctions.impl.LossMSLE;
+import org.nd4j.linalg.lossfunctions.impl.LossMultiLabel;
+import org.nd4j.linalg.lossfunctions.impl.LossNegativeLogLikelihood;
+import org.nd4j.linalg.lossfunctions.impl.LossPoisson;
+import org.nd4j.linalg.lossfunctions.impl.LossSquaredHinge;
import org.nd4j.shade.jackson.databind.DeserializationFeature;
import org.nd4j.shade.jackson.databind.MapperFeature;
import org.nd4j.shade.jackson.databind.ObjectMapper;
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/lossfunctions/LossFunctionTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/lossfunctions/LossFunctionTest.java
index 4ae80e176..c5863ece2 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/lossfunctions/LossFunctionTest.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/lossfunctions/LossFunctionTest.java
@@ -28,6 +28,16 @@ import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.indexing.conditions.Conditions;
+import org.nd4j.linalg.lossfunctions.impl.LossBinaryXENT;
+import org.nd4j.linalg.lossfunctions.impl.LossL1;
+import org.nd4j.linalg.lossfunctions.impl.LossL2;
+import org.nd4j.linalg.lossfunctions.impl.LossMAE;
+import org.nd4j.linalg.lossfunctions.impl.LossMAPE;
+import org.nd4j.linalg.lossfunctions.impl.LossMCXENT;
+import org.nd4j.linalg.lossfunctions.impl.LossMSE;
+import org.nd4j.linalg.lossfunctions.impl.LossMSLE;
+import org.nd4j.linalg.lossfunctions.impl.LossNegativeLogLikelihood;
+import org.nd4j.linalg.lossfunctions.impl.LossSparseMCXENT;
import static junit.framework.TestCase.assertFalse;
import static junit.framework.TestCase.assertTrue;
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/nativ/OpsMappingTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/nativ/OpsMappingTests.java
index a2aad6348..214e8dae3 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/nativ/OpsMappingTests.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/nativ/OpsMappingTests.java
@@ -24,6 +24,13 @@ import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.BaseNd4jTest;
+import org.nd4j.linalg.api.ops.BaseBroadcastOp;
+import org.nd4j.linalg.api.ops.BaseIndexAccumulation;
+import org.nd4j.linalg.api.ops.BaseReduceFloatOp;
+import org.nd4j.linalg.api.ops.BaseScalarOp;
+import org.nd4j.linalg.api.ops.BaseTransformSameOp;
+import org.nd4j.linalg.api.ops.DynamicCustomOp;
+import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.impl.summarystats.Variance;
import org.nd4j.linalg.api.ops.random.BaseRandomOp;
import org.nd4j.linalg.factory.Nd4j;
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/DerivativeTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/DerivativeTests.java
index f253f4fc3..ef868d7ce 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/DerivativeTests.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/DerivativeTests.java
@@ -26,6 +26,12 @@ import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.scalar.Step;
+import org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeDerivative;
+import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardSigmoidDerivative;
+import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhDerivative;
+import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative;
+import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative;
+import org.nd4j.linalg.api.ops.impl.transforms.strict.Sigmoid;
import org.nd4j.linalg.api.ops.impl.transforms.strict.SigmoidDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.strict.TanhDerivative;
import org.nd4j.linalg.factory.Nd4j;
@@ -84,7 +90,6 @@ public class DerivativeTests extends BaseNd4jTest {
}
}
-
@Test
public void testRectifiedLinearDerivative() {
//ReLU:
@@ -166,11 +171,7 @@ public class DerivativeTests extends BaseNd4jTest {
}
INDArray z = Transforms.hardSigmoid(xArr, true);
- INDArray zPrime = Nd4j.getExecutioner().exec(new HardSigmoidDerivative(xArr.dup()));
-
-// System.out.println(xArr);
-// System.out.println(z);
-// System.out.println(zPrime);
+ INDArray zPrime = Nd4j.getExecutioner().exec(new HardSigmoidDerivative(xArr.dup()));;
for (int i = 0; i < expHSOut.length; i++) {
double relErrorHS =
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RandomTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RandomTests.java
index 31976f123..135e8cb45 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RandomTests.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RandomTests.java
@@ -32,6 +32,21 @@ import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.reduce.floating.Mean;
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
+import org.nd4j.linalg.api.ops.random.custom.DistributionUniform;
+import org.nd4j.linalg.api.ops.random.custom.RandomBernoulli;
+import org.nd4j.linalg.api.ops.random.custom.RandomGamma;
+import org.nd4j.linalg.api.ops.random.custom.RandomPoisson;
+import org.nd4j.linalg.api.ops.random.custom.RandomShuffle;
+import org.nd4j.linalg.api.ops.random.impl.AlphaDropOut;
+import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
+import org.nd4j.linalg.api.ops.random.impl.BinomialDistribution;
+import org.nd4j.linalg.api.ops.random.impl.DropOut;
+import org.nd4j.linalg.api.ops.random.impl.DropOutInverted;
+import org.nd4j.linalg.api.ops.random.impl.GaussianDistribution;
+import org.nd4j.linalg.api.ops.random.impl.Linspace;
+import org.nd4j.linalg.api.ops.random.impl.LogNormalDistribution;
+import org.nd4j.linalg.api.ops.random.impl.TruncatedNormalDistribution;
+import org.nd4j.linalg.api.ops.random.impl.UniformDistribution;
import org.nd4j.linalg.api.rng.DefaultRandom;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.api.rng.distribution.Distribution;
@@ -78,7 +93,6 @@ public class RandomTests extends BaseNd4jTest {
@Test
public void testCrossBackendEquality1() {
-
int[] shape = {12};
double mean = 0;
double standardDeviation = 1.0;
@@ -87,8 +101,6 @@ public class RandomTests extends BaseNd4jTest {
INDArray arr = Nd4j.getExecutioner().exec(new GaussianDistribution(
Nd4j.createUninitialized(shape, Nd4j.order()), mean, standardDeviation), Nd4j.getRandom());
-
-// log.info("arr: {}", arr.data().asDouble());
assertEquals(exp, arr);
}
@@ -105,8 +117,6 @@ public class RandomTests extends BaseNd4jTest {
UniformDistribution distribution2 = new UniformDistribution(z2, 1.0, 2.0);
Nd4j.getExecutioner().exec(distribution2, random2);
-// System.out.println("Data: " + z1);
-// System.out.println("Data: " + z2);
for (int e = 0; e < z1.length(); e++) {
double val = z1.getDouble(e);
assertTrue(val >= 1.0 && val <= 2.0);
@@ -135,8 +145,6 @@ public class RandomTests extends BaseNd4jTest {
log.info("States cpu: {}/{}", random1.rootState(), random1.nodeState());
-// System.out.println("Data: " + z1);
-// System.out.println("Data: " + z2);
for (int e = 0; e < z1.length(); e++) {
double val = z1.getDouble(e);
assertTrue(val >= 1.0 && val <= 2.0);
@@ -156,9 +164,6 @@ public class RandomTests extends BaseNd4jTest {
UniformDistribution distribution2 = new UniformDistribution(z2, 1.0, 2.0);
Nd4j.getExecutioner().exec(distribution2, random1);
-// System.out.println("Data: " + z1);
-// System.out.println("Data: " + z2);
-
assertNotEquals(z1, z2);
}
@@ -174,7 +179,6 @@ public class RandomTests extends BaseNd4jTest {
INDArray z2 = Nd4j.randn('c', new int[] {1, 1000});
assertEquals("Failed on iteration " + i, z1, z2);
-
}
}
@@ -190,7 +194,6 @@ public class RandomTests extends BaseNd4jTest {
INDArray z2 = Nd4j.rand('c', new int[] {1, 1000});
assertEquals("Failed on iteration " + i, z1, z2);
-
}
}
@@ -206,7 +209,6 @@ public class RandomTests extends BaseNd4jTest {
INDArray z2 = Nd4j.getExecutioner().exec(new BinomialDistribution(Nd4j.createUninitialized(1000), 10, 0.2));
assertEquals("Failed on iteration " + i, z1, z2);
-
}
}
@@ -222,8 +224,6 @@ public class RandomTests extends BaseNd4jTest {
assertEquals(z1, z2);
}
-
-
@Test
public void testDropoutInverted1() {
Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119);
@@ -318,7 +318,6 @@ public class RandomTests extends BaseNd4jTest {
assertEquals(z1, z2);
}
-
@Test
public void testGaussianDistribution2() {
Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119);
@@ -403,8 +402,6 @@ public class RandomTests extends BaseNd4jTest {
Distribution nd = new NormalDistribution(random1, 0.0, 1.0);
Nd4j.sort(z1, true);
-// System.out.println("Data for Anderson-Darling: " + z1);
-
for (int i = 0; i < n; i++) {
Double res = nd.cumulativeProbability(z1.getDouble(i));
@@ -432,9 +429,6 @@ public class RandomTests extends BaseNd4jTest {
public void testStepOver1() {
Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119);
-
-// log.info("1: ----------------");
-
INDArray z0 = Nd4j.getExecutioner().exec(new GaussianDistribution(Nd4j.createUninitialized(DataType.DOUBLE, 1000000), 0.0, 1.0));
assertEquals(0.0, z0.meanNumber().doubleValue(), 0.01);
@@ -442,36 +436,15 @@ public class RandomTests extends BaseNd4jTest {
random1.setSeed(119);
-// log.info("2: ----------------");
-
INDArray z1 = Nd4j.zeros(DataType.DOUBLE, 55000000);
INDArray z2 = Nd4j.zeros(DataType.DOUBLE, 55000000);
GaussianDistribution op1 = new GaussianDistribution(z1, 0.0, 1.0);
Nd4j.getExecutioner().exec(op1, random1);
-// log.info("2: ----------------");
-
- //log.info("End: [{}, {}, {}, {}]", z1.getFloat(29000000), z1.getFloat(29000001), z1.getFloat(29000002), z1.getFloat(29000003));
-
- //log.info("Sum: {}", z1.sumNumber().doubleValue());
-// log.info("Sum2: {}", z2.sumNumber().doubleValue());
-
-
INDArray match = Nd4j.getExecutioner().exec(new MatchCondition(z1, Conditions.isNan()));
-// log.info("NaNs: {}", match);
assertEquals(0.0f, match.getFloat(0), 0.01f);
- /*
- for (int i = 0; i < z1.length(); i++) {
- if (Double.isNaN(z1.getDouble(i)))
- throw new IllegalStateException("NaN value found at " + i);
-
- if (Double.isInfinite(z1.getDouble(i)))
- throw new IllegalStateException("Infinite value found at " + i);
- }
- */
-
assertEquals(0.0, z1.meanNumber().doubleValue(), 0.01);
assertEquals(1.0, z1.stdNumber().doubleValue(), 0.01);
}
@@ -480,7 +453,6 @@ public class RandomTests extends BaseNd4jTest {
public void testSum_119() {
INDArray z2 = Nd4j.zeros(DataType.DOUBLE, 55000000);
val sum = z2.sumNumber().doubleValue();
-// log.info("Sum2: {}", sum);
assertEquals(0.0, sum, 1e-5);
}
@@ -493,7 +465,6 @@ public class RandomTests extends BaseNd4jTest {
assertEquals(1.0, z1.stdNumber().doubleValue(), 0.01);
}
-
@Test
public void testSetSeed1() {
Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119);
@@ -533,8 +504,6 @@ public class RandomTests extends BaseNd4jTest {
assertEquals(z02, z12);
}
-
-
@Test
public void testJavaSide1() {
Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119);
@@ -553,8 +522,6 @@ public class RandomTests extends BaseNd4jTest {
assertArrayEquals(array1, array2, 1e-5f);
}
-
-
@Test
public void testJavaSide2() {
Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119);
@@ -574,7 +541,6 @@ public class RandomTests extends BaseNd4jTest {
assertArrayEquals(array1, array2);
}
-
@Test
public void testJavaSide3() {
Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119);
@@ -657,8 +623,6 @@ public class RandomTests extends BaseNd4jTest {
assertNotEquals(0, sum);
}
-
-
@Test
public void testBernoulliDistribution1() {
Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119);
@@ -677,11 +641,8 @@ public class RandomTests extends BaseNd4jTest {
assertNotEquals(z1Dup, z1);
assertEquals(z1, z2);
-
-
}
-
@Test
public void testBernoulliDistribution2() {
Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119);
@@ -690,7 +651,8 @@ public class RandomTests extends BaseNd4jTest {
INDArray z1 = Nd4j.zeros(20);
INDArray z2 = Nd4j.zeros(20);
INDArray z1Dup = Nd4j.zeros(20);
- INDArray exp = Nd4j.create(new double[] {0, 1.0000, 0, 1.0000, 1.0000, 0, 1.0000, 1.0000, 0, 1.0000, 1.0000, 1.0000, 0, 1.0000, 1.0000, 0, 0, 1.0000, 0, 1.0000});
+ INDArray exp = Nd4j.create(new double[]{ 0, 1.0000, 0, 1.0000, 1.0000, 0, 1.0000, 1.0000, 0, 1.0000, 1.0000,
+ 1.0000, 0, 1.0000, 1.0000, 0, 0, 1.0000, 0, 1.0000 });
BernoulliDistribution op1 = new BernoulliDistribution(z1, 0.50);
BernoulliDistribution op2 = new BernoulliDistribution(z2, 0.50);
@@ -705,7 +667,6 @@ public class RandomTests extends BaseNd4jTest {
assertEquals(exp, z1);
}
-
@Test
public void testBernoulliDistribution3() {
Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119);
@@ -716,7 +677,7 @@ public class RandomTests extends BaseNd4jTest {
INDArray z1 = Nd4j.zeros(10);
INDArray z2 = Nd4j.zeros(10);
INDArray z1Dup = Nd4j.zeros(10);
- INDArray exp = Nd4j.create(new double[] {1.0000, 0, 0, 1.0000, 1.0000, 1.0000, 0, 1.0000, 0, 0});
+ INDArray exp = Nd4j.create(new double[]{ 1.0000, 0, 0, 1.0000, 1.0000, 1.0000, 0, 1.0000, 0, 0 });
BernoulliDistribution op1 = new BernoulliDistribution(z1, prob);
BernoulliDistribution op2 = new BernoulliDistribution(z2, prob);
@@ -731,7 +692,6 @@ public class RandomTests extends BaseNd4jTest {
assertEquals(exp, z1);
}
-
@Test
public void testBinomialDistribution1() {
Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119);
@@ -780,7 +740,6 @@ public class RandomTests extends BaseNd4jTest {
BooleanIndexing.and(z1, Conditions.greaterThanOrEqual(0.0));
}
-
@Test
public void testMultithreading1() throws Exception {
@@ -822,7 +781,6 @@ public class RandomTests extends BaseNd4jTest {
}
}
-
@Test
public void testMultithreading2() throws Exception {
@@ -885,11 +843,11 @@ public class RandomTests extends BaseNd4jTest {
assertNotEquals(someInt, otherInt);
- } else
+ } else {
log.warn("Not a NativeRandom object received, skipping test");
+ }
}
-
@Test
public void testStepOver4() {
Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119, 100000);
@@ -903,7 +861,6 @@ public class RandomTests extends BaseNd4jTest {
}
}
-
@Test
public void testSignatures1() {
@@ -915,7 +872,6 @@ public class RandomTests extends BaseNd4jTest {
}
}
-
@Test
public void testChoice1() {
INDArray source = Nd4j.create(new double[] {1, 2, 3, 4, 5});
@@ -926,7 +882,6 @@ public class RandomTests extends BaseNd4jTest {
assertEquals(exp, sampled);
}
-
@Test
public void testChoice2() {
INDArray source = Nd4j.create(new double[] {1, 2, 3, 4, 5});
@@ -937,8 +892,6 @@ public class RandomTests extends BaseNd4jTest {
assertEquals(exp, sampled);
}
-
-
@Ignore
@Test
public void testDeallocation1() throws Exception {
@@ -952,7 +905,6 @@ public class RandomTests extends BaseNd4jTest {
}
}
-
@Test
public void someTest() {
DataTypeUtil.setDTypeForContext(DataType.DOUBLE);
@@ -1353,9 +1305,6 @@ public class RandomTests extends BaseNd4jTest {
log.info("Java mean: {}; Native mean: {}", mean, z01.meanNumber().doubleValue());
assertEquals(mean, z01.meanNumber().doubleValue(), 1e-1);
-
-
-
}
@Test
@@ -1364,44 +1313,32 @@ public class RandomTests extends BaseNd4jTest {
INDArray exp = Nd4j.create(new double[] {1, 2, 3, 4, 5});
assertEquals(exp, res);
-
}
@Test
public void testOrthogonalDistribution1() {
val dist = new OrthogonalDistribution(1.0);
-
val array = dist.sample(new int[] {6, 9});
-
-// log.info("Array: {}", array);
}
@Test
public void testOrthogonalDistribution2() {
val dist = new OrthogonalDistribution(1.0);
-
val array = dist.sample(new int[] {9, 6});
-
-// log.info("Array: {}", array);
}
@Test
public void testOrthogonalDistribution3() {
val dist = new OrthogonalDistribution(1.0);
-
val array = dist.sample(new int[] {9, 9});
-
-// log.info("Array: {}", array);
}
@Test
public void reproducabilityTest(){
-
int numBatches = 1;
- for( int t=0; t<10; t++ ) {
-// System.out.println(t);
+ for (int t = 0; t < 10; t++) {
numBatches = t;
List initial = getList(numBatches);
@@ -1410,7 +1347,6 @@ public class RandomTests extends BaseNd4jTest {
List list = getList(numBatches);
assertEquals(initial, list);
}
-
}
}
@@ -1428,7 +1364,6 @@ public class RandomTests extends BaseNd4jTest {
Nd4j.getRandom().setSeed(12345);
INDArray arr = Nd4j.create(DataType.DOUBLE, 100);
Nd4j.exec(new BernoulliDistribution(arr, 0.5));
-// System.out.println(arr);
double sum = arr.sumNumber().doubleValue();
assertTrue(String.valueOf(sum), sum > 0.0 && sum < 100.0);
}
@@ -1436,7 +1371,6 @@ public class RandomTests extends BaseNd4jTest {
private List getList(int numBatches){
Nd4j.getRandom().setSeed(12345);
List out = new ArrayList<>();
-// int numBatches = 32; //passes with 1 or 2
int channels = 3;
int imageHeight = 64;
int imageWidth = 64;
@@ -1446,7 +1380,6 @@ public class RandomTests extends BaseNd4jTest {
return out;
}
-
@Test
public void testRngRepeatabilityUniform(){
val nexp = Nd4j.create(DataType.FLOAT, 10);
@@ -1521,7 +1454,6 @@ public class RandomTests extends BaseNd4jTest {
assertEquals(res[0], res1[0]);
}
-
@Test
public void testRandom() {
val r1 = new java.util.Random(119);
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RngValidationTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RngValidationTests.java
index c87ab0956..fea1d8451 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RngValidationTests.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RngValidationTests.java
@@ -16,12 +16,16 @@
package org.nd4j.linalg.rng;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.fail;
+
import lombok.Builder;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.junit.Test;
import org.nd4j.OpValidationSuite;
import org.nd4j.common.base.Preconditions;
+import org.nd4j.common.util.ArrayUtil;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@@ -33,19 +37,27 @@ import org.nd4j.linalg.api.ops.random.compat.RandomStandardNormal;
import org.nd4j.linalg.api.ops.random.custom.DistributionUniform;
import org.nd4j.linalg.api.ops.random.custom.RandomBernoulli;
import org.nd4j.linalg.api.ops.random.custom.RandomExponential;
+import org.nd4j.linalg.api.ops.random.impl.AlphaDropOut;
+import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
+import org.nd4j.linalg.api.ops.random.impl.BinomialDistribution;
+import org.nd4j.linalg.api.ops.random.impl.Choice;
+import org.nd4j.linalg.api.ops.random.impl.DropOut;
+import org.nd4j.linalg.api.ops.random.impl.DropOutInverted;
+import org.nd4j.linalg.api.ops.random.impl.GaussianDistribution;
+import org.nd4j.linalg.api.ops.random.impl.Linspace;
+import org.nd4j.linalg.api.ops.random.impl.LogNormalDistribution;
+import org.nd4j.linalg.api.ops.random.impl.ProbablisticMerge;
+import org.nd4j.linalg.api.ops.random.impl.TruncatedNormalDistribution;
+import org.nd4j.linalg.api.ops.random.impl.UniformDistribution;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.indexing.conditions.Conditions;
-import org.nd4j.common.util.ArrayUtil;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.fail;
-
@Slf4j
public class RngValidationTests extends BaseNd4jTest {
@@ -407,8 +419,6 @@ public class RngValidationTests extends BaseNd4jTest {
double alpha = alphaDropoutA(tc.prop("p"));
double beta = alphaDropoutB(tc.prop("p"));
return new AlphaDropOut(Nd4j.ones(tc.getDataType(), tc.shape), tc.arr(), tc.prop("p"), alpha, ALPHA_PRIME, beta);
-
-
case "distributionuniform":
INDArray shape = tc.getShape().length == 0 ? Nd4j.empty(DataType.LONG) : Nd4j.create(ArrayUtil.toDouble(tc.shape)).castTo(DataType.LONG);
return new DistributionUniform(shape, tc.arr(), tc.prop("min"), tc.prop("max"));
@@ -437,7 +447,6 @@ public class RngValidationTests extends BaseNd4jTest {
return Math.abs(x-y) / (Math.abs(x) + Math.abs(y));
}
-
public static final double DEFAULT_ALPHA = 1.6732632423543772;
public static final double DEFAULT_LAMBDA = 1.0507009873554804;
public static final double ALPHA_PRIME = -DEFAULT_LAMBDA * DEFAULT_ALPHA;
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/BasicWorkspaceTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/BasicWorkspaceTests.java
index 73f74ca05..1d1340dd5 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/BasicWorkspaceTests.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/BasicWorkspaceTests.java
@@ -29,6 +29,12 @@ import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
+import org.nd4j.linalg.api.memory.enums.AllocationPolicy;
+import org.nd4j.linalg.api.memory.enums.LearningPolicy;
+import org.nd4j.linalg.api.memory.enums.LocationPolicy;
+import org.nd4j.linalg.api.memory.enums.MirroringPolicy;
+import org.nd4j.linalg.api.memory.enums.ResetPolicy;
+import org.nd4j.linalg.api.memory.enums.SpillPolicy;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/DebugModeTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/DebugModeTests.java
index da5c4f0f9..c8755ea89 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/DebugModeTests.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/DebugModeTests.java
@@ -26,6 +26,11 @@ import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
+import org.nd4j.linalg.api.memory.enums.AllocationPolicy;
+import org.nd4j.linalg.api.memory.enums.DebugMode;
+import org.nd4j.linalg.api.memory.enums.LearningPolicy;
+import org.nd4j.linalg.api.memory.enums.MirroringPolicy;
+import org.nd4j.linalg.api.memory.enums.SpillPolicy;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.api.memory.abstracts.Nd4jWorkspace;
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/SpecialWorkspaceTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/SpecialWorkspaceTests.java
index b76c988ba..8345610c2 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/SpecialWorkspaceTests.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/SpecialWorkspaceTests.java
@@ -16,6 +16,11 @@
package org.nd4j.linalg.workspace;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotEquals;
+import static org.junit.Assert.assertTrue;
+
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.junit.After;
@@ -25,19 +30,22 @@ import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
+import org.nd4j.linalg.api.memory.abstracts.Nd4jWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
+import org.nd4j.linalg.api.memory.enums.AllocationPolicy;
+import org.nd4j.linalg.api.memory.enums.LearningPolicy;
+import org.nd4j.linalg.api.memory.enums.LocationPolicy;
+import org.nd4j.linalg.api.memory.enums.ResetPolicy;
+import org.nd4j.linalg.api.memory.enums.SpillPolicy;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
-import org.nd4j.linalg.api.memory.abstracts.Nd4jWorkspace;
import java.nio.file.Files;
import java.util.ArrayList;
import java.util.Arrays;
-import static org.junit.Assert.*;
-
/**
* @author raver119@gmail.com
*/
@@ -60,9 +68,15 @@ public class SpecialWorkspaceTests extends BaseNd4jTest {
@Test
public void testVariableTimeSeries1() {
- WorkspaceConfiguration configuration = WorkspaceConfiguration.builder().initialSize(0).overallocationLimit(3.0)
- .policyAllocation(AllocationPolicy.OVERALLOCATE).policySpill(SpillPolicy.EXTERNAL)
- .policyLearning(LearningPolicy.FIRST_LOOP).policyReset(ResetPolicy.ENDOFBUFFER_REACHED).build();
+ WorkspaceConfiguration configuration = WorkspaceConfiguration
+ .builder()
+ .initialSize(0)
+ .overallocationLimit(3.0)
+ .policyAllocation(AllocationPolicy.OVERALLOCATE)
+ .policySpill(SpillPolicy.EXTERNAL)
+ .policyLearning(LearningPolicy.FIRST_LOOP)
+ .policyReset(ResetPolicy.ENDOFBUFFER_REACHED)
+ .build();
try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(configuration, "WS1")) {
Nd4j.create(500);
@@ -70,7 +84,6 @@ public class SpecialWorkspaceTests extends BaseNd4jTest {
}
Nd4jWorkspace workspace = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread("WS1");
-// workspace.enableDebug(true);
assertEquals(0, workspace.getStepNumber());
@@ -125,7 +138,6 @@ public class SpecialWorkspaceTests extends BaseNd4jTest {
log.info("Workspace state after first block: ---------------------------------------------------------");
Nd4j.getWorkspaceManager().printAllocationStatisticsForCurrentThread();
-
log.info("--------------------------------------------------------------------------------------------");
// we just do huge loop now, with pinned stuff in it
@@ -144,7 +156,6 @@ public class SpecialWorkspaceTests extends BaseNd4jTest {
assertNotEquals(0, workspace.getNumberOfPinnedAllocations());
assertEquals(0, workspace.getNumberOfExternalAllocations());
-
// and we do another clean loo, without pinned stuff in it, to ensure all pinned allocates are gone
for (int i = 0; i < 100; i++) {
try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(configuration, "WS1")) {
@@ -158,12 +169,10 @@ public class SpecialWorkspaceTests extends BaseNd4jTest {
assertEquals(0, workspace.getNumberOfPinnedAllocations());
assertEquals(0, workspace.getNumberOfExternalAllocations());
-
log.info("Workspace state after second block: ---------------------------------------------------------");
Nd4j.getWorkspaceManager().printAllocationStatisticsForCurrentThread();
}
-
@Test
public void testVariableTimeSeries2() {
WorkspaceConfiguration configuration = WorkspaceConfiguration.builder().initialSize(0).overallocationLimit(3.0)
@@ -179,8 +188,6 @@ public class SpecialWorkspaceTests extends BaseNd4jTest {
Nd4j.create(500);
}
-
-
assertEquals(0, workspace.getStepNumber());
long requiredMemory = 1000 * Nd4j.sizeOfDataType();
@@ -189,7 +196,6 @@ public class SpecialWorkspaceTests extends BaseNd4jTest {
assertEquals(shiftedSize, workspace.getInitialBlockSize());
assertEquals(workspace.getInitialBlockSize() * 4, workspace.getCurrentSize());
-
for (int i = 0; i < 100; i++) {
try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(configuration, "WS1")) {
Nd4j.create(500);
@@ -206,7 +212,6 @@ public class SpecialWorkspaceTests extends BaseNd4jTest {
assertEquals(0, workspace.getSpilledSize());
assertEquals(0, workspace.getPinnedSize());
-
}
@Test
@@ -238,7 +243,6 @@ public class SpecialWorkspaceTests extends BaseNd4jTest {
assertEquals(exp, result);
}
-
@Test
public void testAlignment_1() {
WorkspaceConfiguration initialConfig = WorkspaceConfiguration.builder().initialSize(10 * 1024L * 1024L)
@@ -260,7 +264,6 @@ public class SpecialWorkspaceTests extends BaseNd4jTest {
}
}
-
@Test
public void testNoOpExecution_1() {
val configuration = WorkspaceConfiguration.builder().initialSize(10000000).overallocationLimit(3.0)
@@ -424,7 +427,6 @@ public class SpecialWorkspaceTests extends BaseNd4jTest {
Files.delete(tmpFile);
}
-
@Test
public void testMigrateToWorkspace(){
val src = Nd4j.createFromArray (1L,2L);
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/WorkspaceProviderTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/WorkspaceProviderTests.java
index 6517a4731..2b65c67d1 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/WorkspaceProviderTests.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/WorkspaceProviderTests.java
@@ -27,6 +27,11 @@ import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
+import org.nd4j.linalg.api.memory.enums.AllocationPolicy;
+import org.nd4j.linalg.api.memory.enums.LearningPolicy;
+import org.nd4j.linalg.api.memory.enums.MirroringPolicy;
+import org.nd4j.linalg.api.memory.enums.ResetPolicy;
+import org.nd4j.linalg.api.memory.enums.SpillPolicy;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;