Improving SameDiff tests coverage (#227)

* Gradients tests added

* Fix for Standard deviation serialization + test

Signed-off-by: Alex Black <blacka101@gmail.com>

* More fixes

Signed-off-by: Alex Black <blacka101@gmail.com>

* Test fixed

* Spark config driver host config for CI

Signed-off-by: Alex Black <blacka101@gmail.com>

* Op validation timeout increase

Signed-off-by: Alex Black <blacka101@gmail.com>

* Gradient check - fix for low probability test failure due to randomly all 0s mask

Signed-off-by: AlexDBlack <blacka101@gmail.com>

Co-authored-by: Alex Black <blacka101@gmail.com>
master
Alexander Stoyakin 2020-02-13 01:29:08 +02:00 committed by GitHub
parent 5c9e0bc2bb
commit 8c0e378ec3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
58 changed files with 1027 additions and 255 deletions

View File

@ -414,7 +414,11 @@ public class GradientCheckTestsMasking extends BaseDL4JTest {
INDArray l = TestUtils.randomOneHot(mb, 3);
INDArray lm = TestUtils.randomBernoulli(mb, 1);
assertTrue(lm.sumNumber().intValue() > 0);
int attempts = 0;
while(attempts++ < 1000 && lm.sumNumber().intValue() == 0){
lm = TestUtils.randomBernoulli(mb, 1);
}
assertTrue("Could not generate non-zero mask after " + attempts + " attempts", lm.sumNumber().intValue() > 0);
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(f)
.labels(l).labelMask(lm));
@ -467,7 +471,11 @@ public class GradientCheckTestsMasking extends BaseDL4JTest {
INDArray l = TestUtils.randomOneHot(mb, 3);
INDArray lm = TestUtils.randomBernoulli(mb, 1);
assertTrue(lm.sumNumber().intValue() > 0);
int attempts = 0;
while(attempts++ < 1000 && lm.sumNumber().intValue() == 0){
lm = TestUtils.randomBernoulli(mb, 1);
}
assertTrue("Could not generate non-zero mask after " + attempts + " attempts", lm.sumNumber().intValue() > 0);
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(net).inputs(new INDArray[]{f})
.labels(new INDArray[]{l}).labelMask(new INDArray[]{lm}));

View File

@ -67,7 +67,9 @@ public class SparkSequenceVectorsTest extends BaseDL4JTest {
}
}
SparkConf sparkConf = new SparkConf().setMaster("local[8]").setAppName("SeqVecTests");
SparkConf sparkConf = new SparkConf().setMaster("local[8]")
.set("spark.driver.host", "localhost")
.setAppName("SeqVecTests");
sc = new JavaSparkContext(sparkConf);
}

View File

@ -61,7 +61,9 @@ public class SparkWord2VecTest extends BaseDL4JTest {
sentences.add("one another sentence");
}
SparkConf sparkConf = new SparkConf().setMaster("local[8]").setAppName("SeqVecTests");
SparkConf sparkConf = new SparkConf().setMaster("local[8]")
.set("spark.driver.host", "localhost")
.setAppName("SeqVecTests");
sc = new JavaSparkContext(sparkConf);
}

View File

@ -56,7 +56,9 @@ public class Word2VecTest {
@Test
public void testConcepts() throws Exception {
// These are all default values for word2vec
SparkConf sparkConf = new SparkConf().setMaster("local[8]").setAppName("sparktest");
SparkConf sparkConf = new SparkConf().setMaster("local[8]")
.set("spark.driver.host", "localhost")
.setAppName("sparktest");
// Set SparkContext
JavaSparkContext sc = new JavaSparkContext(sparkConf);
@ -156,6 +158,7 @@ public class Word2VecTest {
@Test
public void testSparkW2VonBiggerCorpus() throws Exception {
SparkConf sparkConf = new SparkConf().setMaster("local[8]").setAppName("sparktest")
.set("spark.driver.host", "localhost")
.set("spark.driver.maxResultSize", "4g").set("spark.driver.memory", "8g")
.set("spark.executor.memory", "8g");

View File

@ -63,7 +63,7 @@ public class TextPipelineTest extends BaseSparkTest {
@Before
public void before() throws Exception {
conf = new SparkConf().setMaster("local[4]").setAppName("sparktest");
conf = new SparkConf().setMaster("local[4]").setAppName("sparktest").set("spark.driver.host", "localhost");
// All the avaliable options. These are default values
word2vec = new Word2Vec.Builder().minWordFrequency(1).setNGrams(1)

View File

@ -85,7 +85,8 @@ public abstract class BaseSparkTest extends BaseDL4JTest implements Serializable
if (sc != null)
return sc;
// set to test mode
SparkConf sparkConf = new SparkConf().setMaster("local[" + numExecutors() + "]").set("spark.driver.host", "localhost").setAppName("sparktest");
SparkConf sparkConf = new SparkConf().setMaster("local[" + numExecutors() + "]")
.set("spark.driver.host", "localhost").setAppName("sparktest");
sc = new JavaSparkContext(sparkConf);

View File

@ -59,7 +59,9 @@ public class BaseSparkKryoTest extends BaseSparkTest {
SparkConf sparkConf = new SparkConf().setMaster("local[" + numExecutors() + "]").setAppName("sparktest");
SparkConf sparkConf = new SparkConf().setMaster("local[" + numExecutors() + "]")
.setAppName("sparktest")
.set("spark.driver.host", "localhost");
sparkConf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer");
sparkConf.set("spark.kryo.registrator", "org.nd4j.Nd4jRegistrator");

View File

@ -89,7 +89,8 @@ public abstract class BaseSparkTest extends BaseDL4JTest implements Serializable
if (sc != null)
return sc;
// set to test mode
SparkConf sparkConf = new SparkConf().setMaster("local[" + numExecutors() + "]").set("spark.driver.host", "localhost").setAppName("sparktest");
SparkConf sparkConf = new SparkConf().setMaster("local[" + numExecutors() + "]")
.set("spark.driver.host", "localhost").setAppName("sparktest");
sc = new JavaSparkContext(sparkConf);

View File

@ -72,8 +72,9 @@ public class TestKryoWarning {
@Ignore
public void testKryoMessageMLNIncorrectConfig() {
//Should print warning message
SparkConf sparkConf = new SparkConf().setMaster("local[*]").setAppName("sparktest").set("spark.serializer",
"org.apache.spark.serializer.KryoSerializer");
SparkConf sparkConf = new SparkConf().setMaster("local[*]").setAppName("sparktest")
.set("spark.driver.host", "localhost")
.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer");
doTestMLN(sparkConf);
}
@ -83,6 +84,7 @@ public class TestKryoWarning {
public void testKryoMessageMLNCorrectConfigKryo() {
//Should NOT print warning message
SparkConf sparkConf = new SparkConf().setMaster("local[*]").setAppName("sparktest")
.set("spark.driver.host", "localhost")
.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
.set("spark.kryo.registrator", "org.nd4j.Nd4jRegistrator");
@ -93,7 +95,9 @@ public class TestKryoWarning {
@Ignore
public void testKryoMessageMLNCorrectConfigNoKryo() {
//Should NOT print warning message
SparkConf sparkConf = new SparkConf().setMaster("local[*]").setAppName("sparktest");
SparkConf sparkConf = new SparkConf().setMaster("local[*]")
.set("spark.driver.host", "localhost")
.setAppName("sparktest");
doTestMLN(sparkConf);
}
@ -104,8 +108,9 @@ public class TestKryoWarning {
@Ignore
public void testKryoMessageCGIncorrectConfig() {
//Should print warning message
SparkConf sparkConf = new SparkConf().setMaster("local[*]").setAppName("sparktest").set("spark.serializer",
"org.apache.spark.serializer.KryoSerializer");
SparkConf sparkConf = new SparkConf().setMaster("local[*]").setAppName("sparktest")
.set("spark.driver.host", "localhost")
.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer");
doTestCG(sparkConf);
}
@ -115,6 +120,7 @@ public class TestKryoWarning {
public void testKryoMessageCGCorrectConfigKryo() {
//Should NOT print warning message
SparkConf sparkConf = new SparkConf().setMaster("local[*]").setAppName("sparktest")
.set("spark.driver.host", "localhost")
.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
.set("spark.kryo.registrator", "org.nd4j.Nd4jRegistrator");
@ -125,7 +131,9 @@ public class TestKryoWarning {
@Ignore
public void testKryoMessageCGCorrectConfigNoKryo() {
//Should NOT print warning message
SparkConf sparkConf = new SparkConf().setMaster("local[*]").setAppName("sparktest");
SparkConf sparkConf = new SparkConf().setMaster("local[*]")
.set("spark.driver.host", "localhost")
.setAppName("sparktest");
doTestCG(sparkConf);
}

View File

@ -138,6 +138,7 @@ public class TestCompareParameterAveragingSparkVsSingleMachine {
SparkConf sparkConf = new SparkConf();
sparkConf.setMaster("local[" + nWorkers + "]");
sparkConf.setAppName("Test");
sparkConf.set("spark.driver.host", "localhost");
JavaSparkContext sc = new JavaSparkContext(sparkConf);
return sc;

View File

@ -58,7 +58,7 @@ public class ExportSupportTest {
}
private void assertSupported(SparkConf conf) throws IOException {
JavaSparkContext sc = new JavaSparkContext(conf.setAppName("Test"));
JavaSparkContext sc = new JavaSparkContext(conf.setAppName("Test").set("spark.driver.host", "localhost"));
try {
assertTrue(ExportSupport.exportSupported(sc));
} finally {

View File

@ -25,6 +25,7 @@ import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.spark.BaseSparkTest;
import org.deeplearning4j.spark.api.Repartition;
import org.deeplearning4j.spark.api.stats.CommonSparkTrainingStats;
import org.deeplearning4j.spark.api.stats.SparkTrainingStats;
@ -50,17 +51,13 @@ import static org.junit.Assert.*;
/**
* Created by Alex on 17/06/2016.
*/
public class TestTrainingStatsCollection {
public class TestTrainingStatsCollection extends BaseSparkTest {
@Test
public void testStatsCollection() throws Exception {
int nWorkers = 4;
int nWorkers = numExecutors();
SparkConf sparkConf = new SparkConf();
sparkConf.setMaster("local[" + nWorkers + "]");
sparkConf.setAppName("Test");
JavaSparkContext sc = new JavaSparkContext(sparkConf);
JavaSparkContext sc = getContext();
try {

View File

@ -294,6 +294,11 @@ public class DifferentialFunctionFactory {
return new ZerosLike(name, sameDiff(), input).outputVariable();
}
public SDVariable zerosLike(String name, SDVariable input, DataType dataType) {
validateDifferentialFunctionsameDiff(input);
return new ZerosLike(name, sameDiff(), input, dataType).outputVariable();
}
public SDVariable create(String name, SDVariable shape, boolean initialize, DataType dataType) {
return create(name, shape, 'c', initialize, dataType);
}
@ -1751,12 +1756,12 @@ public class DifferentialFunctionFactory {
return new SoftmaxCrossEntropyLossBp(sameDiff(), lossReduce, logits, weights, labels, labelSmoothing).outputVariables();
}
public SDVariable lossSoftmaxCrossEntropyWithLogits(SDVariable labels, SDVariable logits, SDVariable weights, int classDim) {
return new SoftmaxCrossEntropyWithLogitsLoss(sameDiff(), logits, weights, labels, classDim).outputVariable();
public SDVariable lossSoftmaxCrossEntropyWithLogits(SDVariable labels, SDVariable logits, int classDim) {
return new SoftmaxCrossEntropyWithLogitsLoss(sameDiff(), logits, labels, classDim).outputVariable();
}
public SDVariable[] lossSoftmaxCrossEntropyWithLogitsBp(SDVariable labels, SDVariable logits, SDVariable weights, int classDim) {
return new SoftmaxCrossEntropyWithLogitsLossBp(sameDiff(), logits, weights, labels, classDim).outputVariables();
public SDVariable[] lossSoftmaxCrossEntropyWithLogitsBp(SDVariable labels, SDVariable logits, int classDim) {
return new SoftmaxCrossEntropyWithLogitsLossBp(sameDiff(), logits, labels, classDim).outputVariables();
}
public SDVariable lossSparseSoftmaxCrossEntropy(SDVariable logits, SDVariable labels){
@ -2638,7 +2643,7 @@ public class DifferentialFunctionFactory {
return new Polygamma(sameDiff, n,x).outputVariable();
}
public SDVariable roll(SDVariable input, SDVariable shift) {
public SDVariable roll(SDVariable input, int shift) {
return new Roll(sameDiff, input, shift).outputVariable();
}

View File

@ -787,9 +787,10 @@ public abstract class SDBaseOps {
* @param number Number of values to generate
* @return SDVariable with linearly spaced elements
*/
public SDVariable linspace(DataType dataType, double start, double stop, long number) {
// TODO: fix or remove, currently it is internal recursion
/*public SDVariable linspace(DataType dataType, double start, double stop, long number) {
return linspace(dataType, start, stop, number);
}
}*/
/**
* Create a new 1d array with values evenly spaced between values 'start' and 'stop'
@ -3093,6 +3094,9 @@ public abstract class SDBaseOps {
return zerosLike(null, input);
}
public SDVariable zerosLike(@NonNull SDVariable input, @NonNull DataType dataType) {
return zerosLike(null, input, dataType);
}
/**
* Return a variable of all 0s, with the same shape as the input variable. Note that this is dynamic:
* if the input shape changes in later execution, the returned variable's shape will also be updated
@ -3106,6 +3110,10 @@ public abstract class SDBaseOps {
return updateVariableNameAndReference(ret, name);
}
public SDVariable zerosLike(String name, @NonNull SDVariable input, @NonNull DataType dataType) {
SDVariable ret = f().zerosLike(name, input, dataType);
return updateVariableNameAndReference(ret, name);
}
/**
* See {@link #any(String, SDVariable, int...)}

View File

@ -2545,7 +2545,7 @@ public class SDMath extends SDOps {
* @param shift number of places to shift elements
* @return array
*/
public SDVariable roll(String name, SDVariable input, SDVariable shift) {
public SDVariable roll(String name, SDVariable input, int shift) {
SDVariable res = f().roll(input,shift);
return updateVariableNameAndReference(res, name);
}

View File

@ -815,8 +815,9 @@ public class FlatBuffersMapper {
}
int[] dims;
if (node.opType() == Op.Type.REDUCE_FLOAT || node.opType() == Op.Type.REDUCE_SAME || node.opType() == Op.Type.REDUCE_BOOL
|| node.opType() == Op.Type.REDUCE_LONG || node.opType() == Op.Type.INDEXREDUCE || node.opType() == Op.Type.REDUCE3) {
Type t = node.opType();
if (t == Op.Type.REDUCE_FLOAT || t == Op.Type.REDUCE_SAME || t == Op.Type.REDUCE_BOOL
|| t == Op.Type.REDUCE_LONG || t == Op.Type.INDEXREDUCE || t == Op.Type.REDUCE3 || t == Type.VARIANCE || t == Type.SUMMARYSTATS) {
dims = node.getDimensions();
if (dims == null)
dims = new int[0];

View File

@ -16,6 +16,7 @@
package org.nd4j.autodiff.validation;
import org.nd4j.linalg.api.ops.custom.*;
import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax;
import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin;
import org.nd4j.linalg.api.ops.impl.reduce.HashCode;
@ -38,10 +39,6 @@ import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.CustomOpDescriptor;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.custom.BarnesEdgeForces;
import org.nd4j.linalg.api.ops.custom.BarnesHutGains;
import org.nd4j.linalg.api.ops.custom.BarnesHutSymmetrize;
import org.nd4j.linalg.api.ops.custom.SpTreeCell;
import org.nd4j.linalg.api.ops.impl.broadcast.bool.*;
import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction;
import org.nd4j.linalg.api.ops.impl.loss.bp.*;
@ -1011,7 +1008,10 @@ public class OpValidation {
SpTreeCell.class,
CbowRound.class,
SkipGramRound.class,
HashCode.class
HashCode.class,
HashCode.class,
BitCast.class,
ToggleBits.class
);
return new HashSet<>(list);

View File

@ -200,7 +200,6 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.impl.reduce.custom.BatchMmul.class,
org.nd4j.linalg.api.ops.impl.reduce.custom.LogSumExp.class,
org.nd4j.linalg.api.ops.impl.reduce.floating.AMean.class,
org.nd4j.linalg.api.ops.impl.reduce.floating.Bias.class,
org.nd4j.linalg.api.ops.impl.reduce.floating.Entropy.class,
org.nd4j.linalg.api.ops.impl.reduce.floating.LogEntropy.class,
org.nd4j.linalg.api.ops.impl.reduce.floating.Mean.class,

View File

@ -16,21 +16,28 @@
package org.nd4j.linalg.api.ops.custom;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.val;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import java.util.Arrays;
import java.util.List;
/**
* This op takes arbitrary number of arrays as input, and returns single "flattened" vector
*
* @author raver119@gmail.com
*/
@Data
@NoArgsConstructor
public class Flatten extends DynamicCustomOp {
private char order;
public Flatten() {
//
}
private int order;
public Flatten(char order, INDArray... inputs) {
this.order = order;
@ -47,10 +54,21 @@ public class Flatten extends DynamicCustomOp {
outputArguments.add(output);
}
public Flatten(SameDiff sameDiff, char order, SDVariable... inputs) {
super(sameDiff, inputs);
this.order = order;
addIArgument(order);
}
@Override
public String opName() {
return "flatten";
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
int n = args().length;
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
return Arrays.asList(inputDataTypes.get(0));
}
}

View File

@ -51,6 +51,14 @@ public class FusedBatchNorm extends DynamicCustomOp {
public FusedBatchNorm(@NonNull SameDiff sameDiff, @NonNull SDVariable x, @NonNull SDVariable scale, @NonNull SDVariable offset,
@NonNull SDVariable dataFormat, @NonNull SDVariable isTraining) {
super("", sameDiff, new SDVariable[]{x, scale, offset, dataFormat, isTraining});
this.outputDataType = x.dataType();
}
public FusedBatchNorm(@NonNull SameDiff sameDiff, @NonNull SDVariable x, @NonNull SDVariable scale, @NonNull SDVariable offset,
int dataFormat, int isTraining) {
super("", sameDiff, new SDVariable[]{x, scale, offset});
addIArgument(dataFormat, isTraining);
this.outputDataType = x.dataType();
}
@Override
@ -78,6 +86,8 @@ public class FusedBatchNorm extends DynamicCustomOp {
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
int n = args().length;
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
return Arrays.asList(outputDataType, DataType.FLOAT, DataType.FLOAT); //Activations may be half, bfloat16, float32; mean/var is always float
return Arrays.asList(outputDataType == null ? DataType.FLOAT : outputDataType,
outputDataType == null ? DataType.FLOAT : outputDataType,
outputDataType == null ? DataType.FLOAT : outputDataType);
}
}

View File

@ -64,6 +64,6 @@ public class Lu extends DynamicCustomOp {
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
int n = args().length;
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
return Arrays.asList(inputDataTypes.get(0), indexDataType);
return Arrays.asList(inputDataTypes.get(0), indexDataType == null ? DataType.INT32 : indexDataType);
}
}

View File

@ -46,6 +46,11 @@ public class MatrixBandPart extends DynamicCustomOp {
super("", sameDiff, new SDVariable[]{input, minLower, maxUpper});
}
public MatrixBandPart(@NonNull SameDiff sameDiff, @NonNull SDVariable input, int minLower, int maxUpper) {
super("", sameDiff, new SDVariable[]{input});
addIArgument(minLower, maxUpper);
}
@Override
public String opName() {
return "matrix_band_part";

View File

@ -45,6 +45,15 @@ public class Roll extends DynamicCustomOp {
super("", sameDiff, new SDVariable[]{input,shift});
}
public Roll(@NonNull SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable axes, @NonNull SDVariable shift) {
super("", sameDiff, new SDVariable[]{input,axes,shift});
}
public Roll(@NonNull SameDiff sameDiff, @NonNull SDVariable input, int shift) {
super("", sameDiff, new SDVariable[]{input});
addIArgument(shift);
}
@Override
public String opName() {
return "roll";

View File

@ -7,9 +7,13 @@ import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;
import java.util.Collections;
import java.util.List;
import java.util.Map;
@NoArgsConstructor
public class TriangularSolve extends DynamicCustomOp {
@ -24,11 +28,27 @@ public class TriangularSolve extends DynamicCustomOp {
super(sameDiff, new SDVariable[] {matrix, rhs, lower, adjoint});
}
public TriangularSolve(SameDiff sameDiff, SDVariable matrix, SDVariable rhs,
boolean lower, boolean adjoint) {
super(sameDiff, new SDVariable[] {matrix, rhs});
addBArgument(lower, adjoint);
}
@Override
public String opName() {
return "triangular_solve";
}
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
if(attributesForNode.containsKey("adjoint")){
addBArgument(attributesForNode.get("adjoint").getB());
}
if(attributesForNode.containsKey("lower")){
addBArgument(attributesForNode.get("lower").getB());
}
}
@Override
public String tensorflowName() {
return "MatrixTriangularSolve";

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.broadcast;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
@ -27,6 +28,7 @@ import org.nd4j.linalg.api.ops.DynamicCustomOp;
import java.util.Arrays;
import java.util.List;
@NoArgsConstructor
public class BiasAddGrad extends DynamicCustomOp {
protected boolean nchw = true;
@ -40,7 +42,16 @@ public class BiasAddGrad extends DynamicCustomOp {
super(new INDArray[]{input, bias, gradient}, wrapOrNull(output));
}
public BiasAddGrad() {}
public BiasAddGrad(@NonNull INDArray input, @NonNull INDArray bias, @NonNull INDArray gradient,
boolean nchw) {
addInputArgument(input, bias, gradient);
this.nchw = nchw;
addBArgument(nchw);
}
public BiasAddGrad(@NonNull INDArray input, @NonNull INDArray bias, @NonNull INDArray gradient) {
this(input, bias, gradient, false);
}
@Override
public int opNum() {

View File

@ -16,6 +16,8 @@
package org.nd4j.linalg.api.ops.impl.indexaccum;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
@ -35,6 +37,8 @@ import java.util.List;
*
* @author raver119@gmail.com
*/
@Data
@NoArgsConstructor
public class FirstIndex extends BaseIndexAccumulation {
protected Condition condition;
protected double compare;
@ -50,9 +54,6 @@ public class FirstIndex extends BaseIndexAccumulation {
this.extraArgs = new Object[] {compare, eps, (double) mode};
}
public FirstIndex() {}
public FirstIndex(INDArray x, @NonNull Condition condition, int... dimension) {
this(x, condition, false, dimension);
}

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.indexaccum;
import lombok.Data;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -30,6 +31,7 @@ import java.util.List;
*
* @author Adam Gibson
*/
@Data
public class IAMax extends BaseIndexAccumulation {
public IAMax(SameDiff sameDiff, SDVariable i_v, boolean keepDims, int[] dimensions) {
super(sameDiff, i_v, keepDims, dimensions);

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.indexaccum;
import lombok.Data;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -30,6 +31,7 @@ import java.util.List;
*
* @author Adam Gibson
*/
@Data
public class IAMin extends BaseIndexAccumulation {
public IAMin(SameDiff sameDiff, SDVariable i_v, boolean keepDims, int[] dimensions) {
super(sameDiff, i_v, keepDims, dimensions);

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.indexaccum;
import lombok.Data;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
@ -31,6 +32,7 @@ import java.util.List;
*
* @author Alex Black
*/
@Data
public class IMax extends BaseIndexAccumulation {
public IMax(SameDiff sameDiff, SDVariable i_v, boolean keepDims, int[] dimensions) {
super(sameDiff, i_v, keepDims, dimensions);

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.indexaccum;
import lombok.Data;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
@ -30,6 +31,7 @@ import java.util.List;
*
* @author Alex Black
*/
@Data
public class IMin extends BaseIndexAccumulation {
public IMin(SameDiff sameDiff, SDVariable i_v, boolean keepDims, int[] dimensions) {
super(sameDiff, i_v, keepDims, dimensions);

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.indexaccum;
import lombok.Data;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
@ -36,6 +37,7 @@ import java.util.Map;
*
* @author raver119@gmail.com
*/
@Data
public class LastIndex extends BaseIndexAccumulation {
protected Condition condition;
protected double compare;

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.indexaccum.custom;
import lombok.Data;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
@ -29,6 +30,7 @@ import java.util.Collections;
import java.util.List;
import java.util.Map;
@Data
public class ArgMax extends DynamicCustomOp {
protected DataType outputType;

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.indexaccum.custom;
import lombok.Data;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
@ -34,6 +35,7 @@ import java.util.Map;
*
* @author Alex Black
*/
@Data
public class ArgMin extends DynamicCustomOp {
protected DataType outputType = DataType.LONG;

View File

@ -22,6 +22,7 @@ import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.factory.Nd4j;
import java.util.Arrays;
import java.util.Collections;
@ -38,8 +39,14 @@ public class SoftmaxCrossEntropyWithLogitsLoss extends DynamicCustomOp {
protected int classesDim;
public SoftmaxCrossEntropyWithLogitsLoss(SameDiff sameDiff, SDVariable logits, SDVariable weights, SDVariable labels, int classesDim) {
super(null, sameDiff, new SDVariable[]{logits, weights, labels}, false);
// public SoftmaxCrossEntropyWithLogitsLoss(SameDiff sameDiff, SDVariable logits, SDVariable weights, SDVariable labels, int classesDim) {
// super(null, sameDiff, new SDVariable[]{logits, weights, labels}, false);
// this.classesDim = classesDim;
// addIArgument(classesDim);
// }
public SoftmaxCrossEntropyWithLogitsLoss(SameDiff sameDiff, SDVariable logits, SDVariable labels, int classesDim) {
super(null, sameDiff, new SDVariable[]{logits, labels}, false);
this.classesDim = classesDim;
addIArgument(classesDim);
}
@ -66,7 +73,8 @@ public class SoftmaxCrossEntropyWithLogitsLoss extends DynamicCustomOp {
public List<SDVariable> doDiff(List<SDVariable> grad){
//No external gradient
//Args: logits, weigths, label
SDVariable[] grads = f().lossSoftmaxCrossEntropyWithLogitsBp(arg(2), arg(0), arg(1), classesDim);
SDVariable[] args = args();
SDVariable[] grads = f().lossSoftmaxCrossEntropyWithLogitsBp(arg(0), arg(1), classesDim);
return Arrays.asList(grads);
}
}

View File

@ -20,14 +20,18 @@ import lombok.NoArgsConstructor;
import org.nd4j.autodiff.loss.LossReduce;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.impl.loss.BaseLoss;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
@ -56,4 +60,12 @@ public class SoftmaxCrossEntropyLossBp extends BaseLossBp {
public String opName() {
return "softmax_cross_entropy_loss_grad";
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
Preconditions.checkState(inputDataTypes != null && (inputDataTypes.size() == 2 || inputDataTypes.size() == 3),
"Expected 2 or 3 input datatypes for %s, got %s", getClass(), inputDataTypes);
return Arrays.asList(inputDataTypes.get(0), inputDataTypes.get(1), inputDataTypes.get(2)); //Same as predictions
}
}

View File

@ -19,8 +19,10 @@ package org.nd4j.linalg.api.ops.impl.loss.bp;
import lombok.NoArgsConstructor;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import java.util.Arrays;
import java.util.List;
@ -34,8 +36,8 @@ public class SoftmaxCrossEntropyWithLogitsLossBp extends DynamicCustomOp {
protected int classesDim;
public SoftmaxCrossEntropyWithLogitsLossBp(SameDiff sameDiff, SDVariable logits, SDVariable weights, SDVariable labels, int classesDim) {
super(null, sameDiff, new SDVariable[]{logits, weights, labels}, false);
public SoftmaxCrossEntropyWithLogitsLossBp(SameDiff sameDiff, SDVariable logits, SDVariable labels, int classesDim) {
super(null, sameDiff, new SDVariable[]{logits, labels}, false);
this.classesDim = classesDim;
addIArgument(classesDim);
}
@ -49,4 +51,9 @@ public class SoftmaxCrossEntropyWithLogitsLossBp extends DynamicCustomOp {
public List<SDVariable> doDiff(List<SDVariable> grad){
throw new UnsupportedOperationException("Differentiation of " + getClass().getName() + " not supported");
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
return Arrays.asList(arg(0).dataType(), arg(1).dataType());
}
}

View File

@ -16,9 +16,12 @@
package org.nd4j.linalg.api.ops.impl.reduce;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import java.util.*;
@ -30,12 +33,9 @@ import java.util.*;
*
* @author Alex Black
*/
@NoArgsConstructor
public class SufficientStatistics extends DynamicCustomOp {
public SufficientStatistics() {
}
public SufficientStatistics(SameDiff sameDiff, @NonNull SDVariable x, @NonNull SDVariable axis, SDVariable shift) {
super(null, sameDiff, argsNoNull(x, axis, shift), false);
}
@ -48,14 +48,30 @@ public class SufficientStatistics extends DynamicCustomOp {
}
}
public SufficientStatistics(@NonNull INDArray x, @NonNull INDArray axes, INDArray shift) {
if (shift != null)
addInputArgument(x, axes, shift);
else
addInputArgument(x, axes);
}
public SufficientStatistics(@NonNull INDArray x, @NonNull INDArray axes) {
this(x,axes,null);
}
@Override
public String opName() {
return "sufficient_statistics";
}
@Override
public List<SDVariable> doDiff(List<SDVariable> grad) {
throw new UnsupportedOperationException("Backprop not yet implemented for op: " + getClass().getSimpleName());
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes) {
// FIXME
return Arrays.asList(inputDataTypes.get(0), inputDataTypes.get(0),inputDataTypes.get(0));
}
}

View File

@ -111,7 +111,7 @@ public class TensorMmul extends DynamicCustomOp {
int[][] deletedAxes = new int[][]{
removeIndex(aAxes, sumAxes[0]),
removeIndex(bAxes, sumAxes[1])};
int[] gAxes = range(0, i_v1.get(0).getShape().length);
int[] gAxes = range(0, i_v1.get(0).eval().shape().length);
int[][] firstAxes = new int[][]{
Arrays.copyOfRange(gAxes, deletedAxes[0].length, gAxes.length),
deletedAxes[1]
@ -144,18 +144,20 @@ public class TensorMmul extends DynamicCustomOp {
int[][] axes) {
int validationLength = Math.min(axes[0].length, axes[1].length);
INDArray aArray = a.eval();
INDArray bArray = b.eval();
for (int i = 0; i < validationLength; i++) {
if (a.getShape()[axes[0][i]] != b.getShape()[axes[1][i]])
if (aArray.shape()[axes[0][i]] != bArray.shape()[axes[1][i]])
throw new IllegalArgumentException("Size of the given axes at each dimension must be the same size.");
if (axes[0][i] < 0)
axes[0][i] += a.getShape().length;
axes[0][i] += aArray.shape().length;
if (axes[1][i] < 0)
axes[1][i] += b.getShape().length;
axes[1][i] += bArray.shape().length;
}
List<Integer> listA = new ArrayList<>();
for (int i = 0; i < a.getShape().length; i++) {
for (int i = 0; i < aArray.shape().length; i++) {
if (!Ints.contains(axes[0], i))
listA.add(i);
}
@ -164,7 +166,7 @@ public class TensorMmul extends DynamicCustomOp {
List<Integer> listB = new ArrayList<>();
for (int i = 0; i < b.getShape().length; i++) {
for (int i = 0; i < bArray.shape().length; i++) {
if (!Ints.contains(axes[1], i))
listB.add(i);
}
@ -172,9 +174,9 @@ public class TensorMmul extends DynamicCustomOp {
int[] newAxesB = Ints.concat(axes[1], Ints.toArray(listB));
int n2 = 1;
int aLength = Math.min(a.getShape().length, axes[0].length);
int aLength = Math.min(aArray.shape().length, axes[0].length);
for (int i = 0; i < aLength; i++) {
n2 *= a.getShape()[axes[0][i]];
n2 *= aArray.shape()[axes[0][i]];
}
//if listA and listB are empty these do not initialize.
@ -186,13 +188,13 @@ public class TensorMmul extends DynamicCustomOp {
} else {
oldShapeA = Longs.toArray(listA);
for (int i = 0; i < oldShapeA.length; i++)
oldShapeA[i] = a.getShape()[(int) oldShapeA[i]];
oldShapeA[i] = aArray.shape()[(int) oldShapeA[i]];
}
int n3 = 1;
int bNax = Math.min(b.getShape().length, axes[1].length);
int bNax = Math.min(bArray.shape().length, axes[1].length);
for (int i = 0; i < bNax; i++) {
n3 *= b.getShape()[axes[1][i]];
n3 *= bArray.shape()[axes[1][i]];
}
@ -203,7 +205,7 @@ public class TensorMmul extends DynamicCustomOp {
} else {
oldShapeB = Longs.toArray(listB);
for (int i = 0; i < oldShapeB.length; i++)
oldShapeB[i] = b.getShape()[(int) oldShapeB[i]];
oldShapeB[i] = bArray.shape()[(int) oldShapeB[i]];
}

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.reduce.bp;
import lombok.NoArgsConstructor;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
@ -30,7 +31,7 @@ import java.util.List;
/**
* @author Alex Black
*/
@NoArgsConstructor
public abstract class BaseReductionBp extends DynamicCustomOp {
protected boolean keepDims;
@ -96,7 +97,12 @@ public abstract class BaseReductionBp extends DynamicCustomOp {
addArgs();
}
public BaseReductionBp(){}
public BaseReductionBp(INDArray origInput1, INDArray origInput2, INDArray gradAtOutput, INDArray output1, INDArray output2, boolean keepDims, int... dimensions){
super(null, new INDArray[]{origInput1, origInput2, gradAtOutput}, new INDArray[]{output1, output2});
this.keepDims = keepDims;
this.dimensions = dimensions;
addArgs();
}
protected void addArgs(){
addTArgument(keepDims ? 1 : 0);

View File

@ -16,17 +16,23 @@
package org.nd4j.linalg.api.ops.impl.reduce.bp;
import lombok.NoArgsConstructor;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.Arrays;
import java.util.List;
/**
* Backprop op for Dot pairwise reduction operation
*
* @author Alex Black
*/
@NoArgsConstructor
public class DotBp extends BaseReductionBp {
public DotBp(SameDiff sameDiff, SDVariable origInput1, SDVariable origInput2, SDVariable gradAtOutput, boolean keepDims, int... dimensions) {
@ -37,10 +43,22 @@ public class DotBp extends BaseReductionBp {
super(origInput1, origInput2, gradAtOutput, output, keepDims, dimensions);
}
public DotBp(){}
public DotBp(INDArray origInput1, INDArray origInput2, INDArray gradAtOutput,
INDArray outputX, INDArray outputY, boolean keepDims, int... dimensions) {
super(origInput1, origInput2, gradAtOutput, outputX, outputY, keepDims, dimensions);
}
@Override
public String opName() {
return "reduce_dot_bp";
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
Preconditions.checkState(dataTypes != null && dataTypes.size() == 3, "Expected exactly 3 input datatype for %s, got input %s", getClass(), dataTypes);
Preconditions.checkState(dataTypes.get(0).isFPType(), "First input must be a floating point type, got %s", dataTypes.get(0));
Preconditions.checkState(dataTypes.get(1).isFPType(), "Second input (gradient at reduction output) must be a floating point type, got %s", dataTypes.get(1));
Preconditions.checkState(dataTypes.get(2).isFPType(), "Second input (gradient at reduction output) must be a floating point type, got %s", dataTypes.get(2));
return Arrays.asList(dataTypes.get(0), dataTypes.get(0));
}
}

View File

@ -1,86 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.reduce.floating;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseReduceFloatOp;
import org.nd4j.linalg.api.ops.BaseReduceOp;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
/**
* Calculate a bias
*
* @author Adam Gibson
*/
public class Bias extends BaseReduceFloatOp {
private double mean;
public Bias(SameDiff sameDiff, SDVariable i_v, int[] dimensions, double mean) {
super(sameDiff, i_v, dimensions);
this.mean = mean;
}
public Bias(SameDiff sameDiff, SDVariable i_v, SDVariable i_v2, int[] dimensions, double mean) {
super(sameDiff, i_v, i_v2, dimensions);
this.mean = mean;
}
public Bias() {}
public Bias(INDArray x, int... dimensions) {
super(x, dimensions);
}
@Override
public Map<String, Object> propertiesForFunction() {
Map<String,Object> ret = new LinkedHashMap<>();
ret.put("mean",mean);
return ret;
}
@Override
public int opNum() {
return 2;
}
@Override
public String opName() {
return "bias";
}
@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
return null;
}
@Override
public String onnxName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
}
@Override
public String tensorflowName() {
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
}
}

View File

@ -24,6 +24,7 @@ import org.nd4j.base.Preconditions;
import org.nd4j.imports.descriptors.properties.PropertyMapping;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
@ -45,6 +46,7 @@ public class SequenceMask extends DynamicCustomOp {
public SequenceMask(SameDiff sameDiff, SDVariable input, SDVariable maxLen, DataType dataType) {
super(null, sameDiff, new SDVariable[] {input, maxLen}, false);
this.dataType = dataType;
addDArgument(dataType);
}
public SequenceMask(SameDiff sameDiff, SDVariable input, int maxLen, DataType dataType) {
@ -53,13 +55,23 @@ public class SequenceMask extends DynamicCustomOp {
this.is_static_maxlen = true;
addIArgument(maxLen);
this.dataType = dataType;
addDArgument(dataType);
}
public SequenceMask(SameDiff sameDiff, SDVariable input, DataType dataType) {
super(null, sameDiff, new SDVariable[] {input}, false);
this.dataType = dataType;
addDArgument(dataType);
}
public SequenceMask(INDArray input, int maxLen, DataType dataType) {
addInputArgument(input);
addIArgument(maxLen);
//addIArgument(dataType.toInt());
addDArgument(dataType);
this.dataType = dataType;
}
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.shape;
import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
@ -39,23 +40,37 @@ import java.util.Map;
* @author Adam Gibson
*/
@Slf4j
@NoArgsConstructor
public class ZerosLike extends DynamicCustomOp {
protected DataType outputType; //Allow customizing dtype for TF import
public ZerosLike() {
public ZerosLike(String name, SameDiff sameDiff, SDVariable input) {
this(name, sameDiff, input, false, input.dataType());
}
public ZerosLike(String name, SameDiff sameDiff, SDVariable input) {
this(name, sameDiff, input, false);
public ZerosLike(String name, SameDiff sameDiff, SDVariable input, DataType dataType) {
this(name, sameDiff, input, false, dataType);
}
public ZerosLike(String name, SameDiff sameDiff, SDVariable input, boolean inPlace) {
this(name, sameDiff, input, inPlace, input.dataType());
}
public ZerosLike(String name, SameDiff sameDiff, SDVariable input, boolean inPlace, DataType dataType) {
super(name, sameDiff, new SDVariable[]{input}, inPlace);
addDArgument(dataType);
}
public ZerosLike(INDArray in, INDArray out){
this(in, out, in.dataType());
}
public ZerosLike(INDArray in, INDArray out, DataType dataType) {
super(null, in, out, null, null);
if (dataType != null) {
addDArgument(dataType);
}
}

View File

@ -22,6 +22,7 @@ import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.factory.Nd4j;
import java.util.Arrays;
import java.util.Collections;
@ -52,16 +53,13 @@ public class BatchToSpace extends DynamicCustomOp {
}
public BatchToSpace(SameDiff sameDiff, SDVariable[] args, int[] blocks, int[][] crops, boolean inPlace) {
super(null, sameDiff, args, inPlace);
super(null, sameDiff, new SDVariable[]{args[0], sameDiff.constant(Nd4j.createFromArray(crops))}, inPlace);
this.blocks = blocks;
this.crops = crops;
for (val b : blocks)
addIArgument(b);
for (int e = 0; e < crops.length; e++)
addIArgument(crops[e][0], crops[e][1]);
}
@Override

View File

@ -22,6 +22,7 @@ import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.factory.Nd4j;
import java.util.Arrays;
import java.util.Collections;
@ -53,16 +54,12 @@ public class SpaceToBatch extends DynamicCustomOp {
}
public SpaceToBatch(SameDiff sameDiff, SDVariable[] args, int[] blocks, int[][] padding, boolean inPlace) {
super(null, sameDiff, args, inPlace);
super(null, sameDiff, new SDVariable[]{args[0], sameDiff.constant(Nd4j.createFromArray(padding))}, inPlace);
this.blocks = blocks;
this.padding = padding;
for (val b : blocks)
addIArgument(b);
for (int e = 0; e < padding.length; e++)
addIArgument(padding[e][0], padding[e][1]);
addIArgument(blocks[0]);
}
@Override

View File

@ -58,7 +58,8 @@ public class UnsortedSegmentMax extends DynamicCustomOp {
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes);
Preconditions.checkState(inputDataTypes != null && (inputDataTypes.size() == 2 || inputDataTypes.size() == 3),
"Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes);
return Collections.singletonList(inputDataTypes.get(0));
}

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.transforms.segment;
import lombok.NoArgsConstructor;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
@ -31,6 +32,7 @@ import java.util.List;
*
* @author Alex Black
*/
@NoArgsConstructor
public class UnsortedSegmentMean extends DynamicCustomOp {
private int numSegments;
@ -41,8 +43,6 @@ public class UnsortedSegmentMean extends DynamicCustomOp {
addIArgument(numSegments);
}
public UnsortedSegmentMean(){ }
@Override
public String opName(){
return "unsorted_segment_mean";
@ -60,7 +60,8 @@ public class UnsortedSegmentMean extends DynamicCustomOp {
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes);
Preconditions.checkState(inputDataTypes != null && (inputDataTypes.size() == 2 || inputDataTypes.size() == 3),
"Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes);
return Collections.singletonList(inputDataTypes.get(0));
}

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.transforms.segment;
import lombok.NoArgsConstructor;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
@ -31,6 +32,7 @@ import java.util.List;
*
* @author Alex Black
*/
@NoArgsConstructor
public class UnsortedSegmentMin extends DynamicCustomOp {
private int numSegments;
@ -41,8 +43,6 @@ public class UnsortedSegmentMin extends DynamicCustomOp {
addIArgument(numSegments);
}
public UnsortedSegmentMin(){ }
@Override
public String opName(){
return "unsorted_segment_min";
@ -60,7 +60,8 @@ public class UnsortedSegmentMin extends DynamicCustomOp {
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes);
Preconditions.checkState(inputDataTypes != null && (inputDataTypes.size() == 2 || inputDataTypes.size() == 3),
"Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes);
return Collections.singletonList(inputDataTypes.get(0));
}
}

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.transforms.segment;
import lombok.NoArgsConstructor;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
@ -31,6 +32,7 @@ import java.util.List;
*
* @author Alex Black
*/
@NoArgsConstructor
public class UnsortedSegmentProd extends DynamicCustomOp {
private int numSegments;
@ -41,8 +43,6 @@ public class UnsortedSegmentProd extends DynamicCustomOp {
addIArgument(numSegments);
}
public UnsortedSegmentProd(){ }
@Override
public String opName(){
return "unsorted_segment_prod";
@ -60,7 +60,8 @@ public class UnsortedSegmentProd extends DynamicCustomOp {
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes);
Preconditions.checkState(inputDataTypes != null && (inputDataTypes.size() == 2 || inputDataTypes.size() == 3),
"Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes);
return Collections.singletonList(inputDataTypes.get(0));
}
}

View File

@ -16,10 +16,12 @@
package org.nd4j.linalg.api.ops.impl.transforms.segment;
import lombok.NoArgsConstructor;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import java.util.ArrayList;
@ -31,18 +33,23 @@ import java.util.List;
*
* @author Alex Black
*/
@NoArgsConstructor
public class UnsortedSegmentSqrtN extends DynamicCustomOp {
private int numSegments;
public UnsortedSegmentSqrtN(INDArray data, INDArray segmentIds, int numSegments) {
addInputArgument(data, segmentIds);
addIArgument(numSegments);
this.numSegments = numSegments;
}
public UnsortedSegmentSqrtN(SameDiff sameDiff, SDVariable data, SDVariable segmentIds, int numSegments) {
super(null, sameDiff, new SDVariable[] {data, segmentIds}, false);
this.numSegments = numSegments;
addIArgument(numSegments);
}
public UnsortedSegmentSqrtN(){ }
@Override
public String opName(){
return "unsorted_segment_sqrt_n";
@ -60,7 +67,8 @@ public class UnsortedSegmentSqrtN extends DynamicCustomOp {
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes);
Preconditions.checkState(inputDataTypes != null && (inputDataTypes.size() == 2 || inputDataTypes.size() == 3),
"Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes);
List<DataType> out = new ArrayList<>();
for( int i=0; i<numSegments; i++ ){
out.add(inputDataTypes.get(0));

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.transforms.segment;
import lombok.NoArgsConstructor;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
@ -32,6 +33,7 @@ import java.util.List;
*
* @author Alex Black
*/
@NoArgsConstructor
public class UnsortedSegmentSum extends DynamicCustomOp {
private int numSegments;
@ -42,8 +44,6 @@ public class UnsortedSegmentSum extends DynamicCustomOp {
addIArgument(numSegments);
}
public UnsortedSegmentSum(){ }
@Override
public String opName(){
return "unsorted_segment_sum";
@ -61,7 +61,8 @@ public class UnsortedSegmentSum extends DynamicCustomOp {
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes);
Preconditions.checkState(inputDataTypes != null && (inputDataTypes.size() == 2 || inputDataTypes.size() == 3),
"Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes);
//TODO Allow customizing output type
return Collections.singletonList(Nd4j.defaultFloatingPointType());
}

View File

@ -50,6 +50,11 @@ public class LayerOpValidation extends BaseOpValidation {
super(backend);
}
@Override
public long getTimeoutMilliseconds() {
return 90000L;
}
@Test
public void testXwPlusB() {
Nd4j.getRandom().setSeed(12345);
@ -319,7 +324,7 @@ public class LayerOpValidation extends BaseOpValidation {
@Test
public void testIm2Col() {
OpValidationSuite.ignoreFailing(); //TEMPORARY DUE TO JVM CRASH: https://github.com/deeplearning4j/deeplearning4j/issues/6873
//OpValidationSuite.ignoreFailing(); //TEMPORARY DUE TO JVM CRASH: https://github.com/deeplearning4j/deeplearning4j/issues/6873
Nd4j.getRandom().setSeed(12345);
int[][] inputSizes = new int[][]{{1, 3, 8, 8}, {3, 6, 12, 12}};

View File

@ -32,6 +32,9 @@ import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.custom.*;
import org.nd4j.linalg.api.ops.impl.broadcast.BiasAdd;
import org.nd4j.linalg.api.ops.impl.broadcast.BiasAddGrad;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.StopGradient;
import org.nd4j.linalg.api.ops.impl.reduce.Mmul;
import org.nd4j.linalg.api.ops.impl.shape.DiagPart;
@ -513,7 +516,7 @@ public class MiscOpValidation extends BaseOpValidation {
@Test
public void testTrace(){
//TODO need to work out how to handle shape_op for scalars...
OpValidationSuite.ignoreFailing();
//OpValidationSuite.ignoreFailing();
Nd4j.getRandom().setSeed(12345);
for( int[] inShape : new int[][]{{3,3}}){
@ -546,12 +549,15 @@ public class MiscOpValidation extends BaseOpValidation {
SDVariable x = sameDiff.var("x", arr);
SDVariable y = sameDiff.var("y", arr2);
SDVariable result = sameDiff.tensorMmul(x, y, new int[][]{{0}, {1}});
assertArrayEquals(ArrayUtil.getTensorMmulShape(new long[]{2, 2, 2}, new long[]{2, 2, 2}, new int[][]{{0}, {1}}), result.getShape());
assertEquals(32, sameDiff.numElements());
assertArrayEquals(ArrayUtil.getTensorMmulShape(new long[]{2, 2, 2}, new long[]{2, 2, 2}, new int[][]{{0}, {1}}),
result.eval().shape());
assertEquals(16, sameDiff.numElements());
SDVariable loss = sameDiff.standardDeviation(result, true);
sameDiff.addLossVariable(loss);
String err = OpValidation.validate(new TestCase(sameDiff));
assertNull(err);
}
@Test
@ -1782,4 +1788,338 @@ public class MiscOpValidation extends BaseOpValidation {
assertEquals(exp, out); //Values in x not in y
assertEquals(exp, outIdx); //Indices of the values in x not in y
}
@Test
public void testDivideNoNan() {
OpValidationSuite.ignoreFailing(); //TODO: implement DivideNoNan.doDiff()
SameDiff sameDiff = SameDiff.create();
INDArray in1 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
INDArray in2 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
SDVariable input1 = sameDiff.var(in1);
SDVariable input2 = sameDiff.var(in2);
INDArray expected = Nd4j.ones(3,4);
SDVariable output = new DivideNoNan(sameDiff, input1, input2).outputVariable();
TestCase tc = new TestCase(sameDiff)
.gradientCheck(true)
.expectedOutput(output.name(), expected);
String err = OpValidation.validate(tc);
assertNull(err);
}
@Test
public void testDigamma() {
INDArray in1 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
INDArray expected = Nd4j.createFromArray(new double[]{
-0.5772157,0.42278433,0.9227843,1.2561177,1.5061177,1.7061176,1.8727844,2.0156415,2.1406415,2.2517526,2.3517525,2.4426618
}).reshape(3,4);
val tc = new OpTestCase(new Digamma(in1)).expectedOutput(0, expected);
String err = OpValidation.validate(tc);
assertNull(err);
}
@Test
public void testFlatten() {
SameDiff sameDiff = SameDiff.create();
INDArray x = Nd4j.linspace(DataType.DOUBLE, 1, 27, 1).reshape(3,3,3);
SDVariable sdx = sameDiff.var(x);
INDArray expected = Nd4j.linspace(DataType.DOUBLE,1,27,1);
SDVariable output = new Flatten(sameDiff, 'c', sdx).outputVariable();
SDVariable loss = sameDiff.standardDeviation(sdx, true);
sameDiff.addLossVariable(loss);
TestCase tc = new TestCase(sameDiff)
.gradientCheck(true)
.expectedOutput(output.name(), expected);
String err = OpValidation.validate(tc);
assertNull(err);
}
@Test
public void testFusedBatchNorm() {
OpValidationSuite.ignoreFailing();
SameDiff sameDiff = SameDiff.create();
INDArray x = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 2*2*3*4).reshape(2,2,3,4);
INDArray scale = Nd4j.create(DataType.DOUBLE, 4);
scale.assign(0.5);
INDArray offset = Nd4j.create(DataType.DOUBLE, 4);
offset.assign(2.0);
SDVariable input1 = sameDiff.var(x);
SDVariable input2 = sameDiff.var(scale);
SDVariable input3 = sameDiff.var(offset);
INDArray expectedY = Nd4j.createFromArray(new double[]{
985.5258, 985.5258, 985.5258, 985.5258,
659.7321, 659.7321, 659.7321, 659.7321,
399.0972, 399.0972, 399.0972, 399.0972,
203.6210, 203.6210, 203.6210, 203.6210,
73.3036, 73.3036, 73.3036, 73.3036,
8.1448, 8.1448, 8.1448, 8.1448,
8.1448, 8.1448, 8.1448, 8.1448,
73.3036, 73.3036, 73.3036, 73.3036,
203.6210, 203.6210, 203.6210, 203.6210,
399.0972, 399.0972, 399.0972, 399.0972,
659.7321, 659.7321, 659.7321, 659.7321,
985.5258, 985.5258, 985.5258, 985.5258}).reshape(x.shape());
INDArray expectedBatchMean = Nd4j.createFromArray(new double[]{23., 24., 25., 26.});
INDArray expectedBatchVar = Nd4j.createFromArray(new double[]{208.00001526, 208.00001526, 208.00001526, 208.00001526});
SDVariable[] outputs = new FusedBatchNorm(sameDiff, input1, input2, input3, 0, 1).outputVariables();
SDVariable loss = sameDiff.standardDeviation(input1, true);
sameDiff.addLossVariable(loss);
TestCase tc = new TestCase(sameDiff)
.gradientCheck(true)
.expectedOutput(outputs[0].name(), expectedY)
.expectedOutput(outputs[1].name(), expectedBatchMean)
.expectedOutput(outputs[2].name(), expectedBatchVar);
String err = OpValidation.validate(tc);
assertNull(err);
}
@Test
public void testIgamma() {
INDArray in1 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
INDArray in2 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
INDArray expected = Nd4j.createFromArray(new double[]{
0.63212055,0.59399414,0.5768099,0.56652874,0.5595013,0.5542634,0.5501591,0.5463888,0.54329145,0.54048204,0.5378594,0.53233755
}).reshape(3,4);
val tc = new OpTestCase(new Igamma(in1, in2)).expectedOutput(0, expected);
String err = OpValidation.validate(tc);
assertNull(err);
}
@Test
public void testIgammaC() {
INDArray in1 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
INDArray in2 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
INDArray expected = Nd4j.createFromArray(new double[]{
0.36787945,0.40600586,0.42319012,0.43347126,0.4404987,0.44573656,0.4498409,0.45361117,0.45670855,0.459518,0.46214062,0.46766248
}).reshape(3,4);
val tc = new OpTestCase(new Igammac(in1, in2)).expectedOutput(0, expected);
String err = OpValidation.validate(tc);
assertNull(err);
}
@Test
public void testLgamma() {
SameDiff sameDiff = SameDiff.create();
INDArray in = Nd4j.linspace(DataType.DOUBLE, 1, 12, 1).reshape(3, 4);
SDVariable sdInput = sameDiff.var(in);
INDArray expected = Nd4j.createFromArray(new double[]{
0.0,0.0,0.6931472,1.7917595,3.1780539,4.787492,6.5792513,8.525162,10.604603,12.801827,15.104413,17.502308
}).reshape(3,4);
SDVariable output = new Lgamma(sameDiff, sdInput).outputVariable();
SDVariable loss = sameDiff.standardDeviation(sdInput, true);
sameDiff.addLossVariable(loss);
TestCase tc = new TestCase(sameDiff)
.gradientCheck(true)
.expectedOutput(output.name(), expected);
String err = OpValidation.validate(tc);
assertNull(err);
}
@Test
public void testLu() {
SameDiff sameDiff = SameDiff.create();
INDArray in1 = Nd4j.createFromArray(new double[]{
1., 2., 3., 0., 2., 3., 0., 0., 7.
}).reshape(3,3);
SDVariable input1 = sameDiff.var(in1);
INDArray expected = Nd4j.createFromArray(new double[]{
1., 2., 3., 0., 2., 3., 0., 0., 7
}).reshape(3,3);
INDArray pexpected = Nd4j.createFromArray(new int[]{
0, 1, 2
});
sameDiff.loss.l2Loss(input1);
SDVariable[] output = new Lu(sameDiff, input1).outputVariables();
TestCase tc = new TestCase(sameDiff)
.gradientCheck(true)
.expectedOutput(output[0].name(), expected)
.expectedOutput(output[1].name(), pexpected);
String err = OpValidation.validate(tc);
assertNull(err);
}
@Test
public void testMatrixBandPart() {
OpValidationSuite.ignoreFailing();
SameDiff sameDiff = SameDiff.create();
INDArray input = Nd4j.createFromArray(new float[]{0.7788f,0.8012f,0.7244f,0.2309f,
0.7271f,0.1804f,0.5056f,0.8925f,
0.5461f,0.9234f,0.0856f,0.7938f}).reshape(3,4);
SDVariable sdInput = sameDiff.var(input);
SDVariable sdInput1 = sameDiff.constant(1);
SDVariable sdInput2 = sameDiff.constant(-1);
INDArray expected = Nd4j.createFromArray(new float[]{
0.7788f, 0.8012f, 0.7244f, 0.2309f,
0.7271f, 0.1804f, 0.5056f, 0.8925f,
0.f, 0.9234f, 0.0856f, 0.7938f
}).reshape(3,4);
sameDiff.loss.l2Loss(sdInput);
SDVariable output = new MatrixBandPart(sameDiff, sdInput, 1, -1).outputVariable();
TestCase tc = new TestCase(sameDiff)
.gradientCheck(true)
.expectedOutput(output.name(), expected);
String err = OpValidation.validate(tc);
assertNull(err);
}
@Test
public void testPolygamma() {
INDArray in1 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
INDArray in2 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
INDArray expected = Nd4j.createFromArray(new double[]{
1.644934,-0.4041138,0.1189394,-0.03750069,0.01226151,-0.0041002957,0.001392272,-4.780109E-4,1.6549716E-4,-5.7675967E-5,2.0206635E-5,-7.1101636E-6
}).reshape(3,4);
val tc = new OpTestCase(new Polygamma(in1, in2)).expectedOutput(0, expected);
String err = OpValidation.validate(tc);
assertNull(err);
}
@Test
public void testTriangularSolve() {
INDArray a = Nd4j.createFromArray(new float[]{
3.f, 0.f, 0.f, 0.f,
2.f, 1.f, 0.f, 0.f,
1.f, 0.f, 1.f, 0.f,
1.f, 1.f, 1.f, 1.f
}).reshape(4,4);
INDArray b = Nd4j.createFromArray(new float[]{
4.f, 2.f, 4.f, 2.f
}).reshape(4,1);
INDArray expected = Nd4j.createFromArray(new float[]{
1.333333f, 2.0f, 4.0f, 2.0f
}).reshape(4,1);
val tc = new OpTestCase(new TriangularSolve(a, b, false, true)).expectedOutput(0, expected);
String err = OpValidation.validate(tc);
assertNull(err);
}
@Test
public void testBiasAdd() {
SameDiff sameDiff = SameDiff.create();
INDArray in1 = Nd4j.linspace(1, 12, 12);
INDArray in2 = Nd4j.linspace(1, 12, 12);
SDVariable input1 = sameDiff.var(in1);
SDVariable input2 = sameDiff.var(in2);
INDArray expected = Nd4j.createFromArray(new double[]{
2.0000, 4.0000, 6.0000, 8.0000, 10.0000, 12.0000, 14.0000, 16.0000, 18.0000, 20.0000, 22.0000, 24.0000
});
SDVariable output = new BiasAdd(sameDiff, input1, input2, false).outputVariable();
SDVariable loss = sameDiff.standardDeviation(input1, true);
sameDiff.addLossVariable(loss);
SDVariable loss2 = sameDiff.standardDeviation(input2, true);
sameDiff.addLossVariable(loss2);
TestCase tc = new TestCase(sameDiff)
.gradientCheck(true)
.expectedOutput(output.name(), expected);
String err = OpValidation.validate(tc);
assertNull(err);
}
@Test
public void testBiasAddGrad() {
SameDiff sameDiff = SameDiff.create();
INDArray x = Nd4j.linspace(DataType.FLOAT,1, 24, 24).reshape(2,2,2,3);
INDArray grad = Nd4j.linspace(DataType.FLOAT, 0.1, 0.1, 24).reshape(2,2,2,3);
INDArray bias = Nd4j.createFromArray(new float[]{-1.f, -2.f, -3.f});
INDArray expected = Nd4j.createFromArray(new float[]{9.2f, 10.f , 10.8f});
OpTestCase tc = new OpTestCase(new BiasAddGrad(x, bias, grad,false)).
expectedOutput(0, grad).
expectedOutput(1, expected);
String err = OpValidation.validate(tc);
assertNull(err);
}
@Test
public void testRoll() {
INDArray x = Nd4j.createFromArray(new double[]{ 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42,
21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, 22.11, 22.12, 22.21, 22.22, 22.31, 22.32, 22.41, 22.42}).
reshape(2,2,4,2);
INDArray expected = Nd4j.createFromArray(new double[]{ 22.21, 22.22, 22.31, 22.32, 22.41, 22.42, 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42,
12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, 21.11, 21.12, 21.21, 21.22, 21.31, 21.32,
21.41, 21.42, 22.11, 22.12
}).reshape(x.shape());
int shift = 6;
val tc = new OpTestCase(new Roll(x,shift)).expectedOutput(0,expected);
String err = OpValidation.validate(tc);
assertNull(err);
}
}

View File

@ -35,16 +35,20 @@ import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.indexaccum.IAMax;
import org.nd4j.linalg.api.ops.impl.indexaccum.IAMin;
import org.nd4j.linalg.api.ops.impl.loss.SoftmaxCrossEntropyWithLogitsLoss;
import org.nd4j.linalg.api.ops.impl.reduce.Moments;
import org.nd4j.linalg.api.ops.impl.reduce.NormalizeMoments;
import org.nd4j.linalg.api.ops.impl.reduce.floating.AMean;
import org.nd4j.linalg.api.ops.impl.reduce.SufficientStatistics;
import org.nd4j.linalg.api.ops.impl.reduce.floating.*;
import org.nd4j.linalg.api.ops.impl.reduce.same.ASum;
import org.nd4j.linalg.api.ops.impl.reduce3.*;
import org.nd4j.linalg.api.ops.impl.summarystats.StandardDeviation;
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
import org.nd4j.linalg.checkutil.NDArrayCreationUtil;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.primitives.Pair;
@ -96,7 +100,7 @@ public class ReductionOpValidation extends BaseOpValidation {
@Test
public void testZeroCount() {
List<String> allFailed = new ArrayList<>();
for (int i = 0; i < 2; i++) {
for (int i = 0; i < 21; i++) {
SameDiff sd = SameDiff.create();
INDArray ia;
@ -159,25 +163,25 @@ public class ReductionOpValidation extends BaseOpValidation {
@Test
public void testReductionGradientsSimple() {
OpValidationSuite.ignoreFailing(); //TODO TEMPORARY DUE TO CRASHES
//OpValidationSuite.ignoreFailing(); //TODO TEMPORARY DUE TO CRASHES
//Test reductions: final and only function
Nd4j.getRandom().setSeed(12345);
List<String> failed = new ArrayList<>();
for (int i = 0; i < 21; i++) {
SameDiff sd = SameDiff.create();
int nOut = 4;
int minibatch = 10;
SDVariable input = sd.var("in", -1, nOut);
SDVariable input = sd.var("in", minibatch, nOut);
INDArray inputArr = Nd4j.randn(minibatch, nOut).muli(100);
long length = nOut * minibatch;
SDVariable loss;
String name;
TestCase tc = new TestCase(sd);
boolean gradCheck = true;
switch (i) {
case 0:
loss = sd.mean("loss", input);
@ -234,11 +238,13 @@ public class ReductionOpValidation extends BaseOpValidation {
loss = sd.math().countNonZero("loss", input);
name = "countNonZero";
tc.expectedOutput("loss", Nd4j.scalar(inputArr.length()));
gradCheck = false; //Long out, not floating point
break;
case 11:
loss = sd.math().countZero("loss", input);
name = "countZero";
tc.expectedOutput("loss", Nd4j.scalar(0));
tc.expectedOutput("loss", Nd4j.scalar(0L));
gradCheck = false; //Long out, not floating point
break;
case 12:
loss = sd.math().amax("loss", input);
@ -272,7 +278,7 @@ public class ReductionOpValidation extends BaseOpValidation {
loss = sd.math().logSumExp("loss", input);
INDArray expArr = Transforms.exp(inputArr);
double sum = expArr.sumNumber().doubleValue();
tc.expected("loss", Nd4j.create(new double[]{Math.log(sum)}));
tc.expected("loss", Nd4j.scalar(Math.log(sum)));
break;
case 18:
inputArr = Nd4j.rand(minibatch, nOut);
@ -307,9 +313,15 @@ public class ReductionOpValidation extends BaseOpValidation {
log.info("*** Starting test: " + msg);
sd.associateArrayWithVariable(inputArr, input);
if(gradCheck) {
sd.addLossVariable(loss);
}
tc.testName(msg);
if(!gradCheck){
tc.gradientCheck(false);
}
String error = OpValidation.validate(tc, true);
if (error != null)
failed.add(error);
@ -629,14 +641,14 @@ public class ReductionOpValidation extends BaseOpValidation {
List<String> failed = new ArrayList<>();
for (int[] reduceDims : new int[][]{{Integer.MAX_VALUE}, {0, 1, 2}, {0}, {1}, {2}, {0, 1}, {0, 2}, {1, 2}}) {
for (int i = 6; i < 7; i++) {
for (int i = 0; i < 7; i++) {
SameDiff sd = SameDiff.create();
sd.setLogExecution(false);
SDVariable in = sd.var("in", -1, d1, d2);
SDVariable in2 = sd.var("in2", -1, d1, d2);
SDVariable in = sd.var("in", d1, d1, d2);
SDVariable in2 = sd.var("in2", d0, d1, d2);
INDArray inArr = Nd4j.randn(new int[]{d0, d1, d2}).muli(100);
INDArray in2Arr = Nd4j.randn(inArr.shape()).muli(100);
@ -645,40 +657,43 @@ public class ReductionOpValidation extends BaseOpValidation {
SDVariable reduced;
String name;
TestCase tc = new TestCase(sd);
Double maxRelError = null;
switch (i) {
case 0:
reduced = sd.math().manhattanDistance(in, in2, reduceDims);
name = "manhattan";
exp = Nd4j.getExecutioner().exec(new ManhattanDistance(inArr, in2Arr, null, true, false, reduceDims));
exp = Nd4j.getExecutioner().exec(new ManhattanDistance(inArr, in2Arr, null, false, false, reduceDims));
break;
case 1:
reduced = sd.math().euclideanDistance(in, in2, reduceDims);
name = "euclidean";
exp = Nd4j.getExecutioner().exec(new EuclideanDistance(inArr, in2Arr, null, true, false, reduceDims));
exp = Nd4j.getExecutioner().exec(new EuclideanDistance(inArr, in2Arr, null, false, false, reduceDims));
break;
case 2:
inArr.muli(1e-4);
in2Arr.muli(1e-4);
reduced = sd.math().cosineSimilarity(in, in2, reduceDims);
name = "cosine";
exp = Nd4j.getExecutioner().exec(new CosineSimilarity(inArr, in2Arr, null, true, false, reduceDims));
exp = Nd4j.getExecutioner().exec(new CosineSimilarity(inArr, in2Arr, null, false, false, reduceDims));
maxRelError = 1e-4;
break;
case 3:
reduced = sd.math().cosineDistance(in, in2, reduceDims);
name = "cosinedistance";
exp = Nd4j.getExecutioner().exec(new CosineDistance(inArr, in2Arr, null, true, false, reduceDims));
exp = Nd4j.getExecutioner().exec(new CosineDistance(inArr, in2Arr, null, false, false, reduceDims));
maxRelError = 1e-4;
break;
case 4:
reduced = sd.math().hammingDistance(in, in2, reduceDims);
name = "hamming";
exp = Nd4j.getExecutioner().exec(new HammingDistance(inArr, in2Arr, null, true, false, reduceDims));
exp = Nd4j.getExecutioner().exec(new HammingDistance(inArr, in2Arr, null, false, false, reduceDims));
break;
case 5:
name = "jaccard";
reduced = sd.math().jaccardDistance(name, in, in2, reduceDims);
inArr.divi(100).addi(0.1);
in2Arr.divi(100).addi(0.1);
exp = Nd4j.getExecutioner().exec(new JaccardDistance(inArr, in2Arr, null, true, false, reduceDims));
exp = Nd4j.getExecutioner().exec(new JaccardDistance(inArr, in2Arr, null, false, false, reduceDims));
if (OpValidationSuite.IGNORE_FAILING && reduceDims.length == 2)
continue;
@ -708,6 +723,9 @@ public class ReductionOpValidation extends BaseOpValidation {
tc.expected(reduced, exp);
if(maxRelError != null)
tc.gradCheckMaxRelativeError(maxRelError);
String error = OpValidation.validate(tc, true);
if (error != null) {
failed.add(msg + " - " + error);
@ -768,7 +786,6 @@ public class ReductionOpValidation extends BaseOpValidation {
}
@Test
@Ignore("AB 2019/06/24 - Failing: Ignored to get to all passing baseline to prevent regressions via CI - see issue #7912")
public void testNormalizeMomentsOp() {
INDArray data = Nd4j.linspace(1, 100, 100, DataType.DOUBLE).reshape(10, 10);
INDArray ssSum = data.sum(0);
@ -780,7 +797,7 @@ public class ReductionOpValidation extends BaseOpValidation {
INDArray mean = Nd4j.createUninitialized(DataType.DOUBLE, meanExp.shape());
INDArray var = Nd4j.createUninitialized(DataType.DOUBLE, varExp.shape());
OpTestCase op = new OpTestCase(new NormalizeMoments(Nd4j.scalar(DataType.INT, 10), ssSum, ssSqSum, mean, var));
OpTestCase op = new OpTestCase(new NormalizeMoments(Nd4j.scalar(DataType.DOUBLE, 10), ssSum, ssSqSum, mean, var));
op.expectedOutput(0, meanExp);
op.expectedOutput(1, varExp);
@ -821,7 +838,7 @@ public class ReductionOpValidation extends BaseOpValidation {
List<String> failed = new ArrayList<>();
List<int[]> dims = Arrays.asList(new int[]{0}, new int[]{1}, new int[]{0, 1}, new int[0]);
INDArray in = Nd4j.rand(3, 4);
INDArray in = Nd4j.rand(DataType.DOUBLE,3, 4);
for (int t = 0; t < 4; t++) {
int[] d = dims.get(t);
@ -838,52 +855,47 @@ public class ReductionOpValidation extends BaseOpValidation {
switch (i) {
case 0:
reduce = s.argmax(dim);
exp = Nd4j.argMax(in, dim).castTo(DataType.DOUBLE);
exp = Nd4j.argMax(in, dim);
name = "argmax";
break;
case 1:
reduce = s.argmin(dim);
exp = Nd4j.argMin(in, dim).castTo(DataType.DOUBLE);
exp = Nd4j.argMin(in, dim);
name = "argmin";
break;
case 2:
reduce = sd.math().iamax(s, dim);
exp = Nd4j.getExecutioner().exec(new IAMax(in.dup(), dim));
exp = exp.castTo(DataType.DOUBLE);
name = "iamax";
break;
case 3:
reduce = sd.math().iamin(s, dim);
exp = Nd4j.getExecutioner().exec(new IAMin(in.dup(), dim));
exp = exp.castTo(DataType.DOUBLE);
name = "iamin";
break;
case 4:
reduce = sd.math().firstIndex(s, Conditions.greaterThan(0), dim);
exp = in.sum(dim).assign(0);
exp = exp.castTo(DataType.DOUBLE);
exp = in.sum(dim).assign(0).castTo(DataType.INT64);
name = "firstindex";
break;
case 5:
reduce = sd.math().lastIndex(s, Conditions.greaterThan(0), dim);
if (t == 0) exp = Nd4j.create(new double[]{2, 2, 2, 2});
else if (t == 1) exp = Nd4j.create(new double[]{3, 3, 3});
else exp = Nd4j.scalar(11.0);
exp = exp.castTo(DataType.DOUBLE);
if (t == 0) exp = Nd4j.createFromArray(2L, 2, 2, 2);
else if (t == 1) exp = Nd4j.createFromArray(3L, 3, 3);
else exp = Nd4j.scalar(11L);
name = "lastindex";
break;
case 6:
reduce = sd.matchConditionCount("count", s, Conditions.greaterThan(0), false, dim);
if (t == 0) exp = Nd4j.create(new double[]{3, 3, 3, 3});
else if (t == 1) exp = Nd4j.create(new double[]{4, 4, 4});
else exp = Nd4j.scalar(12.0);
exp = exp.castTo(DataType.DOUBLE);
if (t == 0) exp = Nd4j.createFromArray(3L, 3, 3, 3);
else if (t == 1) exp = Nd4j.createFromArray(4L, 4, 4);
else exp = Nd4j.scalar(12L);
name = "matchConditionCount";
break;
default:
throw new RuntimeException();
}
SDVariable preCast = reduce;
reduce = reduce.castTo(DataType.DOUBLE);
SDVariable loss;
@ -894,7 +906,7 @@ public class ReductionOpValidation extends BaseOpValidation {
}
TestCase tc = new TestCase(sd)
.expected(reduce, exp)
.expected(preCast, exp)
.gradientCheck(false)
.testName(name + " - " + (dim == null ? null : Arrays.toString(dim)));
@ -1335,4 +1347,254 @@ public class ReductionOpValidation extends BaseOpValidation {
}
}
}
@Test
public void testSufficientStatisticsOp() {
INDArray data = Nd4j.createFromArray(new double[]{
5.5, 0., 0.3, 5.5,1.5, 0., 1.3, 6.5,8.6, 0., 0., 0.4,2.5, 1., 0.3, 4.5,1.5, 1.,
1.3, 1.5,3.5, 0., 1.3, 2.5,2.6, 2., 3., 1.4,4.5, 1., 0.3, 0.5
}).reshape(2,2,2,4);
INDArray axes = Nd4j.linspace(DataType.LONG, 0, 3, 1);
OpTestCase op = new OpTestCase(new SufficientStatistics(data, axes));
INDArray expected1 = Nd4j.scalar(8.0);
INDArray expected2 = Nd4j.createFromArray(new double[]{
30.2, 5., 7.8, 22.8
});
INDArray expected3 = Nd4j.createFromArray(new double[]{
154.22, 7., 14.34, 103.62
});
op.expectedOutput(0, expected1);
op.expectedOutput(1, expected2);
op.expectedOutput(2, expected3);
String err = OpValidation.validate(op);
assertNull(err);
}
@Test
public void testStandardDeviation() {
for (boolean keepDims : new boolean[]{false, true}) {
SameDiff sameDiff = SameDiff.create();
INDArray in = Nd4j.linspace(1, 8, 8).reshape(2, 4);
SDVariable input = sameDiff.var(in);
INDArray expected = Nd4j.createFromArray(new double[]{
2, 2, 2, 2
});
if(keepDims){
expected = expected.reshape(1,4);
}
SDVariable output = new StandardDeviation(sameDiff, input, false, keepDims, new int[]{0}).outputVariable();
TestCase tc = new TestCase(sameDiff)
.gradientCheck(true)
.expectedOutput(output.name(), expected);
String err = OpValidation.validate(tc);
assertNull(err);
}
}
@Test
public void testSquaredNorm() {
for (boolean keepDims : new boolean[]{false, true}) {
SameDiff sameDiff = SameDiff.create();
INDArray in = Nd4j.linspace(1, 4, 4);
SDVariable input = sameDiff.var(in);
INDArray expected = Nd4j.scalar(30.0000);
if(keepDims)
expected = expected.reshape(1);
SDVariable output = new SquaredNorm(sameDiff, input, keepDims, new int[]{0}).outputVariable();
TestCase tc = new TestCase(sameDiff)
.gradientCheck(true)
.expectedOutput(output.name(), expected);
String err = OpValidation.validate(tc);
assertNull(err);
}
}
@Test
public void testShannonEntropy() {
OpValidationSuite.ignoreFailing(); //AB 2020/02/11 https://github.com/eclipse/deeplearning4j/issues/8695
SameDiff sameDiff = SameDiff.create();
INDArray in = Nd4j.linspace(1, 4, 4).castTo(DataType.DOUBLE);
SDVariable input = sameDiff.var(in);
INDArray expected = Nd4j.scalar(-69.68162);
SDVariable output = new ShannonEntropy(sameDiff, input, new int[]{0}).outputVariable();
TestCase tc = new TestCase(sameDiff)
.gradientCheck(true)
.expectedOutput(output.name(), expected);
String err = OpValidation.validate(tc);
assertNull(err);
}
@Test
public void testEntropy() {
SameDiff sameDiff = SameDiff.create();
INDArray in = Nd4j.linspace(1, 4, 4);
SDVariable input = sameDiff.var(in);
double expected = -10.2273;
SDVariable output = new Entropy(sameDiff, input, new int[]{0}).outputVariable();
TestCase tc = new TestCase(sameDiff)
.gradientCheck(true)
.expectedOutput(output.name(), Nd4j.scalar(expected));
String err = OpValidation.validate(tc);
assertNull(err);
}
@Test
public void testAMean() {
SameDiff sameDiff = SameDiff.create();
INDArray in = Nd4j.linspace(1, 12, 12).reshape(3, 4);
SDVariable input = sameDiff.var(in);
INDArray expected = Nd4j.createFromArray(new double[]{
5.0000, 6.0000, 7.0000, 8.0000
});
SDVariable output = new AMean(sameDiff, input, new int[]{0}).outputVariable();
TestCase tc = new TestCase(sameDiff)
.gradientCheck(true)
.expectedOutput(output.name(), expected);
String err = OpValidation.validate(tc);
assertNull(err);
}
@Test
public void testMean() {
SameDiff sameDiff = SameDiff.create();
INDArray in = Nd4j.linspace(1, 12, 12).reshape(3, 4);
SDVariable input = sameDiff.var(in);
INDArray expected = Nd4j.createFromArray(new double[]{
5.0000, 6.0000, 7.0000, 8.0000
});
SDVariable output = new Mean(sameDiff, input, false, new int[]{0}).outputVariable();
TestCase tc = new TestCase(sameDiff)
.gradientCheck(true)
.expectedOutput(output.name(), expected);
String err = OpValidation.validate(tc);
assertNull(err);
}
@Test
public void testNorm1() {
SameDiff sameDiff = SameDiff.create();
INDArray in = Nd4j.linspace(1, 12, 12).reshape(3, 4);
SDVariable input = sameDiff.var(in);
INDArray expected = Nd4j.createFromArray(new double[]{
15.0000, 18.0000, 21.0000, 24.0000
});
SDVariable output = new Norm1(sameDiff, input, false, new int[]{0}).outputVariable();
TestCase tc = new TestCase(sameDiff)
.gradientCheck(true)
.expectedOutput(output.name(), expected);
String err = OpValidation.validate(tc);
assertNull(err);
}
@Test
public void testNorm2() {
SameDiff sameDiff = SameDiff.create();
INDArray in = Nd4j.linspace(1, 12, 12).reshape(3, 4);
SDVariable input = sameDiff.var(in);
INDArray expected = Nd4j.createFromArray(new double[]{
10.3441, 11.8322, 13.3791, 14.9666
});
SDVariable output = new Norm2(sameDiff, input, false, new int[]{0}).outputVariable();
TestCase tc = new TestCase(sameDiff)
.gradientCheck(true)
.expectedOutput(output.name(), expected);
String err = OpValidation.validate(tc);
assertNull(err);
}
@Test
public void testNormMax() {
SameDiff sameDiff = SameDiff.create();
INDArray in = Nd4j.linspace(1, 12, 12).reshape(3, 4);
SDVariable input = sameDiff.var(in);
INDArray expected = Nd4j.createFromArray(new double[]{
9.0000, 10.0000, 11.0000, 12.0000
});
SDVariable output = new NormMax(sameDiff, input, false, new int[]{0}).outputVariable();
TestCase tc = new TestCase(sameDiff)
.gradientCheck(true)
.expectedOutput(output.name(), expected);
String err = OpValidation.validate(tc);
assertNull(err);
}
@Test
public void testSoftmaxCrossEntropyWithLogitsLoss() {
OpValidationSuite.ignoreFailing();
SameDiff sameDiff = SameDiff.create();
INDArray labels = Nd4j.createFromArray(new double[]{
0,1,1,0,0,0,1,0,1,0,1,1,1,0,1,0,1,0,0,1,1,0,1,0
}).reshape(2,3,4);
INDArray logits = Nd4j.linspace(DataType.DOUBLE, 0.1, 0.1, 24).reshape(2,3,4);
INDArray expected = Nd4j.createFromArray(new double[]{
0.26328, 1.46328, 1.72656, 0. , 0.26328, 0. , 1.46328, 0.26328, 1.72656, 0. , 1.72656, 1.46328
}).reshape(3,4);
SDVariable sdLogits = sameDiff.var("logits", logits);
SDVariable sdLabels = sameDiff.var("labels", labels);
SDVariable loss = sameDiff.math().abs(sdLogits);
SDVariable output = new SoftmaxCrossEntropyWithLogitsLoss(sameDiff, sdLogits, sdLabels, 0).outputVariable();
sameDiff.setLossVariables(output);
TestCase tc = new TestCase(sameDiff)
.gradientCheck(true)
.expectedOutput(output.name(), expected);
String err = OpValidation.validate(tc);
assertNull(err);
}
}

View File

@ -1016,7 +1016,7 @@ public class ShapeOpValidation extends BaseOpValidation {
@Test
public void testConstant(){
OpValidationSuite.ignoreFailing();
//OpValidationSuite.ignoreFailing();
//Case 0: no shape
SameDiff sd = SameDiff.create();
@ -1035,7 +1035,9 @@ public class ShapeOpValidation extends BaseOpValidation {
INDArray exp = Nd4j.valueArrayOf(new long[]{3,4,5}, 3.0);
loss = constant.std(true);
assertNull(OpValidation.validate(new TestCase(sd).expected(constant, ia)));
assertNull(OpValidation.validate(new TestCase(sd)
.gradientCheck(false)
.expected(constant, Nd4j.create(DataType.FLOAT, 3,4,5))));
}
@ -1272,7 +1274,7 @@ public class ShapeOpValidation extends BaseOpValidation {
SameDiff sd = SameDiff.create();
SDVariable data = sd.var("data", d);
SDVariable segments = sd.var("segments", s);
SDVariable segments = sd.constant("segments", s);
SDVariable sm;
INDArray exp;
@ -1326,6 +1328,7 @@ public class ShapeOpValidation extends BaseOpValidation {
}
SDVariable loss = sm.std(true);
sd.addLossVariable(loss);
TestCase tc = new TestCase(sd)
.testName(op)
@ -1363,17 +1366,19 @@ public class ShapeOpValidation extends BaseOpValidation {
@Test
public void testSequenceMask() {
OpValidationSuite.ignoreFailing(); //2018-01-09: output datatype issue?
SameDiff sameDiff = SameDiff.create();
INDArray arr = Nd4j.create(new float[] {1, 3, 2}).reshape(3);
SDVariable lengths = sameDiff.var("lengths", arr);
INDArray arr = Nd4j.createFromArray(new int[] {1, 3, 2});
// arr is not trainable, so it's constant in model
SDVariable lengths = sameDiff.constant(arr);
// Test with static max len
int maxlen = 5;
INDArray expected = Nd4j.create(new float[] {1, 0, 0, 0, 0,
1, 1, 1, 0, 0,
1, 1, 0, 0, 0},
new long[]{3, 5});
INDArray expected = Nd4j.create(new float[] {
1.f, 0.f, 0.f, 0.f, 0.f,
1.f, 1.f, 1.f, 0.f, 0.f,
1.f, 1.f, 0.f, 0.f, 0.f
}).reshape(3,5);
INDArray[] ret = Nd4j.exec(new SequenceMask(arr, maxlen, DataType.FLOAT));
SDVariable result1 = sameDiff.sequenceMask(lengths, maxlen, DataType.FLOAT);
assertArrayEquals(expected.shape(), result1.eval().shape());
assertEquals(expected, result1.eval());
@ -1382,14 +1387,14 @@ public class ShapeOpValidation extends BaseOpValidation {
String err = OpValidation.validate(new TestCase(sameDiff)
.expected(result1, expected)
.gradCheckSkipVariables(lengths.name()));
.gradientCheck(false));
assertNull(err);
// Test with dynamic maxlen
lengths = sameDiff.var("lengths2", arr); // required because of an internal samediff bug
SDVariable maxLen = sameDiff.var("maxLen", Nd4j.create(new float[]{5}).reshape(1));
lengths = sameDiff.constant("lengths2", arr);
SDVariable maxLen = sameDiff.constant("maxLen", Nd4j.scalar(5));
SDVariable result2 = sameDiff.sequenceMask(lengths, maxLen, DataType.FLOAT);
assertArrayEquals(expected.shape(), result2.eval().shape());
// assertArrayEquals(expected.shape(), result2.eval().shape());
assertEquals(expected, result2.eval());
}

View File

@ -303,7 +303,7 @@ public class TransformOpValidation extends BaseOpValidation {
@Test
public void testBatchToSpace() {
OpValidationSuite.ignoreFailing(); //TODO: https://github.com/deeplearning4j/deeplearning4j/issues/6863
//OpValidationSuite.ignoreFailing(); //TODO: https://github.com/deeplearning4j/deeplearning4j/issues/6863
Nd4j.getRandom().setSeed(1337);
int miniBatch = 4;
@ -314,7 +314,6 @@ public class TransformOpValidation extends BaseOpValidation {
int[] cropShape = new int[]{M, 2};
INDArray input = Nd4j.randn(inputShape).castTo(DataType.DOUBLE);
INDArray blocks = Nd4j.create(new float[]{2, 2}, blockShape).castTo(DataType.INT);
INDArray crops = Nd4j.create(new float[]{0, 0, 0, 0}, cropShape).castTo(DataType.INT);
SameDiff sd = SameDiff.create();
@ -323,7 +322,8 @@ public class TransformOpValidation extends BaseOpValidation {
INDArray expOut = Nd4j.create(DataType.DOUBLE, 1, 2, 2, 1);
DynamicCustomOp op = DynamicCustomOp.builder("batch_to_space")
.addInputs(input, blocks, crops)
.addInputs(input, crops)
.addIntegerArguments(2)
.addOutputs(expOut).build();
Nd4j.getExecutioner().exec(op);
@ -340,7 +340,7 @@ public class TransformOpValidation extends BaseOpValidation {
@Test
public void testSpaceToBatch() {
OpValidationSuite.ignoreFailing(); //TODO: https://github.com/deeplearning4j/deeplearning4j/issues/6863
//OpValidationSuite.ignoreFailing(); //TODO: https://github.com/deeplearning4j/deeplearning4j/issues/6863
Nd4j.getRandom().setSeed(7331);
@ -352,7 +352,6 @@ public class TransformOpValidation extends BaseOpValidation {
int[] paddingShape = new int[]{M, 2};
INDArray input = Nd4j.randn(inputShape).castTo(DataType.DOUBLE);
INDArray blocks = Nd4j.create(new float[]{2, 2}, blockShape).castTo(DataType.INT);
INDArray padding = Nd4j.create(new float[]{0, 0, 0, 0}, paddingShape).castTo(DataType.INT);
SameDiff sd = SameDiff.create();
@ -361,7 +360,8 @@ public class TransformOpValidation extends BaseOpValidation {
INDArray expOut = Nd4j.create(DataType.DOUBLE, miniBatch, 1, 1, 1);
DynamicCustomOp op = DynamicCustomOp.builder("space_to_batch")
.addInputs(input, blocks, padding)
.addIntegerArguments(2)
.addInputs(input, padding)
.addOutputs(expOut).build();
Nd4j.getExecutioner().exec(op);

View File

@ -37,6 +37,7 @@ import org.nd4j.linalg.api.ops.impl.image.ResizeBilinear;
import org.nd4j.linalg.api.ops.impl.reduce.MmulBp;
import org.nd4j.linalg.api.ops.impl.shape.Create;
import org.nd4j.linalg.api.ops.impl.shape.OnesLike;
import org.nd4j.linalg.api.ops.impl.shape.SequenceMask;
import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.ModOp;
@ -1737,4 +1738,19 @@ public class CustomOpsTests extends BaseNd4jTest {
assertEquals(expected, ret[0]);
}
@Test
public void testSequenceMask() {
INDArray arr = Nd4j.createFromArray(new int[]{1, 3, 2});
// Test with static max len
int maxlen = 2;
INDArray expected = Nd4j.createFromArray(new int[]{
1,0,0,
1,1,1,
1,1,0
}).reshape(3, 3);
INDArray[] ret = Nd4j.exec(new SequenceMask(arr, maxlen, DataType.INT32));
assertEquals(expected, ret[0]);
}
}

View File

@ -318,8 +318,8 @@ trait NDArrayExtractionTestBase extends FlatSpec { self: OrderingForTest =>
"num2Scalar" should "convert number to Scalar INDArray" in {
assert(1.toScalar.data() == List(1).toNDArray.data())
assert(2f.toScalar.data() == List(2).toNDArray.data())
assert(3d.toScalar.data() == List(3).toNDArray.data())
assert(1.toScalar.reshape(1) == List(1).toNDArray)
assert(2f.toScalar.reshape(1) == List(2f).toNDArray)
assert(3d.toScalar.reshape(1) == List(3d).toNDArray)
}
}