Improving SameDiff tests coverage (#227)
* Gradients tests added * Fix for Standard deviation serialization + test Signed-off-by: Alex Black <blacka101@gmail.com> * More fixes Signed-off-by: Alex Black <blacka101@gmail.com> * Test fixed * Spark config driver host config for CI Signed-off-by: Alex Black <blacka101@gmail.com> * Op validation timeout increase Signed-off-by: Alex Black <blacka101@gmail.com> * Gradient check - fix for low probability test failure due to randomly all 0s mask Signed-off-by: AlexDBlack <blacka101@gmail.com> Co-authored-by: Alex Black <blacka101@gmail.com>
This commit is contained in:
		
							parent
							
								
									5c9e0bc2bb
								
							
						
					
					
						commit
						8c0e378ec3
					
				@ -414,7 +414,11 @@ public class GradientCheckTestsMasking extends BaseDL4JTest {
 | 
			
		||||
        INDArray l = TestUtils.randomOneHot(mb, 3);
 | 
			
		||||
        INDArray lm = TestUtils.randomBernoulli(mb, 1);
 | 
			
		||||
 | 
			
		||||
        assertTrue(lm.sumNumber().intValue() > 0);
 | 
			
		||||
        int attempts = 0;
 | 
			
		||||
        while(attempts++ < 1000 && lm.sumNumber().intValue() == 0){
 | 
			
		||||
            lm = TestUtils.randomBernoulli(mb, 1);
 | 
			
		||||
        }
 | 
			
		||||
        assertTrue("Could not generate non-zero mask after " + attempts + " attempts", lm.sumNumber().intValue() > 0);
 | 
			
		||||
 | 
			
		||||
        boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(f)
 | 
			
		||||
                .labels(l).labelMask(lm));
 | 
			
		||||
@ -467,7 +471,11 @@ public class GradientCheckTestsMasking extends BaseDL4JTest {
 | 
			
		||||
        INDArray l = TestUtils.randomOneHot(mb, 3);
 | 
			
		||||
        INDArray lm = TestUtils.randomBernoulli(mb, 1);
 | 
			
		||||
 | 
			
		||||
        assertTrue(lm.sumNumber().intValue() > 0);
 | 
			
		||||
        int attempts = 0;
 | 
			
		||||
        while(attempts++ < 1000 && lm.sumNumber().intValue() == 0){
 | 
			
		||||
            lm = TestUtils.randomBernoulli(mb, 1);
 | 
			
		||||
        }
 | 
			
		||||
        assertTrue("Could not generate non-zero mask after " + attempts + " attempts", lm.sumNumber().intValue() > 0);
 | 
			
		||||
 | 
			
		||||
        boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(net).inputs(new INDArray[]{f})
 | 
			
		||||
                .labels(new INDArray[]{l}).labelMask(new INDArray[]{lm}));
 | 
			
		||||
 | 
			
		||||
@ -67,7 +67,9 @@ public class SparkSequenceVectorsTest extends BaseDL4JTest {
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        SparkConf sparkConf = new SparkConf().setMaster("local[8]").setAppName("SeqVecTests");
 | 
			
		||||
        SparkConf sparkConf = new SparkConf().setMaster("local[8]")
 | 
			
		||||
                .set("spark.driver.host", "localhost")
 | 
			
		||||
                .setAppName("SeqVecTests");
 | 
			
		||||
        sc = new JavaSparkContext(sparkConf);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -61,7 +61,9 @@ public class SparkWord2VecTest extends BaseDL4JTest {
 | 
			
		||||
            sentences.add("one another sentence");
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        SparkConf sparkConf = new SparkConf().setMaster("local[8]").setAppName("SeqVecTests");
 | 
			
		||||
        SparkConf sparkConf = new SparkConf().setMaster("local[8]")
 | 
			
		||||
                .set("spark.driver.host", "localhost")
 | 
			
		||||
                .setAppName("SeqVecTests");
 | 
			
		||||
        sc = new JavaSparkContext(sparkConf);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -56,7 +56,9 @@ public class Word2VecTest {
 | 
			
		||||
    @Test
 | 
			
		||||
    public void testConcepts() throws Exception {
 | 
			
		||||
        // These are all default values for word2vec
 | 
			
		||||
        SparkConf sparkConf = new SparkConf().setMaster("local[8]").setAppName("sparktest");
 | 
			
		||||
        SparkConf sparkConf = new SparkConf().setMaster("local[8]")
 | 
			
		||||
                .set("spark.driver.host", "localhost")
 | 
			
		||||
                .setAppName("sparktest");
 | 
			
		||||
 | 
			
		||||
        // Set SparkContext
 | 
			
		||||
        JavaSparkContext sc = new JavaSparkContext(sparkConf);
 | 
			
		||||
@ -156,6 +158,7 @@ public class Word2VecTest {
 | 
			
		||||
    @Test
 | 
			
		||||
    public void testSparkW2VonBiggerCorpus() throws Exception {
 | 
			
		||||
        SparkConf sparkConf = new SparkConf().setMaster("local[8]").setAppName("sparktest")
 | 
			
		||||
                .set("spark.driver.host", "localhost")
 | 
			
		||||
                        .set("spark.driver.maxResultSize", "4g").set("spark.driver.memory", "8g")
 | 
			
		||||
                        .set("spark.executor.memory", "8g");
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -63,7 +63,7 @@ public class TextPipelineTest extends BaseSparkTest {
 | 
			
		||||
 | 
			
		||||
    @Before
 | 
			
		||||
    public void before() throws Exception {
 | 
			
		||||
        conf = new SparkConf().setMaster("local[4]").setAppName("sparktest");
 | 
			
		||||
        conf = new SparkConf().setMaster("local[4]").setAppName("sparktest").set("spark.driver.host", "localhost");
 | 
			
		||||
 | 
			
		||||
        // All the avaliable options. These are default values
 | 
			
		||||
        word2vec = new Word2Vec.Builder().minWordFrequency(1).setNGrams(1)
 | 
			
		||||
 | 
			
		||||
@ -85,7 +85,8 @@ public abstract class BaseSparkTest extends BaseDL4JTest implements Serializable
 | 
			
		||||
        if (sc != null)
 | 
			
		||||
            return sc;
 | 
			
		||||
        // set to test mode
 | 
			
		||||
        SparkConf sparkConf = new SparkConf().setMaster("local[" + numExecutors() + "]").set("spark.driver.host", "localhost").setAppName("sparktest");
 | 
			
		||||
        SparkConf sparkConf = new SparkConf().setMaster("local[" + numExecutors() + "]")
 | 
			
		||||
                .set("spark.driver.host", "localhost").setAppName("sparktest");
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        sc = new JavaSparkContext(sparkConf);
 | 
			
		||||
 | 
			
		||||
@ -59,7 +59,9 @@ public class BaseSparkKryoTest extends BaseSparkTest {
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        SparkConf sparkConf = new SparkConf().setMaster("local[" + numExecutors() + "]").setAppName("sparktest");
 | 
			
		||||
        SparkConf sparkConf = new SparkConf().setMaster("local[" + numExecutors() + "]")
 | 
			
		||||
                .setAppName("sparktest")
 | 
			
		||||
                .set("spark.driver.host", "localhost");
 | 
			
		||||
 | 
			
		||||
        sparkConf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer");
 | 
			
		||||
        sparkConf.set("spark.kryo.registrator", "org.nd4j.Nd4jRegistrator");
 | 
			
		||||
 | 
			
		||||
@ -89,7 +89,8 @@ public abstract class BaseSparkTest extends BaseDL4JTest implements Serializable
 | 
			
		||||
        if (sc != null)
 | 
			
		||||
            return sc;
 | 
			
		||||
        // set to test mode
 | 
			
		||||
        SparkConf sparkConf = new SparkConf().setMaster("local[" + numExecutors() + "]").set("spark.driver.host", "localhost").setAppName("sparktest");
 | 
			
		||||
        SparkConf sparkConf = new SparkConf().setMaster("local[" + numExecutors() + "]")
 | 
			
		||||
                .set("spark.driver.host", "localhost").setAppName("sparktest");
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        sc = new JavaSparkContext(sparkConf);
 | 
			
		||||
 | 
			
		||||
@ -72,8 +72,9 @@ public class TestKryoWarning {
 | 
			
		||||
    @Ignore
 | 
			
		||||
    public void testKryoMessageMLNIncorrectConfig() {
 | 
			
		||||
        //Should print warning message
 | 
			
		||||
        SparkConf sparkConf = new SparkConf().setMaster("local[*]").setAppName("sparktest").set("spark.serializer",
 | 
			
		||||
                        "org.apache.spark.serializer.KryoSerializer");
 | 
			
		||||
        SparkConf sparkConf = new SparkConf().setMaster("local[*]").setAppName("sparktest")
 | 
			
		||||
                .set("spark.driver.host", "localhost")
 | 
			
		||||
                .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer");
 | 
			
		||||
 | 
			
		||||
        doTestMLN(sparkConf);
 | 
			
		||||
    }
 | 
			
		||||
@ -83,6 +84,7 @@ public class TestKryoWarning {
 | 
			
		||||
    public void testKryoMessageMLNCorrectConfigKryo() {
 | 
			
		||||
        //Should NOT print warning message
 | 
			
		||||
        SparkConf sparkConf = new SparkConf().setMaster("local[*]").setAppName("sparktest")
 | 
			
		||||
                        .set("spark.driver.host", "localhost")
 | 
			
		||||
                        .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
 | 
			
		||||
                        .set("spark.kryo.registrator", "org.nd4j.Nd4jRegistrator");
 | 
			
		||||
 | 
			
		||||
@ -93,7 +95,9 @@ public class TestKryoWarning {
 | 
			
		||||
    @Ignore
 | 
			
		||||
    public void testKryoMessageMLNCorrectConfigNoKryo() {
 | 
			
		||||
        //Should NOT print warning message
 | 
			
		||||
        SparkConf sparkConf = new SparkConf().setMaster("local[*]").setAppName("sparktest");
 | 
			
		||||
        SparkConf sparkConf = new SparkConf().setMaster("local[*]")
 | 
			
		||||
                .set("spark.driver.host", "localhost")
 | 
			
		||||
                .setAppName("sparktest");
 | 
			
		||||
 | 
			
		||||
        doTestMLN(sparkConf);
 | 
			
		||||
    }
 | 
			
		||||
@ -104,8 +108,9 @@ public class TestKryoWarning {
 | 
			
		||||
    @Ignore
 | 
			
		||||
    public void testKryoMessageCGIncorrectConfig() {
 | 
			
		||||
        //Should print warning message
 | 
			
		||||
        SparkConf sparkConf = new SparkConf().setMaster("local[*]").setAppName("sparktest").set("spark.serializer",
 | 
			
		||||
                        "org.apache.spark.serializer.KryoSerializer");
 | 
			
		||||
        SparkConf sparkConf = new SparkConf().setMaster("local[*]").setAppName("sparktest")
 | 
			
		||||
                .set("spark.driver.host", "localhost")
 | 
			
		||||
                .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer");
 | 
			
		||||
 | 
			
		||||
        doTestCG(sparkConf);
 | 
			
		||||
    }
 | 
			
		||||
@ -115,6 +120,7 @@ public class TestKryoWarning {
 | 
			
		||||
    public void testKryoMessageCGCorrectConfigKryo() {
 | 
			
		||||
        //Should NOT print warning message
 | 
			
		||||
        SparkConf sparkConf = new SparkConf().setMaster("local[*]").setAppName("sparktest")
 | 
			
		||||
                        .set("spark.driver.host", "localhost")
 | 
			
		||||
                        .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
 | 
			
		||||
                        .set("spark.kryo.registrator", "org.nd4j.Nd4jRegistrator");
 | 
			
		||||
 | 
			
		||||
@ -125,7 +131,9 @@ public class TestKryoWarning {
 | 
			
		||||
    @Ignore
 | 
			
		||||
    public void testKryoMessageCGCorrectConfigNoKryo() {
 | 
			
		||||
        //Should NOT print warning message
 | 
			
		||||
        SparkConf sparkConf = new SparkConf().setMaster("local[*]").setAppName("sparktest");
 | 
			
		||||
        SparkConf sparkConf = new SparkConf().setMaster("local[*]")
 | 
			
		||||
                .set("spark.driver.host", "localhost")
 | 
			
		||||
                .setAppName("sparktest");
 | 
			
		||||
 | 
			
		||||
        doTestCG(sparkConf);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@ -138,6 +138,7 @@ public class TestCompareParameterAveragingSparkVsSingleMachine {
 | 
			
		||||
        SparkConf sparkConf = new SparkConf();
 | 
			
		||||
        sparkConf.setMaster("local[" + nWorkers + "]");
 | 
			
		||||
        sparkConf.setAppName("Test");
 | 
			
		||||
        sparkConf.set("spark.driver.host", "localhost");
 | 
			
		||||
 | 
			
		||||
        JavaSparkContext sc = new JavaSparkContext(sparkConf);
 | 
			
		||||
        return sc;
 | 
			
		||||
 | 
			
		||||
@ -58,7 +58,7 @@ public class ExportSupportTest {
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    private void assertSupported(SparkConf conf) throws IOException {
 | 
			
		||||
        JavaSparkContext sc = new JavaSparkContext(conf.setAppName("Test"));
 | 
			
		||||
        JavaSparkContext sc = new JavaSparkContext(conf.setAppName("Test").set("spark.driver.host", "localhost"));
 | 
			
		||||
        try {
 | 
			
		||||
            assertTrue(ExportSupport.exportSupported(sc));
 | 
			
		||||
        } finally {
 | 
			
		||||
 | 
			
		||||
@ -25,6 +25,7 @@ import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
 | 
			
		||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
 | 
			
		||||
import org.deeplearning4j.nn.conf.layers.DenseLayer;
 | 
			
		||||
import org.deeplearning4j.nn.conf.layers.OutputLayer;
 | 
			
		||||
import org.deeplearning4j.spark.BaseSparkTest;
 | 
			
		||||
import org.deeplearning4j.spark.api.Repartition;
 | 
			
		||||
import org.deeplearning4j.spark.api.stats.CommonSparkTrainingStats;
 | 
			
		||||
import org.deeplearning4j.spark.api.stats.SparkTrainingStats;
 | 
			
		||||
@ -50,17 +51,13 @@ import static org.junit.Assert.*;
 | 
			
		||||
/**
 | 
			
		||||
 * Created by Alex on 17/06/2016.
 | 
			
		||||
 */
 | 
			
		||||
public class TestTrainingStatsCollection {
 | 
			
		||||
public class TestTrainingStatsCollection extends BaseSparkTest {
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void testStatsCollection() throws Exception {
 | 
			
		||||
        int nWorkers = 4;
 | 
			
		||||
        int nWorkers = numExecutors();
 | 
			
		||||
 | 
			
		||||
        SparkConf sparkConf = new SparkConf();
 | 
			
		||||
        sparkConf.setMaster("local[" + nWorkers + "]");
 | 
			
		||||
        sparkConf.setAppName("Test");
 | 
			
		||||
 | 
			
		||||
        JavaSparkContext sc = new JavaSparkContext(sparkConf);
 | 
			
		||||
        JavaSparkContext sc = getContext();
 | 
			
		||||
 | 
			
		||||
        try {
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -294,6 +294,11 @@ public class DifferentialFunctionFactory {
 | 
			
		||||
        return new ZerosLike(name, sameDiff(), input).outputVariable();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public SDVariable zerosLike(String name, SDVariable input, DataType dataType) {
 | 
			
		||||
        validateDifferentialFunctionsameDiff(input);
 | 
			
		||||
        return new ZerosLike(name, sameDiff(), input, dataType).outputVariable();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public SDVariable create(String name, SDVariable shape, boolean initialize, DataType dataType) {
 | 
			
		||||
        return create(name, shape, 'c', initialize, dataType);
 | 
			
		||||
    }
 | 
			
		||||
@ -1751,12 +1756,12 @@ public class DifferentialFunctionFactory {
 | 
			
		||||
        return new SoftmaxCrossEntropyLossBp(sameDiff(), lossReduce, logits, weights, labels, labelSmoothing).outputVariables();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public SDVariable lossSoftmaxCrossEntropyWithLogits(SDVariable labels, SDVariable logits, SDVariable weights, int classDim) {
 | 
			
		||||
        return new SoftmaxCrossEntropyWithLogitsLoss(sameDiff(), logits, weights, labels, classDim).outputVariable();
 | 
			
		||||
    public SDVariable lossSoftmaxCrossEntropyWithLogits(SDVariable labels, SDVariable logits, int classDim) {
 | 
			
		||||
        return new SoftmaxCrossEntropyWithLogitsLoss(sameDiff(), logits, labels, classDim).outputVariable();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public SDVariable[] lossSoftmaxCrossEntropyWithLogitsBp(SDVariable labels, SDVariable logits, SDVariable weights, int classDim) {
 | 
			
		||||
        return new SoftmaxCrossEntropyWithLogitsLossBp(sameDiff(), logits, weights, labels, classDim).outputVariables();
 | 
			
		||||
    public SDVariable[] lossSoftmaxCrossEntropyWithLogitsBp(SDVariable labels, SDVariable logits, int classDim) {
 | 
			
		||||
        return new SoftmaxCrossEntropyWithLogitsLossBp(sameDiff(), logits, labels, classDim).outputVariables();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public SDVariable lossSparseSoftmaxCrossEntropy(SDVariable logits, SDVariable labels){
 | 
			
		||||
@ -2638,7 +2643,7 @@ public class DifferentialFunctionFactory {
 | 
			
		||||
        return new Polygamma(sameDiff, n,x).outputVariable();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public SDVariable roll(SDVariable input, SDVariable shift) {
 | 
			
		||||
    public SDVariable roll(SDVariable input, int shift) {
 | 
			
		||||
        return new Roll(sameDiff, input, shift).outputVariable();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -787,9 +787,10 @@ public abstract class SDBaseOps {
 | 
			
		||||
     * @param number Number of values to generate
 | 
			
		||||
     * @return SDVariable with linearly spaced elements
 | 
			
		||||
     */
 | 
			
		||||
    public SDVariable linspace(DataType dataType, double start, double stop, long number) {
 | 
			
		||||
    // TODO: fix or remove, currently it is internal recursion
 | 
			
		||||
    /*public SDVariable linspace(DataType dataType, double start, double stop, long number) {
 | 
			
		||||
        return linspace(dataType, start, stop, number);
 | 
			
		||||
    }
 | 
			
		||||
    }*/
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Create a new 1d array with values evenly spaced between values 'start' and 'stop'
 | 
			
		||||
@ -3093,6 +3094,9 @@ public abstract class SDBaseOps {
 | 
			
		||||
        return zerosLike(null, input);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public SDVariable zerosLike(@NonNull SDVariable input, @NonNull DataType dataType) {
 | 
			
		||||
        return zerosLike(null, input, dataType);
 | 
			
		||||
    }
 | 
			
		||||
    /**
 | 
			
		||||
     * Return a variable of all 0s, with the same shape as the input variable. Note that this is dynamic:
 | 
			
		||||
     * if the input shape changes in later execution, the returned variable's shape will also be updated
 | 
			
		||||
@ -3106,6 +3110,10 @@ public abstract class SDBaseOps {
 | 
			
		||||
        return updateVariableNameAndReference(ret, name);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public SDVariable zerosLike(String name, @NonNull SDVariable input, @NonNull DataType dataType) {
 | 
			
		||||
        SDVariable ret = f().zerosLike(name, input, dataType);
 | 
			
		||||
        return updateVariableNameAndReference(ret, name);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * See {@link #any(String, SDVariable, int...)}
 | 
			
		||||
 | 
			
		||||
@ -2545,7 +2545,7 @@ public class SDMath extends SDOps {
 | 
			
		||||
     * @param shift number of places to shift elements
 | 
			
		||||
     * @return array
 | 
			
		||||
     */
 | 
			
		||||
    public SDVariable roll(String name, SDVariable input, SDVariable shift) {
 | 
			
		||||
    public SDVariable roll(String name, SDVariable input, int shift) {
 | 
			
		||||
        SDVariable res = f().roll(input,shift);
 | 
			
		||||
        return updateVariableNameAndReference(res, name);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@ -815,8 +815,9 @@ public class FlatBuffersMapper {
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        int[] dims;
 | 
			
		||||
        if (node.opType() == Op.Type.REDUCE_FLOAT || node.opType() == Op.Type.REDUCE_SAME || node.opType() == Op.Type.REDUCE_BOOL
 | 
			
		||||
                || node.opType() == Op.Type.REDUCE_LONG || node.opType() == Op.Type.INDEXREDUCE || node.opType() == Op.Type.REDUCE3) {
 | 
			
		||||
        Type t = node.opType();
 | 
			
		||||
        if (t == Op.Type.REDUCE_FLOAT || t == Op.Type.REDUCE_SAME || t == Op.Type.REDUCE_BOOL
 | 
			
		||||
                || t == Op.Type.REDUCE_LONG || t == Op.Type.INDEXREDUCE || t == Op.Type.REDUCE3 || t == Type.VARIANCE || t == Type.SUMMARYSTATS) {
 | 
			
		||||
            dims = node.getDimensions();
 | 
			
		||||
            if (dims == null)
 | 
			
		||||
                dims = new int[0];
 | 
			
		||||
 | 
			
		||||
@ -16,6 +16,7 @@
 | 
			
		||||
 | 
			
		||||
package org.nd4j.autodiff.validation;
 | 
			
		||||
 | 
			
		||||
import org.nd4j.linalg.api.ops.custom.*;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.reduce.HashCode;
 | 
			
		||||
@ -38,10 +39,6 @@ import org.nd4j.linalg.api.iter.NdIndexIterator;
 | 
			
		||||
import org.nd4j.linalg.api.ndarray.INDArray;
 | 
			
		||||
import org.nd4j.linalg.api.ops.CustomOpDescriptor;
 | 
			
		||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
 | 
			
		||||
import org.nd4j.linalg.api.ops.custom.BarnesEdgeForces;
 | 
			
		||||
import org.nd4j.linalg.api.ops.custom.BarnesHutGains;
 | 
			
		||||
import org.nd4j.linalg.api.ops.custom.BarnesHutSymmetrize;
 | 
			
		||||
import org.nd4j.linalg.api.ops.custom.SpTreeCell;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.broadcast.bool.*;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.loss.bp.*;
 | 
			
		||||
@ -1011,7 +1008,10 @@ public class OpValidation {
 | 
			
		||||
                SpTreeCell.class,
 | 
			
		||||
                CbowRound.class,
 | 
			
		||||
                SkipGramRound.class,
 | 
			
		||||
                HashCode.class
 | 
			
		||||
                HashCode.class,
 | 
			
		||||
                HashCode.class,
 | 
			
		||||
                BitCast.class,
 | 
			
		||||
                ToggleBits.class
 | 
			
		||||
        );
 | 
			
		||||
 | 
			
		||||
        return new HashSet<>(list);
 | 
			
		||||
 | 
			
		||||
@ -200,7 +200,6 @@ public class ImportClassMapping {
 | 
			
		||||
            org.nd4j.linalg.api.ops.impl.reduce.custom.BatchMmul.class,
 | 
			
		||||
            org.nd4j.linalg.api.ops.impl.reduce.custom.LogSumExp.class,
 | 
			
		||||
            org.nd4j.linalg.api.ops.impl.reduce.floating.AMean.class,
 | 
			
		||||
            org.nd4j.linalg.api.ops.impl.reduce.floating.Bias.class,
 | 
			
		||||
            org.nd4j.linalg.api.ops.impl.reduce.floating.Entropy.class,
 | 
			
		||||
            org.nd4j.linalg.api.ops.impl.reduce.floating.LogEntropy.class,
 | 
			
		||||
            org.nd4j.linalg.api.ops.impl.reduce.floating.Mean.class,
 | 
			
		||||
 | 
			
		||||
@ -16,21 +16,28 @@
 | 
			
		||||
 | 
			
		||||
package org.nd4j.linalg.api.ops.custom;
 | 
			
		||||
 | 
			
		||||
import lombok.Data;
 | 
			
		||||
import lombok.NoArgsConstructor;
 | 
			
		||||
import lombok.val;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SDVariable;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SameDiff;
 | 
			
		||||
import org.nd4j.base.Preconditions;
 | 
			
		||||
import org.nd4j.linalg.api.buffer.DataType;
 | 
			
		||||
import org.nd4j.linalg.api.ndarray.INDArray;
 | 
			
		||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
 | 
			
		||||
 | 
			
		||||
import java.util.Arrays;
 | 
			
		||||
import java.util.List;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * This op takes arbitrary number of arrays as input, and returns single "flattened" vector
 | 
			
		||||
 *
 | 
			
		||||
 * @author raver119@gmail.com
 | 
			
		||||
 */
 | 
			
		||||
@Data
 | 
			
		||||
@NoArgsConstructor
 | 
			
		||||
public class Flatten extends DynamicCustomOp {
 | 
			
		||||
    private char order;
 | 
			
		||||
 | 
			
		||||
    public Flatten() {
 | 
			
		||||
        //
 | 
			
		||||
    }
 | 
			
		||||
    private int order;
 | 
			
		||||
 | 
			
		||||
    public Flatten(char order, INDArray... inputs) {
 | 
			
		||||
        this.order = order;
 | 
			
		||||
@ -47,10 +54,21 @@ public class Flatten extends DynamicCustomOp {
 | 
			
		||||
        outputArguments.add(output);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public Flatten(SameDiff sameDiff, char order, SDVariable... inputs) {
 | 
			
		||||
        super(sameDiff, inputs);
 | 
			
		||||
        this.order = order;
 | 
			
		||||
        addIArgument(order);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public String opName() {
 | 
			
		||||
        return "flatten";
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
 | 
			
		||||
        int n = args().length;
 | 
			
		||||
        Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
 | 
			
		||||
        return Arrays.asList(inputDataTypes.get(0));
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -51,6 +51,14 @@ public class FusedBatchNorm extends DynamicCustomOp {
 | 
			
		||||
    public FusedBatchNorm(@NonNull SameDiff sameDiff, @NonNull SDVariable x, @NonNull SDVariable scale, @NonNull SDVariable offset,
 | 
			
		||||
                          @NonNull SDVariable dataFormat, @NonNull SDVariable isTraining) {
 | 
			
		||||
        super("", sameDiff, new SDVariable[]{x, scale, offset, dataFormat, isTraining});
 | 
			
		||||
        this.outputDataType = x.dataType();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public FusedBatchNorm(@NonNull SameDiff sameDiff, @NonNull SDVariable x, @NonNull SDVariable scale, @NonNull SDVariable offset,
 | 
			
		||||
                          int dataFormat, int isTraining) {
 | 
			
		||||
        super("", sameDiff, new SDVariable[]{x, scale, offset});
 | 
			
		||||
        addIArgument(dataFormat, isTraining);
 | 
			
		||||
        this.outputDataType = x.dataType();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
@ -78,6 +86,8 @@ public class FusedBatchNorm extends DynamicCustomOp {
 | 
			
		||||
    public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
 | 
			
		||||
        int n = args().length;
 | 
			
		||||
        Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
 | 
			
		||||
        return Arrays.asList(outputDataType, DataType.FLOAT, DataType.FLOAT);   //Activations may be half, bfloat16, float32; mean/var is always float
 | 
			
		||||
        return Arrays.asList(outputDataType == null ? DataType.FLOAT : outputDataType,
 | 
			
		||||
                outputDataType == null ? DataType.FLOAT : outputDataType,
 | 
			
		||||
                outputDataType == null ? DataType.FLOAT : outputDataType);
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -64,6 +64,6 @@ public class Lu extends DynamicCustomOp {
 | 
			
		||||
    public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
 | 
			
		||||
        int n = args().length;
 | 
			
		||||
        Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
 | 
			
		||||
        return Arrays.asList(inputDataTypes.get(0), indexDataType);
 | 
			
		||||
        return Arrays.asList(inputDataTypes.get(0), indexDataType == null ? DataType.INT32 : indexDataType);
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -46,6 +46,11 @@ public class MatrixBandPart extends DynamicCustomOp {
 | 
			
		||||
        super("", sameDiff, new SDVariable[]{input, minLower, maxUpper});
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public MatrixBandPart(@NonNull SameDiff sameDiff, @NonNull SDVariable input, int minLower, int maxUpper) {
 | 
			
		||||
        super("", sameDiff, new SDVariable[]{input});
 | 
			
		||||
        addIArgument(minLower, maxUpper);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public String opName() {
 | 
			
		||||
        return "matrix_band_part";
 | 
			
		||||
 | 
			
		||||
@ -45,6 +45,15 @@ public class Roll extends DynamicCustomOp {
 | 
			
		||||
        super("", sameDiff, new SDVariable[]{input,shift});
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public Roll(@NonNull SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable axes, @NonNull SDVariable shift) {
 | 
			
		||||
        super("", sameDiff, new SDVariable[]{input,axes,shift});
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public Roll(@NonNull SameDiff sameDiff, @NonNull SDVariable input, int shift) {
 | 
			
		||||
        super("", sameDiff, new SDVariable[]{input});
 | 
			
		||||
        addIArgument(shift);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public String opName() {
 | 
			
		||||
        return "roll";
 | 
			
		||||
 | 
			
		||||
@ -7,9 +7,13 @@ import org.nd4j.base.Preconditions;
 | 
			
		||||
import org.nd4j.linalg.api.buffer.DataType;
 | 
			
		||||
import org.nd4j.linalg.api.ndarray.INDArray;
 | 
			
		||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
 | 
			
		||||
import org.tensorflow.framework.AttrValue;
 | 
			
		||||
import org.tensorflow.framework.GraphDef;
 | 
			
		||||
import org.tensorflow.framework.NodeDef;
 | 
			
		||||
 | 
			
		||||
import java.util.Collections;
 | 
			
		||||
import java.util.List;
 | 
			
		||||
import java.util.Map;
 | 
			
		||||
 | 
			
		||||
@NoArgsConstructor
 | 
			
		||||
public class TriangularSolve extends DynamicCustomOp {
 | 
			
		||||
@ -24,11 +28,27 @@ public class TriangularSolve extends DynamicCustomOp {
 | 
			
		||||
        super(sameDiff, new SDVariable[] {matrix, rhs, lower, adjoint});
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public TriangularSolve(SameDiff sameDiff, SDVariable matrix, SDVariable rhs,
 | 
			
		||||
                           boolean lower, boolean adjoint) {
 | 
			
		||||
        super(sameDiff, new SDVariable[] {matrix, rhs});
 | 
			
		||||
        addBArgument(lower, adjoint);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public String opName() {
 | 
			
		||||
        return "triangular_solve";
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
 | 
			
		||||
        if(attributesForNode.containsKey("adjoint")){
 | 
			
		||||
            addBArgument(attributesForNode.get("adjoint").getB());
 | 
			
		||||
        }
 | 
			
		||||
        if(attributesForNode.containsKey("lower")){
 | 
			
		||||
            addBArgument(attributesForNode.get("lower").getB());
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public String tensorflowName() {
 | 
			
		||||
        return "MatrixTriangularSolve";
 | 
			
		||||
 | 
			
		||||
@ -16,6 +16,7 @@
 | 
			
		||||
 | 
			
		||||
package org.nd4j.linalg.api.ops.impl.broadcast;
 | 
			
		||||
 | 
			
		||||
import lombok.NoArgsConstructor;
 | 
			
		||||
import lombok.NonNull;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SDVariable;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SameDiff;
 | 
			
		||||
@ -27,6 +28,7 @@ import org.nd4j.linalg.api.ops.DynamicCustomOp;
 | 
			
		||||
import java.util.Arrays;
 | 
			
		||||
import java.util.List;
 | 
			
		||||
 | 
			
		||||
@NoArgsConstructor
 | 
			
		||||
public class BiasAddGrad extends DynamicCustomOp {
 | 
			
		||||
    protected boolean nchw = true;
 | 
			
		||||
 | 
			
		||||
@ -40,7 +42,16 @@ public class BiasAddGrad extends DynamicCustomOp {
 | 
			
		||||
        super(new INDArray[]{input, bias, gradient}, wrapOrNull(output));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public BiasAddGrad() {}
 | 
			
		||||
    public BiasAddGrad(@NonNull INDArray input, @NonNull INDArray bias, @NonNull INDArray gradient,
 | 
			
		||||
                       boolean nchw) {
 | 
			
		||||
        addInputArgument(input, bias, gradient);
 | 
			
		||||
        this.nchw = nchw;
 | 
			
		||||
        addBArgument(nchw);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public BiasAddGrad(@NonNull INDArray input, @NonNull INDArray bias, @NonNull INDArray gradient) {
 | 
			
		||||
        this(input, bias, gradient, false);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public int opNum() {
 | 
			
		||||
 | 
			
		||||
@ -16,6 +16,8 @@
 | 
			
		||||
 | 
			
		||||
package org.nd4j.linalg.api.ops.impl.indexaccum;
 | 
			
		||||
 | 
			
		||||
import lombok.Data;
 | 
			
		||||
import lombok.NoArgsConstructor;
 | 
			
		||||
import lombok.NonNull;
 | 
			
		||||
import org.nd4j.autodiff.functions.DifferentialFunction;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SDVariable;
 | 
			
		||||
@ -35,6 +37,8 @@ import java.util.List;
 | 
			
		||||
 *
 | 
			
		||||
 * @author raver119@gmail.com
 | 
			
		||||
 */
 | 
			
		||||
@Data
 | 
			
		||||
@NoArgsConstructor
 | 
			
		||||
public class FirstIndex extends BaseIndexAccumulation {
 | 
			
		||||
    protected Condition condition;
 | 
			
		||||
    protected double compare;
 | 
			
		||||
@ -50,9 +54,6 @@ public class FirstIndex extends BaseIndexAccumulation {
 | 
			
		||||
        this.extraArgs = new Object[] {compare, eps, (double) mode};
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public FirstIndex() {}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    public FirstIndex(INDArray x, @NonNull Condition condition, int... dimension) {
 | 
			
		||||
        this(x, condition, false, dimension);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@ -16,6 +16,7 @@
 | 
			
		||||
 | 
			
		||||
package org.nd4j.linalg.api.ops.impl.indexaccum;
 | 
			
		||||
 | 
			
		||||
import lombok.Data;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SDVariable;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SameDiff;
 | 
			
		||||
import org.nd4j.linalg.api.ndarray.INDArray;
 | 
			
		||||
@ -30,6 +31,7 @@ import java.util.List;
 | 
			
		||||
 *
 | 
			
		||||
 * @author Adam Gibson
 | 
			
		||||
 */
 | 
			
		||||
@Data
 | 
			
		||||
public class IAMax extends BaseIndexAccumulation {
 | 
			
		||||
    public IAMax(SameDiff sameDiff, SDVariable i_v, boolean keepDims, int[] dimensions) {
 | 
			
		||||
        super(sameDiff, i_v, keepDims, dimensions);
 | 
			
		||||
 | 
			
		||||
@ -16,6 +16,7 @@
 | 
			
		||||
 | 
			
		||||
package org.nd4j.linalg.api.ops.impl.indexaccum;
 | 
			
		||||
 | 
			
		||||
import lombok.Data;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SDVariable;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SameDiff;
 | 
			
		||||
import org.nd4j.linalg.api.ndarray.INDArray;
 | 
			
		||||
@ -30,6 +31,7 @@ import java.util.List;
 | 
			
		||||
 *
 | 
			
		||||
 * @author Adam Gibson
 | 
			
		||||
 */
 | 
			
		||||
@Data
 | 
			
		||||
public class IAMin extends BaseIndexAccumulation {
 | 
			
		||||
    public IAMin(SameDiff sameDiff, SDVariable i_v, boolean keepDims, int[] dimensions) {
 | 
			
		||||
        super(sameDiff, i_v, keepDims, dimensions);
 | 
			
		||||
 | 
			
		||||
@ -16,6 +16,7 @@
 | 
			
		||||
 | 
			
		||||
package org.nd4j.linalg.api.ops.impl.indexaccum;
 | 
			
		||||
 | 
			
		||||
import lombok.Data;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SDVariable;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SameDiff;
 | 
			
		||||
import org.nd4j.imports.NoOpNameFoundException;
 | 
			
		||||
@ -31,6 +32,7 @@ import java.util.List;
 | 
			
		||||
 *
 | 
			
		||||
 * @author Alex Black
 | 
			
		||||
 */
 | 
			
		||||
@Data
 | 
			
		||||
public class IMax extends BaseIndexAccumulation {
 | 
			
		||||
    public IMax(SameDiff sameDiff, SDVariable i_v, boolean keepDims, int[] dimensions) {
 | 
			
		||||
        super(sameDiff, i_v, keepDims, dimensions);
 | 
			
		||||
 | 
			
		||||
@ -16,6 +16,7 @@
 | 
			
		||||
 | 
			
		||||
package org.nd4j.linalg.api.ops.impl.indexaccum;
 | 
			
		||||
 | 
			
		||||
import lombok.Data;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SDVariable;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SameDiff;
 | 
			
		||||
import org.nd4j.imports.NoOpNameFoundException;
 | 
			
		||||
@ -30,6 +31,7 @@ import java.util.List;
 | 
			
		||||
 *
 | 
			
		||||
 * @author Alex Black
 | 
			
		||||
 */
 | 
			
		||||
@Data
 | 
			
		||||
public class IMin extends BaseIndexAccumulation {
 | 
			
		||||
    public IMin(SameDiff sameDiff, SDVariable i_v, boolean keepDims, int[] dimensions) {
 | 
			
		||||
        super(sameDiff, i_v, keepDims, dimensions);
 | 
			
		||||
 | 
			
		||||
@ -16,6 +16,7 @@
 | 
			
		||||
 | 
			
		||||
package org.nd4j.linalg.api.ops.impl.indexaccum;
 | 
			
		||||
 | 
			
		||||
import lombok.Data;
 | 
			
		||||
import lombok.NonNull;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SDVariable;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SameDiff;
 | 
			
		||||
@ -36,6 +37,7 @@ import java.util.Map;
 | 
			
		||||
 *
 | 
			
		||||
 * @author raver119@gmail.com
 | 
			
		||||
 */
 | 
			
		||||
@Data
 | 
			
		||||
public class LastIndex extends BaseIndexAccumulation {
 | 
			
		||||
    protected Condition condition;
 | 
			
		||||
    protected double compare;
 | 
			
		||||
 | 
			
		||||
@ -16,6 +16,7 @@
 | 
			
		||||
 | 
			
		||||
package org.nd4j.linalg.api.ops.impl.indexaccum.custom;
 | 
			
		||||
 | 
			
		||||
import lombok.Data;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SameDiff;
 | 
			
		||||
import org.nd4j.base.Preconditions;
 | 
			
		||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
 | 
			
		||||
@ -29,6 +30,7 @@ import java.util.Collections;
 | 
			
		||||
import java.util.List;
 | 
			
		||||
import java.util.Map;
 | 
			
		||||
 | 
			
		||||
@Data
 | 
			
		||||
public class ArgMax extends DynamicCustomOp {
 | 
			
		||||
 | 
			
		||||
    protected DataType outputType;
 | 
			
		||||
 | 
			
		||||
@ -16,6 +16,7 @@
 | 
			
		||||
 | 
			
		||||
package org.nd4j.linalg.api.ops.impl.indexaccum.custom;
 | 
			
		||||
 | 
			
		||||
import lombok.Data;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SameDiff;
 | 
			
		||||
import org.nd4j.base.Preconditions;
 | 
			
		||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
 | 
			
		||||
@ -34,6 +35,7 @@ import java.util.Map;
 | 
			
		||||
 *
 | 
			
		||||
 * @author Alex Black
 | 
			
		||||
 */
 | 
			
		||||
@Data
 | 
			
		||||
public class ArgMin extends DynamicCustomOp {
 | 
			
		||||
 | 
			
		||||
    protected DataType outputType = DataType.LONG;
 | 
			
		||||
 | 
			
		||||
@ -22,6 +22,7 @@ import org.nd4j.autodiff.samediff.SameDiff;
 | 
			
		||||
import org.nd4j.base.Preconditions;
 | 
			
		||||
import org.nd4j.linalg.api.buffer.DataType;
 | 
			
		||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
 | 
			
		||||
import org.nd4j.linalg.factory.Nd4j;
 | 
			
		||||
 | 
			
		||||
import java.util.Arrays;
 | 
			
		||||
import java.util.Collections;
 | 
			
		||||
@ -38,8 +39,14 @@ public class SoftmaxCrossEntropyWithLogitsLoss extends DynamicCustomOp {
 | 
			
		||||
 | 
			
		||||
    protected int classesDim;
 | 
			
		||||
 | 
			
		||||
    public SoftmaxCrossEntropyWithLogitsLoss(SameDiff sameDiff, SDVariable logits, SDVariable weights, SDVariable labels, int classesDim) {
 | 
			
		||||
        super(null, sameDiff, new SDVariable[]{logits, weights, labels}, false);
 | 
			
		||||
//    public SoftmaxCrossEntropyWithLogitsLoss(SameDiff sameDiff, SDVariable logits, SDVariable weights, SDVariable labels, int classesDim) {
 | 
			
		||||
//        super(null, sameDiff, new SDVariable[]{logits, weights, labels}, false);
 | 
			
		||||
//        this.classesDim = classesDim;
 | 
			
		||||
//        addIArgument(classesDim);
 | 
			
		||||
//    }
 | 
			
		||||
 | 
			
		||||
    public SoftmaxCrossEntropyWithLogitsLoss(SameDiff sameDiff, SDVariable logits, SDVariable labels, int classesDim) {
 | 
			
		||||
        super(null, sameDiff, new SDVariable[]{logits, labels}, false);
 | 
			
		||||
        this.classesDim = classesDim;
 | 
			
		||||
        addIArgument(classesDim);
 | 
			
		||||
    }
 | 
			
		||||
@ -66,7 +73,8 @@ public class SoftmaxCrossEntropyWithLogitsLoss extends DynamicCustomOp {
 | 
			
		||||
    public List<SDVariable> doDiff(List<SDVariable> grad){
 | 
			
		||||
        //No external gradient
 | 
			
		||||
        //Args: logits, weigths, label
 | 
			
		||||
        SDVariable[] grads = f().lossSoftmaxCrossEntropyWithLogitsBp(arg(2), arg(0), arg(1), classesDim);
 | 
			
		||||
        SDVariable[] args = args();
 | 
			
		||||
        SDVariable[] grads = f().lossSoftmaxCrossEntropyWithLogitsBp(arg(0), arg(1), classesDim);
 | 
			
		||||
        return Arrays.asList(grads);
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -20,14 +20,18 @@ import lombok.NoArgsConstructor;
 | 
			
		||||
import org.nd4j.autodiff.loss.LossReduce;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SDVariable;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SameDiff;
 | 
			
		||||
import org.nd4j.base.Preconditions;
 | 
			
		||||
import org.nd4j.imports.NoOpNameFoundException;
 | 
			
		||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
 | 
			
		||||
import org.nd4j.linalg.api.buffer.DataType;
 | 
			
		||||
import org.nd4j.linalg.api.ops.Op;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.loss.BaseLoss;
 | 
			
		||||
import org.tensorflow.framework.AttrValue;
 | 
			
		||||
import org.tensorflow.framework.GraphDef;
 | 
			
		||||
import org.tensorflow.framework.NodeDef;
 | 
			
		||||
 | 
			
		||||
import java.util.Arrays;
 | 
			
		||||
import java.util.List;
 | 
			
		||||
import java.util.Map;
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -56,4 +60,12 @@ public class SoftmaxCrossEntropyLossBp extends BaseLossBp {
 | 
			
		||||
    public String opName() {
 | 
			
		||||
        return "softmax_cross_entropy_loss_grad";
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
 | 
			
		||||
        Preconditions.checkState(inputDataTypes != null && (inputDataTypes.size() == 2 || inputDataTypes.size() == 3),
 | 
			
		||||
                "Expected 2 or 3 input datatypes for %s, got %s", getClass(), inputDataTypes);
 | 
			
		||||
 | 
			
		||||
        return Arrays.asList(inputDataTypes.get(0), inputDataTypes.get(1), inputDataTypes.get(2));    //Same as predictions
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -19,8 +19,10 @@ package org.nd4j.linalg.api.ops.impl.loss.bp;
 | 
			
		||||
import lombok.NoArgsConstructor;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SDVariable;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SameDiff;
 | 
			
		||||
import org.nd4j.linalg.api.buffer.DataType;
 | 
			
		||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
 | 
			
		||||
 | 
			
		||||
import java.util.Arrays;
 | 
			
		||||
import java.util.List;
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -34,8 +36,8 @@ public class SoftmaxCrossEntropyWithLogitsLossBp extends DynamicCustomOp {
 | 
			
		||||
 | 
			
		||||
    protected int classesDim;
 | 
			
		||||
 | 
			
		||||
    public SoftmaxCrossEntropyWithLogitsLossBp(SameDiff sameDiff, SDVariable logits, SDVariable weights, SDVariable labels, int classesDim) {
 | 
			
		||||
        super(null, sameDiff, new SDVariable[]{logits, weights, labels}, false);
 | 
			
		||||
    public SoftmaxCrossEntropyWithLogitsLossBp(SameDiff sameDiff, SDVariable logits, SDVariable labels, int classesDim) {
 | 
			
		||||
        super(null, sameDiff, new SDVariable[]{logits, labels}, false);
 | 
			
		||||
        this.classesDim = classesDim;
 | 
			
		||||
        addIArgument(classesDim);
 | 
			
		||||
    }
 | 
			
		||||
@ -49,4 +51,9 @@ public class SoftmaxCrossEntropyWithLogitsLossBp extends DynamicCustomOp {
 | 
			
		||||
    public List<SDVariable> doDiff(List<SDVariable> grad){
 | 
			
		||||
        throw new UnsupportedOperationException("Differentiation of " + getClass().getName() + " not supported");
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
 | 
			
		||||
        return Arrays.asList(arg(0).dataType(), arg(1).dataType());
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -16,9 +16,12 @@
 | 
			
		||||
 | 
			
		||||
package org.nd4j.linalg.api.ops.impl.reduce;
 | 
			
		||||
 | 
			
		||||
import lombok.NoArgsConstructor;
 | 
			
		||||
import lombok.NonNull;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SDVariable;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SameDiff;
 | 
			
		||||
import org.nd4j.linalg.api.buffer.DataType;
 | 
			
		||||
import org.nd4j.linalg.api.ndarray.INDArray;
 | 
			
		||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
 | 
			
		||||
 | 
			
		||||
import java.util.*;
 | 
			
		||||
@ -30,12 +33,9 @@ import java.util.*;
 | 
			
		||||
 *
 | 
			
		||||
 * @author Alex Black
 | 
			
		||||
 */
 | 
			
		||||
@NoArgsConstructor
 | 
			
		||||
public class SufficientStatistics extends DynamicCustomOp {
 | 
			
		||||
 | 
			
		||||
    public SufficientStatistics() {
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    public SufficientStatistics(SameDiff sameDiff, @NonNull SDVariable x, @NonNull SDVariable axis, SDVariable shift) {
 | 
			
		||||
        super(null, sameDiff, argsNoNull(x, axis, shift), false);
 | 
			
		||||
    }
 | 
			
		||||
@ -48,14 +48,30 @@ public class SufficientStatistics extends DynamicCustomOp {
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public SufficientStatistics(@NonNull INDArray x, @NonNull INDArray axes, INDArray shift) {
 | 
			
		||||
        if (shift != null)
 | 
			
		||||
            addInputArgument(x, axes, shift);
 | 
			
		||||
        else
 | 
			
		||||
            addInputArgument(x, axes);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public SufficientStatistics(@NonNull INDArray x, @NonNull INDArray axes) {
 | 
			
		||||
        this(x,axes,null);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public String opName() {
 | 
			
		||||
        return "sufficient_statistics";
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public List<SDVariable> doDiff(List<SDVariable> grad) {
 | 
			
		||||
        throw new UnsupportedOperationException("Backprop not yet implemented for op: " + getClass().getSimpleName());
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes) {
 | 
			
		||||
        // FIXME
 | 
			
		||||
        return Arrays.asList(inputDataTypes.get(0), inputDataTypes.get(0),inputDataTypes.get(0));
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -111,7 +111,7 @@ public class TensorMmul extends DynamicCustomOp {
 | 
			
		||||
        int[][] deletedAxes = new int[][]{
 | 
			
		||||
                removeIndex(aAxes, sumAxes[0]),
 | 
			
		||||
                removeIndex(bAxes, sumAxes[1])};
 | 
			
		||||
        int[] gAxes = range(0, i_v1.get(0).getShape().length);
 | 
			
		||||
        int[] gAxes = range(0, i_v1.get(0).eval().shape().length);
 | 
			
		||||
        int[][] firstAxes = new int[][]{
 | 
			
		||||
                Arrays.copyOfRange(gAxes, deletedAxes[0].length, gAxes.length),
 | 
			
		||||
                deletedAxes[1]
 | 
			
		||||
@ -144,18 +144,20 @@ public class TensorMmul extends DynamicCustomOp {
 | 
			
		||||
                                    int[][] axes) {
 | 
			
		||||
 | 
			
		||||
        int validationLength = Math.min(axes[0].length, axes[1].length);
 | 
			
		||||
        INDArray aArray = a.eval();
 | 
			
		||||
        INDArray bArray = b.eval();
 | 
			
		||||
        for (int i = 0; i < validationLength; i++) {
 | 
			
		||||
            if (a.getShape()[axes[0][i]] != b.getShape()[axes[1][i]])
 | 
			
		||||
            if (aArray.shape()[axes[0][i]] != bArray.shape()[axes[1][i]])
 | 
			
		||||
                throw new IllegalArgumentException("Size of the given axes at each dimension must be the same size.");
 | 
			
		||||
            if (axes[0][i] < 0)
 | 
			
		||||
                axes[0][i] += a.getShape().length;
 | 
			
		||||
                axes[0][i] += aArray.shape().length;
 | 
			
		||||
            if (axes[1][i] < 0)
 | 
			
		||||
                axes[1][i] += b.getShape().length;
 | 
			
		||||
                axes[1][i] += bArray.shape().length;
 | 
			
		||||
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        List<Integer> listA = new ArrayList<>();
 | 
			
		||||
        for (int i = 0; i < a.getShape().length; i++) {
 | 
			
		||||
        for (int i = 0; i < aArray.shape().length; i++) {
 | 
			
		||||
            if (!Ints.contains(axes[0], i))
 | 
			
		||||
                listA.add(i);
 | 
			
		||||
        }
 | 
			
		||||
@ -164,7 +166,7 @@ public class TensorMmul extends DynamicCustomOp {
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        List<Integer> listB = new ArrayList<>();
 | 
			
		||||
        for (int i = 0; i < b.getShape().length; i++) {
 | 
			
		||||
        for (int i = 0; i < bArray.shape().length; i++) {
 | 
			
		||||
            if (!Ints.contains(axes[1], i))
 | 
			
		||||
                listB.add(i);
 | 
			
		||||
        }
 | 
			
		||||
@ -172,9 +174,9 @@ public class TensorMmul extends DynamicCustomOp {
 | 
			
		||||
        int[] newAxesB = Ints.concat(axes[1], Ints.toArray(listB));
 | 
			
		||||
 | 
			
		||||
        int n2 = 1;
 | 
			
		||||
        int aLength = Math.min(a.getShape().length, axes[0].length);
 | 
			
		||||
        int aLength = Math.min(aArray.shape().length, axes[0].length);
 | 
			
		||||
        for (int i = 0; i < aLength; i++) {
 | 
			
		||||
            n2 *= a.getShape()[axes[0][i]];
 | 
			
		||||
            n2 *= aArray.shape()[axes[0][i]];
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        //if listA and listB are empty these do not initialize.
 | 
			
		||||
@ -186,13 +188,13 @@ public class TensorMmul extends DynamicCustomOp {
 | 
			
		||||
        } else {
 | 
			
		||||
            oldShapeA = Longs.toArray(listA);
 | 
			
		||||
            for (int i = 0; i < oldShapeA.length; i++)
 | 
			
		||||
                oldShapeA[i] = a.getShape()[(int) oldShapeA[i]];
 | 
			
		||||
                oldShapeA[i] = aArray.shape()[(int) oldShapeA[i]];
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        int n3 = 1;
 | 
			
		||||
        int bNax = Math.min(b.getShape().length, axes[1].length);
 | 
			
		||||
        int bNax = Math.min(bArray.shape().length, axes[1].length);
 | 
			
		||||
        for (int i = 0; i < bNax; i++) {
 | 
			
		||||
            n3 *= b.getShape()[axes[1][i]];
 | 
			
		||||
            n3 *= bArray.shape()[axes[1][i]];
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -203,7 +205,7 @@ public class TensorMmul extends DynamicCustomOp {
 | 
			
		||||
        } else {
 | 
			
		||||
            oldShapeB = Longs.toArray(listB);
 | 
			
		||||
            for (int i = 0; i < oldShapeB.length; i++)
 | 
			
		||||
                oldShapeB[i] = b.getShape()[(int) oldShapeB[i]];
 | 
			
		||||
                oldShapeB[i] = bArray.shape()[(int) oldShapeB[i]];
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -16,6 +16,7 @@
 | 
			
		||||
 | 
			
		||||
package org.nd4j.linalg.api.ops.impl.reduce.bp;
 | 
			
		||||
 | 
			
		||||
import lombok.NoArgsConstructor;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SDVariable;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SameDiff;
 | 
			
		||||
import org.nd4j.base.Preconditions;
 | 
			
		||||
@ -30,7 +31,7 @@ import java.util.List;
 | 
			
		||||
/**
 | 
			
		||||
 * @author Alex Black
 | 
			
		||||
 */
 | 
			
		||||
 | 
			
		||||
@NoArgsConstructor
 | 
			
		||||
public abstract class BaseReductionBp extends DynamicCustomOp {
 | 
			
		||||
 | 
			
		||||
    protected boolean keepDims;
 | 
			
		||||
@ -96,7 +97,12 @@ public abstract class BaseReductionBp extends DynamicCustomOp {
 | 
			
		||||
        addArgs();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public BaseReductionBp(){}
 | 
			
		||||
    public BaseReductionBp(INDArray origInput1, INDArray origInput2, INDArray gradAtOutput, INDArray output1, INDArray output2, boolean keepDims, int... dimensions){
 | 
			
		||||
        super(null, new INDArray[]{origInput1, origInput2, gradAtOutput}, new INDArray[]{output1, output2});
 | 
			
		||||
        this.keepDims = keepDims;
 | 
			
		||||
        this.dimensions = dimensions;
 | 
			
		||||
        addArgs();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    protected void addArgs(){
 | 
			
		||||
        addTArgument(keepDims ? 1 : 0);
 | 
			
		||||
 | 
			
		||||
@ -16,17 +16,23 @@
 | 
			
		||||
 | 
			
		||||
package org.nd4j.linalg.api.ops.impl.reduce.bp;
 | 
			
		||||
 | 
			
		||||
import lombok.NoArgsConstructor;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SDVariable;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SameDiff;
 | 
			
		||||
import org.nd4j.base.Preconditions;
 | 
			
		||||
import org.nd4j.linalg.api.buffer.DataType;
 | 
			
		||||
import org.nd4j.linalg.api.ndarray.INDArray;
 | 
			
		||||
 | 
			
		||||
import java.util.Arrays;
 | 
			
		||||
import java.util.List;
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Backprop op for Dot pairwise reduction operation
 | 
			
		||||
 *
 | 
			
		||||
 * @author Alex Black
 | 
			
		||||
 */
 | 
			
		||||
 | 
			
		||||
@NoArgsConstructor
 | 
			
		||||
public class DotBp extends BaseReductionBp {
 | 
			
		||||
 | 
			
		||||
    public DotBp(SameDiff sameDiff, SDVariable origInput1, SDVariable origInput2, SDVariable gradAtOutput, boolean keepDims, int... dimensions) {
 | 
			
		||||
@ -37,10 +43,22 @@ public class DotBp extends BaseReductionBp {
 | 
			
		||||
        super(origInput1, origInput2, gradAtOutput, output, keepDims, dimensions);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public DotBp(){}
 | 
			
		||||
    public DotBp(INDArray origInput1, INDArray origInput2, INDArray gradAtOutput,
 | 
			
		||||
                 INDArray outputX, INDArray outputY, boolean keepDims, int... dimensions) {
 | 
			
		||||
        super(origInput1, origInput2, gradAtOutput, outputX, outputY, keepDims, dimensions);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public String opName() {
 | 
			
		||||
        return "reduce_dot_bp";
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
 | 
			
		||||
        Preconditions.checkState(dataTypes != null && dataTypes.size() == 3, "Expected exactly 3 input datatype for %s, got input %s", getClass(), dataTypes);
 | 
			
		||||
        Preconditions.checkState(dataTypes.get(0).isFPType(), "First input must be a floating point type, got %s", dataTypes.get(0));
 | 
			
		||||
        Preconditions.checkState(dataTypes.get(1).isFPType(), "Second input (gradient at reduction output) must be a floating point type, got %s", dataTypes.get(1));
 | 
			
		||||
        Preconditions.checkState(dataTypes.get(2).isFPType(), "Second input (gradient at reduction output) must be a floating point type, got %s", dataTypes.get(2));
 | 
			
		||||
        return Arrays.asList(dataTypes.get(0), dataTypes.get(0));
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -1,86 +0,0 @@
 | 
			
		||||
/*******************************************************************************
 | 
			
		||||
 * Copyright (c) 2015-2018 Skymind, Inc.
 | 
			
		||||
 *
 | 
			
		||||
 * This program and the accompanying materials are made available under the
 | 
			
		||||
 * terms of the Apache License, Version 2.0 which is available at
 | 
			
		||||
 * https://www.apache.org/licenses/LICENSE-2.0.
 | 
			
		||||
 *
 | 
			
		||||
 * Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 | 
			
		||||
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 | 
			
		||||
 * License for the specific language governing permissions and limitations
 | 
			
		||||
 * under the License.
 | 
			
		||||
 *
 | 
			
		||||
 * SPDX-License-Identifier: Apache-2.0
 | 
			
		||||
 ******************************************************************************/
 | 
			
		||||
 | 
			
		||||
package org.nd4j.linalg.api.ops.impl.reduce.floating;
 | 
			
		||||
 | 
			
		||||
import org.nd4j.autodiff.samediff.SDVariable;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SameDiff;
 | 
			
		||||
import org.nd4j.imports.NoOpNameFoundException;
 | 
			
		||||
import org.nd4j.linalg.api.ndarray.INDArray;
 | 
			
		||||
import org.nd4j.linalg.api.ops.BaseReduceFloatOp;
 | 
			
		||||
import org.nd4j.linalg.api.ops.BaseReduceOp;
 | 
			
		||||
 | 
			
		||||
import java.util.LinkedHashMap;
 | 
			
		||||
import java.util.List;
 | 
			
		||||
import java.util.Map;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Calculate a bias
 | 
			
		||||
 *
 | 
			
		||||
 * @author Adam Gibson
 | 
			
		||||
 */
 | 
			
		||||
public class Bias extends BaseReduceFloatOp {
 | 
			
		||||
 | 
			
		||||
    private double mean;
 | 
			
		||||
 | 
			
		||||
    public Bias(SameDiff sameDiff, SDVariable i_v, int[] dimensions, double mean) {
 | 
			
		||||
        super(sameDiff, i_v, dimensions);
 | 
			
		||||
        this.mean = mean;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public Bias(SameDiff sameDiff, SDVariable i_v, SDVariable i_v2, int[] dimensions, double mean) {
 | 
			
		||||
        super(sameDiff, i_v, i_v2, dimensions);
 | 
			
		||||
        this.mean = mean;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public Bias() {}
 | 
			
		||||
 | 
			
		||||
    public Bias(INDArray x, int... dimensions) {
 | 
			
		||||
        super(x, dimensions);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public Map<String, Object> propertiesForFunction() {
 | 
			
		||||
        Map<String,Object> ret = new LinkedHashMap<>();
 | 
			
		||||
        ret.put("mean",mean);
 | 
			
		||||
        return ret;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public int opNum() {
 | 
			
		||||
        return 2;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public String opName() {
 | 
			
		||||
        return "bias";
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public List<SDVariable> doDiff(List<SDVariable> f1) {
 | 
			
		||||
        return null;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public String onnxName() {
 | 
			
		||||
        throw new NoOpNameFoundException("No onnx op opName found for " +  opName());
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public String tensorflowName() {
 | 
			
		||||
        throw new NoOpNameFoundException("No tensorflow op opName found for " +  opName());
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
@ -24,6 +24,7 @@ import org.nd4j.base.Preconditions;
 | 
			
		||||
import org.nd4j.imports.descriptors.properties.PropertyMapping;
 | 
			
		||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
 | 
			
		||||
import org.nd4j.linalg.api.buffer.DataType;
 | 
			
		||||
import org.nd4j.linalg.api.ndarray.INDArray;
 | 
			
		||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
 | 
			
		||||
import org.tensorflow.framework.AttrValue;
 | 
			
		||||
import org.tensorflow.framework.GraphDef;
 | 
			
		||||
@ -45,6 +46,7 @@ public class SequenceMask extends DynamicCustomOp {
 | 
			
		||||
    public SequenceMask(SameDiff sameDiff, SDVariable input, SDVariable maxLen, DataType dataType) {
 | 
			
		||||
        super(null, sameDiff, new SDVariable[] {input, maxLen}, false);
 | 
			
		||||
        this.dataType = dataType;
 | 
			
		||||
        addDArgument(dataType);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public SequenceMask(SameDiff sameDiff, SDVariable input, int maxLen, DataType dataType) {
 | 
			
		||||
@ -53,11 +55,21 @@ public class SequenceMask extends DynamicCustomOp {
 | 
			
		||||
        this.is_static_maxlen = true;
 | 
			
		||||
        addIArgument(maxLen);
 | 
			
		||||
        this.dataType = dataType;
 | 
			
		||||
        addDArgument(dataType);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public SequenceMask(SameDiff sameDiff, SDVariable input, DataType dataType) {
 | 
			
		||||
        super(null, sameDiff, new SDVariable[] {input}, false);
 | 
			
		||||
        this.dataType = dataType;
 | 
			
		||||
        addDArgument(dataType);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public SequenceMask(INDArray input, int maxLen, DataType dataType) {
 | 
			
		||||
        addInputArgument(input);
 | 
			
		||||
        addIArgument(maxLen);
 | 
			
		||||
        //addIArgument(dataType.toInt());
 | 
			
		||||
        addDArgument(dataType);
 | 
			
		||||
        this.dataType = dataType;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -16,6 +16,7 @@
 | 
			
		||||
 | 
			
		||||
package org.nd4j.linalg.api.ops.impl.shape;
 | 
			
		||||
 | 
			
		||||
import lombok.NoArgsConstructor;
 | 
			
		||||
import lombok.extern.slf4j.Slf4j;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SDVariable;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SameDiff;
 | 
			
		||||
@ -39,23 +40,37 @@ import java.util.Map;
 | 
			
		||||
 * @author Adam Gibson
 | 
			
		||||
 */
 | 
			
		||||
@Slf4j
 | 
			
		||||
@NoArgsConstructor
 | 
			
		||||
public class ZerosLike extends DynamicCustomOp {
 | 
			
		||||
 | 
			
		||||
    protected DataType outputType;    //Allow customizing dtype for TF import
 | 
			
		||||
 | 
			
		||||
    public ZerosLike() {
 | 
			
		||||
    public ZerosLike(String name, SameDiff sameDiff, SDVariable input) {
 | 
			
		||||
        this(name, sameDiff, input, false, input.dataType());
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public ZerosLike(String name, SameDiff sameDiff, SDVariable input) {
 | 
			
		||||
        this(name, sameDiff, input, false);
 | 
			
		||||
    public ZerosLike(String name, SameDiff sameDiff, SDVariable input, DataType dataType) {
 | 
			
		||||
        this(name, sameDiff, input, false, dataType);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public ZerosLike(String name, SameDiff sameDiff, SDVariable input, boolean inPlace) {
 | 
			
		||||
        this(name, sameDiff, input, inPlace, input.dataType());
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public ZerosLike(String name, SameDiff sameDiff, SDVariable input, boolean inPlace, DataType dataType) {
 | 
			
		||||
        super(name, sameDiff, new SDVariable[]{input}, inPlace);
 | 
			
		||||
        addDArgument(dataType);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public ZerosLike(INDArray in, INDArray out){
 | 
			
		||||
        this(in, out, in.dataType());
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public ZerosLike(INDArray in, INDArray out, DataType dataType) {
 | 
			
		||||
        super(null, in, out, null, null);
 | 
			
		||||
        if (dataType != null) {
 | 
			
		||||
            addDArgument(dataType);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -22,6 +22,7 @@ import org.nd4j.autodiff.samediff.SDVariable;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SameDiff;
 | 
			
		||||
import org.nd4j.linalg.api.buffer.DataType;
 | 
			
		||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
 | 
			
		||||
import org.nd4j.linalg.factory.Nd4j;
 | 
			
		||||
 | 
			
		||||
import java.util.Arrays;
 | 
			
		||||
import java.util.Collections;
 | 
			
		||||
@ -52,16 +53,13 @@ public class BatchToSpace extends DynamicCustomOp {
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public BatchToSpace(SameDiff sameDiff, SDVariable[] args, int[] blocks, int[][] crops, boolean inPlace) {
 | 
			
		||||
        super(null, sameDiff, args, inPlace);
 | 
			
		||||
        super(null, sameDiff, new SDVariable[]{args[0], sameDiff.constant(Nd4j.createFromArray(crops))}, inPlace);
 | 
			
		||||
 | 
			
		||||
        this.blocks = blocks;
 | 
			
		||||
        this.crops = crops;
 | 
			
		||||
 | 
			
		||||
        for (val b : blocks)
 | 
			
		||||
            addIArgument(b);
 | 
			
		||||
 | 
			
		||||
        for (int e = 0; e < crops.length; e++)
 | 
			
		||||
            addIArgument(crops[e][0], crops[e][1]);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
 | 
			
		||||
@ -22,6 +22,7 @@ import org.nd4j.autodiff.samediff.SDVariable;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SameDiff;
 | 
			
		||||
import org.nd4j.linalg.api.buffer.DataType;
 | 
			
		||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
 | 
			
		||||
import org.nd4j.linalg.factory.Nd4j;
 | 
			
		||||
 | 
			
		||||
import java.util.Arrays;
 | 
			
		||||
import java.util.Collections;
 | 
			
		||||
@ -53,16 +54,12 @@ public class SpaceToBatch extends DynamicCustomOp {
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public SpaceToBatch(SameDiff sameDiff, SDVariable[] args, int[] blocks, int[][] padding, boolean inPlace) {
 | 
			
		||||
        super(null, sameDiff, args, inPlace);
 | 
			
		||||
        super(null, sameDiff, new SDVariable[]{args[0], sameDiff.constant(Nd4j.createFromArray(padding))}, inPlace);
 | 
			
		||||
 | 
			
		||||
        this.blocks = blocks;
 | 
			
		||||
        this.padding = padding;
 | 
			
		||||
 | 
			
		||||
        for (val b : blocks)
 | 
			
		||||
            addIArgument(b);
 | 
			
		||||
 | 
			
		||||
        for (int e = 0; e < padding.length; e++)
 | 
			
		||||
            addIArgument(padding[e][0], padding[e][1]);
 | 
			
		||||
        addIArgument(blocks[0]);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
 | 
			
		||||
@ -58,7 +58,8 @@ public class UnsortedSegmentMax extends DynamicCustomOp {
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
 | 
			
		||||
        Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes);
 | 
			
		||||
        Preconditions.checkState(inputDataTypes != null && (inputDataTypes.size() == 2 || inputDataTypes.size() == 3),
 | 
			
		||||
                "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes);
 | 
			
		||||
        return Collections.singletonList(inputDataTypes.get(0));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -16,6 +16,7 @@
 | 
			
		||||
 | 
			
		||||
package org.nd4j.linalg.api.ops.impl.transforms.segment;
 | 
			
		||||
 | 
			
		||||
import lombok.NoArgsConstructor;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SDVariable;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SameDiff;
 | 
			
		||||
import org.nd4j.base.Preconditions;
 | 
			
		||||
@ -31,6 +32,7 @@ import java.util.List;
 | 
			
		||||
 *
 | 
			
		||||
 * @author Alex Black
 | 
			
		||||
 */
 | 
			
		||||
@NoArgsConstructor
 | 
			
		||||
public class UnsortedSegmentMean extends DynamicCustomOp {
 | 
			
		||||
 | 
			
		||||
    private int numSegments;
 | 
			
		||||
@ -41,8 +43,6 @@ public class UnsortedSegmentMean extends DynamicCustomOp {
 | 
			
		||||
        addIArgument(numSegments);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public UnsortedSegmentMean(){ }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public String opName(){
 | 
			
		||||
        return "unsorted_segment_mean";
 | 
			
		||||
@ -60,7 +60,8 @@ public class UnsortedSegmentMean extends DynamicCustomOp {
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
 | 
			
		||||
        Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes);
 | 
			
		||||
        Preconditions.checkState(inputDataTypes != null && (inputDataTypes.size() == 2 || inputDataTypes.size() == 3),
 | 
			
		||||
                "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes);
 | 
			
		||||
        return Collections.singletonList(inputDataTypes.get(0));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -16,6 +16,7 @@
 | 
			
		||||
 | 
			
		||||
package org.nd4j.linalg.api.ops.impl.transforms.segment;
 | 
			
		||||
 | 
			
		||||
import lombok.NoArgsConstructor;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SDVariable;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SameDiff;
 | 
			
		||||
import org.nd4j.base.Preconditions;
 | 
			
		||||
@ -31,6 +32,7 @@ import java.util.List;
 | 
			
		||||
 *
 | 
			
		||||
 * @author Alex Black
 | 
			
		||||
 */
 | 
			
		||||
@NoArgsConstructor
 | 
			
		||||
public class UnsortedSegmentMin extends DynamicCustomOp {
 | 
			
		||||
 | 
			
		||||
    private int numSegments;
 | 
			
		||||
@ -41,8 +43,6 @@ public class UnsortedSegmentMin extends DynamicCustomOp {
 | 
			
		||||
        addIArgument(numSegments);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public UnsortedSegmentMin(){ }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public String opName(){
 | 
			
		||||
        return "unsorted_segment_min";
 | 
			
		||||
@ -60,7 +60,8 @@ public class UnsortedSegmentMin extends DynamicCustomOp {
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
 | 
			
		||||
        Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes);
 | 
			
		||||
        Preconditions.checkState(inputDataTypes != null && (inputDataTypes.size() == 2 || inputDataTypes.size() == 3),
 | 
			
		||||
                "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes);
 | 
			
		||||
        return Collections.singletonList(inputDataTypes.get(0));
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -16,6 +16,7 @@
 | 
			
		||||
 | 
			
		||||
package org.nd4j.linalg.api.ops.impl.transforms.segment;
 | 
			
		||||
 | 
			
		||||
import lombok.NoArgsConstructor;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SDVariable;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SameDiff;
 | 
			
		||||
import org.nd4j.base.Preconditions;
 | 
			
		||||
@ -31,6 +32,7 @@ import java.util.List;
 | 
			
		||||
 *
 | 
			
		||||
 * @author Alex Black
 | 
			
		||||
 */
 | 
			
		||||
@NoArgsConstructor
 | 
			
		||||
public class UnsortedSegmentProd extends DynamicCustomOp {
 | 
			
		||||
 | 
			
		||||
    private int numSegments;
 | 
			
		||||
@ -41,8 +43,6 @@ public class UnsortedSegmentProd extends DynamicCustomOp {
 | 
			
		||||
        addIArgument(numSegments);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public UnsortedSegmentProd(){ }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public String opName(){
 | 
			
		||||
        return "unsorted_segment_prod";
 | 
			
		||||
@ -60,7 +60,8 @@ public class UnsortedSegmentProd extends DynamicCustomOp {
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
 | 
			
		||||
        Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes);
 | 
			
		||||
        Preconditions.checkState(inputDataTypes != null && (inputDataTypes.size() == 2 || inputDataTypes.size() == 3),
 | 
			
		||||
                "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes);
 | 
			
		||||
        return Collections.singletonList(inputDataTypes.get(0));
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -16,10 +16,12 @@
 | 
			
		||||
 | 
			
		||||
package org.nd4j.linalg.api.ops.impl.transforms.segment;
 | 
			
		||||
 | 
			
		||||
import lombok.NoArgsConstructor;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SDVariable;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SameDiff;
 | 
			
		||||
import org.nd4j.base.Preconditions;
 | 
			
		||||
import org.nd4j.linalg.api.buffer.DataType;
 | 
			
		||||
import org.nd4j.linalg.api.ndarray.INDArray;
 | 
			
		||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
 | 
			
		||||
 | 
			
		||||
import java.util.ArrayList;
 | 
			
		||||
@ -31,18 +33,23 @@ import java.util.List;
 | 
			
		||||
 *
 | 
			
		||||
 * @author Alex Black
 | 
			
		||||
 */
 | 
			
		||||
@NoArgsConstructor
 | 
			
		||||
public class UnsortedSegmentSqrtN extends DynamicCustomOp {
 | 
			
		||||
 | 
			
		||||
    private int numSegments;
 | 
			
		||||
 | 
			
		||||
    public UnsortedSegmentSqrtN(INDArray data, INDArray segmentIds, int numSegments) {
 | 
			
		||||
        addInputArgument(data, segmentIds);
 | 
			
		||||
        addIArgument(numSegments);
 | 
			
		||||
        this.numSegments = numSegments;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public UnsortedSegmentSqrtN(SameDiff sameDiff, SDVariable data, SDVariable segmentIds, int numSegments) {
 | 
			
		||||
        super(null, sameDiff,  new SDVariable[] {data, segmentIds}, false);
 | 
			
		||||
        this.numSegments = numSegments;
 | 
			
		||||
        addIArgument(numSegments);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public UnsortedSegmentSqrtN(){ }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public String opName(){
 | 
			
		||||
        return "unsorted_segment_sqrt_n";
 | 
			
		||||
@ -60,7 +67,8 @@ public class UnsortedSegmentSqrtN extends DynamicCustomOp {
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
 | 
			
		||||
        Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes);
 | 
			
		||||
        Preconditions.checkState(inputDataTypes != null && (inputDataTypes.size() == 2 || inputDataTypes.size() == 3),
 | 
			
		||||
                "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes);
 | 
			
		||||
        List<DataType> out = new ArrayList<>();
 | 
			
		||||
        for( int i=0; i<numSegments; i++ ){
 | 
			
		||||
            out.add(inputDataTypes.get(0));
 | 
			
		||||
 | 
			
		||||
@ -16,6 +16,7 @@
 | 
			
		||||
 | 
			
		||||
package org.nd4j.linalg.api.ops.impl.transforms.segment;
 | 
			
		||||
 | 
			
		||||
import lombok.NoArgsConstructor;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SDVariable;
 | 
			
		||||
import org.nd4j.autodiff.samediff.SameDiff;
 | 
			
		||||
import org.nd4j.base.Preconditions;
 | 
			
		||||
@ -32,6 +33,7 @@ import java.util.List;
 | 
			
		||||
 *
 | 
			
		||||
 * @author Alex Black
 | 
			
		||||
 */
 | 
			
		||||
@NoArgsConstructor
 | 
			
		||||
public class UnsortedSegmentSum extends DynamicCustomOp {
 | 
			
		||||
 | 
			
		||||
    private int numSegments;
 | 
			
		||||
@ -42,8 +44,6 @@ public class UnsortedSegmentSum extends DynamicCustomOp {
 | 
			
		||||
        addIArgument(numSegments);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public UnsortedSegmentSum(){ }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public String opName(){
 | 
			
		||||
        return "unsorted_segment_sum";
 | 
			
		||||
@ -61,7 +61,8 @@ public class UnsortedSegmentSum extends DynamicCustomOp {
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
 | 
			
		||||
        Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes);
 | 
			
		||||
        Preconditions.checkState(inputDataTypes != null && (inputDataTypes.size() == 2 || inputDataTypes.size() == 3),
 | 
			
		||||
                "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes);
 | 
			
		||||
        //TODO Allow customizing output type
 | 
			
		||||
        return Collections.singletonList(Nd4j.defaultFloatingPointType());
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@ -50,6 +50,11 @@ public class LayerOpValidation extends BaseOpValidation {
 | 
			
		||||
        super(backend);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Override
 | 
			
		||||
    public long getTimeoutMilliseconds() {
 | 
			
		||||
        return 90000L;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void testXwPlusB() {
 | 
			
		||||
        Nd4j.getRandom().setSeed(12345);
 | 
			
		||||
@ -319,7 +324,7 @@ public class LayerOpValidation extends BaseOpValidation {
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void testIm2Col() {
 | 
			
		||||
        OpValidationSuite.ignoreFailing();      //TEMPORARY DUE TO JVM CRASH: https://github.com/deeplearning4j/deeplearning4j/issues/6873
 | 
			
		||||
        //OpValidationSuite.ignoreFailing();      //TEMPORARY DUE TO JVM CRASH: https://github.com/deeplearning4j/deeplearning4j/issues/6873
 | 
			
		||||
        Nd4j.getRandom().setSeed(12345);
 | 
			
		||||
 | 
			
		||||
        int[][] inputSizes = new int[][]{{1, 3, 8, 8}, {3, 6, 12, 12}};
 | 
			
		||||
 | 
			
		||||
@ -32,6 +32,9 @@ import org.nd4j.linalg.api.buffer.DataType;
 | 
			
		||||
import org.nd4j.linalg.api.ndarray.INDArray;
 | 
			
		||||
import org.nd4j.linalg.api.ops.CustomOp;
 | 
			
		||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
 | 
			
		||||
import org.nd4j.linalg.api.ops.custom.*;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.broadcast.BiasAdd;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.broadcast.BiasAddGrad;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.controlflow.compat.StopGradient;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.reduce.Mmul;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.shape.DiagPart;
 | 
			
		||||
@ -513,7 +516,7 @@ public class MiscOpValidation extends BaseOpValidation {
 | 
			
		||||
    @Test
 | 
			
		||||
    public void testTrace(){
 | 
			
		||||
        //TODO need to work out how to handle shape_op for scalars...
 | 
			
		||||
        OpValidationSuite.ignoreFailing();
 | 
			
		||||
        //OpValidationSuite.ignoreFailing();
 | 
			
		||||
        Nd4j.getRandom().setSeed(12345);
 | 
			
		||||
        for( int[] inShape : new int[][]{{3,3}}){
 | 
			
		||||
 | 
			
		||||
@ -546,12 +549,15 @@ public class MiscOpValidation extends BaseOpValidation {
 | 
			
		||||
        SDVariable x = sameDiff.var("x", arr);
 | 
			
		||||
        SDVariable y = sameDiff.var("y", arr2);
 | 
			
		||||
        SDVariable result = sameDiff.tensorMmul(x, y, new int[][]{{0}, {1}});
 | 
			
		||||
        assertArrayEquals(ArrayUtil.getTensorMmulShape(new long[]{2, 2, 2}, new long[]{2, 2, 2}, new int[][]{{0}, {1}}), result.getShape());
 | 
			
		||||
        assertEquals(32, sameDiff.numElements());
 | 
			
		||||
        assertArrayEquals(ArrayUtil.getTensorMmulShape(new long[]{2, 2, 2}, new long[]{2, 2, 2}, new int[][]{{0}, {1}}),
 | 
			
		||||
                result.eval().shape());
 | 
			
		||||
        assertEquals(16, sameDiff.numElements());
 | 
			
		||||
 | 
			
		||||
        SDVariable loss = sameDiff.standardDeviation(result, true);
 | 
			
		||||
        sameDiff.addLossVariable(loss);
 | 
			
		||||
 | 
			
		||||
        String err = OpValidation.validate(new TestCase(sameDiff));
 | 
			
		||||
        assertNull(err);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
@ -1782,4 +1788,338 @@ public class MiscOpValidation extends BaseOpValidation {
 | 
			
		||||
        assertEquals(exp, out);         //Values in x not in y
 | 
			
		||||
        assertEquals(exp, outIdx);      //Indices of the values in x not in y
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void testDivideNoNan() {
 | 
			
		||||
        OpValidationSuite.ignoreFailing();  //TODO: implement DivideNoNan.doDiff()
 | 
			
		||||
 | 
			
		||||
        SameDiff sameDiff = SameDiff.create();
 | 
			
		||||
 | 
			
		||||
        INDArray in1 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
 | 
			
		||||
        INDArray in2 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
 | 
			
		||||
 | 
			
		||||
        SDVariable input1 = sameDiff.var(in1);
 | 
			
		||||
        SDVariable input2 = sameDiff.var(in2);
 | 
			
		||||
 | 
			
		||||
        INDArray expected = Nd4j.ones(3,4);
 | 
			
		||||
 | 
			
		||||
        SDVariable output = new DivideNoNan(sameDiff, input1, input2).outputVariable();
 | 
			
		||||
 | 
			
		||||
        TestCase tc = new TestCase(sameDiff)
 | 
			
		||||
                .gradientCheck(true)
 | 
			
		||||
                .expectedOutput(output.name(), expected);
 | 
			
		||||
 | 
			
		||||
        String err = OpValidation.validate(tc);
 | 
			
		||||
        assertNull(err);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void testDigamma() {
 | 
			
		||||
 | 
			
		||||
        INDArray in1 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
 | 
			
		||||
 | 
			
		||||
        INDArray expected = Nd4j.createFromArray(new double[]{
 | 
			
		||||
                -0.5772157,0.42278433,0.9227843,1.2561177,1.5061177,1.7061176,1.8727844,2.0156415,2.1406415,2.2517526,2.3517525,2.4426618
 | 
			
		||||
        }).reshape(3,4);
 | 
			
		||||
 | 
			
		||||
        val tc = new OpTestCase(new Digamma(in1)).expectedOutput(0, expected);
 | 
			
		||||
 | 
			
		||||
        String err = OpValidation.validate(tc);
 | 
			
		||||
        assertNull(err);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void testFlatten() {
 | 
			
		||||
 | 
			
		||||
        SameDiff sameDiff = SameDiff.create();
 | 
			
		||||
 | 
			
		||||
        INDArray x = Nd4j.linspace(DataType.DOUBLE, 1, 27, 1).reshape(3,3,3);
 | 
			
		||||
        SDVariable sdx = sameDiff.var(x);
 | 
			
		||||
 | 
			
		||||
        INDArray expected = Nd4j.linspace(DataType.DOUBLE,1,27,1);
 | 
			
		||||
 | 
			
		||||
        SDVariable output = new Flatten(sameDiff, 'c', sdx).outputVariable();
 | 
			
		||||
        SDVariable loss = sameDiff.standardDeviation(sdx, true);
 | 
			
		||||
        sameDiff.addLossVariable(loss);
 | 
			
		||||
 | 
			
		||||
        TestCase tc = new TestCase(sameDiff)
 | 
			
		||||
                .gradientCheck(true)
 | 
			
		||||
                .expectedOutput(output.name(), expected);
 | 
			
		||||
 | 
			
		||||
        String err = OpValidation.validate(tc);
 | 
			
		||||
        assertNull(err);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void testFusedBatchNorm() {
 | 
			
		||||
        OpValidationSuite.ignoreFailing();
 | 
			
		||||
        SameDiff sameDiff = SameDiff.create();
 | 
			
		||||
 | 
			
		||||
        INDArray x = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 2*2*3*4).reshape(2,2,3,4);
 | 
			
		||||
        INDArray scale = Nd4j.create(DataType.DOUBLE, 4);
 | 
			
		||||
        scale.assign(0.5);
 | 
			
		||||
        INDArray offset = Nd4j.create(DataType.DOUBLE, 4);
 | 
			
		||||
        offset.assign(2.0);
 | 
			
		||||
 | 
			
		||||
        SDVariable input1 = sameDiff.var(x);
 | 
			
		||||
        SDVariable input2 = sameDiff.var(scale);
 | 
			
		||||
        SDVariable input3 = sameDiff.var(offset);
 | 
			
		||||
 | 
			
		||||
        INDArray expectedY = Nd4j.createFromArray(new double[]{
 | 
			
		||||
                985.5258,  985.5258,  985.5258,  985.5258,
 | 
			
		||||
                659.7321,  659.7321,  659.7321,  659.7321,
 | 
			
		||||
                399.0972,  399.0972,  399.0972,  399.0972,
 | 
			
		||||
                203.6210,  203.6210,  203.6210,  203.6210,
 | 
			
		||||
                73.3036,   73.3036,   73.3036,   73.3036,
 | 
			
		||||
                8.1448,    8.1448,    8.1448,    8.1448,
 | 
			
		||||
                8.1448,    8.1448,    8.1448,    8.1448,
 | 
			
		||||
                73.3036,   73.3036,   73.3036,   73.3036,
 | 
			
		||||
                203.6210,  203.6210,  203.6210,  203.6210,
 | 
			
		||||
                399.0972,  399.0972,  399.0972,  399.0972,
 | 
			
		||||
                659.7321,  659.7321,  659.7321,  659.7321,
 | 
			
		||||
                985.5258,  985.5258,  985.5258,  985.5258}).reshape(x.shape());
 | 
			
		||||
        INDArray expectedBatchMean = Nd4j.createFromArray(new double[]{23.,  24.,  25.,  26.});
 | 
			
		||||
        INDArray expectedBatchVar = Nd4j.createFromArray(new double[]{208.00001526,  208.00001526,  208.00001526,  208.00001526});
 | 
			
		||||
 | 
			
		||||
        SDVariable[] outputs = new FusedBatchNorm(sameDiff, input1, input2, input3, 0, 1).outputVariables();
 | 
			
		||||
        SDVariable loss = sameDiff.standardDeviation(input1, true);
 | 
			
		||||
        sameDiff.addLossVariable(loss);
 | 
			
		||||
 | 
			
		||||
        TestCase tc = new TestCase(sameDiff)
 | 
			
		||||
                .gradientCheck(true)
 | 
			
		||||
                .expectedOutput(outputs[0].name(), expectedY)
 | 
			
		||||
                .expectedOutput(outputs[1].name(), expectedBatchMean)
 | 
			
		||||
                .expectedOutput(outputs[2].name(), expectedBatchVar);
 | 
			
		||||
 | 
			
		||||
        String err = OpValidation.validate(tc);
 | 
			
		||||
        assertNull(err);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void testIgamma() {
 | 
			
		||||
 | 
			
		||||
        INDArray in1 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
 | 
			
		||||
        INDArray in2 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
 | 
			
		||||
 | 
			
		||||
        INDArray expected = Nd4j.createFromArray(new double[]{
 | 
			
		||||
                0.63212055,0.59399414,0.5768099,0.56652874,0.5595013,0.5542634,0.5501591,0.5463888,0.54329145,0.54048204,0.5378594,0.53233755
 | 
			
		||||
        }).reshape(3,4);
 | 
			
		||||
 | 
			
		||||
        val tc = new OpTestCase(new Igamma(in1, in2)).expectedOutput(0, expected);
 | 
			
		||||
 | 
			
		||||
        String err = OpValidation.validate(tc);
 | 
			
		||||
        assertNull(err);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void testIgammaC() {
 | 
			
		||||
 | 
			
		||||
        INDArray in1 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
 | 
			
		||||
        INDArray in2 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        INDArray expected = Nd4j.createFromArray(new double[]{
 | 
			
		||||
                0.36787945,0.40600586,0.42319012,0.43347126,0.4404987,0.44573656,0.4498409,0.45361117,0.45670855,0.459518,0.46214062,0.46766248
 | 
			
		||||
        }).reshape(3,4);
 | 
			
		||||
 | 
			
		||||
        val tc = new OpTestCase(new Igammac(in1, in2)).expectedOutput(0, expected);
 | 
			
		||||
 | 
			
		||||
        String err = OpValidation.validate(tc);
 | 
			
		||||
        assertNull(err);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void testLgamma() {
 | 
			
		||||
 | 
			
		||||
        SameDiff sameDiff = SameDiff.create();
 | 
			
		||||
 | 
			
		||||
        INDArray in = Nd4j.linspace(DataType.DOUBLE, 1, 12, 1).reshape(3, 4);
 | 
			
		||||
        SDVariable sdInput = sameDiff.var(in);
 | 
			
		||||
 | 
			
		||||
        INDArray expected = Nd4j.createFromArray(new double[]{
 | 
			
		||||
                0.0,0.0,0.6931472,1.7917595,3.1780539,4.787492,6.5792513,8.525162,10.604603,12.801827,15.104413,17.502308
 | 
			
		||||
        }).reshape(3,4);
 | 
			
		||||
 | 
			
		||||
        SDVariable output = new Lgamma(sameDiff, sdInput).outputVariable();
 | 
			
		||||
 | 
			
		||||
        SDVariable loss = sameDiff.standardDeviation(sdInput, true);
 | 
			
		||||
        sameDiff.addLossVariable(loss);
 | 
			
		||||
 | 
			
		||||
        TestCase tc = new TestCase(sameDiff)
 | 
			
		||||
                .gradientCheck(true)
 | 
			
		||||
                .expectedOutput(output.name(), expected);
 | 
			
		||||
 | 
			
		||||
        String err = OpValidation.validate(tc);
 | 
			
		||||
        assertNull(err);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void testLu() {
 | 
			
		||||
 | 
			
		||||
        SameDiff sameDiff = SameDiff.create();
 | 
			
		||||
 | 
			
		||||
        INDArray in1 = Nd4j.createFromArray(new double[]{
 | 
			
		||||
                1., 2., 3., 0., 2., 3., 0., 0., 7.
 | 
			
		||||
        }).reshape(3,3);
 | 
			
		||||
 | 
			
		||||
        SDVariable input1 = sameDiff.var(in1);
 | 
			
		||||
 | 
			
		||||
        INDArray expected = Nd4j.createFromArray(new double[]{
 | 
			
		||||
                1., 2., 3., 0., 2., 3., 0., 0., 7
 | 
			
		||||
        }).reshape(3,3);
 | 
			
		||||
 | 
			
		||||
        INDArray pexpected = Nd4j.createFromArray(new int[]{
 | 
			
		||||
                0, 1, 2
 | 
			
		||||
        });
 | 
			
		||||
 | 
			
		||||
        sameDiff.loss.l2Loss(input1);
 | 
			
		||||
        SDVariable[] output = new Lu(sameDiff, input1).outputVariables();
 | 
			
		||||
 | 
			
		||||
        TestCase tc = new TestCase(sameDiff)
 | 
			
		||||
                .gradientCheck(true)
 | 
			
		||||
                .expectedOutput(output[0].name(), expected)
 | 
			
		||||
                .expectedOutput(output[1].name(), pexpected);
 | 
			
		||||
 | 
			
		||||
        String err = OpValidation.validate(tc);
 | 
			
		||||
        assertNull(err);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void testMatrixBandPart() {
 | 
			
		||||
        OpValidationSuite.ignoreFailing();
 | 
			
		||||
        SameDiff sameDiff = SameDiff.create();
 | 
			
		||||
 | 
			
		||||
        INDArray input = Nd4j.createFromArray(new float[]{0.7788f,0.8012f,0.7244f,0.2309f,
 | 
			
		||||
                0.7271f,0.1804f,0.5056f,0.8925f,
 | 
			
		||||
                0.5461f,0.9234f,0.0856f,0.7938f}).reshape(3,4);
 | 
			
		||||
 | 
			
		||||
        SDVariable sdInput = sameDiff.var(input);
 | 
			
		||||
        SDVariable sdInput1 = sameDiff.constant(1);
 | 
			
		||||
        SDVariable sdInput2 = sameDiff.constant(-1);
 | 
			
		||||
 | 
			
		||||
        INDArray expected = Nd4j.createFromArray(new float[]{
 | 
			
		||||
                0.7788f,    0.8012f,    0.7244f,    0.2309f,
 | 
			
		||||
                0.7271f,    0.1804f,    0.5056f,    0.8925f,
 | 
			
		||||
                0.f,    0.9234f,    0.0856f,    0.7938f
 | 
			
		||||
        }).reshape(3,4);
 | 
			
		||||
 | 
			
		||||
        sameDiff.loss.l2Loss(sdInput);
 | 
			
		||||
        SDVariable output = new MatrixBandPart(sameDiff, sdInput, 1, -1).outputVariable();
 | 
			
		||||
 | 
			
		||||
        TestCase tc = new TestCase(sameDiff)
 | 
			
		||||
                .gradientCheck(true)
 | 
			
		||||
                .expectedOutput(output.name(), expected);
 | 
			
		||||
 | 
			
		||||
        String err = OpValidation.validate(tc);
 | 
			
		||||
        assertNull(err);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void testPolygamma() {
 | 
			
		||||
 | 
			
		||||
        INDArray in1 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
 | 
			
		||||
        INDArray in2 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
 | 
			
		||||
 | 
			
		||||
        INDArray expected = Nd4j.createFromArray(new double[]{
 | 
			
		||||
                1.644934,-0.4041138,0.1189394,-0.03750069,0.01226151,-0.0041002957,0.001392272,-4.780109E-4,1.6549716E-4,-5.7675967E-5,2.0206635E-5,-7.1101636E-6
 | 
			
		||||
        }).reshape(3,4);
 | 
			
		||||
 | 
			
		||||
        val tc = new OpTestCase(new Polygamma(in1, in2)).expectedOutput(0, expected);
 | 
			
		||||
 | 
			
		||||
        String err = OpValidation.validate(tc);
 | 
			
		||||
        assertNull(err);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void testTriangularSolve() {
 | 
			
		||||
 | 
			
		||||
        INDArray a = Nd4j.createFromArray(new float[]{
 | 
			
		||||
                3.f,  0.f,  0.f,  0.f,
 | 
			
		||||
                2.f,  1.f,  0.f,  0.f,
 | 
			
		||||
                1.f,  0.f,  1.f,  0.f,
 | 
			
		||||
                1.f,  1.f,  1.f,  1.f
 | 
			
		||||
        }).reshape(4,4);
 | 
			
		||||
 | 
			
		||||
        INDArray b = Nd4j.createFromArray(new float[]{
 | 
			
		||||
                4.f, 2.f, 4.f, 2.f
 | 
			
		||||
        }).reshape(4,1);
 | 
			
		||||
 | 
			
		||||
        INDArray expected = Nd4j.createFromArray(new float[]{
 | 
			
		||||
                1.333333f,  2.0f, 4.0f, 2.0f
 | 
			
		||||
        }).reshape(4,1);
 | 
			
		||||
 | 
			
		||||
        val tc = new OpTestCase(new TriangularSolve(a, b, false, true)).expectedOutput(0, expected);
 | 
			
		||||
 | 
			
		||||
        String err = OpValidation.validate(tc);
 | 
			
		||||
        assertNull(err);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void testBiasAdd() {
 | 
			
		||||
 | 
			
		||||
        SameDiff sameDiff = SameDiff.create();
 | 
			
		||||
 | 
			
		||||
        INDArray in1 = Nd4j.linspace(1, 12, 12);
 | 
			
		||||
        INDArray in2 = Nd4j.linspace(1, 12, 12);
 | 
			
		||||
 | 
			
		||||
        SDVariable input1 = sameDiff.var(in1);
 | 
			
		||||
        SDVariable input2 = sameDiff.var(in2);
 | 
			
		||||
 | 
			
		||||
        INDArray expected = Nd4j.createFromArray(new double[]{
 | 
			
		||||
                2.0000,    4.0000,    6.0000,    8.0000,   10.0000,   12.0000,   14.0000,   16.0000,   18.0000,   20.0000,   22.0000,   24.0000
 | 
			
		||||
        });
 | 
			
		||||
 | 
			
		||||
        SDVariable output = new BiasAdd(sameDiff, input1, input2, false).outputVariable();
 | 
			
		||||
        SDVariable loss = sameDiff.standardDeviation(input1, true);
 | 
			
		||||
        sameDiff.addLossVariable(loss);
 | 
			
		||||
        SDVariable loss2 = sameDiff.standardDeviation(input2, true);
 | 
			
		||||
        sameDiff.addLossVariable(loss2);
 | 
			
		||||
 | 
			
		||||
        TestCase tc = new TestCase(sameDiff)
 | 
			
		||||
                .gradientCheck(true)
 | 
			
		||||
                .expectedOutput(output.name(), expected);
 | 
			
		||||
 | 
			
		||||
        String err = OpValidation.validate(tc);
 | 
			
		||||
        assertNull(err);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void testBiasAddGrad() {
 | 
			
		||||
 | 
			
		||||
        SameDiff sameDiff = SameDiff.create();
 | 
			
		||||
 | 
			
		||||
        INDArray x = Nd4j.linspace(DataType.FLOAT,1, 24, 24).reshape(2,2,2,3);
 | 
			
		||||
        INDArray grad = Nd4j.linspace(DataType.FLOAT, 0.1, 0.1, 24).reshape(2,2,2,3);
 | 
			
		||||
 | 
			
		||||
        INDArray bias = Nd4j.createFromArray(new float[]{-1.f, -2.f, -3.f});
 | 
			
		||||
 | 
			
		||||
        INDArray expected = Nd4j.createFromArray(new float[]{9.2f, 10.f , 10.8f});
 | 
			
		||||
 | 
			
		||||
        OpTestCase tc = new OpTestCase(new BiasAddGrad(x, bias, grad,false)).
 | 
			
		||||
                expectedOutput(0, grad).
 | 
			
		||||
                expectedOutput(1, expected);
 | 
			
		||||
 | 
			
		||||
        String err = OpValidation.validate(tc);
 | 
			
		||||
        assertNull(err);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void testRoll() {
 | 
			
		||||
 | 
			
		||||
        INDArray x = Nd4j.createFromArray(new double[]{    11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42,     12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42,
 | 
			
		||||
                21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 21.41, 21.42,     22.11, 22.12, 22.21, 22.22, 22.31, 22.32, 22.41, 22.42}).
 | 
			
		||||
                reshape(2,2,4,2);
 | 
			
		||||
 | 
			
		||||
        INDArray expected = Nd4j.createFromArray(new double[]{    22.21, 22.22, 22.31, 22.32, 22.41, 22.42, 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42,
 | 
			
		||||
                12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, 21.11, 21.12, 21.21, 21.22, 21.31, 21.32,
 | 
			
		||||
                21.41, 21.42, 22.11, 22.12
 | 
			
		||||
        }).reshape(x.shape());
 | 
			
		||||
 | 
			
		||||
        int shift = 6;
 | 
			
		||||
 | 
			
		||||
        val tc = new OpTestCase(new Roll(x,shift)).expectedOutput(0,expected);
 | 
			
		||||
        String err = OpValidation.validate(tc);
 | 
			
		||||
 | 
			
		||||
        assertNull(err);
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -35,16 +35,20 @@ import org.nd4j.linalg.api.ops.CustomOp;
 | 
			
		||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.indexaccum.IAMax;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.indexaccum.IAMin;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.loss.SoftmaxCrossEntropyWithLogitsLoss;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.reduce.Moments;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.reduce.NormalizeMoments;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.reduce.floating.AMean;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.reduce.SufficientStatistics;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.reduce.floating.*;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.reduce.same.ASum;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.reduce3.*;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.summarystats.StandardDeviation;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
 | 
			
		||||
import org.nd4j.linalg.checkutil.NDArrayCreationUtil;
 | 
			
		||||
import org.nd4j.linalg.factory.Nd4j;
 | 
			
		||||
import org.nd4j.linalg.factory.Nd4jBackend;
 | 
			
		||||
import org.nd4j.linalg.indexing.BooleanIndexing;
 | 
			
		||||
import org.nd4j.linalg.indexing.NDArrayIndex;
 | 
			
		||||
import org.nd4j.linalg.indexing.conditions.Conditions;
 | 
			
		||||
import org.nd4j.linalg.ops.transforms.Transforms;
 | 
			
		||||
import org.nd4j.linalg.primitives.Pair;
 | 
			
		||||
@ -96,7 +100,7 @@ public class ReductionOpValidation extends BaseOpValidation {
 | 
			
		||||
    @Test
 | 
			
		||||
    public void testZeroCount() {
 | 
			
		||||
        List<String> allFailed = new ArrayList<>();
 | 
			
		||||
        for (int i = 0; i < 2; i++) {
 | 
			
		||||
        for (int i = 0; i < 21; i++) {
 | 
			
		||||
            SameDiff sd = SameDiff.create();
 | 
			
		||||
 | 
			
		||||
            INDArray ia;
 | 
			
		||||
@ -159,25 +163,25 @@ public class ReductionOpValidation extends BaseOpValidation {
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void testReductionGradientsSimple() {
 | 
			
		||||
        OpValidationSuite.ignoreFailing();  //TODO TEMPORARY DUE TO CRASHES
 | 
			
		||||
        //OpValidationSuite.ignoreFailing();  //TODO TEMPORARY DUE TO CRASHES
 | 
			
		||||
        //Test reductions: final and only function
 | 
			
		||||
        Nd4j.getRandom().setSeed(12345);
 | 
			
		||||
 | 
			
		||||
        List<String> failed = new ArrayList<>();
 | 
			
		||||
 | 
			
		||||
        for (int i = 0; i < 21; i++) {
 | 
			
		||||
 | 
			
		||||
            SameDiff sd = SameDiff.create();
 | 
			
		||||
 | 
			
		||||
            int nOut = 4;
 | 
			
		||||
            int minibatch = 10;
 | 
			
		||||
            SDVariable input = sd.var("in", -1, nOut);
 | 
			
		||||
            SDVariable input = sd.var("in", minibatch, nOut);
 | 
			
		||||
            INDArray inputArr = Nd4j.randn(minibatch, nOut).muli(100);
 | 
			
		||||
            long length = nOut * minibatch;
 | 
			
		||||
 | 
			
		||||
            SDVariable loss;
 | 
			
		||||
            String name;
 | 
			
		||||
            TestCase tc = new TestCase(sd);
 | 
			
		||||
            boolean gradCheck = true;
 | 
			
		||||
            switch (i) {
 | 
			
		||||
                case 0:
 | 
			
		||||
                    loss = sd.mean("loss", input);
 | 
			
		||||
@ -234,11 +238,13 @@ public class ReductionOpValidation extends BaseOpValidation {
 | 
			
		||||
                    loss = sd.math().countNonZero("loss", input);
 | 
			
		||||
                    name = "countNonZero";
 | 
			
		||||
                    tc.expectedOutput("loss", Nd4j.scalar(inputArr.length()));
 | 
			
		||||
                    gradCheck = false;  //Long out, not floating point
 | 
			
		||||
                    break;
 | 
			
		||||
                case 11:
 | 
			
		||||
                    loss = sd.math().countZero("loss", input);
 | 
			
		||||
                    name = "countZero";
 | 
			
		||||
                    tc.expectedOutput("loss", Nd4j.scalar(0));
 | 
			
		||||
                    tc.expectedOutput("loss", Nd4j.scalar(0L));
 | 
			
		||||
                    gradCheck = false;  //Long out, not floating point
 | 
			
		||||
                    break;
 | 
			
		||||
                case 12:
 | 
			
		||||
                    loss = sd.math().amax("loss", input);
 | 
			
		||||
@ -272,7 +278,7 @@ public class ReductionOpValidation extends BaseOpValidation {
 | 
			
		||||
                    loss = sd.math().logSumExp("loss", input);
 | 
			
		||||
                    INDArray expArr = Transforms.exp(inputArr);
 | 
			
		||||
                    double sum = expArr.sumNumber().doubleValue();
 | 
			
		||||
                    tc.expected("loss", Nd4j.create(new double[]{Math.log(sum)}));
 | 
			
		||||
                    tc.expected("loss", Nd4j.scalar(Math.log(sum)));
 | 
			
		||||
                    break;
 | 
			
		||||
                case 18:
 | 
			
		||||
                    inputArr = Nd4j.rand(minibatch, nOut);
 | 
			
		||||
@ -307,9 +313,15 @@ public class ReductionOpValidation extends BaseOpValidation {
 | 
			
		||||
            log.info("*** Starting test: " + msg);
 | 
			
		||||
 | 
			
		||||
            sd.associateArrayWithVariable(inputArr, input);
 | 
			
		||||
 | 
			
		||||
            if(gradCheck) {
 | 
			
		||||
                sd.addLossVariable(loss);
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            tc.testName(msg);
 | 
			
		||||
            if(!gradCheck){
 | 
			
		||||
                tc.gradientCheck(false);
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            String error = OpValidation.validate(tc, true);
 | 
			
		||||
            if (error != null)
 | 
			
		||||
                failed.add(error);
 | 
			
		||||
@ -629,14 +641,14 @@ public class ReductionOpValidation extends BaseOpValidation {
 | 
			
		||||
 | 
			
		||||
        List<String> failed = new ArrayList<>();
 | 
			
		||||
        for (int[] reduceDims : new int[][]{{Integer.MAX_VALUE}, {0, 1, 2}, {0}, {1}, {2}, {0, 1}, {0, 2}, {1, 2}}) {
 | 
			
		||||
            for (int i = 6; i < 7; i++) {
 | 
			
		||||
            for (int i = 0; i < 7; i++) {
 | 
			
		||||
 | 
			
		||||
                SameDiff sd = SameDiff.create();
 | 
			
		||||
                sd.setLogExecution(false);
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
                SDVariable in = sd.var("in", -1, d1, d2);
 | 
			
		||||
                SDVariable in2 = sd.var("in2", -1, d1, d2);
 | 
			
		||||
                SDVariable in = sd.var("in", d1, d1, d2);
 | 
			
		||||
                SDVariable in2 = sd.var("in2", d0, d1, d2);
 | 
			
		||||
 | 
			
		||||
                INDArray inArr = Nd4j.randn(new int[]{d0, d1, d2}).muli(100);
 | 
			
		||||
                INDArray in2Arr = Nd4j.randn(inArr.shape()).muli(100);
 | 
			
		||||
@ -645,40 +657,43 @@ public class ReductionOpValidation extends BaseOpValidation {
 | 
			
		||||
                SDVariable reduced;
 | 
			
		||||
                String name;
 | 
			
		||||
                TestCase tc = new TestCase(sd);
 | 
			
		||||
                Double maxRelError = null;
 | 
			
		||||
                switch (i) {
 | 
			
		||||
                    case 0:
 | 
			
		||||
                        reduced = sd.math().manhattanDistance(in, in2, reduceDims);
 | 
			
		||||
                        name = "manhattan";
 | 
			
		||||
                        exp = Nd4j.getExecutioner().exec(new ManhattanDistance(inArr, in2Arr, null, true, false, reduceDims));
 | 
			
		||||
                        exp = Nd4j.getExecutioner().exec(new ManhattanDistance(inArr, in2Arr, null, false, false, reduceDims));
 | 
			
		||||
                        break;
 | 
			
		||||
                    case 1:
 | 
			
		||||
                        reduced = sd.math().euclideanDistance(in, in2, reduceDims);
 | 
			
		||||
                        name = "euclidean";
 | 
			
		||||
                        exp = Nd4j.getExecutioner().exec(new EuclideanDistance(inArr, in2Arr, null, true, false, reduceDims));
 | 
			
		||||
                        exp = Nd4j.getExecutioner().exec(new EuclideanDistance(inArr, in2Arr, null, false, false, reduceDims));
 | 
			
		||||
                        break;
 | 
			
		||||
                    case 2:
 | 
			
		||||
                        inArr.muli(1e-4);
 | 
			
		||||
                        in2Arr.muli(1e-4);
 | 
			
		||||
                        reduced = sd.math().cosineSimilarity(in, in2, reduceDims);
 | 
			
		||||
                        name = "cosine";
 | 
			
		||||
                        exp = Nd4j.getExecutioner().exec(new CosineSimilarity(inArr, in2Arr, null, true, false, reduceDims));
 | 
			
		||||
                        exp = Nd4j.getExecutioner().exec(new CosineSimilarity(inArr, in2Arr, null, false, false, reduceDims));
 | 
			
		||||
                        maxRelError = 1e-4;
 | 
			
		||||
                        break;
 | 
			
		||||
                    case 3:
 | 
			
		||||
                        reduced = sd.math().cosineDistance(in, in2, reduceDims);
 | 
			
		||||
                        name = "cosinedistance";
 | 
			
		||||
                        exp = Nd4j.getExecutioner().exec(new CosineDistance(inArr, in2Arr, null, true, false, reduceDims));
 | 
			
		||||
                        exp = Nd4j.getExecutioner().exec(new CosineDistance(inArr, in2Arr, null, false, false, reduceDims));
 | 
			
		||||
                        maxRelError = 1e-4;
 | 
			
		||||
                        break;
 | 
			
		||||
                    case 4:
 | 
			
		||||
                        reduced = sd.math().hammingDistance(in, in2, reduceDims);
 | 
			
		||||
                        name = "hamming";
 | 
			
		||||
                        exp = Nd4j.getExecutioner().exec(new HammingDistance(inArr, in2Arr, null, true, false, reduceDims));
 | 
			
		||||
                        exp = Nd4j.getExecutioner().exec(new HammingDistance(inArr, in2Arr, null, false, false, reduceDims));
 | 
			
		||||
                        break;
 | 
			
		||||
                    case 5:
 | 
			
		||||
                        name = "jaccard";
 | 
			
		||||
                        reduced = sd.math().jaccardDistance(name, in, in2, reduceDims);
 | 
			
		||||
                        inArr.divi(100).addi(0.1);
 | 
			
		||||
                        in2Arr.divi(100).addi(0.1);
 | 
			
		||||
                        exp = Nd4j.getExecutioner().exec(new JaccardDistance(inArr, in2Arr, null, true, false, reduceDims));
 | 
			
		||||
                        exp = Nd4j.getExecutioner().exec(new JaccardDistance(inArr, in2Arr, null, false, false, reduceDims));
 | 
			
		||||
 | 
			
		||||
                        if (OpValidationSuite.IGNORE_FAILING && reduceDims.length == 2)
 | 
			
		||||
                            continue;
 | 
			
		||||
@ -708,6 +723,9 @@ public class ReductionOpValidation extends BaseOpValidation {
 | 
			
		||||
 | 
			
		||||
                tc.expected(reduced, exp);
 | 
			
		||||
 | 
			
		||||
                if(maxRelError != null)
 | 
			
		||||
                    tc.gradCheckMaxRelativeError(maxRelError);
 | 
			
		||||
 | 
			
		||||
                String error = OpValidation.validate(tc, true);
 | 
			
		||||
                if (error != null) {
 | 
			
		||||
                    failed.add(msg + " - " + error);
 | 
			
		||||
@ -768,7 +786,6 @@ public class ReductionOpValidation extends BaseOpValidation {
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    @Ignore("AB 2019/06/24 - Failing: Ignored to get to all passing baseline to prevent regressions via CI - see issue #7912")
 | 
			
		||||
    public void testNormalizeMomentsOp() {
 | 
			
		||||
        INDArray data = Nd4j.linspace(1, 100, 100, DataType.DOUBLE).reshape(10, 10);
 | 
			
		||||
        INDArray ssSum = data.sum(0);
 | 
			
		||||
@ -780,7 +797,7 @@ public class ReductionOpValidation extends BaseOpValidation {
 | 
			
		||||
        INDArray mean = Nd4j.createUninitialized(DataType.DOUBLE, meanExp.shape());
 | 
			
		||||
        INDArray var = Nd4j.createUninitialized(DataType.DOUBLE, varExp.shape());
 | 
			
		||||
 | 
			
		||||
        OpTestCase op = new OpTestCase(new NormalizeMoments(Nd4j.scalar(DataType.INT, 10), ssSum, ssSqSum, mean, var));
 | 
			
		||||
        OpTestCase op = new OpTestCase(new NormalizeMoments(Nd4j.scalar(DataType.DOUBLE, 10), ssSum, ssSqSum, mean, var));
 | 
			
		||||
        op.expectedOutput(0, meanExp);
 | 
			
		||||
        op.expectedOutput(1, varExp);
 | 
			
		||||
 | 
			
		||||
@ -821,7 +838,7 @@ public class ReductionOpValidation extends BaseOpValidation {
 | 
			
		||||
        List<String> failed = new ArrayList<>();
 | 
			
		||||
        List<int[]> dims = Arrays.asList(new int[]{0}, new int[]{1}, new int[]{0, 1}, new int[0]);
 | 
			
		||||
 | 
			
		||||
        INDArray in = Nd4j.rand(3, 4);
 | 
			
		||||
        INDArray in = Nd4j.rand(DataType.DOUBLE,3, 4);
 | 
			
		||||
 | 
			
		||||
        for (int t = 0; t < 4; t++) {
 | 
			
		||||
            int[] d = dims.get(t);
 | 
			
		||||
@ -838,52 +855,47 @@ public class ReductionOpValidation extends BaseOpValidation {
 | 
			
		||||
                switch (i) {
 | 
			
		||||
                    case 0:
 | 
			
		||||
                        reduce = s.argmax(dim);
 | 
			
		||||
                        exp = Nd4j.argMax(in, dim).castTo(DataType.DOUBLE);
 | 
			
		||||
                        exp = Nd4j.argMax(in, dim);
 | 
			
		||||
                        name = "argmax";
 | 
			
		||||
                        break;
 | 
			
		||||
                    case 1:
 | 
			
		||||
                        reduce = s.argmin(dim);
 | 
			
		||||
                        exp = Nd4j.argMin(in, dim).castTo(DataType.DOUBLE);
 | 
			
		||||
                        exp = Nd4j.argMin(in, dim);
 | 
			
		||||
                        name = "argmin";
 | 
			
		||||
                        break;
 | 
			
		||||
                    case 2:
 | 
			
		||||
                        reduce = sd.math().iamax(s, dim);
 | 
			
		||||
                        exp = Nd4j.getExecutioner().exec(new IAMax(in.dup(), dim));
 | 
			
		||||
                        exp = exp.castTo(DataType.DOUBLE);
 | 
			
		||||
                        name = "iamax";
 | 
			
		||||
                        break;
 | 
			
		||||
                    case 3:
 | 
			
		||||
                        reduce = sd.math().iamin(s, dim);
 | 
			
		||||
                        exp = Nd4j.getExecutioner().exec(new IAMin(in.dup(), dim));
 | 
			
		||||
                        exp = exp.castTo(DataType.DOUBLE);
 | 
			
		||||
                        name = "iamin";
 | 
			
		||||
                        break;
 | 
			
		||||
                    case 4:
 | 
			
		||||
                        reduce = sd.math().firstIndex(s, Conditions.greaterThan(0), dim);
 | 
			
		||||
                        exp = in.sum(dim).assign(0);
 | 
			
		||||
                        exp = exp.castTo(DataType.DOUBLE);
 | 
			
		||||
                        exp = in.sum(dim).assign(0).castTo(DataType.INT64);
 | 
			
		||||
                        name = "firstindex";
 | 
			
		||||
                        break;
 | 
			
		||||
                    case 5:
 | 
			
		||||
                        reduce = sd.math().lastIndex(s, Conditions.greaterThan(0), dim);
 | 
			
		||||
                        if (t == 0) exp = Nd4j.create(new double[]{2, 2, 2, 2});
 | 
			
		||||
                        else if (t == 1) exp = Nd4j.create(new double[]{3, 3, 3});
 | 
			
		||||
                        else exp = Nd4j.scalar(11.0);
 | 
			
		||||
                        exp = exp.castTo(DataType.DOUBLE);
 | 
			
		||||
                        if (t == 0) exp = Nd4j.createFromArray(2L, 2, 2, 2);
 | 
			
		||||
                        else if (t == 1) exp = Nd4j.createFromArray(3L, 3, 3);
 | 
			
		||||
                        else exp = Nd4j.scalar(11L);
 | 
			
		||||
                        name = "lastindex";
 | 
			
		||||
                        break;
 | 
			
		||||
                    case 6:
 | 
			
		||||
                        reduce = sd.matchConditionCount("count", s, Conditions.greaterThan(0), false, dim);
 | 
			
		||||
                        if (t == 0) exp = Nd4j.create(new double[]{3, 3, 3, 3});
 | 
			
		||||
                        else if (t == 1) exp = Nd4j.create(new double[]{4, 4, 4});
 | 
			
		||||
                        else exp = Nd4j.scalar(12.0);
 | 
			
		||||
                        exp = exp.castTo(DataType.DOUBLE);
 | 
			
		||||
                        if (t == 0) exp = Nd4j.createFromArray(3L, 3, 3, 3);
 | 
			
		||||
                        else if (t == 1) exp = Nd4j.createFromArray(4L, 4, 4);
 | 
			
		||||
                        else exp = Nd4j.scalar(12L);
 | 
			
		||||
                        name = "matchConditionCount";
 | 
			
		||||
                        break;
 | 
			
		||||
                    default:
 | 
			
		||||
                        throw new RuntimeException();
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                SDVariable preCast = reduce;
 | 
			
		||||
                reduce = reduce.castTo(DataType.DOUBLE);
 | 
			
		||||
 | 
			
		||||
                SDVariable loss;
 | 
			
		||||
@ -894,7 +906,7 @@ public class ReductionOpValidation extends BaseOpValidation {
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                TestCase tc = new TestCase(sd)
 | 
			
		||||
                        .expected(reduce, exp)
 | 
			
		||||
                        .expected(preCast, exp)
 | 
			
		||||
                        .gradientCheck(false)
 | 
			
		||||
                        .testName(name + " - " + (dim == null ? null : Arrays.toString(dim)));
 | 
			
		||||
 | 
			
		||||
@ -1335,4 +1347,254 @@ public class ReductionOpValidation extends BaseOpValidation {
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    @Test
 | 
			
		||||
    public void testSufficientStatisticsOp() {
 | 
			
		||||
        INDArray data = Nd4j.createFromArray(new double[]{
 | 
			
		||||
                5.5, 0.,  0.3, 5.5,1.5, 0.,  1.3, 6.5,8.6, 0.,   0., 0.4,2.5, 1.,  0.3, 4.5,1.5, 1.,
 | 
			
		||||
                1.3, 1.5,3.5, 0.,  1.3, 2.5,2.6, 2.,   3., 1.4,4.5, 1.,  0.3, 0.5
 | 
			
		||||
        }).reshape(2,2,2,4);
 | 
			
		||||
        INDArray axes = Nd4j.linspace(DataType.LONG, 0, 3, 1);
 | 
			
		||||
 | 
			
		||||
        OpTestCase op = new OpTestCase(new SufficientStatistics(data, axes));
 | 
			
		||||
 | 
			
		||||
        INDArray expected1 = Nd4j.scalar(8.0);
 | 
			
		||||
        INDArray expected2 = Nd4j.createFromArray(new double[]{
 | 
			
		||||
                30.2, 5., 7.8, 22.8
 | 
			
		||||
        });
 | 
			
		||||
        INDArray expected3 = Nd4j.createFromArray(new double[]{
 | 
			
		||||
                154.22,   7.,    14.34, 103.62
 | 
			
		||||
        });
 | 
			
		||||
 | 
			
		||||
        op.expectedOutput(0, expected1);
 | 
			
		||||
        op.expectedOutput(1, expected2);
 | 
			
		||||
        op.expectedOutput(2, expected3);
 | 
			
		||||
 | 
			
		||||
        String err = OpValidation.validate(op);
 | 
			
		||||
        assertNull(err);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void testStandardDeviation() {
 | 
			
		||||
 | 
			
		||||
        for (boolean keepDims : new boolean[]{false, true}) {
 | 
			
		||||
            SameDiff sameDiff = SameDiff.create();
 | 
			
		||||
 | 
			
		||||
            INDArray in = Nd4j.linspace(1, 8, 8).reshape(2, 4);
 | 
			
		||||
            SDVariable input = sameDiff.var(in);
 | 
			
		||||
            INDArray expected = Nd4j.createFromArray(new double[]{
 | 
			
		||||
                    2,         2,         2,         2
 | 
			
		||||
            });
 | 
			
		||||
 | 
			
		||||
            if(keepDims){
 | 
			
		||||
                expected = expected.reshape(1,4);
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            SDVariable output = new StandardDeviation(sameDiff, input, false, keepDims, new int[]{0}).outputVariable();
 | 
			
		||||
 | 
			
		||||
            TestCase tc = new TestCase(sameDiff)
 | 
			
		||||
                    .gradientCheck(true)
 | 
			
		||||
                    .expectedOutput(output.name(), expected);
 | 
			
		||||
 | 
			
		||||
            String err = OpValidation.validate(tc);
 | 
			
		||||
            assertNull(err);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void testSquaredNorm() {
 | 
			
		||||
 | 
			
		||||
        for (boolean keepDims : new boolean[]{false, true}) {
 | 
			
		||||
            SameDiff sameDiff = SameDiff.create();
 | 
			
		||||
 | 
			
		||||
            INDArray in = Nd4j.linspace(1, 4, 4);
 | 
			
		||||
            SDVariable input = sameDiff.var(in);
 | 
			
		||||
            INDArray expected = Nd4j.scalar(30.0000);
 | 
			
		||||
            if(keepDims)
 | 
			
		||||
                expected = expected.reshape(1);
 | 
			
		||||
 | 
			
		||||
            SDVariable output = new SquaredNorm(sameDiff, input, keepDims, new int[]{0}).outputVariable();
 | 
			
		||||
 | 
			
		||||
            TestCase tc = new TestCase(sameDiff)
 | 
			
		||||
                    .gradientCheck(true)
 | 
			
		||||
                    .expectedOutput(output.name(), expected);
 | 
			
		||||
 | 
			
		||||
            String err = OpValidation.validate(tc);
 | 
			
		||||
            assertNull(err);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void testShannonEntropy() {
 | 
			
		||||
        OpValidationSuite.ignoreFailing();  //AB 2020/02/11 https://github.com/eclipse/deeplearning4j/issues/8695
 | 
			
		||||
 | 
			
		||||
        SameDiff sameDiff = SameDiff.create();
 | 
			
		||||
 | 
			
		||||
        INDArray in = Nd4j.linspace(1, 4, 4).castTo(DataType.DOUBLE);
 | 
			
		||||
        SDVariable input = sameDiff.var(in);
 | 
			
		||||
        INDArray expected = Nd4j.scalar(-69.68162);
 | 
			
		||||
 | 
			
		||||
        SDVariable output = new ShannonEntropy(sameDiff, input, new int[]{0}).outputVariable();
 | 
			
		||||
 | 
			
		||||
        TestCase tc = new TestCase(sameDiff)
 | 
			
		||||
                .gradientCheck(true)
 | 
			
		||||
                .expectedOutput(output.name(), expected);
 | 
			
		||||
 | 
			
		||||
        String err = OpValidation.validate(tc);
 | 
			
		||||
        assertNull(err);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void testEntropy() {
 | 
			
		||||
 | 
			
		||||
        SameDiff sameDiff = SameDiff.create();
 | 
			
		||||
 | 
			
		||||
        INDArray in = Nd4j.linspace(1, 4, 4);
 | 
			
		||||
        SDVariable input = sameDiff.var(in);
 | 
			
		||||
        double expected = -10.2273;
 | 
			
		||||
 | 
			
		||||
        SDVariable output = new Entropy(sameDiff, input, new int[]{0}).outputVariable();
 | 
			
		||||
 | 
			
		||||
        TestCase tc = new TestCase(sameDiff)
 | 
			
		||||
                .gradientCheck(true)
 | 
			
		||||
                .expectedOutput(output.name(), Nd4j.scalar(expected));
 | 
			
		||||
 | 
			
		||||
        String err = OpValidation.validate(tc);
 | 
			
		||||
        assertNull(err);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void testAMean() {
 | 
			
		||||
 | 
			
		||||
        SameDiff sameDiff = SameDiff.create();
 | 
			
		||||
 | 
			
		||||
        INDArray in = Nd4j.linspace(1, 12, 12).reshape(3, 4);
 | 
			
		||||
        SDVariable input = sameDiff.var(in);
 | 
			
		||||
        INDArray expected = Nd4j.createFromArray(new double[]{
 | 
			
		||||
                5.0000,    6.0000,    7.0000,    8.0000
 | 
			
		||||
        });
 | 
			
		||||
 | 
			
		||||
        SDVariable output = new AMean(sameDiff, input, new int[]{0}).outputVariable();
 | 
			
		||||
 | 
			
		||||
        TestCase tc = new TestCase(sameDiff)
 | 
			
		||||
                .gradientCheck(true)
 | 
			
		||||
                .expectedOutput(output.name(), expected);
 | 
			
		||||
 | 
			
		||||
        String err = OpValidation.validate(tc);
 | 
			
		||||
        assertNull(err);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void testMean() {
 | 
			
		||||
 | 
			
		||||
        SameDiff sameDiff = SameDiff.create();
 | 
			
		||||
 | 
			
		||||
        INDArray in = Nd4j.linspace(1, 12, 12).reshape(3, 4);
 | 
			
		||||
        SDVariable input = sameDiff.var(in);
 | 
			
		||||
        INDArray expected = Nd4j.createFromArray(new double[]{
 | 
			
		||||
                5.0000,    6.0000,    7.0000,    8.0000
 | 
			
		||||
        });
 | 
			
		||||
 | 
			
		||||
        SDVariable output = new Mean(sameDiff, input, false, new int[]{0}).outputVariable();
 | 
			
		||||
 | 
			
		||||
        TestCase tc = new TestCase(sameDiff)
 | 
			
		||||
                .gradientCheck(true)
 | 
			
		||||
                .expectedOutput(output.name(), expected);
 | 
			
		||||
 | 
			
		||||
        String err = OpValidation.validate(tc);
 | 
			
		||||
        assertNull(err);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void testNorm1() {
 | 
			
		||||
 | 
			
		||||
        SameDiff sameDiff = SameDiff.create();
 | 
			
		||||
 | 
			
		||||
        INDArray in = Nd4j.linspace(1, 12, 12).reshape(3, 4);
 | 
			
		||||
        SDVariable input = sameDiff.var(in);
 | 
			
		||||
        INDArray expected = Nd4j.createFromArray(new double[]{
 | 
			
		||||
                15.0000,   18.0000,   21.0000,   24.0000
 | 
			
		||||
        });
 | 
			
		||||
 | 
			
		||||
        SDVariable output = new Norm1(sameDiff, input, false, new int[]{0}).outputVariable();
 | 
			
		||||
 | 
			
		||||
        TestCase tc = new TestCase(sameDiff)
 | 
			
		||||
                .gradientCheck(true)
 | 
			
		||||
                .expectedOutput(output.name(), expected);
 | 
			
		||||
 | 
			
		||||
        String err = OpValidation.validate(tc);
 | 
			
		||||
        assertNull(err);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void testNorm2() {
 | 
			
		||||
 | 
			
		||||
        SameDiff sameDiff = SameDiff.create();
 | 
			
		||||
 | 
			
		||||
        INDArray in = Nd4j.linspace(1, 12, 12).reshape(3, 4);
 | 
			
		||||
        SDVariable input = sameDiff.var(in);
 | 
			
		||||
        INDArray expected = Nd4j.createFromArray(new double[]{
 | 
			
		||||
                10.3441,   11.8322,   13.3791,   14.9666
 | 
			
		||||
        });
 | 
			
		||||
 | 
			
		||||
        SDVariable output = new Norm2(sameDiff, input, false, new int[]{0}).outputVariable();
 | 
			
		||||
 | 
			
		||||
        TestCase tc = new TestCase(sameDiff)
 | 
			
		||||
                .gradientCheck(true)
 | 
			
		||||
                .expectedOutput(output.name(), expected);
 | 
			
		||||
 | 
			
		||||
        String err = OpValidation.validate(tc);
 | 
			
		||||
        assertNull(err);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void testNormMax() {
 | 
			
		||||
 | 
			
		||||
        SameDiff sameDiff = SameDiff.create();
 | 
			
		||||
 | 
			
		||||
        INDArray in = Nd4j.linspace(1, 12, 12).reshape(3, 4);
 | 
			
		||||
        SDVariable input = sameDiff.var(in);
 | 
			
		||||
        INDArray expected = Nd4j.createFromArray(new double[]{
 | 
			
		||||
                9.0000,   10.0000,   11.0000,   12.0000
 | 
			
		||||
        });
 | 
			
		||||
 | 
			
		||||
        SDVariable output = new NormMax(sameDiff, input, false, new int[]{0}).outputVariable();
 | 
			
		||||
 | 
			
		||||
        TestCase tc = new TestCase(sameDiff)
 | 
			
		||||
                .gradientCheck(true)
 | 
			
		||||
                .expectedOutput(output.name(), expected);
 | 
			
		||||
 | 
			
		||||
        String err = OpValidation.validate(tc);
 | 
			
		||||
        assertNull(err);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void testSoftmaxCrossEntropyWithLogitsLoss() {
 | 
			
		||||
        OpValidationSuite.ignoreFailing();
 | 
			
		||||
 | 
			
		||||
        SameDiff sameDiff = SameDiff.create();
 | 
			
		||||
 | 
			
		||||
        INDArray labels = Nd4j.createFromArray(new double[]{
 | 
			
		||||
                0,1,1,0,0,0,1,0,1,0,1,1,1,0,1,0,1,0,0,1,1,0,1,0
 | 
			
		||||
        }).reshape(2,3,4);
 | 
			
		||||
 | 
			
		||||
        INDArray logits = Nd4j.linspace(DataType.DOUBLE, 0.1, 0.1, 24).reshape(2,3,4);
 | 
			
		||||
        INDArray expected = Nd4j.createFromArray(new double[]{
 | 
			
		||||
                0.26328, 1.46328, 1.72656, 0.     , 0.26328, 0.     , 1.46328, 0.26328, 1.72656, 0.     , 1.72656, 1.46328
 | 
			
		||||
        }).reshape(3,4);
 | 
			
		||||
 | 
			
		||||
        SDVariable sdLogits = sameDiff.var("logits", logits);
 | 
			
		||||
        SDVariable sdLabels = sameDiff.var("labels", labels);
 | 
			
		||||
        SDVariable loss = sameDiff.math().abs(sdLogits);
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        SDVariable output = new SoftmaxCrossEntropyWithLogitsLoss(sameDiff, sdLogits, sdLabels, 0).outputVariable();
 | 
			
		||||
        sameDiff.setLossVariables(output);
 | 
			
		||||
 | 
			
		||||
        TestCase tc = new TestCase(sameDiff)
 | 
			
		||||
                .gradientCheck(true)
 | 
			
		||||
                .expectedOutput(output.name(), expected);
 | 
			
		||||
 | 
			
		||||
        String err = OpValidation.validate(tc);
 | 
			
		||||
        assertNull(err);
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -1016,7 +1016,7 @@ public class ShapeOpValidation extends BaseOpValidation {
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void testConstant(){
 | 
			
		||||
        OpValidationSuite.ignoreFailing();
 | 
			
		||||
        //OpValidationSuite.ignoreFailing();
 | 
			
		||||
 | 
			
		||||
        //Case 0: no shape
 | 
			
		||||
        SameDiff sd = SameDiff.create();
 | 
			
		||||
@ -1035,7 +1035,9 @@ public class ShapeOpValidation extends BaseOpValidation {
 | 
			
		||||
        INDArray exp = Nd4j.valueArrayOf(new long[]{3,4,5}, 3.0);
 | 
			
		||||
        loss = constant.std(true);
 | 
			
		||||
 | 
			
		||||
        assertNull(OpValidation.validate(new TestCase(sd).expected(constant, ia)));
 | 
			
		||||
        assertNull(OpValidation.validate(new TestCase(sd)
 | 
			
		||||
                .gradientCheck(false)
 | 
			
		||||
                .expected(constant, Nd4j.create(DataType.FLOAT, 3,4,5))));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1272,7 +1274,7 @@ public class ShapeOpValidation extends BaseOpValidation {
 | 
			
		||||
 | 
			
		||||
            SameDiff sd = SameDiff.create();
 | 
			
		||||
            SDVariable data = sd.var("data", d);
 | 
			
		||||
            SDVariable segments = sd.var("segments", s);
 | 
			
		||||
            SDVariable segments = sd.constant("segments", s);
 | 
			
		||||
 | 
			
		||||
            SDVariable sm;
 | 
			
		||||
            INDArray exp;
 | 
			
		||||
@ -1326,6 +1328,7 @@ public class ShapeOpValidation extends BaseOpValidation {
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            SDVariable loss = sm.std(true);
 | 
			
		||||
            sd.addLossVariable(loss);
 | 
			
		||||
 | 
			
		||||
            TestCase tc = new TestCase(sd)
 | 
			
		||||
                    .testName(op)
 | 
			
		||||
@ -1363,17 +1366,19 @@ public class ShapeOpValidation extends BaseOpValidation {
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void testSequenceMask() {
 | 
			
		||||
        OpValidationSuite.ignoreFailing();  //2018-01-09: output datatype issue?
 | 
			
		||||
        SameDiff sameDiff = SameDiff.create();
 | 
			
		||||
        INDArray arr = Nd4j.create(new float[] {1, 3, 2}).reshape(3);
 | 
			
		||||
        SDVariable lengths = sameDiff.var("lengths", arr);
 | 
			
		||||
        INDArray arr = Nd4j.createFromArray(new int[] {1, 3, 2});
 | 
			
		||||
        // arr is not trainable, so it's constant in model
 | 
			
		||||
        SDVariable lengths = sameDiff.constant(arr);
 | 
			
		||||
 | 
			
		||||
        // Test with static max len
 | 
			
		||||
        int maxlen = 5;
 | 
			
		||||
        INDArray expected = Nd4j.create(new float[] {1, 0, 0, 0, 0,
 | 
			
		||||
                        1, 1, 1, 0, 0,
 | 
			
		||||
                        1, 1, 0, 0, 0},
 | 
			
		||||
                new long[]{3, 5});
 | 
			
		||||
        INDArray expected = Nd4j.create(new float[] {
 | 
			
		||||
                    1.f,     0.f,     0.f,    0.f,   0.f,
 | 
			
		||||
                    1.f,     1.f,     1.f,    0.f,   0.f,
 | 
			
		||||
                    1.f,     1.f,     0.f,    0.f,   0.f
 | 
			
		||||
                }).reshape(3,5);
 | 
			
		||||
        INDArray[] ret = Nd4j.exec(new SequenceMask(arr, maxlen, DataType.FLOAT));
 | 
			
		||||
        SDVariable result1 = sameDiff.sequenceMask(lengths, maxlen, DataType.FLOAT);
 | 
			
		||||
        assertArrayEquals(expected.shape(), result1.eval().shape());
 | 
			
		||||
        assertEquals(expected, result1.eval());
 | 
			
		||||
@ -1382,14 +1387,14 @@ public class ShapeOpValidation extends BaseOpValidation {
 | 
			
		||||
 | 
			
		||||
        String err = OpValidation.validate(new TestCase(sameDiff)
 | 
			
		||||
                .expected(result1, expected)
 | 
			
		||||
                .gradCheckSkipVariables(lengths.name()));
 | 
			
		||||
                .gradientCheck(false));
 | 
			
		||||
        assertNull(err);
 | 
			
		||||
 | 
			
		||||
        // Test with dynamic maxlen
 | 
			
		||||
        lengths = sameDiff.var("lengths2", arr); // required because of an internal samediff bug
 | 
			
		||||
        SDVariable maxLen = sameDiff.var("maxLen", Nd4j.create(new float[]{5}).reshape(1));
 | 
			
		||||
        lengths = sameDiff.constant("lengths2", arr);
 | 
			
		||||
        SDVariable maxLen = sameDiff.constant("maxLen", Nd4j.scalar(5));
 | 
			
		||||
        SDVariable result2 = sameDiff.sequenceMask(lengths, maxLen, DataType.FLOAT);
 | 
			
		||||
        assertArrayEquals(expected.shape(), result2.eval().shape());
 | 
			
		||||
//        assertArrayEquals(expected.shape(), result2.eval().shape());
 | 
			
		||||
        assertEquals(expected, result2.eval());
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -303,7 +303,7 @@ public class TransformOpValidation extends BaseOpValidation {
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void testBatchToSpace() {
 | 
			
		||||
        OpValidationSuite.ignoreFailing();          //TODO: https://github.com/deeplearning4j/deeplearning4j/issues/6863
 | 
			
		||||
        //OpValidationSuite.ignoreFailing();          //TODO: https://github.com/deeplearning4j/deeplearning4j/issues/6863
 | 
			
		||||
        Nd4j.getRandom().setSeed(1337);
 | 
			
		||||
 | 
			
		||||
        int miniBatch = 4;
 | 
			
		||||
@ -314,7 +314,6 @@ public class TransformOpValidation extends BaseOpValidation {
 | 
			
		||||
        int[] cropShape = new int[]{M, 2};
 | 
			
		||||
 | 
			
		||||
        INDArray input = Nd4j.randn(inputShape).castTo(DataType.DOUBLE);
 | 
			
		||||
        INDArray blocks = Nd4j.create(new float[]{2, 2}, blockShape).castTo(DataType.INT);
 | 
			
		||||
        INDArray crops = Nd4j.create(new float[]{0, 0, 0, 0}, cropShape).castTo(DataType.INT);
 | 
			
		||||
 | 
			
		||||
        SameDiff sd = SameDiff.create();
 | 
			
		||||
@ -323,7 +322,8 @@ public class TransformOpValidation extends BaseOpValidation {
 | 
			
		||||
 | 
			
		||||
        INDArray expOut = Nd4j.create(DataType.DOUBLE, 1, 2, 2, 1);
 | 
			
		||||
        DynamicCustomOp op = DynamicCustomOp.builder("batch_to_space")
 | 
			
		||||
                .addInputs(input, blocks, crops)
 | 
			
		||||
                .addInputs(input, crops)
 | 
			
		||||
                .addIntegerArguments(2)
 | 
			
		||||
                .addOutputs(expOut).build();
 | 
			
		||||
        Nd4j.getExecutioner().exec(op);
 | 
			
		||||
 | 
			
		||||
@ -340,7 +340,7 @@ public class TransformOpValidation extends BaseOpValidation {
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void testSpaceToBatch() {
 | 
			
		||||
        OpValidationSuite.ignoreFailing();          //TODO: https://github.com/deeplearning4j/deeplearning4j/issues/6863
 | 
			
		||||
        //OpValidationSuite.ignoreFailing();          //TODO: https://github.com/deeplearning4j/deeplearning4j/issues/6863
 | 
			
		||||
 | 
			
		||||
        Nd4j.getRandom().setSeed(7331);
 | 
			
		||||
 | 
			
		||||
@ -352,7 +352,6 @@ public class TransformOpValidation extends BaseOpValidation {
 | 
			
		||||
        int[] paddingShape = new int[]{M, 2};
 | 
			
		||||
 | 
			
		||||
        INDArray input = Nd4j.randn(inputShape).castTo(DataType.DOUBLE);
 | 
			
		||||
        INDArray blocks = Nd4j.create(new float[]{2, 2}, blockShape).castTo(DataType.INT);
 | 
			
		||||
        INDArray padding = Nd4j.create(new float[]{0, 0, 0, 0}, paddingShape).castTo(DataType.INT);
 | 
			
		||||
 | 
			
		||||
        SameDiff sd = SameDiff.create();
 | 
			
		||||
@ -361,7 +360,8 @@ public class TransformOpValidation extends BaseOpValidation {
 | 
			
		||||
 | 
			
		||||
        INDArray expOut = Nd4j.create(DataType.DOUBLE, miniBatch, 1, 1, 1);
 | 
			
		||||
        DynamicCustomOp op = DynamicCustomOp.builder("space_to_batch")
 | 
			
		||||
                .addInputs(input, blocks, padding)
 | 
			
		||||
                .addIntegerArguments(2)
 | 
			
		||||
                .addInputs(input, padding)
 | 
			
		||||
                .addOutputs(expOut).build();
 | 
			
		||||
        Nd4j.getExecutioner().exec(op);
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -37,6 +37,7 @@ import org.nd4j.linalg.api.ops.impl.image.ResizeBilinear;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.reduce.MmulBp;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.shape.Create;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.shape.OnesLike;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.shape.SequenceMask;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp;
 | 
			
		||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.ModOp;
 | 
			
		||||
@ -1737,4 +1738,19 @@ public class CustomOpsTests extends BaseNd4jTest {
 | 
			
		||||
 | 
			
		||||
        assertEquals(expected, ret[0]);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void testSequenceMask() {
 | 
			
		||||
        INDArray arr = Nd4j.createFromArray(new int[]{1, 3, 2});
 | 
			
		||||
        // Test with static max len
 | 
			
		||||
        int maxlen = 2;
 | 
			
		||||
        INDArray expected = Nd4j.createFromArray(new int[]{
 | 
			
		||||
                1,0,0,
 | 
			
		||||
                1,1,1,
 | 
			
		||||
                1,1,0
 | 
			
		||||
        }).reshape(3, 3);
 | 
			
		||||
 | 
			
		||||
        INDArray[] ret = Nd4j.exec(new SequenceMask(arr, maxlen, DataType.INT32));
 | 
			
		||||
        assertEquals(expected, ret[0]);
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -318,8 +318,8 @@ trait NDArrayExtractionTestBase extends FlatSpec { self: OrderingForTest =>
 | 
			
		||||
 | 
			
		||||
  "num2Scalar" should "convert number to Scalar INDArray" in {
 | 
			
		||||
 | 
			
		||||
    assert(1.toScalar.data() == List(1).toNDArray.data())
 | 
			
		||||
    assert(2f.toScalar.data() == List(2).toNDArray.data())
 | 
			
		||||
    assert(3d.toScalar.data() == List(3).toNDArray.data())
 | 
			
		||||
    assert(1.toScalar.reshape(1) == List(1).toNDArray)
 | 
			
		||||
    assert(2f.toScalar.reshape(1) == List(2f).toNDArray)
 | 
			
		||||
    assert(3d.toScalar.reshape(1) == List(3d).toNDArray)
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user