Dev branch merge: dev_20190606 (#7904)

* correct logsoftmax looss (#2)

* Small SameDiff listener fix (#4)

* Various fixes (#6)

* #7839 Fix for asXMatrix and tests

* #7866 EmbeddingSequenceLayer dtype fix + test

* #7856 SameDiff save/load stream methods

* #7859 RegressionEvaluation rank 4 fix + tests + axis configuration

* EvaluationBinary 3d/4d

* More evaluation 3d/4d tests

* #7847 Evaluation empty checks

* Small test ifx

* #7848 Fix median edge case

* Improve DL4J samediff layer tests

* [WIP] FastText wrapper implemented (#8)

* FastText implemented

* Some fixes

* Fix shapes for wordsNearest

* Validation of input vectors

* Fixes

* Fixed test

* Thread tagged

* Some tweaks

* setContextClassLoader for DeallocatorServiceThread

* Numpy format tests (#1)

* Various fixes (#11)

* #7852 SameDiff gather fix

* #7892 SameDiff placeholder to constant conversion

* #7890 validate input rank for MLN/CG init methods

* Fix broken permute shape calculation

* Permute and gather fixes

* Tests

* #7850 LogSumExp fix + test

* Handful of test fixes

* Empty arrays with non-scalar shapes (#10)

* minor rearrangements for lambdas

* empty tensors with non-scalar shapes

* numpy empty tensors with non-scalar shapes

* few more empty tweaks

* Small fixes

* conv3d signature update

* micro fix in batchnorm mkldnn

* Import fixes

* Fix

* MKL-DNN update

* Small fill fix

* fill with empty input + test

* Fixes

* Small error improvement

* Fix

* one special test

* couple of fixes for lstm

* Rewrite TFGraphMapper.getNDArrayFromTensor to be maintainable and less error prone

* Fixes

* FP16

* Unsigned

* BFloat16

* Fill op - empty tweaks

* - couple of fixes for empty arrays construction
- stack updated

* strided slice fix

* one transform test

* provide method for reducing shapeInfo in case of input array is empty

* Fixed reduceAlongDimensions to use empty input properly.

* couple of broadcast tests

* couple of tests broadcast tests + tweak to make them pass

* add check of non-empty to methods producing sub-arrays

* Fixed reshapeC with zeros in shape.

* complete empty check in reduce_... legacy ops

* Concat and cumsum/prod

* Tweak to empty shape inference on import

* add empty check to the rest of reduce legacy ops

* one more test

* correct typo in evalReduceShapeInfoEmpty

* Added tests for reduce_* ops to tests with zero shapes.

* few more tests for empty reductions

* Fixed strided_slice op with empty case and tests.

* one more empty reduction test

* Fixed strided_slice test.

* add empty check to NDArray::reshapei

* infOrMax

* empty min/max with infinity tests

* made unstack working correctly with empty arrays

* few IndexReduce tests + tweaks for empty shapes

* add test for empty concat

* few tests fixed

* Validation fix for reductions on empty shapes

* Reverse fix

* Reduction shape calc fixes

* SameDiff.generateOutputVariable: don't use shape function to determine number of outputs

* Range fix

* - NDArray constructor updated for scalars/empty arrays
- few tests fixed

* More fixes

* Empty creator fixes

* concat fix

* concat fix

* TF import tests: allow 'both all NaN' and 'both all inf' to pass

* Slice, zero fraction, and reshape fixes

* transpose, gather

* Zero fraction

* scalar cast fix

* Empty reduction axis support

* few more tests fixed

* Fixed input checks conforming with TF for concat op and tests.

* few tests fixed

* matmul scalar shape fix

* Fixed checkout for data type and scalarity with concat to allow non-empty scalars with vector concats.

* broadcast bool fix

* few more tests

* few more tests

* correct evalReduceShapeInfoEmpty

* argmax/argmin + tests

* one more empty edge case + one more test

* argmax/argmin/realdiv_bp tweaks

* empty reshape test + fix

* Helper fixes

* Small fixes

* Gather test fix

* Gather test fix

* Small fixes

* reduce scalar zero values

* scalar mean workaround

* Remove debug code

* along dim mean workaround

* one more test

* - equalsTo() tweak for empty arrays
- one more test

* broadcast tweaks
master
Alex Black 2019-06-15 21:34:34 +10:00 committed by GitHub
parent 32e5cc1945
commit 68ea5f3688
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
169 changed files with 7207 additions and 3633 deletions

View File

@ -23,6 +23,7 @@ import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.layers.EmbeddingLayer;
import org.deeplearning4j.nn.conf.layers.EmbeddingSequenceLayer;
@ -283,7 +284,6 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
net.computeGradientAndScore();
net2.computeGradientAndScore();
System.out.println(net.score() + "\t" + net2.score());
assertEquals(net2.score(), net.score(), 1e-6);
Map<String, INDArray> gradient = net.gradient().gradientForVariable();
@ -441,85 +441,87 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
int numInputClasses = 10;
int timeSeriesLength = 5;
for (int nExamples : miniBatchSizes) {
Nd4j.getRandom().setSeed(12345);
for (DataType maskDtype : new DataType[]{DataType.FLOAT, DataType.DOUBLE, DataType.INT}) {
for (int nExamples : miniBatchSizes) {
Nd4j.getRandom().setSeed(12345);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.updater(new Sgd(0.1)).seed(12345).list()
.layer(0, new EmbeddingLayer.Builder().hasBias(true).activation(Activation.TANH).nIn(numInputClasses)
.nOut(5).build())
.layer(1, new DenseLayer.Builder().activation(Activation.TANH).nIn(5).nOut(4).build())
.layer(2, new GravesLSTM.Builder().activation(Activation.TANH).nIn(4).nOut(3).build())
.layer(3, new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(3)
.nOut(4).build())
.inputPreProcessor(0, new RnnToFeedForwardPreProcessor())
.inputPreProcessor(2, new FeedForwardToRnnPreProcessor()).build();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.updater(new Sgd(0.1)).seed(12345).list()
.layer(0, new EmbeddingLayer.Builder().hasBias(true).activation(Activation.TANH).nIn(numInputClasses)
.nOut(5).build())
.layer(1, new DenseLayer.Builder().activation(Activation.TANH).nIn(5).nOut(4).build())
.layer(2, new GravesLSTM.Builder().activation(Activation.TANH).nIn(4).nOut(3).build())
.layer(3, new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(3)
.nOut(4).build())
.inputPreProcessor(0, new RnnToFeedForwardPreProcessor())
.inputPreProcessor(2, new FeedForwardToRnnPreProcessor()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder()
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.updater(new Sgd(0.1)).seed(12345).list()
.layer(0, new DenseLayer.Builder().activation(Activation.TANH).nIn(numInputClasses).nOut(5)
.build())
.layer(1, new DenseLayer.Builder().activation(Activation.TANH).nIn(5).nOut(4).build())
.layer(2, new GravesLSTM.Builder().activation(Activation.TANH).nIn(4).nOut(3).build())
.layer(3, new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(3)
.nOut(4).build())
.inputPreProcessor(0, new RnnToFeedForwardPreProcessor())
.inputPreProcessor(2, new FeedForwardToRnnPreProcessor()).build();
MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder()
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.updater(new Sgd(0.1)).seed(12345).list()
.layer(0, new DenseLayer.Builder().activation(Activation.TANH).nIn(numInputClasses).nOut(5)
.build())
.layer(1, new DenseLayer.Builder().activation(Activation.TANH).nIn(5).nOut(4).build())
.layer(2, new GravesLSTM.Builder().activation(Activation.TANH).nIn(4).nOut(3).build())
.layer(3, new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(3)
.nOut(4).build())
.inputPreProcessor(0, new RnnToFeedForwardPreProcessor())
.inputPreProcessor(2, new FeedForwardToRnnPreProcessor()).build();
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
net2.init();
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
net2.init();
net2.setParams(net.params().dup());
net2.setParams(net.params().dup());
INDArray inEmbedding = Nd4j.zeros(nExamples, 1, timeSeriesLength);
INDArray inDense = Nd4j.zeros(nExamples, numInputClasses, timeSeriesLength);
INDArray inEmbedding = Nd4j.zeros(nExamples, 1, timeSeriesLength);
INDArray inDense = Nd4j.zeros(nExamples, numInputClasses, timeSeriesLength);
INDArray labels = Nd4j.zeros(nExamples, 4, timeSeriesLength);
INDArray labels = Nd4j.zeros(nExamples, 4, timeSeriesLength);
for (int i = 0; i < nExamples; i++) {
for (int j = 0; j < timeSeriesLength; j++) {
int inIdx = r.nextInt(numInputClasses);
inEmbedding.putScalar(new int[]{i, 0, j}, inIdx);
inDense.putScalar(new int[]{i, inIdx, j}, 1.0);
for (int i = 0; i < nExamples; i++) {
for (int j = 0; j < timeSeriesLength; j++) {
int inIdx = r.nextInt(numInputClasses);
inEmbedding.putScalar(new int[]{i, 0, j}, inIdx);
inDense.putScalar(new int[]{i, inIdx, j}, 1.0);
int outIdx = r.nextInt(4);
labels.putScalar(new int[]{i, outIdx, j}, 1.0);
int outIdx = r.nextInt(4);
labels.putScalar(new int[]{i, outIdx, j}, 1.0);
}
}
}
INDArray inputMask = Nd4j.zeros(nExamples, timeSeriesLength);
for (int i = 0; i < nExamples; i++) {
for (int j = 0; j < timeSeriesLength; j++) {
inputMask.putScalar(new int[]{i, j}, (r.nextBoolean() ? 1.0 : 0.0));
INDArray inputMask = Nd4j.zeros(maskDtype, nExamples, timeSeriesLength);
for (int i = 0; i < nExamples; i++) {
for (int j = 0; j < timeSeriesLength; j++) {
inputMask.putScalar(new int[]{i, j}, (r.nextBoolean() ? 1.0 : 0.0));
}
}
}
net.setLayerMaskArrays(inputMask, null);
net2.setLayerMaskArrays(inputMask, null);
List<INDArray> actEmbedding = net.feedForward(inEmbedding, false);
List<INDArray> actDense = net2.feedForward(inDense, false);
for (int i = 1; i < actEmbedding.size(); i++) {
assertEquals(actDense.get(i), actEmbedding.get(i));
}
net.setLayerMaskArrays(inputMask, null);
net2.setLayerMaskArrays(inputMask, null);
List<INDArray> actEmbedding = net.feedForward(inEmbedding, false);
List<INDArray> actDense = net2.feedForward(inDense, false);
for (int i = 1; i < actEmbedding.size(); i++) {
assertEquals(actDense.get(i), actEmbedding.get(i));
}
net.setLabels(labels);
net2.setLabels(labels);
net.computeGradientAndScore();
net2.computeGradientAndScore();
net.setLabels(labels);
net2.setLabels(labels);
net.computeGradientAndScore();
net2.computeGradientAndScore();
System.out.println(net.score() + "\t" + net2.score());
assertEquals(net2.score(), net.score(), 1e-5);
System.out.println(net.score() + "\t" + net2.score());
assertEquals(net2.score(), net.score(), 1e-5);
Map<String, INDArray> gradients = net.gradient().gradientForVariable();
Map<String, INDArray> gradients2 = net2.gradient().gradientForVariable();
assertEquals(gradients.keySet(), gradients2.keySet());
for (String s : gradients.keySet()) {
assertEquals(gradients2.get(s), gradients.get(s));
Map<String, INDArray> gradients = net.gradient().gradientForVariable();
Map<String, INDArray> gradients2 = net2.gradient().gradientForVariable();
assertEquals(gradients.keySet(), gradients2.keySet());
for (String s : gradients.keySet()) {
assertEquals(gradients2.get(s), gradients.get(s));
}
}
}
}
@ -583,6 +585,104 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
}
}
@Test
public void testEmbeddingSequenceLayerWithMasking() {
//Idea: have masking on the input with an embedding and dense layers on input
//Ensure that the parameter gradients for the inputs don't depend on the inputs when inputs are masked
int[] miniBatchSizes = {1, 3};
int nIn = 2;
Random r = new Random(12345);
int numInputClasses = 10;
int timeSeriesLength = 5;
for (DataType maskDtype : new DataType[]{DataType.FLOAT, DataType.DOUBLE, DataType.INT}) {
for (DataType inLabelDtype : new DataType[]{DataType.FLOAT, DataType.DOUBLE, DataType.INT}) {
for(int inputRank : new int[]{2, 3}) {
for (int nExamples : miniBatchSizes) {
Nd4j.getRandom().setSeed(12345);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.updater(new Sgd(0.1)).seed(12345).list()
.layer(0, new EmbeddingSequenceLayer.Builder().hasBias(true).activation(Activation.TANH).nIn(numInputClasses)
.nOut(5).build())
.layer(1, new DenseLayer.Builder().activation(Activation.TANH).nIn(5).nOut(4).build())
.layer(2, new LSTM.Builder().activation(Activation.TANH).nIn(4).nOut(3).build())
.layer(3, new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(3)
.nOut(4).build())
.setInputType(InputType.recurrent(1)).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder()
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.updater(new Sgd(0.1)).seed(12345).list()
.layer(0, new DenseLayer.Builder().activation(Activation.TANH).nIn(numInputClasses).nOut(5)
.build())
.layer(1, new DenseLayer.Builder().activation(Activation.TANH).nIn(5).nOut(4).build())
.layer(2, new LSTM.Builder().activation(Activation.TANH).nIn(4).nOut(3).build())
.layer(3, new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(3)
.nOut(4).build())
.setInputType(InputType.recurrent(1)).build();
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
net2.init();
net2.setParams(net.params().dup());
INDArray inEmbedding = Nd4j.zeros(inLabelDtype, inputRank == 2 ? new long[]{nExamples, timeSeriesLength} : new long[]{nExamples, 1, timeSeriesLength});
INDArray inDense = Nd4j.zeros(inLabelDtype, nExamples, numInputClasses, timeSeriesLength);
INDArray labels = Nd4j.zeros(inLabelDtype, nExamples, 4, timeSeriesLength);
for (int i = 0; i < nExamples; i++) {
for (int j = 0; j < timeSeriesLength; j++) {
int inIdx = r.nextInt(numInputClasses);
inEmbedding.putScalar(inputRank == 2 ? new int[]{i, j} : new int[]{i, 0, j}, inIdx);
inDense.putScalar(new int[]{i, inIdx, j}, 1.0);
int outIdx = r.nextInt(4);
labels.putScalar(new int[]{i, outIdx, j}, 1.0);
}
}
INDArray inputMask = Nd4j.zeros(maskDtype, nExamples, timeSeriesLength);
for (int i = 0; i < nExamples; i++) {
for (int j = 0; j < timeSeriesLength; j++) {
inputMask.putScalar(new int[]{i, j}, (r.nextBoolean() ? 1.0 : 0.0));
}
}
net.setLayerMaskArrays(inputMask, null);
net2.setLayerMaskArrays(inputMask, null);
List<INDArray> actEmbedding = net.feedForward(inEmbedding, false);
List<INDArray> actDense = net2.feedForward(inDense, false);
for (int i = 2; i < actEmbedding.size(); i++) { //Start from layer 2: EmbeddingSequence is 3d, first dense is 2d (before reshape)
assertEquals(actDense.get(i), actEmbedding.get(i));
}
net.setLabels(labels);
net2.setLabels(labels);
net.computeGradientAndScore();
net2.computeGradientAndScore();
assertEquals(net2.score(), net.score(), 1e-5);
Map<String, INDArray> gradients = net.gradient().gradientForVariable();
Map<String, INDArray> gradients2 = net2.gradient().gradientForVariable();
assertEquals(gradients.keySet(), gradients2.keySet());
for (String s : gradients.keySet()) {
assertEquals(gradients2.get(s), gradients.get(s));
}
}
}
}
}
}
@EqualsAndHashCode
private static class WordVectorsMockup implements EmbeddingInitializer {

View File

@ -213,6 +213,12 @@ public class TestSameDiffConv extends BaseDL4JTest {
INDArray outLoaded = netLoaded.output(in);
assertEquals(msg, outExp, outLoaded);
//Sanity check on different minibatch sizes:
INDArray newIn = Nd4j.vstack(in, in);
INDArray outMbsd = net.output(newIn);
INDArray outMb = net2.output(newIn);
assertEquals(outMb, outMbsd);
}
}
}
@ -306,6 +312,10 @@ public class TestSameDiffConv extends BaseDL4JTest {
assertTrue(msg, gradOK);
TestUtils.testModelSerialization(net);
//Sanity check on different minibatch sizes:
INDArray newIn = Nd4j.vstack(f, f);
net.output(newIn);
}
}
}

View File

@ -137,6 +137,12 @@ public class TestSameDiffDense extends BaseDL4JTest {
INDArray outLoaded = netLoaded.output(in);
assertEquals(outExp, outLoaded);
//Sanity check on different minibatch sizes:
INDArray newIn = Nd4j.vstack(in, in);
INDArray outMbsd = net.output(newIn);
INDArray outMb = net2.output(newIn);
assertEquals(outMb, outMbsd);
}
}
}
@ -314,6 +320,12 @@ public class TestSameDiffDense extends BaseDL4JTest {
netSD.computeGradientAndScore();
// netStandard.computeGradientAndScore();
// assertEquals(netStandard.gradient().gradient(), netSD.gradient().gradient());
//Sanity check on different minibatch sizes:
INDArray newIn = Nd4j.vstack(in, in);
INDArray outMbsd = netSD.output(newIn);
INDArray outMb = netStandard.output(newIn);
assertEquals(outMb, outMbsd);
}
}
}
@ -377,6 +389,12 @@ public class TestSameDiffDense extends BaseDL4JTest {
assertEquals(s, netStandard.params(), netSD.params());
assertEquals(s, netStandard.getUpdater().getStateViewArray(), netSD.getUpdater().getStateViewArray());
}
//Sanity check on different minibatch sizes:
INDArray newIn = Nd4j.vstack(ds.getFeatures(), ds.getFeatures());
INDArray outMbsd = netSD.output(newIn);
INDArray outMb = netStandard.output(newIn);
assertEquals(outMb, outMbsd);
}
@Test
@ -417,6 +435,10 @@ public class TestSameDiffDense extends BaseDL4JTest {
assertTrue(msg, gradOK);
TestUtils.testModelSerialization(net);
//Sanity check on different minibatch sizes:
INDArray newIn = Nd4j.vstack(f, f);
net.output(newIn);
}
}
}

View File

@ -166,6 +166,12 @@ public class TestSameDiffDenseVertex extends BaseDL4JTest {
outSD = loaded.outputSingle(in);
outStd = netStandard.outputSingle(in);
assertEquals(outStd, outSD);
//Sanity check on different minibatch sizes:
INDArray newIn = Nd4j.vstack(in, in);
INDArray outMbsd = netSD.output(newIn)[0];
INDArray outMb = netStandard.output(newIn)[0];
assertEquals(outMb, outMbsd);
}
}
}

View File

@ -115,6 +115,12 @@ public class TestSameDiffLambda extends BaseDL4JTest {
outStd = std.outputSingle(in);
assertEquals(outStd, outLambda);
//Sanity check on different minibatch sizes:
INDArray newIn = Nd4j.vstack(in, in);
INDArray outMbsd = lambda.output(newIn)[0];
INDArray outMb = std.output(newIn)[0];
assertEquals(outMb, outMbsd);
}
@Test
@ -186,5 +192,12 @@ public class TestSameDiffLambda extends BaseDL4JTest {
outStd = std.output(in1, in2)[0];
assertEquals(outStd, outLambda);
//Sanity check on different minibatch sizes:
INDArray newIn1 = Nd4j.vstack(in1, in1);
INDArray newIn2 = Nd4j.vstack(in2, in2);
INDArray outMbsd = lambda.output(newIn1, newIn2)[0];
INDArray outMb = std.output(newIn1, newIn2)[0];
assertEquals(outMb, outMbsd);
}
}

View File

@ -90,6 +90,12 @@ public class TestSameDiffOutput extends BaseDL4JTest {
MultiLayerNetwork net = new MultiLayerNetwork(confSD.clone());
net.init();
net.fit(ds);
//Sanity check on different minibatch sizes:
INDArray newIn = Nd4j.vstack(in, in);
INDArray outMbsd = netSD.output(newIn);
INDArray outMb = netStd.output(newIn);
assertEquals(outMb, outMbsd);
}
@ -164,6 +170,12 @@ public class TestSameDiffOutput extends BaseDL4JTest {
MultiLayerNetwork net = new MultiLayerNetwork(confSD.clone());
net.init();
net.fit(ds);
//Sanity check on different minibatch sizes:
INDArray newIn = Nd4j.vstack(in, in);
INDArray outMbsd = netSD.output(newIn);
INDArray outMb = netStd.output(newIn);
assertEquals(outMb, outMbsd);
}
}

View File

@ -32,6 +32,7 @@ import org.deeplearning4j.models.embeddings.learning.impl.elements.SkipGram;
import org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectorsImpl;
import org.deeplearning4j.models.fasttext.FastText;
import org.deeplearning4j.models.glove.Glove;
import org.deeplearning4j.models.paragraphvectors.ParagraphVectors;
import org.deeplearning4j.models.sequencevectors.SequenceVectors;
@ -3089,6 +3090,42 @@ public class WordVectorSerializer {
word2Vec.setModelUtils(vectors.getModelUtils());
return word2Vec;
}
public static void writeWordVectors(@NonNull FastText vectors, @NonNull File path) throws IOException {
ObjectOutputStream outputStream = null;
try {
outputStream = new ObjectOutputStream(new FileOutputStream(path ));
outputStream.writeObject(vectors);
}
finally {
try {
if (outputStream != null) {
outputStream.flush();
outputStream.close();
}
} catch (IOException ex) {
ex.printStackTrace();
}
}
}
public static FastText readWordVectors(File path) {
FastText result = null;
try {
FileInputStream fileIn = new FileInputStream(path);
ObjectInputStream in = new ObjectInputStream(fileIn);
try {
result = (FastText) in.readObject();
} catch (ClassNotFoundException ex) {
}
} catch (FileNotFoundException ex) {
ex.printStackTrace();
} catch (IOException ex) {
ex.printStackTrace();
}
return result;
}
public static void printOutProjectedMemoryUse(long numWords, int vectorLength, int numTables) {
double memSize = numWords * vectorLength * Nd4j.sizeOfDataType() * numTables;

View File

@ -21,6 +21,7 @@ import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.embeddings.reader.ModelUtils;
@ -207,7 +208,6 @@ public class BasicModelUtils<T extends SequenceElement> implements ModelUtils<T>
}
INDArray mean = words.isMatrix() ? words.mean(0) : words;
Collection<String> tempRes = wordsNearest(mean, top + positive.size() + negative.size());
List<String> realResults = new ArrayList<>();
@ -232,6 +232,22 @@ public class BasicModelUtils<T extends SequenceElement> implements ModelUtils<T>
return wordsNearestSum(vec, n);
}
protected INDArray adjustRank(INDArray words) {
if (lookupTable instanceof InMemoryLookupTable) {
InMemoryLookupTable l = (InMemoryLookupTable) lookupTable;
INDArray syn0 = l.getSyn0();
if (!words.dataType().equals(syn0.dataType())) {
return words.castTo(syn0.dataType());
}
if (words.rank() == 0 || words.rank() > 2) {
throw new IllegalStateException("Invalid rank for wordsNearest method");
} else if (words.rank() == 1) {
return words.reshape(1, -1);
}
}
return words;
}
/**
* Words nearest based on positive and negative words
* * @param top the top n words
@ -239,6 +255,8 @@ public class BasicModelUtils<T extends SequenceElement> implements ModelUtils<T>
*/
@Override
public Collection<String> wordsNearest(INDArray words, int top) {
words = adjustRank(words);
if (lookupTable instanceof InMemoryLookupTable) {
InMemoryLookupTable l = (InMemoryLookupTable) lookupTable;

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.models.embeddings.reader.impl;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.ops.transforms.Transforms;
@ -64,6 +65,8 @@ public class FlatModelUtils<T extends SequenceElement> extends BasicModelUtils<T
public Collection<String> wordsNearest(INDArray words, int top) {
Counter<String> distances = new Counter<>();
words = adjustRank(words);
for (String s : vocabCache.words()) {
INDArray otherVec = lookupTable.vector(s);
double sim = Transforms.cosineSim(Transforms.unitVec(words.dup()), Transforms.unitVec(otherVec.dup()));

View File

@ -103,6 +103,7 @@ public class TreeModelUtils<T extends SequenceElement> extends BasicModelUtils<T
@Override
public Collection<String> wordsNearest(INDArray words, int top) {
checkTree();
words = adjustRank(words);
List<DataPoint> add = new ArrayList<>();
List<Double> distances = new ArrayList<>();

View File

@ -172,4 +172,10 @@ public interface WordVectors extends Serializable, EmbeddingInitializer {
*/
void setModelUtils(ModelUtils utils);
/**
* Does implementation vectorize words absent in vocabulary
* @return boolean
*/
boolean outOfVocabularySupported();
}

View File

@ -20,6 +20,7 @@ import com.google.common.util.concurrent.AtomicDouble;
import lombok.Getter;
import lombok.NonNull;
import lombok.Setter;
import lombok.val;
import org.apache.commons.lang.ArrayUtils;
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
@ -357,4 +358,9 @@ public class WordVectorsImpl<T extends SequenceElement> implements WordVectors {
public boolean jsonSerializable() {
return false;
}
@Override
public boolean outOfVocabularySupported() {
return false;
}
}

View File

@ -6,10 +6,13 @@ import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.NotImplementedException;
import org.apache.commons.lang3.StringUtils;
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
import org.deeplearning4j.models.embeddings.reader.ModelUtils;
import org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.Word2Vec;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache;
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
@ -17,41 +20,78 @@ import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.io.*;
import java.util.*;
@Slf4j
@AllArgsConstructor
@lombok.Builder
public class FastText implements WordVectors {
public class FastText implements WordVectors, Serializable {
private boolean supervised;
private boolean quantize;
private boolean predict;
private boolean predict_prob;
// Mandatory
@Getter private String inputFile;
@Getter private String outputFile;
private boolean skipgram;
@Builder.Default private int bucket = 100;
@Builder.Default private int minCount = 1;
// Optional for dictionary
@Builder.Default private int bucket = -1;
@Builder.Default private int minCount = -1;
@Builder.Default private int minCountLabel = -1;
@Builder.Default private int wordNgrams = -1;
@Builder.Default private int minNgramLength = -1;
@Builder.Default private int maxNgramLength = -1;
@Builder.Default private int samplingThreshold = -1;
private String labelPrefix;
private boolean cbow;
private boolean nn;
private boolean analogies;
private String inputFile;
private String outputFile;
private SentenceIterator iterator;
private String modelName;
private String lossName;
//TODO:
private double[] pretrainedVectors;
// Optional for training
@Getter private boolean supervised;
@Getter private boolean quantize;
@Getter private boolean predict;
@Getter private boolean predict_prob;
@Getter private boolean skipgram;
@Getter private boolean cbow;
@Getter private boolean nn;
@Getter private boolean analogies;
@Getter private String pretrainedVectorsFile;
@Getter
@Builder.Default
private double learningRate = -1.0;
@Getter private double learningRateUpdate = -1.0;
@Getter
@Builder.Default
private int dim = -1;
@Getter
@Builder.Default
private int contextWindowSize = -1;
@Getter
@Builder.Default
private int epochs = -1;
@Getter private String modelName;
@Getter private String lossName;
@Getter
@Builder.Default
private int negativeSamples = -1;
@Getter
@Builder.Default
private int numThreads = -1;
@Getter private boolean saveOutput = false;
private JFastText fastTextImpl;
private boolean modelLoaded;
// Optional for quantization
@Getter
@Builder.Default
private int cutOff = -1;
@Getter private boolean retrain;
@Getter private boolean qnorm;
@Getter private boolean qout;
@Getter
@Builder.Default
private int dsub = -1;
@Getter private SentenceIterator iterator;
@Builder.Default private transient JFastText fastTextImpl = new JFastText();
private transient Word2Vec word2Vec;
@Getter private boolean modelLoaded;
@Getter private boolean modelVectorsLoaded;
private VocabCache vocabCache;
public FastText(File modelPath) {
@ -63,8 +103,97 @@ public class FastText implements WordVectors {
fastTextImpl = new JFastText();
}
public void init() {
fastTextImpl = new JFastText();
private static class ArgsFactory {
private List<String> args = new ArrayList<>();
private void add(String label, String value) {
args.add(label);
args.add(value);
}
private void addOptional(String label, int value) {
if (value >= 0) {
args.add(label);
args.add(Integer.toString(value));
}
}
private void addOptional(String label, double value) {
if (value >= 0.0) {
args.add(label);
args.add(Double.toString(value));
}
}
private void addOptional(String label, String value) {
if (StringUtils.isNotEmpty(value)) {
args.add(label);
args.add(value);
}
}
private void addOptional(String label, boolean value) {
if (value) {
args.add(label);
}
}
public String[] args() {
String[] asArray = new String[args.size()];
return args.toArray(asArray);
}
}
private String[] makeArgs() {
ArgsFactory argsFactory = new ArgsFactory();
argsFactory.addOptional("cbow", cbow);
argsFactory.addOptional("skipgram", skipgram);
argsFactory.addOptional("supervised", supervised);
argsFactory.addOptional("quantize", quantize);
argsFactory.addOptional("predict", predict);
argsFactory.addOptional("predict_prob", predict_prob);
argsFactory.add("-input", inputFile);
argsFactory.add("-output", outputFile );
argsFactory.addOptional("-pretrainedVectors", pretrainedVectorsFile);
argsFactory.addOptional("-bucket", bucket);
argsFactory.addOptional("-minCount", minCount);
argsFactory.addOptional("-minCountLabel", minCountLabel);
argsFactory.addOptional("-wordNgrams", wordNgrams);
argsFactory.addOptional("-minn", minNgramLength);
argsFactory.addOptional("-maxn", maxNgramLength);
argsFactory.addOptional("-t", samplingThreshold);
argsFactory.addOptional("-label", labelPrefix);
argsFactory.addOptional("analogies",analogies);
argsFactory.addOptional("-lr", learningRate);
argsFactory.addOptional("-lrUpdateRate", learningRateUpdate);
argsFactory.addOptional("-dim", dim);
argsFactory.addOptional("-ws", contextWindowSize);
argsFactory.addOptional("-epoch", epochs);
argsFactory.addOptional("-loss", lossName);
argsFactory.addOptional("-neg", negativeSamples);
argsFactory.addOptional("-thread", numThreads);
argsFactory.addOptional("-saveOutput", saveOutput);
argsFactory.addOptional("-cutoff", cutOff);
argsFactory.addOptional("-retrain", retrain);
argsFactory.addOptional("-qnorm", qnorm);
argsFactory.addOptional("-qout", qout);
argsFactory.addOptional("-dsub", dsub);
return argsFactory.args();
}
public void fit() {
String[] cmd = makeArgs();
fastTextImpl.runCmd(cmd);
}
public void loadIterator() {
if (iterator != null) {
try {
File tempFile = File.createTempFile("FTX", ".txt");
@ -81,24 +210,11 @@ public class FastText implements WordVectors {
}
}
public void fit() {
String[] cmd;
if (skipgram) {
cmd = new String[]{"skipgram", "-bucket", Integer.toString(bucket), "-minCount", Integer.toString(minCount),
"-input", inputFile, "-output", outputFile};
}
else if (cbow) {
cmd = new String[]{"cbow", "-bucket", Integer.toString(bucket), "-minCount", Integer.toString(minCount),
"-input", inputFile, "-output", outputFile};
}
else if (supervised)
cmd = new String[]{"supervised", "-input", inputFile,
"-output", outputFile};
else
cmd = new String[]{"-input", inputFile,
"-output", outputFile};
fastTextImpl.runCmd(cmd);
public void loadPretrainedVectors(File vectorsFile) {
word2Vec = WordVectorSerializer.readWord2VecModel(vectorsFile);
modelVectorsLoaded = true;
log.info("Loaded vectorized representation from file %s. Functionality will be restricted.",
vectorsFile.getAbsolutePath());
}
public void loadBinaryModel(String modelPath) {
@ -111,10 +227,18 @@ public class FastText implements WordVectors {
modelLoaded = false;
}
public void test(File testFile) {
fastTextImpl.test(testFile.getAbsolutePath());
}
private void assertModelLoaded() {
if (!modelLoaded && !modelVectorsLoaded)
throw new IllegalStateException("Model must be loaded before predict!");
}
public String predict(String text) {
if (!modelLoaded)
throw new IllegalStateException("Model must be loaded before predict!");
assertModelLoaded();
String label = fastTextImpl.predict(text);
return label;
@ -122,8 +246,7 @@ public class FastText implements WordVectors {
public Pair<String, Float> predictProbability(String text) {
if (!modelLoaded)
throw new IllegalStateException("Model must be loaded before predict!");
assertModelLoaded();
JFastText.ProbLabel predictedProbLabel = fastTextImpl.predictProba(text);
@ -135,27 +258,39 @@ public class FastText implements WordVectors {
@Override
public VocabCache vocab() {
if (!modelLoaded)
throw new IllegalStateException("Load model before calling vocab()");
if (vocabCache == null) {
vocabCache = new AbstractCache();
if (modelVectorsLoaded) {
vocabCache = word2Vec.vocab();
}
List<String> words = fastTextImpl.getWords();
for (int i = 0; i < words.size(); ++i) {
vocabCache.addWordToIndex(i, words.get(i));
VocabWord word = new VocabWord();
word.setWord(words.get(i));
vocabCache.addToken(word);
else {
if (!modelLoaded)
throw new IllegalStateException("Load model before calling vocab()");
if (vocabCache == null) {
vocabCache = new AbstractCache();
}
List<String> words = fastTextImpl.getWords();
for (int i = 0; i < words.size(); ++i) {
vocabCache.addWordToIndex(i, words.get(i));
VocabWord word = new VocabWord();
word.setWord(words.get(i));
vocabCache.addToken(word);
}
}
return vocabCache;
}
@Override
public long vocabSize() {
if (!modelLoaded)
throw new IllegalStateException("Load model before calling vocab()");
return fastTextImpl.getNWords();
long result = 0;
if (modelVectorsLoaded) {
result = word2Vec.vocabSize();
}
else {
if (!modelLoaded)
throw new IllegalStateException("Load model before calling vocab()");
result = fastTextImpl.getNWords();
}
return result;
}
@Override
@ -170,99 +305,160 @@ public class FastText implements WordVectors {
@Override
public double[] getWordVector(String word) {
List<Float> vectors = fastTextImpl.getVector(word);
double[] retVal = new double[vectors.size()];
for (int i = 0; i < vectors.size(); ++i) {
retVal[i] = vectors.get(i);
if (modelVectorsLoaded) {
return word2Vec.getWordVector(word);
}
else {
List<Float> vectors = fastTextImpl.getVector(word);
double[] retVal = new double[vectors.size()];
for (int i = 0; i < vectors.size(); ++i) {
retVal[i] = vectors.get(i);
}
return retVal;
}
return retVal;
}
@Override
public INDArray getWordVectorMatrixNormalized(String word) {
INDArray r = getWordVectorMatrix(word);
return r.divi(Nd4j.getBlasWrapper().nrm2(r));
if (modelVectorsLoaded) {
return word2Vec.getWordVectorMatrixNormalized(word);
}
else {
INDArray r = getWordVectorMatrix(word);
return r.divi(Nd4j.getBlasWrapper().nrm2(r));
}
}
@Override
public INDArray getWordVectorMatrix(String word) {
double[] values = getWordVector(word);
return Nd4j.createFromArray(values);
if (modelVectorsLoaded) {
return word2Vec.getWordVectorMatrix(word);
}
else {
double[] values = getWordVector(word);
return Nd4j.createFromArray(values);
}
}
@Override
public INDArray getWordVectors(Collection<String> labels) {
if (modelVectorsLoaded) {
return word2Vec.getWordVectors(labels);
}
return null;
}
@Override
public INDArray getWordVectorsMean(Collection<String> labels) {
if (modelVectorsLoaded) {
return word2Vec.getWordVectorsMean(labels);
}
return null;
}
private List<String> words = new ArrayList<>();
@Override
public boolean hasWord(String word) {
return fastTextImpl.getWords().contains(word);
if (modelVectorsLoaded) {
return word2Vec.outOfVocabularySupported();
}
if (words.isEmpty())
words = fastTextImpl.getWords();
return words.contains(word);
}
protected transient ModelUtils modelUtils = new BasicModelUtils<>();
protected transient ModelUtils modelUtils;
@Override
public Collection<String> wordsNearest(INDArray words, int top) {
if (modelVectorsLoaded) {
return word2Vec.wordsNearest(words, top);
}
return modelUtils.wordsNearest(words, top);
}
@Override
public Collection<String> wordsNearestSum(INDArray words, int top) {
if (modelVectorsLoaded) {
return word2Vec.wordsNearestSum(words, top);
}
return modelUtils.wordsNearestSum(words, top);
}
@Override
public Collection<String> wordsNearestSum(String word, int n) {
if (modelVectorsLoaded) {
return word2Vec.wordsNearestSum(word, n);
}
return modelUtils.wordsNearestSum(word, n);
}
@Override
public Collection<String> wordsNearestSum(Collection<String> positive, Collection<String> negative, int top) {
if (modelVectorsLoaded) {
return word2Vec.wordsNearestSum(positive, negative, top);
}
return modelUtils.wordsNearestSum(positive, negative, top);
}
@Override
public Map<String, Double> accuracy(List<String> questions) {
if (modelVectorsLoaded) {
return word2Vec.accuracy(questions);
}
return modelUtils.accuracy(questions);
}
@Override
public int indexOf(String word) {
if (modelVectorsLoaded) {
return word2Vec.indexOf(word);
}
return vocab().indexOf(word);
}
@Override
public List<String> similarWordsInVocabTo(String word, double accuracy) {
if (modelVectorsLoaded) {
return word2Vec.similarWordsInVocabTo(word, accuracy);
}
return modelUtils.similarWordsInVocabTo(word, accuracy);
}
@Override
public Collection<String> wordsNearest(Collection<String> positive, Collection<String> negative, int top) {
if (modelVectorsLoaded) {
return word2Vec.wordsNearest(positive, negative, top);
}
return modelUtils.wordsNearest(positive, negative, top);
}
@Override
public Collection<String> wordsNearest(String word, int n) {
if (modelVectorsLoaded) {
return word2Vec.wordsNearest(word,n);
}
return modelUtils.wordsNearestSum(word, n);
}
@Override
public double similarity(String word, String word2) {
if (modelVectorsLoaded) {
return word2Vec.similarity(word, word2);
}
return modelUtils.similarity(word, word2);
}
@Override
public WeightLookupTable lookupTable() {
if (modelVectorsLoaded) {
return word2Vec.lookupTable();
}
return null;
}
@ -320,4 +516,9 @@ public class FastText implements WordVectors {
return fastTextImpl.getLabelPrefix();
}
@Override
public boolean outOfVocabularySupported() {
return true;
}
}

View File

@ -376,6 +376,11 @@ public class StaticWord2Vec implements WordVectors {
return false;
}
@Override
public boolean outOfVocabularySupported() {
return false;
}
public static class Builder {
private AbstractStorage<Integer> storage;

View File

@ -1,6 +1,10 @@
package org.deeplearning4j.models.fasttext;
import com.github.jfasttext.JFastText;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
import org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils;
import org.deeplearning4j.models.word2vec.Word2Vec;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.text.sentenceiterator.BasicLineIterator;
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
@ -13,6 +17,7 @@ import org.nd4j.resources.Resources;
import java.io.File;
import java.io.IOException;
import java.util.Arrays;
import static org.junit.Assert.assertArrayEquals;
@ -23,7 +28,9 @@ import static org.junit.Assert.assertEquals;
public class FastTextTest extends BaseDL4JTest {
private File inputFile = Resources.asFile("models/fasttext/data/labeled_data.txt");
private File modelFile = Resources.asFile("models/fasttext/supervised.model.bin");
private File supModelFile = Resources.asFile("models/fasttext/supervised.model.bin");
private File cbowModelFile = Resources.asFile("models/fasttext/cbow.model.bin");
private File supervisedVectors = Resources.asFile("models/fasttext/supervised.model.vec");
@Rule
@ -39,7 +46,6 @@ public class FastTextTest extends BaseDL4JTest {
inputFile(inputFile.getAbsolutePath()).
outputFile(output.getAbsolutePath()).build();
log.info("\nTraining supervised model ...\n");
fastText.init();
fastText.fit();
}
@ -53,7 +59,6 @@ public class FastTextTest extends BaseDL4JTest {
inputFile(inputFile.getAbsolutePath()).
outputFile(output.getAbsolutePath()).build();
log.info("\nTraining supervised model ...\n");
fastText.init();
fastText.fit();
}
@ -68,7 +73,6 @@ public class FastTextTest extends BaseDL4JTest {
inputFile(inputFile.getAbsolutePath()).
outputFile(output.getAbsolutePath()).build();
log.info("\nTraining supervised model ...\n");
fastText.init();
fastText.fit();
}
@ -82,34 +86,42 @@ public class FastTextTest extends BaseDL4JTest {
inputFile(inputFile.getAbsolutePath()).
outputFile(output.getAbsolutePath()).build();
log.info("\nTraining supervised model ...\n");
fastText.init();
fastText.fit();
}
@Ignore
@Test
public void tesLoadCBOWModel() throws IOException {
FastText fastText = new FastText(cbowModelFile);
fastText.test(cbowModelFile);
assertEquals(19, fastText.vocab().numWords());
assertEquals("enjoy", fastText.vocab().wordAtIndex(fastText.vocab().numWords() - 1));
double[] expected = {5.040466203354299E-4, 0.001005030469968915, 2.8882650076411664E-4, -6.413314840756357E-4, -1.78931062691845E-4, -0.0023157168179750443, -0.002215880434960127, 0.00274421414360404, -1.5344757412094623E-4, 4.6274057240225375E-4, -1.4383681991603225E-4, 3.7832374800927937E-4, 2.523412986192852E-4, 0.0018913350068032742, -0.0024741862434893847, -4.976555937901139E-4, 0.0039220210164785385, -0.001781729981303215, -6.010578363202512E-4, -0.00244093406945467, -7.98621098510921E-4, -0.0010007203090935946, -0.001640203408896923, 7.897148607298732E-4, 9.131592814810574E-4, -0.0013367272913455963, -0.0014030139427632093, -7.755287806503475E-4, -4.2878396925516427E-4, 6.912827957421541E-4, -0.0011824817629531026, -0.0036014916840940714, 0.004353308118879795, -7.073904271237552E-5, -9.646290563978255E-4, -0.0031849315855652094, 2.3360115301329643E-4, -2.9103990527801216E-4, -0.0022990566212683916, -0.002393763978034258, -0.001034979010000825, -0.0010725988540798426, 0.0018285386031493545, -0.0013178540393710136, -1.6632364713586867E-4, -1.4665909475297667E-5, 5.445032729767263E-4, 2.999933494720608E-4, -0.0014367225812748075, -0.002345481887459755, 0.001117417006753385, -8.688368834555149E-4, -0.001830018823966384, 0.0013242220738902688, -8.880519890226424E-4, -6.888324278406799E-4, -0.0036394784692674875, 0.002179111586883664, -1.7201311129610986E-4, 0.002365073887631297, 0.002688770182430744, 0.0023955567739903927, 0.001469283364713192, 0.0011803617235273123, 5.871498142369092E-4, -7.099180947989225E-4, 7.518937345594168E-4, -8.599072461947799E-4, -6.600041524507105E-4, -0.002724145073443651, -8.365285466425121E-4, 0.0013173354091122746, 0.001083166105672717, 0.0014539906987920403, -3.1698777456767857E-4, -2.387022686889395E-4, 1.9560157670639455E-4, 0.0020277926232665777, -0.0012741144746541977, -0.0013026101514697075, -1.5212174912448972E-4, 0.0014194383984431624, 0.0012500399025157094, 0.0013362085446715355, 3.692879108712077E-4, 4.319801155361347E-5, 0.0011261265026405454, 0.0017244465416297317, 5.564604725805111E-5, 0.002170475199818611, 0.0014707016525790095, 0.001303741242736578, 0.005553730763494968, -0.0011097051901742816, -0.0013661726843565702, 0.0014100460102781653, 0.0011811562580987811, -6.622733199037611E-4, 7.860265322960913E-4, -9.811905911192298E-4};
assertArrayEquals(expected, fastText.getWordVector("enjoy"), 1e-4);
}
@Test
public void testPredict() throws IOException {
for (int i = 0; i < 100; ++i) {
String text = "I like soccer";
FastText fastText = new FastText(modelFile);
FastText fastText = new FastText(supModelFile);
assertEquals(48, fastText.vocab().numWords());
assertEquals("association", fastText.vocab().wordAtIndex(fastText.vocab().numWords() - 1));
double[] expected = {-0.006423053797334433, 0.007660661358386278, 0.006068876478821039, -0.004772625397890806, -0.007143457420170307, -0.007735592778772116, -0.005607823841273785, -0.00836215727031231, 0.0011235733982175589, 2.599214785732329E-4, 0.004131870809942484, 0.007203693501651287, 0.0016768622444942594, 0.008694255724549294, -0.0012487826170399785, -0.00393667770549655, -0.006292815785855055, 0.0049359360709786415, -3.356488887220621E-4, -0.009407570585608482, -0.0026168026961386204, -0.00978928804397583, 0.0032913016621023417, -0.0029464277904480696, -0.008649969473481178, 8.056449587456882E-4, 0.0043088337406516075, -0.008980576880276203, 0.008716211654245853, 0.0073893265798687935, -0.007388216909021139, 0.003814412746578455, -0.005518500227481127, 0.004668557550758123, 0.006603693123906851, 0.003820829326286912, 0.007174000144004822, -0.006393063813447952, -0.0019381389720365405, -0.0046371882781386375, -0.006193376146256924, -0.0036685809027403593, 7.58899434003979E-4, -0.003185075242072344, -0.008330358192324638, 3.3206873922608793E-4, -0.005389622412621975, 0.009706716984510422, 0.0037855932023376226, -0.008665262721478939, -0.0032511046156287193, 4.4134497875347733E-4, -0.008377416990697384, -0.009110655635595322, 0.0019723298028111458, 0.007486093323677778, 0.006400121841579676, 0.00902814231812954, 0.00975200068205595, 0.0060582347214221954, -0.0075621469877660275, 1.0270809434587136E-4, -0.00673140911385417, -0.007316927425563335, 0.009916870854794979, -0.0011407854035496712, -4.502215306274593E-4, -0.007612560410052538, 0.008726916275918484, -3.0280642022262327E-5, 0.005529289599508047, -0.007944817654788494, 0.005593308713287115, 0.003423960180953145, 4.1348213562741876E-4, 0.009524818509817123, -0.0025129399728029966, -0.0030074280221015215, -0.007503866218030453, -0.0028124507516622543, -0.006841592025011778, -2.9375351732596755E-4, 0.007195258513092995, -0.007775942329317331, 3.951996040996164E-4, -0.006887971889227629, 0.0032655203249305487, -0.007975360378623009, -4.840183464693837E-6, 0.004651934839785099, 0.0031739831902086735, 0.004644941072911024, -0.007461248897016048, 0.003057275665923953, 0.008903342299163342, 0.006857945583760738, 0.007567950990051031, 0.001506582135334611, 0.0063307867385447025, 0.005645462777465582};
assertArrayEquals(expected, fastText.getWordVector("association"), 1e-5);
assertArrayEquals(expected, fastText.getWordVector("association"), 1e-4);
String label = fastText.predict(text);
assertEquals("__label__soccer", label);
}
}
@Ignore
@Test
public void testPredictProbability() throws IOException {
String text = "I like soccer";
FastText fastText = new FastText(modelFile);
FastText fastText = new FastText(supModelFile);
Pair<String,Float> result = fastText.predictProbability(text);
assertEquals("__label__soccer", result.getFirst());
@ -129,7 +141,7 @@ public class FastTextTest extends BaseDL4JTest {
@Test
public void testVocabulary() throws IOException {
FastText fastText = new FastText(modelFile);
FastText fastText = new FastText(supModelFile);
assertEquals(48, fastText.vocab().numWords());
assertEquals(48, fastText.vocabSize());
@ -149,7 +161,7 @@ public class FastTextTest extends BaseDL4JTest {
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
FastText fastText =
FastText.builder().supervised(true).iterator(iter).build();
fastText.init();
fastText.loadIterator();
} catch (IOException e) {
log.error(e.toString());
@ -162,4 +174,60 @@ public class FastTextTest extends BaseDL4JTest {
String label = fastText.predict("something");
}
@Test
public void testPretrainedVectors() throws IOException {
File output = testDir.newFile();
FastText fastText =
FastText.builder().supervised(true).
inputFile(inputFile.getAbsolutePath()).
pretrainedVectorsFile(supervisedVectors.getAbsolutePath()).
outputFile(output.getAbsolutePath()).build();
log.info("\nTraining supervised model ...\n");
fastText.fit();
}
@Test
public void testWordsStatistics() throws IOException {
File output = testDir.newFile();
FastText fastText =
FastText.builder().supervised(true).
inputFile(inputFile.getAbsolutePath()).
outputFile(output.getAbsolutePath()).build();
log.info("\nTraining supervised model ...\n");
fastText.fit();
Word2Vec word2Vec = WordVectorSerializer.readAsCsv(new File(output.getAbsolutePath() + ".vec"));
assertEquals(48, word2Vec.getVocab().numWords());
System.out.println(word2Vec.wordsNearest("association", 3));
System.out.println(word2Vec.similarity("Football", "teams"));
System.out.println(word2Vec.similarity("professional", "minutes"));
System.out.println(word2Vec.similarity("java","cpp"));
}
@Test
public void testWordsNativeStatistics() throws IOException {
File output = testDir.newFile();
FastText fastText = new FastText();
fastText.loadPretrainedVectors(supervisedVectors);
log.info("\nTraining supervised model ...\n");
assertEquals(48, fastText.vocab().numWords());
String[] result = new String[3];
fastText.wordsNearest("association", 3).toArray(result);
assertArrayEquals(new String[]{"most","eleven","hours"}, result);
assertEquals(0.1657, fastText.similarity("Football", "teams"), 1e-4);
assertEquals(0.3661, fastText.similarity("professional", "minutes"), 1e-4);
assertEquals(Double.NaN, fastText.similarity("java","cpp"), 1e-4);
}
}

View File

@ -26,6 +26,7 @@ import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
import org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils;
import org.deeplearning4j.models.embeddings.reader.impl.FlatModelUtils;
import org.deeplearning4j.models.fasttext.FastText;
import org.deeplearning4j.models.paragraphvectors.ParagraphVectors;
import org.deeplearning4j.models.sequencevectors.SequenceVectors;
import org.deeplearning4j.models.word2vec.VocabWord;
@ -42,6 +43,7 @@ import org.nd4j.linalg.factory.Nd4j;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.IOException;
import java.util.Collections;
import static org.junit.Assert.*;
@ -289,4 +291,33 @@ public class WordVectorSerializerTest extends BaseDL4JTest {
}
}
}
@Test
public void FastText_Correct_WhenDeserialized() throws IOException {
FastText fastText =
FastText.builder().cbow(true).build();
WordVectorSerializer.writeWordVectors(fastText, new File("some.data"));
FastText deser = null;
try {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
deser = WordVectorSerializer.readWordVectors(new File("some.data"));
} catch (Exception e) {
e.printStackTrace();
fail();
}
assertNotNull(deser);
assertEquals(fastText.isCbow(), deser.isCbow());
assertEquals(fastText.isModelLoaded(), deser.isModelLoaded());
assertEquals(fastText.isAnalogies(), deser.isAnalogies());
assertEquals(fastText.isNn(), deser.isNn());
assertEquals(fastText.isPredict(), deser.isPredict());
assertEquals(fastText.isPredict_prob(), deser.isPredict_prob());
assertEquals(fastText.isQuantize(), deser.isQuantize());
assertEquals(fastText.getInputFile(), deser.getInputFile());
assertEquals(fastText.getOutputFile(), deser.getOutputFile());
}
}

View File

@ -453,6 +453,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
DataType netDtype = getConfiguration().getDataType();
if(parameters != null && parameters.dataType() != netDtype){
Preconditions.checkState(parameters.rank() == 2 && parameters.size(0) == 1, "Invalid parameters array: should be rank 2 with shape [1,numParams]. Got %ndShape", parameters);
if(cloneParametersArray){
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
parameters = parameters.castTo(netDtype);

View File

@ -178,8 +178,7 @@ public class EmbeddingSequenceLayer extends BaseLayer<org.deeplearning4j.nn.conf
", mask shape: " + Arrays.toString(maskArray.shape()));
}
//Returned array: rank 3, shape [mb, vector, seqLength]. mask shape: [mb, seqLength]
Broadcast.mul(ret, maskArray, ret, 0, 2);
// ret.muliColumnVector(maskArray);
Broadcast.mul(ret, maskArray.castTo(ret.dataType()), ret, 0, 2);
}
return ret;
}

View File

@ -616,6 +616,7 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
DataType netDtype = getLayerWiseConfigurations().getDataType();
if(parameters != null && parameters.dataType() != netDtype){
Preconditions.checkState(parameters.rank() == 2 && parameters.size(0) == 1, "Invalid parameters array: should be rank 2 with shape [1,numParams]. Got %ndShape", parameters);
if(cloneParametersArray){
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
parameters = parameters.castTo(netDtype);
@ -627,6 +628,7 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
}
}
if (layerMap == null)
layerMap = new LinkedHashMap<>();

View File

@ -32,7 +32,7 @@ import java.util.Arrays;
@Ignore
public class TestSameDiffUI {
// @Ignore
@Ignore
@Test
public void testSameDiff() throws Exception {

View File

@ -1598,9 +1598,6 @@ namespace nd4j {
//////////////////////////////////////////////////////////////////////////
int NDArray::rankOf() const {
if (isEmpty())
return 0;
return shape::rank(_shapeInfo);
}

View File

@ -36,7 +36,7 @@ std::string NDArray::e(const Nd4jLong i) const;
template <typename T>
NDArray* NDArray::asT() const{
auto result = new NDArray(ordering(), isScalar() ? std::vector<Nd4jLong>({0}) : getShapeAsVector(), DataTypeUtils::fromT<T>());
auto result = isScalar() ? new NDArray('c', {}, {0.}, DataTypeUtils::fromT<T>(), this->getContext()) : new NDArray(ordering(), getShapeAsVector(), DataTypeUtils::fromT<T>(), this->getContext());
auto l = this->lengthOf();
prepareSpecialUse({result}, {this});
@ -67,17 +67,18 @@ NDArray::NDArray(const NDArray& other) {
////////////////////////////////////////////////////////////////////////
NDArray::NDArray(const char order, const std::vector<Nd4jLong> &shape, nd4j::DataType dtype, nd4j::LaunchContext * context) {
if (shape.empty())
throw std::runtime_error("NDArray constructor: input shape is empty !");
if ((int) shape.size() > MAX_RANK)
throw std::invalid_argument("Rank of NDArray can't exceed 32");
_context = context;
_isAttached = getContext()->getWorkspace() != nullptr;
_isAttached = _context->getWorkspace() != nullptr;
_offset = 0;
setShapeInfo(ShapeDescriptor(dtype, order, shape));
if (shape.empty())
setShapeInfo(ShapeDescriptor::emptyDescriptor(dtype));
else
setShapeInfo(ShapeDescriptor(dtype, order, shape));
_buffer = std::make_shared<DataBuffer>(lengthOf() * DataTypeUtils::sizeOf(dtype), dtype, getContext()->getWorkspace());
_buffer->setToZeroBuffers();
}
@ -85,16 +86,20 @@ NDArray::NDArray(const char order, const std::vector<Nd4jLong> &shape, nd4j::Dat
////////////////////////////////////////////////////////////////////////
NDArray::NDArray(const char order, const std::vector<Nd4jLong> &shape, const std::vector<double>& data, nd4j::DataType dtype, nd4j::LaunchContext * context) {
if (shape.empty())
throw std::runtime_error("NDArray constructor: input shape is empty !");
if ((int) shape.size() > MAX_RANK)
throw std::invalid_argument("Rank of NDArray can't exceed 32");
_context = context;
_offset = 0;
setShapeInfo(ShapeDescriptor(dtype, order, shape));
if (shape.size() == 0) {
if (data.size() == 0)
setShapeInfo(ShapeDescriptor::emptyDescriptor(dtype));
else
setShapeInfo(ShapeDescriptor::scalarDescriptor(dtype));
} else {
setShapeInfo(ShapeDescriptor(dtype, order, shape));
}
if (lengthOf() != data.size()) {
nd4j_printf("NDArray constructor: data size [%i] doesn't match shape length [%i]\n", data.size(), lengthOf());
@ -2441,6 +2446,9 @@ void NDArray::applyTrueBroadcast(nd4j::BroadcastOpsTuple op, const NDArray* othe
if(((op.s == scalar::Divide || op.s == scalar::FloorDiv || op.s == scalar::FloorMod) && other->isB()) || (op.s == scalar::ReverseDivide && this->isB()))
throw std::runtime_error("NDArray::applyTrueBroadcast method: you can't divide by bool array !");
if (isEmpty() || other->isEmpty())
return;
NDArray::prepareSpecialUse({target}, {this, other});
if (isScalar()) {
@ -2513,6 +2521,9 @@ void NDArray::applyTrueBroadcast(nd4j::BroadcastBoolOpsTuple op, const NDArray*
if(target == nullptr || other == nullptr)
throw std::runtime_error("NDArray::applyTrueBroadcast bool method: target or other = nullptr !");
if (isEmpty() || other->isEmpty())
return;
NDArray::prepareSpecialUse({target}, {this, other});
if (isScalar()) {
@ -2583,6 +2594,13 @@ void NDArray::applyTrueBroadcast(nd4j::BroadcastBoolOpsTuple op, const NDArray*
//////////////////////////////////////////////////////////////////////////
NDArray NDArray::applyTrueBroadcast(nd4j::BroadcastOpsTuple op, const NDArray& other, ExtraArguments *extraArgs) const {
if (isEmpty() || other.isEmpty()) {
if (isEmpty())
return NDArray(*this);
else
return NDArray(other);
}
Nd4jLong* newShapeInfo = nullptr;
if(!ShapeUtils::evalBroadcastShapeInfo(*this, other, true, newShapeInfo, getContext()->getWorkspace())) // the rank of new array = max->rankOf)()
throw std::runtime_error("NDArray::applyTrueBroadcast method: the shapes of this and other arrays are not suitable for broadcast operation !");
@ -2812,6 +2830,19 @@ bool NDArray::reshapei(const char order, const std::vector<Nd4jLong>& cshape) {
if(order == ordering() && shape::shapeEquals(rankOf(), shapeOf(), cshape.size(), cshape.data()))
return true;
const bool isOutShapeEmpty = std::find(cshape.begin(), cshape.end(), 0) != cshape.end();
if(isEmpty() && !isOutShapeEmpty)
throw std::invalid_argument("NDArray::reshapei: can't reshape empty array to non-empty !");
if(!isEmpty() && isOutShapeEmpty)
throw std::invalid_argument("NDArray::reshapei: can't reshape non-empty array to empty !");
if(isEmpty() && isOutShapeEmpty) {
Nd4jLong* shapeInfoNew = ShapeBuilders::emptyShapeInfo(dataType(), order, cshape, getContext()->getWorkspace());
setShapeInfo(shapeInfoNew);
RELEASE(shapeInfoNew, getContext()->getWorkspace());
return true;
}
std::vector<Nd4jLong> shape(cshape);
int rank = shape.size();
@ -2823,7 +2854,7 @@ bool NDArray::reshapei(const char order, const std::vector<Nd4jLong>& cshape) {
for (int i = 0; i < (int) shape.size(); i++) {
if (shape[i] < 0) {
if (numberNegativesOnes >= 1)
throw std::runtime_error("Only one dimension can be negative at once");
throw std::runtime_error("NDArray::reshapei: only one dimension can be negative at once");
numberNegativesOnes++;
@ -3664,7 +3695,7 @@ void NDArray::reduceAlongDimension(nd4j::reduce::SameOps op, NDArray* target, co
if(rankOf() == copy.size() || copy.empty()) {
NativeOpExecutioner::execReduceSameScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), nullptr, target->getBuffer(), target->getShapeInfo(), target->getSpecialBuffer(), target->getSpecialShapeInfo());
}
else {
else { //if (!isEmpty()) {
auto pDims = nd4j::Environment::getInstance()->isCPU() ? copy.data() : nullptr;
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(this->getShapeInfo(), copy);
NativeOpExecutioner::execReduceSame(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), nullptr, target->getBuffer(), target->getShapeInfo(), target->getSpecialBuffer(), target->getSpecialShapeInfo(), pDims, copy.size(), packX.platformShapeInfo(), packX.platformOffsets());
@ -4198,6 +4229,9 @@ NDArray* NDArray::tensorAlongDimension(Nd4jLong index, const std::vector<int>& d
// operator returns sub-array with buffer pointing at this->_buffer + certain offset
NDArray NDArray::operator()(const std::vector<Nd4jLong>& idx, const bool keepUnitiesInShape, const bool isStrided) const {
if(isEmpty())
throw std::invalid_argument("NDArray::operator(sub-arrays): array is empty !");
const int rank = rankOf();
Nd4jLong *newShapeInfo = ShapeBuilders::copyShapeInfo(getShapeInfo(), true, getContext()->getWorkspace());
@ -4260,6 +4294,9 @@ NDArray NDArray::operator()(const Nd4jLong subArrIdx, const std::vector<int>& di
////////////////////////////////////////////////////////////////////////
void NDArray::getSubArrShapeAndOffsets(const std::vector<int>& dimsToExclude, Nd4jLong* &subArrShapeInfo, Nd4jLong* &subArrOffsets, bool keepUnitiesInShape) const {
if(isEmpty())
throw std::invalid_argument("NDArray::getSubArrShapeAndOffsets: array is empty !");
const int rank = rankOf();
const int subArrRank = (rank == dimsToExclude.size() || keepUnitiesInShape) ? rank : rank - dimsToExclude.size();
const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(_shapeInfo, dimsToExclude);

View File

@ -1334,18 +1334,7 @@ public:
* @param npyArray
* @return
*/
Nd4jPointer shapeBufferForNumpy(Nd4jPointer npyArray) {
cnpy::NpyArray arr = cnpy::loadNpyFromPointer(reinterpret_cast<char *>(npyArray));
unsigned int shapeSize = arr.shape.size();
auto shape = new unsigned int[shapeSize];
for(unsigned int i = 0; i < shapeSize; i++) {
shape[i] = arr.shape[i];
}
auto shapeBuffer = shape::shapeBufferOfNpy(arr.shape.size(), shape, arr.fortranOrder);
delete[] shape;
return reinterpret_cast<Nd4jPointer>(shapeBuffer);
}
Nd4jPointer shapeBufferForNumpy(Nd4jPointer npyArray);
/**

View File

@ -64,8 +64,8 @@ void NDArray::tickWriteDevice() const { }
void NDArray::tickReadHost() const { }
void NDArray::tickReadDevice() const { }
void NDArray::tickBothActual() const { }
bool NDArray::isActualOnHostSide() const { }
bool NDArray::isActualOnDeviceSide() const { }
bool NDArray::isActualOnHostSide() const { return true; }
bool NDArray::isActualOnDeviceSide() const { return true; }
void NDArray::makeBothBuffersActual() const { }
@ -419,328 +419,8 @@ void NDArray::repeat(int dimension, NDArray& target) const {
//////////////////////////////////////////////////////////////////////////
#ifndef __JAVACPP_HACK__
template<typename T>
void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<T(T, T, T)>& func, NDArray* target) {
if (target == nullptr)
target = this;
#include "NDArrayLambda.hpp"
if (second == nullptr) {
nd4j_printf("applyTriplewiseLambda requires three operands to be valid NDArrays, but Second is NULL\n","");
throw std::runtime_error("second is null");
}
if (third == nullptr) {
nd4j_printf("applyTriplewiseLambda requires three operands to be valid NDArrays, but Third is NULL\n","");
throw std::runtime_error("third is null");
}
if(dataType() != DataTypeUtils::fromT<T>())
throw std::runtime_error("NDArray::applyTriplewiseLambda<T> method: wrong template parameter T, its type should be the same as type of this array!");
if(dataType() != second->dataType() || dataType() != third->dataType() || dataType() != target->dataType())
throw std::runtime_error("NDArray::applyTriplewiseLambda<T> method: bother four arrays (this, second, third, target) should have the same type !");
if (this->lengthOf() != second->lengthOf() || this->lengthOf() != third->lengthOf() || !this->isSameShape(second) || !this->isSameShape(third)) {
nd4j_printf("applyPairwiseLambda requires both operands to have the same shape\n","");
throw std::runtime_error("Shapes mismach");
}
auto f = this->bufferAsT<T>();
auto s = second->bufferAsT<T>();
auto t = third->bufferAsT<T>();
auto z = target->bufferAsT<T>();
if (this->ordering() == second->ordering() && this->ordering() == third->ordering() && this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1) && this->ews() == second->ews() && this->ews() == third->ews()) {
PRAGMA_OMP_PARALLEL_FOR_SIMD
for (Nd4jLong e = 0; e < _length; e++)
z[e] = func(f[e], s[e], t[e]);
} else {
if (f == z) {
PRAGMA_OMP_PARALLEL_FOR_SIMD
for (int e = 0; e < _length; e++) {
auto tOffset = this->getOffset(e);
auto uOffset = second->getOffset(e);
auto vOffset = third->getOffset(e);
f[tOffset] = func(f[tOffset], s[uOffset], t[vOffset]);
}
} else {
PRAGMA_OMP_PARALLEL_FOR_SIMD
for (int e = 0; e < _length; e++) {
auto tOffset = this->getOffset(e);
auto uOffset = second->getOffset(e);
auto vOffset = third->getOffset(e);
auto zOffset = target->getOffset(e);
z[zOffset] = func(f[tOffset], s[uOffset], t[vOffset]);
}
}
}
}
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<double (double, double, double)>& func, NDArray* target);
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<float (float, float, float)>& func, NDArray* target);
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<float16 (float16, float16, float16)>& func, NDArray* target);
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<bfloat16 (bfloat16, bfloat16, bfloat16)>& func, NDArray* target);
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<Nd4jLong (Nd4jLong, Nd4jLong, Nd4jLong)>& func, NDArray* target);
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<int (int, int, int)>& func, NDArray* target);
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<int16_t (int16_t, int16_t, int16_t)>& func, NDArray* target);
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<uint8_t (uint8_t, uint8_t, uint8_t)>& func, NDArray* target);
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<uint16_t (uint16_t, uint16_t, uint16_t)>& func, NDArray* target);
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<uint32_t (uint32_t, uint32_t, uint32_t)>& func, NDArray* target);
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<uint64_t (uint64_t, uint64_t, uint64_t)>& func, NDArray* target);
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<int8_t (int8_t, int8_t, int8_t)>& func, NDArray* target);
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<bool (bool, bool, bool)>& func, NDArray* target);
//////////////////////////////////////////////////////////////////////////
template<typename T>
void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<T(T, T)>& func, NDArray* target) {
if (target == nullptr)
target = this;
if (other == nullptr) {
nd4j_printf("applyPairwiseLambda requires both operands to be valid NDArrays, but Y is NULL\n","");
throw std::runtime_error("Other is null");
}
if(dataType() != DataTypeUtils::fromT<T>())
throw std::runtime_error("NDArray::applyPairwiseLambda<T> method: wrong template parameter T, its type should be the same as type of this array!");
if(dataType() != other->dataType() || dataType() != target->dataType())
throw std::runtime_error("NDArray::applyPairwiseLambda<T> method: all three arrays (this, other, target) must have the same type !");
if (this->lengthOf() != other->lengthOf()) {
nd4j_printf("applyPairwiseLambda requires both operands to have the same shape\n","");
throw std::runtime_error("Shapes mismach");
}
auto f = this->bufferAsT<T>();
auto s = other->bufferAsT<T>();
auto z = target->bufferAsT<T>();
if (this->ordering() == other->ordering() && this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1) && this->ews() == other->ews()) {
PRAGMA_OMP_PARALLEL_FOR_SIMD
for (int e = 0; e < _length; e++)
z[e] = func(f[e], s[e]);
} else {
if (f == z) {
PRAGMA_OMP_PARALLEL_FOR_SIMD
for (int e = 0; e < _length; e++) {
auto xOffset = this->getOffset(e);
auto yOffset = other->getOffset(e);
f[xOffset] = func(f[xOffset], s[yOffset]);
}
} else {
PRAGMA_OMP_PARALLEL_FOR_SIMD
for (int e = 0; e < _length; e++) {
auto xOffset = this->getOffset(e);
auto yOffset = other->getOffset(e);
auto zOffset = target->getOffset(e);
z[zOffset] = func(f[xOffset], s[yOffset]);
}
}
}
}
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<double (double, double)>& func, NDArray* target);
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<float (float, float)>& func, NDArray* target);
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<float16 (float16, float16)>& func, NDArray* target);
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<bfloat16 (bfloat16, bfloat16)>& func, NDArray* target);
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<Nd4jLong (Nd4jLong, Nd4jLong)>& func, NDArray* target);
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<int (int, int)>& func, NDArray* target);
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<int16_t (int16_t, int16_t)>& func, NDArray* target);
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<uint8_t (uint8_t, uint8_t)>& func, NDArray* target);
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<uint16_t (uint16_t, uint16_t)>& func, NDArray* target);
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<uint32_t (uint32_t, uint32_t)>& func, NDArray* target);
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<uint64_t (uint64_t, uint64_t)>& func, NDArray* target);
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<int8_t (int8_t, int8_t)>& func, NDArray* target);
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<bool (bool, bool)>& func, NDArray* target);
//////////////////////////////////////////////////////////////////////////
template<typename T>
void NDArray::applyLambda(const std::function<T(T)>& func, NDArray* target) {
if (target == nullptr)
target = this;
if(dataType() != DataTypeUtils::fromT<T>())
throw std::runtime_error("NDArray::applyLambda<T> method: wrong template parameter T, its type should be the same as type of this array!");
if(dataType() != target->dataType())
throw std::runtime_error("NDArray::applyLambda<T> method: types of this and target array should match !");
auto f = this->bufferAsT<T>();
auto z = target->bufferAsT<T>();
if (this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1)) {
PRAGMA_OMP_PARALLEL_FOR_SIMD
for (int e = 0; e < _length; e++)
z[e] = func(f[e]);
} else {
if (f == z) {
PRAGMA_OMP_PARALLEL_FOR_SIMD
for (int e = 0; e < _length; e++) {
auto xOffset = this->getOffset(e);
f[xOffset] = func(f[xOffset]);
}
} else {
PRAGMA_OMP_PARALLEL_FOR_SIMD
for (int e = 0; e < _length; e++) {
auto xOffset = this->getOffset(e);
auto zOffset = target->getOffset(e);
z[zOffset] = func(f[xOffset]);
}
}
}
}
template void NDArray::applyLambda(const std::function<double(double)>& func, NDArray* target);
template void NDArray::applyLambda(const std::function<float(float)>& func, NDArray* target);
template void NDArray::applyLambda(const std::function<float16(float16)>& func, NDArray* target);
template void NDArray::applyLambda(const std::function<bfloat16(bfloat16)>& func, NDArray* target);
template void NDArray::applyLambda(const std::function<Nd4jLong(Nd4jLong)>& func, NDArray* target);
template void NDArray::applyLambda(const std::function<int16_t(int16_t)>& func, NDArray* target);
template void NDArray::applyLambda(const std::function<int32_t(int32_t)>& func, NDArray* target);
template void NDArray::applyLambda(const std::function<uint8_t(uint8_t)>& func, NDArray* target);
template void NDArray::applyLambda(const std::function<uint16_t(uint16_t)>& func, NDArray* target);
template void NDArray::applyLambda(const std::function<uint32_t(uint32_t)>& func, NDArray* target);
template void NDArray::applyLambda(const std::function<uint64_t(uint64_t)>& func, NDArray* target);
template void NDArray::applyLambda(const std::function<int8_t(int8_t)>& func, NDArray* target);
template void NDArray::applyLambda(const std::function<bool(bool)>& func, NDArray* target);
//////////////////////////////////////////////////////////////////////////
template<typename T>
void NDArray::applyIndexedLambda(const std::function<T(Nd4jLong, T)>& func, NDArray* target) {
if (target == nullptr)
target = this;
if(dataType() != DataTypeUtils::fromT<T>())
throw std::runtime_error("NDArray::applyIndexedLambda<T> method: wrong template parameter T, its type should be the same as type of this array!");
if(dataType() != target->dataType())
throw std::runtime_error("NDArray::applyIndexedLambda<T> method: types of this and target array should match !");
auto f = this->bufferAsT<T>();
auto z = target->bufferAsT<T>();
if (this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1)) {
PRAGMA_OMP_PARALLEL_FOR_SIMD
for (Nd4jLong e = 0; e < _length; e++)
z[e] = func(e, f[e]);
} else {
if (f == z) {
PRAGMA_OMP_PARALLEL_FOR_SIMD
for (Nd4jLong e = 0; e < _length; e++) {
auto xOffset = this->getOffset(e);
f[xOffset] = func(e, f[xOffset]);
}
} else {
PRAGMA_OMP_PARALLEL_FOR_SIMD
for (Nd4jLong e = 0; e < _length; e++) {
auto xOffset = this->getOffset(e);
auto zOffset = target->getOffset(e);
z[zOffset] = func(e, f[xOffset]);
}
}
}
}
template void NDArray::applyIndexedLambda(const std::function<double(Nd4jLong, double)>& func, NDArray* target);
template void NDArray::applyIndexedLambda(const std::function<float(Nd4jLong, float)>& func, NDArray* target);
template void NDArray::applyIndexedLambda(const std::function<float16(Nd4jLong, float16)>& func, NDArray* target);
template void NDArray::applyIndexedLambda(const std::function<bfloat16(Nd4jLong, bfloat16)>& func, NDArray* target);
template void NDArray::applyIndexedLambda(const std::function<Nd4jLong(Nd4jLong, Nd4jLong)>& func, NDArray* target);
template void NDArray::applyIndexedLambda(const std::function<int(Nd4jLong, int)>& func, NDArray* target);
template void NDArray::applyIndexedLambda(const std::function<int16_t(Nd4jLong, int16_t)>& func, NDArray* target);
template void NDArray::applyIndexedLambda(const std::function<uint8_t (Nd4jLong, uint8_t)>& func, NDArray* target);
template void NDArray::applyIndexedLambda(const std::function<uint16_t (Nd4jLong, uint16_t)>& func, NDArray* target);
template void NDArray::applyIndexedLambda(const std::function<uint32_t (Nd4jLong, uint32_t)>& func, NDArray* target);
template void NDArray::applyIndexedLambda(const std::function<uint64_t (Nd4jLong, uint64_t)>& func, NDArray* target);
template void NDArray::applyIndexedLambda(const std::function<int8_t(Nd4jLong, int8_t)>& func, NDArray* target);
template void NDArray::applyIndexedLambda(const std::function<bool(Nd4jLong, bool)>& func, NDArray* target);
//////////////////////////////////////////////////////////////////////////
template<typename T>
void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<T(Nd4jLong, T, T)>& func, NDArray* target) {
if (target == nullptr)
target = this;
if (other == nullptr) {
nd4j_printf("applyIndexedPairwiseLambda requires both operands to be valid NDArrays, but Y is NULL\n","");
throw std::runtime_error("Other is null");
}
if(dataType() != DataTypeUtils::fromT<T>())
throw std::runtime_error("NDArray::applyIndexedPairwiseLambda<T> method: wrong template parameter T, its type should be the same as type of this array!");
if(dataType() != target->dataType())
throw std::runtime_error("NDArray::applyIndexedPairwiseLambda<T> method: types of this and target array should match !");
if (this->lengthOf() != other->lengthOf()) {
nd4j_printf("applyIndexedPairwiseLambda requires both operands to have the same shape\n","");
throw std::runtime_error("Shapes mismach");
}
auto f = this->bufferAsT<T>();
auto s = other->bufferAsT<T>();
auto z = target->bufferAsT<T>();
if (this->ordering() == other->ordering() && this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1) && this->ews() == other->ews()) {
PRAGMA_OMP_PARALLEL_FOR_SIMD
for (Nd4jLong e = 0; e < _length; e++)
z[e] = func((Nd4jLong) e, f[e], s[e]);
} else {
if (f == z) {
PRAGMA_OMP_PARALLEL_FOR_SIMD
for (int e = 0; e < _length; e++) {
auto xOffset = this->getOffset(e);
auto yOffset = other->getOffset(e);
f[xOffset] = func((Nd4jLong) e, f[xOffset], s[yOffset]);
}
} else {
PRAGMA_OMP_PARALLEL_FOR_SIMD
for (int e = 0; e < _length; e++) {
auto xOffset = this->getOffset(e);
auto yOffset = other->getOffset(e);
auto zOffset = target->getOffset(e);
z[zOffset] = func((Nd4jLong) e, f[xOffset], s[yOffset]);
}
}
}
}
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<double (Nd4jLong, double, double)>& func, NDArray* target);
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<float (Nd4jLong, float, float)>& func, NDArray* target);
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<float16 (Nd4jLong, float16, float16)>& func, NDArray* target);
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<bfloat16 (Nd4jLong, bfloat16, bfloat16)>& func, NDArray* target);
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<Nd4jLong (Nd4jLong, Nd4jLong, Nd4jLong)>& func, NDArray* target);
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<int (Nd4jLong, int, int)>& func, NDArray* target);
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<int16_t (Nd4jLong, int16_t, int16_t)>& func, NDArray* target);
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<uint8_t (Nd4jLong, uint8_t, uint8_t)>& func, NDArray* target);
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<uint16_t (Nd4jLong, uint16_t, uint16_t)>& func, NDArray* target);
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<uint32_t (Nd4jLong, uint32_t, uint32_t)>& func, NDArray* target);
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<uint64_t (Nd4jLong, uint64_t, uint64_t)>& func, NDArray* target);
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<int8_t (Nd4jLong, int8_t, int8_t)>& func, NDArray* target);
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<bool (Nd4jLong, bool, bool)>& func, NDArray* target);
#endif
/*

View File

@ -0,0 +1,325 @@
template<typename T>
void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<T(T, T, T)>& func, NDArray* target) {
if (target == nullptr)
target = this;
if (second == nullptr) {
nd4j_printf("applyTriplewiseLambda requires three operands to be valid NDArrays, but Second is NULL\n","");
throw std::runtime_error("second is null");
}
if (third == nullptr) {
nd4j_printf("applyTriplewiseLambda requires three operands to be valid NDArrays, but Third is NULL\n","");
throw std::runtime_error("third is null");
}
if(dataType() != DataTypeUtils::fromT<T>())
throw std::runtime_error("NDArray::applyTriplewiseLambda<T> method: wrong template parameter T, its type should be the same as type of this array!");
if(dataType() != second->dataType() || dataType() != third->dataType() || dataType() != target->dataType())
throw std::runtime_error("NDArray::applyTriplewiseLambda<T> method: bother four arrays (this, second, third, target) should have the same type !");
if (this->lengthOf() != second->lengthOf() || this->lengthOf() != third->lengthOf() || !this->isSameShape(second) || !this->isSameShape(third)) {
nd4j_printf("applyPairwiseLambda requires both operands to have the same shape\n","");
throw std::runtime_error("Shapes mismach");
}
auto f = this->bufferAsT<T>();
auto s = second->bufferAsT<T>();
auto t = third->bufferAsT<T>();
auto z = target->bufferAsT<T>();
if (this->ordering() == second->ordering() && this->ordering() == third->ordering() && this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1) && this->ews() == second->ews() && this->ews() == third->ews()) {
PRAGMA_OMP_PARALLEL_FOR_SIMD
for (Nd4jLong e = 0; e < _length; e++)
z[e] = func(f[e], s[e], t[e]);
} else {
if (f == z) {
PRAGMA_OMP_PARALLEL_FOR_SIMD
for (int e = 0; e < _length; e++) {
auto tOffset = this->getOffset(e);
auto uOffset = second->getOffset(e);
auto vOffset = third->getOffset(e);
f[tOffset] = func(f[tOffset], s[uOffset], t[vOffset]);
}
} else {
PRAGMA_OMP_PARALLEL_FOR_SIMD
for (int e = 0; e < _length; e++) {
auto tOffset = this->getOffset(e);
auto uOffset = second->getOffset(e);
auto vOffset = third->getOffset(e);
auto zOffset = target->getOffset(e);
z[zOffset] = func(f[tOffset], s[uOffset], t[vOffset]);
}
}
}
}
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<double (double, double, double)>& func, NDArray* target);
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<float (float, float, float)>& func, NDArray* target);
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<float16 (float16, float16, float16)>& func, NDArray* target);
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<bfloat16 (bfloat16, bfloat16, bfloat16)>& func, NDArray* target);
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<Nd4jLong (Nd4jLong, Nd4jLong, Nd4jLong)>& func, NDArray* target);
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<int (int, int, int)>& func, NDArray* target);
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<int16_t (int16_t, int16_t, int16_t)>& func, NDArray* target);
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<uint8_t (uint8_t, uint8_t, uint8_t)>& func, NDArray* target);
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<uint16_t (uint16_t, uint16_t, uint16_t)>& func, NDArray* target);
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<uint32_t (uint32_t, uint32_t, uint32_t)>& func, NDArray* target);
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<uint64_t (uint64_t, uint64_t, uint64_t)>& func, NDArray* target);
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<int8_t (int8_t, int8_t, int8_t)>& func, NDArray* target);
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<bool (bool, bool, bool)>& func, NDArray* target);
//////////////////////////////////////////////////////////////////////////
template<typename T>
void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<T(T, T)>& func, NDArray* target) {
if (target == nullptr)
target = this;
if (other == nullptr) {
nd4j_printf("applyPairwiseLambda requires both operands to be valid NDArrays, but Y is NULL\n","");
throw std::runtime_error("Other is null");
}
if(dataType() != DataTypeUtils::fromT<T>())
throw std::runtime_error("NDArray::applyPairwiseLambda<T> method: wrong template parameter T, its type should be the same as type of this array!");
if(dataType() != other->dataType() || dataType() != target->dataType())
throw std::runtime_error("NDArray::applyPairwiseLambda<T> method: all three arrays (this, other, target) must have the same type !");
if (this->lengthOf() != other->lengthOf()) {
nd4j_printf("applyPairwiseLambda requires both operands to have the same shape\n","");
throw std::runtime_error("Shapes mismach");
}
auto f = this->bufferAsT<T>();
auto s = other->bufferAsT<T>();
auto z = target->bufferAsT<T>();
if (this->ordering() == other->ordering() && this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1) && this->ews() == other->ews()) {
PRAGMA_OMP_PARALLEL_FOR_SIMD
for (int e = 0; e < _length; e++)
z[e] = func(f[e], s[e]);
} else {
if (f == z) {
PRAGMA_OMP_PARALLEL_FOR_SIMD
for (int e = 0; e < _length; e++) {
auto xOffset = this->getOffset(e);
auto yOffset = other->getOffset(e);
f[xOffset] = func(f[xOffset], s[yOffset]);
}
} else {
PRAGMA_OMP_PARALLEL_FOR_SIMD
for (int e = 0; e < _length; e++) {
auto xOffset = this->getOffset(e);
auto yOffset = other->getOffset(e);
auto zOffset = target->getOffset(e);
z[zOffset] = func(f[xOffset], s[yOffset]);
}
}
}
}
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<double (double, double)>& func, NDArray* target);
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<float (float, float)>& func, NDArray* target);
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<float16 (float16, float16)>& func, NDArray* target);
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<bfloat16 (bfloat16, bfloat16)>& func, NDArray* target);
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<Nd4jLong (Nd4jLong, Nd4jLong)>& func, NDArray* target);
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<int (int, int)>& func, NDArray* target);
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<int16_t (int16_t, int16_t)>& func, NDArray* target);
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<uint8_t (uint8_t, uint8_t)>& func, NDArray* target);
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<uint16_t (uint16_t, uint16_t)>& func, NDArray* target);
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<uint32_t (uint32_t, uint32_t)>& func, NDArray* target);
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<uint64_t (uint64_t, uint64_t)>& func, NDArray* target);
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<int8_t (int8_t, int8_t)>& func, NDArray* target);
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<bool (bool, bool)>& func, NDArray* target);
//////////////////////////////////////////////////////////////////////////
template<typename T>
void NDArray::applyLambda(const std::function<T(T)>& func, NDArray* target) {
if (target == nullptr)
target = this;
if(dataType() != DataTypeUtils::fromT<T>())
throw std::runtime_error("NDArray::applyLambda<T> method: wrong template parameter T, its type should be the same as type of this array!");
if(dataType() != target->dataType())
throw std::runtime_error("NDArray::applyLambda<T> method: types of this and target array should match !");
auto f = this->bufferAsT<T>();
auto z = target->bufferAsT<T>();
if (this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1)) {
PRAGMA_OMP_PARALLEL_FOR_SIMD
for (int e = 0; e < _length; e++)
z[e] = func(f[e]);
} else {
if (f == z) {
PRAGMA_OMP_PARALLEL_FOR_SIMD
for (int e = 0; e < _length; e++) {
auto xOffset = this->getOffset(e);
f[xOffset] = func(f[xOffset]);
}
} else {
PRAGMA_OMP_PARALLEL_FOR_SIMD
for (int e = 0; e < _length; e++) {
auto xOffset = this->getOffset(e);
auto zOffset = target->getOffset(e);
z[zOffset] = func(f[xOffset]);
}
}
}
}
template void NDArray::applyLambda(const std::function<double(double)>& func, NDArray* target);
template void NDArray::applyLambda(const std::function<float(float)>& func, NDArray* target);
template void NDArray::applyLambda(const std::function<float16(float16)>& func, NDArray* target);
template void NDArray::applyLambda(const std::function<bfloat16(bfloat16)>& func, NDArray* target);
template void NDArray::applyLambda(const std::function<Nd4jLong(Nd4jLong)>& func, NDArray* target);
template void NDArray::applyLambda(const std::function<int16_t(int16_t)>& func, NDArray* target);
template void NDArray::applyLambda(const std::function<int32_t(int32_t)>& func, NDArray* target);
template void NDArray::applyLambda(const std::function<uint8_t(uint8_t)>& func, NDArray* target);
template void NDArray::applyLambda(const std::function<uint16_t(uint16_t)>& func, NDArray* target);
template void NDArray::applyLambda(const std::function<uint32_t(uint32_t)>& func, NDArray* target);
template void NDArray::applyLambda(const std::function<uint64_t(uint64_t)>& func, NDArray* target);
template void NDArray::applyLambda(const std::function<int8_t(int8_t)>& func, NDArray* target);
template void NDArray::applyLambda(const std::function<bool(bool)>& func, NDArray* target);
//////////////////////////////////////////////////////////////////////////
template<typename T>
void NDArray::applyIndexedLambda(const std::function<T(Nd4jLong, T)>& func, NDArray* target) {
if (target == nullptr)
target = this;
if(dataType() != DataTypeUtils::fromT<T>())
throw std::runtime_error("NDArray::applyIndexedLambda<T> method: wrong template parameter T, its type should be the same as type of this array!");
if(dataType() != target->dataType())
throw std::runtime_error("NDArray::applyIndexedLambda<T> method: types of this and target array should match !");
auto f = this->bufferAsT<T>();
auto z = target->bufferAsT<T>();
if (this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1)) {
PRAGMA_OMP_PARALLEL_FOR_SIMD
for (Nd4jLong e = 0; e < _length; e++)
z[e] = func(e, f[e]);
} else {
if (f == z) {
PRAGMA_OMP_PARALLEL_FOR_SIMD
for (Nd4jLong e = 0; e < _length; e++) {
auto xOffset = this->getOffset(e);
f[xOffset] = func(e, f[xOffset]);
}
} else {
PRAGMA_OMP_PARALLEL_FOR_SIMD
for (Nd4jLong e = 0; e < _length; e++) {
auto xOffset = this->getOffset(e);
auto zOffset = target->getOffset(e);
z[zOffset] = func(e, f[xOffset]);
}
}
}
}
template void NDArray::applyIndexedLambda(const std::function<double(Nd4jLong, double)>& func, NDArray* target);
template void NDArray::applyIndexedLambda(const std::function<float(Nd4jLong, float)>& func, NDArray* target);
template void NDArray::applyIndexedLambda(const std::function<float16(Nd4jLong, float16)>& func, NDArray* target);
template void NDArray::applyIndexedLambda(const std::function<bfloat16(Nd4jLong, bfloat16)>& func, NDArray* target);
template void NDArray::applyIndexedLambda(const std::function<Nd4jLong(Nd4jLong, Nd4jLong)>& func, NDArray* target);
template void NDArray::applyIndexedLambda(const std::function<int(Nd4jLong, int)>& func, NDArray* target);
template void NDArray::applyIndexedLambda(const std::function<int16_t(Nd4jLong, int16_t)>& func, NDArray* target);
template void NDArray::applyIndexedLambda(const std::function<uint8_t (Nd4jLong, uint8_t)>& func, NDArray* target);
template void NDArray::applyIndexedLambda(const std::function<uint16_t (Nd4jLong, uint16_t)>& func, NDArray* target);
template void NDArray::applyIndexedLambda(const std::function<uint32_t (Nd4jLong, uint32_t)>& func, NDArray* target);
template void NDArray::applyIndexedLambda(const std::function<uint64_t (Nd4jLong, uint64_t)>& func, NDArray* target);
template void NDArray::applyIndexedLambda(const std::function<int8_t(Nd4jLong, int8_t)>& func, NDArray* target);
template void NDArray::applyIndexedLambda(const std::function<bool(Nd4jLong, bool)>& func, NDArray* target);
//////////////////////////////////////////////////////////////////////////
template<typename T>
void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<T(Nd4jLong, T, T)>& func, NDArray* target) {
if (target == nullptr)
target = this;
if (other == nullptr) {
nd4j_printf("applyIndexedPairwiseLambda requires both operands to be valid NDArrays, but Y is NULL\n","");
throw std::runtime_error("Other is null");
}
if(dataType() != DataTypeUtils::fromT<T>())
throw std::runtime_error("NDArray::applyIndexedPairwiseLambda<T> method: wrong template parameter T, its type should be the same as type of this array!");
if(dataType() != target->dataType())
throw std::runtime_error("NDArray::applyIndexedPairwiseLambda<T> method: types of this and target array should match !");
if (this->lengthOf() != other->lengthOf()) {
nd4j_printf("applyIndexedPairwiseLambda requires both operands to have the same shape\n","");
throw std::runtime_error("Shapes mismach");
}
auto f = this->bufferAsT<T>();
auto s = other->bufferAsT<T>();
auto z = target->bufferAsT<T>();
if (this->ordering() == other->ordering() && this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1) && this->ews() == other->ews()) {
PRAGMA_OMP_PARALLEL_FOR_SIMD
for (Nd4jLong e = 0; e < _length; e++)
z[e] = func((Nd4jLong) e, f[e], s[e]);
} else {
if (f == z) {
PRAGMA_OMP_PARALLEL_FOR_SIMD
for (int e = 0; e < _length; e++) {
auto xOffset = this->getOffset(e);
auto yOffset = other->getOffset(e);
f[xOffset] = func((Nd4jLong) e, f[xOffset], s[yOffset]);
}
} else {
PRAGMA_OMP_PARALLEL_FOR_SIMD
for (int e = 0; e < _length; e++) {
auto xOffset = this->getOffset(e);
auto yOffset = other->getOffset(e);
auto zOffset = target->getOffset(e);
z[zOffset] = func((Nd4jLong) e, f[xOffset], s[yOffset]);
}
}
}
}
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<double (Nd4jLong, double, double)>& func, NDArray* target);
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<float (Nd4jLong, float, float)>& func, NDArray* target);
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<float16 (Nd4jLong, float16, float16)>& func, NDArray* target);
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<bfloat16 (Nd4jLong, bfloat16, bfloat16)>& func, NDArray* target);
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<Nd4jLong (Nd4jLong, Nd4jLong, Nd4jLong)>& func, NDArray* target);
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<int (Nd4jLong, int, int)>& func, NDArray* target);
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<int16_t (Nd4jLong, int16_t, int16_t)>& func, NDArray* target);
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<uint8_t (Nd4jLong, uint8_t, uint8_t)>& func, NDArray* target);
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<uint16_t (Nd4jLong, uint16_t, uint16_t)>& func, NDArray* target);
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<uint32_t (Nd4jLong, uint32_t, uint32_t)>& func, NDArray* target);
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<uint64_t (Nd4jLong, uint64_t, uint64_t)>& func, NDArray* target);
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<int8_t (Nd4jLong, int8_t, int8_t)>& func, NDArray* target);
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<bool (Nd4jLong, bool, bool)>& func, NDArray* target);

View File

@ -2710,6 +2710,32 @@ int NativeOps::dataTypeFromNpyHeader(void *header) {
return (int) cnpy::dataTypeFromHeader(reinterpret_cast<char *>(header));
}
Nd4jPointer NativeOps::shapeBufferForNumpy(Nd4jPointer npyArray) {
cnpy::NpyArray arr = cnpy::loadNpyFromPointer(reinterpret_cast<char *>(npyArray));
unsigned int shapeSize = arr.shape.size();
std::vector<Nd4jLong> shape(shapeSize);
bool _empty = false;
for(unsigned int i = 0; i < shapeSize; i++) {
shape[i] = arr.shape[i];
if (arr.shape[i] == 0)
_empty = true;
}
auto dtype = cnpy::dataTypeFromHeader(reinterpret_cast<char *>(npyArray));
Nd4jLong *shapeBuffer;
if (_empty) {
if (shapeSize > 0)
shapeBuffer = nd4j::ShapeBuilders::emptyShapeInfo(dtype, arr.fortranOrder ? 'f' : 'c', shape);
else
shapeBuffer = nd4j::ShapeBuilders::emptyShapeInfo(dtype);
} else {
shapeBuffer = nd4j::ShapeBuilders::createShapeInfo(dtype, arr.fortranOrder ? 'f' : 'c', shape);
}
return reinterpret_cast<Nd4jPointer>(shapeBuffer);
}
BUILD_SINGLE_TEMPLATE(template void flattenGeneric,(Nd4jPointer*, int, char, void*, Nd4jLong*, void*, Nd4jLong*), LIBND4J_TYPES);
BUILD_SINGLE_TEMPLATE(template void pullRowsGeneric, (void *, Nd4jLong*, void*, Nd4jLong*, const int, Nd4jLong*, Nd4jLong*, Nd4jLong*, Nd4jLong*, Nd4jLong*), LIBND4J_TYPES);
BUILD_SINGLE_TEMPLATE(template void tearGeneric, (void *, Nd4jLong*, Nd4jPointer*, Nd4jLong*, Nd4jLong*, Nd4jLong*), LIBND4J_TYPES);

View File

@ -454,7 +454,7 @@ void NDArray::printCurrentBuffer(const bool host, const char* msg, const int pre
if (ews() != 1) {
for (uint i = 0; i < _length; i++)
cudaMemcpyAsync(pHost + i * sizeof(T), getSpecialBuffer() + getOffset(i) * sizeof(T), sizeof(T), cudaMemcpyDeviceToHost, *(getContext()->getCudaStream()));
cudaMemcpyAsync(reinterpret_cast<T*>(pHost) + i, specialBufferWithOffset(i), sizeof(T), cudaMemcpyDeviceToHost, *(getContext()->getCudaStream()));
}
else
cudaMemcpyAsync(pHost, getSpecialBuffer(), sizeOfT() * _length, cudaMemcpyDeviceToHost, *getContext()->getCudaStream());
@ -475,6 +475,12 @@ template void NDArray::printCurrentBuffer<float>(const bool host, const char* ms
template void NDArray::printCurrentBuffer<double>(const bool host, const char* msg, const int precision) const;
#if defined(__CUDACC__) && !defined(BUILD_TESTS)
#include <cpu/NDArrayLambda.hpp>
#endif
} // end namespace nd4j
#endif

View File

@ -3105,3 +3105,29 @@ nd4j::ConstantDataBuffer* NativeOps::constantBuffer(nd4j::DataType dtype, double
nd4j::ConstantDataBuffer* NativeOps::constantBuffer(nd4j::DataType dtype, nd4j::ConstantDescriptor *descriptor) {
return nd4j::ConstantHelper::getInstance()->constantBuffer(*descriptor, dtype);
}
Nd4jPointer NativeOps::shapeBufferForNumpy(Nd4jPointer npyArray) {
cnpy::NpyArray arr = cnpy::loadNpyFromPointer(reinterpret_cast<char *>(npyArray));
unsigned int shapeSize = arr.shape.size();
std::vector<Nd4jLong> shape(shapeSize);
bool _empty = false;
for(unsigned int i = 0; i < shapeSize; i++) {
shape[i] = arr.shape[i];
if (arr.shape[i] == 0)
_empty = true;
}
auto dtype = cnpy::dataTypeFromHeader(reinterpret_cast<char *>(npyArray));
Nd4jLong *shapeBuffer;
if (_empty) {
if (shapeSize > 0)
shapeBuffer = nd4j::ShapeBuilders::emptyShapeInfo(dtype, arr.fortranOrder ? 'f' : 'c', shape);
else
shapeBuffer = nd4j::ShapeBuilders::emptyShapeInfo(dtype);
} else {
shapeBuffer = nd4j::ShapeBuilders::createShapeInfo(dtype, arr.fortranOrder ? 'f' : 'c', shape);
}
return reinterpret_cast<Nd4jPointer>(shapeBuffer);
}

View File

@ -53,6 +53,15 @@ namespace nd4j {
template <typename T>
FORCEINLINE static _CUDA_HD T max();
/**
* returns inf for float/double and max for everything else
*/
template <typename T>
FORCEINLINE static _CUDA_HD T infOrMax();
template <typename T>
FORCEINLINE static _CUDA_HD T nanOrZero();
// returns the difference between 1.0 and the next representable value of the given floating-point type
template <typename T>
FORCEINLINE static T eps();
@ -290,6 +299,36 @@ FORCEINLINE _CUDA_HD bfloat16 DataTypeUtils::max<bfloat16>() {
return bfloat16::max();
}
template <>
FORCEINLINE _CUDA_HD float DataTypeUtils::infOrMax<float>() {
return std::numeric_limits<float>::infinity();
}
template <>
FORCEINLINE _CUDA_HD double DataTypeUtils::infOrMax<double>() {
return std::numeric_limits<double>::infinity();
}
template <typename T>
FORCEINLINE _CUDA_HD T DataTypeUtils::infOrMax() {
return DataTypeUtils::max<T>();
}
template <>
FORCEINLINE _CUDA_HD float DataTypeUtils::nanOrZero<float>() {
return std::numeric_limits<float>::quiet_NaN();
}
template <>
FORCEINLINE _CUDA_HD double DataTypeUtils::nanOrZero<double>() {
return std::numeric_limits<double>::quiet_NaN();
}
template <typename T>
FORCEINLINE _CUDA_HD T DataTypeUtils::nanOrZero() {
return static_cast<T>(0);
}
FORCEINLINE std::string DataTypeUtils::asString(DataType dataType) {
switch(dataType) {
case INT8:

View File

@ -55,8 +55,13 @@ bool ShapeDescriptor::operator<(const ShapeDescriptor& other) const {
}
Nd4jLong* ShapeDescriptor::toShapeInfo() const {
if (_empty)
return ShapeBuilders::emptyShapeInfo(_dataType);
if (_empty) {
if (_rank == 0)
return ShapeBuilders::emptyShapeInfo(_dataType);
else {
return ShapeBuilders::emptyShapeInfo(_dataType, _order, _shape);
}
}
switch (_rank) {
@ -133,15 +138,11 @@ ShapeDescriptor::ShapeDescriptor(const DataType type, const char order, const Nd
//////////////////////////////////////////////////////////////////////////
ShapeDescriptor::ShapeDescriptor(const DataType type, const char order, const std::vector<Nd4jLong> &shape): _dataType(type), _order(order), _shape(shape) {
_rank = ((shape.size() == 1 && shape[0] == 0)? 0: shape.size());
_rank = shape.size();
_ews = 1;
if (_rank > 0) {
_strides.resize(_rank);
if (order == 'c')
shape::calcStrides(_shape.data(), shape.size(), _strides.data());
else
shape::calcStridesFortran(_shape.data(), shape.size(), _strides.data());
for (auto v:_shape) {
if (v == 0) {
@ -149,6 +150,17 @@ ShapeDescriptor::ShapeDescriptor(const DataType type, const char order, const st
break;
}
}
// no point calculating strides for empty arrays
if (!_empty) {
if (order == 'c')
shape::calcStrides(_shape.data(), shape.size(), _strides.data());
else
shape::calcStridesFortran(_shape.data(), shape.size(), _strides.data());
} else {
// all strides set to 0
memset(_strides.data(), 0, sizeof(Nd4jLong) * shape.size());
}
}
}
@ -191,8 +203,11 @@ ShapeDescriptor::ShapeDescriptor(const Nd4jLong *shapeInfo, bool inheritDtype) {
_empty = shape::isEmpty(shapeInfo);
for (int e = 0; e < _rank; e++)
for (int e = 0; e < _rank; e++) {
_shape.emplace_back(shapeInfo[e + 1]);
if (shapeInfo[e + 1] == 0)
_empty = true;
}
for (int e = 0; e < _rank; e++)
_strides.emplace_back(shapeInfo[e + 1 + _rank]);
@ -304,7 +319,14 @@ ShapeDescriptor ShapeDescriptor::vectorDescriptor(const Nd4jLong length, const D
ShapeDescriptor descriptor;
descriptor._dataType = type;
descriptor._shape.emplace_back(length);
descriptor._strides.emplace_back(1);
if (length > 0)
descriptor._strides.emplace_back(1);
else {
descriptor._strides.emplace_back(0);
descriptor._empty = true;
}
descriptor._order = 'c';
descriptor._ews = 1;
descriptor._rank = 1;

View File

@ -29,7 +29,7 @@
#include <array/ArrayOptions.h>
namespace nd4j {
class ShapeBuilders {
class ND4J_EXPORT ShapeBuilders {
public:
static Nd4jLong* createScalarShapeInfo(nd4j::DataType dataType, nd4j::memory::Workspace* workspace = nullptr);
@ -53,6 +53,8 @@ namespace nd4j {
static Nd4jLong* emptyShapeInfo(const nd4j::DataType dataType, memory::Workspace* workspace = nullptr);
static Nd4jLong* emptyShapeInfo(const nd4j::DataType dataType, const char order, const std::vector<Nd4jLong> &shape, memory::Workspace* workspace = nullptr);
};
}

View File

@ -40,6 +40,12 @@ namespace nd4j {
static Nd4jLong* evalReduceShapeInfo(const char order, std::vector<int>& dimensions, const NDArray& arr, const bool keepDims = false, const bool supportOldShapes = false, nd4j::memory::Workspace* workspace = nullptr);
static Nd4jLong* evalReduceShapeInfo(const char order, std::vector<int>& dimensions, const Nd4jLong* shapeInfo, const bool keepDims = false, const bool supportOldShapes = false, nd4j::memory::Workspace* workspace = nullptr);
/**
* evaluate output shape for reduce operation when input shape is empty
* behavior is analogous to tf
*/
static Nd4jLong* evalReduceShapeInfoEmpty(const char order, std::vector<int>& dimensions, const Nd4jLong *shapeInfo, const nd4j::DataType dataType, const bool keepDims, nd4j::memory::Workspace* workspace);
// evaluate shape for array which is result of repeat operation applied to arr
static std::vector<Nd4jLong> evalRepeatShape(int dimension, const std::vector<Nd4jLong>& repeats, const NDArray& arr);

View File

@ -71,7 +71,9 @@ namespace nd4j {
auto sPtr = new Nd4jLong[shape::shapeInfoLength(subArrRank)];
auto oPtr = new Nd4jLong[numOfSubArrs];
shape::calcSubArrShapeAndOffsets(shapeInfo, numOfSubArrs, dimsToExclude.size(), dimsToExclude.data(), sPtr, oPtr, descriptor.areUnitiesinShape());
if (numOfSubArrs > 0)
shape::calcSubArrShapeAndOffsets(shapeInfo, numOfSubArrs, dimsToExclude.size(), dimsToExclude.data(), sPtr, oPtr, descriptor.areUnitiesinShape());
ConstantDataBuffer shapesBuffer(sPtr, nullptr, shape::shapeInfoLength(subArrRank)*sizeof(Nd4jLong), DataType::INT64);
ConstantDataBuffer offsetsBuffer(oPtr, nullptr, numOfSubArrs*sizeof(Nd4jLong), DataType::INT64);

View File

@ -75,7 +75,8 @@ namespace nd4j {
auto sPtr = new Nd4jLong[shape::shapeInfoLength(subArrRank)];
auto oPtr = new Nd4jLong[numOfSubArrs];
shape::calcSubArrShapeAndOffsets(shapeInfo, numOfSubArrs, dimsToExclude.size(), dimsToExclude.data(), sPtr, oPtr, descriptor.areUnitiesinShape());
if (numOfSubArrs > 0)
shape::calcSubArrShapeAndOffsets(shapeInfo, numOfSubArrs, dimsToExclude.size(), dimsToExclude.data(), sPtr, oPtr, descriptor.areUnitiesinShape());
Nd4jPointer soPtr;
auto res = cudaMalloc(reinterpret_cast<void**>(&soPtr), numOfSubArrs * sizeof(Nd4jLong));

View File

@ -54,11 +54,6 @@ namespace nd4j {
////////////////////////////////////////////////////////////////////////////////
Nd4jLong* ShapeBuilders::createShapeInfo(const nd4j::DataType dataType, const char order, int rank, const Nd4jLong* shapeOnly, memory::Workspace* workspace) {
if (rank)
if(shapeOnly[0] == 0) // scalar case
rank = 0;
Nd4jLong* shapeInfo = nullptr;
if(rank == 0) { // scalar case
@ -67,10 +62,23 @@ namespace nd4j {
else {
ALLOCATE(shapeInfo, workspace, shape::shapeInfoLength(rank), Nd4jLong);
shapeInfo[0] = rank;
for(int i = 0; i < rank; ++i)
bool isEmpty = false;
for(int i = 0; i < rank; ++i) {
shapeInfo[i + 1] = shapeOnly[i];
shape::updateStrides(shapeInfo, order);
if (shapeOnly[i] == 0)
isEmpty = true;
}
if (!isEmpty) {
shape::updateStrides(shapeInfo, order);
}
else {
shapeInfo[shape::shapeInfoLength(rank) - 1] = order;
memset(shape::stride(shapeInfo), 0, rank * sizeof(Nd4jLong));
ArrayOptions::setPropertyBit(shapeInfo, ARRAY_EMPTY);
}
nd4j::ArrayOptions::setDataType(shapeInfo, dataType);
}
@ -78,9 +86,16 @@ namespace nd4j {
}
Nd4jLong* ShapeBuilders::emptyShapeInfo(const nd4j::DataType dataType, memory::Workspace* workspace) {
auto shape = createScalarShapeInfo(dataType, workspace);
ArrayOptions::setPropertyBit(shape, ARRAY_EMPTY);
return shape;
auto shapeInfo = createScalarShapeInfo(dataType, workspace);
ArrayOptions::setPropertyBit(shapeInfo, ARRAY_EMPTY);
return shapeInfo;
}
Nd4jLong* ShapeBuilders::emptyShapeInfo(const nd4j::DataType dataType, const char order, const std::vector<Nd4jLong> &shape, memory::Workspace* workspace) {
auto shapeInfo = createShapeInfo(dataType, order, shape, workspace);
memset(shape::stride(shapeInfo), 0, shape.size() * sizeof(Nd4jLong));
ArrayOptions::setPropertyBit(shapeInfo, ARRAY_EMPTY);
return shapeInfo;
}
////////////////////////////////////////////////////////////////////////////////

View File

@ -108,27 +108,81 @@ std::vector<Nd4jLong> ShapeUtils::evalShapeForTensorDot(const NDArray* a, cons
return evalShapeForTensorDot(a->getShapeInfo(), b->getShapeInfo(), axesA, axesB, permutAt, permutBt, shapeAt, shapeBt);
}
Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector<int>& dimensions, const NDArray& arr, const bool keepDims, const bool supportOldShapes, nd4j::memory::Workspace* workspace) {
return evalReduceShapeInfo(order, dimensions, arr, arr.dataType(), keepDims, supportOldShapes, workspace);
//////////////////////////////////////////////////////////////////////////
// evaluate output shape for reduce operation when input shape is empty
Nd4jLong* ShapeUtils::evalReduceShapeInfoEmpty(const char order, std::vector<int>& dimsToExclude, const Nd4jLong *shapeInfo, const nd4j::DataType dataType, const bool keepDims, nd4j::memory::Workspace* workspace) {
if (dimsToExclude.size() == 0) { // return copy of input shape
Nd4jLong* outShapeInfo = ShapeBuilders::copyShapeInfoAndType(shapeInfo, dataType, true, workspace);
ShapeDescriptor descriptor(outShapeInfo, dataType);
RELEASE(outShapeInfo, workspace);
return ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
}
const int rank = shape::rank(shapeInfo);
Nd4jLong* outShapeInfo = nullptr;
if (dimsToExclude.size() == rank) { // return scalar or shape filled with unities
if(!keepDims)
outShapeInfo = ShapeBuilders::createScalarShapeInfo(dataType, workspace);
else
outShapeInfo = ShapeBuilders::createShapeInfo(dataType, order, std::vector<Nd4jLong>(rank, 1), workspace);
}
else {
shape::checkDimensions(rank, dimsToExclude);
std::vector<Nd4jLong> outShape;
if(keepDims) {
outShape.assign(shapeInfo + 1, shapeInfo + 1 + rank);
for(const auto& dim : dimsToExclude)
outShape[dim] = 1;
}
else {
for (uint i = 0, j = 0; i < rank; ++i) {
if(j < dimsToExclude.size() && i == dimsToExclude[j])
++j;
else
outShape.emplace_back(shapeInfo[i + 1]);
}
}
outShapeInfo = ShapeBuilders::createShapeInfo(dataType, order, outShape, workspace);
}
ShapeDescriptor descriptor(outShapeInfo, dataType);
RELEASE(outShapeInfo, workspace);
return ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
}
Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector<int>& dimensions, const Nd4jLong* shapeInfo, const bool keepDims, const bool supportOldShapes, nd4j::memory::Workspace* workspace) {
return evalReduceShapeInfo(order, dimensions, shapeInfo, ArrayOptions::dataType(shapeInfo), keepDims, supportOldShapes, workspace);
Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector<int>& dimsToExclude, const NDArray& arr, const bool keepDims, const bool supportOldShapes, nd4j::memory::Workspace* workspace) {
return evalReduceShapeInfo(order, dimsToExclude, arr, arr.dataType(), keepDims, supportOldShapes, workspace);
}
Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector<int>& dimsToExclude, const Nd4jLong* shapeInfo, const bool keepDims, const bool supportOldShapes, nd4j::memory::Workspace* workspace) {
return evalReduceShapeInfo(order, dimsToExclude, shapeInfo, ArrayOptions::dataType(shapeInfo), keepDims, supportOldShapes, workspace);
}
//////////////////////////////////////////////////////////////////////////
Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector<int>& dimensions, const NDArray& arr, const nd4j::DataType dataType, const bool keepDims, const bool supportOldShapes, nd4j::memory::Workspace* workspace) {
return evalReduceShapeInfo(order, dimensions, arr.getShapeInfo(), dataType, keepDims, supportOldShapes, workspace);
Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector<int>& dimsToExclude, const NDArray& arr, const nd4j::DataType dataType, const bool keepDims, const bool supportOldShapes, nd4j::memory::Workspace* workspace) {
return evalReduceShapeInfo(order, dimsToExclude, arr.getShapeInfo(), dataType, keepDims, supportOldShapes, workspace);
}
//////////////////////////////////////////////////////////////////////////
// evaluate shape resulting from reduce operation
Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector<int>& dimensions, const Nd4jLong *shapeInfo, const nd4j::DataType dataType, const bool keepDims, const bool supportOldShapes, nd4j::memory::Workspace* workspace) {
Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector<int>& dimsToExclude, const Nd4jLong *shapeInfo, const nd4j::DataType dataType, const bool keepDims, const bool supportOldShapes, nd4j::memory::Workspace* workspace) {
if(ArrayOptions::arrayType(shapeInfo) == ArrayType::EMPTY)
return ShapeUtils::evalReduceShapeInfoEmpty(order, dimsToExclude, shapeInfo, dataType, keepDims, workspace);
Nd4jLong* newShapeInfo = nullptr;
int rank = shape::rank(const_cast<Nd4jLong*>(shapeInfo));
if (dimensions.size() == 0) { // return scalar or array with len=1 in this case
if (dimsToExclude.size() == 0) { // return scalar or array with len=1 in this case
if(keepDims && rank > 1) {
ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(rank), Nd4jLong);
@ -157,16 +211,16 @@ Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector<int>& di
}
}
shape::checkDimensions(rank, dimensions);
shape::checkDimensions(rank, dimsToExclude);
int dimSize = dimensions.size();
int dimSize = dimsToExclude.size();
if(keepDims) {
ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(rank), Nd4jLong);
newShapeInfo[0] = rank;
for(int i = 0; i < rank; ++i)
if (std::binary_search(dimensions.begin(), dimensions.end(), i)) // dimensions is already sorted after shape::checkDimensions() has been applied
if (std::binary_search(dimsToExclude.begin(), dimsToExclude.end(), i)) // dimsToExclude is already sorted after shape::checkDimensions() has been applied
newShapeInfo[i+1] = 1;
else
newShapeInfo[i+1] = shapeInfo[i+1];
@ -178,7 +232,7 @@ Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector<int>& di
}
int newRank = rank - dimSize;
if (newRank==0 || (dimSize==1 && dimensions[0]==INT_MAX)) { // check whether given dimension is meant for the whole dimension
if (newRank==0 || (dimSize==1 && dimsToExclude[0]==INT_MAX)) { // check whether given dimension is meant for the whole dimension
if(supportOldShapes) {
ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(2), Nd4jLong);
@ -199,7 +253,7 @@ Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector<int>& di
newShapeInfo[0] = newRank; // set rank
int j=1;
for(int i = 0; i < rank; ++i)
if (!std::binary_search(dimensions.begin(), dimensions.end(), i)) // dimensions is already sorted after shape::checkDimensions() has been applied
if (!std::binary_search(dimsToExclude.begin(), dimsToExclude.end(), i)) // dimsToExclude is already sorted after shape::checkDimensions() has been applied
newShapeInfo[j++] = shapeInfo[i+1];
//ensure whether vector has proper shape for old shape type
@ -208,7 +262,7 @@ Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector<int>& di
RELEASE(newShapeInfo, workspace);
ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(2), Nd4jLong); // set newRank = 2
newShapeInfo[0] = 2;
if (dimensions[0] == 0) {
if (dimsToExclude[0] == 0) {
newShapeInfo[1] = 1;
newShapeInfo[2] = oldValue;
}
@ -422,8 +476,23 @@ bool ShapeUtils::evalBroadcastShapeInfo(Nd4jLong *max, Nd4jLong *min, const bool
if(maxShapeInfo[maxRank-i] < minShapeInfo[minRank-i])
tmpShapeInfo[maxRank - i] = minShapeInfo[minRank-i];
// nullify zero axis
for (int e = 0; e < maxRank; e++)
if (maxShapeInfo[e+1] == 0)
tmpShapeInfo[e+1] = 0;
int delta = maxRank - minRank;
for (int e = minRank - 1; e >= 0; e--)
if (minShapeInfo[e + 1] == 0)
tmpShapeInfo[e + 1 + delta] = 0;
ShapeUtils::updateStridesAndType(tmpShapeInfo, DataTypeUtils::pickPairwiseResultType(maxShapeInfo, minShapeInfo), shape::order(maxShapeInfo));
if (shape::isEmpty(max) || shape::isEmpty(min)) {
ArrayOptions::setPropertyBit(tmpShapeInfo, ARRAY_EMPTY);
memset(shape::stride(tmpShapeInfo), 0, shape::rank(tmpShapeInfo) * sizeof(Nd4jLong));
}
ShapeDescriptor descriptor(tmpShapeInfo);
RELEASE(tmpShapeInfo, workspace);
resultShapeInfo = ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
@ -805,7 +874,7 @@ std::vector<Nd4jLong> ShapeUtils::evalShapeForMatmul(const Nd4jLong* xShapeInfo,
nd4j_printf("ShapeUtils::evalShapeForMatmul method: since input arrays are vectors they must have the same length, but got x length = %i, y length = %i !", xShapeInfo[1], yShapeInfo[1]);
throw std::invalid_argument("");
}
return std::vector<Nd4jLong>({0});
return std::vector<Nd4jLong>({});
}

View File

@ -1992,7 +1992,7 @@ template <typename T>
len = shape::length(shapeInfo);
//check whether shape is like {1} or {1,1} or {1,1,1,1,...} - in this case we don't need permute
if(len < 2)
if(len == 1)
return;
const int rank = shape::rank(shapeInfo);
@ -3961,7 +3961,7 @@ INLINEDEF _CUDA_H bool reshapeC(const int oldRank, const Nd4jLong* oldShapeInfo,
newDim = newShape[newStart];
oldDim = oldShape[oldStart];
while (newDim != oldDim)
while (newDim != oldDim && newDim > 0 && oldDim > 0)
if (newDim < oldDim) newDim *= newShape[newStop++];
else oldDim *= oldShape[oldStop++];

View File

@ -41,7 +41,7 @@ void IndexReduce<X>::exec(const int opNum,
void *x, Nd4jLong *xShapeInfo,
void *extraParams,
Nd4jLong *z, Nd4jLong *zShapeInfo,
int *dimension, int dimensionLength,
int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffset) {
DISPATCH_BY_OPNUM_T(exec, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffset), INDEX_REDUCE_OPS);
@ -51,7 +51,7 @@ DISPATCH_BY_OPNUM_T(exec, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, dime
template <typename X>
template<typename OpType>
Nd4jLong IndexReduce<X>::execScalar(void *vx, Nd4jLong *xShapeInfo, void *vextraParams) {
auto x = reinterpret_cast<X *>(vx);
auto extraParams = reinterpret_cast<X *>(vextraParams);
@ -116,13 +116,23 @@ void IndexReduce<X>::exec(void *vx, Nd4jLong *xShapeInfo,
auto x = reinterpret_cast<X *>(vx);
auto extraParams = reinterpret_cast<X *>(vextraParams);
const Nd4jLong zLen = shape::length(zShapeInfo);
if(nd4j::ArrayOptions::arrayType(xShapeInfo) == nd4j::ArrayType::EMPTY) {
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
return;
const auto indexValue = OpType::startingIndexValue(x);
PRAGMA_OMP_PARALLEL_FOR_IF(zLen > nd4j::Environment::getInstance()->elementwiseThreshold())
for (uint i = 0; i < zLen; i++)
z[i] = indexValue.index;;
return;
}
if(shape::isScalar(zShapeInfo)) {
z[0] = execScalar<OpType>(x,xShapeInfo,extraParams);
return;
}
const Nd4jLong zLen = shape::length(zShapeInfo);
auto tadOnlyShapeInfo = tadShapeInfo;
Nd4jLong *tadOffsets = tadOffset;

View File

@ -45,7 +45,22 @@ namespace functions {
const Nd4jLong length = shape::length(xShapeInfo);
auto xEws = shape::elementWiseStride(xShapeInfo);
if (shape::isEmpty(xShapeInfo)) {
z[0] = OpType::startingValue(x);
return;
}
if(nd4j::ArrayOptions::arrayType(xShapeInfo) == nd4j::ArrayType::EMPTY) {
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
return;
const auto startingVal = OpType::startingValue(x);
PRAGMA_OMP_PARALLEL_FOR_IF(length > nd4j::Environment::getInstance()->elementwiseThreshold())
for (uint i = 0; i < length; i++)
z[i] = startingVal;
return;
}
if (xEws >= 1) {
z[0] = execScalar<OpType>(x, xEws, length, extraParams);
}
@ -82,7 +97,7 @@ namespace functions {
const Nd4jLong length = shape::length(xShapeInfo);
auto xEws = shape::elementWiseStride(xShapeInfo);
if (xEws >= 1) {
return execScalar<OpType>(x, xEws, length, extraParams);
}
@ -157,6 +172,16 @@ namespace functions {
auto resultLength = shape::length(zShapeInfo);
if(nd4j::ArrayOptions::arrayType(xShapeInfo) == nd4j::ArrayType::EMPTY) {
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
return;
const auto startingVal = OpType::startingValue(x);
PRAGMA_OMP_PARALLEL_FOR_IF(resultLength > nd4j::Environment::getInstance()->elementwiseThreshold())
for (uint i = 0; i < resultLength; i++)
z[i] = startingVal;
return;
}
//pre squeezed: this is for keeping the pointer to the original
//shape information for tad offset
//the squeezed information doesn't render the right strides for
@ -212,9 +237,9 @@ namespace functions {
if (xEws == 1) {
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
{
{
auto local = OpType::startingValue(x);
auto threadNum = omp_get_thread_num();
auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum);
auto xi = x + threadOffset;
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
@ -223,15 +248,15 @@ namespace functions {
local = OpType::update(local, OpType::op(xi[i], extraParams), extraParams);
PRAGMA_OMP_CRITICAL
startingVal = OpType::update(startingVal, local, extraParams);
startingVal = OpType::update(startingVal, local, extraParams);
}
}
else {
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
{
{
auto local = OpType::startingValue(x);
auto threadNum = omp_get_thread_num();
auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum);
auto xi = x + xEws*threadOffset;
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
@ -240,8 +265,8 @@ namespace functions {
local = OpType::update(local, OpType::op(xi[i*xEws], extraParams), extraParams);
PRAGMA_OMP_CRITICAL
startingVal = OpType::update(startingVal, local, extraParams);
}
startingVal = OpType::update(startingVal, local, extraParams);
}
}
return OpType::postProcess(startingVal, length, extraParams);
}

View File

@ -45,7 +45,26 @@ namespace functions {
const Nd4jLong length = shape::length(xShapeInfo);
auto xEws = shape::elementWiseStride(xShapeInfo);
if (shape::isEmpty(xShapeInfo)) {
if (std::is_same<OpType, simdOps::Mean<X,Z>>::value) {
z[0] = nd4j::DataTypeUtils::nanOrZero<Z>();
} else {
z[0] = OpType::startingValue(x);
}
return;
}
if(nd4j::ArrayOptions::arrayType(xShapeInfo) == nd4j::ArrayType::EMPTY) {
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
return;
const auto startingVal = OpType::startingValue(x);
PRAGMA_OMP_PARALLEL_FOR_IF(length > nd4j::Environment::getInstance()->elementwiseThreshold())
for (uint i = 0; i < length; i++)
z[i] = startingVal;
return;
}
if (xEws > 0) {
z[0] = execScalar<OpType>(x, xEws, length, extraParams);
}
@ -69,7 +88,7 @@ namespace functions {
start = OpType::update(start, intermediate[e], extraParams);
z[0] = OpType::postProcess(start, shape::length(xShapeInfo), extraParams);
}
}
}
@ -165,6 +184,16 @@ namespace functions {
auto resultLength = shape::length(zShapeInfo);
if(nd4j::ArrayOptions::arrayType(xShapeInfo) == nd4j::ArrayType::EMPTY) {
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
return;
const auto startingVal = std::is_same<OpType, simdOps::Mean<X,Z>>::value ? nd4j::DataTypeUtils::nanOrZero<Z>() : static_cast<Z>(OpType::startingValue(x));
PRAGMA_OMP_PARALLEL_FOR_IF(resultLength > nd4j::Environment::getInstance()->elementwiseThreshold())
for (uint i = 0; i < resultLength; i++)
z[i] = startingVal;
return;
}
//pre squeezed: this is for keeping the pointer to the original
//shape information for tad offset
//the squeezed information doesn't render the right strides for

View File

@ -46,6 +46,21 @@ namespace functions {
const Nd4jLong length = shape::length(xShapeInfo);
auto xEws = shape::elementWiseStride(xShapeInfo);
if (shape::isEmpty(xShapeInfo)) {
z[0] = OpType::startingValue(x);
return;
}
if(nd4j::ArrayOptions::arrayType(xShapeInfo) == nd4j::ArrayType::EMPTY) {
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
return;
const auto startingVal = OpType::startingValue(x);
PRAGMA_OMP_PARALLEL_FOR_IF(length > nd4j::Environment::getInstance()->elementwiseThreshold())
for (uint i = 0; i < length; i++)
z[i] = startingVal;
return;
}
if (xEws >= 1) {
z[0] = execScalar<OpType>(x, xEws, length, extraParams);
}
@ -105,7 +120,7 @@ namespace functions {
delete[] intermediate;
return OpType::postProcess(start, shape::length(xShapeInfo), extraParams);
}
}
}
@ -159,6 +174,16 @@ namespace functions {
auto resultLength = shape::length(zShapeInfo);
if(nd4j::ArrayOptions::arrayType(xShapeInfo) == nd4j::ArrayType::EMPTY) {
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
return;
const auto startingVal = OpType::startingValue(x);
PRAGMA_OMP_PARALLEL_FOR_IF(resultLength > nd4j::Environment::getInstance()->elementwiseThreshold())
for (uint i = 0; i < resultLength; i++)
z[i] = startingVal;
return;
}
//pre squeezed: this is for keeping the pointer to the original
//shape information for tad offset
//the squeezed information doesn't render the right strides for
@ -209,7 +234,7 @@ namespace functions {
template <typename X, typename Z>
template <typename OpType>
Z _CUDA_H ReduceLongFunction<X, Z>::execScalar(void *vx, Nd4jLong xEws, Nd4jLong length, void *vextraParams) {
auto x = reinterpret_cast<X *>(vx);
auto extraParams = reinterpret_cast<X *>(vextraParams);
@ -219,9 +244,9 @@ namespace functions {
if (xEws == 1) {
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
{
{
auto local = OpType::startingValue(x);
auto threadNum = omp_get_thread_num();
auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum);
auto xi = x + threadOffset;
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
@ -230,15 +255,15 @@ namespace functions {
local = OpType::update(local, OpType::op(xi[i], extraParams), extraParams);
PRAGMA_OMP_CRITICAL
startingVal = OpType::update(startingVal, local, extraParams);
startingVal = OpType::update(startingVal, local, extraParams);
}
}
else {
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
{
{
auto local = OpType::startingValue(x);
auto threadNum = omp_get_thread_num();
auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum);
auto xi = x + xEws*threadOffset;
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
@ -247,8 +272,8 @@ namespace functions {
local = OpType::update(local, OpType::op(xi[i*xEws], extraParams), extraParams);
PRAGMA_OMP_CRITICAL
startingVal = OpType::update(startingVal, local, extraParams);
}
startingVal = OpType::update(startingVal, local, extraParams);
}
}
return OpType::postProcess(startingVal, length, extraParams);
}

View File

@ -48,6 +48,20 @@ namespace functions {
const auto xEws = shape::elementWiseStride(xShapeInfo);
const int rank = shape::rank(xShapeInfo);
if (shape::isEmpty(xShapeInfo)) {
z[0] = OpType::startingValue(x);
return;
}
if(nd4j::ArrayOptions::arrayType(xShapeInfo) == nd4j::ArrayType::EMPTY) {
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
return;
const auto startingVal = OpType::startingValue(x);
PRAGMA_OMP_PARALLEL_FOR_IF(length > nd4j::Environment::getInstance()->elementwiseThreshold())
for (uint i = 0; i < length; i++)
z[i] = startingVal;
return;
}
if (xEws >= 1) {
z[0] = execScalar<OpType>(x, xEws, length, extraParams);
@ -71,7 +85,7 @@ namespace functions {
for (int e = 0; e < maxThreads; e++)
start = OpType::update(start, intermediate[e], extraParams);
z[0] = OpType::postProcess(start, shape::length(xShapeInfo), extraParams);
z[0] = OpType::postProcess(start, length, extraParams);
}
}
@ -171,6 +185,16 @@ namespace functions {
auto zLength = shape::length(zShapeInfo);
if(nd4j::ArrayOptions::arrayType(xShapeInfo) == nd4j::ArrayType::EMPTY) {
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
return;
const auto startingVal = OpType::startingValue(x);
PRAGMA_OMP_PARALLEL_FOR_IF(zLength > nd4j::Environment::getInstance()->elementwiseThreshold())
for (uint i = 0; i < zLength; i++)
z[i] = startingVal;
return;
}
//pre squeezed: this is for keeping the pointer to the original
//shape information for tad offset
//the squeezed information doesn't render the right strides for
@ -231,9 +255,9 @@ namespace functions {
if (xEws == 1) {
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
{
{
auto local = OpType::startingValue(x);
auto threadNum = omp_get_thread_num();
auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum);
auto xi = x + threadOffset;
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
@ -242,15 +266,15 @@ namespace functions {
local = OpType::update(local, OpType::op(xi[i], extraParams), extraParams);
PRAGMA_OMP_CRITICAL
startingVal = OpType::update(startingVal, local, extraParams);
startingVal = OpType::update(startingVal, local, extraParams);
}
}
else {
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
{
{
auto local = OpType::startingValue(x);
auto threadNum = omp_get_thread_num();
auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum);
auto xi = x + xEws*threadOffset;
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
@ -259,8 +283,8 @@ namespace functions {
local = OpType::update(local, OpType::op(xi[i*xEws], extraParams), extraParams);
PRAGMA_OMP_CRITICAL
startingVal = OpType::update(startingVal, local, extraParams);
}
startingVal = OpType::update(startingVal, local, extraParams);
}
}
return OpType::postProcess(startingVal, length, extraParams);
}

View File

@ -37,7 +37,7 @@ void Reduce3<X,Z>::execScalar(void *vx, Nd4jLong *xShapeInfo,
void *vextraParams,
void *vy, Nd4jLong *yShapeInfo,
void *vz, Nd4jLong *zShapeInfo) {
auto x = reinterpret_cast<X *>(vx);
auto y = reinterpret_cast<X *>(vy);
auto z = reinterpret_cast<Z *>(vz);
@ -47,11 +47,21 @@ void Reduce3<X,Z>::execScalar(void *vx, Nd4jLong *xShapeInfo,
auto xEws = shape::elementWiseStride(xShapeInfo);
auto yEws = shape::elementWiseStride(yShapeInfo);
if(nd4j::ArrayOptions::arrayType(xShapeInfo) == nd4j::ArrayType::EMPTY || nd4j::ArrayOptions::arrayType(yShapeInfo) == nd4j::ArrayType::EMPTY) {
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
return;
const auto startingVal = OpType::startingValue(x);
PRAGMA_OMP_PARALLEL_FOR_IF(length > nd4j::Environment::getInstance()->elementwiseThreshold())
for (uint i = 0; i < length; i++)
z[i] = startingVal;
return;
}
Z extraParamsVals[3] = {(Z) 0.0f, (Z) 0.0f, (Z) 0.0f};
// it's possible case for EqualsWithEps op
if (extraParams != nullptr)
if (extraParams != nullptr)
extraParamsVals[2] = extraParams[0];
uint xShapeInfoCast[MAX_RANK];
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
@ -117,7 +127,7 @@ void Reduce3<X,Y>::execScalar(const int opNum,
void *extraParamsVals,
void *vy, Nd4jLong *yShapeInfo,
void *vz, Nd4jLong *zShapeInfo) {
DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(vx, xShapeInfo, extraParamsVals, vy, yShapeInfo, vz, zShapeInfo), REDUCE3_OPS);
}
@ -176,8 +186,8 @@ void Reduce3<X,Z>:: execAll(void *vx, Nd4jLong *xShapeInfo,
void *vextraParams,
void *vy, Nd4jLong *yShapeInfo,
void *vz, Nd4jLong *zShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *xTadShapeInfo, Nd4jLong *xOffsets,
int *dimension, int dimensionLength,
Nd4jLong *xTadShapeInfo, Nd4jLong *xOffsets,
Nd4jLong *yTadShapeInfo, Nd4jLong *yOffsets) {
auto x = reinterpret_cast<X *>(vx);

View File

@ -47,8 +47,8 @@ namespace functions {
Nd4jLong *xShapeInfo,
void *extraParams,
void *z,
Nd4jLong *resultShapeInfoBuffer) {
DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(biasCorrected, x, xShapeInfo, extraParams, z, resultShapeInfoBuffer), SUMMARY_STATS_OPS);
Nd4jLong *zShapeInfo) {
DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(biasCorrected, x, xShapeInfo, extraParams, z, zShapeInfo), SUMMARY_STATS_OPS);
}
template <typename X, typename Y>
@ -58,10 +58,10 @@ namespace functions {
Nd4jLong *xShapeInfo,
void *extraParams,
void *z,
Nd4jLong *resultShapeInfoBuffer,
Nd4jLong *zShapeInfo,
int *dimension,
int dimensionLength) {
DISPATCH_BY_OPNUM_TT(exec, PARAMS(biasCorrected, x, xShapeInfo, extraParams, z, resultShapeInfoBuffer, dimension, dimensionLength), SUMMARY_STATS_OPS);
DISPATCH_BY_OPNUM_TT(exec, PARAMS(biasCorrected, x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength), SUMMARY_STATS_OPS);
}
template <typename X, typename Z>
@ -71,7 +71,7 @@ namespace functions {
Nd4jLong *xShapeInfo,
void *vextraParams,
void *vz,
Nd4jLong *resultShapeInfoBuffer) {
Nd4jLong *zShapeInfo) {
auto z = reinterpret_cast<Z*>(vz);
z[0] = execScalar<OpType>(biasCorrected, vx, xShapeInfo, vextraParams);
}
@ -86,12 +86,12 @@ namespace functions {
SummaryStatsData<X> startingIndex;
startingIndex.initialize();
auto length = shape::length(xShapeInfo);
uint xShapeInfoCast[MAX_RANK];
const bool canCast = nd4j::DataTypeUtils::castShapeInfo<uint>(xShapeInfo, xShapeInfoCast);
for (Nd4jLong i = 0; i < length; i++) {
auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, length, canCast);
SummaryStatsData<X> curr;
@ -99,7 +99,7 @@ namespace functions {
startingIndex = update(startingIndex, curr, extraParams);
}
return OpType::getValue(biasCorrected, startingIndex);
return OpType::getValue(biasCorrected, startingIndex);
}
template <typename X, typename Z>
@ -108,20 +108,31 @@ namespace functions {
void *vx,
Nd4jLong *xShapeInfo,
void *vextraParams,
void *vresult,
Nd4jLong *resultShapeInfoBuffer,
void *vz,
Nd4jLong *zShapeInfo,
int *dimension,
int dimensionLength) {
auto x = reinterpret_cast<X *>(vx);
auto z = reinterpret_cast<Z *>(vresult);
auto extraParams = reinterpret_cast<Z *>(vextraParams);
if (shape::isScalar(resultShapeInfoBuffer)) {
z[0] = execScalar<OpType>(biasCorrected, x, xShapeInfo, extraParams);
auto x = reinterpret_cast<X *>(vx);
auto z = reinterpret_cast<Z *>(vz);
auto extraParams = reinterpret_cast<Z *>(vextraParams);
int resultLength = shape::length(zShapeInfo);
if(nd4j::ArrayOptions::arrayType(xShapeInfo) == nd4j::ArrayType::EMPTY) {
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
return;
SummaryStatsData<X> comp;
comp.initWithValue(x[0]);
PRAGMA_OMP_PARALLEL_FOR_IF(resultLength > nd4j::Environment::getInstance()->elementwiseThreshold())
for (uint i = 0; i < resultLength; i++)
z[i] = OpType::getValue(biasCorrected, comp);
return;
}
if (shape::isScalar(zShapeInfo)) {
z[0] = execScalar<OpType>(biasCorrected, x, xShapeInfo, extraParams);
return;
}
//no-op
if (dimensionLength < 1)
@ -129,7 +140,6 @@ namespace functions {
auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(xShapeInfo, dimension, dimensionLength);
int resultLength = shape::length(resultShapeInfoBuffer);
//pre squeezed: this is for keeping the pointer to the original
//shape information for tad offset
//the squeezed information doesn't render the right strides for
@ -149,7 +159,7 @@ namespace functions {
PRAGMA_OMP_PARALLEL_FOR
for (int r = 0; r < resultLength; r++) {
auto tadOffsetForBlock = tadPack.primaryOffsets()[r];
auto tx = x + tadOffsetForBlock;
SummaryStatsData<X> comp;

View File

@ -131,7 +131,7 @@ namespace nd4j {
COPY_SHAPE(x, shapeE);
COPY_SHAPE(y, shapeG);
auto shapeList = SHAPELIST(shapeE, shapeG);
auto shapeList = SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG));
return shapeList;
}

View File

@ -81,7 +81,7 @@ CUSTOM_OP_IMPL(conv1d, 2, 1, false, 0, 4) {
auto outputReshaped = output ->reshape(output->ordering(), reshapeForOutput);
auto weightsReshaped = weights->reshape(weights->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
ConvolutionUtils::conv2d(*block.launchContext(), inputReshaped, weightsReshaped, bias, outputReshaped, 1,kW, 1,sW, 0,pW, 1,1, isSameMode, isNCW);
ConvolutionUtils::conv2d(block, inputReshaped, weightsReshaped, bias, outputReshaped, 1,kW, 1,sW, 0,pW, 1,1, isSameMode, isNCW);
delete inputReshaped;
delete outputReshaped;
@ -217,7 +217,7 @@ CUSTOM_OP_IMPL(conv1d_bp, 3, 2, false, 0, 4) {
auto weightsReshaped = weights->reshape(weights->ordering(),{1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
auto gradWReshaped = gradW ->reshape(gradW->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
ConvolutionUtils::conv2dBP(*block.launchContext(), inputReshaped, weightsReshaped, bias, gradOReshaped, gradIReshaped, gradWReshaped, gradB, 1,kW, 1,sW, 0,pW, 1,1, isSameMode, isNCW);
ConvolutionUtils::conv2dBP(block, inputReshaped, weightsReshaped, bias, gradOReshaped, gradIReshaped, gradWReshaped, gradB, 1,kW, 1,sW, 0,pW, 1,1, isSameMode, isNCW);
delete inputReshaped;
delete gradIReshaped;

View File

@ -63,7 +63,7 @@ CUSTOM_OP_IMPL(conv2d, 2, 1, false, 0, 9) {
if (bias)
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
ConvolutionUtils::conv2d(*block.launchContext(), input, weights, bias, output, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW);
ConvolutionUtils::conv2d(block, input, weights, bias, output, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW);
return Status::OK();
}
@ -194,7 +194,7 @@ CUSTOM_OP_IMPL(conv2d_bp, 3, 2, false, 0, 9) {
if(bias)
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV2D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
ConvolutionUtils::conv2dBP(*block.launchContext(), input, weights, bias, gradO, gradI, gradW, gradB, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW);
ConvolutionUtils::conv2dBP(block, input, weights, bias, gradO, gradI, gradW, gradB, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW);
return Status::OK();
}
@ -305,7 +305,7 @@ CUSTOM_OP_IMPL(conv2d_input_bp, 3, 1, false, 0, 9) {
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "CUSTOM CONV2D_INPUT_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, "CUSTOM CONV2D_INPUT_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str());
ConvolutionUtils::conv2dBP(*block.launchContext(), &input, weights, nullptr, gradO, gradI, nullptr, nullptr, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW);
ConvolutionUtils::conv2dBP(block, &input, weights, nullptr, gradO, gradI, nullptr, nullptr, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW);
return Status::OK();
}

View File

@ -157,7 +157,7 @@ CUSTOM_OP_IMPL(conv3dnew, 2, 1, false, 0, 13) {
permutForOutput = {0,2,3,4,1}; // [bS, oC, oD, oH, oW] -> [bS, oD, oH, oW, oC]
NDArray columns(input->ordering(), {bS, iC, kD, kH, kW, oD, oH, oW}, input->dataType(), block.launchContext());
ConvolutionUtils::vol2col(*block.launchContext(), *input, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, iC, iD, iH, iW] is convoluted to [bS, iC, kD, kH, kW, oD, oH, oW]
ConvolutionUtils::vol2col(block, *input, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, iC, iD, iH, iW] is convoluted to [bS, iC, kD, kH, kW, oD, oH, oW]
// [bS, iC, kD, kH, kW, oD, oH, oW] x [kD, kH, kW, iC, oC] = [bS, oD, oH, oW, oC]
MmulHelper::tensorDot(&columns, weights, output, {1,2,3,4}, {3,0,1,2}, permutForOutput);
@ -456,7 +456,7 @@ CUSTOM_OP_IMPL(conv3dnew_bp, 3, 2, false, 0, 13) {
// ----- calculation of gradW and gradB ----- //
NDArray columns(input->ordering(), {bS, iC, kD, kH, kW, oD, oH, oW}, input->dataType(), block.launchContext());
ConvolutionUtils::vol2col(*block.launchContext(), *input, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, iC, iD, iH, iW] is convoluted to [bS, iC, kD, kH, kW, oD, oH, oW]
ConvolutionUtils::vol2col(block, *input, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, iC, iD, iH, iW] is convoluted to [bS, iC, kD, kH, kW, oD, oH, oW]
MmulHelper::tensorDot(&columns, gradO, gradW, {0,5,6,7}, gradOaxesForDot, {3,0,1,2,4}); // [bS, iC, kD, kH, kW, oD, oH, oW] x [bS, oD, oH, oW, oC]/[bS, oC, oD, oH, oW] = [iC, kD, kH, kW, oC]
if(gradB) {
@ -469,7 +469,7 @@ CUSTOM_OP_IMPL(conv3dnew_bp, 3, 2, false, 0, 13) {
//----- calculation of gradI -----//
MmulHelper::tensorDot(weights, gradO, &columns, {indWoC}, {indIOioC}, {2,3,4,1,0,5,6,7}); // [kD, kH, kW, iC, oC] x [bS, oD, oH, oW, oC]/[bS, oC, oD, oH, oW] = [kD, kH, kW, iC, bS, oD, oH, oW]
ConvolutionUtils::col2vol(*block.launchContext(), columns, *gradI, sD, sH, sW, pD, pH, pW, dD, dH, dW); // columns [bS, iC, kD, kH, kW, oD, oH, oW] is de-convoluted to [bS, iC, iD, iH, iW]
ConvolutionUtils::col2vol(block, columns, *gradI, sD, sH, sW, pD, pH, pW, dD, dH, dW); // columns [bS, iC, kD, kH, kW, oD, oH, oW] is de-convoluted to [bS, iC, iD, iH, iW]
if(!isNDHWC) {
delete input;

View File

@ -70,7 +70,7 @@ CUSTOM_OP_IMPL(deconv2d_tf, 3, 1, false, 0, 9) {
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "CUSTOM DECONV2D_TF OP: wrong shape of input array, basing on array with output shape expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, "CUSTOM DECONV2D_TF OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str());
ConvolutionUtils::conv2dBP(*block.launchContext(), &input, weights, nullptr, gradO, gradI, nullptr, nullptr, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW);
ConvolutionUtils::conv2dBP(block, &input, weights, nullptr, gradO, gradI, nullptr, nullptr, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW);
return Status::OK();
}

View File

@ -75,7 +75,7 @@ CUSTOM_OP_IMPL(deconv3d, 2, 1, false, 0, 13) {
// NDHWC: [kD, kH, kW, oC, iC] x [bS, iD, iH, iW, iC] = [kD, kH, kW, oC, bS, iD, iH, iW]
// NCDHW: [iC, oC, kD, kH, kW] x [bS, iC, iD, iH, iW] = [oC, kD, kH, kW, bS, iD, iH, iW]
nd4j::MmulHelper::tensorDot(weights, input, &columns, {indWiC}, {indIOioC}, {2, 3, 4, 1, 0, 5, 6, 7}); // [bS, oC, kD, kH, kW, iD, iH, iW] -> [kD, kH, kW, oC, bS, iD, iH, iW]
ConvolutionUtils::col2vol(*block.launchContext(), columns, *output, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, oC, kD, kH, kW, iD, iH, iW] is de-convoluted to [bS, oC, oD, oH, oW]
ConvolutionUtils::col2vol(block, columns, *output, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, oC, kD, kH, kW, iD, iH, iW] is de-convoluted to [bS, oC, oD, oH, oW]
//----- add biases if required -----//
if(bias)
@ -234,7 +234,7 @@ CUSTOM_OP_IMPL(deconv3d_bp, 3, 2, false, 0, 13) {
// ----- calculation of gradW ----- //
auto columns = NDArrayFactory::create(input->ordering(), {bS, oC, kD, kH, kW, iD, iH, iW}, input->dataType(), block.launchContext());
ConvolutionUtils::vol2col(*block.launchContext(), *gradO, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, oC, oD, oH, oW] is deconvoluted to [bS, oC, kD, kH, kW, iD, iH, iW]
ConvolutionUtils::vol2col(block, *gradO, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, oC, oD, oH, oW] is deconvoluted to [bS, oC, kD, kH, kW, iD, iH, iW]
MmulHelper::tensorDot(input, &columns, gradW, inputAxesForDot, {0, 5, 6, 7}, {4, 3, 0, 1, 2}); // [bS, iC, iD, iH, iW]/[bS, iD, iH, iW, iC] x [bS, oC, kD, kH, kW, iD, iH, iW] = [iC, oC, kD, kH, kW]
// ----- calculation of gradB ----- //

View File

@ -62,7 +62,7 @@ CUSTOM_OP_IMPL(depthwise_conv2d, 2, 1, false, 0, 9) {
if (bias)
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DEPTHWISECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
ConvolutionUtils::depthwiseConv2d(*block.launchContext(), input, weights, bias, output, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW);
ConvolutionUtils::depthwiseConv2d(block, input, weights, bias, output, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW);
return Status::OK();
}
@ -185,7 +185,7 @@ CUSTOM_OP_IMPL(depthwise_conv2d_bp, 3, 2, false, 0, 9) {
if(bias)
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
ConvolutionUtils::depthwiseConv2dBP(*block.launchContext(), input, weights, bias, gradO, gradI, gradW, gradB, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW);
ConvolutionUtils::depthwiseConv2dBP(block, input, weights, bias, gradO, gradI, gradW, gradB, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW);
return Status::OK();
}

View File

@ -58,7 +58,7 @@ CUSTOM_OP_IMPL(pointwise_conv2d, 2, 1, false, 0, 0) {
if (bias)
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM POINTWISECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
ConvolutionUtils::conv2d(*block.launchContext(), input, weights, bias, output, kH,kW, sH,sW, pH,pW, dH,dW, 1/*isSameMode*/, isNCHW);
ConvolutionUtils::conv2d(block, input, weights, bias, output, kH,kW, sH,sW, pH,pW, dH,dW, 1/*isSameMode*/, isNCHW);
return Status::OK();
}

View File

@ -70,7 +70,7 @@ CUSTOM_OP_IMPL(avgpool2d, 1, 1, false, 0, 10) {
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor;
ConvolutionUtils::pooling2d(*block.launchContext(), *input, *output, kH, kW, sH, sW, pH, pW, dH, dW, PoolingType::AVG_POOL, extraParam0);
ConvolutionUtils::pooling2d(block, *input, *output, kH, kW, sH, sW, pH, pW, dH, dW, PoolingType::AVG_POOL, extraParam0);
//output->printBuffer("output op");
if (!isNCHW) {
@ -198,7 +198,7 @@ CUSTOM_OP_IMPL(avgpool2d_bp, 2, 1, false, 0, 10) {
// *gradI /= kH*kW;
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor;
ConvolutionUtils::pooling2dBP(*block.launchContext(), *input, *gradO, *gradI, kH, kW, sH, sW, pH, pW, dH, dW, 1, extraParam0);
ConvolutionUtils::pooling2dBP(block, *input, *gradO, *gradI, kH, kW, sH, sW, pH, pW, dH, dW, 1, extraParam0);
if(!isNCHW) {
delete input;

View File

@ -69,7 +69,7 @@ CUSTOM_OP_IMPL(avgpool3dnew, 1, 1, false, 0, 14) {
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
//T extraParams[] = {};
ConvolutionUtils::pooling3d(*block.launchContext(), *input, *output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, 1, extraParam0);
ConvolutionUtils::pooling3d(block, *input, *output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, 1, extraParam0);
if(!isNCDHW) {
delete input;
@ -189,7 +189,7 @@ CUSTOM_OP_IMPL(avgpool3dnew_bp, 2, 1, false, 0, 14) {
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor;
ConvolutionUtils::pooling3dBP(*block.launchContext(), *input, *gradO, *gradI, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, 1, extraParam0);
ConvolutionUtils::pooling3dBP(block, *input, *gradO, *gradI, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, 1, extraParam0);
if(!isNCDHW) {
delete input;

View File

@ -70,7 +70,7 @@ CUSTOM_OP_IMPL(maxpool2d, 1, 1, false, 0, 9) {
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; poolingMode; 9 - divisor;
ConvolutionUtils::pooling2d(*block.launchContext(), *input, *output, kH, kW, sH, sW, pH, pW, dH, dW, PoolingType::MAX_POOL, 1);
ConvolutionUtils::pooling2d(block, *input, *output, kH, kW, sH, sW, pH, pW, dH, dW, PoolingType::MAX_POOL, 1);
if (!isNCHW) {
delete input;
@ -196,7 +196,7 @@ CUSTOM_OP_IMPL(maxpool2d_bp, 2, 1, false, 0, 10) {
// columns->template applyTransform<simdOps::Col2Im<T>>(gradI, std::vector<T>({(T)sH, (T)sW, (T)pH, (T)pW, (T)iH, (T)iW, (T)dH, (T)dW}).data());
ConvolutionUtils::pooling2dBP(*block.launchContext(), *input, *gradO, *gradI, kH, kW, sH, sW, pH, pW, dH, dW, 0., 1.);
ConvolutionUtils::pooling2dBP(block, *input, *gradO, *gradI, kH, kW, sH, sW, pH, pW, dH, dW, 0., 1.);
if(!isNCHW) {
delete input;

View File

@ -70,7 +70,7 @@ CUSTOM_OP_IMPL(maxpool3dnew, 1, 1, false, 0, 14) {
if(isSameMode) // SAME
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
ConvolutionUtils::pooling3d(*block.launchContext(), *input, *output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, 0, 1);
ConvolutionUtils::pooling3d(block, *input, *output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, 0, 1);
if(!isNCDHW) {
delete input;
@ -204,7 +204,7 @@ CUSTOM_OP_IMPL(maxpool3dnew_bp, 2, 1, false, 0, 14) {
// ConvolutionUtils<T>::col2vol(*columns, *gradI, sD, sH, sW, pD, pH, pW, dD, dH, dW); // columns [bS, iC, kD, kH, kW, oD, oH, oW] is de-convoluted to [bS, iC, iD, iH, iW]
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - unnecessary;
ConvolutionUtils::pooling3dBP(*block.launchContext(), *input, *gradO, *gradI, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, 0, 1);
ConvolutionUtils::pooling3dBP(block, *input, *gradO, *gradI, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, 0, 1);
if(!isNCDHW) {
delete input;

View File

@ -68,7 +68,7 @@ namespace nd4j {
ConvolutionUtils::calcPadding2D(pY, pX, oY, oX, inY, inX, kY, kX, sY, sX, dY, dX);
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor;
ConvolutionUtils::pooling2d(*block.launchContext(), *input, *output, kY, kX, sY, sX, pY, pX, dY, dX, PoolingType::PNORM_POOL, extraParam0);
ConvolutionUtils::pooling2d(block, *input, *output, kY, kX, sY, sX, pY, pX, dY, dX, PoolingType::PNORM_POOL, extraParam0);
if (!isNCHW) {
delete input;
@ -209,7 +209,7 @@ CUSTOM_OP_IMPL(pnormpool2d_bp, 2, 1, false, 1, 10) {
// columns->template applyTransform<simdOps::Col2Im<T>>(gradI, std::vector<T>({(T)sH, (T)sW, (T)pH, (T)pW, (T)iH, (T)iW, (T)dH, (T)dW}).data());
ConvolutionUtils::pooling2dBP(*block.launchContext(), *input, *gradO, *gradI, kH, kW, sH, sW, pH, pW, dH, dW, 2, pnorm);
ConvolutionUtils::pooling2dBP(block, *input, *gradO, *gradI, kH, kW, sH, sW, pH, pW, dH, dW, 2, pnorm);
if(!isNCHW) {
delete input;

View File

@ -84,11 +84,11 @@ CUSTOM_OP_IMPL(sconv2d, 2, 1, false, 0, 9) {
if (iC == 1) {
nd4j_debug("SCONV2D OP: for input_channels = 1 this op is equivalent to standard conv2d\n","");
ConvolutionUtils::conv2d(*block.launchContext(), input, weightsDepth, bias, output, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW);
ConvolutionUtils::conv2d(block, input, weightsDepth, bias, output, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW);
return Status::OK();
}
ConvolutionUtils::sconv2d(*block.launchContext(), input, weightsDepth, weightsPoint, bias, output, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW);
ConvolutionUtils::sconv2d(block, input, weightsDepth, weightsPoint, bias, output, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW);
return Status::OK();
}
@ -274,12 +274,12 @@ CUSTOM_OP_IMPL(sconv2d_bp, 3, 2, false, 0, 9) {
auto resultFFShape = isNCHW ? std::vector<Nd4jLong>({bS, mC*iC, oH, oW}) : std::vector<Nd4jLong>({bS, oH, oW, mC*iC});
auto resultFF = NDArrayFactory::create_(input->ordering(), resultFFShape, input->dataType(), block.launchContext());
ConvolutionUtils::sconv2d(*block.launchContext(), input, weightsDepth, nullptr, nullptr, resultFF, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW);
ConvolutionUtils::sconv2d(block, input, weightsDepth, nullptr, nullptr, resultFF, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW);
auto gradIDepthShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC*mC,oH,oW, 0,indIOioC,indIiH,indIiH+1});
auto gradIDepth = NDArrayFactory::create_(resultFF->ordering(), gradIDepthShape, resultFF->dataType(), block.launchContext()); // [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW)
ConvolutionUtils::conv2dBP(*block.launchContext(), resultFF, weightsPoint, bias, gradO, gradIDepth, gradWP, gradB, 1,1, 1,1, 0,0, 1,1, isSameMode, isNCHW); // in this case oH=iH and oW=iW
ConvolutionUtils::conv2dBP(block, resultFF, weightsPoint, bias, gradO, gradIDepth, gradWP, gradB, 1,1, 1,1, 0,0, 1,1, isSameMode, isNCHW); // in this case oH=iH and oW=iW
gradO = gradIDepth;
bias = gradB = nullptr; // if pointwise backprop was done then don't calculate gradB at depthwise_conv2d_bp step
@ -288,7 +288,7 @@ CUSTOM_OP_IMPL(sconv2d_bp, 3, 2, false, 0, 9) {
}
// ----- apply depthwise_conv2d_bp ----- //
ConvolutionUtils::depthwiseConv2dBP(*block.launchContext(), input, weightsDepth, bias, gradO, gradI, gradWD, gradB, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW);
ConvolutionUtils::depthwiseConv2dBP(block, input, weightsDepth, bias, gradO, gradI, gradWD, gradB, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW);
if(weightsPoint)
delete gradO;

View File

@ -41,7 +41,7 @@ CUSTOM_OP_IMPL(upsampling2d, 1, 1, false, 0, 2) {
REQUIRE_TRUE(input->rankOf() == 4, 0, "UPSAMPLING2D op: input should be 4D, but got %i instead!", input->rankOf());
REQUIRE_TRUE(output->rankOf() == 4, 0, "UPSAMPLING2D op: output should be 4D, but got %i instead!", output->rankOf());
ConvolutionUtils::upsampling2d(*block.launchContext(), *input, *output, factorH, factorW, (bool)isNCHW);
ConvolutionUtils::upsampling2d(block, *input, *output, factorH, factorW, (bool)isNCHW);
return Status::OK();
}
@ -105,7 +105,7 @@ CUSTOM_OP_IMPL(upsampling2d_bp, 2, 1, false, 0, 0) {
REQUIRE_TRUE(gradO->rankOf() == 4, 0, "UPSAMPLING2D_BP op: output's gradient array must be 4D, but got %i instead!", gradO->rankOf());
REQUIRE_TRUE(gradI->rankOf() == 4, 0, "UPSAMPLING2D_BP op: input's gradient array must be 4D, but got %i instead!", gradI->rankOf());
ConvolutionUtils::upsampling2dBP(*block.launchContext(), *gradO, *gradI, (bool)isNCHW);
ConvolutionUtils::upsampling2dBP(block, *gradO, *gradI, (bool)isNCHW);
return Status::OK();
}

View File

@ -41,7 +41,7 @@ CUSTOM_OP_IMPL(upsampling3d, 1, 1, false, 0, 3) {
REQUIRE_TRUE(input->rankOf() == 5, 0, "UPSAMPLING3D op: input should be 5D, but got %i instead!", input->rankOf());
REQUIRE_TRUE(output->rankOf() == 5, 0, "UPSAMPLING3D op: output should be 5D, but got %i instead!", output->rankOf());
ConvolutionUtils::upsampling3d(*block.launchContext(), *input, *output, factorD, factorH, factorW, (bool)isNCDHW);
ConvolutionUtils::upsampling3d(block, *input, *output, factorD, factorH, factorW, (bool)isNCDHW);
return Status::OK();
}
@ -105,7 +105,7 @@ CUSTOM_OP_IMPL(upsampling3d_bp, 2, 1, false, 0, 0) {
REQUIRE_TRUE(gradO->rankOf() == 5, 0, "UPSAMPLING3D_BP op: output's gradient array must be 4D, but got %i instead!", gradO->rankOf());
REQUIRE_TRUE(gradI->rankOf() == 5, 0, "UPSAMPLING3D_BP op: input's gradient array must be 4D, but got %i instead!", gradI->rankOf());
ConvolutionUtils::upsampling3dBP(*block.launchContext(), *gradO, *gradI, (bool)isNCDHW);
ConvolutionUtils::upsampling3dBP(block, *gradO, *gradI, (bool)isNCDHW);
return Status::OK();
}

View File

@ -216,7 +216,7 @@ CUSTOM_OP_IMPL(batchnorm_new, 3, 1, false, 1, 2) {
}
std::vector<Nd4jLong> shape({2, mean->lengthOf()});
NDArray weights = NDArrayFactory::create<float>('c', shape, block.getWorkspace());
NDArray weights = NDArrayFactory::create<float>('c', shape, block.launchContext());
weights({0, 1, 0, 0}).assign(1.0f);
weights({1, 2, 0, 0}).assign(0.0f);

View File

@ -72,6 +72,11 @@ namespace nd4j {
if (dims.size() > 1)
std::sort(dims.begin(), dims.end());
for (auto d:dims) {
REQUIRE_TRUE(inputShape->at(0)[d+1] != 0, 0, "ArgMax: you can't reduce along axis with 0 in shape");
}
// special case - output is scalar
if (dims.size() == 0 || (dims.size() == 1 && dims.at(0) == MAX_INT)) {
return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(nd4j::DataType::INT64));

View File

@ -72,6 +72,10 @@ namespace nd4j {
if (dims.size() > 1)
std::sort(dims.begin(), dims.end());
for (auto d:dims) {
REQUIRE_TRUE(inputShape->at(0)[d+1] != 0, 0, "ArgMin: you can't reduce along axis with 0 in shape");
}
// special case - output is scalar
if (dims.size() == 0 || (dims.size() == 1 && dims.at(0) == MAX_INT)) {
return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(DataType::INT64));

View File

@ -71,10 +71,8 @@ namespace nd4j {
ALLOCATE(newShape, block.getWorkspace(), shape::shapeInfoLength(len), Nd4jLong);
newShape[0] = len;
auto empty = false;
for (int e = 0; e < shapeArray->lengthOf(); e++){
newShape[e+1] = shapeArray->e<Nd4jLong>(e);
empty |= (newShape[e+1] == 0); //Support "zeros in shape as empty" for TF import
}
nd4j::DataType dataType;
@ -90,10 +88,6 @@ namespace nd4j {
} else
throw std::runtime_error("Fill: missing value to fill output array with");
if(empty){
return SHAPELIST(ShapeBuilders::emptyShapeInfo(dataType, block.getWorkspace()));
}
ShapeUtils::updateStridesAndType(newShape, dataType, 'c');
return SHAPELIST(CONSTANT(newShape));

View File

@ -151,8 +151,10 @@ DECLARE_SHAPE_FN(range) {
delta = INPUT_VARIABLE(2)->e<double>(0);
}
if (limit == start)
return SHAPELIST(ConstantShapeHelper::getInstance()->emptyShapeInfo(dtype));
if (limit == start){
//Return [0] to match TF
return SHAPELIST(ConstantShapeHelper::getInstance()->vectorShapeInfo(0, dtype));
}
REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !");
@ -177,8 +179,10 @@ DECLARE_SHAPE_FN(range) {
//nd4j_printf("Start: [%lld]; Limit: [%lld]; Delta: [%lld];\n", start, limit, delta)
if (limit == start)
return SHAPELIST(ConstantShapeHelper::getInstance()->emptyShapeInfo(dtype));
if (limit == start){
//Return [0] to match TF
return SHAPELIST(ConstantShapeHelper::getInstance()->vectorShapeInfo(0, dtype));
}
REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !");
@ -203,8 +207,10 @@ DECLARE_SHAPE_FN(range) {
delta = INT_ARG(2);
}
if (limit == start)
return SHAPELIST(ConstantShapeHelper::getInstance()->emptyShapeInfo(nd4j::DataType::INT32));
if (limit == start){
//Return [0] to match TF
return SHAPELIST(ConstantShapeHelper::getInstance()->vectorShapeInfo(0, nd4j::DataType::INT32));
}
REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !");
@ -233,9 +239,10 @@ DECLARE_SHAPE_FN(range) {
delta = T_ARG(2);
}
//REQUIRE_TRUE(limit != start, 0, "CUSTOM RANGE OP: limit and start values should be different, but got both equal to %f !", limit);
if (limit == start)
return SHAPELIST(ConstantShapeHelper::getInstance()->emptyShapeInfo(Environment::getInstance()->defaultFloatDataType()));
if (limit == start){
//Return [0] to match TF
return SHAPELIST(ConstantShapeHelper::getInstance()->vectorShapeInfo(0, Environment::getInstance()->defaultFloatDataType()));
}
REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !");

View File

@ -31,7 +31,8 @@ namespace nd4j {
REQUIRE_TRUE(output->isScalar(), 0, "Rank output should be scalar");
output->assign(static_cast<Nd4jLong>(input->rankOf()));
// output->assign(static_cast<Nd4jLong>(input->rankOf()));
output->assign(input->rankOf());
return Status::OK();
}

View File

@ -88,7 +88,7 @@ DECLARE_TYPES(reduce_max) {
->setSameMode(true);
}
#endif
#endif
#if NOT_EXCLUDED(OP_reduce_max_bp)
//////////////////////////////////////////////////////////////////////////

View File

@ -132,17 +132,13 @@ namespace nd4j {
REQUIRE_TRUE(size == -1 || size >= 0, 0, "Invalid size[%i] value: must be positive (or -1 for 'all remaining'), got %i", e, size, inShape[e+1]);
REQUIRE_TRUE(start >= 0 && start <= inShape[e+1], 0, "Invalid begin[%i] value: Begin must satisfy 0 <= begin <= size[i], got begin=%i for dimension size %i", e, start, inShape[e+1]);
REQUIRE_TRUE(start + size <= inShape[e+1], 0, "Slice: interval [%i, %i] is out of bounds for dimension %i with size %i", start, start + size, e, inShape[e+1]);
if(start == inShape[e+1] || size == 0 ){
empty = true;
if(start == inShape[e+1] ){
size = 0;
}
shape.emplace_back(size);
}
if(empty){
return SHAPELIST(ShapeBuilders::emptyShapeInfo(nd4j::DataType::INT32, block.getWorkspace()));
}
auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShape), 'c', shape);
return SHAPELIST(newShape);
}

View File

@ -33,6 +33,10 @@ CUSTOM_OP_IMPL(stack, -1, 1, false, 0, 0) {
int dim = block.getIArguments()->size() > 0 ? INT_ARG(0) : 0;
if(dim < 0)
dim += input->rankOf() + 1;
// no-op in case of empty output array
if (output->isEmpty())
return Status::OK();
// input validation
// check whether shapes of all input array are the same
@ -47,16 +51,6 @@ CUSTOM_OP_IMPL(stack, -1, 1, false, 0, 0) {
inArrs[i] = INPUT_VARIABLE(i);
helpers::stack(block.launchContext(), inArrs, output, dim);
// remove unity from output shape if input arrays are vectors
// if(input->isVector()) {
// std::vector<int> outShape(output->shapeOf(), output->shapeOf() + output->rankOf());
// outShape.erase(find(outShape.begin(), outShape.end(), 1));
// output->reshapei(output->ordering(), outShape);
// if(dim != 0 && (int)block.width() == 1) // such is implemented by tensorFlow
// output->permutei({1, 0});
// output->getShapeInfo()[output->rankOf()*2 + 2] = 1;
// }
return Status::OK();
}
@ -81,9 +75,23 @@ DECLARE_SHAPE_FN(stack) {
dim += rank + 1;
REQUIRE_TRUE(dim <= inShapeInfo[0], 0, "STACK op: the input dimension parameter must be <= rank of input arrays shapes (rank=%i), but got %i instead !", inShapeInfo[0], dim);
// empty input arrays require some special handling
if (shape::isEmpty(inShapeInfo)) {
switch (rank) {
case 0: {
// we're going to return rank 1 here
if (block.width() == 1) {
return SHAPELIST(ConstantShapeHelper::getInstance()->vectorShapeInfo(0, ArrayOptions::dataType(inShapeInfo)));
} else {
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShapeInfo), 'c', {(Nd4jLong) block.width(), 0}));
}
}
}
}
if(rank == 0) {
return SHAPELIST(ConstantShapeHelper::getInstance()->vectorShapeInfo(block.width(), ArrayOptions::dataType(inShapeInfo)));
return SHAPELIST(ConstantShapeHelper::getInstance()->vectorShapeInfo(block.width(), ArrayOptions::dataType(inShapeInfo)));
}
//the rank of output ShapeInfo is larger by one compared to input ShapeInfo
@ -91,13 +99,9 @@ DECLARE_SHAPE_FN(stack) {
// insert (int) block.width() at dim position of input shape to get output shape
outShape.insert(outShape.begin() + Nd4jLong(dim), (Nd4jLong) block.width());
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(inShapeInfo), shape::order(inShapeInfo), outShape)));
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(inShapeInfo), shape::order(inShapeInfo), outShape)));
}
// 1) 1х4 + 1х4 = 2х1х4 (along dim=0) = 2x4
// 2) 1х4 + 1х4 = 1х2х4 (along dim=1) = 2x4
// 3) 4х1 + 4х1 = 2х4x1 (along dim=0) = 2x4
// 4) 4х1 + 4х1 = 4х2x1 (along dim=1) = 4x2
}
}

View File

@ -410,10 +410,13 @@ namespace nd4j {
// z->assign(x->e<float>(indices[0]));
// }
// else {
auto sub = (*x)(indices, true, true);
z->assign(sub);
// }
if (indices.size()) {
auto sub = (*x)(indices, true, true);
z->assign(sub);
}
else if (!z->isEmpty()){
z->assign(x->e(0));
}
return Status::OK();
}
DECLARE_SYN(stridedslice, strided_slice);
@ -496,28 +499,19 @@ namespace nd4j {
bool is_simple_slice;
bool is_dim0;
// FIXME: remove this, once we bring in 1D NDArrays
//vectorize(input_shape);
bool result = _preprocess_strided_slice(nullptr, &shape, input_shape, begin, end, strides, begin_mask, ellipsis_mask, end_mask, new_axis_mask, shrink_axis_mask, &is_identity, &is_simple_slice, &is_dim0);
bool nonEmpty = shape.size() > 0;
if (nonEmpty)
for (auto x: shape) {
if (x == 0) {
nonEmpty = false;
break;
}
}
if (nonEmpty && inputLen > 1) {
newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShape), 'c', shape);
}
else {
if (shape::rank(inShape) == 0 || begin >= end) {
newShape = ConstantShapeHelper::getInstance()->emptyShapeInfo(ArrayOptions::dataType(inShape));
std::vector<Nd4jLong> indices;
bool result = _preprocess_strided_slice(&indices, &shape, input_shape, begin, end, strides, begin_mask, ellipsis_mask, end_mask, new_axis_mask, shrink_axis_mask, &is_identity, &is_simple_slice, &is_dim0);
if (indices.size()) {
newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShape), 'c',
shape);
if (inputLen > 1) {
newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShape), 'c',
shape);
} else {
newShape = ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(inShape));
}
}
} else
newShape = ConstantShapeHelper::getInstance()->emptyShapeInfo(ArrayOptions::dataType(inShape));
return SHAPELIST(newShape);
}

View File

@ -37,6 +37,9 @@ namespace nd4j {
REQUIRE_TRUE(dim < input->rankOf(), 0, "Unstack dimension should be lower then rank of input %i, but got dimension=%i !", input->rankOf(), dim);
REQUIRE_TRUE(dim >= 0, 0, "Unstack dimension should be non-negative value, but got %i !", dim);
if(input->isEmpty())
return Status::OK();
std::vector<int> dims;
for (int e = 0; e < input->rankOf(); e++)
if (e != dim)
@ -65,7 +68,7 @@ namespace nd4j {
return Status::OK();
}
DECLARE_SYN(unpack, unstack);
DECLARE_SHAPE_FN(unstack) {
auto inShape = inputShape->at(0);
@ -76,6 +79,21 @@ namespace nd4j {
REQUIRE_TRUE(dim < inShape[0], 0, "UNSTACK op: dimension should be lower then rank of input %i, but got dimension=%i !", inShape[0], dim);
REQUIRE_TRUE(dim >= 0, 0, "UNSTACK op: dimension should be non-negative value, but got %i !", dim);
if(ArrayOptions::arrayType(inShape) == ArrayType::EMPTY) {
if(shape::shapeOf(inShape)[dim] == 0)
return SHAPELIST();
const Nd4jLong numTads = shape::shapeOf(inShape)[dim];
std::vector<Nd4jLong> outShape;
for(uint i = 0; i < shape::rank(inShape); ++i)
if(i != dim)
outShape.push_back(shape::shapeOf(inShape)[i]);
auto result = SHAPELIST();
for(uint i = 0; i < numTads; ++i)
result->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShape), shape::order(inShape), outShape));
return result;
}
std::vector<int> dims;
for (int e = 0; e < shape::rank(inShape); e++)
if (e != dim)

View File

@ -30,6 +30,12 @@ namespace nd4j {
auto output = OUTPUT_VARIABLE(0);
REQUIRE_TRUE(output->isScalar(), 0, "Rank output should be scalar");
if(input->isEmpty()){
output->p<double>(0, std::numeric_limits<double>::quiet_NaN());
return Status::OK();
}
int numZeros = 0;
// for (int e = 0; e < input->lengthOf(); e++)
// if ((*input)(e) == T(0))

View File

@ -113,16 +113,10 @@ DECLARE_SHAPE_FN(lstmBlock) {
}
ShapeUtils::updateStridesAndType(s, x, 'c');
Nd4jLong *s1, *s2, *s3, *s4, *s5, *s6;
COPY_SHAPE(s, s1);
COPY_SHAPE(s, s2);
COPY_SHAPE(s, s3);
COPY_SHAPE(s, s4);
COPY_SHAPE(s, s5);
COPY_SHAPE(s, s6);
Nd4jLong *s1 = CONSTANT(s);
//7 outputs, all same shape/type
return SHAPELIST(s, s1, s2, s3, s4, s5, s6);
return SHAPELIST(s1, s1, s1, s1, s1, s1, s1);
}
}

View File

@ -115,16 +115,10 @@ DECLARE_SHAPE_FN(lstmBlockCell) {
ShapeUtils::updateStridesAndType(s, xt, 'c');
Nd4jLong *s1, *s2, *s3, *s4, *s5, *s6;
COPY_SHAPE(s, s1);
COPY_SHAPE(s, s2);
COPY_SHAPE(s, s3);
COPY_SHAPE(s, s4);
COPY_SHAPE(s, s5);
COPY_SHAPE(s, s6);
Nd4jLong *s1 = CONSTANT(s);
//7 outputs, all same shape: z, i, f, o, h, c, y
return SHAPELIST(s, s1, s2, s3, s4, s5, s6);
return SHAPELIST(s1, s1, s1, s1, s1, s1, s1);
}
}

View File

@ -78,7 +78,6 @@ DECLARE_SHAPE_FN(broadcast_to) {
REQUIRE_TRUE(inputShapeInfo[inputRank+1-i] == outShape[shapeLen-i] || inputShapeInfo[inputRank+1-i] == 1, 0, "BROADCAST_TO op: shape of input array %s can't be broadcasted to the shape %s !", ShapeUtils::shapeAsString(inputShapeInfo).c_str(), ShapeUtils::shapeAsString(outShape).c_str());
auto outShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inputShapeInfo), shape::order(inputShapeInfo), outShape);
return SHAPELIST(outShapeInfo);
}

View File

@ -34,26 +34,26 @@ namespace nd4j {
bool replace = false;
auto arguments = block.getIArguments();
if (block.width() == 2 && arguments->size() == 0) {
auto axis = INPUT_VARIABLE(1);
for (int e = 0; e < axis->lengthOf(); e++) {
int ax = axis->e<int>(e);
auto origArgs = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT<int>() : *block.getIArguments();
std::vector<int> arguments({});
if(origArgs.size() > 0){
for (int e = 0; e < origArgs.size(); e++) {
int ax = origArgs[e];
if (ax < 0)
ax += x->rankOf();
arguments->emplace_back(ax);
arguments.emplace_back(ax);
}
replace = true;
} else if (arguments->size() == 0) {
} else {
for (int e = x->rankOf() - 1; e >= 0; e--)
arguments->emplace_back(e);
arguments.emplace_back(e);
}
// 0D edge case
if (x->rankOf() == 0) {
REQUIRE_TRUE(arguments->size() == 1, 0, "Permute: only one axis is allowed for scalar");
REQUIRE_TRUE(arguments.size() == 1, 0, "Permute: only one axis is allowed for scalar");
auto output = OUTPUT_VARIABLE(0);
if (!block.isInplace())
output->assign(x);
@ -62,25 +62,17 @@ namespace nd4j {
}
if(block.isInplace()) { // in-place
x->permutei(*arguments);
x->permutei(arguments);
STORE_RESULT(x);
} else {
if (!replace) { // not-in-place
auto output = OUTPUT_VARIABLE(0);
// nd4j_printv("permute shape", *arguments);
auto result = x->permute(*arguments);
output->assign(result);
STORE_RESULT(output);
delete result;
} else {
auto output = OUTPUT_VARIABLE(0); //->dup();
output->assign(x);
output->permutei(*arguments);
//OVERWRITE_RESULT(output);
}
} else {
auto output = OUTPUT_VARIABLE(0);
auto result = x->permute(arguments);
output->assign(result);
STORE_RESULT(output);
delete result;
}
return Status::OK();
return Status::OK();
}
DECLARE_TYPES(permute) {
@ -92,20 +84,21 @@ namespace nd4j {
DECLARE_SHAPE_FN(permute) {
auto shapeList = SHAPELIST();
auto arguments = block.getIArguments();
auto arguments = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT<int>() : *block.getIArguments();
if (shape::rank(inputShape->at(0)) == 0) {
shapeList->push_back(ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(inputShape->at(0))));
} else if (inputShape->size() == 1 && !arguments->empty()) {
shapeList->push_back(ShapeUtils::evalPermShapeInfo(arguments->data(), arguments->size(), *INPUT_VARIABLE(0), block.workspace()));
} else if (inputShape->size() == 2) {
// dead end
shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(inputShape->at(0))));
} else if (inputShape->size() == 1 && !arguments.empty()) {
shapeList->push_back(ShapeUtils::evalPermShapeInfo(arguments.data(), arguments.size(), *INPUT_VARIABLE(0), block.workspace()));
} else {
int rank = shape::rank(inputShape->at(0));
for (int e = rank - 1; e >= 0; e--)
arguments->emplace_back(e);
if(arguments.size() == 0){
//Reverse dimensions
int rank = shape::rank(inputShape->at(0));
for (int e = rank - 1; e >= 0; e--)
arguments.emplace_back(e);
}
shapeList->push_back(ShapeUtils::evalPermShapeInfo(arguments->data(), arguments->size(), *INPUT_VARIABLE(0), block.workspace()));
shapeList->push_back(ShapeUtils::evalPermShapeInfo(arguments.data(), arguments.size(), *INPUT_VARIABLE(0), block.workspace()));
}
return shapeList;

View File

@ -35,9 +35,8 @@ namespace nd4j {
auto arguments = block.getIArguments();
int argsSize = arguments->size();
//Special case: empty.reshape(-1) -> return empty
//Special case: empty.reshape(<other empty shape>) -> return empty
if (x->isEmpty()) {
REQUIRE_TRUE((int) arguments->size() == 1 && arguments->at(0) == -1, 0, "Reshape: when input is empty, iargs must be [-1]");
REQUIRE_TRUE(OUTPUT_VARIABLE(0)->isEmpty(), 0, "Reshape: when input is empty, output must also be empty");
return ND4J_STATUS_OK; //No op
}
@ -96,9 +95,9 @@ namespace nd4j {
//Special case: empty.reshape(-1) -> return empty
if (x->isEmpty()) {
REQUIRE_TRUE(s->lengthOf() == 1 && s->e<Nd4jLong>(0) == -1, 0, "Reshape: when input is empty, shape must be [-1]");
//REQUIRE_TRUE(s->lengthOf() == 1 && s->e<Nd4jLong>(0) == -1, 0, "Reshape: when input is empty, shape must be [-1]");
REQUIRE_TRUE(OUTPUT_VARIABLE(0)->isEmpty(), 0, "Reshape: when input is empty, output must also be empty");
return ND4J_STATUS_OK; //No op
return Status::OK(); //No op
}
char order = 'c';
@ -116,7 +115,8 @@ namespace nd4j {
}
for(int e2 = e + 1; e2 < (int) s->lengthOf(); e2++){
REQUIRE_TRUE(s->e<Nd4jLong>(e2) != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
shapeLength *= s->e<Nd4jLong>(e2);
shapeLength *=
s->e<Nd4jLong>(e2);
}
long realShape = x->lengthOf() / shapeLength;
shapeNew[e] = realShape;
@ -175,12 +175,12 @@ namespace nd4j {
e = 0;
}
//Special case: empty.reshape(-1) -> return empty
if (INPUT_VARIABLE(0)->isEmpty()) {
REQUIRE_TRUE((int) arguments->size() == 1 && arguments->at(0) == -1, 0, "Reshape: when input is empty, iargs must be [-1]");
auto newShape = ConstantShapeHelper::getInstance()->emptyShapeInfo(ArrayOptions::dataType(inp));
return SHAPELIST(newShape);
}
// //Special case: empty.reshape(-1) -> return empty
// if (INPUT_VARIABLE(0)->isEmpty()) {
// //
// auto newShape = ConstantShapeHelper::getInstance()->emptyShapeInfo(ArrayOptions::dataType(inp));
// return SHAPELIST(newShape);
// }
std::vector<Nd4jLong> shapeNew;
@ -197,8 +197,14 @@ namespace nd4j {
shapeLength *= arguments->at(e2);
}
long realShape = shape::length(inp) / shapeLength;
shapeNew.push_back(realShape);
if(shapeLength == 0){
//Edge case for empty:
shapeNew.push_back(0);
} else {
//Standard case
long realShape = shape::length(inp) / shapeLength;
shapeNew.push_back(realShape);
}
}
else{
shapeNew.push_back(arguments->at(e));
@ -218,9 +224,16 @@ namespace nd4j {
}
//Special case: empty.reshape(-1) -> return empty
if (x->isEmpty()) {
REQUIRE_TRUE(y->lengthOf() == 1 && y->e<Nd4jLong>(0) == -1, 0, "Reshape: when input is empty, shape must be [-1]");
auto newShape = ConstantShapeHelper::getInstance()->emptyShapeInfo(ArrayOptions::dataType(inp));
return SHAPELIST(newShape);
//REQUIRE_TRUE(y->lengthOf() == 1 && y->e<Nd4jLong>(0) == -1, 0, "Reshape: when input is empty, shape must be [-1]");
auto shapeOf = y->getBufferAsVector<Nd4jLong>();
Nd4jLong prod = 1;
for (auto v:shapeOf)
prod *= v;
REQUIRE_TRUE(prod == 0, 0, "Reshape: in case of empty arrays reshape must return empty array as well");
auto newShape = ShapeBuilders::createShapeInfo(ArrayOptions::dataType(inp), shape::order(inp), y->lengthOf(), shapeOf.data());
return SHAPELIST(CONSTANT(newShape));
}
std::vector<Nd4jLong> shapeNew(y->lengthOf());
@ -236,8 +249,14 @@ namespace nd4j {
REQUIRE_TRUE(y->e<Nd4jLong>(e2) != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
shapeLength *= y->e<Nd4jLong>(e2);
}
long realShape = shape::length(inp) / shapeLength;
shapeNew[e] = realShape;
if(shapeLength == 0){
//Edge case for empty:
shapeNew[e] = 0;
} else {
long realShape = shape::length(inp) / shapeLength;
shapeNew[e] = realShape;
}
}else {
shapeNew[e] = dim;
}

View File

@ -38,26 +38,31 @@ CUSTOM_OP_IMPL(concat, -1, 1, false, 0, 1) {
std::vector<int> arrsToDelete;
int index = 0;
bool allOfSameType = true;
auto theFirstRank = block.width() > 0?INPUT_VARIABLE(0)->rankOf():0;
auto theFirstDatatype = block.width() > 0?INPUT_VARIABLE(0)->dataType():block.dataType();
for(int i = 0; i < block.width(); ++i) {
if(!INPUT_VARIABLE(i)->isEmpty()) {
allOfSameType &= (INPUT_VARIABLE(0)->dataType() == INPUT_VARIABLE(i)->dataType());
if(INPUT_VARIABLE(i)->rankOf() == 0) {
// FIXME, use this instead: block.dataType()
auto vec = new NDArray('c', {1}, INPUT_VARIABLE(0)->dataType(), block.launchContext());
vec->assign(INPUT_VARIABLE(i));
auto input = INPUT_VARIABLE(i);
auto currentRank = input->rankOf();
// TODO: follow two lines are accordingly with current tf.concat spec. Commented for compatibility with legacy
// REQUIRE_TRUE(currentRank > 0, 0, "Rank of input variable %i must be greater 0, but is %lld instead.", i, currentRank);
// REQUIRE_TRUE(theFirstRank == currentRank, 0, "Number of dimensions in concat should be equals, but for %i input variable %lld != %lld appears.", i, currentRank, theFirstRank);
if(!input->isEmpty()) {
allOfSameType &= (theFirstDatatype == input->dataType());
if(input->rankOf() == 0) {
auto vec = new NDArray('c', {1}, input->dataType(), block.launchContext());
vec->assign(input);
nonEmptyArrs.push_back(vec);
arrsToDelete.push_back(index);
}
else{
nonEmptyArrs.push_back(INPUT_VARIABLE(i));
nonEmptyArrs.push_back(input);
}
++index;
}
}
const int numOfArrs = nonEmptyArrs.size();
if(numOfArrs == 0){
@ -73,21 +78,21 @@ CUSTOM_OP_IMPL(concat, -1, 1, false, 0, 1) {
REQUIRE_TRUE(allOfSameType, 0, "CONCAT op: all of input arrays must have same type !");
REQUIRE_TRUE(0 <= axis && (axis < rank || (axis == 0 && rank == 0)), 0, "CONCAT op: input axis must be in range [0, %i], but got %i instead!", rank-1, axis);
for(int i = 1; i < numOfArrs; ++i)
for(int i = 1; i < numOfArrs; ++i)
REQUIRE_TRUE(nonEmptyArrs[i]->rankOf() == rank, 0, "CONCAT op: all input arrays must have the same rank !");
for(int i = 1; i < numOfArrs; ++i) {
for(int i = 1; i < numOfArrs; ++i) {
for(int dim = 0; dim < rank; ++dim)
if(dim != axis)
if(dim != axis)
REQUIRE_TRUE(nonEmptyArrs[i]->sizeAt(dim) == nonEmptyArrs[0]->sizeAt(dim), 0, "CONCAT op: all input arrays must have the same dimensions (except those on input axis) !");
}
// ******** end of input validation ******** //
auto output = OUTPUT_VARIABLE(0);
if(numOfArrs == 1)
if(numOfArrs == 1)
output->assign(nonEmptyArrs[0]);
else
else
helpers::concat(block.launchContext(), nonEmptyArrs, *output, axis);
// delete dynamically allocated vectors with length=1
@ -110,36 +115,27 @@ CUSTOM_OP_IMPL(concat, -1, 1, false, 0, 1) {
DECLARE_SHAPE_FN(concat) {
REQUIRE_TRUE(block.width() > 0, 0, "CONCAT op: No input arrays were provided");
// first of all take into account possible presence of empty arrays
// also if scalar is present -> use the shape of vector with length=1 instead
std::vector<Nd4jLong*> nonEmptyArrShapes;
// also if scalar is present -> use the shape of vector with length=1 instead
std::vector<Nd4jLong*> arrShapes;
std::vector<int> shapesToDelete;
int index = 0;
for(int i = 0; i < block.width(); ++i) {
if(!INPUT_VARIABLE(i)->isEmpty()) {
if(inputShape->at(i)[0] == 0) {
// FIXME, use this instead: block.dataType()
nonEmptyArrShapes.push_back(ConstantShapeHelper::getInstance()->vectorShapeInfo(1, INPUT_VARIABLE(0)->dataType()));
}
else{
nonEmptyArrShapes.push_back(inputShape->at(i));
}
++index;
if(inputShape->at(i)[0] == 0) {
// FIXME, use this instead: block.dataType()
arrShapes.push_back(ConstantShapeHelper::getInstance()->vectorShapeInfo(1, INPUT_VARIABLE(0)->dataType()));
}
else{
arrShapes.push_back(inputShape->at(i));
}
++index;
}
const int numOfArrs = nonEmptyArrShapes.size();
const int numOfArrs = arrShapes.size();
if(numOfArrs == 0){
//All inputs are empty arrays -> return empty, mainly for TF import compatibility
auto empty = ConstantShapeHelper::getInstance()->emptyShapeInfo(INPUT_VARIABLE(0)->dataType());
return SHAPELIST(empty);
}
const int rank = nonEmptyArrShapes[0][0]; // look up to first non-empty array
const int rank = arrShapes[0][0];
int axis = INT_ARG(0);
if(axis < 0)
@ -148,34 +144,34 @@ DECLARE_SHAPE_FN(concat) {
// ******** input validation ******** //
REQUIRE_TRUE(0 <= axis && axis < rank, 0, "CONCAT op: input axis must be in range [0, %i], but got %i instead!", rank-1, axis);
for(int i = 1; i < numOfArrs; ++i)
REQUIRE_TRUE(nonEmptyArrShapes[i][0] == rank, 0, "CONCAT op: all input arrays must have the same rank !");
for(int i = 1; i < numOfArrs; ++i)
REQUIRE_TRUE(arrShapes[i][0] == rank, 0, "CONCAT op: all input arrays must have the same rank !");
for(int i = 1; i < numOfArrs; ++i) {
for(int i = 1; i < numOfArrs; ++i) {
for(int dim = 0; dim < rank; ++dim)
if(dim != axis)
REQUIRE_TRUE(nonEmptyArrShapes[i][dim+1] == nonEmptyArrShapes[0][dim+1], 0, "CONCAT op: all input arrays must have the same dimensions (except those on input axis) !");
if(dim != axis)
REQUIRE_TRUE(arrShapes[i][dim+1] == arrShapes[0][dim+1], 0, "CONCAT op: all input arrays must have the same dimensions (except those on input axis) !");
}
// ******** end of input validation ******** //
Nd4jLong* outShapeInfo(nullptr);
COPY_SHAPE(nonEmptyArrShapes[0], outShapeInfo);
COPY_SHAPE(arrShapes[0], outShapeInfo);
// case when we have only one input array
if(numOfArrs == 1) {
ShapeUtils::updateStridesAndType(outShapeInfo, nonEmptyArrShapes[0], shape::order(nonEmptyArrShapes[0]));
if(numOfArrs == 1) {
ShapeUtils::updateStridesAndType(outShapeInfo, arrShapes[0], shape::order(arrShapes[0]));
return SHAPELIST(CONSTANT(outShapeInfo));
}
for(int i = 1; i < numOfArrs; ++i)
outShapeInfo[axis + 1] += nonEmptyArrShapes[i][axis + 1];
outShapeInfo[axis + 1] += arrShapes[i][axis + 1];
ShapeUtils::updateStridesAndType(outShapeInfo, nonEmptyArrShapes[0], shape::order(nonEmptyArrShapes[0]));
ShapeUtils::updateStridesAndType(outShapeInfo, arrShapes[0], shape::order(arrShapes[0]));
// delete dynamically allocated vectors shapes with length=1
for(int index : shapesToDelete)
RELEASE(nonEmptyArrShapes[index], block.getWorkspace());
RELEASE(arrShapes[index], block.getWorkspace());
auto result = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(outShapeInfo));
RELEASE(outShapeInfo, block.getWorkspace());
@ -277,7 +273,7 @@ DECLARE_SHAPE_FN(concat) {
// DECLARE_SYN(ParallelConcat, concat);
// DECLARE_SYN(concat_v2, concat);
// DECLARE_SYN(concatv2, concat);
// DECLARE_SHAPE_FN(concat) {
// auto inp = inputShape->at(0);
// int _dimension = INT_ARG(0);
@ -338,7 +334,7 @@ DECLARE_SHAPE_FN(concat) {
// }
// }
// ALLOCATE(newShape, block.getWorkspace(), shape::shapeInfoLength(first->shapeInfo()), Nd4jLong);
// if (_dimension < 0)
@ -382,11 +378,11 @@ DECLARE_SHAPE_FN(concat) {
auto epsilonChunk = OUTPUT_VARIABLE(e);
std::vector<Nd4jLong> indices(2 * epsilonNext->rankOf());
int width = originalChunk->sizeAt(axis);
int width = originalChunk->sizeAt(axis);
for (int e = 0; e < epsilonNext->rankOf(); e++) {
if (e == axis)
indices[2*e + 1] = (indices[2*e] = startPos) + width;
indices[2*e + 1] = (indices[2*e] = startPos) + width;
else
indices[2*e + 1] = indices[2*e] = 0;
}
@ -394,7 +390,7 @@ DECLARE_SHAPE_FN(concat) {
auto subarray = (*epsilonNext)(indices, true);
epsilonChunk->assign(subarray);
startPos += width;
startPos += width;
}
return ND4J_STATUS_OK;

View File

@ -32,6 +32,11 @@ namespace nd4j {
REQUIRE_TRUE(input->dataType() == output->dataType(), 0, "CumSum: input and output data types must be equal");
if(input->isEmpty()){
//No-op
return Status::OK();
}
const bool exclusive = INT_ARG(0) == 1;
const bool reverse = INT_ARG(1) == 1;

View File

@ -36,6 +36,11 @@ CONFIGURABLE_OP_IMPL(cumsum, 1, 1, true, 0, 2) {
REQUIRE_TRUE(input->dataType() == output->dataType(), 0, "CumSum: input and output data types must be equal");
if(input->isEmpty()){
//No-op
return Status::OK();
}
if (block.getIArguments()->size() == 2 && block.width() == 1) {
// all at once case
nd4j::ops::helpers::_prefix(block.launchContext(), scalar::Add, input, output, exclusive, reverse);

View File

@ -102,12 +102,6 @@ DECLARE_SHAPE_FN(gather) {
if(axis < 0)
axis += inputRank;
//Edge case: empty indices, empty input -> empty output
if(block.width() > 1 && INPUT_VARIABLE(0)->isEmpty() && INPUT_VARIABLE(1)->isEmpty()){
auto empty = ConstantShapeHelper::getInstance()->emptyShapeInfo(INPUT_VARIABLE(0)->dataType());
return SHAPELIST(empty);
}
REQUIRE_TRUE(axis < inputRank, 0, "GATHER op: input axis must be smaller than input array rank, but got %i and %i correspondingly!", axis, inputRank);
bool isEmpty = false;
@ -118,11 +112,6 @@ DECLARE_SHAPE_FN(gather) {
int indicesRank = shape::rank(indicesShapeInfo);
int outputRank = inputRank + indicesRank - 1;
if(INPUT_VARIABLE(1)->isEmpty()) { //Empty indices -> empty output
outputRank = 0;
isEmpty = true;
}
ALLOCATE(outputShapeInfo, block.getWorkspace(), shape::shapeInfoLength(outputRank), Nd4jLong);

View File

@ -33,6 +33,11 @@ namespace ops {
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
if(output->isEmpty()){
//No-op
return Status::OK();
}
std::vector<int> axis;
if (block.width() > 1)

View File

@ -204,6 +204,8 @@ namespace nd4j {
const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst,
mkldnn::memory::desc* conv_src_md, mkldnn::memory::desc* conv_diff_src_md, mkldnn::memory::desc* conv_weights_md,
mkldnn::memory::desc* conv_diff_weights_md, mkldnn::memory::desc* conv_bias_md, mkldnn::memory::desc* conv_dst_md,
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_weights_md,
mkldnn::memory::desc* user_diff_weights_md, mkldnn::memory::desc* user_bias_md, mkldnn::memory::desc* user_dst_md,
mkldnn::memory::dims& conv_strides, mkldnn::memory::dims& conv_padding, mkldnn::memory::dims& conv_padding_r);
static void getMKLDNNMemoryDescConv3d(
@ -212,56 +214,60 @@ namespace nd4j {
const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst,
mkldnn::memory::desc* conv_src_md, mkldnn::memory::desc* conv_diff_src_md, mkldnn::memory::desc* conv_weights_md,
mkldnn::memory::desc* conv_diff_weights_md, mkldnn::memory::desc* conv_bias_md, mkldnn::memory::desc* conv_dst_md,
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_weights_md,
mkldnn::memory::desc* user_diff_weights_md, mkldnn::memory::desc* user_bias_md, mkldnn::memory::desc* user_dst_md,
mkldnn::memory::dims& conv_strides, mkldnn::memory::dims& conv_padding, mkldnn::memory::dims& conv_padding_r);
static void getMKLDNNMemoryDescPool2d(
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, int poolingMode, int extraParam0, bool isNCHW,
int bS, int iC, int iH, int iW, int oC, int oH, int oW,
const NDArray* src, const NDArray* diff_src, const NDArray* dst,
mkldnn::memory::desc* pool_src_md, mkldnn::memory::desc* conv_diff_src_md, mkldnn::memory::desc* pool_dst_md, mkldnn::algorithm& algorithm,
const NDArray* src, const NDArray* diff_src, const NDArray* dst, mkldnn::algorithm& algorithm,
mkldnn::memory::desc* pool_src_md, mkldnn::memory::desc* pool_diff_src_md, mkldnn::memory::desc* pool_dst_md,
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md,
mkldnn::memory::dims& pool_strides, mkldnn::memory::dims& pool_kernel, mkldnn::memory::dims& pool_padding, mkldnn::memory::dims& pool_padding_r);
static void getMKLDNNMemoryDescPool3d(
int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, int poolingMode, int extraParam0, bool isNCDHW,
int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW,
const NDArray* src, const NDArray* diff_src, const NDArray* dst,
mkldnn::memory::desc* pool_src_md, mkldnn::memory::desc* conv_diff_src_md, mkldnn::memory::desc* pool_dst_md, mkldnn::algorithm& algorithm,
const NDArray* src, const NDArray* diff_src, const NDArray* dst, mkldnn::algorithm& algorithm,
mkldnn::memory::desc* pool_src_md, mkldnn::memory::desc* pool_diff_src_md, mkldnn::memory::desc* pool_dst_md,
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md,
mkldnn::memory::dims& pool_strides, mkldnn::memory::dims& pool_kernel, mkldnn::memory::dims& pool_padding, mkldnn::memory::dims& pool_padding_r);
#endif
static void conv2d(nd4j::LaunchContext &context, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW);
static void conv2d(nd4j::graph::Context &context, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW);
static void conv2d(nd4j::LaunchContext & block, const std::vector<NDArray*>& inArrs, NDArray* output, const std::vector<int>& intArgs);
static void conv2d(nd4j::graph::Context & block, const std::vector<NDArray*>& inArrs, NDArray* output, const std::vector<int>& intArgs);
static void conv2dBP(nd4j::LaunchContext & block, const std::vector<NDArray*>& inArrs, const std::vector<NDArray*>& outArrs, const std::vector<int>& intArgs);
static void conv2dBP(nd4j::graph::Context & block, const std::vector<NDArray*>& inArrs, const std::vector<NDArray*>& outArrs, const std::vector<int>& intArgs);
static void conv2dBP(nd4j::LaunchContext & block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW);
static void conv2dBP(nd4j::graph::Context & block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW);
static void depthwiseConv2d(nd4j::LaunchContext & block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW);
static void depthwiseConv2d(nd4j::graph::Context & block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW);
static void depthwiseConv2dBP(nd4j::LaunchContext & block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW);
static void depthwiseConv2dBP(nd4j::graph::Context & block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW);
static void sconv2d(nd4j::LaunchContext & block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW);
static void sconv2d(nd4j::graph::Context & block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW);
static void vol2col(nd4j::LaunchContext & block, const NDArray& vol, NDArray& col, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW);
static void vol2col(nd4j::graph::Context & block, const NDArray& vol, NDArray& col, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW);
static void col2vol(nd4j::LaunchContext & block, const NDArray& col, NDArray& vol, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW);
static void col2vol(nd4j::graph::Context & block, const NDArray& col, NDArray& vol, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW);
static void upsampling2d(nd4j::LaunchContext & block, const NDArray& input, NDArray& output, const int factorH, const int factorW, const bool isNCHW);
static void upsampling2d(nd4j::graph::Context & block, const NDArray& input, NDArray& output, const int factorH, const int factorW, const bool isNCHW);
static void upsampling3d(nd4j::LaunchContext & block, const NDArray& input, NDArray& output, const int factorD, const int factorH, const int factorW, const bool isNCDHW);
static void upsampling3d(nd4j::graph::Context & block, const NDArray& input, NDArray& output, const int factorD, const int factorH, const int factorW, const bool isNCDHW);
static void upsampling2dBP(nd4j::LaunchContext & block, const NDArray& gradO, NDArray& gradI, const bool isNCHW);
static void upsampling2dBP(nd4j::graph::Context & block, const NDArray& gradO, NDArray& gradI, const bool isNCHW);
static void upsampling3dBP(nd4j::LaunchContext & block, const NDArray& gradO, NDArray& gradI, const bool isNCDHW);
static void upsampling3dBP(nd4j::graph::Context & block, const NDArray& gradO, NDArray& gradI, const bool isNCDHW);
static void pooling2d(nd4j::LaunchContext & block, const NDArray& input, NDArray& output, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const PoolingType poolingMode, const int extraParam0);
static void pooling2d(nd4j::graph::Context & block, const NDArray& input, NDArray& output, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const PoolingType poolingMode, const int extraParam0);
static void pooling3d(nd4j::LaunchContext & block, const NDArray& input, NDArray& output, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0);
static void pooling3d(nd4j::graph::Context & block, const NDArray& input, NDArray& output, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0);
static void pooling2dBP(nd4j::LaunchContext & block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0);
static void pooling2dBP(nd4j::graph::Context & block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0);
static void pooling3dBP(nd4j::LaunchContext & block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0);
static void pooling3dBP(nd4j::graph::Context & block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0);
};
}

View File

@ -153,7 +153,7 @@ void softMaxForVector(nd4j::LaunchContext * context, const NDArray& input, NDArr
if (inEWS == 1) {
PRAGMA_OMP_SIMD_MAX(max)
for (int i = 0; i < length; i++)
max = nd4j::math::nd4j_max<T>(max, outBuff[i]);
max = nd4j::math::nd4j_max<T>(max, inBuff[i]);
PRAGMA_OMP_SIMD_SUM(sum)
for (int i = 0; i < length; i++) {
@ -171,7 +171,7 @@ void softMaxForVector(nd4j::LaunchContext * context, const NDArray& input, NDArr
PRAGMA_OMP_SIMD_MAX(max)
for (int i = 0; i < length; i++)
max = nd4j::math::nd4j_max<T>(max, outBuff[i * inEWS]);
max = nd4j::math::nd4j_max<T>(max, inBuff[i * inEWS]);
PRAGMA_OMP_SIMD_SUM(sum)
for (int i = 0; i < length; i++) {
@ -204,7 +204,7 @@ static void softmax_(nd4j::LaunchContext * context, const NDArray& input, NDArra
const int rank = input.rankOf();
if(input.isVector()) {
if(rank == 1 || input.sizeAt(dimension) != 1)
softMaxForVector_<T>(input.getBuffer(), input.getShapeInfo(), output.buffer(), output.getShapeInfo());
else
@ -228,7 +228,7 @@ static void softmax_(nd4j::LaunchContext * context, const NDArray& input, NDArra
T max = -DataTypeUtils::max<T>();
T sum = 0;
for(uint j = 0; j < tadLen; ++j)
max = nd4j::math::nd4j_max<T>(max, inBuff[j]);
@ -237,9 +237,9 @@ static void softmax_(nd4j::LaunchContext * context, const NDArray& input, NDArra
outBuff[j] = temp;
sum += temp;
}
for (uint j = 0; j < tadLen; ++j)
outBuff[j] /= sum;
outBuff[j] /= sum;
}
}
else {

File diff suppressed because it is too large Load Diff

View File

@ -227,15 +227,15 @@ void lstmBlockCell(const NDArray* xt, const NDArray* cLast, const NDArray* yLast
//NDArray* NDArrayFactory::create_( const char order, const std::vector<Nd4jLong> &shape, nd4j::DataType dataType, nd4j::memory::Workspace* workspace) {
std::vector<Nd4jLong> shape = {bS, 4*numUnits};
auto m = NDArrayFactory::create_('c', shape, xt->dataType(), nullptr);
MmulHelper::mmul(&concatOut, W, m, 1.0f, 0.0f, 'c'); //mmul: [bs, (nIn+numUnits)]* [(inSize+numUnits), 4*numUnits] = [bs, 4*numUnits] - C result array
*m += (*b); //addiRowVector
auto m = NDArrayFactory::create('c', shape, xt->dataType());
MmulHelper::mmul(&concatOut, W, &m, 1.0f, 0.0f, 'c'); //mmul: [bs, (nIn+numUnits)]* [(inSize+numUnits), 4*numUnits] = [bs, 4*numUnits] - C result array
m += (*b); //addiRowVector
//Note: weights are ordered [inputGate, blockInput, forgetGate, outputGate] to match TF (TF code comments state [i,f,z/ci,o] but behaviour is [i,z,f,o])
auto zi = (*m)({0,0, 0, numUnits}); // z for input modulation gate, [bS, numUnits]
auto zz = (*m)({0,0, numUnits, 2*numUnits}); // z for block input, [bS, numUnits]
auto zf = (*m)({0,0, 2*numUnits, 3*numUnits}); // z for forget gate, [bS, numUnits]
auto zo = (*m)({0,0, 3*numUnits, 4*numUnits}); // z for output gate, [bS, numUnits]
auto zi = (m)({0,0, 0, numUnits}); // z for input modulation gate, [bS, numUnits]
auto zz = (m)({0,0, numUnits, 2*numUnits}); // z for block input, [bS, numUnits]
auto zf = (m)({0,0, 2*numUnits, 3*numUnits}); // z for forget gate, [bS, numUnits]
auto zo = (m)({0,0, 3*numUnits, 4*numUnits}); // z for output gate, [bS, numUnits]
if(peephole) { // add peephole connections: z + ct_1*Wc
zi += (*cLast) * (*Wci); // add peephole connections to input gate

View File

@ -57,7 +57,7 @@ namespace helpers {
ConvolutionUtils::calcPadding2D(pY, pX, oY, oX, inY, inX, params[0], params[1], params[2], params[3], params[6], params[7]);
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor;
ConvolutionUtils::pooling2d(*block.launchContext(), *input, *values, kY, kX, sY, sX, pY, pX, dY, dX, PoolingType::MAX_POOL, 1);
ConvolutionUtils::pooling2d(block, *input, *values, kY, kX, sY, sX, pY, pX, dY, dX, PoolingType::MAX_POOL, 1);
if (nullptr != indices) {
// for max_pool_with_argmax

View File

@ -53,9 +53,16 @@ namespace nd4j {
dtype = nd4j::DataType::BOOL;
if(shape::isEmpty(x) || shape::isEmpty(y)) {
//Edge case: broadcasting with empty array gives empty array output (behaviour to match TF for import cases)
auto empty = ConstantShapeHelper::getInstance()->emptyShapeInfo(dtype);
shapeList->push_back(empty);
// this is edge case, [3, 4] + [] = []
if ((shape::isEmpty(x) && shape::rank(x) == 0) || (shape::isEmpty(y) && shape::rank(y) == 0)) {
shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor::emptyDescriptor(dtype)));
return shapeList;
}
Nd4jLong *newshape = nullptr;
ShapeUtils::evalBroadcastShapeInfo(x, y, true, newshape, block.workspace());
shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(newshape, dtype)));
} else if (shape::isScalar(x) && shape::isScalar(y)) {
if (shape::rank(x) >= shape::rank(y)) {
shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(x, dtype)));

View File

@ -2873,7 +2873,7 @@ namespace simdOps {
const static functions::ReduceType reduceType = functions::ReduceType::MAX;
op_def static X startingValue(const X *input) {
return -nd4j::DataTypeUtils::max<X>();
return -nd4j::DataTypeUtils::infOrMax<X>();
}
op_def static X merge(X old, X opOutput, X *extraParams) {
@ -3051,7 +3051,7 @@ namespace simdOps {
const static functions::ReduceType reduceType = functions::ReduceType::MIN;
op_def static X startingValue(const X *input) {
return nd4j::DataTypeUtils::max<X>();
return nd4j::DataTypeUtils::infOrMax<X>();
}
op_def static X merge(X old, X opOutput, X *extraParams) {
@ -3831,7 +3831,7 @@ namespace simdOps {
}
static _CUDA_HD inline X startingValue(const X *input) {
return -nd4j::DataTypeUtils::max<X>();
return -nd4j::DataTypeUtils::infOrMax<X>();
}
static _CUDA_HD inline functions::indexreduce::IndexValue<X> startingIndexValue(X *input) {
@ -3890,7 +3890,7 @@ namespace simdOps {
}
static _CUDA_HD inline X startingValue(const X *input) {
return -nd4j::DataTypeUtils::max<X>();
return -nd4j::DataTypeUtils::infOrMax<X>();
}
static _CUDA_HD inline functions::indexreduce::IndexValue<X> startingIndexValue(X *input) {
@ -3958,7 +3958,7 @@ namespace simdOps {
}
static _CUDA_HD inline X startingValue(const X *input) {
return -nd4j::DataTypeUtils::max<X>();
return -nd4j::DataTypeUtils::infOrMax<X>();
}
static _CUDA_HD inline functions::indexreduce::IndexValue<X> startingIndexValue(X *input) {
@ -3984,7 +3984,7 @@ namespace simdOps {
}
static _CUDA_HD inline X startingValue(const X *input) {
return nd4j::DataTypeUtils::max<X>();
return nd4j::DataTypeUtils::infOrMax<X>();
}
static _CUDA_HD inline functions::indexreduce::IndexValue<X> startingIndexValue(X *input) {
@ -4040,7 +4040,7 @@ namespace simdOps {
}
static _CUDA_HD inline X startingValue(const X *input) {
return nd4j::DataTypeUtils::max<X>();
return nd4j::DataTypeUtils::infOrMax<X>();
}
static _CUDA_HD inline functions::indexreduce::IndexValue<X> startingIndexValue(X *input) {

View File

@ -580,6 +580,152 @@ TEST_F(BroadcastableOpsTests, broadcast_empty_1) {
ASSERT_TRUE(z.equalsTo(zExp));
}
TEST_F(BroadcastableOpsTests, broadcast_empty_2) {
NDArray y('c', {1,4}, {1,2,3,4});
NDArray x = NDArrayFactory::create<double>('c', {0, 4});
NDArray e = NDArrayFactory::create<double>('c', {0, 4});;
nd4j::ops::multiply op;
auto status = op.execute({&x, &y}, {&x}, {}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, status);
ASSERT_TRUE(e.isSameShape(x));
ASSERT_TRUE(e.equalsTo(x));
}
TEST_F(BroadcastableOpsTests, broadcast_empty_3) {
NDArray x = NDArrayFactory::create<float>('c', {1, 0, 2});
NDArray y('c', {}, {0.1}, nd4j::DataType::FLOAT32);
NDArray e = NDArrayFactory::create<float>('c', {1, 0, 2});;
nd4j::ops::maximum op;
auto result = op.execute({&x, &y}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_TRUE(e.isSameShape(z));
ASSERT_TRUE(e.equalsTo(*z));
delete result;
}
TEST_F(BroadcastableOpsTests, broadcast_empty_4) {
NDArray x = NDArrayFactory::create<float>('c', {1, 0, 1});
NDArray y = NDArrayFactory::create<float>('c', {1, 0, 2});
NDArray e = NDArrayFactory::create<float>('c', {1, 0, 2});;
nd4j::ops::maximum op;
auto result = op.execute({&x, &y}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_TRUE(e.isSameShape(z));
ASSERT_TRUE(e.equalsTo(*z));
delete result;
}
TEST_F(BroadcastableOpsTests, broadcast_empty_5) {
NDArray x = NDArrayFactory::create<float>('c', {1, 0, 1});
NDArray y = NDArrayFactory::create<float>('c', {1, 0, 2});
NDArray e = NDArrayFactory::create<float>('c', {1, 0, 2});;
nd4j::ops::realdiv op;
auto result = op.execute({&x, &y}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_TRUE(e.isSameShape(z));
ASSERT_TRUE(e.equalsTo(*z));
delete result;
}
TEST_F(BroadcastableOpsTests, broadcast_empty_6) {
NDArray x = NDArrayFactory::create<float>('c', {1, 0, 1});
NDArray y = NDArrayFactory::create<float>('c', {1, 2}, {2, 2});
NDArray e = NDArrayFactory::create<float>('c', {1, 0, 2});;
nd4j::ops::realdiv op;
auto result = op.execute({&x, &y}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_TRUE(e.isSameShape(z));
ASSERT_TRUE(e.equalsTo(*z));
delete result;
}
TEST_F(BroadcastableOpsTests, broadcast_empty_7) {
NDArray x = NDArrayFactory::create<float>('c', {1, 0, 2, 1});
NDArray y = NDArrayFactory::create<float>('c', {1, 2, 0});
NDArray e = NDArrayFactory::create<float>('c', {1, 0, 2, 0});;
nd4j::ops::realdiv op;
auto result = op.execute({&x, &y}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_TRUE(e.isSameShape(z));
ASSERT_TRUE(e.equalsTo(*z));
delete result;
}
TEST_F(BroadcastableOpsTests, broadcast_bool_empty_1) {
NDArray y('c', {3,4}, {0,0,0,0, 1,2,3,4, 1,2,3,4});
NDArray x(nd4j::DataType::DOUBLE, y.getContext(), false);
NDArray z(nd4j::DataType::BOOL, y.getContext(), false);
NDArray zExp(nd4j::DataType::BOOL, y.getContext(), false);
nd4j::ops::greater op;
auto status = op.execute({&x, &y}, {&z}, {}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, status);
ASSERT_TRUE(z.isSameShape(zExp));
ASSERT_TRUE(z.equalsTo(zExp));
}
TEST_F(BroadcastableOpsTests, broadcast_bool_empty_2) {
NDArray y('c', {1,4}, {1,2,3,4});
NDArray x = NDArrayFactory::create<double>('c', {0, 4});
NDArray e = NDArrayFactory::create<bool>('c', {0, 4});;
nd4j::ops::greater op;
auto result = op.execute({&x, &y}, {}, {}, {});
auto z = result->at(0);
z->printShapeInfo("z");
ASSERT_EQ(Status::OK(), result->status());
ASSERT_TRUE(e.isSameShape(z));
ASSERT_TRUE(e.equalsTo(*z));
delete result;
}
TEST_F(BroadcastableOpsTests, broadcast_bool_1) {
NDArray x('c', {3, 1, 2}, nd4j::DataType::FLOAT32);

View File

@ -2021,7 +2021,8 @@ TEST_F(ConvolutionTests1, vol2col_test1) {
// PointersManager manager(columnsExpected.getContext());
// manager.printDevContentOnHost<float>(columnsExpected.getSpecialBuffer(), columnsExpected.lengthOf());
nd4j::ops::ConvolutionUtils::vol2col(*LaunchContext::defaultContext(), volume, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW);
graph::Context context(1);
nd4j::ops::ConvolutionUtils::vol2col(context, volume, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW);
ASSERT_TRUE(columns.equalsTo(columnsExpected));
}
@ -2052,7 +2053,8 @@ TEST_F(ConvolutionTests1, vol2col_test2) {
-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.});
nd4j::ops::ConvolutionUtils::vol2col(*LaunchContext::defaultContext(), volume, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW);
graph::Context context(1);
nd4j::ops::ConvolutionUtils::vol2col(context, volume, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW);
ASSERT_TRUE(columns.equalsTo(columnsExpected));
}

View File

@ -1302,8 +1302,8 @@ TEST_F(DeclarableOpsTests10, broadcast_to_test6) {
TEST_F(DeclarableOpsTests10, broadcast_to_test7) {
auto input = NDArrayFactory::create<double>(10.f);
auto shape = NDArrayFactory::create<double>(0.f);
auto exp = NDArrayFactory::create<double>(10.f);
auto shape = NDArrayFactory::create<Nd4jLong>(1);
auto exp = NDArrayFactory::create<double>('c', {1}, {10.});
nd4j::ops::broadcast_to op;
auto results = op.execute({&input, &shape}, {}, {}, {});
@ -2261,8 +2261,8 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_1) {
NDArray x('c', {2,3}, {-63.80f, -63.75f, -63.70f, -63.5f, 0.0f, 0.1f}, nd4j::DataType::FLOAT32);
NDArray exp('c', {2,3}, {-63.75f, -63.75f, -63.75f, -63.251953f, 0.0f, 0.0f}, nd4j::DataType::FLOAT32);
NDArray min('c', {0}, {-63.65f}, nd4j::DataType::FLOAT32);
NDArray max('c', {0}, {0.1f}, nd4j::DataType::FLOAT32);
NDArray min('c', {}, {-63.65f}, nd4j::DataType::FLOAT32);
NDArray max('c', {}, {0.1f}, nd4j::DataType::FLOAT32);
nd4j::ops::fake_quant_with_min_max_vars op;
auto results = op.execute({&x, &min, &max}, {}, {});

View File

@ -136,7 +136,7 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test3) {
NDArray dLdpExp('c', {2,3,4}, {-12.49997,-13.04346, -13.63635, -14.28571,-14.99999,-15.78947, -16.66666, -17.64705,-18.75 ,-20. , -21.42857, -23.07692,
-24.99999,-27.27272, -29.99999, -33.33332,-37.49999,-42.85713, -49.99998, -59.99998,-74.99995,-99.99992,-149.99986,-299.99911});
NDArray dLdwExp('c', {0}, {-227.77286});
NDArray dLdwExp('c', {}, {-227.77286});
NDArray dLdlExp('c', {2,3,4}, {1.58903, 1.22117, 0.99621, 0.82911, 0.69315, 0.57634, 0.47223, 0.37689, 0.28768, 0.20273, 0.12058, 0.04002,
-0.04002,-0.12058,-0.20273,-0.28768,-0.37689,-0.47223,-0.57634,-0.69315,-0.82911,-0.99621,-1.22117,-1.58903});
@ -261,7 +261,7 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test7) {
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights(nd4j::DataType::DOUBLE);
NDArray dLdwExp('c', {0}, {0.});
NDArray dLdwExp('c', {}, {0.});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
@ -583,7 +583,7 @@ TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test3) {
NDArray dLdpExp('c', {2,3,4}, {-0.96, -1.92, -2.88, -3.84, -4.8 , -5.76, -6.72, -7.68, -8.64, -9.6 ,-10.56,-11.52,
-12.48,-13.44,-14.4 ,-15.36,-16.32,-17.28,-18.24,-19.2 ,-20.16,-21.12,-22.08,-23.04});
NDArray dLdwExp('c', {0}, {4515.84});
NDArray dLdwExp('c', {}, {4515.84});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
@ -702,7 +702,7 @@ TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test7) {
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights(nd4j::DataType::DOUBLE);
NDArray dLdwExp('c', {0}, {0.});
NDArray dLdwExp('c', {}, {0.});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
@ -1031,7 +1031,7 @@ TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test3) {
NDArray dLdpExp('c', {2,3,4}, {-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,
-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5});
NDArray dLdwExp('c', {0}, {288.});
NDArray dLdwExp('c', {}, {288.});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
@ -1150,7 +1150,7 @@ TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test7) {
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights(nd4j::DataType::DOUBLE);
NDArray dLdwExp('c', {0}, {0.});
NDArray dLdwExp('c', {}, {0.});
predictions.linspace(0.04, 0.04);
labels.linspace(1);
@ -1519,7 +1519,7 @@ TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test3) {
NDArray dLdpExp('c', {2,3,4}, {-0.18499,-0.53 ,-0.875 ,-1.22 ,-1.56501,-1.91002,-2.25504,-2.60008,-2.94514,-3.29023,-3.63534,-3.98048,
-4.32566,-4.67087,-5.01613,-5.36143,-5.70677,-6.05217,-6.39762,-6.74313,-7.0887 ,-7.43432,-7.78001,-8.12577});
NDArray dLdwExp('c', {0}, {-91.52109});
NDArray dLdwExp('c', {}, {-91.52109});
NDArray dLdlExp('c', {2,3,4}, {0.028, 0.014, -0., -0.014,-0.028, -0.042, -0.056, -0.07 ,-0.084, -0.098, -0.112, -0.126,
-0.14 , -0.154, -0.168, -0.182,-0.196, -0.21 , -0.224, -0.238,-0.252, -0.266, -0.28 , -0.294});
@ -1642,7 +1642,7 @@ TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test7) {
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
NDArray weights(nd4j::DataType::DOUBLE);
NDArray dLdwExp('c', {0}, {0.});
NDArray dLdwExp('c', {}, {0.});
logits.linspace(-0.08, 0.04);
labels.linspace(1);
@ -2001,10 +2001,10 @@ TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test3) {
NDArray labels('c', {4}, {0,0,1,0}, nd4j::DataType::INT32);
NDArray logits('c', {4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {0}, nd4j::DataType::DOUBLE);
NDArray weights('c', {}, {0}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {4}, {0.125, 0.125, -0.375, 0.125});
NDArray dLdwExp('c', {0}, {1.38629});
NDArray dLdwExp('c', {}, {1.38629});
logits = 2.;
weights.assign(0.5);
@ -2032,10 +2032,10 @@ TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test4) {
NDArray labels('c', {4}, {0,0,1,0}, nd4j::DataType::INT32);
NDArray logits('c', {4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {0}, nd4j::DataType::DOUBLE);
NDArray weights('c', {}, {0}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {4}, {0.23521, 0.2448 , -0.7452 , 0.26519});
NDArray dLdwExp('c', {0}, {0.});
NDArray dLdwExp('c', {}, {0.});
logits.linspace(-0.08, 0.04);
weights = 0.5;
@ -2466,7 +2466,7 @@ TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test2) {
/////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test3) {
NDArray labels('c', {0}, {1}, nd4j::DataType::INT64);
NDArray labels('c', {}, {1}, nd4j::DataType::INT64);
NDArray logits('c', {2}, {-0.2, 0.3});
NDArray dLdpExp('c', {2}, {0.37754, -0.37754});

View File

@ -158,10 +158,10 @@ TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test4) {
NDArray labels('c', {1,4}, {-0.1, 0.3, 2, -1.4});
NDArray predictions('c', {1,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {0}, nd4j::DataType::DOUBLE);
NDArray weights('c', {}, {0.}, nd4j::DataType::DOUBLE);
NDArray dLdpExp('c', {1,4}, {0.05, -0.15, -1., 0.7});
NDArray dLdwExp('c', {0}, {1.3});
NDArray dLdwExp('c', {}, {1.3});
NDArray dLdlExp('c', {1,4}, {0.2, 0.1, -0. , -0.1});
predictions.linspace(-0.4, 0.2);
@ -369,10 +369,10 @@ TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test9) {
TEST_F(DeclarableOpsTests12, hinge_loss_14) {
NDArray logits('c', {3,4}, nd4j::DataType::DOUBLE);
NDArray weights('c', {0}, {1.});
NDArray weights('c', {}, {1.});
NDArray labels('c', {3,4}, {0,1,1,0,1,0,1,0,1,0,1,0});
NDArray output('c', {0}, nd4j::DataType::DOUBLE);
NDArray output('c', {}, {0.}, nd4j::DataType::DOUBLE);
logits.linspace(1.);
weights.assign(1.);
@ -594,7 +594,7 @@ TEST_F(DeclarableOpsTests12, TestMinimumBP_1) {
TEST_F(DeclarableOpsTests12, reverse_test15) {
NDArray x('c', {5}, {1,2,3,4,5}, nd4j::DataType::DOUBLE);
NDArray axis('c', {0}, {0}, nd4j::DataType::INT32);
NDArray axis('c', {}, {0}, nd4j::DataType::INT32);
NDArray z('c', {5}, nd4j::DataType::DOUBLE);
NDArray exp('c', {5}, {5,4,3,2,1}, nd4j::DataType::DOUBLE);

View File

@ -123,6 +123,7 @@ TEST_F(DeclarableOpsTests14, Test_EvalReductionShape_1) {
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
z->printIndexedBuffer("Reduced shape");
ASSERT_EQ(e, *z);
delete result;
@ -212,4 +213,200 @@ TEST_F(DeclarableOpsTests14, Test_scalar_broadcast_2) {
ASSERT_EQ(e, *result->at(0));
delete result;
}
}
TEST_F(DeclarableOpsTests14, test_empty_fill_1) {
auto x = NDArrayFactory::empty<int>();
auto y = NDArrayFactory::create<int>(1);
nd4j::ops::fill op;
auto result = op.execute({&x, &y}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_EQ(y, *z);
delete result;
}
TEST_F(DeclarableOpsTests14, test_lstmBlockCell_1) {
auto a = NDArrayFactory::create<float>('c', {1, 5}, {0.7787856f, 0.80119777f, 0.72437465f, 0.23089433f, 0.72714126f});
auto b = NDArrayFactory::create<float>('c', {1, 3});
auto c = NDArrayFactory::create<float>('c', {1, 3});
auto d = NDArrayFactory::create<float>('c', {8, 12}, {-0.15320599,-0.120416045,0.33126968,0.13921785,-0.32313538,-0.43956736,0.4756174,0.4335605,-0.5450856,-0.3943429,-0.28687626,0.068032146,-0.2793799,0.17298919,-0.36553562,-0.097853184,-0.2544747,-0.39872527,-0.14556861,-0.31479517,0.2559092,0.47166896,-0.31330687,0.47313118,0.5134543,-0.4678212,-0.12853557,0.26142156,0.43472284,-0.42842552,-0.1895876,0.538689,0.508651,-0.020272732,0.112327516,0.2704304,-0.046546757,0.32570732,-0.15148133,-0.19145513,0.18631572,-0.024152994,0.41603214,-0.3421499,0.0106860995,-0.2966229,-0.36713937,0.25841123,0.0843398,0.49082482,0.10800403,0.1874243,-0.26379472,-0.22531849,0.24924624,0.23119557,0.49940765,-0.051413506,0.20315129,-0.41888732,0.44097036,0.40453392,0.013338983,0.23434466,0.23942488,0.47894,-0.19898453,0.09253675,-0.032358468,-0.15213022,-0.3441009,-0.15600958,-0.08235118,0.12165731,-0.4481289,-0.4842423,-0.45797008,-0.4606034,0.08163166,-0.2981107,0.50207126,0.44195646,0.13850057,0.072246075,-0.34388685,0.030900061,0.35821778,0.47900867,0.5094063,0.23683065,0.18020362,-0.1369732,0.015235603,0.2786904,0.07954317,0.12543976});
auto e = NDArrayFactory::create<float>('c', {3});
auto f = NDArrayFactory::create<float>('c', {3});
auto g = NDArrayFactory::create<float>('c', {3});
auto h = NDArrayFactory::create<float>('c', {12});
auto z0 = NDArrayFactory::create<float>('c', {1, 3});
auto z1 = NDArrayFactory::create<float>('c', {1, 3});
auto z2 = NDArrayFactory::create<float>('c', {1, 3});
auto z3 = NDArrayFactory::create<float>('c', {1, 3});
auto z4 = NDArrayFactory::create<float>('c', {1, 3});
auto z5 = NDArrayFactory::create<float>('c', {1, 3});
auto z6 = NDArrayFactory::create<float>('c', {1, 3});
nd4j::ops::lstmBlockCell op;
auto result = op.execute({&a, &b, &c, &d, &e, &f, &g, &h}, {&z0, &z1, &z2, &z3, &z4, &z5, &z6}, {1.0, -1.0}, {0}, {});
ASSERT_EQ(Status::OK(), result);
}
TEST_F(DeclarableOpsTests14, test_empty_stack_1) {
auto x = NDArrayFactory::create<float>('c', {0});
auto e = NDArrayFactory::create<float>('c', {1, 0});
nd4j::ops::stack op;
auto result = op.execute({&x}, {}, {0});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_EQ(e, *z);
nd4j::ops::reduce_min sumOp;
auto res2 = sumOp.execute({&e}, {1.}, {1});
ASSERT_EQ(res2->status(), Status::OK());
auto out = res2->at(0);
out->printShapeInfo("ReduceSum empty shape with keep dims");
out->printIndexedBuffer("ReduceSum scalar");
ASSERT_EQ(out->e<float>(0), DataTypeUtils::infOrMax<float>());
delete res2;
delete result;
}
TEST_F(DeclarableOpsTests14, test_empty_stack_2) {
auto x = NDArrayFactory::empty<float>();
auto e = NDArrayFactory::create<float>('c', {0});
nd4j::ops::stack op;
auto result = op.execute({&x}, {}, {0});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_EQ(e, *z);
delete result;
}
TEST_F(DeclarableOpsTests14, test_empty_stack_3) {
auto x = NDArrayFactory::empty<float>();
auto e = NDArrayFactory::create<float>('c', {2, 0});
nd4j::ops::stack op;
auto result = op.execute({&x, &x}, {}, {0});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_EQ(e, *z);
delete result;
}
TEST_F(DeclarableOpsTests14, test_empty_stack_4) {
auto x = NDArrayFactory::create<float>('c', {0});
auto e = NDArrayFactory::create<float>('c', {2, 0});
nd4j::ops::stack op;
auto result = op.execute({&x, &x}, {}, {0});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_EQ(e, *z);
delete result;
}
TEST_F(DeclarableOpsTests14, test_empty_reduce_min_1) {
auto e = NDArrayFactory::create<float>('c', {1, 0});
nd4j::ops::reduce_min sumOp;
auto res2 = sumOp.execute({&e}, {1.}, {1});
ASSERT_EQ(res2->status(), Status::OK());
auto out = res2->at(0);
ASSERT_EQ(out->e<float>(0), DataTypeUtils::infOrMax<float>());
delete res2;
}
TEST_F(DeclarableOpsTests14, test_empty_reduce_max_1) {
auto e = NDArrayFactory::create<float>('c', {1, 0});
nd4j::ops::reduce_max sumOp;
auto res2 = sumOp.execute({&e}, {1.}, {1});
ASSERT_EQ(res2->status(), Status::OK());
auto out = res2->at(0);
ASSERT_EQ(out->e<float>(0), -DataTypeUtils::infOrMax<float>());
delete res2;
}
TEST_F(DeclarableOpsTests14, test_empty_reduce_sum_1) {
auto e = NDArrayFactory::create<float>('c', {1, 0});
nd4j::ops::reduce_sum sumOp;
auto res2 = sumOp.execute({&e}, {1.}, {1});
ASSERT_EQ(res2->status(), Status::OK());
auto out = res2->at(0);
ASSERT_EQ(out->e<float>(0), 0.f);
delete res2;
}
TEST_F(DeclarableOpsTests14, test_empty_reduce_mean_1) {
auto e = NDArrayFactory::create<float>('c', {1, 0});
nd4j::ops::reduce_mean sumOp;
auto res2 = sumOp.execute({&e}, {1.}, {1});
ASSERT_EQ(res2->status(), Status::OK());
auto out = res2->at(0);
out->printShapeInfo("ReduceMean empty shape with keep dims");
out->printIndexedBuffer("ReduceMean scalar");
ASSERT_EQ(out->e<float>(0), 0.f);
delete res2;
}
TEST_F(DeclarableOpsTests14, test_empty_argmax_1) {
auto x = NDArrayFactory::create<float>('c', {1, 0});
auto y = NDArrayFactory::create<int>(0);
auto e = NDArrayFactory::create<Nd4jLong>('c', {0});
nd4j::ops::argmax op;
//nd4j::ops::reduce_max op;
auto result = op.execute({&x, &y}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
z->printShapeInfo("Z");
ASSERT_EQ(e, *z);
delete result;
}
TEST_F(DeclarableOpsTests14, test_empty_argmax_2) {
auto x = NDArrayFactory::create<float>('c', {1, 0});
auto y = NDArrayFactory::create<int>(1);
nd4j::ops::argmax op;
try {
auto result = op.execute({&x, &y}, {&y}, {}, {}, {});
ASSERT_TRUE(false);
} catch (std::exception &e) {
//
}
}
TEST_F(DeclarableOpsTests14, test_empty_tanh_5) {
auto x = NDArrayFactory::create<float>('c', {32, 0});
nd4j::ops::tanh op;
auto result = op.execute({&x}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_TRUE(x.isSameShape(z));
ASSERT_EQ(x, *z);
delete result;
}

View File

@ -905,9 +905,9 @@ TEST_F(DeclarableOpsTests4, Test_StridedSlice_Alex_2) {
TEST_F(DeclarableOpsTests4, Test_StridedSlice_Alex_3) {
auto x = NDArrayFactory::create<double>('c', {1}, {10});
auto begin = NDArrayFactory::create<double>('c', {1}, {0.});
auto end = NDArrayFactory::create<double>('c', {1}, {0.});
auto stride = NDArrayFactory::create<double>('c', {1}, {1});
auto begin = NDArrayFactory::create<int>('c', {1}, {(int)0});
auto end = NDArrayFactory::create<int>('c', {1}, {(int)0});
auto stride = NDArrayFactory::create<int>('c', {1}, {1});
//x.linspace(1);
//auto exp = NDArrayFactory::create<double>('c', {1,3,4,5});
//exp.linspace(1);

View File

@ -452,14 +452,14 @@ TEST_F(DeclarableOpsTests5, Test_BatchToSpace_3_1) {
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, eye_test1) {
auto expected = NDArrayFactory::create<float>('c', {3, 3}, {1, 0, 0, 0, 1, 0, 0, 0, 1});
nd4j::ops::eye op;
auto results = op.execute({}, {}, {-99, 3});
auto output = results->at(0);
// output->printIndexedBuffer();
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output));
@ -469,7 +469,7 @@ TEST_F(DeclarableOpsTests5, eye_test1) {
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, eye_test2) {
auto expected = NDArrayFactory::create<float>('c', {3, 4}, {1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0});
nd4j::ops::eye op;
@ -485,14 +485,14 @@ TEST_F(DeclarableOpsTests5, eye_test2) {
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, eye_test3) {
auto expected = NDArrayFactory::create<float>('c', {2, 3, 4}, {1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0});
nd4j::ops::eye op;
auto results = op.execute({}, {}, {-99, 3, 4, 2});
auto output = results->at(0);
output->printIndexedBuffer("Output eye");
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output));
@ -502,13 +502,13 @@ TEST_F(DeclarableOpsTests5, eye_test3) {
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, eye_test4) {
auto expected = NDArrayFactory::create<float>('c', {2, 2, 3, 4}, {1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0.});
nd4j::ops::eye op;
auto results = op.execute({}, {}, {-99, 3, 4, 2, 2});
auto output = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output));
@ -633,7 +633,7 @@ TEST_F(DeclarableOpsTests5, gatherNd_test6) {
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, reverse_sequense_test1) {
auto input = NDArrayFactory::create<double>('c', {3, 4, 5});
input.linspace(1);
auto seqLengths = NDArrayFactory::create<int>('c', {4}, {4,4,4,4});
@ -642,7 +642,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test1) {
nd4j::ops::reverse_sequence op;
auto results = op.execute({&input, &seqLengths}, {}, {2, 1});
auto output = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output));
@ -652,7 +652,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test1) {
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, reverse_sequense_test2) {
auto input = NDArrayFactory::create<double>('c', {3, 4, 5});
input.linspace(1);
auto seqLengths = NDArrayFactory::create<Nd4jLong>('c', {4}, {0,1,2,3});
@ -661,7 +661,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test2) {
nd4j::ops::reverse_sequence op;
auto results = op.execute({&input, &seqLengths}, {}, {2, 1});
auto output = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output));
@ -671,7 +671,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test2) {
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, reverse_sequense_test3) {
auto input = NDArrayFactory::create<double>('c', {3, 4, 5});
input.linspace(1);
auto seqLengths = NDArrayFactory::create<int>('c', {3}, {2,3,4});
@ -680,7 +680,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test3) {
nd4j::ops::reverse_sequence op;
auto results = op.execute({&input, &seqLengths}, {}, {2, 0});
auto output = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output));
@ -690,7 +690,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test3) {
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, reverse_sequense_test4) {
auto input = NDArrayFactory::create<double>('c', {3, 4, 5});
input.linspace(1);
auto seqLengths = NDArrayFactory::create<int>('c', {5}, {1, 2, 1, 2, 3});
@ -699,7 +699,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test4) {
nd4j::ops::reverse_sequence op;
auto results = op.execute({&input, &seqLengths}, {}, {0, 2});
auto output = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output));
@ -709,7 +709,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test4) {
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, reverse_sequense_test5) {
auto input = NDArrayFactory::create<double>('c', {3, 4, 5});
input.linspace(1);
auto seqLengths = NDArrayFactory::create<int>('c', {5}, {1, 2, 4, 2, 3});
@ -718,7 +718,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test5) {
nd4j::ops::reverse_sequence op;
auto results = op.execute({&input, &seqLengths}, {}, {1, 2});
auto output = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output));
@ -728,7 +728,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test5) {
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, reverse_sequense_test6) {
auto input = NDArrayFactory::create<double>('c', {3, 4, 5});
input.linspace(1);
auto seqLengths = NDArrayFactory::create<int>('c', {4}, {1, 2, 3, 2});
@ -737,7 +737,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test6) {
nd4j::ops::reverse_sequence op;
auto results = op.execute({&input, &seqLengths}, {}, {0, 1});
auto output = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output));
@ -747,7 +747,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test6) {
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, reverse_sequense_test7) {
auto input = NDArrayFactory::create<double>('c', {1, 5});
input.linspace(1);
std::vector<int> data = {3};
@ -757,7 +757,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test7) {
nd4j::ops::reverse_sequence op;
auto results = op.execute({&input, &seqLengths}, {}, {1, 0});
auto output = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output));
@ -767,7 +767,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test7) {
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, reverse_sequense_test8) {
auto input = NDArrayFactory::create<double>('c', {1, 5});
input.linspace(1);
std::vector<int> data = {1,0,1,0,1};
@ -777,7 +777,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test8) {
nd4j::ops::reverse_sequence op;
auto results = op.execute({&input, &seqLengths}, {}, {0, 1});
auto output = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output));
@ -787,7 +787,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test8) {
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, reverse_sequense_test9) {
auto input = NDArrayFactory::create<double>('c', {5, 1});
input.linspace(1);
std::vector<Nd4jLong> data = {1,0,1,0,1};
@ -797,7 +797,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test9) {
nd4j::ops::reverse_sequence op;
auto results = op.execute({&input, &seqLengths}, {}, {1, 0});
auto output = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output));
@ -807,7 +807,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test9) {
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, reverse_sequense_test10) {
auto input = NDArrayFactory::create<double>('c', {5, 1});
input.linspace(1);
std::vector<int> data = {3};
@ -817,7 +817,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test10) {
nd4j::ops::reverse_sequence op;
auto results = op.execute({&input, &seqLengths}, {}, {0, 1});
auto output = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output));
@ -827,7 +827,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test10) {
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, reverse_sequense_test11) {
auto input = NDArrayFactory::create<double>('c', {1, 1, 5, 1});
input.linspace(1);
std::vector<int> data = {1, 0, 1, 0, 1};
@ -837,7 +837,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test11) {
nd4j::ops::reverse_sequence op;
auto results = op.execute({&input, &seqLengths}, {}, {1, 2});
auto output = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output));
@ -847,7 +847,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test11) {
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, reverse_sequense_test12) {
auto input = NDArrayFactory::create<double>('c', {1, 1, 5, 1});
input.linspace(1);
std::vector<int> data = {3};
@ -857,7 +857,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test12) {
nd4j::ops::reverse_sequence op;
auto results = op.execute({&input, &seqLengths}, {}, {2, 0});
auto output = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output));
@ -867,7 +867,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test12) {
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, reverse_sequense_test13) {
auto input = NDArrayFactory::create<double>('c', {1, 1, 5, 1});
input.linspace(1);
std::vector<int> data = {1};
@ -877,7 +877,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test13) {
nd4j::ops::reverse_sequence op;
auto results = op.execute({&input, &seqLengths}, {}, {3, 0});
auto output = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output));
@ -1307,9 +1307,9 @@ TEST_F(DeclarableOpsTests5, Test_Moments_3) {
11.0, 13.0, 14.0, 5.0,
16.0, 9.0, 13.5, 7.0}
);
auto expV = NDArrayFactory::create<double>('c', {3, 4}, { 8.5f, 6.f , 8.75f, 6.f,
8.5f, 11.f, 8.75f, 6.f,
8.5f, 11.f, 8.75f, 6.f,
18.5f, 6.f, 13.75f, 11.f});
auto expD = NDArrayFactory::create<double>('c', {3, 4}, { 6.25f, 9.f, 27.5625f, 1.f,
6.25f, 4.f, 27.5625f, 1.f,
@ -1368,7 +1368,7 @@ TEST_F(DeclarableOpsTests5, Test_Moments_4) {
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, trace_test1) {
auto input = NDArrayFactory::create<double>('c', {3, 4, 5});
input.linspace(1);
auto exp = NDArrayFactory::create<double>('c', {3}, {40, 120, 200});
@ -1389,7 +1389,7 @@ TEST_F(DeclarableOpsTests5, trace_test1) {
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, trace_test2) {
auto input = NDArrayFactory::create<double>('c', {4, 5});
input.linspace(1);
auto exp = NDArrayFactory::create<double>(40.);
@ -1397,7 +1397,7 @@ TEST_F(DeclarableOpsTests5, trace_test2) {
nd4j::ops::trace op;
auto results = op.execute({&input}, {}, {});
auto output = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output));
@ -1407,7 +1407,7 @@ TEST_F(DeclarableOpsTests5, trace_test2) {
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, trace_test3) {
auto input = NDArrayFactory::create<double>('c', {1, 5});
input.linspace(1);
auto exp = NDArrayFactory::create<double>(1.);
@ -1415,7 +1415,7 @@ TEST_F(DeclarableOpsTests5, trace_test3) {
nd4j::ops::trace op;
auto results = op.execute({&input}, {}, {});
auto output = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output));
@ -1425,7 +1425,7 @@ TEST_F(DeclarableOpsTests5, trace_test3) {
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, trace_test4) {
auto input = NDArrayFactory::create<double>('c', {5, 1});
input.linspace(1);
auto exp = NDArrayFactory::create<double>(1.);
@ -1433,7 +1433,7 @@ TEST_F(DeclarableOpsTests5, trace_test4) {
nd4j::ops::trace op;
auto results = op.execute({&input}, {}, {});
auto output = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output));
@ -1443,7 +1443,7 @@ TEST_F(DeclarableOpsTests5, trace_test4) {
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, trace_test5) {
auto input = NDArrayFactory::create<double>('c', {3, 4, 5, 6});
input.linspace(1);
auto exp = NDArrayFactory::create<double>('c', {3, 4}, {75, 225, 375, 525, 675, 825, 975, 1125, 1275, 1425, 1575, 1725});
@ -1451,7 +1451,7 @@ TEST_F(DeclarableOpsTests5, trace_test5) {
nd4j::ops::trace op;
auto results = op.execute({&input}, {}, {}, {}, false, nd4j::DataType::DOUBLE);
auto output = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output));
@ -1461,7 +1461,7 @@ TEST_F(DeclarableOpsTests5, trace_test5) {
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, random_shuffle_test1) {
auto input = NDArrayFactory::create<double>('c', {2, 2, 2});
input.linspace(1);
@ -1473,7 +1473,7 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test1) {
for(int i = 0; i < output->lengthOf(); ++i)
if(output->e<float>(i) == (float)0.)
haveZeros = true;
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(input.isSameShape(output));
ASSERT_TRUE(!input.equalsTo(output));
@ -1484,9 +1484,9 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test1) {
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, random_shuffle_test2) {
auto input = NDArrayFactory::create<double>('c', {1, 3, 2});
input.linspace(1);
input.linspace(1);
nd4j::ops::random_shuffle op;
auto results = op.execute({&input}, {}, {}, {}, false, nd4j::DataType::DOUBLE);
@ -1494,14 +1494,14 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test2) {
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(input.isSameShape(output));
ASSERT_TRUE(input.equalsTo(output));
ASSERT_TRUE(input.equalsTo(output));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, random_shuffle_test3) {
auto input = NDArrayFactory::create<double>('c', {3, 2, 1});
input.linspace(1);
@ -1513,7 +1513,7 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test3) {
for(int i = 0; i < output->lengthOf(); ++i)
if(output->e<float>(i) == (float)0.)
haveZeros = true;
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(input.isSameShape(output));
ASSERT_TRUE(!input.equalsTo(output));
@ -1535,7 +1535,7 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test4) {
for(int i = 0; i < output->lengthOf(); ++i)
if(output->e<float>(i) == (float)0.)
haveZeros = true;
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(input.isSameShape(output));
ASSERT_TRUE(!input.equalsTo(output));
@ -1546,7 +1546,7 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test4) {
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, random_shuffle_test5) {
auto input = NDArrayFactory::create<double>('c', {4,1});
input.linspace(1);
@ -1558,7 +1558,7 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test5) {
for(int i = 0; i < output->lengthOf(); ++i)
if(output->e<float>(i) == (float)0.)
haveZeros = true;
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(input.isSameShape(output));
ASSERT_TRUE(!input.equalsTo(output));
@ -1569,7 +1569,7 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test5) {
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, random_shuffle_test6) {
auto input = NDArrayFactory::create<double>('c', {4,1,1});
input.linspace(1);
@ -1581,7 +1581,7 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test6) {
for(int i = 0; i < output->lengthOf(); ++i)
if(output->e<float>(i) == (float)0.)
haveZeros = true;
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(input.isSameShape(output));
ASSERT_TRUE(!input.equalsTo(output));
@ -1592,7 +1592,7 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test6) {
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, random_shuffle_test7) {
auto input = NDArrayFactory::create<double>('c', {1,4});
input.linspace(1);
auto exp = NDArrayFactory::create<double>('c', {1,4}, {1, 2, 3, 4});
@ -1611,11 +1611,11 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test7) {
////////////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, EmbeddingLookup_1) {
auto x = NDArrayFactory::create<double>('c', {3, 4, 2}, {10, 20, 11, 21, 12, 22, 13, 23,
14, 24, 15, 25, 16, 26, 17, 27,
18, 28, 19, 29, 20, 30, 21, 31});
auto y = NDArrayFactory::create<int>({1, 1, 1, 0, 0, 0, 2, 2, 2});
auto exp = NDArrayFactory::create<double>('c', {9, 4, 2}, {14, 24, 15, 25, 16, 26, 17, 27, 14, 24, 15, 25,
16, 26, 17, 27, 14, 24, 15, 25, 16, 26, 17, 27,
@ -1637,17 +1637,17 @@ TEST_F(DeclarableOpsTests5, EmbeddingLookup_1) {
ASSERT_TRUE(exp.isSameShape(output));
//output->printIndexedBuffer("Output");
//exp.printIndexedBuffer("Expect");
ASSERT_TRUE(exp.equalsTo(output));
delete result;
}
TEST_F(DeclarableOpsTests5, EmbeddingLookup_2) {
auto x = NDArrayFactory::create<double>('c', {3, 4, 2}, {10, 20, 30, 40, 50, 60,
70, 80, 90, 10, 11, 12,
13, 14, 15, 16, 17, 18,
70, 80, 90, 10, 11, 12,
13, 14, 15, 16, 17, 18,
19, 20, 21, 22, 23, 24});
//1, 0, 1, 0, 1, 0
auto y = NDArrayFactory::create<Nd4jLong>({1, 0, 1, 0, 1, 0});
@ -1673,7 +1673,7 @@ TEST_F(DeclarableOpsTests5, EmbeddingLookup_2) {
ASSERT_TRUE(exp.isSameShape(output));
// output->printIndexedBuffer("Output");
// exp.printIndexedBuffer("Expect");
ASSERT_TRUE(exp.equalsTo(output));
delete result;
@ -1721,19 +1721,19 @@ TEST_F(DeclarableOpsTests5, EmbeddingLookup_3) {
}
TEST_F(DeclarableOpsTests5, DynamicPartition_1) {
auto x = NDArrayFactory::create<double>('c', {3, 4, 2}, {10, 20, 11, 21, 12, 22,
13, 23, 14, 24, 15, 25, 16, 26, 17, 27,
18, 28, 19, 29, 20, 30, 21, 31});
auto y = NDArrayFactory::create<double>('c', {3, 4, 2}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f,
2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f,
1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f
2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f,
1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f
}
);
/* auto y = NDArrayFactory::create<double>('c', {3, 4}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f,
2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f,
1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f
2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f,
1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f
}
);
*/
@ -1762,7 +1762,7 @@ TEST_F(DeclarableOpsTests5, DynamicPartition_1) {
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, DynamicPartition_2) {
auto x = NDArrayFactory::create<double>('c', {2, 4}, {0.1f, -1.f, 5.2f, 4.3f, -1.f, 7.4f, 0.0f, -2.2f});
auto y = NDArrayFactory::create<double>('c', {2, 4}, {1, 2, 1, 2, 1, 2, 3, 0});
@ -1794,7 +1794,7 @@ TEST_F(DeclarableOpsTests5, DynamicPartition_2) {
TEST_F(DeclarableOpsTests5, DynamicPartition_3) {
auto x = NDArrayFactory::create<double>('c', {2, 4}, {0.1f, -1.f, 5.2f, 4.3f, -1.f, 7.4f, 0.0f, -2.2f});
auto y = NDArrayFactory::create<double>('c', {2, 4}, {0, 1, 0, 2, 0, 2, 3, 0});
@ -1817,7 +1817,7 @@ TEST_F(DeclarableOpsTests5, DynamicPartition_3) {
// output->printShapeInfo("Output shape> ");
// exp[e].printShapeInfo("Expected shape> ");
// output->printIndexedBuffer("Output data> ");
ASSERT_TRUE(exp[e].isSameShape(output));
ASSERT_TRUE(exp[e].equalsTo(output));
}
@ -1833,13 +1833,13 @@ TEST_F(DeclarableOpsTests5, DynamicPartition_3) {
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, DynamicStitch_1) {
auto x1 = NDArrayFactory::create<double>({1., 3., 5., 0.});
auto x2 = NDArrayFactory::create<double>({2., 4.});
auto y2 = NDArrayFactory::create<double>({-1., -1.});
auto y1 = NDArrayFactory::create<double>({0.1f, 5.2f, 4.3f, 7.4f});
auto exp = NDArrayFactory::create<double>({7.4f, 0.1f, -1.f, 5.2f, -1.f, 4.3f});
nd4j::ops::dynamic_stitch op;
@ -1852,7 +1852,7 @@ TEST_F(DeclarableOpsTests5, DynamicStitch_1) {
// output->printShapeInfo("Output shape> ");
// exp.printShapeInfo("Expected shape> ");
// output->printIndexedBuffer("Output data> ");
// exp.printIndexedBuffer("Expected res>");
// exp.printIndexedBuffer("Expected res>");
ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output));
@ -1862,13 +1862,13 @@ TEST_F(DeclarableOpsTests5, DynamicStitch_1) {
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, DynamicStitch_2) {
auto x1 = NDArrayFactory::create<double>({1.f, 3.f});
auto x2 = NDArrayFactory::create<double>({5.f, 0.f, 2.f, 4.f});
auto y1 = NDArrayFactory::create<double>({-1.f, -1.f});
auto y2 = NDArrayFactory::create<double>({0.1f, 5.2f, 4.3f, 7.4f});
auto exp = NDArrayFactory::create<double>({5.2f, -1.f, 4.3f, -1.f, 7.4f, 0.1f});
nd4j::ops::dynamic_stitch op;
@ -1881,7 +1881,7 @@ TEST_F(DeclarableOpsTests5, DynamicStitch_2) {
// output->printShapeInfo("Output shape> ");
// exp.printShapeInfo("Expected shape> ");
// output->printIndexedBuffer("Output data> ");
// exp.printIndexedBuffer("Expected res>");
// exp.printIndexedBuffer("Expected res>");
ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output));
@ -1890,11 +1890,11 @@ TEST_F(DeclarableOpsTests5, DynamicStitch_2) {
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, fusedBatchNorm_test1) {
auto x = NDArrayFactory::create<double>('c', {2, 2, 3, 4});
x.linspace(1);
auto scale = NDArrayFactory::create<double>('c', {4});
scale = 0.5;
auto offset = NDArrayFactory::create<double>('c', {4});
offset = 2.;
@ -1908,7 +1908,7 @@ TEST_F(DeclarableOpsTests5, fusedBatchNorm_test1) {
auto y = results->at(0);
auto batchMean = results->at(1);
auto batchVar = results->at(2);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(expY.isSameShape(y));
ASSERT_TRUE(expBatchMean.isSameShape(batchMean));
@ -1919,12 +1919,12 @@ TEST_F(DeclarableOpsTests5, fusedBatchNorm_test1) {
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, fusedBatchNorm_test2) {
auto x = NDArrayFactory::create<double>('c', {2, 2, 3, 4});
x.linspace(1);
auto scale = NDArrayFactory::create<double>('c', {4});
scale = 0.5;
auto offset = NDArrayFactory::create<double>('c', {4});
offset = 2.;
@ -1937,7 +1937,7 @@ TEST_F(DeclarableOpsTests5, fusedBatchNorm_test2) {
auto y = results->at(0);
auto batchMean = results->at(1);
auto batchVar = results->at(2);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(expY.isSameShape(y));
ASSERT_TRUE(expBatchMean.isSameShape(batchMean));
@ -1948,12 +1948,12 @@ TEST_F(DeclarableOpsTests5, fusedBatchNorm_test2) {
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, fusedBatchNorm_test3) {
auto x = NDArrayFactory::create<double>('c', {2, 4, 2, 3});
x.linspace(1);
auto scale = NDArrayFactory::create<double>('c', {4});
scale = 0.5;
auto offset = NDArrayFactory::create<double>('c', {4});
offset = 2.;
@ -1966,7 +1966,7 @@ TEST_F(DeclarableOpsTests5, fusedBatchNorm_test3) {
auto y = results->at(0);
auto batchMean = results->at(1);
auto batchVar = results->at(2);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(expY.isSameShape(y));
ASSERT_TRUE(expBatchMean.isSameShape(batchMean));
@ -1977,7 +1977,7 @@ TEST_F(DeclarableOpsTests5, fusedBatchNorm_test3) {
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, fusedBatchNorm_test4) {
auto x = NDArrayFactory::create<double>('c', {2, 2, 3, 4});
x.linspace(1);
std::vector<Nd4jLong> shape = {4};
@ -1985,8 +1985,8 @@ TEST_F(DeclarableOpsTests5, fusedBatchNorm_test4) {
auto offset = NDArrayFactory::create<double>('c', shape);
auto mean = NDArrayFactory::create<double>('c', shape);
auto variance = NDArrayFactory::create<double>('c', shape);
scale = 0.5;
scale = 0.5;
offset = 2.;
mean = 25.;
variance = 5.;
@ -2001,7 +2001,7 @@ TEST_F(DeclarableOpsTests5, fusedBatchNorm_test4) {
auto y = results->at(0);
auto batchMean = results->at(1);
auto batchVar = results->at(2);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(expY.isSameShape(y));
ASSERT_TRUE(expBatchMean.isSameShape(batchMean));
@ -2012,7 +2012,7 @@ TEST_F(DeclarableOpsTests5, fusedBatchNorm_test4) {
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, fusedBatchNorm_test5) {
auto x = NDArrayFactory::create<double>('c', {2, 2, 3, 4});
x.linspace(1);
std::vector<Nd4jLong> shape = {4};
@ -2020,8 +2020,8 @@ TEST_F(DeclarableOpsTests5, fusedBatchNorm_test5) {
auto offset = NDArrayFactory::create<double>('c', shape);
auto mean = NDArrayFactory::create<double>('c', shape);
auto variance = NDArrayFactory::create<double>('c', shape);
scale = 0.5;
scale = 0.5;
offset = 2.;
mean = 25.;
variance = 5.;
@ -2036,7 +2036,7 @@ TEST_F(DeclarableOpsTests5, fusedBatchNorm_test5) {
auto y = results->at(0);
auto batchMean = results->at(1);
auto batchVar = results->at(2);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(expY.isSameShape(y));
ASSERT_TRUE(expBatchMean.isSameShape(batchMean));
@ -2131,49 +2131,49 @@ TEST_F(DeclarableOpsTests5, confusion_matrix_test4) {
///////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, ZeroFraction_1) {
auto x = NDArrayFactory::create<double>('c', {3, 4, 2}, {0, 20, 30, 0, 50, 0,
70, 0, 90, 0, 11, 12,
13, 14, 15, 16, 17, 18,
70, 0, 90, 0, 11, 12,
13, 14, 15, 16, 17, 18,
19, 0, 21, 22, 23, 24});
nd4j::ops::zero_fraction op;
auto res = op.execute({&x}, {}, {});
ASSERT_EQ(Status::OK(), res->status());
ASSERT_TRUE(res->at(0)->isScalar());
ASSERT_EQ(res->at(0)->e<double>(0), 0.25);
delete res;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, ZeroFraction_2) {
auto x = NDArrayFactory::create<double>('c', {2, 2, 2}, {5.5, 0., 0.3, 5.5, 8.6, 0., 0., 0.4});
nd4j::ops::zero_fraction op;
auto res = op.execute({&x}, {}, {});
ASSERT_EQ(Status::OK(), res->status());
ASSERT_TRUE(res->at(0)->isScalar());
ASSERT_EQ(res->at(0)->e<double>(0), 0.375);
delete res;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, ZeroFraction_3) {
auto x = NDArrayFactory::create<double>('f', {2, 2, 2}, {5.5, 0., 0.3, 5.5, 8.6, 0., 0., 0.4});
nd4j::ops::zero_fraction op;
auto res = op.execute({&x}, {}, {});
ASSERT_EQ(Status::OK(), res->status());
ASSERT_TRUE(res->at(0)->isScalar());
ASSERT_EQ(res->at(0)->e<double>(0), 0.375);
delete res;
}
@ -2219,7 +2219,7 @@ TEST_F(DeclarableOpsTests5, StopGradient_1) {
// output->printShapeInfo("Output shape> ");
// x.printShapeInfo("Expected shape> ");
// output->printIndexedBuffer("Output data> ");
// x.printIndexedBuffer("Expected res>");
// x.printIndexedBuffer("Expected res>");
ASSERT_TRUE(x.isSameShape(output));
ASSERT_TRUE(x.equalsTo(output));
@ -2242,7 +2242,7 @@ TEST_F(DeclarableOpsTests5, StopGradient_2) {
// output->printShapeInfo("Output shape> ");
// x.printShapeInfo("Expected shape> ");
// output->printIndexedBuffer("Output data> ");
// x.printIndexedBuffer("Expected res>");
// x.printIndexedBuffer("Expected res>");
ASSERT_TRUE(x.isSameShape(output));
ASSERT_TRUE(x.equalsTo(output));
@ -2262,7 +2262,7 @@ TEST_F(DeclarableOpsTests5, log_softmax_test1) {
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(expOutput.isSameShape(z));
ASSERT_TRUE(expOutput.equalsTo(z));
ASSERT_TRUE(expOutput.equalsTo(z));
delete results;
}
@ -2279,7 +2279,7 @@ TEST_F(DeclarableOpsTests5, log_softmax_test2) {
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(expOutput.isSameShape(z));
ASSERT_TRUE(expOutput.equalsTo(z));
ASSERT_TRUE(expOutput.equalsTo(z));
delete results;
}
@ -2296,7 +2296,7 @@ TEST_F(DeclarableOpsTests5, log_softmax_test3) {
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(expOutput.isSameShape(z));
ASSERT_TRUE(expOutput.equalsTo(z));
ASSERT_TRUE(expOutput.equalsTo(z));
delete results;
@ -2314,7 +2314,7 @@ TEST_F(DeclarableOpsTests5, log_softmax_test5) {
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(expOutput.isSameShape(z));
ASSERT_TRUE(expOutput.equalsTo(z));
ASSERT_TRUE(expOutput.equalsTo(z));
delete results;
}
@ -2331,7 +2331,7 @@ TEST_F(DeclarableOpsTests5, log_softmax_test6) {
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(expOutput.isSameShape(z));
ASSERT_TRUE(expOutput.equalsTo(z));
ASSERT_TRUE(expOutput.equalsTo(z));
delete results;
}
@ -2348,7 +2348,7 @@ TEST_F(DeclarableOpsTests5, log_softmax_test7) {
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(expOutput.isSameShape(z));
ASSERT_TRUE(expOutput.equalsTo(z));
ASSERT_TRUE(expOutput.equalsTo(z));
delete results;
}
@ -2365,7 +2365,7 @@ TEST_F(DeclarableOpsTests5, log_softmax_test8) {
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(expOutput.isSameShape(z));
ASSERT_TRUE(expOutput.equalsTo(z));
ASSERT_TRUE(expOutput.equalsTo(z));
delete results;
}
@ -2382,7 +2382,7 @@ TEST_F(DeclarableOpsTests5, log_softmax_test9) {
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(expOutput.isSameShape(z));
ASSERT_TRUE(expOutput.equalsTo(z));
ASSERT_TRUE(expOutput.equalsTo(z));
delete results;
}
@ -2399,7 +2399,7 @@ TEST_F(DeclarableOpsTests5, log_softmax_test10) {
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(expOutput.isSameShape(z));
ASSERT_TRUE(expOutput.equalsTo(z));
ASSERT_TRUE(expOutput.equalsTo(z));
delete results;
}
@ -2416,25 +2416,45 @@ TEST_F(DeclarableOpsTests5, log_softmax_test11) {
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(expOutput.isSameShape(z));
ASSERT_TRUE(expOutput.equalsTo(z));
ASSERT_TRUE(expOutput.equalsTo(z));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, log_softmax_test12) {
auto input = NDArrayFactory::create<double>('c', {1, 4}, {0.1869, -1.4918, -0.6497, -0.8864});
auto expOutput = NDArrayFactory::create<double>('c', {1, 4}, {-0.6738, -2.3525, -1.5104, -1.7472});
for (int i = 0; i < 10; ++i)
{
nd4j::ops::log_softmax op;
auto results = op.execute({&input}, {}, {}, {}, false, nd4j::DataType::DOUBLE);
auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(expOutput.isSameShape(z));
ASSERT_TRUE(expOutput.equalsTo(z, 1e-4));
delete results;
}
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, log_softmax_bp_test1) {
auto input = NDArrayFactory::create<double>('c', {2, 2}, {1,2,3,4});
auto epsilon = NDArrayFactory::create<double>('c', {2, 2}, {0.1, 0.2, 0.3, 0.4});
auto exp = NDArrayFactory::create<double>('c', {2, 2}, {-0.07311,0.02689, -0.07311,0.02689});
nd4j::ops::log_softmax_bp op;
auto results = op.execute({&input, &epsilon}, {}, {}, {}, false, nd4j::DataType::DOUBLE);
auto output = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output));
ASSERT_TRUE(exp.equalsTo(output));
delete results;
}
@ -2445,14 +2465,14 @@ TEST_F(DeclarableOpsTests5, log_softmax_bp_test2) {
auto input = NDArrayFactory::create<double>('c', {2, 2}, {1,2,3,4});
auto epsilon = NDArrayFactory::create<double>('c', {2, 2}, {0.1, 0.2, 0.3, 0.4});
auto exp = NDArrayFactory::create<double>('c', {2, 2}, {-0.17616, -0.17616, 0.02384, 0.02384});
nd4j::ops::log_softmax_bp op;
auto results = op.execute({&input, &epsilon}, {}, {0}, {}, false, nd4j::DataType::DOUBLE);
auto output = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output));
ASSERT_TRUE(exp.equalsTo(output));
delete results;
}
@ -2463,7 +2483,7 @@ TEST_F(DeclarableOpsTests5, ELU_1) {
auto input = NDArrayFactory::create<double>('c', {2, 2, 2}, { -1., 2. , 1.5, -1.4, 1., 2., 2., 1.});
auto exp = NDArrayFactory::create<double>('c', {2, 2, 2}, { -0.63212055, 2. , 1.5, -0.753403, 1., 2., 2., 1.});
auto res = NDArrayFactory::create<double>('c', {2, 2, 2});
input.applyTransform(transform::ELU, &res);
ASSERT_TRUE(res.equalsTo(&exp));
@ -2474,7 +2494,7 @@ TEST_F(DeclarableOpsTests5, L2_Loss_1) {
auto input = NDArrayFactory::create<double>('c', {2, 2, 2}, { -1., 2. , 1.5, -1.4, 1., 2., 2., 1.});
double exp(9.605);
nd4j::ops::l2_loss op;
auto results = op.execute({&input}, {}, {});
auto output = results->at(0);
@ -2522,14 +2542,14 @@ TEST_F(DeclarableOpsTests5, LogPoissonLoss_1) {
auto targets = NDArrayFactory::create<double>('c', {2, 2, 2}, {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0});
auto exp = NDArrayFactory::create<double>('c', {2, 2, 2}, {1.3678794, 5.389056, 2.981689, 1.6465969, 1.7182817, 5.389056, 5.389056, 1.7182817});
nd4j::ops::log_poisson_loss op;
auto results = op.execute({&input, &weights, &targets}, {}, {0}, {}, false, nd4j::DataType::DOUBLE);
auto output = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output));
ASSERT_TRUE(exp.equalsTo(output));
delete results;
}
@ -2543,14 +2563,14 @@ TEST_F(DeclarableOpsTests5, LogPoissonLoss_2) {
auto targets = NDArrayFactory::create<double>('c', {2, 2, 2}, {2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0});
auto exp = NDArrayFactory::create<double>('c', {2, 2, 2}, {3.0196857, 4.0408626, 2.1334953, 3.6984034, 1.3700882, 4.0408626, 4.0408626, 1.3700882});
nd4j::ops::log_poisson_loss op;
auto results = op.execute({&input, &weights, &targets}, {}, {0, 1}, {}, false, nd4j::DataType::DOUBLE);
auto output = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output));
ASSERT_TRUE(exp.equalsTo(output));
delete results;
}
@ -2600,9 +2620,9 @@ TEST_F(DeclarableOpsTests5, NormalizeMoments_1) {
auto outputDeviance = results->at(1);
ASSERT_TRUE(expMeans.isSameShape(outputMeans));
ASSERT_TRUE(expMeans.equalsTo(outputMeans));
ASSERT_TRUE(expMeans.equalsTo(outputMeans));
ASSERT_TRUE(expMeans.isSameShape(outputDeviance));
ASSERT_TRUE(expDeviance.equalsTo(outputDeviance));
ASSERT_TRUE(expDeviance.equalsTo(outputDeviance));
delete results;
}
@ -2651,9 +2671,9 @@ TEST_F(DeclarableOpsTests5, NormalizeMoments_2) {
auto outputDeviance = results->at(1);
ASSERT_TRUE(expMeans.isSameShape(outputMeans));
ASSERT_TRUE(expMeans.equalsTo(outputMeans));
ASSERT_TRUE(expMeans.equalsTo(outputMeans));
ASSERT_TRUE(expMeans.isSameShape(outputDeviance));
ASSERT_TRUE(expDeviance.equalsTo(outputDeviance));
ASSERT_TRUE(expDeviance.equalsTo(outputDeviance));
delete results;
}
@ -2702,9 +2722,9 @@ TEST_F(DeclarableOpsTests5, NormalizeMoments_3) {
auto outputDeviance = results->at(1);
ASSERT_TRUE(expMeans.isSameShape(outputMeans));
ASSERT_TRUE(expMeans.equalsTo(outputMeans));
ASSERT_TRUE(expMeans.equalsTo(outputMeans));
ASSERT_TRUE(expMeans.isSameShape(outputDeviance));
ASSERT_TRUE(expDeviance.equalsTo(outputDeviance));
ASSERT_TRUE(expDeviance.equalsTo(outputDeviance));
delete results;
}

View File

@ -109,7 +109,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_4) {
auto e = NDArrayFactory::create<double>('c', {1}, {0.});
auto s = NDArrayFactory::create<double>('c', {1}, {1.0});
//auto exp = NDArrayFactory::create<double>('c', {2}, {1.0f, 2.0f});
auto exp = NDArrayFactory::create<double>(10);
//matrix.linspace(1);
@ -119,7 +119,8 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_4) {
auto z = result->at(0);
z->printShapeInfo("SS OS shape");
ASSERT_TRUE(z->isEmpty());
z->printIndexedBuffer("SS OS out");
ASSERT_TRUE(z->equalsTo(exp));
//ASSERT_EQ(exp, *z);
delete result;

View File

@ -52,7 +52,7 @@ TEST_F(DeclarableOpsTests9, reduceStDevBP_test3) {
nd4j::ops::reduce_stdev_bp op;
auto result = op.execute({&x, &gradO2}, {0,0}, {1});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto output = result->at(0);
// output->printIndexedBuffer();
ASSERT_TRUE(exp.isSameShape(output));
@ -60,7 +60,7 @@ TEST_F(DeclarableOpsTests9, reduceStDevBP_test3) {
delete result;
result = op.execute({&x, &gradO1}, {1,0}, {1});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(ND4J_STATUS_OK, result->status());
output = result->at(0);
ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output));
@ -100,10 +100,10 @@ TEST_F(DeclarableOpsTests9, reduceStDevBP_test03) {
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, exponentialDistributionInv_test1) {
const int N = 50000;
const double lambda = 2.;
const double mean = 1. / lambda;
const double mean = 1. / lambda;
const double std = mean;
auto x = NDArrayFactory::create<double>('c', {N});
@ -114,25 +114,25 @@ TEST_F(DeclarableOpsTests9, exponentialDistributionInv_test1) {
auto rng = (nd4j::random::RandomBuffer *) nativeOps.initRandom(nullptr, 123, N, (Nd4jPointer) buffer);
if (rng == nullptr)
throw std::runtime_error("DeclarableOpsTests9.exponentialDistributionInv_test1: RNG initialization failed !");
functions::random::RandomFunction<double>::template execTransform<randomOps::ExponentialDistributionInv<double>>(rng, x.getBuffer(), x.getShapeInfo(), extraParams);
const double actualMean = x.meanNumber().e<double>(0);
const double actualStd = x.varianceNumber(variance::SummaryStatsStandardDeviation, true).e<double>(0);
ASSERT_NEAR(mean, actualMean, 0.01);
ASSERT_NEAR(std, actualStd, 0.01);
ASSERT_NEAR(std, actualStd, 0.01);
nativeOps.destroyRandom((Nd4jPointer) rng);
delete[] buffer;
}
//////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, exponentialDistributionInv_test2) {
const int N = 50000;
const double lambda = 2.;
const double mean = 1. / lambda;
const double mean = 1. / lambda;
const double std = mean;
double extraParams[] = {lambda};
@ -146,14 +146,14 @@ TEST_F(DeclarableOpsTests9, exponentialDistributionInv_test2) {
auto rng = (nd4j::random::RandomBuffer *) nativeOps.initRandom(nullptr, 123, N, (Nd4jPointer) buffer);
if (rng == nullptr)
throw std::runtime_error("DeclarableOpsTests9.exponentialDistributionInv_test2: RNG initialization failed !");
functions::random::RandomFunction<double>::template execTransform<randomOps::ExponentialDistributionInv<double>>(rng, y.getBuffer(), y.getShapeInfo(), x.getBuffer(), x.getShapeInfo(), extraParams);
const double actualMean = x.meanNumber().e<double>(0);
const double actualStd = x.varianceNumber(variance::SummaryStatsStandardDeviation, true).e<double>(0);
ASSERT_NEAR(mean, actualMean, 0.01);
ASSERT_NEAR(std, actualStd, 0.01);
ASSERT_NEAR(std, actualStd, 0.01);
nativeOps.destroyRandom((Nd4jPointer) rng);
delete[] buffer;
@ -162,10 +162,10 @@ TEST_F(DeclarableOpsTests9, exponentialDistributionInv_test2) {
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, exponentialDistribution_test1) {
const int N = 50000;
const double lambda = 2.;
const double mean = 1. / lambda;
const double mean = 1. / lambda;
const double std = mean;
auto x = NDArrayFactory::create<double>('c', {N});
@ -176,25 +176,25 @@ TEST_F(DeclarableOpsTests9, exponentialDistribution_test1) {
auto rng = (nd4j::random::RandomBuffer *) nativeOps.initRandom(nullptr, 123, N, (Nd4jPointer) buffer);
if (rng == nullptr)
throw std::runtime_error("DeclarableOpsTests9.exponentialDistribution_test1: RNG initialization failed !");
functions::random::RandomFunction<double>::template execTransform<randomOps::ExponentialDistribution<double>>(rng, x.getBuffer(), x.getShapeInfo(), extraParams);
const double actualMean = x.meanNumber().e<double>(0);
const double actualStd = x.varianceNumber(variance::SummaryStatsStandardDeviation, true).e<double>(0);
ASSERT_NEAR(mean, actualMean, 0.01);
ASSERT_NEAR(std, actualStd, 0.01);
ASSERT_NEAR(std, actualStd, 0.01);
nativeOps.destroyRandom((Nd4jPointer) rng);
delete[] buffer;
delete[] buffer;
}
*/
//////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, exponentialDistribution_test2) {
const int N = 50000;
const double lambda = 2.;
const double mean = 1. / lambda;
const double mean = 1. / lambda;
const double std = mean;
double extraParams[] = {lambda};
@ -210,7 +210,7 @@ TEST_F(DeclarableOpsTests9, exponentialDistribution_test2) {
nd4j::random::RandomBuffer* rng = (nd4j::random::RandomBuffer *) nativeOps.initRandom(nullptr, 123, N, (Nd4jPointer) buffer);
if (rng == nullptr)
throw std::runtime_error("DeclarableOpsTests9.exponentialDistribution_test2: RNG initialization failed !");
functions::random::RandomFunction<double>::template execTransform<randomOps::ExponentialDistribution<double>>(rng, y.getBuffer(), y.getShapeInfo(), x.getBuffer(), x.getShapeInfo(), extraParams);
nativeOps.destroyRandom((Nd4jPointer) rng);
@ -218,7 +218,7 @@ TEST_F(DeclarableOpsTests9, exponentialDistribution_test2) {
const double actualMean = x.meanNumber().e<double>(0);
const double actualStd = x.varianceNumber(variance::SummaryStatsStandardDeviation, true).e<double>(0);
ASSERT_NEAR(mean, actualMean, 0.01);
ASSERT_NEAR(std, actualStd, 0.01);
ASSERT_NEAR(std, actualStd, 0.01);
@ -608,6 +608,24 @@ TEST_F(DeclarableOpsTests9, concat_test15) {
delete result;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, concat_test16) {
auto x = NDArrayFactory::create<double>('c', {0,2,3});
auto y = NDArrayFactory::create<double>('c', {0,2,3});
auto exp = NDArrayFactory::create<double>('c', {0,2,3});
nd4j::ops::concat op;
auto result = op.execute({&x, &y}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
ASSERT_TRUE(exp.isSameShape(z));
delete result;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, tile_bp_test3) {
@ -739,11 +757,11 @@ TEST_F(DeclarableOpsTests9, matmul_test1) {
x.linspace(1.);
y.linspace(0.5, 0.5);
nd4j::ops::matmul op;
auto results = op.execute({&x, &y}, {}, {});
auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
@ -761,11 +779,11 @@ TEST_F(DeclarableOpsTests9, matmul_test2) {
x.linspace(1.);
y.linspace(0.5, 0.5);
nd4j::ops::matmul op;
auto results = op.execute({&x, &y}, {}, {});
auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
@ -782,11 +800,11 @@ TEST_F(DeclarableOpsTests9, matmul_test3) {
x.linspace(1.);
y.linspace(0.5, 0.5);
nd4j::ops::matmul op;
auto results = op.execute({&x, &y}, {}, {});
auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
@ -804,11 +822,11 @@ TEST_F(DeclarableOpsTests9, matmul_test4) {
x.linspace(1.);
y.linspace(0.5, 0.5);
nd4j::ops::matmul op;
auto results = op.execute({&x, &y}, {}, {});
auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
@ -826,11 +844,11 @@ TEST_F(DeclarableOpsTests9, matmul_test5) {
x.linspace(1.);
y.linspace(0.5, 0.5);
nd4j::ops::matmul op;
auto results = op.execute({&x, &y}, {}, {1});
auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
@ -847,11 +865,11 @@ TEST_F(DeclarableOpsTests9, matmul_test6) {
x.linspace(1.);
y.linspace(0.5, 0.5);
nd4j::ops::matmul op;
auto results = op.execute({&x, &y}, {}, {1, 1});
auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
@ -870,11 +888,11 @@ TEST_F(DeclarableOpsTests9, matmul_test7) {
x.linspace(1.);
y.linspace(0.1, 0.1);
nd4j::ops::matmul op;
auto results = op.execute({&x, &y}, {}, {0, 1});
auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
@ -895,18 +913,18 @@ TEST_F(DeclarableOpsTests9, matmul_test8) {
x.linspace(1.);
y.linspace(0.1, 0.1);
nd4j::ops::matmul op;
auto results = op.execute({&x, &y}, {}, {0, 1});
auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, matmul_test9) {
@ -920,11 +938,11 @@ TEST_F(DeclarableOpsTests9, matmul_test9) {
x.linspace(1.);
y.linspace(0.1, 0.1);
nd4j::ops::matmul op;
auto results = op.execute({&x, &y}, {}, {1, 1});
auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
@ -1133,11 +1151,11 @@ TEST_F(DeclarableOpsTests9, matmul_test10) {
x.linspace(1.);
y.linspace(0.5, 0.5);
nd4j::ops::matmul op;
auto results = op.execute({&x, &y}, {}, {1, 1});
auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
@ -1174,7 +1192,7 @@ TEST_F(DeclarableOpsTests9, matmul_test12) {
x.linspace(1.);
y.linspace(0.5, 0.5);
nd4j::ops::matmul op;
auto results = op.execute({&x, &y}, {}, {1, 1});
@ -1195,11 +1213,11 @@ TEST_F(DeclarableOpsTests9, matmul_test13) {
x.linspace(1.);
y.linspace(0.5, 0.5);
nd4j::ops::matmul op;
auto results = op.execute({&x, &y}, {}, {0, 0, 1});
auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
@ -1216,11 +1234,11 @@ TEST_F(DeclarableOpsTests9, matmul_test14) {
x.linspace(1.);
y.linspace(0.5, 0.5);
nd4j::ops::matmul op;
auto results = op.execute({&x, &y}, {}, {1, 0, 1});
auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
@ -1237,11 +1255,11 @@ TEST_F(DeclarableOpsTests9, matmul_test15) {
x.linspace(1.);
y.linspace(0.5, 0.5);
nd4j::ops::matmul op;
auto results = op.execute({&x, &y}, {}, {1, 0, 1});
auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
@ -1261,11 +1279,11 @@ TEST_F(DeclarableOpsTests9, matmul_test16) {
x.linspace(1.);
y.linspace(0.1, 0.1);
nd4j::ops::matmul op;
auto results = op.execute({&x, &y}, {}, {1, 1, 1});
auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
@ -1282,11 +1300,11 @@ TEST_F(DeclarableOpsTests9, matmul_test17) {
x.linspace(1.);
y.linspace(0.1, 0.1);
nd4j::ops::matmul op;
auto results = op.execute({&x, &y}, {}, {1, 0});
auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
@ -1303,11 +1321,11 @@ TEST_F(DeclarableOpsTests9, matmul_test18) {
x.linspace(1.);
y.linspace(0.1, 0.1);
nd4j::ops::matmul op;
auto results = op.execute({&x, &y}, {}, {0, 1});
auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
@ -1324,11 +1342,11 @@ TEST_F(DeclarableOpsTests9, matmul_test19) {
x.linspace(2.);
y.linspace(0.1, 0.1);
nd4j::ops::matmul op;
auto results = op.execute({&x, &y}, {}, {});
auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
@ -1346,11 +1364,11 @@ TEST_F(DeclarableOpsTests9, matmul_test20) {
x.linspace(2.);
y.linspace(0.1, 0.1);
nd4j::ops::matmul op;
auto results = op.execute({&x, &y}, {}, {1,1,1});
auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
@ -1368,11 +1386,11 @@ TEST_F(DeclarableOpsTests9, matmul_test21) {
x.linspace(2.);
y.linspace(0.1, 0.1);
nd4j::ops::matmul op;
auto results = op.execute({&x, &y}, {}, {});
auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
@ -1390,11 +1408,11 @@ TEST_F(DeclarableOpsTests9, matmul_test22) {
x.linspace(2.);
y.linspace(0.1, 0.1);
nd4j::ops::matmul op;
auto results = op.execute({&x, &y}, {}, {1});
auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
@ -1412,11 +1430,11 @@ TEST_F(DeclarableOpsTests9, matmul_test23) {
x.linspace(1.);
y.linspace(0.1, 0.1);
nd4j::ops::matmul op;
auto results = op.execute({&x, &y}, {}, {1, 1});
auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
@ -1431,11 +1449,11 @@ TEST_F(DeclarableOpsTests9, matmul_test24) {
auto x = NDArrayFactory::create<double>('f', {1}, {2.});
auto y = NDArrayFactory::create<double>('c', {1}, {3.});
auto exp = NDArrayFactory::create<double>(6.);
nd4j::ops::matmul op;
auto results = op.execute({&x, &y}, {}, {1, 1});
auto z = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
@ -1534,34 +1552,34 @@ TEST_F(DeclarableOpsTests9, test_unstack_SGO_1) {
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, clipbynorm_test12) {
const int bS = 5;
const int nOut = 4;
const int axis = 0;
const double clip = 2.;
auto x = NDArrayFactory::create<double>('c', {bS, nOut}, {0.412 ,0.184 ,0.961 ,0.897 ,0.173 ,0.931 ,0.736 ,0.540 ,0.953 ,0.278 ,0.573 ,0.787 ,0.320 ,0.776 ,0.338 ,0.311 ,0.835 ,0.909 ,0.890 ,0.290}); // uniform random in range [0,1]
auto colVect = NDArrayFactory::create<double>('c', {bS, 1}, {0.9, 0.95, 1.00, 1.05, 1.1});
auto expect = NDArrayFactory::create<double>('c', {bS, nOut});
auto norm2 = x.reduceAlongDims(reduce::Norm2, {axis}, true); // norm2 has shape [1, nOut]
auto y = ( (x / norm2) * clip) * colVect ;
auto temp = (x / norm2) * clip;
for (int j = 0; j < nOut; ++j) {
auto yCol = y({0,0, j,j+1});
const double norm2Col = yCol.reduceNumber(reduce::Norm2).e<double>(0);
if (norm2Col <= clip)
if (norm2Col <= clip)
expect({0,0, j,j+1}).assign(yCol);
else
else
expect({0,0, j,j+1}).assign ( yCol * (clip / norm2Col) );
}
nd4j::ops::clipbynorm op;
auto result = op.execute({&y}, {clip}, {axis}, {}, false, nd4j::DataType::DOUBLE);
auto outFF = result->at(0);
auto outFF = result->at(0);
ASSERT_TRUE(expect.isSameShape(outFF));
ASSERT_TRUE(expect.equalsTo(outFF));
@ -1571,12 +1589,12 @@ TEST_F(DeclarableOpsTests9, clipbynorm_test12) {
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, clipbynorm_bp_test1) {
const int bS = 2;
const int nOut = 3;
const int axis = 0;
const double clip = 0.7;
auto x = NDArrayFactory::create<double>('c', {bS, nOut}, {0.412 ,0.184 ,0.961 ,0.173 ,0.736 ,0.540 }); // uniform random in range [0,1]
auto gradO = NDArrayFactory::create<double>('c', {bS, nOut});
@ -1593,12 +1611,12 @@ TEST_F(DeclarableOpsTests9, clipbynorm_bp_test1) {
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, clipbynorm_bp_test2) {
const int bS = 2;
const int nOut = 3;
const int axis = 0;
const double clip = 0.7;
auto x = NDArrayFactory::create<double>('c', {bS, nOut}, {0.412 ,0.184 ,0.961 ,0.173 ,0.736 ,0.540 }); // uniform random in range [0,1]
auto gradO = NDArrayFactory::create<double>('c', {bS, nOut});
@ -1616,12 +1634,12 @@ TEST_F(DeclarableOpsTests9, clipbynorm_bp_test2) {
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, clipbynorm_bp_test3) {
const int bS = 2;
const int nOut = 3;
const int axis = 1;
const double clip = 1.;
auto x = NDArrayFactory::create<double>('c', {bS, nOut}, {0.412 ,0.184 ,0.961 ,0.173 ,0.736 ,0.540 }); // uniform random in range [0,1]
auto gradO = NDArrayFactory::create<double>('c', {bS, nOut});
@ -1734,7 +1752,7 @@ TEST_F(DeclarableOpsTests9, cumsum_bp_check_2) {
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, cumprod_test1) {
auto inputC = NDArrayFactory::create<double>('c', {3, 5}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.});
auto axis = NDArrayFactory::create<double>(1.);
@ -1745,7 +1763,7 @@ TEST_F(DeclarableOpsTests9, cumprod_test1) {
auto expTT = NDArrayFactory::create<double>('c', {3, 5}, {120, 60, 20, 5, 1,5040, 720, 90, 10, 1,32760, 2730, 210, 15, 1});
auto gradO = NDArrayFactory::create<double>('c', {3, 5});
int exclusive, reverse;
int exclusive, reverse;
//************************************//
exclusive = 0; reverse = 0;
@ -1764,8 +1782,8 @@ TEST_F(DeclarableOpsTests9, cumprod_test1) {
/* exclusive = 1; reverse = 0;
result = op.execute({&inputC, &axis}, {}, {exclusive, reverse});
ASSERT_EQ(Status::OK(), result->status());
z = result->at(0);
ASSERT_EQ(Status::OK(), result->status());
z = result->at(0);
ASSERT_TRUE(expTF.equalsTo(z));
delete result;
*/
@ -1773,8 +1791,8 @@ TEST_F(DeclarableOpsTests9, cumprod_test1) {
/* exclusive = 0; reverse = 1;
result = op.execute({&inputC, &axis}, {}, {exclusive, reverse});
ASSERT_EQ(Status::OK(), result->status());
z = result->at(0);
ASSERT_EQ(Status::OK(), result->status());
z = result->at(0);
ASSERT_TRUE(expFT.equalsTo(z));
delete result;
*/
@ -1782,16 +1800,16 @@ TEST_F(DeclarableOpsTests9, cumprod_test1) {
/* exclusive = 1; reverse = 1;
result = op.execute({&inputC, &axis}, {}, {exclusive, reverse});
ASSERT_EQ(Status::OK(), result->status());
z = result->at(0);
ASSERT_EQ(Status::OK(), result->status());
z = result->at(0);
ASSERT_TRUE(expTT.equalsTo(z));
delete result;
*/
delete result;
*/
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, cumprod_test2) {
auto inputC = NDArrayFactory::create<double>('c', {2, 2});
auto axis = NDArrayFactory::create<double>(1.);
@ -1802,7 +1820,7 @@ TEST_F(DeclarableOpsTests9, cumprod_test2) {
// auto expTT = NDArrayFactory::create<double>('c', {3, 5}, {120, 60, 20, 5, 1,5040, 720, 90, 10, 1,32760, 2730, 210, 15, 1});
auto gradO = NDArrayFactory::create<double>('c', {2, 2});
int exclusive, reverse;
int exclusive, reverse;
//************************************//
exclusive = 0; reverse = 0;
@ -1820,7 +1838,7 @@ TEST_F(DeclarableOpsTests9, cumprod_test2) {
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, prelu_test1) {
auto x = NDArrayFactory::create<double>('c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f});
auto alpha = NDArrayFactory::create<double>('c', {3, 4}, {-0.6f, -0.5f, -0.4f, -0.3f, -0.2f, -0.1f, 0.f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f});
auto exp = NDArrayFactory::create<double>('c', {2, 3, 4}, {7.2f, 5.5f, 4.f, 2.7f, 1.6f, 0.7f, 0.f, -0.5f,-0.8f, -0.9f, -0.8f, -0.5f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f});
@ -1839,7 +1857,7 @@ TEST_F(DeclarableOpsTests9, prelu_test1) {
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, prelu_test2) {
auto x = NDArrayFactory::create<double>('c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f});
auto alpha = NDArrayFactory::create<double>('c', {3}, {-0.6f, 2.f, 4.f});
auto exp = NDArrayFactory::create<double>('c', {2, 3, 4}, {7.2f, 6.6f, 6.f, 5.4f, -16.f, -14.f, -12.f, -10.f, -16.f, -12.f, -8.f, -4.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f});
@ -1857,7 +1875,7 @@ TEST_F(DeclarableOpsTests9, prelu_test2) {
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, prelu_test3) {
auto x = NDArrayFactory::create<double>('c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f});
auto alpha = NDArrayFactory::create<double>('c', {3,1}, {-0.6f, 2.f, 4.f});
auto exp = NDArrayFactory::create<double>('c', {2, 3, 4}, {7.2f, 6.6f, 6.f, 5.4f, -16.f, -14.f, -12.f, -10.f, -16.f, -12.f, -8.f, -4.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f});
@ -1875,7 +1893,7 @@ TEST_F(DeclarableOpsTests9, prelu_test3) {
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, prelu_test4) {
auto x = NDArrayFactory::create<double>('c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f});
auto alpha = NDArrayFactory::create<double>('c', {1, 3}, {-0.6f, 2.f, 4.f});
auto exp = NDArrayFactory::create<double>('c', {2, 3, 4}, {7.2f, 6.6f, 6.f, 5.4f, -16.f, -14.f, -12.f, -10.f, -16.f, -12.f, -8.f, -4.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f});
@ -1893,7 +1911,7 @@ TEST_F(DeclarableOpsTests9, prelu_test4) {
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, prelu_test5) {
auto x = NDArrayFactory::create<double>('c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f});
auto alpha = NDArrayFactory::create<double>('c', {4}, {-0.6f, 2.f, 4.f, -1.f});
auto exp = NDArrayFactory::create<double>('c', {2, 3, 4}, {7.2f, -22.f, -40.f, 9.f, 4.8f, -14.f, -24.f, 5.f, 2.4f, -6.f, -8.f, 1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f});
@ -1911,7 +1929,7 @@ TEST_F(DeclarableOpsTests9, prelu_test5) {
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, prelu_test6) {
auto x = NDArrayFactory::create<double>('c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f});
auto alpha = NDArrayFactory::create<double>('c', {1,1,1}, {-2.});
auto exp = NDArrayFactory::create<double>('c', {2, 3, 4}, {24.f, 22.f, 20.f, 18.f, 16.f, 14.f, 12.f, 10.f, 8.f, 6.f, 4.f, 2.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f});
@ -1930,7 +1948,7 @@ TEST_F(DeclarableOpsTests9, prelu_test6) {
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, prelu_test7) {
auto x = NDArrayFactory::create<double>('c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f});
auto alpha = NDArrayFactory::create<double>(-2.f);
auto exp = NDArrayFactory::create<double>('c', {2, 3, 4}, {24.f, 22.f, 20.f, 18.f, 16.f, 14.f, 12.f, 10.f, 8.f, 6.f, 4.f, 2.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f});
@ -1948,7 +1966,7 @@ TEST_F(DeclarableOpsTests9, prelu_test7) {
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, prelu_test8) {
auto x = NDArrayFactory::create<double>('c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f});
auto alpha = NDArrayFactory::create<double>(-2.f);
auto exp = NDArrayFactory::create<double>('c', {2, 3, 4}, {24.f, 22.f, 20.f, 18.f, 16.f, 14.f, 12.f, 10.f, 8.f, 6.f, 4.f, 2.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f});
@ -1966,7 +1984,7 @@ TEST_F(DeclarableOpsTests9, prelu_test8) {
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, prelu_test9) {
auto x = NDArrayFactory::create<double>('c', {2, 4}, {-4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f});
auto alpha = NDArrayFactory::create<double>(-2.f);
auto exp = NDArrayFactory::create<double>('c', {2, 4}, {8.f, 6.f, 4.f, 2.f,0.f, 1.f, 2.f, 3.f});
@ -1984,7 +2002,7 @@ TEST_F(DeclarableOpsTests9, prelu_test9) {
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, prelu_test10) {
auto x = NDArrayFactory::create<double>('c', {2, 4}, {-4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f});
auto alpha = NDArrayFactory::create<double>(-2.f);
auto exp = NDArrayFactory::create<double>('c', {2, 4}, {8.f, 6.f, 4.f, 2.f,0.f, 1.f, 2.f, 3.f});
@ -2002,16 +2020,16 @@ TEST_F(DeclarableOpsTests9, prelu_test10) {
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, prelu_test11) {
auto x = NDArrayFactory::create<double>('c', {2, 3, 4, 5});
x.linspace(-50.);
auto alpha = NDArrayFactory::create<double>('c', {4}, {0.f, -0.5f, 0.5f, -1.f});
auto exp = NDArrayFactory::create<double>('c', {2, 3, 4, 5}, {0.f, 0.f, 0.f, 0.f, 0.f, 22.5f, 22.f, 21.5f, 21.f, 20.5f, -20.f, -19.5f, -19.f, -18.5f, -18.f, 35.f, 34.f, 33.f,
32.f, 31.f, 0.f, 0.f, 0.f, 0.f, 0.f, 12.5f, 12.f, 11.5f, 11.f, 10.5f, -10.f, -9.5f, -9.f, -8.5f, -8.f, 15.f,
14.f, 13.f, 12.f, 11.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.5f, 2.f, 1.5f, 1.f, 0.5f, 0.f, 1.f, 2.f, 3.f, 4.f,
5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f,
24.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f,
43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 61.f,
32.f, 31.f, 0.f, 0.f, 0.f, 0.f, 0.f, 12.5f, 12.f, 11.5f, 11.f, 10.5f, -10.f, -9.5f, -9.f, -8.5f, -8.f, 15.f,
14.f, 13.f, 12.f, 11.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.5f, 2.f, 1.5f, 1.f, 0.5f, 0.f, 1.f, 2.f, 3.f, 4.f,
5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f,
24.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f,
43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 61.f,
62.f, 63.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f});
nd4j::ops::prelu op;
@ -2027,15 +2045,15 @@ TEST_F(DeclarableOpsTests9, prelu_test11) {
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, prelu_test12) {
auto x = NDArrayFactory::create<double>('c', {2, 3, 4, 5});
x.linspace(-50.);
auto alpha = NDArrayFactory::create<double>('c', {3,5}, {-0.7f, -0.6f, -0.5f, -0.4f, -0.3f, -0.2f, -0.1f, 0.f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f});
auto exp = NDArrayFactory::create<double>('c', {2, 3, 4, 5}, {35.f, 29.4f, 24.f, 18.8f, 13.8f, 31.5f, 26.4f, 21.5f, 16.8f, 12.3f, 28.f, 23.4f, 19.f, 14.8f, 10.8f, 24.5f, 20.4f, 16.5f, 12.8f,
9.3f, 6.f, 2.9f, 0.f, -2.7f, -5.2f, 5.f, 2.4f, 0.f, -2.2f, -4.2f, 4.f, 1.9f, 0.f, -1.7f, -3.2f, 3.f, 1.4f, 0.f, -1.2f,
-2.2f, -3.f, -3.6f, -4.f, -4.2f, -4.2f, -1.5f, -1.6f, -1.5f, -1.2f, -0.7f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f,
9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f,
31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 49.f, 50.f, 51.f, 52.f,
9.3f, 6.f, 2.9f, 0.f, -2.7f, -5.2f, 5.f, 2.4f, 0.f, -2.2f, -4.2f, 4.f, 1.9f, 0.f, -1.7f, -3.2f, 3.f, 1.4f, 0.f, -1.2f,
-2.2f, -3.f, -3.6f, -4.f, -4.2f, -4.2f, -1.5f, -1.6f, -1.5f, -1.2f, -0.7f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f,
9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f,
31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 49.f, 50.f, 51.f, 52.f,
53.f, 54.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f});
nd4j::ops::prelu op;
@ -2051,15 +2069,15 @@ TEST_F(DeclarableOpsTests9, prelu_test12) {
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, prelu_test13) {
auto x = NDArrayFactory::create<double>('c', {2, 3, 4, 5});
x.linspace(-50.);
auto alpha = NDArrayFactory::create<double>('c', {5,3}, {-0.7f, -0.6f, -0.5f, -0.4f, -0.3f, -0.2f, -0.1f, 0.f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f});
auto exp = NDArrayFactory::create<double>('c', {2, 3, 4, 5}, {35.f, 29.4f, 24.f, 18.8f, 13.8f, 31.5f, 26.4f, 21.5f, 16.8f, 12.3f, 28.f, 23.4f, 19.f, 14.8f, 10.8f, 24.5f, 20.4f, 16.5f, 12.8f,
9.3f, 6.f, 2.9f, 0.f, -2.7f, -5.2f, 5.f, 2.4f, 0.f, -2.2f, -4.2f, 4.f, 1.9f, 0.f, -1.7f, -3.2f, 3.f, 1.4f, 0.f, -1.2f,
-2.2f, -3.f, -3.6f, -4.f, -4.2f, -4.2f, -1.5f, -1.6f, -1.5f, -1.2f, -0.7f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f,
9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f,
31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 49.f, 50.f, 51.f, 52.f,
9.3f, 6.f, 2.9f, 0.f, -2.7f, -5.2f, 5.f, 2.4f, 0.f, -2.2f, -4.2f, 4.f, 1.9f, 0.f, -1.7f, -3.2f, 3.f, 1.4f, 0.f, -1.2f,
-2.2f, -3.f, -3.6f, -4.f, -4.2f, -4.2f, -1.5f, -1.6f, -1.5f, -1.2f, -0.7f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f,
9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f,
31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 49.f, 50.f, 51.f, 52.f,
53.f, 54.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f});
nd4j::ops::prelu op;
@ -2075,16 +2093,16 @@ TEST_F(DeclarableOpsTests9, prelu_test13) {
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, prelu_test14) {
auto x = NDArrayFactory::create<double>('c', {2, 3, 4, 5});
x.linspace(-50.);
auto alpha = NDArrayFactory::create<double>('c', {2,10}, {-0.7f, -0.6f, -0.5f, -0.4f, -0.3f, -0.2f, -0.1f, 0.f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f, 0.9f, 1.f, 1.1f, 1.2f});
auto exp = NDArrayFactory::create<double>('c', {2, 3, 4, 5}, {35.f, 29.4f, 24.f, 18.8f, 13.8f, 9.f, 4.4f, 0.f, -4.2f, -8.2f, -12.f, -15.6f, -19.f, -22.2f, -25.2f, -28.f, -30.6f,
-33.f,-35.2f, -37.2f, 21.f, 17.4f, 14.f, 10.8f, 7.8f, 5.f, 2.4f, 0.f, -2.2f, -4.2f, -6.f, -7.6f, -9.f, -10.2f,
-11.2f, -12.f, -12.6f, -13.f, -13.2f, -13.2f, 7.f, 5.4f, 4.f, 2.8f, 1.8f, 1.f, 0.4f, 0.f, -0.2f, -0.2f, 0.f,
1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f,
19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f,
37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f,
-33.f,-35.2f, -37.2f, 21.f, 17.4f, 14.f, 10.8f, 7.8f, 5.f, 2.4f, 0.f, -2.2f, -4.2f, -6.f, -7.6f, -9.f, -10.2f,
-11.2f, -12.f, -12.6f, -13.f, -13.2f, -13.2f, 7.f, 5.4f, 4.f, 2.8f, 1.8f, 1.f, 0.4f, 0.f, -0.2f, -0.2f, 0.f,
1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f,
19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f,
37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f,
55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f});
nd4j::ops::prelu op;
@ -2100,7 +2118,7 @@ TEST_F(DeclarableOpsTests9, prelu_test14) {
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, thresholdedrelu_test1) {
const float theta = 2.f;
auto x = NDArrayFactory::create<double>('c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f});
auto exp = NDArrayFactory::create<double>('c', {2, 3, 4}, {0.f, 0.f, 0.f, 0.f,0.f, 0.f, 0.f, 0.f,0.f, 0.f, 0.f, 0.f,0.f, 0.f, 0.f, 3.f,4.f, 5.f, 6.f, 7.f,8.f, 9.f,10.f,11.f});
@ -2116,10 +2134,10 @@ TEST_F(DeclarableOpsTests9, thresholdedrelu_test1) {
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, thresholdedrelu_test2) {
const float theta = -2.f;
auto x = NDArrayFactory::create<double>('c', {2, 3, 4}, {0.f,-4.f, -10.f, -8.f, 0.f, -9.f, -8.f, 5.f, 6.f, 6.f, 9.f, 6.f, -8.f, 5.f, 10.f, -2.f, 3.f, -7.f, 4.f, -8.f, -4.f, -9.f, -9.f, 3.f});
auto exp = NDArrayFactory::create<double>('c', {2, 3, 4}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 5.f, 6.f, 6.f, 9.f, 6.f, 0.f, 5.f, 10.f, 0.f, 3.f, 0.f, 4.f, 0.f, 0.f, 0.f, 0.f, 3.f});
@ -2128,7 +2146,7 @@ TEST_F(DeclarableOpsTests9, thresholdedrelu_test2) {
auto result = op.execute({&x}, {theta}, {}, {}, false, nd4j::DataType::DOUBLE);
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto output = result->at(0);
auto output = result->at(0);
ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output));
@ -2138,7 +2156,7 @@ TEST_F(DeclarableOpsTests9, thresholdedrelu_test2) {
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, prelu_bp_test1) {
auto x = NDArrayFactory::create<double>('c', {2, 3, 4}, {-12., -11., -10., -9., -8., -7., -6., -5., -4., -3., -2., -1., 0.5, 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.});
auto alpha = NDArrayFactory::create<double>('c', {3, 4}, {-0.6, -0.5, -0.4, -0.3, -0.2, -0.1, 0.5, 0.1, 0.2, 0.3, 0.4, 0.5});
auto dLdO = NDArrayFactory::create<double>('c', {2, 3, 4});
@ -2156,7 +2174,7 @@ TEST_F(DeclarableOpsTests9, prelu_bp_test1) {
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, prelu_bp_test2) {
auto x = NDArrayFactory::create<double>('c', {2, 3, 4}, {-12., -11., -10., -9., -8., -7., -6., -5., -4., -3., -2., -1., 0.5, 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.});
auto alpha = NDArrayFactory::create<double>('c', {4}, {-0.6, 2., 4., -1.});
auto dLdO = NDArrayFactory::create<double>('c', {2, 3, 4});
@ -2174,7 +2192,7 @@ TEST_F(DeclarableOpsTests9, prelu_bp_test2) {
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, prelu_bp_test3) {
auto x = NDArrayFactory::create<double>('c', {2, 3, 2, 5});
x.linspace(-30.);
x.p(30, 0.5); // avoid zero, since it is points of discontinuity for prelu
@ -2194,7 +2212,7 @@ TEST_F(DeclarableOpsTests9, prelu_bp_test3) {
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, prelu_bp_test4) {
auto x = NDArrayFactory::create<double>('c', {2, 3, 4, 5});
x.linspace(-50.);
x.p(50, 0.5); // avoid zero, since it is points of discontinuity for prele
@ -2214,7 +2232,7 @@ TEST_F(DeclarableOpsTests9, prelu_bp_test4) {
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, thresholdedrelu_bp_test1) {
const double theta = 0.15;
auto x = NDArrayFactory::create<double>('c', {2, 3, 4}, {1.2, 1.1, 1., 0.9, 0.8, -0.7, -0.6,-0.5,-0.4,-0.3,-0.2,-0.1, 0., 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, -0.9, -1.0, -1.1});
@ -2233,7 +2251,7 @@ TEST_F(DeclarableOpsTests9, thresholdedrelu_bp_test1) {
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, multiply_test1) {
auto x = NDArrayFactory::create<double>('c', {2, 3, 4});
auto y = NDArrayFactory::create<double>('c', {4});
auto exp = NDArrayFactory::create<double>('c', {2, 3, 4}, {0.1f, 0.4f, 0.9f, 1.6f, 0.5f, 1.2f, 2.1f, 3.2f, 0.9f, 2.f, 3.3f, 4.8f, 1.3f, 2.8f, 4.5f, 6.4f, 1.7f, 3.6f, 5.7f, 8.f, 2.1f, 4.4f, 6.9f, 9.6f});
@ -2253,7 +2271,7 @@ TEST_F(DeclarableOpsTests9, multiply_test1) {
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, multiply_test2) {
auto x = NDArrayFactory::create<double>('c', {2, 3, 4});
auto y = NDArrayFactory::create<double>(0.1);
auto exp = NDArrayFactory::create<double>('c', {2, 3, 4}, {0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f, 0.9f, 1.f, 1.1f, 1.2f, 1.3f, 1.4f, 1.5f, 1.6f, 1.7f, 1.8f, 1.9f, 2.f, 2.1f, 2.2f, 2.3f, 2.4f});
@ -2273,7 +2291,7 @@ TEST_F(DeclarableOpsTests9, multiply_test2) {
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, multiply_test3) {
auto x = NDArrayFactory::create<double>('c', {2, 1, 4});
auto y = NDArrayFactory::create<double>('c', {3,1});
auto exp = NDArrayFactory::create<double>('c', {2, 3, 4}, {0.1f, 0.2f, 0.3f, 0.4f, 0.2f, 0.4f, 0.6f, 0.8f, 0.3f, 0.6f, 0.9f, 1.2f, 0.5f, 0.6f, 0.7f, 0.8f, 1.f, 1.2f, 1.4f, 1.6f, 1.5f, 1.8f, 2.1f, 2.4f});
@ -2293,11 +2311,11 @@ TEST_F(DeclarableOpsTests9, multiply_test3) {
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, multiply_test4) {
auto x = NDArrayFactory::create<double>('c', {1, 1});
auto y = NDArrayFactory::create<double>(0.1f);
auto exp = NDArrayFactory::create<double>('c', {1, 1}, {0.1f});
x.linspace(1.f);
x.linspace(1.f);
nd4j::ops::multiply op;
auto result = op.execute({&x, &y}, {}, {});
@ -2312,11 +2330,11 @@ TEST_F(DeclarableOpsTests9, multiply_test4) {
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, multiply_test5) {
auto x = NDArrayFactory::create<double>(1.f);
auto y = NDArrayFactory::create<double>(0.1f);
auto exp = NDArrayFactory::create<double>(0.1f);
nd4j::ops::multiply op;
auto result = op.execute({&x, &y}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
@ -2330,7 +2348,7 @@ TEST_F(DeclarableOpsTests9, multiply_test5) {
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, multiply_bp_test1) {
auto x = NDArrayFactory::create<double>('c', {1, 1}, {100.});
auto y = NDArrayFactory::create<double>(0.1);
auto dLdz = NDArrayFactory::create<double>('c', {1, 1});
@ -2353,7 +2371,7 @@ TEST_F(DeclarableOpsTests9, multiply_bp_test1) {
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, multiply_bp_test2) {
auto x = NDArrayFactory::create<double>('c', {2, 2}, {1.,2.,3.,4.});
auto y = NDArrayFactory::create<double>(0.1);
auto dLdz = NDArrayFactory::create<double>('c', {2, 2});
@ -2371,7 +2389,7 @@ TEST_F(DeclarableOpsTests9, multiply_bp_test2) {
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, multiply_bp_test3) {
auto y = NDArrayFactory::create<double>('c', {2, 2}, {1.,2.,3.,4.});
auto x = NDArrayFactory::create<double>(0.1);
auto dLdz = NDArrayFactory::create<double>('c', {2, 2});
@ -2389,7 +2407,7 @@ TEST_F(DeclarableOpsTests9, multiply_bp_test3) {
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, multiply_bp_test4) {
auto x = NDArrayFactory::create<double>('c', {2, 2}, {1.,2.,3.,4.});
auto y = NDArrayFactory::create<double>('c', {2, 2}, {0.1,0.2,0.3,0.4});
auto dLdz = NDArrayFactory::create<double>('c', {2, 2});
@ -2407,7 +2425,7 @@ TEST_F(DeclarableOpsTests9, multiply_bp_test4) {
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, multiply_bp_test5) {
auto x = NDArrayFactory::create<double>('c', {2, 2}, {1.,2.,3.,4.});
auto y = NDArrayFactory::create<double>('c', {2}, {0.1,0.2});
auto dLdz = NDArrayFactory::create<double>('c', {2, 2});
@ -2425,7 +2443,7 @@ TEST_F(DeclarableOpsTests9, multiply_bp_test5) {
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, multiply_bp_test6) {
auto y = NDArrayFactory::create<double>('c', {2, 2}, {1.,2.,3.,4.});
auto x = NDArrayFactory::create<double>('c', {2}, {0.1,0.2});
auto dLdz = NDArrayFactory::create<double>('c', {2, 2});
@ -2443,7 +2461,7 @@ TEST_F(DeclarableOpsTests9, multiply_bp_test6) {
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, multiply_bp_test7) {
auto y = NDArrayFactory::create<double>('c', {2, 3}, {1.,2.,3.,4.,5.,6.});
auto x = NDArrayFactory::create<double>('c', {2, 1}, {0.1,0.2});
auto dLdz = NDArrayFactory::create<double>('c', {2, 3});
@ -2461,7 +2479,7 @@ TEST_F(DeclarableOpsTests9, multiply_bp_test7) {
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, multiply_bp_test8) {
auto y = NDArrayFactory::create<double>('c', {2, 1, 4});
auto x = NDArrayFactory::create<double>('c', {1, 3, 4});
auto dLdz = NDArrayFactory::create<double>('c', {2, 3, 4});

View File

@ -59,7 +59,8 @@ TEST_F(EmptyTests, Test_Create_Empty_2) {
}
TEST_F(EmptyTests, Test_Concat_1) {
auto empty = NDArrayFactory::empty_<float>();
// auto empty = NDArrayFactory::empty_<float>();
auto empty = new NDArray('c', {0}, nd4j::DataType::FLOAT32);//NDArrayFactory::create_<float>('c', {(Nd4jLong)0}};
auto vector = NDArrayFactory::create_<float>('c', {1}, {1.0f});
ASSERT_TRUE(empty->isEmpty());
@ -82,9 +83,9 @@ TEST_F(EmptyTests, Test_Concat_1) {
TEST_F(EmptyTests, Test_Concat_2) {
auto empty = NDArrayFactory::empty_<float>();
auto scalar1 = NDArrayFactory::create_<float>(1.0f);
auto scalar2 = NDArrayFactory::create_<float>(2.0f);
auto empty = new NDArray('c', {0}, nd4j::DataType::FLOAT32); //NDArrayFactory::empty_<float>();
auto scalar1 = NDArrayFactory::create_<float>('c', {1}, {1.0f});
auto scalar2 = NDArrayFactory::create_<float>('c', {1}, {2.0f});
auto exp = NDArrayFactory::create<float>('c', {2}, {1.f, 2.f});
ASSERT_TRUE(empty->isEmpty());
@ -139,6 +140,23 @@ TEST_F(EmptyTests, Test_Reshape_2) {
delete result;
}
TEST_F(EmptyTests, Test_Reshape_3) {
auto x = NDArrayFactory::create<float>('c', {1, 0, 0, 2});
auto y = NDArrayFactory::create<int>('c', {2}, {10, 0});
auto e = NDArrayFactory::create<float>('c', {10, 0});
nd4j::ops::reshape op;
auto result = op.execute({&x, &y}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_TRUE(e.isSameShape(z));
ASSERT_EQ(e, *z);
delete result;
}
TEST_F(EmptyTests, Test_dup_1) {
auto empty = NDArrayFactory::empty<int>();
auto dup = empty.dup();
@ -147,4 +165,48 @@ TEST_F(EmptyTests, Test_dup_1) {
ASSERT_EQ(empty, *dup);
delete dup;
}
TEST_F(EmptyTests, test_shaped_empty_1) {
auto empty = NDArrayFactory::create<float>('c', {2, 0, 3});
std::vector<Nd4jLong> shape = {2, 0, 3};
ASSERT_EQ(nd4j::DataType::FLOAT32, empty.dataType());
ASSERT_EQ(0, empty.lengthOf());
ASSERT_TRUE(empty.isEmpty());
ASSERT_EQ(shape, empty.getShapeAsVector());
ASSERT_EQ(3, empty.rankOf());
}
TEST_F(EmptyTests, test_shaped_empty_2) {
auto empty = NDArrayFactory::create<float>('c', {0, 3});
std::vector<Nd4jLong> shape = {0, 3};
ASSERT_EQ(nd4j::DataType::FLOAT32, empty.dataType());
ASSERT_EQ(0, empty.lengthOf());
ASSERT_TRUE(empty.isEmpty());
ASSERT_EQ(shape, empty.getShapeAsVector());
ASSERT_EQ(2, empty.rankOf());
}
TEST_F(EmptyTests, test_shaped_empty_3) {
auto empty = NDArrayFactory::create<float>('c', {0});
std::vector<Nd4jLong> shape = {0};
ASSERT_EQ(nd4j::DataType::FLOAT32, empty.dataType());
ASSERT_EQ(0, empty.lengthOf());
ASSERT_TRUE(empty.isEmpty());
ASSERT_EQ(shape, empty.getShapeAsVector());
ASSERT_EQ(1, empty.rankOf());
}
TEST_F(EmptyTests, test_shaped_empty_4) {
auto shape = ConstantShapeHelper::getInstance()->vectorShapeInfo(0, nd4j::DataType::FLOAT32);
shape::printShapeInfoLinear("shape", shape);
NDArray array(shape, true, nd4j::LaunchContext::defaultContext());
std::vector<Nd4jLong> shapeOf({0});
ASSERT_TRUE(array.isEmpty());
ASSERT_EQ(1, array.rankOf());
ASSERT_EQ(shapeOf, array.getShapeAsVector());
}

Some files were not shown because too many files have changed in this diff Show More