Improving SameDiff tests coverage (#227)
* Gradients tests added * Fix for Standard deviation serialization + test Signed-off-by: Alex Black <blacka101@gmail.com> * More fixes Signed-off-by: Alex Black <blacka101@gmail.com> * Test fixed * Spark config driver host config for CI Signed-off-by: Alex Black <blacka101@gmail.com> * Op validation timeout increase Signed-off-by: Alex Black <blacka101@gmail.com> * Gradient check - fix for low probability test failure due to randomly all 0s mask Signed-off-by: AlexDBlack <blacka101@gmail.com> Co-authored-by: Alex Black <blacka101@gmail.com>master
parent
5c9e0bc2bb
commit
8c0e378ec3
|
@ -414,7 +414,11 @@ public class GradientCheckTestsMasking extends BaseDL4JTest {
|
|||
INDArray l = TestUtils.randomOneHot(mb, 3);
|
||||
INDArray lm = TestUtils.randomBernoulli(mb, 1);
|
||||
|
||||
assertTrue(lm.sumNumber().intValue() > 0);
|
||||
int attempts = 0;
|
||||
while(attempts++ < 1000 && lm.sumNumber().intValue() == 0){
|
||||
lm = TestUtils.randomBernoulli(mb, 1);
|
||||
}
|
||||
assertTrue("Could not generate non-zero mask after " + attempts + " attempts", lm.sumNumber().intValue() > 0);
|
||||
|
||||
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(net).input(f)
|
||||
.labels(l).labelMask(lm));
|
||||
|
@ -467,7 +471,11 @@ public class GradientCheckTestsMasking extends BaseDL4JTest {
|
|||
INDArray l = TestUtils.randomOneHot(mb, 3);
|
||||
INDArray lm = TestUtils.randomBernoulli(mb, 1);
|
||||
|
||||
assertTrue(lm.sumNumber().intValue() > 0);
|
||||
int attempts = 0;
|
||||
while(attempts++ < 1000 && lm.sumNumber().intValue() == 0){
|
||||
lm = TestUtils.randomBernoulli(mb, 1);
|
||||
}
|
||||
assertTrue("Could not generate non-zero mask after " + attempts + " attempts", lm.sumNumber().intValue() > 0);
|
||||
|
||||
boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(net).inputs(new INDArray[]{f})
|
||||
.labels(new INDArray[]{l}).labelMask(new INDArray[]{lm}));
|
||||
|
|
|
@ -67,7 +67,9 @@ public class SparkSequenceVectorsTest extends BaseDL4JTest {
|
|||
}
|
||||
}
|
||||
|
||||
SparkConf sparkConf = new SparkConf().setMaster("local[8]").setAppName("SeqVecTests");
|
||||
SparkConf sparkConf = new SparkConf().setMaster("local[8]")
|
||||
.set("spark.driver.host", "localhost")
|
||||
.setAppName("SeqVecTests");
|
||||
sc = new JavaSparkContext(sparkConf);
|
||||
}
|
||||
|
||||
|
|
|
@ -61,7 +61,9 @@ public class SparkWord2VecTest extends BaseDL4JTest {
|
|||
sentences.add("one another sentence");
|
||||
}
|
||||
|
||||
SparkConf sparkConf = new SparkConf().setMaster("local[8]").setAppName("SeqVecTests");
|
||||
SparkConf sparkConf = new SparkConf().setMaster("local[8]")
|
||||
.set("spark.driver.host", "localhost")
|
||||
.setAppName("SeqVecTests");
|
||||
sc = new JavaSparkContext(sparkConf);
|
||||
}
|
||||
|
||||
|
|
|
@ -56,7 +56,9 @@ public class Word2VecTest {
|
|||
@Test
|
||||
public void testConcepts() throws Exception {
|
||||
// These are all default values for word2vec
|
||||
SparkConf sparkConf = new SparkConf().setMaster("local[8]").setAppName("sparktest");
|
||||
SparkConf sparkConf = new SparkConf().setMaster("local[8]")
|
||||
.set("spark.driver.host", "localhost")
|
||||
.setAppName("sparktest");
|
||||
|
||||
// Set SparkContext
|
||||
JavaSparkContext sc = new JavaSparkContext(sparkConf);
|
||||
|
@ -156,6 +158,7 @@ public class Word2VecTest {
|
|||
@Test
|
||||
public void testSparkW2VonBiggerCorpus() throws Exception {
|
||||
SparkConf sparkConf = new SparkConf().setMaster("local[8]").setAppName("sparktest")
|
||||
.set("spark.driver.host", "localhost")
|
||||
.set("spark.driver.maxResultSize", "4g").set("spark.driver.memory", "8g")
|
||||
.set("spark.executor.memory", "8g");
|
||||
|
||||
|
|
|
@ -63,7 +63,7 @@ public class TextPipelineTest extends BaseSparkTest {
|
|||
|
||||
@Before
|
||||
public void before() throws Exception {
|
||||
conf = new SparkConf().setMaster("local[4]").setAppName("sparktest");
|
||||
conf = new SparkConf().setMaster("local[4]").setAppName("sparktest").set("spark.driver.host", "localhost");
|
||||
|
||||
// All the avaliable options. These are default values
|
||||
word2vec = new Word2Vec.Builder().minWordFrequency(1).setNGrams(1)
|
||||
|
|
|
@ -85,7 +85,8 @@ public abstract class BaseSparkTest extends BaseDL4JTest implements Serializable
|
|||
if (sc != null)
|
||||
return sc;
|
||||
// set to test mode
|
||||
SparkConf sparkConf = new SparkConf().setMaster("local[" + numExecutors() + "]").set("spark.driver.host", "localhost").setAppName("sparktest");
|
||||
SparkConf sparkConf = new SparkConf().setMaster("local[" + numExecutors() + "]")
|
||||
.set("spark.driver.host", "localhost").setAppName("sparktest");
|
||||
|
||||
|
||||
sc = new JavaSparkContext(sparkConf);
|
||||
|
|
|
@ -59,7 +59,9 @@ public class BaseSparkKryoTest extends BaseSparkTest {
|
|||
|
||||
|
||||
|
||||
SparkConf sparkConf = new SparkConf().setMaster("local[" + numExecutors() + "]").setAppName("sparktest");
|
||||
SparkConf sparkConf = new SparkConf().setMaster("local[" + numExecutors() + "]")
|
||||
.setAppName("sparktest")
|
||||
.set("spark.driver.host", "localhost");
|
||||
|
||||
sparkConf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer");
|
||||
sparkConf.set("spark.kryo.registrator", "org.nd4j.Nd4jRegistrator");
|
||||
|
|
|
@ -89,7 +89,8 @@ public abstract class BaseSparkTest extends BaseDL4JTest implements Serializable
|
|||
if (sc != null)
|
||||
return sc;
|
||||
// set to test mode
|
||||
SparkConf sparkConf = new SparkConf().setMaster("local[" + numExecutors() + "]").set("spark.driver.host", "localhost").setAppName("sparktest");
|
||||
SparkConf sparkConf = new SparkConf().setMaster("local[" + numExecutors() + "]")
|
||||
.set("spark.driver.host", "localhost").setAppName("sparktest");
|
||||
|
||||
|
||||
sc = new JavaSparkContext(sparkConf);
|
||||
|
|
|
@ -72,8 +72,9 @@ public class TestKryoWarning {
|
|||
@Ignore
|
||||
public void testKryoMessageMLNIncorrectConfig() {
|
||||
//Should print warning message
|
||||
SparkConf sparkConf = new SparkConf().setMaster("local[*]").setAppName("sparktest").set("spark.serializer",
|
||||
"org.apache.spark.serializer.KryoSerializer");
|
||||
SparkConf sparkConf = new SparkConf().setMaster("local[*]").setAppName("sparktest")
|
||||
.set("spark.driver.host", "localhost")
|
||||
.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer");
|
||||
|
||||
doTestMLN(sparkConf);
|
||||
}
|
||||
|
@ -83,6 +84,7 @@ public class TestKryoWarning {
|
|||
public void testKryoMessageMLNCorrectConfigKryo() {
|
||||
//Should NOT print warning message
|
||||
SparkConf sparkConf = new SparkConf().setMaster("local[*]").setAppName("sparktest")
|
||||
.set("spark.driver.host", "localhost")
|
||||
.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
|
||||
.set("spark.kryo.registrator", "org.nd4j.Nd4jRegistrator");
|
||||
|
||||
|
@ -93,7 +95,9 @@ public class TestKryoWarning {
|
|||
@Ignore
|
||||
public void testKryoMessageMLNCorrectConfigNoKryo() {
|
||||
//Should NOT print warning message
|
||||
SparkConf sparkConf = new SparkConf().setMaster("local[*]").setAppName("sparktest");
|
||||
SparkConf sparkConf = new SparkConf().setMaster("local[*]")
|
||||
.set("spark.driver.host", "localhost")
|
||||
.setAppName("sparktest");
|
||||
|
||||
doTestMLN(sparkConf);
|
||||
}
|
||||
|
@ -104,8 +108,9 @@ public class TestKryoWarning {
|
|||
@Ignore
|
||||
public void testKryoMessageCGIncorrectConfig() {
|
||||
//Should print warning message
|
||||
SparkConf sparkConf = new SparkConf().setMaster("local[*]").setAppName("sparktest").set("spark.serializer",
|
||||
"org.apache.spark.serializer.KryoSerializer");
|
||||
SparkConf sparkConf = new SparkConf().setMaster("local[*]").setAppName("sparktest")
|
||||
.set("spark.driver.host", "localhost")
|
||||
.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer");
|
||||
|
||||
doTestCG(sparkConf);
|
||||
}
|
||||
|
@ -115,6 +120,7 @@ public class TestKryoWarning {
|
|||
public void testKryoMessageCGCorrectConfigKryo() {
|
||||
//Should NOT print warning message
|
||||
SparkConf sparkConf = new SparkConf().setMaster("local[*]").setAppName("sparktest")
|
||||
.set("spark.driver.host", "localhost")
|
||||
.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
|
||||
.set("spark.kryo.registrator", "org.nd4j.Nd4jRegistrator");
|
||||
|
||||
|
@ -125,7 +131,9 @@ public class TestKryoWarning {
|
|||
@Ignore
|
||||
public void testKryoMessageCGCorrectConfigNoKryo() {
|
||||
//Should NOT print warning message
|
||||
SparkConf sparkConf = new SparkConf().setMaster("local[*]").setAppName("sparktest");
|
||||
SparkConf sparkConf = new SparkConf().setMaster("local[*]")
|
||||
.set("spark.driver.host", "localhost")
|
||||
.setAppName("sparktest");
|
||||
|
||||
doTestCG(sparkConf);
|
||||
}
|
||||
|
|
|
@ -138,6 +138,7 @@ public class TestCompareParameterAveragingSparkVsSingleMachine {
|
|||
SparkConf sparkConf = new SparkConf();
|
||||
sparkConf.setMaster("local[" + nWorkers + "]");
|
||||
sparkConf.setAppName("Test");
|
||||
sparkConf.set("spark.driver.host", "localhost");
|
||||
|
||||
JavaSparkContext sc = new JavaSparkContext(sparkConf);
|
||||
return sc;
|
||||
|
|
|
@ -58,7 +58,7 @@ public class ExportSupportTest {
|
|||
}
|
||||
|
||||
private void assertSupported(SparkConf conf) throws IOException {
|
||||
JavaSparkContext sc = new JavaSparkContext(conf.setAppName("Test"));
|
||||
JavaSparkContext sc = new JavaSparkContext(conf.setAppName("Test").set("spark.driver.host", "localhost"));
|
||||
try {
|
||||
assertTrue(ExportSupport.exportSupported(sc));
|
||||
} finally {
|
||||
|
|
|
@ -25,6 +25,7 @@ import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
|||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
||||
import org.deeplearning4j.spark.BaseSparkTest;
|
||||
import org.deeplearning4j.spark.api.Repartition;
|
||||
import org.deeplearning4j.spark.api.stats.CommonSparkTrainingStats;
|
||||
import org.deeplearning4j.spark.api.stats.SparkTrainingStats;
|
||||
|
@ -50,17 +51,13 @@ import static org.junit.Assert.*;
|
|||
/**
|
||||
* Created by Alex on 17/06/2016.
|
||||
*/
|
||||
public class TestTrainingStatsCollection {
|
||||
public class TestTrainingStatsCollection extends BaseSparkTest {
|
||||
|
||||
@Test
|
||||
public void testStatsCollection() throws Exception {
|
||||
int nWorkers = 4;
|
||||
int nWorkers = numExecutors();
|
||||
|
||||
SparkConf sparkConf = new SparkConf();
|
||||
sparkConf.setMaster("local[" + nWorkers + "]");
|
||||
sparkConf.setAppName("Test");
|
||||
|
||||
JavaSparkContext sc = new JavaSparkContext(sparkConf);
|
||||
JavaSparkContext sc = getContext();
|
||||
|
||||
try {
|
||||
|
||||
|
|
|
@ -294,6 +294,11 @@ public class DifferentialFunctionFactory {
|
|||
return new ZerosLike(name, sameDiff(), input).outputVariable();
|
||||
}
|
||||
|
||||
public SDVariable zerosLike(String name, SDVariable input, DataType dataType) {
|
||||
validateDifferentialFunctionsameDiff(input);
|
||||
return new ZerosLike(name, sameDiff(), input, dataType).outputVariable();
|
||||
}
|
||||
|
||||
public SDVariable create(String name, SDVariable shape, boolean initialize, DataType dataType) {
|
||||
return create(name, shape, 'c', initialize, dataType);
|
||||
}
|
||||
|
@ -1751,12 +1756,12 @@ public class DifferentialFunctionFactory {
|
|||
return new SoftmaxCrossEntropyLossBp(sameDiff(), lossReduce, logits, weights, labels, labelSmoothing).outputVariables();
|
||||
}
|
||||
|
||||
public SDVariable lossSoftmaxCrossEntropyWithLogits(SDVariable labels, SDVariable logits, SDVariable weights, int classDim) {
|
||||
return new SoftmaxCrossEntropyWithLogitsLoss(sameDiff(), logits, weights, labels, classDim).outputVariable();
|
||||
public SDVariable lossSoftmaxCrossEntropyWithLogits(SDVariable labels, SDVariable logits, int classDim) {
|
||||
return new SoftmaxCrossEntropyWithLogitsLoss(sameDiff(), logits, labels, classDim).outputVariable();
|
||||
}
|
||||
|
||||
public SDVariable[] lossSoftmaxCrossEntropyWithLogitsBp(SDVariable labels, SDVariable logits, SDVariable weights, int classDim) {
|
||||
return new SoftmaxCrossEntropyWithLogitsLossBp(sameDiff(), logits, weights, labels, classDim).outputVariables();
|
||||
public SDVariable[] lossSoftmaxCrossEntropyWithLogitsBp(SDVariable labels, SDVariable logits, int classDim) {
|
||||
return new SoftmaxCrossEntropyWithLogitsLossBp(sameDiff(), logits, labels, classDim).outputVariables();
|
||||
}
|
||||
|
||||
public SDVariable lossSparseSoftmaxCrossEntropy(SDVariable logits, SDVariable labels){
|
||||
|
@ -2638,7 +2643,7 @@ public class DifferentialFunctionFactory {
|
|||
return new Polygamma(sameDiff, n,x).outputVariable();
|
||||
}
|
||||
|
||||
public SDVariable roll(SDVariable input, SDVariable shift) {
|
||||
public SDVariable roll(SDVariable input, int shift) {
|
||||
return new Roll(sameDiff, input, shift).outputVariable();
|
||||
}
|
||||
|
||||
|
|
|
@ -787,9 +787,10 @@ public abstract class SDBaseOps {
|
|||
* @param number Number of values to generate
|
||||
* @return SDVariable with linearly spaced elements
|
||||
*/
|
||||
public SDVariable linspace(DataType dataType, double start, double stop, long number) {
|
||||
// TODO: fix or remove, currently it is internal recursion
|
||||
/*public SDVariable linspace(DataType dataType, double start, double stop, long number) {
|
||||
return linspace(dataType, start, stop, number);
|
||||
}
|
||||
}*/
|
||||
|
||||
/**
|
||||
* Create a new 1d array with values evenly spaced between values 'start' and 'stop'
|
||||
|
@ -3093,6 +3094,9 @@ public abstract class SDBaseOps {
|
|||
return zerosLike(null, input);
|
||||
}
|
||||
|
||||
public SDVariable zerosLike(@NonNull SDVariable input, @NonNull DataType dataType) {
|
||||
return zerosLike(null, input, dataType);
|
||||
}
|
||||
/**
|
||||
* Return a variable of all 0s, with the same shape as the input variable. Note that this is dynamic:
|
||||
* if the input shape changes in later execution, the returned variable's shape will also be updated
|
||||
|
@ -3106,6 +3110,10 @@ public abstract class SDBaseOps {
|
|||
return updateVariableNameAndReference(ret, name);
|
||||
}
|
||||
|
||||
public SDVariable zerosLike(String name, @NonNull SDVariable input, @NonNull DataType dataType) {
|
||||
SDVariable ret = f().zerosLike(name, input, dataType);
|
||||
return updateVariableNameAndReference(ret, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* See {@link #any(String, SDVariable, int...)}
|
||||
|
|
|
@ -2545,7 +2545,7 @@ public class SDMath extends SDOps {
|
|||
* @param shift number of places to shift elements
|
||||
* @return array
|
||||
*/
|
||||
public SDVariable roll(String name, SDVariable input, SDVariable shift) {
|
||||
public SDVariable roll(String name, SDVariable input, int shift) {
|
||||
SDVariable res = f().roll(input,shift);
|
||||
return updateVariableNameAndReference(res, name);
|
||||
}
|
||||
|
|
|
@ -815,8 +815,9 @@ public class FlatBuffersMapper {
|
|||
}
|
||||
|
||||
int[] dims;
|
||||
if (node.opType() == Op.Type.REDUCE_FLOAT || node.opType() == Op.Type.REDUCE_SAME || node.opType() == Op.Type.REDUCE_BOOL
|
||||
|| node.opType() == Op.Type.REDUCE_LONG || node.opType() == Op.Type.INDEXREDUCE || node.opType() == Op.Type.REDUCE3) {
|
||||
Type t = node.opType();
|
||||
if (t == Op.Type.REDUCE_FLOAT || t == Op.Type.REDUCE_SAME || t == Op.Type.REDUCE_BOOL
|
||||
|| t == Op.Type.REDUCE_LONG || t == Op.Type.INDEXREDUCE || t == Op.Type.REDUCE3 || t == Type.VARIANCE || t == Type.SUMMARYSTATS) {
|
||||
dims = node.getDimensions();
|
||||
if (dims == null)
|
||||
dims = new int[0];
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.nd4j.autodiff.validation;
|
||||
|
||||
import org.nd4j.linalg.api.ops.custom.*;
|
||||
import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax;
|
||||
import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin;
|
||||
import org.nd4j.linalg.api.ops.impl.reduce.HashCode;
|
||||
|
@ -38,10 +39,6 @@ import org.nd4j.linalg.api.iter.NdIndexIterator;
|
|||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.CustomOpDescriptor;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.api.ops.custom.BarnesEdgeForces;
|
||||
import org.nd4j.linalg.api.ops.custom.BarnesHutGains;
|
||||
import org.nd4j.linalg.api.ops.custom.BarnesHutSymmetrize;
|
||||
import org.nd4j.linalg.api.ops.custom.SpTreeCell;
|
||||
import org.nd4j.linalg.api.ops.impl.broadcast.bool.*;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction;
|
||||
import org.nd4j.linalg.api.ops.impl.loss.bp.*;
|
||||
|
@ -1011,7 +1008,10 @@ public class OpValidation {
|
|||
SpTreeCell.class,
|
||||
CbowRound.class,
|
||||
SkipGramRound.class,
|
||||
HashCode.class
|
||||
HashCode.class,
|
||||
HashCode.class,
|
||||
BitCast.class,
|
||||
ToggleBits.class
|
||||
);
|
||||
|
||||
return new HashSet<>(list);
|
||||
|
|
|
@ -200,7 +200,6 @@ public class ImportClassMapping {
|
|||
org.nd4j.linalg.api.ops.impl.reduce.custom.BatchMmul.class,
|
||||
org.nd4j.linalg.api.ops.impl.reduce.custom.LogSumExp.class,
|
||||
org.nd4j.linalg.api.ops.impl.reduce.floating.AMean.class,
|
||||
org.nd4j.linalg.api.ops.impl.reduce.floating.Bias.class,
|
||||
org.nd4j.linalg.api.ops.impl.reduce.floating.Entropy.class,
|
||||
org.nd4j.linalg.api.ops.impl.reduce.floating.LogEntropy.class,
|
||||
org.nd4j.linalg.api.ops.impl.reduce.floating.Mean.class,
|
||||
|
|
|
@ -16,21 +16,28 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.custom;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.val;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* This op takes arbitrary number of arrays as input, and returns single "flattened" vector
|
||||
*
|
||||
* @author raver119@gmail.com
|
||||
*/
|
||||
@Data
|
||||
@NoArgsConstructor
|
||||
public class Flatten extends DynamicCustomOp {
|
||||
private char order;
|
||||
|
||||
public Flatten() {
|
||||
//
|
||||
}
|
||||
private int order;
|
||||
|
||||
public Flatten(char order, INDArray... inputs) {
|
||||
this.order = order;
|
||||
|
@ -47,10 +54,21 @@ public class Flatten extends DynamicCustomOp {
|
|||
outputArguments.add(output);
|
||||
}
|
||||
|
||||
public Flatten(SameDiff sameDiff, char order, SDVariable... inputs) {
|
||||
super(sameDiff, inputs);
|
||||
this.order = order;
|
||||
addIArgument(order);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "flatten";
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||
int n = args().length;
|
||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
|
||||
return Arrays.asList(inputDataTypes.get(0));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -51,6 +51,14 @@ public class FusedBatchNorm extends DynamicCustomOp {
|
|||
public FusedBatchNorm(@NonNull SameDiff sameDiff, @NonNull SDVariable x, @NonNull SDVariable scale, @NonNull SDVariable offset,
|
||||
@NonNull SDVariable dataFormat, @NonNull SDVariable isTraining) {
|
||||
super("", sameDiff, new SDVariable[]{x, scale, offset, dataFormat, isTraining});
|
||||
this.outputDataType = x.dataType();
|
||||
}
|
||||
|
||||
public FusedBatchNorm(@NonNull SameDiff sameDiff, @NonNull SDVariable x, @NonNull SDVariable scale, @NonNull SDVariable offset,
|
||||
int dataFormat, int isTraining) {
|
||||
super("", sameDiff, new SDVariable[]{x, scale, offset});
|
||||
addIArgument(dataFormat, isTraining);
|
||||
this.outputDataType = x.dataType();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -78,6 +86,8 @@ public class FusedBatchNorm extends DynamicCustomOp {
|
|||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||
int n = args().length;
|
||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
|
||||
return Arrays.asList(outputDataType, DataType.FLOAT, DataType.FLOAT); //Activations may be half, bfloat16, float32; mean/var is always float
|
||||
return Arrays.asList(outputDataType == null ? DataType.FLOAT : outputDataType,
|
||||
outputDataType == null ? DataType.FLOAT : outputDataType,
|
||||
outputDataType == null ? DataType.FLOAT : outputDataType);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -64,6 +64,6 @@ public class Lu extends DynamicCustomOp {
|
|||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||
int n = args().length;
|
||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
|
||||
return Arrays.asList(inputDataTypes.get(0), indexDataType);
|
||||
return Arrays.asList(inputDataTypes.get(0), indexDataType == null ? DataType.INT32 : indexDataType);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -46,6 +46,11 @@ public class MatrixBandPart extends DynamicCustomOp {
|
|||
super("", sameDiff, new SDVariable[]{input, minLower, maxUpper});
|
||||
}
|
||||
|
||||
public MatrixBandPart(@NonNull SameDiff sameDiff, @NonNull SDVariable input, int minLower, int maxUpper) {
|
||||
super("", sameDiff, new SDVariable[]{input});
|
||||
addIArgument(minLower, maxUpper);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "matrix_band_part";
|
||||
|
|
|
@ -45,6 +45,15 @@ public class Roll extends DynamicCustomOp {
|
|||
super("", sameDiff, new SDVariable[]{input,shift});
|
||||
}
|
||||
|
||||
public Roll(@NonNull SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable axes, @NonNull SDVariable shift) {
|
||||
super("", sameDiff, new SDVariable[]{input,axes,shift});
|
||||
}
|
||||
|
||||
public Roll(@NonNull SameDiff sameDiff, @NonNull SDVariable input, int shift) {
|
||||
super("", sameDiff, new SDVariable[]{input});
|
||||
addIArgument(shift);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "roll";
|
||||
|
|
|
@ -7,9 +7,13 @@ import org.nd4j.base.Preconditions;
|
|||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.tensorflow.framework.AttrValue;
|
||||
import org.tensorflow.framework.GraphDef;
|
||||
import org.tensorflow.framework.NodeDef;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
@NoArgsConstructor
|
||||
public class TriangularSolve extends DynamicCustomOp {
|
||||
|
@ -24,11 +28,27 @@ public class TriangularSolve extends DynamicCustomOp {
|
|||
super(sameDiff, new SDVariable[] {matrix, rhs, lower, adjoint});
|
||||
}
|
||||
|
||||
public TriangularSolve(SameDiff sameDiff, SDVariable matrix, SDVariable rhs,
|
||||
boolean lower, boolean adjoint) {
|
||||
super(sameDiff, new SDVariable[] {matrix, rhs});
|
||||
addBArgument(lower, adjoint);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "triangular_solve";
|
||||
}
|
||||
|
||||
@Override
|
||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||
if(attributesForNode.containsKey("adjoint")){
|
||||
addBArgument(attributesForNode.get("adjoint").getB());
|
||||
}
|
||||
if(attributesForNode.containsKey("lower")){
|
||||
addBArgument(attributesForNode.get("lower").getB());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public String tensorflowName() {
|
||||
return "MatrixTriangularSolve";
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.broadcast;
|
||||
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
|
@ -27,6 +28,7 @@ import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
|||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
@NoArgsConstructor
|
||||
public class BiasAddGrad extends DynamicCustomOp {
|
||||
protected boolean nchw = true;
|
||||
|
||||
|
@ -40,7 +42,16 @@ public class BiasAddGrad extends DynamicCustomOp {
|
|||
super(new INDArray[]{input, bias, gradient}, wrapOrNull(output));
|
||||
}
|
||||
|
||||
public BiasAddGrad() {}
|
||||
public BiasAddGrad(@NonNull INDArray input, @NonNull INDArray bias, @NonNull INDArray gradient,
|
||||
boolean nchw) {
|
||||
addInputArgument(input, bias, gradient);
|
||||
this.nchw = nchw;
|
||||
addBArgument(nchw);
|
||||
}
|
||||
|
||||
public BiasAddGrad(@NonNull INDArray input, @NonNull INDArray bias, @NonNull INDArray gradient) {
|
||||
this(input, bias, gradient, false);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int opNum() {
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.indexaccum;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
|
@ -35,6 +37,8 @@ import java.util.List;
|
|||
*
|
||||
* @author raver119@gmail.com
|
||||
*/
|
||||
@Data
|
||||
@NoArgsConstructor
|
||||
public class FirstIndex extends BaseIndexAccumulation {
|
||||
protected Condition condition;
|
||||
protected double compare;
|
||||
|
@ -50,9 +54,6 @@ public class FirstIndex extends BaseIndexAccumulation {
|
|||
this.extraArgs = new Object[] {compare, eps, (double) mode};
|
||||
}
|
||||
|
||||
public FirstIndex() {}
|
||||
|
||||
|
||||
public FirstIndex(INDArray x, @NonNull Condition condition, int... dimension) {
|
||||
this(x, condition, false, dimension);
|
||||
}
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.indexaccum;
|
||||
|
||||
import lombok.Data;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
@ -30,6 +31,7 @@ import java.util.List;
|
|||
*
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
@Data
|
||||
public class IAMax extends BaseIndexAccumulation {
|
||||
public IAMax(SameDiff sameDiff, SDVariable i_v, boolean keepDims, int[] dimensions) {
|
||||
super(sameDiff, i_v, keepDims, dimensions);
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.indexaccum;
|
||||
|
||||
import lombok.Data;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
@ -30,6 +31,7 @@ import java.util.List;
|
|||
*
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
@Data
|
||||
public class IAMin extends BaseIndexAccumulation {
|
||||
public IAMin(SameDiff sameDiff, SDVariable i_v, boolean keepDims, int[] dimensions) {
|
||||
super(sameDiff, i_v, keepDims, dimensions);
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.indexaccum;
|
||||
|
||||
import lombok.Data;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.imports.NoOpNameFoundException;
|
||||
|
@ -31,6 +32,7 @@ import java.util.List;
|
|||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
@Data
|
||||
public class IMax extends BaseIndexAccumulation {
|
||||
public IMax(SameDiff sameDiff, SDVariable i_v, boolean keepDims, int[] dimensions) {
|
||||
super(sameDiff, i_v, keepDims, dimensions);
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.indexaccum;
|
||||
|
||||
import lombok.Data;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.imports.NoOpNameFoundException;
|
||||
|
@ -30,6 +31,7 @@ import java.util.List;
|
|||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
@Data
|
||||
public class IMin extends BaseIndexAccumulation {
|
||||
public IMin(SameDiff sameDiff, SDVariable i_v, boolean keepDims, int[] dimensions) {
|
||||
super(sameDiff, i_v, keepDims, dimensions);
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.indexaccum;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
|
@ -36,6 +37,7 @@ import java.util.Map;
|
|||
*
|
||||
* @author raver119@gmail.com
|
||||
*/
|
||||
@Data
|
||||
public class LastIndex extends BaseIndexAccumulation {
|
||||
protected Condition condition;
|
||||
protected double compare;
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.indexaccum.custom;
|
||||
|
||||
import lombok.Data;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||
|
@ -29,6 +30,7 @@ import java.util.Collections;
|
|||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
@Data
|
||||
public class ArgMax extends DynamicCustomOp {
|
||||
|
||||
protected DataType outputType;
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.indexaccum.custom;
|
||||
|
||||
import lombok.Data;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||
|
@ -34,6 +35,7 @@ import java.util.Map;
|
|||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
@Data
|
||||
public class ArgMin extends DynamicCustomOp {
|
||||
|
||||
protected DataType outputType = DataType.LONG;
|
||||
|
|
|
@ -22,6 +22,7 @@ import org.nd4j.autodiff.samediff.SameDiff;
|
|||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
|
@ -38,8 +39,14 @@ public class SoftmaxCrossEntropyWithLogitsLoss extends DynamicCustomOp {
|
|||
|
||||
protected int classesDim;
|
||||
|
||||
public SoftmaxCrossEntropyWithLogitsLoss(SameDiff sameDiff, SDVariable logits, SDVariable weights, SDVariable labels, int classesDim) {
|
||||
super(null, sameDiff, new SDVariable[]{logits, weights, labels}, false);
|
||||
// public SoftmaxCrossEntropyWithLogitsLoss(SameDiff sameDiff, SDVariable logits, SDVariable weights, SDVariable labels, int classesDim) {
|
||||
// super(null, sameDiff, new SDVariable[]{logits, weights, labels}, false);
|
||||
// this.classesDim = classesDim;
|
||||
// addIArgument(classesDim);
|
||||
// }
|
||||
|
||||
public SoftmaxCrossEntropyWithLogitsLoss(SameDiff sameDiff, SDVariable logits, SDVariable labels, int classesDim) {
|
||||
super(null, sameDiff, new SDVariable[]{logits, labels}, false);
|
||||
this.classesDim = classesDim;
|
||||
addIArgument(classesDim);
|
||||
}
|
||||
|
@ -66,7 +73,8 @@ public class SoftmaxCrossEntropyWithLogitsLoss extends DynamicCustomOp {
|
|||
public List<SDVariable> doDiff(List<SDVariable> grad){
|
||||
//No external gradient
|
||||
//Args: logits, weigths, label
|
||||
SDVariable[] grads = f().lossSoftmaxCrossEntropyWithLogitsBp(arg(2), arg(0), arg(1), classesDim);
|
||||
SDVariable[] args = args();
|
||||
SDVariable[] grads = f().lossSoftmaxCrossEntropyWithLogitsBp(arg(0), arg(1), classesDim);
|
||||
return Arrays.asList(grads);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -20,14 +20,18 @@ import lombok.NoArgsConstructor;
|
|||
import org.nd4j.autodiff.loss.LossReduce;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.imports.NoOpNameFoundException;
|
||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ops.Op;
|
||||
import org.nd4j.linalg.api.ops.impl.loss.BaseLoss;
|
||||
import org.tensorflow.framework.AttrValue;
|
||||
import org.tensorflow.framework.GraphDef;
|
||||
import org.tensorflow.framework.NodeDef;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
|
||||
|
@ -56,4 +60,12 @@ public class SoftmaxCrossEntropyLossBp extends BaseLossBp {
|
|||
public String opName() {
|
||||
return "softmax_cross_entropy_loss_grad";
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||
Preconditions.checkState(inputDataTypes != null && (inputDataTypes.size() == 2 || inputDataTypes.size() == 3),
|
||||
"Expected 2 or 3 input datatypes for %s, got %s", getClass(), inputDataTypes);
|
||||
|
||||
return Arrays.asList(inputDataTypes.get(0), inputDataTypes.get(1), inputDataTypes.get(2)); //Same as predictions
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,8 +19,10 @@ package org.nd4j.linalg.api.ops.impl.loss.bp;
|
|||
import lombok.NoArgsConstructor;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
|
||||
|
@ -34,8 +36,8 @@ public class SoftmaxCrossEntropyWithLogitsLossBp extends DynamicCustomOp {
|
|||
|
||||
protected int classesDim;
|
||||
|
||||
public SoftmaxCrossEntropyWithLogitsLossBp(SameDiff sameDiff, SDVariable logits, SDVariable weights, SDVariable labels, int classesDim) {
|
||||
super(null, sameDiff, new SDVariable[]{logits, weights, labels}, false);
|
||||
public SoftmaxCrossEntropyWithLogitsLossBp(SameDiff sameDiff, SDVariable logits, SDVariable labels, int classesDim) {
|
||||
super(null, sameDiff, new SDVariable[]{logits, labels}, false);
|
||||
this.classesDim = classesDim;
|
||||
addIArgument(classesDim);
|
||||
}
|
||||
|
@ -49,4 +51,9 @@ public class SoftmaxCrossEntropyWithLogitsLossBp extends DynamicCustomOp {
|
|||
public List<SDVariable> doDiff(List<SDVariable> grad){
|
||||
throw new UnsupportedOperationException("Differentiation of " + getClass().getName() + " not supported");
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
|
||||
return Arrays.asList(arg(0).dataType(), arg(1).dataType());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -16,9 +16,12 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.reduce;
|
||||
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
|
||||
import java.util.*;
|
||||
|
@ -30,12 +33,9 @@ import java.util.*;
|
|||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
@NoArgsConstructor
|
||||
public class SufficientStatistics extends DynamicCustomOp {
|
||||
|
||||
public SufficientStatistics() {
|
||||
}
|
||||
|
||||
|
||||
public SufficientStatistics(SameDiff sameDiff, @NonNull SDVariable x, @NonNull SDVariable axis, SDVariable shift) {
|
||||
super(null, sameDiff, argsNoNull(x, axis, shift), false);
|
||||
}
|
||||
|
@ -48,14 +48,30 @@ public class SufficientStatistics extends DynamicCustomOp {
|
|||
}
|
||||
}
|
||||
|
||||
public SufficientStatistics(@NonNull INDArray x, @NonNull INDArray axes, INDArray shift) {
|
||||
if (shift != null)
|
||||
addInputArgument(x, axes, shift);
|
||||
else
|
||||
addInputArgument(x, axes);
|
||||
}
|
||||
|
||||
public SufficientStatistics(@NonNull INDArray x, @NonNull INDArray axes) {
|
||||
this(x,axes,null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "sufficient_statistics";
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<SDVariable> doDiff(List<SDVariable> grad) {
|
||||
throw new UnsupportedOperationException("Backprop not yet implemented for op: " + getClass().getSimpleName());
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes) {
|
||||
// FIXME
|
||||
return Arrays.asList(inputDataTypes.get(0), inputDataTypes.get(0),inputDataTypes.get(0));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -111,7 +111,7 @@ public class TensorMmul extends DynamicCustomOp {
|
|||
int[][] deletedAxes = new int[][]{
|
||||
removeIndex(aAxes, sumAxes[0]),
|
||||
removeIndex(bAxes, sumAxes[1])};
|
||||
int[] gAxes = range(0, i_v1.get(0).getShape().length);
|
||||
int[] gAxes = range(0, i_v1.get(0).eval().shape().length);
|
||||
int[][] firstAxes = new int[][]{
|
||||
Arrays.copyOfRange(gAxes, deletedAxes[0].length, gAxes.length),
|
||||
deletedAxes[1]
|
||||
|
@ -144,18 +144,20 @@ public class TensorMmul extends DynamicCustomOp {
|
|||
int[][] axes) {
|
||||
|
||||
int validationLength = Math.min(axes[0].length, axes[1].length);
|
||||
INDArray aArray = a.eval();
|
||||
INDArray bArray = b.eval();
|
||||
for (int i = 0; i < validationLength; i++) {
|
||||
if (a.getShape()[axes[0][i]] != b.getShape()[axes[1][i]])
|
||||
if (aArray.shape()[axes[0][i]] != bArray.shape()[axes[1][i]])
|
||||
throw new IllegalArgumentException("Size of the given axes at each dimension must be the same size.");
|
||||
if (axes[0][i] < 0)
|
||||
axes[0][i] += a.getShape().length;
|
||||
axes[0][i] += aArray.shape().length;
|
||||
if (axes[1][i] < 0)
|
||||
axes[1][i] += b.getShape().length;
|
||||
axes[1][i] += bArray.shape().length;
|
||||
|
||||
}
|
||||
|
||||
List<Integer> listA = new ArrayList<>();
|
||||
for (int i = 0; i < a.getShape().length; i++) {
|
||||
for (int i = 0; i < aArray.shape().length; i++) {
|
||||
if (!Ints.contains(axes[0], i))
|
||||
listA.add(i);
|
||||
}
|
||||
|
@ -164,7 +166,7 @@ public class TensorMmul extends DynamicCustomOp {
|
|||
|
||||
|
||||
List<Integer> listB = new ArrayList<>();
|
||||
for (int i = 0; i < b.getShape().length; i++) {
|
||||
for (int i = 0; i < bArray.shape().length; i++) {
|
||||
if (!Ints.contains(axes[1], i))
|
||||
listB.add(i);
|
||||
}
|
||||
|
@ -172,9 +174,9 @@ public class TensorMmul extends DynamicCustomOp {
|
|||
int[] newAxesB = Ints.concat(axes[1], Ints.toArray(listB));
|
||||
|
||||
int n2 = 1;
|
||||
int aLength = Math.min(a.getShape().length, axes[0].length);
|
||||
int aLength = Math.min(aArray.shape().length, axes[0].length);
|
||||
for (int i = 0; i < aLength; i++) {
|
||||
n2 *= a.getShape()[axes[0][i]];
|
||||
n2 *= aArray.shape()[axes[0][i]];
|
||||
}
|
||||
|
||||
//if listA and listB are empty these do not initialize.
|
||||
|
@ -186,13 +188,13 @@ public class TensorMmul extends DynamicCustomOp {
|
|||
} else {
|
||||
oldShapeA = Longs.toArray(listA);
|
||||
for (int i = 0; i < oldShapeA.length; i++)
|
||||
oldShapeA[i] = a.getShape()[(int) oldShapeA[i]];
|
||||
oldShapeA[i] = aArray.shape()[(int) oldShapeA[i]];
|
||||
}
|
||||
|
||||
int n3 = 1;
|
||||
int bNax = Math.min(b.getShape().length, axes[1].length);
|
||||
int bNax = Math.min(bArray.shape().length, axes[1].length);
|
||||
for (int i = 0; i < bNax; i++) {
|
||||
n3 *= b.getShape()[axes[1][i]];
|
||||
n3 *= bArray.shape()[axes[1][i]];
|
||||
}
|
||||
|
||||
|
||||
|
@ -203,7 +205,7 @@ public class TensorMmul extends DynamicCustomOp {
|
|||
} else {
|
||||
oldShapeB = Longs.toArray(listB);
|
||||
for (int i = 0; i < oldShapeB.length; i++)
|
||||
oldShapeB[i] = b.getShape()[(int) oldShapeB[i]];
|
||||
oldShapeB[i] = bArray.shape()[(int) oldShapeB[i]];
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.reduce.bp;
|
||||
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -30,7 +31,7 @@ import java.util.List;
|
|||
/**
|
||||
* @author Alex Black
|
||||
*/
|
||||
|
||||
@NoArgsConstructor
|
||||
public abstract class BaseReductionBp extends DynamicCustomOp {
|
||||
|
||||
protected boolean keepDims;
|
||||
|
@ -96,7 +97,12 @@ public abstract class BaseReductionBp extends DynamicCustomOp {
|
|||
addArgs();
|
||||
}
|
||||
|
||||
public BaseReductionBp(){}
|
||||
public BaseReductionBp(INDArray origInput1, INDArray origInput2, INDArray gradAtOutput, INDArray output1, INDArray output2, boolean keepDims, int... dimensions){
|
||||
super(null, new INDArray[]{origInput1, origInput2, gradAtOutput}, new INDArray[]{output1, output2});
|
||||
this.keepDims = keepDims;
|
||||
this.dimensions = dimensions;
|
||||
addArgs();
|
||||
}
|
||||
|
||||
protected void addArgs(){
|
||||
addTArgument(keepDims ? 1 : 0);
|
||||
|
|
|
@ -16,17 +16,23 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.reduce.bp;
|
||||
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
|
||||
/**
|
||||
* Backprop op for Dot pairwise reduction operation
|
||||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
|
||||
@NoArgsConstructor
|
||||
public class DotBp extends BaseReductionBp {
|
||||
|
||||
public DotBp(SameDiff sameDiff, SDVariable origInput1, SDVariable origInput2, SDVariable gradAtOutput, boolean keepDims, int... dimensions) {
|
||||
|
@ -37,10 +43,22 @@ public class DotBp extends BaseReductionBp {
|
|||
super(origInput1, origInput2, gradAtOutput, output, keepDims, dimensions);
|
||||
}
|
||||
|
||||
public DotBp(){}
|
||||
public DotBp(INDArray origInput1, INDArray origInput2, INDArray gradAtOutput,
|
||||
INDArray outputX, INDArray outputY, boolean keepDims, int... dimensions) {
|
||||
super(origInput1, origInput2, gradAtOutput, outputX, outputY, keepDims, dimensions);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "reduce_dot_bp";
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
|
||||
Preconditions.checkState(dataTypes != null && dataTypes.size() == 3, "Expected exactly 3 input datatype for %s, got input %s", getClass(), dataTypes);
|
||||
Preconditions.checkState(dataTypes.get(0).isFPType(), "First input must be a floating point type, got %s", dataTypes.get(0));
|
||||
Preconditions.checkState(dataTypes.get(1).isFPType(), "Second input (gradient at reduction output) must be a floating point type, got %s", dataTypes.get(1));
|
||||
Preconditions.checkState(dataTypes.get(2).isFPType(), "Second input (gradient at reduction output) must be a floating point type, got %s", dataTypes.get(2));
|
||||
return Arrays.asList(dataTypes.get(0), dataTypes.get(0));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,86 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.linalg.api.ops.impl.reduce.floating;
|
||||
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.imports.NoOpNameFoundException;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.BaseReduceFloatOp;
|
||||
import org.nd4j.linalg.api.ops.BaseReduceOp;
|
||||
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* Calculate a bias
|
||||
*
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
public class Bias extends BaseReduceFloatOp {
|
||||
|
||||
private double mean;
|
||||
|
||||
public Bias(SameDiff sameDiff, SDVariable i_v, int[] dimensions, double mean) {
|
||||
super(sameDiff, i_v, dimensions);
|
||||
this.mean = mean;
|
||||
}
|
||||
|
||||
public Bias(SameDiff sameDiff, SDVariable i_v, SDVariable i_v2, int[] dimensions, double mean) {
|
||||
super(sameDiff, i_v, i_v2, dimensions);
|
||||
this.mean = mean;
|
||||
}
|
||||
|
||||
public Bias() {}
|
||||
|
||||
public Bias(INDArray x, int... dimensions) {
|
||||
super(x, dimensions);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Map<String, Object> propertiesForFunction() {
|
||||
Map<String,Object> ret = new LinkedHashMap<>();
|
||||
ret.put("mean",mean);
|
||||
return ret;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int opNum() {
|
||||
return 2;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "bias";
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String onnxName() {
|
||||
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
||||
}
|
||||
|
||||
@Override
|
||||
public String tensorflowName() {
|
||||
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
|
||||
}
|
||||
}
|
|
@ -24,6 +24,7 @@ import org.nd4j.base.Preconditions;
|
|||
import org.nd4j.imports.descriptors.properties.PropertyMapping;
|
||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.tensorflow.framework.AttrValue;
|
||||
import org.tensorflow.framework.GraphDef;
|
||||
|
@ -45,6 +46,7 @@ public class SequenceMask extends DynamicCustomOp {
|
|||
public SequenceMask(SameDiff sameDiff, SDVariable input, SDVariable maxLen, DataType dataType) {
|
||||
super(null, sameDiff, new SDVariable[] {input, maxLen}, false);
|
||||
this.dataType = dataType;
|
||||
addDArgument(dataType);
|
||||
}
|
||||
|
||||
public SequenceMask(SameDiff sameDiff, SDVariable input, int maxLen, DataType dataType) {
|
||||
|
@ -53,11 +55,21 @@ public class SequenceMask extends DynamicCustomOp {
|
|||
this.is_static_maxlen = true;
|
||||
addIArgument(maxLen);
|
||||
this.dataType = dataType;
|
||||
addDArgument(dataType);
|
||||
}
|
||||
|
||||
public SequenceMask(SameDiff sameDiff, SDVariable input, DataType dataType) {
|
||||
super(null, sameDiff, new SDVariable[] {input}, false);
|
||||
this.dataType = dataType;
|
||||
addDArgument(dataType);
|
||||
}
|
||||
|
||||
public SequenceMask(INDArray input, int maxLen, DataType dataType) {
|
||||
addInputArgument(input);
|
||||
addIArgument(maxLen);
|
||||
//addIArgument(dataType.toInt());
|
||||
addDArgument(dataType);
|
||||
this.dataType = dataType;
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.shape;
|
||||
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
|
@ -39,23 +40,37 @@ import java.util.Map;
|
|||
* @author Adam Gibson
|
||||
*/
|
||||
@Slf4j
|
||||
@NoArgsConstructor
|
||||
public class ZerosLike extends DynamicCustomOp {
|
||||
|
||||
protected DataType outputType; //Allow customizing dtype for TF import
|
||||
|
||||
public ZerosLike() {
|
||||
public ZerosLike(String name, SameDiff sameDiff, SDVariable input) {
|
||||
this(name, sameDiff, input, false, input.dataType());
|
||||
}
|
||||
|
||||
public ZerosLike(String name, SameDiff sameDiff, SDVariable input) {
|
||||
this(name, sameDiff, input, false);
|
||||
public ZerosLike(String name, SameDiff sameDiff, SDVariable input, DataType dataType) {
|
||||
this(name, sameDiff, input, false, dataType);
|
||||
}
|
||||
|
||||
public ZerosLike(String name, SameDiff sameDiff, SDVariable input, boolean inPlace) {
|
||||
this(name, sameDiff, input, inPlace, input.dataType());
|
||||
}
|
||||
|
||||
public ZerosLike(String name, SameDiff sameDiff, SDVariable input, boolean inPlace, DataType dataType) {
|
||||
super(name, sameDiff, new SDVariable[]{input}, inPlace);
|
||||
addDArgument(dataType);
|
||||
}
|
||||
|
||||
public ZerosLike(INDArray in, INDArray out){
|
||||
this(in, out, in.dataType());
|
||||
}
|
||||
|
||||
public ZerosLike(INDArray in, INDArray out, DataType dataType) {
|
||||
super(null, in, out, null, null);
|
||||
if (dataType != null) {
|
||||
addDArgument(dataType);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@ import org.nd4j.autodiff.samediff.SDVariable;
|
|||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
|
@ -52,16 +53,13 @@ public class BatchToSpace extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
public BatchToSpace(SameDiff sameDiff, SDVariable[] args, int[] blocks, int[][] crops, boolean inPlace) {
|
||||
super(null, sameDiff, args, inPlace);
|
||||
super(null, sameDiff, new SDVariable[]{args[0], sameDiff.constant(Nd4j.createFromArray(crops))}, inPlace);
|
||||
|
||||
this.blocks = blocks;
|
||||
this.crops = crops;
|
||||
|
||||
for (val b : blocks)
|
||||
addIArgument(b);
|
||||
|
||||
for (int e = 0; e < crops.length; e++)
|
||||
addIArgument(crops[e][0], crops[e][1]);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -22,6 +22,7 @@ import org.nd4j.autodiff.samediff.SDVariable;
|
|||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
|
@ -53,16 +54,12 @@ public class SpaceToBatch extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
public SpaceToBatch(SameDiff sameDiff, SDVariable[] args, int[] blocks, int[][] padding, boolean inPlace) {
|
||||
super(null, sameDiff, args, inPlace);
|
||||
super(null, sameDiff, new SDVariable[]{args[0], sameDiff.constant(Nd4j.createFromArray(padding))}, inPlace);
|
||||
|
||||
this.blocks = blocks;
|
||||
this.padding = padding;
|
||||
|
||||
for (val b : blocks)
|
||||
addIArgument(b);
|
||||
|
||||
for (int e = 0; e < padding.length; e++)
|
||||
addIArgument(padding[e][0], padding[e][1]);
|
||||
addIArgument(blocks[0]);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -58,7 +58,8 @@ public class UnsortedSegmentMax extends DynamicCustomOp {
|
|||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes);
|
||||
Preconditions.checkState(inputDataTypes != null && (inputDataTypes.size() == 2 || inputDataTypes.size() == 3),
|
||||
"Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes);
|
||||
return Collections.singletonList(inputDataTypes.get(0));
|
||||
}
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.transforms.segment;
|
||||
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -31,6 +32,7 @@ import java.util.List;
|
|||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
@NoArgsConstructor
|
||||
public class UnsortedSegmentMean extends DynamicCustomOp {
|
||||
|
||||
private int numSegments;
|
||||
|
@ -41,8 +43,6 @@ public class UnsortedSegmentMean extends DynamicCustomOp {
|
|||
addIArgument(numSegments);
|
||||
}
|
||||
|
||||
public UnsortedSegmentMean(){ }
|
||||
|
||||
@Override
|
||||
public String opName(){
|
||||
return "unsorted_segment_mean";
|
||||
|
@ -60,7 +60,8 @@ public class UnsortedSegmentMean extends DynamicCustomOp {
|
|||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes);
|
||||
Preconditions.checkState(inputDataTypes != null && (inputDataTypes.size() == 2 || inputDataTypes.size() == 3),
|
||||
"Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes);
|
||||
return Collections.singletonList(inputDataTypes.get(0));
|
||||
}
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.transforms.segment;
|
||||
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -31,6 +32,7 @@ import java.util.List;
|
|||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
@NoArgsConstructor
|
||||
public class UnsortedSegmentMin extends DynamicCustomOp {
|
||||
|
||||
private int numSegments;
|
||||
|
@ -41,8 +43,6 @@ public class UnsortedSegmentMin extends DynamicCustomOp {
|
|||
addIArgument(numSegments);
|
||||
}
|
||||
|
||||
public UnsortedSegmentMin(){ }
|
||||
|
||||
@Override
|
||||
public String opName(){
|
||||
return "unsorted_segment_min";
|
||||
|
@ -60,7 +60,8 @@ public class UnsortedSegmentMin extends DynamicCustomOp {
|
|||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes);
|
||||
Preconditions.checkState(inputDataTypes != null && (inputDataTypes.size() == 2 || inputDataTypes.size() == 3),
|
||||
"Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes);
|
||||
return Collections.singletonList(inputDataTypes.get(0));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.transforms.segment;
|
||||
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -31,6 +32,7 @@ import java.util.List;
|
|||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
@NoArgsConstructor
|
||||
public class UnsortedSegmentProd extends DynamicCustomOp {
|
||||
|
||||
private int numSegments;
|
||||
|
@ -41,8 +43,6 @@ public class UnsortedSegmentProd extends DynamicCustomOp {
|
|||
addIArgument(numSegments);
|
||||
}
|
||||
|
||||
public UnsortedSegmentProd(){ }
|
||||
|
||||
@Override
|
||||
public String opName(){
|
||||
return "unsorted_segment_prod";
|
||||
|
@ -60,7 +60,8 @@ public class UnsortedSegmentProd extends DynamicCustomOp {
|
|||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes);
|
||||
Preconditions.checkState(inputDataTypes != null && (inputDataTypes.size() == 2 || inputDataTypes.size() == 3),
|
||||
"Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes);
|
||||
return Collections.singletonList(inputDataTypes.get(0));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -16,10 +16,12 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.transforms.segment;
|
||||
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
|
||||
import java.util.ArrayList;
|
||||
|
@ -31,18 +33,23 @@ import java.util.List;
|
|||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
@NoArgsConstructor
|
||||
public class UnsortedSegmentSqrtN extends DynamicCustomOp {
|
||||
|
||||
private int numSegments;
|
||||
|
||||
public UnsortedSegmentSqrtN(INDArray data, INDArray segmentIds, int numSegments) {
|
||||
addInputArgument(data, segmentIds);
|
||||
addIArgument(numSegments);
|
||||
this.numSegments = numSegments;
|
||||
}
|
||||
|
||||
public UnsortedSegmentSqrtN(SameDiff sameDiff, SDVariable data, SDVariable segmentIds, int numSegments) {
|
||||
super(null, sameDiff, new SDVariable[] {data, segmentIds}, false);
|
||||
this.numSegments = numSegments;
|
||||
addIArgument(numSegments);
|
||||
}
|
||||
|
||||
public UnsortedSegmentSqrtN(){ }
|
||||
|
||||
@Override
|
||||
public String opName(){
|
||||
return "unsorted_segment_sqrt_n";
|
||||
|
@ -60,7 +67,8 @@ public class UnsortedSegmentSqrtN extends DynamicCustomOp {
|
|||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes);
|
||||
Preconditions.checkState(inputDataTypes != null && (inputDataTypes.size() == 2 || inputDataTypes.size() == 3),
|
||||
"Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes);
|
||||
List<DataType> out = new ArrayList<>();
|
||||
for( int i=0; i<numSegments; i++ ){
|
||||
out.add(inputDataTypes.get(0));
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.transforms.segment;
|
||||
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -32,6 +33,7 @@ import java.util.List;
|
|||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
@NoArgsConstructor
|
||||
public class UnsortedSegmentSum extends DynamicCustomOp {
|
||||
|
||||
private int numSegments;
|
||||
|
@ -42,8 +44,6 @@ public class UnsortedSegmentSum extends DynamicCustomOp {
|
|||
addIArgument(numSegments);
|
||||
}
|
||||
|
||||
public UnsortedSegmentSum(){ }
|
||||
|
||||
@Override
|
||||
public String opName(){
|
||||
return "unsorted_segment_sum";
|
||||
|
@ -61,7 +61,8 @@ public class UnsortedSegmentSum extends DynamicCustomOp {
|
|||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes);
|
||||
Preconditions.checkState(inputDataTypes != null && (inputDataTypes.size() == 2 || inputDataTypes.size() == 3),
|
||||
"Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes);
|
||||
//TODO Allow customizing output type
|
||||
return Collections.singletonList(Nd4j.defaultFloatingPointType());
|
||||
}
|
||||
|
|
|
@ -50,6 +50,11 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
super(backend);
|
||||
}
|
||||
|
||||
@Override
|
||||
public long getTimeoutMilliseconds() {
|
||||
return 90000L;
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testXwPlusB() {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
@ -319,7 +324,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
|
||||
@Test
|
||||
public void testIm2Col() {
|
||||
OpValidationSuite.ignoreFailing(); //TEMPORARY DUE TO JVM CRASH: https://github.com/deeplearning4j/deeplearning4j/issues/6873
|
||||
//OpValidationSuite.ignoreFailing(); //TEMPORARY DUE TO JVM CRASH: https://github.com/deeplearning4j/deeplearning4j/issues/6873
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
||||
int[][] inputSizes = new int[][]{{1, 3, 8, 8}, {3, 6, 12, 12}};
|
||||
|
|
|
@ -32,6 +32,9 @@ import org.nd4j.linalg.api.buffer.DataType;
|
|||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.CustomOp;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.api.ops.custom.*;
|
||||
import org.nd4j.linalg.api.ops.impl.broadcast.BiasAdd;
|
||||
import org.nd4j.linalg.api.ops.impl.broadcast.BiasAddGrad;
|
||||
import org.nd4j.linalg.api.ops.impl.controlflow.compat.StopGradient;
|
||||
import org.nd4j.linalg.api.ops.impl.reduce.Mmul;
|
||||
import org.nd4j.linalg.api.ops.impl.shape.DiagPart;
|
||||
|
@ -513,7 +516,7 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
@Test
|
||||
public void testTrace(){
|
||||
//TODO need to work out how to handle shape_op for scalars...
|
||||
OpValidationSuite.ignoreFailing();
|
||||
//OpValidationSuite.ignoreFailing();
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
for( int[] inShape : new int[][]{{3,3}}){
|
||||
|
||||
|
@ -546,12 +549,15 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
SDVariable x = sameDiff.var("x", arr);
|
||||
SDVariable y = sameDiff.var("y", arr2);
|
||||
SDVariable result = sameDiff.tensorMmul(x, y, new int[][]{{0}, {1}});
|
||||
assertArrayEquals(ArrayUtil.getTensorMmulShape(new long[]{2, 2, 2}, new long[]{2, 2, 2}, new int[][]{{0}, {1}}), result.getShape());
|
||||
assertEquals(32, sameDiff.numElements());
|
||||
assertArrayEquals(ArrayUtil.getTensorMmulShape(new long[]{2, 2, 2}, new long[]{2, 2, 2}, new int[][]{{0}, {1}}),
|
||||
result.eval().shape());
|
||||
assertEquals(16, sameDiff.numElements());
|
||||
|
||||
SDVariable loss = sameDiff.standardDeviation(result, true);
|
||||
sameDiff.addLossVariable(loss);
|
||||
|
||||
String err = OpValidation.validate(new TestCase(sameDiff));
|
||||
assertNull(err);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -1782,4 +1788,338 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
assertEquals(exp, out); //Values in x not in y
|
||||
assertEquals(exp, outIdx); //Indices of the values in x not in y
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDivideNoNan() {
|
||||
OpValidationSuite.ignoreFailing(); //TODO: implement DivideNoNan.doDiff()
|
||||
|
||||
SameDiff sameDiff = SameDiff.create();
|
||||
|
||||
INDArray in1 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
|
||||
INDArray in2 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
|
||||
|
||||
SDVariable input1 = sameDiff.var(in1);
|
||||
SDVariable input2 = sameDiff.var(in2);
|
||||
|
||||
INDArray expected = Nd4j.ones(3,4);
|
||||
|
||||
SDVariable output = new DivideNoNan(sameDiff, input1, input2).outputVariable();
|
||||
|
||||
TestCase tc = new TestCase(sameDiff)
|
||||
.gradientCheck(true)
|
||||
.expectedOutput(output.name(), expected);
|
||||
|
||||
String err = OpValidation.validate(tc);
|
||||
assertNull(err);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDigamma() {
|
||||
|
||||
INDArray in1 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
|
||||
|
||||
INDArray expected = Nd4j.createFromArray(new double[]{
|
||||
-0.5772157,0.42278433,0.9227843,1.2561177,1.5061177,1.7061176,1.8727844,2.0156415,2.1406415,2.2517526,2.3517525,2.4426618
|
||||
}).reshape(3,4);
|
||||
|
||||
val tc = new OpTestCase(new Digamma(in1)).expectedOutput(0, expected);
|
||||
|
||||
String err = OpValidation.validate(tc);
|
||||
assertNull(err);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testFlatten() {
|
||||
|
||||
SameDiff sameDiff = SameDiff.create();
|
||||
|
||||
INDArray x = Nd4j.linspace(DataType.DOUBLE, 1, 27, 1).reshape(3,3,3);
|
||||
SDVariable sdx = sameDiff.var(x);
|
||||
|
||||
INDArray expected = Nd4j.linspace(DataType.DOUBLE,1,27,1);
|
||||
|
||||
SDVariable output = new Flatten(sameDiff, 'c', sdx).outputVariable();
|
||||
SDVariable loss = sameDiff.standardDeviation(sdx, true);
|
||||
sameDiff.addLossVariable(loss);
|
||||
|
||||
TestCase tc = new TestCase(sameDiff)
|
||||
.gradientCheck(true)
|
||||
.expectedOutput(output.name(), expected);
|
||||
|
||||
String err = OpValidation.validate(tc);
|
||||
assertNull(err);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testFusedBatchNorm() {
|
||||
OpValidationSuite.ignoreFailing();
|
||||
SameDiff sameDiff = SameDiff.create();
|
||||
|
||||
INDArray x = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 2*2*3*4).reshape(2,2,3,4);
|
||||
INDArray scale = Nd4j.create(DataType.DOUBLE, 4);
|
||||
scale.assign(0.5);
|
||||
INDArray offset = Nd4j.create(DataType.DOUBLE, 4);
|
||||
offset.assign(2.0);
|
||||
|
||||
SDVariable input1 = sameDiff.var(x);
|
||||
SDVariable input2 = sameDiff.var(scale);
|
||||
SDVariable input3 = sameDiff.var(offset);
|
||||
|
||||
INDArray expectedY = Nd4j.createFromArray(new double[]{
|
||||
985.5258, 985.5258, 985.5258, 985.5258,
|
||||
659.7321, 659.7321, 659.7321, 659.7321,
|
||||
399.0972, 399.0972, 399.0972, 399.0972,
|
||||
203.6210, 203.6210, 203.6210, 203.6210,
|
||||
73.3036, 73.3036, 73.3036, 73.3036,
|
||||
8.1448, 8.1448, 8.1448, 8.1448,
|
||||
8.1448, 8.1448, 8.1448, 8.1448,
|
||||
73.3036, 73.3036, 73.3036, 73.3036,
|
||||
203.6210, 203.6210, 203.6210, 203.6210,
|
||||
399.0972, 399.0972, 399.0972, 399.0972,
|
||||
659.7321, 659.7321, 659.7321, 659.7321,
|
||||
985.5258, 985.5258, 985.5258, 985.5258}).reshape(x.shape());
|
||||
INDArray expectedBatchMean = Nd4j.createFromArray(new double[]{23., 24., 25., 26.});
|
||||
INDArray expectedBatchVar = Nd4j.createFromArray(new double[]{208.00001526, 208.00001526, 208.00001526, 208.00001526});
|
||||
|
||||
SDVariable[] outputs = new FusedBatchNorm(sameDiff, input1, input2, input3, 0, 1).outputVariables();
|
||||
SDVariable loss = sameDiff.standardDeviation(input1, true);
|
||||
sameDiff.addLossVariable(loss);
|
||||
|
||||
TestCase tc = new TestCase(sameDiff)
|
||||
.gradientCheck(true)
|
||||
.expectedOutput(outputs[0].name(), expectedY)
|
||||
.expectedOutput(outputs[1].name(), expectedBatchMean)
|
||||
.expectedOutput(outputs[2].name(), expectedBatchVar);
|
||||
|
||||
String err = OpValidation.validate(tc);
|
||||
assertNull(err);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testIgamma() {
|
||||
|
||||
INDArray in1 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
|
||||
INDArray in2 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
|
||||
|
||||
INDArray expected = Nd4j.createFromArray(new double[]{
|
||||
0.63212055,0.59399414,0.5768099,0.56652874,0.5595013,0.5542634,0.5501591,0.5463888,0.54329145,0.54048204,0.5378594,0.53233755
|
||||
}).reshape(3,4);
|
||||
|
||||
val tc = new OpTestCase(new Igamma(in1, in2)).expectedOutput(0, expected);
|
||||
|
||||
String err = OpValidation.validate(tc);
|
||||
assertNull(err);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testIgammaC() {
|
||||
|
||||
INDArray in1 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
|
||||
INDArray in2 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
|
||||
|
||||
|
||||
INDArray expected = Nd4j.createFromArray(new double[]{
|
||||
0.36787945,0.40600586,0.42319012,0.43347126,0.4404987,0.44573656,0.4498409,0.45361117,0.45670855,0.459518,0.46214062,0.46766248
|
||||
}).reshape(3,4);
|
||||
|
||||
val tc = new OpTestCase(new Igammac(in1, in2)).expectedOutput(0, expected);
|
||||
|
||||
String err = OpValidation.validate(tc);
|
||||
assertNull(err);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLgamma() {
|
||||
|
||||
SameDiff sameDiff = SameDiff.create();
|
||||
|
||||
INDArray in = Nd4j.linspace(DataType.DOUBLE, 1, 12, 1).reshape(3, 4);
|
||||
SDVariable sdInput = sameDiff.var(in);
|
||||
|
||||
INDArray expected = Nd4j.createFromArray(new double[]{
|
||||
0.0,0.0,0.6931472,1.7917595,3.1780539,4.787492,6.5792513,8.525162,10.604603,12.801827,15.104413,17.502308
|
||||
}).reshape(3,4);
|
||||
|
||||
SDVariable output = new Lgamma(sameDiff, sdInput).outputVariable();
|
||||
|
||||
SDVariable loss = sameDiff.standardDeviation(sdInput, true);
|
||||
sameDiff.addLossVariable(loss);
|
||||
|
||||
TestCase tc = new TestCase(sameDiff)
|
||||
.gradientCheck(true)
|
||||
.expectedOutput(output.name(), expected);
|
||||
|
||||
String err = OpValidation.validate(tc);
|
||||
assertNull(err);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLu() {
|
||||
|
||||
SameDiff sameDiff = SameDiff.create();
|
||||
|
||||
INDArray in1 = Nd4j.createFromArray(new double[]{
|
||||
1., 2., 3., 0., 2., 3., 0., 0., 7.
|
||||
}).reshape(3,3);
|
||||
|
||||
SDVariable input1 = sameDiff.var(in1);
|
||||
|
||||
INDArray expected = Nd4j.createFromArray(new double[]{
|
||||
1., 2., 3., 0., 2., 3., 0., 0., 7
|
||||
}).reshape(3,3);
|
||||
|
||||
INDArray pexpected = Nd4j.createFromArray(new int[]{
|
||||
0, 1, 2
|
||||
});
|
||||
|
||||
sameDiff.loss.l2Loss(input1);
|
||||
SDVariable[] output = new Lu(sameDiff, input1).outputVariables();
|
||||
|
||||
TestCase tc = new TestCase(sameDiff)
|
||||
.gradientCheck(true)
|
||||
.expectedOutput(output[0].name(), expected)
|
||||
.expectedOutput(output[1].name(), pexpected);
|
||||
|
||||
String err = OpValidation.validate(tc);
|
||||
assertNull(err);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMatrixBandPart() {
|
||||
OpValidationSuite.ignoreFailing();
|
||||
SameDiff sameDiff = SameDiff.create();
|
||||
|
||||
INDArray input = Nd4j.createFromArray(new float[]{0.7788f,0.8012f,0.7244f,0.2309f,
|
||||
0.7271f,0.1804f,0.5056f,0.8925f,
|
||||
0.5461f,0.9234f,0.0856f,0.7938f}).reshape(3,4);
|
||||
|
||||
SDVariable sdInput = sameDiff.var(input);
|
||||
SDVariable sdInput1 = sameDiff.constant(1);
|
||||
SDVariable sdInput2 = sameDiff.constant(-1);
|
||||
|
||||
INDArray expected = Nd4j.createFromArray(new float[]{
|
||||
0.7788f, 0.8012f, 0.7244f, 0.2309f,
|
||||
0.7271f, 0.1804f, 0.5056f, 0.8925f,
|
||||
0.f, 0.9234f, 0.0856f, 0.7938f
|
||||
}).reshape(3,4);
|
||||
|
||||
sameDiff.loss.l2Loss(sdInput);
|
||||
SDVariable output = new MatrixBandPart(sameDiff, sdInput, 1, -1).outputVariable();
|
||||
|
||||
TestCase tc = new TestCase(sameDiff)
|
||||
.gradientCheck(true)
|
||||
.expectedOutput(output.name(), expected);
|
||||
|
||||
String err = OpValidation.validate(tc);
|
||||
assertNull(err);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testPolygamma() {
|
||||
|
||||
INDArray in1 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
|
||||
INDArray in2 = Nd4j.linspace(1, 12, 12).reshape(3, 4);
|
||||
|
||||
INDArray expected = Nd4j.createFromArray(new double[]{
|
||||
1.644934,-0.4041138,0.1189394,-0.03750069,0.01226151,-0.0041002957,0.001392272,-4.780109E-4,1.6549716E-4,-5.7675967E-5,2.0206635E-5,-7.1101636E-6
|
||||
}).reshape(3,4);
|
||||
|
||||
val tc = new OpTestCase(new Polygamma(in1, in2)).expectedOutput(0, expected);
|
||||
|
||||
String err = OpValidation.validate(tc);
|
||||
assertNull(err);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTriangularSolve() {
|
||||
|
||||
INDArray a = Nd4j.createFromArray(new float[]{
|
||||
3.f, 0.f, 0.f, 0.f,
|
||||
2.f, 1.f, 0.f, 0.f,
|
||||
1.f, 0.f, 1.f, 0.f,
|
||||
1.f, 1.f, 1.f, 1.f
|
||||
}).reshape(4,4);
|
||||
|
||||
INDArray b = Nd4j.createFromArray(new float[]{
|
||||
4.f, 2.f, 4.f, 2.f
|
||||
}).reshape(4,1);
|
||||
|
||||
INDArray expected = Nd4j.createFromArray(new float[]{
|
||||
1.333333f, 2.0f, 4.0f, 2.0f
|
||||
}).reshape(4,1);
|
||||
|
||||
val tc = new OpTestCase(new TriangularSolve(a, b, false, true)).expectedOutput(0, expected);
|
||||
|
||||
String err = OpValidation.validate(tc);
|
||||
assertNull(err);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBiasAdd() {
|
||||
|
||||
SameDiff sameDiff = SameDiff.create();
|
||||
|
||||
INDArray in1 = Nd4j.linspace(1, 12, 12);
|
||||
INDArray in2 = Nd4j.linspace(1, 12, 12);
|
||||
|
||||
SDVariable input1 = sameDiff.var(in1);
|
||||
SDVariable input2 = sameDiff.var(in2);
|
||||
|
||||
INDArray expected = Nd4j.createFromArray(new double[]{
|
||||
2.0000, 4.0000, 6.0000, 8.0000, 10.0000, 12.0000, 14.0000, 16.0000, 18.0000, 20.0000, 22.0000, 24.0000
|
||||
});
|
||||
|
||||
SDVariable output = new BiasAdd(sameDiff, input1, input2, false).outputVariable();
|
||||
SDVariable loss = sameDiff.standardDeviation(input1, true);
|
||||
sameDiff.addLossVariable(loss);
|
||||
SDVariable loss2 = sameDiff.standardDeviation(input2, true);
|
||||
sameDiff.addLossVariable(loss2);
|
||||
|
||||
TestCase tc = new TestCase(sameDiff)
|
||||
.gradientCheck(true)
|
||||
.expectedOutput(output.name(), expected);
|
||||
|
||||
String err = OpValidation.validate(tc);
|
||||
assertNull(err);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBiasAddGrad() {
|
||||
|
||||
SameDiff sameDiff = SameDiff.create();
|
||||
|
||||
INDArray x = Nd4j.linspace(DataType.FLOAT,1, 24, 24).reshape(2,2,2,3);
|
||||
INDArray grad = Nd4j.linspace(DataType.FLOAT, 0.1, 0.1, 24).reshape(2,2,2,3);
|
||||
|
||||
INDArray bias = Nd4j.createFromArray(new float[]{-1.f, -2.f, -3.f});
|
||||
|
||||
INDArray expected = Nd4j.createFromArray(new float[]{9.2f, 10.f , 10.8f});
|
||||
|
||||
OpTestCase tc = new OpTestCase(new BiasAddGrad(x, bias, grad,false)).
|
||||
expectedOutput(0, grad).
|
||||
expectedOutput(1, expected);
|
||||
|
||||
String err = OpValidation.validate(tc);
|
||||
assertNull(err);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testRoll() {
|
||||
|
||||
INDArray x = Nd4j.createFromArray(new double[]{ 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42,
|
||||
21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, 22.11, 22.12, 22.21, 22.22, 22.31, 22.32, 22.41, 22.42}).
|
||||
reshape(2,2,4,2);
|
||||
|
||||
INDArray expected = Nd4j.createFromArray(new double[]{ 22.21, 22.22, 22.31, 22.32, 22.41, 22.42, 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42,
|
||||
12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, 21.11, 21.12, 21.21, 21.22, 21.31, 21.32,
|
||||
21.41, 21.42, 22.11, 22.12
|
||||
}).reshape(x.shape());
|
||||
|
||||
int shift = 6;
|
||||
|
||||
val tc = new OpTestCase(new Roll(x,shift)).expectedOutput(0,expected);
|
||||
String err = OpValidation.validate(tc);
|
||||
|
||||
assertNull(err);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -35,16 +35,20 @@ import org.nd4j.linalg.api.ops.CustomOp;
|
|||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.api.ops.impl.indexaccum.IAMax;
|
||||
import org.nd4j.linalg.api.ops.impl.indexaccum.IAMin;
|
||||
import org.nd4j.linalg.api.ops.impl.loss.SoftmaxCrossEntropyWithLogitsLoss;
|
||||
import org.nd4j.linalg.api.ops.impl.reduce.Moments;
|
||||
import org.nd4j.linalg.api.ops.impl.reduce.NormalizeMoments;
|
||||
import org.nd4j.linalg.api.ops.impl.reduce.floating.AMean;
|
||||
import org.nd4j.linalg.api.ops.impl.reduce.SufficientStatistics;
|
||||
import org.nd4j.linalg.api.ops.impl.reduce.floating.*;
|
||||
import org.nd4j.linalg.api.ops.impl.reduce.same.ASum;
|
||||
import org.nd4j.linalg.api.ops.impl.reduce3.*;
|
||||
import org.nd4j.linalg.api.ops.impl.summarystats.StandardDeviation;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
|
||||
import org.nd4j.linalg.checkutil.NDArrayCreationUtil;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||
import org.nd4j.linalg.indexing.BooleanIndexing;
|
||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||
import org.nd4j.linalg.indexing.conditions.Conditions;
|
||||
import org.nd4j.linalg.ops.transforms.Transforms;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
|
@ -96,7 +100,7 @@ public class ReductionOpValidation extends BaseOpValidation {
|
|||
@Test
|
||||
public void testZeroCount() {
|
||||
List<String> allFailed = new ArrayList<>();
|
||||
for (int i = 0; i < 2; i++) {
|
||||
for (int i = 0; i < 21; i++) {
|
||||
SameDiff sd = SameDiff.create();
|
||||
|
||||
INDArray ia;
|
||||
|
@ -159,25 +163,25 @@ public class ReductionOpValidation extends BaseOpValidation {
|
|||
|
||||
@Test
|
||||
public void testReductionGradientsSimple() {
|
||||
OpValidationSuite.ignoreFailing(); //TODO TEMPORARY DUE TO CRASHES
|
||||
//OpValidationSuite.ignoreFailing(); //TODO TEMPORARY DUE TO CRASHES
|
||||
//Test reductions: final and only function
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
||||
List<String> failed = new ArrayList<>();
|
||||
|
||||
for (int i = 0; i < 21; i++) {
|
||||
|
||||
SameDiff sd = SameDiff.create();
|
||||
|
||||
int nOut = 4;
|
||||
int minibatch = 10;
|
||||
SDVariable input = sd.var("in", -1, nOut);
|
||||
SDVariable input = sd.var("in", minibatch, nOut);
|
||||
INDArray inputArr = Nd4j.randn(minibatch, nOut).muli(100);
|
||||
long length = nOut * minibatch;
|
||||
|
||||
SDVariable loss;
|
||||
String name;
|
||||
TestCase tc = new TestCase(sd);
|
||||
boolean gradCheck = true;
|
||||
switch (i) {
|
||||
case 0:
|
||||
loss = sd.mean("loss", input);
|
||||
|
@ -234,11 +238,13 @@ public class ReductionOpValidation extends BaseOpValidation {
|
|||
loss = sd.math().countNonZero("loss", input);
|
||||
name = "countNonZero";
|
||||
tc.expectedOutput("loss", Nd4j.scalar(inputArr.length()));
|
||||
gradCheck = false; //Long out, not floating point
|
||||
break;
|
||||
case 11:
|
||||
loss = sd.math().countZero("loss", input);
|
||||
name = "countZero";
|
||||
tc.expectedOutput("loss", Nd4j.scalar(0));
|
||||
tc.expectedOutput("loss", Nd4j.scalar(0L));
|
||||
gradCheck = false; //Long out, not floating point
|
||||
break;
|
||||
case 12:
|
||||
loss = sd.math().amax("loss", input);
|
||||
|
@ -272,7 +278,7 @@ public class ReductionOpValidation extends BaseOpValidation {
|
|||
loss = sd.math().logSumExp("loss", input);
|
||||
INDArray expArr = Transforms.exp(inputArr);
|
||||
double sum = expArr.sumNumber().doubleValue();
|
||||
tc.expected("loss", Nd4j.create(new double[]{Math.log(sum)}));
|
||||
tc.expected("loss", Nd4j.scalar(Math.log(sum)));
|
||||
break;
|
||||
case 18:
|
||||
inputArr = Nd4j.rand(minibatch, nOut);
|
||||
|
@ -307,9 +313,15 @@ public class ReductionOpValidation extends BaseOpValidation {
|
|||
log.info("*** Starting test: " + msg);
|
||||
|
||||
sd.associateArrayWithVariable(inputArr, input);
|
||||
|
||||
if(gradCheck) {
|
||||
sd.addLossVariable(loss);
|
||||
}
|
||||
|
||||
tc.testName(msg);
|
||||
if(!gradCheck){
|
||||
tc.gradientCheck(false);
|
||||
}
|
||||
|
||||
String error = OpValidation.validate(tc, true);
|
||||
if (error != null)
|
||||
failed.add(error);
|
||||
|
@ -629,14 +641,14 @@ public class ReductionOpValidation extends BaseOpValidation {
|
|||
|
||||
List<String> failed = new ArrayList<>();
|
||||
for (int[] reduceDims : new int[][]{{Integer.MAX_VALUE}, {0, 1, 2}, {0}, {1}, {2}, {0, 1}, {0, 2}, {1, 2}}) {
|
||||
for (int i = 6; i < 7; i++) {
|
||||
for (int i = 0; i < 7; i++) {
|
||||
|
||||
SameDiff sd = SameDiff.create();
|
||||
sd.setLogExecution(false);
|
||||
|
||||
|
||||
SDVariable in = sd.var("in", -1, d1, d2);
|
||||
SDVariable in2 = sd.var("in2", -1, d1, d2);
|
||||
SDVariable in = sd.var("in", d1, d1, d2);
|
||||
SDVariable in2 = sd.var("in2", d0, d1, d2);
|
||||
|
||||
INDArray inArr = Nd4j.randn(new int[]{d0, d1, d2}).muli(100);
|
||||
INDArray in2Arr = Nd4j.randn(inArr.shape()).muli(100);
|
||||
|
@ -645,40 +657,43 @@ public class ReductionOpValidation extends BaseOpValidation {
|
|||
SDVariable reduced;
|
||||
String name;
|
||||
TestCase tc = new TestCase(sd);
|
||||
Double maxRelError = null;
|
||||
switch (i) {
|
||||
case 0:
|
||||
reduced = sd.math().manhattanDistance(in, in2, reduceDims);
|
||||
name = "manhattan";
|
||||
exp = Nd4j.getExecutioner().exec(new ManhattanDistance(inArr, in2Arr, null, true, false, reduceDims));
|
||||
exp = Nd4j.getExecutioner().exec(new ManhattanDistance(inArr, in2Arr, null, false, false, reduceDims));
|
||||
break;
|
||||
case 1:
|
||||
reduced = sd.math().euclideanDistance(in, in2, reduceDims);
|
||||
name = "euclidean";
|
||||
exp = Nd4j.getExecutioner().exec(new EuclideanDistance(inArr, in2Arr, null, true, false, reduceDims));
|
||||
exp = Nd4j.getExecutioner().exec(new EuclideanDistance(inArr, in2Arr, null, false, false, reduceDims));
|
||||
break;
|
||||
case 2:
|
||||
inArr.muli(1e-4);
|
||||
in2Arr.muli(1e-4);
|
||||
reduced = sd.math().cosineSimilarity(in, in2, reduceDims);
|
||||
name = "cosine";
|
||||
exp = Nd4j.getExecutioner().exec(new CosineSimilarity(inArr, in2Arr, null, true, false, reduceDims));
|
||||
exp = Nd4j.getExecutioner().exec(new CosineSimilarity(inArr, in2Arr, null, false, false, reduceDims));
|
||||
maxRelError = 1e-4;
|
||||
break;
|
||||
case 3:
|
||||
reduced = sd.math().cosineDistance(in, in2, reduceDims);
|
||||
name = "cosinedistance";
|
||||
exp = Nd4j.getExecutioner().exec(new CosineDistance(inArr, in2Arr, null, true, false, reduceDims));
|
||||
exp = Nd4j.getExecutioner().exec(new CosineDistance(inArr, in2Arr, null, false, false, reduceDims));
|
||||
maxRelError = 1e-4;
|
||||
break;
|
||||
case 4:
|
||||
reduced = sd.math().hammingDistance(in, in2, reduceDims);
|
||||
name = "hamming";
|
||||
exp = Nd4j.getExecutioner().exec(new HammingDistance(inArr, in2Arr, null, true, false, reduceDims));
|
||||
exp = Nd4j.getExecutioner().exec(new HammingDistance(inArr, in2Arr, null, false, false, reduceDims));
|
||||
break;
|
||||
case 5:
|
||||
name = "jaccard";
|
||||
reduced = sd.math().jaccardDistance(name, in, in2, reduceDims);
|
||||
inArr.divi(100).addi(0.1);
|
||||
in2Arr.divi(100).addi(0.1);
|
||||
exp = Nd4j.getExecutioner().exec(new JaccardDistance(inArr, in2Arr, null, true, false, reduceDims));
|
||||
exp = Nd4j.getExecutioner().exec(new JaccardDistance(inArr, in2Arr, null, false, false, reduceDims));
|
||||
|
||||
if (OpValidationSuite.IGNORE_FAILING && reduceDims.length == 2)
|
||||
continue;
|
||||
|
@ -708,6 +723,9 @@ public class ReductionOpValidation extends BaseOpValidation {
|
|||
|
||||
tc.expected(reduced, exp);
|
||||
|
||||
if(maxRelError != null)
|
||||
tc.gradCheckMaxRelativeError(maxRelError);
|
||||
|
||||
String error = OpValidation.validate(tc, true);
|
||||
if (error != null) {
|
||||
failed.add(msg + " - " + error);
|
||||
|
@ -768,7 +786,6 @@ public class ReductionOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
@Ignore("AB 2019/06/24 - Failing: Ignored to get to all passing baseline to prevent regressions via CI - see issue #7912")
|
||||
public void testNormalizeMomentsOp() {
|
||||
INDArray data = Nd4j.linspace(1, 100, 100, DataType.DOUBLE).reshape(10, 10);
|
||||
INDArray ssSum = data.sum(0);
|
||||
|
@ -780,7 +797,7 @@ public class ReductionOpValidation extends BaseOpValidation {
|
|||
INDArray mean = Nd4j.createUninitialized(DataType.DOUBLE, meanExp.shape());
|
||||
INDArray var = Nd4j.createUninitialized(DataType.DOUBLE, varExp.shape());
|
||||
|
||||
OpTestCase op = new OpTestCase(new NormalizeMoments(Nd4j.scalar(DataType.INT, 10), ssSum, ssSqSum, mean, var));
|
||||
OpTestCase op = new OpTestCase(new NormalizeMoments(Nd4j.scalar(DataType.DOUBLE, 10), ssSum, ssSqSum, mean, var));
|
||||
op.expectedOutput(0, meanExp);
|
||||
op.expectedOutput(1, varExp);
|
||||
|
||||
|
@ -821,7 +838,7 @@ public class ReductionOpValidation extends BaseOpValidation {
|
|||
List<String> failed = new ArrayList<>();
|
||||
List<int[]> dims = Arrays.asList(new int[]{0}, new int[]{1}, new int[]{0, 1}, new int[0]);
|
||||
|
||||
INDArray in = Nd4j.rand(3, 4);
|
||||
INDArray in = Nd4j.rand(DataType.DOUBLE,3, 4);
|
||||
|
||||
for (int t = 0; t < 4; t++) {
|
||||
int[] d = dims.get(t);
|
||||
|
@ -838,52 +855,47 @@ public class ReductionOpValidation extends BaseOpValidation {
|
|||
switch (i) {
|
||||
case 0:
|
||||
reduce = s.argmax(dim);
|
||||
exp = Nd4j.argMax(in, dim).castTo(DataType.DOUBLE);
|
||||
exp = Nd4j.argMax(in, dim);
|
||||
name = "argmax";
|
||||
break;
|
||||
case 1:
|
||||
reduce = s.argmin(dim);
|
||||
exp = Nd4j.argMin(in, dim).castTo(DataType.DOUBLE);
|
||||
exp = Nd4j.argMin(in, dim);
|
||||
name = "argmin";
|
||||
break;
|
||||
case 2:
|
||||
reduce = sd.math().iamax(s, dim);
|
||||
exp = Nd4j.getExecutioner().exec(new IAMax(in.dup(), dim));
|
||||
exp = exp.castTo(DataType.DOUBLE);
|
||||
name = "iamax";
|
||||
break;
|
||||
case 3:
|
||||
reduce = sd.math().iamin(s, dim);
|
||||
exp = Nd4j.getExecutioner().exec(new IAMin(in.dup(), dim));
|
||||
exp = exp.castTo(DataType.DOUBLE);
|
||||
name = "iamin";
|
||||
break;
|
||||
case 4:
|
||||
reduce = sd.math().firstIndex(s, Conditions.greaterThan(0), dim);
|
||||
exp = in.sum(dim).assign(0);
|
||||
exp = exp.castTo(DataType.DOUBLE);
|
||||
exp = in.sum(dim).assign(0).castTo(DataType.INT64);
|
||||
name = "firstindex";
|
||||
break;
|
||||
case 5:
|
||||
reduce = sd.math().lastIndex(s, Conditions.greaterThan(0), dim);
|
||||
if (t == 0) exp = Nd4j.create(new double[]{2, 2, 2, 2});
|
||||
else if (t == 1) exp = Nd4j.create(new double[]{3, 3, 3});
|
||||
else exp = Nd4j.scalar(11.0);
|
||||
exp = exp.castTo(DataType.DOUBLE);
|
||||
if (t == 0) exp = Nd4j.createFromArray(2L, 2, 2, 2);
|
||||
else if (t == 1) exp = Nd4j.createFromArray(3L, 3, 3);
|
||||
else exp = Nd4j.scalar(11L);
|
||||
name = "lastindex";
|
||||
break;
|
||||
case 6:
|
||||
reduce = sd.matchConditionCount("count", s, Conditions.greaterThan(0), false, dim);
|
||||
if (t == 0) exp = Nd4j.create(new double[]{3, 3, 3, 3});
|
||||
else if (t == 1) exp = Nd4j.create(new double[]{4, 4, 4});
|
||||
else exp = Nd4j.scalar(12.0);
|
||||
exp = exp.castTo(DataType.DOUBLE);
|
||||
if (t == 0) exp = Nd4j.createFromArray(3L, 3, 3, 3);
|
||||
else if (t == 1) exp = Nd4j.createFromArray(4L, 4, 4);
|
||||
else exp = Nd4j.scalar(12L);
|
||||
name = "matchConditionCount";
|
||||
break;
|
||||
default:
|
||||
throw new RuntimeException();
|
||||
}
|
||||
|
||||
SDVariable preCast = reduce;
|
||||
reduce = reduce.castTo(DataType.DOUBLE);
|
||||
|
||||
SDVariable loss;
|
||||
|
@ -894,7 +906,7 @@ public class ReductionOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
TestCase tc = new TestCase(sd)
|
||||
.expected(reduce, exp)
|
||||
.expected(preCast, exp)
|
||||
.gradientCheck(false)
|
||||
.testName(name + " - " + (dim == null ? null : Arrays.toString(dim)));
|
||||
|
||||
|
@ -1335,4 +1347,254 @@ public class ReductionOpValidation extends BaseOpValidation {
|
|||
}
|
||||
}
|
||||
}
|
||||
@Test
|
||||
public void testSufficientStatisticsOp() {
|
||||
INDArray data = Nd4j.createFromArray(new double[]{
|
||||
5.5, 0., 0.3, 5.5,1.5, 0., 1.3, 6.5,8.6, 0., 0., 0.4,2.5, 1., 0.3, 4.5,1.5, 1.,
|
||||
1.3, 1.5,3.5, 0., 1.3, 2.5,2.6, 2., 3., 1.4,4.5, 1., 0.3, 0.5
|
||||
}).reshape(2,2,2,4);
|
||||
INDArray axes = Nd4j.linspace(DataType.LONG, 0, 3, 1);
|
||||
|
||||
OpTestCase op = new OpTestCase(new SufficientStatistics(data, axes));
|
||||
|
||||
INDArray expected1 = Nd4j.scalar(8.0);
|
||||
INDArray expected2 = Nd4j.createFromArray(new double[]{
|
||||
30.2, 5., 7.8, 22.8
|
||||
});
|
||||
INDArray expected3 = Nd4j.createFromArray(new double[]{
|
||||
154.22, 7., 14.34, 103.62
|
||||
});
|
||||
|
||||
op.expectedOutput(0, expected1);
|
||||
op.expectedOutput(1, expected2);
|
||||
op.expectedOutput(2, expected3);
|
||||
|
||||
String err = OpValidation.validate(op);
|
||||
assertNull(err);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testStandardDeviation() {
|
||||
|
||||
for (boolean keepDims : new boolean[]{false, true}) {
|
||||
SameDiff sameDiff = SameDiff.create();
|
||||
|
||||
INDArray in = Nd4j.linspace(1, 8, 8).reshape(2, 4);
|
||||
SDVariable input = sameDiff.var(in);
|
||||
INDArray expected = Nd4j.createFromArray(new double[]{
|
||||
2, 2, 2, 2
|
||||
});
|
||||
|
||||
if(keepDims){
|
||||
expected = expected.reshape(1,4);
|
||||
}
|
||||
|
||||
SDVariable output = new StandardDeviation(sameDiff, input, false, keepDims, new int[]{0}).outputVariable();
|
||||
|
||||
TestCase tc = new TestCase(sameDiff)
|
||||
.gradientCheck(true)
|
||||
.expectedOutput(output.name(), expected);
|
||||
|
||||
String err = OpValidation.validate(tc);
|
||||
assertNull(err);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSquaredNorm() {
|
||||
|
||||
for (boolean keepDims : new boolean[]{false, true}) {
|
||||
SameDiff sameDiff = SameDiff.create();
|
||||
|
||||
INDArray in = Nd4j.linspace(1, 4, 4);
|
||||
SDVariable input = sameDiff.var(in);
|
||||
INDArray expected = Nd4j.scalar(30.0000);
|
||||
if(keepDims)
|
||||
expected = expected.reshape(1);
|
||||
|
||||
SDVariable output = new SquaredNorm(sameDiff, input, keepDims, new int[]{0}).outputVariable();
|
||||
|
||||
TestCase tc = new TestCase(sameDiff)
|
||||
.gradientCheck(true)
|
||||
.expectedOutput(output.name(), expected);
|
||||
|
||||
String err = OpValidation.validate(tc);
|
||||
assertNull(err);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testShannonEntropy() {
|
||||
OpValidationSuite.ignoreFailing(); //AB 2020/02/11 https://github.com/eclipse/deeplearning4j/issues/8695
|
||||
|
||||
SameDiff sameDiff = SameDiff.create();
|
||||
|
||||
INDArray in = Nd4j.linspace(1, 4, 4).castTo(DataType.DOUBLE);
|
||||
SDVariable input = sameDiff.var(in);
|
||||
INDArray expected = Nd4j.scalar(-69.68162);
|
||||
|
||||
SDVariable output = new ShannonEntropy(sameDiff, input, new int[]{0}).outputVariable();
|
||||
|
||||
TestCase tc = new TestCase(sameDiff)
|
||||
.gradientCheck(true)
|
||||
.expectedOutput(output.name(), expected);
|
||||
|
||||
String err = OpValidation.validate(tc);
|
||||
assertNull(err);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEntropy() {
|
||||
|
||||
SameDiff sameDiff = SameDiff.create();
|
||||
|
||||
INDArray in = Nd4j.linspace(1, 4, 4);
|
||||
SDVariable input = sameDiff.var(in);
|
||||
double expected = -10.2273;
|
||||
|
||||
SDVariable output = new Entropy(sameDiff, input, new int[]{0}).outputVariable();
|
||||
|
||||
TestCase tc = new TestCase(sameDiff)
|
||||
.gradientCheck(true)
|
||||
.expectedOutput(output.name(), Nd4j.scalar(expected));
|
||||
|
||||
String err = OpValidation.validate(tc);
|
||||
assertNull(err);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testAMean() {
|
||||
|
||||
SameDiff sameDiff = SameDiff.create();
|
||||
|
||||
INDArray in = Nd4j.linspace(1, 12, 12).reshape(3, 4);
|
||||
SDVariable input = sameDiff.var(in);
|
||||
INDArray expected = Nd4j.createFromArray(new double[]{
|
||||
5.0000, 6.0000, 7.0000, 8.0000
|
||||
});
|
||||
|
||||
SDVariable output = new AMean(sameDiff, input, new int[]{0}).outputVariable();
|
||||
|
||||
TestCase tc = new TestCase(sameDiff)
|
||||
.gradientCheck(true)
|
||||
.expectedOutput(output.name(), expected);
|
||||
|
||||
String err = OpValidation.validate(tc);
|
||||
assertNull(err);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMean() {
|
||||
|
||||
SameDiff sameDiff = SameDiff.create();
|
||||
|
||||
INDArray in = Nd4j.linspace(1, 12, 12).reshape(3, 4);
|
||||
SDVariable input = sameDiff.var(in);
|
||||
INDArray expected = Nd4j.createFromArray(new double[]{
|
||||
5.0000, 6.0000, 7.0000, 8.0000
|
||||
});
|
||||
|
||||
SDVariable output = new Mean(sameDiff, input, false, new int[]{0}).outputVariable();
|
||||
|
||||
TestCase tc = new TestCase(sameDiff)
|
||||
.gradientCheck(true)
|
||||
.expectedOutput(output.name(), expected);
|
||||
|
||||
String err = OpValidation.validate(tc);
|
||||
assertNull(err);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNorm1() {
|
||||
|
||||
SameDiff sameDiff = SameDiff.create();
|
||||
|
||||
INDArray in = Nd4j.linspace(1, 12, 12).reshape(3, 4);
|
||||
SDVariable input = sameDiff.var(in);
|
||||
INDArray expected = Nd4j.createFromArray(new double[]{
|
||||
15.0000, 18.0000, 21.0000, 24.0000
|
||||
});
|
||||
|
||||
SDVariable output = new Norm1(sameDiff, input, false, new int[]{0}).outputVariable();
|
||||
|
||||
TestCase tc = new TestCase(sameDiff)
|
||||
.gradientCheck(true)
|
||||
.expectedOutput(output.name(), expected);
|
||||
|
||||
String err = OpValidation.validate(tc);
|
||||
assertNull(err);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNorm2() {
|
||||
|
||||
SameDiff sameDiff = SameDiff.create();
|
||||
|
||||
INDArray in = Nd4j.linspace(1, 12, 12).reshape(3, 4);
|
||||
SDVariable input = sameDiff.var(in);
|
||||
INDArray expected = Nd4j.createFromArray(new double[]{
|
||||
10.3441, 11.8322, 13.3791, 14.9666
|
||||
});
|
||||
|
||||
SDVariable output = new Norm2(sameDiff, input, false, new int[]{0}).outputVariable();
|
||||
|
||||
TestCase tc = new TestCase(sameDiff)
|
||||
.gradientCheck(true)
|
||||
.expectedOutput(output.name(), expected);
|
||||
|
||||
String err = OpValidation.validate(tc);
|
||||
assertNull(err);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNormMax() {
|
||||
|
||||
SameDiff sameDiff = SameDiff.create();
|
||||
|
||||
INDArray in = Nd4j.linspace(1, 12, 12).reshape(3, 4);
|
||||
SDVariable input = sameDiff.var(in);
|
||||
INDArray expected = Nd4j.createFromArray(new double[]{
|
||||
9.0000, 10.0000, 11.0000, 12.0000
|
||||
});
|
||||
|
||||
SDVariable output = new NormMax(sameDiff, input, false, new int[]{0}).outputVariable();
|
||||
|
||||
TestCase tc = new TestCase(sameDiff)
|
||||
.gradientCheck(true)
|
||||
.expectedOutput(output.name(), expected);
|
||||
|
||||
String err = OpValidation.validate(tc);
|
||||
assertNull(err);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSoftmaxCrossEntropyWithLogitsLoss() {
|
||||
OpValidationSuite.ignoreFailing();
|
||||
|
||||
SameDiff sameDiff = SameDiff.create();
|
||||
|
||||
INDArray labels = Nd4j.createFromArray(new double[]{
|
||||
0,1,1,0,0,0,1,0,1,0,1,1,1,0,1,0,1,0,0,1,1,0,1,0
|
||||
}).reshape(2,3,4);
|
||||
|
||||
INDArray logits = Nd4j.linspace(DataType.DOUBLE, 0.1, 0.1, 24).reshape(2,3,4);
|
||||
INDArray expected = Nd4j.createFromArray(new double[]{
|
||||
0.26328, 1.46328, 1.72656, 0. , 0.26328, 0. , 1.46328, 0.26328, 1.72656, 0. , 1.72656, 1.46328
|
||||
}).reshape(3,4);
|
||||
|
||||
SDVariable sdLogits = sameDiff.var("logits", logits);
|
||||
SDVariable sdLabels = sameDiff.var("labels", labels);
|
||||
SDVariable loss = sameDiff.math().abs(sdLogits);
|
||||
|
||||
|
||||
SDVariable output = new SoftmaxCrossEntropyWithLogitsLoss(sameDiff, sdLogits, sdLabels, 0).outputVariable();
|
||||
sameDiff.setLossVariables(output);
|
||||
|
||||
TestCase tc = new TestCase(sameDiff)
|
||||
.gradientCheck(true)
|
||||
.expectedOutput(output.name(), expected);
|
||||
|
||||
String err = OpValidation.validate(tc);
|
||||
assertNull(err);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1016,7 +1016,7 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
|
||||
@Test
|
||||
public void testConstant(){
|
||||
OpValidationSuite.ignoreFailing();
|
||||
//OpValidationSuite.ignoreFailing();
|
||||
|
||||
//Case 0: no shape
|
||||
SameDiff sd = SameDiff.create();
|
||||
|
@ -1035,7 +1035,9 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
INDArray exp = Nd4j.valueArrayOf(new long[]{3,4,5}, 3.0);
|
||||
loss = constant.std(true);
|
||||
|
||||
assertNull(OpValidation.validate(new TestCase(sd).expected(constant, ia)));
|
||||
assertNull(OpValidation.validate(new TestCase(sd)
|
||||
.gradientCheck(false)
|
||||
.expected(constant, Nd4j.create(DataType.FLOAT, 3,4,5))));
|
||||
}
|
||||
|
||||
|
||||
|
@ -1272,7 +1274,7 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
|
||||
SameDiff sd = SameDiff.create();
|
||||
SDVariable data = sd.var("data", d);
|
||||
SDVariable segments = sd.var("segments", s);
|
||||
SDVariable segments = sd.constant("segments", s);
|
||||
|
||||
SDVariable sm;
|
||||
INDArray exp;
|
||||
|
@ -1326,6 +1328,7 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
SDVariable loss = sm.std(true);
|
||||
sd.addLossVariable(loss);
|
||||
|
||||
TestCase tc = new TestCase(sd)
|
||||
.testName(op)
|
||||
|
@ -1363,17 +1366,19 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
|
||||
@Test
|
||||
public void testSequenceMask() {
|
||||
OpValidationSuite.ignoreFailing(); //2018-01-09: output datatype issue?
|
||||
SameDiff sameDiff = SameDiff.create();
|
||||
INDArray arr = Nd4j.create(new float[] {1, 3, 2}).reshape(3);
|
||||
SDVariable lengths = sameDiff.var("lengths", arr);
|
||||
INDArray arr = Nd4j.createFromArray(new int[] {1, 3, 2});
|
||||
// arr is not trainable, so it's constant in model
|
||||
SDVariable lengths = sameDiff.constant(arr);
|
||||
|
||||
// Test with static max len
|
||||
int maxlen = 5;
|
||||
INDArray expected = Nd4j.create(new float[] {1, 0, 0, 0, 0,
|
||||
1, 1, 1, 0, 0,
|
||||
1, 1, 0, 0, 0},
|
||||
new long[]{3, 5});
|
||||
INDArray expected = Nd4j.create(new float[] {
|
||||
1.f, 0.f, 0.f, 0.f, 0.f,
|
||||
1.f, 1.f, 1.f, 0.f, 0.f,
|
||||
1.f, 1.f, 0.f, 0.f, 0.f
|
||||
}).reshape(3,5);
|
||||
INDArray[] ret = Nd4j.exec(new SequenceMask(arr, maxlen, DataType.FLOAT));
|
||||
SDVariable result1 = sameDiff.sequenceMask(lengths, maxlen, DataType.FLOAT);
|
||||
assertArrayEquals(expected.shape(), result1.eval().shape());
|
||||
assertEquals(expected, result1.eval());
|
||||
|
@ -1382,14 +1387,14 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
|
||||
String err = OpValidation.validate(new TestCase(sameDiff)
|
||||
.expected(result1, expected)
|
||||
.gradCheckSkipVariables(lengths.name()));
|
||||
.gradientCheck(false));
|
||||
assertNull(err);
|
||||
|
||||
// Test with dynamic maxlen
|
||||
lengths = sameDiff.var("lengths2", arr); // required because of an internal samediff bug
|
||||
SDVariable maxLen = sameDiff.var("maxLen", Nd4j.create(new float[]{5}).reshape(1));
|
||||
lengths = sameDiff.constant("lengths2", arr);
|
||||
SDVariable maxLen = sameDiff.constant("maxLen", Nd4j.scalar(5));
|
||||
SDVariable result2 = sameDiff.sequenceMask(lengths, maxLen, DataType.FLOAT);
|
||||
assertArrayEquals(expected.shape(), result2.eval().shape());
|
||||
// assertArrayEquals(expected.shape(), result2.eval().shape());
|
||||
assertEquals(expected, result2.eval());
|
||||
}
|
||||
|
||||
|
|
|
@ -303,7 +303,7 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
|
||||
@Test
|
||||
public void testBatchToSpace() {
|
||||
OpValidationSuite.ignoreFailing(); //TODO: https://github.com/deeplearning4j/deeplearning4j/issues/6863
|
||||
//OpValidationSuite.ignoreFailing(); //TODO: https://github.com/deeplearning4j/deeplearning4j/issues/6863
|
||||
Nd4j.getRandom().setSeed(1337);
|
||||
|
||||
int miniBatch = 4;
|
||||
|
@ -314,7 +314,6 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
int[] cropShape = new int[]{M, 2};
|
||||
|
||||
INDArray input = Nd4j.randn(inputShape).castTo(DataType.DOUBLE);
|
||||
INDArray blocks = Nd4j.create(new float[]{2, 2}, blockShape).castTo(DataType.INT);
|
||||
INDArray crops = Nd4j.create(new float[]{0, 0, 0, 0}, cropShape).castTo(DataType.INT);
|
||||
|
||||
SameDiff sd = SameDiff.create();
|
||||
|
@ -323,7 +322,8 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
|
||||
INDArray expOut = Nd4j.create(DataType.DOUBLE, 1, 2, 2, 1);
|
||||
DynamicCustomOp op = DynamicCustomOp.builder("batch_to_space")
|
||||
.addInputs(input, blocks, crops)
|
||||
.addInputs(input, crops)
|
||||
.addIntegerArguments(2)
|
||||
.addOutputs(expOut).build();
|
||||
Nd4j.getExecutioner().exec(op);
|
||||
|
||||
|
@ -340,7 +340,7 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
|
||||
@Test
|
||||
public void testSpaceToBatch() {
|
||||
OpValidationSuite.ignoreFailing(); //TODO: https://github.com/deeplearning4j/deeplearning4j/issues/6863
|
||||
//OpValidationSuite.ignoreFailing(); //TODO: https://github.com/deeplearning4j/deeplearning4j/issues/6863
|
||||
|
||||
Nd4j.getRandom().setSeed(7331);
|
||||
|
||||
|
@ -352,7 +352,6 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
int[] paddingShape = new int[]{M, 2};
|
||||
|
||||
INDArray input = Nd4j.randn(inputShape).castTo(DataType.DOUBLE);
|
||||
INDArray blocks = Nd4j.create(new float[]{2, 2}, blockShape).castTo(DataType.INT);
|
||||
INDArray padding = Nd4j.create(new float[]{0, 0, 0, 0}, paddingShape).castTo(DataType.INT);
|
||||
|
||||
SameDiff sd = SameDiff.create();
|
||||
|
@ -361,7 +360,8 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
|
||||
INDArray expOut = Nd4j.create(DataType.DOUBLE, miniBatch, 1, 1, 1);
|
||||
DynamicCustomOp op = DynamicCustomOp.builder("space_to_batch")
|
||||
.addInputs(input, blocks, padding)
|
||||
.addIntegerArguments(2)
|
||||
.addInputs(input, padding)
|
||||
.addOutputs(expOut).build();
|
||||
Nd4j.getExecutioner().exec(op);
|
||||
|
||||
|
|
|
@ -37,6 +37,7 @@ import org.nd4j.linalg.api.ops.impl.image.ResizeBilinear;
|
|||
import org.nd4j.linalg.api.ops.impl.reduce.MmulBp;
|
||||
import org.nd4j.linalg.api.ops.impl.shape.Create;
|
||||
import org.nd4j.linalg.api.ops.impl.shape.OnesLike;
|
||||
import org.nd4j.linalg.api.ops.impl.shape.SequenceMask;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.ModOp;
|
||||
|
@ -1737,4 +1738,19 @@ public class CustomOpsTests extends BaseNd4jTest {
|
|||
|
||||
assertEquals(expected, ret[0]);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSequenceMask() {
|
||||
INDArray arr = Nd4j.createFromArray(new int[]{1, 3, 2});
|
||||
// Test with static max len
|
||||
int maxlen = 2;
|
||||
INDArray expected = Nd4j.createFromArray(new int[]{
|
||||
1,0,0,
|
||||
1,1,1,
|
||||
1,1,0
|
||||
}).reshape(3, 3);
|
||||
|
||||
INDArray[] ret = Nd4j.exec(new SequenceMask(arr, maxlen, DataType.INT32));
|
||||
assertEquals(expected, ret[0]);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -318,8 +318,8 @@ trait NDArrayExtractionTestBase extends FlatSpec { self: OrderingForTest =>
|
|||
|
||||
"num2Scalar" should "convert number to Scalar INDArray" in {
|
||||
|
||||
assert(1.toScalar.data() == List(1).toNDArray.data())
|
||||
assert(2f.toScalar.data() == List(2).toNDArray.data())
|
||||
assert(3d.toScalar.data() == List(3).toNDArray.data())
|
||||
assert(1.toScalar.reshape(1) == List(1).toNDArray)
|
||||
assert(2f.toScalar.reshape(1) == List(2f).toNDArray)
|
||||
assert(3d.toScalar.reshape(1) == List(3d).toNDArray)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue