parent
ca4aee16ec
commit
2e000c84ac
|
@ -14,7 +14,6 @@ Please search for the latest version on search.maven.org.
|
||||||
Or use the versions displayed in:
|
Or use the versions displayed in:
|
||||||
https://github.com/eclipse/deeplearning4j-examples/blob/master/pom.xml
|
https://github.com/eclipse/deeplearning4j-examples/blob/master/pom.xml
|
||||||
|
|
||||||
|
|
||||||
---
|
---
|
||||||
## Main Features
|
## Main Features
|
||||||
|
|
||||||
|
@ -47,6 +46,29 @@ To install ND4J, there are a couple of approaches, and more information can be f
|
||||||
#### Clone from the GitHub Repo
|
#### Clone from the GitHub Repo
|
||||||
|
|
||||||
https://deeplearning4j.org/docs/latest/deeplearning4j-build-from-source
|
https://deeplearning4j.org/docs/latest/deeplearning4j-build-from-source
|
||||||
|
|
||||||
|
#### Build from sources
|
||||||
|
|
||||||
|
To build `ND4J` from sources launch from the present directory:
|
||||||
|
|
||||||
|
```shell script
|
||||||
|
$ mvn clean install -DskipTests=true
|
||||||
|
```
|
||||||
|
|
||||||
|
To run tests using CPU or CUDA backend run the following.
|
||||||
|
|
||||||
|
For CPU:
|
||||||
|
|
||||||
|
```shell script
|
||||||
|
$ mvn clean test -P testresources -P nd4j-testresources -P nd4j-tests-cpu -P nd4j-tf-cpu
|
||||||
|
```
|
||||||
|
|
||||||
|
For CUDA:
|
||||||
|
|
||||||
|
```shell script
|
||||||
|
$ mvn clean test -P testresources -P nd4j-testresources -P nd4j-tests-cuda -P nd4j-tf-gpu
|
||||||
|
```
|
||||||
|
|
||||||
## Contribute
|
## Contribute
|
||||||
|
|
||||||
1. Check for open issues, or open a new issue to start a discussion around a feature idea or a bug.
|
1. Check for open issues, or open a new issue to start a discussion around a feature idea or a bug.
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
// Targeted by JavaCPP version 1.5.4-SNAPSHOT: DO NOT EDIT THIS FILE
|
// Targeted by JavaCPP version 1.5.4: DO NOT EDIT THIS FILE
|
||||||
|
|
||||||
package org.nd4j.nativeblas;
|
package org.nd4j.nativeblas;
|
||||||
|
|
||||||
|
@ -6,6 +6,10 @@ import java.nio.*;
|
||||||
import org.bytedeco.javacpp.*;
|
import org.bytedeco.javacpp.*;
|
||||||
import org.bytedeco.javacpp.annotation.*;
|
import org.bytedeco.javacpp.annotation.*;
|
||||||
|
|
||||||
|
import static org.bytedeco.javacpp.presets.javacpp.*;
|
||||||
|
import static org.bytedeco.openblas.global.openblas_nolapack.*;
|
||||||
|
import static org.bytedeco.openblas.global.openblas.*;
|
||||||
|
|
||||||
public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper {
|
public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper {
|
||||||
static { Loader.load(); }
|
static { Loader.load(); }
|
||||||
|
|
||||||
|
|
|
@ -128,7 +128,6 @@
|
||||||
<version>${project.version}</version>
|
<version>${project.version}</version>
|
||||||
<scope>test</scope>
|
<scope>test</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
<!-- https://maven.apache.org/surefire/maven-surefire-plugin/examples/fork-options-and-parallel-execution.html -->
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>com.github.stephenc.jcip</groupId>
|
<groupId>com.github.stephenc.jcip</groupId>
|
||||||
<artifactId>jcip-annotations</artifactId>
|
<artifactId>jcip-annotations</artifactId>
|
||||||
|
|
|
@ -17,12 +17,14 @@
|
||||||
|
|
||||||
package org.nd4j.autodiff;
|
package org.nd4j.autodiff;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
import org.apache.commons.io.FileUtils;
|
import org.apache.commons.io.FileUtils;
|
||||||
import org.apache.commons.io.FilenameUtils;
|
import org.apache.commons.io.FilenameUtils;
|
||||||
import org.junit.Ignore;
|
import org.junit.Ignore;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
import org.nd4j.imports.NoOpNameFoundException;
|
||||||
import org.nd4j.imports.converters.ImportClassMapping;
|
import org.nd4j.imports.converters.ImportClassMapping;
|
||||||
import org.nd4j.linalg.BaseNd4jTest;
|
import org.nd4j.linalg.BaseNd4jTest;
|
||||||
|
@ -30,6 +32,36 @@ import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.api.ops.NoOp;
|
import org.nd4j.linalg.api.ops.NoOp;
|
||||||
import org.nd4j.linalg.api.ops.compat.CompatSparseToDense;
|
import org.nd4j.linalg.api.ops.compat.CompatSparseToDense;
|
||||||
import org.nd4j.linalg.api.ops.compat.CompatStringSplit;
|
import org.nd4j.linalg.api.ops.compat.CompatStringSplit;
|
||||||
|
import org.nd4j.linalg.api.ops.custom.BarnesEdgeForces;
|
||||||
|
import org.nd4j.linalg.api.ops.custom.BarnesHutGains;
|
||||||
|
import org.nd4j.linalg.api.ops.custom.BarnesHutSymmetrize;
|
||||||
|
import org.nd4j.linalg.api.ops.custom.KnnMinDistance;
|
||||||
|
import org.nd4j.linalg.api.ops.custom.SpTreeCell;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastAMax;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastAMin;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastAddOp;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastCopyOp;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastDivOp;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastGradientArgs;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMax;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMin;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMulOp;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastRDivOp;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastRSubOp;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastSubOp;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastTo;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastEqualTo;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastGreaterThan;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastGreaterThanOrEqual;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastLessThan;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastLessThanOrEqual;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastNotEqual;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Enter;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Exit;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.controlflow.compat.LoopCond;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Merge;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.controlflow.compat.NextIteration;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Switch;
|
||||||
import org.nd4j.linalg.api.ops.impl.grid.FreeGridOp;
|
import org.nd4j.linalg.api.ops.impl.grid.FreeGridOp;
|
||||||
import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax;
|
import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax;
|
||||||
import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin;
|
import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin;
|
||||||
|
@ -55,6 +87,15 @@ import org.nd4j.linalg.api.ops.impl.transforms.pairwise.BinaryRelativeError;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.CopyOp;
|
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.CopyOp;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.PowPairwise;
|
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.PowPairwise;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RealDivOp;
|
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RealDivOp;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.updaters.AdaDeltaUpdater;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.updaters.AdaGradUpdater;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.updaters.AdaMaxUpdater;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.updaters.AdamUpdater;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.updaters.AmsGradUpdater;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.updaters.NadamUpdater;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.updaters.NesterovsUpdater;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.updaters.RmsPropUpdater;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.updaters.SgdUpdater;
|
||||||
import org.nd4j.linalg.api.ops.persistence.RestoreV2;
|
import org.nd4j.linalg.api.ops.persistence.RestoreV2;
|
||||||
import org.nd4j.linalg.api.ops.persistence.SaveV2;
|
import org.nd4j.linalg.api.ops.persistence.SaveV2;
|
||||||
import org.nd4j.linalg.api.ops.util.PrintAffinity;
|
import org.nd4j.linalg.api.ops.util.PrintAffinity;
|
||||||
|
@ -66,13 +107,17 @@ import org.reflections.Reflections;
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.lang.reflect.Modifier;
|
import java.lang.reflect.Modifier;
|
||||||
import java.nio.charset.StandardCharsets;
|
import java.nio.charset.StandardCharsets;
|
||||||
import java.util.*;
|
import java.util.ArrayList;
|
||||||
|
import java.util.Collection;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.Comparator;
|
||||||
|
import java.util.HashSet;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Set;
|
||||||
import java.util.regex.Matcher;
|
import java.util.regex.Matcher;
|
||||||
import java.util.regex.Pattern;
|
import java.util.regex.Pattern;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
|
||||||
import static org.junit.Assert.assertTrue;
|
|
||||||
|
|
||||||
public class TestOpMapping extends BaseNd4jTest {
|
public class TestOpMapping extends BaseNd4jTest {
|
||||||
|
|
||||||
Set<Class<? extends DifferentialFunction>> subTypes;
|
Set<Class<? extends DifferentialFunction>> subTypes;
|
||||||
|
@ -303,9 +348,6 @@ public class TestOpMapping extends BaseNd4jTest {
|
||||||
s.add(PrintVariable.class);
|
s.add(PrintVariable.class);
|
||||||
s.add(PrintAffinity.class);
|
s.add(PrintAffinity.class);
|
||||||
s.add(Assign.class);
|
s.add(Assign.class);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test @Ignore
|
@Test @Ignore
|
||||||
|
|
|
@ -35,6 +35,15 @@ import org.nd4j.linalg.api.ops.impl.layers.convolution.AvgPooling2D;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.DepthwiseConv2D;
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.DepthwiseConv2D;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D;
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2DDerivative;
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2DDerivative;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv2DConfig;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv3DConfig;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.LocalResponseNormalizationConfig;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.PaddingMode;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling3DConfig;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.GRU;
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.GRU;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMActivations;
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMActivations;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMDataFormat;
|
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMDataFormat;
|
||||||
|
|
|
@ -32,6 +32,18 @@ import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.CustomOp;
|
import org.nd4j.linalg.api.ops.CustomOp;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
import org.nd4j.linalg.api.ops.custom.Digamma;
|
||||||
|
import org.nd4j.linalg.api.ops.custom.DivideNoNan;
|
||||||
|
import org.nd4j.linalg.api.ops.custom.Flatten;
|
||||||
|
import org.nd4j.linalg.api.ops.custom.FusedBatchNorm;
|
||||||
|
import org.nd4j.linalg.api.ops.custom.Igamma;
|
||||||
|
import org.nd4j.linalg.api.ops.custom.Igammac;
|
||||||
|
import org.nd4j.linalg.api.ops.custom.Lgamma;
|
||||||
|
import org.nd4j.linalg.api.ops.custom.Lu;
|
||||||
|
import org.nd4j.linalg.api.ops.custom.MatrixBandPart;
|
||||||
|
import org.nd4j.linalg.api.ops.custom.Polygamma;
|
||||||
|
import org.nd4j.linalg.api.ops.custom.Roll;
|
||||||
|
import org.nd4j.linalg.api.ops.custom.TriangularSolve;
|
||||||
import org.nd4j.linalg.api.ops.impl.broadcast.BiasAdd;
|
import org.nd4j.linalg.api.ops.impl.broadcast.BiasAdd;
|
||||||
import org.nd4j.linalg.api.ops.impl.broadcast.BiasAddGrad;
|
import org.nd4j.linalg.api.ops.impl.broadcast.BiasAddGrad;
|
||||||
import org.nd4j.linalg.api.ops.impl.controlflow.compat.StopGradient;
|
import org.nd4j.linalg.api.ops.impl.controlflow.compat.StopGradient;
|
||||||
|
|
|
@ -24,6 +24,17 @@ import org.nd4j.autodiff.validation.OpTestCase;
|
||||||
import org.nd4j.autodiff.validation.OpValidation;
|
import org.nd4j.autodiff.validation.OpValidation;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.reduce.bp.CumSumBp;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.reduce.bp.MaxBp;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.reduce.bp.MeanBp;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.reduce.bp.MinBp;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.reduce.bp.Norm1Bp;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.reduce.bp.Norm2Bp;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.reduce.bp.NormMaxBp;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.reduce.bp.ProdBp;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.reduce.bp.StandardDeviationBp;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.reduce.bp.SumBp;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.reduce.bp.VarianceBp;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||||
import org.nd4j.linalg.ops.transforms.Transforms;
|
import org.nd4j.linalg.ops.transforms.Transforms;
|
||||||
|
|
|
@ -38,7 +38,22 @@ import org.nd4j.linalg.api.ops.impl.loss.SoftmaxCrossEntropyWithLogitsLoss;
|
||||||
import org.nd4j.linalg.api.ops.impl.reduce.Moments;
|
import org.nd4j.linalg.api.ops.impl.reduce.Moments;
|
||||||
import org.nd4j.linalg.api.ops.impl.reduce.NormalizeMoments;
|
import org.nd4j.linalg.api.ops.impl.reduce.NormalizeMoments;
|
||||||
import org.nd4j.linalg.api.ops.impl.reduce.SufficientStatistics;
|
import org.nd4j.linalg.api.ops.impl.reduce.SufficientStatistics;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.reduce.floating.AMean;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.reduce.floating.Entropy;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.reduce.floating.Mean;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.reduce.floating.Norm1;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.reduce.floating.Norm2;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.reduce.floating.NormMax;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.reduce.floating.ShannonEntropy;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.reduce.floating.SquaredNorm;
|
||||||
import org.nd4j.linalg.api.ops.impl.reduce.same.ASum;
|
import org.nd4j.linalg.api.ops.impl.reduce.same.ASum;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.reduce3.CosineDistance;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.reduce3.CosineSimilarity;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.reduce3.Dot;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.reduce3.EuclideanDistance;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.reduce3.HammingDistance;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.reduce3.JaccardDistance;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance;
|
||||||
import org.nd4j.linalg.api.ops.impl.summarystats.StandardDeviation;
|
import org.nd4j.linalg.api.ops.impl.summarystats.StandardDeviation;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
|
||||||
import org.nd4j.linalg.checkutil.NDArrayCreationUtil;
|
import org.nd4j.linalg.checkutil.NDArrayCreationUtil;
|
||||||
|
|
|
@ -35,6 +35,13 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.api.ops.custom.Tri;
|
import org.nd4j.linalg.api.ops.custom.Tri;
|
||||||
import org.nd4j.linalg.api.ops.custom.Triu;
|
import org.nd4j.linalg.api.ops.custom.Triu;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.shape.DiagPart;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.shape.MergeMaxIndex;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.shape.Permute;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.shape.SequenceMask;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.shape.SizeAt;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.shape.Transpose;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.shape.Unstack;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.Fill;
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.Fill;
|
||||||
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
||||||
import org.nd4j.linalg.api.shape.options.ArrayOptionsHelper;
|
import org.nd4j.linalg.api.shape.options.ArrayOptionsHelper;
|
||||||
|
|
|
@ -48,8 +48,26 @@ import org.nd4j.linalg.api.ops.impl.shape.MergeMax;
|
||||||
import org.nd4j.linalg.api.ops.impl.shape.MergeMaxIndex;
|
import org.nd4j.linalg.api.ops.impl.shape.MergeMaxIndex;
|
||||||
import org.nd4j.linalg.api.ops.impl.shape.tensorops.EmbeddingLookup;
|
import org.nd4j.linalg.api.ops.impl.shape.tensorops.EmbeddingLookup;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByAvgNorm;
|
import org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByAvgNorm;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.CReLU;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThanOrEqual;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.LessThanOrEqual;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.Max;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.Min;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.Reverse;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.Standardize;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.floating.RSqrt;
|
import org.nd4j.linalg.api.ops.impl.transforms.floating.RSqrt;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MergeAddOp;
|
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MergeAddOp;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.strict.ACosh;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.strict.ASinh;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.strict.Erf;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.strict.Erfc;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.strict.HardSigmoid;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.strict.LogSigmoid;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.strict.RationalTanh;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.strict.RectifiedTanh;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.strict.SELU;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.strict.Swish;
|
||||||
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
|
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
|
||||||
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
@ -2059,9 +2077,6 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
);
|
);
|
||||||
|
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -2085,7 +2100,6 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testEmbeddingLookup() {
|
public void testEmbeddingLookup() {
|
||||||
|
|
||||||
|
@ -2243,11 +2257,5 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
.gradientCheck(true));
|
.gradientCheck(true));
|
||||||
assertNull(err);
|
assertNull(err);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
|
@ -22,6 +22,14 @@ import static org.junit.Assert.fail;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.linalg.BaseNd4jTest;
|
import org.nd4j.linalg.BaseNd4jTest;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv2D;
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv2D;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv2DConfig;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv3DConfig;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.PaddingMode;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling3DConfig;
|
||||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||||
|
|
||||||
public class ConvConfigTests extends BaseNd4jTest {
|
public class ConvConfigTests extends BaseNd4jTest {
|
||||||
|
@ -487,8 +495,6 @@ public class ConvConfigTests extends BaseNd4jTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testConv1D(){
|
public void testConv1D(){
|
||||||
Conv1DConfig.builder().k(2).paddingMode(PaddingMode.SAME).build();
|
Conv1DConfig.builder().k(2).paddingMode(PaddingMode.SAME).build();
|
||||||
|
|
|
@ -22,6 +22,11 @@ import org.junit.Rule;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.junit.rules.TemporaryFolder;
|
import org.junit.rules.TemporaryFolder;
|
||||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||||
|
import org.nd4j.graph.FlatConfiguration;
|
||||||
|
import org.nd4j.graph.FlatGraph;
|
||||||
|
import org.nd4j.graph.FlatNode;
|
||||||
|
import org.nd4j.graph.FlatVariable;
|
||||||
|
import org.nd4j.graph.IntPair;
|
||||||
import org.nd4j.linalg.BaseNd4jTest;
|
import org.nd4j.linalg.BaseNd4jTest;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -30,6 +35,17 @@ import org.nd4j.linalg.dataset.DataSet;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||||
import org.nd4j.linalg.learning.GradientUpdater;
|
import org.nd4j.linalg.learning.GradientUpdater;
|
||||||
|
import org.nd4j.linalg.learning.config.AMSGrad;
|
||||||
|
import org.nd4j.linalg.learning.config.AdaDelta;
|
||||||
|
import org.nd4j.linalg.learning.config.AdaGrad;
|
||||||
|
import org.nd4j.linalg.learning.config.AdaMax;
|
||||||
|
import org.nd4j.linalg.learning.config.Adam;
|
||||||
|
import org.nd4j.linalg.learning.config.IUpdater;
|
||||||
|
import org.nd4j.linalg.learning.config.Nadam;
|
||||||
|
import org.nd4j.linalg.learning.config.Nesterovs;
|
||||||
|
import org.nd4j.linalg.learning.config.NoOp;
|
||||||
|
import org.nd4j.linalg.learning.config.RmsProp;
|
||||||
|
import org.nd4j.linalg.learning.config.Sgd;
|
||||||
import org.nd4j.linalg.learning.regularization.L1Regularization;
|
import org.nd4j.linalg.learning.regularization.L1Regularization;
|
||||||
import org.nd4j.linalg.learning.regularization.L2Regularization;
|
import org.nd4j.linalg.learning.regularization.L2Regularization;
|
||||||
import org.nd4j.linalg.learning.regularization.WeightDecay;
|
import org.nd4j.linalg.learning.regularization.WeightDecay;
|
||||||
|
|
|
@ -18,6 +18,11 @@ package org.nd4j.autodiff.samediff;
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
import org.nd4j.autodiff.samediff.transform.GraphTransformUtil;
|
||||||
|
import org.nd4j.autodiff.samediff.transform.OpPredicate;
|
||||||
|
import org.nd4j.autodiff.samediff.transform.SubGraph;
|
||||||
|
import org.nd4j.autodiff.samediff.transform.SubGraphPredicate;
|
||||||
|
import org.nd4j.autodiff.samediff.transform.SubGraphProcessor;
|
||||||
import org.nd4j.linalg.BaseNd4jTest;
|
import org.nd4j.linalg.BaseNd4jTest;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
|
@ -45,6 +45,12 @@ import org.nd4j.autodiff.validation.OpValidation;
|
||||||
import org.nd4j.autodiff.validation.TestCase;
|
import org.nd4j.autodiff.validation.TestCase;
|
||||||
import org.nd4j.enums.WeightsFormat;
|
import org.nd4j.enums.WeightsFormat;
|
||||||
import org.nd4j.evaluation.IEvaluation;
|
import org.nd4j.evaluation.IEvaluation;
|
||||||
|
import org.nd4j.evaluation.classification.Evaluation;
|
||||||
|
import org.nd4j.evaluation.classification.EvaluationBinary;
|
||||||
|
import org.nd4j.evaluation.classification.EvaluationCalibration;
|
||||||
|
import org.nd4j.evaluation.classification.ROC;
|
||||||
|
import org.nd4j.evaluation.classification.ROCBinary;
|
||||||
|
import org.nd4j.evaluation.classification.ROCMultiClass;
|
||||||
import org.nd4j.evaluation.regression.RegressionEvaluation;
|
import org.nd4j.evaluation.regression.RegressionEvaluation;
|
||||||
import org.nd4j.linalg.BaseNd4jTest;
|
import org.nd4j.linalg.BaseNd4jTest;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
|
@ -3485,42 +3491,42 @@ public class SameDiffTests extends BaseNd4jTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testConcatVariableGrad() {
|
public void testConcatVariableGrad() {
|
||||||
SameDiff sd = SameDiff.create();
|
SameDiff sd = SameDiff.create();
|
||||||
SDVariable label = sd.var("label", DataType.FLOAT, 3, 4);
|
SDVariable label = sd.var("label", DataType.FLOAT, 3, 4);
|
||||||
SDVariable a = sd.var("a", DataType.FLOAT, 3, 2);
|
SDVariable a = sd.var("a", DataType.FLOAT, 3, 2);
|
||||||
SDVariable b = sd.var("b", DataType.FLOAT, 3, 2);
|
SDVariable b = sd.var("b", DataType.FLOAT, 3, 2);
|
||||||
INDArray inputArr = Nd4j.rand(3,4);
|
INDArray inputArr = Nd4j.rand(3,4);
|
||||||
INDArray labelArr = Nd4j.rand(3,4);
|
INDArray labelArr = Nd4j.rand(3,4);
|
||||||
SDVariable c = sd.concat("concat", 1, a, b);
|
SDVariable c = sd.concat("concat", 1, a, b);
|
||||||
SDVariable loss = sd.math().pow(c.sub(label), 2);
|
SDVariable loss = sd.math().pow(c.sub(label), 2);
|
||||||
sd.setLossVariables(loss);
|
sd.setLossVariables(loss);
|
||||||
sd.associateArrayWithVariable(labelArr, label);
|
sd.associateArrayWithVariable(labelArr, label);
|
||||||
sd.associateArrayWithVariable(inputArr.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 2)), a);
|
sd.associateArrayWithVariable(inputArr.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 2)), a);
|
||||||
sd.associateArrayWithVariable(inputArr.get(NDArrayIndex.all(), NDArrayIndex.interval(2, 4)), b);
|
sd.associateArrayWithVariable(inputArr.get(NDArrayIndex.all(), NDArrayIndex.interval(2, 4)), b);
|
||||||
Map<String, INDArray> map = sd.calculateGradients(null, "a", "b", "concat");
|
Map<String, INDArray> map = sd.calculateGradients(null, "a", "b", "concat");
|
||||||
INDArray concatArray = Nd4j.hstack(map.get("a"), map.get("b"));
|
INDArray concatArray = Nd4j.hstack(map.get("a"), map.get("b"));
|
||||||
assertEquals(concatArray, map.get("concat"));
|
assertEquals(concatArray, map.get("concat"));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testSliceVariableGrad() {
|
public void testSliceVariableGrad() {
|
||||||
SameDiff sd = SameDiff.create();
|
SameDiff sd = SameDiff.create();
|
||||||
SDVariable label = sd.var("label", DataType.FLOAT, 3, 4);
|
SDVariable label = sd.var("label", DataType.FLOAT, 3, 4);
|
||||||
SDVariable input = sd.var("input", DataType.FLOAT, 3, 4);
|
SDVariable input = sd.var("input", DataType.FLOAT, 3, 4);
|
||||||
INDArray inputArr = Nd4j.rand(3,4);
|
INDArray inputArr = Nd4j.rand(3,4);
|
||||||
INDArray labelArr = Nd4j.rand(3,4);
|
INDArray labelArr = Nd4j.rand(3,4);
|
||||||
SDVariable a = input.get(SDIndex.all(), SDIndex.interval(0, 2));
|
SDVariable a = input.get(SDIndex.all(), SDIndex.interval(0, 2));
|
||||||
SDVariable b = input.get(SDIndex.all(), SDIndex.interval(2, 4));
|
SDVariable b = input.get(SDIndex.all(), SDIndex.interval(2, 4));
|
||||||
SDVariable c = sd.concat("concat", 1, a, b);
|
SDVariable c = sd.concat("concat", 1, a, b);
|
||||||
SDVariable loss = sd.math().pow(c.sub(label), 2);
|
SDVariable loss = sd.math().pow(c.sub(label), 2);
|
||||||
sd.setLossVariables(loss);
|
sd.setLossVariables(loss);
|
||||||
sd.associateArrayWithVariable(labelArr, label);
|
sd.associateArrayWithVariable(labelArr, label);
|
||||||
sd.associateArrayWithVariable(inputArr, input);
|
sd.associateArrayWithVariable(inputArr, input);
|
||||||
Map<String, INDArray> map = sd.calculateGradients(null,"input", "concat");
|
Map<String, INDArray> map = sd.calculateGradients(null,"input", "concat");
|
||||||
assertEquals(map.get("input"), map.get("concat"));
|
assertEquals(map.get("input"), map.get("concat"));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testTrainingConfigJson(){
|
public void testTrainingConfigJson(){
|
||||||
|
|
|
@ -17,6 +17,13 @@
|
||||||
package org.nd4j.autodiff.samediff.listeners;
|
package org.nd4j.autodiff.samediff.listeners;
|
||||||
|
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
import org.nd4j.autodiff.listeners.At;
|
||||||
|
import org.nd4j.autodiff.listeners.BaseListener;
|
||||||
|
import org.nd4j.autodiff.listeners.Listener;
|
||||||
|
import org.nd4j.autodiff.listeners.ListenerResponse;
|
||||||
|
import org.nd4j.autodiff.listeners.ListenerVariables;
|
||||||
|
import org.nd4j.autodiff.listeners.Loss;
|
||||||
|
import org.nd4j.autodiff.listeners.Operation;
|
||||||
import org.nd4j.autodiff.listeners.impl.ScoreListener;
|
import org.nd4j.autodiff.listeners.impl.ScoreListener;
|
||||||
import org.nd4j.autodiff.listeners.records.History;
|
import org.nd4j.autodiff.listeners.records.History;
|
||||||
import org.nd4j.autodiff.listeners.records.LossCurve;
|
import org.nd4j.autodiff.listeners.records.LossCurve;
|
||||||
|
@ -351,7 +358,7 @@ public class ListenerTest extends BaseNd4jTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void iterationDone(SameDiff sd, At at, MultiDataSet dataSet, Loss loss) {
|
public void iterationDone(final SameDiff sd, final At at, final MultiDataSet dataSet, final Loss loss) {
|
||||||
iterationDoneCount++;
|
iterationDoneCount++;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -28,6 +28,13 @@ import org.nd4j.autodiff.samediff.VariableType;
|
||||||
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
|
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
|
||||||
import org.nd4j.autodiff.samediff.internal.Variable;
|
import org.nd4j.autodiff.samediff.internal.Variable;
|
||||||
import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper;
|
import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper;
|
||||||
|
import org.nd4j.graph.FlatArray;
|
||||||
|
import org.nd4j.graph.UIAddName;
|
||||||
|
import org.nd4j.graph.UIEvent;
|
||||||
|
import org.nd4j.graph.UIGraphStructure;
|
||||||
|
import org.nd4j.graph.UIInfoType;
|
||||||
|
import org.nd4j.graph.UIOp;
|
||||||
|
import org.nd4j.graph.UIVariable;
|
||||||
import org.nd4j.graph.ui.LogFileWriter;
|
import org.nd4j.graph.ui.LogFileWriter;
|
||||||
import org.nd4j.linalg.BaseNd4jTest;
|
import org.nd4j.linalg.BaseNd4jTest;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
@ -150,7 +157,6 @@ public class FileReadWriteTests extends BaseNd4jTest {
|
||||||
|
|
||||||
assertEquals(UIInfoType.START_EVENTS, read.getData().get(1).getFirst().infoType());
|
assertEquals(UIInfoType.START_EVENTS, read.getData().get(1).getFirst().infoType());
|
||||||
|
|
||||||
|
|
||||||
//Append a number of events
|
//Append a number of events
|
||||||
w.registerEventName("accuracy");
|
w.registerEventName("accuracy");
|
||||||
for( int iter=0; iter<3; iter++) {
|
for( int iter=0; iter<3; iter++) {
|
||||||
|
|
|
@ -1,6 +1,12 @@
|
||||||
package org.nd4j.evaluation;
|
package org.nd4j.evaluation;
|
||||||
|
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
import org.nd4j.evaluation.classification.Evaluation;
|
||||||
|
import org.nd4j.evaluation.classification.EvaluationBinary;
|
||||||
|
import org.nd4j.evaluation.classification.EvaluationCalibration;
|
||||||
|
import org.nd4j.evaluation.classification.ROC;
|
||||||
|
import org.nd4j.evaluation.classification.ROCBinary;
|
||||||
|
import org.nd4j.evaluation.classification.ROCMultiClass;
|
||||||
import org.nd4j.evaluation.regression.RegressionEvaluation;
|
import org.nd4j.evaluation.regression.RegressionEvaluation;
|
||||||
import org.nd4j.evaluation.regression.RegressionEvaluation.Metric;
|
import org.nd4j.evaluation.regression.RegressionEvaluation.Metric;
|
||||||
import org.nd4j.linalg.BaseNd4jTest;
|
import org.nd4j.linalg.BaseNd4jTest;
|
||||||
|
@ -107,7 +113,6 @@ public class EmptyEvaluationTests extends BaseNd4jTest {
|
||||||
assertTrue(t.getMessage(), t.getMessage().contains("no data"));
|
assertTrue(t.getMessage(), t.getMessage().contains("no data"));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
|
|
@ -17,6 +17,12 @@
|
||||||
package org.nd4j.evaluation;
|
package org.nd4j.evaluation;
|
||||||
|
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
import org.nd4j.evaluation.classification.Evaluation;
|
||||||
|
import org.nd4j.evaluation.classification.EvaluationBinary;
|
||||||
|
import org.nd4j.evaluation.classification.EvaluationCalibration;
|
||||||
|
import org.nd4j.evaluation.classification.ROC;
|
||||||
|
import org.nd4j.evaluation.classification.ROCBinary;
|
||||||
|
import org.nd4j.evaluation.classification.ROCMultiClass;
|
||||||
import org.nd4j.evaluation.curves.Histogram;
|
import org.nd4j.evaluation.curves.Histogram;
|
||||||
import org.nd4j.evaluation.curves.PrecisionRecallCurve;
|
import org.nd4j.evaluation.curves.PrecisionRecallCurve;
|
||||||
import org.nd4j.evaluation.curves.RocCurve;
|
import org.nd4j.evaluation.curves.RocCurve;
|
||||||
|
@ -100,8 +106,6 @@ public class EvalJsonTest extends BaseNd4jTest {
|
||||||
|
|
||||||
regressionEvaluation.eval(Nd4j.rand(10, 3), Nd4j.rand(10, 3));
|
regressionEvaluation.eval(Nd4j.rand(10, 3), Nd4j.rand(10, 3));
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
for (IEvaluation e : arr) {
|
for (IEvaluation e : arr) {
|
||||||
String json = e.toJson();
|
String json = e.toJson();
|
||||||
if (print) {
|
if (print) {
|
||||||
|
|
|
@ -22,6 +22,11 @@ import org.junit.Test;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.autodiff.samediff.TrainingConfig;
|
import org.nd4j.autodiff.samediff.TrainingConfig;
|
||||||
|
import org.nd4j.autodiff.samediff.transform.GraphTransformUtil;
|
||||||
|
import org.nd4j.autodiff.samediff.transform.OpPredicate;
|
||||||
|
import org.nd4j.autodiff.samediff.transform.SubGraph;
|
||||||
|
import org.nd4j.autodiff.samediff.transform.SubGraphPredicate;
|
||||||
|
import org.nd4j.autodiff.samediff.transform.SubGraphProcessor;
|
||||||
import org.nd4j.graph.ui.LogFileWriter;
|
import org.nd4j.graph.ui.LogFileWriter;
|
||||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||||
import org.nd4j.imports.tensorflow.TFImportOverride;
|
import org.nd4j.imports.tensorflow.TFImportOverride;
|
||||||
|
|
|
@ -35,7 +35,7 @@ import java.util.ServiceLoader;
|
||||||
public class Nd4jTestSuite extends BlockJUnit4ClassRunner {
|
public class Nd4jTestSuite extends BlockJUnit4ClassRunner {
|
||||||
//the system property for what backends should run
|
//the system property for what backends should run
|
||||||
public final static String BACKENDS_TO_LOAD = "backends";
|
public final static String BACKENDS_TO_LOAD = "backends";
|
||||||
private static List<Nd4jBackend> BACKENDS;
|
private static List<Nd4jBackend> BACKENDS = new ArrayList<>();
|
||||||
static {
|
static {
|
||||||
ServiceLoader<Nd4jBackend> loadedBackends = ND4JClassLoading.loadService(Nd4jBackend.class);
|
ServiceLoader<Nd4jBackend> loadedBackends = ND4JClassLoading.loadService(Nd4jBackend.class);
|
||||||
for (Nd4jBackend backend : loadedBackends) {
|
for (Nd4jBackend backend : loadedBackends) {
|
||||||
|
|
|
@ -16,6 +16,13 @@
|
||||||
|
|
||||||
package org.nd4j.linalg;
|
package org.nd4j.linalg;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertArrayEquals;
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
import static org.junit.Assert.assertFalse;
|
||||||
|
import static org.junit.Assert.assertNotEquals;
|
||||||
|
import static org.junit.Assert.assertNotNull;
|
||||||
|
import static org.junit.Assert.assertTrue;
|
||||||
|
import static org.junit.Assert.fail;
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
|
@ -23,10 +30,18 @@ import lombok.var;
|
||||||
import org.apache.commons.io.FilenameUtils;
|
import org.apache.commons.io.FilenameUtils;
|
||||||
import org.apache.commons.math3.stat.descriptive.rank.Percentile;
|
import org.apache.commons.math3.stat.descriptive.rank.Percentile;
|
||||||
import org.apache.commons.math3.util.FastMath;
|
import org.apache.commons.math3.util.FastMath;
|
||||||
import org.junit.*;
|
import org.junit.After;
|
||||||
|
import org.junit.Before;
|
||||||
|
import org.junit.Ignore;
|
||||||
|
import org.junit.Rule;
|
||||||
|
import org.junit.Test;
|
||||||
import org.junit.rules.TemporaryFolder;
|
import org.junit.rules.TemporaryFolder;
|
||||||
import org.junit.runner.RunWith;
|
import org.junit.runner.RunWith;
|
||||||
import org.junit.runners.Parameterized;
|
import org.junit.runners.Parameterized;
|
||||||
|
import org.nd4j.common.io.ClassPathResource;
|
||||||
|
import org.nd4j.common.primitives.Pair;
|
||||||
|
import org.nd4j.common.util.ArrayUtil;
|
||||||
|
import org.nd4j.common.util.MathUtils;
|
||||||
import org.nd4j.enums.WeightsFormat;
|
import org.nd4j.enums.WeightsFormat;
|
||||||
import org.nd4j.imports.TFGraphs.NodeReader;
|
import org.nd4j.imports.TFGraphs.NodeReader;
|
||||||
import org.nd4j.linalg.api.blas.Level1;
|
import org.nd4j.linalg.api.blas.Level1;
|
||||||
|
@ -47,6 +62,14 @@ import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.api.ops.Op;
|
import org.nd4j.linalg.api.ops.Op;
|
||||||
import org.nd4j.linalg.api.ops.executioner.GridExecutioner;
|
import org.nd4j.linalg.api.ops.executioner.GridExecutioner;
|
||||||
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
|
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastAMax;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastAMin;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastAddOp;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastDivOp;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMax;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMin;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMulOp;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastSubOp;
|
||||||
import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastEqualTo;
|
import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastEqualTo;
|
||||||
import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastGreaterThan;
|
import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastGreaterThan;
|
||||||
import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastGreaterThanOrEqual;
|
import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastGreaterThanOrEqual;
|
||||||
|
@ -64,6 +87,11 @@ import org.nd4j.linalg.api.ops.impl.reduce.custom.LogSumExp;
|
||||||
import org.nd4j.linalg.api.ops.impl.reduce.floating.Norm1;
|
import org.nd4j.linalg.api.ops.impl.reduce.floating.Norm1;
|
||||||
import org.nd4j.linalg.api.ops.impl.reduce.floating.Norm2;
|
import org.nd4j.linalg.api.ops.impl.reduce.floating.Norm2;
|
||||||
import org.nd4j.linalg.api.ops.impl.reduce.same.Sum;
|
import org.nd4j.linalg.api.ops.impl.reduce.same.Sum;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.reduce3.CosineDistance;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.reduce3.CosineSimilarity;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.reduce3.EuclideanDistance;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.reduce3.HammingDistance;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance;
|
||||||
import org.nd4j.linalg.api.ops.impl.scalar.LeakyReLU;
|
import org.nd4j.linalg.api.ops.impl.scalar.LeakyReLU;
|
||||||
import org.nd4j.linalg.api.ops.impl.scalar.ReplaceNans;
|
import org.nd4j.linalg.api.ops.impl.scalar.ReplaceNans;
|
||||||
import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarEquals;
|
import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarEquals;
|
||||||
|
@ -73,8 +101,8 @@ import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.bool.MatchConditionTransform;
|
import org.nd4j.linalg.api.ops.impl.transforms.bool.MatchConditionTransform;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet;
|
import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.comparison.Eps;
|
import org.nd4j.linalg.api.ops.impl.transforms.comparison.Eps;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.Reverse;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.BatchToSpaceND;
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.BatchToSpaceND;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.Reverse;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.BinaryRelativeError;
|
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.BinaryRelativeError;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.Set;
|
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.Set;
|
||||||
|
@ -94,20 +122,28 @@ import org.nd4j.linalg.indexing.INDArrayIndex;
|
||||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||||
import org.nd4j.linalg.indexing.SpecifiedIndex;
|
import org.nd4j.linalg.indexing.SpecifiedIndex;
|
||||||
import org.nd4j.linalg.indexing.conditions.Conditions;
|
import org.nd4j.linalg.indexing.conditions.Conditions;
|
||||||
import org.nd4j.common.io.ClassPathResource;
|
|
||||||
import org.nd4j.linalg.ops.transforms.Transforms;
|
import org.nd4j.linalg.ops.transforms.Transforms;
|
||||||
import org.nd4j.common.primitives.Pair;
|
|
||||||
import org.nd4j.common.util.ArrayUtil;
|
|
||||||
import org.nd4j.common.util.MathUtils;
|
|
||||||
|
|
||||||
import java.io.*;
|
import java.io.ByteArrayInputStream;
|
||||||
|
import java.io.ByteArrayOutputStream;
|
||||||
|
import java.io.DataInputStream;
|
||||||
|
import java.io.DataOutputStream;
|
||||||
|
import java.io.File;
|
||||||
|
import java.io.FileInputStream;
|
||||||
|
import java.io.FileOutputStream;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.ObjectInputStream;
|
||||||
|
import java.io.ObjectOutputStream;
|
||||||
import java.nio.ByteBuffer;
|
import java.nio.ByteBuffer;
|
||||||
import java.nio.ByteOrder;
|
import java.nio.ByteOrder;
|
||||||
import java.nio.file.Files;
|
import java.nio.file.Files;
|
||||||
import java.nio.file.Paths;
|
import java.nio.file.Paths;
|
||||||
import java.util.*;
|
import java.util.ArrayList;
|
||||||
|
import java.util.Arrays;
|
||||||
import static org.junit.Assert.*;
|
import java.util.Collections;
|
||||||
|
import java.util.HashSet;
|
||||||
|
import java.util.Iterator;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* NDArrayTests
|
* NDArrayTests
|
||||||
|
@ -148,8 +184,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
Nd4j.setDataType(initialType);
|
Nd4j.setDataType(initialType);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testArangeNegative() {
|
public void testArangeNegative() {
|
||||||
INDArray arr = Nd4j.arange(-2,2).castTo(DataType.DOUBLE);
|
INDArray arr = Nd4j.arange(-2,2).castTo(DataType.DOUBLE);
|
||||||
|
@ -241,9 +275,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
|
|
||||||
INDArray inDup = in.dup();
|
INDArray inDup = in.dup();
|
||||||
|
|
||||||
// System.out.println(in);
|
|
||||||
// System.out.println(inDup);
|
|
||||||
|
|
||||||
assertEquals(arr, in); //Passes: Original array "in" is OK, but array "inDup" is not!?
|
assertEquals(arr, in); //Passes: Original array "in" is OK, but array "inDup" is not!?
|
||||||
assertEquals(in, inDup); //Fails
|
assertEquals(in, inDup); //Fails
|
||||||
}
|
}
|
||||||
|
@ -310,7 +341,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
assertEquals(assertion,test);
|
assertEquals(assertion,test);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testAudoBroadcastAddMatrix() {
|
public void testAudoBroadcastAddMatrix() {
|
||||||
INDArray arr = Nd4j.linspace(1,4,4, DataType.DOUBLE).reshape(2,2);
|
INDArray arr = Nd4j.linspace(1,4,4, DataType.DOUBLE).reshape(2,2);
|
||||||
|
@ -336,7 +366,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testTensorAlongDimension() {
|
public void testTensorAlongDimension() {
|
||||||
val shape = new long[] {4, 5, 7};
|
val shape = new long[] {4, 5, 7};
|
||||||
|
@ -538,7 +567,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testGetColumns() {
|
public void testGetColumns() {
|
||||||
INDArray matrix = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3);
|
INDArray matrix = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3);
|
||||||
|
@ -2719,7 +2747,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
assertEquals(exp, zOutCF); //fails
|
assertEquals(exp, zOutCF); //fails
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testBroadcastDiv() {
|
public void testBroadcastDiv() {
|
||||||
INDArray num = Nd4j.create(new double[] {1.00, 1.00, 1.00, 1.00, 2.00, 2.00, 2.00, 2.00, 1.00, 1.00, 1.00, 1.00,
|
INDArray num = Nd4j.create(new double[] {1.00, 1.00, 1.00, 1.00, 2.00, 2.00, 2.00, 2.00, 1.00, 1.00, 1.00, 1.00,
|
||||||
|
@ -2753,7 +2780,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testBroadcastMult() {
|
public void testBroadcastMult() {
|
||||||
INDArray num = Nd4j.create(new double[] {1.00, 2.00, 3.00, 4.00, 5.00, 6.00, 7.00, 8.00, -1.00, -2.00, -3.00,
|
INDArray num = Nd4j.create(new double[] {1.00, 2.00, 3.00, 4.00, 5.00, 6.00, 7.00, 8.00, -1.00, -2.00, -3.00,
|
||||||
|
@ -2796,7 +2822,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
assertEquals(expected, actual);
|
assertEquals(expected, actual);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testDimension() {
|
public void testDimension() {
|
||||||
INDArray test = Nd4j.create(Nd4j.linspace(1, 4, 4, DataType.DOUBLE).data(), new long[] {2, 2});
|
INDArray test = Nd4j.create(Nd4j.linspace(1, 4, 4, DataType.DOUBLE).data(), new long[] {2, 2});
|
||||||
|
@ -4595,8 +4620,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
-1.25485503673});
|
-1.25485503673});
|
||||||
|
|
||||||
INDArray reduced = Nd4j.getExecutioner().exec(new CosineDistance(haystack, needle, 1));
|
INDArray reduced = Nd4j.getExecutioner().exec(new CosineDistance(haystack, needle, 1));
|
||||||
// log.info("Reduced: {}", reduced);
|
|
||||||
|
|
||||||
|
|
||||||
INDArray exp = Nd4j.create(new double[] {0.577452, 0.0, 1.80182});
|
INDArray exp = Nd4j.create(new double[] {0.577452, 0.0, 1.80182});
|
||||||
assertEquals(exp, reduced);
|
assertEquals(exp, reduced);
|
||||||
|
@ -4606,9 +4629,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
double res = Nd4j.getExecutioner().execAndReturn(new CosineDistance(row, needle)).z().getDouble(0);
|
double res = Nd4j.getExecutioner().execAndReturn(new CosineDistance(row, needle)).z().getDouble(0);
|
||||||
assertEquals("Failed at " + i, reduced.getDouble(i), res, 1e-5);
|
assertEquals("Failed at " + i, reduced.getDouble(i), res, 1e-5);
|
||||||
}
|
}
|
||||||
//cosinedistance([-0.84443557262, -0.06822254508, 0.74266910552, 0.61765557527, -0.77555125951], [-0.99536740779, -0.0257304441183, -0.6512106060, -0.345789492130, -1.25485503673)
|
|
||||||
//cosinedistance([.62955373525, -0.31357592344, 1.03362500667, -0.59279078245, 1.1914824247], [-0.99536740779, -0.0257304441183, -0.6512106060, -0.345789492130, -1.25485503673)
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -4677,8 +4697,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
double res = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(x, needle)).getFinalResult()
|
double res = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(x, needle)).getFinalResult()
|
||||||
.doubleValue();
|
.doubleValue();
|
||||||
assertEquals("Failed at " + i, reduced.getDouble(i), res, 0.001);
|
assertEquals("Failed at " + i, reduced.getDouble(i), res, 0.001);
|
||||||
|
|
||||||
// log.info("Euclidean: {} vs {} is {}", x, needle, res);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -4698,8 +4716,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
double res = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(x, needle)).getFinalResult()
|
double res = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(x, needle)).getFinalResult()
|
||||||
.doubleValue();
|
.doubleValue();
|
||||||
assertEquals("Failed at " + i, reduced.getDouble(i), res, 0.001);
|
assertEquals("Failed at " + i, reduced.getDouble(i), res, 0.001);
|
||||||
|
|
||||||
// log.info("Euclidean: {} vs {} is {}", x, needle, res);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -4720,8 +4736,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
double res = Nd4j.getExecutioner().execAndReturn(new CosineSimilarity(x, needle)).getFinalResult()
|
double res = Nd4j.getExecutioner().execAndReturn(new CosineSimilarity(x, needle)).getFinalResult()
|
||||||
.doubleValue();
|
.doubleValue();
|
||||||
assertEquals("Failed at " + i, reduced.getDouble(i), res, 0.001);
|
assertEquals("Failed at " + i, reduced.getDouble(i), res, 0.001);
|
||||||
|
|
||||||
// log.info("Cosine: {} vs {} is {}", x, needle, res);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -4755,7 +4769,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testAtan2_1() {
|
public void testAtan2_1() {
|
||||||
INDArray x = Nd4j.create(10).assign(-1.0);
|
INDArray x = Nd4j.create(10).assign(-1.0);
|
||||||
|
@ -4767,7 +4780,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
assertEquals(exp, z);
|
assertEquals(exp, z);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testAtan2_2() {
|
public void testAtan2_2() {
|
||||||
INDArray x = Nd4j.create(10).assign(1.0);
|
INDArray x = Nd4j.create(10).assign(1.0);
|
||||||
|
@ -4779,7 +4791,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
assertEquals(exp, z);
|
assertEquals(exp, z);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testJaccardDistance1() {
|
public void testJaccardDistance1() {
|
||||||
INDArray x = Nd4j.create(new double[] {0, 1, 0, 0, 1, 0});
|
INDArray x = Nd4j.create(new double[] {0, 1, 0, 0, 1, 0});
|
||||||
|
@ -4790,7 +4801,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
assertEquals(0.75, val, 1e-5);
|
assertEquals(0.75, val, 1e-5);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testJaccardDistance2() {
|
public void testJaccardDistance2() {
|
||||||
INDArray x = Nd4j.create(new double[] {0, 1, 0, 0, 1, 1});
|
INDArray x = Nd4j.create(new double[] {0, 1, 0, 0, 1, 1});
|
||||||
|
@ -4811,7 +4821,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
assertEquals(2.0 / 6, val, 1e-5);
|
assertEquals(2.0 / 6, val, 1e-5);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testHammingDistance2() {
|
public void testHammingDistance2() {
|
||||||
INDArray x = Nd4j.create(new double[] {0, 0, 0, 1, 0, 0});
|
INDArray x = Nd4j.create(new double[] {0, 0, 0, 1, 0, 0});
|
||||||
|
@ -4822,7 +4831,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
assertEquals(3.0 / 6, val, 1e-5);
|
assertEquals(3.0 / 6, val, 1e-5);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testHammingDistance3() {
|
public void testHammingDistance3() {
|
||||||
INDArray x = Nd4j.create(DataType.DOUBLE, 10, 6);
|
INDArray x = Nd4j.create(DataType.DOUBLE, 10, 6);
|
||||||
|
@ -4831,7 +4839,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
x.getRow(r).putScalar(p, 1);
|
x.getRow(r).putScalar(p, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
// log.info("X: {}", x);
|
|
||||||
INDArray y = Nd4j.create(new double[] {0, 0, 0, 0, 1, 0});
|
INDArray y = Nd4j.create(new double[] {0, 0, 0, 0, 1, 0});
|
||||||
|
|
||||||
INDArray res = Nd4j.getExecutioner().exec(new HammingDistance(x, y, 1));
|
INDArray res = Nd4j.getExecutioner().exec(new HammingDistance(x, y, 1));
|
||||||
|
@ -4846,7 +4853,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testAllDistances1() {
|
public void testAllDistances1() {
|
||||||
INDArray initialX = Nd4j.create(5, 10);
|
INDArray initialX = Nd4j.create(5, 10);
|
||||||
|
@ -4879,7 +4885,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testAllDistances2() {
|
public void testAllDistances2() {
|
||||||
INDArray initialX = Nd4j.create(5, 10);
|
INDArray initialX = Nd4j.create(5, 10);
|
||||||
|
@ -4940,7 +4945,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testAllDistances3_Large() {
|
public void testAllDistances3_Large() {
|
||||||
INDArray initialX = Nd4j.create(5, 2000);
|
INDArray initialX = Nd4j.create(5, 2000);
|
||||||
|
@ -4968,13 +4972,11 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
double res = result.getDouble(x, y);
|
double res = result.getDouble(x, y);
|
||||||
double exp = Transforms.euclideanDistance(rowX, initialY.getRow(y).dup());
|
double exp = Transforms.euclideanDistance(rowX, initialY.getRow(y).dup());
|
||||||
|
|
||||||
//log.info("Expected [{}, {}]: {}",x, y, exp);
|
|
||||||
assertEquals("Failed for [" + x + ", " + y + "]", exp, res, 0.001);
|
assertEquals("Failed for [" + x + ", " + y + "]", exp, res, 0.001);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testAllDistances3_Large_Columns() {
|
public void testAllDistances3_Large_Columns() {
|
||||||
INDArray initialX = Nd4j.create(2000, 5);
|
INDArray initialX = Nd4j.create(2000, 5);
|
||||||
|
@ -5005,7 +5007,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testAllDistances4_Large_Columns() {
|
public void testAllDistances4_Large_Columns() {
|
||||||
INDArray initialX = Nd4j.create(2000, 5);
|
INDArray initialX = Nd4j.create(2000, 5);
|
||||||
|
@ -5095,8 +5096,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testAllDistances3() {
|
public void testAllDistances3() {
|
||||||
Nd4j.getRandom().setSeed(123);
|
Nd4j.getRandom().setSeed(123);
|
||||||
|
@ -5122,7 +5121,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testStridedTransforms1() {
|
public void testStridedTransforms1() {
|
||||||
//output: Rank: 2,Offset: 0
|
//output: Rank: 2,Offset: 0
|
||||||
|
@ -5176,7 +5174,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testEntropy3() {
|
public void testEntropy3() {
|
||||||
INDArray x = Nd4j.rand(1, 100);
|
INDArray x = Nd4j.rand(1, 100);
|
||||||
|
@ -5197,7 +5194,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
assertEquals(exp, res, 1e-5);
|
assertEquals(exp, res, 1e-5);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
protected double getShannonEntropy(double[] array) {
|
protected double getShannonEntropy(double[] array) {
|
||||||
double ret = 0;
|
double ret = 0;
|
||||||
for (double x : array) {
|
for (double x : array) {
|
||||||
|
@ -5207,12 +5203,10 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
return -ret;
|
return -ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
protected double getLogEntropy(double[] array) {
|
protected double getLogEntropy(double[] array) {
|
||||||
return Math.log(MathUtils.entropy(array));
|
return Math.log(MathUtils.entropy(array));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testReverse1() {
|
public void testReverse1() {
|
||||||
INDArray array = Nd4j.create(new double[] {9, 8, 7, 6, 5, 4, 3, 2, 1, 0});
|
INDArray array = Nd4j.create(new double[] {9, 8, 7, 6, 5, 4, 3, 2, 1, 0});
|
||||||
|
@ -5228,8 +5222,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
INDArray array = Nd4j.create(new double[] {10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0});
|
INDArray array = Nd4j.create(new double[] {10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0});
|
||||||
INDArray exp = Nd4j.create(new double[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10});
|
INDArray exp = Nd4j.create(new double[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10});
|
||||||
|
|
||||||
// log.info("Array shapeInfo: {}", array.shapeInfoJava());
|
|
||||||
|
|
||||||
INDArray rev = Nd4j.reverse(array);
|
INDArray rev = Nd4j.reverse(array);
|
||||||
|
|
||||||
assertEquals(exp, rev);
|
assertEquals(exp, rev);
|
||||||
|
@ -5278,7 +5270,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
assertTrue(rev == array);
|
assertTrue(rev == array);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testNativeSortView1() {
|
public void testNativeSortView1() {
|
||||||
INDArray matrix = Nd4j.create(10, 10);
|
INDArray matrix = Nd4j.create(10, 10);
|
||||||
|
@ -5291,9 +5282,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
|
|
||||||
Nd4j.sort(matrix.getColumn(0), true);
|
Nd4j.sort(matrix.getColumn(0), true);
|
||||||
|
|
||||||
|
|
||||||
// log.info("Matrix: {}", matrix);
|
|
||||||
|
|
||||||
assertEquals(exp, matrix.getColumn(0));
|
assertEquals(exp, matrix.getColumn(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -5384,9 +5372,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
|
|
||||||
Transforms.reverse(array, false);
|
Transforms.reverse(array, false);
|
||||||
|
|
||||||
// log.info("Reversed shapeInfo: {}", array.shapeInfoJava());
|
|
||||||
// log.info("Reversed: {}", array);
|
|
||||||
|
|
||||||
Transforms.reverse(array, false);
|
Transforms.reverse(array, false);
|
||||||
|
|
||||||
val jexp = exp.data().asInt();
|
val jexp = exp.data().asInt();
|
||||||
|
@ -5401,9 +5386,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
val exp = array.dup(array.ordering());
|
val exp = array.dup(array.ordering());
|
||||||
|
|
||||||
val reversed = Transforms.reverse(array, true);
|
val reversed = Transforms.reverse(array, true);
|
||||||
|
|
||||||
// log.info("Reversed: {}", reversed);
|
|
||||||
|
|
||||||
val rereversed = Transforms.reverse(reversed, true);
|
val rereversed = Transforms.reverse(reversed, true);
|
||||||
|
|
||||||
val jexp = exp.data().asInt();
|
val jexp = exp.data().asInt();
|
||||||
|
@ -5445,8 +5427,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
INDArray array = Nd4j.linspace(1, 2017152, 2017152, DataType.DOUBLE).reshape(1, -1);
|
INDArray array = Nd4j.linspace(1, 2017152, 2017152, DataType.DOUBLE).reshape(1, -1);
|
||||||
INDArray exp = array.dup();
|
INDArray exp = array.dup();
|
||||||
Transforms.reverse(array, false);
|
Transforms.reverse(array, false);
|
||||||
// log.info("Reverse: {}", array);
|
|
||||||
|
|
||||||
|
|
||||||
long time1 = System.currentTimeMillis();
|
long time1 = System.currentTimeMillis();
|
||||||
INDArray res = Nd4j.sort(array, true);
|
INDArray res = Nd4j.sort(array, true);
|
||||||
|
@ -5465,7 +5445,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
|
|
||||||
assertNotEquals(exp1, dps);
|
assertNotEquals(exp1, dps);
|
||||||
|
|
||||||
|
|
||||||
for (int r = 0; r < array.rows(); r++) {
|
for (int r = 0; r < array.rows(); r++) {
|
||||||
array.getRow(r).assign(dps);
|
array.getRow(r).assign(dps);
|
||||||
}
|
}
|
||||||
|
@ -5485,7 +5464,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
protected boolean checkIfUnique(INDArray array, int iteration) {
|
protected boolean checkIfUnique(INDArray array, int iteration) {
|
||||||
var jarray = array.data().asInt();
|
var jarray = array.data().asInt();
|
||||||
var set = new HashSet<Integer>();
|
var set = new HashSet<Integer>();
|
||||||
|
@ -5698,7 +5676,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testBroadcastAMax() {
|
public void testBroadcastAMax() {
|
||||||
INDArray matrix = Nd4j.create(5, 5);
|
INDArray matrix = Nd4j.create(5, 5);
|
||||||
|
@ -5715,7 +5692,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testBroadcastAMin() {
|
public void testBroadcastAMin() {
|
||||||
INDArray matrix = Nd4j.create(5, 5);
|
INDArray matrix = Nd4j.create(5, 5);
|
||||||
|
@ -5767,7 +5743,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
assertEquals(exp, res);
|
assertEquals(exp, res);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testRDiv1() {
|
public void testRDiv1() {
|
||||||
val argX = Nd4j.create(3).assign(2.0);
|
val argX = Nd4j.create(3).assign(2.0);
|
||||||
|
@ -5789,7 +5764,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
assertEquals(arrayC, arrayF);
|
assertEquals(arrayC, arrayF);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testMatchTransform() {
|
public void testMatchTransform() {
|
||||||
val array = Nd4j.create(new double[] {1, 1, 1, 0, 1, 1},'c');
|
val array = Nd4j.create(new double[] {1, 1, 1, 0, 1, 1},'c');
|
||||||
|
@ -5838,10 +5812,6 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
|
|
||||||
val a = Nd4j.linspace(1, x * A1 * A2, x * A1 * A2, DataType.DOUBLE).reshape(x, A1, A2);
|
val a = Nd4j.linspace(1, x * A1 * A2, x * A1 * A2, DataType.DOUBLE).reshape(x, A1, A2);
|
||||||
val b = Nd4j.linspace(1, x * B1 * B2, x * B1 * B2, DataType.DOUBLE).reshape(x, B1, B2);
|
val b = Nd4j.linspace(1, x * B1 * B2, x * B1 * B2, DataType.DOUBLE).reshape(x, B1, B2);
|
||||||
|
|
||||||
//
|
|
||||||
|
|
||||||
//log.info("C shape: {}", Arrays.toString(c.shapeInfoDataBuffer().asInt()));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
|
|
@ -21,6 +21,21 @@ import org.junit.Test;
|
||||||
import org.junit.runner.RunWith;
|
import org.junit.runner.RunWith;
|
||||||
import org.junit.runners.Parameterized;
|
import org.junit.runners.Parameterized;
|
||||||
import org.nd4j.linalg.BaseNd4jTest;
|
import org.nd4j.linalg.BaseNd4jTest;
|
||||||
|
import org.nd4j.linalg.activations.impl.ActivationCube;
|
||||||
|
import org.nd4j.linalg.activations.impl.ActivationELU;
|
||||||
|
import org.nd4j.linalg.activations.impl.ActivationGELU;
|
||||||
|
import org.nd4j.linalg.activations.impl.ActivationHardSigmoid;
|
||||||
|
import org.nd4j.linalg.activations.impl.ActivationHardTanH;
|
||||||
|
import org.nd4j.linalg.activations.impl.ActivationIdentity;
|
||||||
|
import org.nd4j.linalg.activations.impl.ActivationLReLU;
|
||||||
|
import org.nd4j.linalg.activations.impl.ActivationRReLU;
|
||||||
|
import org.nd4j.linalg.activations.impl.ActivationRationalTanh;
|
||||||
|
import org.nd4j.linalg.activations.impl.ActivationReLU;
|
||||||
|
import org.nd4j.linalg.activations.impl.ActivationSigmoid;
|
||||||
|
import org.nd4j.linalg.activations.impl.ActivationSoftPlus;
|
||||||
|
import org.nd4j.linalg.activations.impl.ActivationSoftSign;
|
||||||
|
import org.nd4j.linalg.activations.impl.ActivationSoftmax;
|
||||||
|
import org.nd4j.linalg.activations.impl.ActivationTanH;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
@ -73,7 +88,6 @@ public class TestActivation extends BaseNd4jTest {
|
||||||
double[] dIn = in.data().asDouble();
|
double[] dIn = in.data().asDouble();
|
||||||
|
|
||||||
for( int i=0; i<max.length; i++ ){
|
for( int i=0; i<max.length; i++ ){
|
||||||
// System.out.println("i = " + i);
|
|
||||||
ActivationReLU r = new ActivationReLU(max[i], threshold[i], negativeSlope[i]);
|
ActivationReLU r = new ActivationReLU(max[i], threshold[i], negativeSlope[i]);
|
||||||
INDArray out = r.getActivation(in.dup(), true);
|
INDArray out = r.getActivation(in.dup(), true);
|
||||||
double[] exp = new double[dIn.length];
|
double[] exp = new double[dIn.length];
|
||||||
|
@ -145,7 +159,6 @@ public class TestActivation extends BaseNd4jTest {
|
||||||
|
|
||||||
for (int i = 0; i < activations.length; i++) {
|
for (int i = 0; i < activations.length; i++) {
|
||||||
String asJson = mapper.writeValueAsString(activations[i]);
|
String asJson = mapper.writeValueAsString(activations[i]);
|
||||||
// System.out.println(asJson);
|
|
||||||
|
|
||||||
JsonNode node = mapper.readTree(asJson);
|
JsonNode node = mapper.readTree(asJson);
|
||||||
|
|
||||||
|
|
|
@ -26,6 +26,13 @@ import org.nd4j.linalg.api.iter.NdIndexIterator;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||||
|
import org.nd4j.linalg.indexing.INDArrayIndex;
|
||||||
|
import org.nd4j.linalg.indexing.IntervalIndex;
|
||||||
|
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||||
|
import org.nd4j.linalg.indexing.NDArrayIndexAll;
|
||||||
|
import org.nd4j.linalg.indexing.NewAxis;
|
||||||
|
import org.nd4j.linalg.indexing.PointIndex;
|
||||||
|
import org.nd4j.linalg.indexing.SpecifiedIndex;
|
||||||
import org.nd4j.linalg.ops.transforms.Transforms;
|
import org.nd4j.linalg.ops.transforms.Transforms;
|
||||||
import org.nd4j.common.util.ArrayUtil;
|
import org.nd4j.common.util.ArrayUtil;
|
||||||
|
|
||||||
|
@ -75,8 +82,6 @@ public class IndexingTestsC extends BaseNd4jTest {
|
||||||
|
|
||||||
final INDArray aBad = col.broadcast(2, 2);
|
final INDArray aBad = col.broadcast(2, 2);
|
||||||
final INDArray aGood = col.dup().broadcast(2, 2);
|
final INDArray aGood = col.dup().broadcast(2, 2);
|
||||||
// System.out.println(aBad);
|
|
||||||
// System.out.println(aGood);
|
|
||||||
assertTrue(Transforms.abs(aGood.sub(aBad).div(aGood)).maxNumber().doubleValue() < 0.01);
|
assertTrue(Transforms.abs(aGood.sub(aBad).div(aGood)).maxNumber().doubleValue() < 0.01);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -446,12 +451,6 @@ public class IndexingTestsC extends BaseNd4jTest {
|
||||||
msg = "Test case: rank = " + rank + ", order = " + order + ", inShape = " + Arrays.toString(inShape) +
|
msg = "Test case: rank = " + rank + ", order = " + order + ", inShape = " + Arrays.toString(inShape) +
|
||||||
", outShape = " + Arrays.toString(expShape) +
|
", outShape = " + Arrays.toString(expShape) +
|
||||||
", indexes = " + Arrays.toString(indexes) + ", newAxisTest=" + newAxisTestCase;
|
", indexes = " + Arrays.toString(indexes) + ", newAxisTest=" + newAxisTestCase;
|
||||||
// System.out.println(msg);
|
|
||||||
|
|
||||||
// System.out.println(arr);
|
|
||||||
// System.out.println(sub);
|
|
||||||
// System.out.println();
|
|
||||||
|
|
||||||
|
|
||||||
NdIndexIterator posIter = new NdIndexIterator(expShape);
|
NdIndexIterator posIter = new NdIndexIterator(expShape);
|
||||||
while (posIter.hasNext()) {
|
while (posIter.hasNext()) {
|
||||||
|
@ -467,7 +466,6 @@ public class IndexingTestsC extends BaseNd4jTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// System.out.println("TOTAL TEST CASES: " + totalTestCaseCount);
|
|
||||||
assertTrue(String.valueOf(totalTestCaseCount), totalTestCaseCount > 5000);
|
assertTrue(String.valueOf(totalTestCaseCount), totalTestCaseCount > 5000);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -556,14 +554,8 @@ public class IndexingTestsC extends BaseNd4jTest {
|
||||||
char order = 'c';
|
char order = 'c';
|
||||||
INDArray arr = Nd4j.linspace(DataType.FLOAT, 1, prod, prod).reshape('c', inShape).dup(order);
|
INDArray arr = Nd4j.linspace(DataType.FLOAT, 1, prod, prod).reshape('c', inShape).dup(order);
|
||||||
INDArray sub = arr.get(indexes);
|
INDArray sub = arr.get(indexes);
|
||||||
|
|
||||||
// System.out.println(Arrays.toString(indexes));
|
|
||||||
// System.out.println(arr);
|
|
||||||
// System.out.println();
|
|
||||||
// System.out.println(sub);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public char ordering() {
|
public char ordering() {
|
||||||
return 'c';
|
return 'c';
|
||||||
|
|
|
@ -1,8 +1,13 @@
|
||||||
package org.nd4j.linalg.convolution;
|
package org.nd4j.linalg.convolution;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertFalse;
|
||||||
|
import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
import org.junit.Rule;
|
import org.junit.Rule;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.junit.rules.TemporaryFolder;
|
import org.junit.rules.TemporaryFolder;
|
||||||
|
import org.nd4j.common.io.ClassPathResource;
|
||||||
|
import org.nd4j.common.resources.Resources;
|
||||||
import org.nd4j.linalg.BaseNd4jTest;
|
import org.nd4j.linalg.BaseNd4jTest;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -10,12 +15,13 @@ import org.nd4j.linalg.api.ops.CustomOp;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||||
import org.nd4j.common.resources.Resources;
|
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.util.*;
|
import java.util.ArrayList;
|
||||||
|
import java.util.Collections;
|
||||||
import static org.junit.Assert.*;
|
import java.util.HashSet;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Set;
|
||||||
|
|
||||||
public class DeconvTests extends BaseNd4jTest {
|
public class DeconvTests extends BaseNd4jTest {
|
||||||
|
|
||||||
|
@ -33,10 +39,10 @@ public class DeconvTests extends BaseNd4jTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void compareKeras() throws Exception {
|
public void compareKeras() throws Exception {
|
||||||
File f = testDir.newFolder();
|
File newFolder = testDir.newFolder();
|
||||||
Resources.copyDirectory("keras/deconv", f);
|
new ClassPathResource("keras/deconv/").copyDirectory(newFolder);
|
||||||
|
|
||||||
File[] files = f.listFiles();
|
File[] files = newFolder.listFiles();
|
||||||
|
|
||||||
Set<String> tests = new HashSet<>();
|
Set<String> tests = new HashSet<>();
|
||||||
for(File file : files){
|
for(File file : files){
|
||||||
|
@ -64,10 +70,10 @@ public class DeconvTests extends BaseNd4jTest {
|
||||||
int d = Integer.parseInt(nums[5]);
|
int d = Integer.parseInt(nums[5]);
|
||||||
boolean nchw = s.contains("nchw");
|
boolean nchw = s.contains("nchw");
|
||||||
|
|
||||||
INDArray w = Nd4j.readNpy(new File(f, s + "_W.npy"));
|
INDArray w = Nd4j.readNpy(new File(newFolder, s + "_W.npy"));
|
||||||
INDArray b = Nd4j.readNpy(new File(f, s + "_b.npy"));
|
INDArray b = Nd4j.readNpy(new File(newFolder, s + "_b.npy"));
|
||||||
INDArray in = Nd4j.readNpy(new File(f, s + "_in.npy")).castTo(DataType.FLOAT);
|
INDArray in = Nd4j.readNpy(new File(newFolder, s + "_in.npy")).castTo(DataType.FLOAT);
|
||||||
INDArray expOut = Nd4j.readNpy(new File(f, s + "_out.npy"));
|
INDArray expOut = Nd4j.readNpy(new File(newFolder, s + "_out.npy"));
|
||||||
|
|
||||||
CustomOp op = DynamicCustomOp.builder("deconv2d")
|
CustomOp op = DynamicCustomOp.builder("deconv2d")
|
||||||
.addInputs(in, w, b)
|
.addInputs(in, w, b)
|
||||||
|
|
|
@ -26,6 +26,37 @@ import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.CustomOp;
|
import org.nd4j.linalg.api.ops.CustomOp;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
import org.nd4j.linalg.api.ops.custom.AdjustContrast;
|
||||||
|
import org.nd4j.linalg.api.ops.custom.AdjustHue;
|
||||||
|
import org.nd4j.linalg.api.ops.custom.AdjustSaturation;
|
||||||
|
import org.nd4j.linalg.api.ops.custom.BetaInc;
|
||||||
|
import org.nd4j.linalg.api.ops.custom.BitCast;
|
||||||
|
import org.nd4j.linalg.api.ops.custom.CompareAndBitpack;
|
||||||
|
import org.nd4j.linalg.api.ops.custom.DivideNoNan;
|
||||||
|
import org.nd4j.linalg.api.ops.custom.DrawBoundingBoxes;
|
||||||
|
import org.nd4j.linalg.api.ops.custom.FakeQuantWithMinMaxVarsPerChannel;
|
||||||
|
import org.nd4j.linalg.api.ops.custom.Flatten;
|
||||||
|
import org.nd4j.linalg.api.ops.custom.FusedBatchNorm;
|
||||||
|
import org.nd4j.linalg.api.ops.custom.HsvToRgb;
|
||||||
|
import org.nd4j.linalg.api.ops.custom.KnnMinDistance;
|
||||||
|
import org.nd4j.linalg.api.ops.custom.Lgamma;
|
||||||
|
import org.nd4j.linalg.api.ops.custom.LinearSolve;
|
||||||
|
import org.nd4j.linalg.api.ops.custom.Logdet;
|
||||||
|
import org.nd4j.linalg.api.ops.custom.Lstsq;
|
||||||
|
import org.nd4j.linalg.api.ops.custom.Lu;
|
||||||
|
import org.nd4j.linalg.api.ops.custom.MatrixBandPart;
|
||||||
|
import org.nd4j.linalg.api.ops.custom.Polygamma;
|
||||||
|
import org.nd4j.linalg.api.ops.custom.RandomCrop;
|
||||||
|
import org.nd4j.linalg.api.ops.custom.RgbToGrayscale;
|
||||||
|
import org.nd4j.linalg.api.ops.custom.RgbToHsv;
|
||||||
|
import org.nd4j.linalg.api.ops.custom.RgbToYiq;
|
||||||
|
import org.nd4j.linalg.api.ops.custom.RgbToYuv;
|
||||||
|
import org.nd4j.linalg.api.ops.custom.Roll;
|
||||||
|
import org.nd4j.linalg.api.ops.custom.ScatterUpdate;
|
||||||
|
import org.nd4j.linalg.api.ops.custom.ToggleBits;
|
||||||
|
import org.nd4j.linalg.api.ops.custom.TriangularSolve;
|
||||||
|
import org.nd4j.linalg.api.ops.custom.YiqToRgb;
|
||||||
|
import org.nd4j.linalg.api.ops.custom.YuvToRgb;
|
||||||
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
|
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
|
||||||
import org.nd4j.linalg.api.ops.executioner.OpStatus;
|
import org.nd4j.linalg.api.ops.executioner.OpStatus;
|
||||||
import org.nd4j.linalg.api.ops.impl.controlflow.Where;
|
import org.nd4j.linalg.api.ops.impl.controlflow.Where;
|
||||||
|
@ -373,7 +404,6 @@ public class CustomOpsTests extends BaseNd4jTest {
|
||||||
ScatterUpdate op = new ScatterUpdate(matrix, updates, indices, dims, ScatterUpdate.UpdateOp.ADD);
|
ScatterUpdate op = new ScatterUpdate(matrix, updates, indices, dims, ScatterUpdate.UpdateOp.ADD);
|
||||||
Nd4j.getExecutioner().exec(op);
|
Nd4j.getExecutioner().exec(op);
|
||||||
|
|
||||||
// log.info("Matrix: {}", matrix);
|
|
||||||
assertEquals(exp0, matrix.getRow(0));
|
assertEquals(exp0, matrix.getRow(0));
|
||||||
assertEquals(exp1, matrix.getRow(1));
|
assertEquals(exp1, matrix.getRow(1));
|
||||||
assertEquals(exp0, matrix.getRow(2));
|
assertEquals(exp0, matrix.getRow(2));
|
||||||
|
@ -1384,8 +1414,6 @@ public class CustomOpsTests extends BaseNd4jTest {
|
||||||
INDArray y = Nd4j.linspace(DataType.FLOAT, -5, 9, 1).reshape(3, 3);
|
INDArray y = Nd4j.linspace(DataType.FLOAT, -5, 9, 1).reshape(3, 3);
|
||||||
val c = Conditions.equals(0.0);
|
val c = Conditions.equals(0.0);
|
||||||
|
|
||||||
// System.out.println("Y:\n" + y);
|
|
||||||
|
|
||||||
INDArray z = x.match(y, c);
|
INDArray z = x.match(y, c);
|
||||||
INDArray exp = Nd4j.createFromArray(new boolean[][]{
|
INDArray exp = Nd4j.createFromArray(new boolean[][]{
|
||||||
{false, false, false},
|
{false, false, false},
|
||||||
|
@ -1396,7 +1424,6 @@ public class CustomOpsTests extends BaseNd4jTest {
|
||||||
assertEquals(exp, z);
|
assertEquals(exp, z);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testCreateOp_1() {
|
public void testCreateOp_1() {
|
||||||
val shape = Nd4j.createFromArray(new int[] {3, 4, 5});
|
val shape = Nd4j.createFromArray(new int[] {3, 4, 5});
|
||||||
|
@ -1862,11 +1889,9 @@ public class CustomOpsTests extends BaseNd4jTest {
|
||||||
System.out.println("in: " + in.shapeInfoToString());
|
System.out.println("in: " + in.shapeInfoToString());
|
||||||
System.out.println("inBadStrides: " + inBadStrides.shapeInfoToString());
|
System.out.println("inBadStrides: " + inBadStrides.shapeInfoToString());
|
||||||
|
|
||||||
|
|
||||||
INDArray out = Nd4j.create(DataType.FLOAT, 2, 12, 3, 3);
|
INDArray out = Nd4j.create(DataType.FLOAT, 2, 12, 3, 3);
|
||||||
INDArray out2 = out.like();
|
INDArray out2 = out.like();
|
||||||
|
|
||||||
|
|
||||||
CustomOp op1 = DynamicCustomOp.builder("space_to_depth")
|
CustomOp op1 = DynamicCustomOp.builder("space_to_depth")
|
||||||
.addInputs(in)
|
.addInputs(in)
|
||||||
.addIntegerArguments(2, 0) //nchw = 0, nhwc = 1
|
.addIntegerArguments(2, 0) //nchw = 0, nhwc = 1
|
||||||
|
|
|
@ -22,6 +22,14 @@ import org.junit.Test;
|
||||||
import org.junit.runner.RunWith;
|
import org.junit.runner.RunWith;
|
||||||
import org.junit.runners.Parameterized;
|
import org.junit.runners.Parameterized;
|
||||||
import org.nd4j.linalg.BaseNd4jTest;
|
import org.nd4j.linalg.BaseNd4jTest;
|
||||||
|
import org.nd4j.linalg.dataset.api.preprocessor.AbstractDataSetNormalizer;
|
||||||
|
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
|
||||||
|
import org.nd4j.linalg.dataset.api.preprocessor.MinMaxStrategy;
|
||||||
|
import org.nd4j.linalg.dataset.api.preprocessor.MultiNormalizerHybrid;
|
||||||
|
import org.nd4j.linalg.dataset.api.preprocessor.MultiNormalizerMinMaxScaler;
|
||||||
|
import org.nd4j.linalg.dataset.api.preprocessor.MultiNormalizerStandardize;
|
||||||
|
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerMinMaxScaler;
|
||||||
|
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
|
||||||
import org.nd4j.linalg.dataset.api.preprocessor.serializer.CustomSerializerStrategy;
|
import org.nd4j.linalg.dataset.api.preprocessor.serializer.CustomSerializerStrategy;
|
||||||
import org.nd4j.linalg.dataset.api.preprocessor.serializer.NormalizerSerializer;
|
import org.nd4j.linalg.dataset.api.preprocessor.serializer.NormalizerSerializer;
|
||||||
import org.nd4j.linalg.dataset.api.preprocessor.serializer.NormalizerType;
|
import org.nd4j.linalg.dataset.api.preprocessor.serializer.NormalizerType;
|
||||||
|
@ -65,7 +73,6 @@ public class NormalizerSerializerTest extends BaseNd4jTest {
|
||||||
|
|
||||||
ImagePreProcessingScaler restored = SUT.restore(tmpFile);
|
ImagePreProcessingScaler restored = SUT.restore(tmpFile);
|
||||||
assertEquals(imagePreProcessingScaler,restored);
|
assertEquals(imagePreProcessingScaler,restored);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
|
|
@ -25,6 +25,15 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter;
|
import org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.TestDataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.TestDataSetIterator;
|
||||||
|
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
|
||||||
|
import org.nd4j.linalg.dataset.api.preprocessor.ImageMultiPreProcessingScaler;
|
||||||
|
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
|
||||||
|
import org.nd4j.linalg.dataset.api.preprocessor.MultiDataNormalization;
|
||||||
|
import org.nd4j.linalg.dataset.api.preprocessor.MultiNormalizerMinMaxScaler;
|
||||||
|
import org.nd4j.linalg.dataset.api.preprocessor.MultiNormalizerStandardize;
|
||||||
|
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerMinMaxScaler;
|
||||||
|
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
|
||||||
|
import org.nd4j.linalg.dataset.api.preprocessor.VGG16ImagePreProcessor;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||||
|
|
|
@ -24,6 +24,12 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.rng.distribution.Distribution;
|
import org.nd4j.linalg.api.rng.distribution.Distribution;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||||
|
import org.nd4j.linalg.learning.config.AdaDelta;
|
||||||
|
import org.nd4j.linalg.learning.config.AdaGrad;
|
||||||
|
import org.nd4j.linalg.learning.config.AdaMax;
|
||||||
|
import org.nd4j.linalg.learning.config.Adam;
|
||||||
|
import org.nd4j.linalg.learning.config.Nadam;
|
||||||
|
import org.nd4j.linalg.learning.config.Nesterovs;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
|
|
||||||
|
@ -53,7 +59,6 @@ public class UpdaterTest extends BaseNd4jTest {
|
||||||
int rows = 10;
|
int rows = 10;
|
||||||
int cols = 2;
|
int cols = 2;
|
||||||
|
|
||||||
|
|
||||||
NesterovsUpdater grad = new NesterovsUpdater(new Nesterovs(0.5, 0.9));
|
NesterovsUpdater grad = new NesterovsUpdater(new Nesterovs(0.5, 0.9));
|
||||||
grad.setStateViewArray(Nd4j.zeros(1, rows * cols), new long[] {rows, cols}, 'c', true);
|
grad.setStateViewArray(Nd4j.zeros(1, rows * cols), new long[] {rows, cols}, 'c', true);
|
||||||
INDArray W = Nd4j.zeros(rows, cols);
|
INDArray W = Nd4j.zeros(rows, cols);
|
||||||
|
@ -68,13 +73,11 @@ public class UpdaterTest extends BaseNd4jTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testAdaGrad() {
|
public void testAdaGrad() {
|
||||||
int rows = 10;
|
int rows = 10;
|
||||||
int cols = 2;
|
int cols = 2;
|
||||||
|
|
||||||
|
|
||||||
AdaGradUpdater grad = new AdaGradUpdater(new AdaGrad(0.1, AdaGrad.DEFAULT_ADAGRAD_EPSILON));
|
AdaGradUpdater grad = new AdaGradUpdater(new AdaGrad(0.1, AdaGrad.DEFAULT_ADAGRAD_EPSILON));
|
||||||
grad.setStateViewArray(Nd4j.zeros(1, rows * cols), new long[] {rows, cols}, 'c', true);
|
grad.setStateViewArray(Nd4j.zeros(1, rows * cols), new long[] {rows, cols}, 'c', true);
|
||||||
INDArray W = Nd4j.zeros(rows, cols);
|
INDArray W = Nd4j.zeros(rows, cols);
|
||||||
|
|
|
@ -23,6 +23,15 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.impl.updaters.AmsGradUpdater;
|
import org.nd4j.linalg.api.ops.impl.updaters.AmsGradUpdater;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||||
|
import org.nd4j.linalg.learning.config.AMSGrad;
|
||||||
|
import org.nd4j.linalg.learning.config.AdaDelta;
|
||||||
|
import org.nd4j.linalg.learning.config.AdaGrad;
|
||||||
|
import org.nd4j.linalg.learning.config.AdaMax;
|
||||||
|
import org.nd4j.linalg.learning.config.Adam;
|
||||||
|
import org.nd4j.linalg.learning.config.Nadam;
|
||||||
|
import org.nd4j.linalg.learning.config.Nesterovs;
|
||||||
|
import org.nd4j.linalg.learning.config.RmsProp;
|
||||||
|
import org.nd4j.linalg.learning.config.Sgd;
|
||||||
|
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
|
@ -21,6 +21,21 @@ import org.nd4j.linalg.BaseNd4jTest;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||||
|
import org.nd4j.linalg.lossfunctions.impl.LossBinaryXENT;
|
||||||
|
import org.nd4j.linalg.lossfunctions.impl.LossCosineProximity;
|
||||||
|
import org.nd4j.linalg.lossfunctions.impl.LossHinge;
|
||||||
|
import org.nd4j.linalg.lossfunctions.impl.LossKLD;
|
||||||
|
import org.nd4j.linalg.lossfunctions.impl.LossL1;
|
||||||
|
import org.nd4j.linalg.lossfunctions.impl.LossL2;
|
||||||
|
import org.nd4j.linalg.lossfunctions.impl.LossMAE;
|
||||||
|
import org.nd4j.linalg.lossfunctions.impl.LossMAPE;
|
||||||
|
import org.nd4j.linalg.lossfunctions.impl.LossMCXENT;
|
||||||
|
import org.nd4j.linalg.lossfunctions.impl.LossMSE;
|
||||||
|
import org.nd4j.linalg.lossfunctions.impl.LossMSLE;
|
||||||
|
import org.nd4j.linalg.lossfunctions.impl.LossMultiLabel;
|
||||||
|
import org.nd4j.linalg.lossfunctions.impl.LossNegativeLogLikelihood;
|
||||||
|
import org.nd4j.linalg.lossfunctions.impl.LossPoisson;
|
||||||
|
import org.nd4j.linalg.lossfunctions.impl.LossSquaredHinge;
|
||||||
import org.nd4j.shade.jackson.databind.DeserializationFeature;
|
import org.nd4j.shade.jackson.databind.DeserializationFeature;
|
||||||
import org.nd4j.shade.jackson.databind.MapperFeature;
|
import org.nd4j.shade.jackson.databind.MapperFeature;
|
||||||
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
||||||
|
|
|
@ -28,6 +28,16 @@ import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||||
import org.nd4j.linalg.indexing.conditions.Conditions;
|
import org.nd4j.linalg.indexing.conditions.Conditions;
|
||||||
|
import org.nd4j.linalg.lossfunctions.impl.LossBinaryXENT;
|
||||||
|
import org.nd4j.linalg.lossfunctions.impl.LossL1;
|
||||||
|
import org.nd4j.linalg.lossfunctions.impl.LossL2;
|
||||||
|
import org.nd4j.linalg.lossfunctions.impl.LossMAE;
|
||||||
|
import org.nd4j.linalg.lossfunctions.impl.LossMAPE;
|
||||||
|
import org.nd4j.linalg.lossfunctions.impl.LossMCXENT;
|
||||||
|
import org.nd4j.linalg.lossfunctions.impl.LossMSE;
|
||||||
|
import org.nd4j.linalg.lossfunctions.impl.LossMSLE;
|
||||||
|
import org.nd4j.linalg.lossfunctions.impl.LossNegativeLogLikelihood;
|
||||||
|
import org.nd4j.linalg.lossfunctions.impl.LossSparseMCXENT;
|
||||||
|
|
||||||
import static junit.framework.TestCase.assertFalse;
|
import static junit.framework.TestCase.assertFalse;
|
||||||
import static junit.framework.TestCase.assertTrue;
|
import static junit.framework.TestCase.assertTrue;
|
||||||
|
|
|
@ -24,6 +24,13 @@ import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||||
import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper;
|
import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper;
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
import org.nd4j.imports.NoOpNameFoundException;
|
||||||
import org.nd4j.linalg.BaseNd4jTest;
|
import org.nd4j.linalg.BaseNd4jTest;
|
||||||
|
import org.nd4j.linalg.api.ops.BaseBroadcastOp;
|
||||||
|
import org.nd4j.linalg.api.ops.BaseIndexAccumulation;
|
||||||
|
import org.nd4j.linalg.api.ops.BaseReduceFloatOp;
|
||||||
|
import org.nd4j.linalg.api.ops.BaseScalarOp;
|
||||||
|
import org.nd4j.linalg.api.ops.BaseTransformSameOp;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
import org.nd4j.linalg.api.ops.Op;
|
||||||
import org.nd4j.linalg.api.ops.impl.summarystats.Variance;
|
import org.nd4j.linalg.api.ops.impl.summarystats.Variance;
|
||||||
import org.nd4j.linalg.api.ops.random.BaseRandomOp;
|
import org.nd4j.linalg.api.ops.random.BaseRandomOp;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
|
@ -26,6 +26,12 @@ import org.nd4j.linalg.BaseNd4jTest;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.impl.scalar.Step;
|
import org.nd4j.linalg.api.ops.impl.scalar.Step;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeDerivative;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardSigmoidDerivative;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhDerivative;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.strict.Sigmoid;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.SigmoidDerivative;
|
import org.nd4j.linalg.api.ops.impl.transforms.strict.SigmoidDerivative;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.TanhDerivative;
|
import org.nd4j.linalg.api.ops.impl.transforms.strict.TanhDerivative;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
@ -84,7 +90,6 @@ public class DerivativeTests extends BaseNd4jTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testRectifiedLinearDerivative() {
|
public void testRectifiedLinearDerivative() {
|
||||||
//ReLU:
|
//ReLU:
|
||||||
|
@ -166,11 +171,7 @@ public class DerivativeTests extends BaseNd4jTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
INDArray z = Transforms.hardSigmoid(xArr, true);
|
INDArray z = Transforms.hardSigmoid(xArr, true);
|
||||||
INDArray zPrime = Nd4j.getExecutioner().exec(new HardSigmoidDerivative(xArr.dup()));
|
INDArray zPrime = Nd4j.getExecutioner().exec(new HardSigmoidDerivative(xArr.dup()));;
|
||||||
|
|
||||||
// System.out.println(xArr);
|
|
||||||
// System.out.println(z);
|
|
||||||
// System.out.println(zPrime);
|
|
||||||
|
|
||||||
for (int i = 0; i < expHSOut.length; i++) {
|
for (int i = 0; i < expHSOut.length; i++) {
|
||||||
double relErrorHS =
|
double relErrorHS =
|
||||||
|
|
|
@ -32,6 +32,21 @@ import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.impl.reduce.floating.Mean;
|
import org.nd4j.linalg.api.ops.impl.reduce.floating.Mean;
|
||||||
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
|
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
|
||||||
|
import org.nd4j.linalg.api.ops.random.custom.DistributionUniform;
|
||||||
|
import org.nd4j.linalg.api.ops.random.custom.RandomBernoulli;
|
||||||
|
import org.nd4j.linalg.api.ops.random.custom.RandomGamma;
|
||||||
|
import org.nd4j.linalg.api.ops.random.custom.RandomPoisson;
|
||||||
|
import org.nd4j.linalg.api.ops.random.custom.RandomShuffle;
|
||||||
|
import org.nd4j.linalg.api.ops.random.impl.AlphaDropOut;
|
||||||
|
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
|
||||||
|
import org.nd4j.linalg.api.ops.random.impl.BinomialDistribution;
|
||||||
|
import org.nd4j.linalg.api.ops.random.impl.DropOut;
|
||||||
|
import org.nd4j.linalg.api.ops.random.impl.DropOutInverted;
|
||||||
|
import org.nd4j.linalg.api.ops.random.impl.GaussianDistribution;
|
||||||
|
import org.nd4j.linalg.api.ops.random.impl.Linspace;
|
||||||
|
import org.nd4j.linalg.api.ops.random.impl.LogNormalDistribution;
|
||||||
|
import org.nd4j.linalg.api.ops.random.impl.TruncatedNormalDistribution;
|
||||||
|
import org.nd4j.linalg.api.ops.random.impl.UniformDistribution;
|
||||||
import org.nd4j.linalg.api.rng.DefaultRandom;
|
import org.nd4j.linalg.api.rng.DefaultRandom;
|
||||||
import org.nd4j.linalg.api.rng.Random;
|
import org.nd4j.linalg.api.rng.Random;
|
||||||
import org.nd4j.linalg.api.rng.distribution.Distribution;
|
import org.nd4j.linalg.api.rng.distribution.Distribution;
|
||||||
|
@ -78,7 +93,6 @@ public class RandomTests extends BaseNd4jTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testCrossBackendEquality1() {
|
public void testCrossBackendEquality1() {
|
||||||
|
|
||||||
int[] shape = {12};
|
int[] shape = {12};
|
||||||
double mean = 0;
|
double mean = 0;
|
||||||
double standardDeviation = 1.0;
|
double standardDeviation = 1.0;
|
||||||
|
@ -87,8 +101,6 @@ public class RandomTests extends BaseNd4jTest {
|
||||||
INDArray arr = Nd4j.getExecutioner().exec(new GaussianDistribution(
|
INDArray arr = Nd4j.getExecutioner().exec(new GaussianDistribution(
|
||||||
Nd4j.createUninitialized(shape, Nd4j.order()), mean, standardDeviation), Nd4j.getRandom());
|
Nd4j.createUninitialized(shape, Nd4j.order()), mean, standardDeviation), Nd4j.getRandom());
|
||||||
|
|
||||||
|
|
||||||
// log.info("arr: {}", arr.data().asDouble());
|
|
||||||
assertEquals(exp, arr);
|
assertEquals(exp, arr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -105,8 +117,6 @@ public class RandomTests extends BaseNd4jTest {
|
||||||
UniformDistribution distribution2 = new UniformDistribution(z2, 1.0, 2.0);
|
UniformDistribution distribution2 = new UniformDistribution(z2, 1.0, 2.0);
|
||||||
Nd4j.getExecutioner().exec(distribution2, random2);
|
Nd4j.getExecutioner().exec(distribution2, random2);
|
||||||
|
|
||||||
// System.out.println("Data: " + z1);
|
|
||||||
// System.out.println("Data: " + z2);
|
|
||||||
for (int e = 0; e < z1.length(); e++) {
|
for (int e = 0; e < z1.length(); e++) {
|
||||||
double val = z1.getDouble(e);
|
double val = z1.getDouble(e);
|
||||||
assertTrue(val >= 1.0 && val <= 2.0);
|
assertTrue(val >= 1.0 && val <= 2.0);
|
||||||
|
@ -135,8 +145,6 @@ public class RandomTests extends BaseNd4jTest {
|
||||||
|
|
||||||
log.info("States cpu: {}/{}", random1.rootState(), random1.nodeState());
|
log.info("States cpu: {}/{}", random1.rootState(), random1.nodeState());
|
||||||
|
|
||||||
// System.out.println("Data: " + z1);
|
|
||||||
// System.out.println("Data: " + z2);
|
|
||||||
for (int e = 0; e < z1.length(); e++) {
|
for (int e = 0; e < z1.length(); e++) {
|
||||||
double val = z1.getDouble(e);
|
double val = z1.getDouble(e);
|
||||||
assertTrue(val >= 1.0 && val <= 2.0);
|
assertTrue(val >= 1.0 && val <= 2.0);
|
||||||
|
@ -156,9 +164,6 @@ public class RandomTests extends BaseNd4jTest {
|
||||||
UniformDistribution distribution2 = new UniformDistribution(z2, 1.0, 2.0);
|
UniformDistribution distribution2 = new UniformDistribution(z2, 1.0, 2.0);
|
||||||
Nd4j.getExecutioner().exec(distribution2, random1);
|
Nd4j.getExecutioner().exec(distribution2, random1);
|
||||||
|
|
||||||
// System.out.println("Data: " + z1);
|
|
||||||
// System.out.println("Data: " + z2);
|
|
||||||
|
|
||||||
assertNotEquals(z1, z2);
|
assertNotEquals(z1, z2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -174,7 +179,6 @@ public class RandomTests extends BaseNd4jTest {
|
||||||
INDArray z2 = Nd4j.randn('c', new int[] {1, 1000});
|
INDArray z2 = Nd4j.randn('c', new int[] {1, 1000});
|
||||||
|
|
||||||
assertEquals("Failed on iteration " + i, z1, z2);
|
assertEquals("Failed on iteration " + i, z1, z2);
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -190,7 +194,6 @@ public class RandomTests extends BaseNd4jTest {
|
||||||
INDArray z2 = Nd4j.rand('c', new int[] {1, 1000});
|
INDArray z2 = Nd4j.rand('c', new int[] {1, 1000});
|
||||||
|
|
||||||
assertEquals("Failed on iteration " + i, z1, z2);
|
assertEquals("Failed on iteration " + i, z1, z2);
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -206,7 +209,6 @@ public class RandomTests extends BaseNd4jTest {
|
||||||
INDArray z2 = Nd4j.getExecutioner().exec(new BinomialDistribution(Nd4j.createUninitialized(1000), 10, 0.2));
|
INDArray z2 = Nd4j.getExecutioner().exec(new BinomialDistribution(Nd4j.createUninitialized(1000), 10, 0.2));
|
||||||
|
|
||||||
assertEquals("Failed on iteration " + i, z1, z2);
|
assertEquals("Failed on iteration " + i, z1, z2);
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -222,8 +224,6 @@ public class RandomTests extends BaseNd4jTest {
|
||||||
assertEquals(z1, z2);
|
assertEquals(z1, z2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testDropoutInverted1() {
|
public void testDropoutInverted1() {
|
||||||
Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119);
|
Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119);
|
||||||
|
@ -318,7 +318,6 @@ public class RandomTests extends BaseNd4jTest {
|
||||||
assertEquals(z1, z2);
|
assertEquals(z1, z2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testGaussianDistribution2() {
|
public void testGaussianDistribution2() {
|
||||||
Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119);
|
Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119);
|
||||||
|
@ -403,8 +402,6 @@ public class RandomTests extends BaseNd4jTest {
|
||||||
Distribution nd = new NormalDistribution(random1, 0.0, 1.0);
|
Distribution nd = new NormalDistribution(random1, 0.0, 1.0);
|
||||||
Nd4j.sort(z1, true);
|
Nd4j.sort(z1, true);
|
||||||
|
|
||||||
// System.out.println("Data for Anderson-Darling: " + z1);
|
|
||||||
|
|
||||||
for (int i = 0; i < n; i++) {
|
for (int i = 0; i < n; i++) {
|
||||||
|
|
||||||
Double res = nd.cumulativeProbability(z1.getDouble(i));
|
Double res = nd.cumulativeProbability(z1.getDouble(i));
|
||||||
|
@ -432,9 +429,6 @@ public class RandomTests extends BaseNd4jTest {
|
||||||
public void testStepOver1() {
|
public void testStepOver1() {
|
||||||
Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119);
|
Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119);
|
||||||
|
|
||||||
|
|
||||||
// log.info("1: ----------------");
|
|
||||||
|
|
||||||
INDArray z0 = Nd4j.getExecutioner().exec(new GaussianDistribution(Nd4j.createUninitialized(DataType.DOUBLE, 1000000), 0.0, 1.0));
|
INDArray z0 = Nd4j.getExecutioner().exec(new GaussianDistribution(Nd4j.createUninitialized(DataType.DOUBLE, 1000000), 0.0, 1.0));
|
||||||
|
|
||||||
assertEquals(0.0, z0.meanNumber().doubleValue(), 0.01);
|
assertEquals(0.0, z0.meanNumber().doubleValue(), 0.01);
|
||||||
|
@ -442,36 +436,15 @@ public class RandomTests extends BaseNd4jTest {
|
||||||
|
|
||||||
random1.setSeed(119);
|
random1.setSeed(119);
|
||||||
|
|
||||||
// log.info("2: ----------------");
|
|
||||||
|
|
||||||
INDArray z1 = Nd4j.zeros(DataType.DOUBLE, 55000000);
|
INDArray z1 = Nd4j.zeros(DataType.DOUBLE, 55000000);
|
||||||
INDArray z2 = Nd4j.zeros(DataType.DOUBLE, 55000000);
|
INDArray z2 = Nd4j.zeros(DataType.DOUBLE, 55000000);
|
||||||
|
|
||||||
GaussianDistribution op1 = new GaussianDistribution(z1, 0.0, 1.0);
|
GaussianDistribution op1 = new GaussianDistribution(z1, 0.0, 1.0);
|
||||||
Nd4j.getExecutioner().exec(op1, random1);
|
Nd4j.getExecutioner().exec(op1, random1);
|
||||||
|
|
||||||
// log.info("2: ----------------");
|
|
||||||
|
|
||||||
//log.info("End: [{}, {}, {}, {}]", z1.getFloat(29000000), z1.getFloat(29000001), z1.getFloat(29000002), z1.getFloat(29000003));
|
|
||||||
|
|
||||||
//log.info("Sum: {}", z1.sumNumber().doubleValue());
|
|
||||||
// log.info("Sum2: {}", z2.sumNumber().doubleValue());
|
|
||||||
|
|
||||||
|
|
||||||
INDArray match = Nd4j.getExecutioner().exec(new MatchCondition(z1, Conditions.isNan()));
|
INDArray match = Nd4j.getExecutioner().exec(new MatchCondition(z1, Conditions.isNan()));
|
||||||
// log.info("NaNs: {}", match);
|
|
||||||
assertEquals(0.0f, match.getFloat(0), 0.01f);
|
assertEquals(0.0f, match.getFloat(0), 0.01f);
|
||||||
|
|
||||||
/*
|
|
||||||
for (int i = 0; i < z1.length(); i++) {
|
|
||||||
if (Double.isNaN(z1.getDouble(i)))
|
|
||||||
throw new IllegalStateException("NaN value found at " + i);
|
|
||||||
|
|
||||||
if (Double.isInfinite(z1.getDouble(i)))
|
|
||||||
throw new IllegalStateException("Infinite value found at " + i);
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
|
|
||||||
assertEquals(0.0, z1.meanNumber().doubleValue(), 0.01);
|
assertEquals(0.0, z1.meanNumber().doubleValue(), 0.01);
|
||||||
assertEquals(1.0, z1.stdNumber().doubleValue(), 0.01);
|
assertEquals(1.0, z1.stdNumber().doubleValue(), 0.01);
|
||||||
}
|
}
|
||||||
|
@ -480,7 +453,6 @@ public class RandomTests extends BaseNd4jTest {
|
||||||
public void testSum_119() {
|
public void testSum_119() {
|
||||||
INDArray z2 = Nd4j.zeros(DataType.DOUBLE, 55000000);
|
INDArray z2 = Nd4j.zeros(DataType.DOUBLE, 55000000);
|
||||||
val sum = z2.sumNumber().doubleValue();
|
val sum = z2.sumNumber().doubleValue();
|
||||||
// log.info("Sum2: {}", sum);
|
|
||||||
assertEquals(0.0, sum, 1e-5);
|
assertEquals(0.0, sum, 1e-5);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -493,7 +465,6 @@ public class RandomTests extends BaseNd4jTest {
|
||||||
assertEquals(1.0, z1.stdNumber().doubleValue(), 0.01);
|
assertEquals(1.0, z1.stdNumber().doubleValue(), 0.01);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testSetSeed1() {
|
public void testSetSeed1() {
|
||||||
Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119);
|
Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119);
|
||||||
|
@ -533,8 +504,6 @@ public class RandomTests extends BaseNd4jTest {
|
||||||
assertEquals(z02, z12);
|
assertEquals(z02, z12);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testJavaSide1() {
|
public void testJavaSide1() {
|
||||||
Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119);
|
Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119);
|
||||||
|
@ -553,8 +522,6 @@ public class RandomTests extends BaseNd4jTest {
|
||||||
assertArrayEquals(array1, array2, 1e-5f);
|
assertArrayEquals(array1, array2, 1e-5f);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testJavaSide2() {
|
public void testJavaSide2() {
|
||||||
Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119);
|
Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119);
|
||||||
|
@ -574,7 +541,6 @@ public class RandomTests extends BaseNd4jTest {
|
||||||
assertArrayEquals(array1, array2);
|
assertArrayEquals(array1, array2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testJavaSide3() {
|
public void testJavaSide3() {
|
||||||
Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119);
|
Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119);
|
||||||
|
@ -657,8 +623,6 @@ public class RandomTests extends BaseNd4jTest {
|
||||||
assertNotEquals(0, sum);
|
assertNotEquals(0, sum);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testBernoulliDistribution1() {
|
public void testBernoulliDistribution1() {
|
||||||
Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119);
|
Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119);
|
||||||
|
@ -677,11 +641,8 @@ public class RandomTests extends BaseNd4jTest {
|
||||||
assertNotEquals(z1Dup, z1);
|
assertNotEquals(z1Dup, z1);
|
||||||
|
|
||||||
assertEquals(z1, z2);
|
assertEquals(z1, z2);
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testBernoulliDistribution2() {
|
public void testBernoulliDistribution2() {
|
||||||
Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119);
|
Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119);
|
||||||
|
@ -690,7 +651,8 @@ public class RandomTests extends BaseNd4jTest {
|
||||||
INDArray z1 = Nd4j.zeros(20);
|
INDArray z1 = Nd4j.zeros(20);
|
||||||
INDArray z2 = Nd4j.zeros(20);
|
INDArray z2 = Nd4j.zeros(20);
|
||||||
INDArray z1Dup = Nd4j.zeros(20);
|
INDArray z1Dup = Nd4j.zeros(20);
|
||||||
INDArray exp = Nd4j.create(new double[] {0, 1.0000, 0, 1.0000, 1.0000, 0, 1.0000, 1.0000, 0, 1.0000, 1.0000, 1.0000, 0, 1.0000, 1.0000, 0, 0, 1.0000, 0, 1.0000});
|
INDArray exp = Nd4j.create(new double[]{ 0, 1.0000, 0, 1.0000, 1.0000, 0, 1.0000, 1.0000, 0, 1.0000, 1.0000,
|
||||||
|
1.0000, 0, 1.0000, 1.0000, 0, 0, 1.0000, 0, 1.0000 });
|
||||||
|
|
||||||
BernoulliDistribution op1 = new BernoulliDistribution(z1, 0.50);
|
BernoulliDistribution op1 = new BernoulliDistribution(z1, 0.50);
|
||||||
BernoulliDistribution op2 = new BernoulliDistribution(z2, 0.50);
|
BernoulliDistribution op2 = new BernoulliDistribution(z2, 0.50);
|
||||||
|
@ -705,7 +667,6 @@ public class RandomTests extends BaseNd4jTest {
|
||||||
assertEquals(exp, z1);
|
assertEquals(exp, z1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testBernoulliDistribution3() {
|
public void testBernoulliDistribution3() {
|
||||||
Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119);
|
Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119);
|
||||||
|
@ -716,7 +677,7 @@ public class RandomTests extends BaseNd4jTest {
|
||||||
INDArray z1 = Nd4j.zeros(10);
|
INDArray z1 = Nd4j.zeros(10);
|
||||||
INDArray z2 = Nd4j.zeros(10);
|
INDArray z2 = Nd4j.zeros(10);
|
||||||
INDArray z1Dup = Nd4j.zeros(10);
|
INDArray z1Dup = Nd4j.zeros(10);
|
||||||
INDArray exp = Nd4j.create(new double[] {1.0000, 0, 0, 1.0000, 1.0000, 1.0000, 0, 1.0000, 0, 0});
|
INDArray exp = Nd4j.create(new double[]{ 1.0000, 0, 0, 1.0000, 1.0000, 1.0000, 0, 1.0000, 0, 0 });
|
||||||
|
|
||||||
BernoulliDistribution op1 = new BernoulliDistribution(z1, prob);
|
BernoulliDistribution op1 = new BernoulliDistribution(z1, prob);
|
||||||
BernoulliDistribution op2 = new BernoulliDistribution(z2, prob);
|
BernoulliDistribution op2 = new BernoulliDistribution(z2, prob);
|
||||||
|
@ -731,7 +692,6 @@ public class RandomTests extends BaseNd4jTest {
|
||||||
assertEquals(exp, z1);
|
assertEquals(exp, z1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testBinomialDistribution1() {
|
public void testBinomialDistribution1() {
|
||||||
Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119);
|
Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119);
|
||||||
|
@ -780,7 +740,6 @@ public class RandomTests extends BaseNd4jTest {
|
||||||
BooleanIndexing.and(z1, Conditions.greaterThanOrEqual(0.0));
|
BooleanIndexing.and(z1, Conditions.greaterThanOrEqual(0.0));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testMultithreading1() throws Exception {
|
public void testMultithreading1() throws Exception {
|
||||||
|
|
||||||
|
@ -822,7 +781,6 @@ public class RandomTests extends BaseNd4jTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testMultithreading2() throws Exception {
|
public void testMultithreading2() throws Exception {
|
||||||
|
|
||||||
|
@ -885,11 +843,11 @@ public class RandomTests extends BaseNd4jTest {
|
||||||
|
|
||||||
assertNotEquals(someInt, otherInt);
|
assertNotEquals(someInt, otherInt);
|
||||||
|
|
||||||
} else
|
} else {
|
||||||
log.warn("Not a NativeRandom object received, skipping test");
|
log.warn("Not a NativeRandom object received, skipping test");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testStepOver4() {
|
public void testStepOver4() {
|
||||||
Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119, 100000);
|
Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119, 100000);
|
||||||
|
@ -903,7 +861,6 @@ public class RandomTests extends BaseNd4jTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testSignatures1() {
|
public void testSignatures1() {
|
||||||
|
|
||||||
|
@ -915,7 +872,6 @@ public class RandomTests extends BaseNd4jTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testChoice1() {
|
public void testChoice1() {
|
||||||
INDArray source = Nd4j.create(new double[] {1, 2, 3, 4, 5});
|
INDArray source = Nd4j.create(new double[] {1, 2, 3, 4, 5});
|
||||||
|
@ -926,7 +882,6 @@ public class RandomTests extends BaseNd4jTest {
|
||||||
assertEquals(exp, sampled);
|
assertEquals(exp, sampled);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testChoice2() {
|
public void testChoice2() {
|
||||||
INDArray source = Nd4j.create(new double[] {1, 2, 3, 4, 5});
|
INDArray source = Nd4j.create(new double[] {1, 2, 3, 4, 5});
|
||||||
|
@ -937,8 +892,6 @@ public class RandomTests extends BaseNd4jTest {
|
||||||
assertEquals(exp, sampled);
|
assertEquals(exp, sampled);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Ignore
|
@Ignore
|
||||||
@Test
|
@Test
|
||||||
public void testDeallocation1() throws Exception {
|
public void testDeallocation1() throws Exception {
|
||||||
|
@ -952,7 +905,6 @@ public class RandomTests extends BaseNd4jTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void someTest() {
|
public void someTest() {
|
||||||
DataTypeUtil.setDTypeForContext(DataType.DOUBLE);
|
DataTypeUtil.setDTypeForContext(DataType.DOUBLE);
|
||||||
|
@ -1353,9 +1305,6 @@ public class RandomTests extends BaseNd4jTest {
|
||||||
|
|
||||||
log.info("Java mean: {}; Native mean: {}", mean, z01.meanNumber().doubleValue());
|
log.info("Java mean: {}; Native mean: {}", mean, z01.meanNumber().doubleValue());
|
||||||
assertEquals(mean, z01.meanNumber().doubleValue(), 1e-1);
|
assertEquals(mean, z01.meanNumber().doubleValue(), 1e-1);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -1364,44 +1313,32 @@ public class RandomTests extends BaseNd4jTest {
|
||||||
INDArray exp = Nd4j.create(new double[] {1, 2, 3, 4, 5});
|
INDArray exp = Nd4j.create(new double[] {1, 2, 3, 4, 5});
|
||||||
|
|
||||||
assertEquals(exp, res);
|
assertEquals(exp, res);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testOrthogonalDistribution1() {
|
public void testOrthogonalDistribution1() {
|
||||||
val dist = new OrthogonalDistribution(1.0);
|
val dist = new OrthogonalDistribution(1.0);
|
||||||
|
|
||||||
val array = dist.sample(new int[] {6, 9});
|
val array = dist.sample(new int[] {6, 9});
|
||||||
|
|
||||||
// log.info("Array: {}", array);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testOrthogonalDistribution2() {
|
public void testOrthogonalDistribution2() {
|
||||||
val dist = new OrthogonalDistribution(1.0);
|
val dist = new OrthogonalDistribution(1.0);
|
||||||
|
|
||||||
val array = dist.sample(new int[] {9, 6});
|
val array = dist.sample(new int[] {9, 6});
|
||||||
|
|
||||||
// log.info("Array: {}", array);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testOrthogonalDistribution3() {
|
public void testOrthogonalDistribution3() {
|
||||||
val dist = new OrthogonalDistribution(1.0);
|
val dist = new OrthogonalDistribution(1.0);
|
||||||
|
|
||||||
val array = dist.sample(new int[] {9, 9});
|
val array = dist.sample(new int[] {9, 9});
|
||||||
|
|
||||||
// log.info("Array: {}", array);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void reproducabilityTest(){
|
public void reproducabilityTest(){
|
||||||
|
|
||||||
int numBatches = 1;
|
int numBatches = 1;
|
||||||
|
|
||||||
for( int t=0; t<10; t++ ) {
|
for (int t = 0; t < 10; t++) {
|
||||||
// System.out.println(t);
|
|
||||||
numBatches = t;
|
numBatches = t;
|
||||||
|
|
||||||
List<INDArray> initial = getList(numBatches);
|
List<INDArray> initial = getList(numBatches);
|
||||||
|
@ -1410,7 +1347,6 @@ public class RandomTests extends BaseNd4jTest {
|
||||||
List<INDArray> list = getList(numBatches);
|
List<INDArray> list = getList(numBatches);
|
||||||
assertEquals(initial, list);
|
assertEquals(initial, list);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1428,7 +1364,6 @@ public class RandomTests extends BaseNd4jTest {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
INDArray arr = Nd4j.create(DataType.DOUBLE, 100);
|
INDArray arr = Nd4j.create(DataType.DOUBLE, 100);
|
||||||
Nd4j.exec(new BernoulliDistribution(arr, 0.5));
|
Nd4j.exec(new BernoulliDistribution(arr, 0.5));
|
||||||
// System.out.println(arr);
|
|
||||||
double sum = arr.sumNumber().doubleValue();
|
double sum = arr.sumNumber().doubleValue();
|
||||||
assertTrue(String.valueOf(sum), sum > 0.0 && sum < 100.0);
|
assertTrue(String.valueOf(sum), sum > 0.0 && sum < 100.0);
|
||||||
}
|
}
|
||||||
|
@ -1436,7 +1371,6 @@ public class RandomTests extends BaseNd4jTest {
|
||||||
private List<INDArray> getList(int numBatches){
|
private List<INDArray> getList(int numBatches){
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
List<INDArray> out = new ArrayList<>();
|
List<INDArray> out = new ArrayList<>();
|
||||||
// int numBatches = 32; //passes with 1 or 2
|
|
||||||
int channels = 3;
|
int channels = 3;
|
||||||
int imageHeight = 64;
|
int imageHeight = 64;
|
||||||
int imageWidth = 64;
|
int imageWidth = 64;
|
||||||
|
@ -1446,7 +1380,6 @@ public class RandomTests extends BaseNd4jTest {
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testRngRepeatabilityUniform(){
|
public void testRngRepeatabilityUniform(){
|
||||||
val nexp = Nd4j.create(DataType.FLOAT, 10);
|
val nexp = Nd4j.create(DataType.FLOAT, 10);
|
||||||
|
@ -1521,7 +1454,6 @@ public class RandomTests extends BaseNd4jTest {
|
||||||
assertEquals(res[0], res1[0]);
|
assertEquals(res[0], res1[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testRandom() {
|
public void testRandom() {
|
||||||
val r1 = new java.util.Random(119);
|
val r1 = new java.util.Random(119);
|
||||||
|
|
|
@ -16,12 +16,16 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.rng;
|
package org.nd4j.linalg.rng;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
import static org.junit.Assert.fail;
|
||||||
|
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.OpValidationSuite;
|
import org.nd4j.OpValidationSuite;
|
||||||
import org.nd4j.common.base.Preconditions;
|
import org.nd4j.common.base.Preconditions;
|
||||||
|
import org.nd4j.common.util.ArrayUtil;
|
||||||
import org.nd4j.linalg.BaseNd4jTest;
|
import org.nd4j.linalg.BaseNd4jTest;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -33,19 +37,27 @@ import org.nd4j.linalg.api.ops.random.compat.RandomStandardNormal;
|
||||||
import org.nd4j.linalg.api.ops.random.custom.DistributionUniform;
|
import org.nd4j.linalg.api.ops.random.custom.DistributionUniform;
|
||||||
import org.nd4j.linalg.api.ops.random.custom.RandomBernoulli;
|
import org.nd4j.linalg.api.ops.random.custom.RandomBernoulli;
|
||||||
import org.nd4j.linalg.api.ops.random.custom.RandomExponential;
|
import org.nd4j.linalg.api.ops.random.custom.RandomExponential;
|
||||||
|
import org.nd4j.linalg.api.ops.random.impl.AlphaDropOut;
|
||||||
|
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
|
||||||
|
import org.nd4j.linalg.api.ops.random.impl.BinomialDistribution;
|
||||||
|
import org.nd4j.linalg.api.ops.random.impl.Choice;
|
||||||
|
import org.nd4j.linalg.api.ops.random.impl.DropOut;
|
||||||
|
import org.nd4j.linalg.api.ops.random.impl.DropOutInverted;
|
||||||
|
import org.nd4j.linalg.api.ops.random.impl.GaussianDistribution;
|
||||||
|
import org.nd4j.linalg.api.ops.random.impl.Linspace;
|
||||||
|
import org.nd4j.linalg.api.ops.random.impl.LogNormalDistribution;
|
||||||
|
import org.nd4j.linalg.api.ops.random.impl.ProbablisticMerge;
|
||||||
|
import org.nd4j.linalg.api.ops.random.impl.TruncatedNormalDistribution;
|
||||||
|
import org.nd4j.linalg.api.ops.random.impl.UniformDistribution;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||||
import org.nd4j.linalg.indexing.conditions.Conditions;
|
import org.nd4j.linalg.indexing.conditions.Conditions;
|
||||||
import org.nd4j.common.util.ArrayUtil;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.LinkedHashMap;
|
import java.util.LinkedHashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
|
||||||
import static org.junit.Assert.fail;
|
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class RngValidationTests extends BaseNd4jTest {
|
public class RngValidationTests extends BaseNd4jTest {
|
||||||
|
|
||||||
|
@ -407,8 +419,6 @@ public class RngValidationTests extends BaseNd4jTest {
|
||||||
double alpha = alphaDropoutA(tc.prop("p"));
|
double alpha = alphaDropoutA(tc.prop("p"));
|
||||||
double beta = alphaDropoutB(tc.prop("p"));
|
double beta = alphaDropoutB(tc.prop("p"));
|
||||||
return new AlphaDropOut(Nd4j.ones(tc.getDataType(), tc.shape), tc.arr(), tc.prop("p"), alpha, ALPHA_PRIME, beta);
|
return new AlphaDropOut(Nd4j.ones(tc.getDataType(), tc.shape), tc.arr(), tc.prop("p"), alpha, ALPHA_PRIME, beta);
|
||||||
|
|
||||||
|
|
||||||
case "distributionuniform":
|
case "distributionuniform":
|
||||||
INDArray shape = tc.getShape().length == 0 ? Nd4j.empty(DataType.LONG) : Nd4j.create(ArrayUtil.toDouble(tc.shape)).castTo(DataType.LONG);
|
INDArray shape = tc.getShape().length == 0 ? Nd4j.empty(DataType.LONG) : Nd4j.create(ArrayUtil.toDouble(tc.shape)).castTo(DataType.LONG);
|
||||||
return new DistributionUniform(shape, tc.arr(), tc.prop("min"), tc.prop("max"));
|
return new DistributionUniform(shape, tc.arr(), tc.prop("min"), tc.prop("max"));
|
||||||
|
@ -437,7 +447,6 @@ public class RngValidationTests extends BaseNd4jTest {
|
||||||
return Math.abs(x-y) / (Math.abs(x) + Math.abs(y));
|
return Math.abs(x-y) / (Math.abs(x) + Math.abs(y));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public static final double DEFAULT_ALPHA = 1.6732632423543772;
|
public static final double DEFAULT_ALPHA = 1.6732632423543772;
|
||||||
public static final double DEFAULT_LAMBDA = 1.0507009873554804;
|
public static final double DEFAULT_LAMBDA = 1.0507009873554804;
|
||||||
public static final double ALPHA_PRIME = -DEFAULT_LAMBDA * DEFAULT_ALPHA;
|
public static final double ALPHA_PRIME = -DEFAULT_LAMBDA * DEFAULT_ALPHA;
|
||||||
|
|
|
@ -29,6 +29,12 @@ import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.concurrency.AffinityManager;
|
import org.nd4j.linalg.api.concurrency.AffinityManager;
|
||||||
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||||
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
|
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
|
||||||
|
import org.nd4j.linalg.api.memory.enums.AllocationPolicy;
|
||||||
|
import org.nd4j.linalg.api.memory.enums.LearningPolicy;
|
||||||
|
import org.nd4j.linalg.api.memory.enums.LocationPolicy;
|
||||||
|
import org.nd4j.linalg.api.memory.enums.MirroringPolicy;
|
||||||
|
import org.nd4j.linalg.api.memory.enums.ResetPolicy;
|
||||||
|
import org.nd4j.linalg.api.memory.enums.SpillPolicy;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.shape.Shape;
|
import org.nd4j.linalg.api.shape.Shape;
|
||||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||||
|
|
|
@ -26,6 +26,11 @@ import org.junit.runners.Parameterized;
|
||||||
import org.nd4j.linalg.BaseNd4jTest;
|
import org.nd4j.linalg.BaseNd4jTest;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
|
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
|
||||||
|
import org.nd4j.linalg.api.memory.enums.AllocationPolicy;
|
||||||
|
import org.nd4j.linalg.api.memory.enums.DebugMode;
|
||||||
|
import org.nd4j.linalg.api.memory.enums.LearningPolicy;
|
||||||
|
import org.nd4j.linalg.api.memory.enums.MirroringPolicy;
|
||||||
|
import org.nd4j.linalg.api.memory.enums.SpillPolicy;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||||
import org.nd4j.linalg.api.memory.abstracts.Nd4jWorkspace;
|
import org.nd4j.linalg.api.memory.abstracts.Nd4jWorkspace;
|
||||||
|
|
|
@ -16,6 +16,11 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.workspace;
|
package org.nd4j.linalg.workspace;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
import static org.junit.Assert.assertFalse;
|
||||||
|
import static org.junit.Assert.assertNotEquals;
|
||||||
|
import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.junit.After;
|
import org.junit.After;
|
||||||
|
@ -25,19 +30,22 @@ import org.junit.runners.Parameterized;
|
||||||
import org.nd4j.linalg.BaseNd4jTest;
|
import org.nd4j.linalg.BaseNd4jTest;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||||
|
import org.nd4j.linalg.api.memory.abstracts.Nd4jWorkspace;
|
||||||
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
|
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
|
||||||
|
import org.nd4j.linalg.api.memory.enums.AllocationPolicy;
|
||||||
|
import org.nd4j.linalg.api.memory.enums.LearningPolicy;
|
||||||
|
import org.nd4j.linalg.api.memory.enums.LocationPolicy;
|
||||||
|
import org.nd4j.linalg.api.memory.enums.ResetPolicy;
|
||||||
|
import org.nd4j.linalg.api.memory.enums.SpillPolicy;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||||
import org.nd4j.linalg.api.memory.abstracts.Nd4jWorkspace;
|
|
||||||
|
|
||||||
import java.nio.file.Files;
|
import java.nio.file.Files;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
|
||||||
import static org.junit.Assert.*;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @author raver119@gmail.com
|
* @author raver119@gmail.com
|
||||||
*/
|
*/
|
||||||
|
@ -60,9 +68,15 @@ public class SpecialWorkspaceTests extends BaseNd4jTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testVariableTimeSeries1() {
|
public void testVariableTimeSeries1() {
|
||||||
WorkspaceConfiguration configuration = WorkspaceConfiguration.builder().initialSize(0).overallocationLimit(3.0)
|
WorkspaceConfiguration configuration = WorkspaceConfiguration
|
||||||
.policyAllocation(AllocationPolicy.OVERALLOCATE).policySpill(SpillPolicy.EXTERNAL)
|
.builder()
|
||||||
.policyLearning(LearningPolicy.FIRST_LOOP).policyReset(ResetPolicy.ENDOFBUFFER_REACHED).build();
|
.initialSize(0)
|
||||||
|
.overallocationLimit(3.0)
|
||||||
|
.policyAllocation(AllocationPolicy.OVERALLOCATE)
|
||||||
|
.policySpill(SpillPolicy.EXTERNAL)
|
||||||
|
.policyLearning(LearningPolicy.FIRST_LOOP)
|
||||||
|
.policyReset(ResetPolicy.ENDOFBUFFER_REACHED)
|
||||||
|
.build();
|
||||||
|
|
||||||
try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(configuration, "WS1")) {
|
try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(configuration, "WS1")) {
|
||||||
Nd4j.create(500);
|
Nd4j.create(500);
|
||||||
|
@ -70,7 +84,6 @@ public class SpecialWorkspaceTests extends BaseNd4jTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
Nd4jWorkspace workspace = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread("WS1");
|
Nd4jWorkspace workspace = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread("WS1");
|
||||||
// workspace.enableDebug(true);
|
|
||||||
|
|
||||||
assertEquals(0, workspace.getStepNumber());
|
assertEquals(0, workspace.getStepNumber());
|
||||||
|
|
||||||
|
@ -125,7 +138,6 @@ public class SpecialWorkspaceTests extends BaseNd4jTest {
|
||||||
log.info("Workspace state after first block: ---------------------------------------------------------");
|
log.info("Workspace state after first block: ---------------------------------------------------------");
|
||||||
Nd4j.getWorkspaceManager().printAllocationStatisticsForCurrentThread();
|
Nd4j.getWorkspaceManager().printAllocationStatisticsForCurrentThread();
|
||||||
|
|
||||||
|
|
||||||
log.info("--------------------------------------------------------------------------------------------");
|
log.info("--------------------------------------------------------------------------------------------");
|
||||||
|
|
||||||
// we just do huge loop now, with pinned stuff in it
|
// we just do huge loop now, with pinned stuff in it
|
||||||
|
@ -144,7 +156,6 @@ public class SpecialWorkspaceTests extends BaseNd4jTest {
|
||||||
assertNotEquals(0, workspace.getNumberOfPinnedAllocations());
|
assertNotEquals(0, workspace.getNumberOfPinnedAllocations());
|
||||||
assertEquals(0, workspace.getNumberOfExternalAllocations());
|
assertEquals(0, workspace.getNumberOfExternalAllocations());
|
||||||
|
|
||||||
|
|
||||||
// and we do another clean loo, without pinned stuff in it, to ensure all pinned allocates are gone
|
// and we do another clean loo, without pinned stuff in it, to ensure all pinned allocates are gone
|
||||||
for (int i = 0; i < 100; i++) {
|
for (int i = 0; i < 100; i++) {
|
||||||
try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(configuration, "WS1")) {
|
try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(configuration, "WS1")) {
|
||||||
|
@ -158,12 +169,10 @@ public class SpecialWorkspaceTests extends BaseNd4jTest {
|
||||||
assertEquals(0, workspace.getNumberOfPinnedAllocations());
|
assertEquals(0, workspace.getNumberOfPinnedAllocations());
|
||||||
assertEquals(0, workspace.getNumberOfExternalAllocations());
|
assertEquals(0, workspace.getNumberOfExternalAllocations());
|
||||||
|
|
||||||
|
|
||||||
log.info("Workspace state after second block: ---------------------------------------------------------");
|
log.info("Workspace state after second block: ---------------------------------------------------------");
|
||||||
Nd4j.getWorkspaceManager().printAllocationStatisticsForCurrentThread();
|
Nd4j.getWorkspaceManager().printAllocationStatisticsForCurrentThread();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testVariableTimeSeries2() {
|
public void testVariableTimeSeries2() {
|
||||||
WorkspaceConfiguration configuration = WorkspaceConfiguration.builder().initialSize(0).overallocationLimit(3.0)
|
WorkspaceConfiguration configuration = WorkspaceConfiguration.builder().initialSize(0).overallocationLimit(3.0)
|
||||||
|
@ -179,8 +188,6 @@ public class SpecialWorkspaceTests extends BaseNd4jTest {
|
||||||
Nd4j.create(500);
|
Nd4j.create(500);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
assertEquals(0, workspace.getStepNumber());
|
assertEquals(0, workspace.getStepNumber());
|
||||||
|
|
||||||
long requiredMemory = 1000 * Nd4j.sizeOfDataType();
|
long requiredMemory = 1000 * Nd4j.sizeOfDataType();
|
||||||
|
@ -189,7 +196,6 @@ public class SpecialWorkspaceTests extends BaseNd4jTest {
|
||||||
assertEquals(shiftedSize, workspace.getInitialBlockSize());
|
assertEquals(shiftedSize, workspace.getInitialBlockSize());
|
||||||
assertEquals(workspace.getInitialBlockSize() * 4, workspace.getCurrentSize());
|
assertEquals(workspace.getInitialBlockSize() * 4, workspace.getCurrentSize());
|
||||||
|
|
||||||
|
|
||||||
for (int i = 0; i < 100; i++) {
|
for (int i = 0; i < 100; i++) {
|
||||||
try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(configuration, "WS1")) {
|
try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(configuration, "WS1")) {
|
||||||
Nd4j.create(500);
|
Nd4j.create(500);
|
||||||
|
@ -206,7 +212,6 @@ public class SpecialWorkspaceTests extends BaseNd4jTest {
|
||||||
|
|
||||||
assertEquals(0, workspace.getSpilledSize());
|
assertEquals(0, workspace.getSpilledSize());
|
||||||
assertEquals(0, workspace.getPinnedSize());
|
assertEquals(0, workspace.getPinnedSize());
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -238,7 +243,6 @@ public class SpecialWorkspaceTests extends BaseNd4jTest {
|
||||||
assertEquals(exp, result);
|
assertEquals(exp, result);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testAlignment_1() {
|
public void testAlignment_1() {
|
||||||
WorkspaceConfiguration initialConfig = WorkspaceConfiguration.builder().initialSize(10 * 1024L * 1024L)
|
WorkspaceConfiguration initialConfig = WorkspaceConfiguration.builder().initialSize(10 * 1024L * 1024L)
|
||||||
|
@ -260,7 +264,6 @@ public class SpecialWorkspaceTests extends BaseNd4jTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testNoOpExecution_1() {
|
public void testNoOpExecution_1() {
|
||||||
val configuration = WorkspaceConfiguration.builder().initialSize(10000000).overallocationLimit(3.0)
|
val configuration = WorkspaceConfiguration.builder().initialSize(10000000).overallocationLimit(3.0)
|
||||||
|
@ -424,7 +427,6 @@ public class SpecialWorkspaceTests extends BaseNd4jTest {
|
||||||
Files.delete(tmpFile);
|
Files.delete(tmpFile);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testMigrateToWorkspace(){
|
public void testMigrateToWorkspace(){
|
||||||
val src = Nd4j.createFromArray (1L,2L);
|
val src = Nd4j.createFromArray (1L,2L);
|
||||||
|
|
|
@ -27,6 +27,11 @@ import org.nd4j.linalg.BaseNd4jTest;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||||
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
|
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
|
||||||
|
import org.nd4j.linalg.api.memory.enums.AllocationPolicy;
|
||||||
|
import org.nd4j.linalg.api.memory.enums.LearningPolicy;
|
||||||
|
import org.nd4j.linalg.api.memory.enums.MirroringPolicy;
|
||||||
|
import org.nd4j.linalg.api.memory.enums.ResetPolicy;
|
||||||
|
import org.nd4j.linalg.api.memory.enums.SpillPolicy;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||||
|
|
Loading…
Reference in New Issue