From 68ea5f3688be9045cac3c06b1828b21262b22903 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Sat, 15 Jun 2019 21:34:34 +1000 Subject: [PATCH] Dev branch merge: dev_20190606 (#7904) * correct logsoftmax looss (#2) * Small SameDiff listener fix (#4) * Various fixes (#6) * #7839 Fix for asXMatrix and tests * #7866 EmbeddingSequenceLayer dtype fix + test * #7856 SameDiff save/load stream methods * #7859 RegressionEvaluation rank 4 fix + tests + axis configuration * EvaluationBinary 3d/4d * More evaluation 3d/4d tests * #7847 Evaluation empty checks * Small test ifx * #7848 Fix median edge case * Improve DL4J samediff layer tests * [WIP] FastText wrapper implemented (#8) * FastText implemented * Some fixes * Fix shapes for wordsNearest * Validation of input vectors * Fixes * Fixed test * Thread tagged * Some tweaks * setContextClassLoader for DeallocatorServiceThread * Numpy format tests (#1) * Various fixes (#11) * #7852 SameDiff gather fix * #7892 SameDiff placeholder to constant conversion * #7890 validate input rank for MLN/CG init methods * Fix broken permute shape calculation * Permute and gather fixes * Tests * #7850 LogSumExp fix + test * Handful of test fixes * Empty arrays with non-scalar shapes (#10) * minor rearrangements for lambdas * empty tensors with non-scalar shapes * numpy empty tensors with non-scalar shapes * few more empty tweaks * Small fixes * conv3d signature update * micro fix in batchnorm mkldnn * Import fixes * Fix * MKL-DNN update * Small fill fix * fill with empty input + test * Fixes * Small error improvement * Fix * one special test * couple of fixes for lstm * Rewrite TFGraphMapper.getNDArrayFromTensor to be maintainable and less error prone * Fixes * FP16 * Unsigned * BFloat16 * Fill op - empty tweaks * - couple of fixes for empty arrays construction - stack updated * strided slice fix * one transform test * provide method for reducing shapeInfo in case of input array is empty * Fixed reduceAlongDimensions to use empty input properly. * couple of broadcast tests * couple of tests broadcast tests + tweak to make them pass * add check of non-empty to methods producing sub-arrays * Fixed reshapeC with zeros in shape. * complete empty check in reduce_... legacy ops * Concat and cumsum/prod * Tweak to empty shape inference on import * add empty check to the rest of reduce legacy ops * one more test * correct typo in evalReduceShapeInfoEmpty * Added tests for reduce_* ops to tests with zero shapes. * few more tests for empty reductions * Fixed strided_slice op with empty case and tests. * one more empty reduction test * Fixed strided_slice test. * add empty check to NDArray::reshapei * infOrMax * empty min/max with infinity tests * made unstack working correctly with empty arrays * few IndexReduce tests + tweaks for empty shapes * add test for empty concat * few tests fixed * Validation fix for reductions on empty shapes * Reverse fix * Reduction shape calc fixes * SameDiff.generateOutputVariable: don't use shape function to determine number of outputs * Range fix * - NDArray constructor updated for scalars/empty arrays - few tests fixed * More fixes * Empty creator fixes * concat fix * concat fix * TF import tests: allow 'both all NaN' and 'both all inf' to pass * Slice, zero fraction, and reshape fixes * transpose, gather * Zero fraction * scalar cast fix * Empty reduction axis support * few more tests fixed * Fixed input checks conforming with TF for concat op and tests. * few tests fixed * matmul scalar shape fix * Fixed checkout for data type and scalarity with concat to allow non-empty scalars with vector concats. * broadcast bool fix * few more tests * few more tests * correct evalReduceShapeInfoEmpty * argmax/argmin + tests * one more empty edge case + one more test * argmax/argmin/realdiv_bp tweaks * empty reshape test + fix * Helper fixes * Small fixes * Gather test fix * Gather test fix * Small fixes * reduce scalar zero values * scalar mean workaround * Remove debug code * along dim mean workaround * one more test * - equalsTo() tweak for empty arrays - one more test * broadcast tweaks --- .../embedding/EmbeddingLayerTest.java | 228 +- .../nn/layers/samediff/TestSameDiffConv.java | 10 + .../nn/layers/samediff/TestSameDiffDense.java | 22 + .../samediff/TestSameDiffDenseVertex.java | 6 + .../layers/samediff/TestSameDiffLambda.java | 13 + .../layers/samediff/TestSameDiffOutput.java | 12 + .../loader/WordVectorSerializer.java | 37 + .../reader/impl/BasicModelUtils.java | 20 +- .../reader/impl/FlatModelUtils.java | 3 + .../reader/impl/TreeModelUtils.java | 1 + .../embeddings/wordvectors/WordVectors.java | 6 + .../wordvectors/WordVectorsImpl.java | 6 + .../models/fasttext/FastText.java | 353 +- .../models/word2vec/StaticWord2Vec.java | 5 + .../models/fasttext/FastTextTest.java | 96 +- .../WordVectorSerializerTest.java | 31 + .../nn/graph/ComputationGraph.java | 1 + .../embedding/EmbeddingSequenceLayer.java | 3 +- .../nn/multilayer/MultiLayerNetwork.java | 2 + .../ui/play/TestSameDiffUI.java | 2 +- libnd4j/blas/NDArray.h | 3 - libnd4j/blas/NDArray.hpp | 61 +- libnd4j/blas/NativeOps.h | 13 +- libnd4j/blas/cpu/NDArray.cpp | 326 +- libnd4j/blas/cpu/NDArrayLambda.hpp | 325 ++ libnd4j/blas/cpu/NativeOps.cpp | 26 + libnd4j/blas/cuda/NDArray.cu | 8 +- libnd4j/blas/cuda/NativeOps.cu | 26 + libnd4j/include/array/DataTypeUtils.h | 39 + .../include/array/impl/ShapeDescriptor.cpp | 40 +- libnd4j/include/helpers/ShapeBuilders.h | 4 +- libnd4j/include/helpers/ShapeUtils.h | 6 + .../include/helpers/cpu/ConstantTadHelper.cpp | 4 +- .../include/helpers/cuda/ConstantTadHelper.cu | 3 +- .../include/helpers/impl/ShapeBuilders.cpp | 35 +- libnd4j/include/helpers/impl/ShapeUtils.cpp | 99 +- libnd4j/include/helpers/shape.h | 4 +- libnd4j/include/loops/cpu/indexreduce.cpp | 18 +- .../include/loops/cpu/reduce/reduce_bool.cpp | 43 +- .../include/loops/cpu/reduce/reduce_float.cpp | 33 +- .../include/loops/cpu/reduce/reduce_long.cpp | 43 +- .../include/loops/cpu/reduce/reduce_same.cpp | 40 +- libnd4j/include/loops/cpu/reduce3.cpp | 22 +- .../include/loops/cpu/summarystatsreduce.cpp | 48 +- .../generic/broadcastable/realdiv.cpp | 2 +- .../ops/declarable/generic/convo/conv1d.cpp | 4 +- .../ops/declarable/generic/convo/conv2d.cpp | 6 +- .../ops/declarable/generic/convo/conv3d.cpp | 6 +- .../declarable/generic/convo/deconv2d_tf.cpp | 2 +- .../ops/declarable/generic/convo/deconv3d.cpp | 4 +- .../generic/convo/depthwiseConv2d.cpp | 4 +- .../generic/convo/pointwiseConv2d.cpp | 2 +- .../generic/convo/pooling/avgpool2d.cpp | 4 +- .../generic/convo/pooling/avgpool3d.cpp | 4 +- .../generic/convo/pooling/maxpool2d.cpp | 4 +- .../generic/convo/pooling/maxpool3d.cpp | 4 +- .../generic/convo/pooling/pnormpool2d.cpp | 4 +- .../ops/declarable/generic/convo/sconv2d.cpp | 10 +- .../declarable/generic/convo/upsampling2d.cpp | 4 +- .../declarable/generic/convo/upsampling3d.cpp | 4 +- .../ops/declarable/generic/nn/batchnorm.cpp | 2 +- .../declarable/generic/parity_ops/argmax.cpp | 5 + .../declarable/generic/parity_ops/argmin.cpp | 4 + .../declarable/generic/parity_ops/fill.cpp | 6 - .../declarable/generic/parity_ops/range.cpp | 25 +- .../declarable/generic/parity_ops/rank.cpp | 3 +- .../generic/parity_ops/reduce_max.cpp | 2 +- .../declarable/generic/parity_ops/slice.cpp | 8 +- .../declarable/generic/parity_ops/stack.cpp | 38 +- .../generic/parity_ops/strided_slice.cpp | 40 +- .../declarable/generic/parity_ops/unstack.cpp | 20 +- .../generic/parity_ops/zero_fraction.cpp | 6 + .../generic/recurrent/lstmBlock.cpp | 10 +- .../generic/recurrent/lstmBlockCell.cpp | 10 +- .../declarable/generic/shape/broadcast_to.cpp | 1 - .../ops/declarable/generic/shape/permute.cpp | 65 +- .../ops/declarable/generic/shape/reshape.cpp | 55 +- .../declarable/generic/transforms/concat.cpp | 108 +- .../declarable/generic/transforms/cumprod.cpp | 5 + .../declarable/generic/transforms/cumsum.cpp | 5 + .../declarable/generic/transforms/gather.cpp | 11 - .../declarable/generic/transforms/reverse.cpp | 5 + .../ops/declarable/helpers/convolutions.h | 48 +- .../declarable/helpers/cpu/activations.cpp | 12 +- .../declarable/helpers/cpu/convolutions.cpp | 2974 ++++++++--------- .../ops/declarable/helpers/cpu/lstm.cpp | 14 +- .../declarable/helpers/cpu/max_pooling.cpp | 2 +- .../ops/declarable/impl/BroadcastableOp.cpp | 13 +- libnd4j/include/ops/ops.h | 14 +- .../layers_tests/BroadcastableOpsTests.cpp | 146 + .../layers_tests/ConvolutionTests1.cpp | 6 +- .../layers_tests/DeclarableOpsTests10.cpp | 8 +- .../layers_tests/DeclarableOpsTests11.cpp | 26 +- .../layers_tests/DeclarableOpsTests12.cpp | 10 +- .../layers_tests/DeclarableOpsTests14.cpp | 199 +- .../layers_tests/DeclarableOpsTests4.cpp | 6 +- .../layers_tests/DeclarableOpsTests5.cpp | 292 +- .../layers_tests/DeclarableOpsTests6.cpp | 5 +- .../layers_tests/DeclarableOpsTests9.cpp | 318 +- libnd4j/tests_cpu/layers_tests/EmptyTests.cpp | 70 +- .../tests_cpu/layers_tests/LegacyOpsTests.cpp | 44 + .../layers_tests/MultiDataTypeTests.cpp | 32 +- .../tests_cpu/layers_tests/NDArrayTests2.cpp | 18 + .../tests_cpu/layers_tests/ParityOpsTests.cpp | 51 +- libnd4j/tests_cpu/layers_tests/TadTests.cpp | 1 + .../DifferentialFunctionFactory.java | 16 +- .../nd4j/autodiff/listeners/BaseListener.java | 2 +- .../org/nd4j/autodiff/listeners/Listener.java | 2 +- .../autodiff/listeners/impl/UIListener.java | 4 +- .../nd4j/autodiff/samediff/SDVariable.java | 12 + .../org/nd4j/autodiff/samediff/SameDiff.java | 269 +- .../samediff/internal/InferenceSession.java | 52 +- .../nd4j/autodiff/samediff/ops/SDBaseOps.java | 16 + .../nd4j/autodiff/samediff/ops/SDMath.java | 21 +- .../autodiff/validation/GradCheckUtil.java | 4 +- .../org/nd4j/evaluation/BaseEvaluation.java | 32 +- .../evaluation/classification/Evaluation.java | 32 +- .../classification/EvaluationBinary.java | 81 +- .../classification/EvaluationCalibration.java | 4 +- .../nd4j/evaluation/classification/ROC.java | 34 +- .../evaluation/classification/ROCBinary.java | 18 + .../classification/ROCMultiClass.java | 18 + .../regression/RegressionEvaluation.java | 45 +- .../imports/graphmapper/tf/TFGraphMapper.java | 419 +-- .../tf/tensors/TFTensorMapper.java | 41 + .../tf/tensors/TFTensorMappers.java | 726 ++++ .../nd4j/linalg/api/ndarray/BaseNDArray.java | 28 +- .../nd4j/linalg/api/ops/BaseReduceBoolOp.java | 2 +- .../linalg/api/ops/BaseReduceFloatOp.java | 2 +- .../nd4j/linalg/api/ops/BaseReduceLongOp.java | 2 +- .../org/nd4j/linalg/api/ops/BaseReduceOp.java | 2 + .../nd4j/linalg/api/ops/BaseReduceSameOp.java | 2 +- .../api/ops/impl/broadcast/BroadcastTo.java | 12 + .../api/ops/impl/reduce/custom/LogSumExp.java | 41 +- .../linalg/api/ops/impl/shape/Gather.java | 49 +- .../linalg/api/ops/impl/shape/Permute.java | 16 +- .../linalg/api/ops/impl/shape/Transpose.java | 31 +- .../api/ops/impl/transforms/custom/Fill.java | 65 +- .../ops/impl/transforms/custom/ListDiff.java | 5 + .../linalg/api/ops/random/impl/Range.java | 24 +- .../linalg/api/shape/LongShapeDescriptor.java | 4 + .../java/org/nd4j/linalg/api/shape/Shape.java | 23 +- .../linalg/factory/BaseNDArrayFactory.java | 1 + .../java/org/nd4j/linalg/factory/Nd4j.java | 35 +- .../deallocation/DeallocatorService.java | 4 + .../nativeblas/BaseNativeNDArrayFactory.java | 145 +- .../nativecpu/ops/NativeOpExecutioner.java | 38 +- .../opvalidation/ShapeOpValidation.java | 175 +- .../opvalidation/TransformOpValidation.java | 15 + .../samediff/FlatBufferSerdeTest.java | 30 +- .../nd4j/autodiff/samediff/SameDiffTests.java | 42 + .../org/nd4j/autodiff/ui/UIListenerTest.java | 13 +- .../nd4j/evaluation/EmptyEvaluationTests.java | 138 + .../nd4j/evaluation/EvaluationBinaryTest.java | 201 ++ .../evaluation/EvaluationCalibrationTest.java | 66 + .../org/nd4j/evaluation/ROCBinaryTest.java | 203 ++ .../nd4j/evaluation/RegressionEvalTest.java | 190 +- .../TFGraphs/TFGraphTestAllHelper.java | 63 +- .../imports/listener/ImportDebugListener.java | 2 +- .../org/nd4j/linalg/NDArrayTestsFortran.java | 6 +- .../test/java/org/nd4j/linalg/Nd4jTestsC.java | 114 +- .../linalg/api/indexing/IndexingTests.java | 2 +- .../linalg/api/indexing/IndexingTestsC.java | 4 +- .../linalg/broadcast/BasicBroadcastTests.java | 30 + .../nd4j/linalg/serde/NumpyFormatTests.java | 45 + .../org/nd4j/linalg/shape/EmptyTests.java | 140 + .../linalg/api/buffer/BaseDataBuffer.java | 11 +- .../org/nd4j/linalg/api/buffer/DataType.java | 4 + .../java/org/nd4j/linalg/util/ArrayUtil.java | 30 + 169 files changed, 7207 insertions(+), 3633 deletions(-) create mode 100644 libnd4j/blas/cpu/NDArrayLambda.hpp create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/tf/tensors/TFTensorMapper.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/tf/tensors/TFTensorMappers.java create mode 100644 nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EmptyEvaluationTests.java diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/feedforward/embedding/EmbeddingLayerTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/feedforward/embedding/EmbeddingLayerTest.java index ef6a2ef4b..d8921346d 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/feedforward/embedding/EmbeddingLayerTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/feedforward/embedding/EmbeddingLayerTest.java @@ -23,6 +23,7 @@ import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.conf.layers.EmbeddingLayer; import org.deeplearning4j.nn.conf.layers.EmbeddingSequenceLayer; @@ -283,7 +284,6 @@ public class EmbeddingLayerTest extends BaseDL4JTest { net.computeGradientAndScore(); net2.computeGradientAndScore(); - System.out.println(net.score() + "\t" + net2.score()); assertEquals(net2.score(), net.score(), 1e-6); Map gradient = net.gradient().gradientForVariable(); @@ -441,85 +441,87 @@ public class EmbeddingLayerTest extends BaseDL4JTest { int numInputClasses = 10; int timeSeriesLength = 5; - for (int nExamples : miniBatchSizes) { - Nd4j.getRandom().setSeed(12345); + for (DataType maskDtype : new DataType[]{DataType.FLOAT, DataType.DOUBLE, DataType.INT}) { + for (int nExamples : miniBatchSizes) { + Nd4j.getRandom().setSeed(12345); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) - .updater(new Sgd(0.1)).seed(12345).list() - .layer(0, new EmbeddingLayer.Builder().hasBias(true).activation(Activation.TANH).nIn(numInputClasses) - .nOut(5).build()) - .layer(1, new DenseLayer.Builder().activation(Activation.TANH).nIn(5).nOut(4).build()) - .layer(2, new GravesLSTM.Builder().activation(Activation.TANH).nIn(4).nOut(3).build()) - .layer(3, new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(3) - .nOut(4).build()) - .inputPreProcessor(0, new RnnToFeedForwardPreProcessor()) - .inputPreProcessor(2, new FeedForwardToRnnPreProcessor()).build(); + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) + .updater(new Sgd(0.1)).seed(12345).list() + .layer(0, new EmbeddingLayer.Builder().hasBias(true).activation(Activation.TANH).nIn(numInputClasses) + .nOut(5).build()) + .layer(1, new DenseLayer.Builder().activation(Activation.TANH).nIn(5).nOut(4).build()) + .layer(2, new GravesLSTM.Builder().activation(Activation.TANH).nIn(4).nOut(3).build()) + .layer(3, new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(3) + .nOut(4).build()) + .inputPreProcessor(0, new RnnToFeedForwardPreProcessor()) + .inputPreProcessor(2, new FeedForwardToRnnPreProcessor()).build(); - MultiLayerNetwork net = new MultiLayerNetwork(conf); - net.init(); + MultiLayerNetwork net = new MultiLayerNetwork(conf); + net.init(); - MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder() - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) - .updater(new Sgd(0.1)).seed(12345).list() - .layer(0, new DenseLayer.Builder().activation(Activation.TANH).nIn(numInputClasses).nOut(5) - .build()) - .layer(1, new DenseLayer.Builder().activation(Activation.TANH).nIn(5).nOut(4).build()) - .layer(2, new GravesLSTM.Builder().activation(Activation.TANH).nIn(4).nOut(3).build()) - .layer(3, new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(3) - .nOut(4).build()) - .inputPreProcessor(0, new RnnToFeedForwardPreProcessor()) - .inputPreProcessor(2, new FeedForwardToRnnPreProcessor()).build(); + MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder() + .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) + .updater(new Sgd(0.1)).seed(12345).list() + .layer(0, new DenseLayer.Builder().activation(Activation.TANH).nIn(numInputClasses).nOut(5) + .build()) + .layer(1, new DenseLayer.Builder().activation(Activation.TANH).nIn(5).nOut(4).build()) + .layer(2, new GravesLSTM.Builder().activation(Activation.TANH).nIn(4).nOut(3).build()) + .layer(3, new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(3) + .nOut(4).build()) + .inputPreProcessor(0, new RnnToFeedForwardPreProcessor()) + .inputPreProcessor(2, new FeedForwardToRnnPreProcessor()).build(); - MultiLayerNetwork net2 = new MultiLayerNetwork(conf2); - net2.init(); + MultiLayerNetwork net2 = new MultiLayerNetwork(conf2); + net2.init(); - net2.setParams(net.params().dup()); + net2.setParams(net.params().dup()); - INDArray inEmbedding = Nd4j.zeros(nExamples, 1, timeSeriesLength); - INDArray inDense = Nd4j.zeros(nExamples, numInputClasses, timeSeriesLength); + INDArray inEmbedding = Nd4j.zeros(nExamples, 1, timeSeriesLength); + INDArray inDense = Nd4j.zeros(nExamples, numInputClasses, timeSeriesLength); - INDArray labels = Nd4j.zeros(nExamples, 4, timeSeriesLength); + INDArray labels = Nd4j.zeros(nExamples, 4, timeSeriesLength); - for (int i = 0; i < nExamples; i++) { - for (int j = 0; j < timeSeriesLength; j++) { - int inIdx = r.nextInt(numInputClasses); - inEmbedding.putScalar(new int[]{i, 0, j}, inIdx); - inDense.putScalar(new int[]{i, inIdx, j}, 1.0); + for (int i = 0; i < nExamples; i++) { + for (int j = 0; j < timeSeriesLength; j++) { + int inIdx = r.nextInt(numInputClasses); + inEmbedding.putScalar(new int[]{i, 0, j}, inIdx); + inDense.putScalar(new int[]{i, inIdx, j}, 1.0); - int outIdx = r.nextInt(4); - labels.putScalar(new int[]{i, outIdx, j}, 1.0); + int outIdx = r.nextInt(4); + labels.putScalar(new int[]{i, outIdx, j}, 1.0); + } } - } - INDArray inputMask = Nd4j.zeros(nExamples, timeSeriesLength); - for (int i = 0; i < nExamples; i++) { - for (int j = 0; j < timeSeriesLength; j++) { - inputMask.putScalar(new int[]{i, j}, (r.nextBoolean() ? 1.0 : 0.0)); + INDArray inputMask = Nd4j.zeros(maskDtype, nExamples, timeSeriesLength); + for (int i = 0; i < nExamples; i++) { + for (int j = 0; j < timeSeriesLength; j++) { + inputMask.putScalar(new int[]{i, j}, (r.nextBoolean() ? 1.0 : 0.0)); + } } - } - net.setLayerMaskArrays(inputMask, null); - net2.setLayerMaskArrays(inputMask, null); - List actEmbedding = net.feedForward(inEmbedding, false); - List actDense = net2.feedForward(inDense, false); - for (int i = 1; i < actEmbedding.size(); i++) { - assertEquals(actDense.get(i), actEmbedding.get(i)); - } + net.setLayerMaskArrays(inputMask, null); + net2.setLayerMaskArrays(inputMask, null); + List actEmbedding = net.feedForward(inEmbedding, false); + List actDense = net2.feedForward(inDense, false); + for (int i = 1; i < actEmbedding.size(); i++) { + assertEquals(actDense.get(i), actEmbedding.get(i)); + } - net.setLabels(labels); - net2.setLabels(labels); - net.computeGradientAndScore(); - net2.computeGradientAndScore(); + net.setLabels(labels); + net2.setLabels(labels); + net.computeGradientAndScore(); + net2.computeGradientAndScore(); - System.out.println(net.score() + "\t" + net2.score()); - assertEquals(net2.score(), net.score(), 1e-5); + System.out.println(net.score() + "\t" + net2.score()); + assertEquals(net2.score(), net.score(), 1e-5); - Map gradients = net.gradient().gradientForVariable(); - Map gradients2 = net2.gradient().gradientForVariable(); - assertEquals(gradients.keySet(), gradients2.keySet()); - for (String s : gradients.keySet()) { - assertEquals(gradients2.get(s), gradients.get(s)); + Map gradients = net.gradient().gradientForVariable(); + Map gradients2 = net2.gradient().gradientForVariable(); + assertEquals(gradients.keySet(), gradients2.keySet()); + for (String s : gradients.keySet()) { + assertEquals(gradients2.get(s), gradients.get(s)); + } } } } @@ -583,6 +585,104 @@ public class EmbeddingLayerTest extends BaseDL4JTest { } } + @Test + public void testEmbeddingSequenceLayerWithMasking() { + //Idea: have masking on the input with an embedding and dense layers on input + //Ensure that the parameter gradients for the inputs don't depend on the inputs when inputs are masked + + int[] miniBatchSizes = {1, 3}; + int nIn = 2; + Random r = new Random(12345); + + int numInputClasses = 10; + int timeSeriesLength = 5; + + for (DataType maskDtype : new DataType[]{DataType.FLOAT, DataType.DOUBLE, DataType.INT}) { + for (DataType inLabelDtype : new DataType[]{DataType.FLOAT, DataType.DOUBLE, DataType.INT}) { + for(int inputRank : new int[]{2, 3}) { + for (int nExamples : miniBatchSizes) { + Nd4j.getRandom().setSeed(12345); + + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) + .updater(new Sgd(0.1)).seed(12345).list() + .layer(0, new EmbeddingSequenceLayer.Builder().hasBias(true).activation(Activation.TANH).nIn(numInputClasses) + .nOut(5).build()) + .layer(1, new DenseLayer.Builder().activation(Activation.TANH).nIn(5).nOut(4).build()) + .layer(2, new LSTM.Builder().activation(Activation.TANH).nIn(4).nOut(3).build()) + .layer(3, new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(3) + .nOut(4).build()) + .setInputType(InputType.recurrent(1)).build(); + + MultiLayerNetwork net = new MultiLayerNetwork(conf); + net.init(); + + MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder() + .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) + .updater(new Sgd(0.1)).seed(12345).list() + .layer(0, new DenseLayer.Builder().activation(Activation.TANH).nIn(numInputClasses).nOut(5) + .build()) + .layer(1, new DenseLayer.Builder().activation(Activation.TANH).nIn(5).nOut(4).build()) + .layer(2, new LSTM.Builder().activation(Activation.TANH).nIn(4).nOut(3).build()) + .layer(3, new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(3) + .nOut(4).build()) + .setInputType(InputType.recurrent(1)).build(); + + MultiLayerNetwork net2 = new MultiLayerNetwork(conf2); + net2.init(); + + net2.setParams(net.params().dup()); + + INDArray inEmbedding = Nd4j.zeros(inLabelDtype, inputRank == 2 ? new long[]{nExamples, timeSeriesLength} : new long[]{nExamples, 1, timeSeriesLength}); + INDArray inDense = Nd4j.zeros(inLabelDtype, nExamples, numInputClasses, timeSeriesLength); + + INDArray labels = Nd4j.zeros(inLabelDtype, nExamples, 4, timeSeriesLength); + + for (int i = 0; i < nExamples; i++) { + for (int j = 0; j < timeSeriesLength; j++) { + int inIdx = r.nextInt(numInputClasses); + inEmbedding.putScalar(inputRank == 2 ? new int[]{i, j} : new int[]{i, 0, j}, inIdx); + inDense.putScalar(new int[]{i, inIdx, j}, 1.0); + + int outIdx = r.nextInt(4); + labels.putScalar(new int[]{i, outIdx, j}, 1.0); + } + } + + INDArray inputMask = Nd4j.zeros(maskDtype, nExamples, timeSeriesLength); + for (int i = 0; i < nExamples; i++) { + for (int j = 0; j < timeSeriesLength; j++) { + inputMask.putScalar(new int[]{i, j}, (r.nextBoolean() ? 1.0 : 0.0)); + } + } + + net.setLayerMaskArrays(inputMask, null); + net2.setLayerMaskArrays(inputMask, null); + List actEmbedding = net.feedForward(inEmbedding, false); + List actDense = net2.feedForward(inDense, false); + for (int i = 2; i < actEmbedding.size(); i++) { //Start from layer 2: EmbeddingSequence is 3d, first dense is 2d (before reshape) + assertEquals(actDense.get(i), actEmbedding.get(i)); + } + + net.setLabels(labels); + net2.setLabels(labels); + net.computeGradientAndScore(); + net2.computeGradientAndScore(); + + assertEquals(net2.score(), net.score(), 1e-5); + + Map gradients = net.gradient().gradientForVariable(); + Map gradients2 = net2.gradient().gradientForVariable(); + assertEquals(gradients.keySet(), gradients2.keySet()); + for (String s : gradients.keySet()) { + assertEquals(gradients2.get(s), gradients.get(s)); + } + } + } + } + } + } + @EqualsAndHashCode private static class WordVectorsMockup implements EmbeddingInitializer { diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffConv.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffConv.java index 2925dccff..d45195870 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffConv.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffConv.java @@ -213,6 +213,12 @@ public class TestSameDiffConv extends BaseDL4JTest { INDArray outLoaded = netLoaded.output(in); assertEquals(msg, outExp, outLoaded); + + //Sanity check on different minibatch sizes: + INDArray newIn = Nd4j.vstack(in, in); + INDArray outMbsd = net.output(newIn); + INDArray outMb = net2.output(newIn); + assertEquals(outMb, outMbsd); } } } @@ -306,6 +312,10 @@ public class TestSameDiffConv extends BaseDL4JTest { assertTrue(msg, gradOK); TestUtils.testModelSerialization(net); + + //Sanity check on different minibatch sizes: + INDArray newIn = Nd4j.vstack(f, f); + net.output(newIn); } } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffDense.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffDense.java index a0adf36fd..df6757608 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffDense.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffDense.java @@ -137,6 +137,12 @@ public class TestSameDiffDense extends BaseDL4JTest { INDArray outLoaded = netLoaded.output(in); assertEquals(outExp, outLoaded); + + //Sanity check on different minibatch sizes: + INDArray newIn = Nd4j.vstack(in, in); + INDArray outMbsd = net.output(newIn); + INDArray outMb = net2.output(newIn); + assertEquals(outMb, outMbsd); } } } @@ -314,6 +320,12 @@ public class TestSameDiffDense extends BaseDL4JTest { netSD.computeGradientAndScore(); // netStandard.computeGradientAndScore(); // assertEquals(netStandard.gradient().gradient(), netSD.gradient().gradient()); + + //Sanity check on different minibatch sizes: + INDArray newIn = Nd4j.vstack(in, in); + INDArray outMbsd = netSD.output(newIn); + INDArray outMb = netStandard.output(newIn); + assertEquals(outMb, outMbsd); } } } @@ -377,6 +389,12 @@ public class TestSameDiffDense extends BaseDL4JTest { assertEquals(s, netStandard.params(), netSD.params()); assertEquals(s, netStandard.getUpdater().getStateViewArray(), netSD.getUpdater().getStateViewArray()); } + + //Sanity check on different minibatch sizes: + INDArray newIn = Nd4j.vstack(ds.getFeatures(), ds.getFeatures()); + INDArray outMbsd = netSD.output(newIn); + INDArray outMb = netStandard.output(newIn); + assertEquals(outMb, outMbsd); } @Test @@ -417,6 +435,10 @@ public class TestSameDiffDense extends BaseDL4JTest { assertTrue(msg, gradOK); TestUtils.testModelSerialization(net); + + //Sanity check on different minibatch sizes: + INDArray newIn = Nd4j.vstack(f, f); + net.output(newIn); } } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffDenseVertex.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffDenseVertex.java index 4d7fca598..7f9a54f8e 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffDenseVertex.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffDenseVertex.java @@ -166,6 +166,12 @@ public class TestSameDiffDenseVertex extends BaseDL4JTest { outSD = loaded.outputSingle(in); outStd = netStandard.outputSingle(in); assertEquals(outStd, outSD); + + //Sanity check on different minibatch sizes: + INDArray newIn = Nd4j.vstack(in, in); + INDArray outMbsd = netSD.output(newIn)[0]; + INDArray outMb = netStandard.output(newIn)[0]; + assertEquals(outMb, outMbsd); } } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffLambda.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffLambda.java index c96cf0ad8..6264aaf72 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffLambda.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffLambda.java @@ -115,6 +115,12 @@ public class TestSameDiffLambda extends BaseDL4JTest { outStd = std.outputSingle(in); assertEquals(outStd, outLambda); + + //Sanity check on different minibatch sizes: + INDArray newIn = Nd4j.vstack(in, in); + INDArray outMbsd = lambda.output(newIn)[0]; + INDArray outMb = std.output(newIn)[0]; + assertEquals(outMb, outMbsd); } @Test @@ -186,5 +192,12 @@ public class TestSameDiffLambda extends BaseDL4JTest { outStd = std.output(in1, in2)[0]; assertEquals(outStd, outLambda); + + //Sanity check on different minibatch sizes: + INDArray newIn1 = Nd4j.vstack(in1, in1); + INDArray newIn2 = Nd4j.vstack(in2, in2); + INDArray outMbsd = lambda.output(newIn1, newIn2)[0]; + INDArray outMb = std.output(newIn1, newIn2)[0]; + assertEquals(outMb, outMbsd); } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffOutput.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffOutput.java index b0089c2f7..53e6d0ed2 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffOutput.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffOutput.java @@ -90,6 +90,12 @@ public class TestSameDiffOutput extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(confSD.clone()); net.init(); net.fit(ds); + + //Sanity check on different minibatch sizes: + INDArray newIn = Nd4j.vstack(in, in); + INDArray outMbsd = netSD.output(newIn); + INDArray outMb = netStd.output(newIn); + assertEquals(outMb, outMbsd); } @@ -164,6 +170,12 @@ public class TestSameDiffOutput extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(confSD.clone()); net.init(); net.fit(ds); + + //Sanity check on different minibatch sizes: + INDArray newIn = Nd4j.vstack(in, in); + INDArray outMbsd = netSD.output(newIn); + INDArray outMb = netStd.output(newIn); + assertEquals(outMb, outMbsd); } } diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/loader/WordVectorSerializer.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/loader/WordVectorSerializer.java index 78de12332..fdecb5ea4 100755 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/loader/WordVectorSerializer.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/loader/WordVectorSerializer.java @@ -32,6 +32,7 @@ import org.deeplearning4j.models.embeddings.learning.impl.elements.SkipGram; import org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils; import org.deeplearning4j.models.embeddings.wordvectors.WordVectors; import org.deeplearning4j.models.embeddings.wordvectors.WordVectorsImpl; +import org.deeplearning4j.models.fasttext.FastText; import org.deeplearning4j.models.glove.Glove; import org.deeplearning4j.models.paragraphvectors.ParagraphVectors; import org.deeplearning4j.models.sequencevectors.SequenceVectors; @@ -3089,6 +3090,42 @@ public class WordVectorSerializer { word2Vec.setModelUtils(vectors.getModelUtils()); return word2Vec; } + + public static void writeWordVectors(@NonNull FastText vectors, @NonNull File path) throws IOException { + ObjectOutputStream outputStream = null; + try { + outputStream = new ObjectOutputStream(new FileOutputStream(path )); + outputStream.writeObject(vectors); + } + finally { + try { + if (outputStream != null) { + outputStream.flush(); + outputStream.close(); + } + } catch (IOException ex) { + ex.printStackTrace(); + } + } + } + + public static FastText readWordVectors(File path) { + FastText result = null; + try { + FileInputStream fileIn = new FileInputStream(path); + ObjectInputStream in = new ObjectInputStream(fileIn); + try { + result = (FastText) in.readObject(); + } catch (ClassNotFoundException ex) { + + } + } catch (FileNotFoundException ex) { + ex.printStackTrace(); + } catch (IOException ex) { + ex.printStackTrace(); + } + return result; + } public static void printOutProjectedMemoryUse(long numWords, int vectorLength, int numTables) { double memSize = numWords * vectorLength * Nd4j.sizeOfDataType() * numTables; diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/reader/impl/BasicModelUtils.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/reader/impl/BasicModelUtils.java index 68ed56073..240a33bb1 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/reader/impl/BasicModelUtils.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/reader/impl/BasicModelUtils.java @@ -21,6 +21,7 @@ import lombok.AllArgsConstructor; import lombok.Data; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; +import lombok.val; import org.deeplearning4j.models.embeddings.WeightLookupTable; import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable; import org.deeplearning4j.models.embeddings.reader.ModelUtils; @@ -207,7 +208,6 @@ public class BasicModelUtils implements ModelUtils } INDArray mean = words.isMatrix() ? words.mean(0) : words; - Collection tempRes = wordsNearest(mean, top + positive.size() + negative.size()); List realResults = new ArrayList<>(); @@ -232,6 +232,22 @@ public class BasicModelUtils implements ModelUtils return wordsNearestSum(vec, n); } + protected INDArray adjustRank(INDArray words) { + if (lookupTable instanceof InMemoryLookupTable) { + InMemoryLookupTable l = (InMemoryLookupTable) lookupTable; + + INDArray syn0 = l.getSyn0(); + if (!words.dataType().equals(syn0.dataType())) { + return words.castTo(syn0.dataType()); + } + if (words.rank() == 0 || words.rank() > 2) { + throw new IllegalStateException("Invalid rank for wordsNearest method"); + } else if (words.rank() == 1) { + return words.reshape(1, -1); + } + } + return words; + } /** * Words nearest based on positive and negative words * * @param top the top n words @@ -239,6 +255,8 @@ public class BasicModelUtils implements ModelUtils */ @Override public Collection wordsNearest(INDArray words, int top) { + words = adjustRank(words); + if (lookupTable instanceof InMemoryLookupTable) { InMemoryLookupTable l = (InMemoryLookupTable) lookupTable; diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/reader/impl/FlatModelUtils.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/reader/impl/FlatModelUtils.java index ef2414bd5..bed701d73 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/reader/impl/FlatModelUtils.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/reader/impl/FlatModelUtils.java @@ -16,6 +16,7 @@ package org.deeplearning4j.models.embeddings.reader.impl; +import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable; import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.ops.transforms.Transforms; @@ -64,6 +65,8 @@ public class FlatModelUtils extends BasicModelUtils wordsNearest(INDArray words, int top) { Counter distances = new Counter<>(); + words = adjustRank(words); + for (String s : vocabCache.words()) { INDArray otherVec = lookupTable.vector(s); double sim = Transforms.cosineSim(Transforms.unitVec(words.dup()), Transforms.unitVec(otherVec.dup())); diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/reader/impl/TreeModelUtils.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/reader/impl/TreeModelUtils.java index b82a8b5af..8b0be242d 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/reader/impl/TreeModelUtils.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/reader/impl/TreeModelUtils.java @@ -103,6 +103,7 @@ public class TreeModelUtils extends BasicModelUtils wordsNearest(INDArray words, int top) { checkTree(); + words = adjustRank(words); List add = new ArrayList<>(); List distances = new ArrayList<>(); diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/wordvectors/WordVectors.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/wordvectors/WordVectors.java index f766bafc5..a9f8c4fd4 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/wordvectors/WordVectors.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/wordvectors/WordVectors.java @@ -172,4 +172,10 @@ public interface WordVectors extends Serializable, EmbeddingInitializer { */ void setModelUtils(ModelUtils utils); + /** + * Does implementation vectorize words absent in vocabulary + * @return boolean + */ + boolean outOfVocabularySupported(); + } diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/wordvectors/WordVectorsImpl.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/wordvectors/WordVectorsImpl.java index facef70d8..75511cae1 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/wordvectors/WordVectorsImpl.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/wordvectors/WordVectorsImpl.java @@ -20,6 +20,7 @@ import com.google.common.util.concurrent.AtomicDouble; import lombok.Getter; import lombok.NonNull; import lombok.Setter; +import lombok.val; import org.apache.commons.lang.ArrayUtils; import org.deeplearning4j.models.embeddings.WeightLookupTable; import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable; @@ -357,4 +358,9 @@ public class WordVectorsImpl implements WordVectors { public boolean jsonSerializable() { return false; } + + @Override + public boolean outOfVocabularySupported() { + return false; + } } diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/fasttext/FastText.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/fasttext/FastText.java index fd2ed63c7..a704e7f10 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/fasttext/FastText.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/fasttext/FastText.java @@ -6,10 +6,13 @@ import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.NotImplementedException; import org.apache.commons.lang3.StringUtils; import org.deeplearning4j.models.embeddings.WeightLookupTable; +import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable; +import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer; import org.deeplearning4j.models.embeddings.reader.ModelUtils; import org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils; import org.deeplearning4j.models.embeddings.wordvectors.WordVectors; import org.deeplearning4j.models.word2vec.VocabWord; +import org.deeplearning4j.models.word2vec.Word2Vec; import org.deeplearning4j.models.word2vec.wordstore.VocabCache; import org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache; import org.deeplearning4j.text.sentenceiterator.SentenceIterator; @@ -17,41 +20,78 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.primitives.Pair; -import java.io.BufferedWriter; -import java.io.File; -import java.io.FileWriter; -import java.io.IOException; -import java.util.Collection; -import java.util.List; -import java.util.Map; +import java.io.*; +import java.util.*; @Slf4j @AllArgsConstructor @lombok.Builder -public class FastText implements WordVectors { +public class FastText implements WordVectors, Serializable { - private boolean supervised; - private boolean quantize; - private boolean predict; - private boolean predict_prob; + // Mandatory + @Getter private String inputFile; + @Getter private String outputFile; - private boolean skipgram; - @Builder.Default private int bucket = 100; - @Builder.Default private int minCount = 1; + // Optional for dictionary + @Builder.Default private int bucket = -1; + @Builder.Default private int minCount = -1; + @Builder.Default private int minCountLabel = -1; + @Builder.Default private int wordNgrams = -1; + @Builder.Default private int minNgramLength = -1; + @Builder.Default private int maxNgramLength = -1; + @Builder.Default private int samplingThreshold = -1; + private String labelPrefix; - private boolean cbow; - private boolean nn; - private boolean analogies; - private String inputFile; - private String outputFile; - private SentenceIterator iterator; - private String modelName; - private String lossName; - //TODO: - private double[] pretrainedVectors; + // Optional for training + @Getter private boolean supervised; + @Getter private boolean quantize; + @Getter private boolean predict; + @Getter private boolean predict_prob; + @Getter private boolean skipgram; + @Getter private boolean cbow; + @Getter private boolean nn; + @Getter private boolean analogies; + @Getter private String pretrainedVectorsFile; + @Getter + @Builder.Default + private double learningRate = -1.0; + @Getter private double learningRateUpdate = -1.0; + @Getter + @Builder.Default + private int dim = -1; + @Getter + @Builder.Default + private int contextWindowSize = -1; + @Getter + @Builder.Default + private int epochs = -1; + @Getter private String modelName; + @Getter private String lossName; + @Getter + @Builder.Default + private int negativeSamples = -1; + @Getter + @Builder.Default + private int numThreads = -1; + @Getter private boolean saveOutput = false; - private JFastText fastTextImpl; - private boolean modelLoaded; + // Optional for quantization + @Getter + @Builder.Default + private int cutOff = -1; + @Getter private boolean retrain; + @Getter private boolean qnorm; + @Getter private boolean qout; + @Getter + @Builder.Default + private int dsub = -1; + + @Getter private SentenceIterator iterator; + + @Builder.Default private transient JFastText fastTextImpl = new JFastText(); + private transient Word2Vec word2Vec; + @Getter private boolean modelLoaded; + @Getter private boolean modelVectorsLoaded; private VocabCache vocabCache; public FastText(File modelPath) { @@ -63,8 +103,97 @@ public class FastText implements WordVectors { fastTextImpl = new JFastText(); } - public void init() { - fastTextImpl = new JFastText(); + private static class ArgsFactory { + + private List args = new ArrayList<>(); + + private void add(String label, String value) { + args.add(label); + args.add(value); + } + + private void addOptional(String label, int value) { + if (value >= 0) { + args.add(label); + args.add(Integer.toString(value)); + } + } + + private void addOptional(String label, double value) { + if (value >= 0.0) { + args.add(label); + args.add(Double.toString(value)); + } + } + + private void addOptional(String label, String value) { + if (StringUtils.isNotEmpty(value)) { + args.add(label); + args.add(value); + } + } + + private void addOptional(String label, boolean value) { + if (value) { + args.add(label); + } + } + + + public String[] args() { + String[] asArray = new String[args.size()]; + return args.toArray(asArray); + } + } + + private String[] makeArgs() { + ArgsFactory argsFactory = new ArgsFactory(); + + argsFactory.addOptional("cbow", cbow); + argsFactory.addOptional("skipgram", skipgram); + argsFactory.addOptional("supervised", supervised); + argsFactory.addOptional("quantize", quantize); + argsFactory.addOptional("predict", predict); + argsFactory.addOptional("predict_prob", predict_prob); + + argsFactory.add("-input", inputFile); + argsFactory.add("-output", outputFile ); + + argsFactory.addOptional("-pretrainedVectors", pretrainedVectorsFile); + + argsFactory.addOptional("-bucket", bucket); + argsFactory.addOptional("-minCount", minCount); + argsFactory.addOptional("-minCountLabel", minCountLabel); + argsFactory.addOptional("-wordNgrams", wordNgrams); + argsFactory.addOptional("-minn", minNgramLength); + argsFactory.addOptional("-maxn", maxNgramLength); + argsFactory.addOptional("-t", samplingThreshold); + argsFactory.addOptional("-label", labelPrefix); + argsFactory.addOptional("analogies",analogies); + argsFactory.addOptional("-lr", learningRate); + argsFactory.addOptional("-lrUpdateRate", learningRateUpdate); + argsFactory.addOptional("-dim", dim); + argsFactory.addOptional("-ws", contextWindowSize); + argsFactory.addOptional("-epoch", epochs); + argsFactory.addOptional("-loss", lossName); + argsFactory.addOptional("-neg", negativeSamples); + argsFactory.addOptional("-thread", numThreads); + argsFactory.addOptional("-saveOutput", saveOutput); + argsFactory.addOptional("-cutoff", cutOff); + argsFactory.addOptional("-retrain", retrain); + argsFactory.addOptional("-qnorm", qnorm); + argsFactory.addOptional("-qout", qout); + argsFactory.addOptional("-dsub", dsub); + + return argsFactory.args(); + } + + public void fit() { + String[] cmd = makeArgs(); + fastTextImpl.runCmd(cmd); + } + + public void loadIterator() { if (iterator != null) { try { File tempFile = File.createTempFile("FTX", ".txt"); @@ -81,24 +210,11 @@ public class FastText implements WordVectors { } } - public void fit() { - - String[] cmd; - if (skipgram) { - cmd = new String[]{"skipgram", "-bucket", Integer.toString(bucket), "-minCount", Integer.toString(minCount), - "-input", inputFile, "-output", outputFile}; - } - else if (cbow) { - cmd = new String[]{"cbow", "-bucket", Integer.toString(bucket), "-minCount", Integer.toString(minCount), - "-input", inputFile, "-output", outputFile}; - } - else if (supervised) - cmd = new String[]{"supervised", "-input", inputFile, - "-output", outputFile}; - else - cmd = new String[]{"-input", inputFile, - "-output", outputFile}; - fastTextImpl.runCmd(cmd); + public void loadPretrainedVectors(File vectorsFile) { + word2Vec = WordVectorSerializer.readWord2VecModel(vectorsFile); + modelVectorsLoaded = true; + log.info("Loaded vectorized representation from file %s. Functionality will be restricted.", + vectorsFile.getAbsolutePath()); } public void loadBinaryModel(String modelPath) { @@ -111,10 +227,18 @@ public class FastText implements WordVectors { modelLoaded = false; } + public void test(File testFile) { + fastTextImpl.test(testFile.getAbsolutePath()); + } + + private void assertModelLoaded() { + if (!modelLoaded && !modelVectorsLoaded) + throw new IllegalStateException("Model must be loaded before predict!"); + } + public String predict(String text) { - if (!modelLoaded) - throw new IllegalStateException("Model must be loaded before predict!"); + assertModelLoaded(); String label = fastTextImpl.predict(text); return label; @@ -122,8 +246,7 @@ public class FastText implements WordVectors { public Pair predictProbability(String text) { - if (!modelLoaded) - throw new IllegalStateException("Model must be loaded before predict!"); + assertModelLoaded(); JFastText.ProbLabel predictedProbLabel = fastTextImpl.predictProba(text); @@ -135,27 +258,39 @@ public class FastText implements WordVectors { @Override public VocabCache vocab() { - if (!modelLoaded) - throw new IllegalStateException("Load model before calling vocab()"); - - if (vocabCache == null) { - vocabCache = new AbstractCache(); + if (modelVectorsLoaded) { + vocabCache = word2Vec.vocab(); } - List words = fastTextImpl.getWords(); - for (int i = 0; i < words.size(); ++i) { - vocabCache.addWordToIndex(i, words.get(i)); - VocabWord word = new VocabWord(); - word.setWord(words.get(i)); - vocabCache.addToken(word); + else { + if (!modelLoaded) + throw new IllegalStateException("Load model before calling vocab()"); + + if (vocabCache == null) { + vocabCache = new AbstractCache(); + } + List words = fastTextImpl.getWords(); + for (int i = 0; i < words.size(); ++i) { + vocabCache.addWordToIndex(i, words.get(i)); + VocabWord word = new VocabWord(); + word.setWord(words.get(i)); + vocabCache.addToken(word); + } } return vocabCache; } @Override public long vocabSize() { - if (!modelLoaded) - throw new IllegalStateException("Load model before calling vocab()"); - return fastTextImpl.getNWords(); + long result = 0; + if (modelVectorsLoaded) { + result = word2Vec.vocabSize(); + } + else { + if (!modelLoaded) + throw new IllegalStateException("Load model before calling vocab()"); + result = fastTextImpl.getNWords(); + } + return result; } @Override @@ -170,99 +305,160 @@ public class FastText implements WordVectors { @Override public double[] getWordVector(String word) { - List vectors = fastTextImpl.getVector(word); - double[] retVal = new double[vectors.size()]; - for (int i = 0; i < vectors.size(); ++i) { - retVal[i] = vectors.get(i); + if (modelVectorsLoaded) { + return word2Vec.getWordVector(word); + } + else { + List vectors = fastTextImpl.getVector(word); + double[] retVal = new double[vectors.size()]; + for (int i = 0; i < vectors.size(); ++i) { + retVal[i] = vectors.get(i); + } + return retVal; } - return retVal; } @Override public INDArray getWordVectorMatrixNormalized(String word) { - INDArray r = getWordVectorMatrix(word); - return r.divi(Nd4j.getBlasWrapper().nrm2(r)); + if (modelVectorsLoaded) { + return word2Vec.getWordVectorMatrixNormalized(word); + } + else { + INDArray r = getWordVectorMatrix(word); + return r.divi(Nd4j.getBlasWrapper().nrm2(r)); + } } @Override public INDArray getWordVectorMatrix(String word) { - double[] values = getWordVector(word); - return Nd4j.createFromArray(values); + if (modelVectorsLoaded) { + return word2Vec.getWordVectorMatrix(word); + } + else { + double[] values = getWordVector(word); + return Nd4j.createFromArray(values); + } } @Override public INDArray getWordVectors(Collection labels) { + if (modelVectorsLoaded) { + return word2Vec.getWordVectors(labels); + } return null; } @Override public INDArray getWordVectorsMean(Collection labels) { + if (modelVectorsLoaded) { + return word2Vec.getWordVectorsMean(labels); + } return null; } + private List words = new ArrayList<>(); + @Override public boolean hasWord(String word) { - return fastTextImpl.getWords().contains(word); + if (modelVectorsLoaded) { + return word2Vec.outOfVocabularySupported(); + } + if (words.isEmpty()) + words = fastTextImpl.getWords(); + return words.contains(word); } - protected transient ModelUtils modelUtils = new BasicModelUtils<>(); + protected transient ModelUtils modelUtils; @Override public Collection wordsNearest(INDArray words, int top) { + if (modelVectorsLoaded) { + return word2Vec.wordsNearest(words, top); + } return modelUtils.wordsNearest(words, top); } @Override public Collection wordsNearestSum(INDArray words, int top) { + if (modelVectorsLoaded) { + return word2Vec.wordsNearestSum(words, top); + } return modelUtils.wordsNearestSum(words, top); } @Override public Collection wordsNearestSum(String word, int n) { + if (modelVectorsLoaded) { + return word2Vec.wordsNearestSum(word, n); + } return modelUtils.wordsNearestSum(word, n); } @Override public Collection wordsNearestSum(Collection positive, Collection negative, int top) { + if (modelVectorsLoaded) { + return word2Vec.wordsNearestSum(positive, negative, top); + } return modelUtils.wordsNearestSum(positive, negative, top); } @Override public Map accuracy(List questions) { + if (modelVectorsLoaded) { + return word2Vec.accuracy(questions); + } return modelUtils.accuracy(questions); } @Override public int indexOf(String word) { + if (modelVectorsLoaded) { + return word2Vec.indexOf(word); + } return vocab().indexOf(word); } @Override public List similarWordsInVocabTo(String word, double accuracy) { + if (modelVectorsLoaded) { + return word2Vec.similarWordsInVocabTo(word, accuracy); + } return modelUtils.similarWordsInVocabTo(word, accuracy); } @Override public Collection wordsNearest(Collection positive, Collection negative, int top) { + if (modelVectorsLoaded) { + return word2Vec.wordsNearest(positive, negative, top); + } return modelUtils.wordsNearest(positive, negative, top); } @Override public Collection wordsNearest(String word, int n) { + if (modelVectorsLoaded) { + return word2Vec.wordsNearest(word,n); + } return modelUtils.wordsNearestSum(word, n); } @Override public double similarity(String word, String word2) { + if (modelVectorsLoaded) { + return word2Vec.similarity(word, word2); + } return modelUtils.similarity(word, word2); } @Override public WeightLookupTable lookupTable() { + if (modelVectorsLoaded) { + return word2Vec.lookupTable(); + } return null; } @@ -320,4 +516,9 @@ public class FastText implements WordVectors { return fastTextImpl.getLabelPrefix(); } + @Override + public boolean outOfVocabularySupported() { + return true; + } + } diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/word2vec/StaticWord2Vec.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/word2vec/StaticWord2Vec.java index 61bb87ff8..e5821e515 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/word2vec/StaticWord2Vec.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/word2vec/StaticWord2Vec.java @@ -376,6 +376,11 @@ public class StaticWord2Vec implements WordVectors { return false; } + @Override + public boolean outOfVocabularySupported() { + return false; + } + public static class Builder { private AbstractStorage storage; diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/fasttext/FastTextTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/fasttext/FastTextTest.java index f13d5dea0..b60af71d5 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/fasttext/FastTextTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/fasttext/FastTextTest.java @@ -1,6 +1,10 @@ package org.deeplearning4j.models.fasttext; +import com.github.jfasttext.JFastText; import lombok.extern.slf4j.Slf4j; +import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer; +import org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils; +import org.deeplearning4j.models.word2vec.Word2Vec; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.text.sentenceiterator.BasicLineIterator; import org.deeplearning4j.text.sentenceiterator.SentenceIterator; @@ -13,6 +17,7 @@ import org.nd4j.resources.Resources; import java.io.File; import java.io.IOException; +import java.util.Arrays; import static org.junit.Assert.assertArrayEquals; @@ -23,7 +28,9 @@ import static org.junit.Assert.assertEquals; public class FastTextTest extends BaseDL4JTest { private File inputFile = Resources.asFile("models/fasttext/data/labeled_data.txt"); - private File modelFile = Resources.asFile("models/fasttext/supervised.model.bin"); + private File supModelFile = Resources.asFile("models/fasttext/supervised.model.bin"); + private File cbowModelFile = Resources.asFile("models/fasttext/cbow.model.bin"); + private File supervisedVectors = Resources.asFile("models/fasttext/supervised.model.vec"); @Rule @@ -39,7 +46,6 @@ public class FastTextTest extends BaseDL4JTest { inputFile(inputFile.getAbsolutePath()). outputFile(output.getAbsolutePath()).build(); log.info("\nTraining supervised model ...\n"); - fastText.init(); fastText.fit(); } @@ -53,7 +59,6 @@ public class FastTextTest extends BaseDL4JTest { inputFile(inputFile.getAbsolutePath()). outputFile(output.getAbsolutePath()).build(); log.info("\nTraining supervised model ...\n"); - fastText.init(); fastText.fit(); } @@ -68,7 +73,6 @@ public class FastTextTest extends BaseDL4JTest { inputFile(inputFile.getAbsolutePath()). outputFile(output.getAbsolutePath()).build(); log.info("\nTraining supervised model ...\n"); - fastText.init(); fastText.fit(); } @@ -82,34 +86,42 @@ public class FastTextTest extends BaseDL4JTest { inputFile(inputFile.getAbsolutePath()). outputFile(output.getAbsolutePath()).build(); log.info("\nTraining supervised model ...\n"); - fastText.init(); fastText.fit(); } - @Ignore + @Test + public void tesLoadCBOWModel() throws IOException { + + FastText fastText = new FastText(cbowModelFile); + fastText.test(cbowModelFile); + + assertEquals(19, fastText.vocab().numWords()); + assertEquals("enjoy", fastText.vocab().wordAtIndex(fastText.vocab().numWords() - 1)); + + double[] expected = {5.040466203354299E-4, 0.001005030469968915, 2.8882650076411664E-4, -6.413314840756357E-4, -1.78931062691845E-4, -0.0023157168179750443, -0.002215880434960127, 0.00274421414360404, -1.5344757412094623E-4, 4.6274057240225375E-4, -1.4383681991603225E-4, 3.7832374800927937E-4, 2.523412986192852E-4, 0.0018913350068032742, -0.0024741862434893847, -4.976555937901139E-4, 0.0039220210164785385, -0.001781729981303215, -6.010578363202512E-4, -0.00244093406945467, -7.98621098510921E-4, -0.0010007203090935946, -0.001640203408896923, 7.897148607298732E-4, 9.131592814810574E-4, -0.0013367272913455963, -0.0014030139427632093, -7.755287806503475E-4, -4.2878396925516427E-4, 6.912827957421541E-4, -0.0011824817629531026, -0.0036014916840940714, 0.004353308118879795, -7.073904271237552E-5, -9.646290563978255E-4, -0.0031849315855652094, 2.3360115301329643E-4, -2.9103990527801216E-4, -0.0022990566212683916, -0.002393763978034258, -0.001034979010000825, -0.0010725988540798426, 0.0018285386031493545, -0.0013178540393710136, -1.6632364713586867E-4, -1.4665909475297667E-5, 5.445032729767263E-4, 2.999933494720608E-4, -0.0014367225812748075, -0.002345481887459755, 0.001117417006753385, -8.688368834555149E-4, -0.001830018823966384, 0.0013242220738902688, -8.880519890226424E-4, -6.888324278406799E-4, -0.0036394784692674875, 0.002179111586883664, -1.7201311129610986E-4, 0.002365073887631297, 0.002688770182430744, 0.0023955567739903927, 0.001469283364713192, 0.0011803617235273123, 5.871498142369092E-4, -7.099180947989225E-4, 7.518937345594168E-4, -8.599072461947799E-4, -6.600041524507105E-4, -0.002724145073443651, -8.365285466425121E-4, 0.0013173354091122746, 0.001083166105672717, 0.0014539906987920403, -3.1698777456767857E-4, -2.387022686889395E-4, 1.9560157670639455E-4, 0.0020277926232665777, -0.0012741144746541977, -0.0013026101514697075, -1.5212174912448972E-4, 0.0014194383984431624, 0.0012500399025157094, 0.0013362085446715355, 3.692879108712077E-4, 4.319801155361347E-5, 0.0011261265026405454, 0.0017244465416297317, 5.564604725805111E-5, 0.002170475199818611, 0.0014707016525790095, 0.001303741242736578, 0.005553730763494968, -0.0011097051901742816, -0.0013661726843565702, 0.0014100460102781653, 0.0011811562580987811, -6.622733199037611E-4, 7.860265322960913E-4, -9.811905911192298E-4}; + assertArrayEquals(expected, fastText.getWordVector("enjoy"), 1e-4); + } + @Test public void testPredict() throws IOException { - for (int i = 0; i < 100; ++i) { String text = "I like soccer"; - FastText fastText = new FastText(modelFile); + FastText fastText = new FastText(supModelFile); assertEquals(48, fastText.vocab().numWords()); assertEquals("association", fastText.vocab().wordAtIndex(fastText.vocab().numWords() - 1)); double[] expected = {-0.006423053797334433, 0.007660661358386278, 0.006068876478821039, -0.004772625397890806, -0.007143457420170307, -0.007735592778772116, -0.005607823841273785, -0.00836215727031231, 0.0011235733982175589, 2.599214785732329E-4, 0.004131870809942484, 0.007203693501651287, 0.0016768622444942594, 0.008694255724549294, -0.0012487826170399785, -0.00393667770549655, -0.006292815785855055, 0.0049359360709786415, -3.356488887220621E-4, -0.009407570585608482, -0.0026168026961386204, -0.00978928804397583, 0.0032913016621023417, -0.0029464277904480696, -0.008649969473481178, 8.056449587456882E-4, 0.0043088337406516075, -0.008980576880276203, 0.008716211654245853, 0.0073893265798687935, -0.007388216909021139, 0.003814412746578455, -0.005518500227481127, 0.004668557550758123, 0.006603693123906851, 0.003820829326286912, 0.007174000144004822, -0.006393063813447952, -0.0019381389720365405, -0.0046371882781386375, -0.006193376146256924, -0.0036685809027403593, 7.58899434003979E-4, -0.003185075242072344, -0.008330358192324638, 3.3206873922608793E-4, -0.005389622412621975, 0.009706716984510422, 0.0037855932023376226, -0.008665262721478939, -0.0032511046156287193, 4.4134497875347733E-4, -0.008377416990697384, -0.009110655635595322, 0.0019723298028111458, 0.007486093323677778, 0.006400121841579676, 0.00902814231812954, 0.00975200068205595, 0.0060582347214221954, -0.0075621469877660275, 1.0270809434587136E-4, -0.00673140911385417, -0.007316927425563335, 0.009916870854794979, -0.0011407854035496712, -4.502215306274593E-4, -0.007612560410052538, 0.008726916275918484, -3.0280642022262327E-5, 0.005529289599508047, -0.007944817654788494, 0.005593308713287115, 0.003423960180953145, 4.1348213562741876E-4, 0.009524818509817123, -0.0025129399728029966, -0.0030074280221015215, -0.007503866218030453, -0.0028124507516622543, -0.006841592025011778, -2.9375351732596755E-4, 0.007195258513092995, -0.007775942329317331, 3.951996040996164E-4, -0.006887971889227629, 0.0032655203249305487, -0.007975360378623009, -4.840183464693837E-6, 0.004651934839785099, 0.0031739831902086735, 0.004644941072911024, -0.007461248897016048, 0.003057275665923953, 0.008903342299163342, 0.006857945583760738, 0.007567950990051031, 0.001506582135334611, 0.0063307867385447025, 0.005645462777465582}; - assertArrayEquals(expected, fastText.getWordVector("association"), 1e-5); + assertArrayEquals(expected, fastText.getWordVector("association"), 1e-4); String label = fastText.predict(text); assertEquals("__label__soccer", label); - } } - @Ignore @Test public void testPredictProbability() throws IOException { String text = "I like soccer"; - FastText fastText = new FastText(modelFile); + FastText fastText = new FastText(supModelFile); Pair result = fastText.predictProbability(text); assertEquals("__label__soccer", result.getFirst()); @@ -129,7 +141,7 @@ public class FastTextTest extends BaseDL4JTest { @Test public void testVocabulary() throws IOException { - FastText fastText = new FastText(modelFile); + FastText fastText = new FastText(supModelFile); assertEquals(48, fastText.vocab().numWords()); assertEquals(48, fastText.vocabSize()); @@ -149,7 +161,7 @@ public class FastTextTest extends BaseDL4JTest { SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath()); FastText fastText = FastText.builder().supervised(true).iterator(iter).build(); - fastText.init(); + fastText.loadIterator(); } catch (IOException e) { log.error(e.toString()); @@ -162,4 +174,60 @@ public class FastTextTest extends BaseDL4JTest { String label = fastText.predict("something"); } + @Test + public void testPretrainedVectors() throws IOException { + File output = testDir.newFile(); + + FastText fastText = + FastText.builder().supervised(true). + inputFile(inputFile.getAbsolutePath()). + pretrainedVectorsFile(supervisedVectors.getAbsolutePath()). + outputFile(output.getAbsolutePath()).build(); + log.info("\nTraining supervised model ...\n"); + fastText.fit(); + } + + @Test + public void testWordsStatistics() throws IOException { + + File output = testDir.newFile(); + + FastText fastText = + FastText.builder().supervised(true). + inputFile(inputFile.getAbsolutePath()). + outputFile(output.getAbsolutePath()).build(); + + log.info("\nTraining supervised model ...\n"); + fastText.fit(); + + Word2Vec word2Vec = WordVectorSerializer.readAsCsv(new File(output.getAbsolutePath() + ".vec")); + + assertEquals(48, word2Vec.getVocab().numWords()); + + System.out.println(word2Vec.wordsNearest("association", 3)); + System.out.println(word2Vec.similarity("Football", "teams")); + System.out.println(word2Vec.similarity("professional", "minutes")); + System.out.println(word2Vec.similarity("java","cpp")); + } + + + @Test + public void testWordsNativeStatistics() throws IOException { + + File output = testDir.newFile(); + + FastText fastText = new FastText(); + fastText.loadPretrainedVectors(supervisedVectors); + + log.info("\nTraining supervised model ...\n"); + + assertEquals(48, fastText.vocab().numWords()); + + String[] result = new String[3]; + fastText.wordsNearest("association", 3).toArray(result); + assertArrayEquals(new String[]{"most","eleven","hours"}, result); + assertEquals(0.1657, fastText.similarity("Football", "teams"), 1e-4); + assertEquals(0.3661, fastText.similarity("professional", "minutes"), 1e-4); + assertEquals(Double.NaN, fastText.similarity("java","cpp"), 1e-4); + } } diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/serialization/WordVectorSerializerTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/serialization/WordVectorSerializerTest.java index 8a39a7fe3..6643b23bf 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/serialization/WordVectorSerializerTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/serialization/WordVectorSerializerTest.java @@ -26,6 +26,7 @@ import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration; import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer; import org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils; import org.deeplearning4j.models.embeddings.reader.impl.FlatModelUtils; +import org.deeplearning4j.models.fasttext.FastText; import org.deeplearning4j.models.paragraphvectors.ParagraphVectors; import org.deeplearning4j.models.sequencevectors.SequenceVectors; import org.deeplearning4j.models.word2vec.VocabWord; @@ -42,6 +43,7 @@ import org.nd4j.linalg.factory.Nd4j; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.File; +import java.io.IOException; import java.util.Collections; import static org.junit.Assert.*; @@ -289,4 +291,33 @@ public class WordVectorSerializerTest extends BaseDL4JTest { } } } + + @Test + public void FastText_Correct_WhenDeserialized() throws IOException { + + FastText fastText = + FastText.builder().cbow(true).build(); + + WordVectorSerializer.writeWordVectors(fastText, new File("some.data")); + + FastText deser = null; + try { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + deser = WordVectorSerializer.readWordVectors(new File("some.data")); + } catch (Exception e) { + e.printStackTrace(); + fail(); + } + + assertNotNull(deser); + assertEquals(fastText.isCbow(), deser.isCbow()); + assertEquals(fastText.isModelLoaded(), deser.isModelLoaded()); + assertEquals(fastText.isAnalogies(), deser.isAnalogies()); + assertEquals(fastText.isNn(), deser.isNn()); + assertEquals(fastText.isPredict(), deser.isPredict()); + assertEquals(fastText.isPredict_prob(), deser.isPredict_prob()); + assertEquals(fastText.isQuantize(), deser.isQuantize()); + assertEquals(fastText.getInputFile(), deser.getInputFile()); + assertEquals(fastText.getOutputFile(), deser.getOutputFile()); + } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java index 12c855434..d08fa852a 100755 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java @@ -453,6 +453,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { DataType netDtype = getConfiguration().getDataType(); if(parameters != null && parameters.dataType() != netDtype){ + Preconditions.checkState(parameters.rank() == 2 && parameters.size(0) == 1, "Invalid parameters array: should be rank 2 with shape [1,numParams]. Got %ndShape", parameters); if(cloneParametersArray){ try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) { parameters = parameters.castTo(netDtype); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/feedforward/embedding/EmbeddingSequenceLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/feedforward/embedding/EmbeddingSequenceLayer.java index 5da9bdece..2987d4be2 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/feedforward/embedding/EmbeddingSequenceLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/feedforward/embedding/EmbeddingSequenceLayer.java @@ -178,8 +178,7 @@ public class EmbeddingSequenceLayer extends BaseLayer(); diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/test/java/org/deeplearning4j/ui/play/TestSameDiffUI.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/test/java/org/deeplearning4j/ui/play/TestSameDiffUI.java index bc573ca43..b5af4c88e 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/test/java/org/deeplearning4j/ui/play/TestSameDiffUI.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/test/java/org/deeplearning4j/ui/play/TestSameDiffUI.java @@ -32,7 +32,7 @@ import java.util.Arrays; @Ignore public class TestSameDiffUI { -// @Ignore + @Ignore @Test public void testSameDiff() throws Exception { diff --git a/libnd4j/blas/NDArray.h b/libnd4j/blas/NDArray.h index 361f8054c..f60e760d6 100644 --- a/libnd4j/blas/NDArray.h +++ b/libnd4j/blas/NDArray.h @@ -1598,9 +1598,6 @@ namespace nd4j { ////////////////////////////////////////////////////////////////////////// int NDArray::rankOf() const { - if (isEmpty()) - return 0; - return shape::rank(_shapeInfo); } diff --git a/libnd4j/blas/NDArray.hpp b/libnd4j/blas/NDArray.hpp index fb1503452..fc23b581c 100644 --- a/libnd4j/blas/NDArray.hpp +++ b/libnd4j/blas/NDArray.hpp @@ -36,7 +36,7 @@ std::string NDArray::e(const Nd4jLong i) const; template NDArray* NDArray::asT() const{ - auto result = new NDArray(ordering(), isScalar() ? std::vector({0}) : getShapeAsVector(), DataTypeUtils::fromT()); + auto result = isScalar() ? new NDArray('c', {}, {0.}, DataTypeUtils::fromT(), this->getContext()) : new NDArray(ordering(), getShapeAsVector(), DataTypeUtils::fromT(), this->getContext()); auto l = this->lengthOf(); prepareSpecialUse({result}, {this}); @@ -67,17 +67,18 @@ NDArray::NDArray(const NDArray& other) { //////////////////////////////////////////////////////////////////////// NDArray::NDArray(const char order, const std::vector &shape, nd4j::DataType dtype, nd4j::LaunchContext * context) { - if (shape.empty()) - throw std::runtime_error("NDArray constructor: input shape is empty !"); - if ((int) shape.size() > MAX_RANK) throw std::invalid_argument("Rank of NDArray can't exceed 32"); _context = context; - _isAttached = getContext()->getWorkspace() != nullptr; + _isAttached = _context->getWorkspace() != nullptr; _offset = 0; - setShapeInfo(ShapeDescriptor(dtype, order, shape)); + if (shape.empty()) + setShapeInfo(ShapeDescriptor::emptyDescriptor(dtype)); + else + setShapeInfo(ShapeDescriptor(dtype, order, shape)); + _buffer = std::make_shared(lengthOf() * DataTypeUtils::sizeOf(dtype), dtype, getContext()->getWorkspace()); _buffer->setToZeroBuffers(); } @@ -85,16 +86,20 @@ NDArray::NDArray(const char order, const std::vector &shape, nd4j::Dat //////////////////////////////////////////////////////////////////////// NDArray::NDArray(const char order, const std::vector &shape, const std::vector& data, nd4j::DataType dtype, nd4j::LaunchContext * context) { - if (shape.empty()) - throw std::runtime_error("NDArray constructor: input shape is empty !"); - if ((int) shape.size() > MAX_RANK) throw std::invalid_argument("Rank of NDArray can't exceed 32"); _context = context; _offset = 0; - setShapeInfo(ShapeDescriptor(dtype, order, shape)); + if (shape.size() == 0) { + if (data.size() == 0) + setShapeInfo(ShapeDescriptor::emptyDescriptor(dtype)); + else + setShapeInfo(ShapeDescriptor::scalarDescriptor(dtype)); + } else { + setShapeInfo(ShapeDescriptor(dtype, order, shape)); + } if (lengthOf() != data.size()) { nd4j_printf("NDArray constructor: data size [%i] doesn't match shape length [%i]\n", data.size(), lengthOf()); @@ -2441,6 +2446,9 @@ void NDArray::applyTrueBroadcast(nd4j::BroadcastOpsTuple op, const NDArray* othe if(((op.s == scalar::Divide || op.s == scalar::FloorDiv || op.s == scalar::FloorMod) && other->isB()) || (op.s == scalar::ReverseDivide && this->isB())) throw std::runtime_error("NDArray::applyTrueBroadcast method: you can't divide by bool array !"); + if (isEmpty() || other->isEmpty()) + return; + NDArray::prepareSpecialUse({target}, {this, other}); if (isScalar()) { @@ -2513,6 +2521,9 @@ void NDArray::applyTrueBroadcast(nd4j::BroadcastBoolOpsTuple op, const NDArray* if(target == nullptr || other == nullptr) throw std::runtime_error("NDArray::applyTrueBroadcast bool method: target or other = nullptr !"); + if (isEmpty() || other->isEmpty()) + return; + NDArray::prepareSpecialUse({target}, {this, other}); if (isScalar()) { @@ -2583,6 +2594,13 @@ void NDArray::applyTrueBroadcast(nd4j::BroadcastBoolOpsTuple op, const NDArray* ////////////////////////////////////////////////////////////////////////// NDArray NDArray::applyTrueBroadcast(nd4j::BroadcastOpsTuple op, const NDArray& other, ExtraArguments *extraArgs) const { + if (isEmpty() || other.isEmpty()) { + if (isEmpty()) + return NDArray(*this); + else + return NDArray(other); + } + Nd4jLong* newShapeInfo = nullptr; if(!ShapeUtils::evalBroadcastShapeInfo(*this, other, true, newShapeInfo, getContext()->getWorkspace())) // the rank of new array = max->rankOf)() throw std::runtime_error("NDArray::applyTrueBroadcast method: the shapes of this and other arrays are not suitable for broadcast operation !"); @@ -2812,6 +2830,19 @@ bool NDArray::reshapei(const char order, const std::vector& cshape) { if(order == ordering() && shape::shapeEquals(rankOf(), shapeOf(), cshape.size(), cshape.data())) return true; + const bool isOutShapeEmpty = std::find(cshape.begin(), cshape.end(), 0) != cshape.end(); + + if(isEmpty() && !isOutShapeEmpty) + throw std::invalid_argument("NDArray::reshapei: can't reshape empty array to non-empty !"); + if(!isEmpty() && isOutShapeEmpty) + throw std::invalid_argument("NDArray::reshapei: can't reshape non-empty array to empty !"); + if(isEmpty() && isOutShapeEmpty) { + Nd4jLong* shapeInfoNew = ShapeBuilders::emptyShapeInfo(dataType(), order, cshape, getContext()->getWorkspace()); + setShapeInfo(shapeInfoNew); + RELEASE(shapeInfoNew, getContext()->getWorkspace()); + return true; + } + std::vector shape(cshape); int rank = shape.size(); @@ -2823,7 +2854,7 @@ bool NDArray::reshapei(const char order, const std::vector& cshape) { for (int i = 0; i < (int) shape.size(); i++) { if (shape[i] < 0) { if (numberNegativesOnes >= 1) - throw std::runtime_error("Only one dimension can be negative at once"); + throw std::runtime_error("NDArray::reshapei: only one dimension can be negative at once"); numberNegativesOnes++; @@ -3664,7 +3695,7 @@ void NDArray::reduceAlongDimension(nd4j::reduce::SameOps op, NDArray* target, co if(rankOf() == copy.size() || copy.empty()) { NativeOpExecutioner::execReduceSameScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), nullptr, target->getBuffer(), target->getShapeInfo(), target->getSpecialBuffer(), target->getSpecialShapeInfo()); } - else { + else { //if (!isEmpty()) { auto pDims = nd4j::Environment::getInstance()->isCPU() ? copy.data() : nullptr; auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(this->getShapeInfo(), copy); NativeOpExecutioner::execReduceSame(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), nullptr, target->getBuffer(), target->getShapeInfo(), target->getSpecialBuffer(), target->getSpecialShapeInfo(), pDims, copy.size(), packX.platformShapeInfo(), packX.platformOffsets()); @@ -4198,6 +4229,9 @@ NDArray* NDArray::tensorAlongDimension(Nd4jLong index, const std::vector& d // operator returns sub-array with buffer pointing at this->_buffer + certain offset NDArray NDArray::operator()(const std::vector& idx, const bool keepUnitiesInShape, const bool isStrided) const { + if(isEmpty()) + throw std::invalid_argument("NDArray::operator(sub-arrays): array is empty !"); + const int rank = rankOf(); Nd4jLong *newShapeInfo = ShapeBuilders::copyShapeInfo(getShapeInfo(), true, getContext()->getWorkspace()); @@ -4260,6 +4294,9 @@ NDArray NDArray::operator()(const Nd4jLong subArrIdx, const std::vector& di //////////////////////////////////////////////////////////////////////// void NDArray::getSubArrShapeAndOffsets(const std::vector& dimsToExclude, Nd4jLong* &subArrShapeInfo, Nd4jLong* &subArrOffsets, bool keepUnitiesInShape) const { + if(isEmpty()) + throw std::invalid_argument("NDArray::getSubArrShapeAndOffsets: array is empty !"); + const int rank = rankOf(); const int subArrRank = (rank == dimsToExclude.size() || keepUnitiesInShape) ? rank : rank - dimsToExclude.size(); const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(_shapeInfo, dimsToExclude); diff --git a/libnd4j/blas/NativeOps.h b/libnd4j/blas/NativeOps.h index 021ca29a3..4d451ed4b 100755 --- a/libnd4j/blas/NativeOps.h +++ b/libnd4j/blas/NativeOps.h @@ -1334,18 +1334,7 @@ public: * @param npyArray * @return */ - Nd4jPointer shapeBufferForNumpy(Nd4jPointer npyArray) { - cnpy::NpyArray arr = cnpy::loadNpyFromPointer(reinterpret_cast(npyArray)); - unsigned int shapeSize = arr.shape.size(); - auto shape = new unsigned int[shapeSize]; - for(unsigned int i = 0; i < shapeSize; i++) { - shape[i] = arr.shape[i]; - } - - auto shapeBuffer = shape::shapeBufferOfNpy(arr.shape.size(), shape, arr.fortranOrder); - delete[] shape; - return reinterpret_cast(shapeBuffer); - } + Nd4jPointer shapeBufferForNumpy(Nd4jPointer npyArray); /** diff --git a/libnd4j/blas/cpu/NDArray.cpp b/libnd4j/blas/cpu/NDArray.cpp index 74e186301..55db6ddf6 100644 --- a/libnd4j/blas/cpu/NDArray.cpp +++ b/libnd4j/blas/cpu/NDArray.cpp @@ -64,8 +64,8 @@ void NDArray::tickWriteDevice() const { } void NDArray::tickReadHost() const { } void NDArray::tickReadDevice() const { } void NDArray::tickBothActual() const { } -bool NDArray::isActualOnHostSide() const { } -bool NDArray::isActualOnDeviceSide() const { } +bool NDArray::isActualOnHostSide() const { return true; } +bool NDArray::isActualOnDeviceSide() const { return true; } void NDArray::makeBothBuffersActual() const { } @@ -419,328 +419,8 @@ void NDArray::repeat(int dimension, NDArray& target) const { ////////////////////////////////////////////////////////////////////////// #ifndef __JAVACPP_HACK__ -template -void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function& func, NDArray* target) { - if (target == nullptr) - target = this; +#include "NDArrayLambda.hpp" - if (second == nullptr) { - nd4j_printf("applyTriplewiseLambda requires three operands to be valid NDArrays, but Second is NULL\n",""); - throw std::runtime_error("second is null"); - } - - if (third == nullptr) { - nd4j_printf("applyTriplewiseLambda requires three operands to be valid NDArrays, but Third is NULL\n",""); - throw std::runtime_error("third is null"); - } - if(dataType() != DataTypeUtils::fromT()) - throw std::runtime_error("NDArray::applyTriplewiseLambda method: wrong template parameter T, its type should be the same as type of this array!"); - if(dataType() != second->dataType() || dataType() != third->dataType() || dataType() != target->dataType()) - throw std::runtime_error("NDArray::applyTriplewiseLambda method: bother four arrays (this, second, third, target) should have the same type !"); - - if (this->lengthOf() != second->lengthOf() || this->lengthOf() != third->lengthOf() || !this->isSameShape(second) || !this->isSameShape(third)) { - nd4j_printf("applyPairwiseLambda requires both operands to have the same shape\n",""); - throw std::runtime_error("Shapes mismach"); - } - - auto f = this->bufferAsT(); - auto s = second->bufferAsT(); - auto t = third->bufferAsT(); - auto z = target->bufferAsT(); - - if (this->ordering() == second->ordering() && this->ordering() == third->ordering() && this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1) && this->ews() == second->ews() && this->ews() == third->ews()) { - - PRAGMA_OMP_PARALLEL_FOR_SIMD - for (Nd4jLong e = 0; e < _length; e++) - z[e] = func(f[e], s[e], t[e]); - } else { - if (f == z) { - - PRAGMA_OMP_PARALLEL_FOR_SIMD - for (int e = 0; e < _length; e++) { - - auto tOffset = this->getOffset(e); - auto uOffset = second->getOffset(e); - auto vOffset = third->getOffset(e); - - f[tOffset] = func(f[tOffset], s[uOffset], t[vOffset]); - } - } else { - - PRAGMA_OMP_PARALLEL_FOR_SIMD - for (int e = 0; e < _length; e++) { - - auto tOffset = this->getOffset(e); - auto uOffset = second->getOffset(e); - auto vOffset = third->getOffset(e); - auto zOffset = target->getOffset(e); - - z[zOffset] = func(f[tOffset], s[uOffset], t[vOffset]); - } - } - } -} -template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function& func, NDArray* target); -template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function& func, NDArray* target); -template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function& func, NDArray* target); -template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function& func, NDArray* target); -template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function& func, NDArray* target); -template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function& func, NDArray* target); -template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function& func, NDArray* target); -template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function& func, NDArray* target); -template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function& func, NDArray* target); - template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function& func, NDArray* target); - template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function& func, NDArray* target); - template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function& func, NDArray* target); -template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function& func, NDArray* target); - -////////////////////////////////////////////////////////////////////////// -template -void NDArray::applyPairwiseLambda(const NDArray* other, const std::function& func, NDArray* target) { - if (target == nullptr) - target = this; - - if (other == nullptr) { - nd4j_printf("applyPairwiseLambda requires both operands to be valid NDArrays, but Y is NULL\n",""); - throw std::runtime_error("Other is null"); - } - - if(dataType() != DataTypeUtils::fromT()) - throw std::runtime_error("NDArray::applyPairwiseLambda method: wrong template parameter T, its type should be the same as type of this array!"); - if(dataType() != other->dataType() || dataType() != target->dataType()) - throw std::runtime_error("NDArray::applyPairwiseLambda method: all three arrays (this, other, target) must have the same type !"); - - if (this->lengthOf() != other->lengthOf()) { - nd4j_printf("applyPairwiseLambda requires both operands to have the same shape\n",""); - throw std::runtime_error("Shapes mismach"); - } - - auto f = this->bufferAsT(); - auto s = other->bufferAsT(); - auto z = target->bufferAsT(); - - if (this->ordering() == other->ordering() && this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1) && this->ews() == other->ews()) { - - PRAGMA_OMP_PARALLEL_FOR_SIMD - for (int e = 0; e < _length; e++) - z[e] = func(f[e], s[e]); - } else { - if (f == z) { - - PRAGMA_OMP_PARALLEL_FOR_SIMD - for (int e = 0; e < _length; e++) { - - auto xOffset = this->getOffset(e); - auto yOffset = other->getOffset(e); - - f[xOffset] = func(f[xOffset], s[yOffset]); - } - } else { - - PRAGMA_OMP_PARALLEL_FOR_SIMD - for (int e = 0; e < _length; e++) { - - auto xOffset = this->getOffset(e); - auto yOffset = other->getOffset(e); - auto zOffset = target->getOffset(e); - - z[zOffset] = func(f[xOffset], s[yOffset]); - } - } - } -} -template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function& func, NDArray* target); -template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function& func, NDArray* target); -template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function& func, NDArray* target); -template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function& func, NDArray* target); -template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function& func, NDArray* target); -template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function& func, NDArray* target); -template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function& func, NDArray* target); -template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function& func, NDArray* target); -template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function& func, NDArray* target); - template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function& func, NDArray* target); - template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function& func, NDArray* target); - template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function& func, NDArray* target); -template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function& func, NDArray* target); - -////////////////////////////////////////////////////////////////////////// -template -void NDArray::applyLambda(const std::function& func, NDArray* target) { - if (target == nullptr) - target = this; - - if(dataType() != DataTypeUtils::fromT()) - throw std::runtime_error("NDArray::applyLambda method: wrong template parameter T, its type should be the same as type of this array!"); - if(dataType() != target->dataType()) - throw std::runtime_error("NDArray::applyLambda method: types of this and target array should match !"); - - auto f = this->bufferAsT(); - auto z = target->bufferAsT(); - - if (this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1)) { - - PRAGMA_OMP_PARALLEL_FOR_SIMD - for (int e = 0; e < _length; e++) - z[e] = func(f[e]); - } else { - if (f == z) { - - PRAGMA_OMP_PARALLEL_FOR_SIMD - for (int e = 0; e < _length; e++) { - - auto xOffset = this->getOffset(e); - - f[xOffset] = func(f[xOffset]); - } - } else { - - PRAGMA_OMP_PARALLEL_FOR_SIMD - for (int e = 0; e < _length; e++) { - - auto xOffset = this->getOffset(e); - auto zOffset = target->getOffset(e); - - z[zOffset] = func(f[xOffset]); - } - } - } -} -template void NDArray::applyLambda(const std::function& func, NDArray* target); -template void NDArray::applyLambda(const std::function& func, NDArray* target); -template void NDArray::applyLambda(const std::function& func, NDArray* target); -template void NDArray::applyLambda(const std::function& func, NDArray* target); -template void NDArray::applyLambda(const std::function& func, NDArray* target); -template void NDArray::applyLambda(const std::function& func, NDArray* target); -template void NDArray::applyLambda(const std::function& func, NDArray* target); -template void NDArray::applyLambda(const std::function& func, NDArray* target); -template void NDArray::applyLambda(const std::function& func, NDArray* target); - template void NDArray::applyLambda(const std::function& func, NDArray* target); - template void NDArray::applyLambda(const std::function& func, NDArray* target); - template void NDArray::applyLambda(const std::function& func, NDArray* target); -template void NDArray::applyLambda(const std::function& func, NDArray* target); - -////////////////////////////////////////////////////////////////////////// -template -void NDArray::applyIndexedLambda(const std::function& func, NDArray* target) { - if (target == nullptr) - target = this; - - if(dataType() != DataTypeUtils::fromT()) - throw std::runtime_error("NDArray::applyIndexedLambda method: wrong template parameter T, its type should be the same as type of this array!"); - if(dataType() != target->dataType()) - throw std::runtime_error("NDArray::applyIndexedLambda method: types of this and target array should match !"); - - auto f = this->bufferAsT(); - auto z = target->bufferAsT(); - - if (this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1)) { - - PRAGMA_OMP_PARALLEL_FOR_SIMD - for (Nd4jLong e = 0; e < _length; e++) - z[e] = func(e, f[e]); - } else { - if (f == z) { - - PRAGMA_OMP_PARALLEL_FOR_SIMD - for (Nd4jLong e = 0; e < _length; e++) { - - auto xOffset = this->getOffset(e); - - f[xOffset] = func(e, f[xOffset]); - } - } else { - - PRAGMA_OMP_PARALLEL_FOR_SIMD - for (Nd4jLong e = 0; e < _length; e++) { - - auto xOffset = this->getOffset(e); - auto zOffset = target->getOffset(e); - - z[zOffset] = func(e, f[xOffset]); - } - } - } -} -template void NDArray::applyIndexedLambda(const std::function& func, NDArray* target); -template void NDArray::applyIndexedLambda(const std::function& func, NDArray* target); -template void NDArray::applyIndexedLambda(const std::function& func, NDArray* target); -template void NDArray::applyIndexedLambda(const std::function& func, NDArray* target); -template void NDArray::applyIndexedLambda(const std::function& func, NDArray* target); -template void NDArray::applyIndexedLambda(const std::function& func, NDArray* target); -template void NDArray::applyIndexedLambda(const std::function& func, NDArray* target); -template void NDArray::applyIndexedLambda(const std::function& func, NDArray* target); -template void NDArray::applyIndexedLambda(const std::function& func, NDArray* target); - template void NDArray::applyIndexedLambda(const std::function& func, NDArray* target); - template void NDArray::applyIndexedLambda(const std::function& func, NDArray* target); - template void NDArray::applyIndexedLambda(const std::function& func, NDArray* target); -template void NDArray::applyIndexedLambda(const std::function& func, NDArray* target); - -////////////////////////////////////////////////////////////////////////// -template -void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function& func, NDArray* target) { - if (target == nullptr) - target = this; - - if (other == nullptr) { - nd4j_printf("applyIndexedPairwiseLambda requires both operands to be valid NDArrays, but Y is NULL\n",""); - throw std::runtime_error("Other is null"); - } - if(dataType() != DataTypeUtils::fromT()) - throw std::runtime_error("NDArray::applyIndexedPairwiseLambda method: wrong template parameter T, its type should be the same as type of this array!"); - if(dataType() != target->dataType()) - throw std::runtime_error("NDArray::applyIndexedPairwiseLambda method: types of this and target array should match !"); - if (this->lengthOf() != other->lengthOf()) { - nd4j_printf("applyIndexedPairwiseLambda requires both operands to have the same shape\n",""); - throw std::runtime_error("Shapes mismach"); - } - - auto f = this->bufferAsT(); - auto s = other->bufferAsT(); - auto z = target->bufferAsT(); - - if (this->ordering() == other->ordering() && this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1) && this->ews() == other->ews()) { - - PRAGMA_OMP_PARALLEL_FOR_SIMD - for (Nd4jLong e = 0; e < _length; e++) - z[e] = func((Nd4jLong) e, f[e], s[e]); - } else { - if (f == z) { - - PRAGMA_OMP_PARALLEL_FOR_SIMD - for (int e = 0; e < _length; e++) { - - auto xOffset = this->getOffset(e); - auto yOffset = other->getOffset(e); - - f[xOffset] = func((Nd4jLong) e, f[xOffset], s[yOffset]); - } - } else { - - PRAGMA_OMP_PARALLEL_FOR_SIMD - for (int e = 0; e < _length; e++) { - - auto xOffset = this->getOffset(e); - auto yOffset = other->getOffset(e); - auto zOffset = target->getOffset(e); - - z[zOffset] = func((Nd4jLong) e, f[xOffset], s[yOffset]); - } - } - } -} -template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function& func, NDArray* target); -template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function& func, NDArray* target); -template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function& func, NDArray* target); -template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function& func, NDArray* target); -template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function& func, NDArray* target); -template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function& func, NDArray* target); -template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function& func, NDArray* target); -template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function& func, NDArray* target); -template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function& func, NDArray* target); - template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function& func, NDArray* target); - template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function& func, NDArray* target); - template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function& func, NDArray* target); -template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function& func, NDArray* target); #endif /* diff --git a/libnd4j/blas/cpu/NDArrayLambda.hpp b/libnd4j/blas/cpu/NDArrayLambda.hpp new file mode 100644 index 000000000..5006b18ae --- /dev/null +++ b/libnd4j/blas/cpu/NDArrayLambda.hpp @@ -0,0 +1,325 @@ + + + +template +void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function& func, NDArray* target) { + if (target == nullptr) + target = this; + + if (second == nullptr) { + nd4j_printf("applyTriplewiseLambda requires three operands to be valid NDArrays, but Second is NULL\n",""); + throw std::runtime_error("second is null"); + } + + if (third == nullptr) { + nd4j_printf("applyTriplewiseLambda requires three operands to be valid NDArrays, but Third is NULL\n",""); + throw std::runtime_error("third is null"); + } + if(dataType() != DataTypeUtils::fromT()) + throw std::runtime_error("NDArray::applyTriplewiseLambda method: wrong template parameter T, its type should be the same as type of this array!"); + if(dataType() != second->dataType() || dataType() != third->dataType() || dataType() != target->dataType()) + throw std::runtime_error("NDArray::applyTriplewiseLambda method: bother four arrays (this, second, third, target) should have the same type !"); + + if (this->lengthOf() != second->lengthOf() || this->lengthOf() != third->lengthOf() || !this->isSameShape(second) || !this->isSameShape(third)) { + nd4j_printf("applyPairwiseLambda requires both operands to have the same shape\n",""); + throw std::runtime_error("Shapes mismach"); + } + + auto f = this->bufferAsT(); + auto s = second->bufferAsT(); + auto t = third->bufferAsT(); + auto z = target->bufferAsT(); + + if (this->ordering() == second->ordering() && this->ordering() == third->ordering() && this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1) && this->ews() == second->ews() && this->ews() == third->ews()) { + + PRAGMA_OMP_PARALLEL_FOR_SIMD + for (Nd4jLong e = 0; e < _length; e++) + z[e] = func(f[e], s[e], t[e]); + } else { + if (f == z) { + + PRAGMA_OMP_PARALLEL_FOR_SIMD + for (int e = 0; e < _length; e++) { + + auto tOffset = this->getOffset(e); + auto uOffset = second->getOffset(e); + auto vOffset = third->getOffset(e); + + f[tOffset] = func(f[tOffset], s[uOffset], t[vOffset]); + } + } else { + + PRAGMA_OMP_PARALLEL_FOR_SIMD + for (int e = 0; e < _length; e++) { + + auto tOffset = this->getOffset(e); + auto uOffset = second->getOffset(e); + auto vOffset = third->getOffset(e); + auto zOffset = target->getOffset(e); + + z[zOffset] = func(f[tOffset], s[uOffset], t[vOffset]); + } + } + } +} +template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function& func, NDArray* target); +template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function& func, NDArray* target); +template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function& func, NDArray* target); +template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function& func, NDArray* target); +template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function& func, NDArray* target); +template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function& func, NDArray* target); +template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function& func, NDArray* target); +template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function& func, NDArray* target); +template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function& func, NDArray* target); +template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function& func, NDArray* target); +template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function& func, NDArray* target); +template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function& func, NDArray* target); +template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function& func, NDArray* target); + +////////////////////////////////////////////////////////////////////////// +template +void NDArray::applyPairwiseLambda(const NDArray* other, const std::function& func, NDArray* target) { + if (target == nullptr) + target = this; + + if (other == nullptr) { + nd4j_printf("applyPairwiseLambda requires both operands to be valid NDArrays, but Y is NULL\n",""); + throw std::runtime_error("Other is null"); + } + + if(dataType() != DataTypeUtils::fromT()) + throw std::runtime_error("NDArray::applyPairwiseLambda method: wrong template parameter T, its type should be the same as type of this array!"); + if(dataType() != other->dataType() || dataType() != target->dataType()) + throw std::runtime_error("NDArray::applyPairwiseLambda method: all three arrays (this, other, target) must have the same type !"); + + if (this->lengthOf() != other->lengthOf()) { + nd4j_printf("applyPairwiseLambda requires both operands to have the same shape\n",""); + throw std::runtime_error("Shapes mismach"); + } + + auto f = this->bufferAsT(); + auto s = other->bufferAsT(); + auto z = target->bufferAsT(); + + if (this->ordering() == other->ordering() && this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1) && this->ews() == other->ews()) { + + PRAGMA_OMP_PARALLEL_FOR_SIMD + for (int e = 0; e < _length; e++) + z[e] = func(f[e], s[e]); + } else { + if (f == z) { + + PRAGMA_OMP_PARALLEL_FOR_SIMD + for (int e = 0; e < _length; e++) { + + auto xOffset = this->getOffset(e); + auto yOffset = other->getOffset(e); + + f[xOffset] = func(f[xOffset], s[yOffset]); + } + } else { + + PRAGMA_OMP_PARALLEL_FOR_SIMD + for (int e = 0; e < _length; e++) { + + auto xOffset = this->getOffset(e); + auto yOffset = other->getOffset(e); + auto zOffset = target->getOffset(e); + + z[zOffset] = func(f[xOffset], s[yOffset]); + } + } + } +} +template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function& func, NDArray* target); +template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function& func, NDArray* target); +template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function& func, NDArray* target); +template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function& func, NDArray* target); +template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function& func, NDArray* target); +template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function& func, NDArray* target); +template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function& func, NDArray* target); +template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function& func, NDArray* target); +template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function& func, NDArray* target); +template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function& func, NDArray* target); +template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function& func, NDArray* target); +template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function& func, NDArray* target); +template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function& func, NDArray* target); + +////////////////////////////////////////////////////////////////////////// +template +void NDArray::applyLambda(const std::function& func, NDArray* target) { + if (target == nullptr) + target = this; + + if(dataType() != DataTypeUtils::fromT()) + throw std::runtime_error("NDArray::applyLambda method: wrong template parameter T, its type should be the same as type of this array!"); + if(dataType() != target->dataType()) + throw std::runtime_error("NDArray::applyLambda method: types of this and target array should match !"); + + auto f = this->bufferAsT(); + auto z = target->bufferAsT(); + + if (this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1)) { + + PRAGMA_OMP_PARALLEL_FOR_SIMD + for (int e = 0; e < _length; e++) + z[e] = func(f[e]); + } else { + if (f == z) { + + PRAGMA_OMP_PARALLEL_FOR_SIMD + for (int e = 0; e < _length; e++) { + + auto xOffset = this->getOffset(e); + + f[xOffset] = func(f[xOffset]); + } + } else { + + PRAGMA_OMP_PARALLEL_FOR_SIMD + for (int e = 0; e < _length; e++) { + + auto xOffset = this->getOffset(e); + auto zOffset = target->getOffset(e); + + z[zOffset] = func(f[xOffset]); + } + } + } +} +template void NDArray::applyLambda(const std::function& func, NDArray* target); +template void NDArray::applyLambda(const std::function& func, NDArray* target); +template void NDArray::applyLambda(const std::function& func, NDArray* target); +template void NDArray::applyLambda(const std::function& func, NDArray* target); +template void NDArray::applyLambda(const std::function& func, NDArray* target); +template void NDArray::applyLambda(const std::function& func, NDArray* target); +template void NDArray::applyLambda(const std::function& func, NDArray* target); +template void NDArray::applyLambda(const std::function& func, NDArray* target); +template void NDArray::applyLambda(const std::function& func, NDArray* target); +template void NDArray::applyLambda(const std::function& func, NDArray* target); +template void NDArray::applyLambda(const std::function& func, NDArray* target); +template void NDArray::applyLambda(const std::function& func, NDArray* target); +template void NDArray::applyLambda(const std::function& func, NDArray* target); + +////////////////////////////////////////////////////////////////////////// +template +void NDArray::applyIndexedLambda(const std::function& func, NDArray* target) { + if (target == nullptr) + target = this; + + if(dataType() != DataTypeUtils::fromT()) + throw std::runtime_error("NDArray::applyIndexedLambda method: wrong template parameter T, its type should be the same as type of this array!"); + if(dataType() != target->dataType()) + throw std::runtime_error("NDArray::applyIndexedLambda method: types of this and target array should match !"); + + auto f = this->bufferAsT(); + auto z = target->bufferAsT(); + + if (this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1)) { + + PRAGMA_OMP_PARALLEL_FOR_SIMD + for (Nd4jLong e = 0; e < _length; e++) + z[e] = func(e, f[e]); + } else { + if (f == z) { + + PRAGMA_OMP_PARALLEL_FOR_SIMD + for (Nd4jLong e = 0; e < _length; e++) { + + auto xOffset = this->getOffset(e); + + f[xOffset] = func(e, f[xOffset]); + } + } else { + + PRAGMA_OMP_PARALLEL_FOR_SIMD + for (Nd4jLong e = 0; e < _length; e++) { + + auto xOffset = this->getOffset(e); + auto zOffset = target->getOffset(e); + + z[zOffset] = func(e, f[xOffset]); + } + } + } +} +template void NDArray::applyIndexedLambda(const std::function& func, NDArray* target); +template void NDArray::applyIndexedLambda(const std::function& func, NDArray* target); +template void NDArray::applyIndexedLambda(const std::function& func, NDArray* target); +template void NDArray::applyIndexedLambda(const std::function& func, NDArray* target); +template void NDArray::applyIndexedLambda(const std::function& func, NDArray* target); +template void NDArray::applyIndexedLambda(const std::function& func, NDArray* target); +template void NDArray::applyIndexedLambda(const std::function& func, NDArray* target); +template void NDArray::applyIndexedLambda(const std::function& func, NDArray* target); +template void NDArray::applyIndexedLambda(const std::function& func, NDArray* target); +template void NDArray::applyIndexedLambda(const std::function& func, NDArray* target); +template void NDArray::applyIndexedLambda(const std::function& func, NDArray* target); +template void NDArray::applyIndexedLambda(const std::function& func, NDArray* target); +template void NDArray::applyIndexedLambda(const std::function& func, NDArray* target); + +////////////////////////////////////////////////////////////////////////// +template +void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function& func, NDArray* target) { + if (target == nullptr) + target = this; + + if (other == nullptr) { + nd4j_printf("applyIndexedPairwiseLambda requires both operands to be valid NDArrays, but Y is NULL\n",""); + throw std::runtime_error("Other is null"); + } + if(dataType() != DataTypeUtils::fromT()) + throw std::runtime_error("NDArray::applyIndexedPairwiseLambda method: wrong template parameter T, its type should be the same as type of this array!"); + if(dataType() != target->dataType()) + throw std::runtime_error("NDArray::applyIndexedPairwiseLambda method: types of this and target array should match !"); + if (this->lengthOf() != other->lengthOf()) { + nd4j_printf("applyIndexedPairwiseLambda requires both operands to have the same shape\n",""); + throw std::runtime_error("Shapes mismach"); + } + + auto f = this->bufferAsT(); + auto s = other->bufferAsT(); + auto z = target->bufferAsT(); + + if (this->ordering() == other->ordering() && this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1) && this->ews() == other->ews()) { + + PRAGMA_OMP_PARALLEL_FOR_SIMD + for (Nd4jLong e = 0; e < _length; e++) + z[e] = func((Nd4jLong) e, f[e], s[e]); + } else { + if (f == z) { + + PRAGMA_OMP_PARALLEL_FOR_SIMD + for (int e = 0; e < _length; e++) { + + auto xOffset = this->getOffset(e); + auto yOffset = other->getOffset(e); + + f[xOffset] = func((Nd4jLong) e, f[xOffset], s[yOffset]); + } + } else { + + PRAGMA_OMP_PARALLEL_FOR_SIMD + for (int e = 0; e < _length; e++) { + + auto xOffset = this->getOffset(e); + auto yOffset = other->getOffset(e); + auto zOffset = target->getOffset(e); + + z[zOffset] = func((Nd4jLong) e, f[xOffset], s[yOffset]); + } + } + } +} +template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function& func, NDArray* target); +template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function& func, NDArray* target); +template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function& func, NDArray* target); +template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function& func, NDArray* target); +template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function& func, NDArray* target); +template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function& func, NDArray* target); +template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function& func, NDArray* target); +template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function& func, NDArray* target); +template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function& func, NDArray* target); +template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function& func, NDArray* target); +template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function& func, NDArray* target); +template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function& func, NDArray* target); +template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function& func, NDArray* target); \ No newline at end of file diff --git a/libnd4j/blas/cpu/NativeOps.cpp b/libnd4j/blas/cpu/NativeOps.cpp index 2c369471c..bdebeacb8 100644 --- a/libnd4j/blas/cpu/NativeOps.cpp +++ b/libnd4j/blas/cpu/NativeOps.cpp @@ -2710,6 +2710,32 @@ int NativeOps::dataTypeFromNpyHeader(void *header) { return (int) cnpy::dataTypeFromHeader(reinterpret_cast(header)); } +Nd4jPointer NativeOps::shapeBufferForNumpy(Nd4jPointer npyArray) { + cnpy::NpyArray arr = cnpy::loadNpyFromPointer(reinterpret_cast(npyArray)); + unsigned int shapeSize = arr.shape.size(); + std::vector shape(shapeSize); + bool _empty = false; + for(unsigned int i = 0; i < shapeSize; i++) { + shape[i] = arr.shape[i]; + + if (arr.shape[i] == 0) + _empty = true; + } + + auto dtype = cnpy::dataTypeFromHeader(reinterpret_cast(npyArray)); + + Nd4jLong *shapeBuffer; + if (_empty) { + if (shapeSize > 0) + shapeBuffer = nd4j::ShapeBuilders::emptyShapeInfo(dtype, arr.fortranOrder ? 'f' : 'c', shape); + else + shapeBuffer = nd4j::ShapeBuilders::emptyShapeInfo(dtype); + } else { + shapeBuffer = nd4j::ShapeBuilders::createShapeInfo(dtype, arr.fortranOrder ? 'f' : 'c', shape); + } + return reinterpret_cast(shapeBuffer); +} + BUILD_SINGLE_TEMPLATE(template void flattenGeneric,(Nd4jPointer*, int, char, void*, Nd4jLong*, void*, Nd4jLong*), LIBND4J_TYPES); BUILD_SINGLE_TEMPLATE(template void pullRowsGeneric, (void *, Nd4jLong*, void*, Nd4jLong*, const int, Nd4jLong*, Nd4jLong*, Nd4jLong*, Nd4jLong*, Nd4jLong*), LIBND4J_TYPES); BUILD_SINGLE_TEMPLATE(template void tearGeneric, (void *, Nd4jLong*, Nd4jPointer*, Nd4jLong*, Nd4jLong*, Nd4jLong*), LIBND4J_TYPES); diff --git a/libnd4j/blas/cuda/NDArray.cu b/libnd4j/blas/cuda/NDArray.cu index ad205588e..ea85e864e 100644 --- a/libnd4j/blas/cuda/NDArray.cu +++ b/libnd4j/blas/cuda/NDArray.cu @@ -454,7 +454,7 @@ void NDArray::printCurrentBuffer(const bool host, const char* msg, const int pre if (ews() != 1) { for (uint i = 0; i < _length; i++) - cudaMemcpyAsync(pHost + i * sizeof(T), getSpecialBuffer() + getOffset(i) * sizeof(T), sizeof(T), cudaMemcpyDeviceToHost, *(getContext()->getCudaStream())); + cudaMemcpyAsync(reinterpret_cast(pHost) + i, specialBufferWithOffset(i), sizeof(T), cudaMemcpyDeviceToHost, *(getContext()->getCudaStream())); } else cudaMemcpyAsync(pHost, getSpecialBuffer(), sizeOfT() * _length, cudaMemcpyDeviceToHost, *getContext()->getCudaStream()); @@ -475,6 +475,12 @@ template void NDArray::printCurrentBuffer(const bool host, const char* ms template void NDArray::printCurrentBuffer(const bool host, const char* msg, const int precision) const; +#if defined(__CUDACC__) && !defined(BUILD_TESTS) + +#include + +#endif + } // end namespace nd4j #endif diff --git a/libnd4j/blas/cuda/NativeOps.cu b/libnd4j/blas/cuda/NativeOps.cu index 9c51afd58..d857faf35 100755 --- a/libnd4j/blas/cuda/NativeOps.cu +++ b/libnd4j/blas/cuda/NativeOps.cu @@ -3105,3 +3105,29 @@ nd4j::ConstantDataBuffer* NativeOps::constantBuffer(nd4j::DataType dtype, double nd4j::ConstantDataBuffer* NativeOps::constantBuffer(nd4j::DataType dtype, nd4j::ConstantDescriptor *descriptor) { return nd4j::ConstantHelper::getInstance()->constantBuffer(*descriptor, dtype); } + +Nd4jPointer NativeOps::shapeBufferForNumpy(Nd4jPointer npyArray) { + cnpy::NpyArray arr = cnpy::loadNpyFromPointer(reinterpret_cast(npyArray)); + unsigned int shapeSize = arr.shape.size(); + std::vector shape(shapeSize); + bool _empty = false; + for(unsigned int i = 0; i < shapeSize; i++) { + shape[i] = arr.shape[i]; + + if (arr.shape[i] == 0) + _empty = true; + } + + auto dtype = cnpy::dataTypeFromHeader(reinterpret_cast(npyArray)); + + Nd4jLong *shapeBuffer; + if (_empty) { + if (shapeSize > 0) + shapeBuffer = nd4j::ShapeBuilders::emptyShapeInfo(dtype, arr.fortranOrder ? 'f' : 'c', shape); + else + shapeBuffer = nd4j::ShapeBuilders::emptyShapeInfo(dtype); + } else { + shapeBuffer = nd4j::ShapeBuilders::createShapeInfo(dtype, arr.fortranOrder ? 'f' : 'c', shape); + } + return reinterpret_cast(shapeBuffer); +} diff --git a/libnd4j/include/array/DataTypeUtils.h b/libnd4j/include/array/DataTypeUtils.h index a95cc6857..e58166b15 100644 --- a/libnd4j/include/array/DataTypeUtils.h +++ b/libnd4j/include/array/DataTypeUtils.h @@ -53,6 +53,15 @@ namespace nd4j { template FORCEINLINE static _CUDA_HD T max(); + /** + * returns inf for float/double and max for everything else + */ + template + FORCEINLINE static _CUDA_HD T infOrMax(); + + template + FORCEINLINE static _CUDA_HD T nanOrZero(); + // returns the difference between 1.0 and the next representable value of the given floating-point type template FORCEINLINE static T eps(); @@ -290,6 +299,36 @@ FORCEINLINE _CUDA_HD bfloat16 DataTypeUtils::max() { return bfloat16::max(); } +template <> +FORCEINLINE _CUDA_HD float DataTypeUtils::infOrMax() { + return std::numeric_limits::infinity(); +} + +template <> +FORCEINLINE _CUDA_HD double DataTypeUtils::infOrMax() { + return std::numeric_limits::infinity(); +} + +template +FORCEINLINE _CUDA_HD T DataTypeUtils::infOrMax() { + return DataTypeUtils::max(); +} + +template <> +FORCEINLINE _CUDA_HD float DataTypeUtils::nanOrZero() { + return std::numeric_limits::quiet_NaN(); +} + +template <> +FORCEINLINE _CUDA_HD double DataTypeUtils::nanOrZero() { + return std::numeric_limits::quiet_NaN(); +} + +template +FORCEINLINE _CUDA_HD T DataTypeUtils::nanOrZero() { + return static_cast(0); +} + FORCEINLINE std::string DataTypeUtils::asString(DataType dataType) { switch(dataType) { case INT8: diff --git a/libnd4j/include/array/impl/ShapeDescriptor.cpp b/libnd4j/include/array/impl/ShapeDescriptor.cpp index b01356c56..356177163 100644 --- a/libnd4j/include/array/impl/ShapeDescriptor.cpp +++ b/libnd4j/include/array/impl/ShapeDescriptor.cpp @@ -55,8 +55,13 @@ bool ShapeDescriptor::operator<(const ShapeDescriptor& other) const { } Nd4jLong* ShapeDescriptor::toShapeInfo() const { - if (_empty) - return ShapeBuilders::emptyShapeInfo(_dataType); + if (_empty) { + if (_rank == 0) + return ShapeBuilders::emptyShapeInfo(_dataType); + else { + return ShapeBuilders::emptyShapeInfo(_dataType, _order, _shape); + } + } switch (_rank) { @@ -133,15 +138,11 @@ ShapeDescriptor::ShapeDescriptor(const DataType type, const char order, const Nd ////////////////////////////////////////////////////////////////////////// ShapeDescriptor::ShapeDescriptor(const DataType type, const char order, const std::vector &shape): _dataType(type), _order(order), _shape(shape) { - _rank = ((shape.size() == 1 && shape[0] == 0)? 0: shape.size()); + _rank = shape.size(); _ews = 1; if (_rank > 0) { _strides.resize(_rank); - if (order == 'c') - shape::calcStrides(_shape.data(), shape.size(), _strides.data()); - else - shape::calcStridesFortran(_shape.data(), shape.size(), _strides.data()); for (auto v:_shape) { if (v == 0) { @@ -149,6 +150,17 @@ ShapeDescriptor::ShapeDescriptor(const DataType type, const char order, const st break; } } + + // no point calculating strides for empty arrays + if (!_empty) { + if (order == 'c') + shape::calcStrides(_shape.data(), shape.size(), _strides.data()); + else + shape::calcStridesFortran(_shape.data(), shape.size(), _strides.data()); + } else { + // all strides set to 0 + memset(_strides.data(), 0, sizeof(Nd4jLong) * shape.size()); + } } } @@ -191,8 +203,11 @@ ShapeDescriptor::ShapeDescriptor(const Nd4jLong *shapeInfo, bool inheritDtype) { _empty = shape::isEmpty(shapeInfo); - for (int e = 0; e < _rank; e++) + for (int e = 0; e < _rank; e++) { _shape.emplace_back(shapeInfo[e + 1]); + if (shapeInfo[e + 1] == 0) + _empty = true; + } for (int e = 0; e < _rank; e++) _strides.emplace_back(shapeInfo[e + 1 + _rank]); @@ -304,7 +319,14 @@ ShapeDescriptor ShapeDescriptor::vectorDescriptor(const Nd4jLong length, const D ShapeDescriptor descriptor; descriptor._dataType = type; descriptor._shape.emplace_back(length); - descriptor._strides.emplace_back(1); + + if (length > 0) + descriptor._strides.emplace_back(1); + else { + descriptor._strides.emplace_back(0); + descriptor._empty = true; + } + descriptor._order = 'c'; descriptor._ews = 1; descriptor._rank = 1; diff --git a/libnd4j/include/helpers/ShapeBuilders.h b/libnd4j/include/helpers/ShapeBuilders.h index 8af7068af..49ef20e9f 100644 --- a/libnd4j/include/helpers/ShapeBuilders.h +++ b/libnd4j/include/helpers/ShapeBuilders.h @@ -29,7 +29,7 @@ #include namespace nd4j { - class ShapeBuilders { + class ND4J_EXPORT ShapeBuilders { public: static Nd4jLong* createScalarShapeInfo(nd4j::DataType dataType, nd4j::memory::Workspace* workspace = nullptr); @@ -53,6 +53,8 @@ namespace nd4j { static Nd4jLong* emptyShapeInfo(const nd4j::DataType dataType, memory::Workspace* workspace = nullptr); + static Nd4jLong* emptyShapeInfo(const nd4j::DataType dataType, const char order, const std::vector &shape, memory::Workspace* workspace = nullptr); + }; } diff --git a/libnd4j/include/helpers/ShapeUtils.h b/libnd4j/include/helpers/ShapeUtils.h index 422d6ecc5..3f30585b0 100644 --- a/libnd4j/include/helpers/ShapeUtils.h +++ b/libnd4j/include/helpers/ShapeUtils.h @@ -40,6 +40,12 @@ namespace nd4j { static Nd4jLong* evalReduceShapeInfo(const char order, std::vector& dimensions, const NDArray& arr, const bool keepDims = false, const bool supportOldShapes = false, nd4j::memory::Workspace* workspace = nullptr); static Nd4jLong* evalReduceShapeInfo(const char order, std::vector& dimensions, const Nd4jLong* shapeInfo, const bool keepDims = false, const bool supportOldShapes = false, nd4j::memory::Workspace* workspace = nullptr); + /** + * evaluate output shape for reduce operation when input shape is empty + * behavior is analogous to tf + */ + static Nd4jLong* evalReduceShapeInfoEmpty(const char order, std::vector& dimensions, const Nd4jLong *shapeInfo, const nd4j::DataType dataType, const bool keepDims, nd4j::memory::Workspace* workspace); + // evaluate shape for array which is result of repeat operation applied to arr static std::vector evalRepeatShape(int dimension, const std::vector& repeats, const NDArray& arr); diff --git a/libnd4j/include/helpers/cpu/ConstantTadHelper.cpp b/libnd4j/include/helpers/cpu/ConstantTadHelper.cpp index ea29bf5f0..fa3d78684 100644 --- a/libnd4j/include/helpers/cpu/ConstantTadHelper.cpp +++ b/libnd4j/include/helpers/cpu/ConstantTadHelper.cpp @@ -71,7 +71,9 @@ namespace nd4j { auto sPtr = new Nd4jLong[shape::shapeInfoLength(subArrRank)]; auto oPtr = new Nd4jLong[numOfSubArrs]; - shape::calcSubArrShapeAndOffsets(shapeInfo, numOfSubArrs, dimsToExclude.size(), dimsToExclude.data(), sPtr, oPtr, descriptor.areUnitiesinShape()); + if (numOfSubArrs > 0) + shape::calcSubArrShapeAndOffsets(shapeInfo, numOfSubArrs, dimsToExclude.size(), dimsToExclude.data(), sPtr, oPtr, descriptor.areUnitiesinShape()); + ConstantDataBuffer shapesBuffer(sPtr, nullptr, shape::shapeInfoLength(subArrRank)*sizeof(Nd4jLong), DataType::INT64); ConstantDataBuffer offsetsBuffer(oPtr, nullptr, numOfSubArrs*sizeof(Nd4jLong), DataType::INT64); diff --git a/libnd4j/include/helpers/cuda/ConstantTadHelper.cu b/libnd4j/include/helpers/cuda/ConstantTadHelper.cu index bc67f251d..0a4bc6115 100644 --- a/libnd4j/include/helpers/cuda/ConstantTadHelper.cu +++ b/libnd4j/include/helpers/cuda/ConstantTadHelper.cu @@ -75,7 +75,8 @@ namespace nd4j { auto sPtr = new Nd4jLong[shape::shapeInfoLength(subArrRank)]; auto oPtr = new Nd4jLong[numOfSubArrs]; - shape::calcSubArrShapeAndOffsets(shapeInfo, numOfSubArrs, dimsToExclude.size(), dimsToExclude.data(), sPtr, oPtr, descriptor.areUnitiesinShape()); + if (numOfSubArrs > 0) + shape::calcSubArrShapeAndOffsets(shapeInfo, numOfSubArrs, dimsToExclude.size(), dimsToExclude.data(), sPtr, oPtr, descriptor.areUnitiesinShape()); Nd4jPointer soPtr; auto res = cudaMalloc(reinterpret_cast(&soPtr), numOfSubArrs * sizeof(Nd4jLong)); diff --git a/libnd4j/include/helpers/impl/ShapeBuilders.cpp b/libnd4j/include/helpers/impl/ShapeBuilders.cpp index 2a9e30f99..70aa934ca 100644 --- a/libnd4j/include/helpers/impl/ShapeBuilders.cpp +++ b/libnd4j/include/helpers/impl/ShapeBuilders.cpp @@ -54,11 +54,6 @@ namespace nd4j { //////////////////////////////////////////////////////////////////////////////// Nd4jLong* ShapeBuilders::createShapeInfo(const nd4j::DataType dataType, const char order, int rank, const Nd4jLong* shapeOnly, memory::Workspace* workspace) { - - if (rank) - if(shapeOnly[0] == 0) // scalar case - rank = 0; - Nd4jLong* shapeInfo = nullptr; if(rank == 0) { // scalar case @@ -67,10 +62,23 @@ namespace nd4j { else { ALLOCATE(shapeInfo, workspace, shape::shapeInfoLength(rank), Nd4jLong); shapeInfo[0] = rank; - for(int i = 0; i < rank; ++i) + bool isEmpty = false; + for(int i = 0; i < rank; ++i) { shapeInfo[i + 1] = shapeOnly[i]; - shape::updateStrides(shapeInfo, order); + if (shapeOnly[i] == 0) + isEmpty = true; + } + + if (!isEmpty) { + shape::updateStrides(shapeInfo, order); + } + else { + shapeInfo[shape::shapeInfoLength(rank) - 1] = order; + memset(shape::stride(shapeInfo), 0, rank * sizeof(Nd4jLong)); + ArrayOptions::setPropertyBit(shapeInfo, ARRAY_EMPTY); + } + nd4j::ArrayOptions::setDataType(shapeInfo, dataType); } @@ -78,9 +86,16 @@ namespace nd4j { } Nd4jLong* ShapeBuilders::emptyShapeInfo(const nd4j::DataType dataType, memory::Workspace* workspace) { - auto shape = createScalarShapeInfo(dataType, workspace); - ArrayOptions::setPropertyBit(shape, ARRAY_EMPTY); - return shape; + auto shapeInfo = createScalarShapeInfo(dataType, workspace); + ArrayOptions::setPropertyBit(shapeInfo, ARRAY_EMPTY); + return shapeInfo; + } + + Nd4jLong* ShapeBuilders::emptyShapeInfo(const nd4j::DataType dataType, const char order, const std::vector &shape, memory::Workspace* workspace) { + auto shapeInfo = createShapeInfo(dataType, order, shape, workspace); + memset(shape::stride(shapeInfo), 0, shape.size() * sizeof(Nd4jLong)); + ArrayOptions::setPropertyBit(shapeInfo, ARRAY_EMPTY); + return shapeInfo; } //////////////////////////////////////////////////////////////////////////////// diff --git a/libnd4j/include/helpers/impl/ShapeUtils.cpp b/libnd4j/include/helpers/impl/ShapeUtils.cpp index 4494f886d..252bc7c52 100644 --- a/libnd4j/include/helpers/impl/ShapeUtils.cpp +++ b/libnd4j/include/helpers/impl/ShapeUtils.cpp @@ -108,27 +108,81 @@ std::vector ShapeUtils::evalShapeForTensorDot(const NDArray* a, cons return evalShapeForTensorDot(a->getShapeInfo(), b->getShapeInfo(), axesA, axesB, permutAt, permutBt, shapeAt, shapeBt); } -Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector& dimensions, const NDArray& arr, const bool keepDims, const bool supportOldShapes, nd4j::memory::Workspace* workspace) { - return evalReduceShapeInfo(order, dimensions, arr, arr.dataType(), keepDims, supportOldShapes, workspace); + +////////////////////////////////////////////////////////////////////////// +// evaluate output shape for reduce operation when input shape is empty +Nd4jLong* ShapeUtils::evalReduceShapeInfoEmpty(const char order, std::vector& dimsToExclude, const Nd4jLong *shapeInfo, const nd4j::DataType dataType, const bool keepDims, nd4j::memory::Workspace* workspace) { + + if (dimsToExclude.size() == 0) { // return copy of input shape + Nd4jLong* outShapeInfo = ShapeBuilders::copyShapeInfoAndType(shapeInfo, dataType, true, workspace); + ShapeDescriptor descriptor(outShapeInfo, dataType); + RELEASE(outShapeInfo, workspace); + return ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor).primaryAsT(); + } + + const int rank = shape::rank(shapeInfo); + Nd4jLong* outShapeInfo = nullptr; + + if (dimsToExclude.size() == rank) { // return scalar or shape filled with unities + + if(!keepDims) + outShapeInfo = ShapeBuilders::createScalarShapeInfo(dataType, workspace); + else + outShapeInfo = ShapeBuilders::createShapeInfo(dataType, order, std::vector(rank, 1), workspace); + } + else { + + shape::checkDimensions(rank, dimsToExclude); + + std::vector outShape; + + if(keepDims) { + outShape.assign(shapeInfo + 1, shapeInfo + 1 + rank); + for(const auto& dim : dimsToExclude) + outShape[dim] = 1; + } + else { + for (uint i = 0, j = 0; i < rank; ++i) { + if(j < dimsToExclude.size() && i == dimsToExclude[j]) + ++j; + else + outShape.emplace_back(shapeInfo[i + 1]); + } + } + + outShapeInfo = ShapeBuilders::createShapeInfo(dataType, order, outShape, workspace); + } + + ShapeDescriptor descriptor(outShapeInfo, dataType); + RELEASE(outShapeInfo, workspace); + return ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor).primaryAsT(); } -Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector& dimensions, const Nd4jLong* shapeInfo, const bool keepDims, const bool supportOldShapes, nd4j::memory::Workspace* workspace) { - return evalReduceShapeInfo(order, dimensions, shapeInfo, ArrayOptions::dataType(shapeInfo), keepDims, supportOldShapes, workspace); +Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector& dimsToExclude, const NDArray& arr, const bool keepDims, const bool supportOldShapes, nd4j::memory::Workspace* workspace) { + return evalReduceShapeInfo(order, dimsToExclude, arr, arr.dataType(), keepDims, supportOldShapes, workspace); +} + +Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector& dimsToExclude, const Nd4jLong* shapeInfo, const bool keepDims, const bool supportOldShapes, nd4j::memory::Workspace* workspace) { + return evalReduceShapeInfo(order, dimsToExclude, shapeInfo, ArrayOptions::dataType(shapeInfo), keepDims, supportOldShapes, workspace); } ////////////////////////////////////////////////////////////////////////// -Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector& dimensions, const NDArray& arr, const nd4j::DataType dataType, const bool keepDims, const bool supportOldShapes, nd4j::memory::Workspace* workspace) { - return evalReduceShapeInfo(order, dimensions, arr.getShapeInfo(), dataType, keepDims, supportOldShapes, workspace); +Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector& dimsToExclude, const NDArray& arr, const nd4j::DataType dataType, const bool keepDims, const bool supportOldShapes, nd4j::memory::Workspace* workspace) { + return evalReduceShapeInfo(order, dimsToExclude, arr.getShapeInfo(), dataType, keepDims, supportOldShapes, workspace); } ////////////////////////////////////////////////////////////////////////// // evaluate shape resulting from reduce operation -Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector& dimensions, const Nd4jLong *shapeInfo, const nd4j::DataType dataType, const bool keepDims, const bool supportOldShapes, nd4j::memory::Workspace* workspace) { +Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector& dimsToExclude, const Nd4jLong *shapeInfo, const nd4j::DataType dataType, const bool keepDims, const bool supportOldShapes, nd4j::memory::Workspace* workspace) { + + if(ArrayOptions::arrayType(shapeInfo) == ArrayType::EMPTY) + return ShapeUtils::evalReduceShapeInfoEmpty(order, dimsToExclude, shapeInfo, dataType, keepDims, workspace); + Nd4jLong* newShapeInfo = nullptr; int rank = shape::rank(const_cast(shapeInfo)); - if (dimensions.size() == 0) { // return scalar or array with len=1 in this case + if (dimsToExclude.size() == 0) { // return scalar or array with len=1 in this case if(keepDims && rank > 1) { ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(rank), Nd4jLong); @@ -157,16 +211,16 @@ Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector& di } } - shape::checkDimensions(rank, dimensions); + shape::checkDimensions(rank, dimsToExclude); - int dimSize = dimensions.size(); + int dimSize = dimsToExclude.size(); if(keepDims) { ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(rank), Nd4jLong); newShapeInfo[0] = rank; for(int i = 0; i < rank; ++i) - if (std::binary_search(dimensions.begin(), dimensions.end(), i)) // dimensions is already sorted after shape::checkDimensions() has been applied + if (std::binary_search(dimsToExclude.begin(), dimsToExclude.end(), i)) // dimsToExclude is already sorted after shape::checkDimensions() has been applied newShapeInfo[i+1] = 1; else newShapeInfo[i+1] = shapeInfo[i+1]; @@ -178,7 +232,7 @@ Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector& di } int newRank = rank - dimSize; - if (newRank==0 || (dimSize==1 && dimensions[0]==INT_MAX)) { // check whether given dimension is meant for the whole dimension + if (newRank==0 || (dimSize==1 && dimsToExclude[0]==INT_MAX)) { // check whether given dimension is meant for the whole dimension if(supportOldShapes) { ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(2), Nd4jLong); @@ -199,7 +253,7 @@ Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector& di newShapeInfo[0] = newRank; // set rank int j=1; for(int i = 0; i < rank; ++i) - if (!std::binary_search(dimensions.begin(), dimensions.end(), i)) // dimensions is already sorted after shape::checkDimensions() has been applied + if (!std::binary_search(dimsToExclude.begin(), dimsToExclude.end(), i)) // dimsToExclude is already sorted after shape::checkDimensions() has been applied newShapeInfo[j++] = shapeInfo[i+1]; //ensure whether vector has proper shape for old shape type @@ -208,7 +262,7 @@ Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector& di RELEASE(newShapeInfo, workspace); ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(2), Nd4jLong); // set newRank = 2 newShapeInfo[0] = 2; - if (dimensions[0] == 0) { + if (dimsToExclude[0] == 0) { newShapeInfo[1] = 1; newShapeInfo[2] = oldValue; } @@ -422,8 +476,23 @@ bool ShapeUtils::evalBroadcastShapeInfo(Nd4jLong *max, Nd4jLong *min, const bool if(maxShapeInfo[maxRank-i] < minShapeInfo[minRank-i]) tmpShapeInfo[maxRank - i] = minShapeInfo[minRank-i]; + // nullify zero axis + for (int e = 0; e < maxRank; e++) + if (maxShapeInfo[e+1] == 0) + tmpShapeInfo[e+1] = 0; + + int delta = maxRank - minRank; + for (int e = minRank - 1; e >= 0; e--) + if (minShapeInfo[e + 1] == 0) + tmpShapeInfo[e + 1 + delta] = 0; + ShapeUtils::updateStridesAndType(tmpShapeInfo, DataTypeUtils::pickPairwiseResultType(maxShapeInfo, minShapeInfo), shape::order(maxShapeInfo)); + if (shape::isEmpty(max) || shape::isEmpty(min)) { + ArrayOptions::setPropertyBit(tmpShapeInfo, ARRAY_EMPTY); + memset(shape::stride(tmpShapeInfo), 0, shape::rank(tmpShapeInfo) * sizeof(Nd4jLong)); + } + ShapeDescriptor descriptor(tmpShapeInfo); RELEASE(tmpShapeInfo, workspace); resultShapeInfo = ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor).primaryAsT(); @@ -805,7 +874,7 @@ std::vector ShapeUtils::evalShapeForMatmul(const Nd4jLong* xShapeInfo, nd4j_printf("ShapeUtils::evalShapeForMatmul method: since input arrays are vectors they must have the same length, but got x length = %i, y length = %i !", xShapeInfo[1], yShapeInfo[1]); throw std::invalid_argument(""); } - return std::vector({0}); + return std::vector({}); } diff --git a/libnd4j/include/helpers/shape.h b/libnd4j/include/helpers/shape.h index d9a540838..c40ddc1e6 100644 --- a/libnd4j/include/helpers/shape.h +++ b/libnd4j/include/helpers/shape.h @@ -1992,7 +1992,7 @@ template len = shape::length(shapeInfo); //check whether shape is like {1} or {1,1} or {1,1,1,1,...} - in this case we don't need permute - if(len < 2) + if(len == 1) return; const int rank = shape::rank(shapeInfo); @@ -3961,7 +3961,7 @@ INLINEDEF _CUDA_H bool reshapeC(const int oldRank, const Nd4jLong* oldShapeInfo, newDim = newShape[newStart]; oldDim = oldShape[oldStart]; - while (newDim != oldDim) + while (newDim != oldDim && newDim > 0 && oldDim > 0) if (newDim < oldDim) newDim *= newShape[newStop++]; else oldDim *= oldShape[oldStop++]; diff --git a/libnd4j/include/loops/cpu/indexreduce.cpp b/libnd4j/include/loops/cpu/indexreduce.cpp index d87c56677..951ac287b 100644 --- a/libnd4j/include/loops/cpu/indexreduce.cpp +++ b/libnd4j/include/loops/cpu/indexreduce.cpp @@ -41,7 +41,7 @@ void IndexReduce::exec(const int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams, Nd4jLong *z, Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, + int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffset) { DISPATCH_BY_OPNUM_T(exec, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, tadShapeInfo, tadOffset), INDEX_REDUCE_OPS); @@ -51,7 +51,7 @@ DISPATCH_BY_OPNUM_T(exec, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, dime template template Nd4jLong IndexReduce::execScalar(void *vx, Nd4jLong *xShapeInfo, void *vextraParams) { - + auto x = reinterpret_cast(vx); auto extraParams = reinterpret_cast(vextraParams); @@ -116,13 +116,23 @@ void IndexReduce::exec(void *vx, Nd4jLong *xShapeInfo, auto x = reinterpret_cast(vx); auto extraParams = reinterpret_cast(vextraParams); + const Nd4jLong zLen = shape::length(zShapeInfo); + + if(nd4j::ArrayOptions::arrayType(xShapeInfo) == nd4j::ArrayType::EMPTY) { + if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY) + return; + const auto indexValue = OpType::startingIndexValue(x); + PRAGMA_OMP_PARALLEL_FOR_IF(zLen > nd4j::Environment::getInstance()->elementwiseThreshold()) + for (uint i = 0; i < zLen; i++) + z[i] = indexValue.index;; + return; + } + if(shape::isScalar(zShapeInfo)) { z[0] = execScalar(x,xShapeInfo,extraParams); return; } - const Nd4jLong zLen = shape::length(zShapeInfo); - auto tadOnlyShapeInfo = tadShapeInfo; Nd4jLong *tadOffsets = tadOffset; diff --git a/libnd4j/include/loops/cpu/reduce/reduce_bool.cpp b/libnd4j/include/loops/cpu/reduce/reduce_bool.cpp index dbe787aae..91f4144e4 100644 --- a/libnd4j/include/loops/cpu/reduce/reduce_bool.cpp +++ b/libnd4j/include/loops/cpu/reduce/reduce_bool.cpp @@ -45,7 +45,22 @@ namespace functions { const Nd4jLong length = shape::length(xShapeInfo); auto xEws = shape::elementWiseStride(xShapeInfo); - + + if (shape::isEmpty(xShapeInfo)) { + z[0] = OpType::startingValue(x); + return; + } + + if(nd4j::ArrayOptions::arrayType(xShapeInfo) == nd4j::ArrayType::EMPTY) { + if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY) + return; + const auto startingVal = OpType::startingValue(x); + PRAGMA_OMP_PARALLEL_FOR_IF(length > nd4j::Environment::getInstance()->elementwiseThreshold()) + for (uint i = 0; i < length; i++) + z[i] = startingVal; + return; + } + if (xEws >= 1) { z[0] = execScalar(x, xEws, length, extraParams); } @@ -82,7 +97,7 @@ namespace functions { const Nd4jLong length = shape::length(xShapeInfo); auto xEws = shape::elementWiseStride(xShapeInfo); - + if (xEws >= 1) { return execScalar(x, xEws, length, extraParams); } @@ -157,6 +172,16 @@ namespace functions { auto resultLength = shape::length(zShapeInfo); + if(nd4j::ArrayOptions::arrayType(xShapeInfo) == nd4j::ArrayType::EMPTY) { + if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY) + return; + const auto startingVal = OpType::startingValue(x); + PRAGMA_OMP_PARALLEL_FOR_IF(resultLength > nd4j::Environment::getInstance()->elementwiseThreshold()) + for (uint i = 0; i < resultLength; i++) + z[i] = startingVal; + return; + } + //pre squeezed: this is for keeping the pointer to the original //shape information for tad offset //the squeezed information doesn't render the right strides for @@ -212,9 +237,9 @@ namespace functions { if (xEws == 1) { PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) - { + { auto local = OpType::startingValue(x); - auto threadNum = omp_get_thread_num(); + auto threadNum = omp_get_thread_num(); auto threadOffset = info.getThreadOffset(threadNum); auto xi = x + threadOffset; auto ulen = static_cast(info.getItersPerThread(threadNum)); @@ -223,15 +248,15 @@ namespace functions { local = OpType::update(local, OpType::op(xi[i], extraParams), extraParams); PRAGMA_OMP_CRITICAL - startingVal = OpType::update(startingVal, local, extraParams); + startingVal = OpType::update(startingVal, local, extraParams); } } else { PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) - { + { auto local = OpType::startingValue(x); - auto threadNum = omp_get_thread_num(); + auto threadNum = omp_get_thread_num(); auto threadOffset = info.getThreadOffset(threadNum); auto xi = x + xEws*threadOffset; auto ulen = static_cast(info.getItersPerThread(threadNum)); @@ -240,8 +265,8 @@ namespace functions { local = OpType::update(local, OpType::op(xi[i*xEws], extraParams), extraParams); PRAGMA_OMP_CRITICAL - startingVal = OpType::update(startingVal, local, extraParams); - } + startingVal = OpType::update(startingVal, local, extraParams); + } } return OpType::postProcess(startingVal, length, extraParams); } diff --git a/libnd4j/include/loops/cpu/reduce/reduce_float.cpp b/libnd4j/include/loops/cpu/reduce/reduce_float.cpp index 5de435a3e..8d04b7cdb 100644 --- a/libnd4j/include/loops/cpu/reduce/reduce_float.cpp +++ b/libnd4j/include/loops/cpu/reduce/reduce_float.cpp @@ -45,7 +45,26 @@ namespace functions { const Nd4jLong length = shape::length(xShapeInfo); auto xEws = shape::elementWiseStride(xShapeInfo); - + + if (shape::isEmpty(xShapeInfo)) { + if (std::is_same>::value) { + z[0] = nd4j::DataTypeUtils::nanOrZero(); + } else { + z[0] = OpType::startingValue(x); + } + return; + } + + if(nd4j::ArrayOptions::arrayType(xShapeInfo) == nd4j::ArrayType::EMPTY) { + if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY) + return; + const auto startingVal = OpType::startingValue(x); + PRAGMA_OMP_PARALLEL_FOR_IF(length > nd4j::Environment::getInstance()->elementwiseThreshold()) + for (uint i = 0; i < length; i++) + z[i] = startingVal; + return; + } + if (xEws > 0) { z[0] = execScalar(x, xEws, length, extraParams); } @@ -69,7 +88,7 @@ namespace functions { start = OpType::update(start, intermediate[e], extraParams); z[0] = OpType::postProcess(start, shape::length(xShapeInfo), extraParams); - } + } } @@ -165,6 +184,16 @@ namespace functions { auto resultLength = shape::length(zShapeInfo); + if(nd4j::ArrayOptions::arrayType(xShapeInfo) == nd4j::ArrayType::EMPTY) { + if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY) + return; + const auto startingVal = std::is_same>::value ? nd4j::DataTypeUtils::nanOrZero() : static_cast(OpType::startingValue(x)); + PRAGMA_OMP_PARALLEL_FOR_IF(resultLength > nd4j::Environment::getInstance()->elementwiseThreshold()) + for (uint i = 0; i < resultLength; i++) + z[i] = startingVal; + return; + } + //pre squeezed: this is for keeping the pointer to the original //shape information for tad offset //the squeezed information doesn't render the right strides for diff --git a/libnd4j/include/loops/cpu/reduce/reduce_long.cpp b/libnd4j/include/loops/cpu/reduce/reduce_long.cpp index d6ae96d72..9069f4198 100644 --- a/libnd4j/include/loops/cpu/reduce/reduce_long.cpp +++ b/libnd4j/include/loops/cpu/reduce/reduce_long.cpp @@ -46,6 +46,21 @@ namespace functions { const Nd4jLong length = shape::length(xShapeInfo); auto xEws = shape::elementWiseStride(xShapeInfo); + if (shape::isEmpty(xShapeInfo)) { + z[0] = OpType::startingValue(x); + return; + } + + if(nd4j::ArrayOptions::arrayType(xShapeInfo) == nd4j::ArrayType::EMPTY) { + if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY) + return; + const auto startingVal = OpType::startingValue(x); + PRAGMA_OMP_PARALLEL_FOR_IF(length > nd4j::Environment::getInstance()->elementwiseThreshold()) + for (uint i = 0; i < length; i++) + z[i] = startingVal; + return; + } + if (xEws >= 1) { z[0] = execScalar(x, xEws, length, extraParams); } @@ -105,7 +120,7 @@ namespace functions { delete[] intermediate; return OpType::postProcess(start, shape::length(xShapeInfo), extraParams); - } + } } @@ -159,6 +174,16 @@ namespace functions { auto resultLength = shape::length(zShapeInfo); + if(nd4j::ArrayOptions::arrayType(xShapeInfo) == nd4j::ArrayType::EMPTY) { + if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY) + return; + const auto startingVal = OpType::startingValue(x); + PRAGMA_OMP_PARALLEL_FOR_IF(resultLength > nd4j::Environment::getInstance()->elementwiseThreshold()) + for (uint i = 0; i < resultLength; i++) + z[i] = startingVal; + return; + } + //pre squeezed: this is for keeping the pointer to the original //shape information for tad offset //the squeezed information doesn't render the right strides for @@ -209,7 +234,7 @@ namespace functions { template template Z _CUDA_H ReduceLongFunction::execScalar(void *vx, Nd4jLong xEws, Nd4jLong length, void *vextraParams) { - + auto x = reinterpret_cast(vx); auto extraParams = reinterpret_cast(vextraParams); @@ -219,9 +244,9 @@ namespace functions { if (xEws == 1) { PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) - { + { auto local = OpType::startingValue(x); - auto threadNum = omp_get_thread_num(); + auto threadNum = omp_get_thread_num(); auto threadOffset = info.getThreadOffset(threadNum); auto xi = x + threadOffset; auto ulen = static_cast(info.getItersPerThread(threadNum)); @@ -230,15 +255,15 @@ namespace functions { local = OpType::update(local, OpType::op(xi[i], extraParams), extraParams); PRAGMA_OMP_CRITICAL - startingVal = OpType::update(startingVal, local, extraParams); + startingVal = OpType::update(startingVal, local, extraParams); } } else { PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) - { + { auto local = OpType::startingValue(x); - auto threadNum = omp_get_thread_num(); + auto threadNum = omp_get_thread_num(); auto threadOffset = info.getThreadOffset(threadNum); auto xi = x + xEws*threadOffset; auto ulen = static_cast(info.getItersPerThread(threadNum)); @@ -247,8 +272,8 @@ namespace functions { local = OpType::update(local, OpType::op(xi[i*xEws], extraParams), extraParams); PRAGMA_OMP_CRITICAL - startingVal = OpType::update(startingVal, local, extraParams); - } + startingVal = OpType::update(startingVal, local, extraParams); + } } return OpType::postProcess(startingVal, length, extraParams); } diff --git a/libnd4j/include/loops/cpu/reduce/reduce_same.cpp b/libnd4j/include/loops/cpu/reduce/reduce_same.cpp index 571ddda9e..676348017 100644 --- a/libnd4j/include/loops/cpu/reduce/reduce_same.cpp +++ b/libnd4j/include/loops/cpu/reduce/reduce_same.cpp @@ -48,6 +48,20 @@ namespace functions { const auto xEws = shape::elementWiseStride(xShapeInfo); const int rank = shape::rank(xShapeInfo); + if (shape::isEmpty(xShapeInfo)) { + z[0] = OpType::startingValue(x); + return; + } + + if(nd4j::ArrayOptions::arrayType(xShapeInfo) == nd4j::ArrayType::EMPTY) { + if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY) + return; + const auto startingVal = OpType::startingValue(x); + PRAGMA_OMP_PARALLEL_FOR_IF(length > nd4j::Environment::getInstance()->elementwiseThreshold()) + for (uint i = 0; i < length; i++) + z[i] = startingVal; + return; + } if (xEws >= 1) { z[0] = execScalar(x, xEws, length, extraParams); @@ -71,7 +85,7 @@ namespace functions { for (int e = 0; e < maxThreads; e++) start = OpType::update(start, intermediate[e], extraParams); - z[0] = OpType::postProcess(start, shape::length(xShapeInfo), extraParams); + z[0] = OpType::postProcess(start, length, extraParams); } } @@ -171,6 +185,16 @@ namespace functions { auto zLength = shape::length(zShapeInfo); + if(nd4j::ArrayOptions::arrayType(xShapeInfo) == nd4j::ArrayType::EMPTY) { + if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY) + return; + const auto startingVal = OpType::startingValue(x); + PRAGMA_OMP_PARALLEL_FOR_IF(zLength > nd4j::Environment::getInstance()->elementwiseThreshold()) + for (uint i = 0; i < zLength; i++) + z[i] = startingVal; + return; + } + //pre squeezed: this is for keeping the pointer to the original //shape information for tad offset //the squeezed information doesn't render the right strides for @@ -231,9 +255,9 @@ namespace functions { if (xEws == 1) { PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) - { + { auto local = OpType::startingValue(x); - auto threadNum = omp_get_thread_num(); + auto threadNum = omp_get_thread_num(); auto threadOffset = info.getThreadOffset(threadNum); auto xi = x + threadOffset; auto ulen = static_cast(info.getItersPerThread(threadNum)); @@ -242,15 +266,15 @@ namespace functions { local = OpType::update(local, OpType::op(xi[i], extraParams), extraParams); PRAGMA_OMP_CRITICAL - startingVal = OpType::update(startingVal, local, extraParams); + startingVal = OpType::update(startingVal, local, extraParams); } } else { PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) - { + { auto local = OpType::startingValue(x); - auto threadNum = omp_get_thread_num(); + auto threadNum = omp_get_thread_num(); auto threadOffset = info.getThreadOffset(threadNum); auto xi = x + xEws*threadOffset; auto ulen = static_cast(info.getItersPerThread(threadNum)); @@ -259,8 +283,8 @@ namespace functions { local = OpType::update(local, OpType::op(xi[i*xEws], extraParams), extraParams); PRAGMA_OMP_CRITICAL - startingVal = OpType::update(startingVal, local, extraParams); - } + startingVal = OpType::update(startingVal, local, extraParams); + } } return OpType::postProcess(startingVal, length, extraParams); } diff --git a/libnd4j/include/loops/cpu/reduce3.cpp b/libnd4j/include/loops/cpu/reduce3.cpp index 5da61f26b..eeea227c8 100644 --- a/libnd4j/include/loops/cpu/reduce3.cpp +++ b/libnd4j/include/loops/cpu/reduce3.cpp @@ -37,7 +37,7 @@ void Reduce3::execScalar(void *vx, Nd4jLong *xShapeInfo, void *vextraParams, void *vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo) { - + auto x = reinterpret_cast(vx); auto y = reinterpret_cast(vy); auto z = reinterpret_cast(vz); @@ -47,11 +47,21 @@ void Reduce3::execScalar(void *vx, Nd4jLong *xShapeInfo, auto xEws = shape::elementWiseStride(xShapeInfo); auto yEws = shape::elementWiseStride(yShapeInfo); + if(nd4j::ArrayOptions::arrayType(xShapeInfo) == nd4j::ArrayType::EMPTY || nd4j::ArrayOptions::arrayType(yShapeInfo) == nd4j::ArrayType::EMPTY) { + if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY) + return; + const auto startingVal = OpType::startingValue(x); + PRAGMA_OMP_PARALLEL_FOR_IF(length > nd4j::Environment::getInstance()->elementwiseThreshold()) + for (uint i = 0; i < length; i++) + z[i] = startingVal; + return; + } + Z extraParamsVals[3] = {(Z) 0.0f, (Z) 0.0f, (Z) 0.0f}; // it's possible case for EqualsWithEps op - if (extraParams != nullptr) + if (extraParams != nullptr) extraParamsVals[2] = extraParams[0]; - + uint xShapeInfoCast[MAX_RANK]; const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); @@ -117,7 +127,7 @@ void Reduce3::execScalar(const int opNum, void *extraParamsVals, void *vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo) { - + DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(vx, xShapeInfo, extraParamsVals, vy, yShapeInfo, vz, zShapeInfo), REDUCE3_OPS); } @@ -176,8 +186,8 @@ void Reduce3:: execAll(void *vx, Nd4jLong *xShapeInfo, void *vextraParams, void *vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, - int *dimension, int dimensionLength, - Nd4jLong *xTadShapeInfo, Nd4jLong *xOffsets, + int *dimension, int dimensionLength, + Nd4jLong *xTadShapeInfo, Nd4jLong *xOffsets, Nd4jLong *yTadShapeInfo, Nd4jLong *yOffsets) { auto x = reinterpret_cast(vx); diff --git a/libnd4j/include/loops/cpu/summarystatsreduce.cpp b/libnd4j/include/loops/cpu/summarystatsreduce.cpp index 5d29532d1..ed398db28 100644 --- a/libnd4j/include/loops/cpu/summarystatsreduce.cpp +++ b/libnd4j/include/loops/cpu/summarystatsreduce.cpp @@ -47,8 +47,8 @@ namespace functions { Nd4jLong *xShapeInfo, void *extraParams, void *z, - Nd4jLong *resultShapeInfoBuffer) { - DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(biasCorrected, x, xShapeInfo, extraParams, z, resultShapeInfoBuffer), SUMMARY_STATS_OPS); + Nd4jLong *zShapeInfo) { + DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(biasCorrected, x, xShapeInfo, extraParams, z, zShapeInfo), SUMMARY_STATS_OPS); } template @@ -58,10 +58,10 @@ namespace functions { Nd4jLong *xShapeInfo, void *extraParams, void *z, - Nd4jLong *resultShapeInfoBuffer, + Nd4jLong *zShapeInfo, int *dimension, int dimensionLength) { - DISPATCH_BY_OPNUM_TT(exec, PARAMS(biasCorrected, x, xShapeInfo, extraParams, z, resultShapeInfoBuffer, dimension, dimensionLength), SUMMARY_STATS_OPS); + DISPATCH_BY_OPNUM_TT(exec, PARAMS(biasCorrected, x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength), SUMMARY_STATS_OPS); } template @@ -71,7 +71,7 @@ namespace functions { Nd4jLong *xShapeInfo, void *vextraParams, void *vz, - Nd4jLong *resultShapeInfoBuffer) { + Nd4jLong *zShapeInfo) { auto z = reinterpret_cast(vz); z[0] = execScalar(biasCorrected, vx, xShapeInfo, vextraParams); } @@ -86,12 +86,12 @@ namespace functions { SummaryStatsData startingIndex; startingIndex.initialize(); auto length = shape::length(xShapeInfo); - + uint xShapeInfoCast[MAX_RANK]; const bool canCast = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - + for (Nd4jLong i = 0; i < length; i++) { - + auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, length, canCast); SummaryStatsData curr; @@ -99,7 +99,7 @@ namespace functions { startingIndex = update(startingIndex, curr, extraParams); } - return OpType::getValue(biasCorrected, startingIndex); + return OpType::getValue(biasCorrected, startingIndex); } template @@ -108,20 +108,31 @@ namespace functions { void *vx, Nd4jLong *xShapeInfo, void *vextraParams, - void *vresult, - Nd4jLong *resultShapeInfoBuffer, + void *vz, + Nd4jLong *zShapeInfo, int *dimension, int dimensionLength) { - auto x = reinterpret_cast(vx); - auto z = reinterpret_cast(vresult); - auto extraParams = reinterpret_cast(vextraParams); - if (shape::isScalar(resultShapeInfoBuffer)) { - z[0] = execScalar(biasCorrected, x, xShapeInfo, extraParams); + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + int resultLength = shape::length(zShapeInfo); + + if(nd4j::ArrayOptions::arrayType(xShapeInfo) == nd4j::ArrayType::EMPTY) { + if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY) + return; + SummaryStatsData comp; + comp.initWithValue(x[0]); + PRAGMA_OMP_PARALLEL_FOR_IF(resultLength > nd4j::Environment::getInstance()->elementwiseThreshold()) + for (uint i = 0; i < resultLength; i++) + z[i] = OpType::getValue(biasCorrected, comp); return; } - + if (shape::isScalar(zShapeInfo)) { + z[0] = execScalar(biasCorrected, x, xShapeInfo, extraParams); + return; + } //no-op if (dimensionLength < 1) @@ -129,7 +140,6 @@ namespace functions { auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(xShapeInfo, dimension, dimensionLength); - int resultLength = shape::length(resultShapeInfoBuffer); //pre squeezed: this is for keeping the pointer to the original //shape information for tad offset //the squeezed information doesn't render the right strides for @@ -149,7 +159,7 @@ namespace functions { PRAGMA_OMP_PARALLEL_FOR for (int r = 0; r < resultLength; r++) { - + auto tadOffsetForBlock = tadPack.primaryOffsets()[r]; auto tx = x + tadOffsetForBlock; SummaryStatsData comp; diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/realdiv.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/realdiv.cpp index b81a184a0..4260f6ffa 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/realdiv.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/realdiv.cpp @@ -131,7 +131,7 @@ namespace nd4j { COPY_SHAPE(x, shapeE); COPY_SHAPE(y, shapeG); - auto shapeList = SHAPELIST(shapeE, shapeG); + auto shapeList = SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG)); return shapeList; } diff --git a/libnd4j/include/ops/declarable/generic/convo/conv1d.cpp b/libnd4j/include/ops/declarable/generic/convo/conv1d.cpp index 2f3739e11..a9898875d 100644 --- a/libnd4j/include/ops/declarable/generic/convo/conv1d.cpp +++ b/libnd4j/include/ops/declarable/generic/convo/conv1d.cpp @@ -81,7 +81,7 @@ CUSTOM_OP_IMPL(conv1d, 2, 1, false, 0, 4) { auto outputReshaped = output ->reshape(output->ordering(), reshapeForOutput); auto weightsReshaped = weights->reshape(weights->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC] - ConvolutionUtils::conv2d(*block.launchContext(), inputReshaped, weightsReshaped, bias, outputReshaped, 1,kW, 1,sW, 0,pW, 1,1, isSameMode, isNCW); + ConvolutionUtils::conv2d(block, inputReshaped, weightsReshaped, bias, outputReshaped, 1,kW, 1,sW, 0,pW, 1,1, isSameMode, isNCW); delete inputReshaped; delete outputReshaped; @@ -217,7 +217,7 @@ CUSTOM_OP_IMPL(conv1d_bp, 3, 2, false, 0, 4) { auto weightsReshaped = weights->reshape(weights->ordering(),{1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC] auto gradWReshaped = gradW ->reshape(gradW->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC] - ConvolutionUtils::conv2dBP(*block.launchContext(), inputReshaped, weightsReshaped, bias, gradOReshaped, gradIReshaped, gradWReshaped, gradB, 1,kW, 1,sW, 0,pW, 1,1, isSameMode, isNCW); + ConvolutionUtils::conv2dBP(block, inputReshaped, weightsReshaped, bias, gradOReshaped, gradIReshaped, gradWReshaped, gradB, 1,kW, 1,sW, 0,pW, 1,1, isSameMode, isNCW); delete inputReshaped; delete gradIReshaped; diff --git a/libnd4j/include/ops/declarable/generic/convo/conv2d.cpp b/libnd4j/include/ops/declarable/generic/convo/conv2d.cpp index 22cc47b36..be9edd538 100644 --- a/libnd4j/include/ops/declarable/generic/convo/conv2d.cpp +++ b/libnd4j/include/ops/declarable/generic/convo/conv2d.cpp @@ -63,7 +63,7 @@ CUSTOM_OP_IMPL(conv2d, 2, 1, false, 0, 9) { if (bias) REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - ConvolutionUtils::conv2d(*block.launchContext(), input, weights, bias, output, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW); + ConvolutionUtils::conv2d(block, input, weights, bias, output, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW); return Status::OK(); } @@ -194,7 +194,7 @@ CUSTOM_OP_IMPL(conv2d_bp, 3, 2, false, 0, 9) { if(bias) REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV2D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - ConvolutionUtils::conv2dBP(*block.launchContext(), input, weights, bias, gradO, gradI, gradW, gradB, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW); + ConvolutionUtils::conv2dBP(block, input, weights, bias, gradO, gradI, gradW, gradB, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW); return Status::OK(); } @@ -305,7 +305,7 @@ CUSTOM_OP_IMPL(conv2d_input_bp, 3, 1, false, 0, 9) { REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "CUSTOM CONV2D_INPUT_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str()); REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, "CUSTOM CONV2D_INPUT_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str()); - ConvolutionUtils::conv2dBP(*block.launchContext(), &input, weights, nullptr, gradO, gradI, nullptr, nullptr, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW); + ConvolutionUtils::conv2dBP(block, &input, weights, nullptr, gradO, gradI, nullptr, nullptr, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/convo/conv3d.cpp b/libnd4j/include/ops/declarable/generic/convo/conv3d.cpp index 5999545e2..a3c578cfc 100644 --- a/libnd4j/include/ops/declarable/generic/convo/conv3d.cpp +++ b/libnd4j/include/ops/declarable/generic/convo/conv3d.cpp @@ -157,7 +157,7 @@ CUSTOM_OP_IMPL(conv3dnew, 2, 1, false, 0, 13) { permutForOutput = {0,2,3,4,1}; // [bS, oC, oD, oH, oW] -> [bS, oD, oH, oW, oC] NDArray columns(input->ordering(), {bS, iC, kD, kH, kW, oD, oH, oW}, input->dataType(), block.launchContext()); - ConvolutionUtils::vol2col(*block.launchContext(), *input, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, iC, iD, iH, iW] is convoluted to [bS, iC, kD, kH, kW, oD, oH, oW] + ConvolutionUtils::vol2col(block, *input, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, iC, iD, iH, iW] is convoluted to [bS, iC, kD, kH, kW, oD, oH, oW] // [bS, iC, kD, kH, kW, oD, oH, oW] x [kD, kH, kW, iC, oC] = [bS, oD, oH, oW, oC] MmulHelper::tensorDot(&columns, weights, output, {1,2,3,4}, {3,0,1,2}, permutForOutput); @@ -456,7 +456,7 @@ CUSTOM_OP_IMPL(conv3dnew_bp, 3, 2, false, 0, 13) { // ----- calculation of gradW and gradB ----- // NDArray columns(input->ordering(), {bS, iC, kD, kH, kW, oD, oH, oW}, input->dataType(), block.launchContext()); - ConvolutionUtils::vol2col(*block.launchContext(), *input, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, iC, iD, iH, iW] is convoluted to [bS, iC, kD, kH, kW, oD, oH, oW] + ConvolutionUtils::vol2col(block, *input, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, iC, iD, iH, iW] is convoluted to [bS, iC, kD, kH, kW, oD, oH, oW] MmulHelper::tensorDot(&columns, gradO, gradW, {0,5,6,7}, gradOaxesForDot, {3,0,1,2,4}); // [bS, iC, kD, kH, kW, oD, oH, oW] x [bS, oD, oH, oW, oC]/[bS, oC, oD, oH, oW] = [iC, kD, kH, kW, oC] if(gradB) { @@ -469,7 +469,7 @@ CUSTOM_OP_IMPL(conv3dnew_bp, 3, 2, false, 0, 13) { //----- calculation of gradI -----// MmulHelper::tensorDot(weights, gradO, &columns, {indWoC}, {indIOioC}, {2,3,4,1,0,5,6,7}); // [kD, kH, kW, iC, oC] x [bS, oD, oH, oW, oC]/[bS, oC, oD, oH, oW] = [kD, kH, kW, iC, bS, oD, oH, oW] - ConvolutionUtils::col2vol(*block.launchContext(), columns, *gradI, sD, sH, sW, pD, pH, pW, dD, dH, dW); // columns [bS, iC, kD, kH, kW, oD, oH, oW] is de-convoluted to [bS, iC, iD, iH, iW] + ConvolutionUtils::col2vol(block, columns, *gradI, sD, sH, sW, pD, pH, pW, dD, dH, dW); // columns [bS, iC, kD, kH, kW, oD, oH, oW] is de-convoluted to [bS, iC, iD, iH, iW] if(!isNDHWC) { delete input; diff --git a/libnd4j/include/ops/declarable/generic/convo/deconv2d_tf.cpp b/libnd4j/include/ops/declarable/generic/convo/deconv2d_tf.cpp index 814193984..801337dc8 100644 --- a/libnd4j/include/ops/declarable/generic/convo/deconv2d_tf.cpp +++ b/libnd4j/include/ops/declarable/generic/convo/deconv2d_tf.cpp @@ -70,7 +70,7 @@ CUSTOM_OP_IMPL(deconv2d_tf, 3, 1, false, 0, 9) { REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "CUSTOM DECONV2D_TF OP: wrong shape of input array, basing on array with output shape expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str()); REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, "CUSTOM DECONV2D_TF OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str()); - ConvolutionUtils::conv2dBP(*block.launchContext(), &input, weights, nullptr, gradO, gradI, nullptr, nullptr, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW); + ConvolutionUtils::conv2dBP(block, &input, weights, nullptr, gradO, gradI, nullptr, nullptr, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/convo/deconv3d.cpp b/libnd4j/include/ops/declarable/generic/convo/deconv3d.cpp index aafa39917..3cdd19e46 100644 --- a/libnd4j/include/ops/declarable/generic/convo/deconv3d.cpp +++ b/libnd4j/include/ops/declarable/generic/convo/deconv3d.cpp @@ -75,7 +75,7 @@ CUSTOM_OP_IMPL(deconv3d, 2, 1, false, 0, 13) { // NDHWC: [kD, kH, kW, oC, iC] x [bS, iD, iH, iW, iC] = [kD, kH, kW, oC, bS, iD, iH, iW] // NCDHW: [iC, oC, kD, kH, kW] x [bS, iC, iD, iH, iW] = [oC, kD, kH, kW, bS, iD, iH, iW] nd4j::MmulHelper::tensorDot(weights, input, &columns, {indWiC}, {indIOioC}, {2, 3, 4, 1, 0, 5, 6, 7}); // [bS, oC, kD, kH, kW, iD, iH, iW] -> [kD, kH, kW, oC, bS, iD, iH, iW] - ConvolutionUtils::col2vol(*block.launchContext(), columns, *output, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, oC, kD, kH, kW, iD, iH, iW] is de-convoluted to [bS, oC, oD, oH, oW] + ConvolutionUtils::col2vol(block, columns, *output, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, oC, kD, kH, kW, iD, iH, iW] is de-convoluted to [bS, oC, oD, oH, oW] //----- add biases if required -----// if(bias) @@ -234,7 +234,7 @@ CUSTOM_OP_IMPL(deconv3d_bp, 3, 2, false, 0, 13) { // ----- calculation of gradW ----- // auto columns = NDArrayFactory::create(input->ordering(), {bS, oC, kD, kH, kW, iD, iH, iW}, input->dataType(), block.launchContext()); - ConvolutionUtils::vol2col(*block.launchContext(), *gradO, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, oC, oD, oH, oW] is deconvoluted to [bS, oC, kD, kH, kW, iD, iH, iW] + ConvolutionUtils::vol2col(block, *gradO, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, oC, oD, oH, oW] is deconvoluted to [bS, oC, kD, kH, kW, iD, iH, iW] MmulHelper::tensorDot(input, &columns, gradW, inputAxesForDot, {0, 5, 6, 7}, {4, 3, 0, 1, 2}); // [bS, iC, iD, iH, iW]/[bS, iD, iH, iW, iC] x [bS, oC, kD, kH, kW, iD, iH, iW] = [iC, oC, kD, kH, kW] // ----- calculation of gradB ----- // diff --git a/libnd4j/include/ops/declarable/generic/convo/depthwiseConv2d.cpp b/libnd4j/include/ops/declarable/generic/convo/depthwiseConv2d.cpp index cc714ed2b..1a0652462 100644 --- a/libnd4j/include/ops/declarable/generic/convo/depthwiseConv2d.cpp +++ b/libnd4j/include/ops/declarable/generic/convo/depthwiseConv2d.cpp @@ -62,7 +62,7 @@ CUSTOM_OP_IMPL(depthwise_conv2d, 2, 1, false, 0, 9) { if (bias) REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DEPTHWISECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - ConvolutionUtils::depthwiseConv2d(*block.launchContext(), input, weights, bias, output, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW); + ConvolutionUtils::depthwiseConv2d(block, input, weights, bias, output, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW); return Status::OK(); } @@ -185,7 +185,7 @@ CUSTOM_OP_IMPL(depthwise_conv2d_bp, 3, 2, false, 0, 9) { if(bias) REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - ConvolutionUtils::depthwiseConv2dBP(*block.launchContext(), input, weights, bias, gradO, gradI, gradW, gradB, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW); + ConvolutionUtils::depthwiseConv2dBP(block, input, weights, bias, gradO, gradI, gradW, gradB, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/convo/pointwiseConv2d.cpp b/libnd4j/include/ops/declarable/generic/convo/pointwiseConv2d.cpp index 0886c4480..69435ecb2 100644 --- a/libnd4j/include/ops/declarable/generic/convo/pointwiseConv2d.cpp +++ b/libnd4j/include/ops/declarable/generic/convo/pointwiseConv2d.cpp @@ -58,7 +58,7 @@ CUSTOM_OP_IMPL(pointwise_conv2d, 2, 1, false, 0, 0) { if (bias) REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM POINTWISECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - ConvolutionUtils::conv2d(*block.launchContext(), input, weights, bias, output, kH,kW, sH,sW, pH,pW, dH,dW, 1/*isSameMode*/, isNCHW); + ConvolutionUtils::conv2d(block, input, weights, bias, output, kH,kW, sH,sW, pH,pW, dH,dW, 1/*isSameMode*/, isNCHW); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/convo/pooling/avgpool2d.cpp b/libnd4j/include/ops/declarable/generic/convo/pooling/avgpool2d.cpp index d7d515c4f..33b152ea6 100644 --- a/libnd4j/include/ops/declarable/generic/convo/pooling/avgpool2d.cpp +++ b/libnd4j/include/ops/declarable/generic/convo/pooling/avgpool2d.cpp @@ -70,7 +70,7 @@ CUSTOM_OP_IMPL(avgpool2d, 1, 1, false, 0, 10) { ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor; - ConvolutionUtils::pooling2d(*block.launchContext(), *input, *output, kH, kW, sH, sW, pH, pW, dH, dW, PoolingType::AVG_POOL, extraParam0); + ConvolutionUtils::pooling2d(block, *input, *output, kH, kW, sH, sW, pH, pW, dH, dW, PoolingType::AVG_POOL, extraParam0); //output->printBuffer("output op"); if (!isNCHW) { @@ -198,7 +198,7 @@ CUSTOM_OP_IMPL(avgpool2d_bp, 2, 1, false, 0, 10) { // *gradI /= kH*kW; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor; - ConvolutionUtils::pooling2dBP(*block.launchContext(), *input, *gradO, *gradI, kH, kW, sH, sW, pH, pW, dH, dW, 1, extraParam0); + ConvolutionUtils::pooling2dBP(block, *input, *gradO, *gradI, kH, kW, sH, sW, pH, pW, dH, dW, 1, extraParam0); if(!isNCHW) { delete input; diff --git a/libnd4j/include/ops/declarable/generic/convo/pooling/avgpool3d.cpp b/libnd4j/include/ops/declarable/generic/convo/pooling/avgpool3d.cpp index 36e95809a..4712edbe5 100644 --- a/libnd4j/include/ops/declarable/generic/convo/pooling/avgpool3d.cpp +++ b/libnd4j/include/ops/declarable/generic/convo/pooling/avgpool3d.cpp @@ -69,7 +69,7 @@ CUSTOM_OP_IMPL(avgpool3dnew, 1, 1, false, 0, 14) { ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); //T extraParams[] = {}; - ConvolutionUtils::pooling3d(*block.launchContext(), *input, *output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, 1, extraParam0); + ConvolutionUtils::pooling3d(block, *input, *output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, 1, extraParam0); if(!isNCDHW) { delete input; @@ -189,7 +189,7 @@ CUSTOM_OP_IMPL(avgpool3dnew_bp, 2, 1, false, 0, 14) { ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor; - ConvolutionUtils::pooling3dBP(*block.launchContext(), *input, *gradO, *gradI, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, 1, extraParam0); + ConvolutionUtils::pooling3dBP(block, *input, *gradO, *gradI, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, 1, extraParam0); if(!isNCDHW) { delete input; diff --git a/libnd4j/include/ops/declarable/generic/convo/pooling/maxpool2d.cpp b/libnd4j/include/ops/declarable/generic/convo/pooling/maxpool2d.cpp index 29ee8d63c..a8ef611c8 100644 --- a/libnd4j/include/ops/declarable/generic/convo/pooling/maxpool2d.cpp +++ b/libnd4j/include/ops/declarable/generic/convo/pooling/maxpool2d.cpp @@ -70,7 +70,7 @@ CUSTOM_OP_IMPL(maxpool2d, 1, 1, false, 0, 9) { ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; poolingMode; 9 - divisor; - ConvolutionUtils::pooling2d(*block.launchContext(), *input, *output, kH, kW, sH, sW, pH, pW, dH, dW, PoolingType::MAX_POOL, 1); + ConvolutionUtils::pooling2d(block, *input, *output, kH, kW, sH, sW, pH, pW, dH, dW, PoolingType::MAX_POOL, 1); if (!isNCHW) { delete input; @@ -196,7 +196,7 @@ CUSTOM_OP_IMPL(maxpool2d_bp, 2, 1, false, 0, 10) { // columns->template applyTransform>(gradI, std::vector({(T)sH, (T)sW, (T)pH, (T)pW, (T)iH, (T)iW, (T)dH, (T)dW}).data()); - ConvolutionUtils::pooling2dBP(*block.launchContext(), *input, *gradO, *gradI, kH, kW, sH, sW, pH, pW, dH, dW, 0., 1.); + ConvolutionUtils::pooling2dBP(block, *input, *gradO, *gradI, kH, kW, sH, sW, pH, pW, dH, dW, 0., 1.); if(!isNCHW) { delete input; diff --git a/libnd4j/include/ops/declarable/generic/convo/pooling/maxpool3d.cpp b/libnd4j/include/ops/declarable/generic/convo/pooling/maxpool3d.cpp index acfb65a44..b5edf2f34 100644 --- a/libnd4j/include/ops/declarable/generic/convo/pooling/maxpool3d.cpp +++ b/libnd4j/include/ops/declarable/generic/convo/pooling/maxpool3d.cpp @@ -70,7 +70,7 @@ CUSTOM_OP_IMPL(maxpool3dnew, 1, 1, false, 0, 14) { if(isSameMode) // SAME ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); - ConvolutionUtils::pooling3d(*block.launchContext(), *input, *output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, 0, 1); + ConvolutionUtils::pooling3d(block, *input, *output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, 0, 1); if(!isNCDHW) { delete input; @@ -204,7 +204,7 @@ CUSTOM_OP_IMPL(maxpool3dnew_bp, 2, 1, false, 0, 14) { // ConvolutionUtils::col2vol(*columns, *gradI, sD, sH, sW, pD, pH, pW, dD, dH, dW); // columns [bS, iC, kD, kH, kW, oD, oH, oW] is de-convoluted to [bS, iC, iD, iH, iW] // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - unnecessary; - ConvolutionUtils::pooling3dBP(*block.launchContext(), *input, *gradO, *gradI, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, 0, 1); + ConvolutionUtils::pooling3dBP(block, *input, *gradO, *gradI, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, 0, 1); if(!isNCDHW) { delete input; diff --git a/libnd4j/include/ops/declarable/generic/convo/pooling/pnormpool2d.cpp b/libnd4j/include/ops/declarable/generic/convo/pooling/pnormpool2d.cpp index 0aa41bf64..6ed620c65 100644 --- a/libnd4j/include/ops/declarable/generic/convo/pooling/pnormpool2d.cpp +++ b/libnd4j/include/ops/declarable/generic/convo/pooling/pnormpool2d.cpp @@ -68,7 +68,7 @@ namespace nd4j { ConvolutionUtils::calcPadding2D(pY, pX, oY, oX, inY, inX, kY, kX, sY, sX, dY, dX); // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor; - ConvolutionUtils::pooling2d(*block.launchContext(), *input, *output, kY, kX, sY, sX, pY, pX, dY, dX, PoolingType::PNORM_POOL, extraParam0); + ConvolutionUtils::pooling2d(block, *input, *output, kY, kX, sY, sX, pY, pX, dY, dX, PoolingType::PNORM_POOL, extraParam0); if (!isNCHW) { delete input; @@ -209,7 +209,7 @@ CUSTOM_OP_IMPL(pnormpool2d_bp, 2, 1, false, 1, 10) { // columns->template applyTransform>(gradI, std::vector({(T)sH, (T)sW, (T)pH, (T)pW, (T)iH, (T)iW, (T)dH, (T)dW}).data()); - ConvolutionUtils::pooling2dBP(*block.launchContext(), *input, *gradO, *gradI, kH, kW, sH, sW, pH, pW, dH, dW, 2, pnorm); + ConvolutionUtils::pooling2dBP(block, *input, *gradO, *gradI, kH, kW, sH, sW, pH, pW, dH, dW, 2, pnorm); if(!isNCHW) { delete input; diff --git a/libnd4j/include/ops/declarable/generic/convo/sconv2d.cpp b/libnd4j/include/ops/declarable/generic/convo/sconv2d.cpp index 2c18ea890..3e81be61d 100644 --- a/libnd4j/include/ops/declarable/generic/convo/sconv2d.cpp +++ b/libnd4j/include/ops/declarable/generic/convo/sconv2d.cpp @@ -84,11 +84,11 @@ CUSTOM_OP_IMPL(sconv2d, 2, 1, false, 0, 9) { if (iC == 1) { nd4j_debug("SCONV2D OP: for input_channels = 1 this op is equivalent to standard conv2d\n",""); - ConvolutionUtils::conv2d(*block.launchContext(), input, weightsDepth, bias, output, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW); + ConvolutionUtils::conv2d(block, input, weightsDepth, bias, output, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW); return Status::OK(); } - ConvolutionUtils::sconv2d(*block.launchContext(), input, weightsDepth, weightsPoint, bias, output, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW); + ConvolutionUtils::sconv2d(block, input, weightsDepth, weightsPoint, bias, output, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW); return Status::OK(); } @@ -274,12 +274,12 @@ CUSTOM_OP_IMPL(sconv2d_bp, 3, 2, false, 0, 9) { auto resultFFShape = isNCHW ? std::vector({bS, mC*iC, oH, oW}) : std::vector({bS, oH, oW, mC*iC}); auto resultFF = NDArrayFactory::create_(input->ordering(), resultFFShape, input->dataType(), block.launchContext()); - ConvolutionUtils::sconv2d(*block.launchContext(), input, weightsDepth, nullptr, nullptr, resultFF, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW); + ConvolutionUtils::sconv2d(block, input, weightsDepth, nullptr, nullptr, resultFF, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW); auto gradIDepthShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC*mC,oH,oW, 0,indIOioC,indIiH,indIiH+1}); auto gradIDepth = NDArrayFactory::create_(resultFF->ordering(), gradIDepthShape, resultFF->dataType(), block.launchContext()); // [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW) - ConvolutionUtils::conv2dBP(*block.launchContext(), resultFF, weightsPoint, bias, gradO, gradIDepth, gradWP, gradB, 1,1, 1,1, 0,0, 1,1, isSameMode, isNCHW); // in this case oH=iH and oW=iW + ConvolutionUtils::conv2dBP(block, resultFF, weightsPoint, bias, gradO, gradIDepth, gradWP, gradB, 1,1, 1,1, 0,0, 1,1, isSameMode, isNCHW); // in this case oH=iH and oW=iW gradO = gradIDepth; bias = gradB = nullptr; // if pointwise backprop was done then don't calculate gradB at depthwise_conv2d_bp step @@ -288,7 +288,7 @@ CUSTOM_OP_IMPL(sconv2d_bp, 3, 2, false, 0, 9) { } // ----- apply depthwise_conv2d_bp ----- // - ConvolutionUtils::depthwiseConv2dBP(*block.launchContext(), input, weightsDepth, bias, gradO, gradI, gradWD, gradB, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW); + ConvolutionUtils::depthwiseConv2dBP(block, input, weightsDepth, bias, gradO, gradI, gradWD, gradB, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW); if(weightsPoint) delete gradO; diff --git a/libnd4j/include/ops/declarable/generic/convo/upsampling2d.cpp b/libnd4j/include/ops/declarable/generic/convo/upsampling2d.cpp index 5d3091eaf..2978feff1 100644 --- a/libnd4j/include/ops/declarable/generic/convo/upsampling2d.cpp +++ b/libnd4j/include/ops/declarable/generic/convo/upsampling2d.cpp @@ -41,7 +41,7 @@ CUSTOM_OP_IMPL(upsampling2d, 1, 1, false, 0, 2) { REQUIRE_TRUE(input->rankOf() == 4, 0, "UPSAMPLING2D op: input should be 4D, but got %i instead!", input->rankOf()); REQUIRE_TRUE(output->rankOf() == 4, 0, "UPSAMPLING2D op: output should be 4D, but got %i instead!", output->rankOf()); - ConvolutionUtils::upsampling2d(*block.launchContext(), *input, *output, factorH, factorW, (bool)isNCHW); + ConvolutionUtils::upsampling2d(block, *input, *output, factorH, factorW, (bool)isNCHW); return Status::OK(); } @@ -105,7 +105,7 @@ CUSTOM_OP_IMPL(upsampling2d_bp, 2, 1, false, 0, 0) { REQUIRE_TRUE(gradO->rankOf() == 4, 0, "UPSAMPLING2D_BP op: output's gradient array must be 4D, but got %i instead!", gradO->rankOf()); REQUIRE_TRUE(gradI->rankOf() == 4, 0, "UPSAMPLING2D_BP op: input's gradient array must be 4D, but got %i instead!", gradI->rankOf()); - ConvolutionUtils::upsampling2dBP(*block.launchContext(), *gradO, *gradI, (bool)isNCHW); + ConvolutionUtils::upsampling2dBP(block, *gradO, *gradI, (bool)isNCHW); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/convo/upsampling3d.cpp b/libnd4j/include/ops/declarable/generic/convo/upsampling3d.cpp index 7b4dc022c..9cdbbffc2 100644 --- a/libnd4j/include/ops/declarable/generic/convo/upsampling3d.cpp +++ b/libnd4j/include/ops/declarable/generic/convo/upsampling3d.cpp @@ -41,7 +41,7 @@ CUSTOM_OP_IMPL(upsampling3d, 1, 1, false, 0, 3) { REQUIRE_TRUE(input->rankOf() == 5, 0, "UPSAMPLING3D op: input should be 5D, but got %i instead!", input->rankOf()); REQUIRE_TRUE(output->rankOf() == 5, 0, "UPSAMPLING3D op: output should be 5D, but got %i instead!", output->rankOf()); - ConvolutionUtils::upsampling3d(*block.launchContext(), *input, *output, factorD, factorH, factorW, (bool)isNCDHW); + ConvolutionUtils::upsampling3d(block, *input, *output, factorD, factorH, factorW, (bool)isNCDHW); return Status::OK(); } @@ -105,7 +105,7 @@ CUSTOM_OP_IMPL(upsampling3d_bp, 2, 1, false, 0, 0) { REQUIRE_TRUE(gradO->rankOf() == 5, 0, "UPSAMPLING3D_BP op: output's gradient array must be 4D, but got %i instead!", gradO->rankOf()); REQUIRE_TRUE(gradI->rankOf() == 5, 0, "UPSAMPLING3D_BP op: input's gradient array must be 4D, but got %i instead!", gradI->rankOf()); - ConvolutionUtils::upsampling3dBP(*block.launchContext(), *gradO, *gradI, (bool)isNCDHW); + ConvolutionUtils::upsampling3dBP(block, *gradO, *gradI, (bool)isNCDHW); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/nn/batchnorm.cpp b/libnd4j/include/ops/declarable/generic/nn/batchnorm.cpp index 4bd9ba23a..eafa266dd 100644 --- a/libnd4j/include/ops/declarable/generic/nn/batchnorm.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/batchnorm.cpp @@ -216,7 +216,7 @@ CUSTOM_OP_IMPL(batchnorm_new, 3, 1, false, 1, 2) { } std::vector shape({2, mean->lengthOf()}); - NDArray weights = NDArrayFactory::create('c', shape, block.getWorkspace()); + NDArray weights = NDArrayFactory::create('c', shape, block.launchContext()); weights({0, 1, 0, 0}).assign(1.0f); weights({1, 2, 0, 0}).assign(0.0f); diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/argmax.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/argmax.cpp index cbe97ebe4..bdfdfb6c6 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/argmax.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/argmax.cpp @@ -72,6 +72,11 @@ namespace nd4j { if (dims.size() > 1) std::sort(dims.begin(), dims.end()); + + for (auto d:dims) { + REQUIRE_TRUE(inputShape->at(0)[d+1] != 0, 0, "ArgMax: you can't reduce along axis with 0 in shape"); + } + // special case - output is scalar if (dims.size() == 0 || (dims.size() == 1 && dims.at(0) == MAX_INT)) { return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(nd4j::DataType::INT64)); diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/argmin.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/argmin.cpp index 6e8b99985..a80194eb2 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/argmin.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/argmin.cpp @@ -72,6 +72,10 @@ namespace nd4j { if (dims.size() > 1) std::sort(dims.begin(), dims.end()); + for (auto d:dims) { + REQUIRE_TRUE(inputShape->at(0)[d+1] != 0, 0, "ArgMin: you can't reduce along axis with 0 in shape"); + } + // special case - output is scalar if (dims.size() == 0 || (dims.size() == 1 && dims.at(0) == MAX_INT)) { return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(DataType::INT64)); diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/fill.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/fill.cpp index c049e1116..9fc508860 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/fill.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/fill.cpp @@ -71,10 +71,8 @@ namespace nd4j { ALLOCATE(newShape, block.getWorkspace(), shape::shapeInfoLength(len), Nd4jLong); newShape[0] = len; - auto empty = false; for (int e = 0; e < shapeArray->lengthOf(); e++){ newShape[e+1] = shapeArray->e(e); - empty |= (newShape[e+1] == 0); //Support "zeros in shape as empty" for TF import } nd4j::DataType dataType; @@ -90,10 +88,6 @@ namespace nd4j { } else throw std::runtime_error("Fill: missing value to fill output array with"); - if(empty){ - return SHAPELIST(ShapeBuilders::emptyShapeInfo(dataType, block.getWorkspace())); - } - ShapeUtils::updateStridesAndType(newShape, dataType, 'c'); return SHAPELIST(CONSTANT(newShape)); diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/range.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/range.cpp index 007b42daf..04b5b48d6 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/range.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/range.cpp @@ -151,8 +151,10 @@ DECLARE_SHAPE_FN(range) { delta = INPUT_VARIABLE(2)->e(0); } - if (limit == start) - return SHAPELIST(ConstantShapeHelper::getInstance()->emptyShapeInfo(dtype)); + if (limit == start){ + //Return [0] to match TF + return SHAPELIST(ConstantShapeHelper::getInstance()->vectorShapeInfo(0, dtype)); + } REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !"); @@ -177,8 +179,10 @@ DECLARE_SHAPE_FN(range) { //nd4j_printf("Start: [%lld]; Limit: [%lld]; Delta: [%lld];\n", start, limit, delta) - if (limit == start) - return SHAPELIST(ConstantShapeHelper::getInstance()->emptyShapeInfo(dtype)); + if (limit == start){ + //Return [0] to match TF + return SHAPELIST(ConstantShapeHelper::getInstance()->vectorShapeInfo(0, dtype)); + } REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !"); @@ -203,8 +207,10 @@ DECLARE_SHAPE_FN(range) { delta = INT_ARG(2); } - if (limit == start) - return SHAPELIST(ConstantShapeHelper::getInstance()->emptyShapeInfo(nd4j::DataType::INT32)); + if (limit == start){ + //Return [0] to match TF + return SHAPELIST(ConstantShapeHelper::getInstance()->vectorShapeInfo(0, nd4j::DataType::INT32)); + } REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !"); @@ -233,9 +239,10 @@ DECLARE_SHAPE_FN(range) { delta = T_ARG(2); } - //REQUIRE_TRUE(limit != start, 0, "CUSTOM RANGE OP: limit and start values should be different, but got both equal to %f !", limit); - if (limit == start) - return SHAPELIST(ConstantShapeHelper::getInstance()->emptyShapeInfo(Environment::getInstance()->defaultFloatDataType())); + if (limit == start){ + //Return [0] to match TF + return SHAPELIST(ConstantShapeHelper::getInstance()->vectorShapeInfo(0, Environment::getInstance()->defaultFloatDataType())); + } REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !"); diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/rank.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/rank.cpp index 47fe792de..7a15967d5 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/rank.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/rank.cpp @@ -31,7 +31,8 @@ namespace nd4j { REQUIRE_TRUE(output->isScalar(), 0, "Rank output should be scalar"); - output->assign(static_cast(input->rankOf())); +// output->assign(static_cast(input->rankOf())); + output->assign(input->rankOf()); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/reduce_max.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/reduce_max.cpp index 35cf4dc48..326b49c65 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/reduce_max.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/reduce_max.cpp @@ -88,7 +88,7 @@ DECLARE_TYPES(reduce_max) { ->setSameMode(true); } -#endif +#endif #if NOT_EXCLUDED(OP_reduce_max_bp) ////////////////////////////////////////////////////////////////////////// diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/slice.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/slice.cpp index f6090875e..d2a390eb9 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/slice.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/slice.cpp @@ -132,17 +132,13 @@ namespace nd4j { REQUIRE_TRUE(size == -1 || size >= 0, 0, "Invalid size[%i] value: must be positive (or -1 for 'all remaining'), got %i", e, size, inShape[e+1]); REQUIRE_TRUE(start >= 0 && start <= inShape[e+1], 0, "Invalid begin[%i] value: Begin must satisfy 0 <= begin <= size[i], got begin=%i for dimension size %i", e, start, inShape[e+1]); REQUIRE_TRUE(start + size <= inShape[e+1], 0, "Slice: interval [%i, %i] is out of bounds for dimension %i with size %i", start, start + size, e, inShape[e+1]); - if(start == inShape[e+1] || size == 0 ){ - empty = true; + if(start == inShape[e+1] ){ + size = 0; } shape.emplace_back(size); } - if(empty){ - return SHAPELIST(ShapeBuilders::emptyShapeInfo(nd4j::DataType::INT32, block.getWorkspace())); - } - auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShape), 'c', shape); return SHAPELIST(newShape); } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/stack.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/stack.cpp index 72013dcc2..f251551d0 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/stack.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/stack.cpp @@ -33,6 +33,10 @@ CUSTOM_OP_IMPL(stack, -1, 1, false, 0, 0) { int dim = block.getIArguments()->size() > 0 ? INT_ARG(0) : 0; if(dim < 0) dim += input->rankOf() + 1; + + // no-op in case of empty output array + if (output->isEmpty()) + return Status::OK(); // input validation // check whether shapes of all input array are the same @@ -47,16 +51,6 @@ CUSTOM_OP_IMPL(stack, -1, 1, false, 0, 0) { inArrs[i] = INPUT_VARIABLE(i); helpers::stack(block.launchContext(), inArrs, output, dim); - - // remove unity from output shape if input arrays are vectors - // if(input->isVector()) { - // std::vector outShape(output->shapeOf(), output->shapeOf() + output->rankOf()); - // outShape.erase(find(outShape.begin(), outShape.end(), 1)); - // output->reshapei(output->ordering(), outShape); - // if(dim != 0 && (int)block.width() == 1) // such is implemented by tensorFlow - // output->permutei({1, 0}); - // output->getShapeInfo()[output->rankOf()*2 + 2] = 1; - // } return Status::OK(); } @@ -81,9 +75,23 @@ DECLARE_SHAPE_FN(stack) { dim += rank + 1; REQUIRE_TRUE(dim <= inShapeInfo[0], 0, "STACK op: the input dimension parameter must be <= rank of input arrays shapes (rank=%i), but got %i instead !", inShapeInfo[0], dim); - + + // empty input arrays require some special handling + if (shape::isEmpty(inShapeInfo)) { + switch (rank) { + case 0: { + // we're going to return rank 1 here + if (block.width() == 1) { + return SHAPELIST(ConstantShapeHelper::getInstance()->vectorShapeInfo(0, ArrayOptions::dataType(inShapeInfo))); + } else { + return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShapeInfo), 'c', {(Nd4jLong) block.width(), 0})); + } + } + } + } + if(rank == 0) { - return SHAPELIST(ConstantShapeHelper::getInstance()->vectorShapeInfo(block.width(), ArrayOptions::dataType(inShapeInfo))); + return SHAPELIST(ConstantShapeHelper::getInstance()->vectorShapeInfo(block.width(), ArrayOptions::dataType(inShapeInfo))); } //the rank of output ShapeInfo is larger by one compared to input ShapeInfo @@ -91,13 +99,9 @@ DECLARE_SHAPE_FN(stack) { // insert (int) block.width() at dim position of input shape to get output shape outShape.insert(outShape.begin() + Nd4jLong(dim), (Nd4jLong) block.width()); - return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(inShapeInfo), shape::order(inShapeInfo), outShape))); + return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(inShapeInfo), shape::order(inShapeInfo), outShape))); } -// 1) 1Ñ…4 + 1Ñ…4 = 2Ñ…1Ñ…4 (along dim=0) = 2x4 -// 2) 1Ñ…4 + 1Ñ…4 = 1Ñ…2Ñ…4 (along dim=1) = 2x4 -// 3) 4Ñ…1 + 4Ñ…1 = 2Ñ…4x1 (along dim=0) = 2x4 -// 4) 4Ñ…1 + 4Ñ…1 = 4Ñ…2x1 (along dim=1) = 4x2 } } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/strided_slice.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/strided_slice.cpp index 582cce9df..d768516fe 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/strided_slice.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/strided_slice.cpp @@ -410,10 +410,13 @@ namespace nd4j { // z->assign(x->e(indices[0])); // } // else { - auto sub = (*x)(indices, true, true); - z->assign(sub); -// } - + if (indices.size()) { + auto sub = (*x)(indices, true, true); + z->assign(sub); + } + else if (!z->isEmpty()){ + z->assign(x->e(0)); + } return Status::OK(); } DECLARE_SYN(stridedslice, strided_slice); @@ -496,28 +499,19 @@ namespace nd4j { bool is_simple_slice; bool is_dim0; - // FIXME: remove this, once we bring in 1D NDArrays - //vectorize(input_shape); - bool result = _preprocess_strided_slice(nullptr, &shape, input_shape, begin, end, strides, begin_mask, ellipsis_mask, end_mask, new_axis_mask, shrink_axis_mask, &is_identity, &is_simple_slice, &is_dim0); - bool nonEmpty = shape.size() > 0; - if (nonEmpty) - for (auto x: shape) { - if (x == 0) { - nonEmpty = false; - break; - } - } - if (nonEmpty && inputLen > 1) { - newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShape), 'c', shape); - } - else { - if (shape::rank(inShape) == 0 || begin >= end) { - newShape = ConstantShapeHelper::getInstance()->emptyShapeInfo(ArrayOptions::dataType(inShape)); + std::vector indices; + bool result = _preprocess_strided_slice(&indices, &shape, input_shape, begin, end, strides, begin_mask, ellipsis_mask, end_mask, new_axis_mask, shrink_axis_mask, &is_identity, &is_simple_slice, &is_dim0); + if (indices.size()) { + newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShape), 'c', + shape); + if (inputLen > 1) { + newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShape), 'c', + shape); } else { newShape = ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(inShape)); - } - } + } else + newShape = ConstantShapeHelper::getInstance()->emptyShapeInfo(ArrayOptions::dataType(inShape)); return SHAPELIST(newShape); } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/unstack.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/unstack.cpp index 5a6f6b12e..f6ac319ab 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/unstack.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/unstack.cpp @@ -37,6 +37,9 @@ namespace nd4j { REQUIRE_TRUE(dim < input->rankOf(), 0, "Unstack dimension should be lower then rank of input %i, but got dimension=%i !", input->rankOf(), dim); REQUIRE_TRUE(dim >= 0, 0, "Unstack dimension should be non-negative value, but got %i !", dim); + if(input->isEmpty()) + return Status::OK(); + std::vector dims; for (int e = 0; e < input->rankOf(); e++) if (e != dim) @@ -65,7 +68,7 @@ namespace nd4j { return Status::OK(); } DECLARE_SYN(unpack, unstack); - + DECLARE_SHAPE_FN(unstack) { auto inShape = inputShape->at(0); @@ -76,6 +79,21 @@ namespace nd4j { REQUIRE_TRUE(dim < inShape[0], 0, "UNSTACK op: dimension should be lower then rank of input %i, but got dimension=%i !", inShape[0], dim); REQUIRE_TRUE(dim >= 0, 0, "UNSTACK op: dimension should be non-negative value, but got %i !", dim); + if(ArrayOptions::arrayType(inShape) == ArrayType::EMPTY) { + if(shape::shapeOf(inShape)[dim] == 0) + return SHAPELIST(); + const Nd4jLong numTads = shape::shapeOf(inShape)[dim]; + std::vector outShape; + for(uint i = 0; i < shape::rank(inShape); ++i) + if(i != dim) + outShape.push_back(shape::shapeOf(inShape)[i]); + + auto result = SHAPELIST(); + for(uint i = 0; i < numTads; ++i) + result->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShape), shape::order(inShape), outShape)); + return result; + } + std::vector dims; for (int e = 0; e < shape::rank(inShape); e++) if (e != dim) diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/zero_fraction.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/zero_fraction.cpp index 137853400..2b623d23e 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/zero_fraction.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/zero_fraction.cpp @@ -30,6 +30,12 @@ namespace nd4j { auto output = OUTPUT_VARIABLE(0); REQUIRE_TRUE(output->isScalar(), 0, "Rank output should be scalar"); + + if(input->isEmpty()){ + output->p(0, std::numeric_limits::quiet_NaN()); + return Status::OK(); + } + int numZeros = 0; // for (int e = 0; e < input->lengthOf(); e++) // if ((*input)(e) == T(0)) diff --git a/libnd4j/include/ops/declarable/generic/recurrent/lstmBlock.cpp b/libnd4j/include/ops/declarable/generic/recurrent/lstmBlock.cpp index 5d96478e7..debca0053 100644 --- a/libnd4j/include/ops/declarable/generic/recurrent/lstmBlock.cpp +++ b/libnd4j/include/ops/declarable/generic/recurrent/lstmBlock.cpp @@ -113,16 +113,10 @@ DECLARE_SHAPE_FN(lstmBlock) { } ShapeUtils::updateStridesAndType(s, x, 'c'); - Nd4jLong *s1, *s2, *s3, *s4, *s5, *s6; - COPY_SHAPE(s, s1); - COPY_SHAPE(s, s2); - COPY_SHAPE(s, s3); - COPY_SHAPE(s, s4); - COPY_SHAPE(s, s5); - COPY_SHAPE(s, s6); + Nd4jLong *s1 = CONSTANT(s); //7 outputs, all same shape/type - return SHAPELIST(s, s1, s2, s3, s4, s5, s6); + return SHAPELIST(s1, s1, s1, s1, s1, s1, s1); } } diff --git a/libnd4j/include/ops/declarable/generic/recurrent/lstmBlockCell.cpp b/libnd4j/include/ops/declarable/generic/recurrent/lstmBlockCell.cpp index 3c8a62d12..446b523c1 100644 --- a/libnd4j/include/ops/declarable/generic/recurrent/lstmBlockCell.cpp +++ b/libnd4j/include/ops/declarable/generic/recurrent/lstmBlockCell.cpp @@ -115,16 +115,10 @@ DECLARE_SHAPE_FN(lstmBlockCell) { ShapeUtils::updateStridesAndType(s, xt, 'c'); - Nd4jLong *s1, *s2, *s3, *s4, *s5, *s6; - COPY_SHAPE(s, s1); - COPY_SHAPE(s, s2); - COPY_SHAPE(s, s3); - COPY_SHAPE(s, s4); - COPY_SHAPE(s, s5); - COPY_SHAPE(s, s6); + Nd4jLong *s1 = CONSTANT(s); //7 outputs, all same shape: z, i, f, o, h, c, y - return SHAPELIST(s, s1, s2, s3, s4, s5, s6); + return SHAPELIST(s1, s1, s1, s1, s1, s1, s1); } } diff --git a/libnd4j/include/ops/declarable/generic/shape/broadcast_to.cpp b/libnd4j/include/ops/declarable/generic/shape/broadcast_to.cpp index 706fe50aa..3c9030058 100644 --- a/libnd4j/include/ops/declarable/generic/shape/broadcast_to.cpp +++ b/libnd4j/include/ops/declarable/generic/shape/broadcast_to.cpp @@ -78,7 +78,6 @@ DECLARE_SHAPE_FN(broadcast_to) { REQUIRE_TRUE(inputShapeInfo[inputRank+1-i] == outShape[shapeLen-i] || inputShapeInfo[inputRank+1-i] == 1, 0, "BROADCAST_TO op: shape of input array %s can't be broadcasted to the shape %s !", ShapeUtils::shapeAsString(inputShapeInfo).c_str(), ShapeUtils::shapeAsString(outShape).c_str()); auto outShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inputShapeInfo), shape::order(inputShapeInfo), outShape); - return SHAPELIST(outShapeInfo); } diff --git a/libnd4j/include/ops/declarable/generic/shape/permute.cpp b/libnd4j/include/ops/declarable/generic/shape/permute.cpp index 76f84df66..7c21e73e4 100644 --- a/libnd4j/include/ops/declarable/generic/shape/permute.cpp +++ b/libnd4j/include/ops/declarable/generic/shape/permute.cpp @@ -34,26 +34,26 @@ namespace nd4j { bool replace = false; - auto arguments = block.getIArguments(); - if (block.width() == 2 && arguments->size() == 0) { - auto axis = INPUT_VARIABLE(1); - for (int e = 0; e < axis->lengthOf(); e++) { - int ax = axis->e(e); + auto origArgs = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT() : *block.getIArguments(); + std::vector arguments({}); + if(origArgs.size() > 0){ + for (int e = 0; e < origArgs.size(); e++) { + int ax = origArgs[e]; if (ax < 0) ax += x->rankOf(); - arguments->emplace_back(ax); + arguments.emplace_back(ax); } replace = true; - } else if (arguments->size() == 0) { + } else { for (int e = x->rankOf() - 1; e >= 0; e--) - arguments->emplace_back(e); + arguments.emplace_back(e); } // 0D edge case if (x->rankOf() == 0) { - REQUIRE_TRUE(arguments->size() == 1, 0, "Permute: only one axis is allowed for scalar"); + REQUIRE_TRUE(arguments.size() == 1, 0, "Permute: only one axis is allowed for scalar"); auto output = OUTPUT_VARIABLE(0); if (!block.isInplace()) output->assign(x); @@ -62,25 +62,17 @@ namespace nd4j { } if(block.isInplace()) { // in-place - x->permutei(*arguments); + x->permutei(arguments); STORE_RESULT(x); - } else { - if (!replace) { // not-in-place - auto output = OUTPUT_VARIABLE(0); - // nd4j_printv("permute shape", *arguments); - auto result = x->permute(*arguments); - output->assign(result); - STORE_RESULT(output); - delete result; - } else { - auto output = OUTPUT_VARIABLE(0); //->dup(); - output->assign(x); - output->permutei(*arguments); - - //OVERWRITE_RESULT(output); - } + } else { + auto output = OUTPUT_VARIABLE(0); + auto result = x->permute(arguments); + output->assign(result); + STORE_RESULT(output); + delete result; } - return Status::OK(); + + return Status::OK(); } DECLARE_TYPES(permute) { @@ -92,20 +84,21 @@ namespace nd4j { DECLARE_SHAPE_FN(permute) { auto shapeList = SHAPELIST(); - auto arguments = block.getIArguments(); + auto arguments = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT() : *block.getIArguments(); + if (shape::rank(inputShape->at(0)) == 0) { shapeList->push_back(ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(inputShape->at(0)))); - } else if (inputShape->size() == 1 && !arguments->empty()) { - shapeList->push_back(ShapeUtils::evalPermShapeInfo(arguments->data(), arguments->size(), *INPUT_VARIABLE(0), block.workspace())); - } else if (inputShape->size() == 2) { - // dead end - shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(inputShape->at(0)))); + } else if (inputShape->size() == 1 && !arguments.empty()) { + shapeList->push_back(ShapeUtils::evalPermShapeInfo(arguments.data(), arguments.size(), *INPUT_VARIABLE(0), block.workspace())); } else { - int rank = shape::rank(inputShape->at(0)); - for (int e = rank - 1; e >= 0; e--) - arguments->emplace_back(e); + if(arguments.size() == 0){ + //Reverse dimensions + int rank = shape::rank(inputShape->at(0)); + for (int e = rank - 1; e >= 0; e--) + arguments.emplace_back(e); + } - shapeList->push_back(ShapeUtils::evalPermShapeInfo(arguments->data(), arguments->size(), *INPUT_VARIABLE(0), block.workspace())); + shapeList->push_back(ShapeUtils::evalPermShapeInfo(arguments.data(), arguments.size(), *INPUT_VARIABLE(0), block.workspace())); } return shapeList; diff --git a/libnd4j/include/ops/declarable/generic/shape/reshape.cpp b/libnd4j/include/ops/declarable/generic/shape/reshape.cpp index 53106e48e..63e35d90e 100644 --- a/libnd4j/include/ops/declarable/generic/shape/reshape.cpp +++ b/libnd4j/include/ops/declarable/generic/shape/reshape.cpp @@ -35,9 +35,8 @@ namespace nd4j { auto arguments = block.getIArguments(); int argsSize = arguments->size(); - //Special case: empty.reshape(-1) -> return empty + //Special case: empty.reshape() -> return empty if (x->isEmpty()) { - REQUIRE_TRUE((int) arguments->size() == 1 && arguments->at(0) == -1, 0, "Reshape: when input is empty, iargs must be [-1]"); REQUIRE_TRUE(OUTPUT_VARIABLE(0)->isEmpty(), 0, "Reshape: when input is empty, output must also be empty"); return ND4J_STATUS_OK; //No op } @@ -96,9 +95,9 @@ namespace nd4j { //Special case: empty.reshape(-1) -> return empty if (x->isEmpty()) { - REQUIRE_TRUE(s->lengthOf() == 1 && s->e(0) == -1, 0, "Reshape: when input is empty, shape must be [-1]"); + //REQUIRE_TRUE(s->lengthOf() == 1 && s->e(0) == -1, 0, "Reshape: when input is empty, shape must be [-1]"); REQUIRE_TRUE(OUTPUT_VARIABLE(0)->isEmpty(), 0, "Reshape: when input is empty, output must also be empty"); - return ND4J_STATUS_OK; //No op + return Status::OK(); //No op } char order = 'c'; @@ -116,7 +115,8 @@ namespace nd4j { } for(int e2 = e + 1; e2 < (int) s->lengthOf(); e2++){ REQUIRE_TRUE(s->e(e2) != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed."); - shapeLength *= s->e(e2); + shapeLength *= + s->e(e2); } long realShape = x->lengthOf() / shapeLength; shapeNew[e] = realShape; @@ -175,12 +175,12 @@ namespace nd4j { e = 0; } - //Special case: empty.reshape(-1) -> return empty - if (INPUT_VARIABLE(0)->isEmpty()) { - REQUIRE_TRUE((int) arguments->size() == 1 && arguments->at(0) == -1, 0, "Reshape: when input is empty, iargs must be [-1]"); - auto newShape = ConstantShapeHelper::getInstance()->emptyShapeInfo(ArrayOptions::dataType(inp)); - return SHAPELIST(newShape); - } +// //Special case: empty.reshape(-1) -> return empty +// if (INPUT_VARIABLE(0)->isEmpty()) { +// // +// auto newShape = ConstantShapeHelper::getInstance()->emptyShapeInfo(ArrayOptions::dataType(inp)); +// return SHAPELIST(newShape); +// } std::vector shapeNew; @@ -197,8 +197,14 @@ namespace nd4j { shapeLength *= arguments->at(e2); } - long realShape = shape::length(inp) / shapeLength; - shapeNew.push_back(realShape); + if(shapeLength == 0){ + //Edge case for empty: + shapeNew.push_back(0); + } else { + //Standard case + long realShape = shape::length(inp) / shapeLength; + shapeNew.push_back(realShape); + } } else{ shapeNew.push_back(arguments->at(e)); @@ -218,9 +224,16 @@ namespace nd4j { } //Special case: empty.reshape(-1) -> return empty if (x->isEmpty()) { - REQUIRE_TRUE(y->lengthOf() == 1 && y->e(0) == -1, 0, "Reshape: when input is empty, shape must be [-1]"); - auto newShape = ConstantShapeHelper::getInstance()->emptyShapeInfo(ArrayOptions::dataType(inp)); - return SHAPELIST(newShape); + //REQUIRE_TRUE(y->lengthOf() == 1 && y->e(0) == -1, 0, "Reshape: when input is empty, shape must be [-1]"); + auto shapeOf = y->getBufferAsVector(); + Nd4jLong prod = 1; + for (auto v:shapeOf) + prod *= v; + + REQUIRE_TRUE(prod == 0, 0, "Reshape: in case of empty arrays reshape must return empty array as well"); + + auto newShape = ShapeBuilders::createShapeInfo(ArrayOptions::dataType(inp), shape::order(inp), y->lengthOf(), shapeOf.data()); + return SHAPELIST(CONSTANT(newShape)); } std::vector shapeNew(y->lengthOf()); @@ -236,8 +249,14 @@ namespace nd4j { REQUIRE_TRUE(y->e(e2) != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed."); shapeLength *= y->e(e2); } - long realShape = shape::length(inp) / shapeLength; - shapeNew[e] = realShape; + + if(shapeLength == 0){ + //Edge case for empty: + shapeNew[e] = 0; + } else { + long realShape = shape::length(inp) / shapeLength; + shapeNew[e] = realShape; + } }else { shapeNew[e] = dim; } diff --git a/libnd4j/include/ops/declarable/generic/transforms/concat.cpp b/libnd4j/include/ops/declarable/generic/transforms/concat.cpp index 3de2d297d..ac211d17d 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/concat.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/concat.cpp @@ -38,26 +38,31 @@ CUSTOM_OP_IMPL(concat, -1, 1, false, 0, 1) { std::vector arrsToDelete; int index = 0; bool allOfSameType = true; - + auto theFirstRank = block.width() > 0?INPUT_VARIABLE(0)->rankOf():0; + auto theFirstDatatype = block.width() > 0?INPUT_VARIABLE(0)->dataType():block.dataType(); for(int i = 0; i < block.width(); ++i) { - - if(!INPUT_VARIABLE(i)->isEmpty()) { - - allOfSameType &= (INPUT_VARIABLE(0)->dataType() == INPUT_VARIABLE(i)->dataType()); - if(INPUT_VARIABLE(i)->rankOf() == 0) { - // FIXME, use this instead: block.dataType() - auto vec = new NDArray('c', {1}, INPUT_VARIABLE(0)->dataType(), block.launchContext()); - vec->assign(INPUT_VARIABLE(i)); + auto input = INPUT_VARIABLE(i); + auto currentRank = input->rankOf(); + +// TODO: follow two lines are accordingly with current tf.concat spec. Commented for compatibility with legacy +// REQUIRE_TRUE(currentRank > 0, 0, "Rank of input variable %i must be greater 0, but is %lld instead.", i, currentRank); +// REQUIRE_TRUE(theFirstRank == currentRank, 0, "Number of dimensions in concat should be equals, but for %i input variable %lld != %lld appears.", i, currentRank, theFirstRank); + if(!input->isEmpty()) { + + allOfSameType &= (theFirstDatatype == input->dataType()); + if(input->rankOf() == 0) { + auto vec = new NDArray('c', {1}, input->dataType(), block.launchContext()); + vec->assign(input); nonEmptyArrs.push_back(vec); arrsToDelete.push_back(index); } else{ - nonEmptyArrs.push_back(INPUT_VARIABLE(i)); + nonEmptyArrs.push_back(input); } ++index; } } - + const int numOfArrs = nonEmptyArrs.size(); if(numOfArrs == 0){ @@ -73,21 +78,21 @@ CUSTOM_OP_IMPL(concat, -1, 1, false, 0, 1) { REQUIRE_TRUE(allOfSameType, 0, "CONCAT op: all of input arrays must have same type !"); REQUIRE_TRUE(0 <= axis && (axis < rank || (axis == 0 && rank == 0)), 0, "CONCAT op: input axis must be in range [0, %i], but got %i instead!", rank-1, axis); - for(int i = 1; i < numOfArrs; ++i) + for(int i = 1; i < numOfArrs; ++i) REQUIRE_TRUE(nonEmptyArrs[i]->rankOf() == rank, 0, "CONCAT op: all input arrays must have the same rank !"); - for(int i = 1; i < numOfArrs; ++i) { + for(int i = 1; i < numOfArrs; ++i) { for(int dim = 0; dim < rank; ++dim) - if(dim != axis) + if(dim != axis) REQUIRE_TRUE(nonEmptyArrs[i]->sizeAt(dim) == nonEmptyArrs[0]->sizeAt(dim), 0, "CONCAT op: all input arrays must have the same dimensions (except those on input axis) !"); } // ******** end of input validation ******** // auto output = OUTPUT_VARIABLE(0); - if(numOfArrs == 1) + if(numOfArrs == 1) output->assign(nonEmptyArrs[0]); - else + else helpers::concat(block.launchContext(), nonEmptyArrs, *output, axis); // delete dynamically allocated vectors with length=1 @@ -110,36 +115,27 @@ CUSTOM_OP_IMPL(concat, -1, 1, false, 0, 1) { DECLARE_SHAPE_FN(concat) { REQUIRE_TRUE(block.width() > 0, 0, "CONCAT op: No input arrays were provided"); - + // first of all take into account possible presence of empty arrays - // also if scalar is present -> use the shape of vector with length=1 instead - std::vector nonEmptyArrShapes; + // also if scalar is present -> use the shape of vector with length=1 instead + std::vector arrShapes; std::vector shapesToDelete; int index = 0; for(int i = 0; i < block.width(); ++i) { - - if(!INPUT_VARIABLE(i)->isEmpty()) { - - if(inputShape->at(i)[0] == 0) { - // FIXME, use this instead: block.dataType() - nonEmptyArrShapes.push_back(ConstantShapeHelper::getInstance()->vectorShapeInfo(1, INPUT_VARIABLE(0)->dataType())); - } - else{ - nonEmptyArrShapes.push_back(inputShape->at(i)); - } - ++index; + + if(inputShape->at(i)[0] == 0) { + // FIXME, use this instead: block.dataType() + arrShapes.push_back(ConstantShapeHelper::getInstance()->vectorShapeInfo(1, INPUT_VARIABLE(0)->dataType())); } + else{ + arrShapes.push_back(inputShape->at(i)); + } + ++index; } - const int numOfArrs = nonEmptyArrShapes.size(); + const int numOfArrs = arrShapes.size(); - if(numOfArrs == 0){ - //All inputs are empty arrays -> return empty, mainly for TF import compatibility - auto empty = ConstantShapeHelper::getInstance()->emptyShapeInfo(INPUT_VARIABLE(0)->dataType()); - return SHAPELIST(empty); - } - - const int rank = nonEmptyArrShapes[0][0]; // look up to first non-empty array + const int rank = arrShapes[0][0]; int axis = INT_ARG(0); if(axis < 0) @@ -148,34 +144,34 @@ DECLARE_SHAPE_FN(concat) { // ******** input validation ******** // REQUIRE_TRUE(0 <= axis && axis < rank, 0, "CONCAT op: input axis must be in range [0, %i], but got %i instead!", rank-1, axis); - for(int i = 1; i < numOfArrs; ++i) - REQUIRE_TRUE(nonEmptyArrShapes[i][0] == rank, 0, "CONCAT op: all input arrays must have the same rank !"); + for(int i = 1; i < numOfArrs; ++i) + REQUIRE_TRUE(arrShapes[i][0] == rank, 0, "CONCAT op: all input arrays must have the same rank !"); - for(int i = 1; i < numOfArrs; ++i) { + for(int i = 1; i < numOfArrs; ++i) { for(int dim = 0; dim < rank; ++dim) - if(dim != axis) - REQUIRE_TRUE(nonEmptyArrShapes[i][dim+1] == nonEmptyArrShapes[0][dim+1], 0, "CONCAT op: all input arrays must have the same dimensions (except those on input axis) !"); + if(dim != axis) + REQUIRE_TRUE(arrShapes[i][dim+1] == arrShapes[0][dim+1], 0, "CONCAT op: all input arrays must have the same dimensions (except those on input axis) !"); } // ******** end of input validation ******** // - + Nd4jLong* outShapeInfo(nullptr); - COPY_SHAPE(nonEmptyArrShapes[0], outShapeInfo); - + COPY_SHAPE(arrShapes[0], outShapeInfo); + // case when we have only one input array - if(numOfArrs == 1) { - ShapeUtils::updateStridesAndType(outShapeInfo, nonEmptyArrShapes[0], shape::order(nonEmptyArrShapes[0])); + if(numOfArrs == 1) { + ShapeUtils::updateStridesAndType(outShapeInfo, arrShapes[0], shape::order(arrShapes[0])); return SHAPELIST(CONSTANT(outShapeInfo)); } for(int i = 1; i < numOfArrs; ++i) - outShapeInfo[axis + 1] += nonEmptyArrShapes[i][axis + 1]; + outShapeInfo[axis + 1] += arrShapes[i][axis + 1]; - ShapeUtils::updateStridesAndType(outShapeInfo, nonEmptyArrShapes[0], shape::order(nonEmptyArrShapes[0])); + ShapeUtils::updateStridesAndType(outShapeInfo, arrShapes[0], shape::order(arrShapes[0])); // delete dynamically allocated vectors shapes with length=1 for(int index : shapesToDelete) - RELEASE(nonEmptyArrShapes[index], block.getWorkspace()); + RELEASE(arrShapes[index], block.getWorkspace()); auto result = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(outShapeInfo)); RELEASE(outShapeInfo, block.getWorkspace()); @@ -277,7 +273,7 @@ DECLARE_SHAPE_FN(concat) { // DECLARE_SYN(ParallelConcat, concat); // DECLARE_SYN(concat_v2, concat); // DECLARE_SYN(concatv2, concat); - + // DECLARE_SHAPE_FN(concat) { // auto inp = inputShape->at(0); // int _dimension = INT_ARG(0); @@ -338,7 +334,7 @@ DECLARE_SHAPE_FN(concat) { // } // } - + // ALLOCATE(newShape, block.getWorkspace(), shape::shapeInfoLength(first->shapeInfo()), Nd4jLong); // if (_dimension < 0) @@ -382,11 +378,11 @@ DECLARE_SHAPE_FN(concat) { auto epsilonChunk = OUTPUT_VARIABLE(e); std::vector indices(2 * epsilonNext->rankOf()); - int width = originalChunk->sizeAt(axis); + int width = originalChunk->sizeAt(axis); for (int e = 0; e < epsilonNext->rankOf(); e++) { if (e == axis) - indices[2*e + 1] = (indices[2*e] = startPos) + width; + indices[2*e + 1] = (indices[2*e] = startPos) + width; else indices[2*e + 1] = indices[2*e] = 0; } @@ -394,7 +390,7 @@ DECLARE_SHAPE_FN(concat) { auto subarray = (*epsilonNext)(indices, true); epsilonChunk->assign(subarray); - startPos += width; + startPos += width; } return ND4J_STATUS_OK; diff --git a/libnd4j/include/ops/declarable/generic/transforms/cumprod.cpp b/libnd4j/include/ops/declarable/generic/transforms/cumprod.cpp index a3bf5fe95..d539211a9 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/cumprod.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/cumprod.cpp @@ -32,6 +32,11 @@ namespace nd4j { REQUIRE_TRUE(input->dataType() == output->dataType(), 0, "CumSum: input and output data types must be equal"); + if(input->isEmpty()){ + //No-op + return Status::OK(); + } + const bool exclusive = INT_ARG(0) == 1; const bool reverse = INT_ARG(1) == 1; diff --git a/libnd4j/include/ops/declarable/generic/transforms/cumsum.cpp b/libnd4j/include/ops/declarable/generic/transforms/cumsum.cpp index 75c3455dc..260965fb6 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/cumsum.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/cumsum.cpp @@ -36,6 +36,11 @@ CONFIGURABLE_OP_IMPL(cumsum, 1, 1, true, 0, 2) { REQUIRE_TRUE(input->dataType() == output->dataType(), 0, "CumSum: input and output data types must be equal"); + if(input->isEmpty()){ + //No-op + return Status::OK(); + } + if (block.getIArguments()->size() == 2 && block.width() == 1) { // all at once case nd4j::ops::helpers::_prefix(block.launchContext(), scalar::Add, input, output, exclusive, reverse); diff --git a/libnd4j/include/ops/declarable/generic/transforms/gather.cpp b/libnd4j/include/ops/declarable/generic/transforms/gather.cpp index f8ddae840..2e38f9977 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/gather.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/gather.cpp @@ -102,12 +102,6 @@ DECLARE_SHAPE_FN(gather) { if(axis < 0) axis += inputRank; - //Edge case: empty indices, empty input -> empty output - if(block.width() > 1 && INPUT_VARIABLE(0)->isEmpty() && INPUT_VARIABLE(1)->isEmpty()){ - auto empty = ConstantShapeHelper::getInstance()->emptyShapeInfo(INPUT_VARIABLE(0)->dataType()); - return SHAPELIST(empty); - } - REQUIRE_TRUE(axis < inputRank, 0, "GATHER op: input axis must be smaller than input array rank, but got %i and %i correspondingly!", axis, inputRank); bool isEmpty = false; @@ -118,11 +112,6 @@ DECLARE_SHAPE_FN(gather) { int indicesRank = shape::rank(indicesShapeInfo); int outputRank = inputRank + indicesRank - 1; - - if(INPUT_VARIABLE(1)->isEmpty()) { //Empty indices -> empty output - outputRank = 0; - isEmpty = true; - } ALLOCATE(outputShapeInfo, block.getWorkspace(), shape::shapeInfoLength(outputRank), Nd4jLong); diff --git a/libnd4j/include/ops/declarable/generic/transforms/reverse.cpp b/libnd4j/include/ops/declarable/generic/transforms/reverse.cpp index 8267fef4c..1d2b25678 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/reverse.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/reverse.cpp @@ -33,6 +33,11 @@ namespace ops { auto input = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); + if(output->isEmpty()){ + //No-op + return Status::OK(); + } + std::vector axis; if (block.width() > 1) diff --git a/libnd4j/include/ops/declarable/helpers/convolutions.h b/libnd4j/include/ops/declarable/helpers/convolutions.h index fa7518f96..7e23945fa 100644 --- a/libnd4j/include/ops/declarable/helpers/convolutions.h +++ b/libnd4j/include/ops/declarable/helpers/convolutions.h @@ -204,6 +204,8 @@ namespace nd4j { const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst, mkldnn::memory::desc* conv_src_md, mkldnn::memory::desc* conv_diff_src_md, mkldnn::memory::desc* conv_weights_md, mkldnn::memory::desc* conv_diff_weights_md, mkldnn::memory::desc* conv_bias_md, mkldnn::memory::desc* conv_dst_md, + mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_weights_md, + mkldnn::memory::desc* user_diff_weights_md, mkldnn::memory::desc* user_bias_md, mkldnn::memory::desc* user_dst_md, mkldnn::memory::dims& conv_strides, mkldnn::memory::dims& conv_padding, mkldnn::memory::dims& conv_padding_r); static void getMKLDNNMemoryDescConv3d( @@ -212,56 +214,60 @@ namespace nd4j { const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst, mkldnn::memory::desc* conv_src_md, mkldnn::memory::desc* conv_diff_src_md, mkldnn::memory::desc* conv_weights_md, mkldnn::memory::desc* conv_diff_weights_md, mkldnn::memory::desc* conv_bias_md, mkldnn::memory::desc* conv_dst_md, + mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_weights_md, + mkldnn::memory::desc* user_diff_weights_md, mkldnn::memory::desc* user_bias_md, mkldnn::memory::desc* user_dst_md, mkldnn::memory::dims& conv_strides, mkldnn::memory::dims& conv_padding, mkldnn::memory::dims& conv_padding_r); static void getMKLDNNMemoryDescPool2d( int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, int poolingMode, int extraParam0, bool isNCHW, int bS, int iC, int iH, int iW, int oC, int oH, int oW, - const NDArray* src, const NDArray* diff_src, const NDArray* dst, - mkldnn::memory::desc* pool_src_md, mkldnn::memory::desc* conv_diff_src_md, mkldnn::memory::desc* pool_dst_md, mkldnn::algorithm& algorithm, + const NDArray* src, const NDArray* diff_src, const NDArray* dst, mkldnn::algorithm& algorithm, + mkldnn::memory::desc* pool_src_md, mkldnn::memory::desc* pool_diff_src_md, mkldnn::memory::desc* pool_dst_md, + mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md, mkldnn::memory::dims& pool_strides, mkldnn::memory::dims& pool_kernel, mkldnn::memory::dims& pool_padding, mkldnn::memory::dims& pool_padding_r); static void getMKLDNNMemoryDescPool3d( int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, int poolingMode, int extraParam0, bool isNCDHW, int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW, - const NDArray* src, const NDArray* diff_src, const NDArray* dst, - mkldnn::memory::desc* pool_src_md, mkldnn::memory::desc* conv_diff_src_md, mkldnn::memory::desc* pool_dst_md, mkldnn::algorithm& algorithm, + const NDArray* src, const NDArray* diff_src, const NDArray* dst, mkldnn::algorithm& algorithm, + mkldnn::memory::desc* pool_src_md, mkldnn::memory::desc* pool_diff_src_md, mkldnn::memory::desc* pool_dst_md, + mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md, mkldnn::memory::dims& pool_strides, mkldnn::memory::dims& pool_kernel, mkldnn::memory::dims& pool_padding, mkldnn::memory::dims& pool_padding_r); #endif - static void conv2d(nd4j::LaunchContext &context, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW); + static void conv2d(nd4j::graph::Context &context, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW); - static void conv2d(nd4j::LaunchContext & block, const std::vector& inArrs, NDArray* output, const std::vector& intArgs); + static void conv2d(nd4j::graph::Context & block, const std::vector& inArrs, NDArray* output, const std::vector& intArgs); - static void conv2dBP(nd4j::LaunchContext & block, const std::vector& inArrs, const std::vector& outArrs, const std::vector& intArgs); + static void conv2dBP(nd4j::graph::Context & block, const std::vector& inArrs, const std::vector& outArrs, const std::vector& intArgs); - static void conv2dBP(nd4j::LaunchContext & block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW); + static void conv2dBP(nd4j::graph::Context & block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW); - static void depthwiseConv2d(nd4j::LaunchContext & block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW); + static void depthwiseConv2d(nd4j::graph::Context & block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW); - static void depthwiseConv2dBP(nd4j::LaunchContext & block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW); + static void depthwiseConv2dBP(nd4j::graph::Context & block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW); - static void sconv2d(nd4j::LaunchContext & block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW); + static void sconv2d(nd4j::graph::Context & block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW); - static void vol2col(nd4j::LaunchContext & block, const NDArray& vol, NDArray& col, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW); + static void vol2col(nd4j::graph::Context & block, const NDArray& vol, NDArray& col, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW); - static void col2vol(nd4j::LaunchContext & block, const NDArray& col, NDArray& vol, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW); + static void col2vol(nd4j::graph::Context & block, const NDArray& col, NDArray& vol, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW); - static void upsampling2d(nd4j::LaunchContext & block, const NDArray& input, NDArray& output, const int factorH, const int factorW, const bool isNCHW); + static void upsampling2d(nd4j::graph::Context & block, const NDArray& input, NDArray& output, const int factorH, const int factorW, const bool isNCHW); - static void upsampling3d(nd4j::LaunchContext & block, const NDArray& input, NDArray& output, const int factorD, const int factorH, const int factorW, const bool isNCDHW); + static void upsampling3d(nd4j::graph::Context & block, const NDArray& input, NDArray& output, const int factorD, const int factorH, const int factorW, const bool isNCDHW); - static void upsampling2dBP(nd4j::LaunchContext & block, const NDArray& gradO, NDArray& gradI, const bool isNCHW); + static void upsampling2dBP(nd4j::graph::Context & block, const NDArray& gradO, NDArray& gradI, const bool isNCHW); - static void upsampling3dBP(nd4j::LaunchContext & block, const NDArray& gradO, NDArray& gradI, const bool isNCDHW); + static void upsampling3dBP(nd4j::graph::Context & block, const NDArray& gradO, NDArray& gradI, const bool isNCDHW); - static void pooling2d(nd4j::LaunchContext & block, const NDArray& input, NDArray& output, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const PoolingType poolingMode, const int extraParam0); + static void pooling2d(nd4j::graph::Context & block, const NDArray& input, NDArray& output, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const PoolingType poolingMode, const int extraParam0); - static void pooling3d(nd4j::LaunchContext & block, const NDArray& input, NDArray& output, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0); + static void pooling3d(nd4j::graph::Context & block, const NDArray& input, NDArray& output, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0); - static void pooling2dBP(nd4j::LaunchContext & block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0); + static void pooling2dBP(nd4j::graph::Context & block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0); - static void pooling3dBP(nd4j::LaunchContext & block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0); + static void pooling3dBP(nd4j::graph::Context & block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0); }; } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/activations.cpp b/libnd4j/include/ops/declarable/helpers/cpu/activations.cpp index f8666b7ec..6c7fb7fea 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/activations.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/activations.cpp @@ -153,7 +153,7 @@ void softMaxForVector(nd4j::LaunchContext * context, const NDArray& input, NDArr if (inEWS == 1) { PRAGMA_OMP_SIMD_MAX(max) for (int i = 0; i < length; i++) - max = nd4j::math::nd4j_max(max, outBuff[i]); + max = nd4j::math::nd4j_max(max, inBuff[i]); PRAGMA_OMP_SIMD_SUM(sum) for (int i = 0; i < length; i++) { @@ -171,7 +171,7 @@ void softMaxForVector(nd4j::LaunchContext * context, const NDArray& input, NDArr PRAGMA_OMP_SIMD_MAX(max) for (int i = 0; i < length; i++) - max = nd4j::math::nd4j_max(max, outBuff[i * inEWS]); + max = nd4j::math::nd4j_max(max, inBuff[i * inEWS]); PRAGMA_OMP_SIMD_SUM(sum) for (int i = 0; i < length; i++) { @@ -204,7 +204,7 @@ static void softmax_(nd4j::LaunchContext * context, const NDArray& input, NDArra const int rank = input.rankOf(); if(input.isVector()) { - + if(rank == 1 || input.sizeAt(dimension) != 1) softMaxForVector_(input.getBuffer(), input.getShapeInfo(), output.buffer(), output.getShapeInfo()); else @@ -228,7 +228,7 @@ static void softmax_(nd4j::LaunchContext * context, const NDArray& input, NDArra T max = -DataTypeUtils::max(); T sum = 0; - + for(uint j = 0; j < tadLen; ++j) max = nd4j::math::nd4j_max(max, inBuff[j]); @@ -237,9 +237,9 @@ static void softmax_(nd4j::LaunchContext * context, const NDArray& input, NDArra outBuff[j] = temp; sum += temp; } - + for (uint j = 0; j < tadLen; ++j) - outBuff[j] /= sum; + outBuff[j] /= sum; } } else { diff --git a/libnd4j/include/ops/declarable/helpers/cpu/convolutions.cpp b/libnd4j/include/ops/declarable/helpers/cpu/convolutions.cpp index aa5832776..4bbad5146 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/convolutions.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/convolutions.cpp @@ -18,220 +18,336 @@ // @author Yurii Shyrma (iuriish@yahoo.com), created on 18.09.2018 // -#include #include +#include #include #include #include #include namespace nd4j { -namespace ops { + namespace ops { + +#ifdef HAVE_MKLDNN + using namespace mkldnn; + +void ConvolutionUtils::getMKLDNNMemoryDescPool2d( + int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, int poolingMode, int extraParam0, bool isNCHW, + int bS, int iC, int iH, int iW, int oC, int oH, int oW, + const NDArray* src, const NDArray* diff_src, const NDArray* dst, mkldnn::algorithm& algorithm, + mkldnn::memory::desc* pool_src_md, mkldnn::memory::desc* pool_diff_src_md, mkldnn::memory::desc* pool_dst_md, + mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md, + mkldnn::memory::dims& pool_strides, mkldnn::memory::dims& pool_kernel, mkldnn::memory::dims& pool_padding, mkldnn::memory::dims& pool_padding_r) { + mkldnn::memory::dims pool_src_tz = { bS, iC, iH, iW }; + mkldnn::memory::dims pool_dst_tz = { bS, oC, oH, oW }; + + pool_strides = { sH, sW }; + pool_kernel = { kH, kW }; + pool_padding = { pH, pW }; + pool_padding_r = { (oH - 1) * sH - iH + kH - pH, + (oW - 1) * sW - iW + kW - pW }; + + algorithm = poolingMode == 0 ? pooling_max + : extraParam0 == 0 ? pooling_avg_exclude_padding + : pooling_avg_include_padding; + auto type = mkldnn::memory::data_type::f32; + auto format = isNCHW ? mkldnn::memory::format::nchw : mkldnn::memory::format::nhwc; + auto supposed_to_be_any_format = mkldnn::memory::format::nChw8c; // doesn't work with "any" + + if (src != nullptr && src->getBuffer() != nullptr && pool_src_md != nullptr) { + *pool_src_md = mkldnn::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format); + *user_src_md = mkldnn::memory::desc({ pool_src_tz }, type, format); + user_src_md->data.format = mkldnn_blocked; // overrides "format = isNCHW ? nchw : nhwc" + user_src_md->data.layout_desc.blocking.strides[0][0] = src->stridesOf()[isNCHW ? 0 : 0]; + user_src_md->data.layout_desc.blocking.strides[0][1] = src->stridesOf()[isNCHW ? 1 : 3]; + user_src_md->data.layout_desc.blocking.strides[0][2] = src->stridesOf()[isNCHW ? 2 : 1]; + user_src_md->data.layout_desc.blocking.strides[0][3] = src->stridesOf()[isNCHW ? 3 : 2]; + } + + if (diff_src != nullptr && diff_src->getBuffer() != nullptr && pool_diff_src_md != nullptr) { + *pool_diff_src_md = mkldnn::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format); + *user_diff_src_md = mkldnn::memory::desc({ pool_src_tz }, type, format); + user_diff_src_md->data.format = mkldnn_blocked; // overrides "format = isNCHW ? nchw : nhwc" + user_diff_src_md->data.layout_desc.blocking.strides[0][0] = diff_src->stridesOf()[isNCHW ? 0 : 0]; + user_diff_src_md->data.layout_desc.blocking.strides[0][1] = diff_src->stridesOf()[isNCHW ? 1 : 3]; + user_diff_src_md->data.layout_desc.blocking.strides[0][2] = diff_src->stridesOf()[isNCHW ? 2 : 1]; + user_diff_src_md->data.layout_desc.blocking.strides[0][3] = diff_src->stridesOf()[isNCHW ? 3 : 2]; + } + + if (dst != nullptr && dst->getBuffer() != nullptr && pool_dst_md != nullptr) { + *pool_dst_md = mkldnn::memory::desc({ pool_dst_tz }, type, supposed_to_be_any_format); + *user_dst_md = mkldnn::memory::desc({ pool_dst_tz }, type, format); + user_dst_md->data.format = mkldnn_blocked; // overrides "format = isNCHW ? nchw : nhwc" + user_dst_md->data.layout_desc.blocking.strides[0][0] = dst->stridesOf()[isNCHW ? 0 : 0]; + user_dst_md->data.layout_desc.blocking.strides[0][1] = dst->stridesOf()[isNCHW ? 1 : 3]; + user_dst_md->data.layout_desc.blocking.strides[0][2] = dst->stridesOf()[isNCHW ? 2 : 1]; + user_dst_md->data.layout_desc.blocking.strides[0][3] = dst->stridesOf()[isNCHW ? 3 : 2]; + } +} + +void ConvolutionUtils::getMKLDNNMemoryDescPool3d( + int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, int poolingMode, int extraParam0, bool isNCDHW, + int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW, + const NDArray* src, const NDArray* diff_src, const NDArray* dst, mkldnn::algorithm& algorithm, + mkldnn::memory::desc* pool_src_md, mkldnn::memory::desc* pool_diff_src_md, mkldnn::memory::desc* pool_dst_md, + mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md, + mkldnn::memory::dims& pool_strides, mkldnn::memory::dims& pool_kernel, mkldnn::memory::dims& pool_padding, mkldnn::memory::dims& pool_padding_r) { + mkldnn::memory::dims pool_src_tz = { bS, iC, iD, iH, iW }; + mkldnn::memory::dims pool_dst_tz = { bS, oC, oD, oH, oW }; + + pool_strides = { sD, sH, sW }; + pool_kernel = { kD, kH, kW }; + pool_padding = { pD, pH, pW }; + pool_padding_r = { (oD - 1) * sD - iD + kD - pD, + (oH - 1) * sH - iH + kH - pH, + (oW - 1) * sW - iW + kW - pW }; + + algorithm = poolingMode == 0 ? pooling_max + : extraParam0 == 0 ? pooling_avg_exclude_padding + : pooling_avg_include_padding; + auto type = mkldnn::memory::data_type::f32; + auto format = isNCDHW ? mkldnn::memory::format::ncdhw : mkldnn::memory::format::ndhwc; + auto supposed_to_be_any_format = mkldnn::memory::format::nCdhw8c; // doesn't work with "any" + + if (src != nullptr && src->getBuffer() != nullptr && pool_src_md != nullptr) { + *pool_src_md = mkldnn::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format); + *user_src_md = mkldnn::memory::desc({ pool_src_tz }, type, format); + user_src_md->data.format = mkldnn_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc" + user_src_md->data.layout_desc.blocking.strides[0][0] = src->stridesOf()[isNCDHW ? 0 : 0]; + user_src_md->data.layout_desc.blocking.strides[0][1] = src->stridesOf()[isNCDHW ? 1 : 4]; + user_src_md->data.layout_desc.blocking.strides[0][2] = src->stridesOf()[isNCDHW ? 2 : 1]; + user_src_md->data.layout_desc.blocking.strides[0][3] = src->stridesOf()[isNCDHW ? 3 : 2]; + user_src_md->data.layout_desc.blocking.strides[0][4] = src->stridesOf()[isNCDHW ? 4 : 3]; + } + + if (diff_src != nullptr && diff_src->getBuffer() != nullptr && pool_diff_src_md != nullptr) { + *pool_diff_src_md = mkldnn::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format); + *user_diff_src_md = mkldnn::memory::desc({ pool_src_tz }, type, format); + user_diff_src_md->data.format = mkldnn_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc" + user_diff_src_md->data.layout_desc.blocking.strides[0][0] = diff_src->stridesOf()[isNCDHW ? 0 : 0]; + user_diff_src_md->data.layout_desc.blocking.strides[0][1] = diff_src->stridesOf()[isNCDHW ? 1 : 4]; + user_diff_src_md->data.layout_desc.blocking.strides[0][2] = diff_src->stridesOf()[isNCDHW ? 2 : 1]; + user_diff_src_md->data.layout_desc.blocking.strides[0][3] = diff_src->stridesOf()[isNCDHW ? 3 : 2]; + user_diff_src_md->data.layout_desc.blocking.strides[0][4] = diff_src->stridesOf()[isNCDHW ? 4 : 3]; + } + + if (dst != nullptr && dst->getBuffer() != nullptr && pool_dst_md != nullptr) { + *pool_dst_md = mkldnn::memory::desc({ pool_dst_tz }, type, supposed_to_be_any_format); + *user_dst_md = mkldnn::memory::desc({ pool_dst_tz }, type, format); + user_dst_md->data.format = mkldnn_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc" + user_dst_md->data.layout_desc.blocking.strides[0][0] = dst->stridesOf()[isNCDHW ? 0 : 0]; + user_dst_md->data.layout_desc.blocking.strides[0][1] = dst->stridesOf()[isNCDHW ? 1 : 4]; + user_dst_md->data.layout_desc.blocking.strides[0][2] = dst->stridesOf()[isNCDHW ? 2 : 1]; + user_dst_md->data.layout_desc.blocking.strides[0][3] = dst->stridesOf()[isNCDHW ? 3 : 2]; + user_dst_md->data.layout_desc.blocking.strides[0][4] = dst->stridesOf()[isNCDHW ? 4 : 3]; + } +} +#endif ////////////////////////////////////////////////////////////////////////// -// [bS, iC, iD, iH, iW] is convoluted to [bS, iC, kD, kH, kW, oD, oH, oW] -template -static void vol2col_(const NDArray& volume, NDArray& columns, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) { +// [bS, iC, iD, iH, iW] is convoluted to [bS, iC, kD, kH, kW, oD, oH, oW] + template + static void vol2col_(const NDArray& volume, NDArray& columns, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) { - const int bS = volume.sizeAt(0); - const int iC = volume.sizeAt(1); - const int iD = volume.sizeAt(2); - const int iH = volume.sizeAt(3); - const int iW = volume.sizeAt(4); - const int kD = columns.sizeAt(2); - const int kH = columns.sizeAt(3); - const int kW = columns.sizeAt(4); - const int oD = columns.sizeAt(5); - const int oH = columns.sizeAt(6); - const int oW = columns.sizeAt(7); - const Nd4jLong colStride0 = columns.stridesOf()[0]; - const Nd4jLong colStride1 = columns.stridesOf()[1]; - const Nd4jLong colStride2 = columns.stridesOf()[2]; - const Nd4jLong colStride3 = columns.stridesOf()[3]; - const Nd4jLong colStride4 = columns.stridesOf()[4]; - const Nd4jLong colStride5 = columns.stridesOf()[5]; - const Nd4jLong colStride6 = columns.stridesOf()[6]; - const Nd4jLong colStride7 = columns.stridesOf()[7]; - const Nd4jLong volStride0 = volume.stridesOf()[0]; - const Nd4jLong volStride1 = volume.stridesOf()[1]; - const Nd4jLong volStride2 = volume.stridesOf()[2]; - const Nd4jLong volStride3 = volume.stridesOf()[3]; - const Nd4jLong volStride4 = volume.stridesOf()[4]; - - T* colBuff = columns.bufferAsT(); - T* volBuff = const_cast(volume).bufferAsT(); + const int bS = volume.sizeAt(0); + const int iC = volume.sizeAt(1); + const int iD = volume.sizeAt(2); + const int iH = volume.sizeAt(3); + const int iW = volume.sizeAt(4); + const int kD = columns.sizeAt(2); + const int kH = columns.sizeAt(3); + const int kW = columns.sizeAt(4); + const int oD = columns.sizeAt(5); + const int oH = columns.sizeAt(6); + const int oW = columns.sizeAt(7); + const Nd4jLong colStride0 = columns.stridesOf()[0]; + const Nd4jLong colStride1 = columns.stridesOf()[1]; + const Nd4jLong colStride2 = columns.stridesOf()[2]; + const Nd4jLong colStride3 = columns.stridesOf()[3]; + const Nd4jLong colStride4 = columns.stridesOf()[4]; + const Nd4jLong colStride5 = columns.stridesOf()[5]; + const Nd4jLong colStride6 = columns.stridesOf()[6]; + const Nd4jLong colStride7 = columns.stridesOf()[7]; + const Nd4jLong volStride0 = volume.stridesOf()[0]; + const Nd4jLong volStride1 = volume.stridesOf()[1]; + const Nd4jLong volStride2 = volume.stridesOf()[2]; + const Nd4jLong volStride3 = volume.stridesOf()[3]; + const Nd4jLong volStride4 = volume.stridesOf()[4]; - T *col, *vol; - int volDep, volRow, volCol; + T* colBuff = columns.bufferAsT(); + T* volBuff = const_cast(volume).bufferAsT(); -if (volume.ordering() == 'c' && columns.ordering() == 'c' && shape::strideDescendingCAscendingF(volume.getShapeInfo()) && shape::strideDescendingCAscendingF(columns.getShapeInfo())) + T *col, *vol; + int volDep, volRow, volCol; - PRAGMA_OMP_PARALLEL_FOR_ARGS(private(col, vol, volDep, volRow, volCol) collapse(2)) - for (int b = 0; b < bS; b++) { - for (int c = 0; c < iC; ++c) { - for (int kDep = 0; kDep < kD; ++kDep) { - for (int kRow = 0; kRow < kH; ++kRow) { - for (int kCol = 0; kCol < kW; ++kCol) { - for (int colD = 0; colD < oD; ++colD) { - for (int colH = 0; colH < oH; ++colH) { - for (int colW = 0; colW < oW; ++colW) { - - volDep = (-pD + kDep * dD) + colD*sD; - volRow = (-pH + kRow * dH) + colH*sH; - volCol = (-pW + kCol * dW) + colW*sW; - - col = colBuff + b*colStride0 + c*colStride1 + kDep*colStride2 + kRow*colStride3 + kCol*colStride4 + colD*colStride5 + colH*colStride6 + colW*colStride7; - vol = volBuff + b*volStride0 + c*volStride1 + volDep*volStride2 + volRow*volStride3 + volCol*volStride4; - - if (static_cast(volDep) >= static_cast(iD) || static_cast(volRow) >= static_cast(iH) || static_cast(volCol) >= static_cast(iW)) - *col = static_cast(0.); - else - *col = *vol; - } - } - } - } - } - } - } - } + if (volume.ordering() == 'c' && columns.ordering() == 'c' && shape::strideDescendingCAscendingF(volume.getShapeInfo()) && shape::strideDescendingCAscendingF(columns.getShapeInfo())) -else - - PRAGMA_OMP_PARALLEL_FOR_ARGS(private(vol, col, volDep, volRow, volCol)) - for (int b = 0; b < bS; b++) { - for (int colD = 0; colD < oD; ++colD) { - for (int colH = 0; colH < oH; ++colH) { - for (int colW = 0; colW < oW; ++colW) { + PRAGMA_OMP_PARALLEL_FOR_ARGS(private(col, vol, volDep, volRow, volCol) collapse(2)) + for (int b = 0; b < bS; b++) { for (int c = 0; c < iC; ++c) { - for (int kDep = 0; kDep < kD; ++kDep) { - for (int kRow = 0; kRow < kH; ++kRow) { - for (int kCol = 0; kCol < kW; ++kCol) { - - volDep = (-pD + kDep * dD) + colD*sD; - volRow = (-pH + kRow * dH) + colH*sH; - volCol = (-pW + kCol * dW) + colW*sW; - - col = colBuff + b*colStride0 + c*colStride1 + kDep*colStride2 + kRow*colStride3 + kCol*colStride4 + colD*colStride5 + colH*colStride6 + colW*colStride7; - vol = volBuff + b*volStride0 + c*volStride1 + volDep*volStride2 + volRow*volStride3 + volCol*volStride4; - - if (static_cast(volDep) >= static_cast(iD) || static_cast(volRow) >= static_cast(iH) || static_cast(volCol) >= static_cast(iW)) - *col = static_cast(0.); - else - *col = *vol; + for (int kDep = 0; kDep < kD; ++kDep) { + for (int kRow = 0; kRow < kH; ++kRow) { + for (int kCol = 0; kCol < kW; ++kCol) { + for (int colD = 0; colD < oD; ++colD) { + for (int colH = 0; colH < oH; ++colH) { + for (int colW = 0; colW < oW; ++colW) { + + volDep = (-pD + kDep * dD) + colD*sD; + volRow = (-pH + kRow * dH) + colH*sH; + volCol = (-pW + kCol * dW) + colW*sW; + + col = colBuff + b*colStride0 + c*colStride1 + kDep*colStride2 + kRow*colStride3 + kCol*colStride4 + colD*colStride5 + colH*colStride6 + colW*colStride7; + vol = volBuff + b*volStride0 + c*volStride1 + volDep*volStride2 + volRow*volStride3 + volCol*volStride4; + + if (static_cast(volDep) >= static_cast(iD) || static_cast(volRow) >= static_cast(iH) || static_cast(volCol) >= static_cast(iW)) + *col = static_cast(0.); + else + *col = *vol; + } + } + } + } + } + } + } + } + + else + + PRAGMA_OMP_PARALLEL_FOR_ARGS(private(vol, col, volDep, volRow, volCol)) + for (int b = 0; b < bS; b++) { + for (int colD = 0; colD < oD; ++colD) { + for (int colH = 0; colH < oH; ++colH) { + for (int colW = 0; colW < oW; ++colW) { + for (int c = 0; c < iC; ++c) { + for (int kDep = 0; kDep < kD; ++kDep) { + for (int kRow = 0; kRow < kH; ++kRow) { + for (int kCol = 0; kCol < kW; ++kCol) { + + volDep = (-pD + kDep * dD) + colD*sD; + volRow = (-pH + kRow * dH) + colH*sH; + volCol = (-pW + kCol * dW) + colW*sW; + + col = colBuff + b*colStride0 + c*colStride1 + kDep*colStride2 + kRow*colStride3 + kCol*colStride4 + colD*colStride5 + colH*colStride6 + colW*colStride7; + vol = volBuff + b*volStride0 + c*volStride1 + volDep*volStride2 + volRow*volStride3 + volCol*volStride4; + + if (static_cast(volDep) >= static_cast(iD) || static_cast(volRow) >= static_cast(iH) || static_cast(volCol) >= static_cast(iW)) + *col = static_cast(0.); + else + *col = *vol; + } + } + } } } } } } - } } - } -} ////////////////////////////////////////////////////////////////////////// // [bS, iC, kD, kH, kW, oD, oH, oW] is de-convoluted to [bS, iC, iD, iH, iW] -template -static void col2vol_(const NDArray& columns, NDArray& volume, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) { + template + static void col2vol_(const NDArray& columns, NDArray& volume, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) { - const int bS = volume.sizeAt(0); - const int iC = volume.sizeAt(1); - const int iD = volume.sizeAt(2); - const int iH = volume.sizeAt(3); - const int iW = volume.sizeAt(4); - const int kD = columns.sizeAt(2); - const int kH = columns.sizeAt(3); - const int kW = columns.sizeAt(4); - const int oD = columns.sizeAt(5); - const int oH = columns.sizeAt(6); - const int oW = columns.sizeAt(7); - const Nd4jLong colStride0 = columns.stridesOf()[0]; - const Nd4jLong colStride1 = columns.stridesOf()[1]; - const Nd4jLong colStride2 = columns.stridesOf()[2]; - const Nd4jLong colStride3 = columns.stridesOf()[3]; - const Nd4jLong colStride4 = columns.stridesOf()[4]; - const Nd4jLong colStride5 = columns.stridesOf()[5]; - const Nd4jLong colStride6 = columns.stridesOf()[6]; - const Nd4jLong colStride7 = columns.stridesOf()[7]; - const Nd4jLong volStride0 = volume.stridesOf()[0]; - const Nd4jLong volStride1 = volume.stridesOf()[1]; - const Nd4jLong volStride2 = volume.stridesOf()[2]; - const Nd4jLong volStride3 = volume.stridesOf()[3]; - const Nd4jLong volStride4 = volume.stridesOf()[4]; - - T* volBuff = volume.bufferAsT(); - T* colBuff = const_cast(columns).bufferAsT(); + const int bS = volume.sizeAt(0); + const int iC = volume.sizeAt(1); + const int iD = volume.sizeAt(2); + const int iH = volume.sizeAt(3); + const int iW = volume.sizeAt(4); + const int kD = columns.sizeAt(2); + const int kH = columns.sizeAt(3); + const int kW = columns.sizeAt(4); + const int oD = columns.sizeAt(5); + const int oH = columns.sizeAt(6); + const int oW = columns.sizeAt(7); + const Nd4jLong colStride0 = columns.stridesOf()[0]; + const Nd4jLong colStride1 = columns.stridesOf()[1]; + const Nd4jLong colStride2 = columns.stridesOf()[2]; + const Nd4jLong colStride3 = columns.stridesOf()[3]; + const Nd4jLong colStride4 = columns.stridesOf()[4]; + const Nd4jLong colStride5 = columns.stridesOf()[5]; + const Nd4jLong colStride6 = columns.stridesOf()[6]; + const Nd4jLong colStride7 = columns.stridesOf()[7]; + const Nd4jLong volStride0 = volume.stridesOf()[0]; + const Nd4jLong volStride1 = volume.stridesOf()[1]; + const Nd4jLong volStride2 = volume.stridesOf()[2]; + const Nd4jLong volStride3 = volume.stridesOf()[3]; + const Nd4jLong volStride4 = volume.stridesOf()[4]; - // initial zeroing of volume content - memset(volBuff, 0, volume.lengthOf() * sizeof(T)); + T* volBuff = volume.bufferAsT(); + T* colBuff = const_cast(columns).bufferAsT(); - T* col, *vol; - int volDep, volRow, volCol; + // initial zeroing of volume content + memset(volBuff, 0, volume.lengthOf() * sizeof(T)); -if (volume.ordering() == 'c' && columns.ordering() == 'c' && shape::strideDescendingCAscendingF(volume.getShapeInfo()) && shape::strideDescendingCAscendingF(columns.getShapeInfo())) + T* col, *vol; + int volDep, volRow, volCol; - PRAGMA_OMP_PARALLEL_FOR_ARGS(private(col, vol, volDep, volRow, volCol) collapse(2)) - for (int b = 0; b < bS; b++) { - for (int c = 0; c < iC; ++c) { - for (int kDep = 0; kDep < kD; ++kDep) { - for (int kRow = 0; kRow < kH; ++kRow) { - for (int kCol = 0; kCol < kW; ++kCol) { - for (int colD = 0; colD < oD; ++colD) { - for (int colH = 0; colH < oH; ++colH) { - for (int colW = 0; colW < oW; ++colW) { + if (volume.ordering() == 'c' && columns.ordering() == 'c' && shape::strideDescendingCAscendingF(volume.getShapeInfo()) && shape::strideDescendingCAscendingF(columns.getShapeInfo())) - volDep = (-pD + kDep * dD) + colD*sD; - volRow = (-pH + kRow * dH) + colH*sH; - volCol = (-pW + kCol * dW) + colW*sW; - - col = colBuff + b*colStride0 + c*colStride1 + kDep*colStride2 + kRow*colStride3 + kCol*colStride4 + colD*colStride5 + colH*colStride6 + colW*colStride7; - vol = volBuff + b*volStride0 + c*volStride1 + volDep*volStride2 + volRow*volStride3 + volCol*volStride4; - - if (static_cast(volDep) < static_cast(iD) && static_cast(volRow) < static_cast(iH) && static_cast(volCol) < static_cast(iW)) - *vol += *col; - } - } - } - } - } - } - } - } - -else - - PRAGMA_OMP_PARALLEL_FOR_ARGS(private(vol, col, volDep, volRow, volCol)) - for (int b = 0; b < bS; b++) { - for (int colD = 0; colD < oD; ++colD) { - for (int colH = 0; colH < oH; ++colH) { - for (int colW = 0; colW < oW; ++colW) { + PRAGMA_OMP_PARALLEL_FOR_ARGS(private(col, vol, volDep, volRow, volCol) collapse(2)) + for (int b = 0; b < bS; b++) { for (int c = 0; c < iC; ++c) { - for (int kDep = 0; kDep < kD; ++kDep) { - for (int kRow = 0; kRow < kH; ++kRow) { - for (int kCol = 0; kCol < kW; ++kCol) { - - volDep = (-pD + kDep * dD) + colD*sD; - volRow = (-pH + kRow * dH) + colH*sH; - volCol = (-pW + kCol * dW) + colW*sW; - - col = colBuff + b*colStride0 + c*colStride1 + kDep*colStride2 + kRow*colStride3 + kCol*colStride4 + colD*colStride5 + colH*colStride6 + colW*colStride7; - vol = volBuff + b*volStride0 + c*volStride1 + volDep*volStride2 + volRow*volStride3 + volCol*volStride4; - - if (static_cast(volDep) < static_cast(iD) && static_cast(volRow) < static_cast(iH) && static_cast(volCol) < static_cast(iW)) - *vol += *col; + for (int kDep = 0; kDep < kD; ++kDep) { + for (int kRow = 0; kRow < kH; ++kRow) { + for (int kCol = 0; kCol < kW; ++kCol) { + for (int colD = 0; colD < oD; ++colD) { + for (int colH = 0; colH < oH; ++colH) { + for (int colW = 0; colW < oW; ++colW) { + + volDep = (-pD + kDep * dD) + colD*sD; + volRow = (-pH + kRow * dH) + colH*sH; + volCol = (-pW + kCol * dW) + colW*sW; + + col = colBuff + b*colStride0 + c*colStride1 + kDep*colStride2 + kRow*colStride3 + kCol*colStride4 + colD*colStride5 + colH*colStride6 + colW*colStride7; + vol = volBuff + b*volStride0 + c*volStride1 + volDep*volStride2 + volRow*volStride3 + volCol*volStride4; + + if (static_cast(volDep) < static_cast(iD) && static_cast(volRow) < static_cast(iH) && static_cast(volCol) < static_cast(iW)) + *vol += *col; + } + } + } + } + } + } + } + } + + else + + PRAGMA_OMP_PARALLEL_FOR_ARGS(private(vol, col, volDep, volRow, volCol)) + for (int b = 0; b < bS; b++) { + for (int colD = 0; colD < oD; ++colD) { + for (int colH = 0; colH < oH; ++colH) { + for (int colW = 0; colW < oW; ++colW) { + for (int c = 0; c < iC; ++c) { + for (int kDep = 0; kDep < kD; ++kDep) { + for (int kRow = 0; kRow < kH; ++kRow) { + for (int kCol = 0; kCol < kW; ++kCol) { + + volDep = (-pD + kDep * dD) + colD*sD; + volRow = (-pH + kRow * dH) + colH*sH; + volCol = (-pW + kCol * dW) + colW*sW; + + col = colBuff + b*colStride0 + c*colStride1 + kDep*colStride2 + kRow*colStride3 + kCol*colStride4 + colD*colStride5 + colH*colStride6 + colW*colStride7; + vol = volBuff + b*volStride0 + c*volStride1 + volDep*volStride2 + volRow*volStride3 + volCol*volStride4; + + if (static_cast(volDep) < static_cast(iD) && static_cast(volRow) < static_cast(iH) && static_cast(volCol) < static_cast(iW)) + *vol += *col; + } + } + } } } } } } - } } - } -} #ifdef HAVE_MKLDNN -using namespace mkldnn; + using namespace mkldnn; void ConvolutionUtils::getMKLDNNMemoryDescConv2d( int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, bool isSameMode, bool isNCHW, @@ -399,34 +515,34 @@ void ConvolutionUtils::getMKLDNNMemoryDescConv3d( #endif ////////////////////////////////////////////////////////////////////////// -template -static void conv2d_(nd4j::LaunchContext & block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) { + template + static void conv2d_(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) { - // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - // weights [kH, kW, iC, oC] always - // bias [oC] - // output [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) + // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + // weights [kH, kW, iC, oC] always + // bias [oC] + // output [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) - // kH filter(kernel) height - // kW filter(kernel) width - // sH strides height - // sW strides width - // pH paddings height - // pW paddings width - // dH dilations height - // dW dilations width - // isSameMode 0-VALID, 1-SAME - // isNCHW 1-NCHW, 0-NHWC + // kH filter(kernel) height + // kW filter(kernel) width + // sH strides height + // sW strides width + // pH paddings height + // pW paddings width + // dH dilations height + // dW dilations width + // isSameMode 0-VALID, 1-SAME + // isNCHW 1-NCHW, 0-NHWC - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); + int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - if(isSameMode) // SAME - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); + if(isSameMode) // SAME + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); #ifdef HAVE_MKLDNN - if (block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported()) { + if (block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported()) { std::vector& streams = block.getMKLDNNStreams(); if (streams.empty()) { streams.push_back(MKLDNNStream("conv2d")); @@ -502,75 +618,75 @@ static void conv2d_(nd4j::LaunchContext & block, const NDArray* input, const NDA return; } #endif - nd4j_debug("MKL-DNN is not used for conv2d!\n", 0); + nd4j_debug("MKL-DNN is not used for conv2d!\n", 0); - std::vector permutForOutput; - if(!isNCHW) - input = input->permute({0, 3, 1, 2}); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] if NHWC - else - // permutForOutput = {0, indOoH, indOoH+1, indIOioC}; // [bS, oC, oH, oW] -> [bS, oH, oW, oC] - permutForOutput = {0, 3, 1, 2}; // [bS, oH, oW, oC] -> [bS, oC, oH, oW] + std::vector permutForOutput; + if(!isNCHW) + input = input->permute({0, 3, 1, 2}); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] if NHWC + else + // permutForOutput = {0, indOoH, indOoH+1, indIOioC}; // [bS, oC, oH, oW] -> [bS, oH, oW, oC] + permutForOutput = {0, 3, 1, 2}; // [bS, oH, oW, oC] -> [bS, oC, oH, oW] - NDArray col('c', {bS, oH, oW, kH, kW, iC}, input->dataType(), &block); - NDArray* colP = col.permute({0, 5, 3, 4, 1, 2}); // {bS, iC, kH, kW, oH, oW} - NDArray mmulResult('f', {bS*oH*oW, oC}, output->dataType(), &block); + NDArray col('c', {bS, oH, oW, kH, kW, iC}, input->dataType(), input->getContext()); + NDArray* colP = col.permute({0, 5, 3, 4, 1, 2}); // {bS, iC, kH, kW, oH, oW} + NDArray mmulResult('f', {bS*oH*oW, oC}, output->dataType(), output->getContext()); - //----- calculation of output -----// - nd4j::LaunchContext ctx; - helpers::im2col(ctx, *input, *colP, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, &block)); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW] - MmulHelper::tensorDot(&col, weights, &mmulResult, {3,4,5}, {0,1,2}, {}); // [bS, oH, oW, kH, kW, iC] x [kH, kW, iC, oC] = [bS, oH, oW, oC] + //----- calculation of output -----// + auto ctx = block.launchContext(); + helpers::im2col(*ctx, *input, *colP, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW] + MmulHelper::tensorDot(&col, weights, &mmulResult, {3,4,5}, {0,1,2}, {}); // [bS, oH, oW, kH, kW, iC] x [kH, kW, iC, oC] = [bS, oH, oW, oC] - //----- assign outTemp to output -----// - if(isNCHW) { - mmulResult.reshapei({bS, oH, oW, oC}); - mmulResult.permutei(permutForOutput); - } - output->assign(mmulResult); + //----- assign outTemp to output -----// + if(isNCHW) { + mmulResult.reshapei({bS, oH, oW, oC}); + mmulResult.permutei(permutForOutput); + } + output->assign(mmulResult); - //----- add biases if required -----// - if(bias) - // output->applyBroadcast(broadcast::Add, {indIOioC}, bias); - helpers::addBias(*output, *bias, isNCHW); + //----- add biases if required -----// + if(bias) + // output->applyBroadcast(broadcast::Add, {indIOioC}, bias); + helpers::addBias(*output, *bias, isNCHW); - if(!isNCHW) - delete input; + if(!isNCHW) + delete input; - delete colP; -} + delete colP; + } ////////////////////////////////////////////////////////////////////////// -template -static void conv2dBP_(nd4j::LaunchContext & block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) { + template + static void conv2dBP_(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) { - // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - // weights [kH, kW, iC, oC] always - // bias [oC] - // gradO [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next + // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + // weights [kH, kW, iC, oC] always + // bias [oC] + // gradO [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next - // gradI [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon - // gradW [kH, kW, iC, oC] always - // gradB [oC] + // gradI [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon + // gradW [kH, kW, iC, oC] always + // gradB [oC] - // kH filter(kernel) height - // kW filter(kernel) width - // sH strides height - // sW strides width - // pH paddings height - // pW paddings width - // dH dilations height - // dW dilations width - // isSameMode 0-VALID, 1-SAME - // isNCHW 0-NHWC, 1-NCHW + // kH filter(kernel) height + // kW filter(kernel) width + // sH strides height + // sW strides width + // pH paddings height + // pW paddings width + // dH dilations height + // dW dilations width + // isSameMode 0-VALID, 1-SAME + // isNCHW 0-NHWC, 1-NCHW - int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; - int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); + int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - if(isSameMode) // SAME - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); + if(isSameMode) // SAME + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); #ifdef HAVE_MKLDNN - if (block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported()) { + if (block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported()) { std::vector& streams = block.getMKLDNNStreams(); if (streams.empty()) { streams.push_back(MKLDNNStream("conv2d_bp_weights")); @@ -714,294 +830,252 @@ static void conv2dBP_(nd4j::LaunchContext & block, const NDArray* input, const N return; } #endif - nd4j_debug("MKL-DNN is not used for conv2d_bp!\n", 0); + nd4j_debug("MKL-DNN is not used for conv2d_bp!\n", 0); - std::vector gradOaxesForDot; + std::vector gradOaxesForDot; - if(!isNCHW) { - input = input->permute({0, 3, 1, 2}); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] - gradI = gradI->permute({0, 3, 1, 2}); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] - gradOaxesForDot = {0, 1, 2}; // bS, oH, oW - } - else - gradOaxesForDot = {0, 2, 3}; // bS, oH, oW + if(!isNCHW) { + input = input->permute({0, 3, 1, 2}); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] + gradI = gradI->permute({0, 3, 1, 2}); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] + gradOaxesForDot = {0, 1, 2}; // bS, oH, oW + } + else + gradOaxesForDot = {0, 2, 3}; // bS, oH, oW - NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, input->dataType(), input->getContext()); + NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, input->dataType(), input->getContext()); - // ----- calculation of gradW ----- // - if(gradW) { - nd4j::LaunchContext * ctx = █ - helpers::im2col(*ctx, *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW] - nd4j::MmulHelper::tensorDot(&columns, gradO, gradW, {0,4,5}, gradOaxesForDot, {2, 0, 1, 3}); // [bS, iC, kH, kW, oH, oW] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [iC, kH, kW, oC] - } + // ----- calculation of gradW ----- // + if(gradW) { + auto ctx = block.launchContext(); + helpers::im2col(*ctx, *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW] + nd4j::MmulHelper::tensorDot(&columns, gradO, gradW, {0,4,5}, gradOaxesForDot, {2, 0, 1, 3}); // [bS, iC, kH, kW, oH, oW] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [iC, kH, kW, oC] + } - // ----- calculation of gradB ----- // - if(gradB) { - NDArray* gradBR = gradB; - if(gradB->rankOf() == 2) - gradBR = gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}); - gradO->reduceAlongDimension(reduce::Sum, gradBR, gradOaxesForDot); // sum over bS, oH, oW - if(gradBR != gradB) - delete gradBR; - } + // ----- calculation of gradB ----- // + if(gradB) { + NDArray* gradBR = gradB; + if(gradB->rankOf() == 2) + gradBR = gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}); + gradO->reduceAlongDimension(reduce::Sum, gradBR, gradOaxesForDot); // sum over bS, oH, oW + if(gradBR != gradB) + delete gradBR; + } - //----- calculation of gradI -----// - nd4j::MmulHelper::tensorDot(weights, gradO, &columns, {indWoC}, {indIOioC}, {2, 3, 1, 0, 4, 5}); // [kH, kW, iC, oC]/[oC, iC, kH, kW]] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [kH, kW, iC, bS, oH, oW] - nd4j::LaunchContext * ctx = █ - helpers::col2im(*ctx, columns, *gradI, sH, sW, pH, pW, iH, iW, dH, dW); // [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW] + //----- calculation of gradI -----// + nd4j::MmulHelper::tensorDot(weights, gradO, &columns, {indWoC}, {indIOioC}, {2, 3, 1, 0, 4, 5}); // [kH, kW, iC, oC]/[oC, iC, kH, kW]] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [kH, kW, iC, bS, oH, oW] - if(!isNCHW) { - delete input; - delete gradI; - } -} + helpers::col2im(*block.launchContext(), columns, *gradI, sH, sW, pH, pW, iH, iW, dH, dW); // [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW] -////////////////////////////////////////////////////////////////////////// -template -static void depthwiseConv2d_(nd4j::LaunchContext & block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) { - - // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - // weights [kH, kW, iC, mC] always - // bias [oC] = iC*mC - // output [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW) - - // kH filter(kernel) height - // kW filter(kernel) width - // sH strides height - // sW strides width - // pH paddings height - // pW paddings width - // dH dilations height - // dW dilations width - // isSameMode 0-VALID, 1-SAME - // isNCHW 0-NCHW, 1-NHWC - - int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width - int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); - mC = weights->sizeAt(indWmC); // channels multiplier - - std::vector> modifColumns = {{1,0,4,5,2,3}, {iC,bS*oH*oW,kH*kW}}; // [bS,iC,kH,kW,oH,oW] -> [iC,bS,oH,oW,kH,kW] -> [iC,bS*oH*oW,kH*kW] - std::vector> modifOutput; - std::vector outReShape; - - if(!isNCHW) { - input = input->permute({0, 3, 1, 2}); // [bS,iH,iW,iC] -> [bS,iC,iH,iW] - outReShape = {bS, oH, oW, iC, mC}; // [bS,oH,oW,iC*mC] -> [bS,oH,oW,iC,mC] - modifOutput = {{3,0,1,2,4},{iC, bS*oH*oW, mC}}; // [bS,oH,oW,iC,mC] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC] - } - else { - outReShape = {bS, iC, mC, oH, oW}; // [bS,iC*mC,oH,oW] -> [bS,iC,mC,oH,oW] - modifOutput = {{1,0,3,4,2},{iC, bS*oH*oW, mC}}; // [bS,iC,mC,oH,oW] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC] - } - - if(isSameMode) // SAME - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); - - NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, input->dataType(), input->getContext()); - NDArray* outputReshaped = output->reshape(output->ordering(), outReShape); - - nd4j::LaunchContext * ctx = █ - helpers::im2col(*ctx, *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW] - MmulHelper::tensorDot(&columns, weights, outputReshaped, modifColumns, {{2,0,1,3},{iC,kH*kW,mC}}, modifOutput); // [iC, bS*oH*oW, kW*kH] x [iC, kH*kW, mC] = [iC, bS*oH*oW, mC] - - if(bias) - output->applyBroadcast(broadcast::Add, {indIOioC}, bias); - - if(!isNCHW) - delete input; - - delete outputReshaped; -} - -////////////////////////////////////////////////////////////////////////// -template -static void depthwiseConv2dBP_(nd4j::LaunchContext & block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) { - - // input [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW) - // weights [kH, kW, iC, mC] always - // bias [oC] = [iC*mC] - // gradO [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next - // gradI [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon - // gradW [kH, kW, iC, mC] always - // gradB [oC] - - // kH filter(kernel) height - // kW filter(kernel) width - // sH strides height - // sW strides width - // pH paddings height - // pW paddings width - // dH dilations height - // dW dilations width - // isSameMode 0-VALID, 1-SAME - // isNCHW 0-NHWC, 1-NCHW - - int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width - int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); - mC = weights->sizeAt(indWmC); // channels multiplier - - std::vector> modifColumns = {{1,2,3,0,4,5}, {iC, kH*kW, bS*oH*oW}}; // [bS,iC,kH,kW,oH,oW] -> [iC, kH*kW, bS*oH*oW] - std::vector> modifGradO1, modifGradO2; - std::vector gradOreShape; - - if(!isNCHW) { - input = input->permute({0, 3, 1, 2}); // [bS,iH,iW,iC] -> [bS,iC,iH,iW] - gradI = gradI->permute({0, 3, 1, 2}); // [bS,iH,iW,iC] -> [bS,iC,iH,iW] - gradOreShape = {bS, oH, oW, iC, mC}; // [bS,oH,oW,iC*mC] -> [bS,oH,oW,iC,mC] - modifGradO1 = {{3,0,1,2,4},{iC, bS*oH*oW, mC}}; // [bS,oH,oW,iC,mC] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC] - modifGradO2 = {{3,0,1,2},{iC, mC, bS*oH*oW}}; // [bS,oH,oW,iC*mC] -> [iC*mC,bS,oH,oW] -> [iC,mC,bS*oH*oW] - } - else { - gradOreShape = {bS, iC, mC, oH, oW}; // [bS,iC*mC,oH,oW] -> [bS,iC,mC,oH,oW] - modifGradO1 = {{1,0,3,4,2},{iC, bS*oH*oW, mC}}; // [bS,iC,mC,oH,oW] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC] - modifGradO2 = {{1,0,2,3},{iC, mC, bS*oH*oW}}; // [bS,iC*mC,oH,oW] -> [iC*mC,bS,oH,oW] -> [iC,mC,bS*oH*oW] - } - - if(isSameMode) // SAME - ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); - - NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, input->dataType(), input->getContext()); - NDArray* gradOreshaped = gradO->reshape(gradO->ordering(), gradOreShape); - - // ----- calculation of gradW and gradB ----- // - - nd4j::LaunchContext * ctx = █ - helpers::im2col(*ctx, *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW] - nd4j::MmulHelper::tensorDot(&columns, gradOreshaped, gradW, modifColumns, modifGradO1, {{2,0,1,3},{iC,kH*kW,mC}}); // [iC, kW*kH, bS*oH*oW] x [iC, bS*oH*oW, mC] = [iC, kH*kW, mC] - - // ----- calculation of gradB ----- // - if(gradB) { - NDArray* gradBR = gradB; - if(gradB->rankOf() == 2) - gradBR = gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}); - gradO->reduceAlongDimension(reduce::Sum, gradBR, {0,indOoH,indOoH+1}); // sum over bS, oH, oW - if(gradBR != gradB) - delete gradBR; - } - - //----- calculation of gradI -----// - nd4j::MmulHelper::tensorDot(weights, gradO, &columns, {{2,0,1,3},{iC,kH*kW,mC}}, modifGradO2, modifColumns); // [iC, kH*kW, mC] x [iC, mC, bS*oH*oW] = [iC, kW*kH, bS*oH*oW] - helpers::col2im(*ctx, columns, *gradI, sH, sW, pH, pW, iH, iW, dH, dW); // [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW] - - if(!isNCHW) { - delete input; - delete gradI; - } - - delete gradOreshaped; -} - -////////////////////////////////////////////////////////////////////////// -template -static void sconv2d_(nd4j::LaunchContext & block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) { - - // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - // weightsDepth [kH, kW, iC, mC] always - // weightsPoint [1, 1, iC*mC, oC] always - // bias [oC], oC = iC*mC if weightsPoint=nullptr - // output is [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) - - // kH filter(kernel) height - // kW filter(kernel) width - // sH strides height - // sW strides width - // pH paddings height - // pW paddings width - // dH dilations height - // dW dilations width - // isSameMode 0-VALID, 1-SAME - // isNCHW 1-NCHW, 0-NHWC - - int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier, output channels, output height/width - int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); - mC = weightsDepth->sizeAt(indWmC); // channels multiplier - - NDArray* outputDepth = output; - if(weightsPoint) // if pointwise convolution is expected - outputDepth = new NDArray(output->ordering(), !isNCHW ? std::vector({bS, oH, oW, iC*mC}) : std::vector({bS, iC*mC, oH, oW}), input->dataType(), input->getContext()); - - // ----- perform depthwise convolution (if weightsPoint is absent then oC = iC*mC) ----- // - ConvolutionUtils::depthwiseConv2d(block, input, weightsDepth, weightsPoint ? nullptr : bias, outputDepth, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW); - - // ----- perform pointwise convolution (oH = iH, oW = iW) ----- // - if (weightsPoint) { - ConvolutionUtils::conv2d(block, outputDepth, weightsPoint, bias, output, 1,1, 1,1, 0,0, 1,1, isSameMode, isNCHW); // in this case oH=iH, oW=iW - delete outputDepth; - } -} - -////////////////////////////////////////////////////////////////////////// -template -static void upsampling2d_(const NDArray& input, NDArray& output, const int factorH, const int factorW, const bool isNCHW) { - // input has shape [bS, iC, iH, iW] (NCHW) or [bS, iH, iW, iC] (NHWC) - // output has shape [bS, iC, factorH*iH, factorW*iW ] (NCHW) or [bS, factorH*iH, factorW*iW, iC] (NHWC) - - std::vector indIn = {0,0, 0,0, 0,0, 0,0}; - std::vector indOut = {0,0, 0,0, 0,0, 0,0}; - const int dimIH = isNCHW ? 2 : 1; - const int j0 = 2*dimIH; - const int j1 = j0+1, j2 = j0+2, j3 = j0+3; - const int size0 = input.sizeAt(dimIH) * input.sizeAt(dimIH+1); - // const int size1 = factorH * factorW; - - int iT = input.sizeAt(dimIH); - int iH = input.sizeAt(dimIH + 1); - - PRAGMA_OMP_PARALLEL_FOR_ARGS(collapse(2) firstprivate(indIn, indOut)) - for(int ih = 0; ih < iT; ++ih) { - for(int iw = 0; iw < iH; ++iw) { - indIn[j0] = ih; indIn[j1] = ih+1; - indIn[j2] = iw; indIn[j3] = iw+1; - - for(int fh = 0; fh < factorH; ++fh) { - for(int fw = 0; fw < factorW; ++fw) { - - indOut[j0] = ih * factorH + fh; indOut[j1] = indOut[j0] + 1; - indOut[j2] = iw * factorW + fw; indOut[j3] = indOut[j2] + 1; - auto i = input(indIn); - auto o = output(indOut); - o.assign(i); - } + if(!isNCHW) { + delete input; + delete gradI; } } - } -} ////////////////////////////////////////////////////////////////////////// -template -static void upsampling3d_(const NDArray& input, NDArray& output, const int factorD, const int factorH, const int factorW, const bool isNCDHW) { - // input has shape [bS, iC, iD, iH, iW] (NCDHW) or [bS, iD, iH, iW, iC] (NDHWC) - // output has shape [bS, iC, factorD*iD, factorH*iH, factorW*iW ] (NCDHW) or [bS, factorD*iD, factorH*iH, factorW*iW, iC] (NDHWC) - std::vector indIn = {0,0, 0,0, 0,0, 0,0, 0,0}; - std::vector indOut = {0,0, 0,0, 0,0, 0,0, 0,0}; - const int dimID = isNCDHW ? 2 : 1; - const int j0 = 2*dimID; - const int j1 = j0+1, j2 = j0+2, j3 = j0+3, j4 = j0+4, j5 = j0+5;; - const int size0 = input.sizeAt(dimID) * input.sizeAt(dimID+1) * input.sizeAt(dimID+2); - // const int size1 = factorD * factorH * factorW; + template + static void depthwiseConv2d_(const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) { - int l0 = input.sizeAt(dimID); - int l1 = input.sizeAt(dimID + 1); - int l2 = input.sizeAt(dimID + 2); + // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + // weights [kH, kW, iC, mC] always + // bias [oC] = iC*mC + // output [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW) - PRAGMA_OMP_PARALLEL_FOR_ARGS(collapse(2) firstprivate(indIn, indOut)) - for(int id = 0; id < l0; ++id) { - for(int ih = 0; ih < l1; ++ih) { - for(int iw = 0; iw < l2; ++iw) { - indIn[j0] = id; indIn[j1] = id+1; - indIn[j2] = ih; indIn[j3] = ih+1; - indIn[j4] = iw; indIn[j5] = iw+1; + // kH filter(kernel) height + // kW filter(kernel) width + // sH strides height + // sW strides width + // pH paddings height + // pW paddings width + // dH dilations height + // dW dilations width + // isSameMode 0-VALID, 1-SAME + // isNCHW 0-NCHW, 1-NHWC - for(int fd = 0; fd < factorD; ++fd) { - for(int fh = 0; fh < factorH; ++fh) { - for(int fw = 0; fw < factorW; ++fw) { - indOut[j0] = id * factorD + fd; indOut[j1] = indOut[j0] + 1; - indOut[j2] = ih * factorH + fh; indOut[j3] = indOut[j2] + 1; - indOut[j4] = iw * factorW + fw; indOut[j5] = indOut[j4] + 1; - auto i = input(indIn); + int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width + int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); + mC = weights->sizeAt(indWmC); // channels multiplier + + std::vector> modifColumns = {{1,0,4,5,2,3}, {iC,bS*oH*oW,kH*kW}}; // [bS,iC,kH,kW,oH,oW] -> [iC,bS,oH,oW,kH,kW] -> [iC,bS*oH*oW,kH*kW] + std::vector> modifOutput; + std::vector outReShape; + + if(!isNCHW) { + input = input->permute({0, 3, 1, 2}); // [bS,iH,iW,iC] -> [bS,iC,iH,iW] + outReShape = {bS, oH, oW, iC, mC}; // [bS,oH,oW,iC*mC] -> [bS,oH,oW,iC,mC] + modifOutput = {{3,0,1,2,4},{iC, bS*oH*oW, mC}}; // [bS,oH,oW,iC,mC] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC] + } + else { + outReShape = {bS, iC, mC, oH, oW}; // [bS,iC*mC,oH,oW] -> [bS,iC,mC,oH,oW] + modifOutput = {{1,0,3,4,2},{iC, bS*oH*oW, mC}}; // [bS,iC,mC,oH,oW] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC] + } + + if(isSameMode) // SAME + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); + + NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, input->dataType(), input->getContext()); + NDArray* outputReshaped = output->reshape(output->ordering(), outReShape); + + helpers::im2col(*output->getContext(), *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW] + MmulHelper::tensorDot(&columns, weights, outputReshaped, modifColumns, {{2,0,1,3},{iC,kH*kW,mC}}, modifOutput); // [iC, bS*oH*oW, kW*kH] x [iC, kH*kW, mC] = [iC, bS*oH*oW, mC] + + if(bias) + output->applyBroadcast(broadcast::Add, {indIOioC}, bias); + + if(!isNCHW) + delete input; + + delete outputReshaped; + } + +////////////////////////////////////////////////////////////////////////// + template + static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) { + + // input [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW) + // weights [kH, kW, iC, mC] always + // bias [oC] = [iC*mC] + // gradO [bS, oH, oW, oC] (NDHWC) or [bS, oC, oH, oW] (NCDHW), epsilon_next + // gradI [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW), epsilon + // gradW [kH, kW, iC, mC] always + // gradB [oC] + + // kH filter(kernel) height + // kW filter(kernel) width + // sH strides height + // sW strides width + // pH paddings height + // pW paddings width + // dH dilations height + // dW dilations width + // isSameMode 0-VALID, 1-SAME + // isNCHW 0-NHWC, 1-NCHW + + int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier(oC = iC*mC), output channels, output height/width + int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); + mC = weights->sizeAt(indWmC); // channels multiplier + + std::vector> modifColumns = {{1,2,3,0,4,5}, {iC, kH*kW, bS*oH*oW}}; // [bS,iC,kH,kW,oH,oW] -> [iC, kH*kW, bS*oH*oW] + std::vector> modifGradO1, modifGradO2; + std::vector gradOreShape; + + if(!isNCHW) { + input = input->permute({0, 3, 1, 2}); // [bS,iH,iW,iC] -> [bS,iC,iH,iW] + gradI = gradI->permute({0, 3, 1, 2}); // [bS,iH,iW,iC] -> [bS,iC,iH,iW] + gradOreShape = {bS, oH, oW, iC, mC}; // [bS,oH,oW,iC*mC] -> [bS,oH,oW,iC,mC] + modifGradO1 = {{3,0,1,2,4},{iC, bS*oH*oW, mC}}; // [bS,oH,oW,iC,mC] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC] + modifGradO2 = {{3,0,1,2},{iC, mC, bS*oH*oW}}; // [bS,oH,oW,iC*mC] -> [iC*mC,bS,oH,oW] -> [iC,mC,bS*oH*oW] + } + else { + gradOreShape = {bS, iC, mC, oH, oW}; // [bS,iC*mC,oH,oW] -> [bS,iC,mC,oH,oW] + modifGradO1 = {{1,0,3,4,2},{iC, bS*oH*oW, mC}}; // [bS,iC,mC,oH,oW] -> [iC,bS,oH,oW,mC] -> [iC,bS*oH*oW,mC] + modifGradO2 = {{1,0,2,3},{iC, mC, bS*oH*oW}}; // [bS,iC*mC,oH,oW] -> [iC*mC,bS,oH,oW] -> [iC,mC,bS*oH*oW] + } + + if(isSameMode) // SAME + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); + + NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, input->dataType(), input->getContext()); + NDArray* gradOreshaped = gradO->reshape(gradO->ordering(), gradOreShape); + + // ----- calculation of gradW and gradB ----- // + + helpers::im2col(*input->getContext(), *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW] + nd4j::MmulHelper::tensorDot(&columns, gradOreshaped, gradW, modifColumns, modifGradO1, {{2,0,1,3},{iC,kH*kW,mC}}); // [iC, kW*kH, bS*oH*oW] x [iC, bS*oH*oW, mC] = [iC, kH*kW, mC] + + // ----- calculation of gradB ----- // + if(gradB) { + NDArray* gradBR = gradB; + if(gradB->rankOf() == 2) + gradBR = gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}); + gradO->reduceAlongDimension(reduce::Sum, gradBR, {0,indOoH,indOoH+1}); // sum over bS, oH, oW + if(gradBR != gradB) + delete gradBR; + } + + //----- calculation of gradI -----// + nd4j::MmulHelper::tensorDot(weights, gradO, &columns, {{2,0,1,3},{iC,kH*kW,mC}}, modifGradO2, modifColumns); // [iC, kH*kW, mC] x [iC, mC, bS*oH*oW] = [iC, kW*kH, bS*oH*oW] + helpers::col2im(*input->getContext(), columns, *gradI, sH, sW, pH, pW, iH, iW, dH, dW); // [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW] + + if(!isNCHW) { + delete input; + delete gradI; + } + + delete gradOreshaped; + } + +////////////////////////////////////////////////////////////////////////// + template + static void sconv2d_(nd4j::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) { + + // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + // weightsDepth [kH, kW, iC, mC] always + // weightsPoint [1, 1, iC*mC, oC] always + // bias [oC], oC = iC*mC if weightsPoint=nullptr + // output is [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) + + // kH filter(kernel) height + // kW filter(kernel) width + // sH strides height + // sW strides width + // pH paddings height + // pW paddings width + // dH dilations height + // dW dilations width + // isSameMode 0-VALID, 1-SAME + // isNCHW 1-NCHW, 0-NHWC + + int bS, iC, iH, iW, mC, oC, oH, oW; // batch size, input channels, input height/width, channels multiplier, output channels, output height/width + int indIOioC, indIiH, indWmC, indWiC, indWkH, indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH); + mC = weightsDepth->sizeAt(indWmC); // channels multiplier + + NDArray* outputDepth = output; + if(weightsPoint) // if pointwise convolution is expected + outputDepth = new NDArray(output->ordering(), !isNCHW ? std::vector({bS, oH, oW, iC*mC}) : std::vector({bS, iC*mC, oH, oW}), input->dataType(), input->getContext()); + + // ----- perform depthwise convolution (if weightsPoint is absent then oC = iC*mC) ----- // + ConvolutionUtils::depthwiseConv2d(block, input, weightsDepth, weightsPoint ? nullptr : bias, outputDepth, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW); + + // ----- perform pointwise convolution (oH = iH, oW = iW) ----- // + if (weightsPoint) { + ConvolutionUtils::conv2d(block, outputDepth, weightsPoint, bias, output, 1,1, 1,1, 0,0, 1,1, isSameMode, isNCHW); // in this case oH=iH, oW=iW + delete outputDepth; + } + } + +////////////////////////////////////////////////////////////////////////// + template + static void upsampling2d_(const NDArray& input, NDArray& output, const int factorH, const int factorW, const bool isNCHW) { + // input has shape [bS, iC, iH, iW] (NCHW) or [bS, iH, iW, iC] (NHWC) + // output has shape [bS, iC, factorH*iH, factorW*iW ] (NCHW) or [bS, factorH*iH, factorW*iW, iC] (NHWC) + + std::vector indIn = {0,0, 0,0, 0,0, 0,0}; + std::vector indOut = {0,0, 0,0, 0,0, 0,0}; + const int dimIH = isNCHW ? 2 : 1; + const int j0 = 2*dimIH; + const int j1 = j0+1, j2 = j0+2, j3 = j0+3; + const int size0 = input.sizeAt(dimIH) * input.sizeAt(dimIH+1); + // const int size1 = factorH * factorW; + + int iT = input.sizeAt(dimIH); + int iH = input.sizeAt(dimIH + 1); + + PRAGMA_OMP_PARALLEL_FOR_ARGS(collapse(2) firstprivate(indIn, indOut)) + for(int ih = 0; ih < iT; ++ih) { + for(int iw = 0; iw < iH; ++iw) { + indIn[j0] = ih; indIn[j1] = ih+1; + indIn[j2] = iw; indIn[j3] = iw+1; + + for(int fh = 0; fh < factorH; ++fh) { + for(int fw = 0; fw < factorW; ++fw) { + + indOut[j0] = ih * factorH + fh; indOut[j1] = indOut[j0] + 1; + indOut[j2] = iw * factorW + fw; indOut[j3] = indOut[j2] + 1; + auto i = input(indIn); auto o = output(indOut); o.assign(i); } @@ -1009,86 +1083,81 @@ static void upsampling3d_(const NDArray& input, NDArray& output, const int facto } } } - } -} ////////////////////////////////////////////////////////////////////////// -template -static void upsampling2dBP_(const NDArray& gradO, NDArray& gradI, const bool isNCHW) { - // gradO has shape [bS, iC, factorH*iH, factorW*iW ] (NCHW) or [bS, factorH*iH, factorW*iW, iC] (NHWC) - // gradI has shape [bS, iC, iH, iW] (NCHW) or [bS, iH, iW, iC] (NHWC) - std::vector indIn = {0,0, 0,0, 0,0, 0,0}; - std::vector indOut = {0,0, 0,0, 0,0, 0,0}; - const int dimIH = isNCHW ? 2 : 1; - const int factorH = gradO.sizeAt(dimIH) / gradI.sizeAt(dimIH); - const int factorW = gradO.sizeAt(dimIH+1) / gradI.sizeAt(dimIH+1); - const int j0 = 2*dimIH; - const int j1 = j0+1, j2 = j0+2, j3 = j0+3; - const int size0 = gradI.sizeAt(dimIH) * gradI.sizeAt(dimIH+1); + template + static void upsampling3d_(const NDArray& input, NDArray& output, const int factorD, const int factorH, const int factorW, const bool isNCDHW) { + // input has shape [bS, iC, iD, iH, iW] (NCDHW) or [bS, iD, iH, iW, iC] (NDHWC) + // output has shape [bS, iC, factorD*iD, factorH*iH, factorW*iW ] (NCDHW) or [bS, factorD*iD, factorH*iH, factorW*iW, iC] (NDHWC) + std::vector indIn = {0,0, 0,0, 0,0, 0,0, 0,0}; + std::vector indOut = {0,0, 0,0, 0,0, 0,0, 0,0}; + const int dimID = isNCDHW ? 2 : 1; + const int j0 = 2*dimID; + const int j1 = j0+1, j2 = j0+2, j3 = j0+3, j4 = j0+4, j5 = j0+5;; + const int size0 = input.sizeAt(dimID) * input.sizeAt(dimID+1) * input.sizeAt(dimID+2); + // const int size1 = factorD * factorH * factorW; - int l0 = gradI.sizeAt(dimIH); - int l1 = gradI.sizeAt(dimIH + 1); + int l0 = input.sizeAt(dimID); + int l1 = input.sizeAt(dimID + 1); + int l2 = input.sizeAt(dimID + 2); - PRAGMA_OMP_PARALLEL_FOR_ARGS(collapse(2) firstprivate(indIn, indOut)) - for(int ih = 0; ih < l0; ++ih) { - for(int iw = 0; iw < l1; ++iw) { - indIn[j0] = ih; indIn[j1] = ih+1; - indIn[j2] = iw; indIn[j3] = iw+1; - NDArray subGradI = gradI(indIn); + PRAGMA_OMP_PARALLEL_FOR_ARGS(collapse(2) firstprivate(indIn, indOut)) + for(int id = 0; id < l0; ++id) { + for(int ih = 0; ih < l1; ++ih) { + for(int iw = 0; iw < l2; ++iw) { + indIn[j0] = id; indIn[j1] = id+1; + indIn[j2] = ih; indIn[j3] = ih+1; + indIn[j4] = iw; indIn[j5] = iw+1; - for(int fh = 0; fh < factorH; ++fh) { - for(int fw = 0; fw < factorW; ++fw) { - indOut[j0] = ih * factorH + fh; indOut[j1] = indOut[j0] + 1; - indOut[j2] = iw * factorW + fw; indOut[j3] = indOut[j2] + 1; - auto o = gradO(indOut); - if(!fh && !fw) { - subGradI.assign(o); + for(int fd = 0; fd < factorD; ++fd) { + for(int fh = 0; fh < factorH; ++fh) { + for(int fw = 0; fw < factorW; ++fw) { + indOut[j0] = id * factorD + fd; indOut[j1] = indOut[j0] + 1; + indOut[j2] = ih * factorH + fh; indOut[j3] = indOut[j2] + 1; + indOut[j4] = iw * factorW + fw; indOut[j5] = indOut[j4] + 1; + auto i = input(indIn); + auto o = output(indOut); + o.assign(i); + } + } + } } - else - subGradI += o; } } } - } -} ////////////////////////////////////////////////////////////////////////// -template -static void upsampling3dBP_(const NDArray& gradO, NDArray& gradI, const bool isNCDHW) { - // input has shape [bS, iC, iD, iH, iW] (NCDHW) or [bS, iD, iH, iW, iC] (NDHWC) - // output has shape [bS, iC, factorD*iD, factorH*iH, factorW*iW ] (NCDHW) or [bS, factorD*iD, factorH*iH, factorW*iW, iC] (NDHWC) - std::vector indIn = {0,0, 0,0, 0,0, 0,0, 0,0}; - std::vector indOut = {0,0, 0,0, 0,0, 0,0, 0,0}; - const int dimID = isNCDHW ? 2 : 1; - const int factorD = gradO.sizeAt(dimID) / gradI.sizeAt(dimID); - const int factorH = gradO.sizeAt(dimID+1) / gradI.sizeAt(dimID+1); - const int factorW = gradO.sizeAt(dimID+2) / gradI.sizeAt(dimID+2); - const int j0 = 2*dimID; - const int j1 = j0+1, j2 = j0+2, j3 = j0+3, j4 = j0+4, j5 = j0+5;; - const int size0 = gradI.sizeAt(dimID) * gradI.sizeAt(dimID+1) * gradI.sizeAt(dimID+2); + template + static void upsampling2dBP_(const NDArray& gradO, NDArray& gradI, const bool isNCHW) { + // gradO has shape [bS, iC, factorH*iH, factorW*iW ] (NCHW) or [bS, factorH*iH, factorW*iW, iC] (NHWC) + // gradI has shape [bS, iC, iH, iW] (NCHW) or [bS, iH, iW, iC] (NHWC) + std::vector indIn = {0,0, 0,0, 0,0, 0,0}; + std::vector indOut = {0,0, 0,0, 0,0, 0,0}; + const int dimIH = isNCHW ? 2 : 1; + const int factorH = gradO.sizeAt(dimIH) / gradI.sizeAt(dimIH); + const int factorW = gradO.sizeAt(dimIH+1) / gradI.sizeAt(dimIH+1); + const int j0 = 2*dimIH; + const int j1 = j0+1, j2 = j0+2, j3 = j0+3; + const int size0 = gradI.sizeAt(dimIH) * gradI.sizeAt(dimIH+1); - int l0 = gradI.sizeAt(dimID); - int l1 = gradI.sizeAt(dimID + 1); - int l2 = gradI.sizeAt(dimID + 2); + int l0 = gradI.sizeAt(dimIH); + int l1 = gradI.sizeAt(dimIH + 1); - PRAGMA_OMP_PARALLEL_FOR_ARGS(collapse(3) firstprivate(indOut, indIn)) - for(int id = 0; id < l0; ++id) { - for(int ih = 0; ih < l1; ++ih) { - for(int iw = 0; iw < l2; ++iw) { - indIn[j0] = id; indIn[j1] = id+1; - indIn[j2] = ih; indIn[j3] = ih+1; - indIn[j4] = iw; indIn[j5] = iw+1; - NDArray subGradI = gradI(indIn); + PRAGMA_OMP_PARALLEL_FOR_ARGS(collapse(2) firstprivate(indIn, indOut)) + for(int ih = 0; ih < l0; ++ih) { + for(int iw = 0; iw < l1; ++iw) { + indIn[j0] = ih; indIn[j1] = ih+1; + indIn[j2] = iw; indIn[j3] = iw+1; + NDArray subGradI = gradI(indIn); - for(int fd = 0; fd < factorD; ++fd) { - for(int fh = 0; fh < factorH; ++fh) { - for(int fw = 0; fw < factorW; ++fw) { - indOut[j0] = id * factorD + fd; indOut[j1] = indOut[j0] + 1; - indOut[j2] = ih * factorH + fh; indOut[j3] = indOut[j2] + 1; - indOut[j4] = iw * factorW + fw; indOut[j5] = indOut[j4] + 1; + for(int fh = 0; fh < factorH; ++fh) { + for(int fw = 0; fw < factorW; ++fw) { + indOut[j0] = ih * factorH + fh; indOut[j1] = indOut[j0] + 1; + indOut[j2] = iw * factorW + fw; indOut[j3] = indOut[j2] + 1; auto o = gradO(indOut); - if(!fd && !fh && !fw) + if(!fh && !fw) { subGradI.assign(o); + } else subGradI += o; } @@ -1096,147 +1165,76 @@ static void upsampling3dBP_(const NDArray& gradO, NDArray& gradI, const bool isN } } } - } -} - - -#ifdef HAVE_MKLDNN -using namespace mkldnn; - -void ConvolutionUtils::getMKLDNNMemoryDescPool2d( - int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, int poolingMode, int extraParam0, bool isNCHW, - int bS, int iC, int iH, int iW, int oC, int oH, int oW, - const NDArray* src, const NDArray* diff_src, const NDArray* dst, mkldnn::algorithm& algorithm, - mkldnn::memory::desc* pool_src_md, mkldnn::memory::desc* pool_diff_src_md, mkldnn::memory::desc* pool_dst_md, - mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md, - mkldnn::memory::dims& pool_strides, mkldnn::memory::dims& pool_kernel, mkldnn::memory::dims& pool_padding, mkldnn::memory::dims& pool_padding_r) { - mkldnn::memory::dims pool_src_tz = { bS, iC, iH, iW }; - mkldnn::memory::dims pool_dst_tz = { bS, oC, oH, oW }; - - pool_strides = { sH, sW }; - pool_kernel = { kH, kW }; - pool_padding = { pH, pW }; - pool_padding_r = { (oH - 1) * sH - iH + kH - pH, - (oW - 1) * sW - iW + kW - pW }; - - algorithm = poolingMode == 0 ? pooling_max - : extraParam0 == 0 ? pooling_avg_exclude_padding - : pooling_avg_include_padding; - auto type = mkldnn::memory::data_type::f32; - auto format = isNCHW ? mkldnn::memory::format::nchw : mkldnn::memory::format::nhwc; - auto supposed_to_be_any_format = mkldnn::memory::format::nChw8c; // doesn't work with "any" - - if (src != nullptr && src->getBuffer() != nullptr && pool_src_md != nullptr) { - *pool_src_md = mkldnn::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format); - *user_src_md = mkldnn::memory::desc({ pool_src_tz }, type, format); - user_src_md->data.format = mkldnn_blocked; // overrides "format = isNCHW ? nchw : nhwc" - user_src_md->data.layout_desc.blocking.strides[0][0] = src->stridesOf()[isNCHW ? 0 : 0]; - user_src_md->data.layout_desc.blocking.strides[0][1] = src->stridesOf()[isNCHW ? 1 : 3]; - user_src_md->data.layout_desc.blocking.strides[0][2] = src->stridesOf()[isNCHW ? 2 : 1]; - user_src_md->data.layout_desc.blocking.strides[0][3] = src->stridesOf()[isNCHW ? 3 : 2]; - } - - if (diff_src != nullptr && diff_src->getBuffer() != nullptr && pool_diff_src_md != nullptr) { - *pool_diff_src_md = mkldnn::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format); - *user_diff_src_md = mkldnn::memory::desc({ pool_src_tz }, type, format); - user_diff_src_md->data.format = mkldnn_blocked; // overrides "format = isNCHW ? nchw : nhwc" - user_diff_src_md->data.layout_desc.blocking.strides[0][0] = diff_src->stridesOf()[isNCHW ? 0 : 0]; - user_diff_src_md->data.layout_desc.blocking.strides[0][1] = diff_src->stridesOf()[isNCHW ? 1 : 3]; - user_diff_src_md->data.layout_desc.blocking.strides[0][2] = diff_src->stridesOf()[isNCHW ? 2 : 1]; - user_diff_src_md->data.layout_desc.blocking.strides[0][3] = diff_src->stridesOf()[isNCHW ? 3 : 2]; - } - - if (dst != nullptr && dst->getBuffer() != nullptr && pool_dst_md != nullptr) { - *pool_dst_md = mkldnn::memory::desc({ pool_dst_tz }, type, supposed_to_be_any_format); - *user_dst_md = mkldnn::memory::desc({ pool_dst_tz }, type, format); - user_dst_md->data.format = mkldnn_blocked; // overrides "format = isNCHW ? nchw : nhwc" - user_dst_md->data.layout_desc.blocking.strides[0][0] = dst->stridesOf()[isNCHW ? 0 : 0]; - user_dst_md->data.layout_desc.blocking.strides[0][1] = dst->stridesOf()[isNCHW ? 1 : 3]; - user_dst_md->data.layout_desc.blocking.strides[0][2] = dst->stridesOf()[isNCHW ? 2 : 1]; - user_dst_md->data.layout_desc.blocking.strides[0][3] = dst->stridesOf()[isNCHW ? 3 : 2]; - } -} - -void ConvolutionUtils::getMKLDNNMemoryDescPool3d( - int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, int poolingMode, int extraParam0, bool isNCDHW, - int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW, - const NDArray* src, const NDArray* diff_src, const NDArray* dst, mkldnn::algorithm& algorithm, - mkldnn::memory::desc* pool_src_md, mkldnn::memory::desc* pool_diff_src_md, mkldnn::memory::desc* pool_dst_md, - mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md, - mkldnn::memory::dims& pool_strides, mkldnn::memory::dims& pool_kernel, mkldnn::memory::dims& pool_padding, mkldnn::memory::dims& pool_padding_r) { - mkldnn::memory::dims pool_src_tz = { bS, iC, iD, iH, iW }; - mkldnn::memory::dims pool_dst_tz = { bS, oC, oD, oH, oW }; - - pool_strides = { sD, sH, sW }; - pool_kernel = { kD, kH, kW }; - pool_padding = { pD, pH, pW }; - pool_padding_r = { (oD - 1) * sD - iD + kD - pD, - (oH - 1) * sH - iH + kH - pH, - (oW - 1) * sW - iW + kW - pW }; - - algorithm = poolingMode == 0 ? pooling_max - : extraParam0 == 0 ? pooling_avg_exclude_padding - : pooling_avg_include_padding; - auto type = mkldnn::memory::data_type::f32; - auto format = isNCDHW ? mkldnn::memory::format::ncdhw : mkldnn::memory::format::ndhwc; - auto supposed_to_be_any_format = mkldnn::memory::format::nCdhw8c; // doesn't work with "any" - - if (src != nullptr && src->getBuffer() != nullptr && pool_src_md != nullptr) { - *pool_src_md = mkldnn::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format); - *user_src_md = mkldnn::memory::desc({ pool_src_tz }, type, format); - user_src_md->data.format = mkldnn_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc" - user_src_md->data.layout_desc.blocking.strides[0][0] = src->stridesOf()[isNCDHW ? 0 : 0]; - user_src_md->data.layout_desc.blocking.strides[0][1] = src->stridesOf()[isNCDHW ? 1 : 4]; - user_src_md->data.layout_desc.blocking.strides[0][2] = src->stridesOf()[isNCDHW ? 2 : 1]; - user_src_md->data.layout_desc.blocking.strides[0][3] = src->stridesOf()[isNCDHW ? 3 : 2]; - user_src_md->data.layout_desc.blocking.strides[0][4] = src->stridesOf()[isNCDHW ? 4 : 3]; - } - - if (diff_src != nullptr && diff_src->getBuffer() != nullptr && pool_diff_src_md != nullptr) { - *pool_diff_src_md = mkldnn::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format); - *user_diff_src_md = mkldnn::memory::desc({ pool_src_tz }, type, format); - user_diff_src_md->data.format = mkldnn_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc" - user_diff_src_md->data.layout_desc.blocking.strides[0][0] = diff_src->stridesOf()[isNCDHW ? 0 : 0]; - user_diff_src_md->data.layout_desc.blocking.strides[0][1] = diff_src->stridesOf()[isNCDHW ? 1 : 4]; - user_diff_src_md->data.layout_desc.blocking.strides[0][2] = diff_src->stridesOf()[isNCDHW ? 2 : 1]; - user_diff_src_md->data.layout_desc.blocking.strides[0][3] = diff_src->stridesOf()[isNCDHW ? 3 : 2]; - user_diff_src_md->data.layout_desc.blocking.strides[0][4] = diff_src->stridesOf()[isNCDHW ? 4 : 3]; - } - - if (dst != nullptr && dst->getBuffer() != nullptr && pool_dst_md != nullptr) { - *pool_dst_md = mkldnn::memory::desc({ pool_dst_tz }, type, supposed_to_be_any_format); - *user_dst_md = mkldnn::memory::desc({ pool_dst_tz }, type, format); - user_dst_md->data.format = mkldnn_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc" - user_dst_md->data.layout_desc.blocking.strides[0][0] = dst->stridesOf()[isNCDHW ? 0 : 0]; - user_dst_md->data.layout_desc.blocking.strides[0][1] = dst->stridesOf()[isNCDHW ? 1 : 4]; - user_dst_md->data.layout_desc.blocking.strides[0][2] = dst->stridesOf()[isNCDHW ? 2 : 1]; - user_dst_md->data.layout_desc.blocking.strides[0][3] = dst->stridesOf()[isNCDHW ? 3 : 2]; - user_dst_md->data.layout_desc.blocking.strides[0][4] = dst->stridesOf()[isNCDHW ? 4 : 3]; - } -} -#endif ////////////////////////////////////////////////////////////////////////// -template -static void pooling2d_(nd4j::LaunchContext & block, const NDArray& input, NDArray& output, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0) { - // input is [bS, iC, iH, iW] - // output is [bS, iC, oH, oW] - T* out = output.bufferAsT(); - T* in = const_cast(input).bufferAsT(); + template + static void upsampling3dBP_(const NDArray& gradO, NDArray& gradI, const bool isNCDHW) { + // input has shape [bS, iC, iD, iH, iW] (NCDHW) or [bS, iD, iH, iW, iC] (NDHWC) + // output has shape [bS, iC, factorD*iD, factorH*iH, factorW*iW ] (NCDHW) or [bS, factorD*iD, factorH*iH, factorW*iW, iC] (NDHWC) + std::vector indIn = {0,0, 0,0, 0,0, 0,0, 0,0}; + std::vector indOut = {0,0, 0,0, 0,0, 0,0, 0,0}; + const int dimID = isNCDHW ? 2 : 1; + const int factorD = gradO.sizeAt(dimID) / gradI.sizeAt(dimID); + const int factorH = gradO.sizeAt(dimID+1) / gradI.sizeAt(dimID+1); + const int factorW = gradO.sizeAt(dimID+2) / gradI.sizeAt(dimID+2); + const int j0 = 2*dimID; + const int j1 = j0+1, j2 = j0+2, j3 = j0+3, j4 = j0+4, j5 = j0+5;; + const int size0 = gradI.sizeAt(dimID) * gradI.sizeAt(dimID+1) * gradI.sizeAt(dimID+2); - const int kHEff = kH + (kH-1)*(dH-1); - const int kWEff = kW + (kW-1)*(dW-1); + int l0 = gradI.sizeAt(dimID); + int l1 = gradI.sizeAt(dimID + 1); + int l2 = gradI.sizeAt(dimID + 2); - const int bS = input.sizeAt(0); - const int iC = input.sizeAt(1); - const int iH = input.sizeAt(2); - const int iW = input.sizeAt(3); - const int oC = output.sizeAt(1); - const int oH = output.sizeAt(2); - const int oW = output.sizeAt(3); + PRAGMA_OMP_PARALLEL_FOR_ARGS(collapse(3) firstprivate(indOut, indIn)) + for(int id = 0; id < l0; ++id) { + for(int ih = 0; ih < l1; ++ih) { + for(int iw = 0; iw < l2; ++iw) { + indIn[j0] = id; indIn[j1] = id+1; + indIn[j2] = ih; indIn[j3] = ih+1; + indIn[j4] = iw; indIn[j5] = iw+1; + NDArray subGradI = gradI(indIn); + + for(int fd = 0; fd < factorD; ++fd) { + for(int fh = 0; fh < factorH; ++fh) { + for(int fw = 0; fw < factorW; ++fw) { + indOut[j0] = id * factorD + fd; indOut[j1] = indOut[j0] + 1; + indOut[j2] = ih * factorH + fh; indOut[j3] = indOut[j2] + 1; + indOut[j4] = iw * factorW + fw; indOut[j5] = indOut[j4] + 1; + auto o = gradO(indOut); + if(!fd && !fh && !fw) + subGradI.assign(o); + else + subGradI += o; + } + } + } + } + } + } + } + + +////////////////////////////////////////////////////////////////////////// + template + static void pooling2d_(nd4j::graph::Context& block, const NDArray& input, NDArray& output, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0) { + // input is [bS, iC, iH, iW] + // output is [bS, iC, oH, oW] + T* out = output.bufferAsT(); + T* in = const_cast(input).bufferAsT(); + + const int kHEff = kH + (kH-1)*(dH-1); + const int kWEff = kW + (kW-1)*(dW-1); + + const int bS = input.sizeAt(0); + const int iC = input.sizeAt(1); + const int iH = input.sizeAt(2); + const int iW = input.sizeAt(3); + const int oC = output.sizeAt(1); + const int oH = output.sizeAt(2); + const int oW = output.sizeAt(3); #ifdef HAVE_MKLDNN - if (poolingMode < 2 && block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported()) { + if (poolingMode < 2 && block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported()) { std::vector& streams = block.getMKLDNNStreams(); if (streams.empty()) { streams.push_back(MKLDNNStream("pooling2d")); @@ -1291,191 +1289,191 @@ static void pooling2d_(nd4j::LaunchContext & block, const NDArray& input, NDArra return; } #endif - nd4j_debug("MKL-DNN is not used for pooling2d!\n", 0); + nd4j_debug("MKL-DNN is not used for pooling2d!\n", 0); - const Nd4jLong iStride0 = input.stridesOf()[0]; - const Nd4jLong iStride1 = input.stridesOf()[1]; - const Nd4jLong iStride2 = input.stridesOf()[2]; - const Nd4jLong iStride3 = input.stridesOf()[3]; - const Nd4jLong oStride0 = output.stridesOf()[0]; - const Nd4jLong oStride1 = output.stridesOf()[1]; - const Nd4jLong oStride2 = output.stridesOf()[2]; - const Nd4jLong oStride3 = output.stridesOf()[3]; - - const Nd4jLong iStep2 = dH*iStride2; - const Nd4jLong iStep3 = dW*iStride3; - const int kProd = kH*kW; + const Nd4jLong iStride0 = input.stridesOf()[0]; + const Nd4jLong iStride1 = input.stridesOf()[1]; + const Nd4jLong iStride2 = input.stridesOf()[2]; + const Nd4jLong iStride3 = input.stridesOf()[3]; + const Nd4jLong oStride0 = output.stridesOf()[0]; + const Nd4jLong oStride1 = output.stridesOf()[1]; + const Nd4jLong oStride2 = output.stridesOf()[2]; + const Nd4jLong oStride3 = output.stridesOf()[3]; - Nd4jLong hstart, wstart, hend, wend; - T *pIn; + const Nd4jLong iStep2 = dH*iStride2; + const Nd4jLong iStep3 = dW*iStride3; + const int kProd = kH*kW; - if(poolingMode == 0) { // max - PRAGMA_OMP_PARALLEL_FOR_ARGS(private(pIn, hstart, wstart, hend, wend) collapse(2)) - for(int b = 0; b < bS; ++b) { - for(int c = 0; c < iC; ++c) { - for(int oh = 0; oh < oH; ++oh) { - for(int ow = 0; ow < oW; ++ow) { - - pIn = in + b * iStride0 + c * iStride1; - - hstart = oh * sH - pH; - wstart = ow * sW - pW; - hend = hstart + kHEff; - wend = wstart + kWEff; + Nd4jLong hstart, wstart, hend, wend; + T *pIn; - if(hstart < 0) - hstart += dH * ((-hstart + dH - 1) / dH); // (Nd4jLong)nd4j::math::nd4j_ceil(static_cast(-hstart) / static_cast(dH)); - if(wstart < 0) - wstart += dW * ((-wstart + dW -1) / dW); //(Nd4jLong)nd4j::math::nd4j_ceil(static_cast(-wstart) / static_cast(dW)); - if(hend > iH) - hend -= dH * ((hend-iH + dH - 1) / dH); //(Nd4jLong)nd4j::math::nd4j_ceil(static_cast(hend-iH) / static_cast(dH)); - if(wend > iW) - wend -= dW * ((wend-iW + dW - 1) / dW); //(Nd4jLong)nd4j::math::nd4j_ceil(static_cast(wend-iW) / static_cast(dW)); + if(poolingMode == 0) { // max + PRAGMA_OMP_PARALLEL_FOR_ARGS(private(pIn, hstart, wstart, hend, wend) collapse(2)) + for(int b = 0; b < bS; ++b) { + for(int c = 0; c < iC; ++c) { + for(int oh = 0; oh < oH; ++oh) { + for(int ow = 0; ow < oW; ++ow) { - hstart *= iStride2; - hend *= iStride2; - wstart *= iStride3; - wend *= iStride3; + pIn = in + b * iStride0 + c * iStride1; - T max = -DataTypeUtils::max(); + hstart = oh * sH - pH; + wstart = ow * sW - pW; + hend = hstart + kHEff; + wend = wstart + kWEff; - for (Nd4jLong kh = hstart; kh < hend; kh += iStep2) - for (Nd4jLong kw = wstart; kw < wend; kw += iStep3) { - T val = pIn[kh + kw]; - if (val > max) - max = val; + if(hstart < 0) + hstart += dH * ((-hstart + dH - 1) / dH); // (Nd4jLong)nd4j::math::nd4j_ceil(static_cast(-hstart) / static_cast(dH)); + if(wstart < 0) + wstart += dW * ((-wstart + dW -1) / dW); //(Nd4jLong)nd4j::math::nd4j_ceil(static_cast(-wstart) / static_cast(dW)); + if(hend > iH) + hend -= dH * ((hend-iH + dH - 1) / dH); //(Nd4jLong)nd4j::math::nd4j_ceil(static_cast(hend-iH) / static_cast(dH)); + if(wend > iW) + wend -= dW * ((wend-iW + dW - 1) / dW); //(Nd4jLong)nd4j::math::nd4j_ceil(static_cast(wend-iW) / static_cast(dW)); + + hstart *= iStride2; + hend *= iStride2; + wstart *= iStride3; + wend *= iStride3; + + T max = -DataTypeUtils::max(); + + for (Nd4jLong kh = hstart; kh < hend; kh += iStep2) + for (Nd4jLong kw = wstart; kw < wend; kw += iStep3) { + T val = pIn[kh + kw]; + if (val > max) + max = val; + } + out[b * oStride0 + c * oStride1 + oh * oStride2 + ow * oStride3] = max; } - out[b * oStride0 + c * oStride1 + oh * oStride2 + ow * oStride3] = max; + } } } } - } - } -/*************************************************************************/ - else if(poolingMode == 1) { // avg - PRAGMA_OMP_PARALLEL_FOR_ARGS(private(pIn, hstart, wstart, hend, wend) collapse(2)) - for(int b = 0; b < bS; ++b) { - for(int c = 0; c < iC; ++c) { - for(int oh = 0; oh < oH; ++oh) { - for(int ow = 0; ow < oW; ++ow) { - - pIn = in + b * iStride0 + c * iStride1; +/*************************************************************************/ + else if(poolingMode == 1) { // avg + PRAGMA_OMP_PARALLEL_FOR_ARGS(private(pIn, hstart, wstart, hend, wend) collapse(2)) + for(int b = 0; b < bS; ++b) { + for(int c = 0; c < iC; ++c) { + for(int oh = 0; oh < oH; ++oh) { + for(int ow = 0; ow < oW; ++ow) { - hstart = oh * sH - pH; - wstart = ow * sW - pW; - hend = hstart + kHEff; - wend = wstart + kWEff; + pIn = in + b * iStride0 + c * iStride1; - if(hstart < 0) - hstart += dH * ((-hstart + dH - 1) / dH); // (Nd4jLong)nd4j::math::nd4j_ceil(static_cast(-hstart) / static_cast(dH)); - if(wstart < 0) - wstart += dW * ((-wstart + dW -1) / dW); //(Nd4jLong)nd4j::math::nd4j_ceil(static_cast(-wstart) / static_cast(dW)); - if(hend > iH) - hend -= dH * ((hend-iH + dH - 1) / dH); //(Nd4jLong)nd4j::math::nd4j_ceil(static_cast(hend-iH) / static_cast(dH)); - if(wend > iW) - wend -= dW * ((wend-iW + dW - 1) / dW); //(Nd4jLong)nd4j::math::nd4j_ceil(static_cast(wend-iW) / static_cast(dW)); + hstart = oh * sH - pH; + wstart = ow * sW - pW; + hend = hstart + kHEff; + wend = wstart + kWEff; - hstart *= iStride2; - hend *= iStride2; - wstart *= iStride3; - wend *= iStride3; + if(hstart < 0) + hstart += dH * ((-hstart + dH - 1) / dH); // (Nd4jLong)nd4j::math::nd4j_ceil(static_cast(-hstart) / static_cast(dH)); + if(wstart < 0) + wstart += dW * ((-wstart + dW -1) / dW); //(Nd4jLong)nd4j::math::nd4j_ceil(static_cast(-wstart) / static_cast(dW)); + if(hend > iH) + hend -= dH * ((hend-iH + dH - 1) / dH); //(Nd4jLong)nd4j::math::nd4j_ceil(static_cast(hend-iH) / static_cast(dH)); + if(wend > iW) + wend -= dW * ((wend-iW + dW - 1) / dW); //(Nd4jLong)nd4j::math::nd4j_ceil(static_cast(wend-iW) / static_cast(dW)); - T sum = static_cast(0.f); + hstart *= iStride2; + hend *= iStride2; + wstart *= iStride3; + wend *= iStride3; - for (Nd4jLong kh = hstart; kh < hend; kh += iStep2) - for (Nd4jLong kw = wstart; kw < wend; kw += iStep3) - sum += pIn[kh + kw]; + T sum = static_cast(0.f); + + for (Nd4jLong kh = hstart; kh < hend; kh += iStep2) + for (Nd4jLong kw = wstart; kw < wend; kw += iStep3) + sum += pIn[kh + kw]; - auto oi = b * oStride0 + c * oStride1 + oh * oStride2 + ow * oStride3; + auto oi = b * oStride0 + c * oStride1 + oh * oStride2 + ow * oStride3; - if (extraParam0 == 0) { //Exclude padding - int _a = (hend-hstart)/iStep2 + ((hend-hstart) % iStep2 == 0 ? 0 : 1); - int _b = (wend-wstart)/iStep3 + ((wend-wstart) % iStep3 == 0 ? 0 : 1); + if (extraParam0 == 0) { //Exclude padding + int _a = (hend-hstart)/iStep2 + ((hend-hstart) % iStep2 == 0 ? 0 : 1); + int _b = (wend-wstart)/iStep3 + ((wend-wstart) % iStep3 == 0 ? 0 : 1); - sum /= _a * _b; //Accounts for dilation - } else if (extraParam0 == 1) //Include padding - sum /= kProd; - - out[oi] = sum; + sum /= _a * _b; //Accounts for dilation + } else if (extraParam0 == 1) //Include padding + sum /= kProd; + + out[oi] = sum; + } + } } } } +/*************************************************************************/ + else if(poolingMode == 2) { // pnorm + PRAGMA_OMP_PARALLEL_FOR_ARGS(private(pIn, hstart, wstart, hend, wend) collapse(2)) + for(int b = 0; b < bS; ++b) { + for(int c = 0; c < iC; ++c) { + for(int oh = 0; oh < oH; ++oh) { + for(int ow = 0; ow < oW; ++ow) { + + pIn = in + b * iStride0 + c * iStride1; + + hstart = oh * sH - pH; + wstart = ow * sW - pW; + hend = hstart + kHEff; + wend = wstart + kWEff; + + if(hstart < 0) + hstart += dH * ((-hstart + dH - 1) / dH); // (Nd4jLong)nd4j::math::nd4j_ceil(static_cast(-hstart) / static_cast(dH)); + if(wstart < 0) + wstart += dW * ((-wstart + dW -1) / dW); //(Nd4jLong)nd4j::math::nd4j_ceil(static_cast(-wstart) / static_cast(dW)); + if(hend > iH) + hend -= dH * ((hend-iH + dH - 1) / dH); //(Nd4jLong)nd4j::math::nd4j_ceil(static_cast(hend-iH) / static_cast(dH)); + if(wend > iW) + wend -= dW * ((wend-iW + dW - 1) / dW); //(Nd4jLong)nd4j::math::nd4j_ceil(static_cast(wend-iW) / static_cast(dW)); + + hstart *= iStride2; + hend *= iStride2; + wstart *= iStride3; + wend *= iStride3; + + T sum = static_cast(0.f); + + for (Nd4jLong kh = hstart; kh < hend; kh += iStep2) + for (Nd4jLong kw = wstart; kw < wend; kw += iStep3) + sum += nd4j::math::nd4j_pow(nd4j::math::nd4j_abs(pIn[kh + kw]), extraParam0); + + sum = nd4j::math::nd4j_pow(sum, static_cast((T)1.f) / extraParam0); + + out[b * oStride0 + c * oStride1 + oh * oStride2 + ow * oStride3] = sum; + } + } + } + } + } + else { + nd4j_printf("ConvolutionUtils::pooling2d: pooling mode argument can take three values only: 0, 1, 2, but got %i instead !\n", poolingMode); + throw ""; + } } - } -/*************************************************************************/ - else if(poolingMode == 2) { // pnorm - PRAGMA_OMP_PARALLEL_FOR_ARGS(private(pIn, hstart, wstart, hend, wend) collapse(2)) - for(int b = 0; b < bS; ++b) { - for(int c = 0; c < iC; ++c) { - for(int oh = 0; oh < oH; ++oh) { - for(int ow = 0; ow < oW; ++ow) { - - pIn = in + b * iStride0 + c * iStride1; - - hstart = oh * sH - pH; - wstart = ow * sW - pW; - hend = hstart + kHEff; - wend = wstart + kWEff; - - if(hstart < 0) - hstart += dH * ((-hstart + dH - 1) / dH); // (Nd4jLong)nd4j::math::nd4j_ceil(static_cast(-hstart) / static_cast(dH)); - if(wstart < 0) - wstart += dW * ((-wstart + dW -1) / dW); //(Nd4jLong)nd4j::math::nd4j_ceil(static_cast(-wstart) / static_cast(dW)); - if(hend > iH) - hend -= dH * ((hend-iH + dH - 1) / dH); //(Nd4jLong)nd4j::math::nd4j_ceil(static_cast(hend-iH) / static_cast(dH)); - if(wend > iW) - wend -= dW * ((wend-iW + dW - 1) / dW); //(Nd4jLong)nd4j::math::nd4j_ceil(static_cast(wend-iW) / static_cast(dW)); - - hstart *= iStride2; - hend *= iStride2; - wstart *= iStride3; - wend *= iStride3; - - T sum = static_cast(0.f); - - for (Nd4jLong kh = hstart; kh < hend; kh += iStep2) - for (Nd4jLong kw = wstart; kw < wend; kw += iStep3) - sum += nd4j::math::nd4j_pow(nd4j::math::nd4j_abs(pIn[kh + kw]), extraParam0); - - sum = nd4j::math::nd4j_pow(sum, static_cast((T)1.f) / extraParam0); - - out[b * oStride0 + c * oStride1 + oh * oStride2 + ow * oStride3] = sum; - } - } - } - } - } - else { - nd4j_printf("ConvolutionUtils::pooling2d: pooling mode argument can take three values only: 0, 1, 2, but got %i instead !\n", poolingMode); - throw ""; - } -} ////////////////////////////////////////////////////////////////////////// -template -static void pooling3d_(nd4j::LaunchContext & block, const NDArray& input, NDArray& output, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0) { - // input is [bS, iC, iD, iH, iW] - // output is [bS, iC, oD, oH, oW] - T* out = output.bufferAsT(); - T* in = const_cast(input).bufferAsT(); + template + static void pooling3d_(nd4j::graph::Context& block, const NDArray& input, NDArray& output, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0) { + // input is [bS, iC, iD, iH, iW] + // output is [bS, iC, oD, oH, oW] + T* out = output.bufferAsT(); + T* in = const_cast(input).bufferAsT(); - const int kDEff = kD + (kD-1)*(dD-1); - const int kHEff = kH + (kH-1)*(dH-1); - const int kWEff = kW + (kW-1)*(dW-1); + const int kDEff = kD + (kD-1)*(dD-1); + const int kHEff = kH + (kH-1)*(dH-1); + const int kWEff = kW + (kW-1)*(dW-1); - const int bS = input.sizeAt(0); - const int iC = input.sizeAt(1); - const int iD = input.sizeAt(2); - const int iH = input.sizeAt(3); - const int iW = input.sizeAt(4); - const int oC = output.sizeAt(1); - const int oD = output.sizeAt(2); - const int oH = output.sizeAt(3); - const int oW = output.sizeAt(4); + const int bS = input.sizeAt(0); + const int iC = input.sizeAt(1); + const int iD = input.sizeAt(2); + const int iH = input.sizeAt(3); + const int iW = input.sizeAt(4); + const int oC = output.sizeAt(1); + const int oD = output.sizeAt(2); + const int oH = output.sizeAt(3); + const int oW = output.sizeAt(4); #ifdef HAVE_MKLDNN - if (poolingMode < 2 && block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported()) { + if (poolingMode < 2 && block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported()) { std::vector& streams = block.getMKLDNNStreams(); if (streams.empty()) { streams.push_back(MKLDNNStream("pooling3d")); @@ -1530,239 +1528,239 @@ static void pooling3d_(nd4j::LaunchContext & block, const NDArray& input, NDArra return; } #endif - nd4j_debug("MKL-DNN is not used for pooling3d!\n", 0); + nd4j_debug("MKL-DNN is not used for pooling3d!\n", 0); - const Nd4jLong iStride0 = input.stridesOf()[0]; - const Nd4jLong iStride1 = input.stridesOf()[1]; - const Nd4jLong iStride2 = input.stridesOf()[2]; - const Nd4jLong iStride3 = input.stridesOf()[3]; - const Nd4jLong iStride4 = input.stridesOf()[4]; - const Nd4jLong oStride0 = output.stridesOf()[0]; - const Nd4jLong oStride1 = output.stridesOf()[1]; - const Nd4jLong oStride2 = output.stridesOf()[2]; - const Nd4jLong oStride3 = output.stridesOf()[3]; - const Nd4jLong oStride4 = output.stridesOf()[4]; - const Nd4jLong iStep2 = dD*iStride2; - const Nd4jLong iStep3 = dH*iStride3; - const Nd4jLong iStep4 = dW*iStride4; - const int kProd = kD*kH*kW; - const T iStep2Inv = 1./iStep2; - const T iStep3Inv = 1./iStep3; - const T iStep4Inv = 1./iStep4; + const Nd4jLong iStride0 = input.stridesOf()[0]; + const Nd4jLong iStride1 = input.stridesOf()[1]; + const Nd4jLong iStride2 = input.stridesOf()[2]; + const Nd4jLong iStride3 = input.stridesOf()[3]; + const Nd4jLong iStride4 = input.stridesOf()[4]; + const Nd4jLong oStride0 = output.stridesOf()[0]; + const Nd4jLong oStride1 = output.stridesOf()[1]; + const Nd4jLong oStride2 = output.stridesOf()[2]; + const Nd4jLong oStride3 = output.stridesOf()[3]; + const Nd4jLong oStride4 = output.stridesOf()[4]; + const Nd4jLong iStep2 = dD*iStride2; + const Nd4jLong iStep3 = dH*iStride3; + const Nd4jLong iStep4 = dW*iStride4; + const int kProd = kD*kH*kW; + const T iStep2Inv = 1./iStep2; + const T iStep3Inv = 1./iStep3; + const T iStep4Inv = 1./iStep4; - Nd4jLong dstart, hstart, wstart, dend, hend, wend; - T sum, *pIn; + Nd4jLong dstart, hstart, wstart, dend, hend, wend; + T sum, *pIn; - if(poolingMode == 0) { // max - PRAGMA_OMP_PARALLEL_FOR_ARGS(private(pIn, sum, dstart, hstart, wstart, dend, hend, wend)) - for(int b = 0; b < bS; ++b) { - for(int c = 0; c < iC; ++c) { - for(int od = 0; od < oD; ++od) { - for(int oh = 0; oh < oH; ++oh) { - for(int ow = 0; ow < oW; ++ow) { - - pIn = in + b * iStride0 + c * iStride1; + if(poolingMode == 0) { // max + PRAGMA_OMP_PARALLEL_FOR_ARGS(private(pIn, sum, dstart, hstart, wstart, dend, hend, wend)) + for(int b = 0; b < bS; ++b) { + for(int c = 0; c < iC; ++c) { + for(int od = 0; od < oD; ++od) { + for(int oh = 0; oh < oH; ++oh) { + for(int ow = 0; ow < oW; ++ow) { - dstart = od * sD - pD; - hstart = oh * sH - pH; - wstart = ow * sW - pW; - dend = dstart + kDEff; - hend = hstart + kHEff; - wend = wstart + kWEff; + pIn = in + b * iStride0 + c * iStride1; - if(dstart < 0) - dstart += dD * ((-dstart + dD - 1) / dD); - if(hstart < 0) - hstart += dH * ((-hstart + dH - 1) / dH); - if(wstart < 0) - wstart += dW * ((-wstart + dW - 1) / dW); - if(dend > iD) - dend -= dD * ((dend-iD + dD - 1) / dD); - if(hend > iH) - hend -= dH * ((hend-iH + dH - 1) / dH); - if(wend > iW) - wend -= dW * ((wend-iW + dW - 1) / dW); + dstart = od * sD - pD; + hstart = oh * sH - pH; + wstart = ow * sW - pW; + dend = dstart + kDEff; + hend = hstart + kHEff; + wend = wstart + kWEff; - dstart *= iStride2; - dend *= iStride2; - hstart *= iStride3; - hend *= iStride3; - wstart *= iStride4; - wend *= iStride4; + if(dstart < 0) + dstart += dD * ((-dstart + dD - 1) / dD); + if(hstart < 0) + hstart += dH * ((-hstart + dH - 1) / dH); + if(wstart < 0) + wstart += dW * ((-wstart + dW - 1) / dW); + if(dend > iD) + dend -= dD * ((dend-iD + dD - 1) / dD); + if(hend > iH) + hend -= dH * ((hend-iH + dH - 1) / dH); + if(wend > iW) + wend -= dW * ((wend-iW + dW - 1) / dW); - sum = -DataTypeUtils::max(); - - for (Nd4jLong kd = dstart; kd < dend; kd += iStep2) - for (Nd4jLong kh = hstart; kh < hend; kh += iStep3) - for (Nd4jLong kw = wstart; kw < wend; kw += iStep4) { - T val = pIn[kd + kh + kw]; - if (val > sum) - sum = val; - } - out[b * oStride0 + c * oStride1 + od * oStride2 + oh * oStride3 + ow * oStride4] = sum; + dstart *= iStride2; + dend *= iStride2; + hstart *= iStride3; + hend *= iStride3; + wstart *= iStride4; + wend *= iStride4; + + sum = -DataTypeUtils::max(); + + for (Nd4jLong kd = dstart; kd < dend; kd += iStep2) + for (Nd4jLong kh = hstart; kh < hend; kh += iStep3) + for (Nd4jLong kw = wstart; kw < wend; kw += iStep4) { + T val = pIn[kd + kh + kw]; + if (val > sum) + sum = val; + } + out[b * oStride0 + c * oStride1 + od * oStride2 + oh * oStride3 + ow * oStride4] = sum; + } + } } } } } - } - } -/*************************************************************************/ - else if(poolingMode == 1) { // avg - PRAGMA_OMP_PARALLEL_FOR_ARGS(private(pIn, sum, dstart, hstart, wstart, dend, hend, wend)) - for(int b = 0; b < bS; ++b) { - for(int c = 0; c < iC; ++c) { - for(int od = 0; od < oD; ++od) { - for(int oh = 0; oh < oH; ++oh) { - for(int ow = 0; ow < oW; ++ow) { - - pIn = in + b * iStride0 + c * iStride1; +/*************************************************************************/ + else if(poolingMode == 1) { // avg + PRAGMA_OMP_PARALLEL_FOR_ARGS(private(pIn, sum, dstart, hstart, wstart, dend, hend, wend)) + for(int b = 0; b < bS; ++b) { + for(int c = 0; c < iC; ++c) { + for(int od = 0; od < oD; ++od) { + for(int oh = 0; oh < oH; ++oh) { + for(int ow = 0; ow < oW; ++ow) { - dstart = od * sD - pD; - hstart = oh * sH - pH; - wstart = ow * sW - pW; - dend = dstart + kDEff; - hend = hstart + kHEff; - wend = wstart + kWEff; + pIn = in + b * iStride0 + c * iStride1; - if(dstart < 0) - dstart += dD * ((-dstart + dD - 1) / dD); - if(hstart < 0) - hstart += dH * ((-hstart + dH - 1) / dH); - if(wstart < 0) - wstart += dW * ((-wstart + dW - 1) / dW); - if(dend > iD) - dend -= dD * ((dend-iD + dD - 1) / dD); - if(hend > iH) - hend -= dH * ((hend-iH + dH - 1) / dH); - if(wend > iW) - wend -= dW * ((wend-iW + dW - 1) / dW); + dstart = od * sD - pD; + hstart = oh * sH - pH; + wstart = ow * sW - pW; + dend = dstart + kDEff; + hend = hstart + kHEff; + wend = wstart + kWEff; - dstart *= iStride2; - dend *= iStride2; - hstart *= iStride3; - hend *= iStride3; - wstart *= iStride4; - wend *= iStride4; + if(dstart < 0) + dstart += dD * ((-dstart + dD - 1) / dD); + if(hstart < 0) + hstart += dH * ((-hstart + dH - 1) / dH); + if(wstart < 0) + wstart += dW * ((-wstart + dW - 1) / dW); + if(dend > iD) + dend -= dD * ((dend-iD + dD - 1) / dD); + if(hend > iH) + hend -= dH * ((hend-iH + dH - 1) / dH); + if(wend > iW) + wend -= dW * ((wend-iW + dW - 1) / dW); - sum = static_cast(0.); - - for (Nd4jLong kd = dstart; kd < dend; kd += iStep2) - for (Nd4jLong kh = hstart; kh < hend; kh += iStep3) - for (Nd4jLong kw = wstart; kw < wend; kw += iStep4) - sum += pIn[kd + kh + kw]; - - if ((int) extraParam0 == 0) //Exclude padding - sum /= static_cast(nd4j::math::nd4j_ceil(static_cast(dend-dstart) / static_cast(iStep2))) * static_cast(nd4j::math::nd4j_ceil(static_cast(hend-hstart) / static_cast(iStep3))) * static_cast(nd4j::math::nd4j_ceil(static_cast(wend-wstart) / static_cast(iStep4))); //Accounts for dilation - else if ((int) extraParam0 == 1) //Include padding - sum /= kProd; - - out[b * oStride0 + c * oStride1 + od * oStride2 + oh * oStride3 + ow * oStride4] = sum; + dstart *= iStride2; + dend *= iStride2; + hstart *= iStride3; + hend *= iStride3; + wstart *= iStride4; + wend *= iStride4; + + sum = static_cast(0.); + + for (Nd4jLong kd = dstart; kd < dend; kd += iStep2) + for (Nd4jLong kh = hstart; kh < hend; kh += iStep3) + for (Nd4jLong kw = wstart; kw < wend; kw += iStep4) + sum += pIn[kd + kh + kw]; + + if ((int) extraParam0 == 0) //Exclude padding + sum /= static_cast(nd4j::math::nd4j_ceil(static_cast(dend-dstart) / static_cast(iStep2))) * static_cast(nd4j::math::nd4j_ceil(static_cast(hend-hstart) / static_cast(iStep3))) * static_cast(nd4j::math::nd4j_ceil(static_cast(wend-wstart) / static_cast(iStep4))); //Accounts for dilation + else if ((int) extraParam0 == 1) //Include padding + sum /= kProd; + + out[b * oStride0 + c * oStride1 + od * oStride2 + oh * oStride3 + ow * oStride4] = sum; + } + } } } } } - } - } -/*************************************************************************/ - else if(poolingMode == 2) { // pnorm - PRAGMA_OMP_PARALLEL_FOR_ARGS(private(pIn, sum, dstart, hstart, wstart, dend, hend, wend)) - for(int b = 0; b < bS; ++b) { - for(int c = 0; c < iC; ++c) { - for(int od = 0; od < oD; ++od) { - for(int oh = 0; oh < oH; ++oh) { - for(int ow = 0; ow < oW; ++ow) { - - pIn = in + b * iStride0 + c * iStride1; +/*************************************************************************/ + else if(poolingMode == 2) { // pnorm + PRAGMA_OMP_PARALLEL_FOR_ARGS(private(pIn, sum, dstart, hstart, wstart, dend, hend, wend)) + for(int b = 0; b < bS; ++b) { + for(int c = 0; c < iC; ++c) { + for(int od = 0; od < oD; ++od) { + for(int oh = 0; oh < oH; ++oh) { + for(int ow = 0; ow < oW; ++ow) { - dstart = od * sD - pD; - hstart = oh * sH - pH; - wstart = ow * sW - pW; - dend = dstart + kDEff; - hend = hstart + kHEff; - wend = wstart + kWEff; + pIn = in + b * iStride0 + c * iStride1; - if(dstart < 0) - dstart += dD * ((-dstart + dD - 1) / dD); - if(hstart < 0) - hstart += dH * ((-hstart + dH - 1) / dH); - if(wstart < 0) - wstart += dW * ((-wstart + dW - 1) / dW); - if(dend > iD) - dend -= dD * ((dend-iD + dD - 1) / dD); - if(hend > iH) - hend -= dH * ((hend-iH + dH - 1) / dH); - if(wend > iW) - wend -= dW * ((wend-iW + dW - 1) / dW); + dstart = od * sD - pD; + hstart = oh * sH - pH; + wstart = ow * sW - pW; + dend = dstart + kDEff; + hend = hstart + kHEff; + wend = wstart + kWEff; - dstart *= iStride2; - dend *= iStride2; - hstart *= iStride3; - hend *= iStride3; - wstart *= iStride4; - wend *= iStride4; + if(dstart < 0) + dstart += dD * ((-dstart + dD - 1) / dD); + if(hstart < 0) + hstart += dH * ((-hstart + dH - 1) / dH); + if(wstart < 0) + wstart += dW * ((-wstart + dW - 1) / dW); + if(dend > iD) + dend -= dD * ((dend-iD + dD - 1) / dD); + if(hend > iH) + hend -= dH * ((hend-iH + dH - 1) / dH); + if(wend > iW) + wend -= dW * ((wend-iW + dW - 1) / dW); - sum = static_cast(0.); - - for (Nd4jLong kd = dstart; kd < dend; kd += iStep2) - for (Nd4jLong kh = hstart; kh < hend; kh += iStep3) - for (Nd4jLong kw = wstart; kw < wend; kw += iStep4) - sum += nd4j::math::nd4j_pow(nd4j::math::nd4j_abs(pIn[kd + kh + kw]), extraParam0); - - sum = nd4j::math::nd4j_pow(sum, (T) 1.f / extraParam0); - - out[b * oStride0 + c * oStride1 + od * oStride2 + oh * oStride3 + ow * oStride4] = sum; + dstart *= iStride2; + dend *= iStride2; + hstart *= iStride3; + hend *= iStride3; + wstart *= iStride4; + wend *= iStride4; + + sum = static_cast(0.); + + for (Nd4jLong kd = dstart; kd < dend; kd += iStep2) + for (Nd4jLong kh = hstart; kh < hend; kh += iStep3) + for (Nd4jLong kw = wstart; kw < wend; kw += iStep4) + sum += nd4j::math::nd4j_pow(nd4j::math::nd4j_abs(pIn[kd + kh + kw]), extraParam0); + + sum = nd4j::math::nd4j_pow(sum, (T) 1.f / extraParam0); + + out[b * oStride0 + c * oStride1 + od * oStride2 + oh * oStride3 + ow * oStride4] = sum; + } + } } } } } + else { + nd4j_printf("ConvolutionUtils::pooling3d: pooling mode argument can take three values only: 0, 1, 2, but got %i instead !\n", poolingMode); + throw ""; + } } - } - else { - nd4j_printf("ConvolutionUtils::pooling3d: pooling mode argument can take three values only: 0, 1, 2, but got %i instead !\n", poolingMode); - throw ""; - } -} ////////////////////////////////////////////////////////////////////////// -template -static void pooling2dBP_(nd4j::LaunchContext & block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0) { - // input [bS, iC, iH, iW] - // gradI [bS, iC, iH, iW] -> gradI is output in this function - // gradO [bS, iC, oH, oW] + template + static void pooling2dBP_(nd4j::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0) { + // input [bS, iC, iH, iW] + // gradI [bS, iC, iH, iW] -> gradI is output in this function + // gradO [bS, iC, oH, oW] - T* in = const_cast(input).bufferAsT(); - T* gO = const_cast(gradO).bufferAsT(); - T* gI = gradI.bufferAsT(); + T* in = const_cast(input).bufferAsT(); + T* gO = const_cast(gradO).bufferAsT(); + T* gI = gradI.bufferAsT(); - // initial zeroing of gradI - const Nd4jLong gradIEWS = gradI.ews(); - const Nd4jLong gradILen = gradI.lengthOf(); - if(gradIEWS == 1) - memset(gI, 0, gradILen * sizeof(T)); - else if (gradIEWS > 1) { - for (Nd4jLong i = 0; i < gradILen * gradIEWS; i += gradIEWS) - gI[i] = static_cast(0.f); - } - else { - PRAGMA_OMP_PARALLEL_FOR_SIMD - for (Nd4jLong i = 0; i < gradILen; i++) - gI[shape::getIndexOffset(i, gradI.getShapeInfo(), gradILen)] = static_cast(0.f); - } + // initial zeroing of gradI + const Nd4jLong gradIEWS = gradI.ews(); + const Nd4jLong gradILen = gradI.lengthOf(); + if(gradIEWS == 1) + memset(gI, 0, gradILen * sizeof(T)); + else if (gradIEWS > 1) { + for (Nd4jLong i = 0; i < gradILen * gradIEWS; i += gradIEWS) + gI[i] = static_cast(0.f); + } + else { + PRAGMA_OMP_PARALLEL_FOR_SIMD + for (Nd4jLong i = 0; i < gradILen; i++) + gI[shape::getIndexOffset(i, gradI.getShapeInfo(), gradILen)] = static_cast(0.f); + } - const int kHEff = kH + (kH-1)*(dH-1); - const int kWEff = kW + (kW-1)*(dW-1); + const int kHEff = kH + (kH-1)*(dH-1); + const int kWEff = kW + (kW-1)*(dW-1); - const int bS = gradI.sizeAt(0); - const int iC = gradI.sizeAt(1); - const int iH = gradI.sizeAt(2); - const int iW = gradI.sizeAt(3); - const int oC = gradO.sizeAt(1); - const int oH = gradO.sizeAt(2); - const int oW = gradO.sizeAt(3); + const int bS = gradI.sizeAt(0); + const int iC = gradI.sizeAt(1); + const int iH = gradI.sizeAt(2); + const int iW = gradI.sizeAt(3); + const int oC = gradO.sizeAt(1); + const int oH = gradO.sizeAt(2); + const int oW = gradO.sizeAt(3); #ifdef HAVE_MKLDNN - if (poolingMode < 2 && block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported()) { + if (poolingMode < 2 && block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported()) { std::vector& streams = block.getMKLDNNStreams(); if (streams.empty()) { streams.push_back(MKLDNNStream("pooling2d_bp")); @@ -1846,217 +1844,217 @@ static void pooling2dBP_(nd4j::LaunchContext & block, const NDArray& input, cons return; } #endif - nd4j_debug("MKL-DNN is not used for pooling2d_bp!\n", 0); + nd4j_debug("MKL-DNN is not used for pooling2d_bp!\n", 0); - const Nd4jLong iStride0 = gradI.stridesOf()[0]; - const Nd4jLong iStride1 = gradI.stridesOf()[1]; - const Nd4jLong iStride2 = gradI.stridesOf()[2]; - const Nd4jLong iStride3 = gradI.stridesOf()[3]; - const Nd4jLong oStride0 = gradO.stridesOf()[0]; - const Nd4jLong oStride1 = gradO.stridesOf()[1]; - const Nd4jLong oStride2 = gradO.stridesOf()[2]; - const Nd4jLong oStride3 = gradO.stridesOf()[3]; - const Nd4jLong iStep2 = dH*iStride2; - const Nd4jLong iStep3 = dW*iStride3; - const int kProd = kH*kW; - const T iStep2Inv = 1./iStep2; - const T iStep3Inv = 1./iStep3; + const Nd4jLong iStride0 = gradI.stridesOf()[0]; + const Nd4jLong iStride1 = gradI.stridesOf()[1]; + const Nd4jLong iStride2 = gradI.stridesOf()[2]; + const Nd4jLong iStride3 = gradI.stridesOf()[3]; + const Nd4jLong oStride0 = gradO.stridesOf()[0]; + const Nd4jLong oStride1 = gradO.stridesOf()[1]; + const Nd4jLong oStride2 = gradO.stridesOf()[2]; + const Nd4jLong oStride3 = gradO.stridesOf()[3]; + const Nd4jLong iStep2 = dH*iStride2; + const Nd4jLong iStep3 = dW*iStride3; + const int kProd = kH*kW; + const T iStep2Inv = 1./iStep2; + const T iStep3Inv = 1./iStep3; - Nd4jLong hstart, wstart,hend, wend, maxKH, maxKW; - T sum, valO, *pIn, *pgI; + Nd4jLong hstart, wstart,hend, wend, maxKH, maxKW; + T sum, valO, *pIn, *pgI; - if(poolingMode == 0) { // max - PRAGMA_OMP_PARALLEL_FOR_ARGS(private(pIn, valO, sum, hstart, wstart, hend, wend, maxKH, maxKW)) - for(int b = 0; b < bS; ++b) { - for(int c = 0; c < iC; ++c) { - for(int oh = 0; oh < oH; ++oh) { - for(int ow = 0; ow < oW; ++ow) { - - pIn = in + b * iStride0 + c * iStride1; + if(poolingMode == 0) { // max + PRAGMA_OMP_PARALLEL_FOR_ARGS(private(pIn, valO, sum, hstart, wstart, hend, wend, maxKH, maxKW)) + for(int b = 0; b < bS; ++b) { + for(int c = 0; c < iC; ++c) { + for(int oh = 0; oh < oH; ++oh) { + for(int ow = 0; ow < oW; ++ow) { - hstart = oh * sH - pH; - wstart = ow * sW - pW; - hend = hstart + kHEff; - wend = wstart + kWEff; + pIn = in + b * iStride0 + c * iStride1; - if(hstart < 0) - hstart += dH * ((-hstart + dH - 1) / dH); // (Nd4jLong)nd4j::math::nd4j_ceil(static_cast(-hstart) / static_cast(dH)); - if(wstart < 0) - wstart += dW * ((-wstart + dW -1) / dW); //(Nd4jLong)nd4j::math::nd4j_ceil(static_cast(-wstart) / static_cast(dW)); - if(hend > iH) - hend -= dH * ((hend-iH + dH - 1) / dH); //(Nd4jLong)nd4j::math::nd4j_ceil(static_cast(hend-iH) / static_cast(dH)); - if(wend > iW) - wend -= dW * ((wend-iW + dW - 1) / dW); //(Nd4jLong)nd4j::math::nd4j_ceil(static_cast(wend-iW) / static_cast(dW)); + hstart = oh * sH - pH; + wstart = ow * sW - pW; + hend = hstart + kHEff; + wend = wstart + kWEff; - hstart *= iStride2; - hend *= iStride2; - wstart *= iStride3; - wend *= iStride3; + if(hstart < 0) + hstart += dH * ((-hstart + dH - 1) / dH); // (Nd4jLong)nd4j::math::nd4j_ceil(static_cast(-hstart) / static_cast(dH)); + if(wstart < 0) + wstart += dW * ((-wstart + dW -1) / dW); //(Nd4jLong)nd4j::math::nd4j_ceil(static_cast(-wstart) / static_cast(dW)); + if(hend > iH) + hend -= dH * ((hend-iH + dH - 1) / dH); //(Nd4jLong)nd4j::math::nd4j_ceil(static_cast(hend-iH) / static_cast(dH)); + if(wend > iW) + wend -= dW * ((wend-iW + dW - 1) / dW); //(Nd4jLong)nd4j::math::nd4j_ceil(static_cast(wend-iW) / static_cast(dW)); - sum = -DataTypeUtils::max(); - valO = gO[b*oStride0 + c*oStride1 + oh*oStride2 + ow*oStride3]; + hstart *= iStride2; + hend *= iStride2; + wstart *= iStride3; + wend *= iStride3; - // we set these to default values - maxKH = hstart; - maxKW = wstart; + sum = -DataTypeUtils::max(); + valO = gO[b*oStride0 + c*oStride1 + oh*oStride2 + ow*oStride3]; - for (Nd4jLong kh = hstart; kh < hend; kh += iStep2) - for (Nd4jLong kw = wstart; kw < wend; kw += iStep3) { - T valIn = pIn[kh + kw]; - if (valIn > sum) { - sum = valIn; - maxKH = kh; - maxKW = kw; - } + // we set these to default values + maxKH = hstart; + maxKW = wstart; + + for (Nd4jLong kh = hstart; kh < hend; kh += iStep2) + for (Nd4jLong kw = wstart; kw < wend; kw += iStep3) { + T valIn = pIn[kh + kw]; + if (valIn > sum) { + sum = valIn; + maxKH = kh; + maxKW = kw; + } + } + gI[pIn - in + maxKH + maxKW] += valO; } - gI[pIn - in + maxKH + maxKW] += valO; + } } } } - } - } -/*************************************************************************/ - else if(poolingMode == 1) { // avg - PRAGMA_OMP_PARALLEL_FOR_ARGS(private(pgI, valO, hstart, wstart, hend, wend)) - for(int b = 0; b < bS; ++b) { - for(int c = 0; c < iC; ++c) { - for(int oh = 0; oh < oH; ++oh) { - for(int ow = 0; ow < oW; ++ow) { - - pgI = gI + b * iStride0 + c * iStride1; +/*************************************************************************/ + else if(poolingMode == 1) { // avg + PRAGMA_OMP_PARALLEL_FOR_ARGS(private(pgI, valO, hstart, wstart, hend, wend)) + for(int b = 0; b < bS; ++b) { + for(int c = 0; c < iC; ++c) { + for(int oh = 0; oh < oH; ++oh) { + for(int ow = 0; ow < oW; ++ow) { - hstart = oh * sH - pH; - wstart = ow * sW - pW; - hend = hstart + kHEff; - wend = wstart + kWEff; + pgI = gI + b * iStride0 + c * iStride1; - if(hstart < 0) - hstart += dH * ((-hstart + dH - 1) / dH); // (Nd4jLong)nd4j::math::nd4j_ceil(static_cast(-hstart) / static_cast(dH)); - if(wstart < 0) - wstart += dW * ((-wstart + dW -1) / dW); //(Nd4jLong)nd4j::math::nd4j_ceil(static_cast(-wstart) / static_cast(dW)); - if(hend > iH) - hend -= dH * ((hend-iH + dH - 1) / dH); //(Nd4jLong)nd4j::math::nd4j_ceil(static_cast(hend-iH) / static_cast(dH)); - if(wend > iW) - wend -= dW * ((wend-iW + dW - 1) / dW); //(Nd4jLong)nd4j::math::nd4j_ceil(static_cast(wend-iW) / static_cast(dW)); + hstart = oh * sH - pH; + wstart = ow * sW - pW; + hend = hstart + kHEff; + wend = wstart + kWEff; - hstart *= iStride2; - hend *= iStride2; - wstart *= iStride3; - wend *= iStride3; + if(hstart < 0) + hstart += dH * ((-hstart + dH - 1) / dH); // (Nd4jLong)nd4j::math::nd4j_ceil(static_cast(-hstart) / static_cast(dH)); + if(wstart < 0) + wstart += dW * ((-wstart + dW -1) / dW); //(Nd4jLong)nd4j::math::nd4j_ceil(static_cast(-wstart) / static_cast(dW)); + if(hend > iH) + hend -= dH * ((hend-iH + dH - 1) / dH); //(Nd4jLong)nd4j::math::nd4j_ceil(static_cast(hend-iH) / static_cast(dH)); + if(wend > iW) + wend -= dW * ((wend-iW + dW - 1) / dW); //(Nd4jLong)nd4j::math::nd4j_ceil(static_cast(wend-iW) / static_cast(dW)); - valO = gO[b*oStride0 + c*oStride1 + oh*oStride2 + ow*oStride3]; - - if ((int) extraParam0 == 0) //Exclude padding - valO /= static_cast(nd4j::math::nd4j_ceil(static_cast(hend-hstart) / static_cast(iStep2))) * static_cast(nd4j::math::nd4j_ceil(static_cast(wend-wstart) / static_cast(iStep3))); //Accounts for dilation - else if ((int) extraParam0 == 1) //Include padding - valO /= kProd; + hstart *= iStride2; + hend *= iStride2; + wstart *= iStride3; + wend *= iStride3; - for (Nd4jLong kh = hstart; kh < hend; kh += iStep2) - for (Nd4jLong kw = wstart; kw < wend; kw += iStep3) - pgI[kh + kw] += valO; + valO = gO[b*oStride0 + c*oStride1 + oh*oStride2 + ow*oStride3]; + + if ((int) extraParam0 == 0) //Exclude padding + valO /= static_cast(nd4j::math::nd4j_ceil(static_cast(hend-hstart) / static_cast(iStep2))) * static_cast(nd4j::math::nd4j_ceil(static_cast(wend-wstart) / static_cast(iStep3))); //Accounts for dilation + else if ((int) extraParam0 == 1) //Include padding + valO /= kProd; + + for (Nd4jLong kh = hstart; kh < hend; kh += iStep2) + for (Nd4jLong kw = wstart; kw < wend; kw += iStep3) + pgI[kh + kw] += valO; + } + } } } } - } - } -/*************************************************************************/ - else if(poolingMode == 2) { // pnorm - PRAGMA_OMP_PARALLEL_FOR_ARGS(private(pIn, valO, pgI, sum, hstart, wstart, hend, wend)) - for(int b = 0; b < bS; ++b) { - for(int c = 0; c < iC; ++c) { - for(int oh = 0; oh < oH; ++oh) { - for(int ow = 0; ow < oW; ++ow) { - - pIn = in + b * iStride0 + c * iStride1; - pgI = gI + (pIn - in); +/*************************************************************************/ + else if(poolingMode == 2) { // pnorm + PRAGMA_OMP_PARALLEL_FOR_ARGS(private(pIn, valO, pgI, sum, hstart, wstart, hend, wend)) + for(int b = 0; b < bS; ++b) { + for(int c = 0; c < iC; ++c) { + for(int oh = 0; oh < oH; ++oh) { + for(int ow = 0; ow < oW; ++ow) { - hstart = oh * sH - pH; - wstart = ow * sW - pW; - hend = hstart + kHEff; - wend = wstart + kWEff; + pIn = in + b * iStride0 + c * iStride1; + pgI = gI + (pIn - in); - if(hstart < 0) - hstart += dH * ((-hstart + dH - 1) / dH); // (Nd4jLong)nd4j::math::nd4j_ceil(static_cast(-hstart) / static_cast(dH)); - if(wstart < 0) - wstart += dW * ((-wstart + dW -1) / dW); //(Nd4jLong)nd4j::math::nd4j_ceil(static_cast(-wstart) / static_cast(dW)); - if(hend > iH) - hend -= dH * ((hend-iH + dH - 1) / dH); //(Nd4jLong)nd4j::math::nd4j_ceil(static_cast(hend-iH) / static_cast(dH)); - if(wend > iW) - wend -= dW * ((wend-iW + dW - 1) / dW); //(Nd4jLong)nd4j::math::nd4j_ceil(static_cast(wend-iW) / static_cast(dW)); + hstart = oh * sH - pH; + wstart = ow * sW - pW; + hend = hstart + kHEff; + wend = wstart + kWEff; + + if(hstart < 0) + hstart += dH * ((-hstart + dH - 1) / dH); // (Nd4jLong)nd4j::math::nd4j_ceil(static_cast(-hstart) / static_cast(dH)); + if(wstart < 0) + wstart += dW * ((-wstart + dW -1) / dW); //(Nd4jLong)nd4j::math::nd4j_ceil(static_cast(-wstart) / static_cast(dW)); + if(hend > iH) + hend -= dH * ((hend-iH + dH - 1) / dH); //(Nd4jLong)nd4j::math::nd4j_ceil(static_cast(hend-iH) / static_cast(dH)); + if(wend > iW) + wend -= dW * ((wend-iW + dW - 1) / dW); //(Nd4jLong)nd4j::math::nd4j_ceil(static_cast(wend-iW) / static_cast(dW)); - hstart *= iStride2; - hend *= iStride2; - wstart *= iStride3; - wend *= iStride3; + hstart *= iStride2; + hend *= iStride2; + wstart *= iStride3; + wend *= iStride3; - sum = static_cast(0.f); - valO = gO[b*oStride0 + c*oStride1 + oh*oStride2 + ow*oStride3]; - - for (Nd4jLong kh = hstart; kh < hend; kh += iStep2) - for (Nd4jLong kw = wstart; kw < wend; kw += iStep3) - sum += nd4j::math::nd4j_pow(nd4j::math::nd4j_abs(pIn[kh + kw]), extraParam0); - - valO *= nd4j::math::nd4j_pow(sum, ((T)1. - extraParam0) / extraParam0); + sum = static_cast(0.f); + valO = gO[b*oStride0 + c*oStride1 + oh*oStride2 + ow*oStride3]; - for (Nd4jLong kh = hstart; kh < hend; kh += iStep2) - for (Nd4jLong kw = wstart; kw < wend; kw += iStep3) - pgI[kh + kw] += valO * nd4j::math::nd4j_pow(nd4j::math::nd4j_abs(pIn[kh + kw]), extraParam0 - 1.f); + for (Nd4jLong kh = hstart; kh < hend; kh += iStep2) + for (Nd4jLong kw = wstart; kw < wend; kw += iStep3) + sum += nd4j::math::nd4j_pow(nd4j::math::nd4j_abs(pIn[kh + kw]), extraParam0); + + valO *= nd4j::math::nd4j_pow(sum, ((T)1. - extraParam0) / extraParam0); + + for (Nd4jLong kh = hstart; kh < hend; kh += iStep2) + for (Nd4jLong kw = wstart; kw < wend; kw += iStep3) + pgI[kh + kw] += valO * nd4j::math::nd4j_pow(nd4j::math::nd4j_abs(pIn[kh + kw]), extraParam0 - 1.f); + } + } } } } + else { + nd4j_printf("ConvolutionUtils::pooling2dBP: pooling mode argument can take three values only: 0, 1, 2, but got %i instead !\n", poolingMode); + throw ""; + } } - } - else { - nd4j_printf("ConvolutionUtils::pooling2dBP: pooling mode argument can take three values only: 0, 1, 2, but got %i instead !\n", poolingMode); - throw ""; - } -} ////////////////////////////////////////////////////////////////////////// -template -static void pooling3dBP_(nd4j::LaunchContext & block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0) { - // input [bS, iC, iD, iH, iW] - // gradI [bS, iC, iD, iH, iW] -> gradI is output in this function - // gradO [bS, iC, oD, oH, oW] - - T* in = const_cast(input).bufferAsT(); - T* gO = const_cast(gradO).bufferAsT(); - T* gI = gradI.bufferAsT(); + template + static void pooling3dBP_(nd4j::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0) { + // input [bS, iC, iD, iH, iW] + // gradI [bS, iC, iD, iH, iW] -> gradI is output in this function + // gradO [bS, iC, oD, oH, oW] - // initial zeroing of gradI - const Nd4jLong gradIEWS = gradI.ews(); - const Nd4jLong gradILen = gradI.lengthOf(); - if(gradIEWS == 1) { - memset(gI, 0, gradILen * sizeof(T)); - } - else if (gradIEWS > 1) { - PRAGMA_OMP_PARALLEL_FOR - for (Nd4jLong i = 0; i < gradILen * gradIEWS; i += gradIEWS) - gI[i] = static_cast(0.f); - } - else { - PRAGMA_OMP_PARALLEL_FOR - for (Nd4jLong i = 0; i < gradILen; i++) - gI[shape::getIndexOffset(i, gradI.getShapeInfo(), gradILen)] = static_cast(0.f); - } + T* in = const_cast(input).bufferAsT(); + T* gO = const_cast(gradO).bufferAsT(); + T* gI = gradI.bufferAsT(); - const int kDEff = kD + (kD-1)*(dD-1); - const int kHEff = kH + (kH-1)*(dH-1); - const int kWEff = kW + (kW-1)*(dW-1); + // initial zeroing of gradI + const Nd4jLong gradIEWS = gradI.ews(); + const Nd4jLong gradILen = gradI.lengthOf(); + if(gradIEWS == 1) { + memset(gI, 0, gradILen * sizeof(T)); + } + else if (gradIEWS > 1) { + PRAGMA_OMP_PARALLEL_FOR + for (Nd4jLong i = 0; i < gradILen * gradIEWS; i += gradIEWS) + gI[i] = static_cast(0.f); + } + else { + PRAGMA_OMP_PARALLEL_FOR + for (Nd4jLong i = 0; i < gradILen; i++) + gI[shape::getIndexOffset(i, gradI.getShapeInfo(), gradILen)] = static_cast(0.f); + } - const int bS = gradI.sizeAt(0); - const int iC = gradI.sizeAt(1); - const int iD = gradI.sizeAt(2); - const int iH = gradI.sizeAt(3); - const int iW = gradI.sizeAt(4); - const int oC = gradO.sizeAt(1); - const int oD = gradO.sizeAt(2); - const int oH = gradO.sizeAt(3); - const int oW = gradO.sizeAt(4); + const int kDEff = kD + (kD-1)*(dD-1); + const int kHEff = kH + (kH-1)*(dH-1); + const int kWEff = kW + (kW-1)*(dW-1); + + const int bS = gradI.sizeAt(0); + const int iC = gradI.sizeAt(1); + const int iD = gradI.sizeAt(2); + const int iH = gradI.sizeAt(3); + const int iW = gradI.sizeAt(4); + const int oC = gradO.sizeAt(1); + const int oD = gradO.sizeAt(2); + const int oH = gradO.sizeAt(3); + const int oW = gradO.sizeAt(4); #ifdef HAVE_MKLDNN - if (poolingMode < 2 && block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported()) { + if (poolingMode < 2 && block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported()) { std::vector& streams = block.getMKLDNNStreams(); if (streams.empty()) { streams.push_back(MKLDNNStream("pooling3d_bp")); @@ -2143,280 +2141,280 @@ static void pooling3dBP_(nd4j::LaunchContext & block, const NDArray& input, cons return; } #endif - nd4j_debug("MKL-DNN is not used for pooling3d_bp!\n", 0); + nd4j_debug("MKL-DNN is not used for pooling3d_bp!\n", 0); - const Nd4jLong iStride0 = gradI.stridesOf()[0]; - const Nd4jLong iStride1 = gradI.stridesOf()[1]; - const Nd4jLong iStride2 = gradI.stridesOf()[2]; - const Nd4jLong iStride3 = gradI.stridesOf()[3]; - const Nd4jLong iStride4 = gradI.stridesOf()[4]; - const Nd4jLong oStride0 = gradO.stridesOf()[0]; - const Nd4jLong oStride1 = gradO.stridesOf()[1]; - const Nd4jLong oStride2 = gradO.stridesOf()[2]; - const Nd4jLong oStride3 = gradO.stridesOf()[3]; - const Nd4jLong oStride4 = gradO.stridesOf()[4]; - const Nd4jLong iStep2 = dD*iStride2; - const Nd4jLong iStep3 = dH*iStride3; - const Nd4jLong iStep4 = dW*iStride4; - const int kProd = kD*kH*kW; - const T iStep2Inv = 1./iStep2; - const T iStep3Inv = 1./iStep3; - const T iStep4Inv = 1./iStep4; + const Nd4jLong iStride0 = gradI.stridesOf()[0]; + const Nd4jLong iStride1 = gradI.stridesOf()[1]; + const Nd4jLong iStride2 = gradI.stridesOf()[2]; + const Nd4jLong iStride3 = gradI.stridesOf()[3]; + const Nd4jLong iStride4 = gradI.stridesOf()[4]; + const Nd4jLong oStride0 = gradO.stridesOf()[0]; + const Nd4jLong oStride1 = gradO.stridesOf()[1]; + const Nd4jLong oStride2 = gradO.stridesOf()[2]; + const Nd4jLong oStride3 = gradO.stridesOf()[3]; + const Nd4jLong oStride4 = gradO.stridesOf()[4]; + const Nd4jLong iStep2 = dD*iStride2; + const Nd4jLong iStep3 = dH*iStride3; + const Nd4jLong iStep4 = dW*iStride4; + const int kProd = kD*kH*kW; + const T iStep2Inv = 1./iStep2; + const T iStep3Inv = 1./iStep3; + const T iStep4Inv = 1./iStep4; - Nd4jLong dstart, hstart, wstart, dend, hend, wend, maxKD, maxKH, maxKW; - T sum, valO, *pIn, *pgI; + Nd4jLong dstart, hstart, wstart, dend, hend, wend, maxKD, maxKH, maxKW; + T sum, valO, *pIn, *pgI; - if(poolingMode == 0) { // max - PRAGMA_OMP_PARALLEL_FOR_ARGS(private(pIn, valO, sum, dstart, hstart, wstart, dend, hend, wend, maxKD, maxKH, maxKW)) - for(int b = 0; b < bS; ++b) { - for(int c = 0; c < iC; ++c) { - for(int od = 0; od < oD; ++od) { - for(int oh = 0; oh < oH; ++oh) { - for(int ow = 0; ow < oW; ++ow) { - - pIn = in + b * iStride0 + c * iStride1; + if(poolingMode == 0) { // max + PRAGMA_OMP_PARALLEL_FOR_ARGS(private(pIn, valO, sum, dstart, hstart, wstart, dend, hend, wend, maxKD, maxKH, maxKW)) + for(int b = 0; b < bS; ++b) { + for(int c = 0; c < iC; ++c) { + for(int od = 0; od < oD; ++od) { + for(int oh = 0; oh < oH; ++oh) { + for(int ow = 0; ow < oW; ++ow) { - dstart = od * sD - pD; - hstart = oh * sH - pH; - wstart = ow * sW - pW; - dend = dstart + kDEff; - hend = hstart + kHEff; - wend = wstart + kWEff; + pIn = in + b * iStride0 + c * iStride1; - if(dstart < 0) - dstart += dD * ((-dstart + dD - 1) / dD); - if(hstart < 0) - hstart += dH * ((-hstart + dH - 1) / dH); - if(wstart < 0) - wstart += dW * ((-wstart + dW - 1) / dW); - if(dend > iD) - dend -= dD * ((dend-iD + dD - 1) / dD); - if(hend > iH) - hend -= dH * ((hend-iH + dH - 1) / dH); - if(wend > iW) - wend -= dW * ((wend-iW + dW - 1) / dW); + dstart = od * sD - pD; + hstart = oh * sH - pH; + wstart = ow * sW - pW; + dend = dstart + kDEff; + hend = hstart + kHEff; + wend = wstart + kWEff; - dstart *= iStride2; - dend *= iStride2; - hstart *= iStride3; - hend *= iStride3; - wstart *= iStride4; - wend *= iStride4; + if(dstart < 0) + dstart += dD * ((-dstart + dD - 1) / dD); + if(hstart < 0) + hstart += dH * ((-hstart + dH - 1) / dH); + if(wstart < 0) + wstart += dW * ((-wstart + dW - 1) / dW); + if(dend > iD) + dend -= dD * ((dend-iD + dD - 1) / dD); + if(hend > iH) + hend -= dH * ((hend-iH + dH - 1) / dH); + if(wend > iW) + wend -= dW * ((wend-iW + dW - 1) / dW); - maxKD = dstart; - maxKH = hstart; - maxKW = wstart; + dstart *= iStride2; + dend *= iStride2; + hstart *= iStride3; + hend *= iStride3; + wstart *= iStride4; + wend *= iStride4; - sum = -DataTypeUtils::max(); - valO = gO[b*oStride0 + c*oStride1+ od*oStride2 + oh*oStride3 + ow*oStride4]; - - for (Nd4jLong kd = dstart; kd < dend; kd += iStep2) - for (Nd4jLong kh = hstart; kh < hend; kh += iStep3) - for (Nd4jLong kw = wstart; kw < wend; kw += iStep4) { - T valIn = pIn[kd + kh + kw]; - if (valIn > sum) { - sum = valIn; - maxKD = kd; - maxKH = kh; - maxKW = kw; - } - } - gI[pIn - in + maxKD + maxKH + maxKW] += valO; + maxKD = dstart; + maxKH = hstart; + maxKW = wstart; + + sum = -DataTypeUtils::max(); + valO = gO[b*oStride0 + c*oStride1+ od*oStride2 + oh*oStride3 + ow*oStride4]; + + for (Nd4jLong kd = dstart; kd < dend; kd += iStep2) + for (Nd4jLong kh = hstart; kh < hend; kh += iStep3) + for (Nd4jLong kw = wstart; kw < wend; kw += iStep4) { + T valIn = pIn[kd + kh + kw]; + if (valIn > sum) { + sum = valIn; + maxKD = kd; + maxKH = kh; + maxKW = kw; + } + } + gI[pIn - in + maxKD + maxKH + maxKW] += valO; + } + } } } } } - } - } -/*************************************************************************/ - else if(poolingMode == 1) { // avg - PRAGMA_OMP_PARALLEL_FOR_ARGS(private(pgI, valO, dstart, hstart, wstart, dend, hend, wend)) - for(int b = 0; b < bS; ++b) { - for(int c = 0; c < iC; ++c) { - for(int od = 0; od < oD; ++od) { - for(int oh = 0; oh < oH; ++oh) { - for(int ow = 0; ow < oW; ++ow) { - - pgI = gI + b * iStride0 + c * iStride1; +/*************************************************************************/ + else if(poolingMode == 1) { // avg + PRAGMA_OMP_PARALLEL_FOR_ARGS(private(pgI, valO, dstart, hstart, wstart, dend, hend, wend)) + for(int b = 0; b < bS; ++b) { + for(int c = 0; c < iC; ++c) { + for(int od = 0; od < oD; ++od) { + for(int oh = 0; oh < oH; ++oh) { + for(int ow = 0; ow < oW; ++ow) { - dstart = od * sD - pD; - hstart = oh * sH - pH; - wstart = ow * sW - pW; - dend = dstart + kDEff; - hend = hstart + kHEff; - wend = wstart + kWEff; + pgI = gI + b * iStride0 + c * iStride1; - if(dstart < 0) - dstart += dD * ((-dstart + dD - 1) / dD); - if(hstart < 0) - hstart += dH * ((-hstart + dH - 1) / dH); - if(wstart < 0) - wstart += dW * ((-wstart + dW - 1) / dW); - if(dend > iD) - dend -= dD * ((dend-iD + dD - 1) / dD); - if(hend > iH) - hend -= dH * ((hend-iH + dH - 1) / dH); - if(wend > iW) - wend -= dW * ((wend-iW + dW - 1) / dW); + dstart = od * sD - pD; + hstart = oh * sH - pH; + wstart = ow * sW - pW; + dend = dstart + kDEff; + hend = hstart + kHEff; + wend = wstart + kWEff; - dstart *= iStride2; - dend *= iStride2; - hstart *= iStride3; - hend *= iStride3; - wstart *= iStride4; - wend *= iStride4; + if(dstart < 0) + dstart += dD * ((-dstart + dD - 1) / dD); + if(hstart < 0) + hstart += dH * ((-hstart + dH - 1) / dH); + if(wstart < 0) + wstart += dW * ((-wstart + dW - 1) / dW); + if(dend > iD) + dend -= dD * ((dend-iD + dD - 1) / dD); + if(hend > iH) + hend -= dH * ((hend-iH + dH - 1) / dH); + if(wend > iW) + wend -= dW * ((wend-iW + dW - 1) / dW); - valO = gO[b*oStride0 + c*oStride1+ od*oStride2 + oh*oStride3 + ow*oStride4]; - - if ((int) extraParam0 == 0) //Exclude padding - valO /= static_cast(nd4j::math::nd4j_ceil(static_cast(dend-dstart) / static_cast(iStep2))) * static_cast(nd4j::math::nd4j_ceil(static_cast(hend-hstart) / static_cast(iStep3))) * static_cast(nd4j::math::nd4j_ceil(static_cast(wend-wstart) / static_cast(iStep4))); //Accounts for dilation - else if ((int) extraParam0 == 1) //Include padding - valO /= kProd; + dstart *= iStride2; + dend *= iStride2; + hstart *= iStride3; + hend *= iStride3; + wstart *= iStride4; + wend *= iStride4; - for (Nd4jLong kd = dstart; kd < dend; kd += iStep2) - for (Nd4jLong kh = hstart; kh < hend; kh += iStep3) - for (Nd4jLong kw = wstart; kw < wend; kw += iStep4) - pgI[kd + kh + kw] += valO; + valO = gO[b*oStride0 + c*oStride1+ od*oStride2 + oh*oStride3 + ow*oStride4]; + + if ((int) extraParam0 == 0) //Exclude padding + valO /= static_cast(nd4j::math::nd4j_ceil(static_cast(dend-dstart) / static_cast(iStep2))) * static_cast(nd4j::math::nd4j_ceil(static_cast(hend-hstart) / static_cast(iStep3))) * static_cast(nd4j::math::nd4j_ceil(static_cast(wend-wstart) / static_cast(iStep4))); //Accounts for dilation + else if ((int) extraParam0 == 1) //Include padding + valO /= kProd; + + for (Nd4jLong kd = dstart; kd < dend; kd += iStep2) + for (Nd4jLong kh = hstart; kh < hend; kh += iStep3) + for (Nd4jLong kw = wstart; kw < wend; kw += iStep4) + pgI[kd + kh + kw] += valO; + } + } } } } } - } - } -/*************************************************************************/ - else if(poolingMode == 2) { // pnorm - PRAGMA_OMP_PARALLEL_FOR_ARGS(private(pIn, pgI, valO, sum, dstart, hstart, wstart, dend, hend, wend)) - for(int b = 0; b < bS; ++b) { - for(int c = 0; c < iC; ++c) { - for(int od = 0; od < oD; ++od) { - for(int oh = 0; oh < oH; ++oh) { - for(int ow = 0; ow < oW; ++ow) { - - pIn = in + b * iStride0 + c * iStride1; - pgI = gI + (pIn - in); +/*************************************************************************/ + else if(poolingMode == 2) { // pnorm + PRAGMA_OMP_PARALLEL_FOR_ARGS(private(pIn, pgI, valO, sum, dstart, hstart, wstart, dend, hend, wend)) + for(int b = 0; b < bS; ++b) { + for(int c = 0; c < iC; ++c) { + for(int od = 0; od < oD; ++od) { + for(int oh = 0; oh < oH; ++oh) { + for(int ow = 0; ow < oW; ++ow) { - dstart = od * sD - pD; - hstart = oh * sH - pH; - wstart = ow * sW - pW; - dend = dstart + kDEff; - hend = hstart + kHEff; - wend = wstart + kWEff; + pIn = in + b * iStride0 + c * iStride1; + pgI = gI + (pIn - in); - if(dstart < 0) - dstart += dD * ((-dstart + dD - 1) / dD); - if(hstart < 0) - hstart += dH * ((-hstart + dH - 1) / dH); - if(wstart < 0) - wstart += dW * ((-wstart + dW - 1) / dW); - if(dend > iD) - dend -= dD * ((dend-iD + dD - 1) / dD); - if(hend > iH) - hend -= dH * ((hend-iH + dH - 1) / dH); - if(wend > iW) - wend -= dW * ((wend-iW + dW - 1) / dW); + dstart = od * sD - pD; + hstart = oh * sH - pH; + wstart = ow * sW - pW; + dend = dstart + kDEff; + hend = hstart + kHEff; + wend = wstart + kWEff; - dstart *= iStride2; - dend *= iStride2; - hstart *= iStride3; - hend *= iStride3; - wstart *= iStride4; - wend *= iStride4; + if(dstart < 0) + dstart += dD * ((-dstart + dD - 1) / dD); + if(hstart < 0) + hstart += dH * ((-hstart + dH - 1) / dH); + if(wstart < 0) + wstart += dW * ((-wstart + dW - 1) / dW); + if(dend > iD) + dend -= dD * ((dend-iD + dD - 1) / dD); + if(hend > iH) + hend -= dH * ((hend-iH + dH - 1) / dH); + if(wend > iW) + wend -= dW * ((wend-iW + dW - 1) / dW); - sum = static_cast(0.); - valO = gO[b*oStride0 + c*oStride1+ od*oStride2 + oh*oStride3 + ow*oStride4]; + dstart *= iStride2; + dend *= iStride2; + hstart *= iStride3; + hend *= iStride3; + wstart *= iStride4; + wend *= iStride4; - for (Nd4jLong kd = dstart; kd < dend; kd += iStep2) - for (Nd4jLong kh = hstart; kh < hend; kh += iStep3) - for (Nd4jLong kw = wstart; kw < wend; kw += iStep4) - sum += nd4j::math::nd4j_pow(nd4j::math::nd4j_abs(pIn[kd + kh + kw]), extraParam0); + sum = static_cast(0.); + valO = gO[b*oStride0 + c*oStride1+ od*oStride2 + oh*oStride3 + ow*oStride4]; - valO *= nd4j::math::nd4j_pow(sum, ((T)1.f - extraParam0) / extraParam0); + for (Nd4jLong kd = dstart; kd < dend; kd += iStep2) + for (Nd4jLong kh = hstart; kh < hend; kh += iStep3) + for (Nd4jLong kw = wstart; kw < wend; kw += iStep4) + sum += nd4j::math::nd4j_pow(nd4j::math::nd4j_abs(pIn[kd + kh + kw]), extraParam0); - for (Nd4jLong kd = dstart; kd < dend; kd += iStep2) - for (Nd4jLong kh = hstart; kh < hend; kh += iStep3) - for (Nd4jLong kw = wstart; kw < wend; kw += iStep4) - pgI[kd + kh + kw] += valO * nd4j::math::nd4j_pow(nd4j::math::nd4j_abs(pIn[kd + kh + kw]), extraParam0 - 1.f); + valO *= nd4j::math::nd4j_pow(sum, ((T)1.f - extraParam0) / extraParam0); + + for (Nd4jLong kd = dstart; kd < dend; kd += iStep2) + for (Nd4jLong kh = hstart; kh < hend; kh += iStep3) + for (Nd4jLong kw = wstart; kw < wend; kw += iStep4) + pgI[kd + kh + kw] += valO * nd4j::math::nd4j_pow(nd4j::math::nd4j_abs(pIn[kd + kh + kw]), extraParam0 - 1.f); + } + } } } } } + else { + nd4j_printf("ConvolutionUtils::pooling3dBP: pooling mode argument can take three values only: 0, 1, 2, but got %i instead !\n", poolingMode); + throw ""; + } } + + + + + void ConvolutionUtils::conv2d(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) { + BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), conv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), LIBND4J_TYPES, FLOAT_TYPES); + } + void ConvolutionUtils::conv2dBP(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) { + BUILD_DOUBLE_SELECTOR(input->dataType(), gradO->dataType(), conv2dBP_, (block, input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), LIBND4J_TYPES, FLOAT_TYPES); + } + void ConvolutionUtils::depthwiseConv2d(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) { + BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), depthwiseConv2d_, (input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), LIBND4J_TYPES, FLOAT_TYPES); + } + void ConvolutionUtils::depthwiseConv2dBP(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) { + BUILD_DOUBLE_SELECTOR(input->dataType(), gradO->dataType(), depthwiseConv2dBP_, (input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), LIBND4J_TYPES, FLOAT_TYPES); + } + void ConvolutionUtils::sconv2d(nd4j::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) { + BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), sconv2d_, (block, input, weightsDepth, weightsPoint, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), LIBND4J_TYPES, FLOAT_TYPES); + } + void ConvolutionUtils::vol2col(nd4j::graph::Context& block, const NDArray& volume, NDArray& columns, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) { + BUILD_SINGLE_SELECTOR(volume.dataType(), vol2col_, (volume, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW), LIBND4J_TYPES); + } + void ConvolutionUtils::col2vol(nd4j::graph::Context& block, const NDArray& columns, NDArray& volume, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) { + BUILD_SINGLE_SELECTOR(volume.dataType(), col2vol_, (columns, volume, sD, sH, sW, pD, pH, pW, dD, dH, dW), LIBND4J_TYPES); + } + void ConvolutionUtils::upsampling2d(nd4j::graph::Context& block, const NDArray& input, NDArray& output, const int factorH, const int factorW, const bool isNCHW) { + BUILD_SINGLE_SELECTOR(input.dataType(), upsampling2d_, (input, output, factorH, factorW, isNCHW), LIBND4J_TYPES); + } + void ConvolutionUtils::upsampling3d(nd4j::graph::Context& block, const NDArray& input, NDArray& output, const int factorD, const int factorH, const int factorW, const bool isNCDHW) { + BUILD_SINGLE_SELECTOR(input.dataType(), upsampling3d_, (input, output, factorD, factorH, factorW, isNCDHW), LIBND4J_TYPES); + } + void ConvolutionUtils::upsampling2dBP(nd4j::graph::Context& block, const NDArray& gradO, NDArray& gradI, const bool isNCHW) { + BUILD_SINGLE_SELECTOR(gradO.dataType(), upsampling2dBP_, (gradO, gradI, isNCHW), LIBND4J_TYPES); + } + void ConvolutionUtils::upsampling3dBP(nd4j::graph::Context& block, const NDArray& gradO, NDArray& gradI, const bool isNCHW) { + BUILD_SINGLE_SELECTOR(gradO.dataType(), upsampling3dBP_, (gradO, gradI, isNCHW), LIBND4J_TYPES); + } + + + + void ConvolutionUtils::pooling2d(nd4j::graph::Context& block, const NDArray& input, NDArray& output, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const PoolingType poolingMode, const int extraParam0) { + BUILD_SINGLE_SELECTOR(input.dataType(), pooling2d_, (block, input, output, kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0), LIBND4J_TYPES); + } + void ConvolutionUtils::pooling3d(nd4j::graph::Context& block, const NDArray& input, NDArray& output, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0) { + BUILD_SINGLE_SELECTOR(input.dataType(), pooling3d_, (block, input, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0), LIBND4J_TYPES); + } + void ConvolutionUtils::pooling2dBP(nd4j::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0) { + BUILD_SINGLE_SELECTOR(input.dataType(), pooling2dBP_, (block, input, gradO, gradI, kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0), LIBND4J_TYPES); + } + void ConvolutionUtils::pooling3dBP(nd4j::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0) { + BUILD_SINGLE_SELECTOR(input.dataType(), pooling3dBP_, (block, input, gradO, gradI, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0), LIBND4J_TYPES); + } + + + BUILD_DOUBLE_TEMPLATE(template void conv2d_, (nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW), LIBND4J_TYPES, FLOAT_TYPES); + BUILD_DOUBLE_TEMPLATE(template void conv2dBP_, (nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW), LIBND4J_TYPES, FLOAT_TYPES); + BUILD_DOUBLE_TEMPLATE(template void depthwiseConv2d_, (const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW), LIBND4J_TYPES, FLOAT_TYPES); + BUILD_DOUBLE_TEMPLATE(template void depthwiseConv2dBP_, (const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW), LIBND4J_TYPES, FLOAT_TYPES); + BUILD_DOUBLE_TEMPLATE(template void sconv2d_, (nd4j::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW), LIBND4J_TYPES, FLOAT_TYPES); + + BUILD_SINGLE_TEMPLATE(template void upsampling2d_, (const NDArray& input, NDArray& output, const int factorH, const int factorW, const bool isNCHW), LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE(template void upsampling3d_, (const NDArray& input, NDArray& output, const int factorD, const int factorH, const int factorW, const bool isNCDHW), LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE(template void upsampling2dBP_, (const NDArray& gradO, NDArray& gradI, const bool isNCHW), LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE(template void upsampling3dBP_, (const NDArray& gradO, NDArray& gradI, const bool isNCHW), LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE(template void vol2col_, (const NDArray& volume, NDArray& columns, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW), LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE(template void col2vol_, (const NDArray& columns, NDArray& volume, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW), LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE(template void pooling2d_, (nd4j::graph::Context& block, const NDArray& input, NDArray& output, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0), LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE(template void pooling3d_, (nd4j::graph::Context& block, const NDArray& input, NDArray& output, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0), LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE(template void pooling2dBP_, (nd4j::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0), LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE(template void pooling3dBP_, (nd4j::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0), LIBND4J_TYPES); + } - else { - nd4j_printf("ConvolutionUtils::pooling3dBP: pooling mode argument can take three values only: 0, 1, 2, but got %i instead !\n", poolingMode); - throw ""; - } -} - - - - -void ConvolutionUtils::conv2d(nd4j::LaunchContext & block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) { - BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), conv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), LIBND4J_TYPES, FLOAT_TYPES); -} -void ConvolutionUtils::conv2dBP(nd4j::LaunchContext & block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) { - BUILD_DOUBLE_SELECTOR(input->dataType(), gradO->dataType(), conv2dBP_, (block, input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), LIBND4J_TYPES, FLOAT_TYPES); -} -void ConvolutionUtils::depthwiseConv2d(nd4j::LaunchContext & block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) { - BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), depthwiseConv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), LIBND4J_TYPES, FLOAT_TYPES); -} -void ConvolutionUtils::depthwiseConv2dBP(nd4j::LaunchContext & block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) { - BUILD_DOUBLE_SELECTOR(input->dataType(), gradO->dataType(), depthwiseConv2dBP_, (block, input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), LIBND4J_TYPES, FLOAT_TYPES); -} -void ConvolutionUtils::sconv2d(nd4j::LaunchContext & block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) { - BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), sconv2d_, (block, input, weightsDepth, weightsPoint, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), LIBND4J_TYPES, FLOAT_TYPES); -} -void ConvolutionUtils::vol2col(nd4j::LaunchContext & block, const NDArray& volume, NDArray& columns, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) { - BUILD_SINGLE_SELECTOR(volume.dataType(), vol2col_, (volume, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW), LIBND4J_TYPES); -} -void ConvolutionUtils::col2vol(nd4j::LaunchContext & block, const NDArray& columns, NDArray& volume, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) { - BUILD_SINGLE_SELECTOR(volume.dataType(), col2vol_, (columns, volume, sD, sH, sW, pD, pH, pW, dD, dH, dW), LIBND4J_TYPES); -} -void ConvolutionUtils::upsampling2d(nd4j::LaunchContext & block, const NDArray& input, NDArray& output, const int factorH, const int factorW, const bool isNCHW) { - BUILD_SINGLE_SELECTOR(input.dataType(), upsampling2d_, (input, output, factorH, factorW, isNCHW), LIBND4J_TYPES); -} -void ConvolutionUtils::upsampling3d(nd4j::LaunchContext & block, const NDArray& input, NDArray& output, const int factorD, const int factorH, const int factorW, const bool isNCDHW) { - BUILD_SINGLE_SELECTOR(input.dataType(), upsampling3d_, (input, output, factorD, factorH, factorW, isNCDHW), LIBND4J_TYPES); -} -void ConvolutionUtils::upsampling2dBP(nd4j::LaunchContext & block, const NDArray& gradO, NDArray& gradI, const bool isNCHW) { - BUILD_SINGLE_SELECTOR(gradO.dataType(), upsampling2dBP_, (gradO, gradI, isNCHW), LIBND4J_TYPES); -} -void ConvolutionUtils::upsampling3dBP(nd4j::LaunchContext & block, const NDArray& gradO, NDArray& gradI, const bool isNCHW) { - BUILD_SINGLE_SELECTOR(gradO.dataType(), upsampling3dBP_, (gradO, gradI, isNCHW), LIBND4J_TYPES); -} - - - -void ConvolutionUtils::pooling2d(nd4j::LaunchContext & block, const NDArray& input, NDArray& output, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const PoolingType poolingMode, const int extraParam0) { - BUILD_SINGLE_SELECTOR(input.dataType(), pooling2d_, (block, input, output, kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0), LIBND4J_TYPES); -} -void ConvolutionUtils::pooling3d(nd4j::LaunchContext & block, const NDArray& input, NDArray& output, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0) { - BUILD_SINGLE_SELECTOR(input.dataType(), pooling3d_, (block, input, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0), LIBND4J_TYPES); -} -void ConvolutionUtils::pooling2dBP(nd4j::LaunchContext & block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0) { - BUILD_SINGLE_SELECTOR(input.dataType(), pooling2dBP_, (block, input, gradO, gradI, kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0), LIBND4J_TYPES); -} -void ConvolutionUtils::pooling3dBP(nd4j::LaunchContext & block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0) { - BUILD_SINGLE_SELECTOR(input.dataType(), pooling3dBP_, (block, input, gradO, gradI, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0), LIBND4J_TYPES); -} - - -BUILD_DOUBLE_TEMPLATE(template void conv2d_, (nd4j::LaunchContext & block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW), LIBND4J_TYPES, FLOAT_TYPES); -BUILD_DOUBLE_TEMPLATE(template void conv2dBP_, (nd4j::LaunchContext & block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW), LIBND4J_TYPES, FLOAT_TYPES); -BUILD_DOUBLE_TEMPLATE(template void depthwiseConv2d_, (nd4j::LaunchContext & block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW), LIBND4J_TYPES, FLOAT_TYPES); -BUILD_DOUBLE_TEMPLATE(template void depthwiseConv2dBP_, (nd4j::LaunchContext & block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW), LIBND4J_TYPES, FLOAT_TYPES); -BUILD_DOUBLE_TEMPLATE(template void sconv2d_, (nd4j::LaunchContext & block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW), LIBND4J_TYPES, FLOAT_TYPES); - -BUILD_SINGLE_TEMPLATE(template void upsampling2d_, (const NDArray& input, NDArray& output, const int factorH, const int factorW, const bool isNCHW), LIBND4J_TYPES); -BUILD_SINGLE_TEMPLATE(template void upsampling3d_, (const NDArray& input, NDArray& output, const int factorD, const int factorH, const int factorW, const bool isNCDHW), LIBND4J_TYPES); -BUILD_SINGLE_TEMPLATE(template void upsampling2dBP_, (const NDArray& gradO, NDArray& gradI, const bool isNCHW), LIBND4J_TYPES); -BUILD_SINGLE_TEMPLATE(template void upsampling3dBP_, (const NDArray& gradO, NDArray& gradI, const bool isNCHW), LIBND4J_TYPES); -BUILD_SINGLE_TEMPLATE(template void vol2col_, (const NDArray& volume, NDArray& columns, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW), LIBND4J_TYPES); -BUILD_SINGLE_TEMPLATE(template void col2vol_, (const NDArray& columns, NDArray& volume, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW), LIBND4J_TYPES); -BUILD_SINGLE_TEMPLATE(template void pooling2d_, (nd4j::LaunchContext & block, const NDArray& input, NDArray& output, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0), LIBND4J_TYPES); -BUILD_SINGLE_TEMPLATE(template void pooling3d_, (nd4j::LaunchContext & block, const NDArray& input, NDArray& output, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0), LIBND4J_TYPES); -BUILD_SINGLE_TEMPLATE(template void pooling2dBP_, (nd4j::LaunchContext & block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0), LIBND4J_TYPES); -BUILD_SINGLE_TEMPLATE(template void pooling3dBP_, (nd4j::LaunchContext & block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0), LIBND4J_TYPES); - -} -} +} \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/lstm.cpp b/libnd4j/include/ops/declarable/helpers/cpu/lstm.cpp index 22e72545a..cc6da5ba4 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/lstm.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/lstm.cpp @@ -227,15 +227,15 @@ void lstmBlockCell(const NDArray* xt, const NDArray* cLast, const NDArray* yLast //NDArray* NDArrayFactory::create_( const char order, const std::vector &shape, nd4j::DataType dataType, nd4j::memory::Workspace* workspace) { std::vector shape = {bS, 4*numUnits}; - auto m = NDArrayFactory::create_('c', shape, xt->dataType(), nullptr); - MmulHelper::mmul(&concatOut, W, m, 1.0f, 0.0f, 'c'); //mmul: [bs, (nIn+numUnits)]* [(inSize+numUnits), 4*numUnits] = [bs, 4*numUnits] - C result array - *m += (*b); //addiRowVector + auto m = NDArrayFactory::create('c', shape, xt->dataType()); + MmulHelper::mmul(&concatOut, W, &m, 1.0f, 0.0f, 'c'); //mmul: [bs, (nIn+numUnits)]* [(inSize+numUnits), 4*numUnits] = [bs, 4*numUnits] - C result array + m += (*b); //addiRowVector //Note: weights are ordered [inputGate, blockInput, forgetGate, outputGate] to match TF (TF code comments state [i,f,z/ci,o] but behaviour is [i,z,f,o]) - auto zi = (*m)({0,0, 0, numUnits}); // z for input modulation gate, [bS, numUnits] - auto zz = (*m)({0,0, numUnits, 2*numUnits}); // z for block input, [bS, numUnits] - auto zf = (*m)({0,0, 2*numUnits, 3*numUnits}); // z for forget gate, [bS, numUnits] - auto zo = (*m)({0,0, 3*numUnits, 4*numUnits}); // z for output gate, [bS, numUnits] + auto zi = (m)({0,0, 0, numUnits}); // z for input modulation gate, [bS, numUnits] + auto zz = (m)({0,0, numUnits, 2*numUnits}); // z for block input, [bS, numUnits] + auto zf = (m)({0,0, 2*numUnits, 3*numUnits}); // z for forget gate, [bS, numUnits] + auto zo = (m)({0,0, 3*numUnits, 4*numUnits}); // z for output gate, [bS, numUnits] if(peephole) { // add peephole connections: z + ct_1*Wc zi += (*cLast) * (*Wci); // add peephole connections to input gate diff --git a/libnd4j/include/ops/declarable/helpers/cpu/max_pooling.cpp b/libnd4j/include/ops/declarable/helpers/cpu/max_pooling.cpp index bfda53e56..06fe2eec2 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/max_pooling.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/max_pooling.cpp @@ -57,7 +57,7 @@ namespace helpers { ConvolutionUtils::calcPadding2D(pY, pX, oY, oX, inY, inX, params[0], params[1], params[2], params[3], params[6], params[7]); // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor; - ConvolutionUtils::pooling2d(*block.launchContext(), *input, *values, kY, kX, sY, sX, pY, pX, dY, dX, PoolingType::MAX_POOL, 1); + ConvolutionUtils::pooling2d(block, *input, *values, kY, kX, sY, sX, pY, pX, dY, dX, PoolingType::MAX_POOL, 1); if (nullptr != indices) { // for max_pool_with_argmax diff --git a/libnd4j/include/ops/declarable/impl/BroadcastableOp.cpp b/libnd4j/include/ops/declarable/impl/BroadcastableOp.cpp index be17ac074..71c722bca 100644 --- a/libnd4j/include/ops/declarable/impl/BroadcastableOp.cpp +++ b/libnd4j/include/ops/declarable/impl/BroadcastableOp.cpp @@ -53,9 +53,16 @@ namespace nd4j { dtype = nd4j::DataType::BOOL; if(shape::isEmpty(x) || shape::isEmpty(y)) { - //Edge case: broadcasting with empty array gives empty array output (behaviour to match TF for import cases) - auto empty = ConstantShapeHelper::getInstance()->emptyShapeInfo(dtype); - shapeList->push_back(empty); + // this is edge case, [3, 4] + [] = [] + if ((shape::isEmpty(x) && shape::rank(x) == 0) || (shape::isEmpty(y) && shape::rank(y) == 0)) { + shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor::emptyDescriptor(dtype))); + return shapeList; + } + + + Nd4jLong *newshape = nullptr; + ShapeUtils::evalBroadcastShapeInfo(x, y, true, newshape, block.workspace()); + shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(newshape, dtype))); } else if (shape::isScalar(x) && shape::isScalar(y)) { if (shape::rank(x) >= shape::rank(y)) { shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(x, dtype))); diff --git a/libnd4j/include/ops/ops.h b/libnd4j/include/ops/ops.h index 9c14403c3..1b9f1fe4a 100644 --- a/libnd4j/include/ops/ops.h +++ b/libnd4j/include/ops/ops.h @@ -2873,7 +2873,7 @@ namespace simdOps { const static functions::ReduceType reduceType = functions::ReduceType::MAX; op_def static X startingValue(const X *input) { - return -nd4j::DataTypeUtils::max(); + return -nd4j::DataTypeUtils::infOrMax(); } op_def static X merge(X old, X opOutput, X *extraParams) { @@ -3051,7 +3051,7 @@ namespace simdOps { const static functions::ReduceType reduceType = functions::ReduceType::MIN; op_def static X startingValue(const X *input) { - return nd4j::DataTypeUtils::max(); + return nd4j::DataTypeUtils::infOrMax(); } op_def static X merge(X old, X opOutput, X *extraParams) { @@ -3831,7 +3831,7 @@ namespace simdOps { } static _CUDA_HD inline X startingValue(const X *input) { - return -nd4j::DataTypeUtils::max(); + return -nd4j::DataTypeUtils::infOrMax(); } static _CUDA_HD inline functions::indexreduce::IndexValue startingIndexValue(X *input) { @@ -3890,7 +3890,7 @@ namespace simdOps { } static _CUDA_HD inline X startingValue(const X *input) { - return -nd4j::DataTypeUtils::max(); + return -nd4j::DataTypeUtils::infOrMax(); } static _CUDA_HD inline functions::indexreduce::IndexValue startingIndexValue(X *input) { @@ -3958,7 +3958,7 @@ namespace simdOps { } static _CUDA_HD inline X startingValue(const X *input) { - return -nd4j::DataTypeUtils::max(); + return -nd4j::DataTypeUtils::infOrMax(); } static _CUDA_HD inline functions::indexreduce::IndexValue startingIndexValue(X *input) { @@ -3984,7 +3984,7 @@ namespace simdOps { } static _CUDA_HD inline X startingValue(const X *input) { - return nd4j::DataTypeUtils::max(); + return nd4j::DataTypeUtils::infOrMax(); } static _CUDA_HD inline functions::indexreduce::IndexValue startingIndexValue(X *input) { @@ -4040,7 +4040,7 @@ namespace simdOps { } static _CUDA_HD inline X startingValue(const X *input) { - return nd4j::DataTypeUtils::max(); + return nd4j::DataTypeUtils::infOrMax(); } static _CUDA_HD inline functions::indexreduce::IndexValue startingIndexValue(X *input) { diff --git a/libnd4j/tests_cpu/layers_tests/BroadcastableOpsTests.cpp b/libnd4j/tests_cpu/layers_tests/BroadcastableOpsTests.cpp index cf2c6d9ec..c6b834a33 100644 --- a/libnd4j/tests_cpu/layers_tests/BroadcastableOpsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/BroadcastableOpsTests.cpp @@ -580,6 +580,152 @@ TEST_F(BroadcastableOpsTests, broadcast_empty_1) { ASSERT_TRUE(z.equalsTo(zExp)); } +TEST_F(BroadcastableOpsTests, broadcast_empty_2) { + + NDArray y('c', {1,4}, {1,2,3,4}); + NDArray x = NDArrayFactory::create('c', {0, 4}); + NDArray e = NDArrayFactory::create('c', {0, 4});; + + nd4j::ops::multiply op; + auto status = op.execute({&x, &y}, {&x}, {}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(e.isSameShape(x)); + ASSERT_TRUE(e.equalsTo(x)); +} + +TEST_F(BroadcastableOpsTests, broadcast_empty_3) { + + NDArray x = NDArrayFactory::create('c', {1, 0, 2}); + NDArray y('c', {}, {0.1}, nd4j::DataType::FLOAT32); + NDArray e = NDArrayFactory::create('c', {1, 0, 2});; + + nd4j::ops::maximum op; + auto result = op.execute({&x, &y}, {}, {}); + + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); + + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(*z)); + + delete result; +} + +TEST_F(BroadcastableOpsTests, broadcast_empty_4) { + + NDArray x = NDArrayFactory::create('c', {1, 0, 1}); + NDArray y = NDArrayFactory::create('c', {1, 0, 2}); + NDArray e = NDArrayFactory::create('c', {1, 0, 2});; + + nd4j::ops::maximum op; + auto result = op.execute({&x, &y}, {}, {}); + + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); + + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(*z)); + + delete result; +} + +TEST_F(BroadcastableOpsTests, broadcast_empty_5) { + + NDArray x = NDArrayFactory::create('c', {1, 0, 1}); + NDArray y = NDArrayFactory::create('c', {1, 0, 2}); + NDArray e = NDArrayFactory::create('c', {1, 0, 2});; + + nd4j::ops::realdiv op; + auto result = op.execute({&x, &y}, {}, {}); + + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); + + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(*z)); + + delete result; +} + +TEST_F(BroadcastableOpsTests, broadcast_empty_6) { + + NDArray x = NDArrayFactory::create('c', {1, 0, 1}); + NDArray y = NDArrayFactory::create('c', {1, 2}, {2, 2}); + NDArray e = NDArrayFactory::create('c', {1, 0, 2});; + + nd4j::ops::realdiv op; + auto result = op.execute({&x, &y}, {}, {}); + + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); + + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(*z)); + + delete result; +} + +TEST_F(BroadcastableOpsTests, broadcast_empty_7) { + + NDArray x = NDArrayFactory::create('c', {1, 0, 2, 1}); + NDArray y = NDArrayFactory::create('c', {1, 2, 0}); + NDArray e = NDArrayFactory::create('c', {1, 0, 2, 0});; + + nd4j::ops::realdiv op; + auto result = op.execute({&x, &y}, {}, {}); + + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); + + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(*z)); + + delete result; +} + + +TEST_F(BroadcastableOpsTests, broadcast_bool_empty_1) { + + NDArray y('c', {3,4}, {0,0,0,0, 1,2,3,4, 1,2,3,4}); + NDArray x(nd4j::DataType::DOUBLE, y.getContext(), false); + NDArray z(nd4j::DataType::BOOL, y.getContext(), false); + NDArray zExp(nd4j::DataType::BOOL, y.getContext(), false); + + nd4j::ops::greater op; + auto status = op.execute({&x, &y}, {&z}, {}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(z.isSameShape(zExp)); + ASSERT_TRUE(z.equalsTo(zExp)); +} + +TEST_F(BroadcastableOpsTests, broadcast_bool_empty_2) { + + NDArray y('c', {1,4}, {1,2,3,4}); + NDArray x = NDArrayFactory::create('c', {0, 4}); + NDArray e = NDArrayFactory::create('c', {0, 4});; + + + nd4j::ops::greater op; + auto result = op.execute({&x, &y}, {}, {}, {}); + + auto z = result->at(0); + + z->printShapeInfo("z"); + + ASSERT_EQ(Status::OK(), result->status()); + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(*z)); + + delete result; +} + TEST_F(BroadcastableOpsTests, broadcast_bool_1) { NDArray x('c', {3, 1, 2}, nd4j::DataType::FLOAT32); diff --git a/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp b/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp index a029d8f14..f7008c7d0 100644 --- a/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp +++ b/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp @@ -2021,7 +2021,8 @@ TEST_F(ConvolutionTests1, vol2col_test1) { // PointersManager manager(columnsExpected.getContext()); // manager.printDevContentOnHost(columnsExpected.getSpecialBuffer(), columnsExpected.lengthOf()); - nd4j::ops::ConvolutionUtils::vol2col(*LaunchContext::defaultContext(), volume, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); + graph::Context context(1); + nd4j::ops::ConvolutionUtils::vol2col(context, volume, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); ASSERT_TRUE(columns.equalsTo(columnsExpected)); } @@ -2052,7 +2053,8 @@ TEST_F(ConvolutionTests1, vol2col_test2) { -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.}); - nd4j::ops::ConvolutionUtils::vol2col(*LaunchContext::defaultContext(), volume, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); + graph::Context context(1); + nd4j::ops::ConvolutionUtils::vol2col(context, volume, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); ASSERT_TRUE(columns.equalsTo(columnsExpected)); } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp index 89e1d4110..1f94a18c3 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp @@ -1302,8 +1302,8 @@ TEST_F(DeclarableOpsTests10, broadcast_to_test6) { TEST_F(DeclarableOpsTests10, broadcast_to_test7) { auto input = NDArrayFactory::create(10.f); - auto shape = NDArrayFactory::create(0.f); - auto exp = NDArrayFactory::create(10.f); + auto shape = NDArrayFactory::create(1); + auto exp = NDArrayFactory::create('c', {1}, {10.}); nd4j::ops::broadcast_to op; auto results = op.execute({&input, &shape}, {}, {}, {}); @@ -2261,8 +2261,8 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_1) { NDArray x('c', {2,3}, {-63.80f, -63.75f, -63.70f, -63.5f, 0.0f, 0.1f}, nd4j::DataType::FLOAT32); NDArray exp('c', {2,3}, {-63.75f, -63.75f, -63.75f, -63.251953f, 0.0f, 0.0f}, nd4j::DataType::FLOAT32); - NDArray min('c', {0}, {-63.65f}, nd4j::DataType::FLOAT32); - NDArray max('c', {0}, {0.1f}, nd4j::DataType::FLOAT32); + NDArray min('c', {}, {-63.65f}, nd4j::DataType::FLOAT32); + NDArray max('c', {}, {0.1f}, nd4j::DataType::FLOAT32); nd4j::ops::fake_quant_with_min_max_vars op; auto results = op.execute({&x, &min, &max}, {}, {}); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp index 37ffe6368..01e8e82c2 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp @@ -136,7 +136,7 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test3) { NDArray dLdpExp('c', {2,3,4}, {-12.49997,-13.04346, -13.63635, -14.28571,-14.99999,-15.78947, -16.66666, -17.64705,-18.75 ,-20. , -21.42857, -23.07692, -24.99999,-27.27272, -29.99999, -33.33332,-37.49999,-42.85713, -49.99998, -59.99998,-74.99995,-99.99992,-149.99986,-299.99911}); - NDArray dLdwExp('c', {0}, {-227.77286}); + NDArray dLdwExp('c', {}, {-227.77286}); NDArray dLdlExp('c', {2,3,4}, {1.58903, 1.22117, 0.99621, 0.82911, 0.69315, 0.57634, 0.47223, 0.37689, 0.28768, 0.20273, 0.12058, 0.04002, -0.04002,-0.12058,-0.20273,-0.28768,-0.37689,-0.47223,-0.57634,-0.69315,-0.82911,-0.99621,-1.22117,-1.58903}); @@ -261,7 +261,7 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test7) { NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE); NDArray weights(nd4j::DataType::DOUBLE); - NDArray dLdwExp('c', {0}, {0.}); + NDArray dLdwExp('c', {}, {0.}); predictions.linspace(0.04, 0.04); labels.linspace(1); @@ -583,7 +583,7 @@ TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test3) { NDArray dLdpExp('c', {2,3,4}, {-0.96, -1.92, -2.88, -3.84, -4.8 , -5.76, -6.72, -7.68, -8.64, -9.6 ,-10.56,-11.52, -12.48,-13.44,-14.4 ,-15.36,-16.32,-17.28,-18.24,-19.2 ,-20.16,-21.12,-22.08,-23.04}); - NDArray dLdwExp('c', {0}, {4515.84}); + NDArray dLdwExp('c', {}, {4515.84}); predictions.linspace(0.04, 0.04); labels.linspace(1); @@ -702,7 +702,7 @@ TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test7) { NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE); NDArray weights(nd4j::DataType::DOUBLE); - NDArray dLdwExp('c', {0}, {0.}); + NDArray dLdwExp('c', {}, {0.}); predictions.linspace(0.04, 0.04); labels.linspace(1); @@ -1031,7 +1031,7 @@ TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test3) { NDArray dLdpExp('c', {2,3,4}, {-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5, -0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5}); - NDArray dLdwExp('c', {0}, {288.}); + NDArray dLdwExp('c', {}, {288.}); predictions.linspace(0.04, 0.04); labels.linspace(1); @@ -1150,7 +1150,7 @@ TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test7) { NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE); NDArray weights(nd4j::DataType::DOUBLE); - NDArray dLdwExp('c', {0}, {0.}); + NDArray dLdwExp('c', {}, {0.}); predictions.linspace(0.04, 0.04); labels.linspace(1); @@ -1519,7 +1519,7 @@ TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test3) { NDArray dLdpExp('c', {2,3,4}, {-0.18499,-0.53 ,-0.875 ,-1.22 ,-1.56501,-1.91002,-2.25504,-2.60008,-2.94514,-3.29023,-3.63534,-3.98048, -4.32566,-4.67087,-5.01613,-5.36143,-5.70677,-6.05217,-6.39762,-6.74313,-7.0887 ,-7.43432,-7.78001,-8.12577}); - NDArray dLdwExp('c', {0}, {-91.52109}); + NDArray dLdwExp('c', {}, {-91.52109}); NDArray dLdlExp('c', {2,3,4}, {0.028, 0.014, -0., -0.014,-0.028, -0.042, -0.056, -0.07 ,-0.084, -0.098, -0.112, -0.126, -0.14 , -0.154, -0.168, -0.182,-0.196, -0.21 , -0.224, -0.238,-0.252, -0.266, -0.28 , -0.294}); @@ -1642,7 +1642,7 @@ TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test7) { NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE); NDArray weights(nd4j::DataType::DOUBLE); - NDArray dLdwExp('c', {0}, {0.}); + NDArray dLdwExp('c', {}, {0.}); logits.linspace(-0.08, 0.04); labels.linspace(1); @@ -2001,10 +2001,10 @@ TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test3) { NDArray labels('c', {4}, {0,0,1,0}, nd4j::DataType::INT32); NDArray logits('c', {4}, nd4j::DataType::DOUBLE); - NDArray weights('c', {0}, nd4j::DataType::DOUBLE); + NDArray weights('c', {}, {0}, nd4j::DataType::DOUBLE); NDArray dLdpExp('c', {4}, {0.125, 0.125, -0.375, 0.125}); - NDArray dLdwExp('c', {0}, {1.38629}); + NDArray dLdwExp('c', {}, {1.38629}); logits = 2.; weights.assign(0.5); @@ -2032,10 +2032,10 @@ TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test4) { NDArray labels('c', {4}, {0,0,1,0}, nd4j::DataType::INT32); NDArray logits('c', {4}, nd4j::DataType::DOUBLE); - NDArray weights('c', {0}, nd4j::DataType::DOUBLE); + NDArray weights('c', {}, {0}, nd4j::DataType::DOUBLE); NDArray dLdpExp('c', {4}, {0.23521, 0.2448 , -0.7452 , 0.26519}); - NDArray dLdwExp('c', {0}, {0.}); + NDArray dLdwExp('c', {}, {0.}); logits.linspace(-0.08, 0.04); weights = 0.5; @@ -2466,7 +2466,7 @@ TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test2) { ///////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test3) { - NDArray labels('c', {0}, {1}, nd4j::DataType::INT64); + NDArray labels('c', {}, {1}, nd4j::DataType::INT64); NDArray logits('c', {2}, {-0.2, 0.3}); NDArray dLdpExp('c', {2}, {0.37754, -0.37754}); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp index 5ba1f8a81..936fef712 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp @@ -158,10 +158,10 @@ TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test4) { NDArray labels('c', {1,4}, {-0.1, 0.3, 2, -1.4}); NDArray predictions('c', {1,4}, nd4j::DataType::DOUBLE); - NDArray weights('c', {0}, nd4j::DataType::DOUBLE); + NDArray weights('c', {}, {0.}, nd4j::DataType::DOUBLE); NDArray dLdpExp('c', {1,4}, {0.05, -0.15, -1., 0.7}); - NDArray dLdwExp('c', {0}, {1.3}); + NDArray dLdwExp('c', {}, {1.3}); NDArray dLdlExp('c', {1,4}, {0.2, 0.1, -0. , -0.1}); predictions.linspace(-0.4, 0.2); @@ -369,10 +369,10 @@ TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test9) { TEST_F(DeclarableOpsTests12, hinge_loss_14) { NDArray logits('c', {3,4}, nd4j::DataType::DOUBLE); - NDArray weights('c', {0}, {1.}); + NDArray weights('c', {}, {1.}); NDArray labels('c', {3,4}, {0,1,1,0,1,0,1,0,1,0,1,0}); - NDArray output('c', {0}, nd4j::DataType::DOUBLE); + NDArray output('c', {}, {0.}, nd4j::DataType::DOUBLE); logits.linspace(1.); weights.assign(1.); @@ -594,7 +594,7 @@ TEST_F(DeclarableOpsTests12, TestMinimumBP_1) { TEST_F(DeclarableOpsTests12, reverse_test15) { NDArray x('c', {5}, {1,2,3,4,5}, nd4j::DataType::DOUBLE); - NDArray axis('c', {0}, {0}, nd4j::DataType::INT32); + NDArray axis('c', {}, {0}, nd4j::DataType::INT32); NDArray z('c', {5}, nd4j::DataType::DOUBLE); NDArray exp('c', {5}, {5,4,3,2,1}, nd4j::DataType::DOUBLE); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests14.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests14.cpp index d3b5091c3..57bdf0faf 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests14.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests14.cpp @@ -123,6 +123,7 @@ TEST_F(DeclarableOpsTests14, Test_EvalReductionShape_1) { ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); + z->printIndexedBuffer("Reduced shape"); ASSERT_EQ(e, *z); delete result; @@ -212,4 +213,200 @@ TEST_F(DeclarableOpsTests14, Test_scalar_broadcast_2) { ASSERT_EQ(e, *result->at(0)); delete result; -} \ No newline at end of file +} + +TEST_F(DeclarableOpsTests14, test_empty_fill_1) { + auto x = NDArrayFactory::empty(); + auto y = NDArrayFactory::create(1); + + nd4j::ops::fill op; + auto result = op.execute({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); + ASSERT_EQ(y, *z); + + delete result; +} + +TEST_F(DeclarableOpsTests14, test_lstmBlockCell_1) { + auto a = NDArrayFactory::create('c', {1, 5}, {0.7787856f, 0.80119777f, 0.72437465f, 0.23089433f, 0.72714126f}); + auto b = NDArrayFactory::create('c', {1, 3}); + auto c = NDArrayFactory::create('c', {1, 3}); + auto d = NDArrayFactory::create('c', {8, 12}, {-0.15320599,-0.120416045,0.33126968,0.13921785,-0.32313538,-0.43956736,0.4756174,0.4335605,-0.5450856,-0.3943429,-0.28687626,0.068032146,-0.2793799,0.17298919,-0.36553562,-0.097853184,-0.2544747,-0.39872527,-0.14556861,-0.31479517,0.2559092,0.47166896,-0.31330687,0.47313118,0.5134543,-0.4678212,-0.12853557,0.26142156,0.43472284,-0.42842552,-0.1895876,0.538689,0.508651,-0.020272732,0.112327516,0.2704304,-0.046546757,0.32570732,-0.15148133,-0.19145513,0.18631572,-0.024152994,0.41603214,-0.3421499,0.0106860995,-0.2966229,-0.36713937,0.25841123,0.0843398,0.49082482,0.10800403,0.1874243,-0.26379472,-0.22531849,0.24924624,0.23119557,0.49940765,-0.051413506,0.20315129,-0.41888732,0.44097036,0.40453392,0.013338983,0.23434466,0.23942488,0.47894,-0.19898453,0.09253675,-0.032358468,-0.15213022,-0.3441009,-0.15600958,-0.08235118,0.12165731,-0.4481289,-0.4842423,-0.45797008,-0.4606034,0.08163166,-0.2981107,0.50207126,0.44195646,0.13850057,0.072246075,-0.34388685,0.030900061,0.35821778,0.47900867,0.5094063,0.23683065,0.18020362,-0.1369732,0.015235603,0.2786904,0.07954317,0.12543976}); + auto e = NDArrayFactory::create('c', {3}); + auto f = NDArrayFactory::create('c', {3}); + auto g = NDArrayFactory::create('c', {3}); + auto h = NDArrayFactory::create('c', {12}); + + auto z0 = NDArrayFactory::create('c', {1, 3}); + auto z1 = NDArrayFactory::create('c', {1, 3}); + auto z2 = NDArrayFactory::create('c', {1, 3}); + auto z3 = NDArrayFactory::create('c', {1, 3}); + auto z4 = NDArrayFactory::create('c', {1, 3}); + auto z5 = NDArrayFactory::create('c', {1, 3}); + auto z6 = NDArrayFactory::create('c', {1, 3}); + + nd4j::ops::lstmBlockCell op; + auto result = op.execute({&a, &b, &c, &d, &e, &f, &g, &h}, {&z0, &z1, &z2, &z3, &z4, &z5, &z6}, {1.0, -1.0}, {0}, {}); + ASSERT_EQ(Status::OK(), result); +} + +TEST_F(DeclarableOpsTests14, test_empty_stack_1) { + auto x = NDArrayFactory::create('c', {0}); + auto e = NDArrayFactory::create('c', {1, 0}); + + nd4j::ops::stack op; + auto result = op.execute({&x}, {}, {0}); + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); + ASSERT_EQ(e, *z); + nd4j::ops::reduce_min sumOp; + auto res2 = sumOp.execute({&e}, {1.}, {1}); + ASSERT_EQ(res2->status(), Status::OK()); + auto out = res2->at(0); + out->printShapeInfo("ReduceSum empty shape with keep dims"); + out->printIndexedBuffer("ReduceSum scalar"); + ASSERT_EQ(out->e(0), DataTypeUtils::infOrMax()); + delete res2; + delete result; +} + +TEST_F(DeclarableOpsTests14, test_empty_stack_2) { + auto x = NDArrayFactory::empty(); + auto e = NDArrayFactory::create('c', {0}); + + nd4j::ops::stack op; + auto result = op.execute({&x}, {}, {0}); + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); + ASSERT_EQ(e, *z); + + delete result; +} + +TEST_F(DeclarableOpsTests14, test_empty_stack_3) { + auto x = NDArrayFactory::empty(); + auto e = NDArrayFactory::create('c', {2, 0}); + + nd4j::ops::stack op; + auto result = op.execute({&x, &x}, {}, {0}); + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); + ASSERT_EQ(e, *z); + + delete result; +} + +TEST_F(DeclarableOpsTests14, test_empty_stack_4) { + auto x = NDArrayFactory::create('c', {0}); + auto e = NDArrayFactory::create('c', {2, 0}); + + nd4j::ops::stack op; + auto result = op.execute({&x, &x}, {}, {0}); + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); + ASSERT_EQ(e, *z); + + delete result; +} + +TEST_F(DeclarableOpsTests14, test_empty_reduce_min_1) { + + auto e = NDArrayFactory::create('c', {1, 0}); + nd4j::ops::reduce_min sumOp; + auto res2 = sumOp.execute({&e}, {1.}, {1}); + ASSERT_EQ(res2->status(), Status::OK()); + auto out = res2->at(0); + + ASSERT_EQ(out->e(0), DataTypeUtils::infOrMax()); + delete res2; +} + +TEST_F(DeclarableOpsTests14, test_empty_reduce_max_1) { + + auto e = NDArrayFactory::create('c', {1, 0}); + nd4j::ops::reduce_max sumOp; + auto res2 = sumOp.execute({&e}, {1.}, {1}); + ASSERT_EQ(res2->status(), Status::OK()); + auto out = res2->at(0); + + ASSERT_EQ(out->e(0), -DataTypeUtils::infOrMax()); + delete res2; +} + +TEST_F(DeclarableOpsTests14, test_empty_reduce_sum_1) { + + auto e = NDArrayFactory::create('c', {1, 0}); + nd4j::ops::reduce_sum sumOp; + auto res2 = sumOp.execute({&e}, {1.}, {1}); + ASSERT_EQ(res2->status(), Status::OK()); + auto out = res2->at(0); + ASSERT_EQ(out->e(0), 0.f); + delete res2; +} + +TEST_F(DeclarableOpsTests14, test_empty_reduce_mean_1) { + + auto e = NDArrayFactory::create('c', {1, 0}); + nd4j::ops::reduce_mean sumOp; + auto res2 = sumOp.execute({&e}, {1.}, {1}); + ASSERT_EQ(res2->status(), Status::OK()); + auto out = res2->at(0); + out->printShapeInfo("ReduceMean empty shape with keep dims"); + out->printIndexedBuffer("ReduceMean scalar"); + ASSERT_EQ(out->e(0), 0.f); + delete res2; +} + +TEST_F(DeclarableOpsTests14, test_empty_argmax_1) { + auto x = NDArrayFactory::create('c', {1, 0}); + auto y = NDArrayFactory::create(0); + auto e = NDArrayFactory::create('c', {0}); + + nd4j::ops::argmax op; + //nd4j::ops::reduce_max op; + + auto result = op.execute({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); + + z->printShapeInfo("Z"); + + ASSERT_EQ(e, *z); + + delete result; +} + +TEST_F(DeclarableOpsTests14, test_empty_argmax_2) { + auto x = NDArrayFactory::create('c', {1, 0}); + auto y = NDArrayFactory::create(1); + + nd4j::ops::argmax op; + try { + auto result = op.execute({&x, &y}, {&y}, {}, {}, {}); + ASSERT_TRUE(false); + } catch (std::exception &e) { + // + } +} + +TEST_F(DeclarableOpsTests14, test_empty_tanh_5) { + auto x = NDArrayFactory::create('c', {32, 0}); + + nd4j::ops::tanh op; + auto result = op.execute({&x}, {}, {}); + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); + + ASSERT_TRUE(x.isSameShape(z)); + ASSERT_EQ(x, *z); + + delete result; +} diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests4.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests4.cpp index a8ff9d886..d88f7821b 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests4.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests4.cpp @@ -905,9 +905,9 @@ TEST_F(DeclarableOpsTests4, Test_StridedSlice_Alex_2) { TEST_F(DeclarableOpsTests4, Test_StridedSlice_Alex_3) { auto x = NDArrayFactory::create('c', {1}, {10}); - auto begin = NDArrayFactory::create('c', {1}, {0.}); - auto end = NDArrayFactory::create('c', {1}, {0.}); - auto stride = NDArrayFactory::create('c', {1}, {1}); + auto begin = NDArrayFactory::create('c', {1}, {(int)0}); + auto end = NDArrayFactory::create('c', {1}, {(int)0}); + auto stride = NDArrayFactory::create('c', {1}, {1}); //x.linspace(1); //auto exp = NDArrayFactory::create('c', {1,3,4,5}); //exp.linspace(1); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp index 5b38e58d5..6fa2b120a 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp @@ -452,14 +452,14 @@ TEST_F(DeclarableOpsTests5, Test_BatchToSpace_3_1) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, eye_test1) { - + auto expected = NDArrayFactory::create('c', {3, 3}, {1, 0, 0, 0, 1, 0, 0, 0, 1}); nd4j::ops::eye op; auto results = op.execute({}, {}, {-99, 3}); auto output = results->at(0); // output->printIndexedBuffer(); - + ASSERT_EQ(Status::OK(), results->status()); ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); @@ -469,7 +469,7 @@ TEST_F(DeclarableOpsTests5, eye_test1) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, eye_test2) { - + auto expected = NDArrayFactory::create('c', {3, 4}, {1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0}); nd4j::ops::eye op; @@ -485,14 +485,14 @@ TEST_F(DeclarableOpsTests5, eye_test2) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, eye_test3) { - + auto expected = NDArrayFactory::create('c', {2, 3, 4}, {1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0}); nd4j::ops::eye op; auto results = op.execute({}, {}, {-99, 3, 4, 2}); auto output = results->at(0); output->printIndexedBuffer("Output eye"); - + ASSERT_EQ(Status::OK(), results->status()); ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); @@ -502,13 +502,13 @@ TEST_F(DeclarableOpsTests5, eye_test3) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, eye_test4) { - + auto expected = NDArrayFactory::create('c', {2, 2, 3, 4}, {1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0.}); nd4j::ops::eye op; auto results = op.execute({}, {}, {-99, 3, 4, 2, 2}); auto output = results->at(0); - + ASSERT_EQ(Status::OK(), results->status()); ASSERT_TRUE(expected.isSameShape(output)); ASSERT_TRUE(expected.equalsTo(output)); @@ -633,7 +633,7 @@ TEST_F(DeclarableOpsTests5, gatherNd_test6) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, reverse_sequense_test1) { - + auto input = NDArrayFactory::create('c', {3, 4, 5}); input.linspace(1); auto seqLengths = NDArrayFactory::create('c', {4}, {4,4,4,4}); @@ -642,7 +642,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test1) { nd4j::ops::reverse_sequence op; auto results = op.execute({&input, &seqLengths}, {}, {2, 1}); auto output = results->at(0); - + ASSERT_EQ(Status::OK(), results->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); @@ -652,7 +652,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test1) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, reverse_sequense_test2) { - + auto input = NDArrayFactory::create('c', {3, 4, 5}); input.linspace(1); auto seqLengths = NDArrayFactory::create('c', {4}, {0,1,2,3}); @@ -661,7 +661,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test2) { nd4j::ops::reverse_sequence op; auto results = op.execute({&input, &seqLengths}, {}, {2, 1}); auto output = results->at(0); - + ASSERT_EQ(Status::OK(), results->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); @@ -671,7 +671,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test2) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, reverse_sequense_test3) { - + auto input = NDArrayFactory::create('c', {3, 4, 5}); input.linspace(1); auto seqLengths = NDArrayFactory::create('c', {3}, {2,3,4}); @@ -680,7 +680,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test3) { nd4j::ops::reverse_sequence op; auto results = op.execute({&input, &seqLengths}, {}, {2, 0}); auto output = results->at(0); - + ASSERT_EQ(Status::OK(), results->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); @@ -690,7 +690,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test3) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, reverse_sequense_test4) { - + auto input = NDArrayFactory::create('c', {3, 4, 5}); input.linspace(1); auto seqLengths = NDArrayFactory::create('c', {5}, {1, 2, 1, 2, 3}); @@ -699,7 +699,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test4) { nd4j::ops::reverse_sequence op; auto results = op.execute({&input, &seqLengths}, {}, {0, 2}); auto output = results->at(0); - + ASSERT_EQ(Status::OK(), results->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); @@ -709,7 +709,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test4) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, reverse_sequense_test5) { - + auto input = NDArrayFactory::create('c', {3, 4, 5}); input.linspace(1); auto seqLengths = NDArrayFactory::create('c', {5}, {1, 2, 4, 2, 3}); @@ -718,7 +718,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test5) { nd4j::ops::reverse_sequence op; auto results = op.execute({&input, &seqLengths}, {}, {1, 2}); auto output = results->at(0); - + ASSERT_EQ(Status::OK(), results->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); @@ -728,7 +728,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test5) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, reverse_sequense_test6) { - + auto input = NDArrayFactory::create('c', {3, 4, 5}); input.linspace(1); auto seqLengths = NDArrayFactory::create('c', {4}, {1, 2, 3, 2}); @@ -737,7 +737,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test6) { nd4j::ops::reverse_sequence op; auto results = op.execute({&input, &seqLengths}, {}, {0, 1}); auto output = results->at(0); - + ASSERT_EQ(Status::OK(), results->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); @@ -747,7 +747,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test6) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, reverse_sequense_test7) { - + auto input = NDArrayFactory::create('c', {1, 5}); input.linspace(1); std::vector data = {3}; @@ -757,7 +757,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test7) { nd4j::ops::reverse_sequence op; auto results = op.execute({&input, &seqLengths}, {}, {1, 0}); auto output = results->at(0); - + ASSERT_EQ(Status::OK(), results->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); @@ -767,7 +767,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test7) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, reverse_sequense_test8) { - + auto input = NDArrayFactory::create('c', {1, 5}); input.linspace(1); std::vector data = {1,0,1,0,1}; @@ -777,7 +777,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test8) { nd4j::ops::reverse_sequence op; auto results = op.execute({&input, &seqLengths}, {}, {0, 1}); auto output = results->at(0); - + ASSERT_EQ(Status::OK(), results->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); @@ -787,7 +787,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test8) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, reverse_sequense_test9) { - + auto input = NDArrayFactory::create('c', {5, 1}); input.linspace(1); std::vector data = {1,0,1,0,1}; @@ -797,7 +797,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test9) { nd4j::ops::reverse_sequence op; auto results = op.execute({&input, &seqLengths}, {}, {1, 0}); auto output = results->at(0); - + ASSERT_EQ(Status::OK(), results->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); @@ -807,7 +807,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test9) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, reverse_sequense_test10) { - + auto input = NDArrayFactory::create('c', {5, 1}); input.linspace(1); std::vector data = {3}; @@ -817,7 +817,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test10) { nd4j::ops::reverse_sequence op; auto results = op.execute({&input, &seqLengths}, {}, {0, 1}); auto output = results->at(0); - + ASSERT_EQ(Status::OK(), results->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); @@ -827,7 +827,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test10) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, reverse_sequense_test11) { - + auto input = NDArrayFactory::create('c', {1, 1, 5, 1}); input.linspace(1); std::vector data = {1, 0, 1, 0, 1}; @@ -837,7 +837,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test11) { nd4j::ops::reverse_sequence op; auto results = op.execute({&input, &seqLengths}, {}, {1, 2}); auto output = results->at(0); - + ASSERT_EQ(Status::OK(), results->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); @@ -847,7 +847,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test11) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, reverse_sequense_test12) { - + auto input = NDArrayFactory::create('c', {1, 1, 5, 1}); input.linspace(1); std::vector data = {3}; @@ -857,7 +857,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test12) { nd4j::ops::reverse_sequence op; auto results = op.execute({&input, &seqLengths}, {}, {2, 0}); auto output = results->at(0); - + ASSERT_EQ(Status::OK(), results->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); @@ -867,7 +867,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test12) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, reverse_sequense_test13) { - + auto input = NDArrayFactory::create('c', {1, 1, 5, 1}); input.linspace(1); std::vector data = {1}; @@ -877,7 +877,7 @@ TEST_F(DeclarableOpsTests5, reverse_sequense_test13) { nd4j::ops::reverse_sequence op; auto results = op.execute({&input, &seqLengths}, {}, {3, 0}); auto output = results->at(0); - + ASSERT_EQ(Status::OK(), results->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); @@ -1307,9 +1307,9 @@ TEST_F(DeclarableOpsTests5, Test_Moments_3) { 11.0, 13.0, 14.0, 5.0, 16.0, 9.0, 13.5, 7.0} ); - + auto expV = NDArrayFactory::create('c', {3, 4}, { 8.5f, 6.f , 8.75f, 6.f, - 8.5f, 11.f, 8.75f, 6.f, + 8.5f, 11.f, 8.75f, 6.f, 18.5f, 6.f, 13.75f, 11.f}); auto expD = NDArrayFactory::create('c', {3, 4}, { 6.25f, 9.f, 27.5625f, 1.f, 6.25f, 4.f, 27.5625f, 1.f, @@ -1368,7 +1368,7 @@ TEST_F(DeclarableOpsTests5, Test_Moments_4) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, trace_test1) { - + auto input = NDArrayFactory::create('c', {3, 4, 5}); input.linspace(1); auto exp = NDArrayFactory::create('c', {3}, {40, 120, 200}); @@ -1389,7 +1389,7 @@ TEST_F(DeclarableOpsTests5, trace_test1) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, trace_test2) { - + auto input = NDArrayFactory::create('c', {4, 5}); input.linspace(1); auto exp = NDArrayFactory::create(40.); @@ -1397,7 +1397,7 @@ TEST_F(DeclarableOpsTests5, trace_test2) { nd4j::ops::trace op; auto results = op.execute({&input}, {}, {}); auto output = results->at(0); - + ASSERT_EQ(Status::OK(), results->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); @@ -1407,7 +1407,7 @@ TEST_F(DeclarableOpsTests5, trace_test2) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, trace_test3) { - + auto input = NDArrayFactory::create('c', {1, 5}); input.linspace(1); auto exp = NDArrayFactory::create(1.); @@ -1415,7 +1415,7 @@ TEST_F(DeclarableOpsTests5, trace_test3) { nd4j::ops::trace op; auto results = op.execute({&input}, {}, {}); auto output = results->at(0); - + ASSERT_EQ(Status::OK(), results->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); @@ -1425,7 +1425,7 @@ TEST_F(DeclarableOpsTests5, trace_test3) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, trace_test4) { - + auto input = NDArrayFactory::create('c', {5, 1}); input.linspace(1); auto exp = NDArrayFactory::create(1.); @@ -1433,7 +1433,7 @@ TEST_F(DeclarableOpsTests5, trace_test4) { nd4j::ops::trace op; auto results = op.execute({&input}, {}, {}); auto output = results->at(0); - + ASSERT_EQ(Status::OK(), results->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); @@ -1443,7 +1443,7 @@ TEST_F(DeclarableOpsTests5, trace_test4) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, trace_test5) { - + auto input = NDArrayFactory::create('c', {3, 4, 5, 6}); input.linspace(1); auto exp = NDArrayFactory::create('c', {3, 4}, {75, 225, 375, 525, 675, 825, 975, 1125, 1275, 1425, 1575, 1725}); @@ -1451,7 +1451,7 @@ TEST_F(DeclarableOpsTests5, trace_test5) { nd4j::ops::trace op; auto results = op.execute({&input}, {}, {}, {}, false, nd4j::DataType::DOUBLE); auto output = results->at(0); - + ASSERT_EQ(Status::OK(), results->status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); @@ -1461,7 +1461,7 @@ TEST_F(DeclarableOpsTests5, trace_test5) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, random_shuffle_test1) { - + auto input = NDArrayFactory::create('c', {2, 2, 2}); input.linspace(1); @@ -1473,7 +1473,7 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test1) { for(int i = 0; i < output->lengthOf(); ++i) if(output->e(i) == (float)0.) haveZeros = true; - + ASSERT_EQ(Status::OK(), results->status()); ASSERT_TRUE(input.isSameShape(output)); ASSERT_TRUE(!input.equalsTo(output)); @@ -1484,9 +1484,9 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test1) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, random_shuffle_test2) { - + auto input = NDArrayFactory::create('c', {1, 3, 2}); - input.linspace(1); + input.linspace(1); nd4j::ops::random_shuffle op; auto results = op.execute({&input}, {}, {}, {}, false, nd4j::DataType::DOUBLE); @@ -1494,14 +1494,14 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test2) { ASSERT_EQ(Status::OK(), results->status()); ASSERT_TRUE(input.isSameShape(output)); - ASSERT_TRUE(input.equalsTo(output)); + ASSERT_TRUE(input.equalsTo(output)); delete results; } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, random_shuffle_test3) { - + auto input = NDArrayFactory::create('c', {3, 2, 1}); input.linspace(1); @@ -1513,7 +1513,7 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test3) { for(int i = 0; i < output->lengthOf(); ++i) if(output->e(i) == (float)0.) haveZeros = true; - + ASSERT_EQ(Status::OK(), results->status()); ASSERT_TRUE(input.isSameShape(output)); ASSERT_TRUE(!input.equalsTo(output)); @@ -1535,7 +1535,7 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test4) { for(int i = 0; i < output->lengthOf(); ++i) if(output->e(i) == (float)0.) haveZeros = true; - + ASSERT_EQ(Status::OK(), results->status()); ASSERT_TRUE(input.isSameShape(output)); ASSERT_TRUE(!input.equalsTo(output)); @@ -1546,7 +1546,7 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test4) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, random_shuffle_test5) { - + auto input = NDArrayFactory::create('c', {4,1}); input.linspace(1); @@ -1558,7 +1558,7 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test5) { for(int i = 0; i < output->lengthOf(); ++i) if(output->e(i) == (float)0.) haveZeros = true; - + ASSERT_EQ(Status::OK(), results->status()); ASSERT_TRUE(input.isSameShape(output)); ASSERT_TRUE(!input.equalsTo(output)); @@ -1569,7 +1569,7 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test5) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, random_shuffle_test6) { - + auto input = NDArrayFactory::create('c', {4,1,1}); input.linspace(1); @@ -1581,7 +1581,7 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test6) { for(int i = 0; i < output->lengthOf(); ++i) if(output->e(i) == (float)0.) haveZeros = true; - + ASSERT_EQ(Status::OK(), results->status()); ASSERT_TRUE(input.isSameShape(output)); ASSERT_TRUE(!input.equalsTo(output)); @@ -1592,7 +1592,7 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test6) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, random_shuffle_test7) { - + auto input = NDArrayFactory::create('c', {1,4}); input.linspace(1); auto exp = NDArrayFactory::create('c', {1,4}, {1, 2, 3, 4}); @@ -1611,11 +1611,11 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test7) { //////////////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, EmbeddingLookup_1) { - + auto x = NDArrayFactory::create('c', {3, 4, 2}, {10, 20, 11, 21, 12, 22, 13, 23, 14, 24, 15, 25, 16, 26, 17, 27, 18, 28, 19, 29, 20, 30, 21, 31}); - + auto y = NDArrayFactory::create({1, 1, 1, 0, 0, 0, 2, 2, 2}); auto exp = NDArrayFactory::create('c', {9, 4, 2}, {14, 24, 15, 25, 16, 26, 17, 27, 14, 24, 15, 25, 16, 26, 17, 27, 14, 24, 15, 25, 16, 26, 17, 27, @@ -1637,17 +1637,17 @@ TEST_F(DeclarableOpsTests5, EmbeddingLookup_1) { ASSERT_TRUE(exp.isSameShape(output)); //output->printIndexedBuffer("Output"); //exp.printIndexedBuffer("Expect"); - + ASSERT_TRUE(exp.equalsTo(output)); delete result; } TEST_F(DeclarableOpsTests5, EmbeddingLookup_2) { - + auto x = NDArrayFactory::create('c', {3, 4, 2}, {10, 20, 30, 40, 50, 60, - 70, 80, 90, 10, 11, 12, - 13, 14, 15, 16, 17, 18, + 70, 80, 90, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}); //1, 0, 1, 0, 1, 0 auto y = NDArrayFactory::create({1, 0, 1, 0, 1, 0}); @@ -1673,7 +1673,7 @@ TEST_F(DeclarableOpsTests5, EmbeddingLookup_2) { ASSERT_TRUE(exp.isSameShape(output)); // output->printIndexedBuffer("Output"); // exp.printIndexedBuffer("Expect"); - + ASSERT_TRUE(exp.equalsTo(output)); delete result; @@ -1721,19 +1721,19 @@ TEST_F(DeclarableOpsTests5, EmbeddingLookup_3) { } TEST_F(DeclarableOpsTests5, DynamicPartition_1) { - + auto x = NDArrayFactory::create('c', {3, 4, 2}, {10, 20, 11, 21, 12, 22, 13, 23, 14, 24, 15, 25, 16, 26, 17, 27, 18, 28, 19, 29, 20, 30, 21, 31}); - + auto y = NDArrayFactory::create('c', {3, 4, 2}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, - 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, - 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f + 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, + 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f } ); /* auto y = NDArrayFactory::create('c', {3, 4}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, - 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, - 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f + 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, + 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f } ); */ @@ -1762,7 +1762,7 @@ TEST_F(DeclarableOpsTests5, DynamicPartition_1) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, DynamicPartition_2) { - + auto x = NDArrayFactory::create('c', {2, 4}, {0.1f, -1.f, 5.2f, 4.3f, -1.f, 7.4f, 0.0f, -2.2f}); auto y = NDArrayFactory::create('c', {2, 4}, {1, 2, 1, 2, 1, 2, 3, 0}); @@ -1794,7 +1794,7 @@ TEST_F(DeclarableOpsTests5, DynamicPartition_2) { TEST_F(DeclarableOpsTests5, DynamicPartition_3) { - + auto x = NDArrayFactory::create('c', {2, 4}, {0.1f, -1.f, 5.2f, 4.3f, -1.f, 7.4f, 0.0f, -2.2f}); auto y = NDArrayFactory::create('c', {2, 4}, {0, 1, 0, 2, 0, 2, 3, 0}); @@ -1817,7 +1817,7 @@ TEST_F(DeclarableOpsTests5, DynamicPartition_3) { // output->printShapeInfo("Output shape> "); // exp[e].printShapeInfo("Expected shape> "); // output->printIndexedBuffer("Output data> "); - + ASSERT_TRUE(exp[e].isSameShape(output)); ASSERT_TRUE(exp[e].equalsTo(output)); } @@ -1833,13 +1833,13 @@ TEST_F(DeclarableOpsTests5, DynamicPartition_3) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, DynamicStitch_1) { - + auto x1 = NDArrayFactory::create({1., 3., 5., 0.}); auto x2 = NDArrayFactory::create({2., 4.}); auto y2 = NDArrayFactory::create({-1., -1.}); auto y1 = NDArrayFactory::create({0.1f, 5.2f, 4.3f, 7.4f}); - + auto exp = NDArrayFactory::create({7.4f, 0.1f, -1.f, 5.2f, -1.f, 4.3f}); nd4j::ops::dynamic_stitch op; @@ -1852,7 +1852,7 @@ TEST_F(DeclarableOpsTests5, DynamicStitch_1) { // output->printShapeInfo("Output shape> "); // exp.printShapeInfo("Expected shape> "); // output->printIndexedBuffer("Output data> "); - // exp.printIndexedBuffer("Expected res>"); + // exp.printIndexedBuffer("Expected res>"); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); @@ -1862,13 +1862,13 @@ TEST_F(DeclarableOpsTests5, DynamicStitch_1) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, DynamicStitch_2) { - + auto x1 = NDArrayFactory::create({1.f, 3.f}); auto x2 = NDArrayFactory::create({5.f, 0.f, 2.f, 4.f}); auto y1 = NDArrayFactory::create({-1.f, -1.f}); auto y2 = NDArrayFactory::create({0.1f, 5.2f, 4.3f, 7.4f}); - + auto exp = NDArrayFactory::create({5.2f, -1.f, 4.3f, -1.f, 7.4f, 0.1f}); nd4j::ops::dynamic_stitch op; @@ -1881,7 +1881,7 @@ TEST_F(DeclarableOpsTests5, DynamicStitch_2) { // output->printShapeInfo("Output shape> "); // exp.printShapeInfo("Expected shape> "); // output->printIndexedBuffer("Output data> "); - // exp.printIndexedBuffer("Expected res>"); + // exp.printIndexedBuffer("Expected res>"); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); @@ -1890,11 +1890,11 @@ TEST_F(DeclarableOpsTests5, DynamicStitch_2) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, fusedBatchNorm_test1) { - + auto x = NDArrayFactory::create('c', {2, 2, 3, 4}); x.linspace(1); auto scale = NDArrayFactory::create('c', {4}); - + scale = 0.5; auto offset = NDArrayFactory::create('c', {4}); offset = 2.; @@ -1908,7 +1908,7 @@ TEST_F(DeclarableOpsTests5, fusedBatchNorm_test1) { auto y = results->at(0); auto batchMean = results->at(1); auto batchVar = results->at(2); - + ASSERT_EQ(Status::OK(), results->status()); ASSERT_TRUE(expY.isSameShape(y)); ASSERT_TRUE(expBatchMean.isSameShape(batchMean)); @@ -1919,12 +1919,12 @@ TEST_F(DeclarableOpsTests5, fusedBatchNorm_test1) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, fusedBatchNorm_test2) { - + auto x = NDArrayFactory::create('c', {2, 2, 3, 4}); x.linspace(1); auto scale = NDArrayFactory::create('c', {4}); - + scale = 0.5; auto offset = NDArrayFactory::create('c', {4}); offset = 2.; @@ -1937,7 +1937,7 @@ TEST_F(DeclarableOpsTests5, fusedBatchNorm_test2) { auto y = results->at(0); auto batchMean = results->at(1); auto batchVar = results->at(2); - + ASSERT_EQ(Status::OK(), results->status()); ASSERT_TRUE(expY.isSameShape(y)); ASSERT_TRUE(expBatchMean.isSameShape(batchMean)); @@ -1948,12 +1948,12 @@ TEST_F(DeclarableOpsTests5, fusedBatchNorm_test2) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, fusedBatchNorm_test3) { - + auto x = NDArrayFactory::create('c', {2, 4, 2, 3}); x.linspace(1); - + auto scale = NDArrayFactory::create('c', {4}); - + scale = 0.5; auto offset = NDArrayFactory::create('c', {4}); offset = 2.; @@ -1966,7 +1966,7 @@ TEST_F(DeclarableOpsTests5, fusedBatchNorm_test3) { auto y = results->at(0); auto batchMean = results->at(1); auto batchVar = results->at(2); - + ASSERT_EQ(Status::OK(), results->status()); ASSERT_TRUE(expY.isSameShape(y)); ASSERT_TRUE(expBatchMean.isSameShape(batchMean)); @@ -1977,7 +1977,7 @@ TEST_F(DeclarableOpsTests5, fusedBatchNorm_test3) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, fusedBatchNorm_test4) { - + auto x = NDArrayFactory::create('c', {2, 2, 3, 4}); x.linspace(1); std::vector shape = {4}; @@ -1985,8 +1985,8 @@ TEST_F(DeclarableOpsTests5, fusedBatchNorm_test4) { auto offset = NDArrayFactory::create('c', shape); auto mean = NDArrayFactory::create('c', shape); auto variance = NDArrayFactory::create('c', shape); - - scale = 0.5; + + scale = 0.5; offset = 2.; mean = 25.; variance = 5.; @@ -2001,7 +2001,7 @@ TEST_F(DeclarableOpsTests5, fusedBatchNorm_test4) { auto y = results->at(0); auto batchMean = results->at(1); auto batchVar = results->at(2); - + ASSERT_EQ(Status::OK(), results->status()); ASSERT_TRUE(expY.isSameShape(y)); ASSERT_TRUE(expBatchMean.isSameShape(batchMean)); @@ -2012,7 +2012,7 @@ TEST_F(DeclarableOpsTests5, fusedBatchNorm_test4) { ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, fusedBatchNorm_test5) { - + auto x = NDArrayFactory::create('c', {2, 2, 3, 4}); x.linspace(1); std::vector shape = {4}; @@ -2020,8 +2020,8 @@ TEST_F(DeclarableOpsTests5, fusedBatchNorm_test5) { auto offset = NDArrayFactory::create('c', shape); auto mean = NDArrayFactory::create('c', shape); auto variance = NDArrayFactory::create('c', shape); - - scale = 0.5; + + scale = 0.5; offset = 2.; mean = 25.; variance = 5.; @@ -2036,7 +2036,7 @@ TEST_F(DeclarableOpsTests5, fusedBatchNorm_test5) { auto y = results->at(0); auto batchMean = results->at(1); auto batchVar = results->at(2); - + ASSERT_EQ(Status::OK(), results->status()); ASSERT_TRUE(expY.isSameShape(y)); ASSERT_TRUE(expBatchMean.isSameShape(batchMean)); @@ -2131,49 +2131,49 @@ TEST_F(DeclarableOpsTests5, confusion_matrix_test4) { /////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, ZeroFraction_1) { - + auto x = NDArrayFactory::create('c', {3, 4, 2}, {0, 20, 30, 0, 50, 0, - 70, 0, 90, 0, 11, 12, - 13, 14, 15, 16, 17, 18, + 70, 0, 90, 0, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 0, 21, 22, 23, 24}); nd4j::ops::zero_fraction op; auto res = op.execute({&x}, {}, {}); - + ASSERT_EQ(Status::OK(), res->status()); ASSERT_TRUE(res->at(0)->isScalar()); ASSERT_EQ(res->at(0)->e(0), 0.25); - + delete res; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, ZeroFraction_2) { - + auto x = NDArrayFactory::create('c', {2, 2, 2}, {5.5, 0., 0.3, 5.5, 8.6, 0., 0., 0.4}); nd4j::ops::zero_fraction op; auto res = op.execute({&x}, {}, {}); - + ASSERT_EQ(Status::OK(), res->status()); ASSERT_TRUE(res->at(0)->isScalar()); ASSERT_EQ(res->at(0)->e(0), 0.375); - + delete res; } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, ZeroFraction_3) { - + auto x = NDArrayFactory::create('f', {2, 2, 2}, {5.5, 0., 0.3, 5.5, 8.6, 0., 0., 0.4}); nd4j::ops::zero_fraction op; auto res = op.execute({&x}, {}, {}); - + ASSERT_EQ(Status::OK(), res->status()); ASSERT_TRUE(res->at(0)->isScalar()); ASSERT_EQ(res->at(0)->e(0), 0.375); - + delete res; } @@ -2219,7 +2219,7 @@ TEST_F(DeclarableOpsTests5, StopGradient_1) { // output->printShapeInfo("Output shape> "); // x.printShapeInfo("Expected shape> "); // output->printIndexedBuffer("Output data> "); - // x.printIndexedBuffer("Expected res>"); + // x.printIndexedBuffer("Expected res>"); ASSERT_TRUE(x.isSameShape(output)); ASSERT_TRUE(x.equalsTo(output)); @@ -2242,7 +2242,7 @@ TEST_F(DeclarableOpsTests5, StopGradient_2) { // output->printShapeInfo("Output shape> "); // x.printShapeInfo("Expected shape> "); // output->printIndexedBuffer("Output data> "); - // x.printIndexedBuffer("Expected res>"); + // x.printIndexedBuffer("Expected res>"); ASSERT_TRUE(x.isSameShape(output)); ASSERT_TRUE(x.equalsTo(output)); @@ -2262,7 +2262,7 @@ TEST_F(DeclarableOpsTests5, log_softmax_test1) { ASSERT_EQ(Status::OK(), results->status()); ASSERT_TRUE(expOutput.isSameShape(z)); - ASSERT_TRUE(expOutput.equalsTo(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); delete results; } @@ -2279,7 +2279,7 @@ TEST_F(DeclarableOpsTests5, log_softmax_test2) { ASSERT_EQ(Status::OK(), results->status()); ASSERT_TRUE(expOutput.isSameShape(z)); - ASSERT_TRUE(expOutput.equalsTo(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); delete results; } @@ -2296,7 +2296,7 @@ TEST_F(DeclarableOpsTests5, log_softmax_test3) { ASSERT_EQ(Status::OK(), results->status()); ASSERT_TRUE(expOutput.isSameShape(z)); - ASSERT_TRUE(expOutput.equalsTo(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); delete results; @@ -2314,7 +2314,7 @@ TEST_F(DeclarableOpsTests5, log_softmax_test5) { ASSERT_EQ(Status::OK(), results->status()); ASSERT_TRUE(expOutput.isSameShape(z)); - ASSERT_TRUE(expOutput.equalsTo(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); delete results; } @@ -2331,7 +2331,7 @@ TEST_F(DeclarableOpsTests5, log_softmax_test6) { ASSERT_EQ(Status::OK(), results->status()); ASSERT_TRUE(expOutput.isSameShape(z)); - ASSERT_TRUE(expOutput.equalsTo(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); delete results; } @@ -2348,7 +2348,7 @@ TEST_F(DeclarableOpsTests5, log_softmax_test7) { ASSERT_EQ(Status::OK(), results->status()); ASSERT_TRUE(expOutput.isSameShape(z)); - ASSERT_TRUE(expOutput.equalsTo(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); delete results; } @@ -2365,7 +2365,7 @@ TEST_F(DeclarableOpsTests5, log_softmax_test8) { ASSERT_EQ(Status::OK(), results->status()); ASSERT_TRUE(expOutput.isSameShape(z)); - ASSERT_TRUE(expOutput.equalsTo(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); delete results; } @@ -2382,7 +2382,7 @@ TEST_F(DeclarableOpsTests5, log_softmax_test9) { ASSERT_EQ(Status::OK(), results->status()); ASSERT_TRUE(expOutput.isSameShape(z)); - ASSERT_TRUE(expOutput.equalsTo(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); delete results; } @@ -2399,7 +2399,7 @@ TEST_F(DeclarableOpsTests5, log_softmax_test10) { ASSERT_EQ(Status::OK(), results->status()); ASSERT_TRUE(expOutput.isSameShape(z)); - ASSERT_TRUE(expOutput.equalsTo(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); delete results; } @@ -2416,25 +2416,45 @@ TEST_F(DeclarableOpsTests5, log_softmax_test11) { ASSERT_EQ(Status::OK(), results->status()); ASSERT_TRUE(expOutput.isSameShape(z)); - ASSERT_TRUE(expOutput.equalsTo(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); delete results; } +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, log_softmax_test12) { + + auto input = NDArrayFactory::create('c', {1, 4}, {0.1869, -1.4918, -0.6497, -0.8864}); + auto expOutput = NDArrayFactory::create('c', {1, 4}, {-0.6738, -2.3525, -1.5104, -1.7472}); + + for (int i = 0; i < 10; ++i) + { + nd4j::ops::log_softmax op; + auto results = op.execute({&input}, {}, {}, {}, false, nd4j::DataType::DOUBLE); + auto z = results->at(0); + + ASSERT_EQ(Status::OK(), results->status()); + ASSERT_TRUE(expOutput.isSameShape(z)); + ASSERT_TRUE(expOutput.equalsTo(z, 1e-4)); + + delete results; + } +} + ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, log_softmax_bp_test1) { auto input = NDArrayFactory::create('c', {2, 2}, {1,2,3,4}); auto epsilon = NDArrayFactory::create('c', {2, 2}, {0.1, 0.2, 0.3, 0.4}); auto exp = NDArrayFactory::create('c', {2, 2}, {-0.07311,0.02689, -0.07311,0.02689}); - + nd4j::ops::log_softmax_bp op; auto results = op.execute({&input, &epsilon}, {}, {}, {}, false, nd4j::DataType::DOUBLE); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_TRUE(exp.equalsTo(output)); delete results; } @@ -2445,14 +2465,14 @@ TEST_F(DeclarableOpsTests5, log_softmax_bp_test2) { auto input = NDArrayFactory::create('c', {2, 2}, {1,2,3,4}); auto epsilon = NDArrayFactory::create('c', {2, 2}, {0.1, 0.2, 0.3, 0.4}); auto exp = NDArrayFactory::create('c', {2, 2}, {-0.17616, -0.17616, 0.02384, 0.02384}); - + nd4j::ops::log_softmax_bp op; auto results = op.execute({&input, &epsilon}, {}, {0}, {}, false, nd4j::DataType::DOUBLE); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_TRUE(exp.equalsTo(output)); delete results; } @@ -2463,7 +2483,7 @@ TEST_F(DeclarableOpsTests5, ELU_1) { auto input = NDArrayFactory::create('c', {2, 2, 2}, { -1., 2. , 1.5, -1.4, 1., 2., 2., 1.}); auto exp = NDArrayFactory::create('c', {2, 2, 2}, { -0.63212055, 2. , 1.5, -0.753403, 1., 2., 2., 1.}); auto res = NDArrayFactory::create('c', {2, 2, 2}); - + input.applyTransform(transform::ELU, &res); ASSERT_TRUE(res.equalsTo(&exp)); @@ -2474,7 +2494,7 @@ TEST_F(DeclarableOpsTests5, L2_Loss_1) { auto input = NDArrayFactory::create('c', {2, 2, 2}, { -1., 2. , 1.5, -1.4, 1., 2., 2., 1.}); double exp(9.605); - + nd4j::ops::l2_loss op; auto results = op.execute({&input}, {}, {}); auto output = results->at(0); @@ -2522,14 +2542,14 @@ TEST_F(DeclarableOpsTests5, LogPoissonLoss_1) { auto targets = NDArrayFactory::create('c', {2, 2, 2}, {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}); auto exp = NDArrayFactory::create('c', {2, 2, 2}, {1.3678794, 5.389056, 2.981689, 1.6465969, 1.7182817, 5.389056, 5.389056, 1.7182817}); - + nd4j::ops::log_poisson_loss op; auto results = op.execute({&input, &weights, &targets}, {}, {0}, {}, false, nd4j::DataType::DOUBLE); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_TRUE(exp.equalsTo(output)); delete results; } @@ -2543,14 +2563,14 @@ TEST_F(DeclarableOpsTests5, LogPoissonLoss_2) { auto targets = NDArrayFactory::create('c', {2, 2, 2}, {2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0}); auto exp = NDArrayFactory::create('c', {2, 2, 2}, {3.0196857, 4.0408626, 2.1334953, 3.6984034, 1.3700882, 4.0408626, 4.0408626, 1.3700882}); - + nd4j::ops::log_poisson_loss op; auto results = op.execute({&input, &weights, &targets}, {}, {0, 1}, {}, false, nd4j::DataType::DOUBLE); auto output = results->at(0); ASSERT_EQ(Status::OK(), results->status()); ASSERT_TRUE(exp.isSameShape(output)); - ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_TRUE(exp.equalsTo(output)); delete results; } @@ -2600,9 +2620,9 @@ TEST_F(DeclarableOpsTests5, NormalizeMoments_1) { auto outputDeviance = results->at(1); ASSERT_TRUE(expMeans.isSameShape(outputMeans)); - ASSERT_TRUE(expMeans.equalsTo(outputMeans)); + ASSERT_TRUE(expMeans.equalsTo(outputMeans)); ASSERT_TRUE(expMeans.isSameShape(outputDeviance)); - ASSERT_TRUE(expDeviance.equalsTo(outputDeviance)); + ASSERT_TRUE(expDeviance.equalsTo(outputDeviance)); delete results; } @@ -2651,9 +2671,9 @@ TEST_F(DeclarableOpsTests5, NormalizeMoments_2) { auto outputDeviance = results->at(1); ASSERT_TRUE(expMeans.isSameShape(outputMeans)); - ASSERT_TRUE(expMeans.equalsTo(outputMeans)); + ASSERT_TRUE(expMeans.equalsTo(outputMeans)); ASSERT_TRUE(expMeans.isSameShape(outputDeviance)); - ASSERT_TRUE(expDeviance.equalsTo(outputDeviance)); + ASSERT_TRUE(expDeviance.equalsTo(outputDeviance)); delete results; } @@ -2702,9 +2722,9 @@ TEST_F(DeclarableOpsTests5, NormalizeMoments_3) { auto outputDeviance = results->at(1); ASSERT_TRUE(expMeans.isSameShape(outputMeans)); - ASSERT_TRUE(expMeans.equalsTo(outputMeans)); + ASSERT_TRUE(expMeans.equalsTo(outputMeans)); ASSERT_TRUE(expMeans.isSameShape(outputDeviance)); - ASSERT_TRUE(expDeviance.equalsTo(outputDeviance)); + ASSERT_TRUE(expDeviance.equalsTo(outputDeviance)); delete results; } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp index 2ad0d2219..e4068b12d 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp @@ -109,7 +109,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_4) { auto e = NDArrayFactory::create('c', {1}, {0.}); auto s = NDArrayFactory::create('c', {1}, {1.0}); - //auto exp = NDArrayFactory::create('c', {2}, {1.0f, 2.0f}); + auto exp = NDArrayFactory::create(10); //matrix.linspace(1); @@ -119,7 +119,8 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_4) { auto z = result->at(0); z->printShapeInfo("SS OS shape"); - ASSERT_TRUE(z->isEmpty()); + z->printIndexedBuffer("SS OS out"); + ASSERT_TRUE(z->equalsTo(exp)); //ASSERT_EQ(exp, *z); delete result; diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp index 3152d8cae..3a6d85a2a 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp @@ -52,7 +52,7 @@ TEST_F(DeclarableOpsTests9, reduceStDevBP_test3) { nd4j::ops::reduce_stdev_bp op; auto result = op.execute({&x, &gradO2}, {0,0}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, result->status()); + ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto output = result->at(0); // output->printIndexedBuffer(); ASSERT_TRUE(exp.isSameShape(output)); @@ -60,7 +60,7 @@ TEST_F(DeclarableOpsTests9, reduceStDevBP_test3) { delete result; result = op.execute({&x, &gradO1}, {1,0}, {1}); - ASSERT_EQ(ND4J_STATUS_OK, result->status()); + ASSERT_EQ(ND4J_STATUS_OK, result->status()); output = result->at(0); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); @@ -100,10 +100,10 @@ TEST_F(DeclarableOpsTests9, reduceStDevBP_test03) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, exponentialDistributionInv_test1) { - + const int N = 50000; const double lambda = 2.; - const double mean = 1. / lambda; + const double mean = 1. / lambda; const double std = mean; auto x = NDArrayFactory::create('c', {N}); @@ -114,25 +114,25 @@ TEST_F(DeclarableOpsTests9, exponentialDistributionInv_test1) { auto rng = (nd4j::random::RandomBuffer *) nativeOps.initRandom(nullptr, 123, N, (Nd4jPointer) buffer); if (rng == nullptr) throw std::runtime_error("DeclarableOpsTests9.exponentialDistributionInv_test1: RNG initialization failed !"); - + functions::random::RandomFunction::template execTransform>(rng, x.getBuffer(), x.getShapeInfo(), extraParams); const double actualMean = x.meanNumber().e(0); const double actualStd = x.varianceNumber(variance::SummaryStatsStandardDeviation, true).e(0); - + ASSERT_NEAR(mean, actualMean, 0.01); - ASSERT_NEAR(std, actualStd, 0.01); + ASSERT_NEAR(std, actualStd, 0.01); nativeOps.destroyRandom((Nd4jPointer) rng); delete[] buffer; - + } ////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, exponentialDistributionInv_test2) { - + const int N = 50000; const double lambda = 2.; - const double mean = 1. / lambda; + const double mean = 1. / lambda; const double std = mean; double extraParams[] = {lambda}; @@ -146,14 +146,14 @@ TEST_F(DeclarableOpsTests9, exponentialDistributionInv_test2) { auto rng = (nd4j::random::RandomBuffer *) nativeOps.initRandom(nullptr, 123, N, (Nd4jPointer) buffer); if (rng == nullptr) throw std::runtime_error("DeclarableOpsTests9.exponentialDistributionInv_test2: RNG initialization failed !"); - + functions::random::RandomFunction::template execTransform>(rng, y.getBuffer(), y.getShapeInfo(), x.getBuffer(), x.getShapeInfo(), extraParams); const double actualMean = x.meanNumber().e(0); const double actualStd = x.varianceNumber(variance::SummaryStatsStandardDeviation, true).e(0); ASSERT_NEAR(mean, actualMean, 0.01); - ASSERT_NEAR(std, actualStd, 0.01); + ASSERT_NEAR(std, actualStd, 0.01); nativeOps.destroyRandom((Nd4jPointer) rng); delete[] buffer; @@ -162,10 +162,10 @@ TEST_F(DeclarableOpsTests9, exponentialDistributionInv_test2) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, exponentialDistribution_test1) { - + const int N = 50000; const double lambda = 2.; - const double mean = 1. / lambda; + const double mean = 1. / lambda; const double std = mean; auto x = NDArrayFactory::create('c', {N}); @@ -176,25 +176,25 @@ TEST_F(DeclarableOpsTests9, exponentialDistribution_test1) { auto rng = (nd4j::random::RandomBuffer *) nativeOps.initRandom(nullptr, 123, N, (Nd4jPointer) buffer); if (rng == nullptr) throw std::runtime_error("DeclarableOpsTests9.exponentialDistribution_test1: RNG initialization failed !"); - + functions::random::RandomFunction::template execTransform>(rng, x.getBuffer(), x.getShapeInfo(), extraParams); const double actualMean = x.meanNumber().e(0); const double actualStd = x.varianceNumber(variance::SummaryStatsStandardDeviation, true).e(0); - + ASSERT_NEAR(mean, actualMean, 0.01); - ASSERT_NEAR(std, actualStd, 0.01); + ASSERT_NEAR(std, actualStd, 0.01); nativeOps.destroyRandom((Nd4jPointer) rng); - delete[] buffer; + delete[] buffer; } */ ////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, exponentialDistribution_test2) { - + const int N = 50000; const double lambda = 2.; - const double mean = 1. / lambda; + const double mean = 1. / lambda; const double std = mean; double extraParams[] = {lambda}; @@ -210,7 +210,7 @@ TEST_F(DeclarableOpsTests9, exponentialDistribution_test2) { nd4j::random::RandomBuffer* rng = (nd4j::random::RandomBuffer *) nativeOps.initRandom(nullptr, 123, N, (Nd4jPointer) buffer); if (rng == nullptr) throw std::runtime_error("DeclarableOpsTests9.exponentialDistribution_test2: RNG initialization failed !"); - + functions::random::RandomFunction::template execTransform>(rng, y.getBuffer(), y.getShapeInfo(), x.getBuffer(), x.getShapeInfo(), extraParams); nativeOps.destroyRandom((Nd4jPointer) rng); @@ -218,7 +218,7 @@ TEST_F(DeclarableOpsTests9, exponentialDistribution_test2) { const double actualMean = x.meanNumber().e(0); const double actualStd = x.varianceNumber(variance::SummaryStatsStandardDeviation, true).e(0); ASSERT_NEAR(mean, actualMean, 0.01); - ASSERT_NEAR(std, actualStd, 0.01); + ASSERT_NEAR(std, actualStd, 0.01); @@ -608,6 +608,24 @@ TEST_F(DeclarableOpsTests9, concat_test15) { delete result; } +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, concat_test16) { + + auto x = NDArrayFactory::create('c', {0,2,3}); + auto y = NDArrayFactory::create('c', {0,2,3}); + auto exp = NDArrayFactory::create('c', {0,2,3}); + + nd4j::ops::concat op; + auto result = op.execute({&x, &y}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result->status()); + + auto z = result->at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + + delete result; +} + ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, tile_bp_test3) { @@ -739,11 +757,11 @@ TEST_F(DeclarableOpsTests9, matmul_test1) { x.linspace(1.); y.linspace(0.5, 0.5); - + nd4j::ops::matmul op; auto results = op.execute({&x, &y}, {}, {}); auto z = results->at(0); - + ASSERT_EQ(Status::OK(), results->status()); ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); @@ -761,11 +779,11 @@ TEST_F(DeclarableOpsTests9, matmul_test2) { x.linspace(1.); y.linspace(0.5, 0.5); - + nd4j::ops::matmul op; auto results = op.execute({&x, &y}, {}, {}); auto z = results->at(0); - + ASSERT_EQ(Status::OK(), results->status()); ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); @@ -782,11 +800,11 @@ TEST_F(DeclarableOpsTests9, matmul_test3) { x.linspace(1.); y.linspace(0.5, 0.5); - + nd4j::ops::matmul op; auto results = op.execute({&x, &y}, {}, {}); auto z = results->at(0); - + ASSERT_EQ(Status::OK(), results->status()); ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); @@ -804,11 +822,11 @@ TEST_F(DeclarableOpsTests9, matmul_test4) { x.linspace(1.); y.linspace(0.5, 0.5); - + nd4j::ops::matmul op; auto results = op.execute({&x, &y}, {}, {}); auto z = results->at(0); - + ASSERT_EQ(Status::OK(), results->status()); ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); @@ -826,11 +844,11 @@ TEST_F(DeclarableOpsTests9, matmul_test5) { x.linspace(1.); y.linspace(0.5, 0.5); - + nd4j::ops::matmul op; auto results = op.execute({&x, &y}, {}, {1}); auto z = results->at(0); - + ASSERT_EQ(Status::OK(), results->status()); ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); @@ -847,11 +865,11 @@ TEST_F(DeclarableOpsTests9, matmul_test6) { x.linspace(1.); y.linspace(0.5, 0.5); - + nd4j::ops::matmul op; auto results = op.execute({&x, &y}, {}, {1, 1}); auto z = results->at(0); - + ASSERT_EQ(Status::OK(), results->status()); ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); @@ -870,11 +888,11 @@ TEST_F(DeclarableOpsTests9, matmul_test7) { x.linspace(1.); y.linspace(0.1, 0.1); - + nd4j::ops::matmul op; auto results = op.execute({&x, &y}, {}, {0, 1}); auto z = results->at(0); - + ASSERT_EQ(Status::OK(), results->status()); ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); @@ -895,18 +913,18 @@ TEST_F(DeclarableOpsTests9, matmul_test8) { x.linspace(1.); y.linspace(0.1, 0.1); - + nd4j::ops::matmul op; auto results = op.execute({&x, &y}, {}, {0, 1}); auto z = results->at(0); - + ASSERT_EQ(Status::OK(), results->status()); ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); delete results; } - + ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, matmul_test9) { @@ -920,11 +938,11 @@ TEST_F(DeclarableOpsTests9, matmul_test9) { x.linspace(1.); y.linspace(0.1, 0.1); - + nd4j::ops::matmul op; auto results = op.execute({&x, &y}, {}, {1, 1}); auto z = results->at(0); - + ASSERT_EQ(Status::OK(), results->status()); ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); @@ -1133,11 +1151,11 @@ TEST_F(DeclarableOpsTests9, matmul_test10) { x.linspace(1.); y.linspace(0.5, 0.5); - + nd4j::ops::matmul op; auto results = op.execute({&x, &y}, {}, {1, 1}); auto z = results->at(0); - + ASSERT_EQ(Status::OK(), results->status()); ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); @@ -1174,7 +1192,7 @@ TEST_F(DeclarableOpsTests9, matmul_test12) { x.linspace(1.); y.linspace(0.5, 0.5); - + nd4j::ops::matmul op; auto results = op.execute({&x, &y}, {}, {1, 1}); @@ -1195,11 +1213,11 @@ TEST_F(DeclarableOpsTests9, matmul_test13) { x.linspace(1.); y.linspace(0.5, 0.5); - + nd4j::ops::matmul op; auto results = op.execute({&x, &y}, {}, {0, 0, 1}); auto z = results->at(0); - + ASSERT_EQ(Status::OK(), results->status()); ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); @@ -1216,11 +1234,11 @@ TEST_F(DeclarableOpsTests9, matmul_test14) { x.linspace(1.); y.linspace(0.5, 0.5); - + nd4j::ops::matmul op; auto results = op.execute({&x, &y}, {}, {1, 0, 1}); auto z = results->at(0); - + ASSERT_EQ(Status::OK(), results->status()); ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); @@ -1237,11 +1255,11 @@ TEST_F(DeclarableOpsTests9, matmul_test15) { x.linspace(1.); y.linspace(0.5, 0.5); - + nd4j::ops::matmul op; auto results = op.execute({&x, &y}, {}, {1, 0, 1}); auto z = results->at(0); - + ASSERT_EQ(Status::OK(), results->status()); ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); @@ -1261,11 +1279,11 @@ TEST_F(DeclarableOpsTests9, matmul_test16) { x.linspace(1.); y.linspace(0.1, 0.1); - + nd4j::ops::matmul op; auto results = op.execute({&x, &y}, {}, {1, 1, 1}); auto z = results->at(0); - + ASSERT_EQ(Status::OK(), results->status()); ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); @@ -1282,11 +1300,11 @@ TEST_F(DeclarableOpsTests9, matmul_test17) { x.linspace(1.); y.linspace(0.1, 0.1); - + nd4j::ops::matmul op; auto results = op.execute({&x, &y}, {}, {1, 0}); auto z = results->at(0); - + ASSERT_EQ(Status::OK(), results->status()); ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); @@ -1303,11 +1321,11 @@ TEST_F(DeclarableOpsTests9, matmul_test18) { x.linspace(1.); y.linspace(0.1, 0.1); - + nd4j::ops::matmul op; auto results = op.execute({&x, &y}, {}, {0, 1}); auto z = results->at(0); - + ASSERT_EQ(Status::OK(), results->status()); ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); @@ -1324,11 +1342,11 @@ TEST_F(DeclarableOpsTests9, matmul_test19) { x.linspace(2.); y.linspace(0.1, 0.1); - + nd4j::ops::matmul op; auto results = op.execute({&x, &y}, {}, {}); auto z = results->at(0); - + ASSERT_EQ(Status::OK(), results->status()); ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); @@ -1346,11 +1364,11 @@ TEST_F(DeclarableOpsTests9, matmul_test20) { x.linspace(2.); y.linspace(0.1, 0.1); - + nd4j::ops::matmul op; auto results = op.execute({&x, &y}, {}, {1,1,1}); auto z = results->at(0); - + ASSERT_EQ(Status::OK(), results->status()); ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); @@ -1368,11 +1386,11 @@ TEST_F(DeclarableOpsTests9, matmul_test21) { x.linspace(2.); y.linspace(0.1, 0.1); - + nd4j::ops::matmul op; auto results = op.execute({&x, &y}, {}, {}); auto z = results->at(0); - + ASSERT_EQ(Status::OK(), results->status()); ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); @@ -1390,11 +1408,11 @@ TEST_F(DeclarableOpsTests9, matmul_test22) { x.linspace(2.); y.linspace(0.1, 0.1); - + nd4j::ops::matmul op; auto results = op.execute({&x, &y}, {}, {1}); auto z = results->at(0); - + ASSERT_EQ(Status::OK(), results->status()); ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); @@ -1412,11 +1430,11 @@ TEST_F(DeclarableOpsTests9, matmul_test23) { x.linspace(1.); y.linspace(0.1, 0.1); - + nd4j::ops::matmul op; auto results = op.execute({&x, &y}, {}, {1, 1}); auto z = results->at(0); - + ASSERT_EQ(Status::OK(), results->status()); ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); @@ -1431,11 +1449,11 @@ TEST_F(DeclarableOpsTests9, matmul_test24) { auto x = NDArrayFactory::create('f', {1}, {2.}); auto y = NDArrayFactory::create('c', {1}, {3.}); auto exp = NDArrayFactory::create(6.); - + nd4j::ops::matmul op; auto results = op.execute({&x, &y}, {}, {1, 1}); auto z = results->at(0); - + ASSERT_EQ(Status::OK(), results->status()); ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); @@ -1534,34 +1552,34 @@ TEST_F(DeclarableOpsTests9, test_unstack_SGO_1) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, clipbynorm_test12) { - + const int bS = 5; const int nOut = 4; const int axis = 0; const double clip = 2.; - + auto x = NDArrayFactory::create('c', {bS, nOut}, {0.412 ,0.184 ,0.961 ,0.897 ,0.173 ,0.931 ,0.736 ,0.540 ,0.953 ,0.278 ,0.573 ,0.787 ,0.320 ,0.776 ,0.338 ,0.311 ,0.835 ,0.909 ,0.890 ,0.290}); // uniform random in range [0,1] auto colVect = NDArrayFactory::create('c', {bS, 1}, {0.9, 0.95, 1.00, 1.05, 1.1}); auto expect = NDArrayFactory::create('c', {bS, nOut}); auto norm2 = x.reduceAlongDims(reduce::Norm2, {axis}, true); // norm2 has shape [1, nOut] - + auto y = ( (x / norm2) * clip) * colVect ; auto temp = (x / norm2) * clip; for (int j = 0; j < nOut; ++j) { auto yCol = y({0,0, j,j+1}); const double norm2Col = yCol.reduceNumber(reduce::Norm2).e(0); - if (norm2Col <= clip) + if (norm2Col <= clip) expect({0,0, j,j+1}).assign(yCol); - else + else expect({0,0, j,j+1}).assign ( yCol * (clip / norm2Col) ); } - + nd4j::ops::clipbynorm op; auto result = op.execute({&y}, {clip}, {axis}, {}, false, nd4j::DataType::DOUBLE); - auto outFF = result->at(0); - + auto outFF = result->at(0); + ASSERT_TRUE(expect.isSameShape(outFF)); ASSERT_TRUE(expect.equalsTo(outFF)); @@ -1571,12 +1589,12 @@ TEST_F(DeclarableOpsTests9, clipbynorm_test12) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, clipbynorm_bp_test1) { - + const int bS = 2; const int nOut = 3; const int axis = 0; const double clip = 0.7; - + auto x = NDArrayFactory::create('c', {bS, nOut}, {0.412 ,0.184 ,0.961 ,0.173 ,0.736 ,0.540 }); // uniform random in range [0,1] auto gradO = NDArrayFactory::create('c', {bS, nOut}); @@ -1593,12 +1611,12 @@ TEST_F(DeclarableOpsTests9, clipbynorm_bp_test1) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, clipbynorm_bp_test2) { - + const int bS = 2; const int nOut = 3; const int axis = 0; const double clip = 0.7; - + auto x = NDArrayFactory::create('c', {bS, nOut}, {0.412 ,0.184 ,0.961 ,0.173 ,0.736 ,0.540 }); // uniform random in range [0,1] auto gradO = NDArrayFactory::create('c', {bS, nOut}); @@ -1616,12 +1634,12 @@ TEST_F(DeclarableOpsTests9, clipbynorm_bp_test2) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, clipbynorm_bp_test3) { - + const int bS = 2; const int nOut = 3; const int axis = 1; const double clip = 1.; - + auto x = NDArrayFactory::create('c', {bS, nOut}, {0.412 ,0.184 ,0.961 ,0.173 ,0.736 ,0.540 }); // uniform random in range [0,1] auto gradO = NDArrayFactory::create('c', {bS, nOut}); @@ -1734,7 +1752,7 @@ TEST_F(DeclarableOpsTests9, cumsum_bp_check_2) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, cumprod_test1) { - + auto inputC = NDArrayFactory::create('c', {3, 5}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}); auto axis = NDArrayFactory::create(1.); @@ -1745,7 +1763,7 @@ TEST_F(DeclarableOpsTests9, cumprod_test1) { auto expTT = NDArrayFactory::create('c', {3, 5}, {120, 60, 20, 5, 1,5040, 720, 90, 10, 1,32760, 2730, 210, 15, 1}); auto gradO = NDArrayFactory::create('c', {3, 5}); - int exclusive, reverse; + int exclusive, reverse; //************************************// exclusive = 0; reverse = 0; @@ -1764,8 +1782,8 @@ TEST_F(DeclarableOpsTests9, cumprod_test1) { /* exclusive = 1; reverse = 0; result = op.execute({&inputC, &axis}, {}, {exclusive, reverse}); - ASSERT_EQ(Status::OK(), result->status()); - z = result->at(0); + ASSERT_EQ(Status::OK(), result->status()); + z = result->at(0); ASSERT_TRUE(expTF.equalsTo(z)); delete result; */ @@ -1773,8 +1791,8 @@ TEST_F(DeclarableOpsTests9, cumprod_test1) { /* exclusive = 0; reverse = 1; result = op.execute({&inputC, &axis}, {}, {exclusive, reverse}); - ASSERT_EQ(Status::OK(), result->status()); - z = result->at(0); + ASSERT_EQ(Status::OK(), result->status()); + z = result->at(0); ASSERT_TRUE(expFT.equalsTo(z)); delete result; */ @@ -1782,16 +1800,16 @@ TEST_F(DeclarableOpsTests9, cumprod_test1) { /* exclusive = 1; reverse = 1; result = op.execute({&inputC, &axis}, {}, {exclusive, reverse}); - ASSERT_EQ(Status::OK(), result->status()); - z = result->at(0); + ASSERT_EQ(Status::OK(), result->status()); + z = result->at(0); ASSERT_TRUE(expTT.equalsTo(z)); - delete result; -*/ + delete result; +*/ } //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, cumprod_test2) { - + auto inputC = NDArrayFactory::create('c', {2, 2}); auto axis = NDArrayFactory::create(1.); @@ -1802,7 +1820,7 @@ TEST_F(DeclarableOpsTests9, cumprod_test2) { // auto expTT = NDArrayFactory::create('c', {3, 5}, {120, 60, 20, 5, 1,5040, 720, 90, 10, 1,32760, 2730, 210, 15, 1}); auto gradO = NDArrayFactory::create('c', {2, 2}); - int exclusive, reverse; + int exclusive, reverse; //************************************// exclusive = 0; reverse = 0; @@ -1820,7 +1838,7 @@ TEST_F(DeclarableOpsTests9, cumprod_test2) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, prelu_test1) { - + auto x = NDArrayFactory::create('c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); auto alpha = NDArrayFactory::create('c', {3, 4}, {-0.6f, -0.5f, -0.4f, -0.3f, -0.2f, -0.1f, 0.f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f}); auto exp = NDArrayFactory::create('c', {2, 3, 4}, {7.2f, 5.5f, 4.f, 2.7f, 1.6f, 0.7f, 0.f, -0.5f,-0.8f, -0.9f, -0.8f, -0.5f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); @@ -1839,7 +1857,7 @@ TEST_F(DeclarableOpsTests9, prelu_test1) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, prelu_test2) { - + auto x = NDArrayFactory::create('c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); auto alpha = NDArrayFactory::create('c', {3}, {-0.6f, 2.f, 4.f}); auto exp = NDArrayFactory::create('c', {2, 3, 4}, {7.2f, 6.6f, 6.f, 5.4f, -16.f, -14.f, -12.f, -10.f, -16.f, -12.f, -8.f, -4.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); @@ -1857,7 +1875,7 @@ TEST_F(DeclarableOpsTests9, prelu_test2) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, prelu_test3) { - + auto x = NDArrayFactory::create('c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); auto alpha = NDArrayFactory::create('c', {3,1}, {-0.6f, 2.f, 4.f}); auto exp = NDArrayFactory::create('c', {2, 3, 4}, {7.2f, 6.6f, 6.f, 5.4f, -16.f, -14.f, -12.f, -10.f, -16.f, -12.f, -8.f, -4.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); @@ -1875,7 +1893,7 @@ TEST_F(DeclarableOpsTests9, prelu_test3) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, prelu_test4) { - + auto x = NDArrayFactory::create('c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); auto alpha = NDArrayFactory::create('c', {1, 3}, {-0.6f, 2.f, 4.f}); auto exp = NDArrayFactory::create('c', {2, 3, 4}, {7.2f, 6.6f, 6.f, 5.4f, -16.f, -14.f, -12.f, -10.f, -16.f, -12.f, -8.f, -4.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); @@ -1893,7 +1911,7 @@ TEST_F(DeclarableOpsTests9, prelu_test4) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, prelu_test5) { - + auto x = NDArrayFactory::create('c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); auto alpha = NDArrayFactory::create('c', {4}, {-0.6f, 2.f, 4.f, -1.f}); auto exp = NDArrayFactory::create('c', {2, 3, 4}, {7.2f, -22.f, -40.f, 9.f, 4.8f, -14.f, -24.f, 5.f, 2.4f, -6.f, -8.f, 1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); @@ -1911,7 +1929,7 @@ TEST_F(DeclarableOpsTests9, prelu_test5) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, prelu_test6) { - + auto x = NDArrayFactory::create('c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); auto alpha = NDArrayFactory::create('c', {1,1,1}, {-2.}); auto exp = NDArrayFactory::create('c', {2, 3, 4}, {24.f, 22.f, 20.f, 18.f, 16.f, 14.f, 12.f, 10.f, 8.f, 6.f, 4.f, 2.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); @@ -1930,7 +1948,7 @@ TEST_F(DeclarableOpsTests9, prelu_test6) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, prelu_test7) { - + auto x = NDArrayFactory::create('c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); auto alpha = NDArrayFactory::create(-2.f); auto exp = NDArrayFactory::create('c', {2, 3, 4}, {24.f, 22.f, 20.f, 18.f, 16.f, 14.f, 12.f, 10.f, 8.f, 6.f, 4.f, 2.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); @@ -1948,7 +1966,7 @@ TEST_F(DeclarableOpsTests9, prelu_test7) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, prelu_test8) { - + auto x = NDArrayFactory::create('c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); auto alpha = NDArrayFactory::create(-2.f); auto exp = NDArrayFactory::create('c', {2, 3, 4}, {24.f, 22.f, 20.f, 18.f, 16.f, 14.f, 12.f, 10.f, 8.f, 6.f, 4.f, 2.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); @@ -1966,7 +1984,7 @@ TEST_F(DeclarableOpsTests9, prelu_test8) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, prelu_test9) { - + auto x = NDArrayFactory::create('c', {2, 4}, {-4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f}); auto alpha = NDArrayFactory::create(-2.f); auto exp = NDArrayFactory::create('c', {2, 4}, {8.f, 6.f, 4.f, 2.f,0.f, 1.f, 2.f, 3.f}); @@ -1984,7 +2002,7 @@ TEST_F(DeclarableOpsTests9, prelu_test9) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, prelu_test10) { - + auto x = NDArrayFactory::create('c', {2, 4}, {-4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f}); auto alpha = NDArrayFactory::create(-2.f); auto exp = NDArrayFactory::create('c', {2, 4}, {8.f, 6.f, 4.f, 2.f,0.f, 1.f, 2.f, 3.f}); @@ -2002,16 +2020,16 @@ TEST_F(DeclarableOpsTests9, prelu_test10) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, prelu_test11) { - + auto x = NDArrayFactory::create('c', {2, 3, 4, 5}); x.linspace(-50.); auto alpha = NDArrayFactory::create('c', {4}, {0.f, -0.5f, 0.5f, -1.f}); auto exp = NDArrayFactory::create('c', {2, 3, 4, 5}, {0.f, 0.f, 0.f, 0.f, 0.f, 22.5f, 22.f, 21.5f, 21.f, 20.5f, -20.f, -19.5f, -19.f, -18.5f, -18.f, 35.f, 34.f, 33.f, - 32.f, 31.f, 0.f, 0.f, 0.f, 0.f, 0.f, 12.5f, 12.f, 11.5f, 11.f, 10.5f, -10.f, -9.5f, -9.f, -8.5f, -8.f, 15.f, - 14.f, 13.f, 12.f, 11.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.5f, 2.f, 1.5f, 1.f, 0.5f, 0.f, 1.f, 2.f, 3.f, 4.f, - 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, - 24.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, - 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 61.f, + 32.f, 31.f, 0.f, 0.f, 0.f, 0.f, 0.f, 12.5f, 12.f, 11.5f, 11.f, 10.5f, -10.f, -9.5f, -9.f, -8.5f, -8.f, 15.f, + 14.f, 13.f, 12.f, 11.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.5f, 2.f, 1.5f, 1.f, 0.5f, 0.f, 1.f, 2.f, 3.f, 4.f, + 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, + 24.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, + 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f}); nd4j::ops::prelu op; @@ -2027,15 +2045,15 @@ TEST_F(DeclarableOpsTests9, prelu_test11) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, prelu_test12) { - + auto x = NDArrayFactory::create('c', {2, 3, 4, 5}); x.linspace(-50.); auto alpha = NDArrayFactory::create('c', {3,5}, {-0.7f, -0.6f, -0.5f, -0.4f, -0.3f, -0.2f, -0.1f, 0.f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f}); auto exp = NDArrayFactory::create('c', {2, 3, 4, 5}, {35.f, 29.4f, 24.f, 18.8f, 13.8f, 31.5f, 26.4f, 21.5f, 16.8f, 12.3f, 28.f, 23.4f, 19.f, 14.8f, 10.8f, 24.5f, 20.4f, 16.5f, 12.8f, - 9.3f, 6.f, 2.9f, 0.f, -2.7f, -5.2f, 5.f, 2.4f, 0.f, -2.2f, -4.2f, 4.f, 1.9f, 0.f, -1.7f, -3.2f, 3.f, 1.4f, 0.f, -1.2f, - -2.2f, -3.f, -3.6f, -4.f, -4.2f, -4.2f, -1.5f, -1.6f, -1.5f, -1.2f, -0.7f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, - 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, - 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 49.f, 50.f, 51.f, 52.f, + 9.3f, 6.f, 2.9f, 0.f, -2.7f, -5.2f, 5.f, 2.4f, 0.f, -2.2f, -4.2f, 4.f, 1.9f, 0.f, -1.7f, -3.2f, 3.f, 1.4f, 0.f, -1.2f, + -2.2f, -3.f, -3.6f, -4.f, -4.2f, -4.2f, -1.5f, -1.6f, -1.5f, -1.2f, -0.7f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, + 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, + 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f}); nd4j::ops::prelu op; @@ -2051,15 +2069,15 @@ TEST_F(DeclarableOpsTests9, prelu_test12) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, prelu_test13) { - + auto x = NDArrayFactory::create('c', {2, 3, 4, 5}); x.linspace(-50.); auto alpha = NDArrayFactory::create('c', {5,3}, {-0.7f, -0.6f, -0.5f, -0.4f, -0.3f, -0.2f, -0.1f, 0.f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f}); auto exp = NDArrayFactory::create('c', {2, 3, 4, 5}, {35.f, 29.4f, 24.f, 18.8f, 13.8f, 31.5f, 26.4f, 21.5f, 16.8f, 12.3f, 28.f, 23.4f, 19.f, 14.8f, 10.8f, 24.5f, 20.4f, 16.5f, 12.8f, - 9.3f, 6.f, 2.9f, 0.f, -2.7f, -5.2f, 5.f, 2.4f, 0.f, -2.2f, -4.2f, 4.f, 1.9f, 0.f, -1.7f, -3.2f, 3.f, 1.4f, 0.f, -1.2f, - -2.2f, -3.f, -3.6f, -4.f, -4.2f, -4.2f, -1.5f, -1.6f, -1.5f, -1.2f, -0.7f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, - 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, - 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 49.f, 50.f, 51.f, 52.f, + 9.3f, 6.f, 2.9f, 0.f, -2.7f, -5.2f, 5.f, 2.4f, 0.f, -2.2f, -4.2f, 4.f, 1.9f, 0.f, -1.7f, -3.2f, 3.f, 1.4f, 0.f, -1.2f, + -2.2f, -3.f, -3.6f, -4.f, -4.2f, -4.2f, -1.5f, -1.6f, -1.5f, -1.2f, -0.7f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, + 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, + 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f}); nd4j::ops::prelu op; @@ -2075,16 +2093,16 @@ TEST_F(DeclarableOpsTests9, prelu_test13) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, prelu_test14) { - + auto x = NDArrayFactory::create('c', {2, 3, 4, 5}); x.linspace(-50.); auto alpha = NDArrayFactory::create('c', {2,10}, {-0.7f, -0.6f, -0.5f, -0.4f, -0.3f, -0.2f, -0.1f, 0.f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f, 0.9f, 1.f, 1.1f, 1.2f}); auto exp = NDArrayFactory::create('c', {2, 3, 4, 5}, {35.f, 29.4f, 24.f, 18.8f, 13.8f, 9.f, 4.4f, 0.f, -4.2f, -8.2f, -12.f, -15.6f, -19.f, -22.2f, -25.2f, -28.f, -30.6f, - -33.f,-35.2f, -37.2f, 21.f, 17.4f, 14.f, 10.8f, 7.8f, 5.f, 2.4f, 0.f, -2.2f, -4.2f, -6.f, -7.6f, -9.f, -10.2f, - -11.2f, -12.f, -12.6f, -13.f, -13.2f, -13.2f, 7.f, 5.4f, 4.f, 2.8f, 1.8f, 1.f, 0.4f, 0.f, -0.2f, -0.2f, 0.f, - 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, - 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, - 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, + -33.f,-35.2f, -37.2f, 21.f, 17.4f, 14.f, 10.8f, 7.8f, 5.f, 2.4f, 0.f, -2.2f, -4.2f, -6.f, -7.6f, -9.f, -10.2f, + -11.2f, -12.f, -12.6f, -13.f, -13.2f, -13.2f, 7.f, 5.4f, 4.f, 2.8f, 1.8f, 1.f, 0.4f, 0.f, -0.2f, -0.2f, 0.f, + 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, + 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, + 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f}); nd4j::ops::prelu op; @@ -2100,7 +2118,7 @@ TEST_F(DeclarableOpsTests9, prelu_test14) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, thresholdedrelu_test1) { - + const float theta = 2.f; auto x = NDArrayFactory::create('c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); auto exp = NDArrayFactory::create('c', {2, 3, 4}, {0.f, 0.f, 0.f, 0.f,0.f, 0.f, 0.f, 0.f,0.f, 0.f, 0.f, 0.f,0.f, 0.f, 0.f, 3.f,4.f, 5.f, 6.f, 7.f,8.f, 9.f,10.f,11.f}); @@ -2116,10 +2134,10 @@ TEST_F(DeclarableOpsTests9, thresholdedrelu_test1) { delete result; } - + //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, thresholdedrelu_test2) { - + const float theta = -2.f; auto x = NDArrayFactory::create('c', {2, 3, 4}, {0.f,-4.f, -10.f, -8.f, 0.f, -9.f, -8.f, 5.f, 6.f, 6.f, 9.f, 6.f, -8.f, 5.f, 10.f, -2.f, 3.f, -7.f, 4.f, -8.f, -4.f, -9.f, -9.f, 3.f}); auto exp = NDArrayFactory::create('c', {2, 3, 4}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 5.f, 6.f, 6.f, 9.f, 6.f, 0.f, 5.f, 10.f, 0.f, 3.f, 0.f, 4.f, 0.f, 0.f, 0.f, 0.f, 3.f}); @@ -2128,7 +2146,7 @@ TEST_F(DeclarableOpsTests9, thresholdedrelu_test2) { auto result = op.execute({&x}, {theta}, {}, {}, false, nd4j::DataType::DOUBLE); ASSERT_EQ(ND4J_STATUS_OK, result->status()); - auto output = result->at(0); + auto output = result->at(0); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); @@ -2138,7 +2156,7 @@ TEST_F(DeclarableOpsTests9, thresholdedrelu_test2) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, prelu_bp_test1) { - + auto x = NDArrayFactory::create('c', {2, 3, 4}, {-12., -11., -10., -9., -8., -7., -6., -5., -4., -3., -2., -1., 0.5, 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.}); auto alpha = NDArrayFactory::create('c', {3, 4}, {-0.6, -0.5, -0.4, -0.3, -0.2, -0.1, 0.5, 0.1, 0.2, 0.3, 0.4, 0.5}); auto dLdO = NDArrayFactory::create('c', {2, 3, 4}); @@ -2156,7 +2174,7 @@ TEST_F(DeclarableOpsTests9, prelu_bp_test1) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, prelu_bp_test2) { - + auto x = NDArrayFactory::create('c', {2, 3, 4}, {-12., -11., -10., -9., -8., -7., -6., -5., -4., -3., -2., -1., 0.5, 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.}); auto alpha = NDArrayFactory::create('c', {4}, {-0.6, 2., 4., -1.}); auto dLdO = NDArrayFactory::create('c', {2, 3, 4}); @@ -2174,7 +2192,7 @@ TEST_F(DeclarableOpsTests9, prelu_bp_test2) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, prelu_bp_test3) { - + auto x = NDArrayFactory::create('c', {2, 3, 2, 5}); x.linspace(-30.); x.p(30, 0.5); // avoid zero, since it is points of discontinuity for prelu @@ -2194,7 +2212,7 @@ TEST_F(DeclarableOpsTests9, prelu_bp_test3) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, prelu_bp_test4) { - + auto x = NDArrayFactory::create('c', {2, 3, 4, 5}); x.linspace(-50.); x.p(50, 0.5); // avoid zero, since it is points of discontinuity for prele @@ -2214,7 +2232,7 @@ TEST_F(DeclarableOpsTests9, prelu_bp_test4) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, thresholdedrelu_bp_test1) { - + const double theta = 0.15; auto x = NDArrayFactory::create('c', {2, 3, 4}, {1.2, 1.1, 1., 0.9, 0.8, -0.7, -0.6,-0.5,-0.4,-0.3,-0.2,-0.1, 0., 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, -0.9, -1.0, -1.1}); @@ -2233,7 +2251,7 @@ TEST_F(DeclarableOpsTests9, thresholdedrelu_bp_test1) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, multiply_test1) { - + auto x = NDArrayFactory::create('c', {2, 3, 4}); auto y = NDArrayFactory::create('c', {4}); auto exp = NDArrayFactory::create('c', {2, 3, 4}, {0.1f, 0.4f, 0.9f, 1.6f, 0.5f, 1.2f, 2.1f, 3.2f, 0.9f, 2.f, 3.3f, 4.8f, 1.3f, 2.8f, 4.5f, 6.4f, 1.7f, 3.6f, 5.7f, 8.f, 2.1f, 4.4f, 6.9f, 9.6f}); @@ -2253,7 +2271,7 @@ TEST_F(DeclarableOpsTests9, multiply_test1) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, multiply_test2) { - + auto x = NDArrayFactory::create('c', {2, 3, 4}); auto y = NDArrayFactory::create(0.1); auto exp = NDArrayFactory::create('c', {2, 3, 4}, {0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f, 0.9f, 1.f, 1.1f, 1.2f, 1.3f, 1.4f, 1.5f, 1.6f, 1.7f, 1.8f, 1.9f, 2.f, 2.1f, 2.2f, 2.3f, 2.4f}); @@ -2273,7 +2291,7 @@ TEST_F(DeclarableOpsTests9, multiply_test2) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, multiply_test3) { - + auto x = NDArrayFactory::create('c', {2, 1, 4}); auto y = NDArrayFactory::create('c', {3,1}); auto exp = NDArrayFactory::create('c', {2, 3, 4}, {0.1f, 0.2f, 0.3f, 0.4f, 0.2f, 0.4f, 0.6f, 0.8f, 0.3f, 0.6f, 0.9f, 1.2f, 0.5f, 0.6f, 0.7f, 0.8f, 1.f, 1.2f, 1.4f, 1.6f, 1.5f, 1.8f, 2.1f, 2.4f}); @@ -2293,11 +2311,11 @@ TEST_F(DeclarableOpsTests9, multiply_test3) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, multiply_test4) { - + auto x = NDArrayFactory::create('c', {1, 1}); auto y = NDArrayFactory::create(0.1f); auto exp = NDArrayFactory::create('c', {1, 1}, {0.1f}); - x.linspace(1.f); + x.linspace(1.f); nd4j::ops::multiply op; auto result = op.execute({&x, &y}, {}, {}); @@ -2312,11 +2330,11 @@ TEST_F(DeclarableOpsTests9, multiply_test4) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, multiply_test5) { - + auto x = NDArrayFactory::create(1.f); auto y = NDArrayFactory::create(0.1f); auto exp = NDArrayFactory::create(0.1f); - + nd4j::ops::multiply op; auto result = op.execute({&x, &y}, {}, {}); ASSERT_EQ(ND4J_STATUS_OK, result->status()); @@ -2330,7 +2348,7 @@ TEST_F(DeclarableOpsTests9, multiply_test5) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, multiply_bp_test1) { - + auto x = NDArrayFactory::create('c', {1, 1}, {100.}); auto y = NDArrayFactory::create(0.1); auto dLdz = NDArrayFactory::create('c', {1, 1}); @@ -2353,7 +2371,7 @@ TEST_F(DeclarableOpsTests9, multiply_bp_test1) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, multiply_bp_test2) { - + auto x = NDArrayFactory::create('c', {2, 2}, {1.,2.,3.,4.}); auto y = NDArrayFactory::create(0.1); auto dLdz = NDArrayFactory::create('c', {2, 2}); @@ -2371,7 +2389,7 @@ TEST_F(DeclarableOpsTests9, multiply_bp_test2) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, multiply_bp_test3) { - + auto y = NDArrayFactory::create('c', {2, 2}, {1.,2.,3.,4.}); auto x = NDArrayFactory::create(0.1); auto dLdz = NDArrayFactory::create('c', {2, 2}); @@ -2389,7 +2407,7 @@ TEST_F(DeclarableOpsTests9, multiply_bp_test3) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, multiply_bp_test4) { - + auto x = NDArrayFactory::create('c', {2, 2}, {1.,2.,3.,4.}); auto y = NDArrayFactory::create('c', {2, 2}, {0.1,0.2,0.3,0.4}); auto dLdz = NDArrayFactory::create('c', {2, 2}); @@ -2407,7 +2425,7 @@ TEST_F(DeclarableOpsTests9, multiply_bp_test4) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, multiply_bp_test5) { - + auto x = NDArrayFactory::create('c', {2, 2}, {1.,2.,3.,4.}); auto y = NDArrayFactory::create('c', {2}, {0.1,0.2}); auto dLdz = NDArrayFactory::create('c', {2, 2}); @@ -2425,7 +2443,7 @@ TEST_F(DeclarableOpsTests9, multiply_bp_test5) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, multiply_bp_test6) { - + auto y = NDArrayFactory::create('c', {2, 2}, {1.,2.,3.,4.}); auto x = NDArrayFactory::create('c', {2}, {0.1,0.2}); auto dLdz = NDArrayFactory::create('c', {2, 2}); @@ -2443,7 +2461,7 @@ TEST_F(DeclarableOpsTests9, multiply_bp_test6) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, multiply_bp_test7) { - + auto y = NDArrayFactory::create('c', {2, 3}, {1.,2.,3.,4.,5.,6.}); auto x = NDArrayFactory::create('c', {2, 1}, {0.1,0.2}); auto dLdz = NDArrayFactory::create('c', {2, 3}); @@ -2461,7 +2479,7 @@ TEST_F(DeclarableOpsTests9, multiply_bp_test7) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, multiply_bp_test8) { - + auto y = NDArrayFactory::create('c', {2, 1, 4}); auto x = NDArrayFactory::create('c', {1, 3, 4}); auto dLdz = NDArrayFactory::create('c', {2, 3, 4}); diff --git a/libnd4j/tests_cpu/layers_tests/EmptyTests.cpp b/libnd4j/tests_cpu/layers_tests/EmptyTests.cpp index c3af0d5c2..65de2729f 100644 --- a/libnd4j/tests_cpu/layers_tests/EmptyTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/EmptyTests.cpp @@ -59,7 +59,8 @@ TEST_F(EmptyTests, Test_Create_Empty_2) { } TEST_F(EmptyTests, Test_Concat_1) { - auto empty = NDArrayFactory::empty_(); +// auto empty = NDArrayFactory::empty_(); + auto empty = new NDArray('c', {0}, nd4j::DataType::FLOAT32);//NDArrayFactory::create_('c', {(Nd4jLong)0}}; auto vector = NDArrayFactory::create_('c', {1}, {1.0f}); ASSERT_TRUE(empty->isEmpty()); @@ -82,9 +83,9 @@ TEST_F(EmptyTests, Test_Concat_1) { TEST_F(EmptyTests, Test_Concat_2) { - auto empty = NDArrayFactory::empty_(); - auto scalar1 = NDArrayFactory::create_(1.0f); - auto scalar2 = NDArrayFactory::create_(2.0f); + auto empty = new NDArray('c', {0}, nd4j::DataType::FLOAT32); //NDArrayFactory::empty_(); + auto scalar1 = NDArrayFactory::create_('c', {1}, {1.0f}); + auto scalar2 = NDArrayFactory::create_('c', {1}, {2.0f}); auto exp = NDArrayFactory::create('c', {2}, {1.f, 2.f}); ASSERT_TRUE(empty->isEmpty()); @@ -139,6 +140,23 @@ TEST_F(EmptyTests, Test_Reshape_2) { delete result; } +TEST_F(EmptyTests, Test_Reshape_3) { + auto x = NDArrayFactory::create('c', {1, 0, 0, 2}); + auto y = NDArrayFactory::create('c', {2}, {10, 0}); + auto e = NDArrayFactory::create('c', {10, 0}); + + nd4j::ops::reshape op; + auto result = op.execute({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); + + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_EQ(e, *z); + + delete result; +} + TEST_F(EmptyTests, Test_dup_1) { auto empty = NDArrayFactory::empty(); auto dup = empty.dup(); @@ -147,4 +165,48 @@ TEST_F(EmptyTests, Test_dup_1) { ASSERT_EQ(empty, *dup); delete dup; +} + +TEST_F(EmptyTests, test_shaped_empty_1) { + auto empty = NDArrayFactory::create('c', {2, 0, 3}); + std::vector shape = {2, 0, 3}; + + ASSERT_EQ(nd4j::DataType::FLOAT32, empty.dataType()); + ASSERT_EQ(0, empty.lengthOf()); + ASSERT_TRUE(empty.isEmpty()); + ASSERT_EQ(shape, empty.getShapeAsVector()); + ASSERT_EQ(3, empty.rankOf()); +} + +TEST_F(EmptyTests, test_shaped_empty_2) { + auto empty = NDArrayFactory::create('c', {0, 3}); + std::vector shape = {0, 3}; + + ASSERT_EQ(nd4j::DataType::FLOAT32, empty.dataType()); + ASSERT_EQ(0, empty.lengthOf()); + ASSERT_TRUE(empty.isEmpty()); + ASSERT_EQ(shape, empty.getShapeAsVector()); + ASSERT_EQ(2, empty.rankOf()); +} + +TEST_F(EmptyTests, test_shaped_empty_3) { + auto empty = NDArrayFactory::create('c', {0}); + std::vector shape = {0}; + + ASSERT_EQ(nd4j::DataType::FLOAT32, empty.dataType()); + ASSERT_EQ(0, empty.lengthOf()); + ASSERT_TRUE(empty.isEmpty()); + ASSERT_EQ(shape, empty.getShapeAsVector()); + ASSERT_EQ(1, empty.rankOf()); +} + +TEST_F(EmptyTests, test_shaped_empty_4) { + auto shape = ConstantShapeHelper::getInstance()->vectorShapeInfo(0, nd4j::DataType::FLOAT32); + shape::printShapeInfoLinear("shape", shape); + NDArray array(shape, true, nd4j::LaunchContext::defaultContext()); + std::vector shapeOf({0}); + + ASSERT_TRUE(array.isEmpty()); + ASSERT_EQ(1, array.rankOf()); + ASSERT_EQ(shapeOf, array.getShapeAsVector()); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/LegacyOpsTests.cpp b/libnd4j/tests_cpu/layers_tests/LegacyOpsTests.cpp index 9b460933c..a60bb34e6 100644 --- a/libnd4j/tests_cpu/layers_tests/LegacyOpsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/LegacyOpsTests.cpp @@ -667,4 +667,48 @@ TEST_F(LegacyOpsTests, test_inverse_broadcast_2) { delete row; delete erow; +} + +TEST_F(LegacyOpsTests, test_legacy_reduce_empty_1) { + auto x = NDArrayFactory::create('c', {2, 0, 3}); + auto z = NDArrayFactory::create('c', {2, 3}); + auto e = NDArrayFactory::create('c', {2, 3}); + + int dim = 1; + + NativeOpExecutioner::execReduceSame(LaunchContext::defaultContext(), reduce::SameOps::Sum, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), &dim, 1, x.shapeInfo(), nullptr); + + ASSERT_EQ(e, z); +} + +TEST_F(LegacyOpsTests, test_legacy_reduce_empty_2) { + auto x = NDArrayFactory::create('c', {2, 0, 3}); + auto z = NDArrayFactory::create('c', {2, 3}); + auto e = NDArrayFactory::create('c', {2, 3}); + e.assign(std::numeric_limits::infinity()); + + int dim = 1; + + NativeOpExecutioner::execReduceSame(LaunchContext::defaultContext(), reduce::SameOps::Min, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), &dim, 1, x.shapeInfo(), nullptr); + + ASSERT_EQ(e, z); +} + +TEST_F(LegacyOpsTests, test_legacy_reduce_empty_3) { + auto x = NDArrayFactory::create('c', {2, 0, 3}); + auto z = NDArrayFactory::create('c', {2, 3}); + auto e = NDArrayFactory::create('c', {2, 3}); + e.assign(-std::numeric_limits::infinity()); + + int dim = 1; + + NativeOpExecutioner::execReduceSame(LaunchContext::defaultContext(), reduce::SameOps::Max, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), &dim, 1, x.shapeInfo(), nullptr); + + ASSERT_EQ(e, z); +} + +TEST_F(LegacyOpsTests, test_legacy_transform_float_1) { + auto x = NDArrayFactory::create('c', {1, 0, 4}); + + NativeOpExecutioner::execTransformFloat(LaunchContext::defaultContext(), transform::FloatOps::RSqrt, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, nullptr, nullptr); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/MultiDataTypeTests.cpp b/libnd4j/tests_cpu/layers_tests/MultiDataTypeTests.cpp index 32911c32f..b01348620 100644 --- a/libnd4j/tests_cpu/layers_tests/MultiDataTypeTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/MultiDataTypeTests.cpp @@ -204,7 +204,7 @@ TEST_F(MultiDataTypeTests, ndarray_repeat_test1) { //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_bufferAsT_test1) { NDArray x('f', {2}, {1.5, 3.5}, nd4j::DataType::FLOAT32); - NDArray y('c', {0}, {1.5}, nd4j::DataType::FLOAT32); + NDArray y('c', {}, {1.5}, nd4j::DataType::FLOAT32); const int* buffX = x.bufferAsT(); const int* buffY = y.bufferAsT(); @@ -217,8 +217,8 @@ TEST_F(MultiDataTypeTests, ndarray_assign_test1) { NDArray x('c', {2,2}, {0, 1, 2, 3}, nd4j::DataType::UINT8); NDArray exp('c', {2,2}, {10, 10, 20, 20}, nd4j::DataType::UINT8); - NDArray scalar1('c', {0}, {10.5}, nd4j::DataType::FLOAT32); - NDArray scalar2('c', {0}, {20.8}, nd4j::DataType::DOUBLE); + NDArray scalar1('c', {}, {10.5}, nd4j::DataType::FLOAT32); + NDArray scalar2('c', {}, {20.8}, nd4j::DataType::DOUBLE); x(0,{0}).assign(scalar1); x(1,{0}).assign(scalar2); @@ -233,7 +233,7 @@ TEST_F(MultiDataTypeTests, ndarray_assign_test1) { //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_reduceAlongDimension_test1) { NDArray x('f', {2,2}, {0, 1.5, 2.5, 3.5}, nd4j::DataType::HALF); - NDArray exp1('c', {0}, {3}, nd4j::DataType::INT64); + NDArray exp1('c', {}, {3}, nd4j::DataType::INT64); NDArray exp2('c', {1,1}, {1}, nd4j::DataType::INT64); NDArray exp3('c', {2}, {1,2}, nd4j::DataType::INT64); @@ -254,7 +254,7 @@ TEST_F(MultiDataTypeTests, ndarray_reduceAlongDimension_test1) { //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_reduceAlongDimension_test2) { NDArray x('c', {2, 2}, {0, 1, 2, 3}, nd4j::DataType::INT32); - NDArray exp1('c', {0}, {1.5}, nd4j::DataType::FLOAT32); + NDArray exp1('c', {}, {1.5}, nd4j::DataType::FLOAT32); NDArray exp2('c', {2}, {0.5,2.5}, nd4j::DataType::FLOAT32); auto* scalar1 = x.reduceAlongDimension(nd4j::reduce::Mean, {}/*whole range*/); @@ -272,7 +272,7 @@ TEST_F(MultiDataTypeTests, ndarray_reduceAlongDimension_test2) { //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_reduceAlongDimension_test3) { NDArray x('c', {2, 2}, {0.5, 1.5, 2.5, 3.5}, nd4j::DataType::HALF); - NDArray exp1('c', {0}, {8.}, nd4j::DataType::HALF); + NDArray exp1('c', {}, {8.}, nd4j::DataType::HALF); NDArray exp2('c', {2}, {2.,6.}, nd4j::DataType::HALF); auto scalar1 = x.reduceAlongDims(nd4j::reduce::Sum, {}/*whole range*/); @@ -285,7 +285,7 @@ TEST_F(MultiDataTypeTests, ndarray_reduceAlongDimension_test3) { //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_reduceAlongDimension_test4) { NDArray x('c', {2, 2}, {10.5, 1.5, -2.5, -3.5}, nd4j::DataType::HALF); - NDArray exp1('c', {0}, {1}, nd4j::DataType::BOOL); + NDArray exp1('c', {}, {1}, nd4j::DataType::BOOL); NDArray exp2('c', {2}, {1,0}, nd4j::DataType::BOOL); auto scalar1 = x.reduceAlongDims(nd4j::reduce::IsPositive, {}/*whole range*/); @@ -298,8 +298,8 @@ TEST_F(MultiDataTypeTests, ndarray_reduceAlongDimension_test4) { //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_varianceNumber_test1) { NDArray x('f', {2, 2}, {0, 1, 2, 3}, nd4j::DataType::INT64); - NDArray exp1('c', {0}, {1.666666667}, nd4j::DataType::FLOAT32); - NDArray exp2('c', {0}, {1.118033989}, nd4j::DataType::FLOAT32); + NDArray exp1('c', {}, {1.666666667}, nd4j::DataType::FLOAT32); + NDArray exp2('c', {}, {1.118033989}, nd4j::DataType::FLOAT32); auto scalar1 = x.varianceNumber(variance::SummaryStatsVariance); ASSERT_EQ(scalar1, exp1); @@ -1597,7 +1597,7 @@ TEST_F(MultiDataTypeTests, ndarray_applyTriplewiseLambda_test1) { TEST_F(MultiDataTypeTests, ndarray_applyIndexReduce_test1) { NDArray x1('c', {2,3}, {0, 1, 2, 3, 4, 5}, nd4j::DataType::DOUBLE); - NDArray exp1('c', {0}, {5}, nd4j::DataType::INT64); + NDArray exp1('c', {}, {5}, nd4j::DataType::INT64); NDArray exp2('c', {2}, {2,2}, nd4j::DataType::INT64); NDArray exp3('c', {3}, {1,1,1}, nd4j::DataType::INT64); @@ -1619,10 +1619,10 @@ TEST_F(MultiDataTypeTests, ndarray_applyIndexReduce_test1) { TEST_F(MultiDataTypeTests, ndarray_applyIndexReduce_test2) { NDArray x1('c', {2,3}, {0, 1, 2, 3, 4, 5}, nd4j::DataType::DOUBLE); - NDArray scalar('c', {0}, {5}, nd4j::DataType::INT64); + NDArray scalar('c', {}, {5}, nd4j::DataType::INT64); NDArray vec1('c', {2}, {2,2}, nd4j::DataType::INT64); NDArray vec2('c', {3}, {1,1,1}, nd4j::DataType::INT64); - NDArray exp1('c', {0}, {5}, nd4j::DataType::INT64); + NDArray exp1('c', {}, {5}, nd4j::DataType::INT64); NDArray exp2('c', {2}, {2,2}, nd4j::DataType::INT64); NDArray exp3('c', {3}, {1,1,1}, nd4j::DataType::INT64); @@ -1643,8 +1643,8 @@ TEST_F(MultiDataTypeTests, applyReduce3_test1) { NDArray x2('c', {2,2}, {-1,-2,-3,-4}, nd4j::DataType::INT32); NDArray x3('c', {2,2}, {1.5,1.5,1.5,1.5}, nd4j::DataType::DOUBLE); NDArray x4('c', {2,2}, {1,2,3,4}, nd4j::DataType::DOUBLE); - NDArray exp1('c', {0}, {-30}, nd4j::DataType::FLOAT32); - NDArray exp2('c', {0}, {15}, nd4j::DataType::DOUBLE); + NDArray exp1('c', {}, {-30}, nd4j::DataType::FLOAT32); + NDArray exp2('c', {}, {15}, nd4j::DataType::DOUBLE); auto result = x1.applyReduce3(reduce3::Dot, &x2); ASSERT_EQ(*result, exp1); @@ -1667,8 +1667,8 @@ TEST_F(MultiDataTypeTests, applyReduce3_test2) { NDArray x7('c', {2,3}, {1.5,1.5,1.5,1.5,1.5,1.5}, nd4j::DataType::DOUBLE); NDArray x8('c', {2,3}, {1,2,3,4,5,6}, nd4j::DataType::DOUBLE); - NDArray exp1('c', {0}, {-30}, nd4j::DataType::FLOAT32); - NDArray exp2('c', {0}, {15}, nd4j::DataType::DOUBLE); + NDArray exp1('c', {}, {-30}, nd4j::DataType::FLOAT32); + NDArray exp2('c', {}, {15}, nd4j::DataType::DOUBLE); NDArray exp3('c', {3}, {-18,-20,-18}, nd4j::DataType::FLOAT32); NDArray exp4('c', {2}, {-28,-28}, nd4j::DataType::FLOAT32); NDArray exp5('c', {3}, {7.5,10.5,13.5}, nd4j::DataType::DOUBLE); diff --git a/libnd4j/tests_cpu/layers_tests/NDArrayTests2.cpp b/libnd4j/tests_cpu/layers_tests/NDArrayTests2.cpp index bb4fde754..8f2856f91 100644 --- a/libnd4j/tests_cpu/layers_tests/NDArrayTests2.cpp +++ b/libnd4j/tests_cpu/layers_tests/NDArrayTests2.cpp @@ -1303,3 +1303,21 @@ TEST_F(NDArrayTest2, all_tads_1) { delete arrays; } + +TEST_F(NDArrayTest2, test_trueBroadcast_empty_1) { + auto x = NDArrayFactory::create('c', {0, 2}); + auto y = NDArrayFactory::create('c', {1, 2}); + + auto z = x + y; + + ASSERT_EQ(x, z); +} + +TEST_F(NDArrayTest2, test_trueBroadcast_empty_2) { + auto x = NDArrayFactory::create('c', {0, 2}); + auto y = NDArrayFactory::create('c', {1, 2}); + + auto z = y + x; + + ASSERT_EQ(x, z); +} \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/ParityOpsTests.cpp b/libnd4j/tests_cpu/layers_tests/ParityOpsTests.cpp index 9746d4c21..b5834d7cd 100644 --- a/libnd4j/tests_cpu/layers_tests/ParityOpsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ParityOpsTests.cpp @@ -284,6 +284,55 @@ TEST_F(ParityOpsTests, TestUnstack9) { delete result; } +//////////////////////////////////////////////////////////////////////// +TEST_F(ParityOpsTests, TestUnstack10) { + + auto input = NDArrayFactory::create('c', {3, 0, 2}); + auto exp = NDArrayFactory::create('c', {0,2}); + + nd4j::ops::unstack op; + + auto result = op.execute({&input}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result->status()); + + ASSERT_TRUE(exp.isSameShape(result->at(0))); + ASSERT_TRUE(exp.isSameShape(result->at(1))); + ASSERT_TRUE(exp.isSameShape(result->at(2))); + + delete result; +} + +//////////////////////////////////////////////////////////////////////// +TEST_F(ParityOpsTests, TestUnstack11) { + + auto input = NDArrayFactory::create('c', {3, 0, 2}); + auto exp = NDArrayFactory::create('c', {3,0}); + + nd4j::ops::unstack op; + + auto result = op.execute({&input}, {}, {2}); + ASSERT_EQ(ND4J_STATUS_OK, result->status()); + + ASSERT_TRUE(exp.isSameShape(result->at(0))); + ASSERT_TRUE(exp.isSameShape(result->at(1))); + + delete result; +} + +//////////////////////////////////////////////////////////////////////// +TEST_F(ParityOpsTests, TestUnstack12) { + + auto input = NDArrayFactory::create('c', {3, 0, 2}); + + nd4j::ops::unstack op; + + auto result = op.execute({&input}, {}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, result->status()); + + ASSERT_TRUE(result->size() == 0); + + delete result; +} TEST_F(ParityOpsTests, ExpandDimsTest1) { auto input = NDArrayFactory::create('c', {5, 5}); @@ -765,7 +814,7 @@ TEST_F(ParityOpsTests, Test_Scatter_Add_6) { TEST_F(ParityOpsTests, Test_Scatter_Add_7) { auto matrix = NDArrayFactory::create('c', {10, 3}, {1.f,2.f,3.f,4.f,5.f,6.f,7.f,8.f,9.f,10.f,11.f,12.f,13.f,14.f,15.f,16.f,17.f,18.f,19.f,20.f,21.f,22.f,23.f,24.f,25.f,26.f,27.f,28.f,29.f,30.f}); - NDArray idc('c', {0}, {5}, nd4j::DataType::INT64); + NDArray idc('c', {}, {5}, nd4j::DataType::INT64); auto updates = NDArrayFactory::create('c', {3}, {10.f, 20.f, 30.f}); auto exp = NDArrayFactory::create('c', {10, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f,11.f,12.f, 13.f,14.f,15.f, 26.f,37.f,48.f, 19.f,20.f,21.f, 22.f,23.f,24.f, 25.f,26.f,27.f, 28.f,29.f,30.f}); diff --git a/libnd4j/tests_cpu/layers_tests/TadTests.cpp b/libnd4j/tests_cpu/layers_tests/TadTests.cpp index 4aaff4dbe..aabef927f 100644 --- a/libnd4j/tests_cpu/layers_tests/TadTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/TadTests.cpp @@ -302,6 +302,7 @@ TEST_F(TadTests, calcOffsets_1) { ASSERT_TRUE(offsets[e] == expOffsetsF[e]); } + ///////////////////////////////////////////////////////////////// TEST_F(TadTests, outerArrayIndexes_1) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java index afddfcdce..f37c1658d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java @@ -193,6 +193,14 @@ public class DifferentialFunctionFactory { return new Range(sameDiff(), from, to, step, dataType).outputVariable(); } + public SDVariable range(SDVariable from, SDVariable to, SDVariable step, DataType dataType) { + return new Range(sameDiff(), from, to, step, dataType).outputVariable(); + } + + public SDVariable[] listdiff(SDVariable x, SDVariable y){ + return new ListDiff(sameDiff(), x, y).outputVariables(); + } + public SDVariable cast(SDVariable toCast, DataType toType){ return new Cast(sameDiff(), toCast, toType).outputVariable(); } @@ -860,6 +868,10 @@ public class DifferentialFunctionFactory { return new Permute(sameDiff(), iX, dimensions).outputVariable(); } + public SDVariable permute(SDVariable in, SDVariable dimensions) { + return new Permute(sameDiff(), in, dimensions).outputVariable(); + } + public SDVariable noop(SDVariable input) { return new NoOp(sameDiff(), input).outputVariable(); } @@ -1604,8 +1616,8 @@ public class DifferentialFunctionFactory { return new LogSoftMaxDerivative(sameDiff(), arg, wrt).outputVariable(); } - public SDVariable logSumExp(SDVariable arg, int... dimension) { - return new LogSumExp(sameDiff(), arg, dimension).outputVariable(); + public SDVariable logSumExp(SDVariable arg, boolean keepDims, int... dimension) { + return new LogSumExp(sameDiff(), arg, keepDims, dimension).outputVariable(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/BaseListener.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/BaseListener.java index 7cba5ef51..ebbe9dd5b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/BaseListener.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/BaseListener.java @@ -36,7 +36,7 @@ public abstract class BaseListener implements Listener { } @Override - public void opExecution(SameDiff sd, At at, SameDiffOp op, INDArray[] outputs) { + public void opExecution(SameDiff sd, At at, boolean training, SameDiffOp op, INDArray[] outputs) { //No op } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/Listener.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/Listener.java index f92d532f2..627a282a3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/Listener.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/Listener.java @@ -57,7 +57,7 @@ public interface Listener { * @param op Operation that has just been executed * @param outputs The output arrays for the just-executed operation */ - void opExecution(SameDiff sd, At at, SameDiffOp op, INDArray[] outputs); + void opExecution(SameDiff sd, At at, boolean training, SameDiffOp op, INDArray[] outputs); /** * Called just before each parameter is to be updated - i.e., just before each parameter is modified diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/impl/UIListener.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/impl/UIListener.java index 597d5beb2..36ac88189 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/impl/UIListener.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/impl/UIListener.java @@ -401,13 +401,13 @@ public class UIListener extends BaseListener { @Override - public void opExecution(SameDiff sd, At at, SameDiffOp op, INDArray[] outputs) { + public void opExecution(SameDiff sd, At at, boolean training, SameDiffOp op, INDArray[] outputs) { //Do training set evaluation, if required //Note we'll do it in opExecution not iterationDone because we can't be sure arrays will be stil be around in the future //i.e., we'll eventually add workspaces and clear activation arrays once they have been consumed - if(trainEvalMetrics != null && trainEvalMetrics.size() > 0){ + if(training && trainEvalMetrics != null && trainEvalMetrics.size() > 0){ long time = System.currentTimeMillis(); //First: check if this op is relevant at all to evaluation... diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java index 28aafe2d3..33a773415 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java @@ -1699,6 +1699,14 @@ public class SDVariable extends DifferentialFunction implements Serializable { return sameDiff.shape(this); } + /** + * Get the rank of this variable as a dynamic SDVariable + * @return Rank SDVariable + */ + public SDVariable rank(){ + return sameDiff.rank(this); + } + /** * Reshape the current variable to the specified (dynamic) shape. The output variable will have the same values as the * input, but with the specified shape.
@@ -1746,6 +1754,10 @@ public class SDVariable extends DifferentialFunction implements Serializable { return sameDiff.permute(this, dimensions); } + public SDVariable permute(SDVariable dimensions){ + return sameDiff.permute(null, this, dimensions); + } + /** * Associate the specified array with this variable * @param array Array to associate with this variable diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java index 55aecb354..3c7abdb18 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java @@ -24,6 +24,7 @@ import com.rits.cloning.Cloner; import com.rits.cloning.IFastCloner; import lombok.*; import lombok.extern.slf4j.Slf4j; +import org.apache.commons.io.FileUtils; import org.apache.commons.io.IOUtils; import org.apache.commons.io.output.CloseShieldOutputStream; import org.apache.commons.lang3.ArrayUtils; @@ -79,6 +80,7 @@ import org.nd4j.linalg.primitives.AtomicDouble; import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.util.ArrayUtil; import org.nd4j.linalg.util.DeviceLocalNDArray; +import org.nd4j.linalg.util.ND4JFileUtils; import org.nd4j.shade.jackson.databind.ObjectMapper; import org.nd4j.weightinit.WeightInitScheme; import org.nd4j.weightinit.impl.ConstantInitScheme; @@ -2573,13 +2575,38 @@ public class SameDiff extends SDBaseOps { //Remove updater state for now constant variables for (SDVariable v : variables) { GradientUpdater gu = updaterMap.remove(v.getVarName()); - Map m = gu.getState(); + Map m = gu == null ? null : gu.getState(); if(m != null){ for(INDArray arr : m.values()){ if(arr.closeable()) arr.close(); } } + + //Also check dataset feature/label mapping - remove any placeholders here... + if(trainingConfig.getDataSetFeatureMapping() != null && trainingConfig.getDataSetFeatureMapping().contains(v.getVarName())){ + List newFM = new ArrayList<>(trainingConfig.getDataSetFeatureMapping()); //New list in case of immutable list + newFM.remove(v.getVarName()); + trainingConfig.setDataSetFeatureMapping(newFM); + } + + if(trainingConfig.getDataSetLabelMapping() != null && trainingConfig.getDataSetLabelMapping().contains(v.getVarName())){ + List newLM = new ArrayList<>(trainingConfig.getDataSetLabelMapping()); + newLM.remove(v.getVarName()); + trainingConfig.setDataSetLabelMapping(newLM); + } + + if(trainingConfig.getDataSetFeatureMaskMapping() != null && trainingConfig.getDataSetFeatureMaskMapping().contains(v.getVarName())){ + List newFMM = new ArrayList<>(trainingConfig.getDataSetFeatureMaskMapping()); + newFMM.remove(v.getVarName()); + trainingConfig.setDataSetFeatureMaskMapping(newFMM); + } + + if(trainingConfig.getDataSetLabelMaskMapping() != null && trainingConfig.getDataSetLabelMaskMapping().contains(v.getVarName())){ + List newLMM = new ArrayList<>(trainingConfig.getDataSetLabelMaskMapping()); + newLMM.remove(v.getVarName()); + trainingConfig.setDataSetLabelMaskMapping(newLMM); + } } } } @@ -3156,143 +3183,85 @@ public class SameDiff extends SDBaseOps { outputDataTypes = function.calculateOutputDataTypes(inputDataTypes); } - val outputShape = function.calculateOutputShape(); - if (outputShape == null || outputShape.isEmpty()) { - if (function instanceof CustomOp) { - CustomOp customOp = (CustomOp) function; - //can't guess number of outputs, variable - int num_outputs = function.getNumOutputs(); //Use this in preference - if set. Descriptor might specify 2, but it can sometimes be 2+ + //Determine number of output variables + if (function instanceof CustomOp) { + CustomOp customOp = (CustomOp) function; + int num_outputs = function.getNumOutputs(); //Use this in preference - if set. Descriptor might specify 2, but it can sometimes be 2+ + if (num_outputs <= 0) { + val descriptor = customOp.getDescriptor(); + if (descriptor != null) { + num_outputs = descriptor.getNumOutputs(); + } if (num_outputs <= 0) { - val descriptor = customOp.getDescriptor(); - if (descriptor != null) { - num_outputs = descriptor.getNumOutputs(); - } - if (num_outputs <= 0) { - throw new ND4UnresolvedOutputVariables("Could not determine number of output variables for op " - + function.getOwnName() + " - " + function.getClass().getSimpleName() + ". Ops can override" + - " getNumOutputs() to specify number of outputs if required"); - } + throw new ND4UnresolvedOutputVariables("Could not determine number of output variables for op " + + function.getOwnName() + " - " + function.getClass().getSimpleName() + ". Ops can override" + + " getNumOutputs() to specify number of outputs if required"); } - char ordering = 'c'; - SDVariable[] args = function.args(); - if (args != null && args.length > 0 && args[0].getArr() != null) { //Args may be null or length 0 for some ops, like eye - ordering = function.args()[0].getArr().ordering(); - } - SDVariable[] ret = new SDVariable[num_outputs]; - - //Infer the output types: we can always determine datatype but not always shapes - Preconditions.checkState(isImport || num_outputs == 0 || (outputDataTypes != null && outputDataTypes.size() == num_outputs), - "Incorrect number of output datatypes: got %s but expected datatypes for %s outputs - %s (op: %s)", - (outputDataTypes == null ? null : outputDataTypes.size()), num_outputs, outputDataTypes, function.getClass().getSimpleName()); - - //dynamic shapes - //When importing from TF: convention is "unstack", "unstack:1", "unstack:2", ... - for (int i = 0; i < ret.length; i++) { - SDVariable var = (i == 0 ? getVariable(baseName) : getVariable(baseName + ":" + i)); - if (var == null) { - //Generate new variable name if one with the specified name doesn't exist - //Note: output of an op is ARRAY type - activations, not a trainable parameter. Thus has no weight init scheme - - org.nd4j.linalg.api.buffer.DataType dataType = isImport ? null : outputDataTypes.get(i); - var = var(generateNewVarName(baseName, i), VariableType.ARRAY, null, dataType, (long[])null); - } - var.setOutputIndex(i); - var.setCreator(function); - ret[i] = var; - } - - //Update the internal state: outgoing variables for function - if (getOutputsForFunction(function) == null) - addOutgoingFor(ret, function); - - return ret; } + SDVariable[] ret = new SDVariable[num_outputs]; - //this is for unresolved shapes, we know xyz is always 1 output - else if (function instanceof BaseOp && outputShape.isEmpty()) { - SDVariable[] ret = new SDVariable[1]; - SDVariable checkGet = getVariable(baseName); - char ordering = 'c'; - SDVariable[] args = function.args(); - if (args != null && args.length > 0 && function.args()[0].getArr() != null) { //Args may be null or length 0 for some ops, like eye - ordering = function.args()[0].getArr().ordering(); - } - if (checkGet == null) { + //Infer the output types: we can always determine datatype but not always shapes + Preconditions.checkState(isImport || (outputDataTypes != null && outputDataTypes.size() == num_outputs), + "Incorrect number of output datatypes: got %s but expected datatypes for %s outputs - %s (op: %s)", + (outputDataTypes == null ? null : outputDataTypes.size()), num_outputs, outputDataTypes, function.getClass().getSimpleName()); + + //dynamic shapes + //When importing from TF: convention is "unstack", "unstack:1", "unstack:2", ... + for (int i = 0; i < ret.length; i++) { + SDVariable var = (i == 0 ? getVariable(baseName) : getVariable(baseName + ":" + i)); + if (var == null) { + //Generate new variable name if one with the specified name doesn't exist //Note: output of an op is ARRAY type - activations, not a trainable parameter. Thus has no weight init scheme - org.nd4j.linalg.api.buffer.DataType dataType = outputDataTypes.get(0); - checkGet = var(baseName, VariableType.ARRAY, null, dataType, (long[])null); + + org.nd4j.linalg.api.buffer.DataType dataType = isImport ? null : outputDataTypes.get(i); + var = var(generateNewVarName(baseName, i), VariableType.ARRAY, null, dataType, (long[])null); } - - if (checkGet == null) { - //Note: output of an op is ARRAY type - activations, not a trainable parameter. Thus has no weight init scheme - org.nd4j.linalg.api.buffer.DataType dataType = outputDataTypes.get(0); - checkGet = var(baseName, VariableType.ARRAY, null, dataType, (long[])null); - } - - checkGet.setOutputIndex(0); - checkGet.setCreator(function); - ret[0] = checkGet; - - - //Update the internal state: outgoing variables for function - if (getOutputsForFunction(function) == null) - addOutgoingFor(ret, function); - - return ret; + var.setOutputIndex(i); + var.setCreator(function); + ret[i] = var; } + + //Update the internal state: outgoing variables for function + if (getOutputsForFunction(function) == null) + addOutgoingFor(ret, function); + + return ret; } - //Check that output shapes and output dtypes actually match (they should) - if(!isImport) { - for (int i = 0; i < outputShape.size(); i++) { - org.nd4j.linalg.api.buffer.DataType shapeDataType = outputShape.get(i).dataType(); - org.nd4j.linalg.api.buffer.DataType calcType = outputDataTypes.get(i); - Preconditions.checkState(calcType == shapeDataType, "Calculated output data types do not match for shape calculation vs. datatype calculation:" + - " %s vs %s for op %s output %s", shapeDataType, calcType, function.getClass().getName(), i); - } - } - - char ordering = 'c'; - if (function.args() != null && function.args().length > 0 && function.args()[0].getArr() != null) { - ordering = function.args()[0].getArr().ordering(); - } - - SDVariable[] ret = new SDVariable[outputShape.size()]; - - // ownName/baseName will be used to get variables names - val ownName = function.getOwnName(); - val rootName = baseName; - for (int i = 0; i < ret.length; i++) { - LongShapeDescriptor shape = outputShape.get(i); - // it should be: rootName:index. i.e.: split:1, split:2, split:3, split:4 etc - baseName = rootName + (i > 0 ? ":" + i : ""); + //this is for unresolved shapes, we know xyz is always 1 output + else if (function instanceof BaseOp) { + SDVariable[] ret = new SDVariable[1]; SDVariable checkGet = getVariable(baseName); + char ordering = 'c'; + SDVariable[] args = function.args(); + if (args != null && args.length > 0 && function.args()[0].getArr() != null) { //Args may be null or length 0 for some ops, like eye + ordering = function.args()[0].getArr().ordering(); + } if (checkGet == null) { - // obviously - there's no such var, just add it //Note: output of an op is ARRAY type - activations, not a trainable parameter. Thus has no weight init scheme - - - checkGet = var(baseName, VariableType.ARRAY, null, shape.dataType(), shape.getShape()); - } else if (shape != null && !shapeAlreadyExistsForVarName(checkGet.getVarName())) { - // var exists, let's update its shape - putShapeForVarName(checkGet.getVarName(), shape); - } else if (shape != null && shapeAlreadyExistsForVarName(checkGet.getVarName())) { - // no-op. - // TODO: maybe we should check shapes equality here? - // it's either var that already exist, or something bad happening + org.nd4j.linalg.api.buffer.DataType dataType = outputDataTypes.get(0); + checkGet = var(baseName, VariableType.ARRAY, null, dataType, (long[])null); } if (checkGet == null) { - org.nd4j.linalg.api.buffer.DataType dataType = org.nd4j.linalg.api.buffer.DataType.FLOAT; //TODO FIX THIS - checkGet = var(baseName + (i > 0 ? ":" + i : ""), new ZeroInitScheme(ordering), dataType, shape.getShape()); + //Note: output of an op is ARRAY type - activations, not a trainable parameter. Thus has no weight init scheme + org.nd4j.linalg.api.buffer.DataType dataType = outputDataTypes.get(0); + checkGet = var(baseName, VariableType.ARRAY, null, dataType, (long[])null); } - checkGet.setOutputIndex(i); + checkGet.setOutputIndex(0); checkGet.setCreator(function); - ret[i] = checkGet; - } + ret[0] = checkGet; - return ret; + + //Update the internal state: outgoing variables for function + if (getOutputsForFunction(function) == null) + addOutgoingFor(ret, function); + + return ret; + } else { + throw new RuntimeException("Unknown op type: " + function.getClass()); + } } /** @@ -4682,7 +4651,7 @@ public class SameDiff extends SDBaseOps { * however may increase the file size significantly. * If the network is to be used for inference only, set this to false to save space */ - public void save(File file, boolean saveUpdaterState) { + public void save(@NonNull File file, boolean saveUpdaterState) { try { asFlatFile(file, saveUpdaterState); } catch (IOException e) { @@ -4690,8 +4659,37 @@ public class SameDiff extends SDBaseOps { } } + /** + * As per {@link #save(File, boolean)} but the serialized SameDiff instance is written to the output stream instead. + * Note that this temporarily saves to disk (using {@link ND4JFileUtils#createTempFile(String, String)} then copies all + * file bytes to the stream + * + * @param outputStream Stream to write the serialized SameDiff instance to + * @param saveUpdater If true: save the updater state (arrays etc for Adam, Nesterov, RmsProp etc). If false: don't save + * the updater state. If you want to continue training after loading your model, this should be true, + * however may increase the file size significantly. + * If the network is to be used for inference only, set this to false to save space. + */ + public void save(@NonNull OutputStream outputStream, boolean saveUpdater) { + File tempFile = ND4JFileUtils.createTempFile("SameDiffFile", "temp"); + try { + save(tempFile, saveUpdater); + if (!(outputStream instanceof BufferedOutputStream)) { + outputStream = new BufferedOutputStream(outputStream); + } + try (OutputStream os = outputStream; InputStream is = new BufferedInputStream(new FileInputStream(tempFile))) { + IOUtils.copy(is, os); + } catch (IOException e) { + throw new RuntimeException("Error writing to output stream (or reading from temp file)", e); + } + } finally { + tempFile.delete(); + } + } + /** * Load the SameDiff instance previously saved with {@link #save(File, boolean)} + * * @param file The file to load the network from * @param loadUpdaterState If true - load the updater state (history etc for updaters such as Adam, Nesterov momentum, RMSProp etc). * For inference only, this should be false, as the updater state will take more memory, but @@ -4700,14 +4698,39 @@ public class SameDiff extends SDBaseOps { * The updater state can only be loaded if it was saved with the network. * @return The loaded SameDiff network */ - public static SameDiff load(File file, boolean loadUpdaterState) { - try{ + public static SameDiff load(@NonNull File file, boolean loadUpdaterState) { + try { return fromFlatFile(file, loadUpdaterState); - } catch (IOException e){ + } catch (IOException e) { throw new RuntimeException("Error loading SameDiff instance from file", e); } } + /** + * As per {@link #load(File, boolean)} but the SameDiff instance + * + * @param is Input stream to load the saved network from + * @param loadUpdaterState If true - load the updater state (history etc for updaters such as Adam, Nesterov momentum, RMSProp etc). + * For inference only, this should be false, as the updater state will take more memory, but + * is not required for training. + * If the network is to be trained further, this should be true. + * The updater state can only be loaded if it was saved with the network. + * @return The loaded SameDiff network + */ + public static SameDiff load(@NonNull InputStream is, boolean loadUpdaterState) { + File tempFile = ND4JFileUtils.createTempFile("SameDiffFile", "temp"); + try { + try (OutputStream os = new BufferedOutputStream(new FileOutputStream(tempFile))) { + IOUtils.copy(is, os); + } + return fromFlatFile(tempFile, loadUpdaterState); + } catch (IOException e) { + throw new RuntimeException("Error loading SameDiff instance from file", e); + } finally { + tempFile.delete(); + } + } + /** * This method converts SameDiff instance to FlatBuffers and saves it to file which can be restored later
* This includes the updater state, if applicable diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/InferenceSession.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/InferenceSession.java index f241ffe22..ed053d37c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/InferenceSession.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/InferenceSession.java @@ -109,7 +109,7 @@ public class InferenceSession extends AbstractSession 0){ SameDiffOp sdOp = sameDiff.getOps().get(op.getOwnName()); for(Listener l : listeners){ - l.opExecution(sameDiff, at, sdOp, out); + l.opExecution(sameDiff, at, training, sdOp, out); } } return out; @@ -663,18 +663,11 @@ public class InferenceSession extends AbstractSession 0, "Invalid shape for op %s: shape has invalid values <= 0: shape=%s", customOp.opName(), shape); - } - } - if(currOutput == null || !currOutput.shapeDescriptor().equals(reqShape) || currOutput.isEmpty() != reqShape.isEmpty() || isLoop){ INDArray out; try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { @@ -689,6 +682,7 @@ public class InferenceSession extends AbstractSession outputShape = ((BaseOp)op).calculateOutputShape(); - Preconditions.checkState(outputShape != null && outputShape.size() == 1, "Could not calculate output shape for op: %s", op.getClass()); - INDArray z = op.z(); - if(z == null || !outputShape.get(0).equals(z.shapeDescriptor()) || isLoop){ - if(log.isTraceEnabled()){ - log.trace("Existing op result (z) array shape for op {} was {}, allocating new array of shape {}", - op.getClass().getSimpleName(), (z == null ? null : Arrays.toString(z.shape())), outputShape.get(0).toString()); + if(emptyReduce){ + INDArray z = op.z(); + if (z == null || !op.x().equalShapes(z) || isLoop) { + //Note: edge case: [x,y].sum(empty) = [x,y] for TF import compatibility. + op.setZ(op.x().ulike()); } + } else { + List outputShape = ((BaseOp) op).calculateOutputShape(); + Preconditions.checkState(outputShape != null && outputShape.size() == 1, "Could not calculate output shape for op: %s", op.getClass()); + INDArray z = op.z(); + if (z == null || !outputShape.get(0).equals(z.shapeDescriptor()) || isLoop) { + if (log.isTraceEnabled()) { + log.trace("Existing op result (z) array shape for op {} was {}, allocating new array of shape {}", + op.getClass().getSimpleName(), (z == null ? null : Arrays.toString(z.shape())), outputShape.get(0).toString()); + } - LongShapeDescriptor lsd = outputShape.get(0); - try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { - //TODO Proper workspace support will be added to SameDiff later - z = Nd4j.create(lsd, false); + LongShapeDescriptor lsd = outputShape.get(0); + try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { + //TODO Proper workspace support will be added to SameDiff later + z = Nd4j.create(lsd, false); + } + op.setZ(z); } - op.setZ(z); } df.resolvePropertiesFromSameDiffBeforeExecution(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java index 27c86334a..0fb5dc360 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java @@ -1593,6 +1593,14 @@ public abstract class SDBaseOps { return updateVariableNameAndReference(result, name); } + /** + * As per {@link #permute(String, SDVariable, int...)} but with SDVariable permute dimension + */ + public SDVariable permute(String name, SDVariable x, SDVariable dimensions){ + SDVariable result = f().permute(x, dimensions); + return updateVariableNameAndReference(result, name); + } + /** * Product array reduction operation, optionally along specified dimensions * @@ -1668,6 +1676,14 @@ public abstract class SDBaseOps { return updateVariableNameAndReference(ret, name); } + /** + * As per {@link #range(String, double, double, double, DataType)} but with SDVariable arguments + */ + public SDVariable range(String name, SDVariable from, SDVariable to, SDVariable step, DataType dataType) { + SDVariable ret = f().range(from, to, step, dataType); + return updateVariableNameAndReference(ret, name); + } + /** * Returns the rank (number of dimensions, i.e., length(shape)) of the specified SDVariable as a 0D scalar variable * diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java index 76a99a11e..5543db3c9 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java @@ -1530,6 +1530,21 @@ public class SDMath extends SDOps { return lastIndex(null, in, condition, keepDims, dimensions); } + /** + * List diff operation computes the difference between two 1d arrays, and also returns the indices - i.e., the positions + * where the output appears in the input X.
+ * For inputs X and Y, listDiff returns everything in X but not in Y.
+ * For example, if {@code X=[1,10,3,7,6]} and {@code Y=[10, 6]), then: + * output 0 (difference) = {@code [1,3,7]}
+ * output 1 (indices) = {@code [0, 2, 3]}
+ * @param x Input 1 - input values + * @param y Input 2 - values to remove + * @return 2 outputs - difference, and indices + */ + public SDVariable[] listDiff(SDVariable x, SDVariable y){ + return f().listdiff(x, y); + } + /** * Element-wise logarithm function (base e - natural logarithm): out = log(x) * @@ -1648,8 +1663,12 @@ public class SDMath extends SDOps { * @return Output variable */ public SDVariable logSumExp(String name, SDVariable input, int... dimensions) { + return logSumExp(name, input, false, dimensions); + } + + public SDVariable logSumExp(String name, SDVariable input, boolean keepDims, int... dimensions) { validateNumerical("logSumExp reduction", input); - SDVariable ret = f().logSumExp(input, dimensions); + SDVariable ret = f().logSumExp(input, keepDims, dimensions); return updateVariableNameAndReference(ret, name); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/GradCheckUtil.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/GradCheckUtil.java index 50aa9fe90..d2b783a04 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/GradCheckUtil.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/GradCheckUtil.java @@ -158,8 +158,8 @@ public class GradCheckUtil { int totalCount = 0; double maxError = 0.0; for(SDVariable s : sd.variables()){ - if (fnOutputs.contains(s.getVarName())) { - //This is not an input to the graph + if (fnOutputs.contains(s.getVarName()) || !s.dataType().isFPType()) { + //This is not an input to the graph, or is not a floating point input (so can't be gradient checked) continue; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/BaseEvaluation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/BaseEvaluation.java index 13b98bb7c..e07b5c9b8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/BaseEvaluation.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/BaseEvaluation.java @@ -24,6 +24,8 @@ import org.nd4j.evaluation.classification.*; import org.nd4j.evaluation.regression.RegressionEvaluation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastTo; +import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.primitives.AtomicBoolean; import org.nd4j.linalg.primitives.AtomicDouble; @@ -33,6 +35,7 @@ import org.nd4j.linalg.primitives.serde.JsonDeserializerAtomicBoolean; import org.nd4j.linalg.primitives.serde.JsonDeserializerAtomicDouble; import org.nd4j.linalg.primitives.serde.JsonSerializerAtomicBoolean; import org.nd4j.linalg.primitives.serde.JsonSerializerAtomicDouble; +import org.nd4j.linalg.util.ArrayUtil; import org.nd4j.shade.jackson.annotation.JsonAutoDetect; import org.nd4j.shade.jackson.core.JsonProcessingException; import org.nd4j.shade.jackson.databind.DeserializationFeature; @@ -44,6 +47,7 @@ import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory; import java.io.IOException; import java.io.Serializable; +import java.util.Arrays; import java.util.List; /** @@ -215,11 +219,11 @@ public abstract class BaseEvaluation implements IEvalu if(mask == null){ return reshapeSameShapeTo2d(axis, labels, predictions, mask); } else { - if(labels.rank() == 3){ - if(mask.rank() == 2){ + if(labels.rank() == 3) { + if (mask.rank() == 2) { //Per time step masking - Pair p = EvaluationUtils.extractNonMaskedTimeSteps(labels, predictions, mask); - if(p == null){ + Pair p = EvaluationUtils.extractNonMaskedTimeSteps(labels, predictions, mask); + if (p == null) { return null; } return new Triple<>(p.getFirst(), p.getSecond(), null); @@ -237,8 +241,26 @@ public abstract class BaseEvaluation implements IEvalu if(labels.equalShapes(mask)){ //Per output masking case return reshapeSameShapeTo2d(axis, labels, predictions, mask); + } else if(mask.rank() == 1){ + //Treat 1D mask as per-example masking + Preconditions.checkState(mask.length() == labels.size(0), "For rank 4 labels with shape %ndShape and 1d" + + " mask of shape %ndShape, the mask array length must equal labels dimension 0 size", labels, mask); + long[] reshape = ArrayUtil.nTimes(labels.rank(), 1L); + reshape[0] = mask.size(0); + INDArray mReshape = mask.reshape(reshape); + INDArray bMask = Nd4j.createUninitialized(mask.dataType(), labels.shape()); + BroadcastTo b = new BroadcastTo(mReshape, labels.shape(), bMask); + Nd4j.exec(b); + return reshapeSameShapeTo2d(axis, labels, predictions, bMask); + } else if(mask.rank() == labels.rank() && Shape.areShapesBroadcastable(mask.shape(), labels.shape())){ + //Same rank, but different shape -> broadcast + INDArray bMask = Nd4j.createUninitialized(mask.dataType(), labels.shape()); + BroadcastTo b = new BroadcastTo(mask, labels.shape(), bMask); + Nd4j.exec(b); + return reshapeSameShapeTo2d(axis, labels, predictions, bMask); } - throw new UnsupportedOperationException("Evaluation case not yet implemented: rank 4/5 labels with non-per-output mask arrays"); + throw new UnsupportedOperationException("Evaluation case not supported: labels shape " + Arrays.toString(labels.shape()) + + " with mask shape " + Arrays.toString(mask.shape())); } } } else { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/Evaluation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/Evaluation.java index 30ed1b3f7..f74d3e054 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/Evaluation.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/Evaluation.java @@ -614,7 +614,10 @@ public class Evaluation extends BaseEvaluation { } private String stats(boolean suppressWarnings, boolean includeConfusion, boolean logConfusionSizeWarning){ - String actual, predicted; + if(numRowCounter == 0){ + return "Evaluation: No data available (no evaluation has been performed)"; + } + StringBuilder builder = new StringBuilder().append("\n"); StringBuilder warnings = new StringBuilder(); ConfusionMatrix confusion = confusion(); @@ -820,6 +823,7 @@ public class Evaluation extends BaseEvaluation { * @return the precision for the label */ public double precision(Integer classLabel, double edgeCase) { + Preconditions.checkState(numRowCounter > 0, "Cannot get precision: no evaluation has been performed"); double tpCount = truePositives.getCount(classLabel); double fpCount = falsePositives.getCount(classLabel); return EvaluationUtils.precision((long) tpCount, (long) fpCount, edgeCase); @@ -851,9 +855,7 @@ public class Evaluation extends BaseEvaluation { * @return Average precision */ public double precision(EvaluationAveraging averaging) { - if(getNumRowCounter() == 0){ - return 0.0; //No data - } + Preconditions.checkState(numRowCounter > 0, "Cannot get precision: no evaluation has been performed"); int nClasses = confusion().getClasses().size(); if (averaging == EvaluationAveraging.Macro) { double macroPrecision = 0.0; @@ -966,6 +968,7 @@ public class Evaluation extends BaseEvaluation { * @return Recall rate as a double */ public double recall(int classLabel, double edgeCase) { + Preconditions.checkState(numRowCounter > 0, "Cannot get recall: no evaluation has been performed"); double tpCount = truePositives.getCount(classLabel); double fnCount = falseNegatives.getCount(classLabel); @@ -998,9 +1001,7 @@ public class Evaluation extends BaseEvaluation { * @return Average recall */ public double recall(EvaluationAveraging averaging) { - if(getNumRowCounter() == 0.0){ - return 0.0; //No data - } + Preconditions.checkState(numRowCounter > 0, "Cannot get recall: no evaluation has been performed"); int nClasses = confusion().getClasses().size(); if (averaging == EvaluationAveraging.Macro) { double macroRecall = 0.0; @@ -1046,6 +1047,7 @@ public class Evaluation extends BaseEvaluation { * @return fpr as a double */ public double falsePositiveRate(int classLabel, double edgeCase) { + Preconditions.checkState(numRowCounter > 0, "Cannot get false positive rate: no evaluation has been performed"); double fpCount = falsePositives.getCount(classLabel); double tnCount = trueNegatives.getCount(classLabel); @@ -1078,6 +1080,7 @@ public class Evaluation extends BaseEvaluation { * @return Average false positive rate */ public double falsePositiveRate(EvaluationAveraging averaging) { + Preconditions.checkState(numRowCounter > 0, "Cannot get false positive rate: no evaluation has been performed"); int nClasses = confusion().getClasses().size(); if (averaging == EvaluationAveraging.Macro) { double macroFPR = 0.0; @@ -1117,6 +1120,7 @@ public class Evaluation extends BaseEvaluation { * @return fnr as a double */ public double falseNegativeRate(Integer classLabel, double edgeCase) { + Preconditions.checkState(numRowCounter > 0, "Cannot get false negative rate: no evaluation has been performed"); double fnCount = falseNegatives.getCount(classLabel); double tpCount = truePositives.getCount(classLabel); @@ -1149,6 +1153,7 @@ public class Evaluation extends BaseEvaluation { * @return Average false negative rate */ public double falseNegativeRate(EvaluationAveraging averaging) { + Preconditions.checkState(numRowCounter > 0, "Cannot get false negative rate: no evaluation has been performed"); int nClasses = confusion().getClasses().size(); if (averaging == EvaluationAveraging.Macro) { double macroFNR = 0.0; @@ -1223,6 +1228,7 @@ public class Evaluation extends BaseEvaluation { * @return F_beta */ public double fBeta(double beta, int classLabel, double defaultValue) { + Preconditions.checkState(numRowCounter > 0, "Cannot get fBeta score: no evaluation has been performed"); double precision = precision(classLabel, -1); double recall = recall(classLabel, -1); if (precision == -1 || recall == -1) { @@ -1271,9 +1277,7 @@ public class Evaluation extends BaseEvaluation { * @param averaging Averaging method to use */ public double fBeta(double beta, EvaluationAveraging averaging) { - if(getNumRowCounter() == 0.0){ - return Double.NaN; //No data - } + Preconditions.checkState(numRowCounter > 0, "Cannot get fBeta score: no evaluation has been performed"); int nClasses = confusion().getClasses().size(); if (nClasses == 2) { @@ -1315,6 +1319,7 @@ public class Evaluation extends BaseEvaluation { * @return The G-measure for the specified output */ public double gMeasure(int output) { + Preconditions.checkState(numRowCounter > 0, "Cannot get gMeasure: no evaluation has been performed"); double precision = precision(output); double recall = recall(output); return EvaluationUtils.gMeasure(precision, recall); @@ -1327,6 +1332,7 @@ public class Evaluation extends BaseEvaluation { * @return Average G measure */ public double gMeasure(EvaluationAveraging averaging) { + Preconditions.checkState(numRowCounter > 0, "Cannot get gMeasure: no evaluation has been performed"); int nClasses = confusion().getClasses().size(); if (averaging == EvaluationAveraging.Macro) { double macroGMeasure = 0.0; @@ -1359,9 +1365,7 @@ public class Evaluation extends BaseEvaluation { * @return the accuracy of the guesses so far */ public double accuracy() { - if (getNumRowCounter() == 0) { - return 0.0; //No records - } + Preconditions.checkState(numRowCounter > 0, "Cannot get accuracy: no evaluation has been performed"); //Accuracy: sum the counts on the diagonal of the confusion matrix, divide by total int nClasses = confusion().getClasses().size(); int countCorrect = 0; @@ -1391,6 +1395,7 @@ public class Evaluation extends BaseEvaluation { * @param classIdx Class index to calculate Matthews correlation coefficient for */ public double matthewsCorrelation(int classIdx) { + Preconditions.checkState(numRowCounter > 0, "Cannot get Matthews correlation: no evaluation has been performed"); return EvaluationUtils.matthewsCorrelation((long) truePositives.getCount(classIdx), (long) falsePositives.getCount(classIdx), (long) falseNegatives.getCount(classIdx), (long) trueNegatives.getCount(classIdx)); @@ -1405,6 +1410,7 @@ public class Evaluation extends BaseEvaluation { * @return Average */ public double matthewsCorrelation(EvaluationAveraging averaging) { + Preconditions.checkState(numRowCounter > 0, "Cannot get Matthews correlation: no evaluation has been performed"); int nClasses = confusion().getClasses().size(); if (averaging == EvaluationAveraging.Macro) { double macroMatthewsCorrelation = 0.0; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/EvaluationBinary.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/EvaluationBinary.java index d4f9a5c14..bb4ad2396 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/EvaluationBinary.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/EvaluationBinary.java @@ -27,6 +27,7 @@ import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastGreaterThan; import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.conditions.Conditions; +import org.nd4j.linalg.primitives.Triple; import org.nd4j.serde.jackson.shaded.NDArrayTextDeSerializer; import org.nd4j.serde.jackson.shaded.NDArrayTextSerializer; import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize; @@ -62,6 +63,9 @@ public class EvaluationBinary extends BaseEvaluation { public static final int DEFAULT_PRECISION = 4; public static final double DEFAULT_EDGE_VALUE = 0.0; + @EqualsAndHashCode.Exclude //Exclude axis: otherwise 2 Evaluation instances could contain identical stats and fail equality + protected int axis = 1; + //Because we want evaluation to work for large numbers of examples - and with low precision (FP16), we won't //use INDArrays to store the counts private int[] countTruePositive; //P=1, Act=1 @@ -119,6 +123,29 @@ public class EvaluationBinary extends BaseEvaluation { } } + /** + * Set the axis for evaluation - this is the dimension along which the probability (and label classes) are present.
+ * For DL4J, this can be left as the default setting (axis = 1).
+ * Axis should be set as follows:
+ * For 2D (OutputLayer), shape [minibatch, numClasses] - axis = 1
+ * For 3D, RNNs/CNN1D (DL4J RnnOutputLayer), NCW format, shape [minibatch, numClasses, sequenceLength] - axis = 1
+ * For 3D, RNNs/CNN1D (DL4J RnnOutputLayer), NWC format, shape [minibatch, sequenceLength, numClasses] - axis = 2
+ * For 4D, CNN2D (DL4J CnnLossLayer), NCHW format, shape [minibatch, channels, height, width] - axis = 1
+ * For 4D, CNN2D, NHWC format, shape [minibatch, height, width, channels] - axis = 3
+ * + * @param axis Axis to use for evaluation + */ + public void setAxis(int axis){ + this.axis = axis; + } + + /** + * Get the axis - see {@link #setAxis(int)} for details + */ + public int getAxis(){ + return axis; + } + @Override public void eval(INDArray labels, INDArray networkPredictions) { eval(labels, networkPredictions, (INDArray) null); @@ -126,61 +153,47 @@ public class EvaluationBinary extends BaseEvaluation { @Override public void eval(INDArray labels, INDArray networkPredictions, INDArray maskArray, List recordMetaData) { - throw new UnsupportedOperationException("Not yet implemented"); - } - - @Override - public void evalTimeSeries(INDArray labels, INDArray predictions, INDArray labelsMask) { - if (labelsMask == null || labelsMask.rank() == 2) { - super.evalTimeSeries(labels, predictions, labelsMask); - return; - } else if (labelsMask.rank() != 3) { - throw new IllegalArgumentException("Labels must: must be rank 2 or 3. Got: " + labelsMask.rank()); + if(recordMetaData != null){ + throw new UnsupportedOperationException("Evaluation with record metadata not yet implemented for EvaluationBinary"); } - - //Per output time series masking - INDArray l2d = EvaluationUtils.reshapeTimeSeriesTo2d(labels); - INDArray p2d = EvaluationUtils.reshapeTimeSeriesTo2d(predictions); - INDArray m2d = EvaluationUtils.reshapeTimeSeriesTo2d(labelsMask); - - eval(l2d, p2d, m2d); + eval(labels, networkPredictions, maskArray); } @Override - public void eval(INDArray labels, INDArray networkPredictions, INDArray maskArray) { + public void eval(INDArray labelsArr, INDArray predictionsArr, INDArray maskArr) { //Check for NaNs in predictions - without this, evaulation could silently be intepreted as class 0 prediction due to argmax - long count = Nd4j.getExecutioner().execAndReturn(new MatchCondition(networkPredictions, Conditions.isNan())).getFinalResult().longValue(); + long count = Nd4j.getExecutioner().execAndReturn(new MatchCondition(predictionsArr, Conditions.isNan())).getFinalResult().longValue(); org.nd4j.base.Preconditions.checkState(count == 0, "Cannot perform evaluation with NaNs present in predictions:" + " %s NaNs present in predictions INDArray", count); - if (countTruePositive != null && countTruePositive.length != labels.size(1)) { + if (countTruePositive != null && countTruePositive.length != labelsArr.size(axis)) { throw new IllegalStateException("Labels array does not match stored state size. Expected labels array with " - + "size " + countTruePositive.length + ", got labels array with size " + labels.size(1)); + + "size " + countTruePositive.length + ", got labels array with size " + labelsArr.size(axis) + " for axis " + axis); } - if (labels.rank() == 3) { - evalTimeSeries(labels, networkPredictions, maskArray); - return; - } + Triple p = BaseEvaluation.reshapeAndExtractNotMasked(labelsArr, predictionsArr, maskArr, axis); + INDArray labels = p.getFirst(); + INDArray predictions = p.getSecond(); + INDArray maskArray = p.getThird(); - if(labels.dataType() != networkPredictions.dataType()) - labels = labels.castTo(networkPredictions.dataType()); + if(labels.dataType() != predictions.dataType()) + labels = labels.castTo(predictions.dataType()); - if(decisionThreshold != null && decisionThreshold.dataType() != networkPredictions.dataType()) - decisionThreshold = decisionThreshold.castTo(networkPredictions.dataType()); + if(decisionThreshold != null && decisionThreshold.dataType() != predictions.dataType()) + decisionThreshold = decisionThreshold.castTo(predictions.dataType()); //First: binarize the network prediction probabilities, threshold 0.5 unless otherwise specified //This gives us 3 binary arrays: labels, predictions, masks INDArray classPredictions; if (decisionThreshold != null) { - classPredictions = Nd4j.createUninitialized(DataType.BOOL, networkPredictions.shape()); + classPredictions = Nd4j.createUninitialized(DataType.BOOL, predictions.shape()); Nd4j.getExecutioner() - .exec(new BroadcastGreaterThan(networkPredictions, decisionThreshold, classPredictions, 1)); + .exec(new BroadcastGreaterThan(predictions, decisionThreshold, classPredictions, 1)); } else { - classPredictions = networkPredictions.gt(0.5); + classPredictions = predictions.gt(0.5); } - classPredictions = classPredictions.castTo(networkPredictions.dataType()); + classPredictions = classPredictions.castTo(predictions.dataType()); INDArray notLabels = labels.rsub(1.0); //If labels are 0 or 1, then rsub(1) swaps INDArray notClassPredictions = classPredictions.rsub(1.0); @@ -218,7 +231,7 @@ public class EvaluationBinary extends BaseEvaluation { addInPlace(countFalseNegative, fnCount); if (rocBinary != null) { - rocBinary.eval(labels, networkPredictions, maskArray); + rocBinary.eval(labels, predictions, maskArray); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/EvaluationCalibration.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/EvaluationCalibration.java index 97cb0f91e..8d5ff279f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/EvaluationCalibration.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/EvaluationCalibration.java @@ -369,7 +369,7 @@ public class EvaluationCalibration extends BaseEvaluation * @param classIdx Index of the class to get the reliability diagram for */ public ReliabilityDiagram getReliabilityDiagram(int classIdx) { - + Preconditions.checkState(rDiagBinPosCount != null, "Unable to get reliability diagram: no evaluation has been performed (no data)"); INDArray totalCountBins = rDiagBinTotalCount.getColumn(classIdx); INDArray countPositiveBins = rDiagBinPosCount.getColumn(classIdx); @@ -441,6 +441,7 @@ public class EvaluationCalibration extends BaseEvaluation * @return Residual plot (histogram) - all predictions/classes */ public Histogram getResidualPlot(int labelClassIdx) { + Preconditions.checkState(rDiagBinPosCount != null, "Unable to get residual plot: no evaluation has been performed (no data)"); String title = "Residual Plot - Predictions for Label Class " + labelClassIdx; int[] counts = residualPlotByLabelClass.getColumn(labelClassIdx).dup().data().asInt(); return new Histogram(title, 0.0, 1.0, counts); @@ -465,6 +466,7 @@ public class EvaluationCalibration extends BaseEvaluation * @return Probability histogram */ public Histogram getProbabilityHistogram(int labelClassIdx) { + Preconditions.checkState(rDiagBinPosCount != null, "Unable to get probability histogram: no evaluation has been performed (no data)"); String title = "Network Probabilities Histogram - P(class " + labelClassIdx + ") - Data Labelled Class " + labelClassIdx + " Only"; int[] counts = probHistogramByLabelClass.getColumn(labelClassIdx).dup().data().asInt(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/ROC.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/ROC.java index 93b2cd63b..b124f14d7 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/ROC.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/ROC.java @@ -18,6 +18,7 @@ package org.nd4j.evaluation.classification; import lombok.*; import org.apache.commons.lang3.ArrayUtils; +import org.nd4j.base.Preconditions; import org.nd4j.evaluation.BaseEvaluation; import org.nd4j.evaluation.curves.PrecisionRecallCurve; import org.nd4j.evaluation.curves.RocCurve; @@ -76,6 +77,12 @@ import static org.nd4j.linalg.indexing.NDArrayIndex.interval; @JsonSerialize(using = ROCSerializer.class) @JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY) public class ROC extends BaseEvaluation { + /** + * AUROC: Area under ROC curve
+ * AUPRC: Area under Precision-Recall Curve + */ + public enum Metric {AUROC, AUPRC} + private static final int DEFAULT_EXACT_ALLOC_BLOCK_SIZE = 2048; private final Map counts = new LinkedHashMap<>(); private int thresholdSteps; @@ -189,9 +196,7 @@ public class ROC extends BaseEvaluation { return auc; } - if (exampleCount == 0) { - return Double.NaN; - } + Preconditions.checkState(exampleCount > 0, "Unable to calculate AUC: no evaluation has been performed (no examples)"); this.auc = getRocCurve().calculateAUC(); return auc; @@ -207,6 +212,8 @@ public class ROC extends BaseEvaluation { return rocCurve; } + Preconditions.checkState(exampleCount > 0, "Unable to get ROC curve: no evaluation has been performed (no examples)"); + if (isExact) { //Sort ascending. As we decrease threshold, more are predicted positive. //if(prob <= threshold> predict 0, otherwise predict 1 @@ -354,9 +361,7 @@ public class ROC extends BaseEvaluation { return auprc; } - if (exampleCount == 0) { - return Double.NaN; - } + Preconditions.checkState(exampleCount > 0, "Unable to calculate AUPRC: no evaluation has been performed (no examples)"); auprc = getPrecisionRecallCurve().calculateAUPRC(); return auprc; @@ -376,6 +381,8 @@ public class ROC extends BaseEvaluation { return prCurve; } + Preconditions.checkState(exampleCount > 0, "Unable to get PR curve: no evaluation has been performed (no examples)"); + double[] thresholdOut; double[] precisionOut; double[] recallOut; @@ -779,6 +786,10 @@ public class ROC extends BaseEvaluation { @Override public String stats() { + if(this.exampleCount == 0){ + return "ROC: No data available (no data has been performed)"; + } + StringBuilder sb = new StringBuilder(); sb.append("AUC (Area under ROC Curve): ").append(calculateAUC()).append("\n"); sb.append("AUPRC (Area under Precision/Recall Curve): ").append(calculateAUCPR()); @@ -789,4 +800,15 @@ public class ROC extends BaseEvaluation { } return sb.toString(); } + + public double scoreForMetric(Metric metric){ + switch (metric){ + case AUROC: + return calculateAUC(); + case AUPRC: + return calculateAUCPR(); + default: + throw new IllegalStateException("Unknown metric: " + metric); + } + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/ROCBinary.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/ROCBinary.java index 89fa367dd..e58b019a8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/ROCBinary.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/ROCBinary.java @@ -49,6 +49,12 @@ import java.util.List; public class ROCBinary extends BaseEvaluation { public static final int DEFAULT_STATS_PRECISION = 4; + /** + * AUROC: Area under ROC curve
+ * AUPRC: Area under Precision-Recall Curve + */ + public enum Metric {AUROC, AUPRC} + @JsonSerialize(using = ROCArraySerializer.class) private ROC[] underlying; @@ -392,4 +398,16 @@ public class ROCBinary extends BaseEvaluation { public static ROCBinary fromJson(String json){ return fromJson(json, ROCBinary.class); } + + public double scoreForMetric(Metric metric, int idx){ + assertIndex(idx); + switch (metric){ + case AUROC: + return calculateAUC(idx); + case AUPRC: + return calculateAUCPR(idx); + default: + throw new IllegalStateException("Unknown metric: " + metric); + } + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/ROCMultiClass.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/ROCMultiClass.java index b24fc3ccd..8cf2c3aca 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/ROCMultiClass.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/ROCMultiClass.java @@ -45,6 +45,12 @@ import java.util.List; public class ROCMultiClass extends BaseEvaluation { public static final int DEFAULT_STATS_PRECISION = 4; + /** + * AUROC: Area under ROC curve
+ * AUPRC: Area under Precision-Recall Curve + */ + public enum Metric {AUROC, AUPRC} + private int thresholdSteps; private boolean rocRemoveRedundantPts; @JsonSerialize(using = ROCArraySerializer.class) @@ -344,4 +350,16 @@ public class ROCMultiClass extends BaseEvaluation { public static ROCMultiClass fromJson(String json){ return fromJson(json, ROCMultiClass.class); } + + public double scoreForMetric(Metric metric, int idx){ + assertIndex(idx); + switch (metric){ + case AUROC: + return calculateAUC(idx); + case AUPRC: + return calculateAUCPR(idx); + default: + throw new IllegalStateException("Unknown metric: " + metric); + } + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/regression/RegressionEvaluation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/regression/RegressionEvaluation.java index cd7f074fe..67c8e94d5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/regression/RegressionEvaluation.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/regression/RegressionEvaluation.java @@ -24,6 +24,7 @@ import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.reduce.same.ASum; import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.primitives.Triple; import org.nd4j.serde.jackson.shaded.NDArrayTextDeSerializer; import org.nd4j.serde.jackson.shaded.NDArrayTextSerializer; import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize; @@ -69,6 +70,8 @@ public class RegressionEvaluation extends BaseEvaluation { public static final int DEFAULT_PRECISION = 5; + @EqualsAndHashCode.Exclude //Exclude axis: otherwise 2 Evaluation instances could contain identical stats and fail equality + protected int axis = 1; private boolean initialized; private List columnNames; private long precision; @@ -151,6 +154,29 @@ public class RegressionEvaluation extends BaseEvaluation { } } + /** + * Set the axis for evaluation - this is the dimension along which the probability (and label classes) are present.
+ * For DL4J, this can be left as the default setting (axis = 1).
+ * Axis should be set as follows:
+ * For 2D (OutputLayer), shape [minibatch, numClasses] - axis = 1
+ * For 3D, RNNs/CNN1D (DL4J RnnOutputLayer), NCW format, shape [minibatch, numClasses, sequenceLength] - axis = 1
+ * For 3D, RNNs/CNN1D (DL4J RnnOutputLayer), NWC format, shape [minibatch, sequenceLength, numClasses] - axis = 2
+ * For 4D, CNN2D (DL4J CnnLossLayer), NCHW format, shape [minibatch, channels, height, width] - axis = 1
+ * For 4D, CNN2D, NHWC format, shape [minibatch, height, width, channels] - axis = 3
+ * + * @param axis Axis to use for evaluation + */ + public void setAxis(int axis){ + this.axis = axis; + } + + /** + * Get the axis - see {@link #setAxis(int)} for details + */ + public int getAxis(){ + return axis; + } + @Override public void reset() { initialized = false; @@ -194,20 +220,11 @@ public class RegressionEvaluation extends BaseEvaluation { } @Override - public void eval(INDArray labels, INDArray predictions, INDArray maskArray) { - if (labels.rank() == 3) { - //Time series data - evalTimeSeries(labels, predictions, maskArray); - return; - } - - if (maskArray != null && !Arrays.equals(maskArray.shape(), labels.shape())) { - //Time series (per time step) masks are handled in evalTimeSeries by extracting the relevant steps - // and flattening to 2d - throw new RuntimeException("Per output masking detected, but mask array and labels have different shapes: " - + Arrays.toString(maskArray.shape()) + " vs. labels shape " - + Arrays.toString(labels.shape())); - } + public void eval(INDArray labelsArr, INDArray predictionsArr, INDArray maskArr) { + Triple p = BaseEvaluation.reshapeAndExtractNotMasked(labelsArr, predictionsArr, maskArr, axis); + INDArray labels = p.getFirst(); + INDArray predictions = p.getSecond(); + INDArray maskArray = p.getThird(); if(labels.dataType() != predictions.dataType()) labels = labels.castTo(predictions.dataType()); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/tf/TFGraphMapper.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/tf/TFGraphMapper.java index 094dab317..c734fffc6 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/tf/TFGraphMapper.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/tf/TFGraphMapper.java @@ -16,6 +16,7 @@ package org.nd4j.imports.graphmapper.tf; +import com.github.os72.protobuf351.Descriptors; import com.github.os72.protobuf351.Message; import com.google.common.primitives.Floats; import com.google.common.primitives.Ints; @@ -38,6 +39,8 @@ import org.nd4j.imports.graphmapper.BaseGraphMapper; import org.nd4j.imports.graphmapper.ImportState; import org.nd4j.imports.graphmapper.OpImportFilter; import org.nd4j.imports.graphmapper.OpImportOverride; +import org.nd4j.imports.graphmapper.tf.tensors.TFTensorMapper; +import org.nd4j.imports.graphmapper.tf.tensors.TFTensorMappers; import org.nd4j.linalg.api.buffer.*; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.controlflow.IfImportState; @@ -961,7 +964,7 @@ public class TFGraphMapper extends BaseGraphMapper dimensions = new ArrayList<>(); - for (int e = 0; e < dims; e++) { - // TODO: eventually we want long shapes :( - int dim = (int) tfTensor.getTensorShape().getDim(e).getSize(); - dimensions.add(dim); + + TFTensorMapper m = TFTensorMappers.newMapper(tfTensor); + if(m == null){ + throw new RuntimeException("Not implemented datatype: " + tfTensor.getDtype()); } - - - - arrayShape = ArrayUtil.toLongArray(Ints.toArray(dimensions)); - - if (tfTensor.getDtype() == DataType.DT_INT8 || tfTensor.getDtype() == DataType.DT_UINT8) { - // valueOf - if (tfTensor.getIntValCount() == 1 || ArrayUtil.prod(arrayShape) == 1) { - //straight zero case - if (tfTensor.getIntValCount() < 1) - return Nd4j.scalar(ArrayOptionsHelper.convertToDataType(tfTensor.getDtype()), 0); - - //should be scalar otherwise - int val = tfTensor.getIntVal(0); - - if (arrayShape == null || arrayShape.length == 0) - return Nd4j.scalar(ArrayOptionsHelper.convertToDataType(tfTensor.getDtype()), val); - - return Nd4j.valueArrayOf(arrayShape, val, ArrayOptionsHelper.convertToDataType(tfTensor.getDtype())); - } else if (tfTensor.getIntValCount() > 0) { - val jArray = new int[tfTensor.getIntValCount()]; - for (int e = 0; e < tfTensor.getIntValCount(); e++) { - jArray[e] = tfTensor.getIntVal(e); - } - - // TF arrays are always C - return Nd4j.create(Nd4j.createTypedBuffer(jArray, ArrayOptionsHelper.convertToDataType(tfTensor.getDtype())), arrayShape, Nd4j.getStrides(arrayShape, 'c'), 0, 'c', ArrayOptionsHelper.convertToDataType(tfTensor.getDtype())); - } else { - // binary representation - val bb = tfTensor.getTensorContent().asReadOnlyByteBuffer(); - val fb = bb.order(ByteOrder.nativeOrder()).asReadOnlyBuffer(); - val fa = new byte[fb.capacity()]; - for (int e = 0; e < fb.capacity(); e++) - fa[e] = fb.get(e); - - if (fa.length == 0) - return Nd4j.empty(ArrayOptionsHelper.convertToDataType(tfTensor.getDtype())); - //throw new ND4JIllegalStateException("Can't find Tensor values! Probably you've forgot to freeze graph before saving?"); - - if (fa.length == 1) - return Nd4j.scalar(ArrayOptionsHelper.convertToDataType(tfTensor.getDtype()), fa[0]); - - if (arrayShape.length == 1) - return Nd4j.create(fa, new long[]{fa.length}, new long[]{1}, 'c', ArrayOptionsHelper.convertToDataType(tfTensor.getDtype())); - - val array = Nd4j.create(Nd4j.createTypedBuffer(fa, ArrayOptionsHelper.convertToDataType(tfTensor.getDtype())), arrayShape, Nd4j.getStrides(arrayShape, 'c'), 0, 'c', ArrayOptionsHelper.convertToDataType(tfTensor.getDtype())); - return array; - } - } else if (tfTensor.getDtype() == DataType.DT_INT16 || tfTensor.getDtype() == DataType.DT_UINT16) { - // valueOf - if (tfTensor.getIntValCount() == 1 || ArrayUtil.prod(arrayShape) == 1) { - //straight zero case - if (tfTensor.getIntValCount() < 1) - return Nd4j.scalar(ArrayOptionsHelper.convertToDataType(tfTensor.getDtype()), 0); - - //should be scalar otherwise - int val = tfTensor.getIntVal(0); - - if (arrayShape == null || arrayShape.length == 0) - return Nd4j.scalar(ArrayOptionsHelper.convertToDataType(tfTensor.getDtype()), val); - - return Nd4j.valueArrayOf(arrayShape, val, ArrayOptionsHelper.convertToDataType(tfTensor.getDtype())); - } else if (tfTensor.getIntValCount() > 0) { - val jArray = new int[tfTensor.getIntValCount()]; - for (int e = 0; e < tfTensor.getIntValCount(); e++) { - jArray[e] = tfTensor.getIntVal(e); - } - - // TF arrays are always C - return Nd4j.create(Nd4j.createTypedBuffer(jArray, ArrayOptionsHelper.convertToDataType(tfTensor.getDtype())), arrayShape, Nd4j.getStrides(arrayShape, 'c'), 0, 'c', ArrayOptionsHelper.convertToDataType(tfTensor.getDtype())); - } else { - // binary representation - val bb = tfTensor.getTensorContent().asReadOnlyByteBuffer(); - val fb = bb.order(ByteOrder.nativeOrder()).asShortBuffer(); - val fa = new short[fb.capacity()]; - for (int e = 0; e < fb.capacity(); e++) - fa[e] = fb.get(e); - - if (fa.length == 0) - return Nd4j.empty(ArrayOptionsHelper.convertToDataType(tfTensor.getDtype())); - //throw new ND4JIllegalStateException("Can't find Tensor values! Probably you've forgot to freeze graph before saving?"); - - if (fa.length == 1) - return Nd4j.scalar(ArrayOptionsHelper.convertToDataType(tfTensor.getDtype()), fa[0]); - - if (arrayShape.length == 1) - return Nd4j.create(fa, new long[]{fa.length}, new long[]{1}, 'c', ArrayOptionsHelper.convertToDataType(tfTensor.getDtype())); - - val array = Nd4j.create(Nd4j.createTypedBuffer(fa, ArrayOptionsHelper.convertToDataType(tfTensor.getDtype())), arrayShape, Nd4j.getStrides(arrayShape, 'c'), 0, 'c', ArrayOptionsHelper.convertToDataType(tfTensor.getDtype())); - return array; - } - } else if (tfTensor.getDtype() == DataType.DT_INT32 || tfTensor.getDtype() == DataType.DT_UINT32) { - // valueOf - if (tfTensor.getIntValCount() == 1 || ArrayUtil.prod(arrayShape) == 1) { - //straight zero case - if(tfTensor.getIntValCount() < 1) - return Nd4j.scalar( ArrayOptionsHelper.convertToDataType(tfTensor.getDtype()), 0); - - //should be scalar otherwise - int val = tfTensor.getIntVal(0); - - if (arrayShape == null || arrayShape.length == 0) - return Nd4j.scalar( ArrayOptionsHelper.convertToDataType(tfTensor.getDtype()), val); - - return Nd4j.valueArrayOf(arrayShape, val, ArrayOptionsHelper.convertToDataType(tfTensor.getDtype())); - } else if (tfTensor.getIntValCount() > 0) { - val jArray = new int[tfTensor.getIntValCount()]; - for (int e = 0; e < tfTensor.getIntValCount(); e++) { - jArray[e] = tfTensor.getIntVal(e); - } - - // TF arrays are always C - return Nd4j.create(Nd4j.createTypedBuffer(jArray, ArrayOptionsHelper.convertToDataType(tfTensor.getDtype())), arrayShape, Nd4j.getStrides(arrayShape, 'c'), 0, 'c', ArrayOptionsHelper.convertToDataType(tfTensor.getDtype())); - } else { - // binary representation - val bb = tfTensor.getTensorContent().asReadOnlyByteBuffer(); - val fb = bb.order(ByteOrder.nativeOrder()).asIntBuffer(); - val fa = new int[fb.capacity()]; - for (int e = 0; e < fb.capacity(); e++) - fa[e] = fb.get(e); - - if (fa.length == 0) - return Nd4j.empty(ArrayOptionsHelper.convertToDataType(tfTensor.getDtype())); - //throw new ND4JIllegalStateException("Can't find Tensor values! Probably you've forgot to freeze graph before saving?"); - - if (fa.length == 1) - return Nd4j.scalar(ArrayOptionsHelper.convertToDataType(tfTensor.getDtype()), fa[0]); - - if (arrayShape.length == 1) - return Nd4j.create(fa, new long[]{fa.length}, new long[]{1}, 'c', ArrayOptionsHelper.convertToDataType(tfTensor.getDtype())); - - val array = Nd4j.create(Nd4j.createTypedBuffer(fa, ArrayOptionsHelper.convertToDataType(tfTensor.getDtype())), arrayShape, Nd4j.getStrides(arrayShape, 'c'), 0, 'c', ArrayOptionsHelper.convertToDataType(tfTensor.getDtype())); - return array; - } - } else if (tfTensor.getDtype() == DataType.DT_FLOAT) { - if (tfTensor.getFloatValCount() == 1 || ArrayUtil.prod(arrayShape) == 1) { - //straight zero case - if(tfTensor.getFloatValCount() < 1) - return Nd4j.scalar(org.nd4j.linalg.api.buffer.DataType.FLOAT, 0.0f); - - - float val = tfTensor.getFloatVal(0); - - if (arrayShape == null || arrayShape.length == 0) - arrayShape = new long[]{}; - - INDArray array = Nd4j.valueArrayOf(arrayShape, val, org.nd4j.linalg.api.buffer.DataType.FLOAT); - return array; - } else if (tfTensor.getFloatValCount() > 0) { - float[] jArray = new float[tfTensor.getFloatValCount()]; - for (int e = 0; e < tfTensor.getFloatValCount(); e++) { - jArray[e] = tfTensor.getFloatVal(e); - } - - INDArray array = Nd4j.create(Nd4j.createTypedBuffer(jArray, org.nd4j.linalg.api.buffer.DataType.FLOAT), arrayShape, Nd4j.getStrides(arrayShape), 0, 'c'); - return array; - } else if (tfTensor.getTensorContent().size() > 0){ - // binary representation - val bb = tfTensor.getTensorContent().asReadOnlyByteBuffer(); - val fb = bb.order(ByteOrder.nativeOrder()).asFloatBuffer(); - val fa = new float[fb.capacity()]; - for (int e = 0; e < fb.capacity(); e++) - fa[e] = fb.get(e); - - if (fa.length == 0) - throw new ND4JIllegalStateException("Can't find Tensor values! Probably you've forgot to freeze graph before saving?"); - - if (fa.length == 1) - return Nd4j.scalar(org.nd4j.linalg.api.buffer.DataType.FLOAT, fa[0]); - - if (arrayShape.length == 1) - return Nd4j.create(fa, new long[]{fa.length}, new long[]{1}, 'c', org.nd4j.linalg.api.buffer.DataType.FLOAT); - - val array = Nd4j.create(fa, arrayShape, Nd4j.getStrides(arrayShape, 'c'), 'c', org.nd4j.linalg.api.buffer.DataType.FLOAT); - return array; - } - } else if (tfTensor.getDtype() == DataType.DT_DOUBLE) { - if (tfTensor.getDoubleValCount() == 1 || ArrayUtil.prod(arrayShape) == 1) { - //straight zero case - if(tfTensor.getDoubleValCount() < 1) - return Nd4j.scalar(org.nd4j.linalg.api.buffer.DataType.DOUBLE, 0.0); - - double val = tfTensor.getDoubleVal(0); - INDArray array = Nd4j.trueScalar(val); - if (arrayShape.length > 0) - return array.reshape('c', arrayShape); - else - return array; - } else if (tfTensor.getDoubleValCount() > 0) { - val jArray = new double[tfTensor.getDoubleValCount()]; - for (int e = 0; e < tfTensor.getDoubleValCount(); e++) { - jArray[e] = tfTensor.getDoubleVal(e); - } - - // TF arrays are always C - val array = Nd4j.create(jArray, arrayShape, Nd4j.getStrides(arrayShape, 'c'), 'c', org.nd4j.linalg.api.buffer.DataType.DOUBLE); - return array; - } else if (tfTensor.getTensorContent().size() > 0) { - // binary representation - //DataBuffer buffer = Nd4j.createBuffer(tfTensor.getTensorContent().asReadOnlyByteBuffer(), DataType.FLOAT, (int) length); - //INDArray array = Nd4j.createArrayFromShapeBuffer(buffer, Nd4j.getShapeInfoProvider().createShapeInformation(arrayShape, 'c')); - - // binary representation - val bb = tfTensor.getTensorContent().asReadOnlyByteBuffer(); - val fb = bb.order(ByteOrder.nativeOrder()).asDoubleBuffer(); - val da = new double[fb.capacity()]; - for (int e = 0; e < fb.capacity(); e++) - da[e] = fb.get(e); - - if (da.length == 0) - throw new ND4JIllegalStateException("Can't find Tensor values! Probably you've forgot to freeze graph before saving?"); - - if (da.length == 1) - return Nd4j.trueScalar(da[0]); - - if (arrayShape.length == 1) - return Nd4j.trueVector(da); - - val array = Nd4j.create(da, arrayShape, 0, 'c'); - return array; - } - } else if (tfTensor.getDtype() == DataType.DT_BFLOAT16) { - if (tfTensor.getHalfValCount() == 1 || ArrayUtil.prod(arrayShape) == 1) { - //straight zero case - if(tfTensor.getHalfValCount() < 1) - return Nd4j.scalar(org.nd4j.linalg.api.buffer.DataType.BFLOAT16, 0.0); - - int val = tfTensor.getHalfVal(0); //FP16 byte returned as int32 bytes (not cast/conversion) with 2 bytes padding :/ - INDArray array = Nd4j.scalar(org.nd4j.linalg.api.buffer.DataType.BFLOAT16, 0); - array.putScalar(0, Bfloat16Indexer.toFloat(val)); - if (arrayShape.length > 0) - return array.reshape('c', arrayShape); - else - return array; - } else if (tfTensor.getHalfValCount() > 0) { - //TODO this won't work for huge arrays due to int indexing - int n = tfTensor.getHalfValCount(); - INDArray arr = Nd4j.create(org.nd4j.linalg.api.buffer.DataType.BFLOAT16, n); - ByteBuffer bb = arr.data().pointer().asByteBuffer(); - - for (int e = 0; e < n; e++) { - int val = tfTensor.getHalfVal(e); //FP16 byte returned as int32 bytes (not cast/conversion) with 2 bytes padding :/ - arr.putScalar(e, Bfloat16Indexer.toFloat(val)); - } - - return arr.reshape('c', arrayShape); - } else if (tfTensor.getTensorContent().size() > 0) { - // binary representation - //DataBuffer buffer = Nd4j.createBuffer(tfTensor.getTensorContent().asReadOnlyByteBuffer(), DataType.FLOAT, (int) length); - //INDArray array = Nd4j.createArrayFromShapeBuffer(buffer, Nd4j.getShapeInfoProvider().createShapeInformation(arrayShape, 'c')); - - // binary representation - val bb = tfTensor.getTensorContent().asReadOnlyByteBuffer(); - INDArray arr = Nd4j.createUninitialized(org.nd4j.linalg.api.buffer.DataType.BFLOAT16, arrayShape, 'c'); - ByteBuffer bb2 = arr.data().pointer().asByteBuffer(); - bb2.put(bb); - - return arr; - } - } else if (tfTensor.getDtype() == DataType.DT_HALF) { - if (tfTensor.getHalfValCount() == 1 || ArrayUtil.prod(arrayShape) == 1) { - //straight zero case - if(tfTensor.getHalfValCount() < 1) - return Nd4j.scalar(org.nd4j.linalg.api.buffer.DataType.HALF, 0.0); - - int val = tfTensor.getHalfVal(0); //FP16 byte returned as int32 bytes (not cast/conversion) with 2 bytes padding :/ - INDArray array = Nd4j.scalar(org.nd4j.linalg.api.buffer.DataType.HALF, 0); - //setFloat16ValueFromInt(array, 0, val); - array.putScalar(0, HalfIndexer.toFloat(val)); - if (arrayShape.length > 0) - return array.reshape('c', arrayShape); - else - return array; - } else if (tfTensor.getHalfValCount() > 0) { - //TODO this won't work for huge arrays due to int indexing - int n = tfTensor.getHalfValCount(); - INDArray arr = Nd4j.create(org.nd4j.linalg.api.buffer.DataType.HALF, n); - ByteBuffer bb = arr.data().pointer().asByteBuffer(); - - for (int e = 0; e < n; e++) { - int val = tfTensor.getHalfVal(e); //FP16 byte returned as int32 bytes (not cast/conversion) with 2 bytes padding :/ - //bb.put(2*e, (byte)((val >> 8) & 0xff)); - //bb.put(2*e+1, (byte)(val & 0xff)); - arr.putScalar(e, HalfIndexer.toFloat(val)); - } - - return arr.reshape('c', arrayShape); - } else if (tfTensor.getTensorContent().size() > 0) { - // binary representation - //DataBuffer buffer = Nd4j.createBuffer(tfTensor.getTensorContent().asReadOnlyByteBuffer(), DataType.FLOAT, (int) length); - //INDArray array = Nd4j.createArrayFromShapeBuffer(buffer, Nd4j.getShapeInfoProvider().createShapeInformation(arrayShape, 'c')); - - // binary representation - val bb = tfTensor.getTensorContent().asReadOnlyByteBuffer(); - INDArray arr = Nd4j.createUninitialized(org.nd4j.linalg.api.buffer.DataType.HALF, arrayShape, 'c'); - ByteBuffer bb2 = arr.data().pointer().asByteBuffer(); - bb2.put(bb); - - return arr; - } - } else if (tfTensor.getDtype() == DataType.DT_INT64) { - if (tfTensor.getInt64ValCount() == 1 || ArrayUtil.prod(arrayShape) == 1) { - //straight zero case - if (tfTensor.getInt64ValCount() < 1) - return Nd4j.trueScalar(0.0); - - double val = (double) tfTensor.getInt64Val(0); - INDArray array = Nd4j.trueScalar(val); - if (arrayShape.length > 0) - return array.reshape('c', arrayShape); - else - return array; - } else if (tfTensor.getInt64ValCount() > 0) { - val jArray = new long[tfTensor.getInt64ValCount()]; - for (int e = 0; e < tfTensor.getInt64ValCount(); e++) { - jArray[e] = tfTensor.getInt64Val(e); - } - - // TF arrays are always C - INDArray array = Nd4j.create(Nd4j.createTypedBuffer(jArray, org.nd4j.linalg.api.buffer.DataType.LONG), arrayShape, Nd4j.getStrides(arrayShape, 'c'),0, 'c', org.nd4j.linalg.api.buffer.DataType.LONG); - return array; - } else if (tfTensor.getTensorContent().size() > 0) { - //throw new UnsupportedOperationException("To be implemented yet"); - //Mapping INT bytebuffers should be converted to floating point - val bb = tfTensor.getTensorContent().asReadOnlyByteBuffer(); - val lb = bb.order(ByteOrder.nativeOrder()).asLongBuffer(); - val fa = new long[lb.capacity()]; - for (int e = 0; e < lb.capacity(); e++) - fa[e] = lb.get(e); - if (fa.length == 0) - throw new ND4JIllegalStateException("Can't find Tensor values! Probably you've forgot to freeze graph before saving?"); - - if (fa.length == 1) - return Nd4j.scalar(fa[0]); - - if (arrayShape.length == 1) - return Nd4j.createFromArray(fa); - - val array = Nd4j.create(Nd4j.createTypedBuffer(fa, org.nd4j.linalg.api.buffer.DataType.LONG), arrayShape, Nd4j.getStrides(arrayShape, 'c'), 0, 'c', org.nd4j.linalg.api.buffer.DataType.LONG); - return array; - } - } else if (tfTensor.getDtype() == DataType.DT_UINT32) { - //TODO: not sure if tfTensor.getUInt32Val(int) does casting (unlikely) or just uses int (int32) as a storage mechanism - // i.e., the int32 bytes should be interpeted as uint32 bytes, NOT cast - throw new IllegalStateException("Not yet implemented: UINT32"); - } else if (tfTensor.getDtype() == DataType.DT_UINT64) { - //TODO: not sure if tfTensor.getUInt64Val(int) does casting (unlikely) or just uses long (int64) as a storage mechanism - // i.e., the long bytes should be interpeted as uint64 bytes, NOT cast - throw new IllegalStateException("Not yet implemented: UINT64"); - } else if (tfTensor.getDtype() == DataType.DT_BOOL) { - if (tfTensor.getBoolValCount() == 1 || ArrayUtil.prod(arrayShape) == 1) { - //straight zero case - if (tfTensor.getBoolValCount() < 1) - return Nd4j.scalar(false); - - val val = tfTensor.getBoolVal(0); - val arr = Nd4j.scalar(val); - if (arrayShape.length > 0) - return arr.reshape('c', arrayShape); - else - return arr; - } else if (tfTensor.getBoolValCount() > 0) { - val jArray = new boolean[tfTensor.getBoolValCount()]; - for (int e = 0; e < tfTensor.getBoolValCount(); e++) { - jArray[e] = tfTensor.getBoolVal(e); - } - - // TF arrays are always C - INDArray array = Nd4j.create(Nd4j.createTypedBuffer(jArray, org.nd4j.linalg.api.buffer.DataType.BOOL), arrayShape, Nd4j.getStrides(arrayShape, 'c'), 0, 'c', org.nd4j.linalg.api.buffer.DataType.BOOL); - return array; - } else if (tfTensor.getTensorContent().size() > 0) { - throw new UnsupportedOperationException("Not yet implemented for DataType.DT_BOOL"); - } - } else if(tfTensor.getDtype() == DataType.DT_STRING){ - if (tfTensor.getStringValCount() <= 1 || ArrayUtil.prod(arrayShape) == 1) { - //straight zero case - if (tfTensor.getStringValCount() < 1) - return Nd4j.empty(org.nd4j.linalg.api.buffer.DataType.UTF8); - - String val = tfTensor.getStringVal(0).toStringUtf8(); - INDArray arr = Nd4j.scalar(val); - return arr; - } else if (tfTensor.getStringValCount() > 0) { - String[] sArr = new String[tfTensor.getStringValCount()]; - for (int e = 0; e < sArr.length; e++) { - sArr[e] = tfTensor.getStringVal(e).toStringUtf8(); - } - - // TF arrays are always C - INDArray array = Nd4j.create(sArr).reshape(arrayShape); - return array; - } - } else { - throw new UnsupportedOperationException("Unknown dataType found: [" + tfTensor.getDtype() + "]"); - } - - throw new ND4JIllegalStateException("Invalid method state"); + INDArray out = m.toNDArray(); + return out; } protected static void setFloat16ValueFromInt(INDArray arr, int idx, int bytesAsPaddedInt){ diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/tf/tensors/TFTensorMapper.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/tf/tensors/TFTensorMapper.java new file mode 100644 index 000000000..976961656 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/tf/tensors/TFTensorMapper.java @@ -0,0 +1,41 @@ +package org.nd4j.imports.graphmapper.tf.tensors; + +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.tensorflow.framework.TensorProto; + +import java.nio.Buffer; +import java.nio.ByteBuffer; + +/** + * @param Java array type + * @param Java buffer type + */ +public interface TFTensorMapper { + + enum ValueSource {EMPTY, VALUE_COUNT, BINARY}; + + DataType dataType(); + + long[] shape(); + + boolean isEmpty(); + + ValueSource valueSource(); + + int valueCount(); + + J newArray(int length); + + B getBuffer(ByteBuffer bb); + + INDArray toNDArray(); + + void getValue(J jArr, int i); + + void getValue(J jArr, B buffer, int i); + + INDArray arrayFor(long[] shape, J jArr); + + +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/tf/tensors/TFTensorMappers.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/tf/tensors/TFTensorMappers.java new file mode 100644 index 000000000..722168541 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/graphmapper/tf/tensors/TFTensorMappers.java @@ -0,0 +1,726 @@ +package org.nd4j.imports.graphmapper.tf.tensors; + +import com.github.os72.protobuf351.Descriptors; +import org.bytedeco.javacpp.indexer.Bfloat16ArrayIndexer; +import org.bytedeco.javacpp.indexer.HalfIndexer; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.shape.options.ArrayOptionsHelper; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.util.ArrayUtil; +import org.tensorflow.framework.TensorProto; +import org.tensorflow.framework.TensorShapeProto; + +import java.nio.*; +import java.util.Map; + +public class TFTensorMappers { + + private TFTensorMappers() {} + + + public static TFTensorMapper newMapper(TensorProto tp){ + + switch (tp.getDtype()){ + case DT_HALF: + return new Float16TensorMapper(tp); + case DT_FLOAT: + return new Float32TensorMapper(tp); + case DT_DOUBLE: + return new Float64TensorMapper(tp); + case DT_BFLOAT16: + return new BFloat16TensorMapper(tp); + + case DT_INT8: + return new Int8TensorMapper(tp); + case DT_INT16: + return new Int16TensorMapper(tp); + case DT_INT32: + return new Int32TensorMapper(tp); + case DT_INT64: + return new Int64TensorMapper(tp); + + + case DT_STRING: + return new StringTensorMapper(tp); + + case DT_BOOL: + return new BoolTensorMapper(tp); + + case DT_UINT8: + return new UInt8TensorMapper(tp); + case DT_UINT16: + return new UInt16TensorMapper(tp); + case DT_UINT32: + return new UInt32TensorMapper(tp); + case DT_UINT64: + return new UInt64TensorMapper(tp); + + case DT_QINT8: + case DT_QUINT8: + case DT_QINT32: + case DT_QINT16: + case DT_QUINT16: + throw new IllegalStateException("Unable to map quantized type: " + tp.getDtype()); + case DT_COMPLEX64: + case DT_COMPLEX128: + throw new IllegalStateException("Unable to map complex type: " + tp.getDtype()); + case DT_FLOAT_REF: + case DT_DOUBLE_REF: + case DT_INT32_REF: + case DT_UINT8_REF: + case DT_INT16_REF: + case DT_INT8_REF: + case DT_STRING_REF: + case DT_COMPLEX64_REF: + case DT_INT64_REF: + case DT_BOOL_REF: + case DT_QINT8_REF: + case DT_QUINT8_REF: + case DT_QINT32_REF: + case DT_BFLOAT16_REF: + case DT_QINT16_REF: + case DT_QUINT16_REF: + case DT_UINT16_REF: + case DT_COMPLEX128_REF: + case DT_HALF_REF: + case DT_RESOURCE_REF: + case DT_VARIANT_REF: + case DT_UINT32_REF: + case DT_UINT64_REF: + throw new IllegalStateException("Unable to map reference type: " + tp.getDtype()); + case UNRECOGNIZED: + case DT_RESOURCE: + case DT_VARIANT: + case DT_INVALID: + default: + throw new IllegalStateException("Unable to map type: " + tp.getDtype()); + } + } + + + public static abstract class BaseTensorMapper implements TFTensorMapper { + + protected TensorProto tfTensor; + + public BaseTensorMapper(TensorProto tensorProto){ + this.tfTensor = tensorProto; + } + + @Override + public DataType dataType() { + return ArrayOptionsHelper.convertToDataType(tfTensor.getDtype()); + } + + @Override + public long[] shape() { + int dims = tfTensor.getTensorShape().getDimCount(); + long[] arrayShape = new long[dims]; + for (int e = 0; e < dims; e++) { + arrayShape[e] = tfTensor.getTensorShape().getDim(e).getSize(); + } + return arrayShape; + } + + @Override + public boolean isEmpty() { + return valueSource() == ValueSource.EMPTY; + } + + @Override + public ValueSource valueSource() { + if (valueCount() > 0) { + return ValueSource.VALUE_COUNT; + } + if(tfTensor.getTensorContent() != null && tfTensor.getTensorContent().size() > 0){ + return ValueSource.BINARY; + } + + return ValueSource.EMPTY; + } + + @Override + public INDArray toNDArray() { + DataType dt = dataType(); + ValueSource vs = valueSource(); + long[] shape = shape(); + + INDArray out; + switch (vs){ + case EMPTY: + out = Nd4j.create(dt, shape); + break; + case VALUE_COUNT: + int n = valueCount(); + T array = newArray(n); + for( int i=0; i { + public Float16TensorMapper(TensorProto tensorProto) { + super(tensorProto); + } + + @Override + public int valueCount() { + return tfTensor.getHalfValCount(); + } + + @Override + public float[] newArray(int length) { + return new float[length]; + } + + @Override + public Buffer getBuffer(ByteBuffer bb) { + throw new UnsupportedOperationException("Not yet implemnted: FP16 reading from buffer"); + } + + @Override + public void getValue(float[] jArr, int i) { + int asIntBytes = tfTensor.getHalfVal(i); + jArr[i] = HalfIndexer.toFloat(asIntBytes); + } + + @Override + public void getValue(float[] jArr, Buffer buffer, int i){ + throw new UnsupportedOperationException("Not yet implemnted: FP16 reading from buffer"); + } + + @Override + public INDArray arrayFor(long[] shape, float[] jArr) { + //Edge case: sometimes tf has single float value for entire array (getFloatValCount() == 1) + if(jArr.length == 1 && ArrayUtil.prod(shape) > 1) + return Nd4j.createUninitialized(DataType.HALF, shape).assign(jArr[0]); + return Nd4j.create(jArr, shape, 'c').castTo(DataType.HALF); + } + } + + public static class Float32TensorMapper extends BaseTensorMapper { + public Float32TensorMapper(TensorProto tensorProto) { + super(tensorProto); + } + + @Override + public int valueCount() { + return tfTensor.getFloatValCount(); + } + + @Override + public float[] newArray(int length) { + return new float[length]; + } + + @Override + public FloatBuffer getBuffer(ByteBuffer bb) { + return bb.asFloatBuffer(); + } + + @Override + public void getValue(float[] jArr, int i) { + jArr[i] = tfTensor.getFloatVal(i); + } + + @Override + public void getValue(float[] jArr, FloatBuffer buffer, int i){ + jArr[i] = buffer.get(i); + } + + @Override + public INDArray arrayFor(long[] shape, float[] jArr) { + //Edge case: sometimes tf has single float value for entire array (getFloatValCount() == 1) + if(jArr.length == 1 && ArrayUtil.prod(shape) > 1) + return Nd4j.valueArrayOf(shape, jArr[0]); + return Nd4j.create(jArr, shape, 'c'); + } + } + + public static class Float64TensorMapper extends BaseTensorMapper { + public Float64TensorMapper(TensorProto tensorProto) { + super(tensorProto); + } + + @Override + public int valueCount() { + return tfTensor.getDoubleValCount(); + } + + @Override + public double[] newArray(int length) { + return new double[length]; + } + + @Override + public DoubleBuffer getBuffer(ByteBuffer bb) { + return bb.asDoubleBuffer(); + } + + @Override + public void getValue(double[] jArr, int i) { + jArr[i] = tfTensor.getDoubleVal(i); + } + + @Override + public void getValue(double[] jArr, DoubleBuffer buffer, int i) { + jArr[i] = buffer.get(i); + } + + @Override + public INDArray arrayFor(long[] shape, double[] jArr) { + //Edge case: sometimes tf has double float value for entire array (getDoubleValCount() == 1) + if(jArr.length == 1 && ArrayUtil.prod(shape) > 1) + return Nd4j.valueArrayOf(shape, jArr[0]); + return Nd4j.create(jArr, shape, 'c'); + } + } + + public static class BFloat16TensorMapper extends BaseTensorMapper { + public BFloat16TensorMapper(TensorProto tensorProto) { + super(tensorProto); + } + + @Override + public int valueCount() { + return tfTensor.getHalfValCount(); + } + + @Override + public float[] newArray(int length) { + return new float[length]; + } + + @Override + public ShortBuffer getBuffer(ByteBuffer bb) { + return bb.asShortBuffer(); + } + + @Override + public void getValue(float[] jArr, int i) { + int asIntBytes = tfTensor.getHalfVal(i); + jArr[i] = Bfloat16ArrayIndexer.toFloat(asIntBytes); + } + + @Override + public void getValue(float[] jArr, ShortBuffer buffer, int i){ + throw new UnsupportedOperationException("Not yet implemnted: BFP16 reading from buffer"); + } + + @Override + public INDArray arrayFor(long[] shape, float[] jArr) { + //Edge case: sometimes tf has single float value for entire array (getFloatValCount() == 1) + if(jArr.length == 1 && ArrayUtil.prod(shape) > 1) + return Nd4j.createUninitialized(DataType.HALF, shape).assign(jArr[0]); + return Nd4j.create(jArr, shape, 'c').castTo(DataType.BFLOAT16); + } + } + + //Note TF stortes bytes as integer (other than when in a biffer) + public static class Int8TensorMapper extends BaseTensorMapper { + + public Int8TensorMapper(TensorProto tensorProto) { + super(tensorProto); + } + + @Override + public int valueCount() { + //int8 as integer + return tfTensor.getIntValCount(); + } + + @Override + public int[] newArray(int length) { + return new int[length]; + } + + @Override + public ByteBuffer getBuffer(ByteBuffer bb) { + return bb; + } + + @Override + public void getValue(int[] jArr, int i) { + jArr[i] = tfTensor.getIntVal(i); + } + + @Override + public void getValue(int[] jArr, ByteBuffer buffer, int i) { + jArr[i] = buffer.get(i); + } + + @Override + public INDArray arrayFor(long[] shape, int[] jArr) { + DataType dt = dataType(); + return Nd4j.create(Nd4j.createTypedBuffer(jArr, dt), shape,Nd4j.getStrides(shape, 'c'), 0, 'c', dt); + } + } + + public static class Int16TensorMapper extends BaseTensorMapper { + + public Int16TensorMapper(TensorProto tensorProto) { + super(tensorProto); + } + + @Override + public int valueCount() { + //Shorts as integer + return tfTensor.getIntValCount(); + } + + @Override + public int[] newArray(int length) { + return new int[length]; + } + + @Override + public ShortBuffer getBuffer(ByteBuffer bb) { + return bb.asShortBuffer(); + } + + @Override + public void getValue(int[] jArr, int i) { + jArr[i] = tfTensor.getIntVal(i); + } + + @Override + public void getValue(int[] jArr, ShortBuffer buffer, int i) { + jArr[i] = buffer.get(i); + } + + @Override + public INDArray arrayFor(long[] shape, int[] jArr) { + DataType dt = dataType(); + return Nd4j.create(Nd4j.createTypedBuffer(jArr, dt), shape,Nd4j.getStrides(shape, 'c'), 0, 'c', dt); + } + } + + + public static class Int32TensorMapper extends BaseTensorMapper { + + public Int32TensorMapper(TensorProto tensorProto) { + super(tensorProto); + } + + @Override + public int valueCount() { + return tfTensor.getIntValCount(); + } + + @Override + public int[] newArray(int length) { + return new int[length]; + } + + @Override + public IntBuffer getBuffer(ByteBuffer bb) { + return bb.asIntBuffer(); + } + + @Override + public void getValue(int[] jArr, int i) { + jArr[i] = tfTensor.getIntVal(i); + } + + @Override + public void getValue(int[] jArr, IntBuffer buffer, int i) { + jArr[i] = buffer.get(i); + } + + @Override + public INDArray arrayFor(long[] shape, int[] jArr) { + DataType dt = dataType(); + return Nd4j.create(Nd4j.createTypedBuffer(jArr, dt), shape,Nd4j.getStrides(shape, 'c'), 0, 'c', dt); + } + } + + public static class Int64TensorMapper extends BaseTensorMapper { + + public Int64TensorMapper(TensorProto tensorProto) { + super(tensorProto); + } + + @Override + public int valueCount() { + return tfTensor.getInt64ValCount(); + } + + @Override + public long[] newArray(int length) { + return new long[length]; + } + + @Override + public LongBuffer getBuffer(ByteBuffer bb) { + return bb.asLongBuffer(); + } + + @Override + public void getValue(long[] jArr, int i) { + jArr[i] = tfTensor.getInt64Val(i); + } + + @Override + public void getValue(long[] jArr, LongBuffer buffer, int i) { + jArr[i] = buffer.get(i); + } + + @Override + public INDArray arrayFor(long[] shape, long[] jArr) { + DataType dt = dataType(); + return Nd4j.create(Nd4j.createTypedBuffer(jArr, dt), shape,Nd4j.getStrides(shape, 'c'), 0, 'c', dt); + } + } + + //Note TF stortes bytes as integer (other than when in a buffer) + public static class UInt8TensorMapper extends BaseTensorMapper { + + public UInt8TensorMapper(TensorProto tensorProto) { + super(tensorProto); + } + + @Override + public int valueCount() { + //int8 as integer + return tfTensor.getIntValCount(); + } + + @Override + public int[] newArray(int length) { + return new int[length]; + } + + @Override + public ByteBuffer getBuffer(ByteBuffer bb) { + return bb; + } + + @Override + public void getValue(int[] jArr, int i) { + jArr[i] = tfTensor.getIntVal(i); + } + + @Override + public void getValue(int[] jArr, ByteBuffer buffer, int i) { + byte b = buffer.get(i); //Signed, but bytes are really for unsigned... + jArr[i] = b & 0xff; + } + + @Override + public INDArray arrayFor(long[] shape, int[] jArr) { + DataType dt = dataType(); + return Nd4j.create(Nd4j.createTypedBuffer(jArr, dt), shape,Nd4j.getStrides(shape, 'c'), 0, 'c', dt); + } + } + + public static class UInt16TensorMapper extends BaseTensorMapper { + + public UInt16TensorMapper(TensorProto tensorProto) { + super(tensorProto); + } + + @Override + public int valueCount() { + //int8 as integer + return tfTensor.getIntValCount(); + } + + @Override + public int[] newArray(int length) { + return new int[length]; + } + + @Override + public ShortBuffer getBuffer(ByteBuffer bb) { + return bb.asShortBuffer(); + } + + @Override + public void getValue(int[] jArr, int i) { + jArr[i] = tfTensor.getIntVal(i); + } + + @Override + public void getValue(int[] jArr, ShortBuffer buffer, int i) { + short b = buffer.get(i); //Signed, but bytes are really for unsigned... + jArr[i] = b & 0xffff; + } + + @Override + public INDArray arrayFor(long[] shape, int[] jArr) { + DataType dt = dataType(); + return Nd4j.create(Nd4j.createTypedBuffer(jArr, dt), shape,Nd4j.getStrides(shape, 'c'), 0, 'c', dt); + } + } + + public static class UInt32TensorMapper extends BaseTensorMapper { + + public UInt32TensorMapper(TensorProto tensorProto) { + super(tensorProto); + } + + @Override + public int valueCount() { + //int8 as integer + return tfTensor.getInt64ValCount(); + } + + @Override + public long[] newArray(int length) { + return new long[length]; + } + + @Override + public IntBuffer getBuffer(ByteBuffer bb) { + return bb.asIntBuffer(); + } + + @Override + public void getValue(long[] jArr, int i) { + jArr[i] = tfTensor.getInt64Val(i); + } + + @Override + public void getValue(long[] jArr, IntBuffer buffer, int i) { + int b = buffer.get(i); //Signed, but bytes are really for unsigned... + jArr[i] = b & 0xffffffffL; + } + + @Override + public INDArray arrayFor(long[] shape, long[] jArr) { + DataType dt = dataType(); + return Nd4j.create(Nd4j.createTypedBuffer(jArr, dt), shape,Nd4j.getStrides(shape, 'c'), 0, 'c', dt); + } + } + + public static class UInt64TensorMapper extends BaseTensorMapper { + + public UInt64TensorMapper(TensorProto tensorProto) { + super(tensorProto); + } + + @Override + public int valueCount() { + //int8 as integer + return tfTensor.getInt64ValCount(); + } + + @Override + public long[] newArray(int length) { + return new long[length]; + } + + @Override + public LongBuffer getBuffer(ByteBuffer bb) { + return bb.asLongBuffer(); + } + + @Override + public void getValue(long[] jArr, int i) { + //TODO out of range for largest values! + jArr[i] = tfTensor.getInt64Val(i); + } + + @Override + public void getValue(long[] jArr, LongBuffer buffer, int i) { + //TODO out of range for largest values! + jArr[i] = buffer.get(i); + } + + @Override + public INDArray arrayFor(long[] shape, long[] jArr) { + DataType dt = dataType(); + return Nd4j.create(Nd4j.createTypedBuffer(jArr, dt), shape,Nd4j.getStrides(shape, 'c'), 0, 'c', dt); + } + } + + + public static class StringTensorMapper extends BaseTensorMapper { + public StringTensorMapper(TensorProto tensorProto) { + super(tensorProto); + } + + @Override + public int valueCount() { + return tfTensor.getStringValCount(); + } + + @Override + public String[] newArray(int length) { + return new String[length]; + } + + @Override + public ByteBuffer getBuffer(ByteBuffer bb) { + throw new UnsupportedOperationException("Not supported for String types"); + } + + @Override + public void getValue(String[] jArr, int i) { + jArr[i] = tfTensor.getStringVal(i).toStringUtf8(); + } + + @Override + public void getValue(String[] jArr, ByteBuffer buffer, int i) { + throw new UnsupportedOperationException("Not supported for String types"); + } + + @Override + public INDArray arrayFor(long[] shape, String[] jArr) { + return Nd4j.create(jArr).reshape(shape); + } + } + + public static class BoolTensorMapper extends BaseTensorMapper { + public BoolTensorMapper(TensorProto tensorProto) { + super(tensorProto); + } + + @Override + public int valueCount() { + return tfTensor.getBoolValCount(); + } + + @Override + public boolean[] newArray(int length) { + return new boolean[length]; + } + + @Override + public ByteBuffer getBuffer(ByteBuffer bb) { + throw new UnsupportedOperationException("Not supported for String types"); + } + + @Override + public void getValue(boolean[] jArr, int i) { + jArr[i] = tfTensor.getBoolVal(i); + } + + @Override + public void getValue(boolean[] jArr, ByteBuffer buffer, int i) { + throw new UnsupportedOperationException("Not supported for boolean types"); + } + + @Override + public INDArray arrayFor(long[] shape, boolean[] jArr) { + return Nd4j.create(jArr).reshape(shape); + } + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java index ea72f4c71..b3e72bdde 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java @@ -995,6 +995,8 @@ public abstract class BaseNDArray implements INDArray, Iterable { if (dimension == null || dimension.length == 0) throw new IllegalArgumentException("Invalid input: dimensions not specified (null or length 0)"); + Preconditions.checkArgument(!this.isEmpty(), "tensorAlongDimension(...) can't be used on empty tensors"); + if (dimension.length >= rank() || dimension.length == 1 && dimension[0] == Integer.MAX_VALUE) return this; for (int i = 0; i < dimension.length; i++) @@ -2170,13 +2172,11 @@ public abstract class BaseNDArray implements INDArray, Iterable { /** * Returns true if this ndarray is 2d - * or 3d with a singleton element * * @return true if the element is a matrix, false otherwise */ public boolean isMatrix() { - int rank = rank(); - return (rank == 2 && (size(0) != 1 && size(1) != 1)); + return rank() == 2; } protected INDArray newShape(long[] newShape, char ordering) { @@ -4669,7 +4669,7 @@ public abstract class BaseNDArray implements INDArray, Iterable { */ @Override public INDArray min(int... dimension) { - validateNumericalArray("max", false); + validateNumericalArray("min", false); return Nd4j.getExecutioner().exec(new Min(this, dimension)); } @@ -4687,7 +4687,7 @@ public abstract class BaseNDArray implements INDArray, Iterable { */ @Override public INDArray sum(int... dimension) { - validateNumericalArray("sum", false); + validateNumericalArray("sum", true); return Nd4j.getExecutioner().exec(new Sum(this, dimension)); } @@ -4699,7 +4699,7 @@ public abstract class BaseNDArray implements INDArray, Iterable { */ @Override public INDArray sum(boolean keepDim, int... dimension) { - validateNumericalArray("sum", false); + validateNumericalArray("sum", true); return Nd4j.getExecutioner().exec(new Sum(this, null, keepDim, dimension)); } @@ -4739,7 +4739,7 @@ public abstract class BaseNDArray implements INDArray, Iterable { @Override public INDArray sum(@NonNull INDArray result, int... dimension) { - validateNumericalArray("sum", false); + validateNumericalArray("sum", true); return Nd4j.getExecutioner().exec(new Sum(this, result, dimension)); } @@ -5219,7 +5219,7 @@ public abstract class BaseNDArray implements INDArray, Iterable { return false; if (this.isEmpty() && n.isEmpty()) - return true; + return Shape.shapeEquals(this.shape(), n.shape()); if (this.dataType() != n.dataType()) return false; @@ -6384,6 +6384,18 @@ public abstract class BaseNDArray implements INDArray, Iterable { @Override public INDArray median(int... dimension) { validateNumericalArray("median", false); + //Check edge case: size 1 element. No dimension == full array + if(dimension.length == 0){ + return Nd4j.scalar(dataType(), medianNumber().doubleValue()); + } + long shapeProd = 1; + for (int d : dimension) { + shapeProd *= size(d); + } + if (shapeProd == 1) { + long[] newShape = ArrayUtil.removeIndex(shape(), dimension); + return dup('c').reshape('c', newShape); + } return percentile(50, dimension); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceBoolOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceBoolOp.java index b0a75d099..6e2801c67 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceBoolOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceBoolOp.java @@ -89,7 +89,7 @@ public abstract class BaseReduceBoolOp extends BaseReduceOp implements ReduceBoo return Collections.emptyList(); //Calculate reduction shape. Note that reduction on scalar - returns a scalar - long[] reducedShape = x.length() == 0 ? x.shape() : Shape.getReducedShape(x.shape(),dimensions, isKeepDims()); + long[] reducedShape = x.rank() == 0 ? x.shape() : Shape.getReducedShape(x.shape(),dimensions, isKeepDims()); return Collections.singletonList(LongShapeDescriptor.fromShape(reducedShape, DataType.BOOL)); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceFloatOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceFloatOp.java index 8be76164b..fbebe0c9e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceFloatOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceFloatOp.java @@ -114,7 +114,7 @@ public abstract class BaseReduceFloatOp extends BaseReduceOp implements ReduceFl return Collections.emptyList(); //Calculate reduction shape. Note that reduction on scalar - returns a scalar - long[] reducedShape = x.length() == 0 ? x.shape() : Shape.getReducedShape(x.shape(),dimensions, isKeepDims()); + long[] reducedShape = x.rank() == 0 ? x.shape() : Shape.getReducedShape(x.shape(),dimensions, isKeepDims()); DataType retType = arg().dataType(); if(!retType.isFPType()) retType = Nd4j.defaultFloatingPointType(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceLongOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceLongOp.java index c37bf9986..a85457c37 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceLongOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceLongOp.java @@ -90,7 +90,7 @@ public abstract class BaseReduceLongOp extends BaseReduceOp implements ReduceLon return Collections.emptyList(); //Calculate reduction shape. Note that reduction on scalar - returns a scalar - long[] reducedShape = x.length() == 0 ? x.shape() : Shape.getReducedShape(x.shape(),dimensions, isKeepDims()); + long[] reducedShape = x.rank() == 0 ? x.shape() : Shape.getReducedShape(x.shape(),dimensions, isKeepDims()); return Collections.singletonList(LongShapeDescriptor.fromShape(reducedShape, DataType.LONG)); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceOp.java index 2f123e167..039e91b53 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceOp.java @@ -54,6 +54,8 @@ public abstract class BaseReduceOp extends BaseOp implements ReduceOp { @Setter @Getter protected boolean keepDims = false; protected boolean isComplex = false; + @Setter @Getter + protected boolean isEmptyReduce = false; public BaseReduceOp(SameDiff sameDiff, diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceSameOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceSameOp.java index 5fd71bf94..ab6c8b377 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceSameOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceSameOp.java @@ -96,7 +96,7 @@ public abstract class BaseReduceSameOp extends BaseReduceOp implements ReduceSam return Collections.emptyList(); //Calculate reduction shape. Note that reduction on scalar - returns a scalar - long[] reducedShape = x.length() == 0 ? x.shape() : Shape.getReducedShape(x.shape(),dimensions, isKeepDims()); + long[] reducedShape = x.rank() == 0 ? x.shape() : Shape.getReducedShape(x.shape(),dimensions, isKeepDims()); return Collections.singletonList(LongShapeDescriptor.fromShape(reducedShape, this.resultType())); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BroadcastTo.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BroadcastTo.java index a94064a13..8f26bcd43 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BroadcastTo.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BroadcastTo.java @@ -17,12 +17,16 @@ package org.nd4j.linalg.api.ops.impl.broadcast; import lombok.NoArgsConstructor; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.factory.Broadcast; +import org.nd4j.linalg.factory.Nd4j; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; @@ -45,6 +49,14 @@ public class BroadcastTo extends DynamicCustomOp { super(null, sameDiff, new SDVariable[] {input,shape}, false); } + public BroadcastTo(@NonNull INDArray input, @NonNull long[] shape, @NonNull INDArray output){ + this(input, Nd4j.createFromArray(shape), output); + } + + public BroadcastTo(@NonNull INDArray input, @NonNull INDArray shape, @NonNull INDArray output){ + super(null, new INDArray[]{input, shape}, new INDArray[]{output}); + } + @Override public String opName() { return "broadcast_to"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/custom/LogSumExp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/custom/LogSumExp.java index 477ecf9b6..ff29189cd 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/custom/LogSumExp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/custom/LogSumExp.java @@ -18,9 +18,10 @@ package org.nd4j.linalg.api.ops.impl.reduce.custom; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.BaseReduceFloatOp; -import org.nd4j.linalg.api.ops.BaseReduceOp; +import org.nd4j.linalg.api.ops.DynamicCustomOp; import java.util.Collections; import java.util.List; @@ -30,36 +31,46 @@ import java.util.List; * * @author raver119@gmail.com */ -public class LogSumExp extends BaseReduceFloatOp { - public LogSumExp(SameDiff sameDiff, SDVariable i_v, int[] dimensions) { - super(sameDiff, i_v, dimensions); - } +public class LogSumExp extends DynamicCustomOp { - public LogSumExp(SameDiff sameDiff, SDVariable i_v, SDVariable i_v2, int[] dimensions) { - super(sameDiff, i_v, i_v2, dimensions); + protected boolean keepDims; + + public LogSumExp(SameDiff sameDiff, SDVariable i_v, boolean keepDims, int[] dimensions) { + super(sameDiff, i_v); + if(dimensions != null) { + addIArgument(dimensions); + } + addTArgument(keepDims ? 1.0 : 0.0); + this.keepDims = keepDims; } public LogSumExp() {} public LogSumExp(INDArray x, int... dimensions) { - super(x, dimensions); + this(x, false, dimensions); } - public LogSumExp(INDArray x, INDArray z, int... dimensions) { - super(x, z, dimensions); + public LogSumExp(INDArray x, boolean keepDim, int... dimensions) { + this(x, null, keepDim, dimensions); } - @Override - public int opNum() { - return 11; + public LogSumExp(INDArray x, INDArray z, boolean keepDim, int... dimensions) { + super(null, x,z, Collections.singletonList(keepDim ? 1.0 : 0.0), dimensions); } @Override public String opName() { - return "logexpsum"; + return "reduce_logsumexp"; } + @Override + public List calculateOutputDataTypes(List dataTypes){ + Preconditions.checkState(dataTypes != null && (dataTypes.size() == 1 || dataTypes.size() == 2), + "Expected 1 or 2 input datatypes for %s, got %s", getClass(), dataTypes); + return Collections.singletonList(dataTypes.get(0)); + } + @Override public List doDiff(List f1) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Gather.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Gather.java index a8bc01233..1782f75df 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Gather.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Gather.java @@ -150,32 +150,33 @@ public class Gather extends DynamicCustomOp { SDVariable indicesGrad = sameDiff.zerosLike(arg(1)); SDVariable inputGrad = sameDiff.zerosLike(arg(0)); - int ndim = arg(0).getShape().length; - int a = jaxis; - if(a < 0){ - a += ndim; + SDVariable[] inputs = args(); + SDVariable axis; + SDVariable rank = inputs[0].rank(); + if(inputs.length == 2){ + axis = sameDiff.constant(jaxis); + if(jaxis < 0) + axis = axis.add(rank); + } else { + axis = inputs[2]; } - if(a == 0){ - inputGrad = sameDiff.scatterAdd(inputGrad, arg(1), i_v.get(0)); - } else { - int[] permDims = new int[ndim]; - permDims[0] = a; - for(int i=0; i doDiff(List i_v) { - SDVariable ret = f().permute(i_v.get(0), reverseDims); - return Arrays.asList(ret); + SDVariable ret; + if(args().length == 1) { + //Static dimensions + ret = f().permute(i_v.get(0), reverseDims); + } else { + //Dynamic dimensions + ret = f().permute(i_v.get(0), sameDiff.invertPermutation(arg(1))); + } + return Collections.singletonList(ret); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Transpose.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Transpose.java index 972675891..965d071c3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Transpose.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Transpose.java @@ -52,7 +52,10 @@ public class Transpose extends DynamicCustomOp { public Transpose(SameDiff sameDiff, SDVariable in, int[] permuteDims){ super(null, sameDiff, new SDVariable[]{in}); this.permuteDims = permuteDims; + } + protected Transpose(SameDiff sameDiff, SDVariable in, SDVariable permuteDims){ + super(null, sameDiff, new SDVariable[]{in, permuteDims}); } public Transpose(INDArray input, INDArray result){ @@ -160,34 +163,6 @@ public class Transpose extends DynamicCustomOp { this.permuteDims = Ints.toArray(attributesForNode.get("perm").getIntsList()); } - @Override - public List calculateOutputShape() { - if(numInputArguments() > 1){ - return super.calculateOutputShape(); - } else if (args().length > 1) { - if (args()[0].getArr() != null && args()[1].getArr() != null) { - return super.calculateOutputShape(); - } - } else if (permuteDims == null && arg() != null && arg().getShape() != null) { - this.permuteDims = ArrayUtil.reverseCopy(ArrayUtil.range(0, arg().getShape().length)); - val permutedShape = ArrayUtil.permute(arg().getShape(), permuteDims); - return Arrays.asList(LongShapeDescriptor.fromShape(permutedShape, larg().dataType())); - } else if (permuteDims != null && arg() != null && (!inputArguments.isEmpty() || arg().getShape() != null)) { - long[] shape = null; - if(!inputArguments.isEmpty()) - shape = inputArguments.get(0).shape(); - else - shape = arg().getShape(); - val permutedShape = ArrayUtil.permute(shape, permuteDims); - SDVariable lArg = larg(); - DataType lArgType = lArg.dataType(); - return Arrays.asList(LongShapeDescriptor.fromShape(permutedShape, lArgType)); - } - - return Collections.emptyList(); - } - - @Override public List doDiff(List i_v) { SDVariable ret; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Fill.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Fill.java index b2691cb04..db95ee728 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Fill.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Fill.java @@ -22,19 +22,15 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.imports.descriptors.properties.adapters.DataTypeAdapter; -import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.Op; -import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.exception.ND4JIllegalStateException; -import org.nd4j.linalg.factory.Nd4j; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; -import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Map; @@ -79,21 +75,8 @@ public class Fill extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - if(nodeDef.getInputCount() == 2) { - val targetNode = TFGraphMapper.getInstance().getNodeWithNameFromGraph(graph,nodeDef.getInput(1)); - val mapper = TFGraphMapper.getInstance(); - val secondInputAsScalar = mapper.getNDArrayFromTensor("value",targetNode,graph); - //must be scalar - if(secondInputAsScalar.length() == 1) { - addTArgument(secondInputAsScalar.getDouble(0)); - } - else { - throw new ND4JIllegalStateException("Second input to node " + nodeDef + " should be scalar!"); - } - - org.tensorflow.framework.DataType dt = attributesForNode.get("T").getType(); - this.outputDataType = DataTypeAdapter.dtypeConv(dt); - } + org.tensorflow.framework.DataType dt = attributesForNode.get("T").getType(); + this.outputDataType = DataTypeAdapter.dtypeConv(dt); } @Override @@ -119,50 +102,6 @@ public class Fill extends DynamicCustomOp { } - - @Override - public List calculateOutputShape() { - - - INDArray shape; - DataType dt = outputDataType; - if(sameDiff != null ) { - - int numArgs = args().length; - if (numArgs < 1) - return Collections.emptyList(); - - SDVariable[] args = args(); - shape = !inputArguments.isEmpty() ? inputArguments.get(0) : args()[0].getArr(); - if(args.length > 1){ - dt = arg(1).getArr().dataType(); - } else if(inputArguments.size() > 1){ - dt = inputArguments.get(1).dataType(); - } - } else { - if(numInputArguments() == 0) { - shape = null; - } else { - shape = inputArguments.get(0); - if(inputArguments.size() > 1){ - dt = inputArguments.get(1).dataType(); - } - } - } - if(shape == null) - return Collections.emptyList(); - else { - //TODO properly allow customizing datatype - if(shape.isEmpty() || (shape.length() > 0 && shape.minNumber().intValue() == 0)){ - //Empty shape: Edge case, mainly for TF import - //Also 'shape with zero' are empty arrays in TF - return Collections.singletonList(LongShapeDescriptor.empty(dt)); - } else { - return Arrays.asList(LongShapeDescriptor.fromShape(shape.data().asLong(), dt)); - } - } - } - @Override public String opName() { return "fill"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ListDiff.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ListDiff.java index f0ae43d6e..ff63a07b7 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ListDiff.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ListDiff.java @@ -17,6 +17,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.custom; import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ops.DynamicCustomOp; @@ -30,6 +31,10 @@ public class ListDiff extends DynamicCustomOp { // } + public ListDiff(SameDiff sd, SDVariable x, SDVariable y){ + super(sd, new SDVariable[]{x, y}); + } + @Override public String tensorflowName() { return "ListDiff"; //Note: Seems to be renamed to tf.setdiff1d in public API? diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/Range.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/Range.java index c93091e6a..27e9d9f3c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/Range.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/impl/Range.java @@ -71,6 +71,11 @@ public class Range extends DynamicCustomOp { this.dataType = dataType; } + public Range(SameDiff sd, SDVariable from, SDVariable to, SDVariable step, DataType dataType){ + super(null, sd, new SDVariable[]{from, to, step}); + this.dataType = dataType; + } + @Override public int opNum() { @@ -102,25 +107,6 @@ public class Range extends DynamicCustomOp { } } - @Override - public List calculateOutputShape() { - val iArgs = iArgs(); - val tArgs = tArgs(); - val inputArgs = inputArguments(); - int cnt = 0; - - if(sameDiff != null && args().length > 1) { - if (inputArgs.length > 0) - return Nd4j.getExecutioner().calculateOutputShape(this); - } else if (iArgs.length > 0) { - return Nd4j.getExecutioner().calculateOutputShape(this); - } else if (tArgs.length > 0) { - return Nd4j.getExecutioner().calculateOutputShape(this); - } - - return Collections.emptyList(); - } - @Override public List doDiff(List f1) { return Collections.emptyList(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/shape/LongShapeDescriptor.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/shape/LongShapeDescriptor.java index 710c26461..b15da1ad5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/shape/LongShapeDescriptor.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/shape/LongShapeDescriptor.java @@ -96,6 +96,10 @@ public class LongShapeDescriptor { } + public int rank(){ + return shape == null ? 0 : shape.length; + } + public DataType dataType() { return ArrayOptionsHelper.dataType(extras); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/shape/Shape.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/shape/Shape.java index de9f3624d..e561b2257 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/shape/Shape.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/shape/Shape.java @@ -243,8 +243,20 @@ public class Shape { return Ints.toArray(dims); } + public static boolean containsZeros(long[] shapeOnly) { + for (val v:shapeOnly) + if (v == 0) + return true; + + return false; + } public static long[] broadcastOutputShape(long[] left,long[] right) { + if (containsZeros(left)) + return left; + else if (containsZeros(right)) + return right; + assertBroadcastable(left, right); if(Arrays.equals(left,right)) return left; @@ -3201,7 +3213,16 @@ public class Shape { } public static DataBuffer createShapeInformation(long[] shape, long[] stride, long elementWiseStride, char order, DataType dataType, boolean empty) { - return Nd4j.getExecutioner().createShapeInfo(shape, stride, elementWiseStride, order, dataType, empty); + boolean isEmpty = empty; + if (!empty) + for (val v:shape) { + if (v == 0) { + isEmpty = true; + break; + } + } + + return Nd4j.getExecutioner().createShapeInfo(shape, stride, elementWiseStride, order, dataType, isEmpty); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/BaseNDArrayFactory.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/BaseNDArrayFactory.java index 7821b9376..c664fc479 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/BaseNDArrayFactory.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/BaseNDArrayFactory.java @@ -1310,6 +1310,7 @@ public abstract class BaseNDArrayFactory implements NDArrayFactory { val b = value.byteValue(); val arr = create(new byte[] {b}, new long[] {}, new long[] {}, dataType, ws); return arr; + default: throw new UnsupportedOperationException("Unsupported data type used: " + dataType); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java index cf6c2158a..431f24496 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java @@ -606,7 +606,7 @@ public class Nd4j { } public static INDArray create(LongShapeDescriptor descriptor, boolean initialize) { - if(descriptor.isEmpty()){ + if(descriptor.isEmpty() && descriptor.rank() == 0){ return Nd4j.empty(descriptor.dataType()); } if (initialize) @@ -1819,7 +1819,8 @@ public class Nd4j { INDArray indices = Nd4j.create(ndarray.shape()); INDArray[] ret = new INDArray[2]; - for (int i = 0; i < ndarray.vectorsAlongDimension(dimension); i++) { + long nV = ndarray.vectorsAlongDimension(dimension); + for (int i = 0; i < nV; i++) { INDArray vec = ndarray.vectorAlongDimension(i, dimension); INDArray indexVector = indices.vectorAlongDimension(i, dimension); final Double[] data = new Double[(int) vec.length()]; @@ -2867,12 +2868,12 @@ public class Nd4j { */ public static INDArray diag(INDArray x, int k) { INDArray ret; - if(x.isMatrix()) { - ret = Nd4j.createUninitialized(new long[]{Math.min(x.size(0), x.size(1))}); - Nd4j.getExecutioner().execAndReturn(new DiagPart(x,ret)); - } else { + if(x.isVectorOrScalar() || x.isRowVector() || x.isColumnVector()) { ret = Nd4j.create(new long[]{x.length(), x.length()}); Nd4j.getExecutioner().execAndReturn(new Diag(new INDArray[]{x},new INDArray[]{ret})); + } else { + ret = Nd4j.createUninitialized(new long[]{Math.min(x.size(0), x.size(1))}); + Nd4j.getExecutioner().execAndReturn(new DiagPart(x,ret)); } return ret; } @@ -3646,9 +3647,6 @@ public class Nd4j { * @return the created ndarray */ public static INDArray create(int columns, char order) { - if (columns < 1) - throw new ND4JIllegalStateException("Number of columns should be positive for new INDArray"); - INDArray ret = INSTANCE.create(new long[] {columns}, Nd4j.getStrides(new long[] {columns}, order), 0, order); return ret; } @@ -4664,8 +4662,6 @@ public class Nd4j { if(shape.length == 0) return Nd4j.scalar(dataType(), 0.0); - checkShapeValues(shape); - INDArray ret = INSTANCE.create(shape, ordering); return ret; } @@ -4715,9 +4711,9 @@ public class Nd4j { */ public static void checkShapeValues(long[] shape) { for (long e: shape) { - if (e < 1) + if (e < 0) throw new ND4JIllegalStateException("Invalid shape: Requested INDArray shape " + Arrays.toString(shape) - + " contains dimension size values < 1 (all dimensions must be 1 or more)"); + + " contains dimension size values < 0 (all dimensions must be 0 or more)"); } } @@ -4729,7 +4725,7 @@ public class Nd4j { for (int e: shape) { if (e < 1) throw new ND4JIllegalStateException("Invalid shape: Requested INDArray shape " + Arrays.toString(shape) - + " contains dimension size values < 1 (all dimensions must be 1 or more)"); + + " contains dimension size values < 0 (all dimensions must be 0 or more)"); } } @@ -4883,9 +4879,6 @@ public class Nd4j { } public static INDArray createUninitialized(long length) { - if (length < 1) - throw new IllegalStateException("INDArray length should be positive value"); - long[] shape = new long[] {length}; INDArray ret = INSTANCE.createUninitialized(shape, order()); @@ -4899,9 +4892,6 @@ public class Nd4j { * @return */ public static INDArray createUninitializedDetached(int length) { - if (length < 1) - throw new IllegalStateException("INDArray length should be positive value"); - long[] shape = new long[] {length}; INDArray ret = INSTANCE.createUninitializedDetached(shape, order()); @@ -5205,9 +5195,6 @@ public class Nd4j { * @return the created ndarray */ public static INDArray valueArrayOf(long rows, long columns, double value) { - if (rows < 1 || columns < 1) - throw new ND4JIllegalStateException("Number of rows and columns should be positive for new INDArray"); - INDArray ret = INSTANCE.valueArrayOf(rows, columns, value); return ret; } @@ -5618,8 +5605,6 @@ public class Nd4j { public static INDArray ones(@NonNull int... shape) { if(shape.length == 0) return Nd4j.scalar(dataType(), 1.0); - checkShapeValues(shape); - INDArray ret = INSTANCE.ones(shape); return ret; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/deallocation/DeallocatorService.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/deallocation/DeallocatorService.java index a9e9670ce..a38a4a198 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/deallocation/DeallocatorService.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/deallocation/DeallocatorService.java @@ -23,6 +23,7 @@ import org.apache.commons.lang3.RandomUtils; import org.nd4j.linalg.api.memory.Deallocatable; import org.nd4j.linalg.factory.Nd4j; + import java.lang.ref.ReferenceQueue; import java.util.ArrayList; import java.util.HashMap; @@ -85,10 +86,13 @@ public class DeallocatorService { private class DeallocatorServiceThread extends Thread implements Runnable { private final ReferenceQueue queue; private final int threadIdx; + public static final String DeallocatorThreadNamePrefix = "DeallocatorServiceThread thread "; private DeallocatorServiceThread(@NonNull ReferenceQueue queue, int threadIdx) { this.queue = queue; this.threadIdx = threadIdx; + this.setName(DeallocatorThreadNamePrefix + threadIdx); + setContextClassLoader(null); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/BaseNativeNDArrayFactory.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/BaseNativeNDArrayFactory.java index dfb55f1eb..440b096c3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/BaseNativeNDArrayFactory.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/BaseNativeNDArrayFactory.java @@ -27,6 +27,7 @@ import org.nd4j.linalg.api.concurrency.AffinityManager; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.performance.PerformanceTracker; import org.nd4j.linalg.api.shape.Shape; +import org.nd4j.linalg.api.shape.options.ArrayOptionsHelper; import org.nd4j.linalg.factory.BaseNDArrayFactory; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.memory.MemcpyDirection; @@ -138,34 +139,140 @@ public abstract class BaseNativeNDArrayFactory extends BaseNDArrayFactory { dataPointer.limit(dataBufferElementSize * Shape.length(shapeBuffer)); dataPointer.capacity(dataBufferElementSize * Shape.length(shapeBuffer)); + val jvmShapeInfo = shapeBuffer.asLong(); + log.info("JVM shapeInfo: {}", jvmShapeInfo); + val dtype = ArrayOptionsHelper.dataType(jvmShapeInfo); - if(dataBufferElementSize == (Float.SIZE / 8)) { - FloatPointer dPointer = new FloatPointer(dataPointer.limit() / dataBufferElementSize); + switch (dtype) { + case UBYTE: { + val dPointer = new BytePointer(dataPointer.limit() / dataBufferElementSize); + val perfX = PerformanceTracker.getInstance().helperStartTransaction(); - val perfX = PerformanceTracker.getInstance().helperStartTransaction(); + Pointer.memcpy(dPointer, dataPointer, dataPointer.limit()); - Pointer.memcpy(dPointer, dataPointer, dataPointer.limit()); + PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, dataPointer.limit(), MemcpyDirection.HOST_TO_HOST); - PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, dataPointer.limit(), MemcpyDirection.HOST_TO_HOST); + data = Nd4j.createBuffer(dPointer, + dtype, + Shape.length(shapeBuffer), + UByteIndexer.create(dPointer)); + } + break; + case BYTE: { + val dPointer = new BytePointer(dataPointer.limit() / dataBufferElementSize); + val perfX = PerformanceTracker.getInstance().helperStartTransaction(); - data = Nd4j.createBuffer(dPointer, - DataType.FLOAT, - Shape.length(shapeBuffer), - FloatIndexer.create(dPointer)); - } - else if(dataBufferElementSize == (Double.SIZE / 8)) { - DoublePointer dPointer = new DoublePointer(dataPointer.limit() / dataBufferElementSize); + Pointer.memcpy(dPointer, dataPointer, dataPointer.limit()); - val perfX = PerformanceTracker.getInstance().helperStartTransaction(); + PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, dataPointer.limit(), MemcpyDirection.HOST_TO_HOST); - Pointer.memcpy(dPointer, dataPointer, dataPointer.limit()); + data = Nd4j.createBuffer(dPointer, + dtype, + Shape.length(shapeBuffer), + ByteIndexer.create(dPointer)); + } + break; + case UINT64: + case LONG: { + val dPointer = new LongPointer(dataPointer.limit() / dataBufferElementSize); + val perfX = PerformanceTracker.getInstance().helperStartTransaction(); - PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, dataPointer.limit(), MemcpyDirection.HOST_TO_HOST); + Pointer.memcpy(dPointer, dataPointer, dataPointer.limit()); - data = Nd4j.createBuffer(dPointer, - DataType.DOUBLE, - Shape.length(shapeBuffer), - DoubleIndexer.create(dPointer)); + PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, dataPointer.limit(), MemcpyDirection.HOST_TO_HOST); + + data = Nd4j.createBuffer(dPointer, + dtype, + Shape.length(shapeBuffer), + LongIndexer.create(dPointer)); + } + break; + case UINT32: + case INT: { + val dPointer = new IntPointer(dataPointer.limit() / dataBufferElementSize); + val perfX = PerformanceTracker.getInstance().helperStartTransaction(); + + Pointer.memcpy(dPointer, dataPointer, dataPointer.limit()); + + PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, dataPointer.limit(), MemcpyDirection.HOST_TO_HOST); + + data = Nd4j.createBuffer(dPointer, + dtype, + Shape.length(shapeBuffer), + IntIndexer.create(dPointer)); + } + break; + case UINT16: { + val dPointer = new ShortPointer(dataPointer.limit() / dataBufferElementSize); + val perfX = PerformanceTracker.getInstance().helperStartTransaction(); + + Pointer.memcpy(dPointer, dataPointer, dataPointer.limit()); + + PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, dataPointer.limit(), MemcpyDirection.HOST_TO_HOST); + + data = Nd4j.createBuffer(dPointer, + dtype, + Shape.length(shapeBuffer), + UShortIndexer.create(dPointer)); + } + break; + case SHORT: { + val dPointer = new ShortPointer(dataPointer.limit() / dataBufferElementSize); + val perfX = PerformanceTracker.getInstance().helperStartTransaction(); + + Pointer.memcpy(dPointer, dataPointer, dataPointer.limit()); + + PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, dataPointer.limit(), MemcpyDirection.HOST_TO_HOST); + + data = Nd4j.createBuffer(dPointer, + dtype, + Shape.length(shapeBuffer), + ShortIndexer.create(dPointer)); + } + break; + case BFLOAT16: + case HALF: { + val dPointer = new ShortPointer(dataPointer.limit() / dataBufferElementSize); + val perfX = PerformanceTracker.getInstance().helperStartTransaction(); + + Pointer.memcpy(dPointer, dataPointer, dataPointer.limit()); + + PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, dataPointer.limit(), MemcpyDirection.HOST_TO_HOST); + + data = Nd4j.createBuffer(dPointer, + dtype, + Shape.length(shapeBuffer), + HalfIndexer.create(dPointer)); + } + break; + case FLOAT: { + val dPointer = new FloatPointer(dataPointer.limit() / dataBufferElementSize); + val perfX = PerformanceTracker.getInstance().helperStartTransaction(); + + Pointer.memcpy(dPointer, dataPointer, dataPointer.limit()); + + PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, dataPointer.limit(), MemcpyDirection.HOST_TO_HOST); + + data = Nd4j.createBuffer(dPointer, + dtype, + Shape.length(shapeBuffer), + FloatIndexer.create(dPointer)); + } + break; + case DOUBLE: { + val dPointer = new DoublePointer(dataPointer.limit() / dataBufferElementSize); + val perfX = PerformanceTracker.getInstance().helperStartTransaction(); + + Pointer.memcpy(dPointer, dataPointer, dataPointer.limit()); + + PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, dataPointer.limit(), MemcpyDirection.HOST_TO_HOST); + + data = Nd4j.createBuffer(dPointer, + dtype, + Shape.length(shapeBuffer), + DoubleIndexer.create(dPointer)); + } + break; } INDArray ret = Nd4j.create(data, diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java index a9db9c52f..aaaee1d2e 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java @@ -171,6 +171,12 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { val dimension = Shape.normalizeAxis(op.x().rank(), op.dimensions().toIntVector()); + if (op.x().isEmpty()) { + for (val d:dimension) { + Preconditions.checkArgument(op.x().shape()[d] != 0, "IndexReduce can't be issued along axis with 0 in shape"); + } + } + boolean keepDims = op.isKeepDims(); long[] retShape = Shape.reductionShape(op.x(), dimension, true, keepDims); @@ -236,6 +242,21 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { public INDArray exec(ReduceOp op) { Preconditions.checkNotNull(op.x(), "Op.x() cannot be null: Was null for op %s", op); op.validateDataTypes(); + + if(op instanceof BaseReduceOp && ((BaseReduceOp)op).isEmptyReduce()){ + //Edge case for TF import compatibility: [x,y].reduce(empty) = [x,y] + //Note that "empty" axis is NOT the same as length 0, as in INDArray.sum(new int[0]), which means "all dimensions" + if(op.z() != null){ + Preconditions.checkState(op.x().equalShapes(op.z()), "For empty reductions, result (z) array must have same shape as x shape." + + " Got: x=%ndShape, z=%ndShape", op.x(), op.z()); + op.z().assign(op.x()); + return op.z(); + } else { + op.setZ(op.x().dup()); + return op.z(); + } + } + val dimension = Shape.normalizeAxis(op.x().rank(), op.dimensions().toIntVector()); //validateDataType(Nd4j.dataType(), op); @@ -246,6 +267,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { boolean keepDims = op.isKeepDims(); long[] retShape = Shape.reductionShape(op.x(), dimension, true, keepDims); + if (op.x().isVector() && op.x().length() == ArrayUtil.prod(retShape) && ArrayUtil.prodLong(retShape) > 1 && op.y() == null) return op.noOp(); @@ -264,7 +286,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { if (op.y() != null) { //2 options here: either pairwise, equal sizes - OR every X TAD vs. entirety of Y - if(op.x().lengthLong() == op.y().lengthLong()) { + if(op.x().length() == op.y().length()) { //Pairwise if (op.x().tensorsAlongDimension(dimension) != op.y().tensorsAlongDimension(dimension)) { throw new ND4JIllegalStateException("Number of TADs along dimension don't match: (x shape = " + @@ -289,8 +311,12 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { } else { // compare length long shapeProduct = (retShape.length == 0 ? 1 : ArrayUtil.prodLong(retShape)); - if (!op.isComplexAccumulation() && op.z().lengthLong() != shapeProduct) - throw new ND4JIllegalStateException("Shape of target array for reduction [" + Arrays.toString(op.z().shape()) + "] doesn't match expected [" + Arrays.toString(retShape) + "]"); + if (!op.isComplexAccumulation() && op.z().length() != shapeProduct) { + if(!(op.x().isEmpty() && op.isKeepDims())){ + //Empty reductions are special case: [1,0].sum(0,1,keep=true) -> shape [1,1] + throw new ND4JIllegalStateException("Shape of target array for reduction [" + Arrays.toString(op.z().shape()) + "] doesn't match expected [" + Arrays.toString(retShape) + "]"); + } + } else if (op.isComplexAccumulation()) { long xT = op.x().tensorsAlongDimension(dimension); long yT = op.y().tensorsAlongDimension(dimension); @@ -309,16 +335,16 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { * and the associated offsets for each {@link INDArray#tensorAlongDimension(int, int...)} * The first item is the shape information. The second one is the offsets. */ - Pair tadBuffers = tadManager.getTADOnlyShapeInfo(op.x(), dimension); + Pair tadBuffers = op.x().isEmpty() ? Pair.makePair(op.x().data(), null): tadManager.getTADOnlyShapeInfo(op.x(), dimension); Pair yTadBuffers = null; /** * Note that we use addresses in libnd4j. * We use reinterpret cast in c to take the long * we pass to JNI. This manages overhead. */ - Pointer hostTadShapeInfo = tadBuffers.getFirst().addressPointer(); + Pointer hostTadShapeInfo = op.x().isEmpty() ? op.x().shapeInfoDataBuffer().addressPointer() : tadBuffers.getFirst().addressPointer(); - DataBuffer offsets = tadBuffers.getSecond(); + DataBuffer offsets = op.x().isEmpty() ? null : tadBuffers.getSecond(); Pointer hostTadOffsets = offsets == null ? null : offsets.addressPointer(); // we're going to check, if that's TAD vs TAD comparison or TAD vs full array. if later - we're going slightly different route diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java index e62428284..ec11d6c23 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java @@ -1996,24 +1996,21 @@ public class ShapeOpValidation extends BaseOpValidation { public void testGatherEmpty(){ /* tf.reset_default_graph() - # Hack to create empty array - input = tf.constant([False], dtype=tf.bool) - empty = tf.where(condition=input) - emptyInt = tf.cast(empty, tf.int32) + emptyInt = tf.constant([], shape=[0], dtype=tf.int32) ingather = tf.reshape(tf.range(start=0,limit=100,delta=1,dtype=tf.float32), [25,4]) gather = tf.gather(params=ingather, indices=emptyInt) sess = tf.Session() out = sess.run([gather]) print(out[0].shape); print(out[0]); - >> (0, 1, 4) + >> (0, 4) >> [] */ - Nd4j.getExecutioner().enableVerboseMode(true); - Nd4j.getExecutioner().enableDebugMode(true); +// Nd4j.getExecutioner().enableVerboseMode(true); +// Nd4j.getExecutioner().enableDebugMode(true); - INDArray emptyInt = Nd4j.empty(DataType.INT); + INDArray emptyInt = Nd4j.create(DataType.INT, 0); INDArray inGather = Nd4j.linspace(1,100,100,DataType.FLOAT).reshape(25,4); DynamicCustomOp op = DynamicCustomOp.builder("gather") @@ -2022,9 +2019,9 @@ public class ShapeOpValidation extends BaseOpValidation { List l = op.calculateOutputShape(); long[] shape = l.get(0).getShape(); + assertArrayEquals(new long[]{0,4}, l.get(0).getShape()); boolean isEmpty = l.get(0).isEmpty(); assertTrue(isEmpty); - assertArrayEquals(new long[0], shape); } @Test @@ -2071,11 +2068,9 @@ public class ShapeOpValidation extends BaseOpValidation { cotcat(empty,nonEmpty) -> nonEmpty, etc (i.e., empty arrays are ignored) tf.reset_default_graph() - # Hack to create empty array input = tf.constant([False], dtype=tf.bool) - empty = tf.where(condition=input) - emptyFloat = tf.cast(empty, tf.float32) - var11 = tf.reshape(tf.constant([1], dtype=tf.float32), shape=[1,1]) + emptyFloat = tf.constant([], shape=[0,1], dtype=tf.float32) + var11 = tf.constant([1], dtype=tf.float32, shape=[1,1]) concat = tf.concat(values=[emptyFloat, emptyFloat, var11, emptyFloat], axis=0) @@ -2085,18 +2080,32 @@ public class ShapeOpValidation extends BaseOpValidation { print(out[0]); */ - INDArray empty = Nd4j.empty(DataType.FLOAT); + INDArray one1 = Nd4j.create(DataType.FLOAT, 1, 1); + INDArray empty01 = Nd4j.create(DataType.FLOAT, 0, 1); DynamicCustomOp op = DynamicCustomOp.builder("concat") - .addInputs(empty, empty, empty) + .addInputs(empty01, empty01, empty01) .addIntegerArguments(0) //axis = 0 .build(); List l = op.calculateOutputShape(); assertEquals(1, l.size()); assertTrue(l.get(0).isEmpty()); + assertArrayEquals(new long[]{0, 1}, l.get(0).getShape()); - op.addOutputArgument(empty); + op.addOutputArgument(Nd4j.create(DataType.FLOAT, 0, 1)); + Nd4j.exec(op); + + + op = DynamicCustomOp.builder("concat") + .addInputs(empty01, empty01, one1, empty01) + .addIntegerArguments(0) //axis = 0 + .build(); + l = op.calculateOutputShape(); + assertEquals(1, l.size()); + assertFalse(l.get(0).isEmpty()); + assertArrayEquals(new long[]{1, 1}, l.get(0).getShape()); + op.addOutputArgument(Nd4j.create(DataType.FLOAT, 1, 1)); Nd4j.exec(op); } @@ -2104,21 +2113,21 @@ public class ShapeOpValidation extends BaseOpValidation { public void testEmptyGather(){ /* tf.reset_default_graph() - # Hack to create empty array - input = tf.constant([False], dtype=tf.bool) - empty = tf.where(condition=input) - emptyFloat = tf.cast(empty, tf.float32) - emptyInt = tf.cast(empty, tf.int32) + inputFloat = tf.constant([], shape=[0,2,3], dtype=tf.float32) + emptyInt = tf.constant([], shape=[0], dtype=tf.int32) - gather = tf.gather(params=emptyFloat, indices=emptyInt) + gather = tf.gather(params=inputFloat, indices=emptyInt) sess = tf.Session() out = sess.run([gather]) print(out[0].shape) print(out[0]); + + > (0, 2, 3) + > [] */ - INDArray emptyFloat = Nd4j.empty(DataType.FLOAT); - INDArray emptyInt = Nd4j.empty(DataType.INT); + INDArray emptyFloat = Nd4j.create(DataType.FLOAT, 0, 2, 3); + INDArray emptyInt = Nd4j.create(DataType.INT, 0); DynamicCustomOp op = DynamicCustomOp.builder("gather") .addInputs(emptyFloat, emptyInt) .build(); @@ -2126,6 +2135,7 @@ public class ShapeOpValidation extends BaseOpValidation { List l = op.calculateOutputShape(); assertEquals(1, l.size()); assertTrue(l.get(0).isEmpty()); + assertArrayEquals(new long[]{0,2,3}, l.get(0).getShape()); INDArray out = Nd4j.empty(DataType.FLOAT); op.addOutputArgument(out); @@ -2259,7 +2269,7 @@ public class ShapeOpValidation extends BaseOpValidation { List l = op.calculateOutputShape(); assertTrue(l.get(0).isEmpty()); - INDArray out = Nd4j.empty(DataType.INT); + INDArray out = Nd4j.create(DataType.INT, 0); op.setOutputArgument(0, out); Nd4j.exec(op); @@ -2278,7 +2288,7 @@ public class ShapeOpValidation extends BaseOpValidation { List l = op.calculateOutputShape(); assertTrue(l.get(0).isEmpty()); - INDArray out = Nd4j.empty(DataType.INT); + INDArray out = Nd4j.create(DataType.INT, 0); op.setOutputArgument(0, out); Nd4j.exec(op); @@ -2295,11 +2305,11 @@ public class ShapeOpValidation extends BaseOpValidation { .build(); List l = op.calculateOutputShape(); - System.out.println(Arrays.toString(l.get(0).getShape())); assertEquals(1, l.size()); + assertArrayEquals(new long[]{0,4}, l.get(0).getShape()); assertTrue(l.get(0).isEmpty()); - op.setOutputArgument(0, Nd4j.empty(DataType.FLOAT)); + op.setOutputArgument(0, Nd4j.create(DataType.FLOAT, 0, 4)); Nd4j.exec(op); } @@ -2312,11 +2322,116 @@ public class ShapeOpValidation extends BaseOpValidation { DynamicCustomOp op = new Fill(shape, value, null); List l = op.calculateOutputShape(); - System.out.println(Arrays.toString(l.get(0).getShape())); assertEquals(1, l.size()); assertTrue(l.get(0).isEmpty()); + assertArrayEquals(new long[]{0,4}, l.get(0).getShape()); - op.setOutputArgument(0, Nd4j.empty(DataType.FLOAT)); + op.setOutputArgument(0, Nd4j.create(DataType.FLOAT, 0, 4)); Nd4j.exec(op); } + + @Test + public void testPermuteShapeDynamicAxis(){ + + DynamicCustomOp op = DynamicCustomOp.builder("permute") + .addInputs(Nd4j.rand(DataType.FLOAT, 3, 4), + Nd4j.createFromArray(1, 0)) + .build(); + List l = op.calculateOutputShape(); + System.out.println(Arrays.toString(l.get(0).getShape())); + assertArrayEquals(new long[]{4, 3}, l.get(0).getShape()); + + op = DynamicCustomOp.builder("permute") + .addInputs(Nd4j.rand(DataType.FLOAT, 3, 4)) + .addIntegerArguments(1, 0) + .build(); + l = op.calculateOutputShape(); + System.out.println(Arrays.toString(l.get(0).getShape())); + assertArrayEquals(new long[]{4, 3}, l.get(0).getShape()); + + + op = DynamicCustomOp.builder("permute") + .addInputs(Nd4j.rand(DataType.FLOAT, 3, 4, 5), + Nd4j.createFromArray(1, 2, 0)) + .build(); + l = op.calculateOutputShape(); + System.out.println(Arrays.toString(l.get(0).getShape())); + assertArrayEquals(new long[]{4, 5, 3}, l.get(0).getShape()); + } + + @Test + public void testGather2(){ + SameDiff sd = SameDiff.create(); + SDVariable input = sd.var("in", Nd4j.arange(6).castTo(DataType.FLOAT).reshape(2,3)); + SDVariable indices = sd.constant("indices", Nd4j.createFromArray(0)); + + SDVariable gathered = sd.gather(input, indices, 1); + SDVariable loss = gathered.std(true); + + sd.exec(null, gathered.getVarName()); + sd.setLossVariables(gathered.getVarName()); + + String err = OpValidation.validate(new TestCase(sd) + .gradCheckEpsilon(1e-3) + .gradCheckMaxRelativeError(1e-4)); + + assertNull(err); + } + + @Test + public void testPermute3(){ + INDArray in = Nd4j.linspace(DataType.FLOAT, 1, 6, 1).reshape(3,2); + INDArray permute = Nd4j.createFromArray(1,0); + + System.out.println(in); + + SameDiff sd = SameDiff.create(); + SDVariable v = sd.var(in); + SDVariable v2 = sd.constant(permute); + + SDVariable out = v.permute(v2); + + INDArray exp = in.transpose(); + INDArray outArr = out.eval(); + assertEquals(exp, outArr); + } + + @Test + public void testPermute4(){ + Nd4j.getExecutioner().enableDebugMode(true); + Nd4j.getExecutioner().enableVerboseMode(true); + INDArray in = Nd4j.linspace(DataType.FLOAT, 1, 6, 1).reshape(3,2); + INDArray permute = Nd4j.createFromArray(1,0); + + INDArray exp = in.transpose(); + + for( boolean iargs : new boolean[]{true, false}) { + + + DynamicCustomOp.DynamicCustomOpsBuilder b = DynamicCustomOp.builder("permute") + .addInputs(in) + .addOutputs(Nd4j.create(DataType.FLOAT, 2, 3)); + + if(iargs){ + b.addIntegerArguments(1, 0); + } else { + b.addInputs(permute); + } + + DynamicCustomOp op = b.build(); + Nd4j.exec(op); + + System.out.println(in); + System.out.println(op.outputArguments()[0]); + + assertEquals(exp, op.getOutputArgument(0)); + } + } + + @Test + public void testInvertPermutation(){ + DynamicCustomOp op = DynamicCustomOp.builder("invert_permutation") + .addInputs(Nd4j.createFromArray(1, 0)) + .build(); + } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java index d1bff29fe..e4f3bd9ac 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java @@ -1916,4 +1916,19 @@ public class TransformOpValidation extends BaseOpValidation { assertEquals(outCC, outFC); assertEquals(outCC, outFF); } + + @Test + public void testLogSumExp(){ + Nd4j.getRandom().setSeed(12345); + INDArray inputArr = Nd4j.rand(DataType.FLOAT, 1, 4); + SameDiff sd = SameDiff.create(); + SDVariable in = sd.var(inputArr); + SDVariable lse = sd.math().logSumExp(in); + INDArray out = lse.eval(); + + INDArray exp = Transforms.exp(inputArr, true); + INDArray sum = exp.sum(); + INDArray log = Transforms.log(sum); + assertEquals(log, out); + } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/FlatBufferSerdeTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/FlatBufferSerdeTest.java index 97cf4bb7e..1b94afab4 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/FlatBufferSerdeTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/FlatBufferSerdeTest.java @@ -114,7 +114,7 @@ public class FlatBufferSerdeTest extends BaseNd4jTest { assertEquals("in", fg.placeholders(0)); //Check loss variables: - //assertEquals(sd.getLossVariables(), fg) + assertEquals(sd.getLossVariables().size(), fg.lossVariablesLength()); } @Test @@ -209,12 +209,38 @@ public class FlatBufferSerdeTest extends BaseNd4jTest { //Check placeholders Map vBefore = sd.variableMap(); - Map vAfter = sd.variableMap(); + Map vAfter = restored.variableMap(); assertEquals(vBefore.keySet(), vAfter.keySet()); for(String s : vBefore.keySet()){ assertEquals(s, vBefore.get(s).isPlaceHolder(), vAfter.get(s).isPlaceHolder()); assertEquals(s, vBefore.get(s).isConstant(), vAfter.get(s).isConstant()); } + + + //Check save methods + for(boolean withUpdaterState : new boolean[]{false, true}) { + + File f2 = testDir.newFile(); + sd.save(f2, withUpdaterState); + SameDiff r2 = SameDiff.load(f2, withUpdaterState); + assertEquals(varsOrig.size(), r2.variables().size()); + assertEquals(fOrig.length, r2.functions().length); + assertEquals(sd.getLossVariables(), r2.getLossVariables()); + + //Save via stream: + File f3 = testDir.newFile(); + try(OutputStream os = new BufferedOutputStream(new FileOutputStream(f3))){ + sd.save(os, withUpdaterState); + } + + //Load via stream: + try(InputStream is = new BufferedInputStream(new FileInputStream(f3))) { + SameDiff r3 = SameDiff.load(is, withUpdaterState); + assertEquals(varsOrig.size(), r3.variables().size()); + assertEquals(fOrig.length, r3.functions().length); + assertEquals(sd.getLossVariables(), r3.getLossVariables()); + } + } } } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java index b23edfc59..d5b99dbe7 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java @@ -2742,6 +2742,48 @@ public class SameDiffTests extends BaseNd4jTest { sd.fit(new SingletonMultiDataSetIterator(new DataSet(inArr, null).toMultiDataSet()), 1); } + @Test + public void testPlaceholderToConstant() { + Nd4j.getRandom().setSeed(12345); + + SameDiff sd = SameDiff.create(); + SDVariable in = sd.placeHolder("in", DataType.FLOAT, 1, 3); + SDVariable in2 = sd.placeHolder("in2", DataType.FLOAT, 3, 4); + SDVariable b = sd.var("b", Nd4j.rand(DataType.FLOAT, 1, 4)); + SDVariable mmul = in.mmul(in2); + SDVariable add = mmul.add(b); + SDVariable tanh = sd.math().tanh(add); + SDVariable loss = sd.variance(tanh, true); + + INDArray inArr = Nd4j.rand(DataType.FLOAT, 1, 3); + in.setArray(inArr); + INDArray inArr2 = Nd4j.rand(DataType.FLOAT, 3,4); + + TrainingConfig c = TrainingConfig.builder() + .updater(new Adam(0.1)) + .weightDecay(0.01, true) + .dataSetFeatureMapping("in", "in2") + .skipBuilderValidation(true) + .build(); + sd.setTrainingConfig(c); + + + sd.fit(new SingletonMultiDataSetIterator(new MultiDataSet(new INDArray[]{inArr, inArr2}, null)), 1); + + INDArray out = tanh.eval(); + + in.convertToConstant(); + + INDArray out2 = tanh.eval(); + + assertEquals(out, out2); + assertEquals(VariableType.CONSTANT, in.getVariableType()); + assertEquals(inArr, in.getArr()); + + //Sanity check on fitting: + sd.fit(new SingletonMultiDataSetIterator(new MultiDataSet(new INDArray[]{inArr2}, null)), 1); + } + @Test public void testConvertToVariable() { Nd4j.getRandom().setSeed(12345); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/ui/UIListenerTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/ui/UIListenerTest.java index 83f52acf4..2411af627 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/ui/UIListenerTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/ui/UIListenerTest.java @@ -14,13 +14,16 @@ import org.nd4j.graph.UIGraphStructure; import org.nd4j.graph.UIStaticInfoRecord; import org.nd4j.graph.ui.LogFileWriter; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.IrisDataSetIterator; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.Adam; import org.nd4j.linalg.primitives.Pair; import java.io.File; +import java.util.HashMap; import java.util.List; +import java.util.Map; import static org.junit.Assert.*; @@ -54,7 +57,15 @@ public class UIListenerTest { .weightDecay(1e-3, true) .build()); - sd.fit(iter, 30); + sd.fit(iter, 20); + + //Test inference after training with UI Listener still around + Map m = new HashMap<>(); + iter.reset(); + m.put("in", iter.next().getFeatures()); + INDArray out = sd.execSingle(m, "softmax"); + assertNotNull(out); + assertArrayEquals(new long[]{150, 3}, out.shape()); } @Test diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EmptyEvaluationTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EmptyEvaluationTests.java new file mode 100644 index 000000000..4d39fcbae --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EmptyEvaluationTests.java @@ -0,0 +1,138 @@ +package org.nd4j.evaluation; + +import org.junit.Test; +import org.nd4j.evaluation.classification.*; +import org.nd4j.evaluation.regression.RegressionEvaluation; +import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.factory.Nd4jBackend; + +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +public class EmptyEvaluationTests extends BaseNd4jTest { + + public EmptyEvaluationTests(Nd4jBackend backend) { + super(backend); + } + + @Override + public char ordering() { + return 'c'; + } + + @Test + public void testEmptyEvaluation() { + Evaluation e = new Evaluation(); + System.out.println(e.stats()); + + for (Evaluation.Metric m : Evaluation.Metric.values()) { + try { + e.scoreForMetric(m); + fail("Expected exception"); + } catch (Throwable t){ + assertTrue(t.getMessage(), t.getMessage().contains("no evaluation has been performed")); + } + } + } + + @Test + public void testEmptyRegressionEvaluation() { + RegressionEvaluation re = new RegressionEvaluation(); + re.stats(); + + for (RegressionEvaluation.Metric m : RegressionEvaluation.Metric.values()) { + try { + re.scoreForMetric(m); + } catch (Throwable t){ + assertTrue(t.getMessage(), t.getMessage().contains("eval must be called")); + } + } + } + + @Test + public void testEmptyEvaluationBinary() { + EvaluationBinary eb = new EvaluationBinary(); + eb.stats(); + + for (EvaluationBinary.Metric m : EvaluationBinary.Metric.values()) { + try { + eb.scoreForMetric(m, 0); + fail("Expected exception"); + } catch (Throwable t) { + assertTrue(t.getMessage(), t.getMessage().contains("eval must be called")); + } + } + } + + @Test + public void testEmptyROC() { + ROC roc = new ROC(); + roc.stats(); + + for (ROC.Metric m : ROC.Metric.values()) { + try { + roc.scoreForMetric(m); + fail("Expected exception"); + } catch (Throwable t) { + assertTrue(t.getMessage(), t.getMessage().contains("no evaluation")); + } + } + } + + @Test + public void testEmptyROCBinary() { + ROCBinary rb = new ROCBinary(); + rb.stats(); + + for (ROCBinary.Metric m : ROCBinary.Metric.values()) { + try { + rb.scoreForMetric(m, 0); + fail("Expected exception"); + } catch (Throwable t) { + assertTrue(t.getMessage(), t.getMessage().contains("eval must be called")); + } + } + } + + @Test + public void testEmptyROCMultiClass() { + ROCMultiClass r = new ROCMultiClass(); + r.stats(); + + for (ROCMultiClass.Metric m : ROCMultiClass.Metric.values()) { + try { + r.scoreForMetric(m, 0); + fail("Expected exception"); + } catch (Throwable t) { + assertTrue(t.getMessage(), t.getMessage().contains("no data")); + } + } + + } + + @Test + public void testEmptyEvaluationCalibration() { + EvaluationCalibration ec = new EvaluationCalibration(); + ec.stats(); + + try { + ec.getResidualPlot(0); + fail("Expected exception"); + } catch (Throwable t) { + assertTrue(t.getMessage(), t.getMessage().contains("no data")); + } + try { + ec.getProbabilityHistogram(0); + fail("Expected exception"); + } catch (Throwable t) { + assertTrue(t.getMessage(), t.getMessage().contains("no data")); + } + try { + ec.getReliabilityDiagram(0); + fail("Expected exception"); + } catch (Throwable t) { + assertTrue(t.getMessage(), t.getMessage().contains("no data")); + } + } + +} diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvaluationBinaryTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvaluationBinaryTest.java index 370499e2b..c864f6004 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvaluationBinaryTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvaluationBinaryTest.java @@ -21,12 +21,17 @@ import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.evaluation.classification.EvaluationBinary; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.iter.NdIndexIterator; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; +import org.nd4j.linalg.indexing.INDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex; +import java.util.ArrayList; +import java.util.List; + import static org.junit.Assert.assertEquals; import static org.nd4j.evaluation.classification.EvaluationBinary.Metric.*; /** @@ -237,4 +242,200 @@ public class EvaluationBinaryTest extends BaseNd4jTest { System.out.println(eb.stats()); } + + + @Test + public void testEvaluationBinary3d() { + INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 5, 10); + INDArray label = Nd4j.rand(DataType.FLOAT, 2, 5, 10); + + + List rowsP = new ArrayList<>(); + List rowsL = new ArrayList<>(); + NdIndexIterator iter = new NdIndexIterator(2, 10); + while (iter.hasNext()) { + long[] idx = iter.next(); + INDArrayIndex[] idxs = new INDArrayIndex[]{NDArrayIndex.point(idx[0]), NDArrayIndex.all(), NDArrayIndex.point(idx[1])}; + rowsP.add(prediction.get(idxs)); + rowsL.add(label.get(idxs)); + } + + INDArray p2d = Nd4j.vstack(rowsP); + INDArray l2d = Nd4j.vstack(rowsL); + + EvaluationBinary e3d = new EvaluationBinary(); + EvaluationBinary e2d = new EvaluationBinary(); + + e3d.eval(label, prediction); + e2d.eval(l2d, p2d); + + for (EvaluationBinary.Metric m : EvaluationBinary.Metric.values()) { + for( int i=0; i<5; i++ ) { + double d1 = e3d.scoreForMetric(m, i); + double d2 = e2d.scoreForMetric(m, i); + assertEquals(m.toString(), d2, d1, 1e-6); + } + } + } + + @Test + public void testEvaluationBinary4d() { + INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10); + INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10); + + + List rowsP = new ArrayList<>(); + List rowsL = new ArrayList<>(); + NdIndexIterator iter = new NdIndexIterator(2, 10, 10); + while (iter.hasNext()) { + long[] idx = iter.next(); + INDArrayIndex[] idxs = new INDArrayIndex[]{NDArrayIndex.point(idx[0]), NDArrayIndex.all(), NDArrayIndex.point(idx[1]), NDArrayIndex.point(idx[2])}; + rowsP.add(prediction.get(idxs)); + rowsL.add(label.get(idxs)); + } + + INDArray p2d = Nd4j.vstack(rowsP); + INDArray l2d = Nd4j.vstack(rowsL); + + EvaluationBinary e4d = new EvaluationBinary(); + EvaluationBinary e2d = new EvaluationBinary(); + + e4d.eval(label, prediction); + e2d.eval(l2d, p2d); + + for (EvaluationBinary.Metric m : EvaluationBinary.Metric.values()) { + for( int i=0; i<3; i++ ) { + double d1 = e4d.scoreForMetric(m, i); + double d2 = e2d.scoreForMetric(m, i); + assertEquals(m.toString(), d2, d1, 1e-6); + } + } + } + + @Test + public void testEvaluationBinary3dMasking() { + INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10); + INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10); + + List rowsP = new ArrayList<>(); + List rowsL = new ArrayList<>(); + + //Check "DL4J-style" 2d per timestep masking [minibatch, seqLength] mask shape + INDArray mask2d = Nd4j.randomBernoulli(0.5, 2, 10); + rowsP.clear(); + rowsL.clear(); + NdIndexIterator iter = new NdIndexIterator(2, 10); + while (iter.hasNext()) { + long[] idx = iter.next(); + if(mask2d.getDouble(idx[0], idx[1]) != 0.0) { + INDArrayIndex[] idxs = new INDArrayIndex[]{NDArrayIndex.point(idx[0]), NDArrayIndex.all(), NDArrayIndex.point(idx[1])}; + rowsP.add(prediction.get(idxs)); + rowsL.add(label.get(idxs)); + } + } + INDArray p2d = Nd4j.vstack(rowsP); + INDArray l2d = Nd4j.vstack(rowsL); + + EvaluationBinary e3d_m2d = new EvaluationBinary(); + EvaluationBinary e2d_m2d = new EvaluationBinary(); + e3d_m2d.eval(label, prediction, mask2d); + e2d_m2d.eval(l2d, p2d); + + + + //Check per-output masking: + INDArray perOutMask = Nd4j.randomBernoulli(0.5, label.shape()); + rowsP.clear(); + rowsL.clear(); + List rowsM = new ArrayList<>(); + iter = new NdIndexIterator(2, 10); + while (iter.hasNext()) { + long[] idx = iter.next(); + INDArrayIndex[] idxs = new INDArrayIndex[]{NDArrayIndex.point(idx[0]), NDArrayIndex.all(), NDArrayIndex.point(idx[1])}; + rowsP.add(prediction.get(idxs)); + rowsL.add(label.get(idxs)); + rowsM.add(perOutMask.get(idxs)); + } + p2d = Nd4j.vstack(rowsP); + l2d = Nd4j.vstack(rowsL); + INDArray m2d = Nd4j.vstack(rowsM); + + EvaluationBinary e4d_m2 = new EvaluationBinary(); + EvaluationBinary e2d_m2 = new EvaluationBinary(); + e4d_m2.eval(label, prediction, perOutMask); + e2d_m2.eval(l2d, p2d, m2d); + for(EvaluationBinary.Metric m : EvaluationBinary.Metric.values()){ + for(int i=0; i<3; i++ ) { + double d1 = e4d_m2.scoreForMetric(m, i); + double d2 = e2d_m2.scoreForMetric(m, i); + assertEquals(m.toString(), d2, d1, 1e-6); + } + } + } + + @Test + public void testEvaluationBinary4dMasking() { + INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10); + INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10); + + List rowsP = new ArrayList<>(); + List rowsL = new ArrayList<>(); + + //Check per-example masking: + INDArray mask1dPerEx = Nd4j.createFromArray(1, 0); + + NdIndexIterator iter = new NdIndexIterator(2, 10, 10); + while (iter.hasNext()) { + long[] idx = iter.next(); + if(mask1dPerEx.getDouble(idx[0]) != 0.0) { + INDArrayIndex[] idxs = new INDArrayIndex[]{NDArrayIndex.point(idx[0]), NDArrayIndex.all(), NDArrayIndex.point(idx[1]), NDArrayIndex.point(idx[2])}; + rowsP.add(prediction.get(idxs)); + rowsL.add(label.get(idxs)); + } + } + + INDArray p2d = Nd4j.vstack(rowsP); + INDArray l2d = Nd4j.vstack(rowsL); + + EvaluationBinary e4d_m1 = new EvaluationBinary(); + EvaluationBinary e2d_m1 = new EvaluationBinary(); + e4d_m1.eval(label, prediction, mask1dPerEx); + e2d_m1.eval(l2d, p2d); + for(EvaluationBinary.Metric m : EvaluationBinary.Metric.values()){ + for( int i=0; i<3; i++ ) { + double d1 = e4d_m1.scoreForMetric(m, i); + double d2 = e2d_m1.scoreForMetric(m, i); + assertEquals(m.toString(), d2, d1, 1e-6); + } + } + + //Check per-output masking: + INDArray perOutMask = Nd4j.randomBernoulli(0.5, label.shape()); + rowsP.clear(); + rowsL.clear(); + List rowsM = new ArrayList<>(); + iter = new NdIndexIterator(2, 10, 10); + while (iter.hasNext()) { + long[] idx = iter.next(); + INDArrayIndex[] idxs = new INDArrayIndex[]{NDArrayIndex.point(idx[0]), NDArrayIndex.all(), NDArrayIndex.point(idx[1]), NDArrayIndex.point(idx[2])}; + rowsP.add(prediction.get(idxs)); + rowsL.add(label.get(idxs)); + rowsM.add(perOutMask.get(idxs)); + } + p2d = Nd4j.vstack(rowsP); + l2d = Nd4j.vstack(rowsL); + INDArray m2d = Nd4j.vstack(rowsM); + + EvaluationBinary e3d_m2 = new EvaluationBinary(); + EvaluationBinary e2d_m2 = new EvaluationBinary(); + e3d_m2.eval(label, prediction, perOutMask); + e2d_m2.eval(l2d, p2d, m2d); + for(EvaluationBinary.Metric m : EvaluationBinary.Metric.values()){ + for( int i=0; i<3; i++) { + double d1 = e3d_m2.scoreForMetric(m, i); + double d2 = e2d_m2.scoreForMetric(m, i); + assertEquals(m.toString(), d2, d1, 1e-6); + } + } + } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvaluationCalibrationTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvaluationCalibrationTest.java index 2c4de9442..219ccc19c 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvaluationCalibrationTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/EvaluationCalibrationTest.java @@ -21,13 +21,17 @@ import org.nd4j.evaluation.classification.EvaluationBinary; import org.nd4j.evaluation.classification.EvaluationCalibration; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.iter.NdIndexIterator; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; +import org.nd4j.linalg.indexing.INDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.ops.transforms.Transforms; +import java.util.ArrayList; +import java.util.List; import java.util.Random; import static org.junit.Assert.assertEquals; @@ -362,4 +366,66 @@ public class EvaluationCalibrationTest extends BaseNd4jTest { } } } + + @Test + public void testEvaluationCalibration3d() { + INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 5, 10); + INDArray label = Nd4j.rand(DataType.FLOAT, 2, 5, 10); + + + List rowsP = new ArrayList<>(); + List rowsL = new ArrayList<>(); + NdIndexIterator iter = new NdIndexIterator(2, 10); + while (iter.hasNext()) { + long[] idx = iter.next(); + INDArrayIndex[] idxs = new INDArrayIndex[]{NDArrayIndex.point(idx[0]), NDArrayIndex.all(), NDArrayIndex.point(idx[1])}; + rowsP.add(prediction.get(idxs)); + rowsL.add(label.get(idxs)); + } + + INDArray p2d = Nd4j.vstack(rowsP); + INDArray l2d = Nd4j.vstack(rowsL); + + EvaluationCalibration e3d = new EvaluationCalibration(); + EvaluationCalibration e2d = new EvaluationCalibration(); + + e3d.eval(label, prediction); + e2d.eval(l2d, p2d); + + System.out.println(e2d.stats()); + + assertEquals(e2d, e3d); + + assertEquals(e2d.stats(), e3d.stats()); + } + + @Test + public void testEvaluationCalibration3dMasking() { + INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10); + INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10); + + List rowsP = new ArrayList<>(); + List rowsL = new ArrayList<>(); + + //Check "DL4J-style" 2d per timestep masking [minibatch, seqLength] mask shape + INDArray mask2d = Nd4j.randomBernoulli(0.5, 2, 10); + NdIndexIterator iter = new NdIndexIterator(2, 10); + while (iter.hasNext()) { + long[] idx = iter.next(); + if(mask2d.getDouble(idx[0], idx[1]) != 0.0) { + INDArrayIndex[] idxs = new INDArrayIndex[]{NDArrayIndex.point(idx[0]), NDArrayIndex.all(), NDArrayIndex.point(idx[1])}; + rowsP.add(prediction.get(idxs)); + rowsL.add(label.get(idxs)); + } + } + INDArray p2d = Nd4j.vstack(rowsP); + INDArray l2d = Nd4j.vstack(rowsL); + + EvaluationCalibration e3d_m2d = new EvaluationCalibration(); + EvaluationCalibration e2d_m2d = new EvaluationCalibration(); + e3d_m2d.eval(label, prediction, mask2d); + e2d_m2d.eval(l2d, p2d); + + assertEquals(e3d_m2d, e2d_m2d); + } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/ROCBinaryTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/ROCBinaryTest.java index 717cd1396..e6b48957b 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/ROCBinaryTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/ROCBinaryTest.java @@ -24,10 +24,16 @@ import org.nd4j.evaluation.curves.PrecisionRecallCurve; import org.nd4j.evaluation.regression.RegressionEvaluation; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.iter.NdIndexIterator; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; +import org.nd4j.linalg.indexing.INDArrayIndex; +import org.nd4j.linalg.indexing.NDArrayIndex; + +import java.util.ArrayList; +import java.util.List; import static org.junit.Assert.assertEquals; @@ -207,4 +213,201 @@ public class ROCBinaryTest extends BaseNd4jTest { } } } + + + + @Test + public void testROCBinary3d() { + INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 5, 10); + INDArray label = Nd4j.rand(DataType.FLOAT, 2, 5, 10); + + + List rowsP = new ArrayList<>(); + List rowsL = new ArrayList<>(); + NdIndexIterator iter = new NdIndexIterator(2, 10); + while (iter.hasNext()) { + long[] idx = iter.next(); + INDArrayIndex[] idxs = new INDArrayIndex[]{NDArrayIndex.point(idx[0]), NDArrayIndex.all(), NDArrayIndex.point(idx[1])}; + rowsP.add(prediction.get(idxs)); + rowsL.add(label.get(idxs)); + } + + INDArray p2d = Nd4j.vstack(rowsP); + INDArray l2d = Nd4j.vstack(rowsL); + + ROCBinary e3d = new ROCBinary(); + ROCBinary e2d = new ROCBinary(); + + e3d.eval(label, prediction); + e2d.eval(l2d, p2d); + + for (ROCBinary.Metric m : ROCBinary.Metric.values()) { + for( int i=0; i<5; i++ ) { + double d1 = e3d.scoreForMetric(m, i); + double d2 = e2d.scoreForMetric(m, i); + assertEquals(m.toString(), d2, d1, 1e-6); + } + } + } + + @Test + public void testROCBinary4d() { + INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10); + INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10); + + + List rowsP = new ArrayList<>(); + List rowsL = new ArrayList<>(); + NdIndexIterator iter = new NdIndexIterator(2, 10, 10); + while (iter.hasNext()) { + long[] idx = iter.next(); + INDArrayIndex[] idxs = new INDArrayIndex[]{NDArrayIndex.point(idx[0]), NDArrayIndex.all(), NDArrayIndex.point(idx[1]), NDArrayIndex.point(idx[2])}; + rowsP.add(prediction.get(idxs)); + rowsL.add(label.get(idxs)); + } + + INDArray p2d = Nd4j.vstack(rowsP); + INDArray l2d = Nd4j.vstack(rowsL); + + ROCBinary e4d = new ROCBinary(); + ROCBinary e2d = new ROCBinary(); + + e4d.eval(label, prediction); + e2d.eval(l2d, p2d); + + for (ROCBinary.Metric m : ROCBinary.Metric.values()) { + for( int i=0; i<3; i++ ) { + double d1 = e4d.scoreForMetric(m, i); + double d2 = e2d.scoreForMetric(m, i); + assertEquals(m.toString(), d2, d1, 1e-6); + } + } + } + + @Test + public void testROCBinary3dMasking() { + INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10); + INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10); + + List rowsP = new ArrayList<>(); + List rowsL = new ArrayList<>(); + + //Check "DL4J-style" 2d per timestep masking [minibatch, seqLength] mask shape + INDArray mask2d = Nd4j.randomBernoulli(0.5, 2, 10); + rowsP.clear(); + rowsL.clear(); + NdIndexIterator iter = new NdIndexIterator(2, 10); + while (iter.hasNext()) { + long[] idx = iter.next(); + if(mask2d.getDouble(idx[0], idx[1]) != 0.0) { + INDArrayIndex[] idxs = new INDArrayIndex[]{NDArrayIndex.point(idx[0]), NDArrayIndex.all(), NDArrayIndex.point(idx[1])}; + rowsP.add(prediction.get(idxs)); + rowsL.add(label.get(idxs)); + } + } + INDArray p2d = Nd4j.vstack(rowsP); + INDArray l2d = Nd4j.vstack(rowsL); + + ROCBinary e3d_m2d = new ROCBinary(); + ROCBinary e2d_m2d = new ROCBinary(); + e3d_m2d.eval(label, prediction, mask2d); + e2d_m2d.eval(l2d, p2d); + + + + //Check per-output masking: + INDArray perOutMask = Nd4j.randomBernoulli(0.5, label.shape()); + rowsP.clear(); + rowsL.clear(); + List rowsM = new ArrayList<>(); + iter = new NdIndexIterator(2, 10); + while (iter.hasNext()) { + long[] idx = iter.next(); + INDArrayIndex[] idxs = new INDArrayIndex[]{NDArrayIndex.point(idx[0]), NDArrayIndex.all(), NDArrayIndex.point(idx[1])}; + rowsP.add(prediction.get(idxs)); + rowsL.add(label.get(idxs)); + rowsM.add(perOutMask.get(idxs)); + } + p2d = Nd4j.vstack(rowsP); + l2d = Nd4j.vstack(rowsL); + INDArray m2d = Nd4j.vstack(rowsM); + + ROCBinary e4d_m2 = new ROCBinary(); + ROCBinary e2d_m2 = new ROCBinary(); + e4d_m2.eval(label, prediction, perOutMask); + e2d_m2.eval(l2d, p2d, m2d); + for(ROCBinary.Metric m : ROCBinary.Metric.values()){ + for(int i=0; i<3; i++ ) { + double d1 = e4d_m2.scoreForMetric(m, i); + double d2 = e2d_m2.scoreForMetric(m, i); + assertEquals(m.toString(), d2, d1, 1e-6); + } + } + } + + @Test + public void testROCBinary4dMasking() { + INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10); + INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10); + + List rowsP = new ArrayList<>(); + List rowsL = new ArrayList<>(); + + //Check per-example masking: + INDArray mask1dPerEx = Nd4j.createFromArray(1, 0); + + NdIndexIterator iter = new NdIndexIterator(2, 10, 10); + while (iter.hasNext()) { + long[] idx = iter.next(); + if(mask1dPerEx.getDouble(idx[0]) != 0.0) { + INDArrayIndex[] idxs = new INDArrayIndex[]{NDArrayIndex.point(idx[0]), NDArrayIndex.all(), NDArrayIndex.point(idx[1]), NDArrayIndex.point(idx[2])}; + rowsP.add(prediction.get(idxs)); + rowsL.add(label.get(idxs)); + } + } + + INDArray p2d = Nd4j.vstack(rowsP); + INDArray l2d = Nd4j.vstack(rowsL); + + ROCBinary e4d_m1 = new ROCBinary(); + ROCBinary e2d_m1 = new ROCBinary(); + e4d_m1.eval(label, prediction, mask1dPerEx); + e2d_m1.eval(l2d, p2d); + for(ROCBinary.Metric m : ROCBinary.Metric.values()){ + for( int i=0; i<3; i++ ) { + double d1 = e4d_m1.scoreForMetric(m, i); + double d2 = e2d_m1.scoreForMetric(m, i); + assertEquals(m.toString(), d2, d1, 1e-6); + } + } + + //Check per-output masking: + INDArray perOutMask = Nd4j.randomBernoulli(0.5, label.shape()); + rowsP.clear(); + rowsL.clear(); + List rowsM = new ArrayList<>(); + iter = new NdIndexIterator(2, 10, 10); + while (iter.hasNext()) { + long[] idx = iter.next(); + INDArrayIndex[] idxs = new INDArrayIndex[]{NDArrayIndex.point(idx[0]), NDArrayIndex.all(), NDArrayIndex.point(idx[1]), NDArrayIndex.point(idx[2])}; + rowsP.add(prediction.get(idxs)); + rowsL.add(label.get(idxs)); + rowsM.add(perOutMask.get(idxs)); + } + p2d = Nd4j.vstack(rowsP); + l2d = Nd4j.vstack(rowsL); + INDArray m2d = Nd4j.vstack(rowsM); + + ROCBinary e3d_m2 = new ROCBinary(); + ROCBinary e2d_m2 = new ROCBinary(); + e3d_m2.eval(label, prediction, perOutMask); + e2d_m2.eval(l2d, p2d, m2d); + for(ROCBinary.Metric m : ROCBinary.Metric.values()){ + for( int i=0; i<3; i++) { + double d1 = e3d_m2.scoreForMetric(m, i); + double d2 = e2d_m2.scoreForMetric(m, i); + assertEquals(m.toString(), d2, d1, 1e-6); + } + } + } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/RegressionEvalTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/RegressionEvalTest.java index 4d01d67ea..92acb1a20 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/RegressionEvalTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/evaluation/RegressionEvalTest.java @@ -21,9 +21,12 @@ import org.nd4j.evaluation.classification.EvaluationCalibration; import org.nd4j.evaluation.regression.RegressionEvaluation; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.iter.NdIndexIterator; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; +import org.nd4j.linalg.indexing.INDArrayIndex; +import org.nd4j.linalg.indexing.NDArrayIndex; import java.util.ArrayList; import java.util.List; @@ -46,7 +49,7 @@ public class RegressionEvalTest extends BaseNd4jTest { return 'c'; } - @Test(expected = IllegalArgumentException.class) + @Test(expected = IllegalStateException.class) public void testEvalParameters() { int specCols = 5; INDArray labels = Nd4j.ones(3); @@ -227,4 +230,189 @@ public class RegressionEvalTest extends BaseNd4jTest { assertEquals(e1, e2); } + + @Test + public void testRegressionEval3d() { + INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 5, 10); + INDArray label = Nd4j.rand(DataType.FLOAT, 2, 5, 10); + + + List rowsP = new ArrayList<>(); + List rowsL = new ArrayList<>(); + NdIndexIterator iter = new NdIndexIterator(2, 10); + while (iter.hasNext()) { + long[] idx = iter.next(); + INDArrayIndex[] idxs = new INDArrayIndex[]{NDArrayIndex.point(idx[0]), NDArrayIndex.all(), NDArrayIndex.point(idx[1])}; + rowsP.add(prediction.get(idxs)); + rowsL.add(label.get(idxs)); + } + + INDArray p2d = Nd4j.vstack(rowsP); + INDArray l2d = Nd4j.vstack(rowsL); + + RegressionEvaluation e3d = new RegressionEvaluation(); + RegressionEvaluation e2d = new RegressionEvaluation(); + + e3d.eval(label, prediction); + e2d.eval(l2d, p2d); + + for (RegressionEvaluation.Metric m : RegressionEvaluation.Metric.values()) { + double d1 = e3d.scoreForMetric(m); + double d2 = e2d.scoreForMetric(m); + assertEquals(m.toString(), d2, d1, 1e-6); + } + } + + @Test + public void testRegressionEval4d() { + INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10); + INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10); + + + List rowsP = new ArrayList<>(); + List rowsL = new ArrayList<>(); + NdIndexIterator iter = new NdIndexIterator(2, 10, 10); + while (iter.hasNext()) { + long[] idx = iter.next(); + INDArrayIndex[] idxs = new INDArrayIndex[]{NDArrayIndex.point(idx[0]), NDArrayIndex.all(), NDArrayIndex.point(idx[1]), NDArrayIndex.point(idx[2])}; + rowsP.add(prediction.get(idxs)); + rowsL.add(label.get(idxs)); + } + + INDArray p2d = Nd4j.vstack(rowsP); + INDArray l2d = Nd4j.vstack(rowsL); + + RegressionEvaluation e4d = new RegressionEvaluation(); + RegressionEvaluation e2d = new RegressionEvaluation(); + + e4d.eval(label, prediction); + e2d.eval(l2d, p2d); + + for (RegressionEvaluation.Metric m : RegressionEvaluation.Metric.values()) { + double d1 = e4d.scoreForMetric(m); + double d2 = e2d.scoreForMetric(m); + assertEquals(m.toString(), d2, d1, 1e-6); + } + } + + @Test + public void testRegressionEval3dMasking() { + INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10); + INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10); + + List rowsP = new ArrayList<>(); + List rowsL = new ArrayList<>(); + + //Check "DL4J-style" 2d per timestep masking [minibatch, seqLength] mask shape + INDArray mask2d = Nd4j.randomBernoulli(0.5, 2, 10); + rowsP.clear(); + rowsL.clear(); + NdIndexIterator iter = new NdIndexIterator(2, 10); + while (iter.hasNext()) { + long[] idx = iter.next(); + if(mask2d.getDouble(idx[0], idx[1]) != 0.0) { + INDArrayIndex[] idxs = new INDArrayIndex[]{NDArrayIndex.point(idx[0]), NDArrayIndex.all(), NDArrayIndex.point(idx[1])}; + rowsP.add(prediction.get(idxs)); + rowsL.add(label.get(idxs)); + } + } + INDArray p2d = Nd4j.vstack(rowsP); + INDArray l2d = Nd4j.vstack(rowsL); + + RegressionEvaluation e3d_m2d = new RegressionEvaluation(); + RegressionEvaluation e2d_m2d = new RegressionEvaluation(); + e3d_m2d.eval(label, prediction, mask2d); + e2d_m2d.eval(l2d, p2d); + + + + //Check per-output masking: + INDArray perOutMask = Nd4j.randomBernoulli(0.5, label.shape()); + rowsP.clear(); + rowsL.clear(); + List rowsM = new ArrayList<>(); + iter = new NdIndexIterator(2, 10); + while (iter.hasNext()) { + long[] idx = iter.next(); + INDArrayIndex[] idxs = new INDArrayIndex[]{NDArrayIndex.point(idx[0]), NDArrayIndex.all(), NDArrayIndex.point(idx[1])}; + rowsP.add(prediction.get(idxs)); + rowsL.add(label.get(idxs)); + rowsM.add(perOutMask.get(idxs)); + } + p2d = Nd4j.vstack(rowsP); + l2d = Nd4j.vstack(rowsL); + INDArray m2d = Nd4j.vstack(rowsM); + + RegressionEvaluation e4d_m2 = new RegressionEvaluation(); + RegressionEvaluation e2d_m2 = new RegressionEvaluation(); + e4d_m2.eval(label, prediction, perOutMask); + e2d_m2.eval(l2d, p2d, m2d); + for(RegressionEvaluation.Metric m : RegressionEvaluation.Metric.values()){ + double d1 = e4d_m2.scoreForMetric(m); + double d2 = e2d_m2.scoreForMetric(m); + assertEquals(m.toString(), d2, d1, 1e-6); + } + } + + @Test + public void testRegressionEval4dMasking() { + INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10); + INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10); + + List rowsP = new ArrayList<>(); + List rowsL = new ArrayList<>(); + + //Check per-example masking: + INDArray mask1dPerEx = Nd4j.createFromArray(1, 0); + + NdIndexIterator iter = new NdIndexIterator(2, 10, 10); + while (iter.hasNext()) { + long[] idx = iter.next(); + if(mask1dPerEx.getDouble(idx[0]) != 0.0) { + INDArrayIndex[] idxs = new INDArrayIndex[]{NDArrayIndex.point(idx[0]), NDArrayIndex.all(), NDArrayIndex.point(idx[1]), NDArrayIndex.point(idx[2])}; + rowsP.add(prediction.get(idxs)); + rowsL.add(label.get(idxs)); + } + } + + INDArray p2d = Nd4j.vstack(rowsP); + INDArray l2d = Nd4j.vstack(rowsL); + + RegressionEvaluation e4d_m1 = new RegressionEvaluation(); + RegressionEvaluation e2d_m1 = new RegressionEvaluation(); + e4d_m1.eval(label, prediction, mask1dPerEx); + e2d_m1.eval(l2d, p2d); + for(RegressionEvaluation.Metric m : RegressionEvaluation.Metric.values()){ + double d1 = e4d_m1.scoreForMetric(m); + double d2 = e2d_m1.scoreForMetric(m); + assertEquals(m.toString(), d2, d1, 1e-6); + } + + //Check per-output masking: + INDArray perOutMask = Nd4j.randomBernoulli(0.5, label.shape()); + rowsP.clear(); + rowsL.clear(); + List rowsM = new ArrayList<>(); + iter = new NdIndexIterator(2, 10, 10); + while (iter.hasNext()) { + long[] idx = iter.next(); + INDArrayIndex[] idxs = new INDArrayIndex[]{NDArrayIndex.point(idx[0]), NDArrayIndex.all(), NDArrayIndex.point(idx[1]), NDArrayIndex.point(idx[2])}; + rowsP.add(prediction.get(idxs)); + rowsL.add(label.get(idxs)); + rowsM.add(perOutMask.get(idxs)); + } + p2d = Nd4j.vstack(rowsP); + l2d = Nd4j.vstack(rowsL); + INDArray m2d = Nd4j.vstack(rowsM); + + RegressionEvaluation e4d_m2 = new RegressionEvaluation(); + RegressionEvaluation e2d_m2 = new RegressionEvaluation(); + e4d_m2.eval(label, prediction, perOutMask); + e2d_m2.eval(l2d, p2d, m2d); + for(RegressionEvaluation.Metric m : RegressionEvaluation.Metric.values()){ + double d1 = e4d_m2.scoreForMetric(m); + double d2 = e2d_m2.scoreForMetric(m); + assertEquals(m.toString(), d2, d1, 1e-6); + } + } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java index 4aa067de9..f012fb0f7 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java @@ -29,6 +29,7 @@ import org.nd4j.autodiff.execution.conf.ExecutionMode; import org.nd4j.autodiff.execution.conf.ExecutorConfiguration; import org.nd4j.autodiff.execution.conf.OutputMode; import org.nd4j.autodiff.functions.DifferentialFunction; +import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.internal.SameDiffOp; import org.nd4j.autodiff.validation.OpValidation; @@ -36,6 +37,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.iter.NdIndexIterator; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.executioner.OpExecutioner; import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition; @@ -171,7 +173,7 @@ public class TFGraphTestAllHelper { if(maxRelErrorOverride == null) { long[] sTf = tfPred.shape(); long[] sNd4j = nd4jPred.shape(); - assertArrayEquals("Shapes for node \"" + outputNode + "\" are not equal: " + Arrays.toString(sTf) + " vs " + Arrays.toString(sNd4j), sTf, sNd4j); + assertArrayEquals("Shapes for node \"" + outputNode + "\" are not equal: TF: " + Arrays.toString(sTf) + " vs SD: " + Arrays.toString(sNd4j), sTf, sNd4j); // TODO: once we add more dtypes files - this should be removed if (tfPred.dataType() != nd4jPred.dataType()) @@ -179,14 +181,37 @@ public class TFGraphTestAllHelper { boolean eq = tfPred.equals(nd4jPred); if(!eq){ - NDArrayStrings s = new NDArrayStrings(); - String s1 = s.format(tfPred, false); - String s2 = s.format(nd4jPred, false); - System.out.print("TF: "); - System.out.println(s1); - System.out.print("SD: "); - System.out.println(s2); + //Check for both NaN, both inf + if(tfPred.dataType().isFPType() && tfPred.equalShapes(nd4jPred) && tfPred.isNaN().castTo(DataType.INT).sumNumber().intValue() == tfPred.length() + && nd4jPred.isNaN().castTo(DataType.INT).sumNumber().intValue() == nd4jPred.length()){ + //All NaNs in both arrays + eq = true; + } else if(tfPred.dataType().isFPType() && tfPred.equalShapes(nd4jPred) && tfPred.isInfinite().castTo(DataType.INT).sumNumber().intValue() == tfPred.length() + && nd4jPred.isInfinite().castTo(DataType.INT).sumNumber().intValue() == nd4jPred.length()){ + //All infinite in both arrays. But need to check that it's all positive vs. negative infinite in both cases... + NdIndexIterator iter = new NdIndexIterator(tfPred.shape()); + eq = true; + while(iter.hasNext()){ + long[] next = iter.next(); + //Already know they are both infinite, only question is whether they are both positive and negative + double d1 = tfPred.getDouble(next); + double d2 = nd4jPred.getDouble(next); + if((d1 > 0) != (d2 > 0)){ + eq = false; + break; + } + } + } + if(!eq) { + NDArrayStrings s = new NDArrayStrings(); + String s1 = s.format(tfPred, false); + String s2 = s.format(nd4jPred, false); + System.out.print("TF: "); + System.out.println(s1); + System.out.print("SD: "); + System.out.println(s2); + } } assertTrue("Predictions do not match on " + modelName + ", node " + outputNode, eq); } else { @@ -335,10 +360,19 @@ public class TFGraphTestAllHelper { // = TFGraphMapper.getInstance().importGraph(new ClassPathResource(baseDir + "/" + modelName + "/" + modelFilename).getInputStream()); //System.out.println(graph.summary()); if (executeWith.equals(ExecuteWith.SAMEDIFF)) { + List outputs = graph.outputs(); + if(outputs.isEmpty()){ + //Edge case: no ops + List vars = graph.variables(); + outputs = new ArrayList<>(); + for(SDVariable v : vars) { + outputs.add(v.getVarName()); + } + } if (!inputs.isEmpty()) { - graph.exec(inputs, graph.outputs()); //This is expected to be just one result + graph.exec(inputs, outputs); //This is expected to be just one result } else { - graph.exec(Collections.emptyMap(), graph.outputs()); //there are graphs with no placeholders like g_00 + graph.exec(Collections.emptyMap(), outputs); //there are graphs with no placeholders like g_00 } } else if (executeWith.equals(ExecuteWith.LIBND4J)) { for (String input : inputs.keySet()) { @@ -609,8 +643,13 @@ public class TFGraphTestAllHelper { } if (content.isEmpty()) { - if (varShape.length == 1 && varShape[0] == 0) { - varValue = Nd4j.empty(type); + //Should be zeros in shape + boolean foundZero = false; + for( int s : varShape){ + foundZero |= (s == 0); + } + if(foundZero){ + varValue = Nd4j.create(type, ArrayUtil.toLongArray(varShape)); } else { throw new IllegalStateException("Empty data but non-empty shape: " + resources.get(i).getSecond()); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/listener/ImportDebugListener.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/listener/ImportDebugListener.java index 0eb06da30..7722f1a2b 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/listener/ImportDebugListener.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/listener/ImportDebugListener.java @@ -35,7 +35,7 @@ public class ImportDebugListener extends BaseListener { } @Override - public void opExecution(SameDiff sd, At at, SameDiffOp op, INDArray[] outputs) { + public void opExecution(SameDiff sd, At at, boolean training, SameDiffOp op, INDArray[] outputs) { //No op for( int i=0; i shapes = Arrays.asList(new long[]{3, 4}, new long[]{3, 1}, new long[]{1,3}); + for(long[] shape : shapes){ + long length = ArrayUtil.prodLong(shape); + INDArray orig = Nd4j.arange(length).castTo(DataType.DOUBLE).reshape(shape); + for(DataType dt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF, DataType.INT, + DataType.LONG, DataType.SHORT, DataType.UBYTE, DataType.UINT16, DataType.UINT32, DataType.UINT64, DataType.BFLOAT16}) { + INDArray arr = orig.castTo(dt); + + float[][] fArr = arr.toFloatMatrix(); + double[][] dArr = arr.toDoubleMatrix(); + int[][] iArr = arr.toIntMatrix(); + long[][] lArr = arr.toLongMatrix(); + + INDArray f = Nd4j.createFromArray(fArr).castTo(dt); + INDArray d = Nd4j.createFromArray(dArr).castTo(dt); + INDArray i = Nd4j.createFromArray(iArr).castTo(dt); + INDArray l = Nd4j.createFromArray(lArr).castTo(dt); + + assertEquals(arr, f); + assertEquals(arr, d); + assertEquals(arr, i); + assertEquals(arr, l); + } + } + } + + @Test + public void testToXVector(){ + + List shapes = Arrays.asList(new long[]{3}, new long[]{3, 1}, new long[]{1,3}); + for(long[] shape : shapes){ + INDArray orig = Nd4j.arange(3).castTo(DataType.DOUBLE).reshape(shape); + for(DataType dt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF, DataType.INT, + DataType.LONG, DataType.SHORT, DataType.UBYTE, DataType.UINT16, DataType.UINT32, DataType.UINT64, DataType.BFLOAT16}) { + INDArray arr = orig.castTo(dt); + + float[] fArr = arr.toFloatVector(); + double[] dArr = arr.toDoubleVector(); + int[] iArr = arr.toIntVector(); + long[] lArr = arr.toLongVector(); + + INDArray f = Nd4j.createFromArray(fArr).castTo(dt).reshape(shape); + INDArray d = Nd4j.createFromArray(dArr).castTo(dt).reshape(shape); + INDArray i = Nd4j.createFromArray(iArr).castTo(dt).reshape(shape); + INDArray l = Nd4j.createFromArray(lArr).castTo(dt).reshape(shape); + + assertEquals(arr, f); + assertEquals(arr, d); + assertEquals(arr, i); + assertEquals(arr, l); + } + } } @@ -7761,6 +7826,25 @@ public class Nd4jTestsC extends BaseNd4jTest { assertArrayEquals(new long[]{3}, sum2.shape()); } + @Test + public void testMedianEdgeCase(){ + INDArray rowVec = Nd4j.rand(DataType.FLOAT, 1, 10); + INDArray median = rowVec.median(0); + assertEquals(rowVec.reshape(10), median); + + INDArray colVec = Nd4j.rand(DataType.FLOAT, 10, 1); + median = colVec.median(1); + assertEquals(colVec.reshape(10), median); + + //Non-edge cases: + rowVec.median(1); + colVec.median(0); + + //full array case: + rowVec.median(); + colVec.median(); + } + /////////////////////////////////////////////////////// protected static void fillJvmArray3D(float[][][] arr) { int cnt = 1; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/IndexingTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/IndexingTests.java index 0c07f9a7f..ab724ee1b 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/IndexingTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/IndexingTests.java @@ -162,7 +162,7 @@ public class IndexingTests extends BaseNd4jTest { @Test public void testSlicing() { - INDArray arange = Nd4j.arange(1, 17).reshape(4, 4); + INDArray arange = Nd4j.arange(1, 17).reshape(4, 4).castTo(DataType.DOUBLE); INDArray slice1Assert = Nd4j.create(new double[] {2, 6, 10, 14}); INDArray slice1Test = arange.slice(1); assertEquals(slice1Assert, slice1Test); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/IndexingTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/IndexingTestsC.java index d80d46d48..ab05d67d2 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/IndexingTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/indexing/IndexingTestsC.java @@ -86,7 +86,7 @@ public class IndexingTestsC extends BaseNd4jTest { @Test public void testIntervalsIn3D() { - INDArray arr = Nd4j.arange(8).reshape(2, 2, 2); + INDArray arr = Nd4j.arange(8).reshape(2, 2, 2).castTo(DataType.DOUBLE); INDArray assertion = Nd4j.create(new double[][] {{4, 5}, {6, 7}}).reshape(1, 2, 2); INDArray rest = arr.get(interval(1, 2), interval(0, 2), interval(0, 2)); assertEquals(assertion, rest); @@ -95,7 +95,7 @@ public class IndexingTestsC extends BaseNd4jTest { @Test public void testSmallInterval() { - INDArray arr = Nd4j.arange(8).reshape(2, 2, 2); + INDArray arr = Nd4j.arange(8).reshape(2, 2, 2).castTo(DataType.DOUBLE); INDArray assertion = Nd4j.create(new double[][] {{4, 5}, {6, 7}}).reshape(1, 2, 2); INDArray rest = arr.get(interval(1, 2), all(), all()); assertEquals(assertion, rest); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/broadcast/BasicBroadcastTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/broadcast/BasicBroadcastTests.java index 4f492392b..5b4cf2866 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/broadcast/BasicBroadcastTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/broadcast/BasicBroadcastTests.java @@ -25,6 +25,7 @@ import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RealDivOp; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; @@ -154,6 +155,35 @@ public class BasicBroadcastTests extends BaseNd4jTest { assertEquals(e, z); } + @Test + public void emptyBroadcastTest_1() { + val x = Nd4j.create(DataType.FLOAT, 1, 2); + val y = Nd4j.create(DataType.FLOAT, 0, 2); + + val z = x.add(y); + assertEquals(y, z); + } + + @Test(expected = IllegalArgumentException.class) + public void emptyBroadcastTest_2() { + val x = Nd4j.create(DataType.FLOAT, 1, 2); + val y = Nd4j.create(DataType.FLOAT, 0, 2); + + val z = x.addi(y); + assertEquals(y, z); + } + + @Test + public void emptyBroadcastTest_3() { + val x = Nd4j.create(DataType.FLOAT, 1, 0, 1); + val y = Nd4j.create(DataType.FLOAT, 1, 0, 2); + + val op = new RealDivOp(new INDArray[]{x, y}, new INDArray[]{}); + val z = Nd4j.exec(op)[0]; + + assertEquals(y, z); + } + @Override public char ordering() { return 'c'; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/serde/NumpyFormatTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/serde/NumpyFormatTests.java index 42f69dc8e..caec61321 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/serde/NumpyFormatTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/serde/NumpyFormatTests.java @@ -211,6 +211,51 @@ public class NumpyFormatTests extends BaseNd4jTest { assertEquals(exp, arr); } + + @Test + public void testNpy() throws Exception { + for(boolean empty : new boolean[]{false, true}) { + val dir = testDir.newFolder(); + if(!empty) { + new ClassPathResource("numpy_arrays/npy/3,4/").copyDirectory(dir); + } else { + new ClassPathResource("numpy_arrays/npy/0,3_empty/").copyDirectory(dir); + } + + File[] files = dir.listFiles(); + int cnt = 0; + + for (File f : files) { + if (!f.getPath().endsWith(".npy")) { + log.warn("Skipping: {}", f); + continue; + } + + String path = f.getAbsolutePath(); + int lastDot = path.lastIndexOf('.'); + int lastUnderscore = path.lastIndexOf('_'); + String dtype = path.substring(lastUnderscore + 1, lastDot); + System.out.println(path + " : " + dtype); + + DataType dt = DataType.fromNumpy(dtype); + //System.out.println(dt); + + INDArray exp; + if(empty){ + exp = Nd4j.create(dt, 0, 3); + } else { + exp = Nd4j.arange(12).castTo(dt).reshape(3, 4); + } + INDArray act = Nd4j.createFromNpyFile(f); + + assertEquals("Failed with file [" + f.getName() + "]", exp, act); + cnt++; + } + + assertTrue(cnt > 0); + } + } + @Override public char ordering() { return 'c'; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/EmptyTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/EmptyTests.java index 111fd5a65..ce0b94b81 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/EmptyTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/EmptyTests.java @@ -149,6 +149,146 @@ public class EmptyTests extends BaseNd4jTest { } } + @Test + public void testEmptyWithShape_1() { + val array = Nd4j.create(DataType.FLOAT, 2, 0, 3); + + assertNotNull(array); + assertEquals(DataType.FLOAT, array.dataType()); + assertEquals(0, array.length()); + assertTrue(array.isEmpty()); + assertArrayEquals(new long[]{2, 0, 3}, array.shape()); + assertArrayEquals(new long[]{0, 0, 0}, array.stride()); + } + + @Test + public void testEmptyWithShape_2(){ + val array = Nd4j.create(DataType.FLOAT, 0); + + assertNotNull(array); + assertEquals(DataType.FLOAT, array.dataType()); + assertEquals(0, array.length()); + assertTrue(array.isEmpty()); + assertArrayEquals(new long[]{0}, array.shape()); + assertArrayEquals(new long[]{0}, array.stride()); + assertEquals(1, array.rank()); + } + + @Test(expected = IllegalArgumentException.class) + public void testEmptyWithShape_3() { + val array = Nd4j.create(DataType.FLOAT, 2, 0, 3); + array.tensorAlongDimension(0, 2); + } + + @Test + public void testEmptyWithShape_4(){ + val array = Nd4j.create(DataType.FLOAT, 0, 3); + + assertNotNull(array); + assertEquals(DataType.FLOAT, array.dataType()); + assertEquals(0, array.length()); + assertTrue(array.isEmpty()); + assertArrayEquals(new long[]{0, 3}, array.shape()); + assertArrayEquals(new long[]{0, 0}, array.stride()); + assertEquals(2, array.rank()); + assertEquals(0, array.rows()); + assertEquals(3, array.columns()); + assertEquals(0, array.size(0)); + assertEquals(3, array.size(1)); + assertEquals(0, array.stride(0)); + assertEquals(0, array.stride(1)); + } + + @Test + public void testEmptyReduction_1() { + val x = Nd4j.create(DataType.FLOAT, 2, 0, 3); + val e = Nd4j.create(DataType.FLOAT, 2, 1, 3).assign(0); + + val reduced = x.sum(true, 1); + + assertArrayEquals(e.shape(), reduced.shape()); + assertEquals(e, reduced); + } + + @Test + public void testEmptyReduction_2() { + val x = Nd4j.create(DataType.FLOAT, 2, 0, 3); + val e = Nd4j.create(DataType.FLOAT, 2, 3).assign(0); + + val reduced = x.sum(false, 1); + + assertArrayEquals(e.shape(), reduced.shape()); + assertEquals(e, reduced); + } + + + @Test + public void testEmptyReduction_3() { + val x = Nd4j.create(DataType.FLOAT, 2, 0); + val e = Nd4j.create(DataType.FLOAT, 0); + + val reduced = x.argMax(0); + + assertArrayEquals(e.shape(), reduced.shape()); + assertEquals(e, reduced); + } + + @Test(expected = IllegalArgumentException.class) + public void testEmptyReduction_4() { + val x = Nd4j.create(DataType.FLOAT, 2, 0); + val e = Nd4j.create(DataType.FLOAT, 0); + + val reduced = x.argMax(1); + + assertArrayEquals(e.shape(), reduced.shape()); + assertEquals(e, reduced); + } + + @Test + public void testEmptyCreateMethods(){ + DataType dt = DataType.FLOAT; + assertArrayEquals(new long[]{0}, Nd4j.create(0).shape()); + assertArrayEquals(new long[]{0,0}, Nd4j.create(0,0).shape()); + assertArrayEquals(new long[]{0,0,0}, Nd4j.create(0,0,0).shape()); + assertArrayEquals(new long[]{0}, Nd4j.create(0L).shape()); + assertArrayEquals(new long[]{0}, Nd4j.create(dt, 0L).shape()); + + assertArrayEquals(new long[]{0}, Nd4j.zeros(0).shape()); + assertArrayEquals(new long[]{0,0}, Nd4j.zeros(0,0).shape()); + assertArrayEquals(new long[]{0,0,0}, Nd4j.zeros(0,0,0).shape()); + assertArrayEquals(new long[]{0}, Nd4j.zeros(0L).shape()); + assertArrayEquals(new long[]{0}, Nd4j.zeros(dt, 0L).shape()); + + assertArrayEquals(new long[]{0}, Nd4j.ones(0).shape()); + assertArrayEquals(new long[]{0,0}, Nd4j.ones(0,0).shape()); + assertArrayEquals(new long[]{0,0,0}, Nd4j.ones(0,0,0).shape()); + assertArrayEquals(new long[]{0}, Nd4j.ones(0L).shape()); + assertArrayEquals(new long[]{0}, Nd4j.ones(dt, 0L).shape()); + + assertArrayEquals(new long[]{0}, Nd4j.valueArrayOf(0, 1.0).shape()); + assertArrayEquals(new long[]{0}, Nd4j.valueArrayOf(0,1.0).shape()); + assertArrayEquals(new long[]{0,0}, Nd4j.valueArrayOf(0,0,1.0).shape()); + assertArrayEquals(new long[]{1,0}, Nd4j.valueArrayOf(new long[]{1,0}, 1.0).shape()); + assertArrayEquals(new long[]{1,0}, Nd4j.valueArrayOf(new long[]{1,0}, 1.0f).shape()); + assertArrayEquals(new long[]{1,0}, Nd4j.valueArrayOf(new long[]{1,0}, 1L).shape()); + assertArrayEquals(new long[]{1,0}, Nd4j.valueArrayOf(new long[]{1,0}, 1).shape()); + + assertArrayEquals(new long[]{0}, Nd4j.createUninitialized(0).shape()); + assertArrayEquals(new long[]{0,0}, Nd4j.createUninitialized(0,0).shape()); + assertArrayEquals(new long[]{0,0}, Nd4j.createUninitialized(dt, 0,0).shape()); + + assertArrayEquals(new long[]{0,0}, Nd4j.zerosLike(Nd4j.ones(0,0)).shape()); + assertArrayEquals(new long[]{0,0}, Nd4j.onesLike(Nd4j.ones(0,0)).shape()); + assertArrayEquals(new long[]{0,0}, Nd4j.ones(0,0).like().shape()); + assertArrayEquals(new long[]{0,0}, Nd4j.ones(0,0).ulike().shape()); + } + + @Test + public void testEqualShapesEmpty(){ + assertTrue(Nd4j.create(0).equalShapes(Nd4j.create(0))); + assertFalse(Nd4j.create(0).equalShapes(Nd4j.create(1, 0))); + } + @Override public char ordering() { return 'c'; diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java b/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java index fbfbd1aa5..23b982ca3 100644 --- a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java +++ b/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java @@ -135,16 +135,19 @@ public abstract class BaseDataBuffer implements DataBuffer { * @param length the length of the view */ public BaseDataBuffer(Pointer pointer, Indexer indexer, long length) { - if (length < 1) - throw new IllegalArgumentException("Length must be >= 1"); + if (length < 0) + throw new IllegalArgumentException("Length must be >= 0"); + initTypeAndSize(); this.length = length; this.allocationMode = AllocationMode.MIXED_DATA_TYPES; this.underlyingLength = length; this.wrappedDataBuffer = this; - this.pointer = pointer; - setIndexer(indexer); + if (length > 0) { + this.pointer = pointer; + setIndexer(indexer); + } } diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/DataType.java b/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/DataType.java index 1b2773f3a..15d6719be 100644 --- a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/DataType.java +++ b/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/DataType.java @@ -65,6 +65,10 @@ public enum DataType { case INT: return 9; case LONG: return 10; case UBYTE: return 11; + case UINT16: return 12; + case UINT32: return 13; + case UINT64: return 14; + case BFLOAT16: return 17; case UTF8: return 50; default: throw new UnsupportedOperationException("Non-covered data type: [" + this + "]"); } diff --git a/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/util/ArrayUtil.java b/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/util/ArrayUtil.java index 0c29fe839..bdeb882f9 100644 --- a/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/util/ArrayUtil.java +++ b/nd4j/nd4j-common/src/main/java/org/nd4j/linalg/util/ArrayUtil.java @@ -3484,4 +3484,34 @@ public class ArrayUtil { return target; } + + /** + * Is this shape an empty shape? + * Shape is considered to be an empty shape if it contains any zeros. + * Note: a length 0 shape is NOT considered empty (it's rank 0 scalar) + * @param shape Shape to check + * @return True if shape contains zeros + */ + public static boolean isEmptyShape(long[] shape){ + for( long l : shape){ + if(l == 0) + return true; + } + return false; + } + + /** + * Is this shape an empty shape? + * Shape is considered to be an empty shape if it contains any zeros. + * Note: a length 0 shape is NOT considered empty (it's rank 0 scalar) + * @param shape Shape to check + * @return True if shape contains zeros + */ + public static boolean isEmptyShape(int[] shape){ + for( int i : shape){ + if(i == 0) + return true; + } + return false; + } }