FIX: ND4J tests (#9114)

Signed-off-by: hosuaby <alexei.klenin@gmail.com>
master
Alexei KLENIN 2020-10-26 15:17:17 -07:00 committed by GitHub
parent ca4aee16ec
commit 2e000c84ac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
39 changed files with 513 additions and 298 deletions

View File

@ -14,7 +14,6 @@ Please search for the latest version on search.maven.org.
Or use the versions displayed in:
https://github.com/eclipse/deeplearning4j-examples/blob/master/pom.xml
---
## Main Features
@ -47,6 +46,29 @@ To install ND4J, there are a couple of approaches, and more information can be f
#### Clone from the GitHub Repo
https://deeplearning4j.org/docs/latest/deeplearning4j-build-from-source
#### Build from sources
To build `ND4J` from sources launch from the present directory:
```shell script
$ mvn clean install -DskipTests=true
```
To run tests using CPU or CUDA backend run the following.
For CPU:
```shell script
$ mvn clean test -P testresources -P nd4j-testresources -P nd4j-tests-cpu -P nd4j-tf-cpu
```
For CUDA:
```shell script
$ mvn clean test -P testresources -P nd4j-testresources -P nd4j-tests-cuda -P nd4j-tf-gpu
```
## Contribute
1. Check for open issues, or open a new issue to start a discussion around a feature idea or a bug.

View File

@ -1,4 +1,4 @@
// Targeted by JavaCPP version 1.5.4-SNAPSHOT: DO NOT EDIT THIS FILE
// Targeted by JavaCPP version 1.5.4: DO NOT EDIT THIS FILE
package org.nd4j.nativeblas;
@ -6,6 +6,10 @@ import java.nio.*;
import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.annotation.*;
import static org.bytedeco.javacpp.presets.javacpp.*;
import static org.bytedeco.openblas.global.openblas_nolapack.*;
import static org.bytedeco.openblas.global.openblas.*;
public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper {
static { Loader.load(); }

View File

@ -128,7 +128,6 @@
<version>${project.version}</version>
<scope>test</scope>
</dependency>
<!-- https://maven.apache.org/surefire/maven-surefire-plugin/examples/fork-options-and-parallel-execution.html -->
<dependency>
<groupId>com.github.stephenc.jcip</groupId>
<artifactId>jcip-annotations</artifactId>

View File

@ -17,12 +17,14 @@
package org.nd4j.autodiff;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils;
import org.junit.Ignore;
import org.junit.Test;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.imports.converters.ImportClassMapping;
import org.nd4j.linalg.BaseNd4jTest;
@ -30,6 +32,36 @@ import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.NoOp;
import org.nd4j.linalg.api.ops.compat.CompatSparseToDense;
import org.nd4j.linalg.api.ops.compat.CompatStringSplit;
import org.nd4j.linalg.api.ops.custom.BarnesEdgeForces;
import org.nd4j.linalg.api.ops.custom.BarnesHutGains;
import org.nd4j.linalg.api.ops.custom.BarnesHutSymmetrize;
import org.nd4j.linalg.api.ops.custom.KnnMinDistance;
import org.nd4j.linalg.api.ops.custom.SpTreeCell;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastAMax;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastAMin;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastAddOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastCopyOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastDivOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastGradientArgs;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMax;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMin;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMulOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastRDivOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastRSubOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastSubOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastTo;
import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastEqualTo;
import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastGreaterThan;
import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastGreaterThanOrEqual;
import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastLessThan;
import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastLessThanOrEqual;
import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastNotEqual;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Enter;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Exit;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.LoopCond;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Merge;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.NextIteration;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Switch;
import org.nd4j.linalg.api.ops.impl.grid.FreeGridOp;
import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax;
import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin;
@ -55,6 +87,15 @@ import org.nd4j.linalg.api.ops.impl.transforms.pairwise.BinaryRelativeError;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.CopyOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.PowPairwise;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RealDivOp;
import org.nd4j.linalg.api.ops.impl.updaters.AdaDeltaUpdater;
import org.nd4j.linalg.api.ops.impl.updaters.AdaGradUpdater;
import org.nd4j.linalg.api.ops.impl.updaters.AdaMaxUpdater;
import org.nd4j.linalg.api.ops.impl.updaters.AdamUpdater;
import org.nd4j.linalg.api.ops.impl.updaters.AmsGradUpdater;
import org.nd4j.linalg.api.ops.impl.updaters.NadamUpdater;
import org.nd4j.linalg.api.ops.impl.updaters.NesterovsUpdater;
import org.nd4j.linalg.api.ops.impl.updaters.RmsPropUpdater;
import org.nd4j.linalg.api.ops.impl.updaters.SgdUpdater;
import org.nd4j.linalg.api.ops.persistence.RestoreV2;
import org.nd4j.linalg.api.ops.persistence.SaveV2;
import org.nd4j.linalg.api.ops.util.PrintAffinity;
@ -66,13 +107,17 @@ import org.reflections.Reflections;
import java.io.File;
import java.lang.reflect.Modifier;
import java.nio.charset.StandardCharsets;
import java.util.*;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
public class TestOpMapping extends BaseNd4jTest {
Set<Class<? extends DifferentialFunction>> subTypes;
@ -303,9 +348,6 @@ public class TestOpMapping extends BaseNd4jTest {
s.add(PrintVariable.class);
s.add(PrintAffinity.class);
s.add(Assign.class);
}
@Test @Ignore

View File

@ -35,6 +35,15 @@ import org.nd4j.linalg.api.ops.impl.layers.convolution.AvgPooling2D;
import org.nd4j.linalg.api.ops.impl.layers.convolution.DepthwiseConv2D;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2DDerivative;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv2DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv3DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.LocalResponseNormalizationConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.PaddingMode;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling3DConfig;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.GRU;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMActivations;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMDataFormat;

View File

@ -32,6 +32,18 @@ import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.custom.Digamma;
import org.nd4j.linalg.api.ops.custom.DivideNoNan;
import org.nd4j.linalg.api.ops.custom.Flatten;
import org.nd4j.linalg.api.ops.custom.FusedBatchNorm;
import org.nd4j.linalg.api.ops.custom.Igamma;
import org.nd4j.linalg.api.ops.custom.Igammac;
import org.nd4j.linalg.api.ops.custom.Lgamma;
import org.nd4j.linalg.api.ops.custom.Lu;
import org.nd4j.linalg.api.ops.custom.MatrixBandPart;
import org.nd4j.linalg.api.ops.custom.Polygamma;
import org.nd4j.linalg.api.ops.custom.Roll;
import org.nd4j.linalg.api.ops.custom.TriangularSolve;
import org.nd4j.linalg.api.ops.impl.broadcast.BiasAdd;
import org.nd4j.linalg.api.ops.impl.broadcast.BiasAddGrad;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.StopGradient;

View File

@ -24,6 +24,17 @@ import org.nd4j.autodiff.validation.OpTestCase;
import org.nd4j.autodiff.validation.OpValidation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.reduce.bp.CumSumBp;
import org.nd4j.linalg.api.ops.impl.reduce.bp.MaxBp;
import org.nd4j.linalg.api.ops.impl.reduce.bp.MeanBp;
import org.nd4j.linalg.api.ops.impl.reduce.bp.MinBp;
import org.nd4j.linalg.api.ops.impl.reduce.bp.Norm1Bp;
import org.nd4j.linalg.api.ops.impl.reduce.bp.Norm2Bp;
import org.nd4j.linalg.api.ops.impl.reduce.bp.NormMaxBp;
import org.nd4j.linalg.api.ops.impl.reduce.bp.ProdBp;
import org.nd4j.linalg.api.ops.impl.reduce.bp.StandardDeviationBp;
import org.nd4j.linalg.api.ops.impl.reduce.bp.SumBp;
import org.nd4j.linalg.api.ops.impl.reduce.bp.VarianceBp;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.ops.transforms.Transforms;

View File

@ -38,7 +38,22 @@ import org.nd4j.linalg.api.ops.impl.loss.SoftmaxCrossEntropyWithLogitsLoss;
import org.nd4j.linalg.api.ops.impl.reduce.Moments;
import org.nd4j.linalg.api.ops.impl.reduce.NormalizeMoments;
import org.nd4j.linalg.api.ops.impl.reduce.SufficientStatistics;
import org.nd4j.linalg.api.ops.impl.reduce.floating.AMean;
import org.nd4j.linalg.api.ops.impl.reduce.floating.Entropy;
import org.nd4j.linalg.api.ops.impl.reduce.floating.Mean;
import org.nd4j.linalg.api.ops.impl.reduce.floating.Norm1;
import org.nd4j.linalg.api.ops.impl.reduce.floating.Norm2;
import org.nd4j.linalg.api.ops.impl.reduce.floating.NormMax;
import org.nd4j.linalg.api.ops.impl.reduce.floating.ShannonEntropy;
import org.nd4j.linalg.api.ops.impl.reduce.floating.SquaredNorm;
import org.nd4j.linalg.api.ops.impl.reduce.same.ASum;
import org.nd4j.linalg.api.ops.impl.reduce3.CosineDistance;
import org.nd4j.linalg.api.ops.impl.reduce3.CosineSimilarity;
import org.nd4j.linalg.api.ops.impl.reduce3.Dot;
import org.nd4j.linalg.api.ops.impl.reduce3.EuclideanDistance;
import org.nd4j.linalg.api.ops.impl.reduce3.HammingDistance;
import org.nd4j.linalg.api.ops.impl.reduce3.JaccardDistance;
import org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance;
import org.nd4j.linalg.api.ops.impl.summarystats.StandardDeviation;
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
import org.nd4j.linalg.checkutil.NDArrayCreationUtil;

View File

@ -35,6 +35,13 @@ import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.custom.Tri;
import org.nd4j.linalg.api.ops.custom.Triu;
import org.nd4j.linalg.api.ops.impl.shape.DiagPart;
import org.nd4j.linalg.api.ops.impl.shape.MergeMaxIndex;
import org.nd4j.linalg.api.ops.impl.shape.Permute;
import org.nd4j.linalg.api.ops.impl.shape.SequenceMask;
import org.nd4j.linalg.api.ops.impl.shape.SizeAt;
import org.nd4j.linalg.api.ops.impl.shape.Transpose;
import org.nd4j.linalg.api.ops.impl.shape.Unstack;
import org.nd4j.linalg.api.ops.impl.transforms.custom.Fill;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.api.shape.options.ArrayOptionsHelper;

View File

@ -48,8 +48,26 @@ import org.nd4j.linalg.api.ops.impl.shape.MergeMax;
import org.nd4j.linalg.api.ops.impl.shape.MergeMaxIndex;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.EmbeddingLookup;
import org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByAvgNorm;
import org.nd4j.linalg.api.ops.impl.transforms.custom.CReLU;
import org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThanOrEqual;
import org.nd4j.linalg.api.ops.impl.transforms.custom.LessThanOrEqual;
import org.nd4j.linalg.api.ops.impl.transforms.custom.Max;
import org.nd4j.linalg.api.ops.impl.transforms.custom.Min;
import org.nd4j.linalg.api.ops.impl.transforms.custom.Reverse;
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
import org.nd4j.linalg.api.ops.impl.transforms.custom.Standardize;
import org.nd4j.linalg.api.ops.impl.transforms.floating.RSqrt;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MergeAddOp;
import org.nd4j.linalg.api.ops.impl.transforms.strict.ACosh;
import org.nd4j.linalg.api.ops.impl.transforms.strict.ASinh;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Erf;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Erfc;
import org.nd4j.linalg.api.ops.impl.transforms.strict.HardSigmoid;
import org.nd4j.linalg.api.ops.impl.transforms.strict.LogSigmoid;
import org.nd4j.linalg.api.ops.impl.transforms.strict.RationalTanh;
import org.nd4j.linalg.api.ops.impl.transforms.strict.RectifiedTanh;
import org.nd4j.linalg.api.ops.impl.transforms.strict.SELU;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Swish;
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.factory.Nd4j;
@ -2059,9 +2077,6 @@ public class TransformOpValidation extends BaseOpValidation {
);
assertNull(err);
}
@Test
@ -2085,7 +2100,6 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Test
public void testEmbeddingLookup() {
@ -2243,11 +2257,5 @@ public class TransformOpValidation extends BaseOpValidation {
.gradientCheck(true));
assertNull(err);
}
}
}
}

View File

@ -22,6 +22,14 @@ import static org.junit.Assert.fail;
import org.junit.Test;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv2D;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv2DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv3DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.PaddingMode;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling3DConfig;
import org.nd4j.linalg.factory.Nd4jBackend;
public class ConvConfigTests extends BaseNd4jTest {
@ -487,8 +495,6 @@ public class ConvConfigTests extends BaseNd4jTest {
}
}
@Test
public void testConv1D(){
Conv1DConfig.builder().k(2).paddingMode(PaddingMode.SAME).build();

View File

@ -22,6 +22,11 @@ import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.graph.FlatConfiguration;
import org.nd4j.graph.FlatGraph;
import org.nd4j.graph.FlatNode;
import org.nd4j.graph.FlatVariable;
import org.nd4j.graph.IntPair;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -30,6 +35,17 @@ import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.learning.GradientUpdater;
import org.nd4j.linalg.learning.config.AMSGrad;
import org.nd4j.linalg.learning.config.AdaDelta;
import org.nd4j.linalg.learning.config.AdaGrad;
import org.nd4j.linalg.learning.config.AdaMax;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.learning.config.Nadam;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.learning.config.NoOp;
import org.nd4j.linalg.learning.config.RmsProp;
import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.learning.regularization.L1Regularization;
import org.nd4j.linalg.learning.regularization.L2Regularization;
import org.nd4j.linalg.learning.regularization.WeightDecay;

View File

@ -18,6 +18,11 @@ package org.nd4j.autodiff.samediff;
import lombok.extern.slf4j.Slf4j;
import org.junit.Test;
import org.nd4j.autodiff.samediff.transform.GraphTransformUtil;
import org.nd4j.autodiff.samediff.transform.OpPredicate;
import org.nd4j.autodiff.samediff.transform.SubGraph;
import org.nd4j.autodiff.samediff.transform.SubGraphPredicate;
import org.nd4j.autodiff.samediff.transform.SubGraphProcessor;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;

View File

@ -45,6 +45,12 @@ import org.nd4j.autodiff.validation.OpValidation;
import org.nd4j.autodiff.validation.TestCase;
import org.nd4j.enums.WeightsFormat;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.evaluation.classification.EvaluationBinary;
import org.nd4j.evaluation.classification.EvaluationCalibration;
import org.nd4j.evaluation.classification.ROC;
import org.nd4j.evaluation.classification.ROCBinary;
import org.nd4j.evaluation.classification.ROCMultiClass;
import org.nd4j.evaluation.regression.RegressionEvaluation;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.activations.Activation;
@ -3485,42 +3491,42 @@ public class SameDiffTests extends BaseNd4jTest {
}
@Test
public void testConcatVariableGrad() {
SameDiff sd = SameDiff.create();
SDVariable label = sd.var("label", DataType.FLOAT, 3, 4);
SDVariable a = sd.var("a", DataType.FLOAT, 3, 2);
SDVariable b = sd.var("b", DataType.FLOAT, 3, 2);
INDArray inputArr = Nd4j.rand(3,4);
INDArray labelArr = Nd4j.rand(3,4);
SDVariable c = sd.concat("concat", 1, a, b);
SDVariable loss = sd.math().pow(c.sub(label), 2);
sd.setLossVariables(loss);
sd.associateArrayWithVariable(labelArr, label);
sd.associateArrayWithVariable(inputArr.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 2)), a);
sd.associateArrayWithVariable(inputArr.get(NDArrayIndex.all(), NDArrayIndex.interval(2, 4)), b);
Map<String, INDArray> map = sd.calculateGradients(null, "a", "b", "concat");
INDArray concatArray = Nd4j.hstack(map.get("a"), map.get("b"));
assertEquals(concatArray, map.get("concat"));
public void testConcatVariableGrad() {
SameDiff sd = SameDiff.create();
SDVariable label = sd.var("label", DataType.FLOAT, 3, 4);
SDVariable a = sd.var("a", DataType.FLOAT, 3, 2);
SDVariable b = sd.var("b", DataType.FLOAT, 3, 2);
INDArray inputArr = Nd4j.rand(3,4);
INDArray labelArr = Nd4j.rand(3,4);
SDVariable c = sd.concat("concat", 1, a, b);
SDVariable loss = sd.math().pow(c.sub(label), 2);
sd.setLossVariables(loss);
sd.associateArrayWithVariable(labelArr, label);
sd.associateArrayWithVariable(inputArr.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 2)), a);
sd.associateArrayWithVariable(inputArr.get(NDArrayIndex.all(), NDArrayIndex.interval(2, 4)), b);
Map<String, INDArray> map = sd.calculateGradients(null, "a", "b", "concat");
INDArray concatArray = Nd4j.hstack(map.get("a"), map.get("b"));
assertEquals(concatArray, map.get("concat"));
}
}
@Test
public void testSliceVariableGrad() {
SameDiff sd = SameDiff.create();
SDVariable label = sd.var("label", DataType.FLOAT, 3, 4);
SDVariable input = sd.var("input", DataType.FLOAT, 3, 4);
INDArray inputArr = Nd4j.rand(3,4);
INDArray labelArr = Nd4j.rand(3,4);
SDVariable a = input.get(SDIndex.all(), SDIndex.interval(0, 2));
SDVariable b = input.get(SDIndex.all(), SDIndex.interval(2, 4));
SDVariable c = sd.concat("concat", 1, a, b);
SDVariable loss = sd.math().pow(c.sub(label), 2);
sd.setLossVariables(loss);
sd.associateArrayWithVariable(labelArr, label);
sd.associateArrayWithVariable(inputArr, input);
Map<String, INDArray> map = sd.calculateGradients(null,"input", "concat");
assertEquals(map.get("input"), map.get("concat"));
}
@Test
public void testSliceVariableGrad() {
SameDiff sd = SameDiff.create();
SDVariable label = sd.var("label", DataType.FLOAT, 3, 4);
SDVariable input = sd.var("input", DataType.FLOAT, 3, 4);
INDArray inputArr = Nd4j.rand(3,4);
INDArray labelArr = Nd4j.rand(3,4);
SDVariable a = input.get(SDIndex.all(), SDIndex.interval(0, 2));
SDVariable b = input.get(SDIndex.all(), SDIndex.interval(2, 4));
SDVariable c = sd.concat("concat", 1, a, b);
SDVariable loss = sd.math().pow(c.sub(label), 2);
sd.setLossVariables(loss);
sd.associateArrayWithVariable(labelArr, label);
sd.associateArrayWithVariable(inputArr, input);
Map<String, INDArray> map = sd.calculateGradients(null,"input", "concat");
assertEquals(map.get("input"), map.get("concat"));
}
@Test
public void testTrainingConfigJson(){

View File

@ -17,6 +17,13 @@
package org.nd4j.autodiff.samediff.listeners;
import org.junit.Test;
import org.nd4j.autodiff.listeners.At;
import org.nd4j.autodiff.listeners.BaseListener;
import org.nd4j.autodiff.listeners.Listener;
import org.nd4j.autodiff.listeners.ListenerResponse;
import org.nd4j.autodiff.listeners.ListenerVariables;
import org.nd4j.autodiff.listeners.Loss;
import org.nd4j.autodiff.listeners.Operation;
import org.nd4j.autodiff.listeners.impl.ScoreListener;
import org.nd4j.autodiff.listeners.records.History;
import org.nd4j.autodiff.listeners.records.LossCurve;
@ -351,7 +358,7 @@ public class ListenerTest extends BaseNd4jTest {
}
@Override
public void iterationDone(SameDiff sd, At at, MultiDataSet dataSet, Loss loss) {
public void iterationDone(final SameDiff sd, final At at, final MultiDataSet dataSet, final Loss loss) {
iterationDoneCount++;
}

View File

@ -28,6 +28,13 @@ import org.nd4j.autodiff.samediff.VariableType;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.autodiff.samediff.internal.Variable;
import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper;
import org.nd4j.graph.FlatArray;
import org.nd4j.graph.UIAddName;
import org.nd4j.graph.UIEvent;
import org.nd4j.graph.UIGraphStructure;
import org.nd4j.graph.UIInfoType;
import org.nd4j.graph.UIOp;
import org.nd4j.graph.UIVariable;
import org.nd4j.graph.ui.LogFileWriter;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.buffer.DataType;
@ -150,7 +157,6 @@ public class FileReadWriteTests extends BaseNd4jTest {
assertEquals(UIInfoType.START_EVENTS, read.getData().get(1).getFirst().infoType());
//Append a number of events
w.registerEventName("accuracy");
for( int iter=0; iter<3; iter++) {

View File

@ -1,6 +1,12 @@
package org.nd4j.evaluation;
import org.junit.Test;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.evaluation.classification.EvaluationBinary;
import org.nd4j.evaluation.classification.EvaluationCalibration;
import org.nd4j.evaluation.classification.ROC;
import org.nd4j.evaluation.classification.ROCBinary;
import org.nd4j.evaluation.classification.ROCMultiClass;
import org.nd4j.evaluation.regression.RegressionEvaluation;
import org.nd4j.evaluation.regression.RegressionEvaluation.Metric;
import org.nd4j.linalg.BaseNd4jTest;
@ -107,7 +113,6 @@ public class EmptyEvaluationTests extends BaseNd4jTest {
assertTrue(t.getMessage(), t.getMessage().contains("no data"));
}
}
}
@Test

View File

@ -17,6 +17,12 @@
package org.nd4j.evaluation;
import org.junit.Test;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.evaluation.classification.EvaluationBinary;
import org.nd4j.evaluation.classification.EvaluationCalibration;
import org.nd4j.evaluation.classification.ROC;
import org.nd4j.evaluation.classification.ROCBinary;
import org.nd4j.evaluation.classification.ROCMultiClass;
import org.nd4j.evaluation.curves.Histogram;
import org.nd4j.evaluation.curves.PrecisionRecallCurve;
import org.nd4j.evaluation.curves.RocCurve;
@ -100,8 +106,6 @@ public class EvalJsonTest extends BaseNd4jTest {
regressionEvaluation.eval(Nd4j.rand(10, 3), Nd4j.rand(10, 3));
for (IEvaluation e : arr) {
String json = e.toJson();
if (print) {

View File

@ -22,6 +22,11 @@ import org.junit.Test;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.TrainingConfig;
import org.nd4j.autodiff.samediff.transform.GraphTransformUtil;
import org.nd4j.autodiff.samediff.transform.OpPredicate;
import org.nd4j.autodiff.samediff.transform.SubGraph;
import org.nd4j.autodiff.samediff.transform.SubGraphPredicate;
import org.nd4j.autodiff.samediff.transform.SubGraphProcessor;
import org.nd4j.graph.ui.LogFileWriter;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.imports.tensorflow.TFImportOverride;

View File

@ -35,7 +35,7 @@ import java.util.ServiceLoader;
public class Nd4jTestSuite extends BlockJUnit4ClassRunner {
//the system property for what backends should run
public final static String BACKENDS_TO_LOAD = "backends";
private static List<Nd4jBackend> BACKENDS;
private static List<Nd4jBackend> BACKENDS = new ArrayList<>();
static {
ServiceLoader<Nd4jBackend> loadedBackends = ND4JClassLoading.loadService(Nd4jBackend.class);
for (Nd4jBackend backend : loadedBackends) {

View File

@ -16,6 +16,13 @@
package org.nd4j.linalg;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
@ -23,10 +30,18 @@ import lombok.var;
import org.apache.commons.io.FilenameUtils;
import org.apache.commons.math3.stat.descriptive.rank.Percentile;
import org.apache.commons.math3.util.FastMath;
import org.junit.*;
import org.junit.After;
import org.junit.Before;
import org.junit.Ignore;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.common.io.ClassPathResource;
import org.nd4j.common.primitives.Pair;
import org.nd4j.common.util.ArrayUtil;
import org.nd4j.common.util.MathUtils;
import org.nd4j.enums.WeightsFormat;
import org.nd4j.imports.TFGraphs.NodeReader;
import org.nd4j.linalg.api.blas.Level1;
@ -47,6 +62,14 @@ import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.executioner.GridExecutioner;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastAMax;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastAMin;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastAddOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastDivOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMax;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMin;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMulOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastSubOp;
import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastEqualTo;
import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastGreaterThan;
import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastGreaterThanOrEqual;
@ -64,6 +87,11 @@ import org.nd4j.linalg.api.ops.impl.reduce.custom.LogSumExp;
import org.nd4j.linalg.api.ops.impl.reduce.floating.Norm1;
import org.nd4j.linalg.api.ops.impl.reduce.floating.Norm2;
import org.nd4j.linalg.api.ops.impl.reduce.same.Sum;
import org.nd4j.linalg.api.ops.impl.reduce3.CosineDistance;
import org.nd4j.linalg.api.ops.impl.reduce3.CosineSimilarity;
import org.nd4j.linalg.api.ops.impl.reduce3.EuclideanDistance;
import org.nd4j.linalg.api.ops.impl.reduce3.HammingDistance;
import org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance;
import org.nd4j.linalg.api.ops.impl.scalar.LeakyReLU;
import org.nd4j.linalg.api.ops.impl.scalar.ReplaceNans;
import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarEquals;
@ -73,8 +101,8 @@ import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax;
import org.nd4j.linalg.api.ops.impl.transforms.bool.MatchConditionTransform;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.Eps;
import org.nd4j.linalg.api.ops.impl.transforms.custom.Reverse;
import org.nd4j.linalg.api.ops.impl.transforms.custom.BatchToSpaceND;
import org.nd4j.linalg.api.ops.impl.transforms.custom.Reverse;
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.BinaryRelativeError;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.Set;
@ -94,20 +122,28 @@ import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.indexing.SpecifiedIndex;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.common.io.ClassPathResource;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.common.primitives.Pair;
import org.nd4j.common.util.ArrayUtil;
import org.nd4j.common.util.MathUtils;
import java.io.*;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.*;
import static org.junit.Assert.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
/**
* NDArrayTests
@ -148,8 +184,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
Nd4j.setDataType(initialType);
}
@Test
public void testArangeNegative() {
INDArray arr = Nd4j.arange(-2,2).castTo(DataType.DOUBLE);
@ -241,9 +275,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
INDArray inDup = in.dup();
// System.out.println(in);
// System.out.println(inDup);
assertEquals(arr, in); //Passes: Original array "in" is OK, but array "inDup" is not!?
assertEquals(in, inDup); //Fails
}
@ -310,7 +341,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
assertEquals(assertion,test);
}
@Test
public void testAudoBroadcastAddMatrix() {
INDArray arr = Nd4j.linspace(1,4,4, DataType.DOUBLE).reshape(2,2);
@ -336,7 +366,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
}
@Test
public void testTensorAlongDimension() {
val shape = new long[] {4, 5, 7};
@ -538,7 +567,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
}
@Test
public void testGetColumns() {
INDArray matrix = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3);
@ -2719,7 +2747,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
assertEquals(exp, zOutCF); //fails
}
@Test
public void testBroadcastDiv() {
INDArray num = Nd4j.create(new double[] {1.00, 1.00, 1.00, 1.00, 2.00, 2.00, 2.00, 2.00, 1.00, 1.00, 1.00, 1.00,
@ -2753,7 +2780,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
}
}
@Test
public void testBroadcastMult() {
INDArray num = Nd4j.create(new double[] {1.00, 2.00, 3.00, 4.00, 5.00, 6.00, 7.00, 8.00, -1.00, -2.00, -3.00,
@ -2796,7 +2822,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
assertEquals(expected, actual);
}
@Test
public void testDimension() {
INDArray test = Nd4j.create(Nd4j.linspace(1, 4, 4, DataType.DOUBLE).data(), new long[] {2, 2});
@ -4595,8 +4620,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
-1.25485503673});
INDArray reduced = Nd4j.getExecutioner().exec(new CosineDistance(haystack, needle, 1));
// log.info("Reduced: {}", reduced);
INDArray exp = Nd4j.create(new double[] {0.577452, 0.0, 1.80182});
assertEquals(exp, reduced);
@ -4606,9 +4629,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
double res = Nd4j.getExecutioner().execAndReturn(new CosineDistance(row, needle)).z().getDouble(0);
assertEquals("Failed at " + i, reduced.getDouble(i), res, 1e-5);
}
//cosinedistance([-0.84443557262, -0.06822254508, 0.74266910552, 0.61765557527, -0.77555125951], [-0.99536740779, -0.0257304441183, -0.6512106060, -0.345789492130, -1.25485503673)
//cosinedistance([.62955373525, -0.31357592344, 1.03362500667, -0.59279078245, 1.1914824247], [-0.99536740779, -0.0257304441183, -0.6512106060, -0.345789492130, -1.25485503673)
}
@Test
@ -4677,8 +4697,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
double res = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(x, needle)).getFinalResult()
.doubleValue();
assertEquals("Failed at " + i, reduced.getDouble(i), res, 0.001);
// log.info("Euclidean: {} vs {} is {}", x, needle, res);
}
}
@ -4698,8 +4716,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
double res = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(x, needle)).getFinalResult()
.doubleValue();
assertEquals("Failed at " + i, reduced.getDouble(i), res, 0.001);
// log.info("Euclidean: {} vs {} is {}", x, needle, res);
}
}
@ -4720,8 +4736,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
double res = Nd4j.getExecutioner().execAndReturn(new CosineSimilarity(x, needle)).getFinalResult()
.doubleValue();
assertEquals("Failed at " + i, reduced.getDouble(i), res, 0.001);
// log.info("Cosine: {} vs {} is {}", x, needle, res);
}
}
@ -4755,7 +4769,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
}
}
@Test
public void testAtan2_1() {
INDArray x = Nd4j.create(10).assign(-1.0);
@ -4767,7 +4780,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
assertEquals(exp, z);
}
@Test
public void testAtan2_2() {
INDArray x = Nd4j.create(10).assign(1.0);
@ -4779,7 +4791,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
assertEquals(exp, z);
}
@Test
public void testJaccardDistance1() {
INDArray x = Nd4j.create(new double[] {0, 1, 0, 0, 1, 0});
@ -4790,7 +4801,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
assertEquals(0.75, val, 1e-5);
}
@Test
public void testJaccardDistance2() {
INDArray x = Nd4j.create(new double[] {0, 1, 0, 0, 1, 1});
@ -4811,7 +4821,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
assertEquals(2.0 / 6, val, 1e-5);
}
@Test
public void testHammingDistance2() {
INDArray x = Nd4j.create(new double[] {0, 0, 0, 1, 0, 0});
@ -4822,7 +4831,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
assertEquals(3.0 / 6, val, 1e-5);
}
@Test
public void testHammingDistance3() {
INDArray x = Nd4j.create(DataType.DOUBLE, 10, 6);
@ -4831,7 +4839,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
x.getRow(r).putScalar(p, 1);
}
// log.info("X: {}", x);
INDArray y = Nd4j.create(new double[] {0, 0, 0, 0, 1, 0});
INDArray res = Nd4j.getExecutioner().exec(new HammingDistance(x, y, 1));
@ -4846,7 +4853,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
}
}
@Test
public void testAllDistances1() {
INDArray initialX = Nd4j.create(5, 10);
@ -4879,7 +4885,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
}
}
@Test
public void testAllDistances2() {
INDArray initialX = Nd4j.create(5, 10);
@ -4940,7 +4945,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
}
}
@Test
public void testAllDistances3_Large() {
INDArray initialX = Nd4j.create(5, 2000);
@ -4968,13 +4972,11 @@ public class Nd4jTestsC extends BaseNd4jTest {
double res = result.getDouble(x, y);
double exp = Transforms.euclideanDistance(rowX, initialY.getRow(y).dup());
//log.info("Expected [{}, {}]: {}",x, y, exp);
assertEquals("Failed for [" + x + ", " + y + "]", exp, res, 0.001);
}
}
}
@Test
public void testAllDistances3_Large_Columns() {
INDArray initialX = Nd4j.create(2000, 5);
@ -5005,7 +5007,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
}
}
@Test
public void testAllDistances4_Large_Columns() {
INDArray initialX = Nd4j.create(2000, 5);
@ -5095,8 +5096,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
}
}
@Test
public void testAllDistances3() {
Nd4j.getRandom().setSeed(123);
@ -5122,7 +5121,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
}
}
@Test
public void testStridedTransforms1() {
//output: Rank: 2,Offset: 0
@ -5176,7 +5174,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
}
}
@Test
public void testEntropy3() {
INDArray x = Nd4j.rand(1, 100);
@ -5197,7 +5194,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
assertEquals(exp, res, 1e-5);
}
protected double getShannonEntropy(double[] array) {
double ret = 0;
for (double x : array) {
@ -5207,12 +5203,10 @@ public class Nd4jTestsC extends BaseNd4jTest {
return -ret;
}
protected double getLogEntropy(double[] array) {
return Math.log(MathUtils.entropy(array));
}
@Test
public void testReverse1() {
INDArray array = Nd4j.create(new double[] {9, 8, 7, 6, 5, 4, 3, 2, 1, 0});
@ -5228,8 +5222,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
INDArray array = Nd4j.create(new double[] {10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0});
INDArray exp = Nd4j.create(new double[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10});
// log.info("Array shapeInfo: {}", array.shapeInfoJava());
INDArray rev = Nd4j.reverse(array);
assertEquals(exp, rev);
@ -5278,7 +5270,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
assertTrue(rev == array);
}
@Test
public void testNativeSortView1() {
INDArray matrix = Nd4j.create(10, 10);
@ -5291,9 +5282,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
Nd4j.sort(matrix.getColumn(0), true);
// log.info("Matrix: {}", matrix);
assertEquals(exp, matrix.getColumn(0));
}
@ -5384,9 +5372,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
Transforms.reverse(array, false);
// log.info("Reversed shapeInfo: {}", array.shapeInfoJava());
// log.info("Reversed: {}", array);
Transforms.reverse(array, false);
val jexp = exp.data().asInt();
@ -5401,9 +5386,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
val exp = array.dup(array.ordering());
val reversed = Transforms.reverse(array, true);
// log.info("Reversed: {}", reversed);
val rereversed = Transforms.reverse(reversed, true);
val jexp = exp.data().asInt();
@ -5445,8 +5427,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
INDArray array = Nd4j.linspace(1, 2017152, 2017152, DataType.DOUBLE).reshape(1, -1);
INDArray exp = array.dup();
Transforms.reverse(array, false);
// log.info("Reverse: {}", array);
long time1 = System.currentTimeMillis();
INDArray res = Nd4j.sort(array, true);
@ -5465,7 +5445,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
assertNotEquals(exp1, dps);
for (int r = 0; r < array.rows(); r++) {
array.getRow(r).assign(dps);
}
@ -5485,7 +5464,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
}
}
protected boolean checkIfUnique(INDArray array, int iteration) {
var jarray = array.data().asInt();
var set = new HashSet<Integer>();
@ -5698,7 +5676,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
}
}
@Test
public void testBroadcastAMax() {
INDArray matrix = Nd4j.create(5, 5);
@ -5715,7 +5692,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
}
}
@Test
public void testBroadcastAMin() {
INDArray matrix = Nd4j.create(5, 5);
@ -5767,7 +5743,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
assertEquals(exp, res);
}
@Test
public void testRDiv1() {
val argX = Nd4j.create(3).assign(2.0);
@ -5789,7 +5764,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
assertEquals(arrayC, arrayF);
}
@Test
public void testMatchTransform() {
val array = Nd4j.create(new double[] {1, 1, 1, 0, 1, 1},'c');
@ -5838,10 +5812,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
val a = Nd4j.linspace(1, x * A1 * A2, x * A1 * A2, DataType.DOUBLE).reshape(x, A1, A2);
val b = Nd4j.linspace(1, x * B1 * B2, x * B1 * B2, DataType.DOUBLE).reshape(x, B1, B2);
//
//log.info("C shape: {}", Arrays.toString(c.shapeInfoDataBuffer().asInt()));
}
@Test

View File

@ -21,6 +21,21 @@ import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.activations.impl.ActivationCube;
import org.nd4j.linalg.activations.impl.ActivationELU;
import org.nd4j.linalg.activations.impl.ActivationGELU;
import org.nd4j.linalg.activations.impl.ActivationHardSigmoid;
import org.nd4j.linalg.activations.impl.ActivationHardTanH;
import org.nd4j.linalg.activations.impl.ActivationIdentity;
import org.nd4j.linalg.activations.impl.ActivationLReLU;
import org.nd4j.linalg.activations.impl.ActivationRReLU;
import org.nd4j.linalg.activations.impl.ActivationRationalTanh;
import org.nd4j.linalg.activations.impl.ActivationReLU;
import org.nd4j.linalg.activations.impl.ActivationSigmoid;
import org.nd4j.linalg.activations.impl.ActivationSoftPlus;
import org.nd4j.linalg.activations.impl.ActivationSoftSign;
import org.nd4j.linalg.activations.impl.ActivationSoftmax;
import org.nd4j.linalg.activations.impl.ActivationTanH;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
@ -73,7 +88,6 @@ public class TestActivation extends BaseNd4jTest {
double[] dIn = in.data().asDouble();
for( int i=0; i<max.length; i++ ){
// System.out.println("i = " + i);
ActivationReLU r = new ActivationReLU(max[i], threshold[i], negativeSlope[i]);
INDArray out = r.getActivation(in.dup(), true);
double[] exp = new double[dIn.length];
@ -145,7 +159,6 @@ public class TestActivation extends BaseNd4jTest {
for (int i = 0; i < activations.length; i++) {
String asJson = mapper.writeValueAsString(activations[i]);
// System.out.println(asJson);
JsonNode node = mapper.readTree(asJson);

View File

@ -26,6 +26,13 @@ import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.IntervalIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndexAll;
import org.nd4j.linalg.indexing.NewAxis;
import org.nd4j.linalg.indexing.PointIndex;
import org.nd4j.linalg.indexing.SpecifiedIndex;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.common.util.ArrayUtil;
@ -75,8 +82,6 @@ public class IndexingTestsC extends BaseNd4jTest {
final INDArray aBad = col.broadcast(2, 2);
final INDArray aGood = col.dup().broadcast(2, 2);
// System.out.println(aBad);
// System.out.println(aGood);
assertTrue(Transforms.abs(aGood.sub(aBad).div(aGood)).maxNumber().doubleValue() < 0.01);
}
@ -446,12 +451,6 @@ public class IndexingTestsC extends BaseNd4jTest {
msg = "Test case: rank = " + rank + ", order = " + order + ", inShape = " + Arrays.toString(inShape) +
", outShape = " + Arrays.toString(expShape) +
", indexes = " + Arrays.toString(indexes) + ", newAxisTest=" + newAxisTestCase;
// System.out.println(msg);
// System.out.println(arr);
// System.out.println(sub);
// System.out.println();
NdIndexIterator posIter = new NdIndexIterator(expShape);
while (posIter.hasNext()) {
@ -467,7 +466,6 @@ public class IndexingTestsC extends BaseNd4jTest {
}
}
// System.out.println("TOTAL TEST CASES: " + totalTestCaseCount);
assertTrue(String.valueOf(totalTestCaseCount), totalTestCaseCount > 5000);
}
@ -556,14 +554,8 @@ public class IndexingTestsC extends BaseNd4jTest {
char order = 'c';
INDArray arr = Nd4j.linspace(DataType.FLOAT, 1, prod, prod).reshape('c', inShape).dup(order);
INDArray sub = arr.get(indexes);
// System.out.println(Arrays.toString(indexes));
// System.out.println(arr);
// System.out.println();
// System.out.println(sub);
}
@Override
public char ordering() {
return 'c';

View File

@ -1,8 +1,13 @@
package org.nd4j.linalg.convolution;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.nd4j.common.io.ClassPathResource;
import org.nd4j.common.resources.Resources;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -10,12 +15,13 @@ import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.common.resources.Resources;
import java.io.File;
import java.util.*;
import static org.junit.Assert.*;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
public class DeconvTests extends BaseNd4jTest {
@ -33,10 +39,10 @@ public class DeconvTests extends BaseNd4jTest {
@Test
public void compareKeras() throws Exception {
File f = testDir.newFolder();
Resources.copyDirectory("keras/deconv", f);
File newFolder = testDir.newFolder();
new ClassPathResource("keras/deconv/").copyDirectory(newFolder);
File[] files = f.listFiles();
File[] files = newFolder.listFiles();
Set<String> tests = new HashSet<>();
for(File file : files){
@ -64,10 +70,10 @@ public class DeconvTests extends BaseNd4jTest {
int d = Integer.parseInt(nums[5]);
boolean nchw = s.contains("nchw");
INDArray w = Nd4j.readNpy(new File(f, s + "_W.npy"));
INDArray b = Nd4j.readNpy(new File(f, s + "_b.npy"));
INDArray in = Nd4j.readNpy(new File(f, s + "_in.npy")).castTo(DataType.FLOAT);
INDArray expOut = Nd4j.readNpy(new File(f, s + "_out.npy"));
INDArray w = Nd4j.readNpy(new File(newFolder, s + "_W.npy"));
INDArray b = Nd4j.readNpy(new File(newFolder, s + "_b.npy"));
INDArray in = Nd4j.readNpy(new File(newFolder, s + "_in.npy")).castTo(DataType.FLOAT);
INDArray expOut = Nd4j.readNpy(new File(newFolder, s + "_out.npy"));
CustomOp op = DynamicCustomOp.builder("deconv2d")
.addInputs(in, w, b)

View File

@ -26,6 +26,37 @@ import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.custom.AdjustContrast;
import org.nd4j.linalg.api.ops.custom.AdjustHue;
import org.nd4j.linalg.api.ops.custom.AdjustSaturation;
import org.nd4j.linalg.api.ops.custom.BetaInc;
import org.nd4j.linalg.api.ops.custom.BitCast;
import org.nd4j.linalg.api.ops.custom.CompareAndBitpack;
import org.nd4j.linalg.api.ops.custom.DivideNoNan;
import org.nd4j.linalg.api.ops.custom.DrawBoundingBoxes;
import org.nd4j.linalg.api.ops.custom.FakeQuantWithMinMaxVarsPerChannel;
import org.nd4j.linalg.api.ops.custom.Flatten;
import org.nd4j.linalg.api.ops.custom.FusedBatchNorm;
import org.nd4j.linalg.api.ops.custom.HsvToRgb;
import org.nd4j.linalg.api.ops.custom.KnnMinDistance;
import org.nd4j.linalg.api.ops.custom.Lgamma;
import org.nd4j.linalg.api.ops.custom.LinearSolve;
import org.nd4j.linalg.api.ops.custom.Logdet;
import org.nd4j.linalg.api.ops.custom.Lstsq;
import org.nd4j.linalg.api.ops.custom.Lu;
import org.nd4j.linalg.api.ops.custom.MatrixBandPart;
import org.nd4j.linalg.api.ops.custom.Polygamma;
import org.nd4j.linalg.api.ops.custom.RandomCrop;
import org.nd4j.linalg.api.ops.custom.RgbToGrayscale;
import org.nd4j.linalg.api.ops.custom.RgbToHsv;
import org.nd4j.linalg.api.ops.custom.RgbToYiq;
import org.nd4j.linalg.api.ops.custom.RgbToYuv;
import org.nd4j.linalg.api.ops.custom.Roll;
import org.nd4j.linalg.api.ops.custom.ScatterUpdate;
import org.nd4j.linalg.api.ops.custom.ToggleBits;
import org.nd4j.linalg.api.ops.custom.TriangularSolve;
import org.nd4j.linalg.api.ops.custom.YiqToRgb;
import org.nd4j.linalg.api.ops.custom.YuvToRgb;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.api.ops.executioner.OpStatus;
import org.nd4j.linalg.api.ops.impl.controlflow.Where;
@ -373,7 +404,6 @@ public class CustomOpsTests extends BaseNd4jTest {
ScatterUpdate op = new ScatterUpdate(matrix, updates, indices, dims, ScatterUpdate.UpdateOp.ADD);
Nd4j.getExecutioner().exec(op);
// log.info("Matrix: {}", matrix);
assertEquals(exp0, matrix.getRow(0));
assertEquals(exp1, matrix.getRow(1));
assertEquals(exp0, matrix.getRow(2));
@ -1384,8 +1414,6 @@ public class CustomOpsTests extends BaseNd4jTest {
INDArray y = Nd4j.linspace(DataType.FLOAT, -5, 9, 1).reshape(3, 3);
val c = Conditions.equals(0.0);
// System.out.println("Y:\n" + y);
INDArray z = x.match(y, c);
INDArray exp = Nd4j.createFromArray(new boolean[][]{
{false, false, false},
@ -1396,7 +1424,6 @@ public class CustomOpsTests extends BaseNd4jTest {
assertEquals(exp, z);
}
@Test
public void testCreateOp_1() {
val shape = Nd4j.createFromArray(new int[] {3, 4, 5});
@ -1862,11 +1889,9 @@ public class CustomOpsTests extends BaseNd4jTest {
System.out.println("in: " + in.shapeInfoToString());
System.out.println("inBadStrides: " + inBadStrides.shapeInfoToString());
INDArray out = Nd4j.create(DataType.FLOAT, 2, 12, 3, 3);
INDArray out2 = out.like();
CustomOp op1 = DynamicCustomOp.builder("space_to_depth")
.addInputs(in)
.addIntegerArguments(2, 0) //nchw = 0, nhwc = 1

View File

@ -22,6 +22,14 @@ import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.dataset.api.preprocessor.AbstractDataSetNormalizer;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.nd4j.linalg.dataset.api.preprocessor.MinMaxStrategy;
import org.nd4j.linalg.dataset.api.preprocessor.MultiNormalizerHybrid;
import org.nd4j.linalg.dataset.api.preprocessor.MultiNormalizerMinMaxScaler;
import org.nd4j.linalg.dataset.api.preprocessor.MultiNormalizerStandardize;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerMinMaxScaler;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
import org.nd4j.linalg.dataset.api.preprocessor.serializer.CustomSerializerStrategy;
import org.nd4j.linalg.dataset.api.preprocessor.serializer.NormalizerSerializer;
import org.nd4j.linalg.dataset.api.preprocessor.serializer.NormalizerType;
@ -65,7 +73,6 @@ public class NormalizerSerializerTest extends BaseNd4jTest {
ImagePreProcessingScaler restored = SUT.restore(tmpFile);
assertEquals(imagePreProcessingScaler,restored);
}
@Test

View File

@ -25,6 +25,15 @@ import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.TestDataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.ImageMultiPreProcessingScaler;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.nd4j.linalg.dataset.api.preprocessor.MultiDataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.MultiNormalizerMinMaxScaler;
import org.nd4j.linalg.dataset.api.preprocessor.MultiNormalizerStandardize;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerMinMaxScaler;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
import org.nd4j.linalg.dataset.api.preprocessor.VGG16ImagePreProcessor;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.indexing.NDArrayIndex;

View File

@ -24,6 +24,12 @@ import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.distribution.Distribution;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.learning.config.AdaDelta;
import org.nd4j.linalg.learning.config.AdaGrad;
import org.nd4j.linalg.learning.config.AdaMax;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.Nadam;
import org.nd4j.linalg.learning.config.Nesterovs;
import static org.junit.Assert.assertEquals;
@ -53,7 +59,6 @@ public class UpdaterTest extends BaseNd4jTest {
int rows = 10;
int cols = 2;
NesterovsUpdater grad = new NesterovsUpdater(new Nesterovs(0.5, 0.9));
grad.setStateViewArray(Nd4j.zeros(1, rows * cols), new long[] {rows, cols}, 'c', true);
INDArray W = Nd4j.zeros(rows, cols);
@ -68,13 +73,11 @@ public class UpdaterTest extends BaseNd4jTest {
}
}
@Test
public void testAdaGrad() {
int rows = 10;
int cols = 2;
AdaGradUpdater grad = new AdaGradUpdater(new AdaGrad(0.1, AdaGrad.DEFAULT_ADAGRAD_EPSILON));
grad.setStateViewArray(Nd4j.zeros(1, rows * cols), new long[] {rows, cols}, 'c', true);
INDArray W = Nd4j.zeros(rows, cols);

View File

@ -23,6 +23,15 @@ import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.updaters.AmsGradUpdater;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.learning.config.AMSGrad;
import org.nd4j.linalg.learning.config.AdaDelta;
import org.nd4j.linalg.learning.config.AdaGrad;
import org.nd4j.linalg.learning.config.AdaMax;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.Nadam;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.learning.config.RmsProp;
import org.nd4j.linalg.learning.config.Sgd;
import java.util.HashMap;
import java.util.Map;

View File

@ -21,6 +21,21 @@ import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.lossfunctions.impl.LossBinaryXENT;
import org.nd4j.linalg.lossfunctions.impl.LossCosineProximity;
import org.nd4j.linalg.lossfunctions.impl.LossHinge;
import org.nd4j.linalg.lossfunctions.impl.LossKLD;
import org.nd4j.linalg.lossfunctions.impl.LossL1;
import org.nd4j.linalg.lossfunctions.impl.LossL2;
import org.nd4j.linalg.lossfunctions.impl.LossMAE;
import org.nd4j.linalg.lossfunctions.impl.LossMAPE;
import org.nd4j.linalg.lossfunctions.impl.LossMCXENT;
import org.nd4j.linalg.lossfunctions.impl.LossMSE;
import org.nd4j.linalg.lossfunctions.impl.LossMSLE;
import org.nd4j.linalg.lossfunctions.impl.LossMultiLabel;
import org.nd4j.linalg.lossfunctions.impl.LossNegativeLogLikelihood;
import org.nd4j.linalg.lossfunctions.impl.LossPoisson;
import org.nd4j.linalg.lossfunctions.impl.LossSquaredHinge;
import org.nd4j.shade.jackson.databind.DeserializationFeature;
import org.nd4j.shade.jackson.databind.MapperFeature;
import org.nd4j.shade.jackson.databind.ObjectMapper;

View File

@ -28,6 +28,16 @@ import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.lossfunctions.impl.LossBinaryXENT;
import org.nd4j.linalg.lossfunctions.impl.LossL1;
import org.nd4j.linalg.lossfunctions.impl.LossL2;
import org.nd4j.linalg.lossfunctions.impl.LossMAE;
import org.nd4j.linalg.lossfunctions.impl.LossMAPE;
import org.nd4j.linalg.lossfunctions.impl.LossMCXENT;
import org.nd4j.linalg.lossfunctions.impl.LossMSE;
import org.nd4j.linalg.lossfunctions.impl.LossMSLE;
import org.nd4j.linalg.lossfunctions.impl.LossNegativeLogLikelihood;
import org.nd4j.linalg.lossfunctions.impl.LossSparseMCXENT;
import static junit.framework.TestCase.assertFalse;
import static junit.framework.TestCase.assertTrue;

View File

@ -24,6 +24,13 @@ import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.ops.BaseBroadcastOp;
import org.nd4j.linalg.api.ops.BaseIndexAccumulation;
import org.nd4j.linalg.api.ops.BaseReduceFloatOp;
import org.nd4j.linalg.api.ops.BaseScalarOp;
import org.nd4j.linalg.api.ops.BaseTransformSameOp;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.impl.summarystats.Variance;
import org.nd4j.linalg.api.ops.random.BaseRandomOp;
import org.nd4j.linalg.factory.Nd4j;

View File

@ -26,6 +26,12 @@ import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.scalar.Step;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardSigmoidDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Sigmoid;
import org.nd4j.linalg.api.ops.impl.transforms.strict.SigmoidDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.strict.TanhDerivative;
import org.nd4j.linalg.factory.Nd4j;
@ -84,7 +90,6 @@ public class DerivativeTests extends BaseNd4jTest {
}
}
@Test
public void testRectifiedLinearDerivative() {
//ReLU:
@ -166,11 +171,7 @@ public class DerivativeTests extends BaseNd4jTest {
}
INDArray z = Transforms.hardSigmoid(xArr, true);
INDArray zPrime = Nd4j.getExecutioner().exec(new HardSigmoidDerivative(xArr.dup()));
// System.out.println(xArr);
// System.out.println(z);
// System.out.println(zPrime);
INDArray zPrime = Nd4j.getExecutioner().exec(new HardSigmoidDerivative(xArr.dup()));;
for (int i = 0; i < expHSOut.length; i++) {
double relErrorHS =

View File

@ -32,6 +32,21 @@ import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.reduce.floating.Mean;
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
import org.nd4j.linalg.api.ops.random.custom.DistributionUniform;
import org.nd4j.linalg.api.ops.random.custom.RandomBernoulli;
import org.nd4j.linalg.api.ops.random.custom.RandomGamma;
import org.nd4j.linalg.api.ops.random.custom.RandomPoisson;
import org.nd4j.linalg.api.ops.random.custom.RandomShuffle;
import org.nd4j.linalg.api.ops.random.impl.AlphaDropOut;
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
import org.nd4j.linalg.api.ops.random.impl.BinomialDistribution;
import org.nd4j.linalg.api.ops.random.impl.DropOut;
import org.nd4j.linalg.api.ops.random.impl.DropOutInverted;
import org.nd4j.linalg.api.ops.random.impl.GaussianDistribution;
import org.nd4j.linalg.api.ops.random.impl.Linspace;
import org.nd4j.linalg.api.ops.random.impl.LogNormalDistribution;
import org.nd4j.linalg.api.ops.random.impl.TruncatedNormalDistribution;
import org.nd4j.linalg.api.ops.random.impl.UniformDistribution;
import org.nd4j.linalg.api.rng.DefaultRandom;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.api.rng.distribution.Distribution;
@ -78,7 +93,6 @@ public class RandomTests extends BaseNd4jTest {
@Test
public void testCrossBackendEquality1() {
int[] shape = {12};
double mean = 0;
double standardDeviation = 1.0;
@ -87,8 +101,6 @@ public class RandomTests extends BaseNd4jTest {
INDArray arr = Nd4j.getExecutioner().exec(new GaussianDistribution(
Nd4j.createUninitialized(shape, Nd4j.order()), mean, standardDeviation), Nd4j.getRandom());
// log.info("arr: {}", arr.data().asDouble());
assertEquals(exp, arr);
}
@ -105,8 +117,6 @@ public class RandomTests extends BaseNd4jTest {
UniformDistribution distribution2 = new UniformDistribution(z2, 1.0, 2.0);
Nd4j.getExecutioner().exec(distribution2, random2);
// System.out.println("Data: " + z1);
// System.out.println("Data: " + z2);
for (int e = 0; e < z1.length(); e++) {
double val = z1.getDouble(e);
assertTrue(val >= 1.0 && val <= 2.0);
@ -135,8 +145,6 @@ public class RandomTests extends BaseNd4jTest {
log.info("States cpu: {}/{}", random1.rootState(), random1.nodeState());
// System.out.println("Data: " + z1);
// System.out.println("Data: " + z2);
for (int e = 0; e < z1.length(); e++) {
double val = z1.getDouble(e);
assertTrue(val >= 1.0 && val <= 2.0);
@ -156,9 +164,6 @@ public class RandomTests extends BaseNd4jTest {
UniformDistribution distribution2 = new UniformDistribution(z2, 1.0, 2.0);
Nd4j.getExecutioner().exec(distribution2, random1);
// System.out.println("Data: " + z1);
// System.out.println("Data: " + z2);
assertNotEquals(z1, z2);
}
@ -174,7 +179,6 @@ public class RandomTests extends BaseNd4jTest {
INDArray z2 = Nd4j.randn('c', new int[] {1, 1000});
assertEquals("Failed on iteration " + i, z1, z2);
}
}
@ -190,7 +194,6 @@ public class RandomTests extends BaseNd4jTest {
INDArray z2 = Nd4j.rand('c', new int[] {1, 1000});
assertEquals("Failed on iteration " + i, z1, z2);
}
}
@ -206,7 +209,6 @@ public class RandomTests extends BaseNd4jTest {
INDArray z2 = Nd4j.getExecutioner().exec(new BinomialDistribution(Nd4j.createUninitialized(1000), 10, 0.2));
assertEquals("Failed on iteration " + i, z1, z2);
}
}
@ -222,8 +224,6 @@ public class RandomTests extends BaseNd4jTest {
assertEquals(z1, z2);
}
@Test
public void testDropoutInverted1() {
Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119);
@ -318,7 +318,6 @@ public class RandomTests extends BaseNd4jTest {
assertEquals(z1, z2);
}
@Test
public void testGaussianDistribution2() {
Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119);
@ -403,8 +402,6 @@ public class RandomTests extends BaseNd4jTest {
Distribution nd = new NormalDistribution(random1, 0.0, 1.0);
Nd4j.sort(z1, true);
// System.out.println("Data for Anderson-Darling: " + z1);
for (int i = 0; i < n; i++) {
Double res = nd.cumulativeProbability(z1.getDouble(i));
@ -432,9 +429,6 @@ public class RandomTests extends BaseNd4jTest {
public void testStepOver1() {
Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119);
// log.info("1: ----------------");
INDArray z0 = Nd4j.getExecutioner().exec(new GaussianDistribution(Nd4j.createUninitialized(DataType.DOUBLE, 1000000), 0.0, 1.0));
assertEquals(0.0, z0.meanNumber().doubleValue(), 0.01);
@ -442,36 +436,15 @@ public class RandomTests extends BaseNd4jTest {
random1.setSeed(119);
// log.info("2: ----------------");
INDArray z1 = Nd4j.zeros(DataType.DOUBLE, 55000000);
INDArray z2 = Nd4j.zeros(DataType.DOUBLE, 55000000);
GaussianDistribution op1 = new GaussianDistribution(z1, 0.0, 1.0);
Nd4j.getExecutioner().exec(op1, random1);
// log.info("2: ----------------");
//log.info("End: [{}, {}, {}, {}]", z1.getFloat(29000000), z1.getFloat(29000001), z1.getFloat(29000002), z1.getFloat(29000003));
//log.info("Sum: {}", z1.sumNumber().doubleValue());
// log.info("Sum2: {}", z2.sumNumber().doubleValue());
INDArray match = Nd4j.getExecutioner().exec(new MatchCondition(z1, Conditions.isNan()));
// log.info("NaNs: {}", match);
assertEquals(0.0f, match.getFloat(0), 0.01f);
/*
for (int i = 0; i < z1.length(); i++) {
if (Double.isNaN(z1.getDouble(i)))
throw new IllegalStateException("NaN value found at " + i);
if (Double.isInfinite(z1.getDouble(i)))
throw new IllegalStateException("Infinite value found at " + i);
}
*/
assertEquals(0.0, z1.meanNumber().doubleValue(), 0.01);
assertEquals(1.0, z1.stdNumber().doubleValue(), 0.01);
}
@ -480,7 +453,6 @@ public class RandomTests extends BaseNd4jTest {
public void testSum_119() {
INDArray z2 = Nd4j.zeros(DataType.DOUBLE, 55000000);
val sum = z2.sumNumber().doubleValue();
// log.info("Sum2: {}", sum);
assertEquals(0.0, sum, 1e-5);
}
@ -493,7 +465,6 @@ public class RandomTests extends BaseNd4jTest {
assertEquals(1.0, z1.stdNumber().doubleValue(), 0.01);
}
@Test
public void testSetSeed1() {
Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119);
@ -533,8 +504,6 @@ public class RandomTests extends BaseNd4jTest {
assertEquals(z02, z12);
}
@Test
public void testJavaSide1() {
Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119);
@ -553,8 +522,6 @@ public class RandomTests extends BaseNd4jTest {
assertArrayEquals(array1, array2, 1e-5f);
}
@Test
public void testJavaSide2() {
Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119);
@ -574,7 +541,6 @@ public class RandomTests extends BaseNd4jTest {
assertArrayEquals(array1, array2);
}
@Test
public void testJavaSide3() {
Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119);
@ -657,8 +623,6 @@ public class RandomTests extends BaseNd4jTest {
assertNotEquals(0, sum);
}
@Test
public void testBernoulliDistribution1() {
Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119);
@ -677,11 +641,8 @@ public class RandomTests extends BaseNd4jTest {
assertNotEquals(z1Dup, z1);
assertEquals(z1, z2);
}
@Test
public void testBernoulliDistribution2() {
Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119);
@ -690,7 +651,8 @@ public class RandomTests extends BaseNd4jTest {
INDArray z1 = Nd4j.zeros(20);
INDArray z2 = Nd4j.zeros(20);
INDArray z1Dup = Nd4j.zeros(20);
INDArray exp = Nd4j.create(new double[] {0, 1.0000, 0, 1.0000, 1.0000, 0, 1.0000, 1.0000, 0, 1.0000, 1.0000, 1.0000, 0, 1.0000, 1.0000, 0, 0, 1.0000, 0, 1.0000});
INDArray exp = Nd4j.create(new double[]{ 0, 1.0000, 0, 1.0000, 1.0000, 0, 1.0000, 1.0000, 0, 1.0000, 1.0000,
1.0000, 0, 1.0000, 1.0000, 0, 0, 1.0000, 0, 1.0000 });
BernoulliDistribution op1 = new BernoulliDistribution(z1, 0.50);
BernoulliDistribution op2 = new BernoulliDistribution(z2, 0.50);
@ -705,7 +667,6 @@ public class RandomTests extends BaseNd4jTest {
assertEquals(exp, z1);
}
@Test
public void testBernoulliDistribution3() {
Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119);
@ -716,7 +677,7 @@ public class RandomTests extends BaseNd4jTest {
INDArray z1 = Nd4j.zeros(10);
INDArray z2 = Nd4j.zeros(10);
INDArray z1Dup = Nd4j.zeros(10);
INDArray exp = Nd4j.create(new double[] {1.0000, 0, 0, 1.0000, 1.0000, 1.0000, 0, 1.0000, 0, 0});
INDArray exp = Nd4j.create(new double[]{ 1.0000, 0, 0, 1.0000, 1.0000, 1.0000, 0, 1.0000, 0, 0 });
BernoulliDistribution op1 = new BernoulliDistribution(z1, prob);
BernoulliDistribution op2 = new BernoulliDistribution(z2, prob);
@ -731,7 +692,6 @@ public class RandomTests extends BaseNd4jTest {
assertEquals(exp, z1);
}
@Test
public void testBinomialDistribution1() {
Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119);
@ -780,7 +740,6 @@ public class RandomTests extends BaseNd4jTest {
BooleanIndexing.and(z1, Conditions.greaterThanOrEqual(0.0));
}
@Test
public void testMultithreading1() throws Exception {
@ -822,7 +781,6 @@ public class RandomTests extends BaseNd4jTest {
}
}
@Test
public void testMultithreading2() throws Exception {
@ -885,11 +843,11 @@ public class RandomTests extends BaseNd4jTest {
assertNotEquals(someInt, otherInt);
} else
} else {
log.warn("Not a NativeRandom object received, skipping test");
}
}
@Test
public void testStepOver4() {
Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119, 100000);
@ -903,7 +861,6 @@ public class RandomTests extends BaseNd4jTest {
}
}
@Test
public void testSignatures1() {
@ -915,7 +872,6 @@ public class RandomTests extends BaseNd4jTest {
}
}
@Test
public void testChoice1() {
INDArray source = Nd4j.create(new double[] {1, 2, 3, 4, 5});
@ -926,7 +882,6 @@ public class RandomTests extends BaseNd4jTest {
assertEquals(exp, sampled);
}
@Test
public void testChoice2() {
INDArray source = Nd4j.create(new double[] {1, 2, 3, 4, 5});
@ -937,8 +892,6 @@ public class RandomTests extends BaseNd4jTest {
assertEquals(exp, sampled);
}
@Ignore
@Test
public void testDeallocation1() throws Exception {
@ -952,7 +905,6 @@ public class RandomTests extends BaseNd4jTest {
}
}
@Test
public void someTest() {
DataTypeUtil.setDTypeForContext(DataType.DOUBLE);
@ -1353,9 +1305,6 @@ public class RandomTests extends BaseNd4jTest {
log.info("Java mean: {}; Native mean: {}", mean, z01.meanNumber().doubleValue());
assertEquals(mean, z01.meanNumber().doubleValue(), 1e-1);
}
@Test
@ -1364,44 +1313,32 @@ public class RandomTests extends BaseNd4jTest {
INDArray exp = Nd4j.create(new double[] {1, 2, 3, 4, 5});
assertEquals(exp, res);
}
@Test
public void testOrthogonalDistribution1() {
val dist = new OrthogonalDistribution(1.0);
val array = dist.sample(new int[] {6, 9});
// log.info("Array: {}", array);
}
@Test
public void testOrthogonalDistribution2() {
val dist = new OrthogonalDistribution(1.0);
val array = dist.sample(new int[] {9, 6});
// log.info("Array: {}", array);
}
@Test
public void testOrthogonalDistribution3() {
val dist = new OrthogonalDistribution(1.0);
val array = dist.sample(new int[] {9, 9});
// log.info("Array: {}", array);
}
@Test
public void reproducabilityTest(){
int numBatches = 1;
for( int t=0; t<10; t++ ) {
// System.out.println(t);
for (int t = 0; t < 10; t++) {
numBatches = t;
List<INDArray> initial = getList(numBatches);
@ -1410,7 +1347,6 @@ public class RandomTests extends BaseNd4jTest {
List<INDArray> list = getList(numBatches);
assertEquals(initial, list);
}
}
}
@ -1428,7 +1364,6 @@ public class RandomTests extends BaseNd4jTest {
Nd4j.getRandom().setSeed(12345);
INDArray arr = Nd4j.create(DataType.DOUBLE, 100);
Nd4j.exec(new BernoulliDistribution(arr, 0.5));
// System.out.println(arr);
double sum = arr.sumNumber().doubleValue();
assertTrue(String.valueOf(sum), sum > 0.0 && sum < 100.0);
}
@ -1436,7 +1371,6 @@ public class RandomTests extends BaseNd4jTest {
private List<INDArray> getList(int numBatches){
Nd4j.getRandom().setSeed(12345);
List<INDArray> out = new ArrayList<>();
// int numBatches = 32; //passes with 1 or 2
int channels = 3;
int imageHeight = 64;
int imageWidth = 64;
@ -1446,7 +1380,6 @@ public class RandomTests extends BaseNd4jTest {
return out;
}
@Test
public void testRngRepeatabilityUniform(){
val nexp = Nd4j.create(DataType.FLOAT, 10);
@ -1521,7 +1454,6 @@ public class RandomTests extends BaseNd4jTest {
assertEquals(res[0], res1[0]);
}
@Test
public void testRandom() {
val r1 = new java.util.Random(119);

View File

@ -16,12 +16,16 @@
package org.nd4j.linalg.rng;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;
import lombok.Builder;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.junit.Test;
import org.nd4j.OpValidationSuite;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.util.ArrayUtil;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -33,19 +37,27 @@ import org.nd4j.linalg.api.ops.random.compat.RandomStandardNormal;
import org.nd4j.linalg.api.ops.random.custom.DistributionUniform;
import org.nd4j.linalg.api.ops.random.custom.RandomBernoulli;
import org.nd4j.linalg.api.ops.random.custom.RandomExponential;
import org.nd4j.linalg.api.ops.random.impl.AlphaDropOut;
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
import org.nd4j.linalg.api.ops.random.impl.BinomialDistribution;
import org.nd4j.linalg.api.ops.random.impl.Choice;
import org.nd4j.linalg.api.ops.random.impl.DropOut;
import org.nd4j.linalg.api.ops.random.impl.DropOutInverted;
import org.nd4j.linalg.api.ops.random.impl.GaussianDistribution;
import org.nd4j.linalg.api.ops.random.impl.Linspace;
import org.nd4j.linalg.api.ops.random.impl.LogNormalDistribution;
import org.nd4j.linalg.api.ops.random.impl.ProbablisticMerge;
import org.nd4j.linalg.api.ops.random.impl.TruncatedNormalDistribution;
import org.nd4j.linalg.api.ops.random.impl.UniformDistribution;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.common.util.ArrayUtil;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;
@Slf4j
public class RngValidationTests extends BaseNd4jTest {
@ -407,8 +419,6 @@ public class RngValidationTests extends BaseNd4jTest {
double alpha = alphaDropoutA(tc.prop("p"));
double beta = alphaDropoutB(tc.prop("p"));
return new AlphaDropOut(Nd4j.ones(tc.getDataType(), tc.shape), tc.arr(), tc.prop("p"), alpha, ALPHA_PRIME, beta);
case "distributionuniform":
INDArray shape = tc.getShape().length == 0 ? Nd4j.empty(DataType.LONG) : Nd4j.create(ArrayUtil.toDouble(tc.shape)).castTo(DataType.LONG);
return new DistributionUniform(shape, tc.arr(), tc.prop("min"), tc.prop("max"));
@ -437,7 +447,6 @@ public class RngValidationTests extends BaseNd4jTest {
return Math.abs(x-y) / (Math.abs(x) + Math.abs(y));
}
public static final double DEFAULT_ALPHA = 1.6732632423543772;
public static final double DEFAULT_LAMBDA = 1.0507009873554804;
public static final double ALPHA_PRIME = -DEFAULT_LAMBDA * DEFAULT_ALPHA;

View File

@ -29,6 +29,12 @@ import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
import org.nd4j.linalg.api.memory.enums.AllocationPolicy;
import org.nd4j.linalg.api.memory.enums.LearningPolicy;
import org.nd4j.linalg.api.memory.enums.LocationPolicy;
import org.nd4j.linalg.api.memory.enums.MirroringPolicy;
import org.nd4j.linalg.api.memory.enums.ResetPolicy;
import org.nd4j.linalg.api.memory.enums.SpillPolicy;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.exception.ND4JIllegalStateException;

View File

@ -26,6 +26,11 @@ import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
import org.nd4j.linalg.api.memory.enums.AllocationPolicy;
import org.nd4j.linalg.api.memory.enums.DebugMode;
import org.nd4j.linalg.api.memory.enums.LearningPolicy;
import org.nd4j.linalg.api.memory.enums.MirroringPolicy;
import org.nd4j.linalg.api.memory.enums.SpillPolicy;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.api.memory.abstracts.Nd4jWorkspace;

View File

@ -16,6 +16,11 @@
package org.nd4j.linalg.workspace;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertTrue;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.junit.After;
@ -25,19 +30,22 @@ import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.abstracts.Nd4jWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
import org.nd4j.linalg.api.memory.enums.AllocationPolicy;
import org.nd4j.linalg.api.memory.enums.LearningPolicy;
import org.nd4j.linalg.api.memory.enums.LocationPolicy;
import org.nd4j.linalg.api.memory.enums.ResetPolicy;
import org.nd4j.linalg.api.memory.enums.SpillPolicy;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.api.memory.abstracts.Nd4jWorkspace;
import java.nio.file.Files;
import java.util.ArrayList;
import java.util.Arrays;
import static org.junit.Assert.*;
/**
* @author raver119@gmail.com
*/
@ -60,9 +68,15 @@ public class SpecialWorkspaceTests extends BaseNd4jTest {
@Test
public void testVariableTimeSeries1() {
WorkspaceConfiguration configuration = WorkspaceConfiguration.builder().initialSize(0).overallocationLimit(3.0)
.policyAllocation(AllocationPolicy.OVERALLOCATE).policySpill(SpillPolicy.EXTERNAL)
.policyLearning(LearningPolicy.FIRST_LOOP).policyReset(ResetPolicy.ENDOFBUFFER_REACHED).build();
WorkspaceConfiguration configuration = WorkspaceConfiguration
.builder()
.initialSize(0)
.overallocationLimit(3.0)
.policyAllocation(AllocationPolicy.OVERALLOCATE)
.policySpill(SpillPolicy.EXTERNAL)
.policyLearning(LearningPolicy.FIRST_LOOP)
.policyReset(ResetPolicy.ENDOFBUFFER_REACHED)
.build();
try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(configuration, "WS1")) {
Nd4j.create(500);
@ -70,7 +84,6 @@ public class SpecialWorkspaceTests extends BaseNd4jTest {
}
Nd4jWorkspace workspace = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread("WS1");
// workspace.enableDebug(true);
assertEquals(0, workspace.getStepNumber());
@ -125,7 +138,6 @@ public class SpecialWorkspaceTests extends BaseNd4jTest {
log.info("Workspace state after first block: ---------------------------------------------------------");
Nd4j.getWorkspaceManager().printAllocationStatisticsForCurrentThread();
log.info("--------------------------------------------------------------------------------------------");
// we just do huge loop now, with pinned stuff in it
@ -144,7 +156,6 @@ public class SpecialWorkspaceTests extends BaseNd4jTest {
assertNotEquals(0, workspace.getNumberOfPinnedAllocations());
assertEquals(0, workspace.getNumberOfExternalAllocations());
// and we do another clean loo, without pinned stuff in it, to ensure all pinned allocates are gone
for (int i = 0; i < 100; i++) {
try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(configuration, "WS1")) {
@ -158,12 +169,10 @@ public class SpecialWorkspaceTests extends BaseNd4jTest {
assertEquals(0, workspace.getNumberOfPinnedAllocations());
assertEquals(0, workspace.getNumberOfExternalAllocations());
log.info("Workspace state after second block: ---------------------------------------------------------");
Nd4j.getWorkspaceManager().printAllocationStatisticsForCurrentThread();
}
@Test
public void testVariableTimeSeries2() {
WorkspaceConfiguration configuration = WorkspaceConfiguration.builder().initialSize(0).overallocationLimit(3.0)
@ -179,8 +188,6 @@ public class SpecialWorkspaceTests extends BaseNd4jTest {
Nd4j.create(500);
}
assertEquals(0, workspace.getStepNumber());
long requiredMemory = 1000 * Nd4j.sizeOfDataType();
@ -189,7 +196,6 @@ public class SpecialWorkspaceTests extends BaseNd4jTest {
assertEquals(shiftedSize, workspace.getInitialBlockSize());
assertEquals(workspace.getInitialBlockSize() * 4, workspace.getCurrentSize());
for (int i = 0; i < 100; i++) {
try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(configuration, "WS1")) {
Nd4j.create(500);
@ -206,7 +212,6 @@ public class SpecialWorkspaceTests extends BaseNd4jTest {
assertEquals(0, workspace.getSpilledSize());
assertEquals(0, workspace.getPinnedSize());
}
@Test
@ -238,7 +243,6 @@ public class SpecialWorkspaceTests extends BaseNd4jTest {
assertEquals(exp, result);
}
@Test
public void testAlignment_1() {
WorkspaceConfiguration initialConfig = WorkspaceConfiguration.builder().initialSize(10 * 1024L * 1024L)
@ -260,7 +264,6 @@ public class SpecialWorkspaceTests extends BaseNd4jTest {
}
}
@Test
public void testNoOpExecution_1() {
val configuration = WorkspaceConfiguration.builder().initialSize(10000000).overallocationLimit(3.0)
@ -424,7 +427,6 @@ public class SpecialWorkspaceTests extends BaseNd4jTest {
Files.delete(tmpFile);
}
@Test
public void testMigrateToWorkspace(){
val src = Nd4j.createFromArray (1L,2L);

View File

@ -27,6 +27,11 @@ import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
import org.nd4j.linalg.api.memory.enums.AllocationPolicy;
import org.nd4j.linalg.api.memory.enums.LearningPolicy;
import org.nd4j.linalg.api.memory.enums.MirroringPolicy;
import org.nd4j.linalg.api.memory.enums.ResetPolicy;
import org.nd4j.linalg.api.memory.enums.SpillPolicy;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;