Dev branch merge: dev_20190606 (#7904)

* correct logsoftmax looss (#2)

* Small SameDiff listener fix (#4)

* Various fixes (#6)

* #7839 Fix for asXMatrix and tests

* #7866 EmbeddingSequenceLayer dtype fix + test

* #7856 SameDiff save/load stream methods

* #7859 RegressionEvaluation rank 4 fix + tests + axis configuration

* EvaluationBinary 3d/4d

* More evaluation 3d/4d tests

* #7847 Evaluation empty checks

* Small test ifx

* #7848 Fix median edge case

* Improve DL4J samediff layer tests

* [WIP] FastText wrapper implemented (#8)

* FastText implemented

* Some fixes

* Fix shapes for wordsNearest

* Validation of input vectors

* Fixes

* Fixed test

* Thread tagged

* Some tweaks

* setContextClassLoader for DeallocatorServiceThread

* Numpy format tests (#1)

* Various fixes (#11)

* #7852 SameDiff gather fix

* #7892 SameDiff placeholder to constant conversion

* #7890 validate input rank for MLN/CG init methods

* Fix broken permute shape calculation

* Permute and gather fixes

* Tests

* #7850 LogSumExp fix + test

* Handful of test fixes

* Empty arrays with non-scalar shapes (#10)

* minor rearrangements for lambdas

* empty tensors with non-scalar shapes

* numpy empty tensors with non-scalar shapes

* few more empty tweaks

* Small fixes

* conv3d signature update

* micro fix in batchnorm mkldnn

* Import fixes

* Fix

* MKL-DNN update

* Small fill fix

* fill with empty input + test

* Fixes

* Small error improvement

* Fix

* one special test

* couple of fixes for lstm

* Rewrite TFGraphMapper.getNDArrayFromTensor to be maintainable and less error prone

* Fixes

* FP16

* Unsigned

* BFloat16

* Fill op - empty tweaks

* - couple of fixes for empty arrays construction
- stack updated

* strided slice fix

* one transform test

* provide method for reducing shapeInfo in case of input array is empty

* Fixed reduceAlongDimensions to use empty input properly.

* couple of broadcast tests

* couple of tests broadcast tests + tweak to make them pass

* add check of non-empty to methods producing sub-arrays

* Fixed reshapeC with zeros in shape.

* complete empty check in reduce_... legacy ops

* Concat and cumsum/prod

* Tweak to empty shape inference on import

* add empty check to the rest of reduce legacy ops

* one more test

* correct typo in evalReduceShapeInfoEmpty

* Added tests for reduce_* ops to tests with zero shapes.

* few more tests for empty reductions

* Fixed strided_slice op with empty case and tests.

* one more empty reduction test

* Fixed strided_slice test.

* add empty check to NDArray::reshapei

* infOrMax

* empty min/max with infinity tests

* made unstack working correctly with empty arrays

* few IndexReduce tests + tweaks for empty shapes

* add test for empty concat

* few tests fixed

* Validation fix for reductions on empty shapes

* Reverse fix

* Reduction shape calc fixes

* SameDiff.generateOutputVariable: don't use shape function to determine number of outputs

* Range fix

* - NDArray constructor updated for scalars/empty arrays
- few tests fixed

* More fixes

* Empty creator fixes

* concat fix

* concat fix

* TF import tests: allow 'both all NaN' and 'both all inf' to pass

* Slice, zero fraction, and reshape fixes

* transpose, gather

* Zero fraction

* scalar cast fix

* Empty reduction axis support

* few more tests fixed

* Fixed input checks conforming with TF for concat op and tests.

* few tests fixed

* matmul scalar shape fix

* Fixed checkout for data type and scalarity with concat to allow non-empty scalars with vector concats.

* broadcast bool fix

* few more tests

* few more tests

* correct evalReduceShapeInfoEmpty

* argmax/argmin + tests

* one more empty edge case + one more test

* argmax/argmin/realdiv_bp tweaks

* empty reshape test + fix

* Helper fixes

* Small fixes

* Gather test fix

* Gather test fix

* Small fixes

* reduce scalar zero values

* scalar mean workaround

* Remove debug code

* along dim mean workaround

* one more test

* - equalsTo() tweak for empty arrays
- one more test

* broadcast tweaks
master
Alex Black 2019-06-15 21:34:34 +10:00 committed by GitHub
parent 32e5cc1945
commit 68ea5f3688
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
169 changed files with 7207 additions and 3633 deletions

View File

@ -23,6 +23,7 @@ import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.layers.EmbeddingLayer; import org.deeplearning4j.nn.conf.layers.EmbeddingLayer;
import org.deeplearning4j.nn.conf.layers.EmbeddingSequenceLayer; import org.deeplearning4j.nn.conf.layers.EmbeddingSequenceLayer;
@ -283,7 +284,6 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
net.computeGradientAndScore(); net.computeGradientAndScore();
net2.computeGradientAndScore(); net2.computeGradientAndScore();
System.out.println(net.score() + "\t" + net2.score());
assertEquals(net2.score(), net.score(), 1e-6); assertEquals(net2.score(), net.score(), 1e-6);
Map<String, INDArray> gradient = net.gradient().gradientForVariable(); Map<String, INDArray> gradient = net.gradient().gradientForVariable();
@ -441,85 +441,87 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
int numInputClasses = 10; int numInputClasses = 10;
int timeSeriesLength = 5; int timeSeriesLength = 5;
for (int nExamples : miniBatchSizes) { for (DataType maskDtype : new DataType[]{DataType.FLOAT, DataType.DOUBLE, DataType.INT}) {
Nd4j.getRandom().setSeed(12345); for (int nExamples : miniBatchSizes) {
Nd4j.getRandom().setSeed(12345);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.updater(new Sgd(0.1)).seed(12345).list() .updater(new Sgd(0.1)).seed(12345).list()
.layer(0, new EmbeddingLayer.Builder().hasBias(true).activation(Activation.TANH).nIn(numInputClasses) .layer(0, new EmbeddingLayer.Builder().hasBias(true).activation(Activation.TANH).nIn(numInputClasses)
.nOut(5).build()) .nOut(5).build())
.layer(1, new DenseLayer.Builder().activation(Activation.TANH).nIn(5).nOut(4).build()) .layer(1, new DenseLayer.Builder().activation(Activation.TANH).nIn(5).nOut(4).build())
.layer(2, new GravesLSTM.Builder().activation(Activation.TANH).nIn(4).nOut(3).build()) .layer(2, new GravesLSTM.Builder().activation(Activation.TANH).nIn(4).nOut(3).build())
.layer(3, new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(3) .layer(3, new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(3)
.nOut(4).build()) .nOut(4).build())
.inputPreProcessor(0, new RnnToFeedForwardPreProcessor()) .inputPreProcessor(0, new RnnToFeedForwardPreProcessor())
.inputPreProcessor(2, new FeedForwardToRnnPreProcessor()).build(); .inputPreProcessor(2, new FeedForwardToRnnPreProcessor()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder() MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder()
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.updater(new Sgd(0.1)).seed(12345).list() .updater(new Sgd(0.1)).seed(12345).list()
.layer(0, new DenseLayer.Builder().activation(Activation.TANH).nIn(numInputClasses).nOut(5) .layer(0, new DenseLayer.Builder().activation(Activation.TANH).nIn(numInputClasses).nOut(5)
.build()) .build())
.layer(1, new DenseLayer.Builder().activation(Activation.TANH).nIn(5).nOut(4).build()) .layer(1, new DenseLayer.Builder().activation(Activation.TANH).nIn(5).nOut(4).build())
.layer(2, new GravesLSTM.Builder().activation(Activation.TANH).nIn(4).nOut(3).build()) .layer(2, new GravesLSTM.Builder().activation(Activation.TANH).nIn(4).nOut(3).build())
.layer(3, new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(3) .layer(3, new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(3)
.nOut(4).build()) .nOut(4).build())
.inputPreProcessor(0, new RnnToFeedForwardPreProcessor()) .inputPreProcessor(0, new RnnToFeedForwardPreProcessor())
.inputPreProcessor(2, new FeedForwardToRnnPreProcessor()).build(); .inputPreProcessor(2, new FeedForwardToRnnPreProcessor()).build();
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2); MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
net2.init(); net2.init();
net2.setParams(net.params().dup()); net2.setParams(net.params().dup());
INDArray inEmbedding = Nd4j.zeros(nExamples, 1, timeSeriesLength); INDArray inEmbedding = Nd4j.zeros(nExamples, 1, timeSeriesLength);
INDArray inDense = Nd4j.zeros(nExamples, numInputClasses, timeSeriesLength); INDArray inDense = Nd4j.zeros(nExamples, numInputClasses, timeSeriesLength);
INDArray labels = Nd4j.zeros(nExamples, 4, timeSeriesLength); INDArray labels = Nd4j.zeros(nExamples, 4, timeSeriesLength);
for (int i = 0; i < nExamples; i++) { for (int i = 0; i < nExamples; i++) {
for (int j = 0; j < timeSeriesLength; j++) { for (int j = 0; j < timeSeriesLength; j++) {
int inIdx = r.nextInt(numInputClasses); int inIdx = r.nextInt(numInputClasses);
inEmbedding.putScalar(new int[]{i, 0, j}, inIdx); inEmbedding.putScalar(new int[]{i, 0, j}, inIdx);
inDense.putScalar(new int[]{i, inIdx, j}, 1.0); inDense.putScalar(new int[]{i, inIdx, j}, 1.0);
int outIdx = r.nextInt(4); int outIdx = r.nextInt(4);
labels.putScalar(new int[]{i, outIdx, j}, 1.0); labels.putScalar(new int[]{i, outIdx, j}, 1.0);
}
} }
}
INDArray inputMask = Nd4j.zeros(nExamples, timeSeriesLength); INDArray inputMask = Nd4j.zeros(maskDtype, nExamples, timeSeriesLength);
for (int i = 0; i < nExamples; i++) { for (int i = 0; i < nExamples; i++) {
for (int j = 0; j < timeSeriesLength; j++) { for (int j = 0; j < timeSeriesLength; j++) {
inputMask.putScalar(new int[]{i, j}, (r.nextBoolean() ? 1.0 : 0.0)); inputMask.putScalar(new int[]{i, j}, (r.nextBoolean() ? 1.0 : 0.0));
}
} }
}
net.setLayerMaskArrays(inputMask, null); net.setLayerMaskArrays(inputMask, null);
net2.setLayerMaskArrays(inputMask, null); net2.setLayerMaskArrays(inputMask, null);
List<INDArray> actEmbedding = net.feedForward(inEmbedding, false); List<INDArray> actEmbedding = net.feedForward(inEmbedding, false);
List<INDArray> actDense = net2.feedForward(inDense, false); List<INDArray> actDense = net2.feedForward(inDense, false);
for (int i = 1; i < actEmbedding.size(); i++) { for (int i = 1; i < actEmbedding.size(); i++) {
assertEquals(actDense.get(i), actEmbedding.get(i)); assertEquals(actDense.get(i), actEmbedding.get(i));
} }
net.setLabels(labels); net.setLabels(labels);
net2.setLabels(labels); net2.setLabels(labels);
net.computeGradientAndScore(); net.computeGradientAndScore();
net2.computeGradientAndScore(); net2.computeGradientAndScore();
System.out.println(net.score() + "\t" + net2.score()); System.out.println(net.score() + "\t" + net2.score());
assertEquals(net2.score(), net.score(), 1e-5); assertEquals(net2.score(), net.score(), 1e-5);
Map<String, INDArray> gradients = net.gradient().gradientForVariable(); Map<String, INDArray> gradients = net.gradient().gradientForVariable();
Map<String, INDArray> gradients2 = net2.gradient().gradientForVariable(); Map<String, INDArray> gradients2 = net2.gradient().gradientForVariable();
assertEquals(gradients.keySet(), gradients2.keySet()); assertEquals(gradients.keySet(), gradients2.keySet());
for (String s : gradients.keySet()) { for (String s : gradients.keySet()) {
assertEquals(gradients2.get(s), gradients.get(s)); assertEquals(gradients2.get(s), gradients.get(s));
}
} }
} }
} }
@ -583,6 +585,104 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
} }
} }
@Test
public void testEmbeddingSequenceLayerWithMasking() {
//Idea: have masking on the input with an embedding and dense layers on input
//Ensure that the parameter gradients for the inputs don't depend on the inputs when inputs are masked
int[] miniBatchSizes = {1, 3};
int nIn = 2;
Random r = new Random(12345);
int numInputClasses = 10;
int timeSeriesLength = 5;
for (DataType maskDtype : new DataType[]{DataType.FLOAT, DataType.DOUBLE, DataType.INT}) {
for (DataType inLabelDtype : new DataType[]{DataType.FLOAT, DataType.DOUBLE, DataType.INT}) {
for(int inputRank : new int[]{2, 3}) {
for (int nExamples : miniBatchSizes) {
Nd4j.getRandom().setSeed(12345);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.updater(new Sgd(0.1)).seed(12345).list()
.layer(0, new EmbeddingSequenceLayer.Builder().hasBias(true).activation(Activation.TANH).nIn(numInputClasses)
.nOut(5).build())
.layer(1, new DenseLayer.Builder().activation(Activation.TANH).nIn(5).nOut(4).build())
.layer(2, new LSTM.Builder().activation(Activation.TANH).nIn(4).nOut(3).build())
.layer(3, new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(3)
.nOut(4).build())
.setInputType(InputType.recurrent(1)).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder()
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.updater(new Sgd(0.1)).seed(12345).list()
.layer(0, new DenseLayer.Builder().activation(Activation.TANH).nIn(numInputClasses).nOut(5)
.build())
.layer(1, new DenseLayer.Builder().activation(Activation.TANH).nIn(5).nOut(4).build())
.layer(2, new LSTM.Builder().activation(Activation.TANH).nIn(4).nOut(3).build())
.layer(3, new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(3)
.nOut(4).build())
.setInputType(InputType.recurrent(1)).build();
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
net2.init();
net2.setParams(net.params().dup());
INDArray inEmbedding = Nd4j.zeros(inLabelDtype, inputRank == 2 ? new long[]{nExamples, timeSeriesLength} : new long[]{nExamples, 1, timeSeriesLength});
INDArray inDense = Nd4j.zeros(inLabelDtype, nExamples, numInputClasses, timeSeriesLength);
INDArray labels = Nd4j.zeros(inLabelDtype, nExamples, 4, timeSeriesLength);
for (int i = 0; i < nExamples; i++) {
for (int j = 0; j < timeSeriesLength; j++) {
int inIdx = r.nextInt(numInputClasses);
inEmbedding.putScalar(inputRank == 2 ? new int[]{i, j} : new int[]{i, 0, j}, inIdx);
inDense.putScalar(new int[]{i, inIdx, j}, 1.0);
int outIdx = r.nextInt(4);
labels.putScalar(new int[]{i, outIdx, j}, 1.0);
}
}
INDArray inputMask = Nd4j.zeros(maskDtype, nExamples, timeSeriesLength);
for (int i = 0; i < nExamples; i++) {
for (int j = 0; j < timeSeriesLength; j++) {
inputMask.putScalar(new int[]{i, j}, (r.nextBoolean() ? 1.0 : 0.0));
}
}
net.setLayerMaskArrays(inputMask, null);
net2.setLayerMaskArrays(inputMask, null);
List<INDArray> actEmbedding = net.feedForward(inEmbedding, false);
List<INDArray> actDense = net2.feedForward(inDense, false);
for (int i = 2; i < actEmbedding.size(); i++) { //Start from layer 2: EmbeddingSequence is 3d, first dense is 2d (before reshape)
assertEquals(actDense.get(i), actEmbedding.get(i));
}
net.setLabels(labels);
net2.setLabels(labels);
net.computeGradientAndScore();
net2.computeGradientAndScore();
assertEquals(net2.score(), net.score(), 1e-5);
Map<String, INDArray> gradients = net.gradient().gradientForVariable();
Map<String, INDArray> gradients2 = net2.gradient().gradientForVariable();
assertEquals(gradients.keySet(), gradients2.keySet());
for (String s : gradients.keySet()) {
assertEquals(gradients2.get(s), gradients.get(s));
}
}
}
}
}
}
@EqualsAndHashCode @EqualsAndHashCode
private static class WordVectorsMockup implements EmbeddingInitializer { private static class WordVectorsMockup implements EmbeddingInitializer {

View File

@ -213,6 +213,12 @@ public class TestSameDiffConv extends BaseDL4JTest {
INDArray outLoaded = netLoaded.output(in); INDArray outLoaded = netLoaded.output(in);
assertEquals(msg, outExp, outLoaded); assertEquals(msg, outExp, outLoaded);
//Sanity check on different minibatch sizes:
INDArray newIn = Nd4j.vstack(in, in);
INDArray outMbsd = net.output(newIn);
INDArray outMb = net2.output(newIn);
assertEquals(outMb, outMbsd);
} }
} }
} }
@ -306,6 +312,10 @@ public class TestSameDiffConv extends BaseDL4JTest {
assertTrue(msg, gradOK); assertTrue(msg, gradOK);
TestUtils.testModelSerialization(net); TestUtils.testModelSerialization(net);
//Sanity check on different minibatch sizes:
INDArray newIn = Nd4j.vstack(f, f);
net.output(newIn);
} }
} }
} }

View File

@ -137,6 +137,12 @@ public class TestSameDiffDense extends BaseDL4JTest {
INDArray outLoaded = netLoaded.output(in); INDArray outLoaded = netLoaded.output(in);
assertEquals(outExp, outLoaded); assertEquals(outExp, outLoaded);
//Sanity check on different minibatch sizes:
INDArray newIn = Nd4j.vstack(in, in);
INDArray outMbsd = net.output(newIn);
INDArray outMb = net2.output(newIn);
assertEquals(outMb, outMbsd);
} }
} }
} }
@ -314,6 +320,12 @@ public class TestSameDiffDense extends BaseDL4JTest {
netSD.computeGradientAndScore(); netSD.computeGradientAndScore();
// netStandard.computeGradientAndScore(); // netStandard.computeGradientAndScore();
// assertEquals(netStandard.gradient().gradient(), netSD.gradient().gradient()); // assertEquals(netStandard.gradient().gradient(), netSD.gradient().gradient());
//Sanity check on different minibatch sizes:
INDArray newIn = Nd4j.vstack(in, in);
INDArray outMbsd = netSD.output(newIn);
INDArray outMb = netStandard.output(newIn);
assertEquals(outMb, outMbsd);
} }
} }
} }
@ -377,6 +389,12 @@ public class TestSameDiffDense extends BaseDL4JTest {
assertEquals(s, netStandard.params(), netSD.params()); assertEquals(s, netStandard.params(), netSD.params());
assertEquals(s, netStandard.getUpdater().getStateViewArray(), netSD.getUpdater().getStateViewArray()); assertEquals(s, netStandard.getUpdater().getStateViewArray(), netSD.getUpdater().getStateViewArray());
} }
//Sanity check on different minibatch sizes:
INDArray newIn = Nd4j.vstack(ds.getFeatures(), ds.getFeatures());
INDArray outMbsd = netSD.output(newIn);
INDArray outMb = netStandard.output(newIn);
assertEquals(outMb, outMbsd);
} }
@Test @Test
@ -417,6 +435,10 @@ public class TestSameDiffDense extends BaseDL4JTest {
assertTrue(msg, gradOK); assertTrue(msg, gradOK);
TestUtils.testModelSerialization(net); TestUtils.testModelSerialization(net);
//Sanity check on different minibatch sizes:
INDArray newIn = Nd4j.vstack(f, f);
net.output(newIn);
} }
} }
} }

View File

@ -166,6 +166,12 @@ public class TestSameDiffDenseVertex extends BaseDL4JTest {
outSD = loaded.outputSingle(in); outSD = loaded.outputSingle(in);
outStd = netStandard.outputSingle(in); outStd = netStandard.outputSingle(in);
assertEquals(outStd, outSD); assertEquals(outStd, outSD);
//Sanity check on different minibatch sizes:
INDArray newIn = Nd4j.vstack(in, in);
INDArray outMbsd = netSD.output(newIn)[0];
INDArray outMb = netStandard.output(newIn)[0];
assertEquals(outMb, outMbsd);
} }
} }
} }

View File

@ -115,6 +115,12 @@ public class TestSameDiffLambda extends BaseDL4JTest {
outStd = std.outputSingle(in); outStd = std.outputSingle(in);
assertEquals(outStd, outLambda); assertEquals(outStd, outLambda);
//Sanity check on different minibatch sizes:
INDArray newIn = Nd4j.vstack(in, in);
INDArray outMbsd = lambda.output(newIn)[0];
INDArray outMb = std.output(newIn)[0];
assertEquals(outMb, outMbsd);
} }
@Test @Test
@ -186,5 +192,12 @@ public class TestSameDiffLambda extends BaseDL4JTest {
outStd = std.output(in1, in2)[0]; outStd = std.output(in1, in2)[0];
assertEquals(outStd, outLambda); assertEquals(outStd, outLambda);
//Sanity check on different minibatch sizes:
INDArray newIn1 = Nd4j.vstack(in1, in1);
INDArray newIn2 = Nd4j.vstack(in2, in2);
INDArray outMbsd = lambda.output(newIn1, newIn2)[0];
INDArray outMb = std.output(newIn1, newIn2)[0];
assertEquals(outMb, outMbsd);
} }
} }

View File

@ -90,6 +90,12 @@ public class TestSameDiffOutput extends BaseDL4JTest {
MultiLayerNetwork net = new MultiLayerNetwork(confSD.clone()); MultiLayerNetwork net = new MultiLayerNetwork(confSD.clone());
net.init(); net.init();
net.fit(ds); net.fit(ds);
//Sanity check on different minibatch sizes:
INDArray newIn = Nd4j.vstack(in, in);
INDArray outMbsd = netSD.output(newIn);
INDArray outMb = netStd.output(newIn);
assertEquals(outMb, outMbsd);
} }
@ -164,6 +170,12 @@ public class TestSameDiffOutput extends BaseDL4JTest {
MultiLayerNetwork net = new MultiLayerNetwork(confSD.clone()); MultiLayerNetwork net = new MultiLayerNetwork(confSD.clone());
net.init(); net.init();
net.fit(ds); net.fit(ds);
//Sanity check on different minibatch sizes:
INDArray newIn = Nd4j.vstack(in, in);
INDArray outMbsd = netSD.output(newIn);
INDArray outMb = netStd.output(newIn);
assertEquals(outMb, outMbsd);
} }
} }

View File

@ -32,6 +32,7 @@ import org.deeplearning4j.models.embeddings.learning.impl.elements.SkipGram;
import org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils; import org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors; import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectorsImpl; import org.deeplearning4j.models.embeddings.wordvectors.WordVectorsImpl;
import org.deeplearning4j.models.fasttext.FastText;
import org.deeplearning4j.models.glove.Glove; import org.deeplearning4j.models.glove.Glove;
import org.deeplearning4j.models.paragraphvectors.ParagraphVectors; import org.deeplearning4j.models.paragraphvectors.ParagraphVectors;
import org.deeplearning4j.models.sequencevectors.SequenceVectors; import org.deeplearning4j.models.sequencevectors.SequenceVectors;
@ -3090,6 +3091,42 @@ public class WordVectorSerializer {
return word2Vec; return word2Vec;
} }
public static void writeWordVectors(@NonNull FastText vectors, @NonNull File path) throws IOException {
ObjectOutputStream outputStream = null;
try {
outputStream = new ObjectOutputStream(new FileOutputStream(path ));
outputStream.writeObject(vectors);
}
finally {
try {
if (outputStream != null) {
outputStream.flush();
outputStream.close();
}
} catch (IOException ex) {
ex.printStackTrace();
}
}
}
public static FastText readWordVectors(File path) {
FastText result = null;
try {
FileInputStream fileIn = new FileInputStream(path);
ObjectInputStream in = new ObjectInputStream(fileIn);
try {
result = (FastText) in.readObject();
} catch (ClassNotFoundException ex) {
}
} catch (FileNotFoundException ex) {
ex.printStackTrace();
} catch (IOException ex) {
ex.printStackTrace();
}
return result;
}
public static void printOutProjectedMemoryUse(long numWords, int vectorLength, int numTables) { public static void printOutProjectedMemoryUse(long numWords, int vectorLength, int numTables) {
double memSize = numWords * vectorLength * Nd4j.sizeOfDataType() * numTables; double memSize = numWords * vectorLength * Nd4j.sizeOfDataType() * numTables;

View File

@ -21,6 +21,7 @@ import lombok.AllArgsConstructor;
import lombok.Data; import lombok.Data;
import lombok.NonNull; import lombok.NonNull;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.deeplearning4j.models.embeddings.WeightLookupTable; import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable; import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.embeddings.reader.ModelUtils; import org.deeplearning4j.models.embeddings.reader.ModelUtils;
@ -207,7 +208,6 @@ public class BasicModelUtils<T extends SequenceElement> implements ModelUtils<T>
} }
INDArray mean = words.isMatrix() ? words.mean(0) : words; INDArray mean = words.isMatrix() ? words.mean(0) : words;
Collection<String> tempRes = wordsNearest(mean, top + positive.size() + negative.size()); Collection<String> tempRes = wordsNearest(mean, top + positive.size() + negative.size());
List<String> realResults = new ArrayList<>(); List<String> realResults = new ArrayList<>();
@ -232,6 +232,22 @@ public class BasicModelUtils<T extends SequenceElement> implements ModelUtils<T>
return wordsNearestSum(vec, n); return wordsNearestSum(vec, n);
} }
protected INDArray adjustRank(INDArray words) {
if (lookupTable instanceof InMemoryLookupTable) {
InMemoryLookupTable l = (InMemoryLookupTable) lookupTable;
INDArray syn0 = l.getSyn0();
if (!words.dataType().equals(syn0.dataType())) {
return words.castTo(syn0.dataType());
}
if (words.rank() == 0 || words.rank() > 2) {
throw new IllegalStateException("Invalid rank for wordsNearest method");
} else if (words.rank() == 1) {
return words.reshape(1, -1);
}
}
return words;
}
/** /**
* Words nearest based on positive and negative words * Words nearest based on positive and negative words
* * @param top the top n words * * @param top the top n words
@ -239,6 +255,8 @@ public class BasicModelUtils<T extends SequenceElement> implements ModelUtils<T>
*/ */
@Override @Override
public Collection<String> wordsNearest(INDArray words, int top) { public Collection<String> wordsNearest(INDArray words, int top) {
words = adjustRank(words);
if (lookupTable instanceof InMemoryLookupTable) { if (lookupTable instanceof InMemoryLookupTable) {
InMemoryLookupTable l = (InMemoryLookupTable) lookupTable; InMemoryLookupTable l = (InMemoryLookupTable) lookupTable;

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.models.embeddings.reader.impl; package org.deeplearning4j.models.embeddings.reader.impl;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement; import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.ops.transforms.Transforms; import org.nd4j.linalg.ops.transforms.Transforms;
@ -64,6 +65,8 @@ public class FlatModelUtils<T extends SequenceElement> extends BasicModelUtils<T
public Collection<String> wordsNearest(INDArray words, int top) { public Collection<String> wordsNearest(INDArray words, int top) {
Counter<String> distances = new Counter<>(); Counter<String> distances = new Counter<>();
words = adjustRank(words);
for (String s : vocabCache.words()) { for (String s : vocabCache.words()) {
INDArray otherVec = lookupTable.vector(s); INDArray otherVec = lookupTable.vector(s);
double sim = Transforms.cosineSim(Transforms.unitVec(words.dup()), Transforms.unitVec(otherVec.dup())); double sim = Transforms.cosineSim(Transforms.unitVec(words.dup()), Transforms.unitVec(otherVec.dup()));

View File

@ -103,6 +103,7 @@ public class TreeModelUtils<T extends SequenceElement> extends BasicModelUtils<T
@Override @Override
public Collection<String> wordsNearest(INDArray words, int top) { public Collection<String> wordsNearest(INDArray words, int top) {
checkTree(); checkTree();
words = adjustRank(words);
List<DataPoint> add = new ArrayList<>(); List<DataPoint> add = new ArrayList<>();
List<Double> distances = new ArrayList<>(); List<Double> distances = new ArrayList<>();

View File

@ -172,4 +172,10 @@ public interface WordVectors extends Serializable, EmbeddingInitializer {
*/ */
void setModelUtils(ModelUtils utils); void setModelUtils(ModelUtils utils);
/**
* Does implementation vectorize words absent in vocabulary
* @return boolean
*/
boolean outOfVocabularySupported();
} }

View File

@ -20,6 +20,7 @@ import com.google.common.util.concurrent.AtomicDouble;
import lombok.Getter; import lombok.Getter;
import lombok.NonNull; import lombok.NonNull;
import lombok.Setter; import lombok.Setter;
import lombok.val;
import org.apache.commons.lang.ArrayUtils; import org.apache.commons.lang.ArrayUtils;
import org.deeplearning4j.models.embeddings.WeightLookupTable; import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable; import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
@ -357,4 +358,9 @@ public class WordVectorsImpl<T extends SequenceElement> implements WordVectors {
public boolean jsonSerializable() { public boolean jsonSerializable() {
return false; return false;
} }
@Override
public boolean outOfVocabularySupported() {
return false;
}
} }

View File

@ -6,10 +6,13 @@ import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.NotImplementedException; import org.apache.commons.lang3.NotImplementedException;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.deeplearning4j.models.embeddings.WeightLookupTable; import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
import org.deeplearning4j.models.embeddings.reader.ModelUtils; import org.deeplearning4j.models.embeddings.reader.ModelUtils;
import org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils; import org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors; import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
import org.deeplearning4j.models.word2vec.VocabWord; import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.Word2Vec;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache; import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache; import org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache;
import org.deeplearning4j.text.sentenceiterator.SentenceIterator; import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
@ -17,41 +20,78 @@ import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.primitives.Pair;
import java.io.BufferedWriter; import java.io.*;
import java.io.File; import java.util.*;
import java.io.FileWriter;
import java.io.IOException;
import java.util.Collection;
import java.util.List;
import java.util.Map;
@Slf4j @Slf4j
@AllArgsConstructor @AllArgsConstructor
@lombok.Builder @lombok.Builder
public class FastText implements WordVectors { public class FastText implements WordVectors, Serializable {
private boolean supervised; // Mandatory
private boolean quantize; @Getter private String inputFile;
private boolean predict; @Getter private String outputFile;
private boolean predict_prob;
private boolean skipgram; // Optional for dictionary
@Builder.Default private int bucket = 100; @Builder.Default private int bucket = -1;
@Builder.Default private int minCount = 1; @Builder.Default private int minCount = -1;
@Builder.Default private int minCountLabel = -1;
@Builder.Default private int wordNgrams = -1;
@Builder.Default private int minNgramLength = -1;
@Builder.Default private int maxNgramLength = -1;
@Builder.Default private int samplingThreshold = -1;
private String labelPrefix;
private boolean cbow; // Optional for training
private boolean nn; @Getter private boolean supervised;
private boolean analogies; @Getter private boolean quantize;
private String inputFile; @Getter private boolean predict;
private String outputFile; @Getter private boolean predict_prob;
private SentenceIterator iterator; @Getter private boolean skipgram;
private String modelName; @Getter private boolean cbow;
private String lossName; @Getter private boolean nn;
//TODO: @Getter private boolean analogies;
private double[] pretrainedVectors; @Getter private String pretrainedVectorsFile;
@Getter
@Builder.Default
private double learningRate = -1.0;
@Getter private double learningRateUpdate = -1.0;
@Getter
@Builder.Default
private int dim = -1;
@Getter
@Builder.Default
private int contextWindowSize = -1;
@Getter
@Builder.Default
private int epochs = -1;
@Getter private String modelName;
@Getter private String lossName;
@Getter
@Builder.Default
private int negativeSamples = -1;
@Getter
@Builder.Default
private int numThreads = -1;
@Getter private boolean saveOutput = false;
private JFastText fastTextImpl; // Optional for quantization
private boolean modelLoaded; @Getter
@Builder.Default
private int cutOff = -1;
@Getter private boolean retrain;
@Getter private boolean qnorm;
@Getter private boolean qout;
@Getter
@Builder.Default
private int dsub = -1;
@Getter private SentenceIterator iterator;
@Builder.Default private transient JFastText fastTextImpl = new JFastText();
private transient Word2Vec word2Vec;
@Getter private boolean modelLoaded;
@Getter private boolean modelVectorsLoaded;
private VocabCache vocabCache; private VocabCache vocabCache;
public FastText(File modelPath) { public FastText(File modelPath) {
@ -63,8 +103,97 @@ public class FastText implements WordVectors {
fastTextImpl = new JFastText(); fastTextImpl = new JFastText();
} }
public void init() { private static class ArgsFactory {
fastTextImpl = new JFastText();
private List<String> args = new ArrayList<>();
private void add(String label, String value) {
args.add(label);
args.add(value);
}
private void addOptional(String label, int value) {
if (value >= 0) {
args.add(label);
args.add(Integer.toString(value));
}
}
private void addOptional(String label, double value) {
if (value >= 0.0) {
args.add(label);
args.add(Double.toString(value));
}
}
private void addOptional(String label, String value) {
if (StringUtils.isNotEmpty(value)) {
args.add(label);
args.add(value);
}
}
private void addOptional(String label, boolean value) {
if (value) {
args.add(label);
}
}
public String[] args() {
String[] asArray = new String[args.size()];
return args.toArray(asArray);
}
}
private String[] makeArgs() {
ArgsFactory argsFactory = new ArgsFactory();
argsFactory.addOptional("cbow", cbow);
argsFactory.addOptional("skipgram", skipgram);
argsFactory.addOptional("supervised", supervised);
argsFactory.addOptional("quantize", quantize);
argsFactory.addOptional("predict", predict);
argsFactory.addOptional("predict_prob", predict_prob);
argsFactory.add("-input", inputFile);
argsFactory.add("-output", outputFile );
argsFactory.addOptional("-pretrainedVectors", pretrainedVectorsFile);
argsFactory.addOptional("-bucket", bucket);
argsFactory.addOptional("-minCount", minCount);
argsFactory.addOptional("-minCountLabel", minCountLabel);
argsFactory.addOptional("-wordNgrams", wordNgrams);
argsFactory.addOptional("-minn", minNgramLength);
argsFactory.addOptional("-maxn", maxNgramLength);
argsFactory.addOptional("-t", samplingThreshold);
argsFactory.addOptional("-label", labelPrefix);
argsFactory.addOptional("analogies",analogies);
argsFactory.addOptional("-lr", learningRate);
argsFactory.addOptional("-lrUpdateRate", learningRateUpdate);
argsFactory.addOptional("-dim", dim);
argsFactory.addOptional("-ws", contextWindowSize);
argsFactory.addOptional("-epoch", epochs);
argsFactory.addOptional("-loss", lossName);
argsFactory.addOptional("-neg", negativeSamples);
argsFactory.addOptional("-thread", numThreads);
argsFactory.addOptional("-saveOutput", saveOutput);
argsFactory.addOptional("-cutoff", cutOff);
argsFactory.addOptional("-retrain", retrain);
argsFactory.addOptional("-qnorm", qnorm);
argsFactory.addOptional("-qout", qout);
argsFactory.addOptional("-dsub", dsub);
return argsFactory.args();
}
public void fit() {
String[] cmd = makeArgs();
fastTextImpl.runCmd(cmd);
}
public void loadIterator() {
if (iterator != null) { if (iterator != null) {
try { try {
File tempFile = File.createTempFile("FTX", ".txt"); File tempFile = File.createTempFile("FTX", ".txt");
@ -81,24 +210,11 @@ public class FastText implements WordVectors {
} }
} }
public void fit() { public void loadPretrainedVectors(File vectorsFile) {
word2Vec = WordVectorSerializer.readWord2VecModel(vectorsFile);
String[] cmd; modelVectorsLoaded = true;
if (skipgram) { log.info("Loaded vectorized representation from file %s. Functionality will be restricted.",
cmd = new String[]{"skipgram", "-bucket", Integer.toString(bucket), "-minCount", Integer.toString(minCount), vectorsFile.getAbsolutePath());
"-input", inputFile, "-output", outputFile};
}
else if (cbow) {
cmd = new String[]{"cbow", "-bucket", Integer.toString(bucket), "-minCount", Integer.toString(minCount),
"-input", inputFile, "-output", outputFile};
}
else if (supervised)
cmd = new String[]{"supervised", "-input", inputFile,
"-output", outputFile};
else
cmd = new String[]{"-input", inputFile,
"-output", outputFile};
fastTextImpl.runCmd(cmd);
} }
public void loadBinaryModel(String modelPath) { public void loadBinaryModel(String modelPath) {
@ -111,10 +227,18 @@ public class FastText implements WordVectors {
modelLoaded = false; modelLoaded = false;
} }
public void test(File testFile) {
fastTextImpl.test(testFile.getAbsolutePath());
}
private void assertModelLoaded() {
if (!modelLoaded && !modelVectorsLoaded)
throw new IllegalStateException("Model must be loaded before predict!");
}
public String predict(String text) { public String predict(String text) {
if (!modelLoaded) assertModelLoaded();
throw new IllegalStateException("Model must be loaded before predict!");
String label = fastTextImpl.predict(text); String label = fastTextImpl.predict(text);
return label; return label;
@ -122,8 +246,7 @@ public class FastText implements WordVectors {
public Pair<String, Float> predictProbability(String text) { public Pair<String, Float> predictProbability(String text) {
if (!modelLoaded) assertModelLoaded();
throw new IllegalStateException("Model must be loaded before predict!");
JFastText.ProbLabel predictedProbLabel = fastTextImpl.predictProba(text); JFastText.ProbLabel predictedProbLabel = fastTextImpl.predictProba(text);
@ -135,27 +258,39 @@ public class FastText implements WordVectors {
@Override @Override
public VocabCache vocab() { public VocabCache vocab() {
if (!modelLoaded) if (modelVectorsLoaded) {
throw new IllegalStateException("Load model before calling vocab()"); vocabCache = word2Vec.vocab();
if (vocabCache == null) {
vocabCache = new AbstractCache();
} }
List<String> words = fastTextImpl.getWords(); else {
for (int i = 0; i < words.size(); ++i) { if (!modelLoaded)
vocabCache.addWordToIndex(i, words.get(i)); throw new IllegalStateException("Load model before calling vocab()");
VocabWord word = new VocabWord();
word.setWord(words.get(i)); if (vocabCache == null) {
vocabCache.addToken(word); vocabCache = new AbstractCache();
}
List<String> words = fastTextImpl.getWords();
for (int i = 0; i < words.size(); ++i) {
vocabCache.addWordToIndex(i, words.get(i));
VocabWord word = new VocabWord();
word.setWord(words.get(i));
vocabCache.addToken(word);
}
} }
return vocabCache; return vocabCache;
} }
@Override @Override
public long vocabSize() { public long vocabSize() {
if (!modelLoaded) long result = 0;
throw new IllegalStateException("Load model before calling vocab()"); if (modelVectorsLoaded) {
return fastTextImpl.getNWords(); result = word2Vec.vocabSize();
}
else {
if (!modelLoaded)
throw new IllegalStateException("Load model before calling vocab()");
result = fastTextImpl.getNWords();
}
return result;
} }
@Override @Override
@ -170,99 +305,160 @@ public class FastText implements WordVectors {
@Override @Override
public double[] getWordVector(String word) { public double[] getWordVector(String word) {
List<Float> vectors = fastTextImpl.getVector(word); if (modelVectorsLoaded) {
double[] retVal = new double[vectors.size()]; return word2Vec.getWordVector(word);
for (int i = 0; i < vectors.size(); ++i) { }
retVal[i] = vectors.get(i); else {
List<Float> vectors = fastTextImpl.getVector(word);
double[] retVal = new double[vectors.size()];
for (int i = 0; i < vectors.size(); ++i) {
retVal[i] = vectors.get(i);
}
return retVal;
} }
return retVal;
} }
@Override @Override
public INDArray getWordVectorMatrixNormalized(String word) { public INDArray getWordVectorMatrixNormalized(String word) {
INDArray r = getWordVectorMatrix(word); if (modelVectorsLoaded) {
return r.divi(Nd4j.getBlasWrapper().nrm2(r)); return word2Vec.getWordVectorMatrixNormalized(word);
}
else {
INDArray r = getWordVectorMatrix(word);
return r.divi(Nd4j.getBlasWrapper().nrm2(r));
}
} }
@Override @Override
public INDArray getWordVectorMatrix(String word) { public INDArray getWordVectorMatrix(String word) {
double[] values = getWordVector(word); if (modelVectorsLoaded) {
return Nd4j.createFromArray(values); return word2Vec.getWordVectorMatrix(word);
}
else {
double[] values = getWordVector(word);
return Nd4j.createFromArray(values);
}
} }
@Override @Override
public INDArray getWordVectors(Collection<String> labels) { public INDArray getWordVectors(Collection<String> labels) {
if (modelVectorsLoaded) {
return word2Vec.getWordVectors(labels);
}
return null; return null;
} }
@Override @Override
public INDArray getWordVectorsMean(Collection<String> labels) { public INDArray getWordVectorsMean(Collection<String> labels) {
if (modelVectorsLoaded) {
return word2Vec.getWordVectorsMean(labels);
}
return null; return null;
} }
private List<String> words = new ArrayList<>();
@Override @Override
public boolean hasWord(String word) { public boolean hasWord(String word) {
return fastTextImpl.getWords().contains(word); if (modelVectorsLoaded) {
return word2Vec.outOfVocabularySupported();
}
if (words.isEmpty())
words = fastTextImpl.getWords();
return words.contains(word);
} }
protected transient ModelUtils modelUtils = new BasicModelUtils<>(); protected transient ModelUtils modelUtils;
@Override @Override
public Collection<String> wordsNearest(INDArray words, int top) { public Collection<String> wordsNearest(INDArray words, int top) {
if (modelVectorsLoaded) {
return word2Vec.wordsNearest(words, top);
}
return modelUtils.wordsNearest(words, top); return modelUtils.wordsNearest(words, top);
} }
@Override @Override
public Collection<String> wordsNearestSum(INDArray words, int top) { public Collection<String> wordsNearestSum(INDArray words, int top) {
if (modelVectorsLoaded) {
return word2Vec.wordsNearestSum(words, top);
}
return modelUtils.wordsNearestSum(words, top); return modelUtils.wordsNearestSum(words, top);
} }
@Override @Override
public Collection<String> wordsNearestSum(String word, int n) { public Collection<String> wordsNearestSum(String word, int n) {
if (modelVectorsLoaded) {
return word2Vec.wordsNearestSum(word, n);
}
return modelUtils.wordsNearestSum(word, n); return modelUtils.wordsNearestSum(word, n);
} }
@Override @Override
public Collection<String> wordsNearestSum(Collection<String> positive, Collection<String> negative, int top) { public Collection<String> wordsNearestSum(Collection<String> positive, Collection<String> negative, int top) {
if (modelVectorsLoaded) {
return word2Vec.wordsNearestSum(positive, negative, top);
}
return modelUtils.wordsNearestSum(positive, negative, top); return modelUtils.wordsNearestSum(positive, negative, top);
} }
@Override @Override
public Map<String, Double> accuracy(List<String> questions) { public Map<String, Double> accuracy(List<String> questions) {
if (modelVectorsLoaded) {
return word2Vec.accuracy(questions);
}
return modelUtils.accuracy(questions); return modelUtils.accuracy(questions);
} }
@Override @Override
public int indexOf(String word) { public int indexOf(String word) {
if (modelVectorsLoaded) {
return word2Vec.indexOf(word);
}
return vocab().indexOf(word); return vocab().indexOf(word);
} }
@Override @Override
public List<String> similarWordsInVocabTo(String word, double accuracy) { public List<String> similarWordsInVocabTo(String word, double accuracy) {
if (modelVectorsLoaded) {
return word2Vec.similarWordsInVocabTo(word, accuracy);
}
return modelUtils.similarWordsInVocabTo(word, accuracy); return modelUtils.similarWordsInVocabTo(word, accuracy);
} }
@Override @Override
public Collection<String> wordsNearest(Collection<String> positive, Collection<String> negative, int top) { public Collection<String> wordsNearest(Collection<String> positive, Collection<String> negative, int top) {
if (modelVectorsLoaded) {
return word2Vec.wordsNearest(positive, negative, top);
}
return modelUtils.wordsNearest(positive, negative, top); return modelUtils.wordsNearest(positive, negative, top);
} }
@Override @Override
public Collection<String> wordsNearest(String word, int n) { public Collection<String> wordsNearest(String word, int n) {
if (modelVectorsLoaded) {
return word2Vec.wordsNearest(word,n);
}
return modelUtils.wordsNearestSum(word, n); return modelUtils.wordsNearestSum(word, n);
} }
@Override @Override
public double similarity(String word, String word2) { public double similarity(String word, String word2) {
if (modelVectorsLoaded) {
return word2Vec.similarity(word, word2);
}
return modelUtils.similarity(word, word2); return modelUtils.similarity(word, word2);
} }
@Override @Override
public WeightLookupTable lookupTable() { public WeightLookupTable lookupTable() {
if (modelVectorsLoaded) {
return word2Vec.lookupTable();
}
return null; return null;
} }
@ -320,4 +516,9 @@ public class FastText implements WordVectors {
return fastTextImpl.getLabelPrefix(); return fastTextImpl.getLabelPrefix();
} }
@Override
public boolean outOfVocabularySupported() {
return true;
}
} }

View File

@ -376,6 +376,11 @@ public class StaticWord2Vec implements WordVectors {
return false; return false;
} }
@Override
public boolean outOfVocabularySupported() {
return false;
}
public static class Builder { public static class Builder {
private AbstractStorage<Integer> storage; private AbstractStorage<Integer> storage;

View File

@ -1,6 +1,10 @@
package org.deeplearning4j.models.fasttext; package org.deeplearning4j.models.fasttext;
import com.github.jfasttext.JFastText;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
import org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils;
import org.deeplearning4j.models.word2vec.Word2Vec;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.text.sentenceiterator.BasicLineIterator; import org.deeplearning4j.text.sentenceiterator.BasicLineIterator;
import org.deeplearning4j.text.sentenceiterator.SentenceIterator; import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
@ -13,6 +17,7 @@ import org.nd4j.resources.Resources;
import java.io.File; import java.io.File;
import java.io.IOException; import java.io.IOException;
import java.util.Arrays;
import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertArrayEquals;
@ -23,7 +28,9 @@ import static org.junit.Assert.assertEquals;
public class FastTextTest extends BaseDL4JTest { public class FastTextTest extends BaseDL4JTest {
private File inputFile = Resources.asFile("models/fasttext/data/labeled_data.txt"); private File inputFile = Resources.asFile("models/fasttext/data/labeled_data.txt");
private File modelFile = Resources.asFile("models/fasttext/supervised.model.bin"); private File supModelFile = Resources.asFile("models/fasttext/supervised.model.bin");
private File cbowModelFile = Resources.asFile("models/fasttext/cbow.model.bin");
private File supervisedVectors = Resources.asFile("models/fasttext/supervised.model.vec");
@Rule @Rule
@ -39,7 +46,6 @@ public class FastTextTest extends BaseDL4JTest {
inputFile(inputFile.getAbsolutePath()). inputFile(inputFile.getAbsolutePath()).
outputFile(output.getAbsolutePath()).build(); outputFile(output.getAbsolutePath()).build();
log.info("\nTraining supervised model ...\n"); log.info("\nTraining supervised model ...\n");
fastText.init();
fastText.fit(); fastText.fit();
} }
@ -53,7 +59,6 @@ public class FastTextTest extends BaseDL4JTest {
inputFile(inputFile.getAbsolutePath()). inputFile(inputFile.getAbsolutePath()).
outputFile(output.getAbsolutePath()).build(); outputFile(output.getAbsolutePath()).build();
log.info("\nTraining supervised model ...\n"); log.info("\nTraining supervised model ...\n");
fastText.init();
fastText.fit(); fastText.fit();
} }
@ -68,7 +73,6 @@ public class FastTextTest extends BaseDL4JTest {
inputFile(inputFile.getAbsolutePath()). inputFile(inputFile.getAbsolutePath()).
outputFile(output.getAbsolutePath()).build(); outputFile(output.getAbsolutePath()).build();
log.info("\nTraining supervised model ...\n"); log.info("\nTraining supervised model ...\n");
fastText.init();
fastText.fit(); fastText.fit();
} }
@ -82,34 +86,42 @@ public class FastTextTest extends BaseDL4JTest {
inputFile(inputFile.getAbsolutePath()). inputFile(inputFile.getAbsolutePath()).
outputFile(output.getAbsolutePath()).build(); outputFile(output.getAbsolutePath()).build();
log.info("\nTraining supervised model ...\n"); log.info("\nTraining supervised model ...\n");
fastText.init();
fastText.fit(); fastText.fit();
} }
@Ignore @Test
public void tesLoadCBOWModel() throws IOException {
FastText fastText = new FastText(cbowModelFile);
fastText.test(cbowModelFile);
assertEquals(19, fastText.vocab().numWords());
assertEquals("enjoy", fastText.vocab().wordAtIndex(fastText.vocab().numWords() - 1));
double[] expected = {5.040466203354299E-4, 0.001005030469968915, 2.8882650076411664E-4, -6.413314840756357E-4, -1.78931062691845E-4, -0.0023157168179750443, -0.002215880434960127, 0.00274421414360404, -1.5344757412094623E-4, 4.6274057240225375E-4, -1.4383681991603225E-4, 3.7832374800927937E-4, 2.523412986192852E-4, 0.0018913350068032742, -0.0024741862434893847, -4.976555937901139E-4, 0.0039220210164785385, -0.001781729981303215, -6.010578363202512E-4, -0.00244093406945467, -7.98621098510921E-4, -0.0010007203090935946, -0.001640203408896923, 7.897148607298732E-4, 9.131592814810574E-4, -0.0013367272913455963, -0.0014030139427632093, -7.755287806503475E-4, -4.2878396925516427E-4, 6.912827957421541E-4, -0.0011824817629531026, -0.0036014916840940714, 0.004353308118879795, -7.073904271237552E-5, -9.646290563978255E-4, -0.0031849315855652094, 2.3360115301329643E-4, -2.9103990527801216E-4, -0.0022990566212683916, -0.002393763978034258, -0.001034979010000825, -0.0010725988540798426, 0.0018285386031493545, -0.0013178540393710136, -1.6632364713586867E-4, -1.4665909475297667E-5, 5.445032729767263E-4, 2.999933494720608E-4, -0.0014367225812748075, -0.002345481887459755, 0.001117417006753385, -8.688368834555149E-4, -0.001830018823966384, 0.0013242220738902688, -8.880519890226424E-4, -6.888324278406799E-4, -0.0036394784692674875, 0.002179111586883664, -1.7201311129610986E-4, 0.002365073887631297, 0.002688770182430744, 0.0023955567739903927, 0.001469283364713192, 0.0011803617235273123, 5.871498142369092E-4, -7.099180947989225E-4, 7.518937345594168E-4, -8.599072461947799E-4, -6.600041524507105E-4, -0.002724145073443651, -8.365285466425121E-4, 0.0013173354091122746, 0.001083166105672717, 0.0014539906987920403, -3.1698777456767857E-4, -2.387022686889395E-4, 1.9560157670639455E-4, 0.0020277926232665777, -0.0012741144746541977, -0.0013026101514697075, -1.5212174912448972E-4, 0.0014194383984431624, 0.0012500399025157094, 0.0013362085446715355, 3.692879108712077E-4, 4.319801155361347E-5, 0.0011261265026405454, 0.0017244465416297317, 5.564604725805111E-5, 0.002170475199818611, 0.0014707016525790095, 0.001303741242736578, 0.005553730763494968, -0.0011097051901742816, -0.0013661726843565702, 0.0014100460102781653, 0.0011811562580987811, -6.622733199037611E-4, 7.860265322960913E-4, -9.811905911192298E-4};
assertArrayEquals(expected, fastText.getWordVector("enjoy"), 1e-4);
}
@Test @Test
public void testPredict() throws IOException { public void testPredict() throws IOException {
for (int i = 0; i < 100; ++i) {
String text = "I like soccer"; String text = "I like soccer";
FastText fastText = new FastText(modelFile); FastText fastText = new FastText(supModelFile);
assertEquals(48, fastText.vocab().numWords()); assertEquals(48, fastText.vocab().numWords());
assertEquals("association", fastText.vocab().wordAtIndex(fastText.vocab().numWords() - 1)); assertEquals("association", fastText.vocab().wordAtIndex(fastText.vocab().numWords() - 1));
double[] expected = {-0.006423053797334433, 0.007660661358386278, 0.006068876478821039, -0.004772625397890806, -0.007143457420170307, -0.007735592778772116, -0.005607823841273785, -0.00836215727031231, 0.0011235733982175589, 2.599214785732329E-4, 0.004131870809942484, 0.007203693501651287, 0.0016768622444942594, 0.008694255724549294, -0.0012487826170399785, -0.00393667770549655, -0.006292815785855055, 0.0049359360709786415, -3.356488887220621E-4, -0.009407570585608482, -0.0026168026961386204, -0.00978928804397583, 0.0032913016621023417, -0.0029464277904480696, -0.008649969473481178, 8.056449587456882E-4, 0.0043088337406516075, -0.008980576880276203, 0.008716211654245853, 0.0073893265798687935, -0.007388216909021139, 0.003814412746578455, -0.005518500227481127, 0.004668557550758123, 0.006603693123906851, 0.003820829326286912, 0.007174000144004822, -0.006393063813447952, -0.0019381389720365405, -0.0046371882781386375, -0.006193376146256924, -0.0036685809027403593, 7.58899434003979E-4, -0.003185075242072344, -0.008330358192324638, 3.3206873922608793E-4, -0.005389622412621975, 0.009706716984510422, 0.0037855932023376226, -0.008665262721478939, -0.0032511046156287193, 4.4134497875347733E-4, -0.008377416990697384, -0.009110655635595322, 0.0019723298028111458, 0.007486093323677778, 0.006400121841579676, 0.00902814231812954, 0.00975200068205595, 0.0060582347214221954, -0.0075621469877660275, 1.0270809434587136E-4, -0.00673140911385417, -0.007316927425563335, 0.009916870854794979, -0.0011407854035496712, -4.502215306274593E-4, -0.007612560410052538, 0.008726916275918484, -3.0280642022262327E-5, 0.005529289599508047, -0.007944817654788494, 0.005593308713287115, 0.003423960180953145, 4.1348213562741876E-4, 0.009524818509817123, -0.0025129399728029966, -0.0030074280221015215, -0.007503866218030453, -0.0028124507516622543, -0.006841592025011778, -2.9375351732596755E-4, 0.007195258513092995, -0.007775942329317331, 3.951996040996164E-4, -0.006887971889227629, 0.0032655203249305487, -0.007975360378623009, -4.840183464693837E-6, 0.004651934839785099, 0.0031739831902086735, 0.004644941072911024, -0.007461248897016048, 0.003057275665923953, 0.008903342299163342, 0.006857945583760738, 0.007567950990051031, 0.001506582135334611, 0.0063307867385447025, 0.005645462777465582}; double[] expected = {-0.006423053797334433, 0.007660661358386278, 0.006068876478821039, -0.004772625397890806, -0.007143457420170307, -0.007735592778772116, -0.005607823841273785, -0.00836215727031231, 0.0011235733982175589, 2.599214785732329E-4, 0.004131870809942484, 0.007203693501651287, 0.0016768622444942594, 0.008694255724549294, -0.0012487826170399785, -0.00393667770549655, -0.006292815785855055, 0.0049359360709786415, -3.356488887220621E-4, -0.009407570585608482, -0.0026168026961386204, -0.00978928804397583, 0.0032913016621023417, -0.0029464277904480696, -0.008649969473481178, 8.056449587456882E-4, 0.0043088337406516075, -0.008980576880276203, 0.008716211654245853, 0.0073893265798687935, -0.007388216909021139, 0.003814412746578455, -0.005518500227481127, 0.004668557550758123, 0.006603693123906851, 0.003820829326286912, 0.007174000144004822, -0.006393063813447952, -0.0019381389720365405, -0.0046371882781386375, -0.006193376146256924, -0.0036685809027403593, 7.58899434003979E-4, -0.003185075242072344, -0.008330358192324638, 3.3206873922608793E-4, -0.005389622412621975, 0.009706716984510422, 0.0037855932023376226, -0.008665262721478939, -0.0032511046156287193, 4.4134497875347733E-4, -0.008377416990697384, -0.009110655635595322, 0.0019723298028111458, 0.007486093323677778, 0.006400121841579676, 0.00902814231812954, 0.00975200068205595, 0.0060582347214221954, -0.0075621469877660275, 1.0270809434587136E-4, -0.00673140911385417, -0.007316927425563335, 0.009916870854794979, -0.0011407854035496712, -4.502215306274593E-4, -0.007612560410052538, 0.008726916275918484, -3.0280642022262327E-5, 0.005529289599508047, -0.007944817654788494, 0.005593308713287115, 0.003423960180953145, 4.1348213562741876E-4, 0.009524818509817123, -0.0025129399728029966, -0.0030074280221015215, -0.007503866218030453, -0.0028124507516622543, -0.006841592025011778, -2.9375351732596755E-4, 0.007195258513092995, -0.007775942329317331, 3.951996040996164E-4, -0.006887971889227629, 0.0032655203249305487, -0.007975360378623009, -4.840183464693837E-6, 0.004651934839785099, 0.0031739831902086735, 0.004644941072911024, -0.007461248897016048, 0.003057275665923953, 0.008903342299163342, 0.006857945583760738, 0.007567950990051031, 0.001506582135334611, 0.0063307867385447025, 0.005645462777465582};
assertArrayEquals(expected, fastText.getWordVector("association"), 1e-5); assertArrayEquals(expected, fastText.getWordVector("association"), 1e-4);
String label = fastText.predict(text); String label = fastText.predict(text);
assertEquals("__label__soccer", label); assertEquals("__label__soccer", label);
}
} }
@Ignore
@Test @Test
public void testPredictProbability() throws IOException { public void testPredictProbability() throws IOException {
String text = "I like soccer"; String text = "I like soccer";
FastText fastText = new FastText(modelFile); FastText fastText = new FastText(supModelFile);
Pair<String,Float> result = fastText.predictProbability(text); Pair<String,Float> result = fastText.predictProbability(text);
assertEquals("__label__soccer", result.getFirst()); assertEquals("__label__soccer", result.getFirst());
@ -129,7 +141,7 @@ public class FastTextTest extends BaseDL4JTest {
@Test @Test
public void testVocabulary() throws IOException { public void testVocabulary() throws IOException {
FastText fastText = new FastText(modelFile); FastText fastText = new FastText(supModelFile);
assertEquals(48, fastText.vocab().numWords()); assertEquals(48, fastText.vocab().numWords());
assertEquals(48, fastText.vocabSize()); assertEquals(48, fastText.vocabSize());
@ -149,7 +161,7 @@ public class FastTextTest extends BaseDL4JTest {
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath()); SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
FastText fastText = FastText fastText =
FastText.builder().supervised(true).iterator(iter).build(); FastText.builder().supervised(true).iterator(iter).build();
fastText.init(); fastText.loadIterator();
} catch (IOException e) { } catch (IOException e) {
log.error(e.toString()); log.error(e.toString());
@ -162,4 +174,60 @@ public class FastTextTest extends BaseDL4JTest {
String label = fastText.predict("something"); String label = fastText.predict("something");
} }
@Test
public void testPretrainedVectors() throws IOException {
File output = testDir.newFile();
FastText fastText =
FastText.builder().supervised(true).
inputFile(inputFile.getAbsolutePath()).
pretrainedVectorsFile(supervisedVectors.getAbsolutePath()).
outputFile(output.getAbsolutePath()).build();
log.info("\nTraining supervised model ...\n");
fastText.fit();
}
@Test
public void testWordsStatistics() throws IOException {
File output = testDir.newFile();
FastText fastText =
FastText.builder().supervised(true).
inputFile(inputFile.getAbsolutePath()).
outputFile(output.getAbsolutePath()).build();
log.info("\nTraining supervised model ...\n");
fastText.fit();
Word2Vec word2Vec = WordVectorSerializer.readAsCsv(new File(output.getAbsolutePath() + ".vec"));
assertEquals(48, word2Vec.getVocab().numWords());
System.out.println(word2Vec.wordsNearest("association", 3));
System.out.println(word2Vec.similarity("Football", "teams"));
System.out.println(word2Vec.similarity("professional", "minutes"));
System.out.println(word2Vec.similarity("java","cpp"));
}
@Test
public void testWordsNativeStatistics() throws IOException {
File output = testDir.newFile();
FastText fastText = new FastText();
fastText.loadPretrainedVectors(supervisedVectors);
log.info("\nTraining supervised model ...\n");
assertEquals(48, fastText.vocab().numWords());
String[] result = new String[3];
fastText.wordsNearest("association", 3).toArray(result);
assertArrayEquals(new String[]{"most","eleven","hours"}, result);
assertEquals(0.1657, fastText.similarity("Football", "teams"), 1e-4);
assertEquals(0.3661, fastText.similarity("professional", "minutes"), 1e-4);
assertEquals(Double.NaN, fastText.similarity("java","cpp"), 1e-4);
}
} }

View File

@ -26,6 +26,7 @@ import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer; import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
import org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils; import org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils;
import org.deeplearning4j.models.embeddings.reader.impl.FlatModelUtils; import org.deeplearning4j.models.embeddings.reader.impl.FlatModelUtils;
import org.deeplearning4j.models.fasttext.FastText;
import org.deeplearning4j.models.paragraphvectors.ParagraphVectors; import org.deeplearning4j.models.paragraphvectors.ParagraphVectors;
import org.deeplearning4j.models.sequencevectors.SequenceVectors; import org.deeplearning4j.models.sequencevectors.SequenceVectors;
import org.deeplearning4j.models.word2vec.VocabWord; import org.deeplearning4j.models.word2vec.VocabWord;
@ -42,6 +43,7 @@ import org.nd4j.linalg.factory.Nd4j;
import java.io.ByteArrayInputStream; import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream; import java.io.ByteArrayOutputStream;
import java.io.File; import java.io.File;
import java.io.IOException;
import java.util.Collections; import java.util.Collections;
import static org.junit.Assert.*; import static org.junit.Assert.*;
@ -289,4 +291,33 @@ public class WordVectorSerializerTest extends BaseDL4JTest {
} }
} }
} }
@Test
public void FastText_Correct_WhenDeserialized() throws IOException {
FastText fastText =
FastText.builder().cbow(true).build();
WordVectorSerializer.writeWordVectors(fastText, new File("some.data"));
FastText deser = null;
try {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
deser = WordVectorSerializer.readWordVectors(new File("some.data"));
} catch (Exception e) {
e.printStackTrace();
fail();
}
assertNotNull(deser);
assertEquals(fastText.isCbow(), deser.isCbow());
assertEquals(fastText.isModelLoaded(), deser.isModelLoaded());
assertEquals(fastText.isAnalogies(), deser.isAnalogies());
assertEquals(fastText.isNn(), deser.isNn());
assertEquals(fastText.isPredict(), deser.isPredict());
assertEquals(fastText.isPredict_prob(), deser.isPredict_prob());
assertEquals(fastText.isQuantize(), deser.isQuantize());
assertEquals(fastText.getInputFile(), deser.getInputFile());
assertEquals(fastText.getOutputFile(), deser.getOutputFile());
}
} }

View File

@ -453,6 +453,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
DataType netDtype = getConfiguration().getDataType(); DataType netDtype = getConfiguration().getDataType();
if(parameters != null && parameters.dataType() != netDtype){ if(parameters != null && parameters.dataType() != netDtype){
Preconditions.checkState(parameters.rank() == 2 && parameters.size(0) == 1, "Invalid parameters array: should be rank 2 with shape [1,numParams]. Got %ndShape", parameters);
if(cloneParametersArray){ if(cloneParametersArray){
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) { try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
parameters = parameters.castTo(netDtype); parameters = parameters.castTo(netDtype);

View File

@ -178,8 +178,7 @@ public class EmbeddingSequenceLayer extends BaseLayer<org.deeplearning4j.nn.conf
", mask shape: " + Arrays.toString(maskArray.shape())); ", mask shape: " + Arrays.toString(maskArray.shape()));
} }
//Returned array: rank 3, shape [mb, vector, seqLength]. mask shape: [mb, seqLength] //Returned array: rank 3, shape [mb, vector, seqLength]. mask shape: [mb, seqLength]
Broadcast.mul(ret, maskArray, ret, 0, 2); Broadcast.mul(ret, maskArray.castTo(ret.dataType()), ret, 0, 2);
// ret.muliColumnVector(maskArray);
} }
return ret; return ret;
} }

View File

@ -616,6 +616,7 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
DataType netDtype = getLayerWiseConfigurations().getDataType(); DataType netDtype = getLayerWiseConfigurations().getDataType();
if(parameters != null && parameters.dataType() != netDtype){ if(parameters != null && parameters.dataType() != netDtype){
Preconditions.checkState(parameters.rank() == 2 && parameters.size(0) == 1, "Invalid parameters array: should be rank 2 with shape [1,numParams]. Got %ndShape", parameters);
if(cloneParametersArray){ if(cloneParametersArray){
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) { try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
parameters = parameters.castTo(netDtype); parameters = parameters.castTo(netDtype);
@ -627,6 +628,7 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
} }
} }
if (layerMap == null) if (layerMap == null)
layerMap = new LinkedHashMap<>(); layerMap = new LinkedHashMap<>();

View File

@ -32,7 +32,7 @@ import java.util.Arrays;
@Ignore @Ignore
public class TestSameDiffUI { public class TestSameDiffUI {
// @Ignore @Ignore
@Test @Test
public void testSameDiff() throws Exception { public void testSameDiff() throws Exception {

View File

@ -1598,9 +1598,6 @@ namespace nd4j {
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
int NDArray::rankOf() const { int NDArray::rankOf() const {
if (isEmpty())
return 0;
return shape::rank(_shapeInfo); return shape::rank(_shapeInfo);
} }

View File

@ -36,7 +36,7 @@ std::string NDArray::e(const Nd4jLong i) const;
template <typename T> template <typename T>
NDArray* NDArray::asT() const{ NDArray* NDArray::asT() const{
auto result = new NDArray(ordering(), isScalar() ? std::vector<Nd4jLong>({0}) : getShapeAsVector(), DataTypeUtils::fromT<T>()); auto result = isScalar() ? new NDArray('c', {}, {0.}, DataTypeUtils::fromT<T>(), this->getContext()) : new NDArray(ordering(), getShapeAsVector(), DataTypeUtils::fromT<T>(), this->getContext());
auto l = this->lengthOf(); auto l = this->lengthOf();
prepareSpecialUse({result}, {this}); prepareSpecialUse({result}, {this});
@ -67,17 +67,18 @@ NDArray::NDArray(const NDArray& other) {
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
NDArray::NDArray(const char order, const std::vector<Nd4jLong> &shape, nd4j::DataType dtype, nd4j::LaunchContext * context) { NDArray::NDArray(const char order, const std::vector<Nd4jLong> &shape, nd4j::DataType dtype, nd4j::LaunchContext * context) {
if (shape.empty())
throw std::runtime_error("NDArray constructor: input shape is empty !");
if ((int) shape.size() > MAX_RANK) if ((int) shape.size() > MAX_RANK)
throw std::invalid_argument("Rank of NDArray can't exceed 32"); throw std::invalid_argument("Rank of NDArray can't exceed 32");
_context = context; _context = context;
_isAttached = getContext()->getWorkspace() != nullptr; _isAttached = _context->getWorkspace() != nullptr;
_offset = 0; _offset = 0;
setShapeInfo(ShapeDescriptor(dtype, order, shape)); if (shape.empty())
setShapeInfo(ShapeDescriptor::emptyDescriptor(dtype));
else
setShapeInfo(ShapeDescriptor(dtype, order, shape));
_buffer = std::make_shared<DataBuffer>(lengthOf() * DataTypeUtils::sizeOf(dtype), dtype, getContext()->getWorkspace()); _buffer = std::make_shared<DataBuffer>(lengthOf() * DataTypeUtils::sizeOf(dtype), dtype, getContext()->getWorkspace());
_buffer->setToZeroBuffers(); _buffer->setToZeroBuffers();
} }
@ -85,16 +86,20 @@ NDArray::NDArray(const char order, const std::vector<Nd4jLong> &shape, nd4j::Dat
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
NDArray::NDArray(const char order, const std::vector<Nd4jLong> &shape, const std::vector<double>& data, nd4j::DataType dtype, nd4j::LaunchContext * context) { NDArray::NDArray(const char order, const std::vector<Nd4jLong> &shape, const std::vector<double>& data, nd4j::DataType dtype, nd4j::LaunchContext * context) {
if (shape.empty())
throw std::runtime_error("NDArray constructor: input shape is empty !");
if ((int) shape.size() > MAX_RANK) if ((int) shape.size() > MAX_RANK)
throw std::invalid_argument("Rank of NDArray can't exceed 32"); throw std::invalid_argument("Rank of NDArray can't exceed 32");
_context = context; _context = context;
_offset = 0; _offset = 0;
setShapeInfo(ShapeDescriptor(dtype, order, shape)); if (shape.size() == 0) {
if (data.size() == 0)
setShapeInfo(ShapeDescriptor::emptyDescriptor(dtype));
else
setShapeInfo(ShapeDescriptor::scalarDescriptor(dtype));
} else {
setShapeInfo(ShapeDescriptor(dtype, order, shape));
}
if (lengthOf() != data.size()) { if (lengthOf() != data.size()) {
nd4j_printf("NDArray constructor: data size [%i] doesn't match shape length [%i]\n", data.size(), lengthOf()); nd4j_printf("NDArray constructor: data size [%i] doesn't match shape length [%i]\n", data.size(), lengthOf());
@ -2441,6 +2446,9 @@ void NDArray::applyTrueBroadcast(nd4j::BroadcastOpsTuple op, const NDArray* othe
if(((op.s == scalar::Divide || op.s == scalar::FloorDiv || op.s == scalar::FloorMod) && other->isB()) || (op.s == scalar::ReverseDivide && this->isB())) if(((op.s == scalar::Divide || op.s == scalar::FloorDiv || op.s == scalar::FloorMod) && other->isB()) || (op.s == scalar::ReverseDivide && this->isB()))
throw std::runtime_error("NDArray::applyTrueBroadcast method: you can't divide by bool array !"); throw std::runtime_error("NDArray::applyTrueBroadcast method: you can't divide by bool array !");
if (isEmpty() || other->isEmpty())
return;
NDArray::prepareSpecialUse({target}, {this, other}); NDArray::prepareSpecialUse({target}, {this, other});
if (isScalar()) { if (isScalar()) {
@ -2513,6 +2521,9 @@ void NDArray::applyTrueBroadcast(nd4j::BroadcastBoolOpsTuple op, const NDArray*
if(target == nullptr || other == nullptr) if(target == nullptr || other == nullptr)
throw std::runtime_error("NDArray::applyTrueBroadcast bool method: target or other = nullptr !"); throw std::runtime_error("NDArray::applyTrueBroadcast bool method: target or other = nullptr !");
if (isEmpty() || other->isEmpty())
return;
NDArray::prepareSpecialUse({target}, {this, other}); NDArray::prepareSpecialUse({target}, {this, other});
if (isScalar()) { if (isScalar()) {
@ -2583,6 +2594,13 @@ void NDArray::applyTrueBroadcast(nd4j::BroadcastBoolOpsTuple op, const NDArray*
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
NDArray NDArray::applyTrueBroadcast(nd4j::BroadcastOpsTuple op, const NDArray& other, ExtraArguments *extraArgs) const { NDArray NDArray::applyTrueBroadcast(nd4j::BroadcastOpsTuple op, const NDArray& other, ExtraArguments *extraArgs) const {
if (isEmpty() || other.isEmpty()) {
if (isEmpty())
return NDArray(*this);
else
return NDArray(other);
}
Nd4jLong* newShapeInfo = nullptr; Nd4jLong* newShapeInfo = nullptr;
if(!ShapeUtils::evalBroadcastShapeInfo(*this, other, true, newShapeInfo, getContext()->getWorkspace())) // the rank of new array = max->rankOf)() if(!ShapeUtils::evalBroadcastShapeInfo(*this, other, true, newShapeInfo, getContext()->getWorkspace())) // the rank of new array = max->rankOf)()
throw std::runtime_error("NDArray::applyTrueBroadcast method: the shapes of this and other arrays are not suitable for broadcast operation !"); throw std::runtime_error("NDArray::applyTrueBroadcast method: the shapes of this and other arrays are not suitable for broadcast operation !");
@ -2812,6 +2830,19 @@ bool NDArray::reshapei(const char order, const std::vector<Nd4jLong>& cshape) {
if(order == ordering() && shape::shapeEquals(rankOf(), shapeOf(), cshape.size(), cshape.data())) if(order == ordering() && shape::shapeEquals(rankOf(), shapeOf(), cshape.size(), cshape.data()))
return true; return true;
const bool isOutShapeEmpty = std::find(cshape.begin(), cshape.end(), 0) != cshape.end();
if(isEmpty() && !isOutShapeEmpty)
throw std::invalid_argument("NDArray::reshapei: can't reshape empty array to non-empty !");
if(!isEmpty() && isOutShapeEmpty)
throw std::invalid_argument("NDArray::reshapei: can't reshape non-empty array to empty !");
if(isEmpty() && isOutShapeEmpty) {
Nd4jLong* shapeInfoNew = ShapeBuilders::emptyShapeInfo(dataType(), order, cshape, getContext()->getWorkspace());
setShapeInfo(shapeInfoNew);
RELEASE(shapeInfoNew, getContext()->getWorkspace());
return true;
}
std::vector<Nd4jLong> shape(cshape); std::vector<Nd4jLong> shape(cshape);
int rank = shape.size(); int rank = shape.size();
@ -2823,7 +2854,7 @@ bool NDArray::reshapei(const char order, const std::vector<Nd4jLong>& cshape) {
for (int i = 0; i < (int) shape.size(); i++) { for (int i = 0; i < (int) shape.size(); i++) {
if (shape[i] < 0) { if (shape[i] < 0) {
if (numberNegativesOnes >= 1) if (numberNegativesOnes >= 1)
throw std::runtime_error("Only one dimension can be negative at once"); throw std::runtime_error("NDArray::reshapei: only one dimension can be negative at once");
numberNegativesOnes++; numberNegativesOnes++;
@ -3664,7 +3695,7 @@ void NDArray::reduceAlongDimension(nd4j::reduce::SameOps op, NDArray* target, co
if(rankOf() == copy.size() || copy.empty()) { if(rankOf() == copy.size() || copy.empty()) {
NativeOpExecutioner::execReduceSameScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), nullptr, target->getBuffer(), target->getShapeInfo(), target->getSpecialBuffer(), target->getSpecialShapeInfo()); NativeOpExecutioner::execReduceSameScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), nullptr, target->getBuffer(), target->getShapeInfo(), target->getSpecialBuffer(), target->getSpecialShapeInfo());
} }
else { else { //if (!isEmpty()) {
auto pDims = nd4j::Environment::getInstance()->isCPU() ? copy.data() : nullptr; auto pDims = nd4j::Environment::getInstance()->isCPU() ? copy.data() : nullptr;
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(this->getShapeInfo(), copy); auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(this->getShapeInfo(), copy);
NativeOpExecutioner::execReduceSame(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), nullptr, target->getBuffer(), target->getShapeInfo(), target->getSpecialBuffer(), target->getSpecialShapeInfo(), pDims, copy.size(), packX.platformShapeInfo(), packX.platformOffsets()); NativeOpExecutioner::execReduceSame(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), nullptr, target->getBuffer(), target->getShapeInfo(), target->getSpecialBuffer(), target->getSpecialShapeInfo(), pDims, copy.size(), packX.platformShapeInfo(), packX.platformOffsets());
@ -4198,6 +4229,9 @@ NDArray* NDArray::tensorAlongDimension(Nd4jLong index, const std::vector<int>& d
// operator returns sub-array with buffer pointing at this->_buffer + certain offset // operator returns sub-array with buffer pointing at this->_buffer + certain offset
NDArray NDArray::operator()(const std::vector<Nd4jLong>& idx, const bool keepUnitiesInShape, const bool isStrided) const { NDArray NDArray::operator()(const std::vector<Nd4jLong>& idx, const bool keepUnitiesInShape, const bool isStrided) const {
if(isEmpty())
throw std::invalid_argument("NDArray::operator(sub-arrays): array is empty !");
const int rank = rankOf(); const int rank = rankOf();
Nd4jLong *newShapeInfo = ShapeBuilders::copyShapeInfo(getShapeInfo(), true, getContext()->getWorkspace()); Nd4jLong *newShapeInfo = ShapeBuilders::copyShapeInfo(getShapeInfo(), true, getContext()->getWorkspace());
@ -4260,6 +4294,9 @@ NDArray NDArray::operator()(const Nd4jLong subArrIdx, const std::vector<int>& di
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void NDArray::getSubArrShapeAndOffsets(const std::vector<int>& dimsToExclude, Nd4jLong* &subArrShapeInfo, Nd4jLong* &subArrOffsets, bool keepUnitiesInShape) const { void NDArray::getSubArrShapeAndOffsets(const std::vector<int>& dimsToExclude, Nd4jLong* &subArrShapeInfo, Nd4jLong* &subArrOffsets, bool keepUnitiesInShape) const {
if(isEmpty())
throw std::invalid_argument("NDArray::getSubArrShapeAndOffsets: array is empty !");
const int rank = rankOf(); const int rank = rankOf();
const int subArrRank = (rank == dimsToExclude.size() || keepUnitiesInShape) ? rank : rank - dimsToExclude.size(); const int subArrRank = (rank == dimsToExclude.size() || keepUnitiesInShape) ? rank : rank - dimsToExclude.size();
const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(_shapeInfo, dimsToExclude); const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(_shapeInfo, dimsToExclude);

View File

@ -1334,18 +1334,7 @@ public:
* @param npyArray * @param npyArray
* @return * @return
*/ */
Nd4jPointer shapeBufferForNumpy(Nd4jPointer npyArray) { Nd4jPointer shapeBufferForNumpy(Nd4jPointer npyArray);
cnpy::NpyArray arr = cnpy::loadNpyFromPointer(reinterpret_cast<char *>(npyArray));
unsigned int shapeSize = arr.shape.size();
auto shape = new unsigned int[shapeSize];
for(unsigned int i = 0; i < shapeSize; i++) {
shape[i] = arr.shape[i];
}
auto shapeBuffer = shape::shapeBufferOfNpy(arr.shape.size(), shape, arr.fortranOrder);
delete[] shape;
return reinterpret_cast<Nd4jPointer>(shapeBuffer);
}
/** /**

View File

@ -64,8 +64,8 @@ void NDArray::tickWriteDevice() const { }
void NDArray::tickReadHost() const { } void NDArray::tickReadHost() const { }
void NDArray::tickReadDevice() const { } void NDArray::tickReadDevice() const { }
void NDArray::tickBothActual() const { } void NDArray::tickBothActual() const { }
bool NDArray::isActualOnHostSide() const { } bool NDArray::isActualOnHostSide() const { return true; }
bool NDArray::isActualOnDeviceSide() const { } bool NDArray::isActualOnDeviceSide() const { return true; }
void NDArray::makeBothBuffersActual() const { } void NDArray::makeBothBuffersActual() const { }
@ -419,328 +419,8 @@ void NDArray::repeat(int dimension, NDArray& target) const {
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
#ifndef __JAVACPP_HACK__ #ifndef __JAVACPP_HACK__
template<typename T> #include "NDArrayLambda.hpp"
void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<T(T, T, T)>& func, NDArray* target) {
if (target == nullptr)
target = this;
if (second == nullptr) {
nd4j_printf("applyTriplewiseLambda requires three operands to be valid NDArrays, but Second is NULL\n","");
throw std::runtime_error("second is null");
}
if (third == nullptr) {
nd4j_printf("applyTriplewiseLambda requires three operands to be valid NDArrays, but Third is NULL\n","");
throw std::runtime_error("third is null");
}
if(dataType() != DataTypeUtils::fromT<T>())
throw std::runtime_error("NDArray::applyTriplewiseLambda<T> method: wrong template parameter T, its type should be the same as type of this array!");
if(dataType() != second->dataType() || dataType() != third->dataType() || dataType() != target->dataType())
throw std::runtime_error("NDArray::applyTriplewiseLambda<T> method: bother four arrays (this, second, third, target) should have the same type !");
if (this->lengthOf() != second->lengthOf() || this->lengthOf() != third->lengthOf() || !this->isSameShape(second) || !this->isSameShape(third)) {
nd4j_printf("applyPairwiseLambda requires both operands to have the same shape\n","");
throw std::runtime_error("Shapes mismach");
}
auto f = this->bufferAsT<T>();
auto s = second->bufferAsT<T>();
auto t = third->bufferAsT<T>();
auto z = target->bufferAsT<T>();
if (this->ordering() == second->ordering() && this->ordering() == third->ordering() && this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1) && this->ews() == second->ews() && this->ews() == third->ews()) {
PRAGMA_OMP_PARALLEL_FOR_SIMD
for (Nd4jLong e = 0; e < _length; e++)
z[e] = func(f[e], s[e], t[e]);
} else {
if (f == z) {
PRAGMA_OMP_PARALLEL_FOR_SIMD
for (int e = 0; e < _length; e++) {
auto tOffset = this->getOffset(e);
auto uOffset = second->getOffset(e);
auto vOffset = third->getOffset(e);
f[tOffset] = func(f[tOffset], s[uOffset], t[vOffset]);
}
} else {
PRAGMA_OMP_PARALLEL_FOR_SIMD
for (int e = 0; e < _length; e++) {
auto tOffset = this->getOffset(e);
auto uOffset = second->getOffset(e);
auto vOffset = third->getOffset(e);
auto zOffset = target->getOffset(e);
z[zOffset] = func(f[tOffset], s[uOffset], t[vOffset]);
}
}
}
}
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<double (double, double, double)>& func, NDArray* target);
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<float (float, float, float)>& func, NDArray* target);
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<float16 (float16, float16, float16)>& func, NDArray* target);
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<bfloat16 (bfloat16, bfloat16, bfloat16)>& func, NDArray* target);
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<Nd4jLong (Nd4jLong, Nd4jLong, Nd4jLong)>& func, NDArray* target);
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<int (int, int, int)>& func, NDArray* target);
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<int16_t (int16_t, int16_t, int16_t)>& func, NDArray* target);
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<uint8_t (uint8_t, uint8_t, uint8_t)>& func, NDArray* target);
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<uint16_t (uint16_t, uint16_t, uint16_t)>& func, NDArray* target);
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<uint32_t (uint32_t, uint32_t, uint32_t)>& func, NDArray* target);
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<uint64_t (uint64_t, uint64_t, uint64_t)>& func, NDArray* target);
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<int8_t (int8_t, int8_t, int8_t)>& func, NDArray* target);
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<bool (bool, bool, bool)>& func, NDArray* target);
//////////////////////////////////////////////////////////////////////////
template<typename T>
void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<T(T, T)>& func, NDArray* target) {
if (target == nullptr)
target = this;
if (other == nullptr) {
nd4j_printf("applyPairwiseLambda requires both operands to be valid NDArrays, but Y is NULL\n","");
throw std::runtime_error("Other is null");
}
if(dataType() != DataTypeUtils::fromT<T>())
throw std::runtime_error("NDArray::applyPairwiseLambda<T> method: wrong template parameter T, its type should be the same as type of this array!");
if(dataType() != other->dataType() || dataType() != target->dataType())
throw std::runtime_error("NDArray::applyPairwiseLambda<T> method: all three arrays (this, other, target) must have the same type !");
if (this->lengthOf() != other->lengthOf()) {
nd4j_printf("applyPairwiseLambda requires both operands to have the same shape\n","");
throw std::runtime_error("Shapes mismach");
}
auto f = this->bufferAsT<T>();
auto s = other->bufferAsT<T>();
auto z = target->bufferAsT<T>();
if (this->ordering() == other->ordering() && this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1) && this->ews() == other->ews()) {
PRAGMA_OMP_PARALLEL_FOR_SIMD
for (int e = 0; e < _length; e++)
z[e] = func(f[e], s[e]);
} else {
if (f == z) {
PRAGMA_OMP_PARALLEL_FOR_SIMD
for (int e = 0; e < _length; e++) {
auto xOffset = this->getOffset(e);
auto yOffset = other->getOffset(e);
f[xOffset] = func(f[xOffset], s[yOffset]);
}
} else {
PRAGMA_OMP_PARALLEL_FOR_SIMD
for (int e = 0; e < _length; e++) {
auto xOffset = this->getOffset(e);
auto yOffset = other->getOffset(e);
auto zOffset = target->getOffset(e);
z[zOffset] = func(f[xOffset], s[yOffset]);
}
}
}
}
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<double (double, double)>& func, NDArray* target);
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<float (float, float)>& func, NDArray* target);
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<float16 (float16, float16)>& func, NDArray* target);
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<bfloat16 (bfloat16, bfloat16)>& func, NDArray* target);
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<Nd4jLong (Nd4jLong, Nd4jLong)>& func, NDArray* target);
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<int (int, int)>& func, NDArray* target);
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<int16_t (int16_t, int16_t)>& func, NDArray* target);
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<uint8_t (uint8_t, uint8_t)>& func, NDArray* target);
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<uint16_t (uint16_t, uint16_t)>& func, NDArray* target);
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<uint32_t (uint32_t, uint32_t)>& func, NDArray* target);
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<uint64_t (uint64_t, uint64_t)>& func, NDArray* target);
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<int8_t (int8_t, int8_t)>& func, NDArray* target);
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<bool (bool, bool)>& func, NDArray* target);
//////////////////////////////////////////////////////////////////////////
template<typename T>
void NDArray::applyLambda(const std::function<T(T)>& func, NDArray* target) {
if (target == nullptr)
target = this;
if(dataType() != DataTypeUtils::fromT<T>())
throw std::runtime_error("NDArray::applyLambda<T> method: wrong template parameter T, its type should be the same as type of this array!");
if(dataType() != target->dataType())
throw std::runtime_error("NDArray::applyLambda<T> method: types of this and target array should match !");
auto f = this->bufferAsT<T>();
auto z = target->bufferAsT<T>();
if (this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1)) {
PRAGMA_OMP_PARALLEL_FOR_SIMD
for (int e = 0; e < _length; e++)
z[e] = func(f[e]);
} else {
if (f == z) {
PRAGMA_OMP_PARALLEL_FOR_SIMD
for (int e = 0; e < _length; e++) {
auto xOffset = this->getOffset(e);
f[xOffset] = func(f[xOffset]);
}
} else {
PRAGMA_OMP_PARALLEL_FOR_SIMD
for (int e = 0; e < _length; e++) {
auto xOffset = this->getOffset(e);
auto zOffset = target->getOffset(e);
z[zOffset] = func(f[xOffset]);
}
}
}
}
template void NDArray::applyLambda(const std::function<double(double)>& func, NDArray* target);
template void NDArray::applyLambda(const std::function<float(float)>& func, NDArray* target);
template void NDArray::applyLambda(const std::function<float16(float16)>& func, NDArray* target);
template void NDArray::applyLambda(const std::function<bfloat16(bfloat16)>& func, NDArray* target);
template void NDArray::applyLambda(const std::function<Nd4jLong(Nd4jLong)>& func, NDArray* target);
template void NDArray::applyLambda(const std::function<int16_t(int16_t)>& func, NDArray* target);
template void NDArray::applyLambda(const std::function<int32_t(int32_t)>& func, NDArray* target);
template void NDArray::applyLambda(const std::function<uint8_t(uint8_t)>& func, NDArray* target);
template void NDArray::applyLambda(const std::function<uint16_t(uint16_t)>& func, NDArray* target);
template void NDArray::applyLambda(const std::function<uint32_t(uint32_t)>& func, NDArray* target);
template void NDArray::applyLambda(const std::function<uint64_t(uint64_t)>& func, NDArray* target);
template void NDArray::applyLambda(const std::function<int8_t(int8_t)>& func, NDArray* target);
template void NDArray::applyLambda(const std::function<bool(bool)>& func, NDArray* target);
//////////////////////////////////////////////////////////////////////////
template<typename T>
void NDArray::applyIndexedLambda(const std::function<T(Nd4jLong, T)>& func, NDArray* target) {
if (target == nullptr)
target = this;
if(dataType() != DataTypeUtils::fromT<T>())
throw std::runtime_error("NDArray::applyIndexedLambda<T> method: wrong template parameter T, its type should be the same as type of this array!");
if(dataType() != target->dataType())
throw std::runtime_error("NDArray::applyIndexedLambda<T> method: types of this and target array should match !");
auto f = this->bufferAsT<T>();
auto z = target->bufferAsT<T>();
if (this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1)) {
PRAGMA_OMP_PARALLEL_FOR_SIMD
for (Nd4jLong e = 0; e < _length; e++)
z[e] = func(e, f[e]);
} else {
if (f == z) {
PRAGMA_OMP_PARALLEL_FOR_SIMD
for (Nd4jLong e = 0; e < _length; e++) {
auto xOffset = this->getOffset(e);
f[xOffset] = func(e, f[xOffset]);
}
} else {
PRAGMA_OMP_PARALLEL_FOR_SIMD
for (Nd4jLong e = 0; e < _length; e++) {
auto xOffset = this->getOffset(e);
auto zOffset = target->getOffset(e);
z[zOffset] = func(e, f[xOffset]);
}
}
}
}
template void NDArray::applyIndexedLambda(const std::function<double(Nd4jLong, double)>& func, NDArray* target);
template void NDArray::applyIndexedLambda(const std::function<float(Nd4jLong, float)>& func, NDArray* target);
template void NDArray::applyIndexedLambda(const std::function<float16(Nd4jLong, float16)>& func, NDArray* target);
template void NDArray::applyIndexedLambda(const std::function<bfloat16(Nd4jLong, bfloat16)>& func, NDArray* target);
template void NDArray::applyIndexedLambda(const std::function<Nd4jLong(Nd4jLong, Nd4jLong)>& func, NDArray* target);
template void NDArray::applyIndexedLambda(const std::function<int(Nd4jLong, int)>& func, NDArray* target);
template void NDArray::applyIndexedLambda(const std::function<int16_t(Nd4jLong, int16_t)>& func, NDArray* target);
template void NDArray::applyIndexedLambda(const std::function<uint8_t (Nd4jLong, uint8_t)>& func, NDArray* target);
template void NDArray::applyIndexedLambda(const std::function<uint16_t (Nd4jLong, uint16_t)>& func, NDArray* target);
template void NDArray::applyIndexedLambda(const std::function<uint32_t (Nd4jLong, uint32_t)>& func, NDArray* target);
template void NDArray::applyIndexedLambda(const std::function<uint64_t (Nd4jLong, uint64_t)>& func, NDArray* target);
template void NDArray::applyIndexedLambda(const std::function<int8_t(Nd4jLong, int8_t)>& func, NDArray* target);
template void NDArray::applyIndexedLambda(const std::function<bool(Nd4jLong, bool)>& func, NDArray* target);
//////////////////////////////////////////////////////////////////////////
template<typename T>
void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<T(Nd4jLong, T, T)>& func, NDArray* target) {
if (target == nullptr)
target = this;
if (other == nullptr) {
nd4j_printf("applyIndexedPairwiseLambda requires both operands to be valid NDArrays, but Y is NULL\n","");
throw std::runtime_error("Other is null");
}
if(dataType() != DataTypeUtils::fromT<T>())
throw std::runtime_error("NDArray::applyIndexedPairwiseLambda<T> method: wrong template parameter T, its type should be the same as type of this array!");
if(dataType() != target->dataType())
throw std::runtime_error("NDArray::applyIndexedPairwiseLambda<T> method: types of this and target array should match !");
if (this->lengthOf() != other->lengthOf()) {
nd4j_printf("applyIndexedPairwiseLambda requires both operands to have the same shape\n","");
throw std::runtime_error("Shapes mismach");
}
auto f = this->bufferAsT<T>();
auto s = other->bufferAsT<T>();
auto z = target->bufferAsT<T>();
if (this->ordering() == other->ordering() && this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1) && this->ews() == other->ews()) {
PRAGMA_OMP_PARALLEL_FOR_SIMD
for (Nd4jLong e = 0; e < _length; e++)
z[e] = func((Nd4jLong) e, f[e], s[e]);
} else {
if (f == z) {
PRAGMA_OMP_PARALLEL_FOR_SIMD
for (int e = 0; e < _length; e++) {
auto xOffset = this->getOffset(e);
auto yOffset = other->getOffset(e);
f[xOffset] = func((Nd4jLong) e, f[xOffset], s[yOffset]);
}
} else {
PRAGMA_OMP_PARALLEL_FOR_SIMD
for (int e = 0; e < _length; e++) {
auto xOffset = this->getOffset(e);
auto yOffset = other->getOffset(e);
auto zOffset = target->getOffset(e);
z[zOffset] = func((Nd4jLong) e, f[xOffset], s[yOffset]);
}
}
}
}
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<double (Nd4jLong, double, double)>& func, NDArray* target);
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<float (Nd4jLong, float, float)>& func, NDArray* target);
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<float16 (Nd4jLong, float16, float16)>& func, NDArray* target);
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<bfloat16 (Nd4jLong, bfloat16, bfloat16)>& func, NDArray* target);
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<Nd4jLong (Nd4jLong, Nd4jLong, Nd4jLong)>& func, NDArray* target);
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<int (Nd4jLong, int, int)>& func, NDArray* target);
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<int16_t (Nd4jLong, int16_t, int16_t)>& func, NDArray* target);
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<uint8_t (Nd4jLong, uint8_t, uint8_t)>& func, NDArray* target);
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<uint16_t (Nd4jLong, uint16_t, uint16_t)>& func, NDArray* target);
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<uint32_t (Nd4jLong, uint32_t, uint32_t)>& func, NDArray* target);
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<uint64_t (Nd4jLong, uint64_t, uint64_t)>& func, NDArray* target);
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<int8_t (Nd4jLong, int8_t, int8_t)>& func, NDArray* target);
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<bool (Nd4jLong, bool, bool)>& func, NDArray* target);
#endif #endif
/* /*

View File

@ -0,0 +1,325 @@
template<typename T>
void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<T(T, T, T)>& func, NDArray* target) {
if (target == nullptr)
target = this;
if (second == nullptr) {
nd4j_printf("applyTriplewiseLambda requires three operands to be valid NDArrays, but Second is NULL\n","");
throw std::runtime_error("second is null");
}
if (third == nullptr) {
nd4j_printf("applyTriplewiseLambda requires three operands to be valid NDArrays, but Third is NULL\n","");
throw std::runtime_error("third is null");
}
if(dataType() != DataTypeUtils::fromT<T>())
throw std::runtime_error("NDArray::applyTriplewiseLambda<T> method: wrong template parameter T, its type should be the same as type of this array!");
if(dataType() != second->dataType() || dataType() != third->dataType() || dataType() != target->dataType())
throw std::runtime_error("NDArray::applyTriplewiseLambda<T> method: bother four arrays (this, second, third, target) should have the same type !");
if (this->lengthOf() != second->lengthOf() || this->lengthOf() != third->lengthOf() || !this->isSameShape(second) || !this->isSameShape(third)) {
nd4j_printf("applyPairwiseLambda requires both operands to have the same shape\n","");
throw std::runtime_error("Shapes mismach");
}
auto f = this->bufferAsT<T>();
auto s = second->bufferAsT<T>();
auto t = third->bufferAsT<T>();
auto z = target->bufferAsT<T>();
if (this->ordering() == second->ordering() && this->ordering() == third->ordering() && this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1) && this->ews() == second->ews() && this->ews() == third->ews()) {
PRAGMA_OMP_PARALLEL_FOR_SIMD
for (Nd4jLong e = 0; e < _length; e++)
z[e] = func(f[e], s[e], t[e]);
} else {
if (f == z) {
PRAGMA_OMP_PARALLEL_FOR_SIMD
for (int e = 0; e < _length; e++) {
auto tOffset = this->getOffset(e);
auto uOffset = second->getOffset(e);
auto vOffset = third->getOffset(e);
f[tOffset] = func(f[tOffset], s[uOffset], t[vOffset]);
}
} else {
PRAGMA_OMP_PARALLEL_FOR_SIMD
for (int e = 0; e < _length; e++) {
auto tOffset = this->getOffset(e);
auto uOffset = second->getOffset(e);
auto vOffset = third->getOffset(e);
auto zOffset = target->getOffset(e);
z[zOffset] = func(f[tOffset], s[uOffset], t[vOffset]);
}
}
}
}
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<double (double, double, double)>& func, NDArray* target);
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<float (float, float, float)>& func, NDArray* target);
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<float16 (float16, float16, float16)>& func, NDArray* target);
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<bfloat16 (bfloat16, bfloat16, bfloat16)>& func, NDArray* target);
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<Nd4jLong (Nd4jLong, Nd4jLong, Nd4jLong)>& func, NDArray* target);
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<int (int, int, int)>& func, NDArray* target);
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<int16_t (int16_t, int16_t, int16_t)>& func, NDArray* target);
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<uint8_t (uint8_t, uint8_t, uint8_t)>& func, NDArray* target);
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<uint16_t (uint16_t, uint16_t, uint16_t)>& func, NDArray* target);
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<uint32_t (uint32_t, uint32_t, uint32_t)>& func, NDArray* target);
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<uint64_t (uint64_t, uint64_t, uint64_t)>& func, NDArray* target);
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<int8_t (int8_t, int8_t, int8_t)>& func, NDArray* target);
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<bool (bool, bool, bool)>& func, NDArray* target);
//////////////////////////////////////////////////////////////////////////
template<typename T>
void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<T(T, T)>& func, NDArray* target) {
if (target == nullptr)
target = this;
if (other == nullptr) {
nd4j_printf("applyPairwiseLambda requires both operands to be valid NDArrays, but Y is NULL\n","");
throw std::runtime_error("Other is null");
}
if(dataType() != DataTypeUtils::fromT<T>())
throw std::runtime_error("NDArray::applyPairwiseLambda<T> method: wrong template parameter T, its type should be the same as type of this array!");
if(dataType() != other->dataType() || dataType() != target->dataType())
throw std::runtime_error("NDArray::applyPairwiseLambda<T> method: all three arrays (this, other, target) must have the same type !");
if (this->lengthOf() != other->lengthOf()) {
nd4j_printf("applyPairwiseLambda requires both operands to have the same shape\n","");
throw std::runtime_error("Shapes mismach");
}
auto f = this->bufferAsT<T>();
auto s = other->bufferAsT<T>();
auto z = target->bufferAsT<T>();
if (this->ordering() == other->ordering() && this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1) && this->ews() == other->ews()) {
PRAGMA_OMP_PARALLEL_FOR_SIMD
for (int e = 0; e < _length; e++)
z[e] = func(f[e], s[e]);
} else {
if (f == z) {
PRAGMA_OMP_PARALLEL_FOR_SIMD
for (int e = 0; e < _length; e++) {
auto xOffset = this->getOffset(e);
auto yOffset = other->getOffset(e);
f[xOffset] = func(f[xOffset], s[yOffset]);
}
} else {
PRAGMA_OMP_PARALLEL_FOR_SIMD
for (int e = 0; e < _length; e++) {
auto xOffset = this->getOffset(e);
auto yOffset = other->getOffset(e);
auto zOffset = target->getOffset(e);
z[zOffset] = func(f[xOffset], s[yOffset]);
}
}
}
}
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<double (double, double)>& func, NDArray* target);
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<float (float, float)>& func, NDArray* target);
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<float16 (float16, float16)>& func, NDArray* target);
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<bfloat16 (bfloat16, bfloat16)>& func, NDArray* target);
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<Nd4jLong (Nd4jLong, Nd4jLong)>& func, NDArray* target);
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<int (int, int)>& func, NDArray* target);
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<int16_t (int16_t, int16_t)>& func, NDArray* target);
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<uint8_t (uint8_t, uint8_t)>& func, NDArray* target);
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<uint16_t (uint16_t, uint16_t)>& func, NDArray* target);
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<uint32_t (uint32_t, uint32_t)>& func, NDArray* target);
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<uint64_t (uint64_t, uint64_t)>& func, NDArray* target);
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<int8_t (int8_t, int8_t)>& func, NDArray* target);
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<bool (bool, bool)>& func, NDArray* target);
//////////////////////////////////////////////////////////////////////////
template<typename T>
void NDArray::applyLambda(const std::function<T(T)>& func, NDArray* target) {
if (target == nullptr)
target = this;
if(dataType() != DataTypeUtils::fromT<T>())
throw std::runtime_error("NDArray::applyLambda<T> method: wrong template parameter T, its type should be the same as type of this array!");
if(dataType() != target->dataType())
throw std::runtime_error("NDArray::applyLambda<T> method: types of this and target array should match !");
auto f = this->bufferAsT<T>();
auto z = target->bufferAsT<T>();
if (this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1)) {
PRAGMA_OMP_PARALLEL_FOR_SIMD
for (int e = 0; e < _length; e++)
z[e] = func(f[e]);
} else {
if (f == z) {
PRAGMA_OMP_PARALLEL_FOR_SIMD
for (int e = 0; e < _length; e++) {
auto xOffset = this->getOffset(e);
f[xOffset] = func(f[xOffset]);
}
} else {
PRAGMA_OMP_PARALLEL_FOR_SIMD
for (int e = 0; e < _length; e++) {
auto xOffset = this->getOffset(e);
auto zOffset = target->getOffset(e);
z[zOffset] = func(f[xOffset]);
}
}
}
}
template void NDArray::applyLambda(const std::function<double(double)>& func, NDArray* target);
template void NDArray::applyLambda(const std::function<float(float)>& func, NDArray* target);
template void NDArray::applyLambda(const std::function<float16(float16)>& func, NDArray* target);
template void NDArray::applyLambda(const std::function<bfloat16(bfloat16)>& func, NDArray* target);
template void NDArray::applyLambda(const std::function<Nd4jLong(Nd4jLong)>& func, NDArray* target);
template void NDArray::applyLambda(const std::function<int16_t(int16_t)>& func, NDArray* target);
template void NDArray::applyLambda(const std::function<int32_t(int32_t)>& func, NDArray* target);
template void NDArray::applyLambda(const std::function<uint8_t(uint8_t)>& func, NDArray* target);
template void NDArray::applyLambda(const std::function<uint16_t(uint16_t)>& func, NDArray* target);
template void NDArray::applyLambda(const std::function<uint32_t(uint32_t)>& func, NDArray* target);
template void NDArray::applyLambda(const std::function<uint64_t(uint64_t)>& func, NDArray* target);
template void NDArray::applyLambda(const std::function<int8_t(int8_t)>& func, NDArray* target);
template void NDArray::applyLambda(const std::function<bool(bool)>& func, NDArray* target);
//////////////////////////////////////////////////////////////////////////
template<typename T>
void NDArray::applyIndexedLambda(const std::function<T(Nd4jLong, T)>& func, NDArray* target) {
if (target == nullptr)
target = this;
if(dataType() != DataTypeUtils::fromT<T>())
throw std::runtime_error("NDArray::applyIndexedLambda<T> method: wrong template parameter T, its type should be the same as type of this array!");
if(dataType() != target->dataType())
throw std::runtime_error("NDArray::applyIndexedLambda<T> method: types of this and target array should match !");
auto f = this->bufferAsT<T>();
auto z = target->bufferAsT<T>();
if (this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1)) {
PRAGMA_OMP_PARALLEL_FOR_SIMD
for (Nd4jLong e = 0; e < _length; e++)
z[e] = func(e, f[e]);
} else {
if (f == z) {
PRAGMA_OMP_PARALLEL_FOR_SIMD
for (Nd4jLong e = 0; e < _length; e++) {
auto xOffset = this->getOffset(e);
f[xOffset] = func(e, f[xOffset]);
}
} else {
PRAGMA_OMP_PARALLEL_FOR_SIMD
for (Nd4jLong e = 0; e < _length; e++) {
auto xOffset = this->getOffset(e);
auto zOffset = target->getOffset(e);
z[zOffset] = func(e, f[xOffset]);
}
}
}
}
template void NDArray::applyIndexedLambda(const std::function<double(Nd4jLong, double)>& func, NDArray* target);
template void NDArray::applyIndexedLambda(const std::function<float(Nd4jLong, float)>& func, NDArray* target);
template void NDArray::applyIndexedLambda(const std::function<float16(Nd4jLong, float16)>& func, NDArray* target);
template void NDArray::applyIndexedLambda(const std::function<bfloat16(Nd4jLong, bfloat16)>& func, NDArray* target);
template void NDArray::applyIndexedLambda(const std::function<Nd4jLong(Nd4jLong, Nd4jLong)>& func, NDArray* target);
template void NDArray::applyIndexedLambda(const std::function<int(Nd4jLong, int)>& func, NDArray* target);
template void NDArray::applyIndexedLambda(const std::function<int16_t(Nd4jLong, int16_t)>& func, NDArray* target);
template void NDArray::applyIndexedLambda(const std::function<uint8_t (Nd4jLong, uint8_t)>& func, NDArray* target);
template void NDArray::applyIndexedLambda(const std::function<uint16_t (Nd4jLong, uint16_t)>& func, NDArray* target);
template void NDArray::applyIndexedLambda(const std::function<uint32_t (Nd4jLong, uint32_t)>& func, NDArray* target);
template void NDArray::applyIndexedLambda(const std::function<uint64_t (Nd4jLong, uint64_t)>& func, NDArray* target);
template void NDArray::applyIndexedLambda(const std::function<int8_t(Nd4jLong, int8_t)>& func, NDArray* target);
template void NDArray::applyIndexedLambda(const std::function<bool(Nd4jLong, bool)>& func, NDArray* target);
//////////////////////////////////////////////////////////////////////////
template<typename T>
void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<T(Nd4jLong, T, T)>& func, NDArray* target) {
if (target == nullptr)
target = this;
if (other == nullptr) {
nd4j_printf("applyIndexedPairwiseLambda requires both operands to be valid NDArrays, but Y is NULL\n","");
throw std::runtime_error("Other is null");
}
if(dataType() != DataTypeUtils::fromT<T>())
throw std::runtime_error("NDArray::applyIndexedPairwiseLambda<T> method: wrong template parameter T, its type should be the same as type of this array!");
if(dataType() != target->dataType())
throw std::runtime_error("NDArray::applyIndexedPairwiseLambda<T> method: types of this and target array should match !");
if (this->lengthOf() != other->lengthOf()) {
nd4j_printf("applyIndexedPairwiseLambda requires both operands to have the same shape\n","");
throw std::runtime_error("Shapes mismach");
}
auto f = this->bufferAsT<T>();
auto s = other->bufferAsT<T>();
auto z = target->bufferAsT<T>();
if (this->ordering() == other->ordering() && this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1) && this->ews() == other->ews()) {
PRAGMA_OMP_PARALLEL_FOR_SIMD
for (Nd4jLong e = 0; e < _length; e++)
z[e] = func((Nd4jLong) e, f[e], s[e]);
} else {
if (f == z) {
PRAGMA_OMP_PARALLEL_FOR_SIMD
for (int e = 0; e < _length; e++) {
auto xOffset = this->getOffset(e);
auto yOffset = other->getOffset(e);
f[xOffset] = func((Nd4jLong) e, f[xOffset], s[yOffset]);
}
} else {
PRAGMA_OMP_PARALLEL_FOR_SIMD
for (int e = 0; e < _length; e++) {
auto xOffset = this->getOffset(e);
auto yOffset = other->getOffset(e);
auto zOffset = target->getOffset(e);
z[zOffset] = func((Nd4jLong) e, f[xOffset], s[yOffset]);
}
}
}
}
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<double (Nd4jLong, double, double)>& func, NDArray* target);
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<float (Nd4jLong, float, float)>& func, NDArray* target);
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<float16 (Nd4jLong, float16, float16)>& func, NDArray* target);
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<bfloat16 (Nd4jLong, bfloat16, bfloat16)>& func, NDArray* target);
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<Nd4jLong (Nd4jLong, Nd4jLong, Nd4jLong)>& func, NDArray* target);
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<int (Nd4jLong, int, int)>& func, NDArray* target);
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<int16_t (Nd4jLong, int16_t, int16_t)>& func, NDArray* target);
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<uint8_t (Nd4jLong, uint8_t, uint8_t)>& func, NDArray* target);
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<uint16_t (Nd4jLong, uint16_t, uint16_t)>& func, NDArray* target);
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<uint32_t (Nd4jLong, uint32_t, uint32_t)>& func, NDArray* target);
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<uint64_t (Nd4jLong, uint64_t, uint64_t)>& func, NDArray* target);
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<int8_t (Nd4jLong, int8_t, int8_t)>& func, NDArray* target);
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<bool (Nd4jLong, bool, bool)>& func, NDArray* target);

View File

@ -2710,6 +2710,32 @@ int NativeOps::dataTypeFromNpyHeader(void *header) {
return (int) cnpy::dataTypeFromHeader(reinterpret_cast<char *>(header)); return (int) cnpy::dataTypeFromHeader(reinterpret_cast<char *>(header));
} }
Nd4jPointer NativeOps::shapeBufferForNumpy(Nd4jPointer npyArray) {
cnpy::NpyArray arr = cnpy::loadNpyFromPointer(reinterpret_cast<char *>(npyArray));
unsigned int shapeSize = arr.shape.size();
std::vector<Nd4jLong> shape(shapeSize);
bool _empty = false;
for(unsigned int i = 0; i < shapeSize; i++) {
shape[i] = arr.shape[i];
if (arr.shape[i] == 0)
_empty = true;
}
auto dtype = cnpy::dataTypeFromHeader(reinterpret_cast<char *>(npyArray));
Nd4jLong *shapeBuffer;
if (_empty) {
if (shapeSize > 0)
shapeBuffer = nd4j::ShapeBuilders::emptyShapeInfo(dtype, arr.fortranOrder ? 'f' : 'c', shape);
else
shapeBuffer = nd4j::ShapeBuilders::emptyShapeInfo(dtype);
} else {
shapeBuffer = nd4j::ShapeBuilders::createShapeInfo(dtype, arr.fortranOrder ? 'f' : 'c', shape);
}
return reinterpret_cast<Nd4jPointer>(shapeBuffer);
}
BUILD_SINGLE_TEMPLATE(template void flattenGeneric,(Nd4jPointer*, int, char, void*, Nd4jLong*, void*, Nd4jLong*), LIBND4J_TYPES); BUILD_SINGLE_TEMPLATE(template void flattenGeneric,(Nd4jPointer*, int, char, void*, Nd4jLong*, void*, Nd4jLong*), LIBND4J_TYPES);
BUILD_SINGLE_TEMPLATE(template void pullRowsGeneric, (void *, Nd4jLong*, void*, Nd4jLong*, const int, Nd4jLong*, Nd4jLong*, Nd4jLong*, Nd4jLong*, Nd4jLong*), LIBND4J_TYPES); BUILD_SINGLE_TEMPLATE(template void pullRowsGeneric, (void *, Nd4jLong*, void*, Nd4jLong*, const int, Nd4jLong*, Nd4jLong*, Nd4jLong*, Nd4jLong*, Nd4jLong*), LIBND4J_TYPES);
BUILD_SINGLE_TEMPLATE(template void tearGeneric, (void *, Nd4jLong*, Nd4jPointer*, Nd4jLong*, Nd4jLong*, Nd4jLong*), LIBND4J_TYPES); BUILD_SINGLE_TEMPLATE(template void tearGeneric, (void *, Nd4jLong*, Nd4jPointer*, Nd4jLong*, Nd4jLong*, Nd4jLong*), LIBND4J_TYPES);

View File

@ -454,7 +454,7 @@ void NDArray::printCurrentBuffer(const bool host, const char* msg, const int pre
if (ews() != 1) { if (ews() != 1) {
for (uint i = 0; i < _length; i++) for (uint i = 0; i < _length; i++)
cudaMemcpyAsync(pHost + i * sizeof(T), getSpecialBuffer() + getOffset(i) * sizeof(T), sizeof(T), cudaMemcpyDeviceToHost, *(getContext()->getCudaStream())); cudaMemcpyAsync(reinterpret_cast<T*>(pHost) + i, specialBufferWithOffset(i), sizeof(T), cudaMemcpyDeviceToHost, *(getContext()->getCudaStream()));
} }
else else
cudaMemcpyAsync(pHost, getSpecialBuffer(), sizeOfT() * _length, cudaMemcpyDeviceToHost, *getContext()->getCudaStream()); cudaMemcpyAsync(pHost, getSpecialBuffer(), sizeOfT() * _length, cudaMemcpyDeviceToHost, *getContext()->getCudaStream());
@ -475,6 +475,12 @@ template void NDArray::printCurrentBuffer<float>(const bool host, const char* ms
template void NDArray::printCurrentBuffer<double>(const bool host, const char* msg, const int precision) const; template void NDArray::printCurrentBuffer<double>(const bool host, const char* msg, const int precision) const;
#if defined(__CUDACC__) && !defined(BUILD_TESTS)
#include <cpu/NDArrayLambda.hpp>
#endif
} // end namespace nd4j } // end namespace nd4j
#endif #endif

View File

@ -3105,3 +3105,29 @@ nd4j::ConstantDataBuffer* NativeOps::constantBuffer(nd4j::DataType dtype, double
nd4j::ConstantDataBuffer* NativeOps::constantBuffer(nd4j::DataType dtype, nd4j::ConstantDescriptor *descriptor) { nd4j::ConstantDataBuffer* NativeOps::constantBuffer(nd4j::DataType dtype, nd4j::ConstantDescriptor *descriptor) {
return nd4j::ConstantHelper::getInstance()->constantBuffer(*descriptor, dtype); return nd4j::ConstantHelper::getInstance()->constantBuffer(*descriptor, dtype);
} }
Nd4jPointer NativeOps::shapeBufferForNumpy(Nd4jPointer npyArray) {
cnpy::NpyArray arr = cnpy::loadNpyFromPointer(reinterpret_cast<char *>(npyArray));
unsigned int shapeSize = arr.shape.size();
std::vector<Nd4jLong> shape(shapeSize);
bool _empty = false;
for(unsigned int i = 0; i < shapeSize; i++) {
shape[i] = arr.shape[i];
if (arr.shape[i] == 0)
_empty = true;
}
auto dtype = cnpy::dataTypeFromHeader(reinterpret_cast<char *>(npyArray));
Nd4jLong *shapeBuffer;
if (_empty) {
if (shapeSize > 0)
shapeBuffer = nd4j::ShapeBuilders::emptyShapeInfo(dtype, arr.fortranOrder ? 'f' : 'c', shape);
else
shapeBuffer = nd4j::ShapeBuilders::emptyShapeInfo(dtype);
} else {
shapeBuffer = nd4j::ShapeBuilders::createShapeInfo(dtype, arr.fortranOrder ? 'f' : 'c', shape);
}
return reinterpret_cast<Nd4jPointer>(shapeBuffer);
}

View File

@ -53,6 +53,15 @@ namespace nd4j {
template <typename T> template <typename T>
FORCEINLINE static _CUDA_HD T max(); FORCEINLINE static _CUDA_HD T max();
/**
* returns inf for float/double and max for everything else
*/
template <typename T>
FORCEINLINE static _CUDA_HD T infOrMax();
template <typename T>
FORCEINLINE static _CUDA_HD T nanOrZero();
// returns the difference between 1.0 and the next representable value of the given floating-point type // returns the difference between 1.0 and the next representable value of the given floating-point type
template <typename T> template <typename T>
FORCEINLINE static T eps(); FORCEINLINE static T eps();
@ -290,6 +299,36 @@ FORCEINLINE _CUDA_HD bfloat16 DataTypeUtils::max<bfloat16>() {
return bfloat16::max(); return bfloat16::max();
} }
template <>
FORCEINLINE _CUDA_HD float DataTypeUtils::infOrMax<float>() {
return std::numeric_limits<float>::infinity();
}
template <>
FORCEINLINE _CUDA_HD double DataTypeUtils::infOrMax<double>() {
return std::numeric_limits<double>::infinity();
}
template <typename T>
FORCEINLINE _CUDA_HD T DataTypeUtils::infOrMax() {
return DataTypeUtils::max<T>();
}
template <>
FORCEINLINE _CUDA_HD float DataTypeUtils::nanOrZero<float>() {
return std::numeric_limits<float>::quiet_NaN();
}
template <>
FORCEINLINE _CUDA_HD double DataTypeUtils::nanOrZero<double>() {
return std::numeric_limits<double>::quiet_NaN();
}
template <typename T>
FORCEINLINE _CUDA_HD T DataTypeUtils::nanOrZero() {
return static_cast<T>(0);
}
FORCEINLINE std::string DataTypeUtils::asString(DataType dataType) { FORCEINLINE std::string DataTypeUtils::asString(DataType dataType) {
switch(dataType) { switch(dataType) {
case INT8: case INT8:

View File

@ -55,8 +55,13 @@ bool ShapeDescriptor::operator<(const ShapeDescriptor& other) const {
} }
Nd4jLong* ShapeDescriptor::toShapeInfo() const { Nd4jLong* ShapeDescriptor::toShapeInfo() const {
if (_empty) if (_empty) {
return ShapeBuilders::emptyShapeInfo(_dataType); if (_rank == 0)
return ShapeBuilders::emptyShapeInfo(_dataType);
else {
return ShapeBuilders::emptyShapeInfo(_dataType, _order, _shape);
}
}
switch (_rank) { switch (_rank) {
@ -133,15 +138,11 @@ ShapeDescriptor::ShapeDescriptor(const DataType type, const char order, const Nd
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
ShapeDescriptor::ShapeDescriptor(const DataType type, const char order, const std::vector<Nd4jLong> &shape): _dataType(type), _order(order), _shape(shape) { ShapeDescriptor::ShapeDescriptor(const DataType type, const char order, const std::vector<Nd4jLong> &shape): _dataType(type), _order(order), _shape(shape) {
_rank = ((shape.size() == 1 && shape[0] == 0)? 0: shape.size()); _rank = shape.size();
_ews = 1; _ews = 1;
if (_rank > 0) { if (_rank > 0) {
_strides.resize(_rank); _strides.resize(_rank);
if (order == 'c')
shape::calcStrides(_shape.data(), shape.size(), _strides.data());
else
shape::calcStridesFortran(_shape.data(), shape.size(), _strides.data());
for (auto v:_shape) { for (auto v:_shape) {
if (v == 0) { if (v == 0) {
@ -149,6 +150,17 @@ ShapeDescriptor::ShapeDescriptor(const DataType type, const char order, const st
break; break;
} }
} }
// no point calculating strides for empty arrays
if (!_empty) {
if (order == 'c')
shape::calcStrides(_shape.data(), shape.size(), _strides.data());
else
shape::calcStridesFortran(_shape.data(), shape.size(), _strides.data());
} else {
// all strides set to 0
memset(_strides.data(), 0, sizeof(Nd4jLong) * shape.size());
}
} }
} }
@ -191,8 +203,11 @@ ShapeDescriptor::ShapeDescriptor(const Nd4jLong *shapeInfo, bool inheritDtype) {
_empty = shape::isEmpty(shapeInfo); _empty = shape::isEmpty(shapeInfo);
for (int e = 0; e < _rank; e++) for (int e = 0; e < _rank; e++) {
_shape.emplace_back(shapeInfo[e + 1]); _shape.emplace_back(shapeInfo[e + 1]);
if (shapeInfo[e + 1] == 0)
_empty = true;
}
for (int e = 0; e < _rank; e++) for (int e = 0; e < _rank; e++)
_strides.emplace_back(shapeInfo[e + 1 + _rank]); _strides.emplace_back(shapeInfo[e + 1 + _rank]);
@ -304,7 +319,14 @@ ShapeDescriptor ShapeDescriptor::vectorDescriptor(const Nd4jLong length, const D
ShapeDescriptor descriptor; ShapeDescriptor descriptor;
descriptor._dataType = type; descriptor._dataType = type;
descriptor._shape.emplace_back(length); descriptor._shape.emplace_back(length);
descriptor._strides.emplace_back(1);
if (length > 0)
descriptor._strides.emplace_back(1);
else {
descriptor._strides.emplace_back(0);
descriptor._empty = true;
}
descriptor._order = 'c'; descriptor._order = 'c';
descriptor._ews = 1; descriptor._ews = 1;
descriptor._rank = 1; descriptor._rank = 1;

View File

@ -29,7 +29,7 @@
#include <array/ArrayOptions.h> #include <array/ArrayOptions.h>
namespace nd4j { namespace nd4j {
class ShapeBuilders { class ND4J_EXPORT ShapeBuilders {
public: public:
static Nd4jLong* createScalarShapeInfo(nd4j::DataType dataType, nd4j::memory::Workspace* workspace = nullptr); static Nd4jLong* createScalarShapeInfo(nd4j::DataType dataType, nd4j::memory::Workspace* workspace = nullptr);
@ -53,6 +53,8 @@ namespace nd4j {
static Nd4jLong* emptyShapeInfo(const nd4j::DataType dataType, memory::Workspace* workspace = nullptr); static Nd4jLong* emptyShapeInfo(const nd4j::DataType dataType, memory::Workspace* workspace = nullptr);
static Nd4jLong* emptyShapeInfo(const nd4j::DataType dataType, const char order, const std::vector<Nd4jLong> &shape, memory::Workspace* workspace = nullptr);
}; };
} }

View File

@ -40,6 +40,12 @@ namespace nd4j {
static Nd4jLong* evalReduceShapeInfo(const char order, std::vector<int>& dimensions, const NDArray& arr, const bool keepDims = false, const bool supportOldShapes = false, nd4j::memory::Workspace* workspace = nullptr); static Nd4jLong* evalReduceShapeInfo(const char order, std::vector<int>& dimensions, const NDArray& arr, const bool keepDims = false, const bool supportOldShapes = false, nd4j::memory::Workspace* workspace = nullptr);
static Nd4jLong* evalReduceShapeInfo(const char order, std::vector<int>& dimensions, const Nd4jLong* shapeInfo, const bool keepDims = false, const bool supportOldShapes = false, nd4j::memory::Workspace* workspace = nullptr); static Nd4jLong* evalReduceShapeInfo(const char order, std::vector<int>& dimensions, const Nd4jLong* shapeInfo, const bool keepDims = false, const bool supportOldShapes = false, nd4j::memory::Workspace* workspace = nullptr);
/**
* evaluate output shape for reduce operation when input shape is empty
* behavior is analogous to tf
*/
static Nd4jLong* evalReduceShapeInfoEmpty(const char order, std::vector<int>& dimensions, const Nd4jLong *shapeInfo, const nd4j::DataType dataType, const bool keepDims, nd4j::memory::Workspace* workspace);
// evaluate shape for array which is result of repeat operation applied to arr // evaluate shape for array which is result of repeat operation applied to arr
static std::vector<Nd4jLong> evalRepeatShape(int dimension, const std::vector<Nd4jLong>& repeats, const NDArray& arr); static std::vector<Nd4jLong> evalRepeatShape(int dimension, const std::vector<Nd4jLong>& repeats, const NDArray& arr);

View File

@ -71,7 +71,9 @@ namespace nd4j {
auto sPtr = new Nd4jLong[shape::shapeInfoLength(subArrRank)]; auto sPtr = new Nd4jLong[shape::shapeInfoLength(subArrRank)];
auto oPtr = new Nd4jLong[numOfSubArrs]; auto oPtr = new Nd4jLong[numOfSubArrs];
shape::calcSubArrShapeAndOffsets(shapeInfo, numOfSubArrs, dimsToExclude.size(), dimsToExclude.data(), sPtr, oPtr, descriptor.areUnitiesinShape()); if (numOfSubArrs > 0)
shape::calcSubArrShapeAndOffsets(shapeInfo, numOfSubArrs, dimsToExclude.size(), dimsToExclude.data(), sPtr, oPtr, descriptor.areUnitiesinShape());
ConstantDataBuffer shapesBuffer(sPtr, nullptr, shape::shapeInfoLength(subArrRank)*sizeof(Nd4jLong), DataType::INT64); ConstantDataBuffer shapesBuffer(sPtr, nullptr, shape::shapeInfoLength(subArrRank)*sizeof(Nd4jLong), DataType::INT64);
ConstantDataBuffer offsetsBuffer(oPtr, nullptr, numOfSubArrs*sizeof(Nd4jLong), DataType::INT64); ConstantDataBuffer offsetsBuffer(oPtr, nullptr, numOfSubArrs*sizeof(Nd4jLong), DataType::INT64);

View File

@ -75,7 +75,8 @@ namespace nd4j {
auto sPtr = new Nd4jLong[shape::shapeInfoLength(subArrRank)]; auto sPtr = new Nd4jLong[shape::shapeInfoLength(subArrRank)];
auto oPtr = new Nd4jLong[numOfSubArrs]; auto oPtr = new Nd4jLong[numOfSubArrs];
shape::calcSubArrShapeAndOffsets(shapeInfo, numOfSubArrs, dimsToExclude.size(), dimsToExclude.data(), sPtr, oPtr, descriptor.areUnitiesinShape()); if (numOfSubArrs > 0)
shape::calcSubArrShapeAndOffsets(shapeInfo, numOfSubArrs, dimsToExclude.size(), dimsToExclude.data(), sPtr, oPtr, descriptor.areUnitiesinShape());
Nd4jPointer soPtr; Nd4jPointer soPtr;
auto res = cudaMalloc(reinterpret_cast<void**>(&soPtr), numOfSubArrs * sizeof(Nd4jLong)); auto res = cudaMalloc(reinterpret_cast<void**>(&soPtr), numOfSubArrs * sizeof(Nd4jLong));

View File

@ -54,11 +54,6 @@ namespace nd4j {
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
Nd4jLong* ShapeBuilders::createShapeInfo(const nd4j::DataType dataType, const char order, int rank, const Nd4jLong* shapeOnly, memory::Workspace* workspace) { Nd4jLong* ShapeBuilders::createShapeInfo(const nd4j::DataType dataType, const char order, int rank, const Nd4jLong* shapeOnly, memory::Workspace* workspace) {
if (rank)
if(shapeOnly[0] == 0) // scalar case
rank = 0;
Nd4jLong* shapeInfo = nullptr; Nd4jLong* shapeInfo = nullptr;
if(rank == 0) { // scalar case if(rank == 0) { // scalar case
@ -67,10 +62,23 @@ namespace nd4j {
else { else {
ALLOCATE(shapeInfo, workspace, shape::shapeInfoLength(rank), Nd4jLong); ALLOCATE(shapeInfo, workspace, shape::shapeInfoLength(rank), Nd4jLong);
shapeInfo[0] = rank; shapeInfo[0] = rank;
for(int i = 0; i < rank; ++i) bool isEmpty = false;
for(int i = 0; i < rank; ++i) {
shapeInfo[i + 1] = shapeOnly[i]; shapeInfo[i + 1] = shapeOnly[i];
shape::updateStrides(shapeInfo, order); if (shapeOnly[i] == 0)
isEmpty = true;
}
if (!isEmpty) {
shape::updateStrides(shapeInfo, order);
}
else {
shapeInfo[shape::shapeInfoLength(rank) - 1] = order;
memset(shape::stride(shapeInfo), 0, rank * sizeof(Nd4jLong));
ArrayOptions::setPropertyBit(shapeInfo, ARRAY_EMPTY);
}
nd4j::ArrayOptions::setDataType(shapeInfo, dataType); nd4j::ArrayOptions::setDataType(shapeInfo, dataType);
} }
@ -78,9 +86,16 @@ namespace nd4j {
} }
Nd4jLong* ShapeBuilders::emptyShapeInfo(const nd4j::DataType dataType, memory::Workspace* workspace) { Nd4jLong* ShapeBuilders::emptyShapeInfo(const nd4j::DataType dataType, memory::Workspace* workspace) {
auto shape = createScalarShapeInfo(dataType, workspace); auto shapeInfo = createScalarShapeInfo(dataType, workspace);
ArrayOptions::setPropertyBit(shape, ARRAY_EMPTY); ArrayOptions::setPropertyBit(shapeInfo, ARRAY_EMPTY);
return shape; return shapeInfo;
}
Nd4jLong* ShapeBuilders::emptyShapeInfo(const nd4j::DataType dataType, const char order, const std::vector<Nd4jLong> &shape, memory::Workspace* workspace) {
auto shapeInfo = createShapeInfo(dataType, order, shape, workspace);
memset(shape::stride(shapeInfo), 0, shape.size() * sizeof(Nd4jLong));
ArrayOptions::setPropertyBit(shapeInfo, ARRAY_EMPTY);
return shapeInfo;
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////

View File

@ -108,27 +108,81 @@ std::vector<Nd4jLong> ShapeUtils::evalShapeForTensorDot(const NDArray* a, cons
return evalShapeForTensorDot(a->getShapeInfo(), b->getShapeInfo(), axesA, axesB, permutAt, permutBt, shapeAt, shapeBt); return evalShapeForTensorDot(a->getShapeInfo(), b->getShapeInfo(), axesA, axesB, permutAt, permutBt, shapeAt, shapeBt);
} }
Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector<int>& dimensions, const NDArray& arr, const bool keepDims, const bool supportOldShapes, nd4j::memory::Workspace* workspace) {
return evalReduceShapeInfo(order, dimensions, arr, arr.dataType(), keepDims, supportOldShapes, workspace); //////////////////////////////////////////////////////////////////////////
// evaluate output shape for reduce operation when input shape is empty
Nd4jLong* ShapeUtils::evalReduceShapeInfoEmpty(const char order, std::vector<int>& dimsToExclude, const Nd4jLong *shapeInfo, const nd4j::DataType dataType, const bool keepDims, nd4j::memory::Workspace* workspace) {
if (dimsToExclude.size() == 0) { // return copy of input shape
Nd4jLong* outShapeInfo = ShapeBuilders::copyShapeInfoAndType(shapeInfo, dataType, true, workspace);
ShapeDescriptor descriptor(outShapeInfo, dataType);
RELEASE(outShapeInfo, workspace);
return ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
}
const int rank = shape::rank(shapeInfo);
Nd4jLong* outShapeInfo = nullptr;
if (dimsToExclude.size() == rank) { // return scalar or shape filled with unities
if(!keepDims)
outShapeInfo = ShapeBuilders::createScalarShapeInfo(dataType, workspace);
else
outShapeInfo = ShapeBuilders::createShapeInfo(dataType, order, std::vector<Nd4jLong>(rank, 1), workspace);
}
else {
shape::checkDimensions(rank, dimsToExclude);
std::vector<Nd4jLong> outShape;
if(keepDims) {
outShape.assign(shapeInfo + 1, shapeInfo + 1 + rank);
for(const auto& dim : dimsToExclude)
outShape[dim] = 1;
}
else {
for (uint i = 0, j = 0; i < rank; ++i) {
if(j < dimsToExclude.size() && i == dimsToExclude[j])
++j;
else
outShape.emplace_back(shapeInfo[i + 1]);
}
}
outShapeInfo = ShapeBuilders::createShapeInfo(dataType, order, outShape, workspace);
}
ShapeDescriptor descriptor(outShapeInfo, dataType);
RELEASE(outShapeInfo, workspace);
return ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
} }
Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector<int>& dimensions, const Nd4jLong* shapeInfo, const bool keepDims, const bool supportOldShapes, nd4j::memory::Workspace* workspace) { Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector<int>& dimsToExclude, const NDArray& arr, const bool keepDims, const bool supportOldShapes, nd4j::memory::Workspace* workspace) {
return evalReduceShapeInfo(order, dimensions, shapeInfo, ArrayOptions::dataType(shapeInfo), keepDims, supportOldShapes, workspace); return evalReduceShapeInfo(order, dimsToExclude, arr, arr.dataType(), keepDims, supportOldShapes, workspace);
}
Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector<int>& dimsToExclude, const Nd4jLong* shapeInfo, const bool keepDims, const bool supportOldShapes, nd4j::memory::Workspace* workspace) {
return evalReduceShapeInfo(order, dimsToExclude, shapeInfo, ArrayOptions::dataType(shapeInfo), keepDims, supportOldShapes, workspace);
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector<int>& dimensions, const NDArray& arr, const nd4j::DataType dataType, const bool keepDims, const bool supportOldShapes, nd4j::memory::Workspace* workspace) { Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector<int>& dimsToExclude, const NDArray& arr, const nd4j::DataType dataType, const bool keepDims, const bool supportOldShapes, nd4j::memory::Workspace* workspace) {
return evalReduceShapeInfo(order, dimensions, arr.getShapeInfo(), dataType, keepDims, supportOldShapes, workspace); return evalReduceShapeInfo(order, dimsToExclude, arr.getShapeInfo(), dataType, keepDims, supportOldShapes, workspace);
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
// evaluate shape resulting from reduce operation // evaluate shape resulting from reduce operation
Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector<int>& dimensions, const Nd4jLong *shapeInfo, const nd4j::DataType dataType, const bool keepDims, const bool supportOldShapes, nd4j::memory::Workspace* workspace) { Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector<int>& dimsToExclude, const Nd4jLong *shapeInfo, const nd4j::DataType dataType, const bool keepDims, const bool supportOldShapes, nd4j::memory::Workspace* workspace) {
if(ArrayOptions::arrayType(shapeInfo) == ArrayType::EMPTY)
return ShapeUtils::evalReduceShapeInfoEmpty(order, dimsToExclude, shapeInfo, dataType, keepDims, workspace);
Nd4jLong* newShapeInfo = nullptr; Nd4jLong* newShapeInfo = nullptr;
int rank = shape::rank(const_cast<Nd4jLong*>(shapeInfo)); int rank = shape::rank(const_cast<Nd4jLong*>(shapeInfo));
if (dimensions.size() == 0) { // return scalar or array with len=1 in this case if (dimsToExclude.size() == 0) { // return scalar or array with len=1 in this case
if(keepDims && rank > 1) { if(keepDims && rank > 1) {
ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(rank), Nd4jLong); ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(rank), Nd4jLong);
@ -157,16 +211,16 @@ Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector<int>& di
} }
} }
shape::checkDimensions(rank, dimensions); shape::checkDimensions(rank, dimsToExclude);
int dimSize = dimensions.size(); int dimSize = dimsToExclude.size();
if(keepDims) { if(keepDims) {
ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(rank), Nd4jLong); ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(rank), Nd4jLong);
newShapeInfo[0] = rank; newShapeInfo[0] = rank;
for(int i = 0; i < rank; ++i) for(int i = 0; i < rank; ++i)
if (std::binary_search(dimensions.begin(), dimensions.end(), i)) // dimensions is already sorted after shape::checkDimensions() has been applied if (std::binary_search(dimsToExclude.begin(), dimsToExclude.end(), i)) // dimsToExclude is already sorted after shape::checkDimensions() has been applied
newShapeInfo[i+1] = 1; newShapeInfo[i+1] = 1;
else else
newShapeInfo[i+1] = shapeInfo[i+1]; newShapeInfo[i+1] = shapeInfo[i+1];
@ -178,7 +232,7 @@ Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector<int>& di
} }
int newRank = rank - dimSize; int newRank = rank - dimSize;
if (newRank==0 || (dimSize==1 && dimensions[0]==INT_MAX)) { // check whether given dimension is meant for the whole dimension if (newRank==0 || (dimSize==1 && dimsToExclude[0]==INT_MAX)) { // check whether given dimension is meant for the whole dimension
if(supportOldShapes) { if(supportOldShapes) {
ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(2), Nd4jLong); ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(2), Nd4jLong);
@ -199,7 +253,7 @@ Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector<int>& di
newShapeInfo[0] = newRank; // set rank newShapeInfo[0] = newRank; // set rank
int j=1; int j=1;
for(int i = 0; i < rank; ++i) for(int i = 0; i < rank; ++i)
if (!std::binary_search(dimensions.begin(), dimensions.end(), i)) // dimensions is already sorted after shape::checkDimensions() has been applied if (!std::binary_search(dimsToExclude.begin(), dimsToExclude.end(), i)) // dimsToExclude is already sorted after shape::checkDimensions() has been applied
newShapeInfo[j++] = shapeInfo[i+1]; newShapeInfo[j++] = shapeInfo[i+1];
//ensure whether vector has proper shape for old shape type //ensure whether vector has proper shape for old shape type
@ -208,7 +262,7 @@ Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector<int>& di
RELEASE(newShapeInfo, workspace); RELEASE(newShapeInfo, workspace);
ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(2), Nd4jLong); // set newRank = 2 ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(2), Nd4jLong); // set newRank = 2
newShapeInfo[0] = 2; newShapeInfo[0] = 2;
if (dimensions[0] == 0) { if (dimsToExclude[0] == 0) {
newShapeInfo[1] = 1; newShapeInfo[1] = 1;
newShapeInfo[2] = oldValue; newShapeInfo[2] = oldValue;
} }
@ -422,8 +476,23 @@ bool ShapeUtils::evalBroadcastShapeInfo(Nd4jLong *max, Nd4jLong *min, const bool
if(maxShapeInfo[maxRank-i] < minShapeInfo[minRank-i]) if(maxShapeInfo[maxRank-i] < minShapeInfo[minRank-i])
tmpShapeInfo[maxRank - i] = minShapeInfo[minRank-i]; tmpShapeInfo[maxRank - i] = minShapeInfo[minRank-i];
// nullify zero axis
for (int e = 0; e < maxRank; e++)
if (maxShapeInfo[e+1] == 0)
tmpShapeInfo[e+1] = 0;
int delta = maxRank - minRank;
for (int e = minRank - 1; e >= 0; e--)
if (minShapeInfo[e + 1] == 0)
tmpShapeInfo[e + 1 + delta] = 0;
ShapeUtils::updateStridesAndType(tmpShapeInfo, DataTypeUtils::pickPairwiseResultType(maxShapeInfo, minShapeInfo), shape::order(maxShapeInfo)); ShapeUtils::updateStridesAndType(tmpShapeInfo, DataTypeUtils::pickPairwiseResultType(maxShapeInfo, minShapeInfo), shape::order(maxShapeInfo));
if (shape::isEmpty(max) || shape::isEmpty(min)) {
ArrayOptions::setPropertyBit(tmpShapeInfo, ARRAY_EMPTY);
memset(shape::stride(tmpShapeInfo), 0, shape::rank(tmpShapeInfo) * sizeof(Nd4jLong));
}
ShapeDescriptor descriptor(tmpShapeInfo); ShapeDescriptor descriptor(tmpShapeInfo);
RELEASE(tmpShapeInfo, workspace); RELEASE(tmpShapeInfo, workspace);
resultShapeInfo = ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>(); resultShapeInfo = ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
@ -805,7 +874,7 @@ std::vector<Nd4jLong> ShapeUtils::evalShapeForMatmul(const Nd4jLong* xShapeInfo,
nd4j_printf("ShapeUtils::evalShapeForMatmul method: since input arrays are vectors they must have the same length, but got x length = %i, y length = %i !", xShapeInfo[1], yShapeInfo[1]); nd4j_printf("ShapeUtils::evalShapeForMatmul method: since input arrays are vectors they must have the same length, but got x length = %i, y length = %i !", xShapeInfo[1], yShapeInfo[1]);
throw std::invalid_argument(""); throw std::invalid_argument("");
} }
return std::vector<Nd4jLong>({0}); return std::vector<Nd4jLong>({});
} }

View File

@ -1992,7 +1992,7 @@ template <typename T>
len = shape::length(shapeInfo); len = shape::length(shapeInfo);
//check whether shape is like {1} or {1,1} or {1,1,1,1,...} - in this case we don't need permute //check whether shape is like {1} or {1,1} or {1,1,1,1,...} - in this case we don't need permute
if(len < 2) if(len == 1)
return; return;
const int rank = shape::rank(shapeInfo); const int rank = shape::rank(shapeInfo);
@ -3961,7 +3961,7 @@ INLINEDEF _CUDA_H bool reshapeC(const int oldRank, const Nd4jLong* oldShapeInfo,
newDim = newShape[newStart]; newDim = newShape[newStart];
oldDim = oldShape[oldStart]; oldDim = oldShape[oldStart];
while (newDim != oldDim) while (newDim != oldDim && newDim > 0 && oldDim > 0)
if (newDim < oldDim) newDim *= newShape[newStop++]; if (newDim < oldDim) newDim *= newShape[newStop++];
else oldDim *= oldShape[oldStop++]; else oldDim *= oldShape[oldStop++];

View File

@ -116,13 +116,23 @@ void IndexReduce<X>::exec(void *vx, Nd4jLong *xShapeInfo,
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<X *>(vx);
auto extraParams = reinterpret_cast<X *>(vextraParams); auto extraParams = reinterpret_cast<X *>(vextraParams);
const Nd4jLong zLen = shape::length(zShapeInfo);
if(nd4j::ArrayOptions::arrayType(xShapeInfo) == nd4j::ArrayType::EMPTY) {
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
return;
const auto indexValue = OpType::startingIndexValue(x);
PRAGMA_OMP_PARALLEL_FOR_IF(zLen > nd4j::Environment::getInstance()->elementwiseThreshold())
for (uint i = 0; i < zLen; i++)
z[i] = indexValue.index;;
return;
}
if(shape::isScalar(zShapeInfo)) { if(shape::isScalar(zShapeInfo)) {
z[0] = execScalar<OpType>(x,xShapeInfo,extraParams); z[0] = execScalar<OpType>(x,xShapeInfo,extraParams);
return; return;
} }
const Nd4jLong zLen = shape::length(zShapeInfo);
auto tadOnlyShapeInfo = tadShapeInfo; auto tadOnlyShapeInfo = tadShapeInfo;
Nd4jLong *tadOffsets = tadOffset; Nd4jLong *tadOffsets = tadOffset;

View File

@ -46,6 +46,21 @@ namespace functions {
const Nd4jLong length = shape::length(xShapeInfo); const Nd4jLong length = shape::length(xShapeInfo);
auto xEws = shape::elementWiseStride(xShapeInfo); auto xEws = shape::elementWiseStride(xShapeInfo);
if (shape::isEmpty(xShapeInfo)) {
z[0] = OpType::startingValue(x);
return;
}
if(nd4j::ArrayOptions::arrayType(xShapeInfo) == nd4j::ArrayType::EMPTY) {
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
return;
const auto startingVal = OpType::startingValue(x);
PRAGMA_OMP_PARALLEL_FOR_IF(length > nd4j::Environment::getInstance()->elementwiseThreshold())
for (uint i = 0; i < length; i++)
z[i] = startingVal;
return;
}
if (xEws >= 1) { if (xEws >= 1) {
z[0] = execScalar<OpType>(x, xEws, length, extraParams); z[0] = execScalar<OpType>(x, xEws, length, extraParams);
} }
@ -157,6 +172,16 @@ namespace functions {
auto resultLength = shape::length(zShapeInfo); auto resultLength = shape::length(zShapeInfo);
if(nd4j::ArrayOptions::arrayType(xShapeInfo) == nd4j::ArrayType::EMPTY) {
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
return;
const auto startingVal = OpType::startingValue(x);
PRAGMA_OMP_PARALLEL_FOR_IF(resultLength > nd4j::Environment::getInstance()->elementwiseThreshold())
for (uint i = 0; i < resultLength; i++)
z[i] = startingVal;
return;
}
//pre squeezed: this is for keeping the pointer to the original //pre squeezed: this is for keeping the pointer to the original
//shape information for tad offset //shape information for tad offset
//the squeezed information doesn't render the right strides for //the squeezed information doesn't render the right strides for

View File

@ -46,6 +46,25 @@ namespace functions {
const Nd4jLong length = shape::length(xShapeInfo); const Nd4jLong length = shape::length(xShapeInfo);
auto xEws = shape::elementWiseStride(xShapeInfo); auto xEws = shape::elementWiseStride(xShapeInfo);
if (shape::isEmpty(xShapeInfo)) {
if (std::is_same<OpType, simdOps::Mean<X,Z>>::value) {
z[0] = nd4j::DataTypeUtils::nanOrZero<Z>();
} else {
z[0] = OpType::startingValue(x);
}
return;
}
if(nd4j::ArrayOptions::arrayType(xShapeInfo) == nd4j::ArrayType::EMPTY) {
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
return;
const auto startingVal = OpType::startingValue(x);
PRAGMA_OMP_PARALLEL_FOR_IF(length > nd4j::Environment::getInstance()->elementwiseThreshold())
for (uint i = 0; i < length; i++)
z[i] = startingVal;
return;
}
if (xEws > 0) { if (xEws > 0) {
z[0] = execScalar<OpType>(x, xEws, length, extraParams); z[0] = execScalar<OpType>(x, xEws, length, extraParams);
} }
@ -165,6 +184,16 @@ namespace functions {
auto resultLength = shape::length(zShapeInfo); auto resultLength = shape::length(zShapeInfo);
if(nd4j::ArrayOptions::arrayType(xShapeInfo) == nd4j::ArrayType::EMPTY) {
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
return;
const auto startingVal = std::is_same<OpType, simdOps::Mean<X,Z>>::value ? nd4j::DataTypeUtils::nanOrZero<Z>() : static_cast<Z>(OpType::startingValue(x));
PRAGMA_OMP_PARALLEL_FOR_IF(resultLength > nd4j::Environment::getInstance()->elementwiseThreshold())
for (uint i = 0; i < resultLength; i++)
z[i] = startingVal;
return;
}
//pre squeezed: this is for keeping the pointer to the original //pre squeezed: this is for keeping the pointer to the original
//shape information for tad offset //shape information for tad offset
//the squeezed information doesn't render the right strides for //the squeezed information doesn't render the right strides for

View File

@ -46,6 +46,21 @@ namespace functions {
const Nd4jLong length = shape::length(xShapeInfo); const Nd4jLong length = shape::length(xShapeInfo);
auto xEws = shape::elementWiseStride(xShapeInfo); auto xEws = shape::elementWiseStride(xShapeInfo);
if (shape::isEmpty(xShapeInfo)) {
z[0] = OpType::startingValue(x);
return;
}
if(nd4j::ArrayOptions::arrayType(xShapeInfo) == nd4j::ArrayType::EMPTY) {
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
return;
const auto startingVal = OpType::startingValue(x);
PRAGMA_OMP_PARALLEL_FOR_IF(length > nd4j::Environment::getInstance()->elementwiseThreshold())
for (uint i = 0; i < length; i++)
z[i] = startingVal;
return;
}
if (xEws >= 1) { if (xEws >= 1) {
z[0] = execScalar<OpType>(x, xEws, length, extraParams); z[0] = execScalar<OpType>(x, xEws, length, extraParams);
} }
@ -159,6 +174,16 @@ namespace functions {
auto resultLength = shape::length(zShapeInfo); auto resultLength = shape::length(zShapeInfo);
if(nd4j::ArrayOptions::arrayType(xShapeInfo) == nd4j::ArrayType::EMPTY) {
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
return;
const auto startingVal = OpType::startingValue(x);
PRAGMA_OMP_PARALLEL_FOR_IF(resultLength > nd4j::Environment::getInstance()->elementwiseThreshold())
for (uint i = 0; i < resultLength; i++)
z[i] = startingVal;
return;
}
//pre squeezed: this is for keeping the pointer to the original //pre squeezed: this is for keeping the pointer to the original
//shape information for tad offset //shape information for tad offset
//the squeezed information doesn't render the right strides for //the squeezed information doesn't render the right strides for

View File

@ -48,6 +48,20 @@ namespace functions {
const auto xEws = shape::elementWiseStride(xShapeInfo); const auto xEws = shape::elementWiseStride(xShapeInfo);
const int rank = shape::rank(xShapeInfo); const int rank = shape::rank(xShapeInfo);
if (shape::isEmpty(xShapeInfo)) {
z[0] = OpType::startingValue(x);
return;
}
if(nd4j::ArrayOptions::arrayType(xShapeInfo) == nd4j::ArrayType::EMPTY) {
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
return;
const auto startingVal = OpType::startingValue(x);
PRAGMA_OMP_PARALLEL_FOR_IF(length > nd4j::Environment::getInstance()->elementwiseThreshold())
for (uint i = 0; i < length; i++)
z[i] = startingVal;
return;
}
if (xEws >= 1) { if (xEws >= 1) {
z[0] = execScalar<OpType>(x, xEws, length, extraParams); z[0] = execScalar<OpType>(x, xEws, length, extraParams);
@ -71,7 +85,7 @@ namespace functions {
for (int e = 0; e < maxThreads; e++) for (int e = 0; e < maxThreads; e++)
start = OpType::update(start, intermediate[e], extraParams); start = OpType::update(start, intermediate[e], extraParams);
z[0] = OpType::postProcess(start, shape::length(xShapeInfo), extraParams); z[0] = OpType::postProcess(start, length, extraParams);
} }
} }
@ -171,6 +185,16 @@ namespace functions {
auto zLength = shape::length(zShapeInfo); auto zLength = shape::length(zShapeInfo);
if(nd4j::ArrayOptions::arrayType(xShapeInfo) == nd4j::ArrayType::EMPTY) {
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
return;
const auto startingVal = OpType::startingValue(x);
PRAGMA_OMP_PARALLEL_FOR_IF(zLength > nd4j::Environment::getInstance()->elementwiseThreshold())
for (uint i = 0; i < zLength; i++)
z[i] = startingVal;
return;
}
//pre squeezed: this is for keeping the pointer to the original //pre squeezed: this is for keeping the pointer to the original
//shape information for tad offset //shape information for tad offset
//the squeezed information doesn't render the right strides for //the squeezed information doesn't render the right strides for

View File

@ -47,6 +47,16 @@ void Reduce3<X,Z>::execScalar(void *vx, Nd4jLong *xShapeInfo,
auto xEws = shape::elementWiseStride(xShapeInfo); auto xEws = shape::elementWiseStride(xShapeInfo);
auto yEws = shape::elementWiseStride(yShapeInfo); auto yEws = shape::elementWiseStride(yShapeInfo);
if(nd4j::ArrayOptions::arrayType(xShapeInfo) == nd4j::ArrayType::EMPTY || nd4j::ArrayOptions::arrayType(yShapeInfo) == nd4j::ArrayType::EMPTY) {
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
return;
const auto startingVal = OpType::startingValue(x);
PRAGMA_OMP_PARALLEL_FOR_IF(length > nd4j::Environment::getInstance()->elementwiseThreshold())
for (uint i = 0; i < length; i++)
z[i] = startingVal;
return;
}
Z extraParamsVals[3] = {(Z) 0.0f, (Z) 0.0f, (Z) 0.0f}; Z extraParamsVals[3] = {(Z) 0.0f, (Z) 0.0f, (Z) 0.0f};
// it's possible case for EqualsWithEps op // it's possible case for EqualsWithEps op
if (extraParams != nullptr) if (extraParams != nullptr)

View File

@ -47,8 +47,8 @@ namespace functions {
Nd4jLong *xShapeInfo, Nd4jLong *xShapeInfo,
void *extraParams, void *extraParams,
void *z, void *z,
Nd4jLong *resultShapeInfoBuffer) { Nd4jLong *zShapeInfo) {
DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(biasCorrected, x, xShapeInfo, extraParams, z, resultShapeInfoBuffer), SUMMARY_STATS_OPS); DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(biasCorrected, x, xShapeInfo, extraParams, z, zShapeInfo), SUMMARY_STATS_OPS);
} }
template <typename X, typename Y> template <typename X, typename Y>
@ -58,10 +58,10 @@ namespace functions {
Nd4jLong *xShapeInfo, Nd4jLong *xShapeInfo,
void *extraParams, void *extraParams,
void *z, void *z,
Nd4jLong *resultShapeInfoBuffer, Nd4jLong *zShapeInfo,
int *dimension, int *dimension,
int dimensionLength) { int dimensionLength) {
DISPATCH_BY_OPNUM_TT(exec, PARAMS(biasCorrected, x, xShapeInfo, extraParams, z, resultShapeInfoBuffer, dimension, dimensionLength), SUMMARY_STATS_OPS); DISPATCH_BY_OPNUM_TT(exec, PARAMS(biasCorrected, x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength), SUMMARY_STATS_OPS);
} }
template <typename X, typename Z> template <typename X, typename Z>
@ -71,7 +71,7 @@ namespace functions {
Nd4jLong *xShapeInfo, Nd4jLong *xShapeInfo,
void *vextraParams, void *vextraParams,
void *vz, void *vz,
Nd4jLong *resultShapeInfoBuffer) { Nd4jLong *zShapeInfo) {
auto z = reinterpret_cast<Z*>(vz); auto z = reinterpret_cast<Z*>(vz);
z[0] = execScalar<OpType>(biasCorrected, vx, xShapeInfo, vextraParams); z[0] = execScalar<OpType>(biasCorrected, vx, xShapeInfo, vextraParams);
} }
@ -108,20 +108,31 @@ namespace functions {
void *vx, void *vx,
Nd4jLong *xShapeInfo, Nd4jLong *xShapeInfo,
void *vextraParams, void *vextraParams,
void *vresult, void *vz,
Nd4jLong *resultShapeInfoBuffer, Nd4jLong *zShapeInfo,
int *dimension, int *dimension,
int dimensionLength) { int dimensionLength) {
auto x = reinterpret_cast<X *>(vx);
auto z = reinterpret_cast<Z *>(vresult);
auto extraParams = reinterpret_cast<Z *>(vextraParams);
if (shape::isScalar(resultShapeInfoBuffer)) { auto x = reinterpret_cast<X *>(vx);
z[0] = execScalar<OpType>(biasCorrected, x, xShapeInfo, extraParams); auto z = reinterpret_cast<Z *>(vz);
auto extraParams = reinterpret_cast<Z *>(vextraParams);
int resultLength = shape::length(zShapeInfo);
if(nd4j::ArrayOptions::arrayType(xShapeInfo) == nd4j::ArrayType::EMPTY) {
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
return;
SummaryStatsData<X> comp;
comp.initWithValue(x[0]);
PRAGMA_OMP_PARALLEL_FOR_IF(resultLength > nd4j::Environment::getInstance()->elementwiseThreshold())
for (uint i = 0; i < resultLength; i++)
z[i] = OpType::getValue(biasCorrected, comp);
return; return;
} }
if (shape::isScalar(zShapeInfo)) {
z[0] = execScalar<OpType>(biasCorrected, x, xShapeInfo, extraParams);
return;
}
//no-op //no-op
if (dimensionLength < 1) if (dimensionLength < 1)
@ -129,7 +140,6 @@ namespace functions {
auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(xShapeInfo, dimension, dimensionLength); auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(xShapeInfo, dimension, dimensionLength);
int resultLength = shape::length(resultShapeInfoBuffer);
//pre squeezed: this is for keeping the pointer to the original //pre squeezed: this is for keeping the pointer to the original
//shape information for tad offset //shape information for tad offset
//the squeezed information doesn't render the right strides for //the squeezed information doesn't render the right strides for

View File

@ -131,7 +131,7 @@ namespace nd4j {
COPY_SHAPE(x, shapeE); COPY_SHAPE(x, shapeE);
COPY_SHAPE(y, shapeG); COPY_SHAPE(y, shapeG);
auto shapeList = SHAPELIST(shapeE, shapeG); auto shapeList = SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG));
return shapeList; return shapeList;
} }

View File

@ -81,7 +81,7 @@ CUSTOM_OP_IMPL(conv1d, 2, 1, false, 0, 4) {
auto outputReshaped = output ->reshape(output->ordering(), reshapeForOutput); auto outputReshaped = output ->reshape(output->ordering(), reshapeForOutput);
auto weightsReshaped = weights->reshape(weights->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC] auto weightsReshaped = weights->reshape(weights->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
ConvolutionUtils::conv2d(*block.launchContext(), inputReshaped, weightsReshaped, bias, outputReshaped, 1,kW, 1,sW, 0,pW, 1,1, isSameMode, isNCW); ConvolutionUtils::conv2d(block, inputReshaped, weightsReshaped, bias, outputReshaped, 1,kW, 1,sW, 0,pW, 1,1, isSameMode, isNCW);
delete inputReshaped; delete inputReshaped;
delete outputReshaped; delete outputReshaped;
@ -217,7 +217,7 @@ CUSTOM_OP_IMPL(conv1d_bp, 3, 2, false, 0, 4) {
auto weightsReshaped = weights->reshape(weights->ordering(),{1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC] auto weightsReshaped = weights->reshape(weights->ordering(),{1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
auto gradWReshaped = gradW ->reshape(gradW->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC] auto gradWReshaped = gradW ->reshape(gradW->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
ConvolutionUtils::conv2dBP(*block.launchContext(), inputReshaped, weightsReshaped, bias, gradOReshaped, gradIReshaped, gradWReshaped, gradB, 1,kW, 1,sW, 0,pW, 1,1, isSameMode, isNCW); ConvolutionUtils::conv2dBP(block, inputReshaped, weightsReshaped, bias, gradOReshaped, gradIReshaped, gradWReshaped, gradB, 1,kW, 1,sW, 0,pW, 1,1, isSameMode, isNCW);
delete inputReshaped; delete inputReshaped;
delete gradIReshaped; delete gradIReshaped;

View File

@ -63,7 +63,7 @@ CUSTOM_OP_IMPL(conv2d, 2, 1, false, 0, 9) {
if (bias) if (bias)
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
ConvolutionUtils::conv2d(*block.launchContext(), input, weights, bias, output, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW); ConvolutionUtils::conv2d(block, input, weights, bias, output, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW);
return Status::OK(); return Status::OK();
} }
@ -194,7 +194,7 @@ CUSTOM_OP_IMPL(conv2d_bp, 3, 2, false, 0, 9) {
if(bias) if(bias)
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV2D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV2D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
ConvolutionUtils::conv2dBP(*block.launchContext(), input, weights, bias, gradO, gradI, gradW, gradB, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW); ConvolutionUtils::conv2dBP(block, input, weights, bias, gradO, gradI, gradW, gradB, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW);
return Status::OK(); return Status::OK();
} }
@ -305,7 +305,7 @@ CUSTOM_OP_IMPL(conv2d_input_bp, 3, 1, false, 0, 9) {
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "CUSTOM CONV2D_INPUT_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str()); REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "CUSTOM CONV2D_INPUT_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, "CUSTOM CONV2D_INPUT_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str()); REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, "CUSTOM CONV2D_INPUT_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str());
ConvolutionUtils::conv2dBP(*block.launchContext(), &input, weights, nullptr, gradO, gradI, nullptr, nullptr, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW); ConvolutionUtils::conv2dBP(block, &input, weights, nullptr, gradO, gradI, nullptr, nullptr, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW);
return Status::OK(); return Status::OK();
} }

View File

@ -157,7 +157,7 @@ CUSTOM_OP_IMPL(conv3dnew, 2, 1, false, 0, 13) {
permutForOutput = {0,2,3,4,1}; // [bS, oC, oD, oH, oW] -> [bS, oD, oH, oW, oC] permutForOutput = {0,2,3,4,1}; // [bS, oC, oD, oH, oW] -> [bS, oD, oH, oW, oC]
NDArray columns(input->ordering(), {bS, iC, kD, kH, kW, oD, oH, oW}, input->dataType(), block.launchContext()); NDArray columns(input->ordering(), {bS, iC, kD, kH, kW, oD, oH, oW}, input->dataType(), block.launchContext());
ConvolutionUtils::vol2col(*block.launchContext(), *input, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, iC, iD, iH, iW] is convoluted to [bS, iC, kD, kH, kW, oD, oH, oW] ConvolutionUtils::vol2col(block, *input, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, iC, iD, iH, iW] is convoluted to [bS, iC, kD, kH, kW, oD, oH, oW]
// [bS, iC, kD, kH, kW, oD, oH, oW] x [kD, kH, kW, iC, oC] = [bS, oD, oH, oW, oC] // [bS, iC, kD, kH, kW, oD, oH, oW] x [kD, kH, kW, iC, oC] = [bS, oD, oH, oW, oC]
MmulHelper::tensorDot(&columns, weights, output, {1,2,3,4}, {3,0,1,2}, permutForOutput); MmulHelper::tensorDot(&columns, weights, output, {1,2,3,4}, {3,0,1,2}, permutForOutput);
@ -456,7 +456,7 @@ CUSTOM_OP_IMPL(conv3dnew_bp, 3, 2, false, 0, 13) {
// ----- calculation of gradW and gradB ----- // // ----- calculation of gradW and gradB ----- //
NDArray columns(input->ordering(), {bS, iC, kD, kH, kW, oD, oH, oW}, input->dataType(), block.launchContext()); NDArray columns(input->ordering(), {bS, iC, kD, kH, kW, oD, oH, oW}, input->dataType(), block.launchContext());
ConvolutionUtils::vol2col(*block.launchContext(), *input, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, iC, iD, iH, iW] is convoluted to [bS, iC, kD, kH, kW, oD, oH, oW] ConvolutionUtils::vol2col(block, *input, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, iC, iD, iH, iW] is convoluted to [bS, iC, kD, kH, kW, oD, oH, oW]
MmulHelper::tensorDot(&columns, gradO, gradW, {0,5,6,7}, gradOaxesForDot, {3,0,1,2,4}); // [bS, iC, kD, kH, kW, oD, oH, oW] x [bS, oD, oH, oW, oC]/[bS, oC, oD, oH, oW] = [iC, kD, kH, kW, oC] MmulHelper::tensorDot(&columns, gradO, gradW, {0,5,6,7}, gradOaxesForDot, {3,0,1,2,4}); // [bS, iC, kD, kH, kW, oD, oH, oW] x [bS, oD, oH, oW, oC]/[bS, oC, oD, oH, oW] = [iC, kD, kH, kW, oC]
if(gradB) { if(gradB) {
@ -469,7 +469,7 @@ CUSTOM_OP_IMPL(conv3dnew_bp, 3, 2, false, 0, 13) {
//----- calculation of gradI -----// //----- calculation of gradI -----//
MmulHelper::tensorDot(weights, gradO, &columns, {indWoC}, {indIOioC}, {2,3,4,1,0,5,6,7}); // [kD, kH, kW, iC, oC] x [bS, oD, oH, oW, oC]/[bS, oC, oD, oH, oW] = [kD, kH, kW, iC, bS, oD, oH, oW] MmulHelper::tensorDot(weights, gradO, &columns, {indWoC}, {indIOioC}, {2,3,4,1,0,5,6,7}); // [kD, kH, kW, iC, oC] x [bS, oD, oH, oW, oC]/[bS, oC, oD, oH, oW] = [kD, kH, kW, iC, bS, oD, oH, oW]
ConvolutionUtils::col2vol(*block.launchContext(), columns, *gradI, sD, sH, sW, pD, pH, pW, dD, dH, dW); // columns [bS, iC, kD, kH, kW, oD, oH, oW] is de-convoluted to [bS, iC, iD, iH, iW] ConvolutionUtils::col2vol(block, columns, *gradI, sD, sH, sW, pD, pH, pW, dD, dH, dW); // columns [bS, iC, kD, kH, kW, oD, oH, oW] is de-convoluted to [bS, iC, iD, iH, iW]
if(!isNDHWC) { if(!isNDHWC) {
delete input; delete input;

View File

@ -70,7 +70,7 @@ CUSTOM_OP_IMPL(deconv2d_tf, 3, 1, false, 0, 9) {
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "CUSTOM DECONV2D_TF OP: wrong shape of input array, basing on array with output shape expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str()); REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "CUSTOM DECONV2D_TF OP: wrong shape of input array, basing on array with output shape expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, "CUSTOM DECONV2D_TF OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str()); REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, "CUSTOM DECONV2D_TF OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str());
ConvolutionUtils::conv2dBP(*block.launchContext(), &input, weights, nullptr, gradO, gradI, nullptr, nullptr, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW); ConvolutionUtils::conv2dBP(block, &input, weights, nullptr, gradO, gradI, nullptr, nullptr, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW);
return Status::OK(); return Status::OK();
} }

View File

@ -75,7 +75,7 @@ CUSTOM_OP_IMPL(deconv3d, 2, 1, false, 0, 13) {
// NDHWC: [kD, kH, kW, oC, iC] x [bS, iD, iH, iW, iC] = [kD, kH, kW, oC, bS, iD, iH, iW] // NDHWC: [kD, kH, kW, oC, iC] x [bS, iD, iH, iW, iC] = [kD, kH, kW, oC, bS, iD, iH, iW]
// NCDHW: [iC, oC, kD, kH, kW] x [bS, iC, iD, iH, iW] = [oC, kD, kH, kW, bS, iD, iH, iW] // NCDHW: [iC, oC, kD, kH, kW] x [bS, iC, iD, iH, iW] = [oC, kD, kH, kW, bS, iD, iH, iW]
nd4j::MmulHelper::tensorDot(weights, input, &columns, {indWiC}, {indIOioC}, {2, 3, 4, 1, 0, 5, 6, 7}); // [bS, oC, kD, kH, kW, iD, iH, iW] -> [kD, kH, kW, oC, bS, iD, iH, iW] nd4j::MmulHelper::tensorDot(weights, input, &columns, {indWiC}, {indIOioC}, {2, 3, 4, 1, 0, 5, 6, 7}); // [bS, oC, kD, kH, kW, iD, iH, iW] -> [kD, kH, kW, oC, bS, iD, iH, iW]
ConvolutionUtils::col2vol(*block.launchContext(), columns, *output, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, oC, kD, kH, kW, iD, iH, iW] is de-convoluted to [bS, oC, oD, oH, oW] ConvolutionUtils::col2vol(block, columns, *output, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, oC, kD, kH, kW, iD, iH, iW] is de-convoluted to [bS, oC, oD, oH, oW]
//----- add biases if required -----// //----- add biases if required -----//
if(bias) if(bias)
@ -234,7 +234,7 @@ CUSTOM_OP_IMPL(deconv3d_bp, 3, 2, false, 0, 13) {
// ----- calculation of gradW ----- // // ----- calculation of gradW ----- //
auto columns = NDArrayFactory::create(input->ordering(), {bS, oC, kD, kH, kW, iD, iH, iW}, input->dataType(), block.launchContext()); auto columns = NDArrayFactory::create(input->ordering(), {bS, oC, kD, kH, kW, iD, iH, iW}, input->dataType(), block.launchContext());
ConvolutionUtils::vol2col(*block.launchContext(), *gradO, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, oC, oD, oH, oW] is deconvoluted to [bS, oC, kD, kH, kW, iD, iH, iW] ConvolutionUtils::vol2col(block, *gradO, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, oC, oD, oH, oW] is deconvoluted to [bS, oC, kD, kH, kW, iD, iH, iW]
MmulHelper::tensorDot(input, &columns, gradW, inputAxesForDot, {0, 5, 6, 7}, {4, 3, 0, 1, 2}); // [bS, iC, iD, iH, iW]/[bS, iD, iH, iW, iC] x [bS, oC, kD, kH, kW, iD, iH, iW] = [iC, oC, kD, kH, kW] MmulHelper::tensorDot(input, &columns, gradW, inputAxesForDot, {0, 5, 6, 7}, {4, 3, 0, 1, 2}); // [bS, iC, iD, iH, iW]/[bS, iD, iH, iW, iC] x [bS, oC, kD, kH, kW, iD, iH, iW] = [iC, oC, kD, kH, kW]
// ----- calculation of gradB ----- // // ----- calculation of gradB ----- //

View File

@ -62,7 +62,7 @@ CUSTOM_OP_IMPL(depthwise_conv2d, 2, 1, false, 0, 9) {
if (bias) if (bias)
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DEPTHWISECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DEPTHWISECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
ConvolutionUtils::depthwiseConv2d(*block.launchContext(), input, weights, bias, output, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW); ConvolutionUtils::depthwiseConv2d(block, input, weights, bias, output, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW);
return Status::OK(); return Status::OK();
} }
@ -185,7 +185,7 @@ CUSTOM_OP_IMPL(depthwise_conv2d_bp, 3, 2, false, 0, 9) {
if(bias) if(bias)
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
ConvolutionUtils::depthwiseConv2dBP(*block.launchContext(), input, weights, bias, gradO, gradI, gradW, gradB, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW); ConvolutionUtils::depthwiseConv2dBP(block, input, weights, bias, gradO, gradI, gradW, gradB, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW);
return Status::OK(); return Status::OK();
} }

View File

@ -58,7 +58,7 @@ CUSTOM_OP_IMPL(pointwise_conv2d, 2, 1, false, 0, 0) {
if (bias) if (bias)
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM POINTWISECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM POINTWISECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
ConvolutionUtils::conv2d(*block.launchContext(), input, weights, bias, output, kH,kW, sH,sW, pH,pW, dH,dW, 1/*isSameMode*/, isNCHW); ConvolutionUtils::conv2d(block, input, weights, bias, output, kH,kW, sH,sW, pH,pW, dH,dW, 1/*isSameMode*/, isNCHW);
return Status::OK(); return Status::OK();
} }

View File

@ -70,7 +70,7 @@ CUSTOM_OP_IMPL(avgpool2d, 1, 1, false, 0, 10) {
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor;
ConvolutionUtils::pooling2d(*block.launchContext(), *input, *output, kH, kW, sH, sW, pH, pW, dH, dW, PoolingType::AVG_POOL, extraParam0); ConvolutionUtils::pooling2d(block, *input, *output, kH, kW, sH, sW, pH, pW, dH, dW, PoolingType::AVG_POOL, extraParam0);
//output->printBuffer("output op"); //output->printBuffer("output op");
if (!isNCHW) { if (!isNCHW) {
@ -198,7 +198,7 @@ CUSTOM_OP_IMPL(avgpool2d_bp, 2, 1, false, 0, 10) {
// *gradI /= kH*kW; // *gradI /= kH*kW;
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor;
ConvolutionUtils::pooling2dBP(*block.launchContext(), *input, *gradO, *gradI, kH, kW, sH, sW, pH, pW, dH, dW, 1, extraParam0); ConvolutionUtils::pooling2dBP(block, *input, *gradO, *gradI, kH, kW, sH, sW, pH, pW, dH, dW, 1, extraParam0);
if(!isNCHW) { if(!isNCHW) {
delete input; delete input;

View File

@ -69,7 +69,7 @@ CUSTOM_OP_IMPL(avgpool3dnew, 1, 1, false, 0, 14) {
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
//T extraParams[] = {}; //T extraParams[] = {};
ConvolutionUtils::pooling3d(*block.launchContext(), *input, *output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, 1, extraParam0); ConvolutionUtils::pooling3d(block, *input, *output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, 1, extraParam0);
if(!isNCDHW) { if(!isNCDHW) {
delete input; delete input;
@ -189,7 +189,7 @@ CUSTOM_OP_IMPL(avgpool3dnew_bp, 2, 1, false, 0, 14) {
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor;
ConvolutionUtils::pooling3dBP(*block.launchContext(), *input, *gradO, *gradI, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, 1, extraParam0); ConvolutionUtils::pooling3dBP(block, *input, *gradO, *gradI, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, 1, extraParam0);
if(!isNCDHW) { if(!isNCDHW) {
delete input; delete input;

View File

@ -70,7 +70,7 @@ CUSTOM_OP_IMPL(maxpool2d, 1, 1, false, 0, 9) {
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; poolingMode; 9 - divisor; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; poolingMode; 9 - divisor;
ConvolutionUtils::pooling2d(*block.launchContext(), *input, *output, kH, kW, sH, sW, pH, pW, dH, dW, PoolingType::MAX_POOL, 1); ConvolutionUtils::pooling2d(block, *input, *output, kH, kW, sH, sW, pH, pW, dH, dW, PoolingType::MAX_POOL, 1);
if (!isNCHW) { if (!isNCHW) {
delete input; delete input;
@ -196,7 +196,7 @@ CUSTOM_OP_IMPL(maxpool2d_bp, 2, 1, false, 0, 10) {
// columns->template applyTransform<simdOps::Col2Im<T>>(gradI, std::vector<T>({(T)sH, (T)sW, (T)pH, (T)pW, (T)iH, (T)iW, (T)dH, (T)dW}).data()); // columns->template applyTransform<simdOps::Col2Im<T>>(gradI, std::vector<T>({(T)sH, (T)sW, (T)pH, (T)pW, (T)iH, (T)iW, (T)dH, (T)dW}).data());
ConvolutionUtils::pooling2dBP(*block.launchContext(), *input, *gradO, *gradI, kH, kW, sH, sW, pH, pW, dH, dW, 0., 1.); ConvolutionUtils::pooling2dBP(block, *input, *gradO, *gradI, kH, kW, sH, sW, pH, pW, dH, dW, 0., 1.);
if(!isNCHW) { if(!isNCHW) {
delete input; delete input;

View File

@ -70,7 +70,7 @@ CUSTOM_OP_IMPL(maxpool3dnew, 1, 1, false, 0, 14) {
if(isSameMode) // SAME if(isSameMode) // SAME
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
ConvolutionUtils::pooling3d(*block.launchContext(), *input, *output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, 0, 1); ConvolutionUtils::pooling3d(block, *input, *output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, 0, 1);
if(!isNCDHW) { if(!isNCDHW) {
delete input; delete input;
@ -204,7 +204,7 @@ CUSTOM_OP_IMPL(maxpool3dnew_bp, 2, 1, false, 0, 14) {
// ConvolutionUtils<T>::col2vol(*columns, *gradI, sD, sH, sW, pD, pH, pW, dD, dH, dW); // columns [bS, iC, kD, kH, kW, oD, oH, oW] is de-convoluted to [bS, iC, iD, iH, iW] // ConvolutionUtils<T>::col2vol(*columns, *gradI, sD, sH, sW, pD, pH, pW, dD, dH, dW); // columns [bS, iC, kD, kH, kW, oD, oH, oW] is de-convoluted to [bS, iC, iD, iH, iW]
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - unnecessary; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - unnecessary;
ConvolutionUtils::pooling3dBP(*block.launchContext(), *input, *gradO, *gradI, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, 0, 1); ConvolutionUtils::pooling3dBP(block, *input, *gradO, *gradI, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, 0, 1);
if(!isNCDHW) { if(!isNCDHW) {
delete input; delete input;

View File

@ -68,7 +68,7 @@ namespace nd4j {
ConvolutionUtils::calcPadding2D(pY, pX, oY, oX, inY, inX, kY, kX, sY, sX, dY, dX); ConvolutionUtils::calcPadding2D(pY, pX, oY, oX, inY, inX, kY, kX, sY, sX, dY, dX);
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor;
ConvolutionUtils::pooling2d(*block.launchContext(), *input, *output, kY, kX, sY, sX, pY, pX, dY, dX, PoolingType::PNORM_POOL, extraParam0); ConvolutionUtils::pooling2d(block, *input, *output, kY, kX, sY, sX, pY, pX, dY, dX, PoolingType::PNORM_POOL, extraParam0);
if (!isNCHW) { if (!isNCHW) {
delete input; delete input;
@ -209,7 +209,7 @@ CUSTOM_OP_IMPL(pnormpool2d_bp, 2, 1, false, 1, 10) {
// columns->template applyTransform<simdOps::Col2Im<T>>(gradI, std::vector<T>({(T)sH, (T)sW, (T)pH, (T)pW, (T)iH, (T)iW, (T)dH, (T)dW}).data()); // columns->template applyTransform<simdOps::Col2Im<T>>(gradI, std::vector<T>({(T)sH, (T)sW, (T)pH, (T)pW, (T)iH, (T)iW, (T)dH, (T)dW}).data());
ConvolutionUtils::pooling2dBP(*block.launchContext(), *input, *gradO, *gradI, kH, kW, sH, sW, pH, pW, dH, dW, 2, pnorm); ConvolutionUtils::pooling2dBP(block, *input, *gradO, *gradI, kH, kW, sH, sW, pH, pW, dH, dW, 2, pnorm);
if(!isNCHW) { if(!isNCHW) {
delete input; delete input;

View File

@ -84,11 +84,11 @@ CUSTOM_OP_IMPL(sconv2d, 2, 1, false, 0, 9) {
if (iC == 1) { if (iC == 1) {
nd4j_debug("SCONV2D OP: for input_channels = 1 this op is equivalent to standard conv2d\n",""); nd4j_debug("SCONV2D OP: for input_channels = 1 this op is equivalent to standard conv2d\n","");
ConvolutionUtils::conv2d(*block.launchContext(), input, weightsDepth, bias, output, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW); ConvolutionUtils::conv2d(block, input, weightsDepth, bias, output, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW);
return Status::OK(); return Status::OK();
} }
ConvolutionUtils::sconv2d(*block.launchContext(), input, weightsDepth, weightsPoint, bias, output, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW); ConvolutionUtils::sconv2d(block, input, weightsDepth, weightsPoint, bias, output, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW);
return Status::OK(); return Status::OK();
} }
@ -274,12 +274,12 @@ CUSTOM_OP_IMPL(sconv2d_bp, 3, 2, false, 0, 9) {
auto resultFFShape = isNCHW ? std::vector<Nd4jLong>({bS, mC*iC, oH, oW}) : std::vector<Nd4jLong>({bS, oH, oW, mC*iC}); auto resultFFShape = isNCHW ? std::vector<Nd4jLong>({bS, mC*iC, oH, oW}) : std::vector<Nd4jLong>({bS, oH, oW, mC*iC});
auto resultFF = NDArrayFactory::create_(input->ordering(), resultFFShape, input->dataType(), block.launchContext()); auto resultFF = NDArrayFactory::create_(input->ordering(), resultFFShape, input->dataType(), block.launchContext());
ConvolutionUtils::sconv2d(*block.launchContext(), input, weightsDepth, nullptr, nullptr, resultFF, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW); ConvolutionUtils::sconv2d(block, input, weightsDepth, nullptr, nullptr, resultFF, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW);
auto gradIDepthShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC*mC,oH,oW, 0,indIOioC,indIiH,indIiH+1}); auto gradIDepthShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC*mC,oH,oW, 0,indIOioC,indIiH,indIiH+1});
auto gradIDepth = NDArrayFactory::create_(resultFF->ordering(), gradIDepthShape, resultFF->dataType(), block.launchContext()); // [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW) auto gradIDepth = NDArrayFactory::create_(resultFF->ordering(), gradIDepthShape, resultFF->dataType(), block.launchContext()); // [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW)
ConvolutionUtils::conv2dBP(*block.launchContext(), resultFF, weightsPoint, bias, gradO, gradIDepth, gradWP, gradB, 1,1, 1,1, 0,0, 1,1, isSameMode, isNCHW); // in this case oH=iH and oW=iW ConvolutionUtils::conv2dBP(block, resultFF, weightsPoint, bias, gradO, gradIDepth, gradWP, gradB, 1,1, 1,1, 0,0, 1,1, isSameMode, isNCHW); // in this case oH=iH and oW=iW
gradO = gradIDepth; gradO = gradIDepth;
bias = gradB = nullptr; // if pointwise backprop was done then don't calculate gradB at depthwise_conv2d_bp step bias = gradB = nullptr; // if pointwise backprop was done then don't calculate gradB at depthwise_conv2d_bp step
@ -288,7 +288,7 @@ CUSTOM_OP_IMPL(sconv2d_bp, 3, 2, false, 0, 9) {
} }
// ----- apply depthwise_conv2d_bp ----- // // ----- apply depthwise_conv2d_bp ----- //
ConvolutionUtils::depthwiseConv2dBP(*block.launchContext(), input, weightsDepth, bias, gradO, gradI, gradWD, gradB, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW); ConvolutionUtils::depthwiseConv2dBP(block, input, weightsDepth, bias, gradO, gradI, gradWD, gradB, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW);
if(weightsPoint) if(weightsPoint)
delete gradO; delete gradO;

View File

@ -41,7 +41,7 @@ CUSTOM_OP_IMPL(upsampling2d, 1, 1, false, 0, 2) {
REQUIRE_TRUE(input->rankOf() == 4, 0, "UPSAMPLING2D op: input should be 4D, but got %i instead!", input->rankOf()); REQUIRE_TRUE(input->rankOf() == 4, 0, "UPSAMPLING2D op: input should be 4D, but got %i instead!", input->rankOf());
REQUIRE_TRUE(output->rankOf() == 4, 0, "UPSAMPLING2D op: output should be 4D, but got %i instead!", output->rankOf()); REQUIRE_TRUE(output->rankOf() == 4, 0, "UPSAMPLING2D op: output should be 4D, but got %i instead!", output->rankOf());
ConvolutionUtils::upsampling2d(*block.launchContext(), *input, *output, factorH, factorW, (bool)isNCHW); ConvolutionUtils::upsampling2d(block, *input, *output, factorH, factorW, (bool)isNCHW);
return Status::OK(); return Status::OK();
} }
@ -105,7 +105,7 @@ CUSTOM_OP_IMPL(upsampling2d_bp, 2, 1, false, 0, 0) {
REQUIRE_TRUE(gradO->rankOf() == 4, 0, "UPSAMPLING2D_BP op: output's gradient array must be 4D, but got %i instead!", gradO->rankOf()); REQUIRE_TRUE(gradO->rankOf() == 4, 0, "UPSAMPLING2D_BP op: output's gradient array must be 4D, but got %i instead!", gradO->rankOf());
REQUIRE_TRUE(gradI->rankOf() == 4, 0, "UPSAMPLING2D_BP op: input's gradient array must be 4D, but got %i instead!", gradI->rankOf()); REQUIRE_TRUE(gradI->rankOf() == 4, 0, "UPSAMPLING2D_BP op: input's gradient array must be 4D, but got %i instead!", gradI->rankOf());
ConvolutionUtils::upsampling2dBP(*block.launchContext(), *gradO, *gradI, (bool)isNCHW); ConvolutionUtils::upsampling2dBP(block, *gradO, *gradI, (bool)isNCHW);
return Status::OK(); return Status::OK();
} }

View File

@ -41,7 +41,7 @@ CUSTOM_OP_IMPL(upsampling3d, 1, 1, false, 0, 3) {
REQUIRE_TRUE(input->rankOf() == 5, 0, "UPSAMPLING3D op: input should be 5D, but got %i instead!", input->rankOf()); REQUIRE_TRUE(input->rankOf() == 5, 0, "UPSAMPLING3D op: input should be 5D, but got %i instead!", input->rankOf());
REQUIRE_TRUE(output->rankOf() == 5, 0, "UPSAMPLING3D op: output should be 5D, but got %i instead!", output->rankOf()); REQUIRE_TRUE(output->rankOf() == 5, 0, "UPSAMPLING3D op: output should be 5D, but got %i instead!", output->rankOf());
ConvolutionUtils::upsampling3d(*block.launchContext(), *input, *output, factorD, factorH, factorW, (bool)isNCDHW); ConvolutionUtils::upsampling3d(block, *input, *output, factorD, factorH, factorW, (bool)isNCDHW);
return Status::OK(); return Status::OK();
} }
@ -105,7 +105,7 @@ CUSTOM_OP_IMPL(upsampling3d_bp, 2, 1, false, 0, 0) {
REQUIRE_TRUE(gradO->rankOf() == 5, 0, "UPSAMPLING3D_BP op: output's gradient array must be 4D, but got %i instead!", gradO->rankOf()); REQUIRE_TRUE(gradO->rankOf() == 5, 0, "UPSAMPLING3D_BP op: output's gradient array must be 4D, but got %i instead!", gradO->rankOf());
REQUIRE_TRUE(gradI->rankOf() == 5, 0, "UPSAMPLING3D_BP op: input's gradient array must be 4D, but got %i instead!", gradI->rankOf()); REQUIRE_TRUE(gradI->rankOf() == 5, 0, "UPSAMPLING3D_BP op: input's gradient array must be 4D, but got %i instead!", gradI->rankOf());
ConvolutionUtils::upsampling3dBP(*block.launchContext(), *gradO, *gradI, (bool)isNCDHW); ConvolutionUtils::upsampling3dBP(block, *gradO, *gradI, (bool)isNCDHW);
return Status::OK(); return Status::OK();
} }

View File

@ -216,7 +216,7 @@ CUSTOM_OP_IMPL(batchnorm_new, 3, 1, false, 1, 2) {
} }
std::vector<Nd4jLong> shape({2, mean->lengthOf()}); std::vector<Nd4jLong> shape({2, mean->lengthOf()});
NDArray weights = NDArrayFactory::create<float>('c', shape, block.getWorkspace()); NDArray weights = NDArrayFactory::create<float>('c', shape, block.launchContext());
weights({0, 1, 0, 0}).assign(1.0f); weights({0, 1, 0, 0}).assign(1.0f);
weights({1, 2, 0, 0}).assign(0.0f); weights({1, 2, 0, 0}).assign(0.0f);

View File

@ -72,6 +72,11 @@ namespace nd4j {
if (dims.size() > 1) if (dims.size() > 1)
std::sort(dims.begin(), dims.end()); std::sort(dims.begin(), dims.end());
for (auto d:dims) {
REQUIRE_TRUE(inputShape->at(0)[d+1] != 0, 0, "ArgMax: you can't reduce along axis with 0 in shape");
}
// special case - output is scalar // special case - output is scalar
if (dims.size() == 0 || (dims.size() == 1 && dims.at(0) == MAX_INT)) { if (dims.size() == 0 || (dims.size() == 1 && dims.at(0) == MAX_INT)) {
return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(nd4j::DataType::INT64)); return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(nd4j::DataType::INT64));

View File

@ -72,6 +72,10 @@ namespace nd4j {
if (dims.size() > 1) if (dims.size() > 1)
std::sort(dims.begin(), dims.end()); std::sort(dims.begin(), dims.end());
for (auto d:dims) {
REQUIRE_TRUE(inputShape->at(0)[d+1] != 0, 0, "ArgMin: you can't reduce along axis with 0 in shape");
}
// special case - output is scalar // special case - output is scalar
if (dims.size() == 0 || (dims.size() == 1 && dims.at(0) == MAX_INT)) { if (dims.size() == 0 || (dims.size() == 1 && dims.at(0) == MAX_INT)) {
return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(DataType::INT64)); return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(DataType::INT64));

View File

@ -71,10 +71,8 @@ namespace nd4j {
ALLOCATE(newShape, block.getWorkspace(), shape::shapeInfoLength(len), Nd4jLong); ALLOCATE(newShape, block.getWorkspace(), shape::shapeInfoLength(len), Nd4jLong);
newShape[0] = len; newShape[0] = len;
auto empty = false;
for (int e = 0; e < shapeArray->lengthOf(); e++){ for (int e = 0; e < shapeArray->lengthOf(); e++){
newShape[e+1] = shapeArray->e<Nd4jLong>(e); newShape[e+1] = shapeArray->e<Nd4jLong>(e);
empty |= (newShape[e+1] == 0); //Support "zeros in shape as empty" for TF import
} }
nd4j::DataType dataType; nd4j::DataType dataType;
@ -90,10 +88,6 @@ namespace nd4j {
} else } else
throw std::runtime_error("Fill: missing value to fill output array with"); throw std::runtime_error("Fill: missing value to fill output array with");
if(empty){
return SHAPELIST(ShapeBuilders::emptyShapeInfo(dataType, block.getWorkspace()));
}
ShapeUtils::updateStridesAndType(newShape, dataType, 'c'); ShapeUtils::updateStridesAndType(newShape, dataType, 'c');
return SHAPELIST(CONSTANT(newShape)); return SHAPELIST(CONSTANT(newShape));

View File

@ -151,8 +151,10 @@ DECLARE_SHAPE_FN(range) {
delta = INPUT_VARIABLE(2)->e<double>(0); delta = INPUT_VARIABLE(2)->e<double>(0);
} }
if (limit == start) if (limit == start){
return SHAPELIST(ConstantShapeHelper::getInstance()->emptyShapeInfo(dtype)); //Return [0] to match TF
return SHAPELIST(ConstantShapeHelper::getInstance()->vectorShapeInfo(0, dtype));
}
REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !"); REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !");
@ -177,8 +179,10 @@ DECLARE_SHAPE_FN(range) {
//nd4j_printf("Start: [%lld]; Limit: [%lld]; Delta: [%lld];\n", start, limit, delta) //nd4j_printf("Start: [%lld]; Limit: [%lld]; Delta: [%lld];\n", start, limit, delta)
if (limit == start) if (limit == start){
return SHAPELIST(ConstantShapeHelper::getInstance()->emptyShapeInfo(dtype)); //Return [0] to match TF
return SHAPELIST(ConstantShapeHelper::getInstance()->vectorShapeInfo(0, dtype));
}
REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !"); REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !");
@ -203,8 +207,10 @@ DECLARE_SHAPE_FN(range) {
delta = INT_ARG(2); delta = INT_ARG(2);
} }
if (limit == start) if (limit == start){
return SHAPELIST(ConstantShapeHelper::getInstance()->emptyShapeInfo(nd4j::DataType::INT32)); //Return [0] to match TF
return SHAPELIST(ConstantShapeHelper::getInstance()->vectorShapeInfo(0, nd4j::DataType::INT32));
}
REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !"); REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !");
@ -233,9 +239,10 @@ DECLARE_SHAPE_FN(range) {
delta = T_ARG(2); delta = T_ARG(2);
} }
//REQUIRE_TRUE(limit != start, 0, "CUSTOM RANGE OP: limit and start values should be different, but got both equal to %f !", limit); if (limit == start){
if (limit == start) //Return [0] to match TF
return SHAPELIST(ConstantShapeHelper::getInstance()->emptyShapeInfo(Environment::getInstance()->defaultFloatDataType())); return SHAPELIST(ConstantShapeHelper::getInstance()->vectorShapeInfo(0, Environment::getInstance()->defaultFloatDataType()));
}
REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !"); REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !");

View File

@ -31,7 +31,8 @@ namespace nd4j {
REQUIRE_TRUE(output->isScalar(), 0, "Rank output should be scalar"); REQUIRE_TRUE(output->isScalar(), 0, "Rank output should be scalar");
output->assign(static_cast<Nd4jLong>(input->rankOf())); // output->assign(static_cast<Nd4jLong>(input->rankOf()));
output->assign(input->rankOf());
return Status::OK(); return Status::OK();
} }

View File

@ -132,17 +132,13 @@ namespace nd4j {
REQUIRE_TRUE(size == -1 || size >= 0, 0, "Invalid size[%i] value: must be positive (or -1 for 'all remaining'), got %i", e, size, inShape[e+1]); REQUIRE_TRUE(size == -1 || size >= 0, 0, "Invalid size[%i] value: must be positive (or -1 for 'all remaining'), got %i", e, size, inShape[e+1]);
REQUIRE_TRUE(start >= 0 && start <= inShape[e+1], 0, "Invalid begin[%i] value: Begin must satisfy 0 <= begin <= size[i], got begin=%i for dimension size %i", e, start, inShape[e+1]); REQUIRE_TRUE(start >= 0 && start <= inShape[e+1], 0, "Invalid begin[%i] value: Begin must satisfy 0 <= begin <= size[i], got begin=%i for dimension size %i", e, start, inShape[e+1]);
REQUIRE_TRUE(start + size <= inShape[e+1], 0, "Slice: interval [%i, %i] is out of bounds for dimension %i with size %i", start, start + size, e, inShape[e+1]); REQUIRE_TRUE(start + size <= inShape[e+1], 0, "Slice: interval [%i, %i] is out of bounds for dimension %i with size %i", start, start + size, e, inShape[e+1]);
if(start == inShape[e+1] || size == 0 ){ if(start == inShape[e+1] ){
empty = true; size = 0;
} }
shape.emplace_back(size); shape.emplace_back(size);
} }
if(empty){
return SHAPELIST(ShapeBuilders::emptyShapeInfo(nd4j::DataType::INT32, block.getWorkspace()));
}
auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShape), 'c', shape); auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShape), 'c', shape);
return SHAPELIST(newShape); return SHAPELIST(newShape);
} }

View File

@ -34,6 +34,10 @@ CUSTOM_OP_IMPL(stack, -1, 1, false, 0, 0) {
if(dim < 0) if(dim < 0)
dim += input->rankOf() + 1; dim += input->rankOf() + 1;
// no-op in case of empty output array
if (output->isEmpty())
return Status::OK();
// input validation // input validation
// check whether shapes of all input array are the same // check whether shapes of all input array are the same
for (int i = 0; i < (int) block.width() - 1; ++i) for (int i = 0; i < (int) block.width() - 1; ++i)
@ -48,16 +52,6 @@ CUSTOM_OP_IMPL(stack, -1, 1, false, 0, 0) {
helpers::stack(block.launchContext(), inArrs, output, dim); helpers::stack(block.launchContext(), inArrs, output, dim);
// remove unity from output shape if input arrays are vectors
// if(input->isVector()) {
// std::vector<int> outShape(output->shapeOf(), output->shapeOf() + output->rankOf());
// outShape.erase(find(outShape.begin(), outShape.end(), 1));
// output->reshapei(output->ordering(), outShape);
// if(dim != 0 && (int)block.width() == 1) // such is implemented by tensorFlow
// output->permutei({1, 0});
// output->getShapeInfo()[output->rankOf()*2 + 2] = 1;
// }
return Status::OK(); return Status::OK();
} }
DECLARE_SYN(pack, stack); DECLARE_SYN(pack, stack);
@ -82,8 +76,22 @@ DECLARE_SHAPE_FN(stack) {
REQUIRE_TRUE(dim <= inShapeInfo[0], 0, "STACK op: the input dimension parameter must be <= rank of input arrays shapes (rank=%i), but got %i instead !", inShapeInfo[0], dim); REQUIRE_TRUE(dim <= inShapeInfo[0], 0, "STACK op: the input dimension parameter must be <= rank of input arrays shapes (rank=%i), but got %i instead !", inShapeInfo[0], dim);
// empty input arrays require some special handling
if (shape::isEmpty(inShapeInfo)) {
switch (rank) {
case 0: {
// we're going to return rank 1 here
if (block.width() == 1) {
return SHAPELIST(ConstantShapeHelper::getInstance()->vectorShapeInfo(0, ArrayOptions::dataType(inShapeInfo)));
} else {
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShapeInfo), 'c', {(Nd4jLong) block.width(), 0}));
}
}
}
}
if(rank == 0) { if(rank == 0) {
return SHAPELIST(ConstantShapeHelper::getInstance()->vectorShapeInfo(block.width(), ArrayOptions::dataType(inShapeInfo))); return SHAPELIST(ConstantShapeHelper::getInstance()->vectorShapeInfo(block.width(), ArrayOptions::dataType(inShapeInfo)));
} }
//the rank of output ShapeInfo is larger by one compared to input ShapeInfo //the rank of output ShapeInfo is larger by one compared to input ShapeInfo
@ -91,13 +99,9 @@ DECLARE_SHAPE_FN(stack) {
// insert (int) block.width() at dim position of input shape to get output shape // insert (int) block.width() at dim position of input shape to get output shape
outShape.insert(outShape.begin() + Nd4jLong(dim), (Nd4jLong) block.width()); outShape.insert(outShape.begin() + Nd4jLong(dim), (Nd4jLong) block.width());
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(inShapeInfo), shape::order(inShapeInfo), outShape))); return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(inShapeInfo), shape::order(inShapeInfo), outShape)));
} }
// 1) 1х4 + 1х4 = 2х1х4 (along dim=0) = 2x4
// 2) 1х4 + 1х4 = 1х2х4 (along dim=1) = 2x4
// 3) 4х1 + 4х1 = 2х4x1 (along dim=0) = 2x4
// 4) 4х1 + 4х1 = 4х2x1 (along dim=1) = 4x2
} }
} }

View File

@ -410,10 +410,13 @@ namespace nd4j {
// z->assign(x->e<float>(indices[0])); // z->assign(x->e<float>(indices[0]));
// } // }
// else { // else {
auto sub = (*x)(indices, true, true); if (indices.size()) {
z->assign(sub); auto sub = (*x)(indices, true, true);
// } z->assign(sub);
}
else if (!z->isEmpty()){
z->assign(x->e(0));
}
return Status::OK(); return Status::OK();
} }
DECLARE_SYN(stridedslice, strided_slice); DECLARE_SYN(stridedslice, strided_slice);
@ -496,28 +499,19 @@ namespace nd4j {
bool is_simple_slice; bool is_simple_slice;
bool is_dim0; bool is_dim0;
// FIXME: remove this, once we bring in 1D NDArrays std::vector<Nd4jLong> indices;
//vectorize(input_shape); bool result = _preprocess_strided_slice(&indices, &shape, input_shape, begin, end, strides, begin_mask, ellipsis_mask, end_mask, new_axis_mask, shrink_axis_mask, &is_identity, &is_simple_slice, &is_dim0);
bool result = _preprocess_strided_slice(nullptr, &shape, input_shape, begin, end, strides, begin_mask, ellipsis_mask, end_mask, new_axis_mask, shrink_axis_mask, &is_identity, &is_simple_slice, &is_dim0); if (indices.size()) {
bool nonEmpty = shape.size() > 0; newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShape), 'c',
if (nonEmpty) shape);
for (auto x: shape) { if (inputLen > 1) {
if (x == 0) { newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShape), 'c',
nonEmpty = false; shape);
break;
}
}
if (nonEmpty && inputLen > 1) {
newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShape), 'c', shape);
}
else {
if (shape::rank(inShape) == 0 || begin >= end) {
newShape = ConstantShapeHelper::getInstance()->emptyShapeInfo(ArrayOptions::dataType(inShape));
} else { } else {
newShape = ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(inShape)); newShape = ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(inShape));
} }
} } else
newShape = ConstantShapeHelper::getInstance()->emptyShapeInfo(ArrayOptions::dataType(inShape));
return SHAPELIST(newShape); return SHAPELIST(newShape);
} }

View File

@ -37,6 +37,9 @@ namespace nd4j {
REQUIRE_TRUE(dim < input->rankOf(), 0, "Unstack dimension should be lower then rank of input %i, but got dimension=%i !", input->rankOf(), dim); REQUIRE_TRUE(dim < input->rankOf(), 0, "Unstack dimension should be lower then rank of input %i, but got dimension=%i !", input->rankOf(), dim);
REQUIRE_TRUE(dim >= 0, 0, "Unstack dimension should be non-negative value, but got %i !", dim); REQUIRE_TRUE(dim >= 0, 0, "Unstack dimension should be non-negative value, but got %i !", dim);
if(input->isEmpty())
return Status::OK();
std::vector<int> dims; std::vector<int> dims;
for (int e = 0; e < input->rankOf(); e++) for (int e = 0; e < input->rankOf(); e++)
if (e != dim) if (e != dim)
@ -76,6 +79,21 @@ namespace nd4j {
REQUIRE_TRUE(dim < inShape[0], 0, "UNSTACK op: dimension should be lower then rank of input %i, but got dimension=%i !", inShape[0], dim); REQUIRE_TRUE(dim < inShape[0], 0, "UNSTACK op: dimension should be lower then rank of input %i, but got dimension=%i !", inShape[0], dim);
REQUIRE_TRUE(dim >= 0, 0, "UNSTACK op: dimension should be non-negative value, but got %i !", dim); REQUIRE_TRUE(dim >= 0, 0, "UNSTACK op: dimension should be non-negative value, but got %i !", dim);
if(ArrayOptions::arrayType(inShape) == ArrayType::EMPTY) {
if(shape::shapeOf(inShape)[dim] == 0)
return SHAPELIST();
const Nd4jLong numTads = shape::shapeOf(inShape)[dim];
std::vector<Nd4jLong> outShape;
for(uint i = 0; i < shape::rank(inShape); ++i)
if(i != dim)
outShape.push_back(shape::shapeOf(inShape)[i]);
auto result = SHAPELIST();
for(uint i = 0; i < numTads; ++i)
result->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShape), shape::order(inShape), outShape));
return result;
}
std::vector<int> dims; std::vector<int> dims;
for (int e = 0; e < shape::rank(inShape); e++) for (int e = 0; e < shape::rank(inShape); e++)
if (e != dim) if (e != dim)

View File

@ -30,6 +30,12 @@ namespace nd4j {
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0);
REQUIRE_TRUE(output->isScalar(), 0, "Rank output should be scalar"); REQUIRE_TRUE(output->isScalar(), 0, "Rank output should be scalar");
if(input->isEmpty()){
output->p<double>(0, std::numeric_limits<double>::quiet_NaN());
return Status::OK();
}
int numZeros = 0; int numZeros = 0;
// for (int e = 0; e < input->lengthOf(); e++) // for (int e = 0; e < input->lengthOf(); e++)
// if ((*input)(e) == T(0)) // if ((*input)(e) == T(0))

View File

@ -113,16 +113,10 @@ DECLARE_SHAPE_FN(lstmBlock) {
} }
ShapeUtils::updateStridesAndType(s, x, 'c'); ShapeUtils::updateStridesAndType(s, x, 'c');
Nd4jLong *s1, *s2, *s3, *s4, *s5, *s6; Nd4jLong *s1 = CONSTANT(s);
COPY_SHAPE(s, s1);
COPY_SHAPE(s, s2);
COPY_SHAPE(s, s3);
COPY_SHAPE(s, s4);
COPY_SHAPE(s, s5);
COPY_SHAPE(s, s6);
//7 outputs, all same shape/type //7 outputs, all same shape/type
return SHAPELIST(s, s1, s2, s3, s4, s5, s6); return SHAPELIST(s1, s1, s1, s1, s1, s1, s1);
} }
} }

View File

@ -115,16 +115,10 @@ DECLARE_SHAPE_FN(lstmBlockCell) {
ShapeUtils::updateStridesAndType(s, xt, 'c'); ShapeUtils::updateStridesAndType(s, xt, 'c');
Nd4jLong *s1, *s2, *s3, *s4, *s5, *s6; Nd4jLong *s1 = CONSTANT(s);
COPY_SHAPE(s, s1);
COPY_SHAPE(s, s2);
COPY_SHAPE(s, s3);
COPY_SHAPE(s, s4);
COPY_SHAPE(s, s5);
COPY_SHAPE(s, s6);
//7 outputs, all same shape: z, i, f, o, h, c, y //7 outputs, all same shape: z, i, f, o, h, c, y
return SHAPELIST(s, s1, s2, s3, s4, s5, s6); return SHAPELIST(s1, s1, s1, s1, s1, s1, s1);
} }
} }

View File

@ -78,7 +78,6 @@ DECLARE_SHAPE_FN(broadcast_to) {
REQUIRE_TRUE(inputShapeInfo[inputRank+1-i] == outShape[shapeLen-i] || inputShapeInfo[inputRank+1-i] == 1, 0, "BROADCAST_TO op: shape of input array %s can't be broadcasted to the shape %s !", ShapeUtils::shapeAsString(inputShapeInfo).c_str(), ShapeUtils::shapeAsString(outShape).c_str()); REQUIRE_TRUE(inputShapeInfo[inputRank+1-i] == outShape[shapeLen-i] || inputShapeInfo[inputRank+1-i] == 1, 0, "BROADCAST_TO op: shape of input array %s can't be broadcasted to the shape %s !", ShapeUtils::shapeAsString(inputShapeInfo).c_str(), ShapeUtils::shapeAsString(outShape).c_str());
auto outShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inputShapeInfo), shape::order(inputShapeInfo), outShape); auto outShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inputShapeInfo), shape::order(inputShapeInfo), outShape);
return SHAPELIST(outShapeInfo); return SHAPELIST(outShapeInfo);
} }

View File

@ -34,26 +34,26 @@ namespace nd4j {
bool replace = false; bool replace = false;
auto arguments = block.getIArguments(); auto origArgs = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT<int>() : *block.getIArguments();
if (block.width() == 2 && arguments->size() == 0) { std::vector<int> arguments({});
auto axis = INPUT_VARIABLE(1); if(origArgs.size() > 0){
for (int e = 0; e < axis->lengthOf(); e++) { for (int e = 0; e < origArgs.size(); e++) {
int ax = axis->e<int>(e); int ax = origArgs[e];
if (ax < 0) if (ax < 0)
ax += x->rankOf(); ax += x->rankOf();
arguments->emplace_back(ax); arguments.emplace_back(ax);
} }
replace = true; replace = true;
} else if (arguments->size() == 0) { } else {
for (int e = x->rankOf() - 1; e >= 0; e--) for (int e = x->rankOf() - 1; e >= 0; e--)
arguments->emplace_back(e); arguments.emplace_back(e);
} }
// 0D edge case // 0D edge case
if (x->rankOf() == 0) { if (x->rankOf() == 0) {
REQUIRE_TRUE(arguments->size() == 1, 0, "Permute: only one axis is allowed for scalar"); REQUIRE_TRUE(arguments.size() == 1, 0, "Permute: only one axis is allowed for scalar");
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0);
if (!block.isInplace()) if (!block.isInplace())
output->assign(x); output->assign(x);
@ -62,25 +62,17 @@ namespace nd4j {
} }
if(block.isInplace()) { // in-place if(block.isInplace()) { // in-place
x->permutei(*arguments); x->permutei(arguments);
STORE_RESULT(x); STORE_RESULT(x);
} else { } else {
if (!replace) { // not-in-place auto output = OUTPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0); auto result = x->permute(arguments);
// nd4j_printv("permute shape", *arguments); output->assign(result);
auto result = x->permute(*arguments); STORE_RESULT(output);
output->assign(result); delete result;
STORE_RESULT(output);
delete result;
} else {
auto output = OUTPUT_VARIABLE(0); //->dup();
output->assign(x);
output->permutei(*arguments);
//OVERWRITE_RESULT(output);
}
} }
return Status::OK();
return Status::OK();
} }
DECLARE_TYPES(permute) { DECLARE_TYPES(permute) {
@ -92,20 +84,21 @@ namespace nd4j {
DECLARE_SHAPE_FN(permute) { DECLARE_SHAPE_FN(permute) {
auto shapeList = SHAPELIST(); auto shapeList = SHAPELIST();
auto arguments = block.getIArguments(); auto arguments = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT<int>() : *block.getIArguments();
if (shape::rank(inputShape->at(0)) == 0) { if (shape::rank(inputShape->at(0)) == 0) {
shapeList->push_back(ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(inputShape->at(0)))); shapeList->push_back(ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(inputShape->at(0))));
} else if (inputShape->size() == 1 && !arguments->empty()) { } else if (inputShape->size() == 1 && !arguments.empty()) {
shapeList->push_back(ShapeUtils::evalPermShapeInfo(arguments->data(), arguments->size(), *INPUT_VARIABLE(0), block.workspace())); shapeList->push_back(ShapeUtils::evalPermShapeInfo(arguments.data(), arguments.size(), *INPUT_VARIABLE(0), block.workspace()));
} else if (inputShape->size() == 2) {
// dead end
shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(inputShape->at(0))));
} else { } else {
int rank = shape::rank(inputShape->at(0)); if(arguments.size() == 0){
for (int e = rank - 1; e >= 0; e--) //Reverse dimensions
arguments->emplace_back(e); int rank = shape::rank(inputShape->at(0));
for (int e = rank - 1; e >= 0; e--)
arguments.emplace_back(e);
}
shapeList->push_back(ShapeUtils::evalPermShapeInfo(arguments->data(), arguments->size(), *INPUT_VARIABLE(0), block.workspace())); shapeList->push_back(ShapeUtils::evalPermShapeInfo(arguments.data(), arguments.size(), *INPUT_VARIABLE(0), block.workspace()));
} }
return shapeList; return shapeList;

View File

@ -35,9 +35,8 @@ namespace nd4j {
auto arguments = block.getIArguments(); auto arguments = block.getIArguments();
int argsSize = arguments->size(); int argsSize = arguments->size();
//Special case: empty.reshape(-1) -> return empty //Special case: empty.reshape(<other empty shape>) -> return empty
if (x->isEmpty()) { if (x->isEmpty()) {
REQUIRE_TRUE((int) arguments->size() == 1 && arguments->at(0) == -1, 0, "Reshape: when input is empty, iargs must be [-1]");
REQUIRE_TRUE(OUTPUT_VARIABLE(0)->isEmpty(), 0, "Reshape: when input is empty, output must also be empty"); REQUIRE_TRUE(OUTPUT_VARIABLE(0)->isEmpty(), 0, "Reshape: when input is empty, output must also be empty");
return ND4J_STATUS_OK; //No op return ND4J_STATUS_OK; //No op
} }
@ -96,9 +95,9 @@ namespace nd4j {
//Special case: empty.reshape(-1) -> return empty //Special case: empty.reshape(-1) -> return empty
if (x->isEmpty()) { if (x->isEmpty()) {
REQUIRE_TRUE(s->lengthOf() == 1 && s->e<Nd4jLong>(0) == -1, 0, "Reshape: when input is empty, shape must be [-1]"); //REQUIRE_TRUE(s->lengthOf() == 1 && s->e<Nd4jLong>(0) == -1, 0, "Reshape: when input is empty, shape must be [-1]");
REQUIRE_TRUE(OUTPUT_VARIABLE(0)->isEmpty(), 0, "Reshape: when input is empty, output must also be empty"); REQUIRE_TRUE(OUTPUT_VARIABLE(0)->isEmpty(), 0, "Reshape: when input is empty, output must also be empty");
return ND4J_STATUS_OK; //No op return Status::OK(); //No op
} }
char order = 'c'; char order = 'c';
@ -116,7 +115,8 @@ namespace nd4j {
} }
for(int e2 = e + 1; e2 < (int) s->lengthOf(); e2++){ for(int e2 = e + 1; e2 < (int) s->lengthOf(); e2++){
REQUIRE_TRUE(s->e<Nd4jLong>(e2) != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed."); REQUIRE_TRUE(s->e<Nd4jLong>(e2) != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
shapeLength *= s->e<Nd4jLong>(e2); shapeLength *=
s->e<Nd4jLong>(e2);
} }
long realShape = x->lengthOf() / shapeLength; long realShape = x->lengthOf() / shapeLength;
shapeNew[e] = realShape; shapeNew[e] = realShape;
@ -175,12 +175,12 @@ namespace nd4j {
e = 0; e = 0;
} }
//Special case: empty.reshape(-1) -> return empty // //Special case: empty.reshape(-1) -> return empty
if (INPUT_VARIABLE(0)->isEmpty()) { // if (INPUT_VARIABLE(0)->isEmpty()) {
REQUIRE_TRUE((int) arguments->size() == 1 && arguments->at(0) == -1, 0, "Reshape: when input is empty, iargs must be [-1]"); // //
auto newShape = ConstantShapeHelper::getInstance()->emptyShapeInfo(ArrayOptions::dataType(inp)); // auto newShape = ConstantShapeHelper::getInstance()->emptyShapeInfo(ArrayOptions::dataType(inp));
return SHAPELIST(newShape); // return SHAPELIST(newShape);
} // }
std::vector<Nd4jLong> shapeNew; std::vector<Nd4jLong> shapeNew;
@ -197,8 +197,14 @@ namespace nd4j {
shapeLength *= arguments->at(e2); shapeLength *= arguments->at(e2);
} }
long realShape = shape::length(inp) / shapeLength; if(shapeLength == 0){
shapeNew.push_back(realShape); //Edge case for empty:
shapeNew.push_back(0);
} else {
//Standard case
long realShape = shape::length(inp) / shapeLength;
shapeNew.push_back(realShape);
}
} }
else{ else{
shapeNew.push_back(arguments->at(e)); shapeNew.push_back(arguments->at(e));
@ -218,9 +224,16 @@ namespace nd4j {
} }
//Special case: empty.reshape(-1) -> return empty //Special case: empty.reshape(-1) -> return empty
if (x->isEmpty()) { if (x->isEmpty()) {
REQUIRE_TRUE(y->lengthOf() == 1 && y->e<Nd4jLong>(0) == -1, 0, "Reshape: when input is empty, shape must be [-1]"); //REQUIRE_TRUE(y->lengthOf() == 1 && y->e<Nd4jLong>(0) == -1, 0, "Reshape: when input is empty, shape must be [-1]");
auto newShape = ConstantShapeHelper::getInstance()->emptyShapeInfo(ArrayOptions::dataType(inp)); auto shapeOf = y->getBufferAsVector<Nd4jLong>();
return SHAPELIST(newShape); Nd4jLong prod = 1;
for (auto v:shapeOf)
prod *= v;
REQUIRE_TRUE(prod == 0, 0, "Reshape: in case of empty arrays reshape must return empty array as well");
auto newShape = ShapeBuilders::createShapeInfo(ArrayOptions::dataType(inp), shape::order(inp), y->lengthOf(), shapeOf.data());
return SHAPELIST(CONSTANT(newShape));
} }
std::vector<Nd4jLong> shapeNew(y->lengthOf()); std::vector<Nd4jLong> shapeNew(y->lengthOf());
@ -236,8 +249,14 @@ namespace nd4j {
REQUIRE_TRUE(y->e<Nd4jLong>(e2) != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed."); REQUIRE_TRUE(y->e<Nd4jLong>(e2) != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
shapeLength *= y->e<Nd4jLong>(e2); shapeLength *= y->e<Nd4jLong>(e2);
} }
long realShape = shape::length(inp) / shapeLength;
shapeNew[e] = realShape; if(shapeLength == 0){
//Edge case for empty:
shapeNew[e] = 0;
} else {
long realShape = shape::length(inp) / shapeLength;
shapeNew[e] = realShape;
}
}else { }else {
shapeNew[e] = dim; shapeNew[e] = dim;
} }

View File

@ -38,21 +38,26 @@ CUSTOM_OP_IMPL(concat, -1, 1, false, 0, 1) {
std::vector<int> arrsToDelete; std::vector<int> arrsToDelete;
int index = 0; int index = 0;
bool allOfSameType = true; bool allOfSameType = true;
auto theFirstRank = block.width() > 0?INPUT_VARIABLE(0)->rankOf():0;
auto theFirstDatatype = block.width() > 0?INPUT_VARIABLE(0)->dataType():block.dataType();
for(int i = 0; i < block.width(); ++i) { for(int i = 0; i < block.width(); ++i) {
auto input = INPUT_VARIABLE(i);
auto currentRank = input->rankOf();
if(!INPUT_VARIABLE(i)->isEmpty()) { // TODO: follow two lines are accordingly with current tf.concat spec. Commented for compatibility with legacy
// REQUIRE_TRUE(currentRank > 0, 0, "Rank of input variable %i must be greater 0, but is %lld instead.", i, currentRank);
// REQUIRE_TRUE(theFirstRank == currentRank, 0, "Number of dimensions in concat should be equals, but for %i input variable %lld != %lld appears.", i, currentRank, theFirstRank);
if(!input->isEmpty()) {
allOfSameType &= (INPUT_VARIABLE(0)->dataType() == INPUT_VARIABLE(i)->dataType()); allOfSameType &= (theFirstDatatype == input->dataType());
if(INPUT_VARIABLE(i)->rankOf() == 0) { if(input->rankOf() == 0) {
// FIXME, use this instead: block.dataType() auto vec = new NDArray('c', {1}, input->dataType(), block.launchContext());
auto vec = new NDArray('c', {1}, INPUT_VARIABLE(0)->dataType(), block.launchContext()); vec->assign(input);
vec->assign(INPUT_VARIABLE(i));
nonEmptyArrs.push_back(vec); nonEmptyArrs.push_back(vec);
arrsToDelete.push_back(index); arrsToDelete.push_back(index);
} }
else{ else{
nonEmptyArrs.push_back(INPUT_VARIABLE(i)); nonEmptyArrs.push_back(input);
} }
++index; ++index;
} }
@ -113,33 +118,24 @@ DECLARE_SHAPE_FN(concat) {
// first of all take into account possible presence of empty arrays // first of all take into account possible presence of empty arrays
// also if scalar is present -> use the shape of vector with length=1 instead // also if scalar is present -> use the shape of vector with length=1 instead
std::vector<Nd4jLong*> nonEmptyArrShapes; std::vector<Nd4jLong*> arrShapes;
std::vector<int> shapesToDelete; std::vector<int> shapesToDelete;
int index = 0; int index = 0;
for(int i = 0; i < block.width(); ++i) { for(int i = 0; i < block.width(); ++i) {
if(!INPUT_VARIABLE(i)->isEmpty()) { if(inputShape->at(i)[0] == 0) {
// FIXME, use this instead: block.dataType()
if(inputShape->at(i)[0] == 0) { arrShapes.push_back(ConstantShapeHelper::getInstance()->vectorShapeInfo(1, INPUT_VARIABLE(0)->dataType()));
// FIXME, use this instead: block.dataType()
nonEmptyArrShapes.push_back(ConstantShapeHelper::getInstance()->vectorShapeInfo(1, INPUT_VARIABLE(0)->dataType()));
}
else{
nonEmptyArrShapes.push_back(inputShape->at(i));
}
++index;
} }
else{
arrShapes.push_back(inputShape->at(i));
}
++index;
} }
const int numOfArrs = nonEmptyArrShapes.size(); const int numOfArrs = arrShapes.size();
if(numOfArrs == 0){ const int rank = arrShapes[0][0];
//All inputs are empty arrays -> return empty, mainly for TF import compatibility
auto empty = ConstantShapeHelper::getInstance()->emptyShapeInfo(INPUT_VARIABLE(0)->dataType());
return SHAPELIST(empty);
}
const int rank = nonEmptyArrShapes[0][0]; // look up to first non-empty array
int axis = INT_ARG(0); int axis = INT_ARG(0);
if(axis < 0) if(axis < 0)
@ -149,33 +145,33 @@ DECLARE_SHAPE_FN(concat) {
REQUIRE_TRUE(0 <= axis && axis < rank, 0, "CONCAT op: input axis must be in range [0, %i], but got %i instead!", rank-1, axis); REQUIRE_TRUE(0 <= axis && axis < rank, 0, "CONCAT op: input axis must be in range [0, %i], but got %i instead!", rank-1, axis);
for(int i = 1; i < numOfArrs; ++i) for(int i = 1; i < numOfArrs; ++i)
REQUIRE_TRUE(nonEmptyArrShapes[i][0] == rank, 0, "CONCAT op: all input arrays must have the same rank !"); REQUIRE_TRUE(arrShapes[i][0] == rank, 0, "CONCAT op: all input arrays must have the same rank !");
for(int i = 1; i < numOfArrs; ++i) { for(int i = 1; i < numOfArrs; ++i) {
for(int dim = 0; dim < rank; ++dim) for(int dim = 0; dim < rank; ++dim)
if(dim != axis) if(dim != axis)
REQUIRE_TRUE(nonEmptyArrShapes[i][dim+1] == nonEmptyArrShapes[0][dim+1], 0, "CONCAT op: all input arrays must have the same dimensions (except those on input axis) !"); REQUIRE_TRUE(arrShapes[i][dim+1] == arrShapes[0][dim+1], 0, "CONCAT op: all input arrays must have the same dimensions (except those on input axis) !");
} }
// ******** end of input validation ******** // // ******** end of input validation ******** //
Nd4jLong* outShapeInfo(nullptr); Nd4jLong* outShapeInfo(nullptr);
COPY_SHAPE(nonEmptyArrShapes[0], outShapeInfo); COPY_SHAPE(arrShapes[0], outShapeInfo);
// case when we have only one input array // case when we have only one input array
if(numOfArrs == 1) { if(numOfArrs == 1) {
ShapeUtils::updateStridesAndType(outShapeInfo, nonEmptyArrShapes[0], shape::order(nonEmptyArrShapes[0])); ShapeUtils::updateStridesAndType(outShapeInfo, arrShapes[0], shape::order(arrShapes[0]));
return SHAPELIST(CONSTANT(outShapeInfo)); return SHAPELIST(CONSTANT(outShapeInfo));
} }
for(int i = 1; i < numOfArrs; ++i) for(int i = 1; i < numOfArrs; ++i)
outShapeInfo[axis + 1] += nonEmptyArrShapes[i][axis + 1]; outShapeInfo[axis + 1] += arrShapes[i][axis + 1];
ShapeUtils::updateStridesAndType(outShapeInfo, nonEmptyArrShapes[0], shape::order(nonEmptyArrShapes[0])); ShapeUtils::updateStridesAndType(outShapeInfo, arrShapes[0], shape::order(arrShapes[0]));
// delete dynamically allocated vectors shapes with length=1 // delete dynamically allocated vectors shapes with length=1
for(int index : shapesToDelete) for(int index : shapesToDelete)
RELEASE(nonEmptyArrShapes[index], block.getWorkspace()); RELEASE(arrShapes[index], block.getWorkspace());
auto result = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(outShapeInfo)); auto result = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(outShapeInfo));
RELEASE(outShapeInfo, block.getWorkspace()); RELEASE(outShapeInfo, block.getWorkspace());

View File

@ -32,6 +32,11 @@ namespace nd4j {
REQUIRE_TRUE(input->dataType() == output->dataType(), 0, "CumSum: input and output data types must be equal"); REQUIRE_TRUE(input->dataType() == output->dataType(), 0, "CumSum: input and output data types must be equal");
if(input->isEmpty()){
//No-op
return Status::OK();
}
const bool exclusive = INT_ARG(0) == 1; const bool exclusive = INT_ARG(0) == 1;
const bool reverse = INT_ARG(1) == 1; const bool reverse = INT_ARG(1) == 1;

View File

@ -36,6 +36,11 @@ CONFIGURABLE_OP_IMPL(cumsum, 1, 1, true, 0, 2) {
REQUIRE_TRUE(input->dataType() == output->dataType(), 0, "CumSum: input and output data types must be equal"); REQUIRE_TRUE(input->dataType() == output->dataType(), 0, "CumSum: input and output data types must be equal");
if(input->isEmpty()){
//No-op
return Status::OK();
}
if (block.getIArguments()->size() == 2 && block.width() == 1) { if (block.getIArguments()->size() == 2 && block.width() == 1) {
// all at once case // all at once case
nd4j::ops::helpers::_prefix(block.launchContext(), scalar::Add, input, output, exclusive, reverse); nd4j::ops::helpers::_prefix(block.launchContext(), scalar::Add, input, output, exclusive, reverse);

View File

@ -102,12 +102,6 @@ DECLARE_SHAPE_FN(gather) {
if(axis < 0) if(axis < 0)
axis += inputRank; axis += inputRank;
//Edge case: empty indices, empty input -> empty output
if(block.width() > 1 && INPUT_VARIABLE(0)->isEmpty() && INPUT_VARIABLE(1)->isEmpty()){
auto empty = ConstantShapeHelper::getInstance()->emptyShapeInfo(INPUT_VARIABLE(0)->dataType());
return SHAPELIST(empty);
}
REQUIRE_TRUE(axis < inputRank, 0, "GATHER op: input axis must be smaller than input array rank, but got %i and %i correspondingly!", axis, inputRank); REQUIRE_TRUE(axis < inputRank, 0, "GATHER op: input axis must be smaller than input array rank, but got %i and %i correspondingly!", axis, inputRank);
bool isEmpty = false; bool isEmpty = false;
@ -119,11 +113,6 @@ DECLARE_SHAPE_FN(gather) {
int outputRank = inputRank + indicesRank - 1; int outputRank = inputRank + indicesRank - 1;
if(INPUT_VARIABLE(1)->isEmpty()) { //Empty indices -> empty output
outputRank = 0;
isEmpty = true;
}
ALLOCATE(outputShapeInfo, block.getWorkspace(), shape::shapeInfoLength(outputRank), Nd4jLong); ALLOCATE(outputShapeInfo, block.getWorkspace(), shape::shapeInfoLength(outputRank), Nd4jLong);
// fill output shapeInfo // fill output shapeInfo

View File

@ -33,6 +33,11 @@ namespace ops {
auto input = INPUT_VARIABLE(0); auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0);
if(output->isEmpty()){
//No-op
return Status::OK();
}
std::vector<int> axis; std::vector<int> axis;
if (block.width() > 1) if (block.width() > 1)

View File

@ -204,6 +204,8 @@ namespace nd4j {
const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst, const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst,
mkldnn::memory::desc* conv_src_md, mkldnn::memory::desc* conv_diff_src_md, mkldnn::memory::desc* conv_weights_md, mkldnn::memory::desc* conv_src_md, mkldnn::memory::desc* conv_diff_src_md, mkldnn::memory::desc* conv_weights_md,
mkldnn::memory::desc* conv_diff_weights_md, mkldnn::memory::desc* conv_bias_md, mkldnn::memory::desc* conv_dst_md, mkldnn::memory::desc* conv_diff_weights_md, mkldnn::memory::desc* conv_bias_md, mkldnn::memory::desc* conv_dst_md,
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_weights_md,
mkldnn::memory::desc* user_diff_weights_md, mkldnn::memory::desc* user_bias_md, mkldnn::memory::desc* user_dst_md,
mkldnn::memory::dims& conv_strides, mkldnn::memory::dims& conv_padding, mkldnn::memory::dims& conv_padding_r); mkldnn::memory::dims& conv_strides, mkldnn::memory::dims& conv_padding, mkldnn::memory::dims& conv_padding_r);
static void getMKLDNNMemoryDescConv3d( static void getMKLDNNMemoryDescConv3d(
@ -212,56 +214,60 @@ namespace nd4j {
const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst, const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst,
mkldnn::memory::desc* conv_src_md, mkldnn::memory::desc* conv_diff_src_md, mkldnn::memory::desc* conv_weights_md, mkldnn::memory::desc* conv_src_md, mkldnn::memory::desc* conv_diff_src_md, mkldnn::memory::desc* conv_weights_md,
mkldnn::memory::desc* conv_diff_weights_md, mkldnn::memory::desc* conv_bias_md, mkldnn::memory::desc* conv_dst_md, mkldnn::memory::desc* conv_diff_weights_md, mkldnn::memory::desc* conv_bias_md, mkldnn::memory::desc* conv_dst_md,
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_weights_md,
mkldnn::memory::desc* user_diff_weights_md, mkldnn::memory::desc* user_bias_md, mkldnn::memory::desc* user_dst_md,
mkldnn::memory::dims& conv_strides, mkldnn::memory::dims& conv_padding, mkldnn::memory::dims& conv_padding_r); mkldnn::memory::dims& conv_strides, mkldnn::memory::dims& conv_padding, mkldnn::memory::dims& conv_padding_r);
static void getMKLDNNMemoryDescPool2d( static void getMKLDNNMemoryDescPool2d(
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, int poolingMode, int extraParam0, bool isNCHW, int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, int poolingMode, int extraParam0, bool isNCHW,
int bS, int iC, int iH, int iW, int oC, int oH, int oW, int bS, int iC, int iH, int iW, int oC, int oH, int oW,
const NDArray* src, const NDArray* diff_src, const NDArray* dst, const NDArray* src, const NDArray* diff_src, const NDArray* dst, mkldnn::algorithm& algorithm,
mkldnn::memory::desc* pool_src_md, mkldnn::memory::desc* conv_diff_src_md, mkldnn::memory::desc* pool_dst_md, mkldnn::algorithm& algorithm, mkldnn::memory::desc* pool_src_md, mkldnn::memory::desc* pool_diff_src_md, mkldnn::memory::desc* pool_dst_md,
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md,
mkldnn::memory::dims& pool_strides, mkldnn::memory::dims& pool_kernel, mkldnn::memory::dims& pool_padding, mkldnn::memory::dims& pool_padding_r); mkldnn::memory::dims& pool_strides, mkldnn::memory::dims& pool_kernel, mkldnn::memory::dims& pool_padding, mkldnn::memory::dims& pool_padding_r);
static void getMKLDNNMemoryDescPool3d( static void getMKLDNNMemoryDescPool3d(
int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, int poolingMode, int extraParam0, bool isNCDHW, int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, int poolingMode, int extraParam0, bool isNCDHW,
int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW, int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW,
const NDArray* src, const NDArray* diff_src, const NDArray* dst, const NDArray* src, const NDArray* diff_src, const NDArray* dst, mkldnn::algorithm& algorithm,
mkldnn::memory::desc* pool_src_md, mkldnn::memory::desc* conv_diff_src_md, mkldnn::memory::desc* pool_dst_md, mkldnn::algorithm& algorithm, mkldnn::memory::desc* pool_src_md, mkldnn::memory::desc* pool_diff_src_md, mkldnn::memory::desc* pool_dst_md,
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md,
mkldnn::memory::dims& pool_strides, mkldnn::memory::dims& pool_kernel, mkldnn::memory::dims& pool_padding, mkldnn::memory::dims& pool_padding_r); mkldnn::memory::dims& pool_strides, mkldnn::memory::dims& pool_kernel, mkldnn::memory::dims& pool_padding, mkldnn::memory::dims& pool_padding_r);
#endif #endif
static void conv2d(nd4j::LaunchContext &context, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW); static void conv2d(nd4j::graph::Context &context, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW);
static void conv2d(nd4j::LaunchContext & block, const std::vector<NDArray*>& inArrs, NDArray* output, const std::vector<int>& intArgs); static void conv2d(nd4j::graph::Context & block, const std::vector<NDArray*>& inArrs, NDArray* output, const std::vector<int>& intArgs);
static void conv2dBP(nd4j::LaunchContext & block, const std::vector<NDArray*>& inArrs, const std::vector<NDArray*>& outArrs, const std::vector<int>& intArgs); static void conv2dBP(nd4j::graph::Context & block, const std::vector<NDArray*>& inArrs, const std::vector<NDArray*>& outArrs, const std::vector<int>& intArgs);
static void conv2dBP(nd4j::LaunchContext & block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW); static void conv2dBP(nd4j::graph::Context & block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW);
static void depthwiseConv2d(nd4j::LaunchContext & block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW); static void depthwiseConv2d(nd4j::graph::Context & block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW);
static void depthwiseConv2dBP(nd4j::LaunchContext & block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW); static void depthwiseConv2dBP(nd4j::graph::Context & block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW);
static void sconv2d(nd4j::LaunchContext & block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW); static void sconv2d(nd4j::graph::Context & block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW);
static void vol2col(nd4j::LaunchContext & block, const NDArray& vol, NDArray& col, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW); static void vol2col(nd4j::graph::Context & block, const NDArray& vol, NDArray& col, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW);
static void col2vol(nd4j::LaunchContext & block, const NDArray& col, NDArray& vol, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW); static void col2vol(nd4j::graph::Context & block, const NDArray& col, NDArray& vol, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW);
static void upsampling2d(nd4j::LaunchContext & block, const NDArray& input, NDArray& output, const int factorH, const int factorW, const bool isNCHW); static void upsampling2d(nd4j::graph::Context & block, const NDArray& input, NDArray& output, const int factorH, const int factorW, const bool isNCHW);
static void upsampling3d(nd4j::LaunchContext & block, const NDArray& input, NDArray& output, const int factorD, const int factorH, const int factorW, const bool isNCDHW); static void upsampling3d(nd4j::graph::Context & block, const NDArray& input, NDArray& output, const int factorD, const int factorH, const int factorW, const bool isNCDHW);
static void upsampling2dBP(nd4j::LaunchContext & block, const NDArray& gradO, NDArray& gradI, const bool isNCHW); static void upsampling2dBP(nd4j::graph::Context & block, const NDArray& gradO, NDArray& gradI, const bool isNCHW);
static void upsampling3dBP(nd4j::LaunchContext & block, const NDArray& gradO, NDArray& gradI, const bool isNCDHW); static void upsampling3dBP(nd4j::graph::Context & block, const NDArray& gradO, NDArray& gradI, const bool isNCDHW);
static void pooling2d(nd4j::LaunchContext & block, const NDArray& input, NDArray& output, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const PoolingType poolingMode, const int extraParam0); static void pooling2d(nd4j::graph::Context & block, const NDArray& input, NDArray& output, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const PoolingType poolingMode, const int extraParam0);
static void pooling3d(nd4j::LaunchContext & block, const NDArray& input, NDArray& output, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0); static void pooling3d(nd4j::graph::Context & block, const NDArray& input, NDArray& output, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0);
static void pooling2dBP(nd4j::LaunchContext & block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0); static void pooling2dBP(nd4j::graph::Context & block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0);
static void pooling3dBP(nd4j::LaunchContext & block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0); static void pooling3dBP(nd4j::graph::Context & block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0);
}; };
} }

View File

@ -153,7 +153,7 @@ void softMaxForVector(nd4j::LaunchContext * context, const NDArray& input, NDArr
if (inEWS == 1) { if (inEWS == 1) {
PRAGMA_OMP_SIMD_MAX(max) PRAGMA_OMP_SIMD_MAX(max)
for (int i = 0; i < length; i++) for (int i = 0; i < length; i++)
max = nd4j::math::nd4j_max<T>(max, outBuff[i]); max = nd4j::math::nd4j_max<T>(max, inBuff[i]);
PRAGMA_OMP_SIMD_SUM(sum) PRAGMA_OMP_SIMD_SUM(sum)
for (int i = 0; i < length; i++) { for (int i = 0; i < length; i++) {
@ -171,7 +171,7 @@ void softMaxForVector(nd4j::LaunchContext * context, const NDArray& input, NDArr
PRAGMA_OMP_SIMD_MAX(max) PRAGMA_OMP_SIMD_MAX(max)
for (int i = 0; i < length; i++) for (int i = 0; i < length; i++)
max = nd4j::math::nd4j_max<T>(max, outBuff[i * inEWS]); max = nd4j::math::nd4j_max<T>(max, inBuff[i * inEWS]);
PRAGMA_OMP_SIMD_SUM(sum) PRAGMA_OMP_SIMD_SUM(sum)
for (int i = 0; i < length; i++) { for (int i = 0; i < length; i++) {

File diff suppressed because it is too large Load Diff

View File

@ -227,15 +227,15 @@ void lstmBlockCell(const NDArray* xt, const NDArray* cLast, const NDArray* yLast
//NDArray* NDArrayFactory::create_( const char order, const std::vector<Nd4jLong> &shape, nd4j::DataType dataType, nd4j::memory::Workspace* workspace) { //NDArray* NDArrayFactory::create_( const char order, const std::vector<Nd4jLong> &shape, nd4j::DataType dataType, nd4j::memory::Workspace* workspace) {
std::vector<Nd4jLong> shape = {bS, 4*numUnits}; std::vector<Nd4jLong> shape = {bS, 4*numUnits};
auto m = NDArrayFactory::create_('c', shape, xt->dataType(), nullptr); auto m = NDArrayFactory::create('c', shape, xt->dataType());
MmulHelper::mmul(&concatOut, W, m, 1.0f, 0.0f, 'c'); //mmul: [bs, (nIn+numUnits)]* [(inSize+numUnits), 4*numUnits] = [bs, 4*numUnits] - C result array MmulHelper::mmul(&concatOut, W, &m, 1.0f, 0.0f, 'c'); //mmul: [bs, (nIn+numUnits)]* [(inSize+numUnits), 4*numUnits] = [bs, 4*numUnits] - C result array
*m += (*b); //addiRowVector m += (*b); //addiRowVector
//Note: weights are ordered [inputGate, blockInput, forgetGate, outputGate] to match TF (TF code comments state [i,f,z/ci,o] but behaviour is [i,z,f,o]) //Note: weights are ordered [inputGate, blockInput, forgetGate, outputGate] to match TF (TF code comments state [i,f,z/ci,o] but behaviour is [i,z,f,o])
auto zi = (*m)({0,0, 0, numUnits}); // z for input modulation gate, [bS, numUnits] auto zi = (m)({0,0, 0, numUnits}); // z for input modulation gate, [bS, numUnits]
auto zz = (*m)({0,0, numUnits, 2*numUnits}); // z for block input, [bS, numUnits] auto zz = (m)({0,0, numUnits, 2*numUnits}); // z for block input, [bS, numUnits]
auto zf = (*m)({0,0, 2*numUnits, 3*numUnits}); // z for forget gate, [bS, numUnits] auto zf = (m)({0,0, 2*numUnits, 3*numUnits}); // z for forget gate, [bS, numUnits]
auto zo = (*m)({0,0, 3*numUnits, 4*numUnits}); // z for output gate, [bS, numUnits] auto zo = (m)({0,0, 3*numUnits, 4*numUnits}); // z for output gate, [bS, numUnits]
if(peephole) { // add peephole connections: z + ct_1*Wc if(peephole) { // add peephole connections: z + ct_1*Wc
zi += (*cLast) * (*Wci); // add peephole connections to input gate zi += (*cLast) * (*Wci); // add peephole connections to input gate

View File

@ -57,7 +57,7 @@ namespace helpers {
ConvolutionUtils::calcPadding2D(pY, pX, oY, oX, inY, inX, params[0], params[1], params[2], params[3], params[6], params[7]); ConvolutionUtils::calcPadding2D(pY, pX, oY, oX, inY, inX, params[0], params[1], params[2], params[3], params[6], params[7]);
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor;
ConvolutionUtils::pooling2d(*block.launchContext(), *input, *values, kY, kX, sY, sX, pY, pX, dY, dX, PoolingType::MAX_POOL, 1); ConvolutionUtils::pooling2d(block, *input, *values, kY, kX, sY, sX, pY, pX, dY, dX, PoolingType::MAX_POOL, 1);
if (nullptr != indices) { if (nullptr != indices) {
// for max_pool_with_argmax // for max_pool_with_argmax

View File

@ -53,9 +53,16 @@ namespace nd4j {
dtype = nd4j::DataType::BOOL; dtype = nd4j::DataType::BOOL;
if(shape::isEmpty(x) || shape::isEmpty(y)) { if(shape::isEmpty(x) || shape::isEmpty(y)) {
//Edge case: broadcasting with empty array gives empty array output (behaviour to match TF for import cases) // this is edge case, [3, 4] + [] = []
auto empty = ConstantShapeHelper::getInstance()->emptyShapeInfo(dtype); if ((shape::isEmpty(x) && shape::rank(x) == 0) || (shape::isEmpty(y) && shape::rank(y) == 0)) {
shapeList->push_back(empty); shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor::emptyDescriptor(dtype)));
return shapeList;
}
Nd4jLong *newshape = nullptr;
ShapeUtils::evalBroadcastShapeInfo(x, y, true, newshape, block.workspace());
shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(newshape, dtype)));
} else if (shape::isScalar(x) && shape::isScalar(y)) { } else if (shape::isScalar(x) && shape::isScalar(y)) {
if (shape::rank(x) >= shape::rank(y)) { if (shape::rank(x) >= shape::rank(y)) {
shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(x, dtype))); shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(x, dtype)));

View File

@ -2873,7 +2873,7 @@ namespace simdOps {
const static functions::ReduceType reduceType = functions::ReduceType::MAX; const static functions::ReduceType reduceType = functions::ReduceType::MAX;
op_def static X startingValue(const X *input) { op_def static X startingValue(const X *input) {
return -nd4j::DataTypeUtils::max<X>(); return -nd4j::DataTypeUtils::infOrMax<X>();
} }
op_def static X merge(X old, X opOutput, X *extraParams) { op_def static X merge(X old, X opOutput, X *extraParams) {
@ -3051,7 +3051,7 @@ namespace simdOps {
const static functions::ReduceType reduceType = functions::ReduceType::MIN; const static functions::ReduceType reduceType = functions::ReduceType::MIN;
op_def static X startingValue(const X *input) { op_def static X startingValue(const X *input) {
return nd4j::DataTypeUtils::max<X>(); return nd4j::DataTypeUtils::infOrMax<X>();
} }
op_def static X merge(X old, X opOutput, X *extraParams) { op_def static X merge(X old, X opOutput, X *extraParams) {
@ -3831,7 +3831,7 @@ namespace simdOps {
} }
static _CUDA_HD inline X startingValue(const X *input) { static _CUDA_HD inline X startingValue(const X *input) {
return -nd4j::DataTypeUtils::max<X>(); return -nd4j::DataTypeUtils::infOrMax<X>();
} }
static _CUDA_HD inline functions::indexreduce::IndexValue<X> startingIndexValue(X *input) { static _CUDA_HD inline functions::indexreduce::IndexValue<X> startingIndexValue(X *input) {
@ -3890,7 +3890,7 @@ namespace simdOps {
} }
static _CUDA_HD inline X startingValue(const X *input) { static _CUDA_HD inline X startingValue(const X *input) {
return -nd4j::DataTypeUtils::max<X>(); return -nd4j::DataTypeUtils::infOrMax<X>();
} }
static _CUDA_HD inline functions::indexreduce::IndexValue<X> startingIndexValue(X *input) { static _CUDA_HD inline functions::indexreduce::IndexValue<X> startingIndexValue(X *input) {
@ -3958,7 +3958,7 @@ namespace simdOps {
} }
static _CUDA_HD inline X startingValue(const X *input) { static _CUDA_HD inline X startingValue(const X *input) {
return -nd4j::DataTypeUtils::max<X>(); return -nd4j::DataTypeUtils::infOrMax<X>();
} }
static _CUDA_HD inline functions::indexreduce::IndexValue<X> startingIndexValue(X *input) { static _CUDA_HD inline functions::indexreduce::IndexValue<X> startingIndexValue(X *input) {
@ -3984,7 +3984,7 @@ namespace simdOps {
} }
static _CUDA_HD inline X startingValue(const X *input) { static _CUDA_HD inline X startingValue(const X *input) {
return nd4j::DataTypeUtils::max<X>(); return nd4j::DataTypeUtils::infOrMax<X>();
} }
static _CUDA_HD inline functions::indexreduce::IndexValue<X> startingIndexValue(X *input) { static _CUDA_HD inline functions::indexreduce::IndexValue<X> startingIndexValue(X *input) {
@ -4040,7 +4040,7 @@ namespace simdOps {
} }
static _CUDA_HD inline X startingValue(const X *input) { static _CUDA_HD inline X startingValue(const X *input) {
return nd4j::DataTypeUtils::max<X>(); return nd4j::DataTypeUtils::infOrMax<X>();
} }
static _CUDA_HD inline functions::indexreduce::IndexValue<X> startingIndexValue(X *input) { static _CUDA_HD inline functions::indexreduce::IndexValue<X> startingIndexValue(X *input) {

View File

@ -580,6 +580,152 @@ TEST_F(BroadcastableOpsTests, broadcast_empty_1) {
ASSERT_TRUE(z.equalsTo(zExp)); ASSERT_TRUE(z.equalsTo(zExp));
} }
TEST_F(BroadcastableOpsTests, broadcast_empty_2) {
NDArray y('c', {1,4}, {1,2,3,4});
NDArray x = NDArrayFactory::create<double>('c', {0, 4});
NDArray e = NDArrayFactory::create<double>('c', {0, 4});;
nd4j::ops::multiply op;
auto status = op.execute({&x, &y}, {&x}, {}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, status);
ASSERT_TRUE(e.isSameShape(x));
ASSERT_TRUE(e.equalsTo(x));
}
TEST_F(BroadcastableOpsTests, broadcast_empty_3) {
NDArray x = NDArrayFactory::create<float>('c', {1, 0, 2});
NDArray y('c', {}, {0.1}, nd4j::DataType::FLOAT32);
NDArray e = NDArrayFactory::create<float>('c', {1, 0, 2});;
nd4j::ops::maximum op;
auto result = op.execute({&x, &y}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_TRUE(e.isSameShape(z));
ASSERT_TRUE(e.equalsTo(*z));
delete result;
}
TEST_F(BroadcastableOpsTests, broadcast_empty_4) {
NDArray x = NDArrayFactory::create<float>('c', {1, 0, 1});
NDArray y = NDArrayFactory::create<float>('c', {1, 0, 2});
NDArray e = NDArrayFactory::create<float>('c', {1, 0, 2});;
nd4j::ops::maximum op;
auto result = op.execute({&x, &y}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_TRUE(e.isSameShape(z));
ASSERT_TRUE(e.equalsTo(*z));
delete result;
}
TEST_F(BroadcastableOpsTests, broadcast_empty_5) {
NDArray x = NDArrayFactory::create<float>('c', {1, 0, 1});
NDArray y = NDArrayFactory::create<float>('c', {1, 0, 2});
NDArray e = NDArrayFactory::create<float>('c', {1, 0, 2});;
nd4j::ops::realdiv op;
auto result = op.execute({&x, &y}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_TRUE(e.isSameShape(z));
ASSERT_TRUE(e.equalsTo(*z));
delete result;
}
TEST_F(BroadcastableOpsTests, broadcast_empty_6) {
NDArray x = NDArrayFactory::create<float>('c', {1, 0, 1});
NDArray y = NDArrayFactory::create<float>('c', {1, 2}, {2, 2});
NDArray e = NDArrayFactory::create<float>('c', {1, 0, 2});;
nd4j::ops::realdiv op;
auto result = op.execute({&x, &y}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_TRUE(e.isSameShape(z));
ASSERT_TRUE(e.equalsTo(*z));
delete result;
}
TEST_F(BroadcastableOpsTests, broadcast_empty_7) {
NDArray x = NDArrayFactory::create<float>('c', {1, 0, 2, 1});
NDArray y = NDArrayFactory::create<float>('c', {1, 2, 0});
NDArray e = NDArrayFactory::create<float>('c', {1, 0, 2, 0});;
nd4j::ops::realdiv op;
auto result = op.execute({&x, &y}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_TRUE(e.isSameShape(z));
ASSERT_TRUE(e.equalsTo(*z));
delete result;
}
TEST_F(BroadcastableOpsTests, broadcast_bool_empty_1) {
NDArray y('c', {3,4}, {0,0,0,0, 1,2,3,4, 1,2,3,4});
NDArray x(nd4j::DataType::DOUBLE, y.getContext(), false);
NDArray z(nd4j::DataType::BOOL, y.getContext(), false);
NDArray zExp(nd4j::DataType::BOOL, y.getContext(), false);
nd4j::ops::greater op;
auto status = op.execute({&x, &y}, {&z}, {}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, status);
ASSERT_TRUE(z.isSameShape(zExp));
ASSERT_TRUE(z.equalsTo(zExp));
}
TEST_F(BroadcastableOpsTests, broadcast_bool_empty_2) {
NDArray y('c', {1,4}, {1,2,3,4});
NDArray x = NDArrayFactory::create<double>('c', {0, 4});
NDArray e = NDArrayFactory::create<bool>('c', {0, 4});;
nd4j::ops::greater op;
auto result = op.execute({&x, &y}, {}, {}, {});
auto z = result->at(0);
z->printShapeInfo("z");
ASSERT_EQ(Status::OK(), result->status());
ASSERT_TRUE(e.isSameShape(z));
ASSERT_TRUE(e.equalsTo(*z));
delete result;
}
TEST_F(BroadcastableOpsTests, broadcast_bool_1) { TEST_F(BroadcastableOpsTests, broadcast_bool_1) {
NDArray x('c', {3, 1, 2}, nd4j::DataType::FLOAT32); NDArray x('c', {3, 1, 2}, nd4j::DataType::FLOAT32);

View File

@ -2021,7 +2021,8 @@ TEST_F(ConvolutionTests1, vol2col_test1) {
// PointersManager manager(columnsExpected.getContext()); // PointersManager manager(columnsExpected.getContext());
// manager.printDevContentOnHost<float>(columnsExpected.getSpecialBuffer(), columnsExpected.lengthOf()); // manager.printDevContentOnHost<float>(columnsExpected.getSpecialBuffer(), columnsExpected.lengthOf());
nd4j::ops::ConvolutionUtils::vol2col(*LaunchContext::defaultContext(), volume, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); graph::Context context(1);
nd4j::ops::ConvolutionUtils::vol2col(context, volume, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW);
ASSERT_TRUE(columns.equalsTo(columnsExpected)); ASSERT_TRUE(columns.equalsTo(columnsExpected));
} }
@ -2052,7 +2053,8 @@ TEST_F(ConvolutionTests1, vol2col_test2) {
-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.}); -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.});
nd4j::ops::ConvolutionUtils::vol2col(*LaunchContext::defaultContext(), volume, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); graph::Context context(1);
nd4j::ops::ConvolutionUtils::vol2col(context, volume, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW);
ASSERT_TRUE(columns.equalsTo(columnsExpected)); ASSERT_TRUE(columns.equalsTo(columnsExpected));
} }

View File

@ -1302,8 +1302,8 @@ TEST_F(DeclarableOpsTests10, broadcast_to_test6) {
TEST_F(DeclarableOpsTests10, broadcast_to_test7) { TEST_F(DeclarableOpsTests10, broadcast_to_test7) {
auto input = NDArrayFactory::create<double>(10.f); auto input = NDArrayFactory::create<double>(10.f);
auto shape = NDArrayFactory::create<double>(0.f); auto shape = NDArrayFactory::create<Nd4jLong>(1);
auto exp = NDArrayFactory::create<double>(10.f); auto exp = NDArrayFactory::create<double>('c', {1}, {10.});
nd4j::ops::broadcast_to op; nd4j::ops::broadcast_to op;
auto results = op.execute({&input, &shape}, {}, {}, {}); auto results = op.execute({&input, &shape}, {}, {}, {});
@ -2261,8 +2261,8 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_1) {
NDArray x('c', {2,3}, {-63.80f, -63.75f, -63.70f, -63.5f, 0.0f, 0.1f}, nd4j::DataType::FLOAT32); NDArray x('c', {2,3}, {-63.80f, -63.75f, -63.70f, -63.5f, 0.0f, 0.1f}, nd4j::DataType::FLOAT32);
NDArray exp('c', {2,3}, {-63.75f, -63.75f, -63.75f, -63.251953f, 0.0f, 0.0f}, nd4j::DataType::FLOAT32); NDArray exp('c', {2,3}, {-63.75f, -63.75f, -63.75f, -63.251953f, 0.0f, 0.0f}, nd4j::DataType::FLOAT32);
NDArray min('c', {0}, {-63.65f}, nd4j::DataType::FLOAT32); NDArray min('c', {}, {-63.65f}, nd4j::DataType::FLOAT32);
NDArray max('c', {0}, {0.1f}, nd4j::DataType::FLOAT32); NDArray max('c', {}, {0.1f}, nd4j::DataType::FLOAT32);
nd4j::ops::fake_quant_with_min_max_vars op; nd4j::ops::fake_quant_with_min_max_vars op;
auto results = op.execute({&x, &min, &max}, {}, {}); auto results = op.execute({&x, &min, &max}, {}, {});

View File

@ -136,7 +136,7 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test3) {
NDArray dLdpExp('c', {2,3,4}, {-12.49997,-13.04346, -13.63635, -14.28571,-14.99999,-15.78947, -16.66666, -17.64705,-18.75 ,-20. , -21.42857, -23.07692, NDArray dLdpExp('c', {2,3,4}, {-12.49997,-13.04346, -13.63635, -14.28571,-14.99999,-15.78947, -16.66666, -17.64705,-18.75 ,-20. , -21.42857, -23.07692,
-24.99999,-27.27272, -29.99999, -33.33332,-37.49999,-42.85713, -49.99998, -59.99998,-74.99995,-99.99992,-149.99986,-299.99911}); -24.99999,-27.27272, -29.99999, -33.33332,-37.49999,-42.85713, -49.99998, -59.99998,-74.99995,-99.99992,-149.99986,-299.99911});
NDArray dLdwExp('c', {0}, {-227.77286}); NDArray dLdwExp('c', {}, {-227.77286});
NDArray dLdlExp('c', {2,3,4}, {1.58903, 1.22117, 0.99621, 0.82911, 0.69315, 0.57634, 0.47223, 0.37689, 0.28768, 0.20273, 0.12058, 0.04002, NDArray dLdlExp('c', {2,3,4}, {1.58903, 1.22117, 0.99621, 0.82911, 0.69315, 0.57634, 0.47223, 0.37689, 0.28768, 0.20273, 0.12058, 0.04002,
-0.04002,-0.12058,-0.20273,-0.28768,-0.37689,-0.47223,-0.57634,-0.69315,-0.82911,-0.99621,-1.22117,-1.58903}); -0.04002,-0.12058,-0.20273,-0.28768,-0.37689,-0.47223,-0.57634,-0.69315,-0.82911,-0.99621,-1.22117,-1.58903});
@ -261,7 +261,7 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test7) {
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE); NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights(nd4j::DataType::DOUBLE); NDArray weights(nd4j::DataType::DOUBLE);
NDArray dLdwExp('c', {0}, {0.}); NDArray dLdwExp('c', {}, {0.});
predictions.linspace(0.04, 0.04); predictions.linspace(0.04, 0.04);
labels.linspace(1); labels.linspace(1);
@ -583,7 +583,7 @@ TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test3) {
NDArray dLdpExp('c', {2,3,4}, {-0.96, -1.92, -2.88, -3.84, -4.8 , -5.76, -6.72, -7.68, -8.64, -9.6 ,-10.56,-11.52, NDArray dLdpExp('c', {2,3,4}, {-0.96, -1.92, -2.88, -3.84, -4.8 , -5.76, -6.72, -7.68, -8.64, -9.6 ,-10.56,-11.52,
-12.48,-13.44,-14.4 ,-15.36,-16.32,-17.28,-18.24,-19.2 ,-20.16,-21.12,-22.08,-23.04}); -12.48,-13.44,-14.4 ,-15.36,-16.32,-17.28,-18.24,-19.2 ,-20.16,-21.12,-22.08,-23.04});
NDArray dLdwExp('c', {0}, {4515.84}); NDArray dLdwExp('c', {}, {4515.84});
predictions.linspace(0.04, 0.04); predictions.linspace(0.04, 0.04);
labels.linspace(1); labels.linspace(1);
@ -702,7 +702,7 @@ TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test7) {
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE); NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights(nd4j::DataType::DOUBLE); NDArray weights(nd4j::DataType::DOUBLE);
NDArray dLdwExp('c', {0}, {0.}); NDArray dLdwExp('c', {}, {0.});
predictions.linspace(0.04, 0.04); predictions.linspace(0.04, 0.04);
labels.linspace(1); labels.linspace(1);
@ -1031,7 +1031,7 @@ TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test3) {
NDArray dLdpExp('c', {2,3,4}, {-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5, NDArray dLdpExp('c', {2,3,4}, {-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,
-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5}); -0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5});
NDArray dLdwExp('c', {0}, {288.}); NDArray dLdwExp('c', {}, {288.});
predictions.linspace(0.04, 0.04); predictions.linspace(0.04, 0.04);
labels.linspace(1); labels.linspace(1);
@ -1150,7 +1150,7 @@ TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test7) {
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE); NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights(nd4j::DataType::DOUBLE); NDArray weights(nd4j::DataType::DOUBLE);
NDArray dLdwExp('c', {0}, {0.}); NDArray dLdwExp('c', {}, {0.});
predictions.linspace(0.04, 0.04); predictions.linspace(0.04, 0.04);
labels.linspace(1); labels.linspace(1);
@ -1519,7 +1519,7 @@ TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test3) {
NDArray dLdpExp('c', {2,3,4}, {-0.18499,-0.53 ,-0.875 ,-1.22 ,-1.56501,-1.91002,-2.25504,-2.60008,-2.94514,-3.29023,-3.63534,-3.98048, NDArray dLdpExp('c', {2,3,4}, {-0.18499,-0.53 ,-0.875 ,-1.22 ,-1.56501,-1.91002,-2.25504,-2.60008,-2.94514,-3.29023,-3.63534,-3.98048,
-4.32566,-4.67087,-5.01613,-5.36143,-5.70677,-6.05217,-6.39762,-6.74313,-7.0887 ,-7.43432,-7.78001,-8.12577}); -4.32566,-4.67087,-5.01613,-5.36143,-5.70677,-6.05217,-6.39762,-6.74313,-7.0887 ,-7.43432,-7.78001,-8.12577});
NDArray dLdwExp('c', {0}, {-91.52109}); NDArray dLdwExp('c', {}, {-91.52109});
NDArray dLdlExp('c', {2,3,4}, {0.028, 0.014, -0., -0.014,-0.028, -0.042, -0.056, -0.07 ,-0.084, -0.098, -0.112, -0.126, NDArray dLdlExp('c', {2,3,4}, {0.028, 0.014, -0., -0.014,-0.028, -0.042, -0.056, -0.07 ,-0.084, -0.098, -0.112, -0.126,
-0.14 , -0.154, -0.168, -0.182,-0.196, -0.21 , -0.224, -0.238,-0.252, -0.266, -0.28 , -0.294}); -0.14 , -0.154, -0.168, -0.182,-0.196, -0.21 , -0.224, -0.238,-0.252, -0.266, -0.28 , -0.294});
@ -1642,7 +1642,7 @@ TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test7) {
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE); NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights(nd4j::DataType::DOUBLE); NDArray weights(nd4j::DataType::DOUBLE);
NDArray dLdwExp('c', {0}, {0.}); NDArray dLdwExp('c', {}, {0.});
logits.linspace(-0.08, 0.04); logits.linspace(-0.08, 0.04);
labels.linspace(1); labels.linspace(1);
@ -2001,10 +2001,10 @@ TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test3) {
NDArray labels('c', {4}, {0,0,1,0}, nd4j::DataType::INT32); NDArray labels('c', {4}, {0,0,1,0}, nd4j::DataType::INT32);
NDArray logits('c', {4}, nd4j::DataType::DOUBLE); NDArray logits('c', {4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {0}, nd4j::DataType::DOUBLE); NDArray weights('c', {}, {0}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {4}, {0.125, 0.125, -0.375, 0.125}); NDArray dLdpExp('c', {4}, {0.125, 0.125, -0.375, 0.125});
NDArray dLdwExp('c', {0}, {1.38629}); NDArray dLdwExp('c', {}, {1.38629});
logits = 2.; logits = 2.;
weights.assign(0.5); weights.assign(0.5);
@ -2032,10 +2032,10 @@ TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test4) {
NDArray labels('c', {4}, {0,0,1,0}, nd4j::DataType::INT32); NDArray labels('c', {4}, {0,0,1,0}, nd4j::DataType::INT32);
NDArray logits('c', {4}, nd4j::DataType::DOUBLE); NDArray logits('c', {4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {0}, nd4j::DataType::DOUBLE); NDArray weights('c', {}, {0}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {4}, {0.23521, 0.2448 , -0.7452 , 0.26519}); NDArray dLdpExp('c', {4}, {0.23521, 0.2448 , -0.7452 , 0.26519});
NDArray dLdwExp('c', {0}, {0.}); NDArray dLdwExp('c', {}, {0.});
logits.linspace(-0.08, 0.04); logits.linspace(-0.08, 0.04);
weights = 0.5; weights = 0.5;
@ -2466,7 +2466,7 @@ TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test2) {
///////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test3) { TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test3) {
NDArray labels('c', {0}, {1}, nd4j::DataType::INT64); NDArray labels('c', {}, {1}, nd4j::DataType::INT64);
NDArray logits('c', {2}, {-0.2, 0.3}); NDArray logits('c', {2}, {-0.2, 0.3});
NDArray dLdpExp('c', {2}, {0.37754, -0.37754}); NDArray dLdpExp('c', {2}, {0.37754, -0.37754});

View File

@ -158,10 +158,10 @@ TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test4) {
NDArray labels('c', {1,4}, {-0.1, 0.3, 2, -1.4}); NDArray labels('c', {1,4}, {-0.1, 0.3, 2, -1.4});
NDArray predictions('c', {1,4}, nd4j::DataType::DOUBLE); NDArray predictions('c', {1,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {0}, nd4j::DataType::DOUBLE); NDArray weights('c', {}, {0.}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {1,4}, {0.05, -0.15, -1., 0.7}); NDArray dLdpExp('c', {1,4}, {0.05, -0.15, -1., 0.7});
NDArray dLdwExp('c', {0}, {1.3}); NDArray dLdwExp('c', {}, {1.3});
NDArray dLdlExp('c', {1,4}, {0.2, 0.1, -0. , -0.1}); NDArray dLdlExp('c', {1,4}, {0.2, 0.1, -0. , -0.1});
predictions.linspace(-0.4, 0.2); predictions.linspace(-0.4, 0.2);
@ -369,10 +369,10 @@ TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test9) {
TEST_F(DeclarableOpsTests12, hinge_loss_14) { TEST_F(DeclarableOpsTests12, hinge_loss_14) {
NDArray logits('c', {3,4}, nd4j::DataType::DOUBLE); NDArray logits('c', {3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {0}, {1.}); NDArray weights('c', {}, {1.});
NDArray labels('c', {3,4}, {0,1,1,0,1,0,1,0,1,0,1,0}); NDArray labels('c', {3,4}, {0,1,1,0,1,0,1,0,1,0,1,0});
NDArray output('c', {0}, nd4j::DataType::DOUBLE); NDArray output('c', {}, {0.}, nd4j::DataType::DOUBLE);
logits.linspace(1.); logits.linspace(1.);
weights.assign(1.); weights.assign(1.);
@ -594,7 +594,7 @@ TEST_F(DeclarableOpsTests12, TestMinimumBP_1) {
TEST_F(DeclarableOpsTests12, reverse_test15) { TEST_F(DeclarableOpsTests12, reverse_test15) {
NDArray x('c', {5}, {1,2,3,4,5}, nd4j::DataType::DOUBLE); NDArray x('c', {5}, {1,2,3,4,5}, nd4j::DataType::DOUBLE);
NDArray axis('c', {0}, {0}, nd4j::DataType::INT32); NDArray axis('c', {}, {0}, nd4j::DataType::INT32);
NDArray z('c', {5}, nd4j::DataType::DOUBLE); NDArray z('c', {5}, nd4j::DataType::DOUBLE);
NDArray exp('c', {5}, {5,4,3,2,1}, nd4j::DataType::DOUBLE); NDArray exp('c', {5}, {5,4,3,2,1}, nd4j::DataType::DOUBLE);

View File

@ -123,6 +123,7 @@ TEST_F(DeclarableOpsTests14, Test_EvalReductionShape_1) {
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0); auto z = result->at(0);
z->printIndexedBuffer("Reduced shape");
ASSERT_EQ(e, *z); ASSERT_EQ(e, *z);
delete result; delete result;
@ -213,3 +214,199 @@ TEST_F(DeclarableOpsTests14, Test_scalar_broadcast_2) {
delete result; delete result;
} }
TEST_F(DeclarableOpsTests14, test_empty_fill_1) {
auto x = NDArrayFactory::empty<int>();
auto y = NDArrayFactory::create<int>(1);
nd4j::ops::fill op;
auto result = op.execute({&x, &y}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_EQ(y, *z);
delete result;
}
TEST_F(DeclarableOpsTests14, test_lstmBlockCell_1) {
auto a = NDArrayFactory::create<float>('c', {1, 5}, {0.7787856f, 0.80119777f, 0.72437465f, 0.23089433f, 0.72714126f});
auto b = NDArrayFactory::create<float>('c', {1, 3});
auto c = NDArrayFactory::create<float>('c', {1, 3});
auto d = NDArrayFactory::create<float>('c', {8, 12}, {-0.15320599,-0.120416045,0.33126968,0.13921785,-0.32313538,-0.43956736,0.4756174,0.4335605,-0.5450856,-0.3943429,-0.28687626,0.068032146,-0.2793799,0.17298919,-0.36553562,-0.097853184,-0.2544747,-0.39872527,-0.14556861,-0.31479517,0.2559092,0.47166896,-0.31330687,0.47313118,0.5134543,-0.4678212,-0.12853557,0.26142156,0.43472284,-0.42842552,-0.1895876,0.538689,0.508651,-0.020272732,0.112327516,0.2704304,-0.046546757,0.32570732,-0.15148133,-0.19145513,0.18631572,-0.024152994,0.41603214,-0.3421499,0.0106860995,-0.2966229,-0.36713937,0.25841123,0.0843398,0.49082482,0.10800403,0.1874243,-0.26379472,-0.22531849,0.24924624,0.23119557,0.49940765,-0.051413506,0.20315129,-0.41888732,0.44097036,0.40453392,0.013338983,0.23434466,0.23942488,0.47894,-0.19898453,0.09253675,-0.032358468,-0.15213022,-0.3441009,-0.15600958,-0.08235118,0.12165731,-0.4481289,-0.4842423,-0.45797008,-0.4606034,0.08163166,-0.2981107,0.50207126,0.44195646,0.13850057,0.072246075,-0.34388685,0.030900061,0.35821778,0.47900867,0.5094063,0.23683065,0.18020362,-0.1369732,0.015235603,0.2786904,0.07954317,0.12543976});
auto e = NDArrayFactory::create<float>('c', {3});
auto f = NDArrayFactory::create<float>('c', {3});
auto g = NDArrayFactory::create<float>('c', {3});
auto h = NDArrayFactory::create<float>('c', {12});
auto z0 = NDArrayFactory::create<float>('c', {1, 3});
auto z1 = NDArrayFactory::create<float>('c', {1, 3});
auto z2 = NDArrayFactory::create<float>('c', {1, 3});
auto z3 = NDArrayFactory::create<float>('c', {1, 3});
auto z4 = NDArrayFactory::create<float>('c', {1, 3});
auto z5 = NDArrayFactory::create<float>('c', {1, 3});
auto z6 = NDArrayFactory::create<float>('c', {1, 3});
nd4j::ops::lstmBlockCell op;
auto result = op.execute({&a, &b, &c, &d, &e, &f, &g, &h}, {&z0, &z1, &z2, &z3, &z4, &z5, &z6}, {1.0, -1.0}, {0}, {});
ASSERT_EQ(Status::OK(), result);
}
TEST_F(DeclarableOpsTests14, test_empty_stack_1) {
auto x = NDArrayFactory::create<float>('c', {0});
auto e = NDArrayFactory::create<float>('c', {1, 0});
nd4j::ops::stack op;
auto result = op.execute({&x}, {}, {0});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_EQ(e, *z);
nd4j::ops::reduce_min sumOp;
auto res2 = sumOp.execute({&e}, {1.}, {1});
ASSERT_EQ(res2->status(), Status::OK());
auto out = res2->at(0);
out->printShapeInfo("ReduceSum empty shape with keep dims");
out->printIndexedBuffer("ReduceSum scalar");
ASSERT_EQ(out->e<float>(0), DataTypeUtils::infOrMax<float>());
delete res2;
delete result;
}
TEST_F(DeclarableOpsTests14, test_empty_stack_2) {
auto x = NDArrayFactory::empty<float>();
auto e = NDArrayFactory::create<float>('c', {0});
nd4j::ops::stack op;
auto result = op.execute({&x}, {}, {0});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_EQ(e, *z);
delete result;
}
TEST_F(DeclarableOpsTests14, test_empty_stack_3) {
auto x = NDArrayFactory::empty<float>();
auto e = NDArrayFactory::create<float>('c', {2, 0});
nd4j::ops::stack op;
auto result = op.execute({&x, &x}, {}, {0});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_EQ(e, *z);
delete result;
}
TEST_F(DeclarableOpsTests14, test_empty_stack_4) {
auto x = NDArrayFactory::create<float>('c', {0});
auto e = NDArrayFactory::create<float>('c', {2, 0});
nd4j::ops::stack op;
auto result = op.execute({&x, &x}, {}, {0});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_EQ(e, *z);
delete result;
}
TEST_F(DeclarableOpsTests14, test_empty_reduce_min_1) {
auto e = NDArrayFactory::create<float>('c', {1, 0});
nd4j::ops::reduce_min sumOp;
auto res2 = sumOp.execute({&e}, {1.}, {1});
ASSERT_EQ(res2->status(), Status::OK());
auto out = res2->at(0);
ASSERT_EQ(out->e<float>(0), DataTypeUtils::infOrMax<float>());
delete res2;
}
TEST_F(DeclarableOpsTests14, test_empty_reduce_max_1) {
auto e = NDArrayFactory::create<float>('c', {1, 0});
nd4j::ops::reduce_max sumOp;
auto res2 = sumOp.execute({&e}, {1.}, {1});
ASSERT_EQ(res2->status(), Status::OK());
auto out = res2->at(0);
ASSERT_EQ(out->e<float>(0), -DataTypeUtils::infOrMax<float>());
delete res2;
}
TEST_F(DeclarableOpsTests14, test_empty_reduce_sum_1) {
auto e = NDArrayFactory::create<float>('c', {1, 0});
nd4j::ops::reduce_sum sumOp;
auto res2 = sumOp.execute({&e}, {1.}, {1});
ASSERT_EQ(res2->status(), Status::OK());
auto out = res2->at(0);
ASSERT_EQ(out->e<float>(0), 0.f);
delete res2;
}
TEST_F(DeclarableOpsTests14, test_empty_reduce_mean_1) {
auto e = NDArrayFactory::create<float>('c', {1, 0});
nd4j::ops::reduce_mean sumOp;
auto res2 = sumOp.execute({&e}, {1.}, {1});
ASSERT_EQ(res2->status(), Status::OK());
auto out = res2->at(0);
out->printShapeInfo("ReduceMean empty shape with keep dims");
out->printIndexedBuffer("ReduceMean scalar");
ASSERT_EQ(out->e<float>(0), 0.f);
delete res2;
}
TEST_F(DeclarableOpsTests14, test_empty_argmax_1) {
auto x = NDArrayFactory::create<float>('c', {1, 0});
auto y = NDArrayFactory::create<int>(0);
auto e = NDArrayFactory::create<Nd4jLong>('c', {0});
nd4j::ops::argmax op;
//nd4j::ops::reduce_max op;
auto result = op.execute({&x, &y}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
z->printShapeInfo("Z");
ASSERT_EQ(e, *z);
delete result;
}
TEST_F(DeclarableOpsTests14, test_empty_argmax_2) {
auto x = NDArrayFactory::create<float>('c', {1, 0});
auto y = NDArrayFactory::create<int>(1);
nd4j::ops::argmax op;
try {
auto result = op.execute({&x, &y}, {&y}, {}, {}, {});
ASSERT_TRUE(false);
} catch (std::exception &e) {
//
}
}
TEST_F(DeclarableOpsTests14, test_empty_tanh_5) {
auto x = NDArrayFactory::create<float>('c', {32, 0});
nd4j::ops::tanh op;
auto result = op.execute({&x}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_TRUE(x.isSameShape(z));
ASSERT_EQ(x, *z);
delete result;
}

View File

@ -905,9 +905,9 @@ TEST_F(DeclarableOpsTests4, Test_StridedSlice_Alex_2) {
TEST_F(DeclarableOpsTests4, Test_StridedSlice_Alex_3) { TEST_F(DeclarableOpsTests4, Test_StridedSlice_Alex_3) {
auto x = NDArrayFactory::create<double>('c', {1}, {10}); auto x = NDArrayFactory::create<double>('c', {1}, {10});
auto begin = NDArrayFactory::create<double>('c', {1}, {0.}); auto begin = NDArrayFactory::create<int>('c', {1}, {(int)0});
auto end = NDArrayFactory::create<double>('c', {1}, {0.}); auto end = NDArrayFactory::create<int>('c', {1}, {(int)0});
auto stride = NDArrayFactory::create<double>('c', {1}, {1}); auto stride = NDArrayFactory::create<int>('c', {1}, {1});
//x.linspace(1); //x.linspace(1);
//auto exp = NDArrayFactory::create<double>('c', {1,3,4,5}); //auto exp = NDArrayFactory::create<double>('c', {1,3,4,5});
//exp.linspace(1); //exp.linspace(1);

View File

@ -2421,6 +2421,26 @@ TEST_F(DeclarableOpsTests5, log_softmax_test11) {
delete results; delete results;
} }
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, log_softmax_test12) {
auto input = NDArrayFactory::create<double>('c', {1, 4}, {0.1869, -1.4918, -0.6497, -0.8864});
auto expOutput = NDArrayFactory::create<double>('c', {1, 4}, {-0.6738, -2.3525, -1.5104, -1.7472});
for (int i = 0; i < 10; ++i)
{
nd4j::ops::log_softmax op;
auto results = op.execute({&input}, {}, {}, {}, false, nd4j::DataType::DOUBLE);
auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(expOutput.isSameShape(z));
ASSERT_TRUE(expOutput.equalsTo(z, 1e-4));
delete results;
}
}
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, log_softmax_bp_test1) { TEST_F(DeclarableOpsTests5, log_softmax_bp_test1) {

View File

@ -109,7 +109,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_4) {
auto e = NDArrayFactory::create<double>('c', {1}, {0.}); auto e = NDArrayFactory::create<double>('c', {1}, {0.});
auto s = NDArrayFactory::create<double>('c', {1}, {1.0}); auto s = NDArrayFactory::create<double>('c', {1}, {1.0});
//auto exp = NDArrayFactory::create<double>('c', {2}, {1.0f, 2.0f}); auto exp = NDArrayFactory::create<double>(10);
//matrix.linspace(1); //matrix.linspace(1);
@ -119,7 +119,8 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_4) {
auto z = result->at(0); auto z = result->at(0);
z->printShapeInfo("SS OS shape"); z->printShapeInfo("SS OS shape");
ASSERT_TRUE(z->isEmpty()); z->printIndexedBuffer("SS OS out");
ASSERT_TRUE(z->equalsTo(exp));
//ASSERT_EQ(exp, *z); //ASSERT_EQ(exp, *z);
delete result; delete result;

View File

@ -608,6 +608,24 @@ TEST_F(DeclarableOpsTests9, concat_test15) {
delete result; delete result;
} }
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, concat_test16) {
auto x = NDArrayFactory::create<double>('c', {0,2,3});
auto y = NDArrayFactory::create<double>('c', {0,2,3});
auto exp = NDArrayFactory::create<double>('c', {0,2,3});
nd4j::ops::concat op;
auto result = op.execute({&x, &y}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
ASSERT_TRUE(exp.isSameShape(z));
delete result;
}
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, tile_bp_test3) { TEST_F(DeclarableOpsTests9, tile_bp_test3) {

View File

@ -59,7 +59,8 @@ TEST_F(EmptyTests, Test_Create_Empty_2) {
} }
TEST_F(EmptyTests, Test_Concat_1) { TEST_F(EmptyTests, Test_Concat_1) {
auto empty = NDArrayFactory::empty_<float>(); // auto empty = NDArrayFactory::empty_<float>();
auto empty = new NDArray('c', {0}, nd4j::DataType::FLOAT32);//NDArrayFactory::create_<float>('c', {(Nd4jLong)0}};
auto vector = NDArrayFactory::create_<float>('c', {1}, {1.0f}); auto vector = NDArrayFactory::create_<float>('c', {1}, {1.0f});
ASSERT_TRUE(empty->isEmpty()); ASSERT_TRUE(empty->isEmpty());
@ -82,9 +83,9 @@ TEST_F(EmptyTests, Test_Concat_1) {
TEST_F(EmptyTests, Test_Concat_2) { TEST_F(EmptyTests, Test_Concat_2) {
auto empty = NDArrayFactory::empty_<float>(); auto empty = new NDArray('c', {0}, nd4j::DataType::FLOAT32); //NDArrayFactory::empty_<float>();
auto scalar1 = NDArrayFactory::create_<float>(1.0f); auto scalar1 = NDArrayFactory::create_<float>('c', {1}, {1.0f});
auto scalar2 = NDArrayFactory::create_<float>(2.0f); auto scalar2 = NDArrayFactory::create_<float>('c', {1}, {2.0f});
auto exp = NDArrayFactory::create<float>('c', {2}, {1.f, 2.f}); auto exp = NDArrayFactory::create<float>('c', {2}, {1.f, 2.f});
ASSERT_TRUE(empty->isEmpty()); ASSERT_TRUE(empty->isEmpty());
@ -139,6 +140,23 @@ TEST_F(EmptyTests, Test_Reshape_2) {
delete result; delete result;
} }
TEST_F(EmptyTests, Test_Reshape_3) {
auto x = NDArrayFactory::create<float>('c', {1, 0, 0, 2});
auto y = NDArrayFactory::create<int>('c', {2}, {10, 0});
auto e = NDArrayFactory::create<float>('c', {10, 0});
nd4j::ops::reshape op;
auto result = op.execute({&x, &y}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_TRUE(e.isSameShape(z));
ASSERT_EQ(e, *z);
delete result;
}
TEST_F(EmptyTests, Test_dup_1) { TEST_F(EmptyTests, Test_dup_1) {
auto empty = NDArrayFactory::empty<int>(); auto empty = NDArrayFactory::empty<int>();
auto dup = empty.dup(); auto dup = empty.dup();
@ -148,3 +166,47 @@ TEST_F(EmptyTests, Test_dup_1) {
delete dup; delete dup;
} }
TEST_F(EmptyTests, test_shaped_empty_1) {
auto empty = NDArrayFactory::create<float>('c', {2, 0, 3});
std::vector<Nd4jLong> shape = {2, 0, 3};
ASSERT_EQ(nd4j::DataType::FLOAT32, empty.dataType());
ASSERT_EQ(0, empty.lengthOf());
ASSERT_TRUE(empty.isEmpty());
ASSERT_EQ(shape, empty.getShapeAsVector());
ASSERT_EQ(3, empty.rankOf());
}
TEST_F(EmptyTests, test_shaped_empty_2) {
auto empty = NDArrayFactory::create<float>('c', {0, 3});
std::vector<Nd4jLong> shape = {0, 3};
ASSERT_EQ(nd4j::DataType::FLOAT32, empty.dataType());
ASSERT_EQ(0, empty.lengthOf());
ASSERT_TRUE(empty.isEmpty());
ASSERT_EQ(shape, empty.getShapeAsVector());
ASSERT_EQ(2, empty.rankOf());
}
TEST_F(EmptyTests, test_shaped_empty_3) {
auto empty = NDArrayFactory::create<float>('c', {0});
std::vector<Nd4jLong> shape = {0};
ASSERT_EQ(nd4j::DataType::FLOAT32, empty.dataType());
ASSERT_EQ(0, empty.lengthOf());
ASSERT_TRUE(empty.isEmpty());
ASSERT_EQ(shape, empty.getShapeAsVector());
ASSERT_EQ(1, empty.rankOf());
}
TEST_F(EmptyTests, test_shaped_empty_4) {
auto shape = ConstantShapeHelper::getInstance()->vectorShapeInfo(0, nd4j::DataType::FLOAT32);
shape::printShapeInfoLinear("shape", shape);
NDArray array(shape, true, nd4j::LaunchContext::defaultContext());
std::vector<Nd4jLong> shapeOf({0});
ASSERT_TRUE(array.isEmpty());
ASSERT_EQ(1, array.rankOf());
ASSERT_EQ(shapeOf, array.getShapeAsVector());
}

View File

@ -668,3 +668,47 @@ TEST_F(LegacyOpsTests, test_inverse_broadcast_2) {
delete row; delete row;
delete erow; delete erow;
} }
TEST_F(LegacyOpsTests, test_legacy_reduce_empty_1) {
auto x = NDArrayFactory::create<float>('c', {2, 0, 3});
auto z = NDArrayFactory::create<float>('c', {2, 3});
auto e = NDArrayFactory::create<float>('c', {2, 3});
int dim = 1;
NativeOpExecutioner::execReduceSame(LaunchContext::defaultContext(), reduce::SameOps::Sum, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), &dim, 1, x.shapeInfo(), nullptr);
ASSERT_EQ(e, z);
}
TEST_F(LegacyOpsTests, test_legacy_reduce_empty_2) {
auto x = NDArrayFactory::create<float>('c', {2, 0, 3});
auto z = NDArrayFactory::create<float>('c', {2, 3});
auto e = NDArrayFactory::create<float>('c', {2, 3});
e.assign(std::numeric_limits<float>::infinity());
int dim = 1;
NativeOpExecutioner::execReduceSame(LaunchContext::defaultContext(), reduce::SameOps::Min, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), &dim, 1, x.shapeInfo(), nullptr);
ASSERT_EQ(e, z);
}
TEST_F(LegacyOpsTests, test_legacy_reduce_empty_3) {
auto x = NDArrayFactory::create<float>('c', {2, 0, 3});
auto z = NDArrayFactory::create<float>('c', {2, 3});
auto e = NDArrayFactory::create<float>('c', {2, 3});
e.assign(-std::numeric_limits<float>::infinity());
int dim = 1;
NativeOpExecutioner::execReduceSame(LaunchContext::defaultContext(), reduce::SameOps::Max, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), &dim, 1, x.shapeInfo(), nullptr);
ASSERT_EQ(e, z);
}
TEST_F(LegacyOpsTests, test_legacy_transform_float_1) {
auto x = NDArrayFactory::create<float>('c', {1, 0, 4});
NativeOpExecutioner::execTransformFloat(LaunchContext::defaultContext(), transform::FloatOps::RSqrt, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, nullptr, nullptr);
}

Some files were not shown because too many files have changed in this diff Show More