Dev branch merge: dev_20190606 (#7904)
* correct logsoftmax looss (#2) * Small SameDiff listener fix (#4) * Various fixes (#6) * #7839 Fix for asXMatrix and tests * #7866 EmbeddingSequenceLayer dtype fix + test * #7856 SameDiff save/load stream methods * #7859 RegressionEvaluation rank 4 fix + tests + axis configuration * EvaluationBinary 3d/4d * More evaluation 3d/4d tests * #7847 Evaluation empty checks * Small test ifx * #7848 Fix median edge case * Improve DL4J samediff layer tests * [WIP] FastText wrapper implemented (#8) * FastText implemented * Some fixes * Fix shapes for wordsNearest * Validation of input vectors * Fixes * Fixed test * Thread tagged * Some tweaks * setContextClassLoader for DeallocatorServiceThread * Numpy format tests (#1) * Various fixes (#11) * #7852 SameDiff gather fix * #7892 SameDiff placeholder to constant conversion * #7890 validate input rank for MLN/CG init methods * Fix broken permute shape calculation * Permute and gather fixes * Tests * #7850 LogSumExp fix + test * Handful of test fixes * Empty arrays with non-scalar shapes (#10) * minor rearrangements for lambdas * empty tensors with non-scalar shapes * numpy empty tensors with non-scalar shapes * few more empty tweaks * Small fixes * conv3d signature update * micro fix in batchnorm mkldnn * Import fixes * Fix * MKL-DNN update * Small fill fix * fill with empty input + test * Fixes * Small error improvement * Fix * one special test * couple of fixes for lstm * Rewrite TFGraphMapper.getNDArrayFromTensor to be maintainable and less error prone * Fixes * FP16 * Unsigned * BFloat16 * Fill op - empty tweaks * - couple of fixes for empty arrays construction - stack updated * strided slice fix * one transform test * provide method for reducing shapeInfo in case of input array is empty * Fixed reduceAlongDimensions to use empty input properly. * couple of broadcast tests * couple of tests broadcast tests + tweak to make them pass * add check of non-empty to methods producing sub-arrays * Fixed reshapeC with zeros in shape. * complete empty check in reduce_... legacy ops * Concat and cumsum/prod * Tweak to empty shape inference on import * add empty check to the rest of reduce legacy ops * one more test * correct typo in evalReduceShapeInfoEmpty * Added tests for reduce_* ops to tests with zero shapes. * few more tests for empty reductions * Fixed strided_slice op with empty case and tests. * one more empty reduction test * Fixed strided_slice test. * add empty check to NDArray::reshapei * infOrMax * empty min/max with infinity tests * made unstack working correctly with empty arrays * few IndexReduce tests + tweaks for empty shapes * add test for empty concat * few tests fixed * Validation fix for reductions on empty shapes * Reverse fix * Reduction shape calc fixes * SameDiff.generateOutputVariable: don't use shape function to determine number of outputs * Range fix * - NDArray constructor updated for scalars/empty arrays - few tests fixed * More fixes * Empty creator fixes * concat fix * concat fix * TF import tests: allow 'both all NaN' and 'both all inf' to pass * Slice, zero fraction, and reshape fixes * transpose, gather * Zero fraction * scalar cast fix * Empty reduction axis support * few more tests fixed * Fixed input checks conforming with TF for concat op and tests. * few tests fixed * matmul scalar shape fix * Fixed checkout for data type and scalarity with concat to allow non-empty scalars with vector concats. * broadcast bool fix * few more tests * few more tests * correct evalReduceShapeInfoEmpty * argmax/argmin + tests * one more empty edge case + one more test * argmax/argmin/realdiv_bp tweaks * empty reshape test + fix * Helper fixes * Small fixes * Gather test fix * Gather test fix * Small fixes * reduce scalar zero values * scalar mean workaround * Remove debug code * along dim mean workaround * one more test * - equalsTo() tweak for empty arrays - one more test * broadcast tweaksmaster
parent
32e5cc1945
commit
68ea5f3688
|
@ -23,6 +23,7 @@ import org.deeplearning4j.nn.api.Layer;
|
||||||
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
||||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.conf.layers.*;
|
import org.deeplearning4j.nn.conf.layers.*;
|
||||||
import org.deeplearning4j.nn.conf.layers.EmbeddingLayer;
|
import org.deeplearning4j.nn.conf.layers.EmbeddingLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.EmbeddingSequenceLayer;
|
import org.deeplearning4j.nn.conf.layers.EmbeddingSequenceLayer;
|
||||||
|
@ -283,7 +284,6 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
|
||||||
net.computeGradientAndScore();
|
net.computeGradientAndScore();
|
||||||
net2.computeGradientAndScore();
|
net2.computeGradientAndScore();
|
||||||
|
|
||||||
System.out.println(net.score() + "\t" + net2.score());
|
|
||||||
assertEquals(net2.score(), net.score(), 1e-6);
|
assertEquals(net2.score(), net.score(), 1e-6);
|
||||||
|
|
||||||
Map<String, INDArray> gradient = net.gradient().gradientForVariable();
|
Map<String, INDArray> gradient = net.gradient().gradientForVariable();
|
||||||
|
@ -441,85 +441,87 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
|
||||||
int numInputClasses = 10;
|
int numInputClasses = 10;
|
||||||
int timeSeriesLength = 5;
|
int timeSeriesLength = 5;
|
||||||
|
|
||||||
for (int nExamples : miniBatchSizes) {
|
for (DataType maskDtype : new DataType[]{DataType.FLOAT, DataType.DOUBLE, DataType.INT}) {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
for (int nExamples : miniBatchSizes) {
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||||
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
||||||
.updater(new Sgd(0.1)).seed(12345).list()
|
.updater(new Sgd(0.1)).seed(12345).list()
|
||||||
.layer(0, new EmbeddingLayer.Builder().hasBias(true).activation(Activation.TANH).nIn(numInputClasses)
|
.layer(0, new EmbeddingLayer.Builder().hasBias(true).activation(Activation.TANH).nIn(numInputClasses)
|
||||||
.nOut(5).build())
|
.nOut(5).build())
|
||||||
.layer(1, new DenseLayer.Builder().activation(Activation.TANH).nIn(5).nOut(4).build())
|
.layer(1, new DenseLayer.Builder().activation(Activation.TANH).nIn(5).nOut(4).build())
|
||||||
.layer(2, new GravesLSTM.Builder().activation(Activation.TANH).nIn(4).nOut(3).build())
|
.layer(2, new GravesLSTM.Builder().activation(Activation.TANH).nIn(4).nOut(3).build())
|
||||||
.layer(3, new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(3)
|
.layer(3, new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(3)
|
||||||
.nOut(4).build())
|
.nOut(4).build())
|
||||||
.inputPreProcessor(0, new RnnToFeedForwardPreProcessor())
|
.inputPreProcessor(0, new RnnToFeedForwardPreProcessor())
|
||||||
.inputPreProcessor(2, new FeedForwardToRnnPreProcessor()).build();
|
.inputPreProcessor(2, new FeedForwardToRnnPreProcessor()).build();
|
||||||
|
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
net.init();
|
net.init();
|
||||||
|
|
||||||
MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder()
|
MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder()
|
||||||
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
||||||
.updater(new Sgd(0.1)).seed(12345).list()
|
.updater(new Sgd(0.1)).seed(12345).list()
|
||||||
.layer(0, new DenseLayer.Builder().activation(Activation.TANH).nIn(numInputClasses).nOut(5)
|
.layer(0, new DenseLayer.Builder().activation(Activation.TANH).nIn(numInputClasses).nOut(5)
|
||||||
.build())
|
.build())
|
||||||
.layer(1, new DenseLayer.Builder().activation(Activation.TANH).nIn(5).nOut(4).build())
|
.layer(1, new DenseLayer.Builder().activation(Activation.TANH).nIn(5).nOut(4).build())
|
||||||
.layer(2, new GravesLSTM.Builder().activation(Activation.TANH).nIn(4).nOut(3).build())
|
.layer(2, new GravesLSTM.Builder().activation(Activation.TANH).nIn(4).nOut(3).build())
|
||||||
.layer(3, new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(3)
|
.layer(3, new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(3)
|
||||||
.nOut(4).build())
|
.nOut(4).build())
|
||||||
.inputPreProcessor(0, new RnnToFeedForwardPreProcessor())
|
.inputPreProcessor(0, new RnnToFeedForwardPreProcessor())
|
||||||
.inputPreProcessor(2, new FeedForwardToRnnPreProcessor()).build();
|
.inputPreProcessor(2, new FeedForwardToRnnPreProcessor()).build();
|
||||||
|
|
||||||
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
|
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
|
||||||
net2.init();
|
net2.init();
|
||||||
|
|
||||||
net2.setParams(net.params().dup());
|
net2.setParams(net.params().dup());
|
||||||
|
|
||||||
INDArray inEmbedding = Nd4j.zeros(nExamples, 1, timeSeriesLength);
|
INDArray inEmbedding = Nd4j.zeros(nExamples, 1, timeSeriesLength);
|
||||||
INDArray inDense = Nd4j.zeros(nExamples, numInputClasses, timeSeriesLength);
|
INDArray inDense = Nd4j.zeros(nExamples, numInputClasses, timeSeriesLength);
|
||||||
|
|
||||||
INDArray labels = Nd4j.zeros(nExamples, 4, timeSeriesLength);
|
INDArray labels = Nd4j.zeros(nExamples, 4, timeSeriesLength);
|
||||||
|
|
||||||
for (int i = 0; i < nExamples; i++) {
|
for (int i = 0; i < nExamples; i++) {
|
||||||
for (int j = 0; j < timeSeriesLength; j++) {
|
for (int j = 0; j < timeSeriesLength; j++) {
|
||||||
int inIdx = r.nextInt(numInputClasses);
|
int inIdx = r.nextInt(numInputClasses);
|
||||||
inEmbedding.putScalar(new int[]{i, 0, j}, inIdx);
|
inEmbedding.putScalar(new int[]{i, 0, j}, inIdx);
|
||||||
inDense.putScalar(new int[]{i, inIdx, j}, 1.0);
|
inDense.putScalar(new int[]{i, inIdx, j}, 1.0);
|
||||||
|
|
||||||
int outIdx = r.nextInt(4);
|
int outIdx = r.nextInt(4);
|
||||||
labels.putScalar(new int[]{i, outIdx, j}, 1.0);
|
labels.putScalar(new int[]{i, outIdx, j}, 1.0);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
INDArray inputMask = Nd4j.zeros(nExamples, timeSeriesLength);
|
INDArray inputMask = Nd4j.zeros(maskDtype, nExamples, timeSeriesLength);
|
||||||
for (int i = 0; i < nExamples; i++) {
|
for (int i = 0; i < nExamples; i++) {
|
||||||
for (int j = 0; j < timeSeriesLength; j++) {
|
for (int j = 0; j < timeSeriesLength; j++) {
|
||||||
inputMask.putScalar(new int[]{i, j}, (r.nextBoolean() ? 1.0 : 0.0));
|
inputMask.putScalar(new int[]{i, j}, (r.nextBoolean() ? 1.0 : 0.0));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
net.setLayerMaskArrays(inputMask, null);
|
net.setLayerMaskArrays(inputMask, null);
|
||||||
net2.setLayerMaskArrays(inputMask, null);
|
net2.setLayerMaskArrays(inputMask, null);
|
||||||
List<INDArray> actEmbedding = net.feedForward(inEmbedding, false);
|
List<INDArray> actEmbedding = net.feedForward(inEmbedding, false);
|
||||||
List<INDArray> actDense = net2.feedForward(inDense, false);
|
List<INDArray> actDense = net2.feedForward(inDense, false);
|
||||||
for (int i = 1; i < actEmbedding.size(); i++) {
|
for (int i = 1; i < actEmbedding.size(); i++) {
|
||||||
assertEquals(actDense.get(i), actEmbedding.get(i));
|
assertEquals(actDense.get(i), actEmbedding.get(i));
|
||||||
}
|
}
|
||||||
|
|
||||||
net.setLabels(labels);
|
net.setLabels(labels);
|
||||||
net2.setLabels(labels);
|
net2.setLabels(labels);
|
||||||
net.computeGradientAndScore();
|
net.computeGradientAndScore();
|
||||||
net2.computeGradientAndScore();
|
net2.computeGradientAndScore();
|
||||||
|
|
||||||
System.out.println(net.score() + "\t" + net2.score());
|
System.out.println(net.score() + "\t" + net2.score());
|
||||||
assertEquals(net2.score(), net.score(), 1e-5);
|
assertEquals(net2.score(), net.score(), 1e-5);
|
||||||
|
|
||||||
Map<String, INDArray> gradients = net.gradient().gradientForVariable();
|
Map<String, INDArray> gradients = net.gradient().gradientForVariable();
|
||||||
Map<String, INDArray> gradients2 = net2.gradient().gradientForVariable();
|
Map<String, INDArray> gradients2 = net2.gradient().gradientForVariable();
|
||||||
assertEquals(gradients.keySet(), gradients2.keySet());
|
assertEquals(gradients.keySet(), gradients2.keySet());
|
||||||
for (String s : gradients.keySet()) {
|
for (String s : gradients.keySet()) {
|
||||||
assertEquals(gradients2.get(s), gradients.get(s));
|
assertEquals(gradients2.get(s), gradients.get(s));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -583,6 +585,104 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testEmbeddingSequenceLayerWithMasking() {
|
||||||
|
//Idea: have masking on the input with an embedding and dense layers on input
|
||||||
|
//Ensure that the parameter gradients for the inputs don't depend on the inputs when inputs are masked
|
||||||
|
|
||||||
|
int[] miniBatchSizes = {1, 3};
|
||||||
|
int nIn = 2;
|
||||||
|
Random r = new Random(12345);
|
||||||
|
|
||||||
|
int numInputClasses = 10;
|
||||||
|
int timeSeriesLength = 5;
|
||||||
|
|
||||||
|
for (DataType maskDtype : new DataType[]{DataType.FLOAT, DataType.DOUBLE, DataType.INT}) {
|
||||||
|
for (DataType inLabelDtype : new DataType[]{DataType.FLOAT, DataType.DOUBLE, DataType.INT}) {
|
||||||
|
for(int inputRank : new int[]{2, 3}) {
|
||||||
|
for (int nExamples : miniBatchSizes) {
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
|
||||||
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||||
|
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
||||||
|
.updater(new Sgd(0.1)).seed(12345).list()
|
||||||
|
.layer(0, new EmbeddingSequenceLayer.Builder().hasBias(true).activation(Activation.TANH).nIn(numInputClasses)
|
||||||
|
.nOut(5).build())
|
||||||
|
.layer(1, new DenseLayer.Builder().activation(Activation.TANH).nIn(5).nOut(4).build())
|
||||||
|
.layer(2, new LSTM.Builder().activation(Activation.TANH).nIn(4).nOut(3).build())
|
||||||
|
.layer(3, new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(3)
|
||||||
|
.nOut(4).build())
|
||||||
|
.setInputType(InputType.recurrent(1)).build();
|
||||||
|
|
||||||
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
|
net.init();
|
||||||
|
|
||||||
|
MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder()
|
||||||
|
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
||||||
|
.updater(new Sgd(0.1)).seed(12345).list()
|
||||||
|
.layer(0, new DenseLayer.Builder().activation(Activation.TANH).nIn(numInputClasses).nOut(5)
|
||||||
|
.build())
|
||||||
|
.layer(1, new DenseLayer.Builder().activation(Activation.TANH).nIn(5).nOut(4).build())
|
||||||
|
.layer(2, new LSTM.Builder().activation(Activation.TANH).nIn(4).nOut(3).build())
|
||||||
|
.layer(3, new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(3)
|
||||||
|
.nOut(4).build())
|
||||||
|
.setInputType(InputType.recurrent(1)).build();
|
||||||
|
|
||||||
|
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
|
||||||
|
net2.init();
|
||||||
|
|
||||||
|
net2.setParams(net.params().dup());
|
||||||
|
|
||||||
|
INDArray inEmbedding = Nd4j.zeros(inLabelDtype, inputRank == 2 ? new long[]{nExamples, timeSeriesLength} : new long[]{nExamples, 1, timeSeriesLength});
|
||||||
|
INDArray inDense = Nd4j.zeros(inLabelDtype, nExamples, numInputClasses, timeSeriesLength);
|
||||||
|
|
||||||
|
INDArray labels = Nd4j.zeros(inLabelDtype, nExamples, 4, timeSeriesLength);
|
||||||
|
|
||||||
|
for (int i = 0; i < nExamples; i++) {
|
||||||
|
for (int j = 0; j < timeSeriesLength; j++) {
|
||||||
|
int inIdx = r.nextInt(numInputClasses);
|
||||||
|
inEmbedding.putScalar(inputRank == 2 ? new int[]{i, j} : new int[]{i, 0, j}, inIdx);
|
||||||
|
inDense.putScalar(new int[]{i, inIdx, j}, 1.0);
|
||||||
|
|
||||||
|
int outIdx = r.nextInt(4);
|
||||||
|
labels.putScalar(new int[]{i, outIdx, j}, 1.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
INDArray inputMask = Nd4j.zeros(maskDtype, nExamples, timeSeriesLength);
|
||||||
|
for (int i = 0; i < nExamples; i++) {
|
||||||
|
for (int j = 0; j < timeSeriesLength; j++) {
|
||||||
|
inputMask.putScalar(new int[]{i, j}, (r.nextBoolean() ? 1.0 : 0.0));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
net.setLayerMaskArrays(inputMask, null);
|
||||||
|
net2.setLayerMaskArrays(inputMask, null);
|
||||||
|
List<INDArray> actEmbedding = net.feedForward(inEmbedding, false);
|
||||||
|
List<INDArray> actDense = net2.feedForward(inDense, false);
|
||||||
|
for (int i = 2; i < actEmbedding.size(); i++) { //Start from layer 2: EmbeddingSequence is 3d, first dense is 2d (before reshape)
|
||||||
|
assertEquals(actDense.get(i), actEmbedding.get(i));
|
||||||
|
}
|
||||||
|
|
||||||
|
net.setLabels(labels);
|
||||||
|
net2.setLabels(labels);
|
||||||
|
net.computeGradientAndScore();
|
||||||
|
net2.computeGradientAndScore();
|
||||||
|
|
||||||
|
assertEquals(net2.score(), net.score(), 1e-5);
|
||||||
|
|
||||||
|
Map<String, INDArray> gradients = net.gradient().gradientForVariable();
|
||||||
|
Map<String, INDArray> gradients2 = net2.gradient().gradientForVariable();
|
||||||
|
assertEquals(gradients.keySet(), gradients2.keySet());
|
||||||
|
for (String s : gradients.keySet()) {
|
||||||
|
assertEquals(gradients2.get(s), gradients.get(s));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@EqualsAndHashCode
|
@EqualsAndHashCode
|
||||||
private static class WordVectorsMockup implements EmbeddingInitializer {
|
private static class WordVectorsMockup implements EmbeddingInitializer {
|
||||||
|
|
||||||
|
|
|
@ -213,6 +213,12 @@ public class TestSameDiffConv extends BaseDL4JTest {
|
||||||
INDArray outLoaded = netLoaded.output(in);
|
INDArray outLoaded = netLoaded.output(in);
|
||||||
|
|
||||||
assertEquals(msg, outExp, outLoaded);
|
assertEquals(msg, outExp, outLoaded);
|
||||||
|
|
||||||
|
//Sanity check on different minibatch sizes:
|
||||||
|
INDArray newIn = Nd4j.vstack(in, in);
|
||||||
|
INDArray outMbsd = net.output(newIn);
|
||||||
|
INDArray outMb = net2.output(newIn);
|
||||||
|
assertEquals(outMb, outMbsd);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -306,6 +312,10 @@ public class TestSameDiffConv extends BaseDL4JTest {
|
||||||
assertTrue(msg, gradOK);
|
assertTrue(msg, gradOK);
|
||||||
|
|
||||||
TestUtils.testModelSerialization(net);
|
TestUtils.testModelSerialization(net);
|
||||||
|
|
||||||
|
//Sanity check on different minibatch sizes:
|
||||||
|
INDArray newIn = Nd4j.vstack(f, f);
|
||||||
|
net.output(newIn);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -137,6 +137,12 @@ public class TestSameDiffDense extends BaseDL4JTest {
|
||||||
INDArray outLoaded = netLoaded.output(in);
|
INDArray outLoaded = netLoaded.output(in);
|
||||||
|
|
||||||
assertEquals(outExp, outLoaded);
|
assertEquals(outExp, outLoaded);
|
||||||
|
|
||||||
|
//Sanity check on different minibatch sizes:
|
||||||
|
INDArray newIn = Nd4j.vstack(in, in);
|
||||||
|
INDArray outMbsd = net.output(newIn);
|
||||||
|
INDArray outMb = net2.output(newIn);
|
||||||
|
assertEquals(outMb, outMbsd);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -314,6 +320,12 @@ public class TestSameDiffDense extends BaseDL4JTest {
|
||||||
netSD.computeGradientAndScore();
|
netSD.computeGradientAndScore();
|
||||||
// netStandard.computeGradientAndScore();
|
// netStandard.computeGradientAndScore();
|
||||||
// assertEquals(netStandard.gradient().gradient(), netSD.gradient().gradient());
|
// assertEquals(netStandard.gradient().gradient(), netSD.gradient().gradient());
|
||||||
|
|
||||||
|
//Sanity check on different minibatch sizes:
|
||||||
|
INDArray newIn = Nd4j.vstack(in, in);
|
||||||
|
INDArray outMbsd = netSD.output(newIn);
|
||||||
|
INDArray outMb = netStandard.output(newIn);
|
||||||
|
assertEquals(outMb, outMbsd);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -377,6 +389,12 @@ public class TestSameDiffDense extends BaseDL4JTest {
|
||||||
assertEquals(s, netStandard.params(), netSD.params());
|
assertEquals(s, netStandard.params(), netSD.params());
|
||||||
assertEquals(s, netStandard.getUpdater().getStateViewArray(), netSD.getUpdater().getStateViewArray());
|
assertEquals(s, netStandard.getUpdater().getStateViewArray(), netSD.getUpdater().getStateViewArray());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//Sanity check on different minibatch sizes:
|
||||||
|
INDArray newIn = Nd4j.vstack(ds.getFeatures(), ds.getFeatures());
|
||||||
|
INDArray outMbsd = netSD.output(newIn);
|
||||||
|
INDArray outMb = netStandard.output(newIn);
|
||||||
|
assertEquals(outMb, outMbsd);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -417,6 +435,10 @@ public class TestSameDiffDense extends BaseDL4JTest {
|
||||||
assertTrue(msg, gradOK);
|
assertTrue(msg, gradOK);
|
||||||
|
|
||||||
TestUtils.testModelSerialization(net);
|
TestUtils.testModelSerialization(net);
|
||||||
|
|
||||||
|
//Sanity check on different minibatch sizes:
|
||||||
|
INDArray newIn = Nd4j.vstack(f, f);
|
||||||
|
net.output(newIn);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -166,6 +166,12 @@ public class TestSameDiffDenseVertex extends BaseDL4JTest {
|
||||||
outSD = loaded.outputSingle(in);
|
outSD = loaded.outputSingle(in);
|
||||||
outStd = netStandard.outputSingle(in);
|
outStd = netStandard.outputSingle(in);
|
||||||
assertEquals(outStd, outSD);
|
assertEquals(outStd, outSD);
|
||||||
|
|
||||||
|
//Sanity check on different minibatch sizes:
|
||||||
|
INDArray newIn = Nd4j.vstack(in, in);
|
||||||
|
INDArray outMbsd = netSD.output(newIn)[0];
|
||||||
|
INDArray outMb = netStandard.output(newIn)[0];
|
||||||
|
assertEquals(outMb, outMbsd);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -115,6 +115,12 @@ public class TestSameDiffLambda extends BaseDL4JTest {
|
||||||
outStd = std.outputSingle(in);
|
outStd = std.outputSingle(in);
|
||||||
|
|
||||||
assertEquals(outStd, outLambda);
|
assertEquals(outStd, outLambda);
|
||||||
|
|
||||||
|
//Sanity check on different minibatch sizes:
|
||||||
|
INDArray newIn = Nd4j.vstack(in, in);
|
||||||
|
INDArray outMbsd = lambda.output(newIn)[0];
|
||||||
|
INDArray outMb = std.output(newIn)[0];
|
||||||
|
assertEquals(outMb, outMbsd);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -186,5 +192,12 @@ public class TestSameDiffLambda extends BaseDL4JTest {
|
||||||
outStd = std.output(in1, in2)[0];
|
outStd = std.output(in1, in2)[0];
|
||||||
|
|
||||||
assertEquals(outStd, outLambda);
|
assertEquals(outStd, outLambda);
|
||||||
|
|
||||||
|
//Sanity check on different minibatch sizes:
|
||||||
|
INDArray newIn1 = Nd4j.vstack(in1, in1);
|
||||||
|
INDArray newIn2 = Nd4j.vstack(in2, in2);
|
||||||
|
INDArray outMbsd = lambda.output(newIn1, newIn2)[0];
|
||||||
|
INDArray outMb = std.output(newIn1, newIn2)[0];
|
||||||
|
assertEquals(outMb, outMbsd);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -90,6 +90,12 @@ public class TestSameDiffOutput extends BaseDL4JTest {
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(confSD.clone());
|
MultiLayerNetwork net = new MultiLayerNetwork(confSD.clone());
|
||||||
net.init();
|
net.init();
|
||||||
net.fit(ds);
|
net.fit(ds);
|
||||||
|
|
||||||
|
//Sanity check on different minibatch sizes:
|
||||||
|
INDArray newIn = Nd4j.vstack(in, in);
|
||||||
|
INDArray outMbsd = netSD.output(newIn);
|
||||||
|
INDArray outMb = netStd.output(newIn);
|
||||||
|
assertEquals(outMb, outMbsd);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -164,6 +170,12 @@ public class TestSameDiffOutput extends BaseDL4JTest {
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(confSD.clone());
|
MultiLayerNetwork net = new MultiLayerNetwork(confSD.clone());
|
||||||
net.init();
|
net.init();
|
||||||
net.fit(ds);
|
net.fit(ds);
|
||||||
|
|
||||||
|
//Sanity check on different minibatch sizes:
|
||||||
|
INDArray newIn = Nd4j.vstack(in, in);
|
||||||
|
INDArray outMbsd = netSD.output(newIn);
|
||||||
|
INDArray outMb = netStd.output(newIn);
|
||||||
|
assertEquals(outMb, outMbsd);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -32,6 +32,7 @@ import org.deeplearning4j.models.embeddings.learning.impl.elements.SkipGram;
|
||||||
import org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils;
|
import org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils;
|
||||||
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
|
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
|
||||||
import org.deeplearning4j.models.embeddings.wordvectors.WordVectorsImpl;
|
import org.deeplearning4j.models.embeddings.wordvectors.WordVectorsImpl;
|
||||||
|
import org.deeplearning4j.models.fasttext.FastText;
|
||||||
import org.deeplearning4j.models.glove.Glove;
|
import org.deeplearning4j.models.glove.Glove;
|
||||||
import org.deeplearning4j.models.paragraphvectors.ParagraphVectors;
|
import org.deeplearning4j.models.paragraphvectors.ParagraphVectors;
|
||||||
import org.deeplearning4j.models.sequencevectors.SequenceVectors;
|
import org.deeplearning4j.models.sequencevectors.SequenceVectors;
|
||||||
|
@ -3090,6 +3091,42 @@ public class WordVectorSerializer {
|
||||||
return word2Vec;
|
return word2Vec;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static void writeWordVectors(@NonNull FastText vectors, @NonNull File path) throws IOException {
|
||||||
|
ObjectOutputStream outputStream = null;
|
||||||
|
try {
|
||||||
|
outputStream = new ObjectOutputStream(new FileOutputStream(path ));
|
||||||
|
outputStream.writeObject(vectors);
|
||||||
|
}
|
||||||
|
finally {
|
||||||
|
try {
|
||||||
|
if (outputStream != null) {
|
||||||
|
outputStream.flush();
|
||||||
|
outputStream.close();
|
||||||
|
}
|
||||||
|
} catch (IOException ex) {
|
||||||
|
ex.printStackTrace();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static FastText readWordVectors(File path) {
|
||||||
|
FastText result = null;
|
||||||
|
try {
|
||||||
|
FileInputStream fileIn = new FileInputStream(path);
|
||||||
|
ObjectInputStream in = new ObjectInputStream(fileIn);
|
||||||
|
try {
|
||||||
|
result = (FastText) in.readObject();
|
||||||
|
} catch (ClassNotFoundException ex) {
|
||||||
|
|
||||||
|
}
|
||||||
|
} catch (FileNotFoundException ex) {
|
||||||
|
ex.printStackTrace();
|
||||||
|
} catch (IOException ex) {
|
||||||
|
ex.printStackTrace();
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
public static void printOutProjectedMemoryUse(long numWords, int vectorLength, int numTables) {
|
public static void printOutProjectedMemoryUse(long numWords, int vectorLength, int numTables) {
|
||||||
double memSize = numWords * vectorLength * Nd4j.sizeOfDataType() * numTables;
|
double memSize = numWords * vectorLength * Nd4j.sizeOfDataType() * numTables;
|
||||||
|
|
||||||
|
|
|
@ -21,6 +21,7 @@ import lombok.AllArgsConstructor;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import lombok.val;
|
||||||
import org.deeplearning4j.models.embeddings.WeightLookupTable;
|
import org.deeplearning4j.models.embeddings.WeightLookupTable;
|
||||||
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
|
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
|
||||||
import org.deeplearning4j.models.embeddings.reader.ModelUtils;
|
import org.deeplearning4j.models.embeddings.reader.ModelUtils;
|
||||||
|
@ -207,7 +208,6 @@ public class BasicModelUtils<T extends SequenceElement> implements ModelUtils<T>
|
||||||
}
|
}
|
||||||
|
|
||||||
INDArray mean = words.isMatrix() ? words.mean(0) : words;
|
INDArray mean = words.isMatrix() ? words.mean(0) : words;
|
||||||
|
|
||||||
Collection<String> tempRes = wordsNearest(mean, top + positive.size() + negative.size());
|
Collection<String> tempRes = wordsNearest(mean, top + positive.size() + negative.size());
|
||||||
List<String> realResults = new ArrayList<>();
|
List<String> realResults = new ArrayList<>();
|
||||||
|
|
||||||
|
@ -232,6 +232,22 @@ public class BasicModelUtils<T extends SequenceElement> implements ModelUtils<T>
|
||||||
return wordsNearestSum(vec, n);
|
return wordsNearestSum(vec, n);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protected INDArray adjustRank(INDArray words) {
|
||||||
|
if (lookupTable instanceof InMemoryLookupTable) {
|
||||||
|
InMemoryLookupTable l = (InMemoryLookupTable) lookupTable;
|
||||||
|
|
||||||
|
INDArray syn0 = l.getSyn0();
|
||||||
|
if (!words.dataType().equals(syn0.dataType())) {
|
||||||
|
return words.castTo(syn0.dataType());
|
||||||
|
}
|
||||||
|
if (words.rank() == 0 || words.rank() > 2) {
|
||||||
|
throw new IllegalStateException("Invalid rank for wordsNearest method");
|
||||||
|
} else if (words.rank() == 1) {
|
||||||
|
return words.reshape(1, -1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return words;
|
||||||
|
}
|
||||||
/**
|
/**
|
||||||
* Words nearest based on positive and negative words
|
* Words nearest based on positive and negative words
|
||||||
* * @param top the top n words
|
* * @param top the top n words
|
||||||
|
@ -239,6 +255,8 @@ public class BasicModelUtils<T extends SequenceElement> implements ModelUtils<T>
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public Collection<String> wordsNearest(INDArray words, int top) {
|
public Collection<String> wordsNearest(INDArray words, int top) {
|
||||||
|
words = adjustRank(words);
|
||||||
|
|
||||||
if (lookupTable instanceof InMemoryLookupTable) {
|
if (lookupTable instanceof InMemoryLookupTable) {
|
||||||
InMemoryLookupTable l = (InMemoryLookupTable) lookupTable;
|
InMemoryLookupTable l = (InMemoryLookupTable) lookupTable;
|
||||||
|
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.models.embeddings.reader.impl;
|
package org.deeplearning4j.models.embeddings.reader.impl;
|
||||||
|
|
||||||
|
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
|
||||||
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
|
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.ops.transforms.Transforms;
|
import org.nd4j.linalg.ops.transforms.Transforms;
|
||||||
|
@ -64,6 +65,8 @@ public class FlatModelUtils<T extends SequenceElement> extends BasicModelUtils<T
|
||||||
public Collection<String> wordsNearest(INDArray words, int top) {
|
public Collection<String> wordsNearest(INDArray words, int top) {
|
||||||
Counter<String> distances = new Counter<>();
|
Counter<String> distances = new Counter<>();
|
||||||
|
|
||||||
|
words = adjustRank(words);
|
||||||
|
|
||||||
for (String s : vocabCache.words()) {
|
for (String s : vocabCache.words()) {
|
||||||
INDArray otherVec = lookupTable.vector(s);
|
INDArray otherVec = lookupTable.vector(s);
|
||||||
double sim = Transforms.cosineSim(Transforms.unitVec(words.dup()), Transforms.unitVec(otherVec.dup()));
|
double sim = Transforms.cosineSim(Transforms.unitVec(words.dup()), Transforms.unitVec(otherVec.dup()));
|
||||||
|
|
|
@ -103,6 +103,7 @@ public class TreeModelUtils<T extends SequenceElement> extends BasicModelUtils<T
|
||||||
@Override
|
@Override
|
||||||
public Collection<String> wordsNearest(INDArray words, int top) {
|
public Collection<String> wordsNearest(INDArray words, int top) {
|
||||||
checkTree();
|
checkTree();
|
||||||
|
words = adjustRank(words);
|
||||||
|
|
||||||
List<DataPoint> add = new ArrayList<>();
|
List<DataPoint> add = new ArrayList<>();
|
||||||
List<Double> distances = new ArrayList<>();
|
List<Double> distances = new ArrayList<>();
|
||||||
|
|
|
@ -172,4 +172,10 @@ public interface WordVectors extends Serializable, EmbeddingInitializer {
|
||||||
*/
|
*/
|
||||||
void setModelUtils(ModelUtils utils);
|
void setModelUtils(ModelUtils utils);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Does implementation vectorize words absent in vocabulary
|
||||||
|
* @return boolean
|
||||||
|
*/
|
||||||
|
boolean outOfVocabularySupported();
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,6 +20,7 @@ import com.google.common.util.concurrent.AtomicDouble;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
import lombok.Setter;
|
import lombok.Setter;
|
||||||
|
import lombok.val;
|
||||||
import org.apache.commons.lang.ArrayUtils;
|
import org.apache.commons.lang.ArrayUtils;
|
||||||
import org.deeplearning4j.models.embeddings.WeightLookupTable;
|
import org.deeplearning4j.models.embeddings.WeightLookupTable;
|
||||||
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
|
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
|
||||||
|
@ -357,4 +358,9 @@ public class WordVectorsImpl<T extends SequenceElement> implements WordVectors {
|
||||||
public boolean jsonSerializable() {
|
public boolean jsonSerializable() {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean outOfVocabularySupported() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,10 +6,13 @@ import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.lang3.NotImplementedException;
|
import org.apache.commons.lang3.NotImplementedException;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
import org.deeplearning4j.models.embeddings.WeightLookupTable;
|
import org.deeplearning4j.models.embeddings.WeightLookupTable;
|
||||||
|
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
|
||||||
|
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
|
||||||
import org.deeplearning4j.models.embeddings.reader.ModelUtils;
|
import org.deeplearning4j.models.embeddings.reader.ModelUtils;
|
||||||
import org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils;
|
import org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils;
|
||||||
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
|
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
|
||||||
import org.deeplearning4j.models.word2vec.VocabWord;
|
import org.deeplearning4j.models.word2vec.VocabWord;
|
||||||
|
import org.deeplearning4j.models.word2vec.Word2Vec;
|
||||||
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
|
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
|
||||||
import org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache;
|
import org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache;
|
||||||
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
|
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
|
||||||
|
@ -17,41 +20,78 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
|
|
||||||
import java.io.BufferedWriter;
|
import java.io.*;
|
||||||
import java.io.File;
|
import java.util.*;
|
||||||
import java.io.FileWriter;
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.util.Collection;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
@lombok.Builder
|
@lombok.Builder
|
||||||
public class FastText implements WordVectors {
|
public class FastText implements WordVectors, Serializable {
|
||||||
|
|
||||||
private boolean supervised;
|
// Mandatory
|
||||||
private boolean quantize;
|
@Getter private String inputFile;
|
||||||
private boolean predict;
|
@Getter private String outputFile;
|
||||||
private boolean predict_prob;
|
|
||||||
|
|
||||||
private boolean skipgram;
|
// Optional for dictionary
|
||||||
@Builder.Default private int bucket = 100;
|
@Builder.Default private int bucket = -1;
|
||||||
@Builder.Default private int minCount = 1;
|
@Builder.Default private int minCount = -1;
|
||||||
|
@Builder.Default private int minCountLabel = -1;
|
||||||
|
@Builder.Default private int wordNgrams = -1;
|
||||||
|
@Builder.Default private int minNgramLength = -1;
|
||||||
|
@Builder.Default private int maxNgramLength = -1;
|
||||||
|
@Builder.Default private int samplingThreshold = -1;
|
||||||
|
private String labelPrefix;
|
||||||
|
|
||||||
private boolean cbow;
|
// Optional for training
|
||||||
private boolean nn;
|
@Getter private boolean supervised;
|
||||||
private boolean analogies;
|
@Getter private boolean quantize;
|
||||||
private String inputFile;
|
@Getter private boolean predict;
|
||||||
private String outputFile;
|
@Getter private boolean predict_prob;
|
||||||
private SentenceIterator iterator;
|
@Getter private boolean skipgram;
|
||||||
private String modelName;
|
@Getter private boolean cbow;
|
||||||
private String lossName;
|
@Getter private boolean nn;
|
||||||
//TODO:
|
@Getter private boolean analogies;
|
||||||
private double[] pretrainedVectors;
|
@Getter private String pretrainedVectorsFile;
|
||||||
|
@Getter
|
||||||
|
@Builder.Default
|
||||||
|
private double learningRate = -1.0;
|
||||||
|
@Getter private double learningRateUpdate = -1.0;
|
||||||
|
@Getter
|
||||||
|
@Builder.Default
|
||||||
|
private int dim = -1;
|
||||||
|
@Getter
|
||||||
|
@Builder.Default
|
||||||
|
private int contextWindowSize = -1;
|
||||||
|
@Getter
|
||||||
|
@Builder.Default
|
||||||
|
private int epochs = -1;
|
||||||
|
@Getter private String modelName;
|
||||||
|
@Getter private String lossName;
|
||||||
|
@Getter
|
||||||
|
@Builder.Default
|
||||||
|
private int negativeSamples = -1;
|
||||||
|
@Getter
|
||||||
|
@Builder.Default
|
||||||
|
private int numThreads = -1;
|
||||||
|
@Getter private boolean saveOutput = false;
|
||||||
|
|
||||||
private JFastText fastTextImpl;
|
// Optional for quantization
|
||||||
private boolean modelLoaded;
|
@Getter
|
||||||
|
@Builder.Default
|
||||||
|
private int cutOff = -1;
|
||||||
|
@Getter private boolean retrain;
|
||||||
|
@Getter private boolean qnorm;
|
||||||
|
@Getter private boolean qout;
|
||||||
|
@Getter
|
||||||
|
@Builder.Default
|
||||||
|
private int dsub = -1;
|
||||||
|
|
||||||
|
@Getter private SentenceIterator iterator;
|
||||||
|
|
||||||
|
@Builder.Default private transient JFastText fastTextImpl = new JFastText();
|
||||||
|
private transient Word2Vec word2Vec;
|
||||||
|
@Getter private boolean modelLoaded;
|
||||||
|
@Getter private boolean modelVectorsLoaded;
|
||||||
private VocabCache vocabCache;
|
private VocabCache vocabCache;
|
||||||
|
|
||||||
public FastText(File modelPath) {
|
public FastText(File modelPath) {
|
||||||
|
@ -63,8 +103,97 @@ public class FastText implements WordVectors {
|
||||||
fastTextImpl = new JFastText();
|
fastTextImpl = new JFastText();
|
||||||
}
|
}
|
||||||
|
|
||||||
public void init() {
|
private static class ArgsFactory {
|
||||||
fastTextImpl = new JFastText();
|
|
||||||
|
private List<String> args = new ArrayList<>();
|
||||||
|
|
||||||
|
private void add(String label, String value) {
|
||||||
|
args.add(label);
|
||||||
|
args.add(value);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void addOptional(String label, int value) {
|
||||||
|
if (value >= 0) {
|
||||||
|
args.add(label);
|
||||||
|
args.add(Integer.toString(value));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void addOptional(String label, double value) {
|
||||||
|
if (value >= 0.0) {
|
||||||
|
args.add(label);
|
||||||
|
args.add(Double.toString(value));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void addOptional(String label, String value) {
|
||||||
|
if (StringUtils.isNotEmpty(value)) {
|
||||||
|
args.add(label);
|
||||||
|
args.add(value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void addOptional(String label, boolean value) {
|
||||||
|
if (value) {
|
||||||
|
args.add(label);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public String[] args() {
|
||||||
|
String[] asArray = new String[args.size()];
|
||||||
|
return args.toArray(asArray);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private String[] makeArgs() {
|
||||||
|
ArgsFactory argsFactory = new ArgsFactory();
|
||||||
|
|
||||||
|
argsFactory.addOptional("cbow", cbow);
|
||||||
|
argsFactory.addOptional("skipgram", skipgram);
|
||||||
|
argsFactory.addOptional("supervised", supervised);
|
||||||
|
argsFactory.addOptional("quantize", quantize);
|
||||||
|
argsFactory.addOptional("predict", predict);
|
||||||
|
argsFactory.addOptional("predict_prob", predict_prob);
|
||||||
|
|
||||||
|
argsFactory.add("-input", inputFile);
|
||||||
|
argsFactory.add("-output", outputFile );
|
||||||
|
|
||||||
|
argsFactory.addOptional("-pretrainedVectors", pretrainedVectorsFile);
|
||||||
|
|
||||||
|
argsFactory.addOptional("-bucket", bucket);
|
||||||
|
argsFactory.addOptional("-minCount", minCount);
|
||||||
|
argsFactory.addOptional("-minCountLabel", minCountLabel);
|
||||||
|
argsFactory.addOptional("-wordNgrams", wordNgrams);
|
||||||
|
argsFactory.addOptional("-minn", minNgramLength);
|
||||||
|
argsFactory.addOptional("-maxn", maxNgramLength);
|
||||||
|
argsFactory.addOptional("-t", samplingThreshold);
|
||||||
|
argsFactory.addOptional("-label", labelPrefix);
|
||||||
|
argsFactory.addOptional("analogies",analogies);
|
||||||
|
argsFactory.addOptional("-lr", learningRate);
|
||||||
|
argsFactory.addOptional("-lrUpdateRate", learningRateUpdate);
|
||||||
|
argsFactory.addOptional("-dim", dim);
|
||||||
|
argsFactory.addOptional("-ws", contextWindowSize);
|
||||||
|
argsFactory.addOptional("-epoch", epochs);
|
||||||
|
argsFactory.addOptional("-loss", lossName);
|
||||||
|
argsFactory.addOptional("-neg", negativeSamples);
|
||||||
|
argsFactory.addOptional("-thread", numThreads);
|
||||||
|
argsFactory.addOptional("-saveOutput", saveOutput);
|
||||||
|
argsFactory.addOptional("-cutoff", cutOff);
|
||||||
|
argsFactory.addOptional("-retrain", retrain);
|
||||||
|
argsFactory.addOptional("-qnorm", qnorm);
|
||||||
|
argsFactory.addOptional("-qout", qout);
|
||||||
|
argsFactory.addOptional("-dsub", dsub);
|
||||||
|
|
||||||
|
return argsFactory.args();
|
||||||
|
}
|
||||||
|
|
||||||
|
public void fit() {
|
||||||
|
String[] cmd = makeArgs();
|
||||||
|
fastTextImpl.runCmd(cmd);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void loadIterator() {
|
||||||
if (iterator != null) {
|
if (iterator != null) {
|
||||||
try {
|
try {
|
||||||
File tempFile = File.createTempFile("FTX", ".txt");
|
File tempFile = File.createTempFile("FTX", ".txt");
|
||||||
|
@ -81,24 +210,11 @@ public class FastText implements WordVectors {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public void fit() {
|
public void loadPretrainedVectors(File vectorsFile) {
|
||||||
|
word2Vec = WordVectorSerializer.readWord2VecModel(vectorsFile);
|
||||||
String[] cmd;
|
modelVectorsLoaded = true;
|
||||||
if (skipgram) {
|
log.info("Loaded vectorized representation from file %s. Functionality will be restricted.",
|
||||||
cmd = new String[]{"skipgram", "-bucket", Integer.toString(bucket), "-minCount", Integer.toString(minCount),
|
vectorsFile.getAbsolutePath());
|
||||||
"-input", inputFile, "-output", outputFile};
|
|
||||||
}
|
|
||||||
else if (cbow) {
|
|
||||||
cmd = new String[]{"cbow", "-bucket", Integer.toString(bucket), "-minCount", Integer.toString(minCount),
|
|
||||||
"-input", inputFile, "-output", outputFile};
|
|
||||||
}
|
|
||||||
else if (supervised)
|
|
||||||
cmd = new String[]{"supervised", "-input", inputFile,
|
|
||||||
"-output", outputFile};
|
|
||||||
else
|
|
||||||
cmd = new String[]{"-input", inputFile,
|
|
||||||
"-output", outputFile};
|
|
||||||
fastTextImpl.runCmd(cmd);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public void loadBinaryModel(String modelPath) {
|
public void loadBinaryModel(String modelPath) {
|
||||||
|
@ -111,10 +227,18 @@ public class FastText implements WordVectors {
|
||||||
modelLoaded = false;
|
modelLoaded = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void test(File testFile) {
|
||||||
|
fastTextImpl.test(testFile.getAbsolutePath());
|
||||||
|
}
|
||||||
|
|
||||||
|
private void assertModelLoaded() {
|
||||||
|
if (!modelLoaded && !modelVectorsLoaded)
|
||||||
|
throw new IllegalStateException("Model must be loaded before predict!");
|
||||||
|
}
|
||||||
|
|
||||||
public String predict(String text) {
|
public String predict(String text) {
|
||||||
|
|
||||||
if (!modelLoaded)
|
assertModelLoaded();
|
||||||
throw new IllegalStateException("Model must be loaded before predict!");
|
|
||||||
|
|
||||||
String label = fastTextImpl.predict(text);
|
String label = fastTextImpl.predict(text);
|
||||||
return label;
|
return label;
|
||||||
|
@ -122,8 +246,7 @@ public class FastText implements WordVectors {
|
||||||
|
|
||||||
public Pair<String, Float> predictProbability(String text) {
|
public Pair<String, Float> predictProbability(String text) {
|
||||||
|
|
||||||
if (!modelLoaded)
|
assertModelLoaded();
|
||||||
throw new IllegalStateException("Model must be loaded before predict!");
|
|
||||||
|
|
||||||
JFastText.ProbLabel predictedProbLabel = fastTextImpl.predictProba(text);
|
JFastText.ProbLabel predictedProbLabel = fastTextImpl.predictProba(text);
|
||||||
|
|
||||||
|
@ -135,27 +258,39 @@ public class FastText implements WordVectors {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public VocabCache vocab() {
|
public VocabCache vocab() {
|
||||||
if (!modelLoaded)
|
if (modelVectorsLoaded) {
|
||||||
throw new IllegalStateException("Load model before calling vocab()");
|
vocabCache = word2Vec.vocab();
|
||||||
|
|
||||||
if (vocabCache == null) {
|
|
||||||
vocabCache = new AbstractCache();
|
|
||||||
}
|
}
|
||||||
List<String> words = fastTextImpl.getWords();
|
else {
|
||||||
for (int i = 0; i < words.size(); ++i) {
|
if (!modelLoaded)
|
||||||
vocabCache.addWordToIndex(i, words.get(i));
|
throw new IllegalStateException("Load model before calling vocab()");
|
||||||
VocabWord word = new VocabWord();
|
|
||||||
word.setWord(words.get(i));
|
if (vocabCache == null) {
|
||||||
vocabCache.addToken(word);
|
vocabCache = new AbstractCache();
|
||||||
|
}
|
||||||
|
List<String> words = fastTextImpl.getWords();
|
||||||
|
for (int i = 0; i < words.size(); ++i) {
|
||||||
|
vocabCache.addWordToIndex(i, words.get(i));
|
||||||
|
VocabWord word = new VocabWord();
|
||||||
|
word.setWord(words.get(i));
|
||||||
|
vocabCache.addToken(word);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return vocabCache;
|
return vocabCache;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public long vocabSize() {
|
public long vocabSize() {
|
||||||
if (!modelLoaded)
|
long result = 0;
|
||||||
throw new IllegalStateException("Load model before calling vocab()");
|
if (modelVectorsLoaded) {
|
||||||
return fastTextImpl.getNWords();
|
result = word2Vec.vocabSize();
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
if (!modelLoaded)
|
||||||
|
throw new IllegalStateException("Load model before calling vocab()");
|
||||||
|
result = fastTextImpl.getNWords();
|
||||||
|
}
|
||||||
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -170,99 +305,160 @@ public class FastText implements WordVectors {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public double[] getWordVector(String word) {
|
public double[] getWordVector(String word) {
|
||||||
List<Float> vectors = fastTextImpl.getVector(word);
|
if (modelVectorsLoaded) {
|
||||||
double[] retVal = new double[vectors.size()];
|
return word2Vec.getWordVector(word);
|
||||||
for (int i = 0; i < vectors.size(); ++i) {
|
}
|
||||||
retVal[i] = vectors.get(i);
|
else {
|
||||||
|
List<Float> vectors = fastTextImpl.getVector(word);
|
||||||
|
double[] retVal = new double[vectors.size()];
|
||||||
|
for (int i = 0; i < vectors.size(); ++i) {
|
||||||
|
retVal[i] = vectors.get(i);
|
||||||
|
}
|
||||||
|
return retVal;
|
||||||
}
|
}
|
||||||
return retVal;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray getWordVectorMatrixNormalized(String word) {
|
public INDArray getWordVectorMatrixNormalized(String word) {
|
||||||
INDArray r = getWordVectorMatrix(word);
|
if (modelVectorsLoaded) {
|
||||||
return r.divi(Nd4j.getBlasWrapper().nrm2(r));
|
return word2Vec.getWordVectorMatrixNormalized(word);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
INDArray r = getWordVectorMatrix(word);
|
||||||
|
return r.divi(Nd4j.getBlasWrapper().nrm2(r));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray getWordVectorMatrix(String word) {
|
public INDArray getWordVectorMatrix(String word) {
|
||||||
double[] values = getWordVector(word);
|
if (modelVectorsLoaded) {
|
||||||
return Nd4j.createFromArray(values);
|
return word2Vec.getWordVectorMatrix(word);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
double[] values = getWordVector(word);
|
||||||
|
return Nd4j.createFromArray(values);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray getWordVectors(Collection<String> labels) {
|
public INDArray getWordVectors(Collection<String> labels) {
|
||||||
|
if (modelVectorsLoaded) {
|
||||||
|
return word2Vec.getWordVectors(labels);
|
||||||
|
}
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray getWordVectorsMean(Collection<String> labels) {
|
public INDArray getWordVectorsMean(Collection<String> labels) {
|
||||||
|
if (modelVectorsLoaded) {
|
||||||
|
return word2Vec.getWordVectorsMean(labels);
|
||||||
|
}
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private List<String> words = new ArrayList<>();
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean hasWord(String word) {
|
public boolean hasWord(String word) {
|
||||||
return fastTextImpl.getWords().contains(word);
|
if (modelVectorsLoaded) {
|
||||||
|
return word2Vec.outOfVocabularySupported();
|
||||||
|
}
|
||||||
|
if (words.isEmpty())
|
||||||
|
words = fastTextImpl.getWords();
|
||||||
|
return words.contains(word);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected transient ModelUtils modelUtils = new BasicModelUtils<>();
|
protected transient ModelUtils modelUtils;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Collection<String> wordsNearest(INDArray words, int top) {
|
public Collection<String> wordsNearest(INDArray words, int top) {
|
||||||
|
if (modelVectorsLoaded) {
|
||||||
|
return word2Vec.wordsNearest(words, top);
|
||||||
|
}
|
||||||
return modelUtils.wordsNearest(words, top);
|
return modelUtils.wordsNearest(words, top);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Collection<String> wordsNearestSum(INDArray words, int top) {
|
public Collection<String> wordsNearestSum(INDArray words, int top) {
|
||||||
|
if (modelVectorsLoaded) {
|
||||||
|
return word2Vec.wordsNearestSum(words, top);
|
||||||
|
}
|
||||||
return modelUtils.wordsNearestSum(words, top);
|
return modelUtils.wordsNearestSum(words, top);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Collection<String> wordsNearestSum(String word, int n) {
|
public Collection<String> wordsNearestSum(String word, int n) {
|
||||||
|
if (modelVectorsLoaded) {
|
||||||
|
return word2Vec.wordsNearestSum(word, n);
|
||||||
|
}
|
||||||
return modelUtils.wordsNearestSum(word, n);
|
return modelUtils.wordsNearestSum(word, n);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Collection<String> wordsNearestSum(Collection<String> positive, Collection<String> negative, int top) {
|
public Collection<String> wordsNearestSum(Collection<String> positive, Collection<String> negative, int top) {
|
||||||
|
if (modelVectorsLoaded) {
|
||||||
|
return word2Vec.wordsNearestSum(positive, negative, top);
|
||||||
|
}
|
||||||
return modelUtils.wordsNearestSum(positive, negative, top);
|
return modelUtils.wordsNearestSum(positive, negative, top);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Map<String, Double> accuracy(List<String> questions) {
|
public Map<String, Double> accuracy(List<String> questions) {
|
||||||
|
if (modelVectorsLoaded) {
|
||||||
|
return word2Vec.accuracy(questions);
|
||||||
|
}
|
||||||
return modelUtils.accuracy(questions);
|
return modelUtils.accuracy(questions);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int indexOf(String word) {
|
public int indexOf(String word) {
|
||||||
|
if (modelVectorsLoaded) {
|
||||||
|
return word2Vec.indexOf(word);
|
||||||
|
}
|
||||||
return vocab().indexOf(word);
|
return vocab().indexOf(word);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<String> similarWordsInVocabTo(String word, double accuracy) {
|
public List<String> similarWordsInVocabTo(String word, double accuracy) {
|
||||||
|
if (modelVectorsLoaded) {
|
||||||
|
return word2Vec.similarWordsInVocabTo(word, accuracy);
|
||||||
|
}
|
||||||
return modelUtils.similarWordsInVocabTo(word, accuracy);
|
return modelUtils.similarWordsInVocabTo(word, accuracy);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Collection<String> wordsNearest(Collection<String> positive, Collection<String> negative, int top) {
|
public Collection<String> wordsNearest(Collection<String> positive, Collection<String> negative, int top) {
|
||||||
|
if (modelVectorsLoaded) {
|
||||||
|
return word2Vec.wordsNearest(positive, negative, top);
|
||||||
|
}
|
||||||
return modelUtils.wordsNearest(positive, negative, top);
|
return modelUtils.wordsNearest(positive, negative, top);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Collection<String> wordsNearest(String word, int n) {
|
public Collection<String> wordsNearest(String word, int n) {
|
||||||
|
if (modelVectorsLoaded) {
|
||||||
|
return word2Vec.wordsNearest(word,n);
|
||||||
|
}
|
||||||
return modelUtils.wordsNearestSum(word, n);
|
return modelUtils.wordsNearestSum(word, n);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public double similarity(String word, String word2) {
|
public double similarity(String word, String word2) {
|
||||||
|
if (modelVectorsLoaded) {
|
||||||
|
return word2Vec.similarity(word, word2);
|
||||||
|
}
|
||||||
return modelUtils.similarity(word, word2);
|
return modelUtils.similarity(word, word2);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public WeightLookupTable lookupTable() {
|
public WeightLookupTable lookupTable() {
|
||||||
|
if (modelVectorsLoaded) {
|
||||||
|
return word2Vec.lookupTable();
|
||||||
|
}
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -320,4 +516,9 @@ public class FastText implements WordVectors {
|
||||||
return fastTextImpl.getLabelPrefix();
|
return fastTextImpl.getLabelPrefix();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean outOfVocabularySupported() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -376,6 +376,11 @@ public class StaticWord2Vec implements WordVectors {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean outOfVocabularySupported() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
public static class Builder {
|
public static class Builder {
|
||||||
|
|
||||||
private AbstractStorage<Integer> storage;
|
private AbstractStorage<Integer> storage;
|
||||||
|
|
|
@ -1,6 +1,10 @@
|
||||||
package org.deeplearning4j.models.fasttext;
|
package org.deeplearning4j.models.fasttext;
|
||||||
|
|
||||||
|
import com.github.jfasttext.JFastText;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
|
||||||
|
import org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils;
|
||||||
|
import org.deeplearning4j.models.word2vec.Word2Vec;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.text.sentenceiterator.BasicLineIterator;
|
import org.deeplearning4j.text.sentenceiterator.BasicLineIterator;
|
||||||
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
|
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
|
||||||
|
@ -13,6 +17,7 @@ import org.nd4j.resources.Resources;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
import java.util.Arrays;
|
||||||
|
|
||||||
|
|
||||||
import static org.junit.Assert.assertArrayEquals;
|
import static org.junit.Assert.assertArrayEquals;
|
||||||
|
@ -23,7 +28,9 @@ import static org.junit.Assert.assertEquals;
|
||||||
public class FastTextTest extends BaseDL4JTest {
|
public class FastTextTest extends BaseDL4JTest {
|
||||||
|
|
||||||
private File inputFile = Resources.asFile("models/fasttext/data/labeled_data.txt");
|
private File inputFile = Resources.asFile("models/fasttext/data/labeled_data.txt");
|
||||||
private File modelFile = Resources.asFile("models/fasttext/supervised.model.bin");
|
private File supModelFile = Resources.asFile("models/fasttext/supervised.model.bin");
|
||||||
|
private File cbowModelFile = Resources.asFile("models/fasttext/cbow.model.bin");
|
||||||
|
private File supervisedVectors = Resources.asFile("models/fasttext/supervised.model.vec");
|
||||||
|
|
||||||
|
|
||||||
@Rule
|
@Rule
|
||||||
|
@ -39,7 +46,6 @@ public class FastTextTest extends BaseDL4JTest {
|
||||||
inputFile(inputFile.getAbsolutePath()).
|
inputFile(inputFile.getAbsolutePath()).
|
||||||
outputFile(output.getAbsolutePath()).build();
|
outputFile(output.getAbsolutePath()).build();
|
||||||
log.info("\nTraining supervised model ...\n");
|
log.info("\nTraining supervised model ...\n");
|
||||||
fastText.init();
|
|
||||||
fastText.fit();
|
fastText.fit();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -53,7 +59,6 @@ public class FastTextTest extends BaseDL4JTest {
|
||||||
inputFile(inputFile.getAbsolutePath()).
|
inputFile(inputFile.getAbsolutePath()).
|
||||||
outputFile(output.getAbsolutePath()).build();
|
outputFile(output.getAbsolutePath()).build();
|
||||||
log.info("\nTraining supervised model ...\n");
|
log.info("\nTraining supervised model ...\n");
|
||||||
fastText.init();
|
|
||||||
fastText.fit();
|
fastText.fit();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -68,7 +73,6 @@ public class FastTextTest extends BaseDL4JTest {
|
||||||
inputFile(inputFile.getAbsolutePath()).
|
inputFile(inputFile.getAbsolutePath()).
|
||||||
outputFile(output.getAbsolutePath()).build();
|
outputFile(output.getAbsolutePath()).build();
|
||||||
log.info("\nTraining supervised model ...\n");
|
log.info("\nTraining supervised model ...\n");
|
||||||
fastText.init();
|
|
||||||
fastText.fit();
|
fastText.fit();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -82,34 +86,42 @@ public class FastTextTest extends BaseDL4JTest {
|
||||||
inputFile(inputFile.getAbsolutePath()).
|
inputFile(inputFile.getAbsolutePath()).
|
||||||
outputFile(output.getAbsolutePath()).build();
|
outputFile(output.getAbsolutePath()).build();
|
||||||
log.info("\nTraining supervised model ...\n");
|
log.info("\nTraining supervised model ...\n");
|
||||||
fastText.init();
|
|
||||||
fastText.fit();
|
fastText.fit();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Ignore
|
@Test
|
||||||
|
public void tesLoadCBOWModel() throws IOException {
|
||||||
|
|
||||||
|
FastText fastText = new FastText(cbowModelFile);
|
||||||
|
fastText.test(cbowModelFile);
|
||||||
|
|
||||||
|
assertEquals(19, fastText.vocab().numWords());
|
||||||
|
assertEquals("enjoy", fastText.vocab().wordAtIndex(fastText.vocab().numWords() - 1));
|
||||||
|
|
||||||
|
double[] expected = {5.040466203354299E-4, 0.001005030469968915, 2.8882650076411664E-4, -6.413314840756357E-4, -1.78931062691845E-4, -0.0023157168179750443, -0.002215880434960127, 0.00274421414360404, -1.5344757412094623E-4, 4.6274057240225375E-4, -1.4383681991603225E-4, 3.7832374800927937E-4, 2.523412986192852E-4, 0.0018913350068032742, -0.0024741862434893847, -4.976555937901139E-4, 0.0039220210164785385, -0.001781729981303215, -6.010578363202512E-4, -0.00244093406945467, -7.98621098510921E-4, -0.0010007203090935946, -0.001640203408896923, 7.897148607298732E-4, 9.131592814810574E-4, -0.0013367272913455963, -0.0014030139427632093, -7.755287806503475E-4, -4.2878396925516427E-4, 6.912827957421541E-4, -0.0011824817629531026, -0.0036014916840940714, 0.004353308118879795, -7.073904271237552E-5, -9.646290563978255E-4, -0.0031849315855652094, 2.3360115301329643E-4, -2.9103990527801216E-4, -0.0022990566212683916, -0.002393763978034258, -0.001034979010000825, -0.0010725988540798426, 0.0018285386031493545, -0.0013178540393710136, -1.6632364713586867E-4, -1.4665909475297667E-5, 5.445032729767263E-4, 2.999933494720608E-4, -0.0014367225812748075, -0.002345481887459755, 0.001117417006753385, -8.688368834555149E-4, -0.001830018823966384, 0.0013242220738902688, -8.880519890226424E-4, -6.888324278406799E-4, -0.0036394784692674875, 0.002179111586883664, -1.7201311129610986E-4, 0.002365073887631297, 0.002688770182430744, 0.0023955567739903927, 0.001469283364713192, 0.0011803617235273123, 5.871498142369092E-4, -7.099180947989225E-4, 7.518937345594168E-4, -8.599072461947799E-4, -6.600041524507105E-4, -0.002724145073443651, -8.365285466425121E-4, 0.0013173354091122746, 0.001083166105672717, 0.0014539906987920403, -3.1698777456767857E-4, -2.387022686889395E-4, 1.9560157670639455E-4, 0.0020277926232665777, -0.0012741144746541977, -0.0013026101514697075, -1.5212174912448972E-4, 0.0014194383984431624, 0.0012500399025157094, 0.0013362085446715355, 3.692879108712077E-4, 4.319801155361347E-5, 0.0011261265026405454, 0.0017244465416297317, 5.564604725805111E-5, 0.002170475199818611, 0.0014707016525790095, 0.001303741242736578, 0.005553730763494968, -0.0011097051901742816, -0.0013661726843565702, 0.0014100460102781653, 0.0011811562580987811, -6.622733199037611E-4, 7.860265322960913E-4, -9.811905911192298E-4};
|
||||||
|
assertArrayEquals(expected, fastText.getWordVector("enjoy"), 1e-4);
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testPredict() throws IOException {
|
public void testPredict() throws IOException {
|
||||||
for (int i = 0; i < 100; ++i) {
|
|
||||||
String text = "I like soccer";
|
String text = "I like soccer";
|
||||||
|
|
||||||
FastText fastText = new FastText(modelFile);
|
FastText fastText = new FastText(supModelFile);
|
||||||
assertEquals(48, fastText.vocab().numWords());
|
assertEquals(48, fastText.vocab().numWords());
|
||||||
assertEquals("association", fastText.vocab().wordAtIndex(fastText.vocab().numWords() - 1));
|
assertEquals("association", fastText.vocab().wordAtIndex(fastText.vocab().numWords() - 1));
|
||||||
|
|
||||||
double[] expected = {-0.006423053797334433, 0.007660661358386278, 0.006068876478821039, -0.004772625397890806, -0.007143457420170307, -0.007735592778772116, -0.005607823841273785, -0.00836215727031231, 0.0011235733982175589, 2.599214785732329E-4, 0.004131870809942484, 0.007203693501651287, 0.0016768622444942594, 0.008694255724549294, -0.0012487826170399785, -0.00393667770549655, -0.006292815785855055, 0.0049359360709786415, -3.356488887220621E-4, -0.009407570585608482, -0.0026168026961386204, -0.00978928804397583, 0.0032913016621023417, -0.0029464277904480696, -0.008649969473481178, 8.056449587456882E-4, 0.0043088337406516075, -0.008980576880276203, 0.008716211654245853, 0.0073893265798687935, -0.007388216909021139, 0.003814412746578455, -0.005518500227481127, 0.004668557550758123, 0.006603693123906851, 0.003820829326286912, 0.007174000144004822, -0.006393063813447952, -0.0019381389720365405, -0.0046371882781386375, -0.006193376146256924, -0.0036685809027403593, 7.58899434003979E-4, -0.003185075242072344, -0.008330358192324638, 3.3206873922608793E-4, -0.005389622412621975, 0.009706716984510422, 0.0037855932023376226, -0.008665262721478939, -0.0032511046156287193, 4.4134497875347733E-4, -0.008377416990697384, -0.009110655635595322, 0.0019723298028111458, 0.007486093323677778, 0.006400121841579676, 0.00902814231812954, 0.00975200068205595, 0.0060582347214221954, -0.0075621469877660275, 1.0270809434587136E-4, -0.00673140911385417, -0.007316927425563335, 0.009916870854794979, -0.0011407854035496712, -4.502215306274593E-4, -0.007612560410052538, 0.008726916275918484, -3.0280642022262327E-5, 0.005529289599508047, -0.007944817654788494, 0.005593308713287115, 0.003423960180953145, 4.1348213562741876E-4, 0.009524818509817123, -0.0025129399728029966, -0.0030074280221015215, -0.007503866218030453, -0.0028124507516622543, -0.006841592025011778, -2.9375351732596755E-4, 0.007195258513092995, -0.007775942329317331, 3.951996040996164E-4, -0.006887971889227629, 0.0032655203249305487, -0.007975360378623009, -4.840183464693837E-6, 0.004651934839785099, 0.0031739831902086735, 0.004644941072911024, -0.007461248897016048, 0.003057275665923953, 0.008903342299163342, 0.006857945583760738, 0.007567950990051031, 0.001506582135334611, 0.0063307867385447025, 0.005645462777465582};
|
double[] expected = {-0.006423053797334433, 0.007660661358386278, 0.006068876478821039, -0.004772625397890806, -0.007143457420170307, -0.007735592778772116, -0.005607823841273785, -0.00836215727031231, 0.0011235733982175589, 2.599214785732329E-4, 0.004131870809942484, 0.007203693501651287, 0.0016768622444942594, 0.008694255724549294, -0.0012487826170399785, -0.00393667770549655, -0.006292815785855055, 0.0049359360709786415, -3.356488887220621E-4, -0.009407570585608482, -0.0026168026961386204, -0.00978928804397583, 0.0032913016621023417, -0.0029464277904480696, -0.008649969473481178, 8.056449587456882E-4, 0.0043088337406516075, -0.008980576880276203, 0.008716211654245853, 0.0073893265798687935, -0.007388216909021139, 0.003814412746578455, -0.005518500227481127, 0.004668557550758123, 0.006603693123906851, 0.003820829326286912, 0.007174000144004822, -0.006393063813447952, -0.0019381389720365405, -0.0046371882781386375, -0.006193376146256924, -0.0036685809027403593, 7.58899434003979E-4, -0.003185075242072344, -0.008330358192324638, 3.3206873922608793E-4, -0.005389622412621975, 0.009706716984510422, 0.0037855932023376226, -0.008665262721478939, -0.0032511046156287193, 4.4134497875347733E-4, -0.008377416990697384, -0.009110655635595322, 0.0019723298028111458, 0.007486093323677778, 0.006400121841579676, 0.00902814231812954, 0.00975200068205595, 0.0060582347214221954, -0.0075621469877660275, 1.0270809434587136E-4, -0.00673140911385417, -0.007316927425563335, 0.009916870854794979, -0.0011407854035496712, -4.502215306274593E-4, -0.007612560410052538, 0.008726916275918484, -3.0280642022262327E-5, 0.005529289599508047, -0.007944817654788494, 0.005593308713287115, 0.003423960180953145, 4.1348213562741876E-4, 0.009524818509817123, -0.0025129399728029966, -0.0030074280221015215, -0.007503866218030453, -0.0028124507516622543, -0.006841592025011778, -2.9375351732596755E-4, 0.007195258513092995, -0.007775942329317331, 3.951996040996164E-4, -0.006887971889227629, 0.0032655203249305487, -0.007975360378623009, -4.840183464693837E-6, 0.004651934839785099, 0.0031739831902086735, 0.004644941072911024, -0.007461248897016048, 0.003057275665923953, 0.008903342299163342, 0.006857945583760738, 0.007567950990051031, 0.001506582135334611, 0.0063307867385447025, 0.005645462777465582};
|
||||||
assertArrayEquals(expected, fastText.getWordVector("association"), 1e-5);
|
assertArrayEquals(expected, fastText.getWordVector("association"), 1e-4);
|
||||||
|
|
||||||
String label = fastText.predict(text);
|
String label = fastText.predict(text);
|
||||||
assertEquals("__label__soccer", label);
|
assertEquals("__label__soccer", label);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Ignore
|
|
||||||
@Test
|
@Test
|
||||||
public void testPredictProbability() throws IOException {
|
public void testPredictProbability() throws IOException {
|
||||||
String text = "I like soccer";
|
String text = "I like soccer";
|
||||||
|
|
||||||
FastText fastText = new FastText(modelFile);
|
FastText fastText = new FastText(supModelFile);
|
||||||
|
|
||||||
Pair<String,Float> result = fastText.predictProbability(text);
|
Pair<String,Float> result = fastText.predictProbability(text);
|
||||||
assertEquals("__label__soccer", result.getFirst());
|
assertEquals("__label__soccer", result.getFirst());
|
||||||
|
@ -129,7 +141,7 @@ public class FastTextTest extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testVocabulary() throws IOException {
|
public void testVocabulary() throws IOException {
|
||||||
FastText fastText = new FastText(modelFile);
|
FastText fastText = new FastText(supModelFile);
|
||||||
assertEquals(48, fastText.vocab().numWords());
|
assertEquals(48, fastText.vocab().numWords());
|
||||||
assertEquals(48, fastText.vocabSize());
|
assertEquals(48, fastText.vocabSize());
|
||||||
|
|
||||||
|
@ -149,7 +161,7 @@ public class FastTextTest extends BaseDL4JTest {
|
||||||
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
|
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
|
||||||
FastText fastText =
|
FastText fastText =
|
||||||
FastText.builder().supervised(true).iterator(iter).build();
|
FastText.builder().supervised(true).iterator(iter).build();
|
||||||
fastText.init();
|
fastText.loadIterator();
|
||||||
|
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
log.error(e.toString());
|
log.error(e.toString());
|
||||||
|
@ -162,4 +174,60 @@ public class FastTextTest extends BaseDL4JTest {
|
||||||
String label = fastText.predict("something");
|
String label = fastText.predict("something");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testPretrainedVectors() throws IOException {
|
||||||
|
File output = testDir.newFile();
|
||||||
|
|
||||||
|
FastText fastText =
|
||||||
|
FastText.builder().supervised(true).
|
||||||
|
inputFile(inputFile.getAbsolutePath()).
|
||||||
|
pretrainedVectorsFile(supervisedVectors.getAbsolutePath()).
|
||||||
|
outputFile(output.getAbsolutePath()).build();
|
||||||
|
log.info("\nTraining supervised model ...\n");
|
||||||
|
fastText.fit();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testWordsStatistics() throws IOException {
|
||||||
|
|
||||||
|
File output = testDir.newFile();
|
||||||
|
|
||||||
|
FastText fastText =
|
||||||
|
FastText.builder().supervised(true).
|
||||||
|
inputFile(inputFile.getAbsolutePath()).
|
||||||
|
outputFile(output.getAbsolutePath()).build();
|
||||||
|
|
||||||
|
log.info("\nTraining supervised model ...\n");
|
||||||
|
fastText.fit();
|
||||||
|
|
||||||
|
Word2Vec word2Vec = WordVectorSerializer.readAsCsv(new File(output.getAbsolutePath() + ".vec"));
|
||||||
|
|
||||||
|
assertEquals(48, word2Vec.getVocab().numWords());
|
||||||
|
|
||||||
|
System.out.println(word2Vec.wordsNearest("association", 3));
|
||||||
|
System.out.println(word2Vec.similarity("Football", "teams"));
|
||||||
|
System.out.println(word2Vec.similarity("professional", "minutes"));
|
||||||
|
System.out.println(word2Vec.similarity("java","cpp"));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testWordsNativeStatistics() throws IOException {
|
||||||
|
|
||||||
|
File output = testDir.newFile();
|
||||||
|
|
||||||
|
FastText fastText = new FastText();
|
||||||
|
fastText.loadPretrainedVectors(supervisedVectors);
|
||||||
|
|
||||||
|
log.info("\nTraining supervised model ...\n");
|
||||||
|
|
||||||
|
assertEquals(48, fastText.vocab().numWords());
|
||||||
|
|
||||||
|
String[] result = new String[3];
|
||||||
|
fastText.wordsNearest("association", 3).toArray(result);
|
||||||
|
assertArrayEquals(new String[]{"most","eleven","hours"}, result);
|
||||||
|
assertEquals(0.1657, fastText.similarity("Football", "teams"), 1e-4);
|
||||||
|
assertEquals(0.3661, fastText.similarity("professional", "minutes"), 1e-4);
|
||||||
|
assertEquals(Double.NaN, fastText.similarity("java","cpp"), 1e-4);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,6 +26,7 @@ import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
|
||||||
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
|
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
|
||||||
import org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils;
|
import org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils;
|
||||||
import org.deeplearning4j.models.embeddings.reader.impl.FlatModelUtils;
|
import org.deeplearning4j.models.embeddings.reader.impl.FlatModelUtils;
|
||||||
|
import org.deeplearning4j.models.fasttext.FastText;
|
||||||
import org.deeplearning4j.models.paragraphvectors.ParagraphVectors;
|
import org.deeplearning4j.models.paragraphvectors.ParagraphVectors;
|
||||||
import org.deeplearning4j.models.sequencevectors.SequenceVectors;
|
import org.deeplearning4j.models.sequencevectors.SequenceVectors;
|
||||||
import org.deeplearning4j.models.word2vec.VocabWord;
|
import org.deeplearning4j.models.word2vec.VocabWord;
|
||||||
|
@ -42,6 +43,7 @@ import org.nd4j.linalg.factory.Nd4j;
|
||||||
import java.io.ByteArrayInputStream;
|
import java.io.ByteArrayInputStream;
|
||||||
import java.io.ByteArrayOutputStream;
|
import java.io.ByteArrayOutputStream;
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
|
import java.io.IOException;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
|
|
||||||
import static org.junit.Assert.*;
|
import static org.junit.Assert.*;
|
||||||
|
@ -289,4 +291,33 @@ public class WordVectorSerializerTest extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void FastText_Correct_WhenDeserialized() throws IOException {
|
||||||
|
|
||||||
|
FastText fastText =
|
||||||
|
FastText.builder().cbow(true).build();
|
||||||
|
|
||||||
|
WordVectorSerializer.writeWordVectors(fastText, new File("some.data"));
|
||||||
|
|
||||||
|
FastText deser = null;
|
||||||
|
try {
|
||||||
|
ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
||||||
|
deser = WordVectorSerializer.readWordVectors(new File("some.data"));
|
||||||
|
} catch (Exception e) {
|
||||||
|
e.printStackTrace();
|
||||||
|
fail();
|
||||||
|
}
|
||||||
|
|
||||||
|
assertNotNull(deser);
|
||||||
|
assertEquals(fastText.isCbow(), deser.isCbow());
|
||||||
|
assertEquals(fastText.isModelLoaded(), deser.isModelLoaded());
|
||||||
|
assertEquals(fastText.isAnalogies(), deser.isAnalogies());
|
||||||
|
assertEquals(fastText.isNn(), deser.isNn());
|
||||||
|
assertEquals(fastText.isPredict(), deser.isPredict());
|
||||||
|
assertEquals(fastText.isPredict_prob(), deser.isPredict_prob());
|
||||||
|
assertEquals(fastText.isQuantize(), deser.isQuantize());
|
||||||
|
assertEquals(fastText.getInputFile(), deser.getInputFile());
|
||||||
|
assertEquals(fastText.getOutputFile(), deser.getOutputFile());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -453,6 +453,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
||||||
|
|
||||||
DataType netDtype = getConfiguration().getDataType();
|
DataType netDtype = getConfiguration().getDataType();
|
||||||
if(parameters != null && parameters.dataType() != netDtype){
|
if(parameters != null && parameters.dataType() != netDtype){
|
||||||
|
Preconditions.checkState(parameters.rank() == 2 && parameters.size(0) == 1, "Invalid parameters array: should be rank 2 with shape [1,numParams]. Got %ndShape", parameters);
|
||||||
if(cloneParametersArray){
|
if(cloneParametersArray){
|
||||||
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
|
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
|
||||||
parameters = parameters.castTo(netDtype);
|
parameters = parameters.castTo(netDtype);
|
||||||
|
|
|
@ -178,8 +178,7 @@ public class EmbeddingSequenceLayer extends BaseLayer<org.deeplearning4j.nn.conf
|
||||||
", mask shape: " + Arrays.toString(maskArray.shape()));
|
", mask shape: " + Arrays.toString(maskArray.shape()));
|
||||||
}
|
}
|
||||||
//Returned array: rank 3, shape [mb, vector, seqLength]. mask shape: [mb, seqLength]
|
//Returned array: rank 3, shape [mb, vector, seqLength]. mask shape: [mb, seqLength]
|
||||||
Broadcast.mul(ret, maskArray, ret, 0, 2);
|
Broadcast.mul(ret, maskArray.castTo(ret.dataType()), ret, 0, 2);
|
||||||
// ret.muliColumnVector(maskArray);
|
|
||||||
}
|
}
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
|
@ -616,6 +616,7 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
|
||||||
|
|
||||||
DataType netDtype = getLayerWiseConfigurations().getDataType();
|
DataType netDtype = getLayerWiseConfigurations().getDataType();
|
||||||
if(parameters != null && parameters.dataType() != netDtype){
|
if(parameters != null && parameters.dataType() != netDtype){
|
||||||
|
Preconditions.checkState(parameters.rank() == 2 && parameters.size(0) == 1, "Invalid parameters array: should be rank 2 with shape [1,numParams]. Got %ndShape", parameters);
|
||||||
if(cloneParametersArray){
|
if(cloneParametersArray){
|
||||||
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
|
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
|
||||||
parameters = parameters.castTo(netDtype);
|
parameters = parameters.castTo(netDtype);
|
||||||
|
@ -627,6 +628,7 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
if (layerMap == null)
|
if (layerMap == null)
|
||||||
layerMap = new LinkedHashMap<>();
|
layerMap = new LinkedHashMap<>();
|
||||||
|
|
||||||
|
|
|
@ -32,7 +32,7 @@ import java.util.Arrays;
|
||||||
@Ignore
|
@Ignore
|
||||||
public class TestSameDiffUI {
|
public class TestSameDiffUI {
|
||||||
|
|
||||||
// @Ignore
|
@Ignore
|
||||||
@Test
|
@Test
|
||||||
public void testSameDiff() throws Exception {
|
public void testSameDiff() throws Exception {
|
||||||
|
|
||||||
|
|
|
@ -1598,9 +1598,6 @@ namespace nd4j {
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
int NDArray::rankOf() const {
|
int NDArray::rankOf() const {
|
||||||
if (isEmpty())
|
|
||||||
return 0;
|
|
||||||
|
|
||||||
return shape::rank(_shapeInfo);
|
return shape::rank(_shapeInfo);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -36,7 +36,7 @@ std::string NDArray::e(const Nd4jLong i) const;
|
||||||
template <typename T>
|
template <typename T>
|
||||||
NDArray* NDArray::asT() const{
|
NDArray* NDArray::asT() const{
|
||||||
|
|
||||||
auto result = new NDArray(ordering(), isScalar() ? std::vector<Nd4jLong>({0}) : getShapeAsVector(), DataTypeUtils::fromT<T>());
|
auto result = isScalar() ? new NDArray('c', {}, {0.}, DataTypeUtils::fromT<T>(), this->getContext()) : new NDArray(ordering(), getShapeAsVector(), DataTypeUtils::fromT<T>(), this->getContext());
|
||||||
auto l = this->lengthOf();
|
auto l = this->lengthOf();
|
||||||
|
|
||||||
prepareSpecialUse({result}, {this});
|
prepareSpecialUse({result}, {this});
|
||||||
|
@ -67,17 +67,18 @@ NDArray::NDArray(const NDArray& other) {
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
NDArray::NDArray(const char order, const std::vector<Nd4jLong> &shape, nd4j::DataType dtype, nd4j::LaunchContext * context) {
|
NDArray::NDArray(const char order, const std::vector<Nd4jLong> &shape, nd4j::DataType dtype, nd4j::LaunchContext * context) {
|
||||||
|
|
||||||
if (shape.empty())
|
|
||||||
throw std::runtime_error("NDArray constructor: input shape is empty !");
|
|
||||||
|
|
||||||
if ((int) shape.size() > MAX_RANK)
|
if ((int) shape.size() > MAX_RANK)
|
||||||
throw std::invalid_argument("Rank of NDArray can't exceed 32");
|
throw std::invalid_argument("Rank of NDArray can't exceed 32");
|
||||||
|
|
||||||
_context = context;
|
_context = context;
|
||||||
_isAttached = getContext()->getWorkspace() != nullptr;
|
_isAttached = _context->getWorkspace() != nullptr;
|
||||||
_offset = 0;
|
_offset = 0;
|
||||||
|
|
||||||
setShapeInfo(ShapeDescriptor(dtype, order, shape));
|
if (shape.empty())
|
||||||
|
setShapeInfo(ShapeDescriptor::emptyDescriptor(dtype));
|
||||||
|
else
|
||||||
|
setShapeInfo(ShapeDescriptor(dtype, order, shape));
|
||||||
|
|
||||||
_buffer = std::make_shared<DataBuffer>(lengthOf() * DataTypeUtils::sizeOf(dtype), dtype, getContext()->getWorkspace());
|
_buffer = std::make_shared<DataBuffer>(lengthOf() * DataTypeUtils::sizeOf(dtype), dtype, getContext()->getWorkspace());
|
||||||
_buffer->setToZeroBuffers();
|
_buffer->setToZeroBuffers();
|
||||||
}
|
}
|
||||||
|
@ -85,16 +86,20 @@ NDArray::NDArray(const char order, const std::vector<Nd4jLong> &shape, nd4j::Dat
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
NDArray::NDArray(const char order, const std::vector<Nd4jLong> &shape, const std::vector<double>& data, nd4j::DataType dtype, nd4j::LaunchContext * context) {
|
NDArray::NDArray(const char order, const std::vector<Nd4jLong> &shape, const std::vector<double>& data, nd4j::DataType dtype, nd4j::LaunchContext * context) {
|
||||||
|
|
||||||
if (shape.empty())
|
|
||||||
throw std::runtime_error("NDArray constructor: input shape is empty !");
|
|
||||||
|
|
||||||
if ((int) shape.size() > MAX_RANK)
|
if ((int) shape.size() > MAX_RANK)
|
||||||
throw std::invalid_argument("Rank of NDArray can't exceed 32");
|
throw std::invalid_argument("Rank of NDArray can't exceed 32");
|
||||||
|
|
||||||
_context = context;
|
_context = context;
|
||||||
_offset = 0;
|
_offset = 0;
|
||||||
|
|
||||||
setShapeInfo(ShapeDescriptor(dtype, order, shape));
|
if (shape.size() == 0) {
|
||||||
|
if (data.size() == 0)
|
||||||
|
setShapeInfo(ShapeDescriptor::emptyDescriptor(dtype));
|
||||||
|
else
|
||||||
|
setShapeInfo(ShapeDescriptor::scalarDescriptor(dtype));
|
||||||
|
} else {
|
||||||
|
setShapeInfo(ShapeDescriptor(dtype, order, shape));
|
||||||
|
}
|
||||||
|
|
||||||
if (lengthOf() != data.size()) {
|
if (lengthOf() != data.size()) {
|
||||||
nd4j_printf("NDArray constructor: data size [%i] doesn't match shape length [%i]\n", data.size(), lengthOf());
|
nd4j_printf("NDArray constructor: data size [%i] doesn't match shape length [%i]\n", data.size(), lengthOf());
|
||||||
|
@ -2441,6 +2446,9 @@ void NDArray::applyTrueBroadcast(nd4j::BroadcastOpsTuple op, const NDArray* othe
|
||||||
if(((op.s == scalar::Divide || op.s == scalar::FloorDiv || op.s == scalar::FloorMod) && other->isB()) || (op.s == scalar::ReverseDivide && this->isB()))
|
if(((op.s == scalar::Divide || op.s == scalar::FloorDiv || op.s == scalar::FloorMod) && other->isB()) || (op.s == scalar::ReverseDivide && this->isB()))
|
||||||
throw std::runtime_error("NDArray::applyTrueBroadcast method: you can't divide by bool array !");
|
throw std::runtime_error("NDArray::applyTrueBroadcast method: you can't divide by bool array !");
|
||||||
|
|
||||||
|
if (isEmpty() || other->isEmpty())
|
||||||
|
return;
|
||||||
|
|
||||||
NDArray::prepareSpecialUse({target}, {this, other});
|
NDArray::prepareSpecialUse({target}, {this, other});
|
||||||
|
|
||||||
if (isScalar()) {
|
if (isScalar()) {
|
||||||
|
@ -2513,6 +2521,9 @@ void NDArray::applyTrueBroadcast(nd4j::BroadcastBoolOpsTuple op, const NDArray*
|
||||||
if(target == nullptr || other == nullptr)
|
if(target == nullptr || other == nullptr)
|
||||||
throw std::runtime_error("NDArray::applyTrueBroadcast bool method: target or other = nullptr !");
|
throw std::runtime_error("NDArray::applyTrueBroadcast bool method: target or other = nullptr !");
|
||||||
|
|
||||||
|
if (isEmpty() || other->isEmpty())
|
||||||
|
return;
|
||||||
|
|
||||||
NDArray::prepareSpecialUse({target}, {this, other});
|
NDArray::prepareSpecialUse({target}, {this, other});
|
||||||
|
|
||||||
if (isScalar()) {
|
if (isScalar()) {
|
||||||
|
@ -2583,6 +2594,13 @@ void NDArray::applyTrueBroadcast(nd4j::BroadcastBoolOpsTuple op, const NDArray*
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
NDArray NDArray::applyTrueBroadcast(nd4j::BroadcastOpsTuple op, const NDArray& other, ExtraArguments *extraArgs) const {
|
NDArray NDArray::applyTrueBroadcast(nd4j::BroadcastOpsTuple op, const NDArray& other, ExtraArguments *extraArgs) const {
|
||||||
|
if (isEmpty() || other.isEmpty()) {
|
||||||
|
if (isEmpty())
|
||||||
|
return NDArray(*this);
|
||||||
|
else
|
||||||
|
return NDArray(other);
|
||||||
|
}
|
||||||
|
|
||||||
Nd4jLong* newShapeInfo = nullptr;
|
Nd4jLong* newShapeInfo = nullptr;
|
||||||
if(!ShapeUtils::evalBroadcastShapeInfo(*this, other, true, newShapeInfo, getContext()->getWorkspace())) // the rank of new array = max->rankOf)()
|
if(!ShapeUtils::evalBroadcastShapeInfo(*this, other, true, newShapeInfo, getContext()->getWorkspace())) // the rank of new array = max->rankOf)()
|
||||||
throw std::runtime_error("NDArray::applyTrueBroadcast method: the shapes of this and other arrays are not suitable for broadcast operation !");
|
throw std::runtime_error("NDArray::applyTrueBroadcast method: the shapes of this and other arrays are not suitable for broadcast operation !");
|
||||||
|
@ -2812,6 +2830,19 @@ bool NDArray::reshapei(const char order, const std::vector<Nd4jLong>& cshape) {
|
||||||
if(order == ordering() && shape::shapeEquals(rankOf(), shapeOf(), cshape.size(), cshape.data()))
|
if(order == ordering() && shape::shapeEquals(rankOf(), shapeOf(), cshape.size(), cshape.data()))
|
||||||
return true;
|
return true;
|
||||||
|
|
||||||
|
const bool isOutShapeEmpty = std::find(cshape.begin(), cshape.end(), 0) != cshape.end();
|
||||||
|
|
||||||
|
if(isEmpty() && !isOutShapeEmpty)
|
||||||
|
throw std::invalid_argument("NDArray::reshapei: can't reshape empty array to non-empty !");
|
||||||
|
if(!isEmpty() && isOutShapeEmpty)
|
||||||
|
throw std::invalid_argument("NDArray::reshapei: can't reshape non-empty array to empty !");
|
||||||
|
if(isEmpty() && isOutShapeEmpty) {
|
||||||
|
Nd4jLong* shapeInfoNew = ShapeBuilders::emptyShapeInfo(dataType(), order, cshape, getContext()->getWorkspace());
|
||||||
|
setShapeInfo(shapeInfoNew);
|
||||||
|
RELEASE(shapeInfoNew, getContext()->getWorkspace());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<Nd4jLong> shape(cshape);
|
std::vector<Nd4jLong> shape(cshape);
|
||||||
int rank = shape.size();
|
int rank = shape.size();
|
||||||
|
|
||||||
|
@ -2823,7 +2854,7 @@ bool NDArray::reshapei(const char order, const std::vector<Nd4jLong>& cshape) {
|
||||||
for (int i = 0; i < (int) shape.size(); i++) {
|
for (int i = 0; i < (int) shape.size(); i++) {
|
||||||
if (shape[i] < 0) {
|
if (shape[i] < 0) {
|
||||||
if (numberNegativesOnes >= 1)
|
if (numberNegativesOnes >= 1)
|
||||||
throw std::runtime_error("Only one dimension can be negative at once");
|
throw std::runtime_error("NDArray::reshapei: only one dimension can be negative at once");
|
||||||
|
|
||||||
numberNegativesOnes++;
|
numberNegativesOnes++;
|
||||||
|
|
||||||
|
@ -3664,7 +3695,7 @@ void NDArray::reduceAlongDimension(nd4j::reduce::SameOps op, NDArray* target, co
|
||||||
if(rankOf() == copy.size() || copy.empty()) {
|
if(rankOf() == copy.size() || copy.empty()) {
|
||||||
NativeOpExecutioner::execReduceSameScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), nullptr, target->getBuffer(), target->getShapeInfo(), target->getSpecialBuffer(), target->getSpecialShapeInfo());
|
NativeOpExecutioner::execReduceSameScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), nullptr, target->getBuffer(), target->getShapeInfo(), target->getSpecialBuffer(), target->getSpecialShapeInfo());
|
||||||
}
|
}
|
||||||
else {
|
else { //if (!isEmpty()) {
|
||||||
auto pDims = nd4j::Environment::getInstance()->isCPU() ? copy.data() : nullptr;
|
auto pDims = nd4j::Environment::getInstance()->isCPU() ? copy.data() : nullptr;
|
||||||
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(this->getShapeInfo(), copy);
|
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(this->getShapeInfo(), copy);
|
||||||
NativeOpExecutioner::execReduceSame(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), nullptr, target->getBuffer(), target->getShapeInfo(), target->getSpecialBuffer(), target->getSpecialShapeInfo(), pDims, copy.size(), packX.platformShapeInfo(), packX.platformOffsets());
|
NativeOpExecutioner::execReduceSame(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), nullptr, target->getBuffer(), target->getShapeInfo(), target->getSpecialBuffer(), target->getSpecialShapeInfo(), pDims, copy.size(), packX.platformShapeInfo(), packX.platformOffsets());
|
||||||
|
@ -4198,6 +4229,9 @@ NDArray* NDArray::tensorAlongDimension(Nd4jLong index, const std::vector<int>& d
|
||||||
// operator returns sub-array with buffer pointing at this->_buffer + certain offset
|
// operator returns sub-array with buffer pointing at this->_buffer + certain offset
|
||||||
NDArray NDArray::operator()(const std::vector<Nd4jLong>& idx, const bool keepUnitiesInShape, const bool isStrided) const {
|
NDArray NDArray::operator()(const std::vector<Nd4jLong>& idx, const bool keepUnitiesInShape, const bool isStrided) const {
|
||||||
|
|
||||||
|
if(isEmpty())
|
||||||
|
throw std::invalid_argument("NDArray::operator(sub-arrays): array is empty !");
|
||||||
|
|
||||||
const int rank = rankOf();
|
const int rank = rankOf();
|
||||||
Nd4jLong *newShapeInfo = ShapeBuilders::copyShapeInfo(getShapeInfo(), true, getContext()->getWorkspace());
|
Nd4jLong *newShapeInfo = ShapeBuilders::copyShapeInfo(getShapeInfo(), true, getContext()->getWorkspace());
|
||||||
|
|
||||||
|
@ -4260,6 +4294,9 @@ NDArray NDArray::operator()(const Nd4jLong subArrIdx, const std::vector<int>& di
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
void NDArray::getSubArrShapeAndOffsets(const std::vector<int>& dimsToExclude, Nd4jLong* &subArrShapeInfo, Nd4jLong* &subArrOffsets, bool keepUnitiesInShape) const {
|
void NDArray::getSubArrShapeAndOffsets(const std::vector<int>& dimsToExclude, Nd4jLong* &subArrShapeInfo, Nd4jLong* &subArrOffsets, bool keepUnitiesInShape) const {
|
||||||
|
|
||||||
|
if(isEmpty())
|
||||||
|
throw std::invalid_argument("NDArray::getSubArrShapeAndOffsets: array is empty !");
|
||||||
|
|
||||||
const int rank = rankOf();
|
const int rank = rankOf();
|
||||||
const int subArrRank = (rank == dimsToExclude.size() || keepUnitiesInShape) ? rank : rank - dimsToExclude.size();
|
const int subArrRank = (rank == dimsToExclude.size() || keepUnitiesInShape) ? rank : rank - dimsToExclude.size();
|
||||||
const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(_shapeInfo, dimsToExclude);
|
const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(_shapeInfo, dimsToExclude);
|
||||||
|
|
|
@ -1334,18 +1334,7 @@ public:
|
||||||
* @param npyArray
|
* @param npyArray
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
Nd4jPointer shapeBufferForNumpy(Nd4jPointer npyArray) {
|
Nd4jPointer shapeBufferForNumpy(Nd4jPointer npyArray);
|
||||||
cnpy::NpyArray arr = cnpy::loadNpyFromPointer(reinterpret_cast<char *>(npyArray));
|
|
||||||
unsigned int shapeSize = arr.shape.size();
|
|
||||||
auto shape = new unsigned int[shapeSize];
|
|
||||||
for(unsigned int i = 0; i < shapeSize; i++) {
|
|
||||||
shape[i] = arr.shape[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
auto shapeBuffer = shape::shapeBufferOfNpy(arr.shape.size(), shape, arr.fortranOrder);
|
|
||||||
delete[] shape;
|
|
||||||
return reinterpret_cast<Nd4jPointer>(shapeBuffer);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -64,8 +64,8 @@ void NDArray::tickWriteDevice() const { }
|
||||||
void NDArray::tickReadHost() const { }
|
void NDArray::tickReadHost() const { }
|
||||||
void NDArray::tickReadDevice() const { }
|
void NDArray::tickReadDevice() const { }
|
||||||
void NDArray::tickBothActual() const { }
|
void NDArray::tickBothActual() const { }
|
||||||
bool NDArray::isActualOnHostSide() const { }
|
bool NDArray::isActualOnHostSide() const { return true; }
|
||||||
bool NDArray::isActualOnDeviceSide() const { }
|
bool NDArray::isActualOnDeviceSide() const { return true; }
|
||||||
void NDArray::makeBothBuffersActual() const { }
|
void NDArray::makeBothBuffersActual() const { }
|
||||||
|
|
||||||
|
|
||||||
|
@ -419,328 +419,8 @@ void NDArray::repeat(int dimension, NDArray& target) const {
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
#ifndef __JAVACPP_HACK__
|
#ifndef __JAVACPP_HACK__
|
||||||
|
|
||||||
template<typename T>
|
#include "NDArrayLambda.hpp"
|
||||||
void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<T(T, T, T)>& func, NDArray* target) {
|
|
||||||
if (target == nullptr)
|
|
||||||
target = this;
|
|
||||||
|
|
||||||
if (second == nullptr) {
|
|
||||||
nd4j_printf("applyTriplewiseLambda requires three operands to be valid NDArrays, but Second is NULL\n","");
|
|
||||||
throw std::runtime_error("second is null");
|
|
||||||
}
|
|
||||||
|
|
||||||
if (third == nullptr) {
|
|
||||||
nd4j_printf("applyTriplewiseLambda requires three operands to be valid NDArrays, but Third is NULL\n","");
|
|
||||||
throw std::runtime_error("third is null");
|
|
||||||
}
|
|
||||||
if(dataType() != DataTypeUtils::fromT<T>())
|
|
||||||
throw std::runtime_error("NDArray::applyTriplewiseLambda<T> method: wrong template parameter T, its type should be the same as type of this array!");
|
|
||||||
if(dataType() != second->dataType() || dataType() != third->dataType() || dataType() != target->dataType())
|
|
||||||
throw std::runtime_error("NDArray::applyTriplewiseLambda<T> method: bother four arrays (this, second, third, target) should have the same type !");
|
|
||||||
|
|
||||||
if (this->lengthOf() != second->lengthOf() || this->lengthOf() != third->lengthOf() || !this->isSameShape(second) || !this->isSameShape(third)) {
|
|
||||||
nd4j_printf("applyPairwiseLambda requires both operands to have the same shape\n","");
|
|
||||||
throw std::runtime_error("Shapes mismach");
|
|
||||||
}
|
|
||||||
|
|
||||||
auto f = this->bufferAsT<T>();
|
|
||||||
auto s = second->bufferAsT<T>();
|
|
||||||
auto t = third->bufferAsT<T>();
|
|
||||||
auto z = target->bufferAsT<T>();
|
|
||||||
|
|
||||||
if (this->ordering() == second->ordering() && this->ordering() == third->ordering() && this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1) && this->ews() == second->ews() && this->ews() == third->ews()) {
|
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
|
||||||
for (Nd4jLong e = 0; e < _length; e++)
|
|
||||||
z[e] = func(f[e], s[e], t[e]);
|
|
||||||
} else {
|
|
||||||
if (f == z) {
|
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
|
||||||
for (int e = 0; e < _length; e++) {
|
|
||||||
|
|
||||||
auto tOffset = this->getOffset(e);
|
|
||||||
auto uOffset = second->getOffset(e);
|
|
||||||
auto vOffset = third->getOffset(e);
|
|
||||||
|
|
||||||
f[tOffset] = func(f[tOffset], s[uOffset], t[vOffset]);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
|
||||||
for (int e = 0; e < _length; e++) {
|
|
||||||
|
|
||||||
auto tOffset = this->getOffset(e);
|
|
||||||
auto uOffset = second->getOffset(e);
|
|
||||||
auto vOffset = third->getOffset(e);
|
|
||||||
auto zOffset = target->getOffset(e);
|
|
||||||
|
|
||||||
z[zOffset] = func(f[tOffset], s[uOffset], t[vOffset]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<double (double, double, double)>& func, NDArray* target);
|
|
||||||
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<float (float, float, float)>& func, NDArray* target);
|
|
||||||
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<float16 (float16, float16, float16)>& func, NDArray* target);
|
|
||||||
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<bfloat16 (bfloat16, bfloat16, bfloat16)>& func, NDArray* target);
|
|
||||||
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<Nd4jLong (Nd4jLong, Nd4jLong, Nd4jLong)>& func, NDArray* target);
|
|
||||||
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<int (int, int, int)>& func, NDArray* target);
|
|
||||||
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<int16_t (int16_t, int16_t, int16_t)>& func, NDArray* target);
|
|
||||||
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<uint8_t (uint8_t, uint8_t, uint8_t)>& func, NDArray* target);
|
|
||||||
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<uint16_t (uint16_t, uint16_t, uint16_t)>& func, NDArray* target);
|
|
||||||
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<uint32_t (uint32_t, uint32_t, uint32_t)>& func, NDArray* target);
|
|
||||||
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<uint64_t (uint64_t, uint64_t, uint64_t)>& func, NDArray* target);
|
|
||||||
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<int8_t (int8_t, int8_t, int8_t)>& func, NDArray* target);
|
|
||||||
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<bool (bool, bool, bool)>& func, NDArray* target);
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
|
||||||
template<typename T>
|
|
||||||
void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<T(T, T)>& func, NDArray* target) {
|
|
||||||
if (target == nullptr)
|
|
||||||
target = this;
|
|
||||||
|
|
||||||
if (other == nullptr) {
|
|
||||||
nd4j_printf("applyPairwiseLambda requires both operands to be valid NDArrays, but Y is NULL\n","");
|
|
||||||
throw std::runtime_error("Other is null");
|
|
||||||
}
|
|
||||||
|
|
||||||
if(dataType() != DataTypeUtils::fromT<T>())
|
|
||||||
throw std::runtime_error("NDArray::applyPairwiseLambda<T> method: wrong template parameter T, its type should be the same as type of this array!");
|
|
||||||
if(dataType() != other->dataType() || dataType() != target->dataType())
|
|
||||||
throw std::runtime_error("NDArray::applyPairwiseLambda<T> method: all three arrays (this, other, target) must have the same type !");
|
|
||||||
|
|
||||||
if (this->lengthOf() != other->lengthOf()) {
|
|
||||||
nd4j_printf("applyPairwiseLambda requires both operands to have the same shape\n","");
|
|
||||||
throw std::runtime_error("Shapes mismach");
|
|
||||||
}
|
|
||||||
|
|
||||||
auto f = this->bufferAsT<T>();
|
|
||||||
auto s = other->bufferAsT<T>();
|
|
||||||
auto z = target->bufferAsT<T>();
|
|
||||||
|
|
||||||
if (this->ordering() == other->ordering() && this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1) && this->ews() == other->ews()) {
|
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
|
||||||
for (int e = 0; e < _length; e++)
|
|
||||||
z[e] = func(f[e], s[e]);
|
|
||||||
} else {
|
|
||||||
if (f == z) {
|
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
|
||||||
for (int e = 0; e < _length; e++) {
|
|
||||||
|
|
||||||
auto xOffset = this->getOffset(e);
|
|
||||||
auto yOffset = other->getOffset(e);
|
|
||||||
|
|
||||||
f[xOffset] = func(f[xOffset], s[yOffset]);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
|
||||||
for (int e = 0; e < _length; e++) {
|
|
||||||
|
|
||||||
auto xOffset = this->getOffset(e);
|
|
||||||
auto yOffset = other->getOffset(e);
|
|
||||||
auto zOffset = target->getOffset(e);
|
|
||||||
|
|
||||||
z[zOffset] = func(f[xOffset], s[yOffset]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<double (double, double)>& func, NDArray* target);
|
|
||||||
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<float (float, float)>& func, NDArray* target);
|
|
||||||
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<float16 (float16, float16)>& func, NDArray* target);
|
|
||||||
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<bfloat16 (bfloat16, bfloat16)>& func, NDArray* target);
|
|
||||||
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<Nd4jLong (Nd4jLong, Nd4jLong)>& func, NDArray* target);
|
|
||||||
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<int (int, int)>& func, NDArray* target);
|
|
||||||
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<int16_t (int16_t, int16_t)>& func, NDArray* target);
|
|
||||||
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<uint8_t (uint8_t, uint8_t)>& func, NDArray* target);
|
|
||||||
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<uint16_t (uint16_t, uint16_t)>& func, NDArray* target);
|
|
||||||
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<uint32_t (uint32_t, uint32_t)>& func, NDArray* target);
|
|
||||||
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<uint64_t (uint64_t, uint64_t)>& func, NDArray* target);
|
|
||||||
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<int8_t (int8_t, int8_t)>& func, NDArray* target);
|
|
||||||
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<bool (bool, bool)>& func, NDArray* target);
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
|
||||||
template<typename T>
|
|
||||||
void NDArray::applyLambda(const std::function<T(T)>& func, NDArray* target) {
|
|
||||||
if (target == nullptr)
|
|
||||||
target = this;
|
|
||||||
|
|
||||||
if(dataType() != DataTypeUtils::fromT<T>())
|
|
||||||
throw std::runtime_error("NDArray::applyLambda<T> method: wrong template parameter T, its type should be the same as type of this array!");
|
|
||||||
if(dataType() != target->dataType())
|
|
||||||
throw std::runtime_error("NDArray::applyLambda<T> method: types of this and target array should match !");
|
|
||||||
|
|
||||||
auto f = this->bufferAsT<T>();
|
|
||||||
auto z = target->bufferAsT<T>();
|
|
||||||
|
|
||||||
if (this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1)) {
|
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
|
||||||
for (int e = 0; e < _length; e++)
|
|
||||||
z[e] = func(f[e]);
|
|
||||||
} else {
|
|
||||||
if (f == z) {
|
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
|
||||||
for (int e = 0; e < _length; e++) {
|
|
||||||
|
|
||||||
auto xOffset = this->getOffset(e);
|
|
||||||
|
|
||||||
f[xOffset] = func(f[xOffset]);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
|
||||||
for (int e = 0; e < _length; e++) {
|
|
||||||
|
|
||||||
auto xOffset = this->getOffset(e);
|
|
||||||
auto zOffset = target->getOffset(e);
|
|
||||||
|
|
||||||
z[zOffset] = func(f[xOffset]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
template void NDArray::applyLambda(const std::function<double(double)>& func, NDArray* target);
|
|
||||||
template void NDArray::applyLambda(const std::function<float(float)>& func, NDArray* target);
|
|
||||||
template void NDArray::applyLambda(const std::function<float16(float16)>& func, NDArray* target);
|
|
||||||
template void NDArray::applyLambda(const std::function<bfloat16(bfloat16)>& func, NDArray* target);
|
|
||||||
template void NDArray::applyLambda(const std::function<Nd4jLong(Nd4jLong)>& func, NDArray* target);
|
|
||||||
template void NDArray::applyLambda(const std::function<int16_t(int16_t)>& func, NDArray* target);
|
|
||||||
template void NDArray::applyLambda(const std::function<int32_t(int32_t)>& func, NDArray* target);
|
|
||||||
template void NDArray::applyLambda(const std::function<uint8_t(uint8_t)>& func, NDArray* target);
|
|
||||||
template void NDArray::applyLambda(const std::function<uint16_t(uint16_t)>& func, NDArray* target);
|
|
||||||
template void NDArray::applyLambda(const std::function<uint32_t(uint32_t)>& func, NDArray* target);
|
|
||||||
template void NDArray::applyLambda(const std::function<uint64_t(uint64_t)>& func, NDArray* target);
|
|
||||||
template void NDArray::applyLambda(const std::function<int8_t(int8_t)>& func, NDArray* target);
|
|
||||||
template void NDArray::applyLambda(const std::function<bool(bool)>& func, NDArray* target);
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
|
||||||
template<typename T>
|
|
||||||
void NDArray::applyIndexedLambda(const std::function<T(Nd4jLong, T)>& func, NDArray* target) {
|
|
||||||
if (target == nullptr)
|
|
||||||
target = this;
|
|
||||||
|
|
||||||
if(dataType() != DataTypeUtils::fromT<T>())
|
|
||||||
throw std::runtime_error("NDArray::applyIndexedLambda<T> method: wrong template parameter T, its type should be the same as type of this array!");
|
|
||||||
if(dataType() != target->dataType())
|
|
||||||
throw std::runtime_error("NDArray::applyIndexedLambda<T> method: types of this and target array should match !");
|
|
||||||
|
|
||||||
auto f = this->bufferAsT<T>();
|
|
||||||
auto z = target->bufferAsT<T>();
|
|
||||||
|
|
||||||
if (this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1)) {
|
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
|
||||||
for (Nd4jLong e = 0; e < _length; e++)
|
|
||||||
z[e] = func(e, f[e]);
|
|
||||||
} else {
|
|
||||||
if (f == z) {
|
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
|
||||||
for (Nd4jLong e = 0; e < _length; e++) {
|
|
||||||
|
|
||||||
auto xOffset = this->getOffset(e);
|
|
||||||
|
|
||||||
f[xOffset] = func(e, f[xOffset]);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
|
||||||
for (Nd4jLong e = 0; e < _length; e++) {
|
|
||||||
|
|
||||||
auto xOffset = this->getOffset(e);
|
|
||||||
auto zOffset = target->getOffset(e);
|
|
||||||
|
|
||||||
z[zOffset] = func(e, f[xOffset]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
template void NDArray::applyIndexedLambda(const std::function<double(Nd4jLong, double)>& func, NDArray* target);
|
|
||||||
template void NDArray::applyIndexedLambda(const std::function<float(Nd4jLong, float)>& func, NDArray* target);
|
|
||||||
template void NDArray::applyIndexedLambda(const std::function<float16(Nd4jLong, float16)>& func, NDArray* target);
|
|
||||||
template void NDArray::applyIndexedLambda(const std::function<bfloat16(Nd4jLong, bfloat16)>& func, NDArray* target);
|
|
||||||
template void NDArray::applyIndexedLambda(const std::function<Nd4jLong(Nd4jLong, Nd4jLong)>& func, NDArray* target);
|
|
||||||
template void NDArray::applyIndexedLambda(const std::function<int(Nd4jLong, int)>& func, NDArray* target);
|
|
||||||
template void NDArray::applyIndexedLambda(const std::function<int16_t(Nd4jLong, int16_t)>& func, NDArray* target);
|
|
||||||
template void NDArray::applyIndexedLambda(const std::function<uint8_t (Nd4jLong, uint8_t)>& func, NDArray* target);
|
|
||||||
template void NDArray::applyIndexedLambda(const std::function<uint16_t (Nd4jLong, uint16_t)>& func, NDArray* target);
|
|
||||||
template void NDArray::applyIndexedLambda(const std::function<uint32_t (Nd4jLong, uint32_t)>& func, NDArray* target);
|
|
||||||
template void NDArray::applyIndexedLambda(const std::function<uint64_t (Nd4jLong, uint64_t)>& func, NDArray* target);
|
|
||||||
template void NDArray::applyIndexedLambda(const std::function<int8_t(Nd4jLong, int8_t)>& func, NDArray* target);
|
|
||||||
template void NDArray::applyIndexedLambda(const std::function<bool(Nd4jLong, bool)>& func, NDArray* target);
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
|
||||||
template<typename T>
|
|
||||||
void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<T(Nd4jLong, T, T)>& func, NDArray* target) {
|
|
||||||
if (target == nullptr)
|
|
||||||
target = this;
|
|
||||||
|
|
||||||
if (other == nullptr) {
|
|
||||||
nd4j_printf("applyIndexedPairwiseLambda requires both operands to be valid NDArrays, but Y is NULL\n","");
|
|
||||||
throw std::runtime_error("Other is null");
|
|
||||||
}
|
|
||||||
if(dataType() != DataTypeUtils::fromT<T>())
|
|
||||||
throw std::runtime_error("NDArray::applyIndexedPairwiseLambda<T> method: wrong template parameter T, its type should be the same as type of this array!");
|
|
||||||
if(dataType() != target->dataType())
|
|
||||||
throw std::runtime_error("NDArray::applyIndexedPairwiseLambda<T> method: types of this and target array should match !");
|
|
||||||
if (this->lengthOf() != other->lengthOf()) {
|
|
||||||
nd4j_printf("applyIndexedPairwiseLambda requires both operands to have the same shape\n","");
|
|
||||||
throw std::runtime_error("Shapes mismach");
|
|
||||||
}
|
|
||||||
|
|
||||||
auto f = this->bufferAsT<T>();
|
|
||||||
auto s = other->bufferAsT<T>();
|
|
||||||
auto z = target->bufferAsT<T>();
|
|
||||||
|
|
||||||
if (this->ordering() == other->ordering() && this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1) && this->ews() == other->ews()) {
|
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
|
||||||
for (Nd4jLong e = 0; e < _length; e++)
|
|
||||||
z[e] = func((Nd4jLong) e, f[e], s[e]);
|
|
||||||
} else {
|
|
||||||
if (f == z) {
|
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
|
||||||
for (int e = 0; e < _length; e++) {
|
|
||||||
|
|
||||||
auto xOffset = this->getOffset(e);
|
|
||||||
auto yOffset = other->getOffset(e);
|
|
||||||
|
|
||||||
f[xOffset] = func((Nd4jLong) e, f[xOffset], s[yOffset]);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
|
||||||
for (int e = 0; e < _length; e++) {
|
|
||||||
|
|
||||||
auto xOffset = this->getOffset(e);
|
|
||||||
auto yOffset = other->getOffset(e);
|
|
||||||
auto zOffset = target->getOffset(e);
|
|
||||||
|
|
||||||
z[zOffset] = func((Nd4jLong) e, f[xOffset], s[yOffset]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<double (Nd4jLong, double, double)>& func, NDArray* target);
|
|
||||||
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<float (Nd4jLong, float, float)>& func, NDArray* target);
|
|
||||||
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<float16 (Nd4jLong, float16, float16)>& func, NDArray* target);
|
|
||||||
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<bfloat16 (Nd4jLong, bfloat16, bfloat16)>& func, NDArray* target);
|
|
||||||
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<Nd4jLong (Nd4jLong, Nd4jLong, Nd4jLong)>& func, NDArray* target);
|
|
||||||
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<int (Nd4jLong, int, int)>& func, NDArray* target);
|
|
||||||
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<int16_t (Nd4jLong, int16_t, int16_t)>& func, NDArray* target);
|
|
||||||
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<uint8_t (Nd4jLong, uint8_t, uint8_t)>& func, NDArray* target);
|
|
||||||
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<uint16_t (Nd4jLong, uint16_t, uint16_t)>& func, NDArray* target);
|
|
||||||
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<uint32_t (Nd4jLong, uint32_t, uint32_t)>& func, NDArray* target);
|
|
||||||
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<uint64_t (Nd4jLong, uint64_t, uint64_t)>& func, NDArray* target);
|
|
||||||
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<int8_t (Nd4jLong, int8_t, int8_t)>& func, NDArray* target);
|
|
||||||
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<bool (Nd4jLong, bool, bool)>& func, NDArray* target);
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/*
|
/*
|
||||||
|
|
|
@ -0,0 +1,325 @@
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<T(T, T, T)>& func, NDArray* target) {
|
||||||
|
if (target == nullptr)
|
||||||
|
target = this;
|
||||||
|
|
||||||
|
if (second == nullptr) {
|
||||||
|
nd4j_printf("applyTriplewiseLambda requires three operands to be valid NDArrays, but Second is NULL\n","");
|
||||||
|
throw std::runtime_error("second is null");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (third == nullptr) {
|
||||||
|
nd4j_printf("applyTriplewiseLambda requires three operands to be valid NDArrays, but Third is NULL\n","");
|
||||||
|
throw std::runtime_error("third is null");
|
||||||
|
}
|
||||||
|
if(dataType() != DataTypeUtils::fromT<T>())
|
||||||
|
throw std::runtime_error("NDArray::applyTriplewiseLambda<T> method: wrong template parameter T, its type should be the same as type of this array!");
|
||||||
|
if(dataType() != second->dataType() || dataType() != third->dataType() || dataType() != target->dataType())
|
||||||
|
throw std::runtime_error("NDArray::applyTriplewiseLambda<T> method: bother four arrays (this, second, third, target) should have the same type !");
|
||||||
|
|
||||||
|
if (this->lengthOf() != second->lengthOf() || this->lengthOf() != third->lengthOf() || !this->isSameShape(second) || !this->isSameShape(third)) {
|
||||||
|
nd4j_printf("applyPairwiseLambda requires both operands to have the same shape\n","");
|
||||||
|
throw std::runtime_error("Shapes mismach");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto f = this->bufferAsT<T>();
|
||||||
|
auto s = second->bufferAsT<T>();
|
||||||
|
auto t = third->bufferAsT<T>();
|
||||||
|
auto z = target->bufferAsT<T>();
|
||||||
|
|
||||||
|
if (this->ordering() == second->ordering() && this->ordering() == third->ordering() && this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1) && this->ews() == second->ews() && this->ews() == third->ews()) {
|
||||||
|
|
||||||
|
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||||||
|
for (Nd4jLong e = 0; e < _length; e++)
|
||||||
|
z[e] = func(f[e], s[e], t[e]);
|
||||||
|
} else {
|
||||||
|
if (f == z) {
|
||||||
|
|
||||||
|
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||||||
|
for (int e = 0; e < _length; e++) {
|
||||||
|
|
||||||
|
auto tOffset = this->getOffset(e);
|
||||||
|
auto uOffset = second->getOffset(e);
|
||||||
|
auto vOffset = third->getOffset(e);
|
||||||
|
|
||||||
|
f[tOffset] = func(f[tOffset], s[uOffset], t[vOffset]);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
|
||||||
|
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||||||
|
for (int e = 0; e < _length; e++) {
|
||||||
|
|
||||||
|
auto tOffset = this->getOffset(e);
|
||||||
|
auto uOffset = second->getOffset(e);
|
||||||
|
auto vOffset = third->getOffset(e);
|
||||||
|
auto zOffset = target->getOffset(e);
|
||||||
|
|
||||||
|
z[zOffset] = func(f[tOffset], s[uOffset], t[vOffset]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<double (double, double, double)>& func, NDArray* target);
|
||||||
|
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<float (float, float, float)>& func, NDArray* target);
|
||||||
|
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<float16 (float16, float16, float16)>& func, NDArray* target);
|
||||||
|
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<bfloat16 (bfloat16, bfloat16, bfloat16)>& func, NDArray* target);
|
||||||
|
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<Nd4jLong (Nd4jLong, Nd4jLong, Nd4jLong)>& func, NDArray* target);
|
||||||
|
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<int (int, int, int)>& func, NDArray* target);
|
||||||
|
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<int16_t (int16_t, int16_t, int16_t)>& func, NDArray* target);
|
||||||
|
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<uint8_t (uint8_t, uint8_t, uint8_t)>& func, NDArray* target);
|
||||||
|
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<uint16_t (uint16_t, uint16_t, uint16_t)>& func, NDArray* target);
|
||||||
|
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<uint32_t (uint32_t, uint32_t, uint32_t)>& func, NDArray* target);
|
||||||
|
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<uint64_t (uint64_t, uint64_t, uint64_t)>& func, NDArray* target);
|
||||||
|
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<int8_t (int8_t, int8_t, int8_t)>& func, NDArray* target);
|
||||||
|
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<bool (bool, bool, bool)>& func, NDArray* target);
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
template<typename T>
|
||||||
|
void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<T(T, T)>& func, NDArray* target) {
|
||||||
|
if (target == nullptr)
|
||||||
|
target = this;
|
||||||
|
|
||||||
|
if (other == nullptr) {
|
||||||
|
nd4j_printf("applyPairwiseLambda requires both operands to be valid NDArrays, but Y is NULL\n","");
|
||||||
|
throw std::runtime_error("Other is null");
|
||||||
|
}
|
||||||
|
|
||||||
|
if(dataType() != DataTypeUtils::fromT<T>())
|
||||||
|
throw std::runtime_error("NDArray::applyPairwiseLambda<T> method: wrong template parameter T, its type should be the same as type of this array!");
|
||||||
|
if(dataType() != other->dataType() || dataType() != target->dataType())
|
||||||
|
throw std::runtime_error("NDArray::applyPairwiseLambda<T> method: all three arrays (this, other, target) must have the same type !");
|
||||||
|
|
||||||
|
if (this->lengthOf() != other->lengthOf()) {
|
||||||
|
nd4j_printf("applyPairwiseLambda requires both operands to have the same shape\n","");
|
||||||
|
throw std::runtime_error("Shapes mismach");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto f = this->bufferAsT<T>();
|
||||||
|
auto s = other->bufferAsT<T>();
|
||||||
|
auto z = target->bufferAsT<T>();
|
||||||
|
|
||||||
|
if (this->ordering() == other->ordering() && this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1) && this->ews() == other->ews()) {
|
||||||
|
|
||||||
|
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||||||
|
for (int e = 0; e < _length; e++)
|
||||||
|
z[e] = func(f[e], s[e]);
|
||||||
|
} else {
|
||||||
|
if (f == z) {
|
||||||
|
|
||||||
|
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||||||
|
for (int e = 0; e < _length; e++) {
|
||||||
|
|
||||||
|
auto xOffset = this->getOffset(e);
|
||||||
|
auto yOffset = other->getOffset(e);
|
||||||
|
|
||||||
|
f[xOffset] = func(f[xOffset], s[yOffset]);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
|
||||||
|
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||||||
|
for (int e = 0; e < _length; e++) {
|
||||||
|
|
||||||
|
auto xOffset = this->getOffset(e);
|
||||||
|
auto yOffset = other->getOffset(e);
|
||||||
|
auto zOffset = target->getOffset(e);
|
||||||
|
|
||||||
|
z[zOffset] = func(f[xOffset], s[yOffset]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<double (double, double)>& func, NDArray* target);
|
||||||
|
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<float (float, float)>& func, NDArray* target);
|
||||||
|
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<float16 (float16, float16)>& func, NDArray* target);
|
||||||
|
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<bfloat16 (bfloat16, bfloat16)>& func, NDArray* target);
|
||||||
|
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<Nd4jLong (Nd4jLong, Nd4jLong)>& func, NDArray* target);
|
||||||
|
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<int (int, int)>& func, NDArray* target);
|
||||||
|
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<int16_t (int16_t, int16_t)>& func, NDArray* target);
|
||||||
|
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<uint8_t (uint8_t, uint8_t)>& func, NDArray* target);
|
||||||
|
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<uint16_t (uint16_t, uint16_t)>& func, NDArray* target);
|
||||||
|
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<uint32_t (uint32_t, uint32_t)>& func, NDArray* target);
|
||||||
|
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<uint64_t (uint64_t, uint64_t)>& func, NDArray* target);
|
||||||
|
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<int8_t (int8_t, int8_t)>& func, NDArray* target);
|
||||||
|
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<bool (bool, bool)>& func, NDArray* target);
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
template<typename T>
|
||||||
|
void NDArray::applyLambda(const std::function<T(T)>& func, NDArray* target) {
|
||||||
|
if (target == nullptr)
|
||||||
|
target = this;
|
||||||
|
|
||||||
|
if(dataType() != DataTypeUtils::fromT<T>())
|
||||||
|
throw std::runtime_error("NDArray::applyLambda<T> method: wrong template parameter T, its type should be the same as type of this array!");
|
||||||
|
if(dataType() != target->dataType())
|
||||||
|
throw std::runtime_error("NDArray::applyLambda<T> method: types of this and target array should match !");
|
||||||
|
|
||||||
|
auto f = this->bufferAsT<T>();
|
||||||
|
auto z = target->bufferAsT<T>();
|
||||||
|
|
||||||
|
if (this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1)) {
|
||||||
|
|
||||||
|
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||||||
|
for (int e = 0; e < _length; e++)
|
||||||
|
z[e] = func(f[e]);
|
||||||
|
} else {
|
||||||
|
if (f == z) {
|
||||||
|
|
||||||
|
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||||||
|
for (int e = 0; e < _length; e++) {
|
||||||
|
|
||||||
|
auto xOffset = this->getOffset(e);
|
||||||
|
|
||||||
|
f[xOffset] = func(f[xOffset]);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
|
||||||
|
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||||||
|
for (int e = 0; e < _length; e++) {
|
||||||
|
|
||||||
|
auto xOffset = this->getOffset(e);
|
||||||
|
auto zOffset = target->getOffset(e);
|
||||||
|
|
||||||
|
z[zOffset] = func(f[xOffset]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
template void NDArray::applyLambda(const std::function<double(double)>& func, NDArray* target);
|
||||||
|
template void NDArray::applyLambda(const std::function<float(float)>& func, NDArray* target);
|
||||||
|
template void NDArray::applyLambda(const std::function<float16(float16)>& func, NDArray* target);
|
||||||
|
template void NDArray::applyLambda(const std::function<bfloat16(bfloat16)>& func, NDArray* target);
|
||||||
|
template void NDArray::applyLambda(const std::function<Nd4jLong(Nd4jLong)>& func, NDArray* target);
|
||||||
|
template void NDArray::applyLambda(const std::function<int16_t(int16_t)>& func, NDArray* target);
|
||||||
|
template void NDArray::applyLambda(const std::function<int32_t(int32_t)>& func, NDArray* target);
|
||||||
|
template void NDArray::applyLambda(const std::function<uint8_t(uint8_t)>& func, NDArray* target);
|
||||||
|
template void NDArray::applyLambda(const std::function<uint16_t(uint16_t)>& func, NDArray* target);
|
||||||
|
template void NDArray::applyLambda(const std::function<uint32_t(uint32_t)>& func, NDArray* target);
|
||||||
|
template void NDArray::applyLambda(const std::function<uint64_t(uint64_t)>& func, NDArray* target);
|
||||||
|
template void NDArray::applyLambda(const std::function<int8_t(int8_t)>& func, NDArray* target);
|
||||||
|
template void NDArray::applyLambda(const std::function<bool(bool)>& func, NDArray* target);
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
template<typename T>
|
||||||
|
void NDArray::applyIndexedLambda(const std::function<T(Nd4jLong, T)>& func, NDArray* target) {
|
||||||
|
if (target == nullptr)
|
||||||
|
target = this;
|
||||||
|
|
||||||
|
if(dataType() != DataTypeUtils::fromT<T>())
|
||||||
|
throw std::runtime_error("NDArray::applyIndexedLambda<T> method: wrong template parameter T, its type should be the same as type of this array!");
|
||||||
|
if(dataType() != target->dataType())
|
||||||
|
throw std::runtime_error("NDArray::applyIndexedLambda<T> method: types of this and target array should match !");
|
||||||
|
|
||||||
|
auto f = this->bufferAsT<T>();
|
||||||
|
auto z = target->bufferAsT<T>();
|
||||||
|
|
||||||
|
if (this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1)) {
|
||||||
|
|
||||||
|
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||||||
|
for (Nd4jLong e = 0; e < _length; e++)
|
||||||
|
z[e] = func(e, f[e]);
|
||||||
|
} else {
|
||||||
|
if (f == z) {
|
||||||
|
|
||||||
|
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||||||
|
for (Nd4jLong e = 0; e < _length; e++) {
|
||||||
|
|
||||||
|
auto xOffset = this->getOffset(e);
|
||||||
|
|
||||||
|
f[xOffset] = func(e, f[xOffset]);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
|
||||||
|
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||||||
|
for (Nd4jLong e = 0; e < _length; e++) {
|
||||||
|
|
||||||
|
auto xOffset = this->getOffset(e);
|
||||||
|
auto zOffset = target->getOffset(e);
|
||||||
|
|
||||||
|
z[zOffset] = func(e, f[xOffset]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
template void NDArray::applyIndexedLambda(const std::function<double(Nd4jLong, double)>& func, NDArray* target);
|
||||||
|
template void NDArray::applyIndexedLambda(const std::function<float(Nd4jLong, float)>& func, NDArray* target);
|
||||||
|
template void NDArray::applyIndexedLambda(const std::function<float16(Nd4jLong, float16)>& func, NDArray* target);
|
||||||
|
template void NDArray::applyIndexedLambda(const std::function<bfloat16(Nd4jLong, bfloat16)>& func, NDArray* target);
|
||||||
|
template void NDArray::applyIndexedLambda(const std::function<Nd4jLong(Nd4jLong, Nd4jLong)>& func, NDArray* target);
|
||||||
|
template void NDArray::applyIndexedLambda(const std::function<int(Nd4jLong, int)>& func, NDArray* target);
|
||||||
|
template void NDArray::applyIndexedLambda(const std::function<int16_t(Nd4jLong, int16_t)>& func, NDArray* target);
|
||||||
|
template void NDArray::applyIndexedLambda(const std::function<uint8_t (Nd4jLong, uint8_t)>& func, NDArray* target);
|
||||||
|
template void NDArray::applyIndexedLambda(const std::function<uint16_t (Nd4jLong, uint16_t)>& func, NDArray* target);
|
||||||
|
template void NDArray::applyIndexedLambda(const std::function<uint32_t (Nd4jLong, uint32_t)>& func, NDArray* target);
|
||||||
|
template void NDArray::applyIndexedLambda(const std::function<uint64_t (Nd4jLong, uint64_t)>& func, NDArray* target);
|
||||||
|
template void NDArray::applyIndexedLambda(const std::function<int8_t(Nd4jLong, int8_t)>& func, NDArray* target);
|
||||||
|
template void NDArray::applyIndexedLambda(const std::function<bool(Nd4jLong, bool)>& func, NDArray* target);
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
template<typename T>
|
||||||
|
void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<T(Nd4jLong, T, T)>& func, NDArray* target) {
|
||||||
|
if (target == nullptr)
|
||||||
|
target = this;
|
||||||
|
|
||||||
|
if (other == nullptr) {
|
||||||
|
nd4j_printf("applyIndexedPairwiseLambda requires both operands to be valid NDArrays, but Y is NULL\n","");
|
||||||
|
throw std::runtime_error("Other is null");
|
||||||
|
}
|
||||||
|
if(dataType() != DataTypeUtils::fromT<T>())
|
||||||
|
throw std::runtime_error("NDArray::applyIndexedPairwiseLambda<T> method: wrong template parameter T, its type should be the same as type of this array!");
|
||||||
|
if(dataType() != target->dataType())
|
||||||
|
throw std::runtime_error("NDArray::applyIndexedPairwiseLambda<T> method: types of this and target array should match !");
|
||||||
|
if (this->lengthOf() != other->lengthOf()) {
|
||||||
|
nd4j_printf("applyIndexedPairwiseLambda requires both operands to have the same shape\n","");
|
||||||
|
throw std::runtime_error("Shapes mismach");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto f = this->bufferAsT<T>();
|
||||||
|
auto s = other->bufferAsT<T>();
|
||||||
|
auto z = target->bufferAsT<T>();
|
||||||
|
|
||||||
|
if (this->ordering() == other->ordering() && this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1) && this->ews() == other->ews()) {
|
||||||
|
|
||||||
|
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||||||
|
for (Nd4jLong e = 0; e < _length; e++)
|
||||||
|
z[e] = func((Nd4jLong) e, f[e], s[e]);
|
||||||
|
} else {
|
||||||
|
if (f == z) {
|
||||||
|
|
||||||
|
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||||||
|
for (int e = 0; e < _length; e++) {
|
||||||
|
|
||||||
|
auto xOffset = this->getOffset(e);
|
||||||
|
auto yOffset = other->getOffset(e);
|
||||||
|
|
||||||
|
f[xOffset] = func((Nd4jLong) e, f[xOffset], s[yOffset]);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
|
||||||
|
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||||||
|
for (int e = 0; e < _length; e++) {
|
||||||
|
|
||||||
|
auto xOffset = this->getOffset(e);
|
||||||
|
auto yOffset = other->getOffset(e);
|
||||||
|
auto zOffset = target->getOffset(e);
|
||||||
|
|
||||||
|
z[zOffset] = func((Nd4jLong) e, f[xOffset], s[yOffset]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<double (Nd4jLong, double, double)>& func, NDArray* target);
|
||||||
|
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<float (Nd4jLong, float, float)>& func, NDArray* target);
|
||||||
|
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<float16 (Nd4jLong, float16, float16)>& func, NDArray* target);
|
||||||
|
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<bfloat16 (Nd4jLong, bfloat16, bfloat16)>& func, NDArray* target);
|
||||||
|
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<Nd4jLong (Nd4jLong, Nd4jLong, Nd4jLong)>& func, NDArray* target);
|
||||||
|
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<int (Nd4jLong, int, int)>& func, NDArray* target);
|
||||||
|
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<int16_t (Nd4jLong, int16_t, int16_t)>& func, NDArray* target);
|
||||||
|
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<uint8_t (Nd4jLong, uint8_t, uint8_t)>& func, NDArray* target);
|
||||||
|
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<uint16_t (Nd4jLong, uint16_t, uint16_t)>& func, NDArray* target);
|
||||||
|
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<uint32_t (Nd4jLong, uint32_t, uint32_t)>& func, NDArray* target);
|
||||||
|
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<uint64_t (Nd4jLong, uint64_t, uint64_t)>& func, NDArray* target);
|
||||||
|
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<int8_t (Nd4jLong, int8_t, int8_t)>& func, NDArray* target);
|
||||||
|
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<bool (Nd4jLong, bool, bool)>& func, NDArray* target);
|
|
@ -2710,6 +2710,32 @@ int NativeOps::dataTypeFromNpyHeader(void *header) {
|
||||||
return (int) cnpy::dataTypeFromHeader(reinterpret_cast<char *>(header));
|
return (int) cnpy::dataTypeFromHeader(reinterpret_cast<char *>(header));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Nd4jPointer NativeOps::shapeBufferForNumpy(Nd4jPointer npyArray) {
|
||||||
|
cnpy::NpyArray arr = cnpy::loadNpyFromPointer(reinterpret_cast<char *>(npyArray));
|
||||||
|
unsigned int shapeSize = arr.shape.size();
|
||||||
|
std::vector<Nd4jLong> shape(shapeSize);
|
||||||
|
bool _empty = false;
|
||||||
|
for(unsigned int i = 0; i < shapeSize; i++) {
|
||||||
|
shape[i] = arr.shape[i];
|
||||||
|
|
||||||
|
if (arr.shape[i] == 0)
|
||||||
|
_empty = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto dtype = cnpy::dataTypeFromHeader(reinterpret_cast<char *>(npyArray));
|
||||||
|
|
||||||
|
Nd4jLong *shapeBuffer;
|
||||||
|
if (_empty) {
|
||||||
|
if (shapeSize > 0)
|
||||||
|
shapeBuffer = nd4j::ShapeBuilders::emptyShapeInfo(dtype, arr.fortranOrder ? 'f' : 'c', shape);
|
||||||
|
else
|
||||||
|
shapeBuffer = nd4j::ShapeBuilders::emptyShapeInfo(dtype);
|
||||||
|
} else {
|
||||||
|
shapeBuffer = nd4j::ShapeBuilders::createShapeInfo(dtype, arr.fortranOrder ? 'f' : 'c', shape);
|
||||||
|
}
|
||||||
|
return reinterpret_cast<Nd4jPointer>(shapeBuffer);
|
||||||
|
}
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void flattenGeneric,(Nd4jPointer*, int, char, void*, Nd4jLong*, void*, Nd4jLong*), LIBND4J_TYPES);
|
BUILD_SINGLE_TEMPLATE(template void flattenGeneric,(Nd4jPointer*, int, char, void*, Nd4jLong*, void*, Nd4jLong*), LIBND4J_TYPES);
|
||||||
BUILD_SINGLE_TEMPLATE(template void pullRowsGeneric, (void *, Nd4jLong*, void*, Nd4jLong*, const int, Nd4jLong*, Nd4jLong*, Nd4jLong*, Nd4jLong*, Nd4jLong*), LIBND4J_TYPES);
|
BUILD_SINGLE_TEMPLATE(template void pullRowsGeneric, (void *, Nd4jLong*, void*, Nd4jLong*, const int, Nd4jLong*, Nd4jLong*, Nd4jLong*, Nd4jLong*, Nd4jLong*), LIBND4J_TYPES);
|
||||||
BUILD_SINGLE_TEMPLATE(template void tearGeneric, (void *, Nd4jLong*, Nd4jPointer*, Nd4jLong*, Nd4jLong*, Nd4jLong*), LIBND4J_TYPES);
|
BUILD_SINGLE_TEMPLATE(template void tearGeneric, (void *, Nd4jLong*, Nd4jPointer*, Nd4jLong*, Nd4jLong*, Nd4jLong*), LIBND4J_TYPES);
|
||||||
|
|
|
@ -454,7 +454,7 @@ void NDArray::printCurrentBuffer(const bool host, const char* msg, const int pre
|
||||||
|
|
||||||
if (ews() != 1) {
|
if (ews() != 1) {
|
||||||
for (uint i = 0; i < _length; i++)
|
for (uint i = 0; i < _length; i++)
|
||||||
cudaMemcpyAsync(pHost + i * sizeof(T), getSpecialBuffer() + getOffset(i) * sizeof(T), sizeof(T), cudaMemcpyDeviceToHost, *(getContext()->getCudaStream()));
|
cudaMemcpyAsync(reinterpret_cast<T*>(pHost) + i, specialBufferWithOffset(i), sizeof(T), cudaMemcpyDeviceToHost, *(getContext()->getCudaStream()));
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
cudaMemcpyAsync(pHost, getSpecialBuffer(), sizeOfT() * _length, cudaMemcpyDeviceToHost, *getContext()->getCudaStream());
|
cudaMemcpyAsync(pHost, getSpecialBuffer(), sizeOfT() * _length, cudaMemcpyDeviceToHost, *getContext()->getCudaStream());
|
||||||
|
@ -475,6 +475,12 @@ template void NDArray::printCurrentBuffer<float>(const bool host, const char* ms
|
||||||
template void NDArray::printCurrentBuffer<double>(const bool host, const char* msg, const int precision) const;
|
template void NDArray::printCurrentBuffer<double>(const bool host, const char* msg, const int precision) const;
|
||||||
|
|
||||||
|
|
||||||
|
#if defined(__CUDACC__) && !defined(BUILD_TESTS)
|
||||||
|
|
||||||
|
#include <cpu/NDArrayLambda.hpp>
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
} // end namespace nd4j
|
} // end namespace nd4j
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
|
@ -3105,3 +3105,29 @@ nd4j::ConstantDataBuffer* NativeOps::constantBuffer(nd4j::DataType dtype, double
|
||||||
nd4j::ConstantDataBuffer* NativeOps::constantBuffer(nd4j::DataType dtype, nd4j::ConstantDescriptor *descriptor) {
|
nd4j::ConstantDataBuffer* NativeOps::constantBuffer(nd4j::DataType dtype, nd4j::ConstantDescriptor *descriptor) {
|
||||||
return nd4j::ConstantHelper::getInstance()->constantBuffer(*descriptor, dtype);
|
return nd4j::ConstantHelper::getInstance()->constantBuffer(*descriptor, dtype);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Nd4jPointer NativeOps::shapeBufferForNumpy(Nd4jPointer npyArray) {
|
||||||
|
cnpy::NpyArray arr = cnpy::loadNpyFromPointer(reinterpret_cast<char *>(npyArray));
|
||||||
|
unsigned int shapeSize = arr.shape.size();
|
||||||
|
std::vector<Nd4jLong> shape(shapeSize);
|
||||||
|
bool _empty = false;
|
||||||
|
for(unsigned int i = 0; i < shapeSize; i++) {
|
||||||
|
shape[i] = arr.shape[i];
|
||||||
|
|
||||||
|
if (arr.shape[i] == 0)
|
||||||
|
_empty = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto dtype = cnpy::dataTypeFromHeader(reinterpret_cast<char *>(npyArray));
|
||||||
|
|
||||||
|
Nd4jLong *shapeBuffer;
|
||||||
|
if (_empty) {
|
||||||
|
if (shapeSize > 0)
|
||||||
|
shapeBuffer = nd4j::ShapeBuilders::emptyShapeInfo(dtype, arr.fortranOrder ? 'f' : 'c', shape);
|
||||||
|
else
|
||||||
|
shapeBuffer = nd4j::ShapeBuilders::emptyShapeInfo(dtype);
|
||||||
|
} else {
|
||||||
|
shapeBuffer = nd4j::ShapeBuilders::createShapeInfo(dtype, arr.fortranOrder ? 'f' : 'c', shape);
|
||||||
|
}
|
||||||
|
return reinterpret_cast<Nd4jPointer>(shapeBuffer);
|
||||||
|
}
|
||||||
|
|
|
@ -53,6 +53,15 @@ namespace nd4j {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
FORCEINLINE static _CUDA_HD T max();
|
FORCEINLINE static _CUDA_HD T max();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* returns inf for float/double and max for everything else
|
||||||
|
*/
|
||||||
|
template <typename T>
|
||||||
|
FORCEINLINE static _CUDA_HD T infOrMax();
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
FORCEINLINE static _CUDA_HD T nanOrZero();
|
||||||
|
|
||||||
// returns the difference between 1.0 and the next representable value of the given floating-point type
|
// returns the difference between 1.0 and the next representable value of the given floating-point type
|
||||||
template <typename T>
|
template <typename T>
|
||||||
FORCEINLINE static T eps();
|
FORCEINLINE static T eps();
|
||||||
|
@ -290,6 +299,36 @@ FORCEINLINE _CUDA_HD bfloat16 DataTypeUtils::max<bfloat16>() {
|
||||||
return bfloat16::max();
|
return bfloat16::max();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
FORCEINLINE _CUDA_HD float DataTypeUtils::infOrMax<float>() {
|
||||||
|
return std::numeric_limits<float>::infinity();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
FORCEINLINE _CUDA_HD double DataTypeUtils::infOrMax<double>() {
|
||||||
|
return std::numeric_limits<double>::infinity();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
FORCEINLINE _CUDA_HD T DataTypeUtils::infOrMax() {
|
||||||
|
return DataTypeUtils::max<T>();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
FORCEINLINE _CUDA_HD float DataTypeUtils::nanOrZero<float>() {
|
||||||
|
return std::numeric_limits<float>::quiet_NaN();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
FORCEINLINE _CUDA_HD double DataTypeUtils::nanOrZero<double>() {
|
||||||
|
return std::numeric_limits<double>::quiet_NaN();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
FORCEINLINE _CUDA_HD T DataTypeUtils::nanOrZero() {
|
||||||
|
return static_cast<T>(0);
|
||||||
|
}
|
||||||
|
|
||||||
FORCEINLINE std::string DataTypeUtils::asString(DataType dataType) {
|
FORCEINLINE std::string DataTypeUtils::asString(DataType dataType) {
|
||||||
switch(dataType) {
|
switch(dataType) {
|
||||||
case INT8:
|
case INT8:
|
||||||
|
|
|
@ -55,8 +55,13 @@ bool ShapeDescriptor::operator<(const ShapeDescriptor& other) const {
|
||||||
}
|
}
|
||||||
|
|
||||||
Nd4jLong* ShapeDescriptor::toShapeInfo() const {
|
Nd4jLong* ShapeDescriptor::toShapeInfo() const {
|
||||||
if (_empty)
|
if (_empty) {
|
||||||
return ShapeBuilders::emptyShapeInfo(_dataType);
|
if (_rank == 0)
|
||||||
|
return ShapeBuilders::emptyShapeInfo(_dataType);
|
||||||
|
else {
|
||||||
|
return ShapeBuilders::emptyShapeInfo(_dataType, _order, _shape);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
switch (_rank) {
|
switch (_rank) {
|
||||||
|
@ -133,15 +138,11 @@ ShapeDescriptor::ShapeDescriptor(const DataType type, const char order, const Nd
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
ShapeDescriptor::ShapeDescriptor(const DataType type, const char order, const std::vector<Nd4jLong> &shape): _dataType(type), _order(order), _shape(shape) {
|
ShapeDescriptor::ShapeDescriptor(const DataType type, const char order, const std::vector<Nd4jLong> &shape): _dataType(type), _order(order), _shape(shape) {
|
||||||
_rank = ((shape.size() == 1 && shape[0] == 0)? 0: shape.size());
|
_rank = shape.size();
|
||||||
_ews = 1;
|
_ews = 1;
|
||||||
|
|
||||||
if (_rank > 0) {
|
if (_rank > 0) {
|
||||||
_strides.resize(_rank);
|
_strides.resize(_rank);
|
||||||
if (order == 'c')
|
|
||||||
shape::calcStrides(_shape.data(), shape.size(), _strides.data());
|
|
||||||
else
|
|
||||||
shape::calcStridesFortran(_shape.data(), shape.size(), _strides.data());
|
|
||||||
|
|
||||||
for (auto v:_shape) {
|
for (auto v:_shape) {
|
||||||
if (v == 0) {
|
if (v == 0) {
|
||||||
|
@ -149,6 +150,17 @@ ShapeDescriptor::ShapeDescriptor(const DataType type, const char order, const st
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// no point calculating strides for empty arrays
|
||||||
|
if (!_empty) {
|
||||||
|
if (order == 'c')
|
||||||
|
shape::calcStrides(_shape.data(), shape.size(), _strides.data());
|
||||||
|
else
|
||||||
|
shape::calcStridesFortran(_shape.data(), shape.size(), _strides.data());
|
||||||
|
} else {
|
||||||
|
// all strides set to 0
|
||||||
|
memset(_strides.data(), 0, sizeof(Nd4jLong) * shape.size());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -191,8 +203,11 @@ ShapeDescriptor::ShapeDescriptor(const Nd4jLong *shapeInfo, bool inheritDtype) {
|
||||||
|
|
||||||
_empty = shape::isEmpty(shapeInfo);
|
_empty = shape::isEmpty(shapeInfo);
|
||||||
|
|
||||||
for (int e = 0; e < _rank; e++)
|
for (int e = 0; e < _rank; e++) {
|
||||||
_shape.emplace_back(shapeInfo[e + 1]);
|
_shape.emplace_back(shapeInfo[e + 1]);
|
||||||
|
if (shapeInfo[e + 1] == 0)
|
||||||
|
_empty = true;
|
||||||
|
}
|
||||||
|
|
||||||
for (int e = 0; e < _rank; e++)
|
for (int e = 0; e < _rank; e++)
|
||||||
_strides.emplace_back(shapeInfo[e + 1 + _rank]);
|
_strides.emplace_back(shapeInfo[e + 1 + _rank]);
|
||||||
|
@ -304,7 +319,14 @@ ShapeDescriptor ShapeDescriptor::vectorDescriptor(const Nd4jLong length, const D
|
||||||
ShapeDescriptor descriptor;
|
ShapeDescriptor descriptor;
|
||||||
descriptor._dataType = type;
|
descriptor._dataType = type;
|
||||||
descriptor._shape.emplace_back(length);
|
descriptor._shape.emplace_back(length);
|
||||||
descriptor._strides.emplace_back(1);
|
|
||||||
|
if (length > 0)
|
||||||
|
descriptor._strides.emplace_back(1);
|
||||||
|
else {
|
||||||
|
descriptor._strides.emplace_back(0);
|
||||||
|
descriptor._empty = true;
|
||||||
|
}
|
||||||
|
|
||||||
descriptor._order = 'c';
|
descriptor._order = 'c';
|
||||||
descriptor._ews = 1;
|
descriptor._ews = 1;
|
||||||
descriptor._rank = 1;
|
descriptor._rank = 1;
|
||||||
|
|
|
@ -29,7 +29,7 @@
|
||||||
#include <array/ArrayOptions.h>
|
#include <array/ArrayOptions.h>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
class ShapeBuilders {
|
class ND4J_EXPORT ShapeBuilders {
|
||||||
public:
|
public:
|
||||||
static Nd4jLong* createScalarShapeInfo(nd4j::DataType dataType, nd4j::memory::Workspace* workspace = nullptr);
|
static Nd4jLong* createScalarShapeInfo(nd4j::DataType dataType, nd4j::memory::Workspace* workspace = nullptr);
|
||||||
|
|
||||||
|
@ -53,6 +53,8 @@ namespace nd4j {
|
||||||
|
|
||||||
static Nd4jLong* emptyShapeInfo(const nd4j::DataType dataType, memory::Workspace* workspace = nullptr);
|
static Nd4jLong* emptyShapeInfo(const nd4j::DataType dataType, memory::Workspace* workspace = nullptr);
|
||||||
|
|
||||||
|
static Nd4jLong* emptyShapeInfo(const nd4j::DataType dataType, const char order, const std::vector<Nd4jLong> &shape, memory::Workspace* workspace = nullptr);
|
||||||
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -40,6 +40,12 @@ namespace nd4j {
|
||||||
static Nd4jLong* evalReduceShapeInfo(const char order, std::vector<int>& dimensions, const NDArray& arr, const bool keepDims = false, const bool supportOldShapes = false, nd4j::memory::Workspace* workspace = nullptr);
|
static Nd4jLong* evalReduceShapeInfo(const char order, std::vector<int>& dimensions, const NDArray& arr, const bool keepDims = false, const bool supportOldShapes = false, nd4j::memory::Workspace* workspace = nullptr);
|
||||||
static Nd4jLong* evalReduceShapeInfo(const char order, std::vector<int>& dimensions, const Nd4jLong* shapeInfo, const bool keepDims = false, const bool supportOldShapes = false, nd4j::memory::Workspace* workspace = nullptr);
|
static Nd4jLong* evalReduceShapeInfo(const char order, std::vector<int>& dimensions, const Nd4jLong* shapeInfo, const bool keepDims = false, const bool supportOldShapes = false, nd4j::memory::Workspace* workspace = nullptr);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* evaluate output shape for reduce operation when input shape is empty
|
||||||
|
* behavior is analogous to tf
|
||||||
|
*/
|
||||||
|
static Nd4jLong* evalReduceShapeInfoEmpty(const char order, std::vector<int>& dimensions, const Nd4jLong *shapeInfo, const nd4j::DataType dataType, const bool keepDims, nd4j::memory::Workspace* workspace);
|
||||||
|
|
||||||
// evaluate shape for array which is result of repeat operation applied to arr
|
// evaluate shape for array which is result of repeat operation applied to arr
|
||||||
static std::vector<Nd4jLong> evalRepeatShape(int dimension, const std::vector<Nd4jLong>& repeats, const NDArray& arr);
|
static std::vector<Nd4jLong> evalRepeatShape(int dimension, const std::vector<Nd4jLong>& repeats, const NDArray& arr);
|
||||||
|
|
||||||
|
|
|
@ -71,7 +71,9 @@ namespace nd4j {
|
||||||
auto sPtr = new Nd4jLong[shape::shapeInfoLength(subArrRank)];
|
auto sPtr = new Nd4jLong[shape::shapeInfoLength(subArrRank)];
|
||||||
auto oPtr = new Nd4jLong[numOfSubArrs];
|
auto oPtr = new Nd4jLong[numOfSubArrs];
|
||||||
|
|
||||||
shape::calcSubArrShapeAndOffsets(shapeInfo, numOfSubArrs, dimsToExclude.size(), dimsToExclude.data(), sPtr, oPtr, descriptor.areUnitiesinShape());
|
if (numOfSubArrs > 0)
|
||||||
|
shape::calcSubArrShapeAndOffsets(shapeInfo, numOfSubArrs, dimsToExclude.size(), dimsToExclude.data(), sPtr, oPtr, descriptor.areUnitiesinShape());
|
||||||
|
|
||||||
|
|
||||||
ConstantDataBuffer shapesBuffer(sPtr, nullptr, shape::shapeInfoLength(subArrRank)*sizeof(Nd4jLong), DataType::INT64);
|
ConstantDataBuffer shapesBuffer(sPtr, nullptr, shape::shapeInfoLength(subArrRank)*sizeof(Nd4jLong), DataType::INT64);
|
||||||
ConstantDataBuffer offsetsBuffer(oPtr, nullptr, numOfSubArrs*sizeof(Nd4jLong), DataType::INT64);
|
ConstantDataBuffer offsetsBuffer(oPtr, nullptr, numOfSubArrs*sizeof(Nd4jLong), DataType::INT64);
|
||||||
|
|
|
@ -75,7 +75,8 @@ namespace nd4j {
|
||||||
auto sPtr = new Nd4jLong[shape::shapeInfoLength(subArrRank)];
|
auto sPtr = new Nd4jLong[shape::shapeInfoLength(subArrRank)];
|
||||||
auto oPtr = new Nd4jLong[numOfSubArrs];
|
auto oPtr = new Nd4jLong[numOfSubArrs];
|
||||||
|
|
||||||
shape::calcSubArrShapeAndOffsets(shapeInfo, numOfSubArrs, dimsToExclude.size(), dimsToExclude.data(), sPtr, oPtr, descriptor.areUnitiesinShape());
|
if (numOfSubArrs > 0)
|
||||||
|
shape::calcSubArrShapeAndOffsets(shapeInfo, numOfSubArrs, dimsToExclude.size(), dimsToExclude.data(), sPtr, oPtr, descriptor.areUnitiesinShape());
|
||||||
|
|
||||||
Nd4jPointer soPtr;
|
Nd4jPointer soPtr;
|
||||||
auto res = cudaMalloc(reinterpret_cast<void**>(&soPtr), numOfSubArrs * sizeof(Nd4jLong));
|
auto res = cudaMalloc(reinterpret_cast<void**>(&soPtr), numOfSubArrs * sizeof(Nd4jLong));
|
||||||
|
|
|
@ -54,11 +54,6 @@ namespace nd4j {
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
Nd4jLong* ShapeBuilders::createShapeInfo(const nd4j::DataType dataType, const char order, int rank, const Nd4jLong* shapeOnly, memory::Workspace* workspace) {
|
Nd4jLong* ShapeBuilders::createShapeInfo(const nd4j::DataType dataType, const char order, int rank, const Nd4jLong* shapeOnly, memory::Workspace* workspace) {
|
||||||
|
|
||||||
if (rank)
|
|
||||||
if(shapeOnly[0] == 0) // scalar case
|
|
||||||
rank = 0;
|
|
||||||
|
|
||||||
Nd4jLong* shapeInfo = nullptr;
|
Nd4jLong* shapeInfo = nullptr;
|
||||||
|
|
||||||
if(rank == 0) { // scalar case
|
if(rank == 0) { // scalar case
|
||||||
|
@ -67,10 +62,23 @@ namespace nd4j {
|
||||||
else {
|
else {
|
||||||
ALLOCATE(shapeInfo, workspace, shape::shapeInfoLength(rank), Nd4jLong);
|
ALLOCATE(shapeInfo, workspace, shape::shapeInfoLength(rank), Nd4jLong);
|
||||||
shapeInfo[0] = rank;
|
shapeInfo[0] = rank;
|
||||||
for(int i = 0; i < rank; ++i)
|
bool isEmpty = false;
|
||||||
|
for(int i = 0; i < rank; ++i) {
|
||||||
shapeInfo[i + 1] = shapeOnly[i];
|
shapeInfo[i + 1] = shapeOnly[i];
|
||||||
|
|
||||||
shape::updateStrides(shapeInfo, order);
|
if (shapeOnly[i] == 0)
|
||||||
|
isEmpty = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!isEmpty) {
|
||||||
|
shape::updateStrides(shapeInfo, order);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
shapeInfo[shape::shapeInfoLength(rank) - 1] = order;
|
||||||
|
memset(shape::stride(shapeInfo), 0, rank * sizeof(Nd4jLong));
|
||||||
|
ArrayOptions::setPropertyBit(shapeInfo, ARRAY_EMPTY);
|
||||||
|
}
|
||||||
|
|
||||||
nd4j::ArrayOptions::setDataType(shapeInfo, dataType);
|
nd4j::ArrayOptions::setDataType(shapeInfo, dataType);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -78,9 +86,16 @@ namespace nd4j {
|
||||||
}
|
}
|
||||||
|
|
||||||
Nd4jLong* ShapeBuilders::emptyShapeInfo(const nd4j::DataType dataType, memory::Workspace* workspace) {
|
Nd4jLong* ShapeBuilders::emptyShapeInfo(const nd4j::DataType dataType, memory::Workspace* workspace) {
|
||||||
auto shape = createScalarShapeInfo(dataType, workspace);
|
auto shapeInfo = createScalarShapeInfo(dataType, workspace);
|
||||||
ArrayOptions::setPropertyBit(shape, ARRAY_EMPTY);
|
ArrayOptions::setPropertyBit(shapeInfo, ARRAY_EMPTY);
|
||||||
return shape;
|
return shapeInfo;
|
||||||
|
}
|
||||||
|
|
||||||
|
Nd4jLong* ShapeBuilders::emptyShapeInfo(const nd4j::DataType dataType, const char order, const std::vector<Nd4jLong> &shape, memory::Workspace* workspace) {
|
||||||
|
auto shapeInfo = createShapeInfo(dataType, order, shape, workspace);
|
||||||
|
memset(shape::stride(shapeInfo), 0, shape.size() * sizeof(Nd4jLong));
|
||||||
|
ArrayOptions::setPropertyBit(shapeInfo, ARRAY_EMPTY);
|
||||||
|
return shapeInfo;
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
|
@ -108,27 +108,81 @@ std::vector<Nd4jLong> ShapeUtils::evalShapeForTensorDot(const NDArray* a, cons
|
||||||
return evalShapeForTensorDot(a->getShapeInfo(), b->getShapeInfo(), axesA, axesB, permutAt, permutBt, shapeAt, shapeBt);
|
return evalShapeForTensorDot(a->getShapeInfo(), b->getShapeInfo(), axesA, axesB, permutAt, permutBt, shapeAt, shapeBt);
|
||||||
}
|
}
|
||||||
|
|
||||||
Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector<int>& dimensions, const NDArray& arr, const bool keepDims, const bool supportOldShapes, nd4j::memory::Workspace* workspace) {
|
|
||||||
return evalReduceShapeInfo(order, dimensions, arr, arr.dataType(), keepDims, supportOldShapes, workspace);
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
// evaluate output shape for reduce operation when input shape is empty
|
||||||
|
Nd4jLong* ShapeUtils::evalReduceShapeInfoEmpty(const char order, std::vector<int>& dimsToExclude, const Nd4jLong *shapeInfo, const nd4j::DataType dataType, const bool keepDims, nd4j::memory::Workspace* workspace) {
|
||||||
|
|
||||||
|
if (dimsToExclude.size() == 0) { // return copy of input shape
|
||||||
|
Nd4jLong* outShapeInfo = ShapeBuilders::copyShapeInfoAndType(shapeInfo, dataType, true, workspace);
|
||||||
|
ShapeDescriptor descriptor(outShapeInfo, dataType);
|
||||||
|
RELEASE(outShapeInfo, workspace);
|
||||||
|
return ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
|
||||||
|
}
|
||||||
|
|
||||||
|
const int rank = shape::rank(shapeInfo);
|
||||||
|
Nd4jLong* outShapeInfo = nullptr;
|
||||||
|
|
||||||
|
if (dimsToExclude.size() == rank) { // return scalar or shape filled with unities
|
||||||
|
|
||||||
|
if(!keepDims)
|
||||||
|
outShapeInfo = ShapeBuilders::createScalarShapeInfo(dataType, workspace);
|
||||||
|
else
|
||||||
|
outShapeInfo = ShapeBuilders::createShapeInfo(dataType, order, std::vector<Nd4jLong>(rank, 1), workspace);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
|
||||||
|
shape::checkDimensions(rank, dimsToExclude);
|
||||||
|
|
||||||
|
std::vector<Nd4jLong> outShape;
|
||||||
|
|
||||||
|
if(keepDims) {
|
||||||
|
outShape.assign(shapeInfo + 1, shapeInfo + 1 + rank);
|
||||||
|
for(const auto& dim : dimsToExclude)
|
||||||
|
outShape[dim] = 1;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
for (uint i = 0, j = 0; i < rank; ++i) {
|
||||||
|
if(j < dimsToExclude.size() && i == dimsToExclude[j])
|
||||||
|
++j;
|
||||||
|
else
|
||||||
|
outShape.emplace_back(shapeInfo[i + 1]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
outShapeInfo = ShapeBuilders::createShapeInfo(dataType, order, outShape, workspace);
|
||||||
|
}
|
||||||
|
|
||||||
|
ShapeDescriptor descriptor(outShapeInfo, dataType);
|
||||||
|
RELEASE(outShapeInfo, workspace);
|
||||||
|
return ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
|
||||||
}
|
}
|
||||||
|
|
||||||
Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector<int>& dimensions, const Nd4jLong* shapeInfo, const bool keepDims, const bool supportOldShapes, nd4j::memory::Workspace* workspace) {
|
Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector<int>& dimsToExclude, const NDArray& arr, const bool keepDims, const bool supportOldShapes, nd4j::memory::Workspace* workspace) {
|
||||||
return evalReduceShapeInfo(order, dimensions, shapeInfo, ArrayOptions::dataType(shapeInfo), keepDims, supportOldShapes, workspace);
|
return evalReduceShapeInfo(order, dimsToExclude, arr, arr.dataType(), keepDims, supportOldShapes, workspace);
|
||||||
|
}
|
||||||
|
|
||||||
|
Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector<int>& dimsToExclude, const Nd4jLong* shapeInfo, const bool keepDims, const bool supportOldShapes, nd4j::memory::Workspace* workspace) {
|
||||||
|
return evalReduceShapeInfo(order, dimsToExclude, shapeInfo, ArrayOptions::dataType(shapeInfo), keepDims, supportOldShapes, workspace);
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector<int>& dimensions, const NDArray& arr, const nd4j::DataType dataType, const bool keepDims, const bool supportOldShapes, nd4j::memory::Workspace* workspace) {
|
Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector<int>& dimsToExclude, const NDArray& arr, const nd4j::DataType dataType, const bool keepDims, const bool supportOldShapes, nd4j::memory::Workspace* workspace) {
|
||||||
return evalReduceShapeInfo(order, dimensions, arr.getShapeInfo(), dataType, keepDims, supportOldShapes, workspace);
|
return evalReduceShapeInfo(order, dimsToExclude, arr.getShapeInfo(), dataType, keepDims, supportOldShapes, workspace);
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
// evaluate shape resulting from reduce operation
|
// evaluate shape resulting from reduce operation
|
||||||
Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector<int>& dimensions, const Nd4jLong *shapeInfo, const nd4j::DataType dataType, const bool keepDims, const bool supportOldShapes, nd4j::memory::Workspace* workspace) {
|
Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector<int>& dimsToExclude, const Nd4jLong *shapeInfo, const nd4j::DataType dataType, const bool keepDims, const bool supportOldShapes, nd4j::memory::Workspace* workspace) {
|
||||||
|
|
||||||
|
if(ArrayOptions::arrayType(shapeInfo) == ArrayType::EMPTY)
|
||||||
|
return ShapeUtils::evalReduceShapeInfoEmpty(order, dimsToExclude, shapeInfo, dataType, keepDims, workspace);
|
||||||
|
|
||||||
Nd4jLong* newShapeInfo = nullptr;
|
Nd4jLong* newShapeInfo = nullptr;
|
||||||
|
|
||||||
int rank = shape::rank(const_cast<Nd4jLong*>(shapeInfo));
|
int rank = shape::rank(const_cast<Nd4jLong*>(shapeInfo));
|
||||||
|
|
||||||
if (dimensions.size() == 0) { // return scalar or array with len=1 in this case
|
if (dimsToExclude.size() == 0) { // return scalar or array with len=1 in this case
|
||||||
|
|
||||||
if(keepDims && rank > 1) {
|
if(keepDims && rank > 1) {
|
||||||
ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(rank), Nd4jLong);
|
ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(rank), Nd4jLong);
|
||||||
|
@ -157,16 +211,16 @@ Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector<int>& di
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
shape::checkDimensions(rank, dimensions);
|
shape::checkDimensions(rank, dimsToExclude);
|
||||||
|
|
||||||
int dimSize = dimensions.size();
|
int dimSize = dimsToExclude.size();
|
||||||
|
|
||||||
if(keepDims) {
|
if(keepDims) {
|
||||||
|
|
||||||
ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(rank), Nd4jLong);
|
ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(rank), Nd4jLong);
|
||||||
newShapeInfo[0] = rank;
|
newShapeInfo[0] = rank;
|
||||||
for(int i = 0; i < rank; ++i)
|
for(int i = 0; i < rank; ++i)
|
||||||
if (std::binary_search(dimensions.begin(), dimensions.end(), i)) // dimensions is already sorted after shape::checkDimensions() has been applied
|
if (std::binary_search(dimsToExclude.begin(), dimsToExclude.end(), i)) // dimsToExclude is already sorted after shape::checkDimensions() has been applied
|
||||||
newShapeInfo[i+1] = 1;
|
newShapeInfo[i+1] = 1;
|
||||||
else
|
else
|
||||||
newShapeInfo[i+1] = shapeInfo[i+1];
|
newShapeInfo[i+1] = shapeInfo[i+1];
|
||||||
|
@ -178,7 +232,7 @@ Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector<int>& di
|
||||||
}
|
}
|
||||||
|
|
||||||
int newRank = rank - dimSize;
|
int newRank = rank - dimSize;
|
||||||
if (newRank==0 || (dimSize==1 && dimensions[0]==INT_MAX)) { // check whether given dimension is meant for the whole dimension
|
if (newRank==0 || (dimSize==1 && dimsToExclude[0]==INT_MAX)) { // check whether given dimension is meant for the whole dimension
|
||||||
|
|
||||||
if(supportOldShapes) {
|
if(supportOldShapes) {
|
||||||
ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(2), Nd4jLong);
|
ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(2), Nd4jLong);
|
||||||
|
@ -199,7 +253,7 @@ Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector<int>& di
|
||||||
newShapeInfo[0] = newRank; // set rank
|
newShapeInfo[0] = newRank; // set rank
|
||||||
int j=1;
|
int j=1;
|
||||||
for(int i = 0; i < rank; ++i)
|
for(int i = 0; i < rank; ++i)
|
||||||
if (!std::binary_search(dimensions.begin(), dimensions.end(), i)) // dimensions is already sorted after shape::checkDimensions() has been applied
|
if (!std::binary_search(dimsToExclude.begin(), dimsToExclude.end(), i)) // dimsToExclude is already sorted after shape::checkDimensions() has been applied
|
||||||
newShapeInfo[j++] = shapeInfo[i+1];
|
newShapeInfo[j++] = shapeInfo[i+1];
|
||||||
|
|
||||||
//ensure whether vector has proper shape for old shape type
|
//ensure whether vector has proper shape for old shape type
|
||||||
|
@ -208,7 +262,7 @@ Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector<int>& di
|
||||||
RELEASE(newShapeInfo, workspace);
|
RELEASE(newShapeInfo, workspace);
|
||||||
ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(2), Nd4jLong); // set newRank = 2
|
ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(2), Nd4jLong); // set newRank = 2
|
||||||
newShapeInfo[0] = 2;
|
newShapeInfo[0] = 2;
|
||||||
if (dimensions[0] == 0) {
|
if (dimsToExclude[0] == 0) {
|
||||||
newShapeInfo[1] = 1;
|
newShapeInfo[1] = 1;
|
||||||
newShapeInfo[2] = oldValue;
|
newShapeInfo[2] = oldValue;
|
||||||
}
|
}
|
||||||
|
@ -422,8 +476,23 @@ bool ShapeUtils::evalBroadcastShapeInfo(Nd4jLong *max, Nd4jLong *min, const bool
|
||||||
if(maxShapeInfo[maxRank-i] < minShapeInfo[minRank-i])
|
if(maxShapeInfo[maxRank-i] < minShapeInfo[minRank-i])
|
||||||
tmpShapeInfo[maxRank - i] = minShapeInfo[minRank-i];
|
tmpShapeInfo[maxRank - i] = minShapeInfo[minRank-i];
|
||||||
|
|
||||||
|
// nullify zero axis
|
||||||
|
for (int e = 0; e < maxRank; e++)
|
||||||
|
if (maxShapeInfo[e+1] == 0)
|
||||||
|
tmpShapeInfo[e+1] = 0;
|
||||||
|
|
||||||
|
int delta = maxRank - minRank;
|
||||||
|
for (int e = minRank - 1; e >= 0; e--)
|
||||||
|
if (minShapeInfo[e + 1] == 0)
|
||||||
|
tmpShapeInfo[e + 1 + delta] = 0;
|
||||||
|
|
||||||
ShapeUtils::updateStridesAndType(tmpShapeInfo, DataTypeUtils::pickPairwiseResultType(maxShapeInfo, minShapeInfo), shape::order(maxShapeInfo));
|
ShapeUtils::updateStridesAndType(tmpShapeInfo, DataTypeUtils::pickPairwiseResultType(maxShapeInfo, minShapeInfo), shape::order(maxShapeInfo));
|
||||||
|
|
||||||
|
if (shape::isEmpty(max) || shape::isEmpty(min)) {
|
||||||
|
ArrayOptions::setPropertyBit(tmpShapeInfo, ARRAY_EMPTY);
|
||||||
|
memset(shape::stride(tmpShapeInfo), 0, shape::rank(tmpShapeInfo) * sizeof(Nd4jLong));
|
||||||
|
}
|
||||||
|
|
||||||
ShapeDescriptor descriptor(tmpShapeInfo);
|
ShapeDescriptor descriptor(tmpShapeInfo);
|
||||||
RELEASE(tmpShapeInfo, workspace);
|
RELEASE(tmpShapeInfo, workspace);
|
||||||
resultShapeInfo = ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
|
resultShapeInfo = ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
|
||||||
|
@ -805,7 +874,7 @@ std::vector<Nd4jLong> ShapeUtils::evalShapeForMatmul(const Nd4jLong* xShapeInfo,
|
||||||
nd4j_printf("ShapeUtils::evalShapeForMatmul method: since input arrays are vectors they must have the same length, but got x length = %i, y length = %i !", xShapeInfo[1], yShapeInfo[1]);
|
nd4j_printf("ShapeUtils::evalShapeForMatmul method: since input arrays are vectors they must have the same length, but got x length = %i, y length = %i !", xShapeInfo[1], yShapeInfo[1]);
|
||||||
throw std::invalid_argument("");
|
throw std::invalid_argument("");
|
||||||
}
|
}
|
||||||
return std::vector<Nd4jLong>({0});
|
return std::vector<Nd4jLong>({});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1992,7 +1992,7 @@ template <typename T>
|
||||||
len = shape::length(shapeInfo);
|
len = shape::length(shapeInfo);
|
||||||
|
|
||||||
//check whether shape is like {1} or {1,1} or {1,1,1,1,...} - in this case we don't need permute
|
//check whether shape is like {1} or {1,1} or {1,1,1,1,...} - in this case we don't need permute
|
||||||
if(len < 2)
|
if(len == 1)
|
||||||
return;
|
return;
|
||||||
|
|
||||||
const int rank = shape::rank(shapeInfo);
|
const int rank = shape::rank(shapeInfo);
|
||||||
|
@ -3961,7 +3961,7 @@ INLINEDEF _CUDA_H bool reshapeC(const int oldRank, const Nd4jLong* oldShapeInfo,
|
||||||
newDim = newShape[newStart];
|
newDim = newShape[newStart];
|
||||||
oldDim = oldShape[oldStart];
|
oldDim = oldShape[oldStart];
|
||||||
|
|
||||||
while (newDim != oldDim)
|
while (newDim != oldDim && newDim > 0 && oldDim > 0)
|
||||||
if (newDim < oldDim) newDim *= newShape[newStop++];
|
if (newDim < oldDim) newDim *= newShape[newStop++];
|
||||||
else oldDim *= oldShape[oldStop++];
|
else oldDim *= oldShape[oldStop++];
|
||||||
|
|
||||||
|
|
|
@ -116,13 +116,23 @@ void IndexReduce<X>::exec(void *vx, Nd4jLong *xShapeInfo,
|
||||||
auto x = reinterpret_cast<X *>(vx);
|
auto x = reinterpret_cast<X *>(vx);
|
||||||
auto extraParams = reinterpret_cast<X *>(vextraParams);
|
auto extraParams = reinterpret_cast<X *>(vextraParams);
|
||||||
|
|
||||||
|
const Nd4jLong zLen = shape::length(zShapeInfo);
|
||||||
|
|
||||||
|
if(nd4j::ArrayOptions::arrayType(xShapeInfo) == nd4j::ArrayType::EMPTY) {
|
||||||
|
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
|
||||||
|
return;
|
||||||
|
const auto indexValue = OpType::startingIndexValue(x);
|
||||||
|
PRAGMA_OMP_PARALLEL_FOR_IF(zLen > nd4j::Environment::getInstance()->elementwiseThreshold())
|
||||||
|
for (uint i = 0; i < zLen; i++)
|
||||||
|
z[i] = indexValue.index;;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
if(shape::isScalar(zShapeInfo)) {
|
if(shape::isScalar(zShapeInfo)) {
|
||||||
z[0] = execScalar<OpType>(x,xShapeInfo,extraParams);
|
z[0] = execScalar<OpType>(x,xShapeInfo,extraParams);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const Nd4jLong zLen = shape::length(zShapeInfo);
|
|
||||||
|
|
||||||
auto tadOnlyShapeInfo = tadShapeInfo;
|
auto tadOnlyShapeInfo = tadShapeInfo;
|
||||||
Nd4jLong *tadOffsets = tadOffset;
|
Nd4jLong *tadOffsets = tadOffset;
|
||||||
|
|
||||||
|
|
|
@ -46,6 +46,21 @@ namespace functions {
|
||||||
const Nd4jLong length = shape::length(xShapeInfo);
|
const Nd4jLong length = shape::length(xShapeInfo);
|
||||||
auto xEws = shape::elementWiseStride(xShapeInfo);
|
auto xEws = shape::elementWiseStride(xShapeInfo);
|
||||||
|
|
||||||
|
if (shape::isEmpty(xShapeInfo)) {
|
||||||
|
z[0] = OpType::startingValue(x);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if(nd4j::ArrayOptions::arrayType(xShapeInfo) == nd4j::ArrayType::EMPTY) {
|
||||||
|
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
|
||||||
|
return;
|
||||||
|
const auto startingVal = OpType::startingValue(x);
|
||||||
|
PRAGMA_OMP_PARALLEL_FOR_IF(length > nd4j::Environment::getInstance()->elementwiseThreshold())
|
||||||
|
for (uint i = 0; i < length; i++)
|
||||||
|
z[i] = startingVal;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
if (xEws >= 1) {
|
if (xEws >= 1) {
|
||||||
z[0] = execScalar<OpType>(x, xEws, length, extraParams);
|
z[0] = execScalar<OpType>(x, xEws, length, extraParams);
|
||||||
}
|
}
|
||||||
|
@ -157,6 +172,16 @@ namespace functions {
|
||||||
|
|
||||||
auto resultLength = shape::length(zShapeInfo);
|
auto resultLength = shape::length(zShapeInfo);
|
||||||
|
|
||||||
|
if(nd4j::ArrayOptions::arrayType(xShapeInfo) == nd4j::ArrayType::EMPTY) {
|
||||||
|
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
|
||||||
|
return;
|
||||||
|
const auto startingVal = OpType::startingValue(x);
|
||||||
|
PRAGMA_OMP_PARALLEL_FOR_IF(resultLength > nd4j::Environment::getInstance()->elementwiseThreshold())
|
||||||
|
for (uint i = 0; i < resultLength; i++)
|
||||||
|
z[i] = startingVal;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
//pre squeezed: this is for keeping the pointer to the original
|
//pre squeezed: this is for keeping the pointer to the original
|
||||||
//shape information for tad offset
|
//shape information for tad offset
|
||||||
//the squeezed information doesn't render the right strides for
|
//the squeezed information doesn't render the right strides for
|
||||||
|
|
|
@ -46,6 +46,25 @@ namespace functions {
|
||||||
const Nd4jLong length = shape::length(xShapeInfo);
|
const Nd4jLong length = shape::length(xShapeInfo);
|
||||||
auto xEws = shape::elementWiseStride(xShapeInfo);
|
auto xEws = shape::elementWiseStride(xShapeInfo);
|
||||||
|
|
||||||
|
if (shape::isEmpty(xShapeInfo)) {
|
||||||
|
if (std::is_same<OpType, simdOps::Mean<X,Z>>::value) {
|
||||||
|
z[0] = nd4j::DataTypeUtils::nanOrZero<Z>();
|
||||||
|
} else {
|
||||||
|
z[0] = OpType::startingValue(x);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if(nd4j::ArrayOptions::arrayType(xShapeInfo) == nd4j::ArrayType::EMPTY) {
|
||||||
|
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
|
||||||
|
return;
|
||||||
|
const auto startingVal = OpType::startingValue(x);
|
||||||
|
PRAGMA_OMP_PARALLEL_FOR_IF(length > nd4j::Environment::getInstance()->elementwiseThreshold())
|
||||||
|
for (uint i = 0; i < length; i++)
|
||||||
|
z[i] = startingVal;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
if (xEws > 0) {
|
if (xEws > 0) {
|
||||||
z[0] = execScalar<OpType>(x, xEws, length, extraParams);
|
z[0] = execScalar<OpType>(x, xEws, length, extraParams);
|
||||||
}
|
}
|
||||||
|
@ -165,6 +184,16 @@ namespace functions {
|
||||||
|
|
||||||
auto resultLength = shape::length(zShapeInfo);
|
auto resultLength = shape::length(zShapeInfo);
|
||||||
|
|
||||||
|
if(nd4j::ArrayOptions::arrayType(xShapeInfo) == nd4j::ArrayType::EMPTY) {
|
||||||
|
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
|
||||||
|
return;
|
||||||
|
const auto startingVal = std::is_same<OpType, simdOps::Mean<X,Z>>::value ? nd4j::DataTypeUtils::nanOrZero<Z>() : static_cast<Z>(OpType::startingValue(x));
|
||||||
|
PRAGMA_OMP_PARALLEL_FOR_IF(resultLength > nd4j::Environment::getInstance()->elementwiseThreshold())
|
||||||
|
for (uint i = 0; i < resultLength; i++)
|
||||||
|
z[i] = startingVal;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
//pre squeezed: this is for keeping the pointer to the original
|
//pre squeezed: this is for keeping the pointer to the original
|
||||||
//shape information for tad offset
|
//shape information for tad offset
|
||||||
//the squeezed information doesn't render the right strides for
|
//the squeezed information doesn't render the right strides for
|
||||||
|
|
|
@ -46,6 +46,21 @@ namespace functions {
|
||||||
const Nd4jLong length = shape::length(xShapeInfo);
|
const Nd4jLong length = shape::length(xShapeInfo);
|
||||||
auto xEws = shape::elementWiseStride(xShapeInfo);
|
auto xEws = shape::elementWiseStride(xShapeInfo);
|
||||||
|
|
||||||
|
if (shape::isEmpty(xShapeInfo)) {
|
||||||
|
z[0] = OpType::startingValue(x);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if(nd4j::ArrayOptions::arrayType(xShapeInfo) == nd4j::ArrayType::EMPTY) {
|
||||||
|
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
|
||||||
|
return;
|
||||||
|
const auto startingVal = OpType::startingValue(x);
|
||||||
|
PRAGMA_OMP_PARALLEL_FOR_IF(length > nd4j::Environment::getInstance()->elementwiseThreshold())
|
||||||
|
for (uint i = 0; i < length; i++)
|
||||||
|
z[i] = startingVal;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
if (xEws >= 1) {
|
if (xEws >= 1) {
|
||||||
z[0] = execScalar<OpType>(x, xEws, length, extraParams);
|
z[0] = execScalar<OpType>(x, xEws, length, extraParams);
|
||||||
}
|
}
|
||||||
|
@ -159,6 +174,16 @@ namespace functions {
|
||||||
|
|
||||||
auto resultLength = shape::length(zShapeInfo);
|
auto resultLength = shape::length(zShapeInfo);
|
||||||
|
|
||||||
|
if(nd4j::ArrayOptions::arrayType(xShapeInfo) == nd4j::ArrayType::EMPTY) {
|
||||||
|
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
|
||||||
|
return;
|
||||||
|
const auto startingVal = OpType::startingValue(x);
|
||||||
|
PRAGMA_OMP_PARALLEL_FOR_IF(resultLength > nd4j::Environment::getInstance()->elementwiseThreshold())
|
||||||
|
for (uint i = 0; i < resultLength; i++)
|
||||||
|
z[i] = startingVal;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
//pre squeezed: this is for keeping the pointer to the original
|
//pre squeezed: this is for keeping the pointer to the original
|
||||||
//shape information for tad offset
|
//shape information for tad offset
|
||||||
//the squeezed information doesn't render the right strides for
|
//the squeezed information doesn't render the right strides for
|
||||||
|
|
|
@ -48,6 +48,20 @@ namespace functions {
|
||||||
const auto xEws = shape::elementWiseStride(xShapeInfo);
|
const auto xEws = shape::elementWiseStride(xShapeInfo);
|
||||||
const int rank = shape::rank(xShapeInfo);
|
const int rank = shape::rank(xShapeInfo);
|
||||||
|
|
||||||
|
if (shape::isEmpty(xShapeInfo)) {
|
||||||
|
z[0] = OpType::startingValue(x);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if(nd4j::ArrayOptions::arrayType(xShapeInfo) == nd4j::ArrayType::EMPTY) {
|
||||||
|
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
|
||||||
|
return;
|
||||||
|
const auto startingVal = OpType::startingValue(x);
|
||||||
|
PRAGMA_OMP_PARALLEL_FOR_IF(length > nd4j::Environment::getInstance()->elementwiseThreshold())
|
||||||
|
for (uint i = 0; i < length; i++)
|
||||||
|
z[i] = startingVal;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
if (xEws >= 1) {
|
if (xEws >= 1) {
|
||||||
z[0] = execScalar<OpType>(x, xEws, length, extraParams);
|
z[0] = execScalar<OpType>(x, xEws, length, extraParams);
|
||||||
|
@ -71,7 +85,7 @@ namespace functions {
|
||||||
for (int e = 0; e < maxThreads; e++)
|
for (int e = 0; e < maxThreads; e++)
|
||||||
start = OpType::update(start, intermediate[e], extraParams);
|
start = OpType::update(start, intermediate[e], extraParams);
|
||||||
|
|
||||||
z[0] = OpType::postProcess(start, shape::length(xShapeInfo), extraParams);
|
z[0] = OpType::postProcess(start, length, extraParams);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -171,6 +185,16 @@ namespace functions {
|
||||||
|
|
||||||
auto zLength = shape::length(zShapeInfo);
|
auto zLength = shape::length(zShapeInfo);
|
||||||
|
|
||||||
|
if(nd4j::ArrayOptions::arrayType(xShapeInfo) == nd4j::ArrayType::EMPTY) {
|
||||||
|
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
|
||||||
|
return;
|
||||||
|
const auto startingVal = OpType::startingValue(x);
|
||||||
|
PRAGMA_OMP_PARALLEL_FOR_IF(zLength > nd4j::Environment::getInstance()->elementwiseThreshold())
|
||||||
|
for (uint i = 0; i < zLength; i++)
|
||||||
|
z[i] = startingVal;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
//pre squeezed: this is for keeping the pointer to the original
|
//pre squeezed: this is for keeping the pointer to the original
|
||||||
//shape information for tad offset
|
//shape information for tad offset
|
||||||
//the squeezed information doesn't render the right strides for
|
//the squeezed information doesn't render the right strides for
|
||||||
|
|
|
@ -47,6 +47,16 @@ void Reduce3<X,Z>::execScalar(void *vx, Nd4jLong *xShapeInfo,
|
||||||
auto xEws = shape::elementWiseStride(xShapeInfo);
|
auto xEws = shape::elementWiseStride(xShapeInfo);
|
||||||
auto yEws = shape::elementWiseStride(yShapeInfo);
|
auto yEws = shape::elementWiseStride(yShapeInfo);
|
||||||
|
|
||||||
|
if(nd4j::ArrayOptions::arrayType(xShapeInfo) == nd4j::ArrayType::EMPTY || nd4j::ArrayOptions::arrayType(yShapeInfo) == nd4j::ArrayType::EMPTY) {
|
||||||
|
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
|
||||||
|
return;
|
||||||
|
const auto startingVal = OpType::startingValue(x);
|
||||||
|
PRAGMA_OMP_PARALLEL_FOR_IF(length > nd4j::Environment::getInstance()->elementwiseThreshold())
|
||||||
|
for (uint i = 0; i < length; i++)
|
||||||
|
z[i] = startingVal;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
Z extraParamsVals[3] = {(Z) 0.0f, (Z) 0.0f, (Z) 0.0f};
|
Z extraParamsVals[3] = {(Z) 0.0f, (Z) 0.0f, (Z) 0.0f};
|
||||||
// it's possible case for EqualsWithEps op
|
// it's possible case for EqualsWithEps op
|
||||||
if (extraParams != nullptr)
|
if (extraParams != nullptr)
|
||||||
|
|
|
@ -47,8 +47,8 @@ namespace functions {
|
||||||
Nd4jLong *xShapeInfo,
|
Nd4jLong *xShapeInfo,
|
||||||
void *extraParams,
|
void *extraParams,
|
||||||
void *z,
|
void *z,
|
||||||
Nd4jLong *resultShapeInfoBuffer) {
|
Nd4jLong *zShapeInfo) {
|
||||||
DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(biasCorrected, x, xShapeInfo, extraParams, z, resultShapeInfoBuffer), SUMMARY_STATS_OPS);
|
DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(biasCorrected, x, xShapeInfo, extraParams, z, zShapeInfo), SUMMARY_STATS_OPS);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename X, typename Y>
|
template <typename X, typename Y>
|
||||||
|
@ -58,10 +58,10 @@ namespace functions {
|
||||||
Nd4jLong *xShapeInfo,
|
Nd4jLong *xShapeInfo,
|
||||||
void *extraParams,
|
void *extraParams,
|
||||||
void *z,
|
void *z,
|
||||||
Nd4jLong *resultShapeInfoBuffer,
|
Nd4jLong *zShapeInfo,
|
||||||
int *dimension,
|
int *dimension,
|
||||||
int dimensionLength) {
|
int dimensionLength) {
|
||||||
DISPATCH_BY_OPNUM_TT(exec, PARAMS(biasCorrected, x, xShapeInfo, extraParams, z, resultShapeInfoBuffer, dimension, dimensionLength), SUMMARY_STATS_OPS);
|
DISPATCH_BY_OPNUM_TT(exec, PARAMS(biasCorrected, x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength), SUMMARY_STATS_OPS);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename X, typename Z>
|
template <typename X, typename Z>
|
||||||
|
@ -71,7 +71,7 @@ namespace functions {
|
||||||
Nd4jLong *xShapeInfo,
|
Nd4jLong *xShapeInfo,
|
||||||
void *vextraParams,
|
void *vextraParams,
|
||||||
void *vz,
|
void *vz,
|
||||||
Nd4jLong *resultShapeInfoBuffer) {
|
Nd4jLong *zShapeInfo) {
|
||||||
auto z = reinterpret_cast<Z*>(vz);
|
auto z = reinterpret_cast<Z*>(vz);
|
||||||
z[0] = execScalar<OpType>(biasCorrected, vx, xShapeInfo, vextraParams);
|
z[0] = execScalar<OpType>(biasCorrected, vx, xShapeInfo, vextraParams);
|
||||||
}
|
}
|
||||||
|
@ -108,20 +108,31 @@ namespace functions {
|
||||||
void *vx,
|
void *vx,
|
||||||
Nd4jLong *xShapeInfo,
|
Nd4jLong *xShapeInfo,
|
||||||
void *vextraParams,
|
void *vextraParams,
|
||||||
void *vresult,
|
void *vz,
|
||||||
Nd4jLong *resultShapeInfoBuffer,
|
Nd4jLong *zShapeInfo,
|
||||||
int *dimension,
|
int *dimension,
|
||||||
int dimensionLength) {
|
int dimensionLength) {
|
||||||
auto x = reinterpret_cast<X *>(vx);
|
|
||||||
auto z = reinterpret_cast<Z *>(vresult);
|
|
||||||
auto extraParams = reinterpret_cast<Z *>(vextraParams);
|
|
||||||
|
|
||||||
if (shape::isScalar(resultShapeInfoBuffer)) {
|
auto x = reinterpret_cast<X *>(vx);
|
||||||
z[0] = execScalar<OpType>(biasCorrected, x, xShapeInfo, extraParams);
|
auto z = reinterpret_cast<Z *>(vz);
|
||||||
|
auto extraParams = reinterpret_cast<Z *>(vextraParams);
|
||||||
|
int resultLength = shape::length(zShapeInfo);
|
||||||
|
|
||||||
|
if(nd4j::ArrayOptions::arrayType(xShapeInfo) == nd4j::ArrayType::EMPTY) {
|
||||||
|
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
|
||||||
|
return;
|
||||||
|
SummaryStatsData<X> comp;
|
||||||
|
comp.initWithValue(x[0]);
|
||||||
|
PRAGMA_OMP_PARALLEL_FOR_IF(resultLength > nd4j::Environment::getInstance()->elementwiseThreshold())
|
||||||
|
for (uint i = 0; i < resultLength; i++)
|
||||||
|
z[i] = OpType::getValue(biasCorrected, comp);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (shape::isScalar(zShapeInfo)) {
|
||||||
|
z[0] = execScalar<OpType>(biasCorrected, x, xShapeInfo, extraParams);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
//no-op
|
//no-op
|
||||||
if (dimensionLength < 1)
|
if (dimensionLength < 1)
|
||||||
|
@ -129,7 +140,6 @@ namespace functions {
|
||||||
|
|
||||||
auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(xShapeInfo, dimension, dimensionLength);
|
auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(xShapeInfo, dimension, dimensionLength);
|
||||||
|
|
||||||
int resultLength = shape::length(resultShapeInfoBuffer);
|
|
||||||
//pre squeezed: this is for keeping the pointer to the original
|
//pre squeezed: this is for keeping the pointer to the original
|
||||||
//shape information for tad offset
|
//shape information for tad offset
|
||||||
//the squeezed information doesn't render the right strides for
|
//the squeezed information doesn't render the right strides for
|
||||||
|
|
|
@ -131,7 +131,7 @@ namespace nd4j {
|
||||||
COPY_SHAPE(x, shapeE);
|
COPY_SHAPE(x, shapeE);
|
||||||
COPY_SHAPE(y, shapeG);
|
COPY_SHAPE(y, shapeG);
|
||||||
|
|
||||||
auto shapeList = SHAPELIST(shapeE, shapeG);
|
auto shapeList = SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG));
|
||||||
|
|
||||||
return shapeList;
|
return shapeList;
|
||||||
}
|
}
|
||||||
|
|
|
@ -81,7 +81,7 @@ CUSTOM_OP_IMPL(conv1d, 2, 1, false, 0, 4) {
|
||||||
auto outputReshaped = output ->reshape(output->ordering(), reshapeForOutput);
|
auto outputReshaped = output ->reshape(output->ordering(), reshapeForOutput);
|
||||||
auto weightsReshaped = weights->reshape(weights->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
|
auto weightsReshaped = weights->reshape(weights->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
|
||||||
|
|
||||||
ConvolutionUtils::conv2d(*block.launchContext(), inputReshaped, weightsReshaped, bias, outputReshaped, 1,kW, 1,sW, 0,pW, 1,1, isSameMode, isNCW);
|
ConvolutionUtils::conv2d(block, inputReshaped, weightsReshaped, bias, outputReshaped, 1,kW, 1,sW, 0,pW, 1,1, isSameMode, isNCW);
|
||||||
|
|
||||||
delete inputReshaped;
|
delete inputReshaped;
|
||||||
delete outputReshaped;
|
delete outputReshaped;
|
||||||
|
@ -217,7 +217,7 @@ CUSTOM_OP_IMPL(conv1d_bp, 3, 2, false, 0, 4) {
|
||||||
auto weightsReshaped = weights->reshape(weights->ordering(),{1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
|
auto weightsReshaped = weights->reshape(weights->ordering(),{1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
|
||||||
auto gradWReshaped = gradW ->reshape(gradW->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
|
auto gradWReshaped = gradW ->reshape(gradW->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
|
||||||
|
|
||||||
ConvolutionUtils::conv2dBP(*block.launchContext(), inputReshaped, weightsReshaped, bias, gradOReshaped, gradIReshaped, gradWReshaped, gradB, 1,kW, 1,sW, 0,pW, 1,1, isSameMode, isNCW);
|
ConvolutionUtils::conv2dBP(block, inputReshaped, weightsReshaped, bias, gradOReshaped, gradIReshaped, gradWReshaped, gradB, 1,kW, 1,sW, 0,pW, 1,1, isSameMode, isNCW);
|
||||||
|
|
||||||
delete inputReshaped;
|
delete inputReshaped;
|
||||||
delete gradIReshaped;
|
delete gradIReshaped;
|
||||||
|
|
|
@ -63,7 +63,7 @@ CUSTOM_OP_IMPL(conv2d, 2, 1, false, 0, 9) {
|
||||||
if (bias)
|
if (bias)
|
||||||
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||||
|
|
||||||
ConvolutionUtils::conv2d(*block.launchContext(), input, weights, bias, output, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW);
|
ConvolutionUtils::conv2d(block, input, weights, bias, output, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -194,7 +194,7 @@ CUSTOM_OP_IMPL(conv2d_bp, 3, 2, false, 0, 9) {
|
||||||
if(bias)
|
if(bias)
|
||||||
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV2D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV2D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||||
|
|
||||||
ConvolutionUtils::conv2dBP(*block.launchContext(), input, weights, bias, gradO, gradI, gradW, gradB, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW);
|
ConvolutionUtils::conv2dBP(block, input, weights, bias, gradO, gradI, gradW, gradB, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -305,7 +305,7 @@ CUSTOM_OP_IMPL(conv2d_input_bp, 3, 1, false, 0, 9) {
|
||||||
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "CUSTOM CONV2D_INPUT_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "CUSTOM CONV2D_INPUT_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||||
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, "CUSTOM CONV2D_INPUT_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, "CUSTOM CONV2D_INPUT_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||||
|
|
||||||
ConvolutionUtils::conv2dBP(*block.launchContext(), &input, weights, nullptr, gradO, gradI, nullptr, nullptr, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW);
|
ConvolutionUtils::conv2dBP(block, &input, weights, nullptr, gradO, gradI, nullptr, nullptr, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -157,7 +157,7 @@ CUSTOM_OP_IMPL(conv3dnew, 2, 1, false, 0, 13) {
|
||||||
permutForOutput = {0,2,3,4,1}; // [bS, oC, oD, oH, oW] -> [bS, oD, oH, oW, oC]
|
permutForOutput = {0,2,3,4,1}; // [bS, oC, oD, oH, oW] -> [bS, oD, oH, oW, oC]
|
||||||
|
|
||||||
NDArray columns(input->ordering(), {bS, iC, kD, kH, kW, oD, oH, oW}, input->dataType(), block.launchContext());
|
NDArray columns(input->ordering(), {bS, iC, kD, kH, kW, oD, oH, oW}, input->dataType(), block.launchContext());
|
||||||
ConvolutionUtils::vol2col(*block.launchContext(), *input, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, iC, iD, iH, iW] is convoluted to [bS, iC, kD, kH, kW, oD, oH, oW]
|
ConvolutionUtils::vol2col(block, *input, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, iC, iD, iH, iW] is convoluted to [bS, iC, kD, kH, kW, oD, oH, oW]
|
||||||
// [bS, iC, kD, kH, kW, oD, oH, oW] x [kD, kH, kW, iC, oC] = [bS, oD, oH, oW, oC]
|
// [bS, iC, kD, kH, kW, oD, oH, oW] x [kD, kH, kW, iC, oC] = [bS, oD, oH, oW, oC]
|
||||||
MmulHelper::tensorDot(&columns, weights, output, {1,2,3,4}, {3,0,1,2}, permutForOutput);
|
MmulHelper::tensorDot(&columns, weights, output, {1,2,3,4}, {3,0,1,2}, permutForOutput);
|
||||||
|
|
||||||
|
@ -456,7 +456,7 @@ CUSTOM_OP_IMPL(conv3dnew_bp, 3, 2, false, 0, 13) {
|
||||||
|
|
||||||
// ----- calculation of gradW and gradB ----- //
|
// ----- calculation of gradW and gradB ----- //
|
||||||
NDArray columns(input->ordering(), {bS, iC, kD, kH, kW, oD, oH, oW}, input->dataType(), block.launchContext());
|
NDArray columns(input->ordering(), {bS, iC, kD, kH, kW, oD, oH, oW}, input->dataType(), block.launchContext());
|
||||||
ConvolutionUtils::vol2col(*block.launchContext(), *input, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, iC, iD, iH, iW] is convoluted to [bS, iC, kD, kH, kW, oD, oH, oW]
|
ConvolutionUtils::vol2col(block, *input, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, iC, iD, iH, iW] is convoluted to [bS, iC, kD, kH, kW, oD, oH, oW]
|
||||||
MmulHelper::tensorDot(&columns, gradO, gradW, {0,5,6,7}, gradOaxesForDot, {3,0,1,2,4}); // [bS, iC, kD, kH, kW, oD, oH, oW] x [bS, oD, oH, oW, oC]/[bS, oC, oD, oH, oW] = [iC, kD, kH, kW, oC]
|
MmulHelper::tensorDot(&columns, gradO, gradW, {0,5,6,7}, gradOaxesForDot, {3,0,1,2,4}); // [bS, iC, kD, kH, kW, oD, oH, oW] x [bS, oD, oH, oW, oC]/[bS, oC, oD, oH, oW] = [iC, kD, kH, kW, oC]
|
||||||
|
|
||||||
if(gradB) {
|
if(gradB) {
|
||||||
|
@ -469,7 +469,7 @@ CUSTOM_OP_IMPL(conv3dnew_bp, 3, 2, false, 0, 13) {
|
||||||
|
|
||||||
//----- calculation of gradI -----//
|
//----- calculation of gradI -----//
|
||||||
MmulHelper::tensorDot(weights, gradO, &columns, {indWoC}, {indIOioC}, {2,3,4,1,0,5,6,7}); // [kD, kH, kW, iC, oC] x [bS, oD, oH, oW, oC]/[bS, oC, oD, oH, oW] = [kD, kH, kW, iC, bS, oD, oH, oW]
|
MmulHelper::tensorDot(weights, gradO, &columns, {indWoC}, {indIOioC}, {2,3,4,1,0,5,6,7}); // [kD, kH, kW, iC, oC] x [bS, oD, oH, oW, oC]/[bS, oC, oD, oH, oW] = [kD, kH, kW, iC, bS, oD, oH, oW]
|
||||||
ConvolutionUtils::col2vol(*block.launchContext(), columns, *gradI, sD, sH, sW, pD, pH, pW, dD, dH, dW); // columns [bS, iC, kD, kH, kW, oD, oH, oW] is de-convoluted to [bS, iC, iD, iH, iW]
|
ConvolutionUtils::col2vol(block, columns, *gradI, sD, sH, sW, pD, pH, pW, dD, dH, dW); // columns [bS, iC, kD, kH, kW, oD, oH, oW] is de-convoluted to [bS, iC, iD, iH, iW]
|
||||||
|
|
||||||
if(!isNDHWC) {
|
if(!isNDHWC) {
|
||||||
delete input;
|
delete input;
|
||||||
|
|
|
@ -70,7 +70,7 @@ CUSTOM_OP_IMPL(deconv2d_tf, 3, 1, false, 0, 9) {
|
||||||
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "CUSTOM DECONV2D_TF OP: wrong shape of input array, basing on array with output shape expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "CUSTOM DECONV2D_TF OP: wrong shape of input array, basing on array with output shape expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||||
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, "CUSTOM DECONV2D_TF OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, "CUSTOM DECONV2D_TF OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||||
|
|
||||||
ConvolutionUtils::conv2dBP(*block.launchContext(), &input, weights, nullptr, gradO, gradI, nullptr, nullptr, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW);
|
ConvolutionUtils::conv2dBP(block, &input, weights, nullptr, gradO, gradI, nullptr, nullptr, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -75,7 +75,7 @@ CUSTOM_OP_IMPL(deconv3d, 2, 1, false, 0, 13) {
|
||||||
// NDHWC: [kD, kH, kW, oC, iC] x [bS, iD, iH, iW, iC] = [kD, kH, kW, oC, bS, iD, iH, iW]
|
// NDHWC: [kD, kH, kW, oC, iC] x [bS, iD, iH, iW, iC] = [kD, kH, kW, oC, bS, iD, iH, iW]
|
||||||
// NCDHW: [iC, oC, kD, kH, kW] x [bS, iC, iD, iH, iW] = [oC, kD, kH, kW, bS, iD, iH, iW]
|
// NCDHW: [iC, oC, kD, kH, kW] x [bS, iC, iD, iH, iW] = [oC, kD, kH, kW, bS, iD, iH, iW]
|
||||||
nd4j::MmulHelper::tensorDot(weights, input, &columns, {indWiC}, {indIOioC}, {2, 3, 4, 1, 0, 5, 6, 7}); // [bS, oC, kD, kH, kW, iD, iH, iW] -> [kD, kH, kW, oC, bS, iD, iH, iW]
|
nd4j::MmulHelper::tensorDot(weights, input, &columns, {indWiC}, {indIOioC}, {2, 3, 4, 1, 0, 5, 6, 7}); // [bS, oC, kD, kH, kW, iD, iH, iW] -> [kD, kH, kW, oC, bS, iD, iH, iW]
|
||||||
ConvolutionUtils::col2vol(*block.launchContext(), columns, *output, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, oC, kD, kH, kW, iD, iH, iW] is de-convoluted to [bS, oC, oD, oH, oW]
|
ConvolutionUtils::col2vol(block, columns, *output, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, oC, kD, kH, kW, iD, iH, iW] is de-convoluted to [bS, oC, oD, oH, oW]
|
||||||
|
|
||||||
//----- add biases if required -----//
|
//----- add biases if required -----//
|
||||||
if(bias)
|
if(bias)
|
||||||
|
@ -234,7 +234,7 @@ CUSTOM_OP_IMPL(deconv3d_bp, 3, 2, false, 0, 13) {
|
||||||
|
|
||||||
// ----- calculation of gradW ----- //
|
// ----- calculation of gradW ----- //
|
||||||
auto columns = NDArrayFactory::create(input->ordering(), {bS, oC, kD, kH, kW, iD, iH, iW}, input->dataType(), block.launchContext());
|
auto columns = NDArrayFactory::create(input->ordering(), {bS, oC, kD, kH, kW, iD, iH, iW}, input->dataType(), block.launchContext());
|
||||||
ConvolutionUtils::vol2col(*block.launchContext(), *gradO, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, oC, oD, oH, oW] is deconvoluted to [bS, oC, kD, kH, kW, iD, iH, iW]
|
ConvolutionUtils::vol2col(block, *gradO, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, oC, oD, oH, oW] is deconvoluted to [bS, oC, kD, kH, kW, iD, iH, iW]
|
||||||
MmulHelper::tensorDot(input, &columns, gradW, inputAxesForDot, {0, 5, 6, 7}, {4, 3, 0, 1, 2}); // [bS, iC, iD, iH, iW]/[bS, iD, iH, iW, iC] x [bS, oC, kD, kH, kW, iD, iH, iW] = [iC, oC, kD, kH, kW]
|
MmulHelper::tensorDot(input, &columns, gradW, inputAxesForDot, {0, 5, 6, 7}, {4, 3, 0, 1, 2}); // [bS, iC, iD, iH, iW]/[bS, iD, iH, iW, iC] x [bS, oC, kD, kH, kW, iD, iH, iW] = [iC, oC, kD, kH, kW]
|
||||||
|
|
||||||
// ----- calculation of gradB ----- //
|
// ----- calculation of gradB ----- //
|
||||||
|
|
|
@ -62,7 +62,7 @@ CUSTOM_OP_IMPL(depthwise_conv2d, 2, 1, false, 0, 9) {
|
||||||
if (bias)
|
if (bias)
|
||||||
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DEPTHWISECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DEPTHWISECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||||
|
|
||||||
ConvolutionUtils::depthwiseConv2d(*block.launchContext(), input, weights, bias, output, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW);
|
ConvolutionUtils::depthwiseConv2d(block, input, weights, bias, output, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -185,7 +185,7 @@ CUSTOM_OP_IMPL(depthwise_conv2d_bp, 3, 2, false, 0, 9) {
|
||||||
if(bias)
|
if(bias)
|
||||||
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||||
|
|
||||||
ConvolutionUtils::depthwiseConv2dBP(*block.launchContext(), input, weights, bias, gradO, gradI, gradW, gradB, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW);
|
ConvolutionUtils::depthwiseConv2dBP(block, input, weights, bias, gradO, gradI, gradW, gradB, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -58,7 +58,7 @@ CUSTOM_OP_IMPL(pointwise_conv2d, 2, 1, false, 0, 0) {
|
||||||
if (bias)
|
if (bias)
|
||||||
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM POINTWISECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM POINTWISECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||||
|
|
||||||
ConvolutionUtils::conv2d(*block.launchContext(), input, weights, bias, output, kH,kW, sH,sW, pH,pW, dH,dW, 1/*isSameMode*/, isNCHW);
|
ConvolutionUtils::conv2d(block, input, weights, bias, output, kH,kW, sH,sW, pH,pW, dH,dW, 1/*isSameMode*/, isNCHW);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -70,7 +70,7 @@ CUSTOM_OP_IMPL(avgpool2d, 1, 1, false, 0, 10) {
|
||||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
||||||
|
|
||||||
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor;
|
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor;
|
||||||
ConvolutionUtils::pooling2d(*block.launchContext(), *input, *output, kH, kW, sH, sW, pH, pW, dH, dW, PoolingType::AVG_POOL, extraParam0);
|
ConvolutionUtils::pooling2d(block, *input, *output, kH, kW, sH, sW, pH, pW, dH, dW, PoolingType::AVG_POOL, extraParam0);
|
||||||
//output->printBuffer("output op");
|
//output->printBuffer("output op");
|
||||||
|
|
||||||
if (!isNCHW) {
|
if (!isNCHW) {
|
||||||
|
@ -198,7 +198,7 @@ CUSTOM_OP_IMPL(avgpool2d_bp, 2, 1, false, 0, 10) {
|
||||||
// *gradI /= kH*kW;
|
// *gradI /= kH*kW;
|
||||||
|
|
||||||
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor;
|
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor;
|
||||||
ConvolutionUtils::pooling2dBP(*block.launchContext(), *input, *gradO, *gradI, kH, kW, sH, sW, pH, pW, dH, dW, 1, extraParam0);
|
ConvolutionUtils::pooling2dBP(block, *input, *gradO, *gradI, kH, kW, sH, sW, pH, pW, dH, dW, 1, extraParam0);
|
||||||
|
|
||||||
if(!isNCHW) {
|
if(!isNCHW) {
|
||||||
delete input;
|
delete input;
|
||||||
|
|
|
@ -69,7 +69,7 @@ CUSTOM_OP_IMPL(avgpool3dnew, 1, 1, false, 0, 14) {
|
||||||
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
||||||
|
|
||||||
//T extraParams[] = {};
|
//T extraParams[] = {};
|
||||||
ConvolutionUtils::pooling3d(*block.launchContext(), *input, *output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, 1, extraParam0);
|
ConvolutionUtils::pooling3d(block, *input, *output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, 1, extraParam0);
|
||||||
|
|
||||||
if(!isNCDHW) {
|
if(!isNCDHW) {
|
||||||
delete input;
|
delete input;
|
||||||
|
@ -189,7 +189,7 @@ CUSTOM_OP_IMPL(avgpool3dnew_bp, 2, 1, false, 0, 14) {
|
||||||
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
||||||
|
|
||||||
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor;
|
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor;
|
||||||
ConvolutionUtils::pooling3dBP(*block.launchContext(), *input, *gradO, *gradI, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, 1, extraParam0);
|
ConvolutionUtils::pooling3dBP(block, *input, *gradO, *gradI, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, 1, extraParam0);
|
||||||
|
|
||||||
if(!isNCDHW) {
|
if(!isNCDHW) {
|
||||||
delete input;
|
delete input;
|
||||||
|
|
|
@ -70,7 +70,7 @@ CUSTOM_OP_IMPL(maxpool2d, 1, 1, false, 0, 9) {
|
||||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
||||||
|
|
||||||
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; poolingMode; 9 - divisor;
|
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; poolingMode; 9 - divisor;
|
||||||
ConvolutionUtils::pooling2d(*block.launchContext(), *input, *output, kH, kW, sH, sW, pH, pW, dH, dW, PoolingType::MAX_POOL, 1);
|
ConvolutionUtils::pooling2d(block, *input, *output, kH, kW, sH, sW, pH, pW, dH, dW, PoolingType::MAX_POOL, 1);
|
||||||
|
|
||||||
if (!isNCHW) {
|
if (!isNCHW) {
|
||||||
delete input;
|
delete input;
|
||||||
|
@ -196,7 +196,7 @@ CUSTOM_OP_IMPL(maxpool2d_bp, 2, 1, false, 0, 10) {
|
||||||
|
|
||||||
// columns->template applyTransform<simdOps::Col2Im<T>>(gradI, std::vector<T>({(T)sH, (T)sW, (T)pH, (T)pW, (T)iH, (T)iW, (T)dH, (T)dW}).data());
|
// columns->template applyTransform<simdOps::Col2Im<T>>(gradI, std::vector<T>({(T)sH, (T)sW, (T)pH, (T)pW, (T)iH, (T)iW, (T)dH, (T)dW}).data());
|
||||||
|
|
||||||
ConvolutionUtils::pooling2dBP(*block.launchContext(), *input, *gradO, *gradI, kH, kW, sH, sW, pH, pW, dH, dW, 0., 1.);
|
ConvolutionUtils::pooling2dBP(block, *input, *gradO, *gradI, kH, kW, sH, sW, pH, pW, dH, dW, 0., 1.);
|
||||||
|
|
||||||
if(!isNCHW) {
|
if(!isNCHW) {
|
||||||
delete input;
|
delete input;
|
||||||
|
|
|
@ -70,7 +70,7 @@ CUSTOM_OP_IMPL(maxpool3dnew, 1, 1, false, 0, 14) {
|
||||||
if(isSameMode) // SAME
|
if(isSameMode) // SAME
|
||||||
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
||||||
|
|
||||||
ConvolutionUtils::pooling3d(*block.launchContext(), *input, *output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, 0, 1);
|
ConvolutionUtils::pooling3d(block, *input, *output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, 0, 1);
|
||||||
|
|
||||||
if(!isNCDHW) {
|
if(!isNCDHW) {
|
||||||
delete input;
|
delete input;
|
||||||
|
@ -204,7 +204,7 @@ CUSTOM_OP_IMPL(maxpool3dnew_bp, 2, 1, false, 0, 14) {
|
||||||
// ConvolutionUtils<T>::col2vol(*columns, *gradI, sD, sH, sW, pD, pH, pW, dD, dH, dW); // columns [bS, iC, kD, kH, kW, oD, oH, oW] is de-convoluted to [bS, iC, iD, iH, iW]
|
// ConvolutionUtils<T>::col2vol(*columns, *gradI, sD, sH, sW, pD, pH, pW, dD, dH, dW); // columns [bS, iC, kD, kH, kW, oD, oH, oW] is de-convoluted to [bS, iC, iD, iH, iW]
|
||||||
|
|
||||||
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - unnecessary;
|
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - unnecessary;
|
||||||
ConvolutionUtils::pooling3dBP(*block.launchContext(), *input, *gradO, *gradI, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, 0, 1);
|
ConvolutionUtils::pooling3dBP(block, *input, *gradO, *gradI, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, 0, 1);
|
||||||
|
|
||||||
if(!isNCDHW) {
|
if(!isNCDHW) {
|
||||||
delete input;
|
delete input;
|
||||||
|
|
|
@ -68,7 +68,7 @@ namespace nd4j {
|
||||||
ConvolutionUtils::calcPadding2D(pY, pX, oY, oX, inY, inX, kY, kX, sY, sX, dY, dX);
|
ConvolutionUtils::calcPadding2D(pY, pX, oY, oX, inY, inX, kY, kX, sY, sX, dY, dX);
|
||||||
|
|
||||||
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor;
|
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor;
|
||||||
ConvolutionUtils::pooling2d(*block.launchContext(), *input, *output, kY, kX, sY, sX, pY, pX, dY, dX, PoolingType::PNORM_POOL, extraParam0);
|
ConvolutionUtils::pooling2d(block, *input, *output, kY, kX, sY, sX, pY, pX, dY, dX, PoolingType::PNORM_POOL, extraParam0);
|
||||||
|
|
||||||
if (!isNCHW) {
|
if (!isNCHW) {
|
||||||
delete input;
|
delete input;
|
||||||
|
@ -209,7 +209,7 @@ CUSTOM_OP_IMPL(pnormpool2d_bp, 2, 1, false, 1, 10) {
|
||||||
|
|
||||||
// columns->template applyTransform<simdOps::Col2Im<T>>(gradI, std::vector<T>({(T)sH, (T)sW, (T)pH, (T)pW, (T)iH, (T)iW, (T)dH, (T)dW}).data());
|
// columns->template applyTransform<simdOps::Col2Im<T>>(gradI, std::vector<T>({(T)sH, (T)sW, (T)pH, (T)pW, (T)iH, (T)iW, (T)dH, (T)dW}).data());
|
||||||
|
|
||||||
ConvolutionUtils::pooling2dBP(*block.launchContext(), *input, *gradO, *gradI, kH, kW, sH, sW, pH, pW, dH, dW, 2, pnorm);
|
ConvolutionUtils::pooling2dBP(block, *input, *gradO, *gradI, kH, kW, sH, sW, pH, pW, dH, dW, 2, pnorm);
|
||||||
|
|
||||||
if(!isNCHW) {
|
if(!isNCHW) {
|
||||||
delete input;
|
delete input;
|
||||||
|
|
|
@ -84,11 +84,11 @@ CUSTOM_OP_IMPL(sconv2d, 2, 1, false, 0, 9) {
|
||||||
|
|
||||||
if (iC == 1) {
|
if (iC == 1) {
|
||||||
nd4j_debug("SCONV2D OP: for input_channels = 1 this op is equivalent to standard conv2d\n","");
|
nd4j_debug("SCONV2D OP: for input_channels = 1 this op is equivalent to standard conv2d\n","");
|
||||||
ConvolutionUtils::conv2d(*block.launchContext(), input, weightsDepth, bias, output, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW);
|
ConvolutionUtils::conv2d(block, input, weightsDepth, bias, output, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
ConvolutionUtils::sconv2d(*block.launchContext(), input, weightsDepth, weightsPoint, bias, output, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW);
|
ConvolutionUtils::sconv2d(block, input, weightsDepth, weightsPoint, bias, output, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -274,12 +274,12 @@ CUSTOM_OP_IMPL(sconv2d_bp, 3, 2, false, 0, 9) {
|
||||||
|
|
||||||
auto resultFFShape = isNCHW ? std::vector<Nd4jLong>({bS, mC*iC, oH, oW}) : std::vector<Nd4jLong>({bS, oH, oW, mC*iC});
|
auto resultFFShape = isNCHW ? std::vector<Nd4jLong>({bS, mC*iC, oH, oW}) : std::vector<Nd4jLong>({bS, oH, oW, mC*iC});
|
||||||
auto resultFF = NDArrayFactory::create_(input->ordering(), resultFFShape, input->dataType(), block.launchContext());
|
auto resultFF = NDArrayFactory::create_(input->ordering(), resultFFShape, input->dataType(), block.launchContext());
|
||||||
ConvolutionUtils::sconv2d(*block.launchContext(), input, weightsDepth, nullptr, nullptr, resultFF, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW);
|
ConvolutionUtils::sconv2d(block, input, weightsDepth, nullptr, nullptr, resultFF, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW);
|
||||||
|
|
||||||
auto gradIDepthShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC*mC,oH,oW, 0,indIOioC,indIiH,indIiH+1});
|
auto gradIDepthShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC*mC,oH,oW, 0,indIOioC,indIiH,indIiH+1});
|
||||||
auto gradIDepth = NDArrayFactory::create_(resultFF->ordering(), gradIDepthShape, resultFF->dataType(), block.launchContext()); // [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW)
|
auto gradIDepth = NDArrayFactory::create_(resultFF->ordering(), gradIDepthShape, resultFF->dataType(), block.launchContext()); // [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW)
|
||||||
|
|
||||||
ConvolutionUtils::conv2dBP(*block.launchContext(), resultFF, weightsPoint, bias, gradO, gradIDepth, gradWP, gradB, 1,1, 1,1, 0,0, 1,1, isSameMode, isNCHW); // in this case oH=iH and oW=iW
|
ConvolutionUtils::conv2dBP(block, resultFF, weightsPoint, bias, gradO, gradIDepth, gradWP, gradB, 1,1, 1,1, 0,0, 1,1, isSameMode, isNCHW); // in this case oH=iH and oW=iW
|
||||||
|
|
||||||
gradO = gradIDepth;
|
gradO = gradIDepth;
|
||||||
bias = gradB = nullptr; // if pointwise backprop was done then don't calculate gradB at depthwise_conv2d_bp step
|
bias = gradB = nullptr; // if pointwise backprop was done then don't calculate gradB at depthwise_conv2d_bp step
|
||||||
|
@ -288,7 +288,7 @@ CUSTOM_OP_IMPL(sconv2d_bp, 3, 2, false, 0, 9) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// ----- apply depthwise_conv2d_bp ----- //
|
// ----- apply depthwise_conv2d_bp ----- //
|
||||||
ConvolutionUtils::depthwiseConv2dBP(*block.launchContext(), input, weightsDepth, bias, gradO, gradI, gradWD, gradB, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW);
|
ConvolutionUtils::depthwiseConv2dBP(block, input, weightsDepth, bias, gradO, gradI, gradWD, gradB, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW);
|
||||||
|
|
||||||
if(weightsPoint)
|
if(weightsPoint)
|
||||||
delete gradO;
|
delete gradO;
|
||||||
|
|
|
@ -41,7 +41,7 @@ CUSTOM_OP_IMPL(upsampling2d, 1, 1, false, 0, 2) {
|
||||||
REQUIRE_TRUE(input->rankOf() == 4, 0, "UPSAMPLING2D op: input should be 4D, but got %i instead!", input->rankOf());
|
REQUIRE_TRUE(input->rankOf() == 4, 0, "UPSAMPLING2D op: input should be 4D, but got %i instead!", input->rankOf());
|
||||||
REQUIRE_TRUE(output->rankOf() == 4, 0, "UPSAMPLING2D op: output should be 4D, but got %i instead!", output->rankOf());
|
REQUIRE_TRUE(output->rankOf() == 4, 0, "UPSAMPLING2D op: output should be 4D, but got %i instead!", output->rankOf());
|
||||||
|
|
||||||
ConvolutionUtils::upsampling2d(*block.launchContext(), *input, *output, factorH, factorW, (bool)isNCHW);
|
ConvolutionUtils::upsampling2d(block, *input, *output, factorH, factorW, (bool)isNCHW);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -105,7 +105,7 @@ CUSTOM_OP_IMPL(upsampling2d_bp, 2, 1, false, 0, 0) {
|
||||||
REQUIRE_TRUE(gradO->rankOf() == 4, 0, "UPSAMPLING2D_BP op: output's gradient array must be 4D, but got %i instead!", gradO->rankOf());
|
REQUIRE_TRUE(gradO->rankOf() == 4, 0, "UPSAMPLING2D_BP op: output's gradient array must be 4D, but got %i instead!", gradO->rankOf());
|
||||||
REQUIRE_TRUE(gradI->rankOf() == 4, 0, "UPSAMPLING2D_BP op: input's gradient array must be 4D, but got %i instead!", gradI->rankOf());
|
REQUIRE_TRUE(gradI->rankOf() == 4, 0, "UPSAMPLING2D_BP op: input's gradient array must be 4D, but got %i instead!", gradI->rankOf());
|
||||||
|
|
||||||
ConvolutionUtils::upsampling2dBP(*block.launchContext(), *gradO, *gradI, (bool)isNCHW);
|
ConvolutionUtils::upsampling2dBP(block, *gradO, *gradI, (bool)isNCHW);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -41,7 +41,7 @@ CUSTOM_OP_IMPL(upsampling3d, 1, 1, false, 0, 3) {
|
||||||
REQUIRE_TRUE(input->rankOf() == 5, 0, "UPSAMPLING3D op: input should be 5D, but got %i instead!", input->rankOf());
|
REQUIRE_TRUE(input->rankOf() == 5, 0, "UPSAMPLING3D op: input should be 5D, but got %i instead!", input->rankOf());
|
||||||
REQUIRE_TRUE(output->rankOf() == 5, 0, "UPSAMPLING3D op: output should be 5D, but got %i instead!", output->rankOf());
|
REQUIRE_TRUE(output->rankOf() == 5, 0, "UPSAMPLING3D op: output should be 5D, but got %i instead!", output->rankOf());
|
||||||
|
|
||||||
ConvolutionUtils::upsampling3d(*block.launchContext(), *input, *output, factorD, factorH, factorW, (bool)isNCDHW);
|
ConvolutionUtils::upsampling3d(block, *input, *output, factorD, factorH, factorW, (bool)isNCDHW);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -105,7 +105,7 @@ CUSTOM_OP_IMPL(upsampling3d_bp, 2, 1, false, 0, 0) {
|
||||||
REQUIRE_TRUE(gradO->rankOf() == 5, 0, "UPSAMPLING3D_BP op: output's gradient array must be 4D, but got %i instead!", gradO->rankOf());
|
REQUIRE_TRUE(gradO->rankOf() == 5, 0, "UPSAMPLING3D_BP op: output's gradient array must be 4D, but got %i instead!", gradO->rankOf());
|
||||||
REQUIRE_TRUE(gradI->rankOf() == 5, 0, "UPSAMPLING3D_BP op: input's gradient array must be 4D, but got %i instead!", gradI->rankOf());
|
REQUIRE_TRUE(gradI->rankOf() == 5, 0, "UPSAMPLING3D_BP op: input's gradient array must be 4D, but got %i instead!", gradI->rankOf());
|
||||||
|
|
||||||
ConvolutionUtils::upsampling3dBP(*block.launchContext(), *gradO, *gradI, (bool)isNCDHW);
|
ConvolutionUtils::upsampling3dBP(block, *gradO, *gradI, (bool)isNCDHW);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -216,7 +216,7 @@ CUSTOM_OP_IMPL(batchnorm_new, 3, 1, false, 1, 2) {
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<Nd4jLong> shape({2, mean->lengthOf()});
|
std::vector<Nd4jLong> shape({2, mean->lengthOf()});
|
||||||
NDArray weights = NDArrayFactory::create<float>('c', shape, block.getWorkspace());
|
NDArray weights = NDArrayFactory::create<float>('c', shape, block.launchContext());
|
||||||
weights({0, 1, 0, 0}).assign(1.0f);
|
weights({0, 1, 0, 0}).assign(1.0f);
|
||||||
weights({1, 2, 0, 0}).assign(0.0f);
|
weights({1, 2, 0, 0}).assign(0.0f);
|
||||||
|
|
||||||
|
|
|
@ -72,6 +72,11 @@ namespace nd4j {
|
||||||
if (dims.size() > 1)
|
if (dims.size() > 1)
|
||||||
std::sort(dims.begin(), dims.end());
|
std::sort(dims.begin(), dims.end());
|
||||||
|
|
||||||
|
|
||||||
|
for (auto d:dims) {
|
||||||
|
REQUIRE_TRUE(inputShape->at(0)[d+1] != 0, 0, "ArgMax: you can't reduce along axis with 0 in shape");
|
||||||
|
}
|
||||||
|
|
||||||
// special case - output is scalar
|
// special case - output is scalar
|
||||||
if (dims.size() == 0 || (dims.size() == 1 && dims.at(0) == MAX_INT)) {
|
if (dims.size() == 0 || (dims.size() == 1 && dims.at(0) == MAX_INT)) {
|
||||||
return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(nd4j::DataType::INT64));
|
return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(nd4j::DataType::INT64));
|
||||||
|
|
|
@ -72,6 +72,10 @@ namespace nd4j {
|
||||||
if (dims.size() > 1)
|
if (dims.size() > 1)
|
||||||
std::sort(dims.begin(), dims.end());
|
std::sort(dims.begin(), dims.end());
|
||||||
|
|
||||||
|
for (auto d:dims) {
|
||||||
|
REQUIRE_TRUE(inputShape->at(0)[d+1] != 0, 0, "ArgMin: you can't reduce along axis with 0 in shape");
|
||||||
|
}
|
||||||
|
|
||||||
// special case - output is scalar
|
// special case - output is scalar
|
||||||
if (dims.size() == 0 || (dims.size() == 1 && dims.at(0) == MAX_INT)) {
|
if (dims.size() == 0 || (dims.size() == 1 && dims.at(0) == MAX_INT)) {
|
||||||
return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(DataType::INT64));
|
return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(DataType::INT64));
|
||||||
|
|
|
@ -71,10 +71,8 @@ namespace nd4j {
|
||||||
ALLOCATE(newShape, block.getWorkspace(), shape::shapeInfoLength(len), Nd4jLong);
|
ALLOCATE(newShape, block.getWorkspace(), shape::shapeInfoLength(len), Nd4jLong);
|
||||||
|
|
||||||
newShape[0] = len;
|
newShape[0] = len;
|
||||||
auto empty = false;
|
|
||||||
for (int e = 0; e < shapeArray->lengthOf(); e++){
|
for (int e = 0; e < shapeArray->lengthOf(); e++){
|
||||||
newShape[e+1] = shapeArray->e<Nd4jLong>(e);
|
newShape[e+1] = shapeArray->e<Nd4jLong>(e);
|
||||||
empty |= (newShape[e+1] == 0); //Support "zeros in shape as empty" for TF import
|
|
||||||
}
|
}
|
||||||
|
|
||||||
nd4j::DataType dataType;
|
nd4j::DataType dataType;
|
||||||
|
@ -90,10 +88,6 @@ namespace nd4j {
|
||||||
} else
|
} else
|
||||||
throw std::runtime_error("Fill: missing value to fill output array with");
|
throw std::runtime_error("Fill: missing value to fill output array with");
|
||||||
|
|
||||||
if(empty){
|
|
||||||
return SHAPELIST(ShapeBuilders::emptyShapeInfo(dataType, block.getWorkspace()));
|
|
||||||
}
|
|
||||||
|
|
||||||
ShapeUtils::updateStridesAndType(newShape, dataType, 'c');
|
ShapeUtils::updateStridesAndType(newShape, dataType, 'c');
|
||||||
|
|
||||||
return SHAPELIST(CONSTANT(newShape));
|
return SHAPELIST(CONSTANT(newShape));
|
||||||
|
|
|
@ -151,8 +151,10 @@ DECLARE_SHAPE_FN(range) {
|
||||||
delta = INPUT_VARIABLE(2)->e<double>(0);
|
delta = INPUT_VARIABLE(2)->e<double>(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (limit == start)
|
if (limit == start){
|
||||||
return SHAPELIST(ConstantShapeHelper::getInstance()->emptyShapeInfo(dtype));
|
//Return [0] to match TF
|
||||||
|
return SHAPELIST(ConstantShapeHelper::getInstance()->vectorShapeInfo(0, dtype));
|
||||||
|
}
|
||||||
|
|
||||||
REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !");
|
REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !");
|
||||||
|
|
||||||
|
@ -177,8 +179,10 @@ DECLARE_SHAPE_FN(range) {
|
||||||
|
|
||||||
//nd4j_printf("Start: [%lld]; Limit: [%lld]; Delta: [%lld];\n", start, limit, delta)
|
//nd4j_printf("Start: [%lld]; Limit: [%lld]; Delta: [%lld];\n", start, limit, delta)
|
||||||
|
|
||||||
if (limit == start)
|
if (limit == start){
|
||||||
return SHAPELIST(ConstantShapeHelper::getInstance()->emptyShapeInfo(dtype));
|
//Return [0] to match TF
|
||||||
|
return SHAPELIST(ConstantShapeHelper::getInstance()->vectorShapeInfo(0, dtype));
|
||||||
|
}
|
||||||
|
|
||||||
REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !");
|
REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !");
|
||||||
|
|
||||||
|
@ -203,8 +207,10 @@ DECLARE_SHAPE_FN(range) {
|
||||||
delta = INT_ARG(2);
|
delta = INT_ARG(2);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (limit == start)
|
if (limit == start){
|
||||||
return SHAPELIST(ConstantShapeHelper::getInstance()->emptyShapeInfo(nd4j::DataType::INT32));
|
//Return [0] to match TF
|
||||||
|
return SHAPELIST(ConstantShapeHelper::getInstance()->vectorShapeInfo(0, nd4j::DataType::INT32));
|
||||||
|
}
|
||||||
|
|
||||||
REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !");
|
REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !");
|
||||||
|
|
||||||
|
@ -233,9 +239,10 @@ DECLARE_SHAPE_FN(range) {
|
||||||
delta = T_ARG(2);
|
delta = T_ARG(2);
|
||||||
}
|
}
|
||||||
|
|
||||||
//REQUIRE_TRUE(limit != start, 0, "CUSTOM RANGE OP: limit and start values should be different, but got both equal to %f !", limit);
|
if (limit == start){
|
||||||
if (limit == start)
|
//Return [0] to match TF
|
||||||
return SHAPELIST(ConstantShapeHelper::getInstance()->emptyShapeInfo(Environment::getInstance()->defaultFloatDataType()));
|
return SHAPELIST(ConstantShapeHelper::getInstance()->vectorShapeInfo(0, Environment::getInstance()->defaultFloatDataType()));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !");
|
REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !");
|
||||||
|
|
|
@ -31,7 +31,8 @@ namespace nd4j {
|
||||||
|
|
||||||
REQUIRE_TRUE(output->isScalar(), 0, "Rank output should be scalar");
|
REQUIRE_TRUE(output->isScalar(), 0, "Rank output should be scalar");
|
||||||
|
|
||||||
output->assign(static_cast<Nd4jLong>(input->rankOf()));
|
// output->assign(static_cast<Nd4jLong>(input->rankOf()));
|
||||||
|
output->assign(input->rankOf());
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -132,17 +132,13 @@ namespace nd4j {
|
||||||
REQUIRE_TRUE(size == -1 || size >= 0, 0, "Invalid size[%i] value: must be positive (or -1 for 'all remaining'), got %i", e, size, inShape[e+1]);
|
REQUIRE_TRUE(size == -1 || size >= 0, 0, "Invalid size[%i] value: must be positive (or -1 for 'all remaining'), got %i", e, size, inShape[e+1]);
|
||||||
REQUIRE_TRUE(start >= 0 && start <= inShape[e+1], 0, "Invalid begin[%i] value: Begin must satisfy 0 <= begin <= size[i], got begin=%i for dimension size %i", e, start, inShape[e+1]);
|
REQUIRE_TRUE(start >= 0 && start <= inShape[e+1], 0, "Invalid begin[%i] value: Begin must satisfy 0 <= begin <= size[i], got begin=%i for dimension size %i", e, start, inShape[e+1]);
|
||||||
REQUIRE_TRUE(start + size <= inShape[e+1], 0, "Slice: interval [%i, %i] is out of bounds for dimension %i with size %i", start, start + size, e, inShape[e+1]);
|
REQUIRE_TRUE(start + size <= inShape[e+1], 0, "Slice: interval [%i, %i] is out of bounds for dimension %i with size %i", start, start + size, e, inShape[e+1]);
|
||||||
if(start == inShape[e+1] || size == 0 ){
|
if(start == inShape[e+1] ){
|
||||||
empty = true;
|
size = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
shape.emplace_back(size);
|
shape.emplace_back(size);
|
||||||
}
|
}
|
||||||
|
|
||||||
if(empty){
|
|
||||||
return SHAPELIST(ShapeBuilders::emptyShapeInfo(nd4j::DataType::INT32, block.getWorkspace()));
|
|
||||||
}
|
|
||||||
|
|
||||||
auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShape), 'c', shape);
|
auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShape), 'c', shape);
|
||||||
return SHAPELIST(newShape);
|
return SHAPELIST(newShape);
|
||||||
}
|
}
|
||||||
|
|
|
@ -34,6 +34,10 @@ CUSTOM_OP_IMPL(stack, -1, 1, false, 0, 0) {
|
||||||
if(dim < 0)
|
if(dim < 0)
|
||||||
dim += input->rankOf() + 1;
|
dim += input->rankOf() + 1;
|
||||||
|
|
||||||
|
// no-op in case of empty output array
|
||||||
|
if (output->isEmpty())
|
||||||
|
return Status::OK();
|
||||||
|
|
||||||
// input validation
|
// input validation
|
||||||
// check whether shapes of all input array are the same
|
// check whether shapes of all input array are the same
|
||||||
for (int i = 0; i < (int) block.width() - 1; ++i)
|
for (int i = 0; i < (int) block.width() - 1; ++i)
|
||||||
|
@ -48,16 +52,6 @@ CUSTOM_OP_IMPL(stack, -1, 1, false, 0, 0) {
|
||||||
|
|
||||||
helpers::stack(block.launchContext(), inArrs, output, dim);
|
helpers::stack(block.launchContext(), inArrs, output, dim);
|
||||||
|
|
||||||
// remove unity from output shape if input arrays are vectors
|
|
||||||
// if(input->isVector()) {
|
|
||||||
// std::vector<int> outShape(output->shapeOf(), output->shapeOf() + output->rankOf());
|
|
||||||
// outShape.erase(find(outShape.begin(), outShape.end(), 1));
|
|
||||||
// output->reshapei(output->ordering(), outShape);
|
|
||||||
// if(dim != 0 && (int)block.width() == 1) // such is implemented by tensorFlow
|
|
||||||
// output->permutei({1, 0});
|
|
||||||
// output->getShapeInfo()[output->rankOf()*2 + 2] = 1;
|
|
||||||
// }
|
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
DECLARE_SYN(pack, stack);
|
DECLARE_SYN(pack, stack);
|
||||||
|
@ -82,8 +76,22 @@ DECLARE_SHAPE_FN(stack) {
|
||||||
|
|
||||||
REQUIRE_TRUE(dim <= inShapeInfo[0], 0, "STACK op: the input dimension parameter must be <= rank of input arrays shapes (rank=%i), but got %i instead !", inShapeInfo[0], dim);
|
REQUIRE_TRUE(dim <= inShapeInfo[0], 0, "STACK op: the input dimension parameter must be <= rank of input arrays shapes (rank=%i), but got %i instead !", inShapeInfo[0], dim);
|
||||||
|
|
||||||
|
// empty input arrays require some special handling
|
||||||
|
if (shape::isEmpty(inShapeInfo)) {
|
||||||
|
switch (rank) {
|
||||||
|
case 0: {
|
||||||
|
// we're going to return rank 1 here
|
||||||
|
if (block.width() == 1) {
|
||||||
|
return SHAPELIST(ConstantShapeHelper::getInstance()->vectorShapeInfo(0, ArrayOptions::dataType(inShapeInfo)));
|
||||||
|
} else {
|
||||||
|
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShapeInfo), 'c', {(Nd4jLong) block.width(), 0}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if(rank == 0) {
|
if(rank == 0) {
|
||||||
return SHAPELIST(ConstantShapeHelper::getInstance()->vectorShapeInfo(block.width(), ArrayOptions::dataType(inShapeInfo)));
|
return SHAPELIST(ConstantShapeHelper::getInstance()->vectorShapeInfo(block.width(), ArrayOptions::dataType(inShapeInfo)));
|
||||||
}
|
}
|
||||||
|
|
||||||
//the rank of output ShapeInfo is larger by one compared to input ShapeInfo
|
//the rank of output ShapeInfo is larger by one compared to input ShapeInfo
|
||||||
|
@ -91,13 +99,9 @@ DECLARE_SHAPE_FN(stack) {
|
||||||
|
|
||||||
// insert (int) block.width() at dim position of input shape to get output shape
|
// insert (int) block.width() at dim position of input shape to get output shape
|
||||||
outShape.insert(outShape.begin() + Nd4jLong(dim), (Nd4jLong) block.width());
|
outShape.insert(outShape.begin() + Nd4jLong(dim), (Nd4jLong) block.width());
|
||||||
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(inShapeInfo), shape::order(inShapeInfo), outShape)));
|
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(inShapeInfo), shape::order(inShapeInfo), outShape)));
|
||||||
}
|
}
|
||||||
|
|
||||||
// 1) 1х4 + 1х4 = 2х1х4 (along dim=0) = 2x4
|
|
||||||
// 2) 1х4 + 1х4 = 1х2х4 (along dim=1) = 2x4
|
|
||||||
// 3) 4х1 + 4х1 = 2х4x1 (along dim=0) = 2x4
|
|
||||||
// 4) 4х1 + 4х1 = 4х2x1 (along dim=1) = 4x2
|
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -410,10 +410,13 @@ namespace nd4j {
|
||||||
// z->assign(x->e<float>(indices[0]));
|
// z->assign(x->e<float>(indices[0]));
|
||||||
// }
|
// }
|
||||||
// else {
|
// else {
|
||||||
auto sub = (*x)(indices, true, true);
|
if (indices.size()) {
|
||||||
z->assign(sub);
|
auto sub = (*x)(indices, true, true);
|
||||||
// }
|
z->assign(sub);
|
||||||
|
}
|
||||||
|
else if (!z->isEmpty()){
|
||||||
|
z->assign(x->e(0));
|
||||||
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
DECLARE_SYN(stridedslice, strided_slice);
|
DECLARE_SYN(stridedslice, strided_slice);
|
||||||
|
@ -496,28 +499,19 @@ namespace nd4j {
|
||||||
bool is_simple_slice;
|
bool is_simple_slice;
|
||||||
bool is_dim0;
|
bool is_dim0;
|
||||||
|
|
||||||
// FIXME: remove this, once we bring in 1D NDArrays
|
std::vector<Nd4jLong> indices;
|
||||||
//vectorize(input_shape);
|
bool result = _preprocess_strided_slice(&indices, &shape, input_shape, begin, end, strides, begin_mask, ellipsis_mask, end_mask, new_axis_mask, shrink_axis_mask, &is_identity, &is_simple_slice, &is_dim0);
|
||||||
bool result = _preprocess_strided_slice(nullptr, &shape, input_shape, begin, end, strides, begin_mask, ellipsis_mask, end_mask, new_axis_mask, shrink_axis_mask, &is_identity, &is_simple_slice, &is_dim0);
|
if (indices.size()) {
|
||||||
bool nonEmpty = shape.size() > 0;
|
newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShape), 'c',
|
||||||
if (nonEmpty)
|
shape);
|
||||||
for (auto x: shape) {
|
if (inputLen > 1) {
|
||||||
if (x == 0) {
|
newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShape), 'c',
|
||||||
nonEmpty = false;
|
shape);
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (nonEmpty && inputLen > 1) {
|
|
||||||
newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShape), 'c', shape);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
if (shape::rank(inShape) == 0 || begin >= end) {
|
|
||||||
newShape = ConstantShapeHelper::getInstance()->emptyShapeInfo(ArrayOptions::dataType(inShape));
|
|
||||||
} else {
|
} else {
|
||||||
newShape = ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(inShape));
|
newShape = ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(inShape));
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
} else
|
||||||
|
newShape = ConstantShapeHelper::getInstance()->emptyShapeInfo(ArrayOptions::dataType(inShape));
|
||||||
|
|
||||||
return SHAPELIST(newShape);
|
return SHAPELIST(newShape);
|
||||||
}
|
}
|
||||||
|
|
|
@ -37,6 +37,9 @@ namespace nd4j {
|
||||||
REQUIRE_TRUE(dim < input->rankOf(), 0, "Unstack dimension should be lower then rank of input %i, but got dimension=%i !", input->rankOf(), dim);
|
REQUIRE_TRUE(dim < input->rankOf(), 0, "Unstack dimension should be lower then rank of input %i, but got dimension=%i !", input->rankOf(), dim);
|
||||||
REQUIRE_TRUE(dim >= 0, 0, "Unstack dimension should be non-negative value, but got %i !", dim);
|
REQUIRE_TRUE(dim >= 0, 0, "Unstack dimension should be non-negative value, but got %i !", dim);
|
||||||
|
|
||||||
|
if(input->isEmpty())
|
||||||
|
return Status::OK();
|
||||||
|
|
||||||
std::vector<int> dims;
|
std::vector<int> dims;
|
||||||
for (int e = 0; e < input->rankOf(); e++)
|
for (int e = 0; e < input->rankOf(); e++)
|
||||||
if (e != dim)
|
if (e != dim)
|
||||||
|
@ -76,6 +79,21 @@ namespace nd4j {
|
||||||
REQUIRE_TRUE(dim < inShape[0], 0, "UNSTACK op: dimension should be lower then rank of input %i, but got dimension=%i !", inShape[0], dim);
|
REQUIRE_TRUE(dim < inShape[0], 0, "UNSTACK op: dimension should be lower then rank of input %i, but got dimension=%i !", inShape[0], dim);
|
||||||
REQUIRE_TRUE(dim >= 0, 0, "UNSTACK op: dimension should be non-negative value, but got %i !", dim);
|
REQUIRE_TRUE(dim >= 0, 0, "UNSTACK op: dimension should be non-negative value, but got %i !", dim);
|
||||||
|
|
||||||
|
if(ArrayOptions::arrayType(inShape) == ArrayType::EMPTY) {
|
||||||
|
if(shape::shapeOf(inShape)[dim] == 0)
|
||||||
|
return SHAPELIST();
|
||||||
|
const Nd4jLong numTads = shape::shapeOf(inShape)[dim];
|
||||||
|
std::vector<Nd4jLong> outShape;
|
||||||
|
for(uint i = 0; i < shape::rank(inShape); ++i)
|
||||||
|
if(i != dim)
|
||||||
|
outShape.push_back(shape::shapeOf(inShape)[i]);
|
||||||
|
|
||||||
|
auto result = SHAPELIST();
|
||||||
|
for(uint i = 0; i < numTads; ++i)
|
||||||
|
result->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShape), shape::order(inShape), outShape));
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<int> dims;
|
std::vector<int> dims;
|
||||||
for (int e = 0; e < shape::rank(inShape); e++)
|
for (int e = 0; e < shape::rank(inShape); e++)
|
||||||
if (e != dim)
|
if (e != dim)
|
||||||
|
|
|
@ -30,6 +30,12 @@ namespace nd4j {
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
REQUIRE_TRUE(output->isScalar(), 0, "Rank output should be scalar");
|
REQUIRE_TRUE(output->isScalar(), 0, "Rank output should be scalar");
|
||||||
|
|
||||||
|
if(input->isEmpty()){
|
||||||
|
output->p<double>(0, std::numeric_limits<double>::quiet_NaN());
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
int numZeros = 0;
|
int numZeros = 0;
|
||||||
// for (int e = 0; e < input->lengthOf(); e++)
|
// for (int e = 0; e < input->lengthOf(); e++)
|
||||||
// if ((*input)(e) == T(0))
|
// if ((*input)(e) == T(0))
|
||||||
|
|
|
@ -113,16 +113,10 @@ DECLARE_SHAPE_FN(lstmBlock) {
|
||||||
}
|
}
|
||||||
ShapeUtils::updateStridesAndType(s, x, 'c');
|
ShapeUtils::updateStridesAndType(s, x, 'c');
|
||||||
|
|
||||||
Nd4jLong *s1, *s2, *s3, *s4, *s5, *s6;
|
Nd4jLong *s1 = CONSTANT(s);
|
||||||
COPY_SHAPE(s, s1);
|
|
||||||
COPY_SHAPE(s, s2);
|
|
||||||
COPY_SHAPE(s, s3);
|
|
||||||
COPY_SHAPE(s, s4);
|
|
||||||
COPY_SHAPE(s, s5);
|
|
||||||
COPY_SHAPE(s, s6);
|
|
||||||
|
|
||||||
//7 outputs, all same shape/type
|
//7 outputs, all same shape/type
|
||||||
return SHAPELIST(s, s1, s2, s3, s4, s5, s6);
|
return SHAPELIST(s1, s1, s1, s1, s1, s1, s1);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -115,16 +115,10 @@ DECLARE_SHAPE_FN(lstmBlockCell) {
|
||||||
|
|
||||||
ShapeUtils::updateStridesAndType(s, xt, 'c');
|
ShapeUtils::updateStridesAndType(s, xt, 'c');
|
||||||
|
|
||||||
Nd4jLong *s1, *s2, *s3, *s4, *s5, *s6;
|
Nd4jLong *s1 = CONSTANT(s);
|
||||||
COPY_SHAPE(s, s1);
|
|
||||||
COPY_SHAPE(s, s2);
|
|
||||||
COPY_SHAPE(s, s3);
|
|
||||||
COPY_SHAPE(s, s4);
|
|
||||||
COPY_SHAPE(s, s5);
|
|
||||||
COPY_SHAPE(s, s6);
|
|
||||||
|
|
||||||
//7 outputs, all same shape: z, i, f, o, h, c, y
|
//7 outputs, all same shape: z, i, f, o, h, c, y
|
||||||
return SHAPELIST(s, s1, s2, s3, s4, s5, s6);
|
return SHAPELIST(s1, s1, s1, s1, s1, s1, s1);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -78,7 +78,6 @@ DECLARE_SHAPE_FN(broadcast_to) {
|
||||||
REQUIRE_TRUE(inputShapeInfo[inputRank+1-i] == outShape[shapeLen-i] || inputShapeInfo[inputRank+1-i] == 1, 0, "BROADCAST_TO op: shape of input array %s can't be broadcasted to the shape %s !", ShapeUtils::shapeAsString(inputShapeInfo).c_str(), ShapeUtils::shapeAsString(outShape).c_str());
|
REQUIRE_TRUE(inputShapeInfo[inputRank+1-i] == outShape[shapeLen-i] || inputShapeInfo[inputRank+1-i] == 1, 0, "BROADCAST_TO op: shape of input array %s can't be broadcasted to the shape %s !", ShapeUtils::shapeAsString(inputShapeInfo).c_str(), ShapeUtils::shapeAsString(outShape).c_str());
|
||||||
|
|
||||||
auto outShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inputShapeInfo), shape::order(inputShapeInfo), outShape);
|
auto outShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inputShapeInfo), shape::order(inputShapeInfo), outShape);
|
||||||
|
|
||||||
return SHAPELIST(outShapeInfo);
|
return SHAPELIST(outShapeInfo);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -34,26 +34,26 @@ namespace nd4j {
|
||||||
|
|
||||||
bool replace = false;
|
bool replace = false;
|
||||||
|
|
||||||
auto arguments = block.getIArguments();
|
auto origArgs = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT<int>() : *block.getIArguments();
|
||||||
if (block.width() == 2 && arguments->size() == 0) {
|
std::vector<int> arguments({});
|
||||||
auto axis = INPUT_VARIABLE(1);
|
if(origArgs.size() > 0){
|
||||||
for (int e = 0; e < axis->lengthOf(); e++) {
|
for (int e = 0; e < origArgs.size(); e++) {
|
||||||
int ax = axis->e<int>(e);
|
int ax = origArgs[e];
|
||||||
if (ax < 0)
|
if (ax < 0)
|
||||||
ax += x->rankOf();
|
ax += x->rankOf();
|
||||||
|
|
||||||
arguments->emplace_back(ax);
|
arguments.emplace_back(ax);
|
||||||
}
|
}
|
||||||
|
|
||||||
replace = true;
|
replace = true;
|
||||||
} else if (arguments->size() == 0) {
|
} else {
|
||||||
for (int e = x->rankOf() - 1; e >= 0; e--)
|
for (int e = x->rankOf() - 1; e >= 0; e--)
|
||||||
arguments->emplace_back(e);
|
arguments.emplace_back(e);
|
||||||
}
|
}
|
||||||
|
|
||||||
// 0D edge case
|
// 0D edge case
|
||||||
if (x->rankOf() == 0) {
|
if (x->rankOf() == 0) {
|
||||||
REQUIRE_TRUE(arguments->size() == 1, 0, "Permute: only one axis is allowed for scalar");
|
REQUIRE_TRUE(arguments.size() == 1, 0, "Permute: only one axis is allowed for scalar");
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
if (!block.isInplace())
|
if (!block.isInplace())
|
||||||
output->assign(x);
|
output->assign(x);
|
||||||
|
@ -62,25 +62,17 @@ namespace nd4j {
|
||||||
}
|
}
|
||||||
|
|
||||||
if(block.isInplace()) { // in-place
|
if(block.isInplace()) { // in-place
|
||||||
x->permutei(*arguments);
|
x->permutei(arguments);
|
||||||
STORE_RESULT(x);
|
STORE_RESULT(x);
|
||||||
} else {
|
} else {
|
||||||
if (!replace) { // not-in-place
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto result = x->permute(arguments);
|
||||||
// nd4j_printv("permute shape", *arguments);
|
output->assign(result);
|
||||||
auto result = x->permute(*arguments);
|
STORE_RESULT(output);
|
||||||
output->assign(result);
|
delete result;
|
||||||
STORE_RESULT(output);
|
|
||||||
delete result;
|
|
||||||
} else {
|
|
||||||
auto output = OUTPUT_VARIABLE(0); //->dup();
|
|
||||||
output->assign(x);
|
|
||||||
output->permutei(*arguments);
|
|
||||||
|
|
||||||
//OVERWRITE_RESULT(output);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return Status::OK();
|
|
||||||
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
DECLARE_TYPES(permute) {
|
DECLARE_TYPES(permute) {
|
||||||
|
@ -92,20 +84,21 @@ namespace nd4j {
|
||||||
|
|
||||||
DECLARE_SHAPE_FN(permute) {
|
DECLARE_SHAPE_FN(permute) {
|
||||||
auto shapeList = SHAPELIST();
|
auto shapeList = SHAPELIST();
|
||||||
auto arguments = block.getIArguments();
|
auto arguments = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT<int>() : *block.getIArguments();
|
||||||
|
|
||||||
if (shape::rank(inputShape->at(0)) == 0) {
|
if (shape::rank(inputShape->at(0)) == 0) {
|
||||||
shapeList->push_back(ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(inputShape->at(0))));
|
shapeList->push_back(ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(inputShape->at(0))));
|
||||||
} else if (inputShape->size() == 1 && !arguments->empty()) {
|
} else if (inputShape->size() == 1 && !arguments.empty()) {
|
||||||
shapeList->push_back(ShapeUtils::evalPermShapeInfo(arguments->data(), arguments->size(), *INPUT_VARIABLE(0), block.workspace()));
|
shapeList->push_back(ShapeUtils::evalPermShapeInfo(arguments.data(), arguments.size(), *INPUT_VARIABLE(0), block.workspace()));
|
||||||
} else if (inputShape->size() == 2) {
|
|
||||||
// dead end
|
|
||||||
shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(inputShape->at(0))));
|
|
||||||
} else {
|
} else {
|
||||||
int rank = shape::rank(inputShape->at(0));
|
if(arguments.size() == 0){
|
||||||
for (int e = rank - 1; e >= 0; e--)
|
//Reverse dimensions
|
||||||
arguments->emplace_back(e);
|
int rank = shape::rank(inputShape->at(0));
|
||||||
|
for (int e = rank - 1; e >= 0; e--)
|
||||||
|
arguments.emplace_back(e);
|
||||||
|
}
|
||||||
|
|
||||||
shapeList->push_back(ShapeUtils::evalPermShapeInfo(arguments->data(), arguments->size(), *INPUT_VARIABLE(0), block.workspace()));
|
shapeList->push_back(ShapeUtils::evalPermShapeInfo(arguments.data(), arguments.size(), *INPUT_VARIABLE(0), block.workspace()));
|
||||||
}
|
}
|
||||||
|
|
||||||
return shapeList;
|
return shapeList;
|
||||||
|
|
|
@ -35,9 +35,8 @@ namespace nd4j {
|
||||||
auto arguments = block.getIArguments();
|
auto arguments = block.getIArguments();
|
||||||
int argsSize = arguments->size();
|
int argsSize = arguments->size();
|
||||||
|
|
||||||
//Special case: empty.reshape(-1) -> return empty
|
//Special case: empty.reshape(<other empty shape>) -> return empty
|
||||||
if (x->isEmpty()) {
|
if (x->isEmpty()) {
|
||||||
REQUIRE_TRUE((int) arguments->size() == 1 && arguments->at(0) == -1, 0, "Reshape: when input is empty, iargs must be [-1]");
|
|
||||||
REQUIRE_TRUE(OUTPUT_VARIABLE(0)->isEmpty(), 0, "Reshape: when input is empty, output must also be empty");
|
REQUIRE_TRUE(OUTPUT_VARIABLE(0)->isEmpty(), 0, "Reshape: when input is empty, output must also be empty");
|
||||||
return ND4J_STATUS_OK; //No op
|
return ND4J_STATUS_OK; //No op
|
||||||
}
|
}
|
||||||
|
@ -96,9 +95,9 @@ namespace nd4j {
|
||||||
|
|
||||||
//Special case: empty.reshape(-1) -> return empty
|
//Special case: empty.reshape(-1) -> return empty
|
||||||
if (x->isEmpty()) {
|
if (x->isEmpty()) {
|
||||||
REQUIRE_TRUE(s->lengthOf() == 1 && s->e<Nd4jLong>(0) == -1, 0, "Reshape: when input is empty, shape must be [-1]");
|
//REQUIRE_TRUE(s->lengthOf() == 1 && s->e<Nd4jLong>(0) == -1, 0, "Reshape: when input is empty, shape must be [-1]");
|
||||||
REQUIRE_TRUE(OUTPUT_VARIABLE(0)->isEmpty(), 0, "Reshape: when input is empty, output must also be empty");
|
REQUIRE_TRUE(OUTPUT_VARIABLE(0)->isEmpty(), 0, "Reshape: when input is empty, output must also be empty");
|
||||||
return ND4J_STATUS_OK; //No op
|
return Status::OK(); //No op
|
||||||
}
|
}
|
||||||
|
|
||||||
char order = 'c';
|
char order = 'c';
|
||||||
|
@ -116,7 +115,8 @@ namespace nd4j {
|
||||||
}
|
}
|
||||||
for(int e2 = e + 1; e2 < (int) s->lengthOf(); e2++){
|
for(int e2 = e + 1; e2 < (int) s->lengthOf(); e2++){
|
||||||
REQUIRE_TRUE(s->e<Nd4jLong>(e2) != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
|
REQUIRE_TRUE(s->e<Nd4jLong>(e2) != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
|
||||||
shapeLength *= s->e<Nd4jLong>(e2);
|
shapeLength *=
|
||||||
|
s->e<Nd4jLong>(e2);
|
||||||
}
|
}
|
||||||
long realShape = x->lengthOf() / shapeLength;
|
long realShape = x->lengthOf() / shapeLength;
|
||||||
shapeNew[e] = realShape;
|
shapeNew[e] = realShape;
|
||||||
|
@ -175,12 +175,12 @@ namespace nd4j {
|
||||||
e = 0;
|
e = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
//Special case: empty.reshape(-1) -> return empty
|
// //Special case: empty.reshape(-1) -> return empty
|
||||||
if (INPUT_VARIABLE(0)->isEmpty()) {
|
// if (INPUT_VARIABLE(0)->isEmpty()) {
|
||||||
REQUIRE_TRUE((int) arguments->size() == 1 && arguments->at(0) == -1, 0, "Reshape: when input is empty, iargs must be [-1]");
|
// //
|
||||||
auto newShape = ConstantShapeHelper::getInstance()->emptyShapeInfo(ArrayOptions::dataType(inp));
|
// auto newShape = ConstantShapeHelper::getInstance()->emptyShapeInfo(ArrayOptions::dataType(inp));
|
||||||
return SHAPELIST(newShape);
|
// return SHAPELIST(newShape);
|
||||||
}
|
// }
|
||||||
|
|
||||||
std::vector<Nd4jLong> shapeNew;
|
std::vector<Nd4jLong> shapeNew;
|
||||||
|
|
||||||
|
@ -197,8 +197,14 @@ namespace nd4j {
|
||||||
shapeLength *= arguments->at(e2);
|
shapeLength *= arguments->at(e2);
|
||||||
}
|
}
|
||||||
|
|
||||||
long realShape = shape::length(inp) / shapeLength;
|
if(shapeLength == 0){
|
||||||
shapeNew.push_back(realShape);
|
//Edge case for empty:
|
||||||
|
shapeNew.push_back(0);
|
||||||
|
} else {
|
||||||
|
//Standard case
|
||||||
|
long realShape = shape::length(inp) / shapeLength;
|
||||||
|
shapeNew.push_back(realShape);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
else{
|
else{
|
||||||
shapeNew.push_back(arguments->at(e));
|
shapeNew.push_back(arguments->at(e));
|
||||||
|
@ -218,9 +224,16 @@ namespace nd4j {
|
||||||
}
|
}
|
||||||
//Special case: empty.reshape(-1) -> return empty
|
//Special case: empty.reshape(-1) -> return empty
|
||||||
if (x->isEmpty()) {
|
if (x->isEmpty()) {
|
||||||
REQUIRE_TRUE(y->lengthOf() == 1 && y->e<Nd4jLong>(0) == -1, 0, "Reshape: when input is empty, shape must be [-1]");
|
//REQUIRE_TRUE(y->lengthOf() == 1 && y->e<Nd4jLong>(0) == -1, 0, "Reshape: when input is empty, shape must be [-1]");
|
||||||
auto newShape = ConstantShapeHelper::getInstance()->emptyShapeInfo(ArrayOptions::dataType(inp));
|
auto shapeOf = y->getBufferAsVector<Nd4jLong>();
|
||||||
return SHAPELIST(newShape);
|
Nd4jLong prod = 1;
|
||||||
|
for (auto v:shapeOf)
|
||||||
|
prod *= v;
|
||||||
|
|
||||||
|
REQUIRE_TRUE(prod == 0, 0, "Reshape: in case of empty arrays reshape must return empty array as well");
|
||||||
|
|
||||||
|
auto newShape = ShapeBuilders::createShapeInfo(ArrayOptions::dataType(inp), shape::order(inp), y->lengthOf(), shapeOf.data());
|
||||||
|
return SHAPELIST(CONSTANT(newShape));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<Nd4jLong> shapeNew(y->lengthOf());
|
std::vector<Nd4jLong> shapeNew(y->lengthOf());
|
||||||
|
@ -236,8 +249,14 @@ namespace nd4j {
|
||||||
REQUIRE_TRUE(y->e<Nd4jLong>(e2) != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
|
REQUIRE_TRUE(y->e<Nd4jLong>(e2) != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
|
||||||
shapeLength *= y->e<Nd4jLong>(e2);
|
shapeLength *= y->e<Nd4jLong>(e2);
|
||||||
}
|
}
|
||||||
long realShape = shape::length(inp) / shapeLength;
|
|
||||||
shapeNew[e] = realShape;
|
if(shapeLength == 0){
|
||||||
|
//Edge case for empty:
|
||||||
|
shapeNew[e] = 0;
|
||||||
|
} else {
|
||||||
|
long realShape = shape::length(inp) / shapeLength;
|
||||||
|
shapeNew[e] = realShape;
|
||||||
|
}
|
||||||
}else {
|
}else {
|
||||||
shapeNew[e] = dim;
|
shapeNew[e] = dim;
|
||||||
}
|
}
|
||||||
|
|
|
@ -38,21 +38,26 @@ CUSTOM_OP_IMPL(concat, -1, 1, false, 0, 1) {
|
||||||
std::vector<int> arrsToDelete;
|
std::vector<int> arrsToDelete;
|
||||||
int index = 0;
|
int index = 0;
|
||||||
bool allOfSameType = true;
|
bool allOfSameType = true;
|
||||||
|
auto theFirstRank = block.width() > 0?INPUT_VARIABLE(0)->rankOf():0;
|
||||||
|
auto theFirstDatatype = block.width() > 0?INPUT_VARIABLE(0)->dataType():block.dataType();
|
||||||
for(int i = 0; i < block.width(); ++i) {
|
for(int i = 0; i < block.width(); ++i) {
|
||||||
|
auto input = INPUT_VARIABLE(i);
|
||||||
|
auto currentRank = input->rankOf();
|
||||||
|
|
||||||
if(!INPUT_VARIABLE(i)->isEmpty()) {
|
// TODO: follow two lines are accordingly with current tf.concat spec. Commented for compatibility with legacy
|
||||||
|
// REQUIRE_TRUE(currentRank > 0, 0, "Rank of input variable %i must be greater 0, but is %lld instead.", i, currentRank);
|
||||||
|
// REQUIRE_TRUE(theFirstRank == currentRank, 0, "Number of dimensions in concat should be equals, but for %i input variable %lld != %lld appears.", i, currentRank, theFirstRank);
|
||||||
|
if(!input->isEmpty()) {
|
||||||
|
|
||||||
allOfSameType &= (INPUT_VARIABLE(0)->dataType() == INPUT_VARIABLE(i)->dataType());
|
allOfSameType &= (theFirstDatatype == input->dataType());
|
||||||
if(INPUT_VARIABLE(i)->rankOf() == 0) {
|
if(input->rankOf() == 0) {
|
||||||
// FIXME, use this instead: block.dataType()
|
auto vec = new NDArray('c', {1}, input->dataType(), block.launchContext());
|
||||||
auto vec = new NDArray('c', {1}, INPUT_VARIABLE(0)->dataType(), block.launchContext());
|
vec->assign(input);
|
||||||
vec->assign(INPUT_VARIABLE(i));
|
|
||||||
nonEmptyArrs.push_back(vec);
|
nonEmptyArrs.push_back(vec);
|
||||||
arrsToDelete.push_back(index);
|
arrsToDelete.push_back(index);
|
||||||
}
|
}
|
||||||
else{
|
else{
|
||||||
nonEmptyArrs.push_back(INPUT_VARIABLE(i));
|
nonEmptyArrs.push_back(input);
|
||||||
}
|
}
|
||||||
++index;
|
++index;
|
||||||
}
|
}
|
||||||
|
@ -113,33 +118,24 @@ DECLARE_SHAPE_FN(concat) {
|
||||||
|
|
||||||
// first of all take into account possible presence of empty arrays
|
// first of all take into account possible presence of empty arrays
|
||||||
// also if scalar is present -> use the shape of vector with length=1 instead
|
// also if scalar is present -> use the shape of vector with length=1 instead
|
||||||
std::vector<Nd4jLong*> nonEmptyArrShapes;
|
std::vector<Nd4jLong*> arrShapes;
|
||||||
std::vector<int> shapesToDelete;
|
std::vector<int> shapesToDelete;
|
||||||
int index = 0;
|
int index = 0;
|
||||||
for(int i = 0; i < block.width(); ++i) {
|
for(int i = 0; i < block.width(); ++i) {
|
||||||
|
|
||||||
if(!INPUT_VARIABLE(i)->isEmpty()) {
|
if(inputShape->at(i)[0] == 0) {
|
||||||
|
// FIXME, use this instead: block.dataType()
|
||||||
if(inputShape->at(i)[0] == 0) {
|
arrShapes.push_back(ConstantShapeHelper::getInstance()->vectorShapeInfo(1, INPUT_VARIABLE(0)->dataType()));
|
||||||
// FIXME, use this instead: block.dataType()
|
|
||||||
nonEmptyArrShapes.push_back(ConstantShapeHelper::getInstance()->vectorShapeInfo(1, INPUT_VARIABLE(0)->dataType()));
|
|
||||||
}
|
|
||||||
else{
|
|
||||||
nonEmptyArrShapes.push_back(inputShape->at(i));
|
|
||||||
}
|
|
||||||
++index;
|
|
||||||
}
|
}
|
||||||
|
else{
|
||||||
|
arrShapes.push_back(inputShape->at(i));
|
||||||
|
}
|
||||||
|
++index;
|
||||||
}
|
}
|
||||||
|
|
||||||
const int numOfArrs = nonEmptyArrShapes.size();
|
const int numOfArrs = arrShapes.size();
|
||||||
|
|
||||||
if(numOfArrs == 0){
|
const int rank = arrShapes[0][0];
|
||||||
//All inputs are empty arrays -> return empty, mainly for TF import compatibility
|
|
||||||
auto empty = ConstantShapeHelper::getInstance()->emptyShapeInfo(INPUT_VARIABLE(0)->dataType());
|
|
||||||
return SHAPELIST(empty);
|
|
||||||
}
|
|
||||||
|
|
||||||
const int rank = nonEmptyArrShapes[0][0]; // look up to first non-empty array
|
|
||||||
|
|
||||||
int axis = INT_ARG(0);
|
int axis = INT_ARG(0);
|
||||||
if(axis < 0)
|
if(axis < 0)
|
||||||
|
@ -149,33 +145,33 @@ DECLARE_SHAPE_FN(concat) {
|
||||||
REQUIRE_TRUE(0 <= axis && axis < rank, 0, "CONCAT op: input axis must be in range [0, %i], but got %i instead!", rank-1, axis);
|
REQUIRE_TRUE(0 <= axis && axis < rank, 0, "CONCAT op: input axis must be in range [0, %i], but got %i instead!", rank-1, axis);
|
||||||
|
|
||||||
for(int i = 1; i < numOfArrs; ++i)
|
for(int i = 1; i < numOfArrs; ++i)
|
||||||
REQUIRE_TRUE(nonEmptyArrShapes[i][0] == rank, 0, "CONCAT op: all input arrays must have the same rank !");
|
REQUIRE_TRUE(arrShapes[i][0] == rank, 0, "CONCAT op: all input arrays must have the same rank !");
|
||||||
|
|
||||||
for(int i = 1; i < numOfArrs; ++i) {
|
for(int i = 1; i < numOfArrs; ++i) {
|
||||||
for(int dim = 0; dim < rank; ++dim)
|
for(int dim = 0; dim < rank; ++dim)
|
||||||
if(dim != axis)
|
if(dim != axis)
|
||||||
REQUIRE_TRUE(nonEmptyArrShapes[i][dim+1] == nonEmptyArrShapes[0][dim+1], 0, "CONCAT op: all input arrays must have the same dimensions (except those on input axis) !");
|
REQUIRE_TRUE(arrShapes[i][dim+1] == arrShapes[0][dim+1], 0, "CONCAT op: all input arrays must have the same dimensions (except those on input axis) !");
|
||||||
}
|
}
|
||||||
// ******** end of input validation ******** //
|
// ******** end of input validation ******** //
|
||||||
|
|
||||||
|
|
||||||
Nd4jLong* outShapeInfo(nullptr);
|
Nd4jLong* outShapeInfo(nullptr);
|
||||||
COPY_SHAPE(nonEmptyArrShapes[0], outShapeInfo);
|
COPY_SHAPE(arrShapes[0], outShapeInfo);
|
||||||
|
|
||||||
// case when we have only one input array
|
// case when we have only one input array
|
||||||
if(numOfArrs == 1) {
|
if(numOfArrs == 1) {
|
||||||
ShapeUtils::updateStridesAndType(outShapeInfo, nonEmptyArrShapes[0], shape::order(nonEmptyArrShapes[0]));
|
ShapeUtils::updateStridesAndType(outShapeInfo, arrShapes[0], shape::order(arrShapes[0]));
|
||||||
return SHAPELIST(CONSTANT(outShapeInfo));
|
return SHAPELIST(CONSTANT(outShapeInfo));
|
||||||
}
|
}
|
||||||
|
|
||||||
for(int i = 1; i < numOfArrs; ++i)
|
for(int i = 1; i < numOfArrs; ++i)
|
||||||
outShapeInfo[axis + 1] += nonEmptyArrShapes[i][axis + 1];
|
outShapeInfo[axis + 1] += arrShapes[i][axis + 1];
|
||||||
|
|
||||||
ShapeUtils::updateStridesAndType(outShapeInfo, nonEmptyArrShapes[0], shape::order(nonEmptyArrShapes[0]));
|
ShapeUtils::updateStridesAndType(outShapeInfo, arrShapes[0], shape::order(arrShapes[0]));
|
||||||
|
|
||||||
// delete dynamically allocated vectors shapes with length=1
|
// delete dynamically allocated vectors shapes with length=1
|
||||||
for(int index : shapesToDelete)
|
for(int index : shapesToDelete)
|
||||||
RELEASE(nonEmptyArrShapes[index], block.getWorkspace());
|
RELEASE(arrShapes[index], block.getWorkspace());
|
||||||
|
|
||||||
auto result = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(outShapeInfo));
|
auto result = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(outShapeInfo));
|
||||||
RELEASE(outShapeInfo, block.getWorkspace());
|
RELEASE(outShapeInfo, block.getWorkspace());
|
||||||
|
|
|
@ -32,6 +32,11 @@ namespace nd4j {
|
||||||
|
|
||||||
REQUIRE_TRUE(input->dataType() == output->dataType(), 0, "CumSum: input and output data types must be equal");
|
REQUIRE_TRUE(input->dataType() == output->dataType(), 0, "CumSum: input and output data types must be equal");
|
||||||
|
|
||||||
|
if(input->isEmpty()){
|
||||||
|
//No-op
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
const bool exclusive = INT_ARG(0) == 1;
|
const bool exclusive = INT_ARG(0) == 1;
|
||||||
const bool reverse = INT_ARG(1) == 1;
|
const bool reverse = INT_ARG(1) == 1;
|
||||||
|
|
||||||
|
|
|
@ -36,6 +36,11 @@ CONFIGURABLE_OP_IMPL(cumsum, 1, 1, true, 0, 2) {
|
||||||
|
|
||||||
REQUIRE_TRUE(input->dataType() == output->dataType(), 0, "CumSum: input and output data types must be equal");
|
REQUIRE_TRUE(input->dataType() == output->dataType(), 0, "CumSum: input and output data types must be equal");
|
||||||
|
|
||||||
|
if(input->isEmpty()){
|
||||||
|
//No-op
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
if (block.getIArguments()->size() == 2 && block.width() == 1) {
|
if (block.getIArguments()->size() == 2 && block.width() == 1) {
|
||||||
// all at once case
|
// all at once case
|
||||||
nd4j::ops::helpers::_prefix(block.launchContext(), scalar::Add, input, output, exclusive, reverse);
|
nd4j::ops::helpers::_prefix(block.launchContext(), scalar::Add, input, output, exclusive, reverse);
|
||||||
|
|
|
@ -102,12 +102,6 @@ DECLARE_SHAPE_FN(gather) {
|
||||||
if(axis < 0)
|
if(axis < 0)
|
||||||
axis += inputRank;
|
axis += inputRank;
|
||||||
|
|
||||||
//Edge case: empty indices, empty input -> empty output
|
|
||||||
if(block.width() > 1 && INPUT_VARIABLE(0)->isEmpty() && INPUT_VARIABLE(1)->isEmpty()){
|
|
||||||
auto empty = ConstantShapeHelper::getInstance()->emptyShapeInfo(INPUT_VARIABLE(0)->dataType());
|
|
||||||
return SHAPELIST(empty);
|
|
||||||
}
|
|
||||||
|
|
||||||
REQUIRE_TRUE(axis < inputRank, 0, "GATHER op: input axis must be smaller than input array rank, but got %i and %i correspondingly!", axis, inputRank);
|
REQUIRE_TRUE(axis < inputRank, 0, "GATHER op: input axis must be smaller than input array rank, but got %i and %i correspondingly!", axis, inputRank);
|
||||||
|
|
||||||
bool isEmpty = false;
|
bool isEmpty = false;
|
||||||
|
@ -119,11 +113,6 @@ DECLARE_SHAPE_FN(gather) {
|
||||||
|
|
||||||
int outputRank = inputRank + indicesRank - 1;
|
int outputRank = inputRank + indicesRank - 1;
|
||||||
|
|
||||||
if(INPUT_VARIABLE(1)->isEmpty()) { //Empty indices -> empty output
|
|
||||||
outputRank = 0;
|
|
||||||
isEmpty = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
ALLOCATE(outputShapeInfo, block.getWorkspace(), shape::shapeInfoLength(outputRank), Nd4jLong);
|
ALLOCATE(outputShapeInfo, block.getWorkspace(), shape::shapeInfoLength(outputRank), Nd4jLong);
|
||||||
|
|
||||||
// fill output shapeInfo
|
// fill output shapeInfo
|
||||||
|
|
|
@ -33,6 +33,11 @@ namespace ops {
|
||||||
auto input = INPUT_VARIABLE(0);
|
auto input = INPUT_VARIABLE(0);
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
if(output->isEmpty()){
|
||||||
|
//No-op
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<int> axis;
|
std::vector<int> axis;
|
||||||
|
|
||||||
if (block.width() > 1)
|
if (block.width() > 1)
|
||||||
|
|
|
@ -204,6 +204,8 @@ namespace nd4j {
|
||||||
const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst,
|
const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst,
|
||||||
mkldnn::memory::desc* conv_src_md, mkldnn::memory::desc* conv_diff_src_md, mkldnn::memory::desc* conv_weights_md,
|
mkldnn::memory::desc* conv_src_md, mkldnn::memory::desc* conv_diff_src_md, mkldnn::memory::desc* conv_weights_md,
|
||||||
mkldnn::memory::desc* conv_diff_weights_md, mkldnn::memory::desc* conv_bias_md, mkldnn::memory::desc* conv_dst_md,
|
mkldnn::memory::desc* conv_diff_weights_md, mkldnn::memory::desc* conv_bias_md, mkldnn::memory::desc* conv_dst_md,
|
||||||
|
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_weights_md,
|
||||||
|
mkldnn::memory::desc* user_diff_weights_md, mkldnn::memory::desc* user_bias_md, mkldnn::memory::desc* user_dst_md,
|
||||||
mkldnn::memory::dims& conv_strides, mkldnn::memory::dims& conv_padding, mkldnn::memory::dims& conv_padding_r);
|
mkldnn::memory::dims& conv_strides, mkldnn::memory::dims& conv_padding, mkldnn::memory::dims& conv_padding_r);
|
||||||
|
|
||||||
static void getMKLDNNMemoryDescConv3d(
|
static void getMKLDNNMemoryDescConv3d(
|
||||||
|
@ -212,56 +214,60 @@ namespace nd4j {
|
||||||
const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst,
|
const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst,
|
||||||
mkldnn::memory::desc* conv_src_md, mkldnn::memory::desc* conv_diff_src_md, mkldnn::memory::desc* conv_weights_md,
|
mkldnn::memory::desc* conv_src_md, mkldnn::memory::desc* conv_diff_src_md, mkldnn::memory::desc* conv_weights_md,
|
||||||
mkldnn::memory::desc* conv_diff_weights_md, mkldnn::memory::desc* conv_bias_md, mkldnn::memory::desc* conv_dst_md,
|
mkldnn::memory::desc* conv_diff_weights_md, mkldnn::memory::desc* conv_bias_md, mkldnn::memory::desc* conv_dst_md,
|
||||||
|
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_weights_md,
|
||||||
|
mkldnn::memory::desc* user_diff_weights_md, mkldnn::memory::desc* user_bias_md, mkldnn::memory::desc* user_dst_md,
|
||||||
mkldnn::memory::dims& conv_strides, mkldnn::memory::dims& conv_padding, mkldnn::memory::dims& conv_padding_r);
|
mkldnn::memory::dims& conv_strides, mkldnn::memory::dims& conv_padding, mkldnn::memory::dims& conv_padding_r);
|
||||||
|
|
||||||
static void getMKLDNNMemoryDescPool2d(
|
static void getMKLDNNMemoryDescPool2d(
|
||||||
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, int poolingMode, int extraParam0, bool isNCHW,
|
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, int poolingMode, int extraParam0, bool isNCHW,
|
||||||
int bS, int iC, int iH, int iW, int oC, int oH, int oW,
|
int bS, int iC, int iH, int iW, int oC, int oH, int oW,
|
||||||
const NDArray* src, const NDArray* diff_src, const NDArray* dst,
|
const NDArray* src, const NDArray* diff_src, const NDArray* dst, mkldnn::algorithm& algorithm,
|
||||||
mkldnn::memory::desc* pool_src_md, mkldnn::memory::desc* conv_diff_src_md, mkldnn::memory::desc* pool_dst_md, mkldnn::algorithm& algorithm,
|
mkldnn::memory::desc* pool_src_md, mkldnn::memory::desc* pool_diff_src_md, mkldnn::memory::desc* pool_dst_md,
|
||||||
|
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md,
|
||||||
mkldnn::memory::dims& pool_strides, mkldnn::memory::dims& pool_kernel, mkldnn::memory::dims& pool_padding, mkldnn::memory::dims& pool_padding_r);
|
mkldnn::memory::dims& pool_strides, mkldnn::memory::dims& pool_kernel, mkldnn::memory::dims& pool_padding, mkldnn::memory::dims& pool_padding_r);
|
||||||
|
|
||||||
static void getMKLDNNMemoryDescPool3d(
|
static void getMKLDNNMemoryDescPool3d(
|
||||||
int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, int poolingMode, int extraParam0, bool isNCDHW,
|
int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, int poolingMode, int extraParam0, bool isNCDHW,
|
||||||
int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW,
|
int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW,
|
||||||
const NDArray* src, const NDArray* diff_src, const NDArray* dst,
|
const NDArray* src, const NDArray* diff_src, const NDArray* dst, mkldnn::algorithm& algorithm,
|
||||||
mkldnn::memory::desc* pool_src_md, mkldnn::memory::desc* conv_diff_src_md, mkldnn::memory::desc* pool_dst_md, mkldnn::algorithm& algorithm,
|
mkldnn::memory::desc* pool_src_md, mkldnn::memory::desc* pool_diff_src_md, mkldnn::memory::desc* pool_dst_md,
|
||||||
|
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md,
|
||||||
mkldnn::memory::dims& pool_strides, mkldnn::memory::dims& pool_kernel, mkldnn::memory::dims& pool_padding, mkldnn::memory::dims& pool_padding_r);
|
mkldnn::memory::dims& pool_strides, mkldnn::memory::dims& pool_kernel, mkldnn::memory::dims& pool_padding, mkldnn::memory::dims& pool_padding_r);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
static void conv2d(nd4j::LaunchContext &context, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW);
|
static void conv2d(nd4j::graph::Context &context, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW);
|
||||||
|
|
||||||
static void conv2d(nd4j::LaunchContext & block, const std::vector<NDArray*>& inArrs, NDArray* output, const std::vector<int>& intArgs);
|
static void conv2d(nd4j::graph::Context & block, const std::vector<NDArray*>& inArrs, NDArray* output, const std::vector<int>& intArgs);
|
||||||
|
|
||||||
static void conv2dBP(nd4j::LaunchContext & block, const std::vector<NDArray*>& inArrs, const std::vector<NDArray*>& outArrs, const std::vector<int>& intArgs);
|
static void conv2dBP(nd4j::graph::Context & block, const std::vector<NDArray*>& inArrs, const std::vector<NDArray*>& outArrs, const std::vector<int>& intArgs);
|
||||||
|
|
||||||
static void conv2dBP(nd4j::LaunchContext & block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW);
|
static void conv2dBP(nd4j::graph::Context & block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW);
|
||||||
|
|
||||||
static void depthwiseConv2d(nd4j::LaunchContext & block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW);
|
static void depthwiseConv2d(nd4j::graph::Context & block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW);
|
||||||
|
|
||||||
static void depthwiseConv2dBP(nd4j::LaunchContext & block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW);
|
static void depthwiseConv2dBP(nd4j::graph::Context & block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW);
|
||||||
|
|
||||||
static void sconv2d(nd4j::LaunchContext & block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW);
|
static void sconv2d(nd4j::graph::Context & block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW);
|
||||||
|
|
||||||
static void vol2col(nd4j::LaunchContext & block, const NDArray& vol, NDArray& col, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW);
|
static void vol2col(nd4j::graph::Context & block, const NDArray& vol, NDArray& col, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW);
|
||||||
|
|
||||||
static void col2vol(nd4j::LaunchContext & block, const NDArray& col, NDArray& vol, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW);
|
static void col2vol(nd4j::graph::Context & block, const NDArray& col, NDArray& vol, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW);
|
||||||
|
|
||||||
static void upsampling2d(nd4j::LaunchContext & block, const NDArray& input, NDArray& output, const int factorH, const int factorW, const bool isNCHW);
|
static void upsampling2d(nd4j::graph::Context & block, const NDArray& input, NDArray& output, const int factorH, const int factorW, const bool isNCHW);
|
||||||
|
|
||||||
static void upsampling3d(nd4j::LaunchContext & block, const NDArray& input, NDArray& output, const int factorD, const int factorH, const int factorW, const bool isNCDHW);
|
static void upsampling3d(nd4j::graph::Context & block, const NDArray& input, NDArray& output, const int factorD, const int factorH, const int factorW, const bool isNCDHW);
|
||||||
|
|
||||||
static void upsampling2dBP(nd4j::LaunchContext & block, const NDArray& gradO, NDArray& gradI, const bool isNCHW);
|
static void upsampling2dBP(nd4j::graph::Context & block, const NDArray& gradO, NDArray& gradI, const bool isNCHW);
|
||||||
|
|
||||||
static void upsampling3dBP(nd4j::LaunchContext & block, const NDArray& gradO, NDArray& gradI, const bool isNCDHW);
|
static void upsampling3dBP(nd4j::graph::Context & block, const NDArray& gradO, NDArray& gradI, const bool isNCDHW);
|
||||||
|
|
||||||
static void pooling2d(nd4j::LaunchContext & block, const NDArray& input, NDArray& output, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const PoolingType poolingMode, const int extraParam0);
|
static void pooling2d(nd4j::graph::Context & block, const NDArray& input, NDArray& output, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const PoolingType poolingMode, const int extraParam0);
|
||||||
|
|
||||||
static void pooling3d(nd4j::LaunchContext & block, const NDArray& input, NDArray& output, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0);
|
static void pooling3d(nd4j::graph::Context & block, const NDArray& input, NDArray& output, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0);
|
||||||
|
|
||||||
static void pooling2dBP(nd4j::LaunchContext & block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0);
|
static void pooling2dBP(nd4j::graph::Context & block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0);
|
||||||
|
|
||||||
static void pooling3dBP(nd4j::LaunchContext & block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0);
|
static void pooling3dBP(nd4j::graph::Context & block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0);
|
||||||
};
|
};
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -153,7 +153,7 @@ void softMaxForVector(nd4j::LaunchContext * context, const NDArray& input, NDArr
|
||||||
if (inEWS == 1) {
|
if (inEWS == 1) {
|
||||||
PRAGMA_OMP_SIMD_MAX(max)
|
PRAGMA_OMP_SIMD_MAX(max)
|
||||||
for (int i = 0; i < length; i++)
|
for (int i = 0; i < length; i++)
|
||||||
max = nd4j::math::nd4j_max<T>(max, outBuff[i]);
|
max = nd4j::math::nd4j_max<T>(max, inBuff[i]);
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD_SUM(sum)
|
PRAGMA_OMP_SIMD_SUM(sum)
|
||||||
for (int i = 0; i < length; i++) {
|
for (int i = 0; i < length; i++) {
|
||||||
|
@ -171,7 +171,7 @@ void softMaxForVector(nd4j::LaunchContext * context, const NDArray& input, NDArr
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD_MAX(max)
|
PRAGMA_OMP_SIMD_MAX(max)
|
||||||
for (int i = 0; i < length; i++)
|
for (int i = 0; i < length; i++)
|
||||||
max = nd4j::math::nd4j_max<T>(max, outBuff[i * inEWS]);
|
max = nd4j::math::nd4j_max<T>(max, inBuff[i * inEWS]);
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD_SUM(sum)
|
PRAGMA_OMP_SIMD_SUM(sum)
|
||||||
for (int i = 0; i < length; i++) {
|
for (int i = 0; i < length; i++) {
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -227,15 +227,15 @@ void lstmBlockCell(const NDArray* xt, const NDArray* cLast, const NDArray* yLast
|
||||||
|
|
||||||
//NDArray* NDArrayFactory::create_( const char order, const std::vector<Nd4jLong> &shape, nd4j::DataType dataType, nd4j::memory::Workspace* workspace) {
|
//NDArray* NDArrayFactory::create_( const char order, const std::vector<Nd4jLong> &shape, nd4j::DataType dataType, nd4j::memory::Workspace* workspace) {
|
||||||
std::vector<Nd4jLong> shape = {bS, 4*numUnits};
|
std::vector<Nd4jLong> shape = {bS, 4*numUnits};
|
||||||
auto m = NDArrayFactory::create_('c', shape, xt->dataType(), nullptr);
|
auto m = NDArrayFactory::create('c', shape, xt->dataType());
|
||||||
MmulHelper::mmul(&concatOut, W, m, 1.0f, 0.0f, 'c'); //mmul: [bs, (nIn+numUnits)]* [(inSize+numUnits), 4*numUnits] = [bs, 4*numUnits] - C result array
|
MmulHelper::mmul(&concatOut, W, &m, 1.0f, 0.0f, 'c'); //mmul: [bs, (nIn+numUnits)]* [(inSize+numUnits), 4*numUnits] = [bs, 4*numUnits] - C result array
|
||||||
*m += (*b); //addiRowVector
|
m += (*b); //addiRowVector
|
||||||
|
|
||||||
//Note: weights are ordered [inputGate, blockInput, forgetGate, outputGate] to match TF (TF code comments state [i,f,z/ci,o] but behaviour is [i,z,f,o])
|
//Note: weights are ordered [inputGate, blockInput, forgetGate, outputGate] to match TF (TF code comments state [i,f,z/ci,o] but behaviour is [i,z,f,o])
|
||||||
auto zi = (*m)({0,0, 0, numUnits}); // z for input modulation gate, [bS, numUnits]
|
auto zi = (m)({0,0, 0, numUnits}); // z for input modulation gate, [bS, numUnits]
|
||||||
auto zz = (*m)({0,0, numUnits, 2*numUnits}); // z for block input, [bS, numUnits]
|
auto zz = (m)({0,0, numUnits, 2*numUnits}); // z for block input, [bS, numUnits]
|
||||||
auto zf = (*m)({0,0, 2*numUnits, 3*numUnits}); // z for forget gate, [bS, numUnits]
|
auto zf = (m)({0,0, 2*numUnits, 3*numUnits}); // z for forget gate, [bS, numUnits]
|
||||||
auto zo = (*m)({0,0, 3*numUnits, 4*numUnits}); // z for output gate, [bS, numUnits]
|
auto zo = (m)({0,0, 3*numUnits, 4*numUnits}); // z for output gate, [bS, numUnits]
|
||||||
|
|
||||||
if(peephole) { // add peephole connections: z + ct_1*Wc
|
if(peephole) { // add peephole connections: z + ct_1*Wc
|
||||||
zi += (*cLast) * (*Wci); // add peephole connections to input gate
|
zi += (*cLast) * (*Wci); // add peephole connections to input gate
|
||||||
|
|
|
@ -57,7 +57,7 @@ namespace helpers {
|
||||||
ConvolutionUtils::calcPadding2D(pY, pX, oY, oX, inY, inX, params[0], params[1], params[2], params[3], params[6], params[7]);
|
ConvolutionUtils::calcPadding2D(pY, pX, oY, oX, inY, inX, params[0], params[1], params[2], params[3], params[6], params[7]);
|
||||||
|
|
||||||
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor;
|
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor;
|
||||||
ConvolutionUtils::pooling2d(*block.launchContext(), *input, *values, kY, kX, sY, sX, pY, pX, dY, dX, PoolingType::MAX_POOL, 1);
|
ConvolutionUtils::pooling2d(block, *input, *values, kY, kX, sY, sX, pY, pX, dY, dX, PoolingType::MAX_POOL, 1);
|
||||||
|
|
||||||
if (nullptr != indices) {
|
if (nullptr != indices) {
|
||||||
// for max_pool_with_argmax
|
// for max_pool_with_argmax
|
||||||
|
|
|
@ -53,9 +53,16 @@ namespace nd4j {
|
||||||
dtype = nd4j::DataType::BOOL;
|
dtype = nd4j::DataType::BOOL;
|
||||||
|
|
||||||
if(shape::isEmpty(x) || shape::isEmpty(y)) {
|
if(shape::isEmpty(x) || shape::isEmpty(y)) {
|
||||||
//Edge case: broadcasting with empty array gives empty array output (behaviour to match TF for import cases)
|
// this is edge case, [3, 4] + [] = []
|
||||||
auto empty = ConstantShapeHelper::getInstance()->emptyShapeInfo(dtype);
|
if ((shape::isEmpty(x) && shape::rank(x) == 0) || (shape::isEmpty(y) && shape::rank(y) == 0)) {
|
||||||
shapeList->push_back(empty);
|
shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor::emptyDescriptor(dtype)));
|
||||||
|
return shapeList;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
Nd4jLong *newshape = nullptr;
|
||||||
|
ShapeUtils::evalBroadcastShapeInfo(x, y, true, newshape, block.workspace());
|
||||||
|
shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(newshape, dtype)));
|
||||||
} else if (shape::isScalar(x) && shape::isScalar(y)) {
|
} else if (shape::isScalar(x) && shape::isScalar(y)) {
|
||||||
if (shape::rank(x) >= shape::rank(y)) {
|
if (shape::rank(x) >= shape::rank(y)) {
|
||||||
shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(x, dtype)));
|
shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(x, dtype)));
|
||||||
|
|
|
@ -2873,7 +2873,7 @@ namespace simdOps {
|
||||||
const static functions::ReduceType reduceType = functions::ReduceType::MAX;
|
const static functions::ReduceType reduceType = functions::ReduceType::MAX;
|
||||||
|
|
||||||
op_def static X startingValue(const X *input) {
|
op_def static X startingValue(const X *input) {
|
||||||
return -nd4j::DataTypeUtils::max<X>();
|
return -nd4j::DataTypeUtils::infOrMax<X>();
|
||||||
}
|
}
|
||||||
|
|
||||||
op_def static X merge(X old, X opOutput, X *extraParams) {
|
op_def static X merge(X old, X opOutput, X *extraParams) {
|
||||||
|
@ -3051,7 +3051,7 @@ namespace simdOps {
|
||||||
const static functions::ReduceType reduceType = functions::ReduceType::MIN;
|
const static functions::ReduceType reduceType = functions::ReduceType::MIN;
|
||||||
|
|
||||||
op_def static X startingValue(const X *input) {
|
op_def static X startingValue(const X *input) {
|
||||||
return nd4j::DataTypeUtils::max<X>();
|
return nd4j::DataTypeUtils::infOrMax<X>();
|
||||||
}
|
}
|
||||||
|
|
||||||
op_def static X merge(X old, X opOutput, X *extraParams) {
|
op_def static X merge(X old, X opOutput, X *extraParams) {
|
||||||
|
@ -3831,7 +3831,7 @@ namespace simdOps {
|
||||||
}
|
}
|
||||||
|
|
||||||
static _CUDA_HD inline X startingValue(const X *input) {
|
static _CUDA_HD inline X startingValue(const X *input) {
|
||||||
return -nd4j::DataTypeUtils::max<X>();
|
return -nd4j::DataTypeUtils::infOrMax<X>();
|
||||||
}
|
}
|
||||||
|
|
||||||
static _CUDA_HD inline functions::indexreduce::IndexValue<X> startingIndexValue(X *input) {
|
static _CUDA_HD inline functions::indexreduce::IndexValue<X> startingIndexValue(X *input) {
|
||||||
|
@ -3890,7 +3890,7 @@ namespace simdOps {
|
||||||
}
|
}
|
||||||
|
|
||||||
static _CUDA_HD inline X startingValue(const X *input) {
|
static _CUDA_HD inline X startingValue(const X *input) {
|
||||||
return -nd4j::DataTypeUtils::max<X>();
|
return -nd4j::DataTypeUtils::infOrMax<X>();
|
||||||
}
|
}
|
||||||
|
|
||||||
static _CUDA_HD inline functions::indexreduce::IndexValue<X> startingIndexValue(X *input) {
|
static _CUDA_HD inline functions::indexreduce::IndexValue<X> startingIndexValue(X *input) {
|
||||||
|
@ -3958,7 +3958,7 @@ namespace simdOps {
|
||||||
}
|
}
|
||||||
|
|
||||||
static _CUDA_HD inline X startingValue(const X *input) {
|
static _CUDA_HD inline X startingValue(const X *input) {
|
||||||
return -nd4j::DataTypeUtils::max<X>();
|
return -nd4j::DataTypeUtils::infOrMax<X>();
|
||||||
}
|
}
|
||||||
|
|
||||||
static _CUDA_HD inline functions::indexreduce::IndexValue<X> startingIndexValue(X *input) {
|
static _CUDA_HD inline functions::indexreduce::IndexValue<X> startingIndexValue(X *input) {
|
||||||
|
@ -3984,7 +3984,7 @@ namespace simdOps {
|
||||||
}
|
}
|
||||||
|
|
||||||
static _CUDA_HD inline X startingValue(const X *input) {
|
static _CUDA_HD inline X startingValue(const X *input) {
|
||||||
return nd4j::DataTypeUtils::max<X>();
|
return nd4j::DataTypeUtils::infOrMax<X>();
|
||||||
}
|
}
|
||||||
|
|
||||||
static _CUDA_HD inline functions::indexreduce::IndexValue<X> startingIndexValue(X *input) {
|
static _CUDA_HD inline functions::indexreduce::IndexValue<X> startingIndexValue(X *input) {
|
||||||
|
@ -4040,7 +4040,7 @@ namespace simdOps {
|
||||||
}
|
}
|
||||||
|
|
||||||
static _CUDA_HD inline X startingValue(const X *input) {
|
static _CUDA_HD inline X startingValue(const X *input) {
|
||||||
return nd4j::DataTypeUtils::max<X>();
|
return nd4j::DataTypeUtils::infOrMax<X>();
|
||||||
}
|
}
|
||||||
|
|
||||||
static _CUDA_HD inline functions::indexreduce::IndexValue<X> startingIndexValue(X *input) {
|
static _CUDA_HD inline functions::indexreduce::IndexValue<X> startingIndexValue(X *input) {
|
||||||
|
|
|
@ -580,6 +580,152 @@ TEST_F(BroadcastableOpsTests, broadcast_empty_1) {
|
||||||
ASSERT_TRUE(z.equalsTo(zExp));
|
ASSERT_TRUE(z.equalsTo(zExp));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(BroadcastableOpsTests, broadcast_empty_2) {
|
||||||
|
|
||||||
|
NDArray y('c', {1,4}, {1,2,3,4});
|
||||||
|
NDArray x = NDArrayFactory::create<double>('c', {0, 4});
|
||||||
|
NDArray e = NDArrayFactory::create<double>('c', {0, 4});;
|
||||||
|
|
||||||
|
nd4j::ops::multiply op;
|
||||||
|
auto status = op.execute({&x, &y}, {&x}, {}, {}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||||
|
ASSERT_TRUE(e.isSameShape(x));
|
||||||
|
ASSERT_TRUE(e.equalsTo(x));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(BroadcastableOpsTests, broadcast_empty_3) {
|
||||||
|
|
||||||
|
NDArray x = NDArrayFactory::create<float>('c', {1, 0, 2});
|
||||||
|
NDArray y('c', {}, {0.1}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray e = NDArrayFactory::create<float>('c', {1, 0, 2});;
|
||||||
|
|
||||||
|
nd4j::ops::maximum op;
|
||||||
|
auto result = op.execute({&x, &y}, {}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(e.isSameShape(z));
|
||||||
|
ASSERT_TRUE(e.equalsTo(*z));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(BroadcastableOpsTests, broadcast_empty_4) {
|
||||||
|
|
||||||
|
NDArray x = NDArrayFactory::create<float>('c', {1, 0, 1});
|
||||||
|
NDArray y = NDArrayFactory::create<float>('c', {1, 0, 2});
|
||||||
|
NDArray e = NDArrayFactory::create<float>('c', {1, 0, 2});;
|
||||||
|
|
||||||
|
nd4j::ops::maximum op;
|
||||||
|
auto result = op.execute({&x, &y}, {}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(e.isSameShape(z));
|
||||||
|
ASSERT_TRUE(e.equalsTo(*z));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(BroadcastableOpsTests, broadcast_empty_5) {
|
||||||
|
|
||||||
|
NDArray x = NDArrayFactory::create<float>('c', {1, 0, 1});
|
||||||
|
NDArray y = NDArrayFactory::create<float>('c', {1, 0, 2});
|
||||||
|
NDArray e = NDArrayFactory::create<float>('c', {1, 0, 2});;
|
||||||
|
|
||||||
|
nd4j::ops::realdiv op;
|
||||||
|
auto result = op.execute({&x, &y}, {}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(e.isSameShape(z));
|
||||||
|
ASSERT_TRUE(e.equalsTo(*z));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(BroadcastableOpsTests, broadcast_empty_6) {
|
||||||
|
|
||||||
|
NDArray x = NDArrayFactory::create<float>('c', {1, 0, 1});
|
||||||
|
NDArray y = NDArrayFactory::create<float>('c', {1, 2}, {2, 2});
|
||||||
|
NDArray e = NDArrayFactory::create<float>('c', {1, 0, 2});;
|
||||||
|
|
||||||
|
nd4j::ops::realdiv op;
|
||||||
|
auto result = op.execute({&x, &y}, {}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(e.isSameShape(z));
|
||||||
|
ASSERT_TRUE(e.equalsTo(*z));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(BroadcastableOpsTests, broadcast_empty_7) {
|
||||||
|
|
||||||
|
NDArray x = NDArrayFactory::create<float>('c', {1, 0, 2, 1});
|
||||||
|
NDArray y = NDArrayFactory::create<float>('c', {1, 2, 0});
|
||||||
|
NDArray e = NDArrayFactory::create<float>('c', {1, 0, 2, 0});;
|
||||||
|
|
||||||
|
nd4j::ops::realdiv op;
|
||||||
|
auto result = op.execute({&x, &y}, {}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(e.isSameShape(z));
|
||||||
|
ASSERT_TRUE(e.equalsTo(*z));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(BroadcastableOpsTests, broadcast_bool_empty_1) {
|
||||||
|
|
||||||
|
NDArray y('c', {3,4}, {0,0,0,0, 1,2,3,4, 1,2,3,4});
|
||||||
|
NDArray x(nd4j::DataType::DOUBLE, y.getContext(), false);
|
||||||
|
NDArray z(nd4j::DataType::BOOL, y.getContext(), false);
|
||||||
|
NDArray zExp(nd4j::DataType::BOOL, y.getContext(), false);
|
||||||
|
|
||||||
|
nd4j::ops::greater op;
|
||||||
|
auto status = op.execute({&x, &y}, {&z}, {}, {}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||||
|
ASSERT_TRUE(z.isSameShape(zExp));
|
||||||
|
ASSERT_TRUE(z.equalsTo(zExp));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(BroadcastableOpsTests, broadcast_bool_empty_2) {
|
||||||
|
|
||||||
|
NDArray y('c', {1,4}, {1,2,3,4});
|
||||||
|
NDArray x = NDArrayFactory::create<double>('c', {0, 4});
|
||||||
|
NDArray e = NDArrayFactory::create<bool>('c', {0, 4});;
|
||||||
|
|
||||||
|
|
||||||
|
nd4j::ops::greater op;
|
||||||
|
auto result = op.execute({&x, &y}, {}, {}, {});
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
|
||||||
|
z->printShapeInfo("z");
|
||||||
|
|
||||||
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
ASSERT_TRUE(e.isSameShape(z));
|
||||||
|
ASSERT_TRUE(e.equalsTo(*z));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(BroadcastableOpsTests, broadcast_bool_1) {
|
TEST_F(BroadcastableOpsTests, broadcast_bool_1) {
|
||||||
|
|
||||||
NDArray x('c', {3, 1, 2}, nd4j::DataType::FLOAT32);
|
NDArray x('c', {3, 1, 2}, nd4j::DataType::FLOAT32);
|
||||||
|
|
|
@ -2021,7 +2021,8 @@ TEST_F(ConvolutionTests1, vol2col_test1) {
|
||||||
// PointersManager manager(columnsExpected.getContext());
|
// PointersManager manager(columnsExpected.getContext());
|
||||||
// manager.printDevContentOnHost<float>(columnsExpected.getSpecialBuffer(), columnsExpected.lengthOf());
|
// manager.printDevContentOnHost<float>(columnsExpected.getSpecialBuffer(), columnsExpected.lengthOf());
|
||||||
|
|
||||||
nd4j::ops::ConvolutionUtils::vol2col(*LaunchContext::defaultContext(), volume, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW);
|
graph::Context context(1);
|
||||||
|
nd4j::ops::ConvolutionUtils::vol2col(context, volume, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW);
|
||||||
|
|
||||||
ASSERT_TRUE(columns.equalsTo(columnsExpected));
|
ASSERT_TRUE(columns.equalsTo(columnsExpected));
|
||||||
}
|
}
|
||||||
|
@ -2052,7 +2053,8 @@ TEST_F(ConvolutionTests1, vol2col_test2) {
|
||||||
-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
|
-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
|
||||||
-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.});
|
-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.});
|
||||||
|
|
||||||
nd4j::ops::ConvolutionUtils::vol2col(*LaunchContext::defaultContext(), volume, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW);
|
graph::Context context(1);
|
||||||
|
nd4j::ops::ConvolutionUtils::vol2col(context, volume, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW);
|
||||||
|
|
||||||
ASSERT_TRUE(columns.equalsTo(columnsExpected));
|
ASSERT_TRUE(columns.equalsTo(columnsExpected));
|
||||||
}
|
}
|
||||||
|
|
|
@ -1302,8 +1302,8 @@ TEST_F(DeclarableOpsTests10, broadcast_to_test6) {
|
||||||
TEST_F(DeclarableOpsTests10, broadcast_to_test7) {
|
TEST_F(DeclarableOpsTests10, broadcast_to_test7) {
|
||||||
|
|
||||||
auto input = NDArrayFactory::create<double>(10.f);
|
auto input = NDArrayFactory::create<double>(10.f);
|
||||||
auto shape = NDArrayFactory::create<double>(0.f);
|
auto shape = NDArrayFactory::create<Nd4jLong>(1);
|
||||||
auto exp = NDArrayFactory::create<double>(10.f);
|
auto exp = NDArrayFactory::create<double>('c', {1}, {10.});
|
||||||
|
|
||||||
nd4j::ops::broadcast_to op;
|
nd4j::ops::broadcast_to op;
|
||||||
auto results = op.execute({&input, &shape}, {}, {}, {});
|
auto results = op.execute({&input, &shape}, {}, {}, {});
|
||||||
|
@ -2261,8 +2261,8 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_1) {
|
||||||
|
|
||||||
NDArray x('c', {2,3}, {-63.80f, -63.75f, -63.70f, -63.5f, 0.0f, 0.1f}, nd4j::DataType::FLOAT32);
|
NDArray x('c', {2,3}, {-63.80f, -63.75f, -63.70f, -63.5f, 0.0f, 0.1f}, nd4j::DataType::FLOAT32);
|
||||||
NDArray exp('c', {2,3}, {-63.75f, -63.75f, -63.75f, -63.251953f, 0.0f, 0.0f}, nd4j::DataType::FLOAT32);
|
NDArray exp('c', {2,3}, {-63.75f, -63.75f, -63.75f, -63.251953f, 0.0f, 0.0f}, nd4j::DataType::FLOAT32);
|
||||||
NDArray min('c', {0}, {-63.65f}, nd4j::DataType::FLOAT32);
|
NDArray min('c', {}, {-63.65f}, nd4j::DataType::FLOAT32);
|
||||||
NDArray max('c', {0}, {0.1f}, nd4j::DataType::FLOAT32);
|
NDArray max('c', {}, {0.1f}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::fake_quant_with_min_max_vars op;
|
nd4j::ops::fake_quant_with_min_max_vars op;
|
||||||
auto results = op.execute({&x, &min, &max}, {}, {});
|
auto results = op.execute({&x, &min, &max}, {}, {});
|
||||||
|
|
|
@ -136,7 +136,7 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test3) {
|
||||||
|
|
||||||
NDArray dLdpExp('c', {2,3,4}, {-12.49997,-13.04346, -13.63635, -14.28571,-14.99999,-15.78947, -16.66666, -17.64705,-18.75 ,-20. , -21.42857, -23.07692,
|
NDArray dLdpExp('c', {2,3,4}, {-12.49997,-13.04346, -13.63635, -14.28571,-14.99999,-15.78947, -16.66666, -17.64705,-18.75 ,-20. , -21.42857, -23.07692,
|
||||||
-24.99999,-27.27272, -29.99999, -33.33332,-37.49999,-42.85713, -49.99998, -59.99998,-74.99995,-99.99992,-149.99986,-299.99911});
|
-24.99999,-27.27272, -29.99999, -33.33332,-37.49999,-42.85713, -49.99998, -59.99998,-74.99995,-99.99992,-149.99986,-299.99911});
|
||||||
NDArray dLdwExp('c', {0}, {-227.77286});
|
NDArray dLdwExp('c', {}, {-227.77286});
|
||||||
NDArray dLdlExp('c', {2,3,4}, {1.58903, 1.22117, 0.99621, 0.82911, 0.69315, 0.57634, 0.47223, 0.37689, 0.28768, 0.20273, 0.12058, 0.04002,
|
NDArray dLdlExp('c', {2,3,4}, {1.58903, 1.22117, 0.99621, 0.82911, 0.69315, 0.57634, 0.47223, 0.37689, 0.28768, 0.20273, 0.12058, 0.04002,
|
||||||
-0.04002,-0.12058,-0.20273,-0.28768,-0.37689,-0.47223,-0.57634,-0.69315,-0.82911,-0.99621,-1.22117,-1.58903});
|
-0.04002,-0.12058,-0.20273,-0.28768,-0.37689,-0.47223,-0.57634,-0.69315,-0.82911,-0.99621,-1.22117,-1.58903});
|
||||||
|
|
||||||
|
@ -261,7 +261,7 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test7) {
|
||||||
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
||||||
NDArray weights(nd4j::DataType::DOUBLE);
|
NDArray weights(nd4j::DataType::DOUBLE);
|
||||||
|
|
||||||
NDArray dLdwExp('c', {0}, {0.});
|
NDArray dLdwExp('c', {}, {0.});
|
||||||
|
|
||||||
predictions.linspace(0.04, 0.04);
|
predictions.linspace(0.04, 0.04);
|
||||||
labels.linspace(1);
|
labels.linspace(1);
|
||||||
|
@ -583,7 +583,7 @@ TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test3) {
|
||||||
|
|
||||||
NDArray dLdpExp('c', {2,3,4}, {-0.96, -1.92, -2.88, -3.84, -4.8 , -5.76, -6.72, -7.68, -8.64, -9.6 ,-10.56,-11.52,
|
NDArray dLdpExp('c', {2,3,4}, {-0.96, -1.92, -2.88, -3.84, -4.8 , -5.76, -6.72, -7.68, -8.64, -9.6 ,-10.56,-11.52,
|
||||||
-12.48,-13.44,-14.4 ,-15.36,-16.32,-17.28,-18.24,-19.2 ,-20.16,-21.12,-22.08,-23.04});
|
-12.48,-13.44,-14.4 ,-15.36,-16.32,-17.28,-18.24,-19.2 ,-20.16,-21.12,-22.08,-23.04});
|
||||||
NDArray dLdwExp('c', {0}, {4515.84});
|
NDArray dLdwExp('c', {}, {4515.84});
|
||||||
|
|
||||||
predictions.linspace(0.04, 0.04);
|
predictions.linspace(0.04, 0.04);
|
||||||
labels.linspace(1);
|
labels.linspace(1);
|
||||||
|
@ -702,7 +702,7 @@ TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test7) {
|
||||||
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
||||||
NDArray weights(nd4j::DataType::DOUBLE);
|
NDArray weights(nd4j::DataType::DOUBLE);
|
||||||
|
|
||||||
NDArray dLdwExp('c', {0}, {0.});
|
NDArray dLdwExp('c', {}, {0.});
|
||||||
|
|
||||||
predictions.linspace(0.04, 0.04);
|
predictions.linspace(0.04, 0.04);
|
||||||
labels.linspace(1);
|
labels.linspace(1);
|
||||||
|
@ -1031,7 +1031,7 @@ TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test3) {
|
||||||
|
|
||||||
NDArray dLdpExp('c', {2,3,4}, {-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,
|
NDArray dLdpExp('c', {2,3,4}, {-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,
|
||||||
-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5});
|
-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5});
|
||||||
NDArray dLdwExp('c', {0}, {288.});
|
NDArray dLdwExp('c', {}, {288.});
|
||||||
|
|
||||||
predictions.linspace(0.04, 0.04);
|
predictions.linspace(0.04, 0.04);
|
||||||
labels.linspace(1);
|
labels.linspace(1);
|
||||||
|
@ -1150,7 +1150,7 @@ TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test7) {
|
||||||
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
||||||
NDArray weights(nd4j::DataType::DOUBLE);
|
NDArray weights(nd4j::DataType::DOUBLE);
|
||||||
|
|
||||||
NDArray dLdwExp('c', {0}, {0.});
|
NDArray dLdwExp('c', {}, {0.});
|
||||||
|
|
||||||
predictions.linspace(0.04, 0.04);
|
predictions.linspace(0.04, 0.04);
|
||||||
labels.linspace(1);
|
labels.linspace(1);
|
||||||
|
@ -1519,7 +1519,7 @@ TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test3) {
|
||||||
|
|
||||||
NDArray dLdpExp('c', {2,3,4}, {-0.18499,-0.53 ,-0.875 ,-1.22 ,-1.56501,-1.91002,-2.25504,-2.60008,-2.94514,-3.29023,-3.63534,-3.98048,
|
NDArray dLdpExp('c', {2,3,4}, {-0.18499,-0.53 ,-0.875 ,-1.22 ,-1.56501,-1.91002,-2.25504,-2.60008,-2.94514,-3.29023,-3.63534,-3.98048,
|
||||||
-4.32566,-4.67087,-5.01613,-5.36143,-5.70677,-6.05217,-6.39762,-6.74313,-7.0887 ,-7.43432,-7.78001,-8.12577});
|
-4.32566,-4.67087,-5.01613,-5.36143,-5.70677,-6.05217,-6.39762,-6.74313,-7.0887 ,-7.43432,-7.78001,-8.12577});
|
||||||
NDArray dLdwExp('c', {0}, {-91.52109});
|
NDArray dLdwExp('c', {}, {-91.52109});
|
||||||
NDArray dLdlExp('c', {2,3,4}, {0.028, 0.014, -0., -0.014,-0.028, -0.042, -0.056, -0.07 ,-0.084, -0.098, -0.112, -0.126,
|
NDArray dLdlExp('c', {2,3,4}, {0.028, 0.014, -0., -0.014,-0.028, -0.042, -0.056, -0.07 ,-0.084, -0.098, -0.112, -0.126,
|
||||||
-0.14 , -0.154, -0.168, -0.182,-0.196, -0.21 , -0.224, -0.238,-0.252, -0.266, -0.28 , -0.294});
|
-0.14 , -0.154, -0.168, -0.182,-0.196, -0.21 , -0.224, -0.238,-0.252, -0.266, -0.28 , -0.294});
|
||||||
|
|
||||||
|
@ -1642,7 +1642,7 @@ TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test7) {
|
||||||
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
||||||
NDArray weights(nd4j::DataType::DOUBLE);
|
NDArray weights(nd4j::DataType::DOUBLE);
|
||||||
|
|
||||||
NDArray dLdwExp('c', {0}, {0.});
|
NDArray dLdwExp('c', {}, {0.});
|
||||||
|
|
||||||
logits.linspace(-0.08, 0.04);
|
logits.linspace(-0.08, 0.04);
|
||||||
labels.linspace(1);
|
labels.linspace(1);
|
||||||
|
@ -2001,10 +2001,10 @@ TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test3) {
|
||||||
|
|
||||||
NDArray labels('c', {4}, {0,0,1,0}, nd4j::DataType::INT32);
|
NDArray labels('c', {4}, {0,0,1,0}, nd4j::DataType::INT32);
|
||||||
NDArray logits('c', {4}, nd4j::DataType::DOUBLE);
|
NDArray logits('c', {4}, nd4j::DataType::DOUBLE);
|
||||||
NDArray weights('c', {0}, nd4j::DataType::DOUBLE);
|
NDArray weights('c', {}, {0}, nd4j::DataType::DOUBLE);
|
||||||
|
|
||||||
NDArray dLdpExp('c', {4}, {0.125, 0.125, -0.375, 0.125});
|
NDArray dLdpExp('c', {4}, {0.125, 0.125, -0.375, 0.125});
|
||||||
NDArray dLdwExp('c', {0}, {1.38629});
|
NDArray dLdwExp('c', {}, {1.38629});
|
||||||
|
|
||||||
logits = 2.;
|
logits = 2.;
|
||||||
weights.assign(0.5);
|
weights.assign(0.5);
|
||||||
|
@ -2032,10 +2032,10 @@ TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test4) {
|
||||||
|
|
||||||
NDArray labels('c', {4}, {0,0,1,0}, nd4j::DataType::INT32);
|
NDArray labels('c', {4}, {0,0,1,0}, nd4j::DataType::INT32);
|
||||||
NDArray logits('c', {4}, nd4j::DataType::DOUBLE);
|
NDArray logits('c', {4}, nd4j::DataType::DOUBLE);
|
||||||
NDArray weights('c', {0}, nd4j::DataType::DOUBLE);
|
NDArray weights('c', {}, {0}, nd4j::DataType::DOUBLE);
|
||||||
|
|
||||||
NDArray dLdpExp('c', {4}, {0.23521, 0.2448 , -0.7452 , 0.26519});
|
NDArray dLdpExp('c', {4}, {0.23521, 0.2448 , -0.7452 , 0.26519});
|
||||||
NDArray dLdwExp('c', {0}, {0.});
|
NDArray dLdwExp('c', {}, {0.});
|
||||||
|
|
||||||
logits.linspace(-0.08, 0.04);
|
logits.linspace(-0.08, 0.04);
|
||||||
weights = 0.5;
|
weights = 0.5;
|
||||||
|
@ -2466,7 +2466,7 @@ TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test2) {
|
||||||
/////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test3) {
|
TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test3) {
|
||||||
|
|
||||||
NDArray labels('c', {0}, {1}, nd4j::DataType::INT64);
|
NDArray labels('c', {}, {1}, nd4j::DataType::INT64);
|
||||||
NDArray logits('c', {2}, {-0.2, 0.3});
|
NDArray logits('c', {2}, {-0.2, 0.3});
|
||||||
|
|
||||||
NDArray dLdpExp('c', {2}, {0.37754, -0.37754});
|
NDArray dLdpExp('c', {2}, {0.37754, -0.37754});
|
||||||
|
|
|
@ -158,10 +158,10 @@ TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test4) {
|
||||||
|
|
||||||
NDArray labels('c', {1,4}, {-0.1, 0.3, 2, -1.4});
|
NDArray labels('c', {1,4}, {-0.1, 0.3, 2, -1.4});
|
||||||
NDArray predictions('c', {1,4}, nd4j::DataType::DOUBLE);
|
NDArray predictions('c', {1,4}, nd4j::DataType::DOUBLE);
|
||||||
NDArray weights('c', {0}, nd4j::DataType::DOUBLE);
|
NDArray weights('c', {}, {0.}, nd4j::DataType::DOUBLE);
|
||||||
|
|
||||||
NDArray dLdpExp('c', {1,4}, {0.05, -0.15, -1., 0.7});
|
NDArray dLdpExp('c', {1,4}, {0.05, -0.15, -1., 0.7});
|
||||||
NDArray dLdwExp('c', {0}, {1.3});
|
NDArray dLdwExp('c', {}, {1.3});
|
||||||
NDArray dLdlExp('c', {1,4}, {0.2, 0.1, -0. , -0.1});
|
NDArray dLdlExp('c', {1,4}, {0.2, 0.1, -0. , -0.1});
|
||||||
|
|
||||||
predictions.linspace(-0.4, 0.2);
|
predictions.linspace(-0.4, 0.2);
|
||||||
|
@ -369,10 +369,10 @@ TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test9) {
|
||||||
TEST_F(DeclarableOpsTests12, hinge_loss_14) {
|
TEST_F(DeclarableOpsTests12, hinge_loss_14) {
|
||||||
|
|
||||||
NDArray logits('c', {3,4}, nd4j::DataType::DOUBLE);
|
NDArray logits('c', {3,4}, nd4j::DataType::DOUBLE);
|
||||||
NDArray weights('c', {0}, {1.});
|
NDArray weights('c', {}, {1.});
|
||||||
NDArray labels('c', {3,4}, {0,1,1,0,1,0,1,0,1,0,1,0});
|
NDArray labels('c', {3,4}, {0,1,1,0,1,0,1,0,1,0,1,0});
|
||||||
|
|
||||||
NDArray output('c', {0}, nd4j::DataType::DOUBLE);
|
NDArray output('c', {}, {0.}, nd4j::DataType::DOUBLE);
|
||||||
|
|
||||||
logits.linspace(1.);
|
logits.linspace(1.);
|
||||||
weights.assign(1.);
|
weights.assign(1.);
|
||||||
|
@ -594,7 +594,7 @@ TEST_F(DeclarableOpsTests12, TestMinimumBP_1) {
|
||||||
TEST_F(DeclarableOpsTests12, reverse_test15) {
|
TEST_F(DeclarableOpsTests12, reverse_test15) {
|
||||||
|
|
||||||
NDArray x('c', {5}, {1,2,3,4,5}, nd4j::DataType::DOUBLE);
|
NDArray x('c', {5}, {1,2,3,4,5}, nd4j::DataType::DOUBLE);
|
||||||
NDArray axis('c', {0}, {0}, nd4j::DataType::INT32);
|
NDArray axis('c', {}, {0}, nd4j::DataType::INT32);
|
||||||
NDArray z('c', {5}, nd4j::DataType::DOUBLE);
|
NDArray z('c', {5}, nd4j::DataType::DOUBLE);
|
||||||
NDArray exp('c', {5}, {5,4,3,2,1}, nd4j::DataType::DOUBLE);
|
NDArray exp('c', {5}, {5,4,3,2,1}, nd4j::DataType::DOUBLE);
|
||||||
|
|
||||||
|
|
|
@ -123,6 +123,7 @@ TEST_F(DeclarableOpsTests14, Test_EvalReductionShape_1) {
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
z->printIndexedBuffer("Reduced shape");
|
||||||
ASSERT_EQ(e, *z);
|
ASSERT_EQ(e, *z);
|
||||||
|
|
||||||
delete result;
|
delete result;
|
||||||
|
@ -213,3 +214,199 @@ TEST_F(DeclarableOpsTests14, Test_scalar_broadcast_2) {
|
||||||
|
|
||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests14, test_empty_fill_1) {
|
||||||
|
auto x = NDArrayFactory::empty<int>();
|
||||||
|
auto y = NDArrayFactory::create<int>(1);
|
||||||
|
|
||||||
|
nd4j::ops::fill op;
|
||||||
|
auto result = op.execute({&x, &y}, {}, {});
|
||||||
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
ASSERT_EQ(y, *z);
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests14, test_lstmBlockCell_1) {
|
||||||
|
auto a = NDArrayFactory::create<float>('c', {1, 5}, {0.7787856f, 0.80119777f, 0.72437465f, 0.23089433f, 0.72714126f});
|
||||||
|
auto b = NDArrayFactory::create<float>('c', {1, 3});
|
||||||
|
auto c = NDArrayFactory::create<float>('c', {1, 3});
|
||||||
|
auto d = NDArrayFactory::create<float>('c', {8, 12}, {-0.15320599,-0.120416045,0.33126968,0.13921785,-0.32313538,-0.43956736,0.4756174,0.4335605,-0.5450856,-0.3943429,-0.28687626,0.068032146,-0.2793799,0.17298919,-0.36553562,-0.097853184,-0.2544747,-0.39872527,-0.14556861,-0.31479517,0.2559092,0.47166896,-0.31330687,0.47313118,0.5134543,-0.4678212,-0.12853557,0.26142156,0.43472284,-0.42842552,-0.1895876,0.538689,0.508651,-0.020272732,0.112327516,0.2704304,-0.046546757,0.32570732,-0.15148133,-0.19145513,0.18631572,-0.024152994,0.41603214,-0.3421499,0.0106860995,-0.2966229,-0.36713937,0.25841123,0.0843398,0.49082482,0.10800403,0.1874243,-0.26379472,-0.22531849,0.24924624,0.23119557,0.49940765,-0.051413506,0.20315129,-0.41888732,0.44097036,0.40453392,0.013338983,0.23434466,0.23942488,0.47894,-0.19898453,0.09253675,-0.032358468,-0.15213022,-0.3441009,-0.15600958,-0.08235118,0.12165731,-0.4481289,-0.4842423,-0.45797008,-0.4606034,0.08163166,-0.2981107,0.50207126,0.44195646,0.13850057,0.072246075,-0.34388685,0.030900061,0.35821778,0.47900867,0.5094063,0.23683065,0.18020362,-0.1369732,0.015235603,0.2786904,0.07954317,0.12543976});
|
||||||
|
auto e = NDArrayFactory::create<float>('c', {3});
|
||||||
|
auto f = NDArrayFactory::create<float>('c', {3});
|
||||||
|
auto g = NDArrayFactory::create<float>('c', {3});
|
||||||
|
auto h = NDArrayFactory::create<float>('c', {12});
|
||||||
|
|
||||||
|
auto z0 = NDArrayFactory::create<float>('c', {1, 3});
|
||||||
|
auto z1 = NDArrayFactory::create<float>('c', {1, 3});
|
||||||
|
auto z2 = NDArrayFactory::create<float>('c', {1, 3});
|
||||||
|
auto z3 = NDArrayFactory::create<float>('c', {1, 3});
|
||||||
|
auto z4 = NDArrayFactory::create<float>('c', {1, 3});
|
||||||
|
auto z5 = NDArrayFactory::create<float>('c', {1, 3});
|
||||||
|
auto z6 = NDArrayFactory::create<float>('c', {1, 3});
|
||||||
|
|
||||||
|
nd4j::ops::lstmBlockCell op;
|
||||||
|
auto result = op.execute({&a, &b, &c, &d, &e, &f, &g, &h}, {&z0, &z1, &z2, &z3, &z4, &z5, &z6}, {1.0, -1.0}, {0}, {});
|
||||||
|
ASSERT_EQ(Status::OK(), result);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests14, test_empty_stack_1) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {0});
|
||||||
|
auto e = NDArrayFactory::create<float>('c', {1, 0});
|
||||||
|
|
||||||
|
nd4j::ops::stack op;
|
||||||
|
auto result = op.execute({&x}, {}, {0});
|
||||||
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
ASSERT_EQ(e, *z);
|
||||||
|
nd4j::ops::reduce_min sumOp;
|
||||||
|
auto res2 = sumOp.execute({&e}, {1.}, {1});
|
||||||
|
ASSERT_EQ(res2->status(), Status::OK());
|
||||||
|
auto out = res2->at(0);
|
||||||
|
out->printShapeInfo("ReduceSum empty shape with keep dims");
|
||||||
|
out->printIndexedBuffer("ReduceSum scalar");
|
||||||
|
ASSERT_EQ(out->e<float>(0), DataTypeUtils::infOrMax<float>());
|
||||||
|
delete res2;
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests14, test_empty_stack_2) {
|
||||||
|
auto x = NDArrayFactory::empty<float>();
|
||||||
|
auto e = NDArrayFactory::create<float>('c', {0});
|
||||||
|
|
||||||
|
nd4j::ops::stack op;
|
||||||
|
auto result = op.execute({&x}, {}, {0});
|
||||||
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
ASSERT_EQ(e, *z);
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests14, test_empty_stack_3) {
|
||||||
|
auto x = NDArrayFactory::empty<float>();
|
||||||
|
auto e = NDArrayFactory::create<float>('c', {2, 0});
|
||||||
|
|
||||||
|
nd4j::ops::stack op;
|
||||||
|
auto result = op.execute({&x, &x}, {}, {0});
|
||||||
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
ASSERT_EQ(e, *z);
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests14, test_empty_stack_4) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {0});
|
||||||
|
auto e = NDArrayFactory::create<float>('c', {2, 0});
|
||||||
|
|
||||||
|
nd4j::ops::stack op;
|
||||||
|
auto result = op.execute({&x, &x}, {}, {0});
|
||||||
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
ASSERT_EQ(e, *z);
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests14, test_empty_reduce_min_1) {
|
||||||
|
|
||||||
|
auto e = NDArrayFactory::create<float>('c', {1, 0});
|
||||||
|
nd4j::ops::reduce_min sumOp;
|
||||||
|
auto res2 = sumOp.execute({&e}, {1.}, {1});
|
||||||
|
ASSERT_EQ(res2->status(), Status::OK());
|
||||||
|
auto out = res2->at(0);
|
||||||
|
|
||||||
|
ASSERT_EQ(out->e<float>(0), DataTypeUtils::infOrMax<float>());
|
||||||
|
delete res2;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests14, test_empty_reduce_max_1) {
|
||||||
|
|
||||||
|
auto e = NDArrayFactory::create<float>('c', {1, 0});
|
||||||
|
nd4j::ops::reduce_max sumOp;
|
||||||
|
auto res2 = sumOp.execute({&e}, {1.}, {1});
|
||||||
|
ASSERT_EQ(res2->status(), Status::OK());
|
||||||
|
auto out = res2->at(0);
|
||||||
|
|
||||||
|
ASSERT_EQ(out->e<float>(0), -DataTypeUtils::infOrMax<float>());
|
||||||
|
delete res2;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests14, test_empty_reduce_sum_1) {
|
||||||
|
|
||||||
|
auto e = NDArrayFactory::create<float>('c', {1, 0});
|
||||||
|
nd4j::ops::reduce_sum sumOp;
|
||||||
|
auto res2 = sumOp.execute({&e}, {1.}, {1});
|
||||||
|
ASSERT_EQ(res2->status(), Status::OK());
|
||||||
|
auto out = res2->at(0);
|
||||||
|
ASSERT_EQ(out->e<float>(0), 0.f);
|
||||||
|
delete res2;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests14, test_empty_reduce_mean_1) {
|
||||||
|
|
||||||
|
auto e = NDArrayFactory::create<float>('c', {1, 0});
|
||||||
|
nd4j::ops::reduce_mean sumOp;
|
||||||
|
auto res2 = sumOp.execute({&e}, {1.}, {1});
|
||||||
|
ASSERT_EQ(res2->status(), Status::OK());
|
||||||
|
auto out = res2->at(0);
|
||||||
|
out->printShapeInfo("ReduceMean empty shape with keep dims");
|
||||||
|
out->printIndexedBuffer("ReduceMean scalar");
|
||||||
|
ASSERT_EQ(out->e<float>(0), 0.f);
|
||||||
|
delete res2;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests14, test_empty_argmax_1) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {1, 0});
|
||||||
|
auto y = NDArrayFactory::create<int>(0);
|
||||||
|
auto e = NDArrayFactory::create<Nd4jLong>('c', {0});
|
||||||
|
|
||||||
|
nd4j::ops::argmax op;
|
||||||
|
//nd4j::ops::reduce_max op;
|
||||||
|
|
||||||
|
auto result = op.execute({&x, &y}, {}, {});
|
||||||
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
|
||||||
|
z->printShapeInfo("Z");
|
||||||
|
|
||||||
|
ASSERT_EQ(e, *z);
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests14, test_empty_argmax_2) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {1, 0});
|
||||||
|
auto y = NDArrayFactory::create<int>(1);
|
||||||
|
|
||||||
|
nd4j::ops::argmax op;
|
||||||
|
try {
|
||||||
|
auto result = op.execute({&x, &y}, {&y}, {}, {}, {});
|
||||||
|
ASSERT_TRUE(false);
|
||||||
|
} catch (std::exception &e) {
|
||||||
|
//
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests14, test_empty_tanh_5) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {32, 0});
|
||||||
|
|
||||||
|
nd4j::ops::tanh op;
|
||||||
|
auto result = op.execute({&x}, {}, {});
|
||||||
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(x.isSameShape(z));
|
||||||
|
ASSERT_EQ(x, *z);
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
|
@ -905,9 +905,9 @@ TEST_F(DeclarableOpsTests4, Test_StridedSlice_Alex_2) {
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests4, Test_StridedSlice_Alex_3) {
|
TEST_F(DeclarableOpsTests4, Test_StridedSlice_Alex_3) {
|
||||||
auto x = NDArrayFactory::create<double>('c', {1}, {10});
|
auto x = NDArrayFactory::create<double>('c', {1}, {10});
|
||||||
auto begin = NDArrayFactory::create<double>('c', {1}, {0.});
|
auto begin = NDArrayFactory::create<int>('c', {1}, {(int)0});
|
||||||
auto end = NDArrayFactory::create<double>('c', {1}, {0.});
|
auto end = NDArrayFactory::create<int>('c', {1}, {(int)0});
|
||||||
auto stride = NDArrayFactory::create<double>('c', {1}, {1});
|
auto stride = NDArrayFactory::create<int>('c', {1}, {1});
|
||||||
//x.linspace(1);
|
//x.linspace(1);
|
||||||
//auto exp = NDArrayFactory::create<double>('c', {1,3,4,5});
|
//auto exp = NDArrayFactory::create<double>('c', {1,3,4,5});
|
||||||
//exp.linspace(1);
|
//exp.linspace(1);
|
||||||
|
|
|
@ -2421,6 +2421,26 @@ TEST_F(DeclarableOpsTests5, log_softmax_test11) {
|
||||||
delete results;
|
delete results;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests5, log_softmax_test12) {
|
||||||
|
|
||||||
|
auto input = NDArrayFactory::create<double>('c', {1, 4}, {0.1869, -1.4918, -0.6497, -0.8864});
|
||||||
|
auto expOutput = NDArrayFactory::create<double>('c', {1, 4}, {-0.6738, -2.3525, -1.5104, -1.7472});
|
||||||
|
|
||||||
|
for (int i = 0; i < 10; ++i)
|
||||||
|
{
|
||||||
|
nd4j::ops::log_softmax op;
|
||||||
|
auto results = op.execute({&input}, {}, {}, {}, false, nd4j::DataType::DOUBLE);
|
||||||
|
auto z = results->at(0);
|
||||||
|
|
||||||
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
ASSERT_TRUE(expOutput.isSameShape(z));
|
||||||
|
ASSERT_TRUE(expOutput.equalsTo(z, 1e-4));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests5, log_softmax_bp_test1) {
|
TEST_F(DeclarableOpsTests5, log_softmax_bp_test1) {
|
||||||
|
|
||||||
|
|
|
@ -109,7 +109,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_4) {
|
||||||
auto e = NDArrayFactory::create<double>('c', {1}, {0.});
|
auto e = NDArrayFactory::create<double>('c', {1}, {0.});
|
||||||
auto s = NDArrayFactory::create<double>('c', {1}, {1.0});
|
auto s = NDArrayFactory::create<double>('c', {1}, {1.0});
|
||||||
|
|
||||||
//auto exp = NDArrayFactory::create<double>('c', {2}, {1.0f, 2.0f});
|
auto exp = NDArrayFactory::create<double>(10);
|
||||||
|
|
||||||
//matrix.linspace(1);
|
//matrix.linspace(1);
|
||||||
|
|
||||||
|
@ -119,7 +119,8 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_4) {
|
||||||
|
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
z->printShapeInfo("SS OS shape");
|
z->printShapeInfo("SS OS shape");
|
||||||
ASSERT_TRUE(z->isEmpty());
|
z->printIndexedBuffer("SS OS out");
|
||||||
|
ASSERT_TRUE(z->equalsTo(exp));
|
||||||
//ASSERT_EQ(exp, *z);
|
//ASSERT_EQ(exp, *z);
|
||||||
|
|
||||||
delete result;
|
delete result;
|
||||||
|
|
|
@ -608,6 +608,24 @@ TEST_F(DeclarableOpsTests9, concat_test15) {
|
||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests9, concat_test16) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {0,2,3});
|
||||||
|
auto y = NDArrayFactory::create<double>('c', {0,2,3});
|
||||||
|
auto exp = NDArrayFactory::create<double>('c', {0,2,3});
|
||||||
|
|
||||||
|
nd4j::ops::concat op;
|
||||||
|
auto result = op.execute({&x, &y}, {}, {0});
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests9, tile_bp_test3) {
|
TEST_F(DeclarableOpsTests9, tile_bp_test3) {
|
||||||
|
|
||||||
|
|
|
@ -59,7 +59,8 @@ TEST_F(EmptyTests, Test_Create_Empty_2) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(EmptyTests, Test_Concat_1) {
|
TEST_F(EmptyTests, Test_Concat_1) {
|
||||||
auto empty = NDArrayFactory::empty_<float>();
|
// auto empty = NDArrayFactory::empty_<float>();
|
||||||
|
auto empty = new NDArray('c', {0}, nd4j::DataType::FLOAT32);//NDArrayFactory::create_<float>('c', {(Nd4jLong)0}};
|
||||||
auto vector = NDArrayFactory::create_<float>('c', {1}, {1.0f});
|
auto vector = NDArrayFactory::create_<float>('c', {1}, {1.0f});
|
||||||
|
|
||||||
ASSERT_TRUE(empty->isEmpty());
|
ASSERT_TRUE(empty->isEmpty());
|
||||||
|
@ -82,9 +83,9 @@ TEST_F(EmptyTests, Test_Concat_1) {
|
||||||
|
|
||||||
|
|
||||||
TEST_F(EmptyTests, Test_Concat_2) {
|
TEST_F(EmptyTests, Test_Concat_2) {
|
||||||
auto empty = NDArrayFactory::empty_<float>();
|
auto empty = new NDArray('c', {0}, nd4j::DataType::FLOAT32); //NDArrayFactory::empty_<float>();
|
||||||
auto scalar1 = NDArrayFactory::create_<float>(1.0f);
|
auto scalar1 = NDArrayFactory::create_<float>('c', {1}, {1.0f});
|
||||||
auto scalar2 = NDArrayFactory::create_<float>(2.0f);
|
auto scalar2 = NDArrayFactory::create_<float>('c', {1}, {2.0f});
|
||||||
auto exp = NDArrayFactory::create<float>('c', {2}, {1.f, 2.f});
|
auto exp = NDArrayFactory::create<float>('c', {2}, {1.f, 2.f});
|
||||||
|
|
||||||
ASSERT_TRUE(empty->isEmpty());
|
ASSERT_TRUE(empty->isEmpty());
|
||||||
|
@ -139,6 +140,23 @@ TEST_F(EmptyTests, Test_Reshape_2) {
|
||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(EmptyTests, Test_Reshape_3) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {1, 0, 0, 2});
|
||||||
|
auto y = NDArrayFactory::create<int>('c', {2}, {10, 0});
|
||||||
|
auto e = NDArrayFactory::create<float>('c', {10, 0});
|
||||||
|
|
||||||
|
nd4j::ops::reshape op;
|
||||||
|
auto result = op.execute({&x, &y}, {}, {});
|
||||||
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(e.isSameShape(z));
|
||||||
|
ASSERT_EQ(e, *z);
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(EmptyTests, Test_dup_1) {
|
TEST_F(EmptyTests, Test_dup_1) {
|
||||||
auto empty = NDArrayFactory::empty<int>();
|
auto empty = NDArrayFactory::empty<int>();
|
||||||
auto dup = empty.dup();
|
auto dup = empty.dup();
|
||||||
|
@ -148,3 +166,47 @@ TEST_F(EmptyTests, Test_dup_1) {
|
||||||
|
|
||||||
delete dup;
|
delete dup;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(EmptyTests, test_shaped_empty_1) {
|
||||||
|
auto empty = NDArrayFactory::create<float>('c', {2, 0, 3});
|
||||||
|
std::vector<Nd4jLong> shape = {2, 0, 3};
|
||||||
|
|
||||||
|
ASSERT_EQ(nd4j::DataType::FLOAT32, empty.dataType());
|
||||||
|
ASSERT_EQ(0, empty.lengthOf());
|
||||||
|
ASSERT_TRUE(empty.isEmpty());
|
||||||
|
ASSERT_EQ(shape, empty.getShapeAsVector());
|
||||||
|
ASSERT_EQ(3, empty.rankOf());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(EmptyTests, test_shaped_empty_2) {
|
||||||
|
auto empty = NDArrayFactory::create<float>('c', {0, 3});
|
||||||
|
std::vector<Nd4jLong> shape = {0, 3};
|
||||||
|
|
||||||
|
ASSERT_EQ(nd4j::DataType::FLOAT32, empty.dataType());
|
||||||
|
ASSERT_EQ(0, empty.lengthOf());
|
||||||
|
ASSERT_TRUE(empty.isEmpty());
|
||||||
|
ASSERT_EQ(shape, empty.getShapeAsVector());
|
||||||
|
ASSERT_EQ(2, empty.rankOf());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(EmptyTests, test_shaped_empty_3) {
|
||||||
|
auto empty = NDArrayFactory::create<float>('c', {0});
|
||||||
|
std::vector<Nd4jLong> shape = {0};
|
||||||
|
|
||||||
|
ASSERT_EQ(nd4j::DataType::FLOAT32, empty.dataType());
|
||||||
|
ASSERT_EQ(0, empty.lengthOf());
|
||||||
|
ASSERT_TRUE(empty.isEmpty());
|
||||||
|
ASSERT_EQ(shape, empty.getShapeAsVector());
|
||||||
|
ASSERT_EQ(1, empty.rankOf());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(EmptyTests, test_shaped_empty_4) {
|
||||||
|
auto shape = ConstantShapeHelper::getInstance()->vectorShapeInfo(0, nd4j::DataType::FLOAT32);
|
||||||
|
shape::printShapeInfoLinear("shape", shape);
|
||||||
|
NDArray array(shape, true, nd4j::LaunchContext::defaultContext());
|
||||||
|
std::vector<Nd4jLong> shapeOf({0});
|
||||||
|
|
||||||
|
ASSERT_TRUE(array.isEmpty());
|
||||||
|
ASSERT_EQ(1, array.rankOf());
|
||||||
|
ASSERT_EQ(shapeOf, array.getShapeAsVector());
|
||||||
|
}
|
|
@ -668,3 +668,47 @@ TEST_F(LegacyOpsTests, test_inverse_broadcast_2) {
|
||||||
delete row;
|
delete row;
|
||||||
delete erow;
|
delete erow;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(LegacyOpsTests, test_legacy_reduce_empty_1) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {2, 0, 3});
|
||||||
|
auto z = NDArrayFactory::create<float>('c', {2, 3});
|
||||||
|
auto e = NDArrayFactory::create<float>('c', {2, 3});
|
||||||
|
|
||||||
|
int dim = 1;
|
||||||
|
|
||||||
|
NativeOpExecutioner::execReduceSame(LaunchContext::defaultContext(), reduce::SameOps::Sum, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), &dim, 1, x.shapeInfo(), nullptr);
|
||||||
|
|
||||||
|
ASSERT_EQ(e, z);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(LegacyOpsTests, test_legacy_reduce_empty_2) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {2, 0, 3});
|
||||||
|
auto z = NDArrayFactory::create<float>('c', {2, 3});
|
||||||
|
auto e = NDArrayFactory::create<float>('c', {2, 3});
|
||||||
|
e.assign(std::numeric_limits<float>::infinity());
|
||||||
|
|
||||||
|
int dim = 1;
|
||||||
|
|
||||||
|
NativeOpExecutioner::execReduceSame(LaunchContext::defaultContext(), reduce::SameOps::Min, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), &dim, 1, x.shapeInfo(), nullptr);
|
||||||
|
|
||||||
|
ASSERT_EQ(e, z);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(LegacyOpsTests, test_legacy_reduce_empty_3) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {2, 0, 3});
|
||||||
|
auto z = NDArrayFactory::create<float>('c', {2, 3});
|
||||||
|
auto e = NDArrayFactory::create<float>('c', {2, 3});
|
||||||
|
e.assign(-std::numeric_limits<float>::infinity());
|
||||||
|
|
||||||
|
int dim = 1;
|
||||||
|
|
||||||
|
NativeOpExecutioner::execReduceSame(LaunchContext::defaultContext(), reduce::SameOps::Max, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), &dim, 1, x.shapeInfo(), nullptr);
|
||||||
|
|
||||||
|
ASSERT_EQ(e, z);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(LegacyOpsTests, test_legacy_transform_float_1) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {1, 0, 4});
|
||||||
|
|
||||||
|
NativeOpExecutioner::execTransformFloat(LaunchContext::defaultContext(), transform::FloatOps::RSqrt, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, nullptr, nullptr);
|
||||||
|
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue