From 972fae60dc3afb39c08fcf05cb80ce2f170c1365 Mon Sep 17 00:00:00 2001 From: raver119 Date: Fri, 6 Dec 2019 11:10:44 +0300 Subject: [PATCH] Update master (#8511) * cleaned up bert iterator tests (#110) Signed-off-by: eraly * Various pre-release fixes (#111) * Various fixes Signed-off-by: AlexDBlack * Fix default dtypes for MaxPoolWithArgmax Signed-off-by: AlexDBlack * Small pre-release tweak (#112) * Log UI address on launch as in previous Play-based UI Signed-off-by: AlexDBlack * Logging level tweak for UI Signed-off-by: AlexDBlack * http not https Signed-off-by: AlexDBlack * datavec python ensure host (#113) * ensure host * one more host ensure * info->debug * [WIP] reverse improvements (#115) * initial commit Signed-off-by: raver119 * reverse draft Signed-off-by: raver119 * reverse kernel Signed-off-by: raver119 * reverse kernel Signed-off-by: raver119 * 2 micro fixes Signed-off-by: raver119 * Shugeo resize fix5 (#102) * Refactored resize images ops to use TF-like bool args as input. * Refactored helpers for cpu implementation of resize_bilinear and resize_nearest_neighbor ops. * Refactored cuda implementation for image.resize_bilinear and image.resize_nearest_neighbor ops helpers. * Refactored nearest_neighbor resize op. * Added a pair of tests for special case of resize_bilinear algorithm. * Fixed issue with resize_bilinear op. * Refactored cpu implementation for helpers with resize_nearest_neighbor op. * Final fixed for resize ops to conform TF v.1.5 * Refactored cuda helpers for resize_neares_neighbor op. * Fixed resize_bilinear to accept proper data. * Fixed issue with non-float input for resize_bilinear op. * Refactored cuda helper for resize_bilinear to proper process non-float inputs. * Added tests for resize_bilinear to int inputs. * Fixed ResizeBilinear wrapper * Tests fixed * Fixed float and bool constant to avoid overflow for some kind of compilers. * Corrected float constants with float data type. * Added f suffix for float constants. * Corrected float constant to avoid overflow with initializing lists. * Corrected float initializing list with float input. * Corrected bool constant with initalizing list. * Corrected float and bool values with initializing lists. * Fixed wrong constant. * Fixed issue with 1x1 input picture for resize. * ResizeBilinear default values on import fix Signed-off-by: raver119 --- .../java/org/datavec/python/NumpyArray.java | 8 +- .../org/datavec/python/PythonExecutioner.java | 2 +- .../iterator/TestBertIterator.java | 681 +++++++++--------- .../org/deeplearning4j/ui/VertxUIServer.java | 7 +- libnd4j/blas/NDArray.h | 47 +- libnd4j/blas/NDArray.hpp | 4 +- .../nn/pooling/maxpool_with_argmax.cpp | 2 +- .../generic/parity_ops/resize_bicubic.cpp | 4 +- .../generic/parity_ops/resize_linear.cpp | 39 +- .../generic/parity_ops/resize_neighbor.cpp | 33 +- .../declarable/helpers/cpu/image_resize.cpp | 213 +++--- .../declarable/helpers/cuda/image_resize.cu | 264 ++++--- .../ops/declarable/helpers/cuda/reverse.cu | 95 ++- .../ops/declarable/helpers/image_resize.h | 14 +- .../layers_tests/ConvolutionTests1.cpp | 18 +- .../layers_tests/ConvolutionTests2.cpp | 4 +- .../layers_tests/DeclarableOpsTests1.cpp | 9 +- .../layers_tests/DeclarableOpsTests10.cpp | 269 +++++-- .../layers_tests/DeclarableOpsTests12.cpp | 4 +- .../layers_tests/DeclarableOpsTests13.cpp | 194 ++--- .../layers_tests/DeclarableOpsTests15.cpp | 4 +- .../layers_tests/DeclarableOpsTests16.cpp | 41 ++ .../layers_tests/DeclarableOpsTests2.cpp | 22 +- .../layers_tests/DeclarableOpsTests3.cpp | 18 +- .../layers_tests/DeclarableOpsTests7.cpp | 2 +- .../layers_tests/DeclarableOpsTestsCuda1.cu | 18 +- .../layers_tests/JavaInteropTests.cpp | 8 +- .../layers_tests/NDArrayCudaBasicsTests.cu | 90 +-- .../tests_cpu/layers_tests/NDArrayTests.cpp | 86 ++- .../converters/ImportClassMapping.java | 1 + .../api/ops/impl/image/NonMaxSuppression.java | 2 +- .../api/ops/impl/image/ResizeBilinear.java | 20 +- .../layers/convolution/MaxPoolWithArgmax.java | 4 +- .../java/org/nd4j/nativeblas/Nd4jCuda.java | 5 + .../java/org/nd4j/nativeblas/Nd4jCpu.java | 45 +- .../opvalidation/LayerOpValidation.java | 4 +- .../autodiff/samediff/ConvConfigTests.java | 16 +- .../TFGraphs/TFGraphTestAllSameDiff.java | 3 - .../nd4j/linalg/custom/CustomOpsTests.java | 3 +- 39 files changed, 1420 insertions(+), 883 deletions(-) diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/NumpyArray.java b/datavec/datavec-python/src/main/java/org/datavec/python/NumpyArray.java index ab49cf5ea..24a2c2e09 100644 --- a/datavec/datavec-python/src/main/java/org/datavec/python/NumpyArray.java +++ b/datavec/datavec-python/src/main/java/org/datavec/python/NumpyArray.java @@ -21,6 +21,7 @@ import lombok.Getter; import lombok.NoArgsConstructor; import org.bytedeco.javacpp.Pointer; import org.nd4j.linalg.api.buffer.DataBuffer; +import org.nd4j.linalg.api.concurrency.AffinityManager; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.factory.Nd4j; @@ -60,6 +61,7 @@ public class NumpyArray { setND4JArray(); if (copy){ nd4jArray = nd4jArray.dup(); + Nd4j.getAffinityManager().ensureLocation(nd4jArray, AffinityManager.Location.HOST); this.address = nd4jArray.data().address(); } @@ -85,6 +87,7 @@ public class NumpyArray { setND4JArray(); if (copy){ nd4jArray = nd4jArray.dup(); + Nd4j.getAffinityManager().ensureLocation(nd4jArray, AffinityManager.Location.HOST); this.address = nd4jArray.data().address(); } } @@ -104,11 +107,12 @@ public class NumpyArray { nd4jStrides[i] = strides[i] / elemSize; } - this.nd4jArray = Nd4j.create(buff, shape, nd4jStrides, 0, Shape.getOrder(shape,nd4jStrides,1), dtype); - + nd4jArray = Nd4j.create(buff, shape, nd4jStrides, 0, Shape.getOrder(shape,nd4jStrides,1), dtype); + Nd4j.getAffinityManager().ensureLocation(nd4jArray, AffinityManager.Location.HOST); } public NumpyArray(INDArray nd4jArray){ + Nd4j.getAffinityManager().ensureLocation(nd4jArray, AffinityManager.Location.HOST); DataBuffer buff = nd4jArray.data(); address = buff.pointer().address(); shape = nd4jArray.shape(); diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/PythonExecutioner.java b/datavec/datavec-python/src/main/java/org/datavec/python/PythonExecutioner.java index c6272e7ad..0f926b9ad 100644 --- a/datavec/datavec-python/src/main/java/org/datavec/python/PythonExecutioner.java +++ b/datavec/datavec-python/src/main/java/org/datavec/python/PythonExecutioner.java @@ -605,7 +605,7 @@ public class PythonExecutioner { private static synchronized void _exec(String code) { - log.info(code); + log.debug(code); log.info("CPython: PyRun_SimpleStringFlag()"); int result = PyRun_SimpleStringFlags(code, null); diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/iterator/TestBertIterator.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/iterator/TestBertIterator.java index a6716ba40..52644c360 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/iterator/TestBertIterator.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/iterator/TestBertIterator.java @@ -17,11 +17,13 @@ package org.deeplearning4j.iterator; +import lombok.Getter; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.iterator.bert.BertMaskedLMMasker; +import org.deeplearning4j.iterator.provider.CollectionLabeledPairSentenceProvider; +import org.deeplearning4j.iterator.provider.CollectionLabeledSentenceProvider; import org.deeplearning4j.text.tokenization.tokenizerfactory.BertWordPieceTokenizerFactory; import org.junit.Test; -import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.MultiDataSet; @@ -42,8 +44,12 @@ import static org.junit.Assert.*; public class TestBertIterator extends BaseDL4JTest { - private File pathToVocab = Resources.asFile("other/vocab.txt"); + private static File pathToVocab = Resources.asFile("other/vocab.txt"); private static Charset c = StandardCharsets.UTF_8; + private static String shortSentence = "I saw a girl with a telescope."; + private static String longSentence = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum"; + private static String sentenceA = "Goodnight noises everywhere"; + private static String sentenceB = "Goodnight moon"; public TestBertIterator() throws IOException { } @@ -51,20 +57,15 @@ public class TestBertIterator extends BaseDL4JTest { @Test(timeout = 20000L) public void testBertSequenceClassification() throws Exception { - String toTokenize1 = "I saw a girl with a telescope."; - String toTokenize2 = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum"; - List forInference = new ArrayList<>(); - forInference.add(toTokenize1); - forInference.add(toTokenize2); - BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c); - + int minibatchSize = 2; + TestSentenceHelper testHelper = new TestSentenceHelper(); BertIterator b = BertIterator.builder() - .tokenizer(t) + .tokenizer(testHelper.getTokenizer()) .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, 16) - .minibatchSize(2) - .sentenceProvider(new TestSentenceProvider()) + .minibatchSize(minibatchSize) + .sentenceProvider(testHelper.getSentenceProvider()) .featureArrays(BertIterator.FeatureArrays.INDICES_MASK) - .vocabMap(t.getVocab()) + .vocabMap(testHelper.getTokenizer().getVocab()) .task(BertIterator.Task.SEQ_CLASSIFICATION) .build(); @@ -73,82 +74,77 @@ public class TestBertIterator extends BaseDL4JTest { System.out.println(mds.getFeatures(0)); System.out.println(mds.getFeaturesMaskArray(0)); - - INDArray expEx0 = Nd4j.create(DataType.INT, 1, 16); - INDArray expM0 = Nd4j.create(DataType.INT, 1, 16); - List tokens = t.create(toTokenize1).getTokens(); - Map m = t.getVocab(); - for (int i = 0; i < tokens.size(); i++) { - int idx = m.get(tokens.get(i)); - expEx0.putScalar(0, i, idx); - expM0.putScalar(0, i, 1); - } - - INDArray expEx1 = Nd4j.create(DataType.INT, 1, 16); - INDArray expM1 = Nd4j.create(DataType.INT, 1, 16); - List tokens2 = t.create(toTokenize2).getTokens(); - for (int i = 0; i < tokens2.size(); i++) { - String token = tokens2.get(i); - if (!m.containsKey(token)) { - throw new IllegalStateException("Unknown token: \"" + token + "\""); + INDArray expF = Nd4j.create(DataType.INT, 1, 16); + INDArray expM = Nd4j.create(DataType.INT, 1, 16); + Map m = testHelper.getTokenizer().getVocab(); + for (int i = 0; i < minibatchSize; i++) { + INDArray expFTemp = Nd4j.create(DataType.INT, 1, 16); + INDArray expMTemp = Nd4j.create(DataType.INT, 1, 16); + List tokens = testHelper.getTokenizedSentences().get(i); + System.out.println(tokens); + for (int j = 0; j < tokens.size(); j++) { + String token = tokens.get(j); + if (!m.containsKey(token)) { + throw new IllegalStateException("Unknown token: \"" + token + "\""); + } + int idx = m.get(token); + expFTemp.putScalar(0, j, idx); + expMTemp.putScalar(0, j, 1); + } + if (i == 0) { + expF = expFTemp.dup(); + expM = expMTemp.dup(); + } else { + expF = Nd4j.vstack(expF, expFTemp); + expM = Nd4j.vstack(expM, expMTemp); } - int idx = m.get(token); - expEx1.putScalar(0, i, idx); - expM1.putScalar(0, i, 1); } - - INDArray expF = Nd4j.vstack(expEx0, expEx1); - INDArray expM = Nd4j.vstack(expM0, expM1); - assertEquals(expF, mds.getFeatures(0)); assertEquals(expM, mds.getFeaturesMaskArray(0)); - assertEquals(expF, b.featurizeSentences(forInference).getFirst()[0]); - assertEquals(expM, b.featurizeSentences(forInference).getSecond()[0]); + assertEquals(expF, b.featurizeSentences(testHelper.getSentences()).getFirst()[0]); + assertEquals(expM, b.featurizeSentences(testHelper.getSentences()).getSecond()[0]); - b.next(); //pop the third element assertFalse(b.hasNext()); b.reset(); assertTrue(b.hasNext()); - forInference.set(0, toTokenize2); //Same thing, but with segment ID also b = BertIterator.builder() - .tokenizer(t) + .tokenizer(testHelper.getTokenizer()) .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, 16) - .minibatchSize(2) - .sentenceProvider(new TestSentenceProvider()) + .minibatchSize(minibatchSize) + .sentenceProvider(testHelper.getSentenceProvider()) .featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID) - .vocabMap(t.getVocab()) + .vocabMap(testHelper.getTokenizer().getVocab()) .task(BertIterator.Task.SEQ_CLASSIFICATION) .build(); mds = b.next(); assertEquals(2, mds.getFeatures().length); - //assertEquals(2, mds.getFeaturesMaskArrays().length); second element is null... - assertEquals(2, b.featurizeSentences(forInference).getFirst().length); //Segment ID should be all 0s for single segment task INDArray segmentId = expM.like(); assertEquals(segmentId, mds.getFeatures(1)); - assertEquals(segmentId, b.featurizeSentences(forInference).getFirst()[1]); + assertEquals(segmentId, b.featurizeSentences(testHelper.getSentences()).getFirst()[1]); } @Test(timeout = 20000L) public void testBertUnsupervised() throws Exception { + int minibatchSize = 2; + TestSentenceHelper testHelper = new TestSentenceHelper(); //Task 1: Unsupervised - BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c); BertIterator b = BertIterator.builder() - .tokenizer(t) + .tokenizer(testHelper.getTokenizer()) .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, 16) - .minibatchSize(2) - .sentenceProvider(new TestSentenceProvider()) + .minibatchSize(minibatchSize) + .sentenceProvider(testHelper.getSentenceProvider()) .featureArrays(BertIterator.FeatureArrays.INDICES_MASK) - .vocabMap(t.getVocab()) + .vocabMap(testHelper.getTokenizer().getVocab()) .task(BertIterator.Task.UNSUPERVISED) .masker(new BertMaskedLMMasker(new Random(12345), 0.2, 0.5, 0.5)) .unsupervisedLabelFormat(BertIterator.UnsupervisedLabelFormat.RANK2_IDX) .maskToken("[MASK]") .build(); - System.out.println("Mask token index: " + t.getVocab().get("[MASK]")); + System.out.println("Mask token index: " + testHelper.getTokenizer().getVocab().get("[MASK]")); MultiDataSet mds = b.next(); System.out.println(mds.getFeatures(0)); @@ -156,7 +152,6 @@ public class TestBertIterator extends BaseDL4JTest { System.out.println(mds.getLabels(0)); System.out.println(mds.getLabelsMaskArray(0)); - b.next(); //pop the third element assertFalse(b.hasNext()); b.reset(); assertTrue(b.hasNext()); @@ -164,40 +159,34 @@ public class TestBertIterator extends BaseDL4JTest { @Test(timeout = 20000L) public void testLengthHandling() throws Exception { - String toTokenize1 = "I saw a girl with a telescope."; - String toTokenize2 = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum"; - List forInference = new ArrayList<>(); - forInference.add(toTokenize1); - forInference.add(toTokenize2); - BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c); - INDArray expEx0 = Nd4j.create(DataType.INT, 1, 16); - INDArray expM0 = Nd4j.create(DataType.INT, 1, 16); - List tokens = t.create(toTokenize1).getTokens(); - System.out.println(tokens); - Map m = t.getVocab(); - for (int i = 0; i < tokens.size(); i++) { - int idx = m.get(tokens.get(i)); - expEx0.putScalar(0, i, idx); - expM0.putScalar(0, i, 1); - } - - INDArray expEx1 = Nd4j.create(DataType.INT, 1, 16); - INDArray expM1 = Nd4j.create(DataType.INT, 1, 16); - List tokens2 = t.create(toTokenize2).getTokens(); - System.out.println(tokens2); - for (int i = 0; i < tokens2.size(); i++) { - String token = tokens2.get(i); - if (!m.containsKey(token)) { - throw new IllegalStateException("Unknown token: \"" + token + "\""); + int minibatchSize = 2; + TestSentenceHelper testHelper = new TestSentenceHelper(); + INDArray expF = Nd4j.create(DataType.INT, 1, 16); + INDArray expM = Nd4j.create(DataType.INT, 1, 16); + Map m = testHelper.getTokenizer().getVocab(); + for (int i = 0; i < minibatchSize; i++) { + List tokens = testHelper.getTokenizedSentences().get(i); + INDArray expFTemp = Nd4j.create(DataType.INT, 1, 16); + INDArray expMTemp = Nd4j.create(DataType.INT, 1, 16); + System.out.println(tokens); + for (int j = 0; j < tokens.size(); j++) { + String token = tokens.get(j); + if (!m.containsKey(token)) { + throw new IllegalStateException("Unknown token: \"" + token + "\""); + } + int idx = m.get(token); + expFTemp.putScalar(0, j, idx); + expMTemp.putScalar(0, j, 1); + } + if (i == 0) { + expF = expFTemp.dup(); + expM = expMTemp.dup(); + } else { + expF = Nd4j.vstack(expF, expFTemp); + expM = Nd4j.vstack(expM, expMTemp); } - int idx = m.get(token); - expEx1.putScalar(0, i, idx); - expM1.putScalar(0, i, 1); } - INDArray expF = Nd4j.vstack(expEx0, expEx1); - INDArray expM = Nd4j.vstack(expM0, expM1); - //-------------------------------------------------------------- //Fixed length: clip or pad - already tested in other tests @@ -205,12 +194,12 @@ public class TestBertIterator extends BaseDL4JTest { //Any length: as long as we need to fit longest sequence BertIterator b = BertIterator.builder() - .tokenizer(t) + .tokenizer(testHelper.getTokenizer()) .lengthHandling(BertIterator.LengthHandling.ANY_LENGTH, -1) - .minibatchSize(2) - .sentenceProvider(new TestSentenceProvider()) + .minibatchSize(minibatchSize) + .sentenceProvider(testHelper.getSentenceProvider()) .featureArrays(BertIterator.FeatureArrays.INDICES_MASK) - .vocabMap(t.getVocab()) + .vocabMap(testHelper.getTokenizer().getVocab()) .task(BertIterator.Task.SEQ_CLASSIFICATION) .build(); MultiDataSet mds = b.next(); @@ -219,20 +208,19 @@ public class TestBertIterator extends BaseDL4JTest { assertArrayEquals(expShape, mds.getFeaturesMaskArray(0).shape()); assertEquals(expF.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 14)), mds.getFeatures(0)); assertEquals(expM.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 14)), mds.getFeaturesMaskArray(0)); - assertEquals(mds.getFeatures(0), b.featurizeSentences(forInference).getFirst()[0]); - assertEquals(mds.getFeaturesMaskArray(0), b.featurizeSentences(forInference).getSecond()[0]); + assertEquals(mds.getFeatures(0), b.featurizeSentences(testHelper.getSentences()).getFirst()[0]); + assertEquals(mds.getFeaturesMaskArray(0), b.featurizeSentences(testHelper.getSentences()).getSecond()[0]); //Clip only: clip to maximum, but don't pad if less b = BertIterator.builder() - .tokenizer(t) + .tokenizer(testHelper.getTokenizer()) .lengthHandling(BertIterator.LengthHandling.CLIP_ONLY, 20) - .minibatchSize(2) - .sentenceProvider(new TestSentenceProvider()) + .minibatchSize(minibatchSize) + .sentenceProvider(testHelper.getSentenceProvider()) .featureArrays(BertIterator.FeatureArrays.INDICES_MASK) - .vocabMap(t.getVocab()) + .vocabMap(testHelper.getTokenizer().getVocab()) .task(BertIterator.Task.SEQ_CLASSIFICATION) .build(); - mds = b.next(); expShape = new long[]{2, 14}; assertArrayEquals(expShape, mds.getFeatures(0).shape()); assertArrayEquals(expShape, mds.getFeaturesMaskArray(0).shape()); @@ -241,54 +229,38 @@ public class TestBertIterator extends BaseDL4JTest { @Test(timeout = 20000L) public void testMinibatchPadding() throws Exception { Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT); - String toTokenize1 = "I saw a girl with a telescope."; - String toTokenize2 = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum"; - String toTokenize3 = "Goodnight noises everywhere"; - List forInference = new ArrayList<>(); - forInference.add(toTokenize1); - forInference.add(toTokenize2); - forInference.add(toTokenize3); - BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c); - INDArray expEx0 = Nd4j.create(DataType.INT, 1, 16); - INDArray expM0 = Nd4j.create(DataType.INT, 1, 16); - List tokens = t.create(toTokenize1).getTokens(); - Map m = t.getVocab(); - for (int i = 0; i < tokens.size(); i++) { - int idx = m.get(tokens.get(i)); - expEx0.putScalar(0, i, idx); - expM0.putScalar(0, i, 1); - } - - INDArray expEx1 = Nd4j.create(DataType.INT, 1, 16); - INDArray expM1 = Nd4j.create(DataType.INT, 1, 16); - List tokens2 = t.create(toTokenize2).getTokens(); - for (int i = 0; i < tokens2.size(); i++) { - String token = tokens2.get(i); - if (!m.containsKey(token)) { - throw new IllegalStateException("Unknown token: \"" + token + "\""); - } - int idx = m.get(token); - expEx1.putScalar(0, i, idx); - expM1.putScalar(0, i, 1); - } - - INDArray expEx3 = Nd4j.create(DataType.INT, 1, 16); - INDArray expM3 = Nd4j.create(DataType.INT, 1, 16); - List tokens3 = t.create(toTokenize3).getTokens(); - for (int i = 0; i < tokens3.size(); i++) { - String token = tokens3.get(i); - if (!m.containsKey(token)) { - throw new IllegalStateException("Unknown token: \"" + token + "\""); - } - int idx = m.get(token); - expEx3.putScalar(0, i, idx); - expM3.putScalar(0, i, 1); - } - + int minibatchSize = 3; + TestSentenceHelper testHelper = new TestSentenceHelper(minibatchSize); INDArray zeros = Nd4j.create(DataType.INT, 1, 16); - INDArray expF = Nd4j.vstack(expEx0, expEx1, expEx3, zeros); - INDArray expM = Nd4j.vstack(expM0, expM1, expM3, zeros); - INDArray expL = Nd4j.createFromArray(new float[][]{{1, 0}, {0, 1}, {1, 0}, {0, 0}}); + INDArray expF = Nd4j.create(DataType.INT, 1, 16); + INDArray expM = Nd4j.create(DataType.INT, 1, 16); + Map m = testHelper.getTokenizer().getVocab(); + for (int i = 0; i < minibatchSize; i++) { + List tokens = testHelper.getTokenizedSentences().get(i); + INDArray expFTemp = Nd4j.create(DataType.INT, 1, 16); + INDArray expMTemp = Nd4j.create(DataType.INT, 1, 16); + System.out.println(tokens); + for (int j = 0; j < tokens.size(); j++) { + String token = tokens.get(j); + if (!m.containsKey(token)) { + throw new IllegalStateException("Unknown token: \"" + token + "\""); + } + int idx = m.get(token); + expFTemp.putScalar(0, j, idx); + expMTemp.putScalar(0, j, 1); + } + if (i == 0) { + expF = expFTemp.dup(); + expM = expMTemp.dup(); + } else { + expF = Nd4j.vstack(expF.dup(), expFTemp); + expM = Nd4j.vstack(expM.dup(), expMTemp); + } + } + + expF = Nd4j.vstack(expF, zeros); + expM = Nd4j.vstack(expM, zeros); + INDArray expL = Nd4j.createFromArray(new float[][]{{0, 1}, {1, 0}, {0, 1}, {0, 0}}); INDArray expLM = Nd4j.create(DataType.FLOAT, 4, 1); expLM.putScalar(0, 0, 1); expLM.putScalar(1, 0, 1); @@ -297,13 +269,13 @@ public class TestBertIterator extends BaseDL4JTest { //-------------------------------------------------------------- BertIterator b = BertIterator.builder() - .tokenizer(t) + .tokenizer(testHelper.getTokenizer()) .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, 16) - .minibatchSize(4) + .minibatchSize(minibatchSize + 1) .padMinibatches(true) - .sentenceProvider(new TestSentenceProvider()) + .sentenceProvider(testHelper.getSentenceProvider()) .featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID) - .vocabMap(t.getVocab()) + .vocabMap(testHelper.getTokenizer().getVocab()) .task(BertIterator.Task.SEQ_CLASSIFICATION) .build(); @@ -323,170 +295,175 @@ public class TestBertIterator extends BaseDL4JTest { assertEquals(expL, mds.getLabels(0)); assertEquals(expLM, mds.getLabelsMaskArray(0)); - assertEquals(expF, b.featurizeSentences(forInference).getFirst()[0]); - assertEquals(expM, b.featurizeSentences(forInference).getSecond()[0]); + assertEquals(expF, b.featurizeSentences(testHelper.getSentences()).getFirst()[0]); + assertEquals(expM, b.featurizeSentences(testHelper.getSentences()).getSecond()[0]); } + /* + Checks that a mds from a pair sentence is equal to hstack'd mds from the left side and right side of the pair + Checks different lengths for max length to check popping and padding + */ @Test public void testSentencePairsSingle() throws IOException { - String shortSent = "I saw a girl with a telescope."; - String longSent = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum"; boolean prependAppend; - BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c); - int shortL = t.create(shortSent).countTokens(); - int longL = t.create(longSent).countTokens(); + int numOfSentences; + + TestSentenceHelper testHelper = new TestSentenceHelper(); + int shortL = testHelper.getShortestL(); + int longL = testHelper.getLongestL(); Triple multiDataSetTriple; - MultiDataSet shortLongPair, shortSentence, longSentence; + MultiDataSet fromPair, leftSide, rightSide; // check for pair max length exactly equal to sum of lengths - pop neither no padding // should be the same as hstack with segment ids 1 for second sentence prependAppend = true; - multiDataSetTriple = generateMultiDataSets(new Triple<>(shortL + longL, shortL, longL), prependAppend); - shortLongPair = multiDataSetTriple.getFirst(); - shortSentence = multiDataSetTriple.getSecond(); - longSentence = multiDataSetTriple.getThird(); - assertEquals(shortLongPair.getFeatures(0), Nd4j.hstack(shortSentence.getFeatures(0), longSentence.getFeatures(0))); - longSentence.getFeatures(1).addi(1); - assertEquals(shortLongPair.getFeatures(1), Nd4j.hstack(shortSentence.getFeatures(1), longSentence.getFeatures(1))); - assertEquals(shortLongPair.getFeaturesMaskArray(0), Nd4j.hstack(shortSentence.getFeaturesMaskArray(0), longSentence.getFeaturesMaskArray(0))); + numOfSentences = 1; + multiDataSetTriple = generateMultiDataSets(new Triple<>(shortL + longL, shortL, longL), prependAppend, numOfSentences); + fromPair = multiDataSetTriple.getFirst(); + leftSide = multiDataSetTriple.getSecond(); + rightSide = multiDataSetTriple.getThird(); + assertEquals(fromPair.getFeatures(0), Nd4j.hstack(leftSide.getFeatures(0), rightSide.getFeatures(0))); + rightSide.getFeatures(1).addi(1); //add 1 for right side segment ids + assertEquals(fromPair.getFeatures(1), Nd4j.hstack(leftSide.getFeatures(1), rightSide.getFeatures(1))); + assertEquals(fromPair.getFeaturesMaskArray(0), Nd4j.hstack(leftSide.getFeaturesMaskArray(0), rightSide.getFeaturesMaskArray(0))); //check for pair max length greater than sum of lengths - pop neither with padding // features should be the same as hstack of shorter and longer padded with prepend/append // segment id should 1 only in the longer for part of the length of the sentence prependAppend = true; - multiDataSetTriple = generateMultiDataSets(new Triple<>(shortL + longL + 5, shortL, longL + 5), prependAppend); - shortLongPair = multiDataSetTriple.getFirst(); - shortSentence = multiDataSetTriple.getSecond(); - longSentence = multiDataSetTriple.getThird(); - assertEquals(shortLongPair.getFeatures(0), Nd4j.hstack(shortSentence.getFeatures(0), longSentence.getFeatures(0))); - longSentence.getFeatures(1).get(NDArrayIndex.all(), NDArrayIndex.interval(0, longL + 1)).addi(1); //segmentId stays 0 for the padded part - assertEquals(shortLongPair.getFeatures(1), Nd4j.hstack(shortSentence.getFeatures(1), longSentence.getFeatures(1))); - assertEquals(shortLongPair.getFeaturesMaskArray(0), Nd4j.hstack(shortSentence.getFeaturesMaskArray(0), longSentence.getFeaturesMaskArray(0))); + numOfSentences = 1; + multiDataSetTriple = generateMultiDataSets(new Triple<>(shortL + longL + 5, shortL, longL + 5), prependAppend, numOfSentences); + fromPair = multiDataSetTriple.getFirst(); + leftSide = multiDataSetTriple.getSecond(); + rightSide = multiDataSetTriple.getThird(); + assertEquals(fromPair.getFeatures(0), Nd4j.hstack(leftSide.getFeatures(0), rightSide.getFeatures(0))); + rightSide.getFeatures(1).get(NDArrayIndex.all(), NDArrayIndex.interval(0, longL + 1)).addi(1); //segmentId stays 0 for the padded part + assertEquals(fromPair.getFeatures(1), Nd4j.hstack(leftSide.getFeatures(1), rightSide.getFeatures(1))); + assertEquals(fromPair.getFeaturesMaskArray(0), Nd4j.hstack(leftSide.getFeaturesMaskArray(0), rightSide.getFeaturesMaskArray(0))); //check for pair max length less than shorter sentence - pop both //should be the same as hstack with segment ids 1 for second sentence if no prepend/append - int maxL = shortL - 2; + int maxL = 5;//checking odd + numOfSentences = 3; prependAppend = false; - multiDataSetTriple = generateMultiDataSets(new Triple<>(maxL, maxL / 2, maxL - maxL / 2), prependAppend); - shortLongPair = multiDataSetTriple.getFirst(); - shortSentence = multiDataSetTriple.getSecond(); - longSentence = multiDataSetTriple.getThird(); - assertEquals(shortLongPair.getFeatures(0), Nd4j.hstack(shortSentence.getFeatures(0), longSentence.getFeatures(0))); - longSentence.getFeatures(1).addi(1); - assertEquals(shortLongPair.getFeatures(1), Nd4j.hstack(shortSentence.getFeatures(1), longSentence.getFeatures(1))); - assertEquals(shortLongPair.getFeaturesMaskArray(0), Nd4j.hstack(shortSentence.getFeaturesMaskArray(0), longSentence.getFeaturesMaskArray(0))); + multiDataSetTriple = generateMultiDataSets(new Triple<>(maxL, maxL / 2, maxL - maxL / 2), prependAppend, numOfSentences); + fromPair = multiDataSetTriple.getFirst(); + leftSide = multiDataSetTriple.getSecond(); + rightSide = multiDataSetTriple.getThird(); + assertEquals(fromPair.getFeatures(0), Nd4j.hstack(leftSide.getFeatures(0), rightSide.getFeatures(0))); + rightSide.getFeatures(1).addi(1); + assertEquals(fromPair.getFeatures(1), Nd4j.hstack(leftSide.getFeatures(1), rightSide.getFeatures(1))); + assertEquals(fromPair.getFeaturesMaskArray(0), Nd4j.hstack(leftSide.getFeaturesMaskArray(0), rightSide.getFeaturesMaskArray(0))); } + /* + Same idea as previous test - construct mds from bert iterator with sep sentences and check against one with pairs + Checks various max lengths + Has sentences of varying lengths + */ @Test public void testSentencePairsUnequalLengths() throws IOException { - //check for pop only longer (i.e between longer and longer + shorter), first row pop from second sentence, next row pop from first sentence, nothing to pop in the third row - //should be identical to hstack if there is no append, prepend - //batch size is 2 - int mbS = 4; - String shortSent = "I saw a girl with a telescope."; - String longSent = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum"; - String sent1 = "Goodnight noises everywhere"; //shorter than shortSent - no popping - String sent2 = "Goodnight moon"; //shorter than shortSent - no popping - BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c); - int shortL = t.create(shortSent).countTokens(); - int longL = t.create(longSent).countTokens(); - int sent1L = t.create(sent1).countTokens(); - int sent2L = t.create(sent2).countTokens(); - //won't check 2*shortL + 1 because this will always pop on the left - for (int maxL = longL + shortL - 1; maxL > 2 * shortL; maxL--) { + + int minibatchSize = 4; + int numOfSentencesinIter = 3; + + TestSentencePairsHelper testPairHelper = new TestSentencePairsHelper(numOfSentencesinIter); + int shortL = testPairHelper.getShortL(); + int longL = testPairHelper.getLongL(); + int sent1L = testPairHelper.getSentenceALen(); + int sent2L = testPairHelper.getSentenceBLen(); + + System.out.println("Sentence Pairs, Left"); + System.out.println(testPairHelper.getSentencesLeft()); + System.out.println("Sentence Pairs, Right"); + System.out.println(testPairHelper.getSentencesRight()); + + //anything outside this range more will need to check padding,truncation + for (int maxL = longL + shortL; maxL > 2 * shortL + 1; maxL--) { + + System.out.println("Running for max length = " + maxL); + MultiDataSet leftMDS = BertIterator.builder() - .tokenizer(t) - .minibatchSize(mbS) + .tokenizer(testPairHelper.getTokenizer()) + .minibatchSize(minibatchSize) .featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID) - .vocabMap(t.getVocab()) + .vocabMap(testPairHelper.getTokenizer().getVocab()) .task(BertIterator.Task.SEQ_CLASSIFICATION) - .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, longL + 10) //random big num guaranteed to be longer than either - .sentenceProvider(new TestSentenceProvider()) + .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, longL * 10) //random big num guaranteed to be longer than either + .sentenceProvider(new TestSentenceHelper(numOfSentencesinIter).getSentenceProvider()) .padMinibatches(true) .build().next(); MultiDataSet rightMDS = BertIterator.builder() - .tokenizer(t) - .minibatchSize(mbS) + .tokenizer(testPairHelper.getTokenizer()) + .minibatchSize(minibatchSize) .featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID) - .vocabMap(t.getVocab()) + .vocabMap(testPairHelper.getTokenizer().getVocab()) .task(BertIterator.Task.SEQ_CLASSIFICATION) - .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, longL + 10) //random big num guaranteed to be longer than either - .sentenceProvider(new TestSentenceProvider(true)) + .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, longL * 10) //random big num guaranteed to be longer than either + .sentenceProvider(new TestSentenceHelper(true, numOfSentencesinIter).getSentenceProvider()) .padMinibatches(true) .build().next(); MultiDataSet pairMDS = BertIterator.builder() - .tokenizer(t) - .minibatchSize(mbS) + .tokenizer(testPairHelper.getTokenizer()) + .minibatchSize(minibatchSize) .featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID) - .vocabMap(t.getVocab()) + .vocabMap(testPairHelper.getTokenizer().getVocab()) .task(BertIterator.Task.SEQ_CLASSIFICATION) - .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, maxL) //random big num guaranteed to be longer than either - .sentencePairProvider(new TestSentencePairProvider()) + .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, maxL) + .sentencePairProvider(testPairHelper.getPairSentenceProvider()) .padMinibatches(true) .build().next(); - //Left sentences here are {{shortSent}, - // {longSent}, - // {Sent1}} - //Right sentences here are {{longSent}, - // {shortSent}, - // {Sent2}} - //The sentence pairs here are {{shortSent,longSent}, - // {longSent,shortSent} - // {Sent1, Sent2}} - //CHECK FEATURES - INDArray combinedFeat = Nd4j.create(DataType.INT,mbS,maxL); + INDArray combinedFeat = Nd4j.create(DataType.INT, minibatchSize, maxL); //left side INDArray leftFeatures = leftMDS.getFeatures(0); INDArray topLSentFeat = leftFeatures.getRow(0).get(NDArrayIndex.interval(0, shortL)); INDArray midLSentFeat = leftFeatures.getRow(1).get(NDArrayIndex.interval(0, maxL - shortL)); - INDArray bottomLSentFeat = leftFeatures.getRow(2).get(NDArrayIndex.interval(0,sent1L)); + INDArray bottomLSentFeat = leftFeatures.getRow(2).get(NDArrayIndex.interval(0, sent1L)); //right side INDArray rightFeatures = rightMDS.getFeatures(0); INDArray topRSentFeat = rightFeatures.getRow(0).get(NDArrayIndex.interval(0, maxL - shortL)); INDArray midRSentFeat = rightFeatures.getRow(1).get(NDArrayIndex.interval(0, shortL)); - INDArray bottomRSentFeat = rightFeatures.getRow(2).get(NDArrayIndex.interval(0,sent2L)); + INDArray bottomRSentFeat = rightFeatures.getRow(2).get(NDArrayIndex.interval(0, sent2L)); //expected pair - combinedFeat.getRow(0).addi(Nd4j.hstack(topLSentFeat,topRSentFeat)); - combinedFeat.getRow(1).addi(Nd4j.hstack(midLSentFeat,midRSentFeat)); - combinedFeat.getRow(2).get(NDArrayIndex.interval(0,sent1L+sent2L)).addi(Nd4j.hstack(bottomLSentFeat,bottomRSentFeat)); + combinedFeat.getRow(0).addi(Nd4j.hstack(topLSentFeat, topRSentFeat)); + combinedFeat.getRow(1).addi(Nd4j.hstack(midLSentFeat, midRSentFeat)); + combinedFeat.getRow(2).get(NDArrayIndex.interval(0, sent1L + sent2L)).addi(Nd4j.hstack(bottomLSentFeat, bottomRSentFeat)); assertEquals(maxL, pairMDS.getFeatures(0).shape()[1]); assertArrayEquals(combinedFeat.shape(), pairMDS.getFeatures(0).shape()); assertEquals(combinedFeat, pairMDS.getFeatures(0)); //CHECK SEGMENT ID - INDArray combinedFetSeg = Nd4j.create(DataType.INT, mbS, maxL); + INDArray combinedFetSeg = Nd4j.create(DataType.INT, minibatchSize, maxL); combinedFetSeg.get(NDArrayIndex.point(0), NDArrayIndex.interval(shortL, maxL)).addi(1); combinedFetSeg.get(NDArrayIndex.point(1), NDArrayIndex.interval(maxL - shortL, maxL)).addi(1); - combinedFetSeg.get(NDArrayIndex.point(2), NDArrayIndex.interval(sent1L, sent1L+sent2L)).addi(1); + combinedFetSeg.get(NDArrayIndex.point(2), NDArrayIndex.interval(sent1L, sent1L + sent2L)).addi(1); assertArrayEquals(combinedFetSeg.shape(), pairMDS.getFeatures(1).shape()); assertEquals(maxL, combinedFetSeg.shape()[1]); assertEquals(combinedFetSeg, pairMDS.getFeatures(1)); + + testPairHelper.getPairSentenceProvider().reset(); } } @Test public void testSentencePairFeaturizer() throws IOException { - String shortSent = "I saw a girl with a telescope."; - String longSent = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum"; - List> listSentencePair = new ArrayList<>(); - listSentencePair.add(new Pair<>(shortSent, longSent)); - listSentencePair.add(new Pair<>(longSent, shortSent)); - BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c); + int minibatchSize = 2; + TestSentencePairsHelper testPairHelper = new TestSentencePairsHelper(minibatchSize); BertIterator b = BertIterator.builder() - .tokenizer(t) - .minibatchSize(2) + .tokenizer(testPairHelper.getTokenizer()) + .minibatchSize(minibatchSize) .padMinibatches(true) .featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID) - .vocabMap(t.getVocab()) + .vocabMap(testPairHelper.getTokenizer().getVocab()) .task(BertIterator.Task.SEQ_CLASSIFICATION) .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, 128) - .sentencePairProvider(new TestSentencePairProvider()) + .sentencePairProvider(testPairHelper.getPairSentenceProvider()) .prependToken("[CLS]") .appendToken("[SEP]") .build(); @@ -494,23 +471,19 @@ public class TestBertIterator extends BaseDL4JTest { INDArray[] featuresArr = mds.getFeatures(); INDArray[] featuresMaskArr = mds.getFeaturesMaskArrays(); - Pair p = b.featurizeSentencePairs(listSentencePair); + Pair p = b.featurizeSentencePairs(testPairHelper.getSentencePairs()); assertEquals(p.getFirst().length, 2); assertEquals(featuresArr[0], p.getFirst()[0]); assertEquals(featuresArr[1], p.getFirst()[1]); - //assertEquals(p.getSecond().length, 2); assertEquals(featuresMaskArr[0], p.getSecond()[0]); - //assertEquals(featuresMaskArr[1], p.getSecond()[1]); } /** - * Returns three multidatasets from bert iterator based on given max lengths and whether to prepend/append + * Returns three multidatasets (one from pair of sentences and the other two from single sentence lists) from bert iterator + * with given max lengths and whether to prepend/append * Idea is the sentence pair dataset can be constructed from the single sentence datasets - * First one is constructed from a sentence pair "I saw a girl with a telescope." & "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum" - * Second one is constructed from the left of the sentence pair i.e "I saw a girl with a telescope." - * Third one is constructed from the right of the sentence pair i.e "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum" */ - private Triple generateMultiDataSets(Triple maxLengths, boolean prependAppend) throws IOException { + private Triple generateMultiDataSets(Triple maxLengths, boolean prependAppend, int numSentences) throws IOException { BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c); int maxforPair = maxLengths.getFirst(); int maxPartOne = maxLengths.getSecond(); @@ -518,133 +491,155 @@ public class TestBertIterator extends BaseDL4JTest { BertIterator.Builder commonBuilder; commonBuilder = BertIterator.builder() .tokenizer(t) - .minibatchSize(1) + .minibatchSize(4) .featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID) .vocabMap(t.getVocab()) .task(BertIterator.Task.SEQ_CLASSIFICATION); - BertIterator shortLongPairFirstIter = commonBuilder + BertIterator pairIter = commonBuilder .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, prependAppend ? maxforPair + 3 : maxforPair) - .sentencePairProvider(new TestSentencePairProvider()) + .sentencePairProvider(new TestSentencePairsHelper(numSentences).getPairSentenceProvider()) .prependToken(prependAppend ? "[CLS]" : null) .appendToken(prependAppend ? "[SEP]" : null) .build(); - BertIterator shortFirstIter = commonBuilder + BertIterator leftIter = commonBuilder .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, prependAppend ? maxPartOne + 2 : maxPartOne) - .sentenceProvider(new TestSentenceProvider()) + .sentenceProvider(new TestSentenceHelper(numSentences).getSentenceProvider()) .prependToken(prependAppend ? "[CLS]" : null) .appendToken(prependAppend ? "[SEP]" : null) .build(); - BertIterator longFirstIter = commonBuilder + BertIterator rightIter = commonBuilder .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, prependAppend ? maxPartTwo + 1 : maxPartTwo) - .sentenceProvider(new TestSentenceProvider(true)) + .sentenceProvider(new TestSentenceHelper(true, numSentences).getSentenceProvider()) .prependToken(null) .appendToken(prependAppend ? "[SEP]" : null) .build(); - return new Triple<>(shortLongPairFirstIter.next(), shortFirstIter.next(), longFirstIter.next()); + return new Triple<>(pairIter.next(), leftIter.next(), rightIter.next()); } - private static class TestSentenceProvider implements LabeledSentenceProvider { + @Getter + private static class TestSentencePairsHelper { - private int pos = 0; - private boolean invert; + private List sentencesLeft; + private List sentencesRight; + private List> sentencePairs; + private List> tokenizedSentencesLeft; + private List> tokenizedSentencesRight; + private List labels; + private int shortL; + private int longL; + private int sentenceALen; + private int sentenceBLen; + private BertWordPieceTokenizerFactory tokenizer; + private CollectionLabeledPairSentenceProvider pairSentenceProvider; - private TestSentenceProvider() { - this.invert = false; + private TestSentencePairsHelper() throws IOException { + this(3); } - private TestSentenceProvider(boolean invert) { - this.invert = invert; - } - - @Override - public boolean hasNext() { - return pos < totalNumSentences(); - } - - @Override - public Pair nextSentence() { - Preconditions.checkState(hasNext()); - if (pos == 0) { - pos++; - if (!invert) return new Pair<>("I saw a girl with a telescope.", "positive"); - return new Pair<>("Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum", "negative"); - } else { - if (pos == 1) { - pos++; - if (!invert) return new Pair<>("Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum", "negative"); - return new Pair<>("I saw a girl with a telescope.", "positive"); + private TestSentencePairsHelper(int minibatchSize) throws IOException { + sentencesLeft = new ArrayList<>(); + sentencesRight = new ArrayList<>(); + sentencePairs = new ArrayList<>(); + labels = new ArrayList<>(); + tokenizedSentencesLeft = new ArrayList<>(); + tokenizedSentencesRight = new ArrayList<>(); + tokenizer = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c); + sentencesLeft.add(shortSentence); + sentencesRight.add(longSentence); + sentencePairs.add(new Pair<>(shortSentence, longSentence)); + labels.add("positive"); + if (minibatchSize > 1) { + sentencesLeft.add(longSentence); + sentencesRight.add(shortSentence); + sentencePairs.add(new Pair<>(longSentence, shortSentence)); + labels.add("negative"); + if (minibatchSize > 2) { + sentencesLeft.add(sentenceA); + sentencesRight.add(sentenceB); + sentencePairs.add(new Pair<>(sentenceA, sentenceB)); + labels.add("positive"); } - pos++; - if (!invert) - return new Pair<>("Goodnight noises everywhere", "positive"); - return new Pair<>("Goodnight moon", "positive"); } - } - - @Override - public void reset() { - pos = 0; - } - - @Override - public int totalNumSentences() { - return 3; - } - - @Override - public List allLabels() { - return Arrays.asList("positive", "negative"); - } - - @Override - public int numLabelClasses() { - return 2; + for (int i = 0; i < minibatchSize; i++) { + List tokensL = tokenizer.create(sentencesLeft.get(i)).getTokens(); + List tokensR = tokenizer.create(sentencesRight.get(i)).getTokens(); + if (i == 0) { + shortL = tokensL.size(); + longL = tokensR.size(); + } + if (i == 2) { + sentenceALen = tokensL.size(); + sentenceBLen = tokensR.size(); + } + tokenizedSentencesLeft.add(tokensL); + tokenizedSentencesRight.add(tokensR); + } + pairSentenceProvider = new CollectionLabeledPairSentenceProvider(sentencesLeft, sentencesRight, labels, null); } } - private static class TestSentencePairProvider implements LabeledPairSentenceProvider { + @Getter + private static class TestSentenceHelper { - private int pos = 0; + private List sentences; + private List> tokenizedSentences; + private List labels; + private int shortestL = 0; + private int longestL = 0; + private BertWordPieceTokenizerFactory tokenizer; + private CollectionLabeledSentenceProvider sentenceProvider; - @Override - public boolean hasNext() { - return pos < totalNumSentences(); + private TestSentenceHelper() throws IOException { + this(false, 2); } - @Override - public Triple nextSentencePair() { - Preconditions.checkState(hasNext()); - if (pos == 0) { - pos++; - return new Triple<>("I saw a girl with a telescope.", "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum", "positive"); - } else { - if (pos == 1) { - pos++; - return new Triple<>("Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum", "I saw a girl with a telescope.", "negative"); + private TestSentenceHelper(int minibatchSize) throws IOException { + this(false, minibatchSize); + } + + private TestSentenceHelper(boolean alternateOrder) throws IOException { + this(false, 3); + } + + private TestSentenceHelper(boolean alternateOrder, int minibatchSize) throws IOException { + sentences = new ArrayList<>(); + labels = new ArrayList<>(); + tokenizedSentences = new ArrayList<>(); + tokenizer = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c); + if (!alternateOrder) { + sentences.add(shortSentence); + labels.add("positive"); + if (minibatchSize > 1) { + sentences.add(longSentence); + labels.add("negative"); + if (minibatchSize > 2) { + sentences.add(sentenceA); + labels.add("positive"); + } + } + } else { + sentences.add(longSentence); + labels.add("negative"); + if (minibatchSize > 1) { + sentences.add(shortSentence); + labels.add("positive"); + if (minibatchSize > 2) { + sentences.add(sentenceB); + labels.add("positive"); + } } - pos++; - return new Triple<>("Goodnight noises everywhere", "Goodnight moon", "positive"); } - } - - @Override - public void reset() { - pos = 0; - } - - @Override - public int totalNumSentences() { - return 3; - } - - @Override - public List allLabels() { - return Arrays.asList("positive", "negative"); - } - - @Override - public int numLabelClasses() { - return 2; + for (int i = 0; i < sentences.size(); i++) { + List tokenizedSentence = tokenizer.create(sentences.get(i)).getTokens(); + if (i == 0) + shortestL = tokenizedSentence.size(); + if (tokenizedSentence.size() > longestL) + longestL = tokenizedSentence.size(); + if (tokenizedSentence.size() < shortestL) + shortestL = tokenizedSentence.size(); + tokenizedSentences.add(tokenizedSentence); + } + sentenceProvider = new CollectionLabeledSentenceProvider(sentences, labels, null); } } diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/java/org/deeplearning4j/ui/VertxUIServer.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/java/org/deeplearning4j/ui/VertxUIServer.java index 2aec66a77..64b033133 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/java/org/deeplearning4j/ui/VertxUIServer.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/java/org/deeplearning4j/ui/VertxUIServer.java @@ -254,6 +254,9 @@ public class VertxUIServer extends AbstractVerticle implements UIServer { uiEventRoutingThread = new Thread(new StatsEventRouterRunnable()); uiEventRoutingThread.setDaemon(true); uiEventRoutingThread.start(); + + String address = UIServer.getInstance().getAddress(); + log.info("Deeplearning4j UI server started at: {}", address); } private List extractArgsFromRoute(String path, RoutingContext rc) { @@ -317,7 +320,7 @@ public class VertxUIServer extends AbstractVerticle implements UIServer { @Override public String getAddress() { - return "https://localhost:" + server.actualPort(); + return "http://localhost:" + server.actualPort(); } @Override @@ -421,7 +424,7 @@ public class VertxUIServer extends AbstractVerticle implements UIServer { } private void runHelper() throws Exception { - log.info("VertxUIServer.StatsEventRouterRunnable started"); + log.trace("VertxUIServer.StatsEventRouterRunnable started"); //Idea: collect all event stats, and route them to the appropriate modules while (!shutdown.get()) { diff --git a/libnd4j/blas/NDArray.h b/libnd4j/blas/NDArray.h index cfad05b49..d89ef8c72 100644 --- a/libnd4j/blas/NDArray.h +++ b/libnd4j/blas/NDArray.h @@ -1256,6 +1256,9 @@ namespace nd4j { FORCEINLINE T& t(const Nd4jLong i, const Nd4jLong j); template FORCEINLINE T& t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k); + template + FORCEINLINE T& t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong w); + /** * returns array element with given index @@ -1268,6 +1271,8 @@ namespace nd4j { FORCEINLINE T t(const Nd4jLong i, const Nd4jLong j) const; template FORCEINLINE T t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) const; + template + FORCEINLINE T t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong w) const; /** @@ -1711,7 +1716,7 @@ namespace nd4j { if (isEmpty()) return false; - return shape::isMatrix(this->_shapeInfo); + return 0 != shape::isMatrix(this->_shapeInfo); } ////////////////////////////////////////////////////////////////////////// @@ -1751,7 +1756,7 @@ namespace nd4j { ////////////////////////////////////////////////////////////////////////// bool NDArray::isScalar() const { - return shape::isScalar(this->_shapeInfo); + return 0 != shape::isScalar(this->_shapeInfo); } ////////////////////////////////////////////////////////////////////////// @@ -2082,7 +2087,7 @@ template T& NDArray::t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) { if (rankOf() != 3 || i >= sizeAt(0) || j >= sizeAt(1) || k >= sizeAt(2)) - throw std::invalid_argument("NDArray::t(i,j,k): one of input indexes is out of array length or rank!=2 !"); + throw std::invalid_argument("NDArray::t(i,j,k): one of input indexes is out of array length or rank!=3!"); if (DataTypeUtils::fromT() != _dataType) throw std::invalid_argument("NDArray::t(i,j,k): type of array is not equal to template type T!"); @@ -2095,6 +2100,23 @@ T& NDArray::t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) { return *(reinterpret_cast(bufferWithOffset(offset))); } +template +T& NDArray::t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong w) { + + if (rankOf() != 4 || i >= sizeAt(0) || j >= sizeAt(1) || k >= sizeAt(2), w >= sizeAt(3)) + throw std::invalid_argument("NDArray::t(i,j,k,w): one of input indexes is out of array length or rank!=4 !"); + if (DataTypeUtils::fromT() != _dataType) + throw std::invalid_argument("NDArray::t(i,j,k,w): type of array is not equal to template type T!"); + + if(!isActualOnHostSide()) + syncToHost(); + + Nd4jLong coords[4] = {i, j, k, w}; + auto offset = shape::getOffset(getShapeInfo(), coords); + tickWriteHost(); + return *(reinterpret_cast(bufferWithOffset(offset))); +} + //////////////////////////////////////////////////////////////////////// template T NDArray::t(const Nd4jLong i) const { @@ -2133,7 +2155,7 @@ T NDArray::t(const Nd4jLong i, const Nd4jLong j) const { T NDArray::t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) const { if (rankOf() != 3 || i >= sizeAt(0) || j >= sizeAt(1) || k >= sizeAt(2)) - throw std::invalid_argument("NDArray::t(i,j,k): one of input indexes is out of array length or rank!=2 !"); + throw std::invalid_argument("NDArray::t(i,j,k): one of input indexes is out of array length or rank!=3!"); if (DataTypeUtils::fromT() != _dataType) throw std::invalid_argument("NDArray::t(i,j,k): type of array is not equal to template type T!"); @@ -2146,6 +2168,23 @@ T NDArray::t(const Nd4jLong i, const Nd4jLong j) const { return *(reinterpret_cast(bufferWithOffset(offset))); } + template + T NDArray::t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong w) const { + + if (rankOf() != 4 || i >= sizeAt(0) || j >= sizeAt(1) || k >= sizeAt(2) || w >= sizeAt(3)) + throw std::invalid_argument("NDArray::t(i,j,k,w): one of input indexes is out of array length or rank!=4!"); + if (DataTypeUtils::fromT() != _dataType) + throw std::invalid_argument("NDArray::t(i,j,k,w): type of array is not equal to template type T!"); + + if(!isActualOnHostSide()) + syncToHost(); + + Nd4jLong coords[4] = {i, j, k, w}; + auto offset = shape::getOffset(getShapeInfo(), coords); + tickReadHost(); + return *(reinterpret_cast(bufferWithOffset(offset))); + } + #ifndef __JAVACPP_HACK__ //////////////////////////////////////////////////////////////////////// std::shared_ptr NDArray::getDataBuffer() const { diff --git a/libnd4j/blas/NDArray.hpp b/libnd4j/blas/NDArray.hpp index df358b64f..5adff5853 100644 --- a/libnd4j/blas/NDArray.hpp +++ b/libnd4j/blas/NDArray.hpp @@ -2348,7 +2348,7 @@ NDArray NDArray::operator-(const NDArray& other) const { NDArray result(getShapeInfo(), DataTypeUtils::pickPairwiseResultType(getShapeInfo(), other.getShapeInfo()), false, getContext()); NDArray::prepareSpecialUse({&result}, {this, &other}); - NativeOpExecutioner::execPairwiseTransform(getContext(), nd4j::pairwise::Subtract, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), getSpecialShapeInfo(), nullptr); + NativeOpExecutioner::execPairwiseTransform(getContext(), nd4j::pairwise::Subtract, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), result.getSpecialShapeInfo(), nullptr); NDArray::registerSpecialUse({&result}, {this, &other}); return result; @@ -2394,7 +2394,7 @@ NDArray NDArray::operator/(const NDArray& other) const { NDArray result(getShapeInfo(), DataTypeUtils::pickPairwiseResultType(getShapeInfo(), other.getShapeInfo()), false, getContext()); NDArray::prepareSpecialUse({&result}, {this, &other}); - NativeOpExecutioner::execPairwiseTransform(getContext(), nd4j::pairwise::Divide, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), getSpecialShapeInfo(), nullptr); + NativeOpExecutioner::execPairwiseTransform(getContext(), nd4j::pairwise::Divide, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), result.getSpecialShapeInfo(), nullptr); NDArray::registerSpecialUse({&result}, {this, &other}); return result; diff --git a/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool_with_argmax.cpp b/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool_with_argmax.cpp index bf5a3eb6e..5fe7455fc 100644 --- a/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool_with_argmax.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool_with_argmax.cpp @@ -46,7 +46,7 @@ namespace nd4j { getOpDescriptor() ->setAllowedInputTypes(nd4j::DataType::ANY) ->setAllowedOutputTypes(0, DataType::INHERIT) - ->setAllowedOutputTypes(1, DataType::INT64); + ->setAllowedOutputTypes(1, {ALL_INTS}); } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/resize_bicubic.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/resize_bicubic.cpp index 0c1aeba61..99053561c 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/resize_bicubic.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/resize_bicubic.cpp @@ -35,6 +35,8 @@ namespace nd4j { int width; int height; auto inRank = image->rankOf(); + if (output->isEmpty()) return Status::OK(); + REQUIRE_TRUE(inRank == 3 || inRank == 4, 0, "resize_bicubic: Source tensor should have rank 4, but %i given.", inRank); REQUIRE_TRUE(output->rankOf() == inRank, 0, "resize_bicubic: Source tensor and output should have the same rank, but %i and %i given.", inRank, output->rankOf()); REQUIRE_TRUE(size->rankOf() == 1, size->lengthOf() == 2, 0, "resize_bicubic: Resize params is a pair of values, not %i.", size->lengthOf()); @@ -57,7 +59,7 @@ namespace nd4j { if (block.numB()> 1) halfPixelAlign = block.getBArguments()->at(1); } - REQUIRE_TRUE(halfPixelAlign == false || halfPixelAlign == true && alignCorners == false, 0, "resize_bicubic: half pixel align can be used only with non-aligned corners"); + REQUIRE_TRUE(!halfPixelAlign || (halfPixelAlign && !alignCorners), 0, "resize_bicubic: `half_pixel_centers' should be false or true only when `align_corners' is false"); auto source = inRank == 4?image->reshape(image->ordering(), {image->sizeAt(0), image->sizeAt(1), image->sizeAt(2), image->sizeAt(3)}):image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)}); auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}):output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)}); diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/resize_linear.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/resize_linear.cpp index f60f14fdc..f1f79b08f 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/resize_linear.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/resize_linear.cpp @@ -32,8 +32,10 @@ namespace nd4j { NDArray* output = OUTPUT_VARIABLE(0); int width; int height; - bool center = false; // - default value + bool alignCorners = false; // - default value auto inRank = image->rankOf(); + if (output->isEmpty()) return Status::OK(); + REQUIRE_TRUE( inRank == 4 || inRank == 3, 0, "resize_bilinear: input image should be 4D " "tensor, but input has rank %i", image->rankOf()); @@ -46,21 +48,25 @@ namespace nd4j { auto newImageSize = INPUT_VARIABLE(1); REQUIRE_TRUE(newImageSize->lengthOf() == 2, 0, "resize_bilinear: Resize params is a pair of values, not %i.", newImageSize->lengthOf()); REQUIRE_TRUE(block.numI() <= 1, 0, "resize_bilinear: Resize params already given by the second param. Int params are expensive."); - width = newImageSize->e(0); - height = newImageSize->e(1); - if (block.numI() == 1) { - center = 0 != INT_ARG(0); - } + height = newImageSize->e(0); + width = newImageSize->e(1); } else { - REQUIRE_TRUE(block.numI() <= 3, 0, "resize_bilinear: Neither resize width nor height are provided."); - width = INT_ARG(0); - height = INT_ARG(1); - if (block.numI() == 3) - center = 0 != INT_ARG(2); + REQUIRE_TRUE(block.numI() > 1, 0, "resize_bilinear: Neither resize width nor height are provided."); + height = INT_ARG(0); + width = INT_ARG(1); } - return helpers::resizeBilinearFunctor(block.launchContext(), inRank==4?image:&source, width, height, center, inRank == 4?output:&target); + if (block.numB() > 0) + alignCorners = B_ARG(0); + bool halfPixelCenter = false; + + if (block.numB() > 1) + halfPixelCenter = B_ARG(1); + + REQUIRE_TRUE(!halfPixelCenter || (halfPixelCenter && !alignCorners), 0, "resize_bilinear: `half_pixel_centers' should be false or true only when `align_corners' is false"); + + return helpers::resizeBilinearFunctor(block.launchContext(), inRank==4?image:&source, width, height, alignCorners, halfPixelCenter, inRank == 4 ? output : &target); } DECLARE_SHAPE_FN(resize_bilinear) { @@ -83,7 +89,7 @@ namespace nd4j { height = newImageSize->e(1); } else { - REQUIRE_TRUE(block.numI() <= 3, 0, "resize_bilinear: Neither resize width nor height are provided."); + REQUIRE_TRUE(block.numI() == 2, 0, "resize_bilinear: Neither resize width nor height are provided."); width = INT_ARG(0); height = INT_ARG(1); } @@ -101,7 +107,12 @@ namespace nd4j { outputShape[2] = height; outputShape[3] = in[3]; } - ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); + if (DataTypeUtils::isR(ArrayOptions::dataType(in))) { + ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); + } + else { + ShapeUtils::updateStridesAndType(outputShape, DataType::FLOAT32, shape::order(in)); + } shapeList->push_back(CONSTANT(outputShape)); return shapeList; diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/resize_neighbor.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/resize_neighbor.cpp index 8733cb9d5..6c18e61e1 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/resize_neighbor.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/resize_neighbor.cpp @@ -31,35 +31,40 @@ namespace nd4j { auto image = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); + auto inRank = image->rankOf(); int width; int height; - bool center = false; // - default value + bool alignCorners = false; // - default value + if (output->isEmpty()) return Status::OK(); if (block.width() > 1) { auto newImageSize = INPUT_VARIABLE(1); REQUIRE_TRUE(newImageSize->lengthOf() == 2, 0, "resize_nearest_neighbor: Resize params is a pair of values, not %i.", newImageSize->lengthOf()); REQUIRE_TRUE(block.numI() <= 1, 0, "resize_nearest_neighbor: Resize params already given by the second param. Int params are expensive."); - width = newImageSize->e(0); - height = newImageSize->e(1); - if (block.numI() == 1) { - center = 0 != INT_ARG(0); - } + height = newImageSize->e(0); + width = newImageSize->e(1); } else { - REQUIRE_TRUE(block.numI() <= 3, 0, "resize_nearest_neighbor: Neither resize width nor height are provided."); - width = INT_ARG(0); - height = INT_ARG(1); - if (block.numI() == 3) - center = 0 != INT_ARG(2); + REQUIRE_TRUE(block.numI() == 2, 0, "resize_nearest_neighbor: Neither resize width nor height are provided."); + height = INT_ARG(0); + width = INT_ARG(1); } - auto inRank = image->rankOf(); + if (block.numB() > 0) + alignCorners = B_ARG(0); + bool halfPixelCenter = false; + + if (block.numB() > 1) + halfPixelCenter = B_ARG(1); + REQUIRE_TRUE(width <= (1 << 24) || height <= (1 << 24), 0, "resize_nearest_neighbour: the image resize should be limited to 2^24 pixels both for height and width, but %d and %d were given.", height, width); REQUIRE_TRUE(inRank == 4 || inRank == 3, 0, "resize_nearest_neighbor: Input should be 4D tensor, but rank %i occured"); REQUIRE_TRUE(inRank == output->rankOf(), 0, "resize_nearest_neighbor: Input and output ranks should be equals, but %i and %i occured.", inRank, output->rankOf()); REQUIRE_TRUE(image->dataType() == output->dataType(), 0, "resize_nearest_neighbor: Input and output types should be the same, but `%s' occured instead.", DataTypeUtils::asString(output->dataType()).c_str()); - auto source = inRank == 4?*image:image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)}); + REQUIRE_TRUE(!halfPixelCenter || (halfPixelCenter && !alignCorners), 0, "resize_nearest_neighbor: `half_pixel_centers' should be false or true only when `align_corners' is false"); + REQUIRE_TRUE(((alignCorners && height > 2) || (height > 0)) && ((alignCorners && width > 1) || (width > 0)), 0, "resize_nearest_neighbor: Wrong input or output size to resize (width = %d, height = %d)", width, height); + auto source = inRank == 4?*image:image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)}); auto target = inRank == 4?*output:output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)}); - return helpers::resizeNeighborFunctor(block.launchContext(), inRank==4?image:&source, width, height, center, inRank == 4?output:&target); + return helpers::resizeNeighborFunctor(block.launchContext(), inRank==4?image:&source, width, height, alignCorners, halfPixelCenter, inRank == 4 ? output : &target); } DECLARE_SHAPE_FN(resize_nearest_neighbor) { diff --git a/libnd4j/include/ops/declarable/helpers/cpu/image_resize.cpp b/libnd4j/include/ops/declarable/helpers/cpu/image_resize.cpp index d334caed2..16ddd17da 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/image_resize.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/image_resize.cpp @@ -120,6 +120,27 @@ namespace helpers { } }; + // Half pixel scaler scales assuming that the pixel centers are at 0.5, i.e. the +// floating point coordinates of the top,left pixel is 0.5,0.5. + struct HalfPixelScalerNN { + HalfPixelScalerNN(){}; + inline float operator()(const int x, const float scale) const { + // Note that we subtract 0.5 from the return value, as the existing bilinear + // sampling code etc assumes pixels are in the old coordinate system. + return (static_cast(x) + 0.5f) * scale; + } + }; + +// Older incorrect scaling method that causes all resizes to have a slight +// translation leading to inconsistent results. For example, a flip then a +// resize gives different results then a resize then a flip. + struct LegacyScaler { + LegacyScaler(){}; + inline float operator()(const int x, const float scale) const { + return static_cast(x) * scale; + } + }; + struct WeightsAndIndices { float _weight0; float _weight1; @@ -133,7 +154,8 @@ namespace helpers { int _advance; // advance value. }; - inline void computeInterpolationWeights(Nd4jLong outSize, + template + inline void computeInterpolationWeights(const Scaler scaler, Nd4jLong outSize, Nd4jLong inSize, double scale, BilinearInterpolationData *interpolationData) { @@ -143,10 +165,12 @@ namespace helpers { auto func = PRAGMA_THREADS_FOR { for (auto k = start; k < stop; k++) { auto i = (outSize - k - 1); - double in = i * scale; - interpolationData[i]._bottomIndex = static_cast(in); - interpolationData[i]._topIndex = nd4j::math::nd4j_min(interpolationData[i]._bottomIndex + 1, inSize - 1); - interpolationData[i]._interpolarValue = in - interpolationData[i]._bottomIndex; + double const in = scaler(i, scale); + double const in_f = nd4j::math::nd4j_floor(in); + double const in_c = nd4j::math::nd4j_ceil(in); + interpolationData[i]._bottomIndex = nd4j::math::nd4j_max(static_cast(in_f), (Nd4jLong)0LL);//static_cast(in); + interpolationData[i]._topIndex = nd4j::math::nd4j_min(static_cast(in_c), inSize - 1); + interpolationData[i]._interpolarValue = in - in_f; } }; samediff::Threads::parallel_for(func, 0, outSize); @@ -156,29 +180,29 @@ namespace helpers { * Computes the bilinear interpolation from the appropriate 4 float points * and the linear interpolation weights. */ - static void - resizeImage(NDArray const *images, Nd4jLong batchSize, Nd4jLong inHeight, Nd4jLong inWidth, Nd4jLong outHeight, - Nd4jLong outWidth, Nd4jLong channels, - std::vector const& xs, - std::vector const& ys, - NDArray *output); +// static void +// resizeImage(NDArray const *images, Nd4jLong batchSize, Nd4jLong inHeight, Nd4jLong inWidth, Nd4jLong outHeight, +// Nd4jLong outWidth, Nd4jLong channels, +// std::vector const& xs, +// std::vector const& ys, +// NDArray *output); - template + template static void - resizeImage_(NDArray const *images, Nd4jLong batchSize, Nd4jLong inHeight, Nd4jLong inWidth, Nd4jLong outHeight, + resizeImage_(T const* pInputBuf, Nd4jLong batchSize, Nd4jLong inHeight, Nd4jLong inWidth, Nd4jLong outHeight, Nd4jLong outWidth, Nd4jLong channels, std::vector const &xs, std::vector const &ys, - NDArray *output) { + Z* pOutputBuf) { Nd4jLong inRowSize = inWidth * channels; Nd4jLong inBatchNumValues = inHeight * inRowSize; Nd4jLong outRowSize = outWidth * channels; - T const *pInputBuf = images->getDataBuffer()->primaryAsT(); // this works only with 'c' direction +// T const *pInputBuf = images->getDataBuffer()->primaryAsT(); // this works only with 'c' direction BilinearInterpolationData const* xsPtr = xs.data(); - T* pOutputBuf = output->dataBuffer()->primaryAsT(); +// T* pOutputBuf = output->dataBuffer()->primaryAsT(); auto computeBilinear = [](double topLeft, double topRight, double bottomLeft, double bottomRight, double xVal, double yVal) { @@ -214,8 +238,12 @@ namespace helpers { samediff::Threads::parallel_tad(func, 0, batchSize); } - template - static int resizeBilinearFunctor_(NDArray const *images, int width, int height, bool center, NDArray *output) { + template + static int resizeBilinearFunctor_(NDArray const *images, int const width, int const height, bool const alignCorners, + bool const halfPixelCenter, NDArray *output) { + ImageResizerState st(alignCorners, halfPixelCenter); + st.validateAndCalculateOutputSize(images, width, height); + const Nd4jLong batchSize = images->sizeAt(0); const Nd4jLong inHeight = images->sizeAt(1); const Nd4jLong inWidth = images->sizeAt(2); @@ -230,28 +258,20 @@ namespace helpers { return ND4J_STATUS_OK; } - // Special case for TF compatibility - if((center && inHeight < 2) || (center && inWidth < 2)){ - center = false; - } - - if ((center && inHeight < 2) || (inHeight < 1) || (outHeight < 1) || (center && outHeight < 2) || - (center && inWidth < 2) || (inWidth < 1) || (outWidth < 1) || (center && outWidth < 2)) { - // wrong input data - nd4j_printf("image.resize_bilinear: Wrong input or output size to resize\n", ""); - return ND4J_STATUS_BAD_ARGUMENTS; - } - float heightScale = center ? (inHeight - 1.f) / double(outHeight - 1.f) : (inHeight / float(outHeight)); - float widthScale = center ? (inWidth - 1.f) / double(outWidth - 1.f) : (inWidth / float(outWidth)); - std::vector ys(outHeight + 1); std::vector xs(outWidth + 1); + if (halfPixelCenter) { + computeInterpolationWeights(HalfPixelScaler(), outHeight, inHeight, st.heightScale, + ys.data()); + computeInterpolationWeights(HalfPixelScaler(), outWidth, inWidth, st.widthScale, xs.data()); - // Compute the cached interpolation weights on the x and y dimensions. - computeInterpolationWeights(outHeight, inHeight, heightScale, - ys.data()); - computeInterpolationWeights(outWidth, inWidth, widthScale, xs.data()); - + } + else { + // Compute the cached interpolation weights on the x and y dimensions. + computeInterpolationWeights(LegacyScaler(), outHeight, inHeight, st.heightScale, + ys.data()); + computeInterpolationWeights(LegacyScaler(), outWidth, inWidth, st.widthScale, xs.data()); + } int xsSize = xs.size(); // Scale x interpolation weights to avoid a multiplication during iteration. auto func = PRAGMA_THREADS_FOR { @@ -262,71 +282,84 @@ namespace helpers { }; samediff::Threads::parallel_for(func, 0, xsSize); - resizeImage(images, batchSize, inHeight, inWidth, outHeight, outWidth, channels, xs, ys, output); + resizeImage_(images->getDataBuffer()->primaryAsT(), batchSize, inHeight, inWidth, outHeight, outWidth, channels, xs, ys, output->dataBuffer()->primaryAsT()); return ND4J_STATUS_OK; } - template - int resizeNeighborFunctor_(NDArray const *images, int width, int height, bool center, NDArray *output) { - const Nd4jLong batchSize = images->sizeAt(0); - const Nd4jLong inHeight = images->sizeAt(1); - const Nd4jLong inWidth = images->sizeAt(2); - const Nd4jLong channels = images->sizeAt(3); + template + void resizeNeighbor(ImageResizerState const& st, NDArray const *images, bool const alignCorners, bool const halfPixelCenter, NDArray *output) { + const Nd4jLong batchSize = st.batchSize; + const Nd4jLong inHeight = st.inHeight; + const Nd4jLong inWidth = st.inWidth; + const Nd4jLong channels = st.channels; - const Nd4jLong outHeight = output->sizeAt(1); - const Nd4jLong outWidth = output->sizeAt(2); - - // Handle no-op resizes efficiently. - if (outHeight == inHeight && outWidth == inWidth) { - output->assign(images); - return Status::OK(); - } - - if ((center && inHeight < 2) || (inHeight < 1) || (outHeight < 1) || (center && outHeight < 2) || - (center && inWidth < 2) || (inWidth < 1) || (outWidth < 1) || (center && outWidth < 2)) { - // wrong input data - nd4j_printf("image.resize_nearest_neighbor: Wrong input or output size to resize\n", ""); - return ND4J_STATUS_BAD_ARGUMENTS; - } - double heightScale = center ? (inHeight - 1.) / double(outHeight - 1.0) : (inHeight / double(outHeight)); - double widthScale = center ? (inWidth - 1.) / double(outWidth - 1.0) : (inWidth / double(outWidth)); + const Nd4jLong outHeight = st.outHeight; + const Nd4jLong outWidth = st.outWidth; + Scaler scaler; auto func = PRAGMA_THREADS_FOR_2D { for (auto b = start_x; b < stop_x; b += inc_x) { for (auto y = start_y; y < stop_y; y += inc_y) { - Nd4jLong inY = nd4j::math::nd4j_min((center) ? static_cast(nd4j::math::p_round(y * heightScale)) : static_cast(nd4j::math::p_floor(y * heightScale)), inHeight - 1); - + auto posY = alignCorners ? static_cast(nd4j::math::p_round(scaler(y, st.heightScale))) : static_cast(nd4j::math::p_floor(scaler(y, st.heightScale))); + Nd4jLong inY = nd4j::math::nd4j_min(posY, inHeight - 1); + if (halfPixelCenter) { + inY = nd4j::math::nd4j_max(0LL, inY); + } for (auto x = 0; x < outWidth; ++x) { - Nd4jLong inX = nd4j::math::nd4j_min((center) ? static_cast(nd4j::math::p_round(x * widthScale)) : static_cast(nd4j::math::p_floor(x * widthScale)),inWidth - 1); + auto posX = alignCorners ? static_cast(nd4j::math::p_round(scaler(x, st.widthScale))) : static_cast(nd4j::math::p_floor(scaler(x, st.widthScale))); + Nd4jLong inX = nd4j::math::nd4j_min(posX,inWidth - 1); + if (halfPixelCenter) { + inX = nd4j::math::nd4j_max(0LL, inX); + } + // copy pixel over all channels for (auto e = 0; e < channels; e++) - output->p(b, y, x, e, images->e(b, inY, inX, e)); + output->t(b, y, x, e) = images->t(b, inY, inX, e); } } } }; samediff::Threads::parallel_for(func, 0, batchSize, 1, 0, outHeight, 1); + } + + template + int resizeNeighborFunctor_(NDArray const *images, int const width, int const height, bool const alignCorners, bool const halfPixelCenter, NDArray *output) { + ImageResizerState st(alignCorners, halfPixelCenter); + st.validateAndCalculateOutputSize(images, width, height); + + // Handle no-op resizes efficiently. + if (output->sizeAt(1) == images->sizeAt(1) && output->sizeAt(2) == images->sizeAt(2)) { + output->assign(images); + return Status::OK(); + } + + if (halfPixelCenter) + resizeNeighbor(st, images, alignCorners, true, output); + else + resizeNeighbor(st, images, alignCorners, false, output); return Status::OK(); } - void resizeImage(NDArray const *images, Nd4jLong batchSize, Nd4jLong inHeight, Nd4jLong inWidth, Nd4jLong outHeight, - Nd4jLong outWidth, Nd4jLong channels, - std::vector const &xs, - std::vector const &ys, - NDArray *output) { - BUILD_SINGLE_SELECTOR(images->dataType(), resizeImage_, - (images, batchSize, inHeight, inWidth, outHeight, outWidth, channels, xs, ys, output), - LIBND4J_TYPES); +// void resizeImage(NDArray const *images, Nd4jLong batchSize, Nd4jLong inHeight, Nd4jLong inWidth, Nd4jLong outHeight, +// Nd4jLong outWidth, Nd4jLong channels, +// std::vector const &xs, +// std::vector const &ys, +// NDArray *output) { +// BUILD_DOUBLE_SELECTOR(images->dataType(), output->dataType(), resizeImage_, +// (images, batchSize, inHeight, inWidth, outHeight, outWidth, channels, xs, ys, output), +// NUMERIC_TYPES, FLOAT_TYPES); +// } + + int resizeBilinearFunctor(nd4j::LaunchContext * context, NDArray const *images, int const width, int const height, + bool const alignCorners, bool const halfPixelCenter, NDArray *output) { + BUILD_DOUBLE_SELECTOR(images->dataType(), output->dataType(), return resizeBilinearFunctor_, + (images, width, height, alignCorners, halfPixelCenter, output), NUMERIC_TYPES, FLOAT_TYPES); } - int resizeBilinearFunctor(nd4j::LaunchContext * context, NDArray const *images, int width, int height, bool center, NDArray *output) { - BUILD_SINGLE_SELECTOR(images->dataType(), return resizeBilinearFunctor_, - (images, width, height, center, output), LIBND4J_TYPES); - } - - int resizeNeighborFunctor(nd4j::LaunchContext * context, NDArray const *images, int width, int height, bool center, NDArray *output) { + int resizeNeighborFunctor(nd4j::LaunchContext * context, NDArray const *images, int const width, int const height, + bool const alignCorners, bool const halfPixelCenter, NDArray *output) { BUILD_SINGLE_SELECTOR(images->dataType(), return resizeNeighborFunctor_, - (images, width, height, center, output), LIBND4J_TYPES); + (images, width, height, alignCorners, halfPixelCenter, output), LIBND4J_TYPES); } @@ -586,16 +619,6 @@ namespace helpers { } } -// Older incorrect scaling method that causes all resizes to have a slight -// translation leading to inconsistent results. For example, a flip then a -// resize gives different results then a resize then a flip. - struct LegacyScaler { - LegacyScaler(){}; - inline float operator()(const int x, const float scale) const { - return static_cast(x) * scale; - } - }; - static void computeXWeightsAndIndices(const ImageResizerState& resizer_state, const bool half_pixel_centers, std::vector* x_wais) { @@ -847,7 +870,7 @@ namespace helpers { // simplified bicubic resize without antialiasing // template - int resizeBicubicFunctorA_(nd4j::LaunchContext * context, NDArray const* image, int width, int height, + int resizeBicubicFunctorA_(nd4j::LaunchContext * context, NDArray const* image, int const width, int const height, bool const alignCorners, bool const halfPixelAlign, NDArray* output) { ImageResizerState st(alignCorners, halfPixelAlign); // align_corners, half_pixel_align int res = st.validateAndCreateOutput(image, width, height); @@ -856,17 +879,17 @@ namespace helpers { return res; } - int resizeBicubicFunctorA(nd4j::LaunchContext * context, NDArray const* image, int width, int height, + int resizeBicubicFunctorA(nd4j::LaunchContext * context, NDArray const* image, int const width, int const height, bool const alignCorners, bool const halfPixelAlign, NDArray* output) { BUILD_SINGLE_SELECTOR(image->dataType(), return resizeBicubicFunctorA_, (context, image, width, height, alignCorners, halfPixelAlign, output), NUMERIC_TYPES); } // ------------------------------------------------------------------------------------------------------------------ // - int resizeFunctor(nd4j::LaunchContext * context, NDArray const* image, int width, int height, + int resizeFunctor(nd4j::LaunchContext * context, NDArray const* image, int const width, int const height, ImageResizeMethods method, bool preserveAspectRatio, bool antialias, NDArray* output) { switch (method) { - case kResizeBilinear: return resizeBilinearFunctor(context, image, width, height, false, output); break; - case kResizeNearest: return resizeNeighborFunctor(context, image, width, height, false, output); break; + case kResizeBilinear: return resizeBilinearFunctor(context, image, width, height, false, false, output); break; + case kResizeNearest: return resizeNeighborFunctor(context, image, width, height, false, false, output); break; case kResizeBicubic: return resizeBicubicFunctor(context, image, width, height, preserveAspectRatio, antialias, output); break; case kResizeLanczos5: case kResizeGaussian: diff --git a/libnd4j/include/ops/declarable/helpers/cuda/image_resize.cu b/libnd4j/include/ops/declarable/helpers/cuda/image_resize.cu index 0541742ca..4f025d851 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/image_resize.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/image_resize.cu @@ -13,6 +13,20 @@ * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ // // @author sgazeos@gmail.com @@ -32,6 +46,38 @@ namespace helpers { // https://en.wikipedia.org/wiki/Bilinear_interpolation) double interpolarValue; }; + +// Older incorrect scaling method that causes all resizes to have a slight +// translation leading to inconsistent results. For example, a flip then a +// resize gives different results then a resize then a flip. + struct LegacyScaler { + _CUDA_HD LegacyScaler(){}; + inline _CUDA_HD float operator()(const int x, const float scale) const { + return static_cast(x) * scale; + } + }; + +// Half pixel scaler scales assuming that the pixel centers are at 0.5, i.e. the +// floating point coordinates of the top,left pixel is 0.5,0.5. + struct HalfPixelScaler { + _CUDA_HD HalfPixelScaler(){}; + inline _CUDA_HD float operator()(const int x, const float scale) const { + // Note that we subtract 0.5 from the return value, as the existing bilinear + // sampling code etc assumes pixels are in the old coordinate system. + return (static_cast(x) + 0.5f) * scale - 0.5f; + } + }; + + + // Utility functions + // calculateResizeScale determines the float scaling factor. + inline float calculateResizeScale(Nd4jLong inSize, Nd4jLong outSize, + bool alignCorners) { + return (alignCorners && outSize > 1) + ? (inSize - 1) / static_cast(outSize - 1) + : inSize / static_cast(outSize); + } + //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // computeInterpolationWeights kernel // outSize - output length @@ -39,6 +85,7 @@ namespace helpers { // scale - input scale // interporationData - result // + template static __global__ void computeInterpolationWeights(Nd4jLong outSize, Nd4jLong inSize, double scale, @@ -48,12 +95,18 @@ namespace helpers { interpolationData[outSize].topIndex = 0; auto tid = blockIdx.x * blockDim.x + threadIdx.x; auto step = blockDim.x * gridDim.x; - + Scaler scaler; for (Nd4jLong i = outSize - tid; i >= 0; i -= step) { - double in = i * scale; - interpolationData[i].bottomIndex = static_cast(in); - interpolationData[i].topIndex = nd4j::math::nd4j_min(interpolationData[i].bottomIndex + 1, inSize - 1); - interpolationData[i].interpolarValue = in - interpolationData[i].bottomIndex; + double in = scaler(i, scale); +// interpolationData[i].bottomIndex = static_cast(in); +// interpolationData[i].topIndex = nd4j::math::nd4j_min(interpolationData[i].bottomIndex + 1, inSize - 1); +// interpolationData[i].interpolarValue = in - interpolationData[i].bottomIndex; + double const in_f = nd4j::math::p_floor(in); + double const in_c = nd4j::math::p_ceil(in); + interpolationData[i].bottomIndex = nd4j::math::nd4j_max(static_cast(in_f), (Nd4jLong)0LL);//static_cast(in); + interpolationData[i].topIndex = nd4j::math::nd4j_min(static_cast(in_c), inSize - 1); + interpolationData[i].interpolarValue = in - in_f; + if (channels) { math::atomics::nd4j_atomicMul(&interpolationData[i].bottomIndex, channels); math::atomics::nd4j_atomicMul(&interpolationData[i].topIndex, channels); @@ -72,31 +125,33 @@ namespace helpers { //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // resize image with bilinear interpolation algorithm kernel // - template - static __global__ void resizeImageKernel(T const* input, Nd4jLong const* inputShape, T* outputYptr, Nd4jLong* outputShape, Nd4jLong batchSize, - Nd4jLong outWidth, Nd4jLong outHeight, Nd4jLong channels, Nd4jLong inRowSize, Nd4jLong outRowSize, Nd4jLong inBatchNumValues, - BilinearInterpolationData* xs_, BilinearInterpolationData* ys_) { + template + static __global__ void resizeImageKernel(T const* input, Nd4jLong const* inputShape, Z* outputYptr, + Nd4jLong* outputShape, Nd4jLong batchSize, Nd4jLong outWidth, Nd4jLong outHeight, Nd4jLong channels, + Nd4jLong inRowSize, Nd4jLong outRowSize, Nd4jLong inBatchNumValues, + BilinearInterpolationData* xs_, BilinearInterpolationData* ys_) { for (auto batch = blockIdx.x; batch < batchSize; batch += gridDim.x ) { // blockIdx.x as batch index auto pX = input + batch * inBatchNumValues; for (Nd4jLong y = threadIdx.x; y < outHeight; y += blockDim.x) { - const T *ys_input_lower_ptr = pX + ys_[y].bottomIndex * inRowSize; - const T *ys_input_upper_ptr = pX + ys_[y].topIndex * inRowSize; + const T* ys_input_lower_ptr = pX + ys_[y].bottomIndex * inRowSize; + const T* ys_input_upper_ptr = pX + ys_[y].topIndex * inRowSize; double yVal = ys_[y].interpolarValue; auto pZ = outputYptr + (batch * outHeight + y) * outRowSize; - for (Nd4jLong x = threadIdx.y; x < outWidth; x += blockDim.y) { + for (Nd4jLong x = 0; x < outWidth; x++) { auto xsBottom = xs_[x].bottomIndex; auto xsTop = xs_[x].topIndex; auto xVal = xs_[x].interpolarValue; // process interpolation for all channels - for (int c = threadIdx.z; c < channels; c += blockDim.z) { - double topLeft(ys_input_lower_ptr[xsBottom + c]); - double topRight(ys_input_lower_ptr[xsTop + c]); - double bottomLeft(ys_input_upper_ptr[xsBottom + c]); - double bottomRight(ys_input_upper_ptr[xsTop + c]); - double top = topLeft + (topRight - topLeft) * xVal; - double bottom = bottomLeft + (bottomRight - bottomLeft) * xVal; - pZ[x * channels + c] = T(top + (bottom - top) * yVal); + for (int c = 0; c < channels; c++) { + Z topLeft(ys_input_lower_ptr[xsBottom + c]); + Z topRight(ys_input_lower_ptr[xsTop + c]); + Z bottomLeft(ys_input_upper_ptr[xsBottom + c]); + Z bottomRight(ys_input_upper_ptr[xsTop + c]); + Z top = topLeft + (topRight - topLeft) * xVal; + Z bottom = bottomLeft + (bottomRight - bottomLeft) * xVal; + Z resVal = Z(top + (bottom - top) * yVal); + pZ[x * channels + c] = resVal; } } } @@ -105,7 +160,7 @@ namespace helpers { //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // resize image with - template + template static void resizeImage_(nd4j::LaunchContext* context, NDArray const* images, Nd4jLong batchSize, Nd4jLong inHeight, Nd4jLong inWidth, Nd4jLong outHeight, Nd4jLong outWidth, Nd4jLong channels, BilinearInterpolationData* xs_, @@ -115,12 +170,13 @@ namespace helpers { Nd4jLong inBatchNumValues = inHeight * inRowSize; Nd4jLong outRowSize = outWidth * channels; auto stream = context->getCudaStream(); - T const *input_b_ptr = reinterpret_cast(images->getSpecialBuffer()); // this works only with 'c' direction - T *output_y_ptr = reinterpret_cast(output->specialBuffer()); + T const* pInput = images->getDataBuffer()->specialAsT(); //reinterpret_cast(images->getSpecialBuffer()); // this works only with 'c' direction + F* pOutput = output->dataBuffer()->specialAsT();//reinterpret_cast(output->specialBuffer()); dim3 batchSizeBlock(batchSize, 1, 1); dim3 pictureBlock(outHeight, outWidth, channels); - resizeImageKernel<<<256, pictureBlock, 256, *stream>>>(input_b_ptr, images->getSpecialShapeInfo(), output_y_ptr, output->specialShapeInfo(), batchSize, - outWidth, outHeight, channels, inRowSize, outRowSize, inBatchNumValues, xs_, ys_); + resizeImageKernel<<<256, 256, 256, *stream>>>(pInput, images->getSpecialShapeInfo(), pOutput, + output->specialShapeInfo(), batchSize, outWidth, outHeight, channels, inRowSize, outRowSize, + inBatchNumValues, xs_, ys_); auto err = cudaStreamSynchronize(*stream); if (err != 0) { @@ -129,8 +185,9 @@ namespace helpers { } //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - template - static int resizeBilinearFunctor_(nd4j::LaunchContext* context, NDArray const* images, int width, int height, bool center, NDArray* output) { + template + static int resizeBilinearFunctor_(nd4j::LaunchContext* context, NDArray const* images, int const width, + int const height, bool const alignCorners, bool const halfPixelCenter, NDArray* output) { const Nd4jLong batchSize = images->sizeAt(0); const Nd4jLong inHeight = images->sizeAt(1); const Nd4jLong inWidth = images->sizeAt(2); @@ -145,19 +202,8 @@ namespace helpers { return ND4J_STATUS_OK; } - // Special case for TF compatibility - if((center && inHeight < 2) || (center && inWidth < 2)){ - center = false; - } - - if ((center && inHeight < 2) || (inHeight < 1) || (outHeight < 1) || (center && outHeight < 2) || - (center && inWidth < 2) || (inWidth < 1) || (outWidth < 1) || (center && outWidth < 2)) { - // wrong input data - nd4j_printf("image.resize_bilinear: Wrong input or output size to resize\n", ""); - return ND4J_STATUS_BAD_ARGUMENTS; - } - float heightScale = center ? (inHeight - 1.f) / double(outHeight - 1.f) : (inHeight / float(outHeight)); - float widthScale = center ? (inWidth - 1.f) / double(outWidth - 1.f) : (inWidth / float(outWidth)); + float heightScale = calculateResizeScale(inHeight, outHeight, alignCorners); + float widthScale = calculateResizeScale(inWidth, outWidth, alignCorners); BilinearInterpolationData* xs_;// = xs.data(); BilinearInterpolationData* ys_;// = xs.data(); @@ -173,12 +219,24 @@ namespace helpers { } auto stream = context->getCudaStream(); // Compute the cached interpolation weights on the x and y dimensions. - computeInterpolationWeights<<<256, 512, 512, *stream>>>(outHeight, inHeight, heightScale, 0, ys_); - computeInterpolationWeights<<<256, 512, 512, *stream>>>(outWidth, inWidth, widthScale, channels, xs_); - + if (halfPixelCenter) { + computeInterpolationWeights < + HalfPixelScaler ><<<256, 512, 512, *stream>>>(outHeight, inHeight, heightScale, 0, ys_); + computeInterpolationWeights < + HalfPixelScaler ><<<256, 512, 512, *stream>>>(outWidth, inWidth, widthScale, channels, xs_); + } + else { + computeInterpolationWeights < + LegacyScaler ><<<256, 512, 512, *stream>>>(outHeight, inHeight, heightScale, 0, ys_); + computeInterpolationWeights < + LegacyScaler ><<<256, 512, 512, *stream>>>(outWidth, inWidth, widthScale, channels, xs_); + } + printf("Input is %dx%d, Output is %dx%d\n", inHeight, inWidth, outHeight, outWidth); NDArray::prepareSpecialUse({output}, {images}); - resizeImage(context, images, batchSize, inHeight, inWidth, outHeight, outWidth, channels, xs_, ys_, output); + resizeImage_(context, images, batchSize, inHeight, inWidth, outHeight, outWidth, channels, xs_, ys_, output); + err = cudaStreamSynchronize(*stream); NDArray::registerSpecialUse({output}, {images}); + err = cudaFree(xs_); if (err != 0) { throw cuda_exception::build("helpers::resize_image: Cannot deallocate memory for vertical parts rectangulars", err); @@ -197,20 +255,28 @@ namespace helpers { // template static __global__ void resizeNeighborKernel(T const* input, Nd4jLong* inputShape, T* output, Nd4jLong* outputShape, - Nd4jLong batchSize, Nd4jLong inWidth, Nd4jLong inHeight, Nd4jLong outWidth, Nd4jLong outHeight, Nd4jLong channels, double widthScale, double heightScale, bool center) { + Nd4jLong batchSize, Nd4jLong inWidth, Nd4jLong inHeight, Nd4jLong outWidth, Nd4jLong outHeight, Nd4jLong channels, double widthScale, double heightScale, bool alignCorners, bool halfPixelCenters) { //for (int b = blockIdx.x; b < batchSize; b += gridDim.x) if (blockIdx.x < batchSize) { auto b = blockIdx.x; for (int y = threadIdx.x; y < outHeight; y += blockDim.x) { - Nd4jLong inY = nd4j::math::nd4j_min( - (center) ? static_cast(nd4j::math::p_round(y * heightScale)) : static_cast(nd4j::math::p_floor( - y * heightScale)), inHeight - 1); + auto posY = alignCorners ? static_cast(nd4j::math::p_round(halfPixelCenters?((float)y + 0.5f) * heightScale:(float)y * heightScale)) : static_cast(nd4j::math::p_floor( + halfPixelCenters?((float)y + 0.5f) * heightScale:(float)y * heightScale)); + Nd4jLong inY = nd4j::math::nd4j_min(posY, inHeight - 1); + if (halfPixelCenters) { + inY = nd4j::math::nd4j_max(0LL, inY); + } + for (int x = threadIdx.y; x < outWidth; x += blockDim.y) { - Nd4jLong inX = nd4j::math::nd4j_min( - (center) ? static_cast(nd4j::math::p_round(x * widthScale)) : static_cast(nd4j::math::p_floor( - x * widthScale)), inWidth - 1); + auto posX = alignCorners ? static_cast(nd4j::math::p_round(halfPixelCenters?((float)x + 0.5f) * widthScale:(float)x * widthScale)) : static_cast(nd4j::math::p_floor( + halfPixelCenters?((float)x + 0.5f) * widthScale:(float)x * widthScale)); + Nd4jLong inX = nd4j::math::nd4j_min(posX, inWidth - 1); + if (halfPixelCenters) { + inX = nd4j::math::nd4j_max(0LL, inX); + } + auto start = blockIdx.z * blockDim.z + threadIdx.z; auto step = blockDim.z * gridDim.z; @@ -231,7 +297,8 @@ namespace helpers { // resizeNeighborFunctor - main algorithm by nearest neighbor // template - int resizeNeighborFunctor_(nd4j::LaunchContext* context, NDArray const* images, int width, int height, bool center, NDArray* output) { + int resizeNeighborFunctor_(nd4j::LaunchContext* context, NDArray const* images, int const width, int const height, + bool const alignCorners, bool const halfPixelCenters, NDArray* output) { const Nd4jLong batchSize = images->sizeAt(0); const Nd4jLong inHeight = images->sizeAt(1); const Nd4jLong inWidth = images->sizeAt(2); @@ -246,25 +313,24 @@ namespace helpers { return ND4J_STATUS_OK; } - if ((center && inHeight < 2) || (inHeight < 1) || (outHeight < 1) || (center && outHeight < 2) || - (center && inWidth < 2) || (inWidth < 1) || (outWidth < 1) || (center && outWidth < 2)) { - // wrong input data - nd4j_printf("image.resize_nearest_neighbor: Wrong input or output size to resize\n", ""); - return ND4J_STATUS_BAD_ARGUMENTS; - } - double heightScale = center ? (inHeight - 1.) / double(outHeight - 1.0) : (inHeight / double(outHeight)); - double widthScale = center ? (inWidth - 1.) / double(outWidth - 1.0) : (inWidth / double(outWidth)); - auto imagesBuffer = reinterpret_cast(images->getSpecialBuffer()); - auto outputBuffer = reinterpret_cast(output->specialBuffer()); +// if ((alignCorners && inHeight < 2) || (inHeight < 1) || (outHeight < 1) || (alignCorners && outHeight < 2) || +// (alignCorners && inWidth < 2) || (inWidth < 1) || (outWidth < 1) || (center && outWidth < 2)) { +// // wrong input data +// nd4j_printf("image.resize_nearest_neighbor: Wrong input or output size to resize\n", ""); +// return ND4J_STATUS_BAD_ARGUMENTS; +// } +// float heightScale = alignCorners ? (inHeight - 1.f) / float(outHeight - 1.f) : (inHeight / float(outHeight)); +// float widthScale = alignCorners ? (inWidth - 1.f) / float(outWidth - 1.f) : (inWidth / float(outWidth)); + float heightScale = calculateResizeScale(inHeight, outHeight, alignCorners); + float widthScale = calculateResizeScale(inWidth, outWidth, alignCorners); + + auto imagesBuffer = images->getDataBuffer()->specialAsT();//reinterpret_cast(images->getSpecialBuffer()); + auto outputBuffer = output->dataBuffer()->specialAsT();//reinterpret_cast(output->specialBuffer()); auto stream = context->getCudaStream(); - //T const* input, Nd4jLong const* inputShape, T* output, Nd4jLong* outputShape, - // Nd4jLong batchSize, Nd4jLong inWidth, Nd4jLong inHeight, Nd4jLong outWidth, Nd4jLong outHeight, Nd4jLong channels, double widthScale, double heightScale, bool center - //input, inputShape, output, outputShape, - // batchSize, inWidth, inHeight, outWidth, outHeight, channels, widthScale, heightScale, center NDArray::prepareSpecialUse({output}, {images}); resizeNeighborKernel<<>>(imagesBuffer, images->getSpecialShapeInfo(), outputBuffer, output->specialShapeInfo(), - batchSize, inWidth, inHeight, outWidth, outHeight, channels, widthScale, heightScale, center); + batchSize, inWidth, inHeight, outWidth, outHeight, channels, widthScale, heightScale, alignCorners, halfPixelCenters); NDArray::registerSpecialUse({output}, {images}); return Status::OK(); @@ -275,39 +341,38 @@ namespace helpers { void resizeImage(nd4j::LaunchContext* context, NDArray const* images, Nd4jLong batchSize, Nd4jLong inHeight, Nd4jLong inWidth, Nd4jLong outHeight, Nd4jLong outWidth, Nd4jLong channels, BilinearInterpolationData* xs_, BilinearInterpolationData* ys_, NDArray* output) { - BUILD_SINGLE_SELECTOR(images->dataType(), resizeImage_, (context, images, batchSize, inHeight, inWidth, outHeight, outWidth, channels, xs_, ys_, output), LIBND4J_TYPES); + BUILD_DOUBLE_SELECTOR(images->dataType(), output->dataType(), + resizeImage_, (context, images, batchSize, inHeight, inWidth, outHeight, outWidth, channels, + xs_, ys_, output), NUMERIC_TYPES, FLOAT_TYPES); } - BUILD_SINGLE_TEMPLATE(template void resizeImage_,(nd4j::LaunchContext* context, NDArray const* images, + BUILD_DOUBLE_TEMPLATE(template void resizeImage_,(nd4j::LaunchContext* context, NDArray const* images, Nd4jLong batchSize, Nd4jLong inHeight, Nd4jLong inWidth, Nd4jLong outHeight, Nd4jLong outWidth, - Nd4jLong channels, BilinearInterpolationData* xs_, BilinearInterpolationData* ys_, NDArray* output), LIBND4J_TYPES); + Nd4jLong channels, BilinearInterpolationData* xs_, BilinearInterpolationData* ys_, NDArray* output), + NUMERIC_TYPES, FLOAT_TYPES); //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - int resizeBilinearFunctor(nd4j::LaunchContext* context, NDArray const* images, int width, int height, bool center, NDArray* output) { - BUILD_SINGLE_SELECTOR(images->dataType(), return resizeBilinearFunctor_, (context, images, width, height, center, output), LIBND4J_TYPES); + int resizeBilinearFunctor(nd4j::LaunchContext* context, NDArray const* images, int width, int height, + bool const alignCorners, bool const halfPixelCenter, NDArray* output) { + BUILD_DOUBLE_SELECTOR(images->dataType(), output->dataType(), return resizeBilinearFunctor_, (context, images, + width, height, alignCorners, halfPixelCenter, output), NUMERIC_TYPES, FLOAT_TYPES); } - BUILD_SINGLE_TEMPLATE(template int resizeBilinearFunctor_, (nd4j::LaunchContext* context, NDArray const* images, int width, int height, bool center, NDArray* output), LIBND4J_TYPES); +// BUILD_SINGLE_TEMPLATE(template int resizeBilinearFunctor_, (nd4j::LaunchContext* context, +// NDArray const* images, int const width, int const height, bool const alignCorners, +// bool const halfPixelCenter, NDArray* output), LIBND4J_TYPES); //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - int resizeNeighborFunctor(nd4j::LaunchContext* context, NDArray const* images, int width, int height, bool center, NDArray* output) { - BUILD_SINGLE_SELECTOR(images->dataType(), return resizeNeighborFunctor_, (context, images, width, height, center, output), LIBND4J_TYPES); + int resizeNeighborFunctor(nd4j::LaunchContext* context, NDArray const* images, int const width, int const height, + bool const alignCorners, bool const halfPixelCenter, NDArray* output) { + BUILD_SINGLE_SELECTOR(images->dataType(), return resizeNeighborFunctor_, + (context, images, width, height, alignCorners, halfPixelCenter, output), LIBND4J_TYPES); } - BUILD_SINGLE_TEMPLATE(template int resizeNeighborFunctor_, (nd4j::LaunchContext* context, NDArray const* images, - int width, int height, bool center, NDArray* output), LIBND4J_TYPES); +// BUILD_SINGLE_TEMPLATE(template int resizeNeighborFunctor_, (nd4j::LaunchContext* context, NDArray const* images, +// int width, int height, bool const alignCorners, bool const halfPixelCenter, NDArray* output), LIBND4J_TYPES); //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // Bicubic interpolation //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -// Utility functions and classes - - // calculateResizeScale determines the float scaling factor. - inline float calculateResizeScale(Nd4jLong inSize, Nd4jLong outSize, - bool alignCorners) { - return (alignCorners && outSize > 1) - ? (inSize - 1) / static_cast(outSize - 1) - : inSize / static_cast(outSize); - } - struct ImageResizerState { explicit ImageResizerState(bool alignCorners, bool halfPixelCenters) : _alignCorners(alignCorners), @@ -362,17 +427,6 @@ namespace helpers { bool _halfPixelCenters; }; - // Half pixel scaler scales assuming that the pixel centers are at 0.5, i.e. the -// floating point coordinates of the top,left pixel is 0.5,0.5. - struct HalfPixelScaler { - _CUDA_HD HalfPixelScaler(){}; - inline _CUDA_HD float operator()(const int x, const float scale) const { - // Note that we subtract 0.5 from the return value, as the existing bilinear - // sampling code etc assumes pixels are in the old coordinate system. - return (static_cast(x) + 0.5f) * scale - 0.5f; - } - }; - struct WeightsAndIndices { float _weight0; float _weight1; @@ -547,16 +601,6 @@ namespace helpers { } } -// Older incorrect scaling method that causes all resizes to have a slight -// translation leading to inconsistent results. For example, a flip then a -// resize gives different results then a resize then a flip. - struct LegacyScaler { - _CUDA_HD LegacyScaler(){}; - inline _CUDA_HD float operator()(const int x, const float scale) const { - return static_cast(x) * scale; - } - }; - static __global__ void accumulateChannelsKernel(WeightsAndIndices* pXWais, Nd4jLong outWidth, Nd4jLong channels) { auto start = blockIdx.x * blockDim.x + threadIdx.x; auto step = blockDim.x * gridDim.x; @@ -906,8 +950,8 @@ namespace helpers { int resizeFunctor(nd4j::LaunchContext * context, NDArray const* image, int width, int height, ImageResizeMethods method, bool preserveAspectRatio, bool antialias, NDArray* output) { switch (method) { - case kResizeBilinear: return resizeBilinearFunctor(context, image, width, height, false, output); break; - case kResizeNearest: return resizeNeighborFunctor(context, image, width, height, true, output); break; + case kResizeBilinear: return resizeBilinearFunctor(context, image, width, height, false, false, output); break; + case kResizeNearest: return resizeNeighborFunctor(context, image, width, height, false, false, output); break; case kResizeBicubic: return resizeBicubicFunctor(context, image, width, height, preserveAspectRatio, antialias, output); break; case kResizeLanczos5: case kResizeGaussian: diff --git a/libnd4j/include/ops/declarable/helpers/cuda/reverse.cu b/libnd4j/include/ops/declarable/helpers/cuda/reverse.cu index aceebf7a0..90e15b21f 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/reverse.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/reverse.cu @@ -30,6 +30,67 @@ namespace nd4j { namespace ops { namespace helpers { + template + static __global__ void reverseTadKernel(void* vinput, Nd4jLong *inputShape, void* voutput, Nd4jLong *outputShape, Nd4jLong *inputTadShape, Nd4jLong *inputTadOffsets, Nd4jLong *outputTadShape, Nd4jLong *outputTadOffsets, uint64_t limit, uint64_t numOfElemsToReverse, uint64_t numTads) { + auto input = reinterpret_cast(vinput); + auto output = reinterpret_cast(voutput); + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + const auto step = gridDim.x * blockDim.x; + + // this means that we'll have additional cycle, to move middle element + auto div = numOfElemsToReverse / 2; + auto odd = numOfElemsToReverse % 2 != 0; + auto rlimit = odd ? limit / 2 + 1 : limit / 2; + + // all threads operate in the same input/output space + for (uint64_t e = tid; e < rlimit; e += step) { + // finding out the TAD we're going to process + auto tadId = e / div; + + if (tadId >= numTads) + continue; + + // now finding out element within tad + auto idx = e % div; + + //printf("TID: %i; numTads: %lld; tadLength: %lld; tadId: %i, idx: %lld\n", tid, numTads, numOfElemsToReverse, tadId, idx); + + auto tadInput = input + inputTadOffsets[tadId]; + auto tadOutput = output + outputTadOffsets[tadId]; + + // we're calculating offsets within input TAD + auto fOffset = shape::getIndexOffset(idx, inputTadShape); + auto lOffset = shape::getIndexOffset(numOfElemsToReverse - idx - 1, inputTadShape); + + // now we're storing input values + auto v1 = tadInput[fOffset]; + auto v2 = tadInput[lOffset]; + + // now we're calculating offsets within output TAD + auto zfOffset = shape::getIndexOffset(idx, outputTadShape); + auto zlOffset = shape::getIndexOffset(numOfElemsToReverse - idx - 1, outputTadShape); + + // and saving values to output arrays + tadOutput[zfOffset] = v2; + tadOutput[zlOffset] = v1; + } + + // moving odd element in blocks + if (odd && threadIdx.x == 0) { + for (uint64_t e = blockIdx.x; e < numTads; e += gridDim.x) { + auto tadInput = input + inputTadOffsets[e]; + auto tadOutput = output + outputTadOffsets[e]; + + auto xOffset = shape::getIndexOffset(numOfElemsToReverse / 2, inputTadShape); + auto zOffset = shape::getIndexOffset(numOfElemsToReverse / 2, outputTadShape); + + tadOutput[zOffset] = tadInput[xOffset]; + } + } + + } + + template static __global__ void reverseArrayKernel(void* input, Nd4jLong *inputShape, void* output, Nd4jLong *outputShape, Nd4jLong numOfElemsToReverse) { const auto tid = blockIdx.x * blockDim.x + threadIdx.x; @@ -52,7 +113,7 @@ namespace helpers { auto odd = numOfElemsToReverse % 2 != 0; auto limit = numOfElemsToReverse / 2; - for (Nd4jLong e = tid; e < limit; e += step) { + for (uint64_t e = tid; e < limit; e += step) { // we're calculating offsets within input array auto fOffset = shape::getIndexOffset(e, inputShape); auto lOffset = shape::getIndexOffset(numOfElemsToReverse - e - 1, inputShape); @@ -80,13 +141,19 @@ namespace helpers { } template - static void reverseArray(nd4j::LaunchContext * context, NDArray* input, NDArray* output, Nd4jLong numOfElemsToReverse) { + static void reverseTad(nd4j::LaunchContext * context, const NDArray* input, NDArray* output, Nd4jLong *inputTadShape, Nd4jLong *inputTadOffsets, Nd4jLong *outputTadShape, Nd4jLong *outputTadOffsets, uint64_t tadLength) { + auto stream = context->getCudaStream(); + reverseTadKernel<<<256, 512, 8192, *stream>>>(input->getSpecialBuffer(), input->getSpecialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), inputTadShape, inputTadOffsets, outputTadShape, outputTadOffsets, input->lengthOf(), tadLength, input->lengthOf() / tadLength); + } + + template + static void reverseArray(nd4j::LaunchContext * context, const NDArray* input, NDArray* output, Nd4jLong numOfElemsToReverse) { auto stream = context->getCudaStream(); Nd4jLong numOfReverse = numOfElemsToReverse; if (numOfElemsToReverse == 0) numOfReverse = input->lengthOf(); - reverseArrayKernel<<<256, 512, 8192, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), numOfReverse); + reverseArrayKernel<<<256, 512, 8192, *stream>>>(input->getSpecialBuffer(), input->getSpecialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), numOfReverse); } @@ -153,27 +220,23 @@ namespace helpers { // we need to reverse axis only if that's new op std::vector dimensions = isBackProp ? ShapeUtils::evalDimsToExclude(input->rankOf(), *intArgs) : *intArgs; std::vector axis = ShapeUtils::evalDimsToExclude(input->rankOf(), dimensions); - auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), axis); - auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), axis); + auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions); + auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions); - auto listOut = output->allTensorsAlongDimension(dimensions); - auto listIn = input->allTensorsAlongDimension(dimensions); - NDArray *subArrIn, *subArrOut; NDArray::prepareSpecialUse({output}, {input}); - for(int i = 0; i < listIn->size(); ++i) { // listIn->size() = listOut->size() - subArrIn = listIn->at(i); - subArrOut = listOut->at(i); - BUILD_SINGLE_SELECTOR(input->dataType(), reverseArray, (context, subArrIn, subArrOut, 0), LIBND4J_TYPES); + + if (packX.numberOfTads() == 1) { + BUILD_SINGLE_SELECTOR(input->dataType(), reverseArray, (context, input, output, 0), LIBND4J_TYPES); + } else { + BUILD_SINGLE_SELECTOR(input->dataType(), reverseTad, (context, input, output, packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets(), (uint64_t) (input->lengthOf() / packX.numberOfTads())), LIBND4J_TYPES); } - //BUILD_SINGLE_SELECTOR(input->dataType(), reverseArray, (context, const_cast(input), output, (int)0), LIBND4J_TYPES); + NDArray::registerSpecialUse({output}, {input}); - delete listOut; - delete listIn; } -BUILD_SINGLE_TEMPLATE(template void reverseArray, (nd4j::LaunchContext * context, NDArray *inArr, NDArray *outArr, Nd4jLong numOfElemsToReverse), LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void reverseArray, (nd4j::LaunchContext * context, const NDArray *inArr, NDArray *outArr, Nd4jLong numOfElemsToReverse), LIBND4J_TYPES); } } diff --git a/libnd4j/include/ops/declarable/helpers/image_resize.h b/libnd4j/include/ops/declarable/helpers/image_resize.h index 22c41833b..d52fd74f7 100644 --- a/libnd4j/include/ops/declarable/helpers/image_resize.h +++ b/libnd4j/include/ops/declarable/helpers/image_resize.h @@ -37,15 +37,15 @@ namespace helpers { kResizeArea }; - int resizeBilinearFunctor(nd4j::LaunchContext * context, NDArray const* image, int width, int height, bool center, - NDArray* output); - int resizeNeighborFunctor(nd4j::LaunchContext * context, NDArray const* image, int width, int height, bool center, - NDArray* output); - int resizeBicubicFunctor(nd4j::LaunchContext * context, NDArray const* image, int width, int height, + int resizeBilinearFunctor(nd4j::LaunchContext * context, NDArray const* image, int const width, int const height, + bool const alignCorners, bool const halfPixelCenter, NDArray* output); + int resizeNeighborFunctor(nd4j::LaunchContext * context, NDArray const* image, int const width, int const height, + bool const alignCorners, bool const halfPixelCenter, NDArray* output); + int resizeBicubicFunctor(nd4j::LaunchContext * context, NDArray const* image, int const width, int const height, bool preserveAspectRatio, bool antialias, NDArray* output); - int resizeBicubicFunctorA(nd4j::LaunchContext * context, NDArray const* image, int width, int height, + int resizeBicubicFunctorA(nd4j::LaunchContext * context, NDArray const* image, int const width, int const height, bool const alignCorners, bool const halfPixelAlign, NDArray* output); - int resizeFunctor(nd4j::LaunchContext * context, NDArray const* image, int width, int height, + int resizeFunctor(nd4j::LaunchContext * context, NDArray const* image, int const width, int const height, ImageResizeMethods method, bool preserveAspectRatio, bool antialias, NDArray* output); void cropAndResizeFunctor(nd4j::LaunchContext * context, NDArray const* images, NDArray const* boxes, diff --git a/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp b/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp index 99cc98af9..eccb73c6c 100644 --- a/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp +++ b/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp @@ -419,7 +419,17 @@ TEST_F(ConvolutionTests1, sconv2d_1) { ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests1, sconv2d_2) { - TypeParam _expBFF[] = {108.9405008, 109.5920008, 110.2435008, 110.8950008, 111.5465008, 112.1980008, 115.4555008, 116.1070008, 116.7585008, 117.410000, 118.061500, 118.7130009, 121.9705009, 122.6220009, 123.2735009, 123.9250009, 124.5765009, 125.2280009, 128.4855009, 129.1370009, 129.7885009, 130.4400009, 131.09150, 131.74300, 135.0005010, 135.6520010, 136.3035010, 136.9550010, 137.6065010, 138.2580010, 141.5155010, 142.1670010, 142.8185010, 143.4700010, 144.1215010, 144.7730010, 248.9617514, 250.670751, 252.3797515, 254.0887515, 255.7977515, 257.5067515, 266.0517515, 267.7607515, 269.469751, 271.1787516, 272.8877516, 274.5967516, 283.1417516, 284.8507516, 286.5597516, 288.268751, 289.9777517, 291.6867517, 300.2317517, 301.9407517, 303.6497517, 305.3587517, 307.067751, 308.7767518, 317.3217518, 319.0307518, 320.7397518, 322.4487518, 324.157751, 325.866751, 334.4117519, 336.1207519, 337.8297519, 339.5387519, 341.2477519, 342.95675, 388.9829964, 391.7494964, 394.5159964, 397.2824964, 400.048996, 402.8154963, 416.647996, 419.4144962, 422.1809962, 424.9474962, 427.7139962, 430.4804962, 444.3129961, 447.0794961, 449.8459961, 452.6124960, 455.3789960, 458.1454960, 471.9779959, 474.7444959, 477.5109959, 480.2774959, 483.0439959, 485.8104958, 499.6429958, 502.4094957, 505.1759957, 507.9424957, 510.7089957, 513.4754957, 527.3079956, 530.0744956, 532.8409956, 535.607495, 538.3739955, 541.1404955, 529.0042487, 532.8282487, 536.6522487, 540.4762487, 544.3002487, 548.1242487, 567.2442487, 571.068248, 574.892248, 578.716248, 582.540248, 586.3642486, 605.4842486, 609.3082486, 613.1322486, 616.9562486, 620.7802486, 624.6042486, 643.7242486, 647.5482486, 651.3722486, 655.1962486, 659.0202486, 662.8442486, 681.9642486, 685.7882486, 689.6122486, 693.4362486, 697.2602486, 701.0842486, 720.2042486, 724.0282486, 727.852248, 731.676248, 735.500248, 739.324248, 669.0255044, 673.9070044, 678.7885044, 683.6700044, 688.5515044, 693.4330044, 717.8405044, 722.7220044, 727.6035044, 732.4850044, 737.3665044, 742.2480044, 766.6555043, 771.5370043, 776.4185043, 781.3000043, 786.1815043, 791.0630043, 815.4705043, 820.3520043, 825.2335043, 830.1150043, 834.9965043, 839.8780043, 864.2855042, 869.1670042, 874.0485042, 878.9300042, 883.8115042, 888.6930042, 913.1005042, 917.9820042, 922.8635042, 927.7450042, 932.6265042, 937.5080042, 809.0467424, 814.9857424, 820.9247424, 826.8637423, 832.8027423, 838.7417423, 868.4367421, 874.3757421, 880.3147420, 886.2537420, 892.1927420, 898.13174, 927.8267418, 933.7657418, 939.7047417, 945.6437417, 951.5827417, 957.5217416, 987.2167415, 993.155741, 999.0947414, 1005.0337414, 1010.972741, 1016.9117413, 1046.6067412, 1052.5457411, 1058.4847411, 1064.4237411, 1070.3627410, 1076.3017410, 1105.996740, 1111.9357408, 1117.8747408, 1123.8137408, 1129.7527407, 1135.6917407, 949.0679815, 956.0644814, 963.060981, 970.0574813, 977.0539812, 984.0504811, 1019.0329807, 1026.0294807, 1033.0259806, 1040.0224805, 1047.0189804, 1054.0154804, 1088.9979800, 1095.9944799, 1102.9909798, 1109.987479, 1116.9839797, 1123.9804796, 1158.9629792, 1165.9594791, 1172.9559791, 1179.9524790, 1186.9489789, 1193.9454788, 1228.9279785, 1235.9244784, 1242.9209783, 1249.9174782, 1256.913978, 1263.9104781, 1298.8929777, 1305.8894776, 1312.8859775, 1319.8824775, 1326.8789774, 1333.8754773, 1089.0892560, 1097.1432561, 1105.1972562, 1113.251256, 1121.3052563, 1129.3592564, 1169.6292568, 1177.6832568, 1185.7372569, 1193.7912570, 1201.845257, 1209.8992571, 1250.1692575, 1258.2232576, 1266.2772576, 1274.3312577, 1282.3852578, 1290.4392579, 1330.7092582, 1338.7632583, 1346.8172584, 1354.8712584, 1362.9252585, 1370.9792586, 1411.24925, 1419.3032590, 1427.3572591, 1435.4112592, 1443.465259, 1451.5192593, 1491.7892597, 1499.8432598, 1507.8972598, 1515.9512599, 1524.0052600, 1532.059260, 1229.1105073, 1238.2220073, 1247.3335073, 1256.4450073, 1265.5565073, 1274.668007, 1320.2255074, 1329.3370074, 1338.4485074, 1347.5600075, 1356.6715075, 1365.7830075, 1411.340507, 1420.4520076, 1429.5635076, 1438.6750076, 1447.7865076, 1456.8980076, 1502.4555077, 1511.5670077, 1520.6785077, 1529.7900077, 1538.9015077, 1548.013007, 1593.5705078, 1602.6820078, 1611.793507, 1620.9050079, 1630.0165079, 1639.1280079, 1684.6855080, 1693.7970080, 1702.9085080, 1712.0200080, 1721.1315080, 1730.2430080, 1369.1317613, 1379.3007614, 1389.4697614, 1399.6387615, 1409.8077615, 1419.976761, 1470.8217618, 1480.9907618, 1491.159761, 1501.3287619, 1511.4977619, 1521.6667620, 1572.5117622, 1582.6807622, 1592.8497623, 1603.0187623, 1613.1877624, 1623.3567624, 1674.2017626, 1684.3707627, 1694.5397627, 1704.7087628, 1714.8777628, 1725.046762, 1775.8917631, 1786.0607631, 1796.229763, 1806.3987632, 1816.5677632, 1826.7367633, 1877.5817635, 1887.7507635, 1897.9197636, 1908.0887636, 1918.2577637, 1928.4267637, 304.3905022, 305.0420022, 305.6935022, 306.3450022, 306.9965022, 307.6480022, 310.9055022, 311.5570022, 312.208502, 312.860002, 313.5115023, 314.1630023, 317.4205023, 318.0720023, 318.7235023, 319.3750023, 320.0265023, 320.6780023, 323.9355023, 324.5870023, 325.2385023, 325.8900023, 326.541502, 327.193002, 330.4505024, 331.1020024, 331.7535024, 332.4050024, 333.0565024, 333.7080024, 336.9655024, 337.6170024, 338.2685024, 338.9200024, 339.5715024, 340.223002, 761.6617542, 763.3707542, 765.0797542, 766.7887542, 768.4977542, 770.206754, 778.7517543, 780.4607543, 782.1697543, 783.8787543, 785.5877543, 787.2967543, 795.8417544, 797.5507544, 799.2597544, 800.9687544, 802.6777544, 804.3867544, 812.9317545, 814.6407545, 816.3497545, 818.0587545, 819.7677545, 821.4767545, 830.0217546, 831.7307546, 833.4397546, 835.1487546, 836.8577546, 838.5667546, 847.1117547, 848.8207547, 850.5297547, 852.2387547, 853.9477547, 855.6567547, 1218.9329915, 1221.6994915, 1224.4659915, 1227.232491, 1229.9989914, 1232.7654914, 1246.5979913, 1249.3644913, 1252.1309913, 1254.8974913, 1257.6639913, 1260.430491, 1274.2629912, 1277.029491, 1279.7959911, 1282.5624911, 1285.3289911, 1288.0954911, 1301.9279910, 1304.6944910, 1307.4609910, 1310.22749, 1312.9939909, 1315.7604909, 1329.5929908, 1332.3594908, 1335.1259908, 1337.8924908, 1340.6589908, 1343.4254908, 1357.2579907, 1360.0244907, 1362.7909906, 1365.5574906, 1368.3239906, 1371.0904906, 1676.2042479, 1680.0282479, 1683.8522479, 1687.6762479, 1691.5002479, 1695.3242479, 1714.4442479, 1718.2682479, 1722.0922479, 1725.9162479, 1729.7402479, 1733.5642479, 1752.6842479, 1756.5082479, 1760.3322479, 1764.1562479, 1767.9802479, 1771.8042479, 1790.9242479, 1794.7482479, 1798.5722479, 1802.3962479, 1806.2202479, 1810.044247, 1829.1642478, 1832.9882478, 1836.8122478, 1840.6362478, 1844.4602478, 1848.2842478, 1867.4042478, 1871.2282478, 1875.0522478, 1878.8762478, 1882.7002478, 1886.5242478, 2133.4755029, 2138.3570029, 2143.2385029, 2148.1200029, 2153.0015029, 2157.8830029, 2182.2905028, 2187.1720028, 2192.0535028, 2196.9350028, 2201.8165028, 2206.6980028, 2231.1055028, 2235.9870028, 2240.8685028, 2245.7500028, 2250.6315028, 2255.5130028, 2279.9205027, 2284.8020027, 2289.6835027, 2294.5650027, 2299.4465027, 2304.3280027, 2328.7355027, 2333.6170027, 2338.4985027, 2343.3800027, 2348.2615027, 2353.1430027, 2377.5505026, 2382.4320026, 2387.3135026, 2392.1950026, 2397.0765026, 2401.9580026, 2590.7467330, 2596.6857330, 2602.6247329, 2608.5637329, 2614.5027329, 2620.441732, 2650.1367327, 2656.0757327, 2662.0147326, 2667.9537326, 2673.8927326, 2679.8317325, 2709.5267324, 2715.465732, 2721.4047323, 2727.3437323, 2733.282732, 2739.2217322, 2768.9167321, 2774.8557320, 2780.7947320, 2786.7337320, 2792.6727319, 2798.6117319, 2828.306731, 2834.2457317, 2840.1847317, 2846.1237317, 2852.0627316, 2858.0017316, 2887.6967314, 2893.6357314, 2899.5747314, 2905.5137313, 2911.4527313, 2917.3917313, 3048.0179587, 3055.0144586, 3062.0109585, 3069.0074584, 3076.0039584, 3083.0004583, 3117.9829579, 3124.9794578, 3131.9759578, 3138.9724577, 3145.9689576, 3152.9654575, 3187.947957, 3194.9444571, 3201.9409570, 3208.9374569, 3215.933956, 3222.9304568, 3257.9129564, 3264.9094563, 3271.9059562, 3278.9024562, 3285.8989561, 3292.8954560, 3327.8779556, 3334.874455, 3341.8709555, 3348.8674554, 3355.8639553, 3362.860455, 3397.8429549, 3404.8394548, 3411.8359547, 3418.8324546, 3425.8289546, 3432.8254545, 3505.28927, 3513.3432780, 3521.3972781, 3529.4512782, 3537.5052782, 3545.5592783, 3585.8292787, 3593.8832788, 3601.9372788, 3609.9912789, 3618.0452790, 3626.099279, 3666.3692794, 3674.4232795, 3682.4772796, 3690.5312796, 3698.5852797, 3706.6392798, 3746.9092801, 3754.9632802, 3763.0172803, 3771.0712804, 3779.1252804, 3787.1792805, 3827.4492809, 3835.50328, 3843.5572810, 3851.6112811, 3859.6652812, 3867.7192812, 3907.9892816, 3916.0432817, 3924.097281, 3932.1512818, 3940.2052819, 3948.2592820, 3962.5605113, 3971.6720113, 3980.783511, 3989.8950114, 3999.0065114, 4008.1180114, 4053.6755115, 4062.7870115, 4071.8985115, 4081.0100115, 4090.1215115, 4099.2330115, 4144.7905116, 4153.9020116, 4163.0135116, 4172.1250116, 4181.236511, 4190.3480117, 4235.9055117, 4245.0170117, 4254.128511, 4263.2400118, 4272.3515118, 4281.4630118, 4327.0205119, 4336.1320119, 4345.2435119, 4354.3550119, 4363.4665119, 4372.5780119, 4418.1355120, 4427.2470120, 4436.3585120, 4445.4700120, 4454.581512, 4463.6930121, 4419.8317743, 4430.0007744, 4440.1697744, 4450.338774, 4460.5077745, 4470.6767745, 4521.521774, 4531.6907748, 4541.8597748, 4552.0287749, 4562.1977749, 4572.3667750, 4623.2117752, 4633.3807752, 4643.5497753, 4653.7187753, 4663.8877754, 4674.0567754, 4724.9017756, 4735.0707757, 4745.2397757, 4755.4087757, 4765.5777758, 4775.7467758, 4826.591776, 4836.7607761, 4846.9297761, 4857.0987762, 4867.2677762, 4877.4367763, 4928.2817765, 4938.4507765, 4948.6197766, 4958.7887766, 4968.957776, 4979.12677675}; + TypeParam _expBFF[] = {108.9405008f, 109.5920008f, 110.2435008f, 110.8950008f, 111.5465008f, 112.1980008f, 115.4555008f, 116.1070008f, 116.7585008f, 117.410000f, 118.061500f, 118.7130009f, 121.9705009f, 122.6220009f, 123.2735009f, 123.9250009f, 124.5765009f, 125.2280009f, 128.4855009f, 129.1370009f, 129.7885009f, 130.4400009f, 131.09150f, 131.74300f, 135.0005010f, 135.6520010f, 136.3035010f, 136.9550010f, 137.6065010f, 138.2580010f, 141.5155010f, 142.1670010f, 142.8185010f, 143.4700010f, 144.1215010f, 144.7730010f, 248.9617514f, 250.670751f, 252.3797515f, 254.0887515f, 255.7977515f, 257.5067515f, 266.0517515f, 267.7607515f, 269.469751f, 271.1787516f, 272.8877516f, 274.5967516f, 283.1417516f, 284.8507516f, + 286.5597516f, 288.268751f, 289.9777517f, 291.6867517f, 300.2317517f, 301.9407517f, 303.6497517f, 305.3587517f, 307.067751f, 308.7767518f, 317.3217518f, 319.0307518f, 320.7397518f, 322.4487518f, 324.157751f, 325.866751f, 334.4117519f, 336.1207519f, 337.8297519f, 339.5387519f, 341.2477519f, 342.95675f, 388.9829964f, 391.7494964f, 394.5159964f, 397.2824964f, 400.048996f, 402.8154963f, 416.647996f, 419.4144962f, 422.1809962f, 424.9474962f, 427.7139962f, 430.4804962f, 444.3129961f, 447.0794961f, 449.8459961f, 452.6124960f, 455.3789960f, 458.1454960f, 471.9779959f, 474.7444959f, 477.5109959f, 480.2774959f, 483.0439959f, 485.8104958f, 499.6429958f, 502.4094957f, 505.1759957f, 507.9424957f, + 510.7089957f, 513.4754957f, 527.3079956f, 530.0744956f, 532.8409956f, 535.607495f, 538.3739955f, 541.1404955f, 529.0042487f, 532.8282487f, 536.6522487f, 540.4762487f, 544.3002487f, 548.1242487f, 567.2442487f, 571.068248f, 574.892248f, 578.716248f, 582.540248f, 586.3642486f, 605.4842486f, 609.3082486f, 613.1322486f, 616.9562486f, 620.7802486f, 624.6042486f, 643.7242486f, 647.5482486f, 651.3722486f, 655.1962486f, 659.0202486f, 662.8442486f, 681.9642486f, 685.7882486f, 689.6122486f, 693.4362486f, 697.2602486f, 701.0842486f, 720.2042486f, 724.0282486f, 727.852248f, 731.676248f, 735.500248f, 739.324248f, 669.0255044f, 673.9070044f, 678.7885044f, 683.6700044f, 688.5515044f, 693.4330044f, + 717.8405044f, 722.7220044f, 727.6035044f, 732.4850044f, 737.3665044f, 742.2480044f, 766.6555043f, 771.5370043f, 776.4185043f, 781.3000043f, 786.1815043f, 791.0630043f, 815.4705043f, 820.3520043f, 825.2335043f, 830.1150043f, 834.9965043f, 839.8780043f, 864.2855042f, 869.1670042f, 874.0485042f, 878.9300042f, 883.8115042f, 888.6930042f, 913.1005042f, 917.9820042f, 922.8635042f, 927.7450042f, 932.6265042f, 937.5080042f, 809.0467424f, 814.9857424f, 820.9247424f, 826.8637423f, 832.8027423f, 838.7417423f, 868.4367421f, 874.3757421f, 880.3147420f, 886.2537420f, 892.1927420f, 898.13174f, 927.8267418f, 933.7657418f, 939.7047417f, 945.6437417f, 951.5827417f, 957.5217416f, 987.2167415f, 993.155741f, + 999.0947414f, 1005.0337414f, 1010.972741f, 1016.9117413f, 1046.6067412f, 1052.5457411f, 1058.4847411f, 1064.4237411f, 1070.3627410f, 1076.3017410f, 1105.996740f, 1111.9357408f, 1117.8747408f, 1123.8137408f, 1129.7527407f, 1135.6917407f, 949.0679815f, 956.0644814f, 963.060981f, 970.0574813f, 977.0539812f, 984.0504811f, 1019.0329807f, 1026.0294807f, 1033.0259806f, 1040.0224805f, 1047.0189804f, 1054.0154804f, 1088.9979800f, 1095.9944799f, 1102.9909798f, 1109.987479f, 1116.9839797f, 1123.9804796f, 1158.9629792f, 1165.9594791f, 1172.9559791f, 1179.9524790f, 1186.9489789f, 1193.9454788f, 1228.9279785f, 1235.9244784f, 1242.9209783f, 1249.9174782f, 1256.913978f, 1263.9104781f, 1298.8929777f, 1305.8894776f, 1312.8859775f, 1319.8824775f, 1326.8789774f, 1333.8754773f, 1089.0892560f, 1097.1432561f, 1105.1972562f, 1113.251256f, 1121.3052563f, 1129.3592564f, 1169.6292568f, 1177.6832568f, 1185.7372569f, 1193.7912570f, 1201.845257f, 1209.8992571f, 1250.1692575f, 1258.2232576f, 1266.2772576f, 1274.3312577f, 1282.3852578f, 1290.4392579f, 1330.7092582f, 1338.7632583f, 1346.8172584f, 1354.8712584f, 1362.9252585f, 1370.9792586f, 1411.24925f, 1419.3032590f, 1427.3572591f, 1435.4112592f, 1443.465259f, 1451.5192593f, 1491.7892597f, 1499.8432598f, 1507.8972598f, 1515.9512599f, 1524.0052600f, 1532.059260f, 1229.1105073f, 1238.2220073f, 1247.3335073f, 1256.4450073f, 1265.5565073f, 1274.668007f, 1320.2255074f, 1329.3370074f, 1338.4485074f, 1347.5600075f, 1356.6715075f, 1365.7830075f, 1411.340507f, 1420.4520076f, 1429.5635076f, 1438.6750076f, 1447.7865076f, 1456.8980076f, 1502.4555077f, 1511.5670077f, 1520.6785077f, 1529.7900077f, 1538.9015077f, 1548.013007f, 1593.5705078f, 1602.6820078f, 1611.793507f, 1620.9050079f, 1630.0165079f, 1639.1280079f, 1684.6855080f, 1693.7970080f, 1702.9085080f, 1712.0200080f, 1721.1315080f, 1730.2430080f, 1369.1317613f, 1379.3007614f, 1389.4697614f, 1399.6387615f, 1409.8077615f, 1419.976761f, 1470.8217618f, 1480.9907618f, 1491.159761f, 1501.3287619f, 1511.4977619f, 1521.6667620f, 1572.5117622f, 1582.6807622f, 1592.8497623f, 1603.0187623f, 1613.1877624f, 1623.3567624f, 1674.2017626f, 1684.3707627f, 1694.5397627f, 1704.7087628f, 1714.8777628f, 1725.046762f, 1775.8917631f, 1786.0607631f, 1796.229763f, 1806.3987632f, 1816.5677632f, 1826.7367633f, 1877.5817635f, 1887.7507635f, 1897.9197636f, 1908.0887636f, 1918.2577637f, 1928.4267637f, 304.3905022f, 305.0420022f, 305.6935022f, 306.3450022f, 306.9965022f, 307.6480022f, 310.9055022f, 311.5570022f, 312.208502f, 312.860002f, 313.5115023f, 314.1630023f, 317.4205023f, 318.0720023f, 318.7235023f, 319.3750023f, 320.0265023f, 320.6780023f, 323.9355023f, 324.5870023f, 325.2385023f, 325.8900023f, 326.541502f, 327.193002f, 330.4505024f, 331.1020024f, 331.7535024f, 332.4050024f, 333.0565024f, 333.7080024f, 336.9655024f, 337.6170024f, 338.2685024f, 338.9200024f, 339.5715024f, 340.223002f, 761.6617542f, 763.3707542f, 765.0797542f, 766.7887542f, 768.4977542f, 770.206754f, 778.7517543f, 780.4607543f, 782.1697543f, 783.8787543f, 785.5877543f, 787.2967543f, 795.8417544f, 797.5507544f, 799.2597544f, 800.9687544f, 802.6777544f, 804.3867544f, 812.9317545f, 814.6407545f, 816.3497545f, 818.0587545f, 819.7677545f, 821.4767545f, 830.0217546f, 831.7307546f, 833.4397546f, 835.1487546f, 836.8577546f, 838.5667546f, 847.1117547f, 848.8207547f, 850.5297547f, 852.2387547f, 853.9477547f, 855.6567547f, 1218.9329915f, 1221.6994915f, 1224.4659915f, 1227.232491f, 1229.9989914f, 1232.7654914f, 1246.5979913f, 1249.3644913f, 1252.1309913f, 1254.8974913f, 1257.6639913f, 1260.430491f, 1274.2629912f, 1277.029491f, 1279.7959911f, 1282.5624911f, 1285.3289911f, 1288.0954911f, 1301.9279910f, 1304.6944910f, 1307.4609910f, 1310.22749f, 1312.9939909f, 1315.7604909f, 1329.5929908f, 1332.3594908f, 1335.1259908f, 1337.8924908f, 1340.6589908f, 1343.4254908f, 1357.2579907f, + 1360.0244907f, 1362.7909906f, 1365.5574906f, 1368.3239906f, 1371.0904906f, 1676.2042479f, 1680.0282479f, 1683.8522479f, 1687.6762479f, 1691.5002479f, 1695.3242479f, 1714.4442479f, 1718.2682479f, 1722.0922479f, 1725.9162479f, 1729.7402479f, 1733.5642479f, 1752.6842479f, 1756.5082479f, 1760.3322479f, 1764.1562479f, 1767.9802479f, 1771.8042479f, 1790.9242479f, 1794.7482479f, 1798.5722479f, 1802.3962479f, 1806.2202479f, 1810.044247f, 1829.1642478f, 1832.9882478f, 1836.8122478f, 1840.6362478f, 1844.4602478f, 1848.2842478f, 1867.4042478f, 1871.2282478f, 1875.0522478f, 1878.8762478f, 1882.7002478f, 1886.5242478f, 2133.4755029f, 2138.3570029f, 2143.2385029f, 2148.1200029f, 2153.0015029f, 2157.8830029f, 2182.2905028f, 2187.1720028f, 2192.0535028f, 2196.9350028f, 2201.8165028f, 2206.6980028f, 2231.1055028f, 2235.9870028f, 2240.8685028f, 2245.7500028f, 2250.6315028f, 2255.5130028f, 2279.9205027f, 2284.8020027f, 2289.6835027f, 2294.5650027f, 2299.4465027f, 2304.3280027f, 2328.7355027f, 2333.6170027f, 2338.4985027f, 2343.3800027f, 2348.2615027f, 2353.1430027f, 2377.5505026f, 2382.4320026f, 2387.3135026f, 2392.1950026f, 2397.0765026f, 2401.9580026f, 2590.7467330f, 2596.6857330f, 2602.6247329f, 2608.5637329f, 2614.5027329f, 2620.441732f, 2650.1367327f, 2656.0757327f, 2662.0147326f, 2667.9537326f, 2673.8927326f, 2679.8317325f, 2709.5267324f, 2715.465732f, 2721.4047323f, 2727.3437323f, 2733.282732f, 2739.2217322f, 2768.9167321f, 2774.8557320f, 2780.7947320f, 2786.7337320f, 2792.6727319f, 2798.6117319f, 2828.306731f, 2834.2457317f, 2840.1847317f, 2846.1237317f, 2852.0627316f, 2858.0017316f, 2887.6967314f, 2893.6357314f, 2899.5747314f, 2905.5137313f, 2911.4527313f, 2917.3917313f, 3048.0179587f, 3055.0144586f, 3062.0109585f, 3069.0074584f, 3076.0039584f, 3083.0004583f, 3117.9829579f, 3124.9794578f, 3131.9759578f, 3138.9724577f, 3145.9689576f, 3152.9654575f, 3187.947957f, 3194.9444571f, 3201.9409570f, 3208.9374569f, 3215.933956f, 3222.9304568f, 3257.9129564f, 3264.9094563f, 3271.9059562f, 3278.9024562f, 3285.8989561f, + 3292.8954560f, 3327.8779556f, 3334.874455f, 3341.8709555f, 3348.8674554f, 3355.8639553f, 3362.860455f, 3397.8429549f, 3404.8394548f, 3411.8359547f, 3418.8324546f, 3425.8289546f, 3432.8254545f, 3505.28927f, 3513.3432780f, 3521.3972781f, 3529.4512782f, 3537.5052782f, 3545.5592783f, 3585.8292787f, 3593.8832788f, 3601.9372788f, 3609.9912789f, 3618.0452790f, 3626.099279f, + 3666.3692794f, 3674.4232795f, 3682.4772796f, 3690.5312796f, 3698.5852797f, 3706.6392798f, 3746.9092801f, 3754.9632802f, 3763.0172803f, 3771.0712804f, 3779.1252804f, 3787.1792805f, 3827.4492809f, 3835.50328f, 3843.5572810f, 3851.6112811f, 3859.6652812f, 3867.7192812f, 3907.9892816f, 3916.0432817f, 3924.097281f, + 3932.1512818f, 3940.2052819f, 3948.2592820f, 3962.5605113f, 3971.6720113f, 3980.783511f, 3989.8950114f, 3999.0065114f, 4008.1180114f, 4053.6755115f, 4062.7870115f, 4071.8985115f, 4081.0100115f, 4090.1215115f, 4099.2330115f, 4144.7905116f, 4153.9020116f, 4163.0135116f, 4172.1250116f, + 4181.236511f, 4190.3480117f, 4235.9055117f, 4245.0170117f, 4254.128511f, 4263.2400118f, 4272.3515118f, 4281.4630118f, 4327.0205119f, 4336.1320119f, 4345.2435119f, 4354.3550119f, 4363.4665119f, 4372.5780119f, 4418.1355120f, 4427.2470120f, 4436.3585120f, 4445.4700120f, 4454.581512f, 4463.6930121f, 4419.8317743f, 4430.0007744f, 4440.1697744f, 4450.338774f, 4460.5077745f, 4470.6767745f, 4521.521774f, 4531.6907748f, + 4541.8597748f, 4552.0287749f, 4562.1977749f, 4572.3667750f, 4623.2117752f, 4633.3807752f, 4643.5497753f, 4653.7187753f, 4663.8877754f, 4674.0567754f, 4724.9017756f, 4735.0707757f, 4745.2397757f, 4755.4087757f, 4765.5777758f, 4775.7467758f, 4826.591776f, 4836.7607761f, 4846.9297761f, 4857.0987762f, 4867.2677762f, 4877.4367763f, 4928.2817765f, 4938.4507765f, 4948.6197766f, 4958.7887766f, 4968.957776f, 4979.12677675f}; Nd4jLong _expSFF[] = {4, 2, 10, 6, 6, 360, 36, 6, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, 1, 99,}; NDArray expFF(_expBFF, _expSFF); @@ -625,11 +635,11 @@ TYPED_TEST(TypedConvolutionTests1, conv2D_BP_NoBias_1) { } TYPED_TEST(TypedConvolutionTests1, sconv2d_conv2d_1) { - TypeParam _expBFF[] = {10025.0f, 10350.0f, 10675.0f, 11000.0f, 11325.0f, 11650.0f, 13275.0f, 13600.0f, 13925.0f, 14250.0f, 14575.0f, 14900.0f, 16525.0f, 16850.0f, 17175.0f, 17500.0f, 17825.0f, 18150.0f, 19775.0f, 20100.0f, 20425.0f, 20750.0f, 21075.0f, 21400.0f, 23025.0f, 23350.0f, 23675.0f, 24000.0f, 24325.0f, 24650.0f, 26275.0f, 26600.0f, 26925.0f, 27250.0f, 27575.0f, 27900.0f, 53150.0f, 55350.0f, 57550.0f, 59750.0f, 61950.0f, 64150.0f, 75150.0f, 77350.0f, 79550.0f, 81750.0f, 83950.0f, 86150.0f, 97150.0f, 99350.0f, 101550.0f, 103750.0f, 105950.0f, 108150.0f, 119150.0f, 121350.0f, 123550.0f, 125750.0f, 127950.0f, 130150.0f, 141150.0f, 143350.0f, 145550.0f, 147750.0f, 149950.0f, 152150.0f, 163150.0f, 165350.0f, 167550.0f, 169750.0f, 171950.0f, 174150.0f, 119400.0f, 120350.0f, 121300.0f, 122250.0f, 123200.0f, 124150.0f, 128900.0f, 129850.0f, 130800.0f, 131750.0f, 132700.0f, 133650.0f, 138400.0f, 139350.0f, 140300.0f, 141250.0f, 142200.0f, 143150.0f, 147900.0f, 148850.0f, 149800.0f, 150750.0f, 151700.0f, 152650.0f, 157400.0f, 158350.0f, 159300.0f, 160250.0f, 161200.0f, 162150.0f, 166900.0f, 167850.0f, 168800.0f, 169750.0f, 170700.0f, 171650.0f, 350025.0f, 352850.0f, 355675.0f, 358500.0f, 361325.0f, 364150.0f, 378275.0f, 381100.0f, 383925.0f, 386750.0f, 389575.0f, 392400.0f, 406525.0f, 409350.0f, 412175.0f, 415000.0f, 417825.0f, 420650.0f, 434775.0f, 437600.0f, 440425.0f, 443250.0f, 446075.0f, 448900.0f, 463025.0f, 465850.0f, 468675.0f, 471500.0f, 474325.0f, 477150.0f, 491275.0f, 494100.0f, 496925.0f, 499750.0f, 502575.0f, 505400.0f, 353775.0f, 355350.0f, 356925.0f, 358500.0f, 360075.0f, 361650.0f, 369525.0f, 371100.0f, 372675.0f, 374250.0f, 375825.0f, 377400.0f, 385275.0f, 386850.0f, 388425.0f, 390000.0f, 391575.0f, 393150.0f, 401025.0f, 402600.0f, 404175.0f, 405750.0f, 407325.0f, 408900.0f, 416775.0f, 418350.0f, 419925.0f, 421500.0f, 423075.0f, 424650.0f, 432525.0f, 434100.0f, 435675.0f, 437250.0f, 438825.0f, 440400.0f, 771900.0f, 775350.0f, 778800.0f, 782250.0f, 785700.0f, 789150.0f, 806400.0f, 809850.0f, 813300.0f, 816750.0f, 820200.0f, 823650.0f, 840900.0f, 844350.0f, 847800.0f, 851250.0f, 854700.0f, 858150.0f, 875400.0f, 878850.0f, 882300.0f, 885750.0f, 889200.0f, 892650.0f, 909900.0f, 913350.0f, 916800.0f, 920250.0f, 923700.0f, 927150.0f, 944400.0f, 947850.0f, 951300.0f, 954750.0f, 958200.0f, 961650.0f, 107525.0f, 107850.0f, 108175.0f, 108500.0f, 108825.0f, 109150.0f, 110775.0f, 111100.0f, 111425.0f, 111750.0f, 112075.0f, 112400.0f, 114025.0f, 114350.0f, 114675.0f, 115000.0f, 115325.0f, 115650.0f, 117275.0f, 117600.0f, 117925.0f, 118250.0f, 118575.0f, 118900.0f, 120525.0f, 120850.0f, 121175.0f, 121500.0f, 121825.0f, 122150.0f, 123775.0f, 124100.0f, 124425.0f, 124750.0f, 125075.0f, 125400.0f, 713150.0f, 715350.0f, 717550.0f, 719750.0f, 721950.0f, 724150.0f, 735150.0f, 737350.0f, 739550.0f, 741750.0f, 743950.0f, 746150.0f, 757150.0f, 759350.0f, 761550.0f, 763750.0f, 765950.0f, 768150.0f, 779150.0f, 781350.0f, 783550.0f, 785750.0f, 787950.0f, 790150.0f, 801150.0f, 803350.0f, 805550.0f, 807750.0f, 809950.0f, 812150.0f, 823150.0f, 825350.0f, 827550.0f, 829750.0f, 831950.0f, 834150.0f, 404400.0f, 405350.0f, 406300.0f, 407250.0f, 408200.0f, 409150.0f, 413900.0f, 414850.0f, 415800.0f, 416750.0f, 417700.0f, 418650.0f, 423400.0f, 424350.0f, 425300.0f, 426250.0f, 427200.0f, 428150.0f, 432900.0f, 433850.0f, 434800.0f, 435750.0f, 436700.0f, 437650.0f, 442400.0f, 443350.0f, 444300.0f, 445250.0f, 446200.0f, 447150.0f, 451900.0f, 452850.0f, 453800.0f, 454750.0f, 455700.0f, 456650.0f, 1197525.0f, 1200350.0f, 1203175.0f, 1206000.0f, 1208825.0f, 1211650.0f, 1225775.0f, 1228600.0f, 1231425.0f, 1234250.0f, 1237075.0f, 1239900.0f, 1254025.0f, 1256850.0f, 1259675.0f, 1262500.0f, 1265325.0f, 1268150.0f, 1282275.0f, 1285100.0f, 1287925.0f, 1290750.0f, 1293575.0f, 1296400.0f, 1310525.0f, 1313350.0f, 1316175.0f, 1319000.0f, 1321825.0f, 1324650.0f, 1338775.0f, 1341600.0f, 1344425.0f, 1347250.0f, 1350075.0f, 1352900.0f, 826275.0f, 827850.0f, 829425.0f, 831000.0f, 832575.0f, 834150.0f, 842025.0f, 843600.0f, 845175.0f, 846750.0f, 848325.0f, 849900.0f, 857775.0f, 859350.0f, 860925.0f, 862500.0f, 864075.0f, 865650.0f, 873525.0f, 875100.0f, 876675.0f, 878250.0f, 879825.0f, 881400.0f, 889275.0f, 890850.0f, 892425.0f, 894000.0f, 895575.0f, 897150.0f, 905025.0f, 906600.0f, 908175.0f, 909750.0f, 911325.0f, 912900.0f, 1806900.0f, 1810350.0f, 1813800.0f, 1817250.0f, 1820700.0f, 1824150.0f, 1841400.0f, 1844850.0f, 1848300.0f, 1851750.0f, 1855200.0f, 1858650.0f, 1875900.0f, 1879350.0f, 1882800.0f, 1886250.0f, 1889700.0f, 1893150.0f, 1910400.0f, 1913850.0f, 1917300.0f, 1920750.0f, 1924200.0f, 1927650.0f, 1944900.0f, 1948350.0f, 1951800.0f, 1955250.0f, 1958700.0f, 1962150.0f, 1979400.0f, 1982850.0f, 1986300.0f, 1989750.0f, 1993200.0f, 1996650.}; + TypeParam _expBFF[] = {10025.0f, 10350.0f, 10675.0f, 11000.0f, 11325.0f, 11650.0f, 13275.0f, 13600.0f, 13925.0f, 14250.0f, 14575.0f, 14900.0f, 16525.0f, 16850.0f, 17175.0f, 17500.0f, 17825.0f, 18150.0f, 19775.0f, 20100.0f, 20425.0f, 20750.0f, 21075.0f, 21400.0f, 23025.0f, 23350.0f, 23675.0f, 24000.0f, 24325.0f, 24650.0f, 26275.0f, 26600.0f, 26925.0f, 27250.0f, 27575.0f, 27900.0f, 53150.0f, 55350.0f, 57550.0f, 59750.0f, 61950.0f, 64150.0f, 75150.0f, 77350.0f, 79550.0f, 81750.0f, 83950.0f, 86150.0f, 97150.0f, 99350.0f, 101550.0f, 103750.0f, 105950.0f, 108150.0f, 119150.0f, 121350.0f, 123550.0f, 125750.0f, 127950.0f, 130150.0f, 141150.0f, 143350.0f, 145550.0f, 147750.0f, 149950.0f, 152150.0f, 163150.0f, 165350.0f, 167550.0f, 169750.0f, 171950.0f, 174150.0f, 119400.0f, 120350.0f, 121300.0f, 122250.0f, 123200.0f, 124150.0f, 128900.0f, 129850.0f, 130800.0f, 131750.0f, 132700.0f, 133650.0f, 138400.0f, 139350.0f, 140300.0f, 141250.0f, 142200.0f, 143150.0f, 147900.0f, 148850.0f, 149800.0f, 150750.0f, 151700.0f, 152650.0f, 157400.0f, 158350.0f, 159300.0f, 160250.0f, 161200.0f, 162150.0f, 166900.0f, 167850.0f, 168800.0f, 169750.0f, 170700.0f, 171650.0f, 350025.0f, 352850.0f, 355675.0f, 358500.0f, 361325.0f, 364150.0f, 378275.0f, 381100.0f, 383925.0f, 386750.0f, 389575.0f, 392400.0f, 406525.0f, 409350.0f, 412175.0f, 415000.0f, 417825.0f, 420650.0f, 434775.0f, 437600.0f, 440425.0f, 443250.0f, 446075.0f, 448900.0f, 463025.0f, 465850.0f, 468675.0f, 471500.0f, 474325.0f, 477150.0f, 491275.0f, 494100.0f, 496925.0f, 499750.0f, 502575.0f, 505400.0f, 353775.0f, 355350.0f, 356925.0f, 358500.0f, 360075.0f, 361650.0f, 369525.0f, 371100.0f, 372675.0f, 374250.0f, 375825.0f, 377400.0f, 385275.0f, 386850.0f, 388425.0f, 390000.0f, 391575.0f, 393150.0f, 401025.0f, 402600.0f, 404175.0f, 405750.0f, 407325.0f, 408900.0f, 416775.0f, 418350.0f, 419925.0f, 421500.0f, 423075.0f, 424650.0f, 432525.0f, 434100.0f, 435675.0f, 437250.0f, 438825.0f, 440400.0f, 771900.0f, 775350.0f, 778800.0f, 782250.0f, 785700.0f, 789150.0f, 806400.0f, 809850.0f, 813300.0f, 816750.0f, 820200.0f, 823650.0f, 840900.0f, 844350.0f, 847800.0f, 851250.0f, 854700.0f, 858150.0f, 875400.0f, 878850.0f, 882300.0f, 885750.0f, 889200.0f, 892650.0f, 909900.0f, 913350.0f, 916800.0f, 920250.0f, 923700.0f, 927150.0f, 944400.0f, 947850.0f, 951300.0f, 954750.0f, 958200.0f, 961650.0f, 107525.0f, 107850.0f, 108175.0f, 108500.0f, 108825.0f, 109150.0f, 110775.0f, 111100.0f, 111425.0f, 111750.0f, 112075.0f, 112400.0f, 114025.0f, 114350.0f, 114675.0f, 115000.0f, 115325.0f, 115650.0f, 117275.0f, 117600.0f, 117925.0f, 118250.0f, 118575.0f, 118900.0f, 120525.0f, 120850.0f, 121175.0f, 121500.0f, 121825.0f, 122150.0f, 123775.0f, 124100.0f, 124425.0f, 124750.0f, 125075.0f, 125400.0f, 713150.0f, 715350.0f, 717550.0f, 719750.0f, 721950.0f, 724150.0f, 735150.0f, 737350.0f, 739550.0f, 741750.0f, 743950.0f, 746150.0f, 757150.0f, 759350.0f, 761550.0f, 763750.0f, 765950.0f, 768150.0f, 779150.0f, 781350.0f, 783550.0f, 785750.0f, 787950.0f, 790150.0f, 801150.0f, 803350.0f, 805550.0f, 807750.0f, 809950.0f, 812150.0f, 823150.0f, 825350.0f, 827550.0f, 829750.0f, 831950.0f, 834150.0f, 404400.0f, 405350.0f, 406300.0f, 407250.0f, 408200.0f, 409150.0f, 413900.0f, 414850.0f, 415800.0f, 416750.0f, 417700.0f, 418650.0f, 423400.0f, 424350.0f, 425300.0f, 426250.0f, 427200.0f, 428150.0f, 432900.0f, 433850.0f, 434800.0f, 435750.0f, 436700.0f, 437650.0f, 442400.0f, 443350.0f, 444300.0f, 445250.0f, 446200.0f, 447150.0f, 451900.0f, 452850.0f, 453800.0f, 454750.0f, 455700.0f, 456650.0f, 1197525.0f, 1200350.0f, 1203175.0f, 1206000.0f, 1208825.0f, 1211650.0f, 1225775.0f, 1228600.0f, 1231425.0f, 1234250.0f, 1237075.0f, 1239900.0f, 1254025.0f, 1256850.0f, 1259675.0f, 1262500.0f, 1265325.0f, 1268150.0f, 1282275.0f, 1285100.0f, 1287925.0f, 1290750.0f, 1293575.0f, 1296400.0f, 1310525.0f, 1313350.0f, 1316175.0f, 1319000.0f, 1321825.0f, 1324650.0f, 1338775.0f, 1341600.0f, 1344425.0f, 1347250.0f, 1350075.0f, 1352900.0f, 826275.0f, 827850.0f, 829425.0f, 831000.0f, 832575.0f, 834150.0f, 842025.0f, 843600.0f, 845175.0f, 846750.0f, 848325.0f, 849900.0f, 857775.0f, 859350.0f, 860925.0f, 862500.0f, 864075.0f, 865650.0f, 873525.0f, 875100.0f, 876675.0f, 878250.0f, 879825.0f, 881400.0f, 889275.0f, 890850.0f, 892425.0f, 894000.0f, 895575.0f, 897150.0f, 905025.0f, 906600.0f, 908175.0f, 909750.0f, 911325.0f, 912900.0f, 1806900.0f, 1810350.0f, 1813800.0f, 1817250.0f, 1820700.0f, 1824150.0f, 1841400.0f, 1844850.0f, 1848300.0f, 1851750.0f, 1855200.0f, 1858650.0f, 1875900.0f, 1879350.0f, 1882800.0f, 1886250.0f, 1889700.0f, 1893150.0f, 1910400.0f, 1913850.0f, 1917300.0f, 1920750.0f, 1924200.0f, 1927650.0f, 1944900.0f, 1948350.0f, 1951800.0f, 1955250.0f, 1958700.0f, 1962150.0f, 1979400.0f, 1982850.0f, 1986300.0f, 1989750.0f, 1993200.0f, 1996650.f}; Nd4jLong _expSFF[] = {4, 2, 6, 6, 6, 216, 36, 6, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, 1, 99,}; NDArray expFF(_expBFF, _expSFF); - TypeParam _exp2BFF[] = {827.4900282f, 832.2350283f, 836.9800284f, 841.725028f, 846.4700287f, 851.2150288f, 874.9400293f, 879.6850294f, 884.4300295f, 889.1750296f, 893.9200297f, 898.665029f, 922.3900304f, 927.1350305f, 931.8800306f, 936.6250307f, 941.3700308f, 946.1150309f, 969.8400315f, 974.5850316f, 979.3300317f, 984.0750318f, 988.8200319f, 993.5650320f, 1017.2900326f, 1022.0350327f, 1026.7800328f, 1031.5250329f, 1036.2700330f, 1041.0150331f, 1064.7400337f, 1069.4850338f, 1074.2300339f, 1078.9750340f, 1083.7200341f, 1088.4650342f, 1822.4550553f, 1833.995055f, 1845.5350558f, 1857.075056f, 1868.6150563f, 1880.1550566f, 1937.8550578f, 1949.3950581f, 1960.9350583f, 1972.4750586f, 1984.015058f, 1995.5550591f, 2053.2550604f, 2064.7950606f, 2076.3350609f, 2087.8750611f, 2099.4150614f, 2110.955061f, 2168.6550629f, 2180.1950632f, 2191.7350634f, 2203.2750637f, 2214.8150639f, 2226.3550642f, 2284.0550655f, 2295.5950657f, 2307.1350660f, 2318.6750662f, 2330.2150665f, 2341.7550667f, 2399.4550680f, 2410.9950683f, 2422.5350685f, 2434.0750688f, 2445.6150690f, 2457.1550693f, 2817.419968f, 2835.7549686f, 2854.0899683f, 2872.4249680f, 2890.7599677f, 2909.0949674f, 3000.7699660f, 3019.104965f, 3037.4399655f, 3055.7749652f, 3074.1099649f, 3092.4449646f, 3184.1199632f, 3202.4549629f, 3220.789962f, 3239.1249624f, 3257.4599621f, 3275.7949618f, 3367.4699604f, 3385.8049601f, 3404.1399598f, 3422.474959f, 3440.8099593f, 3459.1449590f, 3550.8199576f, 3569.1549573f, 3587.4899570f, 3605.8249567f, 3624.1599565f, 3642.4949562f, 3734.1699548f, 3752.5049545f, 3770.8399542f, 3789.1749539f, 3807.5099536f, 3825.8449534f, 3812.385098f, 3837.5150988f, 3862.6450994f, 3887.7751000f, 3912.9051006f, 3938.0351012f, 4063.6851041f, 4088.8151047f, 4113.9451053f, 4139.0751059f, 4164.2051065f, 4189.3351071f, 4314.9851100f, 4340.1151106f, 4365.2451112f, 4390.3751118f, 4415.5051124f, 4440.6351130f, 4566.2851159f, 4591.4151165f, 4616.5451171f, 4641.6751177f, 4666.805118f, 4691.9351188f, 4817.5851218f, 4842.7151224f, 4867.8451230f, 4892.975123f, 4918.1051241f, 4943.2351247f, 5068.8851277f, 5094.0151283f, 5119.1451288f, 5144.2751294f, 5169.4051300f, 5194.5351306f, 4807.3499803f, 4839.2749801f, 4871.1999799f, 4903.1249797f, 4935.0499795f, 4966.9749793f, 5126.5999784f, 5158.5249782f, 5190.4499780f, 5222.3749778f, 5254.2999777f, 5286.2249775f, 5445.8499765f, 5477.774976f, 5509.6999762f, 5541.6249760f, 5573.5499758f, 5605.4749756f, 5765.0999747f, 5797.0249745f, 5828.9499743f, 5860.8749741f, 5892.7999739f, 5924.724973f, 6084.3499728f, 6116.2749726f, 6148.1999724f, 6180.1249723f, 6212.0499721f, 6243.9749719f, 6403.59997f, 6435.5249708f, 6467.4499706f, 6499.3749704f, 6531.2999702f, 6563.2249700f, 5802.3150007f, 5841.0350006f, 5879.7550005f, 5918.4750004f, 5957.195000f, 5995.9150003f, 6189.5149999f, 6228.2349998f, 6266.9549997f, 6305.6749996f, 6344.3949995f, 6383.114999f, 6576.7149990f, 6615.4349990f, 6654.1549989f, 6692.8749988f, 6731.5949987f, 6770.3149986f, 6963.9149982f, 7002.6349981f, 7041.3549981f, 7080.0749980f, 7118.7949979f, 7157.5149978f, 7351.1149974f, 7389.8349973f, 7428.5549972f, 7467.2749972f, 7505.9949971f, 7544.7149970f, 7738.3149966f, 7777.0349965f, 7815.7549964f, 7854.4749963f, 7893.1949963f, 7931.9149962f, 6797.2799488f, 6842.794948f, 6888.3099489f, 6933.8249490f, 6979.3399491f, 7024.8549492f, 7252.4299497f, 7297.9449498f, 7343.4599499f, 7388.9749500f, 7434.489950f, 7480.0049501f, 7707.5799506f, 7753.0949507f, 7798.6099508f, 7844.1249509f, 7889.6399510f, 7935.1549511f, 8162.7299515f, 8208.2449516f, 8253.7599517f, 8299.2749518f, 8344.7899519f, 8390.3049520f, 8617.8799525f, 8663.394952f, 8708.9099526f, 8754.4249527f, 8799.9399528f, 8845.4549529f, 9073.0299534f, 9118.5449535f, 9164.0599536f, 9209.5749537f, 9255.089953f, 9300.604953f, 7792.2451647f, 7844.5551655f, 7896.8651663f, 7949.1751671f, 8001.4851679f, 8053.7951686f, 8315.3451725f, 8367.6551733f, 8419.9651741f, 8472.2751749f, 8524.585175f, 8576.8951764f, 8838.4451803f, 8890.7551811f, 8943.0651819f, 8995.3751827f, 9047.6851834f, 9099.9951842f, 9361.5451881f, 9413.8551889f, 9466.1651897f, 9518.475190f, 9570.7851912f, 9623.0951920f, 9884.6451959f, 9936.9551967f, 9989.2651975f, 10041.5751982f, 10093.8851990f, 10146.1951998f, 10407.7452037f, 10460.0552045f, 10512.3652053f, 10564.6752060f, 10616.9852068f, 10669.2952076f, 8787.210074f, 8846.3150748f, 8905.4200750f, 8964.5250752f, 9023.6300755f, 9082.7350757f, 9378.2600768f, 9437.3650770f, 9496.4700773f, 9555.5750775f, 9614.6800777f, 9673.7850779f, 9969.3100791f, 10028.4150793f, 10087.5200795f, 10146.625079f, 10205.7300800f, 10264.8350802f, 10560.3600813f, 10619.465081f, 10678.5700818f, 10737.6750820f, 10796.7800822f, 10855.8850825f, 11151.4100836f, 11210.5150838f, 11269.6200840f, 11328.7250843f, 11387.8300845f, 11446.9350847f, 11742.4600858f, 11801.5650861f, 11860.6700863f, 11919.7750865f, 11978.880086f, 12037.9850870f, 9782.1750935f, 9848.0750935f, 9913.9750934f, 9979.8750934f, 10045.7750934f, 10111.6750933f, 10441.1750931f, 10507.0750931f, 10572.9750931f, 10638.8750930f, 10704.7750930f, 10770.6750930f, 11100.1750928f, 11166.0750927f, 11231.9750927f, 11297.8750927f, 11363.7750926f, 11429.6750926f, 11759.1750924f, 11825.0750924f, 11890.9750923f, 11956.8750923f, 12022.7750923f, 12088.6750922f, 12418.175092f, 12484.0750920f, 12549.9750920f, 12615.8750919f, 12681.7750919f, 12747.6750919f, 13077.1750917f, 13143.0750916f, 13208.9750916f, 13274.8750916f, 13340.7750915f, 13406.6750915f, 2250.990060f, 2255.7350610f, 2260.4800611f, 2265.2250612f, 2269.9700613f, 2274.7150614f, 2298.4400619f, 2303.185062f, 2307.9300622f, 2312.6750623f, 2317.4200624f, 2322.1650625f, 2345.8900630f, 2350.6350631f, 2355.380063f, 2360.1250634f, 2364.8700635f, 2369.6150636f, 2393.3400641f, 2398.0850642f, 2402.8300643f, 2407.5750644f, 2412.320064f, 2417.0650647f, 2440.7900652f, 2445.5350653f, 2450.2800654f, 2455.0250655f, 2459.7700656f, 2464.515065f, 2488.2400663f, 2492.9850664f, 2497.7300665f, 2502.4750666f, 2507.2200667f, 2511.9650668f, 5284.4551315f, 5295.9951318f, 5307.535132f, 5319.0751323f, 5330.6151326f, 5342.1551328f, 5399.8551341f, 5411.3951343f, 5422.9351346f, 5434.475134f, 5446.0151351f, 5457.5551354f, 5515.2551366f, 5526.7951369f, 5538.3351371f, 5549.8751374f, 5561.4151376f, 5572.9551379f, 5630.6551392f, 5642.1951394f, 5653.7351397f, 5665.2751399f, 5676.8151402f, 5688.3551404f, 5746.0551417f, 5757.5951420f, 5769.1351422f, 5780.6751425f, 5792.2151427f, 5803.7551430f, 5861.455144f, 5872.9951445f, 5884.5351448f, 5896.0751450f, 5907.6151453f, 5919.1551455f, 8317.919884f, 8336.2548841f, 8354.5898838f, 8372.9248835f, 8391.2598832f, 8409.59488f, 8501.2698815f, 8519.6048813f, 8537.9398810f, 8556.2748807f, 8574.6098804f, 8592.9448801f, 8684.6198787f, 8702.9548784f, 8721.2898782f, 8739.6248779f, 8757.9598776f, 8776.2948773f, 8867.9698759f, 8886.3048756f, 8904.6398753f, 8922.9748751f, 8941.3098748f, 8959.6448745f, 9051.3198731f, 9069.6548728f, 9087.9898725f, 9106.3248722f, 9124.6598720f, 9142.9948717f, 9234.6698703f, 9253.0048700f, 9271.3398697f, 9289.6748694f, 9308.0098691f, 9326.3448689f, 11351.3852747f, 11376.5152753f, 11401.6452759f, 11426.7752765f, 11451.9052771f, 11477.0352777f, 11602.6852806f, 11627.8152812f, 11652.9452818f, 11678.0752824f, 11703.2052830f, 11728.335283f, 11853.9852865f, 11879.1152871f, 11904.2452877f, 11929.3752883f, 11954.505288f, 11979.6352894f, 12105.2852924f, 12130.4152930f, 12155.545293f, 12180.6752941f, 12205.8052947f, 12230.9352953f, 12356.5852983f, 12381.715298f, 12406.8452994f, 12431.9753000f, 12457.1053006f, 12482.2353012f, 12607.8853041f, 12633.0153047f, 12658.1453053f, 12683.2753059f, 12708.4053065f, 12733.5353071f, 14384.8499244f, 14416.7749242f, 14448.6999240f, 14480.6249238f, 14512.549923f, 14544.4749235f, 14704.0999225f, 14736.024922f, 14767.9499222f, 14799.8749220f, 14831.7999218f, 14863.7249216f, 15023.3499207f, 15055.2749205f, 15087.1999203f, 15119.1249201f, 15151.0499199f, 15182.9749197f, 15342.5999188f, 15374.5249186f, 15406.4499184f, 15438.374918f, 15470.2999181f, 15502.2249179f, 15661.84991f, 15693.7749168f, 15725.6999166f, 15757.6249164f, 15789.5499162f, 15821.4749160f, 15981.0999151f, 16013.0249149f, 16044.9499147f, 16076.8749145f, 16108.7999143f, 16140.7249142f, 17418.314976f, 17457.0349761f, 17495.7549760f, 17534.4749759f, 17573.1949758f, 17611.9149757f, 17805.5149753f, 17844.234975f, 17882.9549752f, 17921.6749751f, 17960.3949750f, 17999.1149749f, 18192.7149745f, 18231.4349744f, 18270.154974f, 18308.8749743f, 18347.5949742f, 18386.3149741f, 18579.9149737f, 18618.6349736f, 18657.3549735f, 18696.074973f, 18734.7949734f, 18773.5149733f, 18967.1149729f, 19005.8349728f, 19044.5549727f, 19083.2749726f, 19121.994972f, 19160.7149725f, 19354.3149721f, 19393.0349720f, 19431.7549719f, 19470.4749718f, 19509.1949717f, 19547.914971f, 20451.7799765f, 20497.2949766f, 20542.8099767f, 20588.3249768f, 20633.8399769f, 20679.3549770f, 20906.929977f, 20952.4449775f, 20997.9599776f, 21043.4749777f, 21088.9899778f, 21134.5049779f, 21362.0799784f, 21407.5949785f, 21453.1099786f, 21498.624978f, 21544.139978f, 21589.6549788f, 21817.2299793f, 21862.7449794f, 21908.2599795f, 21953.7749796f, 21999.2899797f, 22044.8049798f, 22272.3799802f, 22317.8949803f, 22363.4099804f, 22408.9249805f, 22454.4399806f, 22499.9549807f, 22727.529981f, 22773.044981f, 22818.5599813f, 22864.0749814f, 22909.5899815f, 22955.1049816f, 23485.2453985f, 23537.555399f, 23589.8654000f, 23642.1754008f, 23694.4854016f, 23746.7954024f, 24008.3454063f, 24060.655407f, 24112.9654078f, 24165.2754086f, 24217.5854094f, 24269.8954102f, 24531.4454141f, 24583.7554148f, 24636.0654156f, 24688.3754164f, 24740.6854172f, 24792.99541f, 25054.545421f, 25106.8554226f, 25159.1654234f, 25211.4754242f, 25263.7854250f, 25316.0954257f, 25577.6454296f, 25629.9554304f, 25682.2654312f, 25734.5754320f, 25786.8854328f, 25839.1954335f, 26100.7454374f, 26153.0554382f, 26205.3654390f, 26257.6754398f, 26309.985440f, 26362.2954413f, 26518.7101423f, 26577.8151425f, 26636.920142f, 26696.0251430f, 26755.1301432f, 26814.2351434f, 27109.7601446f, 27168.8651448f, 27227.9701450f, 27287.0751452f, 27346.1801455f, 27405.2851457f, 27700.8101468f, 27759.9151470f, 27819.0201473f, 27878.1251475f, 27937.2301477f, 27996.33514f, 28291.8601491f, 28350.9651493f, 28410.0701495f, 28469.175149f, 28528.2801500f, 28587.3851502f, 28882.9101513f, 28942.0151516f, 29001.1201518f, 29060.2251520f, 29119.3301522f, 29178.4351525f, 29473.9601536f, 29533.0651538f, 29592.1701540f, 29651.2751543f, 29710.3801545f, 29769.4851547f, 29552.1750826f, 29618.0750825f, 29683.9750825f, 29749.8750825f, 29815.7750824f, 29881.6750824f, 30211.1750822f, 30277.0750822f, 30342.9750821f, 30408.8750821f, 30474.7750821f, 30540.6750820f, 30870.175081f, 30936.0750818f, 31001.9750818f, 31067.8750817f, 31133.7750817f, 31199.6750817f, 31529.1750815f, 31595.075081f, 31660.9750814f, 31726.8750814f, 31792.7750813f, 31858.6750813f, 32188.1750811f, 32254.0750811f, 32319.975081f, 32385.8750810f, 32451.7750810f, 32517.6750809f, 32847.1750808f, 32913.0750807f, 32978.9750807f, 33044.875080f, 33110.7750806f, 33176.67508062}; - Nd4jLong _exp2SFF[] = {4, 2, 10, 6, 6, 360, 36, 6, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, 1, 99,}; + TypeParam _exp2BFF[] = {827.4900282f, 832.2350283f, 836.9800284f, 841.725028f, 846.4700287f, 851.2150288f, 874.9400293f, 879.6850294f, 884.4300295f, 889.1750296f, 893.9200297f, 898.665029f, 922.3900304f, 927.1350305f, 931.8800306f, 936.6250307f, 941.3700308f, 946.1150309f, 969.8400315f, 974.5850316f, 979.3300317f, 984.0750318f, 988.8200319f, 993.5650320f, 1017.2900326f, 1022.0350327f, 1026.7800328f, 1031.5250329f, 1036.2700330f, 1041.0150331f, 1064.7400337f, 1069.4850338f, 1074.2300339f, 1078.9750340f, 1083.7200341f, 1088.4650342f, 1822.4550553f, 1833.995055f, 1845.5350558f, 1857.075056f, 1868.6150563f, 1880.1550566f, 1937.8550578f, 1949.3950581f, 1960.9350583f, 1972.4750586f, 1984.015058f, 1995.5550591f, 2053.2550604f, 2064.7950606f, 2076.3350609f, 2087.8750611f, 2099.4150614f, 2110.955061f, 2168.6550629f, 2180.1950632f, 2191.7350634f, 2203.2750637f, 2214.8150639f, 2226.3550642f, 2284.0550655f, 2295.5950657f, 2307.1350660f, 2318.6750662f, 2330.2150665f, 2341.7550667f, 2399.4550680f, 2410.9950683f, 2422.5350685f, 2434.0750688f, 2445.6150690f, 2457.1550693f, 2817.419968f, 2835.7549686f, 2854.0899683f, 2872.4249680f, 2890.7599677f, 2909.0949674f, 3000.7699660f, 3019.104965f, 3037.4399655f, 3055.7749652f, 3074.1099649f, 3092.4449646f, 3184.1199632f, 3202.4549629f, 3220.789962f, 3239.1249624f, 3257.4599621f, 3275.7949618f, 3367.4699604f, 3385.8049601f, 3404.1399598f, 3422.474959f, 3440.8099593f, 3459.1449590f, 3550.8199576f, 3569.1549573f, 3587.4899570f, 3605.8249567f, 3624.1599565f, 3642.4949562f, 3734.1699548f, 3752.5049545f, 3770.8399542f, 3789.1749539f, 3807.5099536f, 3825.8449534f, 3812.385098f, 3837.5150988f, 3862.6450994f, 3887.7751000f, 3912.9051006f, 3938.0351012f, 4063.6851041f, 4088.8151047f, 4113.9451053f, 4139.0751059f, 4164.2051065f, 4189.3351071f, 4314.9851100f, 4340.1151106f, 4365.2451112f, 4390.3751118f, 4415.5051124f, 4440.6351130f, 4566.2851159f, 4591.4151165f, 4616.5451171f, 4641.6751177f, 4666.805118f, 4691.9351188f, 4817.5851218f, 4842.7151224f, 4867.8451230f, 4892.975123f, 4918.1051241f, 4943.2351247f, 5068.8851277f, 5094.0151283f, 5119.1451288f, 5144.2751294f, 5169.4051300f, 5194.5351306f, 4807.3499803f, 4839.2749801f, 4871.1999799f, 4903.1249797f, 4935.0499795f, 4966.9749793f, 5126.5999784f, 5158.5249782f, 5190.4499780f, 5222.3749778f, 5254.2999777f, 5286.2249775f, 5445.8499765f, 5477.774976f, 5509.6999762f, 5541.6249760f, 5573.5499758f, 5605.4749756f, 5765.0999747f, 5797.0249745f, 5828.9499743f, 5860.8749741f, 5892.7999739f, 5924.724973f, 6084.3499728f, 6116.2749726f, 6148.1999724f, 6180.1249723f, 6212.0499721f, 6243.9749719f, 6403.59997f, 6435.5249708f, 6467.4499706f, 6499.3749704f, 6531.2999702f, 6563.2249700f, 5802.3150007f, 5841.0350006f, 5879.7550005f, 5918.4750004f, 5957.195000f, 5995.9150003f, 6189.5149999f, 6228.2349998f, 6266.9549997f, 6305.6749996f, 6344.3949995f, 6383.114999f, 6576.7149990f, 6615.4349990f, 6654.1549989f, 6692.8749988f, 6731.5949987f, 6770.3149986f, 6963.9149982f, 7002.6349981f, 7041.3549981f, 7080.0749980f, 7118.7949979f, 7157.5149978f, 7351.1149974f, 7389.8349973f, 7428.5549972f, 7467.2749972f, 7505.9949971f, 7544.7149970f, 7738.3149966f, 7777.0349965f, 7815.7549964f, 7854.4749963f, 7893.1949963f, 7931.9149962f, 6797.2799488f, 6842.794948f, 6888.3099489f, 6933.8249490f, 6979.3399491f, 7024.8549492f, 7252.4299497f, 7297.9449498f, 7343.4599499f, 7388.9749500f, 7434.489950f, 7480.0049501f, 7707.5799506f, 7753.0949507f, 7798.6099508f, 7844.1249509f, 7889.6399510f, 7935.1549511f, 8162.7299515f, 8208.2449516f, 8253.7599517f, 8299.2749518f, 8344.7899519f, 8390.3049520f, 8617.8799525f, 8663.394952f, 8708.9099526f, 8754.4249527f, 8799.9399528f, 8845.4549529f, 9073.0299534f, 9118.5449535f, 9164.0599536f, 9209.5749537f, 9255.089953f, 9300.604953f, 7792.2451647f, 7844.5551655f, 7896.8651663f, 7949.1751671f, 8001.4851679f, 8053.7951686f, 8315.3451725f, 8367.6551733f, 8419.9651741f, 8472.2751749f, 8524.585175f, 8576.8951764f, 8838.4451803f, 8890.7551811f, 8943.0651819f, 8995.3751827f, 9047.6851834f, 9099.9951842f, 9361.5451881f, 9413.8551889f, 9466.1651897f, 9518.475190f, 9570.7851912f, 9623.0951920f, 9884.6451959f, 9936.9551967f, 9989.2651975f, 10041.5751982f, 10093.8851990f, 10146.1951998f, 10407.7452037f, 10460.0552045f, 10512.3652053f, 10564.6752060f, 10616.9852068f, 10669.2952076f, 8787.210074f, 8846.3150748f, 8905.4200750f, 8964.5250752f, 9023.6300755f, 9082.7350757f, 9378.2600768f, 9437.3650770f, 9496.4700773f, 9555.5750775f, 9614.6800777f, 9673.7850779f, 9969.3100791f, 10028.4150793f, 10087.5200795f, 10146.625079f, 10205.7300800f, 10264.8350802f, 10560.3600813f, 10619.465081f, 10678.5700818f, 10737.6750820f, 10796.7800822f, 10855.8850825f, 11151.4100836f, 11210.5150838f, 11269.6200840f, 11328.7250843f, 11387.8300845f, 11446.9350847f, 11742.4600858f, 11801.5650861f, 11860.6700863f, 11919.7750865f, 11978.880086f, 12037.9850870f, 9782.1750935f, 9848.0750935f, 9913.9750934f, 9979.8750934f, 10045.7750934f, 10111.6750933f, 10441.1750931f, 10507.0750931f, 10572.9750931f, 10638.8750930f, 10704.7750930f, 10770.6750930f, 11100.1750928f, 11166.0750927f, 11231.9750927f, 11297.8750927f, 11363.7750926f, 11429.6750926f, 11759.1750924f, 11825.0750924f, 11890.9750923f, 11956.8750923f, 12022.7750923f, 12088.6750922f, 12418.175092f, 12484.0750920f, 12549.9750920f, 12615.8750919f, 12681.7750919f, 12747.6750919f, 13077.1750917f, 13143.0750916f, 13208.9750916f, 13274.8750916f, 13340.7750915f, 13406.6750915f, 2250.990060f, 2255.7350610f, 2260.4800611f, 2265.2250612f, 2269.9700613f, 2274.7150614f, 2298.4400619f, 2303.185062f, 2307.9300622f, 2312.6750623f, 2317.4200624f, 2322.1650625f, 2345.8900630f, 2350.6350631f, 2355.380063f, 2360.1250634f, 2364.8700635f, 2369.6150636f, 2393.3400641f, 2398.0850642f, 2402.8300643f, 2407.5750644f, 2412.320064f, 2417.0650647f, 2440.7900652f, 2445.5350653f, 2450.2800654f, 2455.0250655f, 2459.7700656f, 2464.515065f, 2488.2400663f, 2492.9850664f, 2497.7300665f, 2502.4750666f, 2507.2200667f, 2511.9650668f, 5284.4551315f, 5295.9951318f, 5307.535132f, 5319.0751323f, 5330.6151326f, 5342.1551328f, 5399.8551341f, 5411.3951343f, 5422.9351346f, 5434.475134f, 5446.0151351f, 5457.5551354f, 5515.2551366f, 5526.7951369f, 5538.3351371f, 5549.8751374f, 5561.4151376f, 5572.9551379f, 5630.6551392f, 5642.1951394f, 5653.7351397f, 5665.2751399f, 5676.8151402f, 5688.3551404f, 5746.0551417f, 5757.5951420f, 5769.1351422f, 5780.6751425f, 5792.2151427f, 5803.7551430f, 5861.455144f, 5872.9951445f, 5884.5351448f, 5896.0751450f, 5907.6151453f, 5919.1551455f, 8317.919884f, 8336.2548841f, 8354.5898838f, 8372.9248835f, 8391.2598832f, 8409.59488f, 8501.2698815f, 8519.6048813f, 8537.9398810f, 8556.2748807f, 8574.6098804f, 8592.9448801f, 8684.6198787f, 8702.9548784f, 8721.2898782f, 8739.6248779f, 8757.9598776f, 8776.2948773f, 8867.9698759f, 8886.3048756f, 8904.6398753f, 8922.9748751f, 8941.3098748f, 8959.6448745f, 9051.3198731f, 9069.6548728f, 9087.9898725f, 9106.3248722f, 9124.6598720f, 9142.9948717f, 9234.6698703f, 9253.0048700f, 9271.3398697f, 9289.6748694f, 9308.0098691f, 9326.3448689f, 11351.3852747f, 11376.5152753f, 11401.6452759f, 11426.7752765f, 11451.9052771f, 11477.0352777f, 11602.6852806f, 11627.8152812f, 11652.9452818f, 11678.0752824f, 11703.2052830f, 11728.335283f, 11853.9852865f, 11879.1152871f, 11904.2452877f, 11929.3752883f, 11954.505288f, 11979.6352894f, 12105.2852924f, 12130.4152930f, 12155.545293f, 12180.6752941f, 12205.8052947f, 12230.9352953f, 12356.5852983f, 12381.715298f, 12406.8452994f, 12431.9753000f, 12457.1053006f, 12482.2353012f, 12607.8853041f, 12633.0153047f, 12658.1453053f, 12683.2753059f, 12708.4053065f, 12733.5353071f, 14384.8499244f, 14416.7749242f, 14448.6999240f, 14480.6249238f, 14512.549923f, 14544.4749235f, 14704.0999225f, 14736.024922f, 14767.9499222f, 14799.8749220f, 14831.7999218f, 14863.7249216f, 15023.3499207f, 15055.2749205f, 15087.1999203f, 15119.1249201f, 15151.0499199f, 15182.9749197f, 15342.5999188f, 15374.5249186f, 15406.4499184f, 15438.374918f, 15470.2999181f, 15502.2249179f, 15661.84991f, 15693.7749168f, 15725.6999166f, 15757.6249164f, 15789.5499162f, 15821.4749160f, 15981.0999151f, 16013.0249149f, 16044.9499147f, 16076.8749145f, 16108.7999143f, 16140.7249142f, 17418.314976f, 17457.0349761f, 17495.7549760f, 17534.4749759f, 17573.1949758f, 17611.9149757f, 17805.5149753f, 17844.234975f, 17882.9549752f, 17921.6749751f, 17960.3949750f, 17999.1149749f, 18192.7149745f, 18231.4349744f, 18270.154974f, 18308.8749743f, 18347.5949742f, 18386.3149741f, 18579.9149737f, 18618.6349736f, 18657.3549735f, 18696.074973f, 18734.7949734f, 18773.5149733f, 18967.1149729f, 19005.8349728f, 19044.5549727f, 19083.2749726f, 19121.994972f, 19160.7149725f, 19354.3149721f, 19393.0349720f, 19431.7549719f, 19470.4749718f, 19509.1949717f, 19547.914971f, 20451.7799765f, 20497.2949766f, 20542.8099767f, 20588.3249768f, 20633.8399769f, 20679.3549770f, 20906.929977f, 20952.4449775f, 20997.9599776f, 21043.4749777f, 21088.9899778f, 21134.5049779f, 21362.0799784f, 21407.5949785f, 21453.1099786f, 21498.624978f, 21544.139978f, 21589.6549788f, 21817.2299793f, 21862.7449794f, 21908.2599795f, 21953.7749796f, 21999.2899797f, 22044.8049798f, 22272.3799802f, 22317.8949803f, 22363.4099804f, 22408.9249805f, 22454.4399806f, 22499.9549807f, 22727.529981f, 22773.044981f, 22818.5599813f, 22864.0749814f, 22909.5899815f, 22955.1049816f, 23485.2453985f, 23537.555399f, 23589.8654000f, 23642.1754008f, 23694.4854016f, 23746.7954024f, 24008.3454063f, 24060.655407f, 24112.9654078f, 24165.2754086f, 24217.5854094f, 24269.8954102f, 24531.4454141f, 24583.7554148f, 24636.0654156f, 24688.3754164f, 24740.6854172f, 24792.99541f, 25054.545421f, 25106.8554226f, 25159.1654234f, 25211.4754242f, 25263.7854250f, 25316.0954257f, 25577.6454296f, 25629.9554304f, 25682.2654312f, 25734.5754320f, 25786.8854328f, 25839.1954335f, 26100.7454374f, 26153.0554382f, 26205.3654390f, 26257.6754398f, 26309.985440f, 26362.2954413f, 26518.7101423f, 26577.8151425f, 26636.920142f, 26696.0251430f, 26755.1301432f, 26814.2351434f, 27109.7601446f, 27168.8651448f, 27227.9701450f, 27287.0751452f, 27346.1801455f, 27405.2851457f, 27700.8101468f, 27759.9151470f, 27819.0201473f, 27878.1251475f, 27937.2301477f, 27996.33514f, 28291.8601491f, 28350.9651493f, 28410.0701495f, 28469.175149f, 28528.2801500f, 28587.3851502f, 28882.9101513f, 28942.0151516f, 29001.1201518f, 29060.2251520f, 29119.3301522f, 29178.4351525f, 29473.9601536f, 29533.0651538f, 29592.1701540f, 29651.2751543f, 29710.3801545f, 29769.4851547f, 29552.1750826f, 29618.0750825f, 29683.9750825f, 29749.8750825f, 29815.7750824f, 29881.6750824f, 30211.1750822f, 30277.0750822f, 30342.9750821f, 30408.8750821f, 30474.7750821f, 30540.6750820f, 30870.175081f, 30936.0750818f, 31001.9750818f, 31067.8750817f, 31133.7750817f, 31199.6750817f, 31529.1750815f, 31595.075081f, 31660.9750814f, 31726.8750814f, 31792.7750813f, 31858.6750813f, 32188.1750811f, 32254.0750811f, 32319.975081f, 32385.8750810f, 32451.7750810f, 32517.6750809f, 32847.1750808f, 32913.0750807f, 32978.9750807f, 33044.875080f, 33110.7750806f, 33176.67508062f}; + Nd4jLong _exp2SFF[] = {4, 2, 10, 6, 6, 360, 36, 6, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, 1, 99}; NDArray exp2FF(_exp2BFF, _exp2SFF); auto input = NDArrayFactory::create('c', {2, 3, 10, 10}); diff --git a/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp b/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp index 4cbf6b6dd..de3cdcdba 100644 --- a/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp +++ b/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp @@ -212,12 +212,12 @@ TEST_F(ConvolutionTests2, Test_Dilation2D_Again_2) { } TYPED_TEST(TypedConvolutionTests2, sconv2d_bp_1) { - TypeParam _expGradWpB[] = {1603.7102981f, 10645.6278024f, 5975.4227995f, 17697.0903052f, 12133.6353024f, 26535.0528052f, 1779.221097f, 11795.5686029f, 6721.9835994f, 19904.0811062f, 13775.2461029f, 30123.0936062f, 1954.7318976f, 12945.5094033f, 7468.5443993f, 22111.071907f, 15416.8569033f, 33711.134407f, 2130.2426974f, 14095.4502038f, 8215.1051992f, 24318.0627081f, 17058.4677038f, 37299.1752081f, 2305.7534972f, 15245.3910042f, 8961.6659991f, 26525.0535091f, 18700.0785042f, 40887.2160091f, 2481.2642970f, 16395.3318047f, 9708.2267991f, 28732.0443100f, 20341.6893047f, 44475.2568100f, 2656.7750968f, 17545.2726051f, 10454.7875990f, 30939.0351110f, 21983.3001051f, 48063.2976110f, 2832.2858966f, 18695.2134056f, 11201.3483989f, 33146.0259119f, 23624.9109056f, 51651.3384119f, 3007.7966964f, 19845.1542060f, 11947.9091988f, 35353.0167129f, 25266.5217060f, 55239.3792129f, 3183.3074962f, 20995.095006f, 12694.4699987f, 37560.007513f, 26908.132506f, 58827.4200139}; + TypeParam _expGradWpB[] = {1603.7102981f, 10645.6278024f, 5975.4227995f, 17697.0903052f, 12133.6353024f, 26535.0528052f, 1779.221097f, 11795.5686029f, 6721.9835994f, 19904.0811062f, 13775.2461029f, 30123.0936062f, 1954.7318976f, 12945.5094033f, 7468.5443993f, 22111.071907f, 15416.8569033f, 33711.134407f, 2130.2426974f, 14095.4502038f, 8215.1051992f, 24318.0627081f, 17058.4677038f, 37299.1752081f, 2305.7534972f, 15245.3910042f, 8961.6659991f, 26525.0535091f, 18700.0785042f, 40887.2160091f, 2481.2642970f, 16395.3318047f, 9708.2267991f, 28732.0443100f, 20341.6893047f, 44475.2568100f, 2656.7750968f, 17545.2726051f, 10454.7875990f, 30939.0351110f, 21983.3001051f, 48063.2976110f, 2832.2858966f, 18695.2134056f, 11201.3483989f, 33146.0259119f, 23624.9109056f, 51651.3384119f, 3007.7966964f, 19845.1542060f, 11947.9091988f, 35353.0167129f, 25266.5217060f, 55239.3792129f, 3183.3074962f, 20995.095006f, 12694.4699987f, 37560.007513f, 26908.132506f, 58827.4200139f}; Nd4jLong _expGradWpS[] {4, 10, 6, 1, 1, 6, 1, 1, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, 1, 99}; NDArray expGWP(_expGradWpB, _expGradWpS); expGWP.permutei({2,3,1,0}); - TypeParam _expGradWdB[] = {2074.21032f, 2082.76104f, 2091.31176f, 2099.86248f, 2108.4132f, 2159.71752f, 2168.26824f, 2176.81896f, 2185.36968f, 2193.9204f, 2245.22472f, 2253.77544f, 2262.32616f, 2270.87688f, 2279.4276f, 2330.73192f, 2339.28264f, 2347.83336f, 2356.38408f, 2364.9348f, 2416.23912f, 2424.78984f, 2433.34056f, 2441.89128f, 2450.442f, 3112.99344f, 3122.06328f, 3131.13312f, 3140.20296f, 3149.2728f, 3203.69184f, 3212.76168f, 3221.83152f, 3230.90136f, 3239.9712f, 3294.39024f, 3303.46008f, 3312.52992f, 3321.59976f, 3330.6696f, 3385.08864f, 3394.15848f, 3403.22832f, 3412.29816f, 3421.368f, 3475.78704f, 3484.85688f, 3493.92672f, 3502.99656f, 3512.0664f, 4255.60056f, 4265.18952f, 4274.77848f, 4284.36744f, 4293.9564f, 4351.49016f, 4361.07912f, 4370.66808f, 4380.25704f, 4389.846f, 4447.37976f, 4456.96872f, 4466.55768f, 4476.14664f, 4485.7356f, 4543.26936f, 4552.85832f, 4562.44728f, 4572.03624f, 4581.6252f, 4639.15896f, 4648.74792f, 4658.33688f, 4667.92584f, 4677.5148f, 2140.10988f, 2148.92016f, 2157.73044f, 2166.54072f, 2175.351f, 2228.21268f, 2237.02296f, 2245.83324f, 2254.64352f, 2263.4538f, 2316.31548f, 2325.12576f, 2333.93604f, 2342.74632f, 2351.5566f, 2404.41828f, 2413.22856f, 2422.03884f, 2430.84912f, 2439.6594f, 2492.52108f, 2501.33136f, 2510.14164f, 2518.95192f, 2527.7622f, 3204.849f, 3214.1784f, 3223.5078f, 3232.8372f, 3242.1666f, 3298.143f, 3307.4724f, 3316.8018f, 3326.1312f, 3335.4606f, 3391.437f, 3400.7664f, 3410.0958f, 3419.4252f, 3428.7546f, 3484.731f, 3494.0604f, 3503.3898f, 3512.7192f, 3522.0486f, 3578.025f, 3587.3544f, 3596.6838f, 3606.0132f, 3615.3426f, 4373.41212f, 4383.26064f, 4393.10916f, 4402.95768f, 4412.8062f, 4471.89732f, 4481.74584f, 4491.59436f, 4501.44288f, 4511.2914f, 4570.38252f, 4580.23104f, 4590.07956f, 4599.92808f, 4609.7766f, 4668.86772f, 4678.71624f, 4688.56476f, 4698.41328f, 4708.2618f, 4767.35292f, 4777.20144f, 4787.04996f, 4796.89848f, 4806.747}; + TypeParam _expGradWdB[] = {2074.21032f, 2082.76104f, 2091.31176f, 2099.86248f, 2108.4132f, 2159.71752f, 2168.26824f, 2176.81896f, 2185.36968f, 2193.9204f, 2245.22472f, 2253.77544f, 2262.32616f, 2270.87688f, 2279.4276f, 2330.73192f, 2339.28264f, 2347.83336f, 2356.38408f, 2364.9348f, 2416.23912f, 2424.78984f, 2433.34056f, 2441.89128f, 2450.442f, 3112.99344f, 3122.06328f, 3131.13312f, 3140.20296f, 3149.2728f, 3203.69184f, 3212.76168f, 3221.83152f, 3230.90136f, 3239.9712f, 3294.39024f, 3303.46008f, 3312.52992f, 3321.59976f, 3330.6696f, 3385.08864f, 3394.15848f, 3403.22832f, 3412.29816f, 3421.368f, 3475.78704f, 3484.85688f, 3493.92672f, 3502.99656f, 3512.0664f, 4255.60056f, 4265.18952f, 4274.77848f, 4284.36744f, 4293.9564f, 4351.49016f, 4361.07912f, 4370.66808f, 4380.25704f, 4389.846f, 4447.37976f, 4456.96872f, 4466.55768f, 4476.14664f, 4485.7356f, 4543.26936f, 4552.85832f, 4562.44728f, 4572.03624f, 4581.6252f, 4639.15896f, 4648.74792f, 4658.33688f, 4667.92584f, 4677.5148f, 2140.10988f, 2148.92016f, 2157.73044f, 2166.54072f, 2175.351f, 2228.21268f, 2237.02296f, 2245.83324f, 2254.64352f, 2263.4538f, 2316.31548f, 2325.12576f, 2333.93604f, 2342.74632f, 2351.5566f, 2404.41828f, 2413.22856f, 2422.03884f, 2430.84912f, 2439.6594f, 2492.52108f, 2501.33136f, 2510.14164f, 2518.95192f, 2527.7622f, 3204.849f, 3214.1784f, 3223.5078f, 3232.8372f, 3242.1666f, 3298.143f, 3307.4724f, 3316.8018f, 3326.1312f, 3335.4606f, 3391.437f, 3400.7664f, 3410.0958f, 3419.4252f, 3428.7546f, 3484.731f, 3494.0604f, 3503.3898f, 3512.7192f, 3522.0486f, 3578.025f, 3587.3544f, 3596.6838f, 3606.0132f, 3615.3426f, 4373.41212f, 4383.26064f, 4393.10916f, 4402.95768f, 4412.8062f, 4471.89732f, 4481.74584f, 4491.59436f, 4501.44288f, 4511.2914f, 4570.38252f, 4580.23104f, 4590.07956f, 4599.92808f, 4609.7766f, 4668.86772f, 4678.71624f, 4688.56476f, 4698.41328f, 4708.2618f, 4767.35292f, 4777.20144f, 4787.04996f, 4796.89848f, 4806.747f}; Nd4jLong _expGradWdS[] = {4, 2, 3, 5, 5, 75, 25, 5, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, 1, 99}; NDArray expGWD(_expGradWdB, _expGradWdS); expGWD.permutei({2,3,1,0}); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp index 7036ef77f..591746804 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp @@ -1594,7 +1594,7 @@ TEST_F(DeclarableOpsTests1, TestGemv1) { auto z = NDArrayFactory::create_('f', {5, 1}); - auto expBuffer = new float[5]{28.00,64.00,100.00,136.00,172.00}; + auto expBuffer = new float[5]{28.00f,64.00f,100.00f,136.00f,172.00f}; auto exp = new NDArray(expBuffer, z->getShapeInfo()); nd4j::blas::GEMV::op('f', x->rows(), x->columns(), 1.0f, x->getBuffer(), y->rows(), y->getBuffer(), 1, 0.0, z->getBuffer(), 1); @@ -3523,7 +3523,8 @@ TEST_F(DeclarableOpsTests1, Reverse_7 ) { ASSERT_EQ(ND4J_STATUS_OK, results->status()); auto result = results->at(0); - // result->printBuffer(); + //expected.printIndexedBuffer("E"); + //result->printIndexedBuffer("R"); ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); @@ -3605,7 +3606,9 @@ TEST_F(DeclarableOpsTests1, Reverse_11 ) { auto input = NDArrayFactory::create('c', {2,3,4}); - auto expected = NDArrayFactory::create('c', {2,3,4}, {24., 23., 22., 21., 20., 19., 18., 17., 16., 15., 14., 13., 12., 11., 10., 9., 8., 7., 6., 5., 4., 3., 2., 1.}); + auto expected = NDArrayFactory::create('c', {2,3,4}, {24.f, 23.f, 22.f, 21.f, 20.f, 19.f, 18.f, 17.f, 16.f, + 15.f, 14.f, 13.f, 12.f, 11.f, 10.f, 9.f, 8.f, 7.f, + 6.f, 5.f, 4.f, 3.f, 2.f, 1.f}); input.linspace(1); nd4j::ops::reverse op; diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp index 6375d935c..21c18299e 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp @@ -121,10 +121,10 @@ TEST_F(DeclarableOpsTests10, Test_Or_1) { } TEST_F(DeclarableOpsTests10, Test_Not_1) { - auto x = NDArrayFactory::create('c', {4}, {1, 1, 0, 1}); - auto y = NDArrayFactory::create('c', {4}, {0, 0, 0, 1}); + auto x = NDArrayFactory::create('c', {4}, {true, true, false, true}); + auto y = NDArrayFactory::create('c', {4}, {false, false, false, true}); // auto e = NDArrayFactory::create('c', {4}, {1, 1, 1, 0}); - auto e = NDArrayFactory::create('c', {4}, {0, 0, 1, 0}); + auto e = NDArrayFactory::create('c', {4}, {false, false, true, false}); nd4j::ops::boolean_not op; auto result = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::BOOL); @@ -245,7 +245,8 @@ TEST_F(DeclarableOpsTests10, WhereNP_SGO_Test_1) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, WhereNP_SGO_Test_2) { - auto cond2d = NDArrayFactory::create('c', {3, 5}, {1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1}); + auto cond2d = NDArrayFactory::create('c', {3, 5}, {true, true, false, false, true, true, true, + true, true, true, false, true, true, true, true}); // auto expIdx({0, 1, 0, 2, 0, 3, 4, 1, 4, 1}); auto exp1 = NDArrayFactory::create({0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2}); auto exp2 = NDArrayFactory::create({0, 1, 4, 0, 1, 2, 3, 4, 1, 2, 3, 4}); @@ -623,7 +624,7 @@ TEST_F(DeclarableOpsTests10, range_test11) { ////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, range_test12) { - auto exp = NDArrayFactory::create('c', {9}, {0.5, 1. , 1.5, 2. , 2.5, 3. , 3.5, 4. , 4.5}); + auto exp = NDArrayFactory::create('c', {9}, {0.5f, 1.f , 1.5f, 2.f , 2.5f, 3.f , 3.5f, 4.f , 4.5f}); nd4j::ops::range op; auto result = op.execute({}, {0.5, 5, 0.5}, {}, {}); @@ -1416,7 +1417,7 @@ TEST_F(DeclarableOpsTests10, broadcast_to_test10) { //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test1) { - NDArray input = NDArrayFactory::create('c', {1, 2,3,4}); + NDArray input = NDArrayFactory::create('c', {1, 2, 3, 4}); //NDArray paddings('c', {3,2}, {0,0, 0,1, 0,0}); //NDArray expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.}); NDArray expected = NDArrayFactory::create('c', {1, 10, 10, 4}, {1., 2., 3., 4., 2.2, 3.2, 4.2, 5.2, 3.4, 4.4, 5.4, 6.4, @@ -1470,6 +1471,138 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test1) { delete results; } +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test_11) { + + NDArray input = NDArrayFactory::create('c', {1, 1, 1, 256}); + + input.assign(0.8f); //linspace(1); + auto size = NDArrayFactory::create({65,65}); + auto ex = NDArrayFactory::create('c', {1,65,65,256}); + nd4j::ops::resize_bilinear op; + auto results = op.execute({&input, &size}, {}, {}, {false}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + NDArray* result = results->at(0); + ASSERT_NE(*result, ex); + + delete results; +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test_12) { + + NDArray input = NDArrayFactory::create('c', {1, 1, 1, 256}); + + input.assign(0.8f); //linspace(1); + auto size = NDArrayFactory::create({65,65}); + auto ex = NDArrayFactory::create('c', {1,65,65,256}); + nd4j::ops::resize_bilinear op; + auto results = op.execute({&input, &size}, {}, {}, {true}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + NDArray* result = results->at(0); + ASSERT_NE(*result, ex); + + delete results; +} + +TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test1_1) { + + NDArray input = NDArrayFactory::create('c', {1, 2, 3, 4}); + //NDArray paddings('c', {3,2}, {0,0, 0,1, 0,0}); + //NDArray expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.}); + NDArray expected = NDArrayFactory::create('c', {1, 4, 5, 4}, { + 1., 2., 3., 4., + 2.6, 3.6, 4.6, 5.6, + 5., 6., 7., 8., + 7.4, 8.4, 9.4, 10.4, + 9., 10., 11., 12., + + 4., 5., 6., 7., + 5.6, 6.6, 7.6, 8.6, + 8., 9., 10., 11., + 10.4, 11.4, 12.4, 13.4, + 12., 13., 14., 15., + + 10., 11., 12., 13., + 11.6, 12.6, 13.6, 14.6, + 14., 15., 16., 17., + 16.4, 17.4, 18.4, 19.4, + 18., 19., 20., 21., + + 13., 14., 15., 16., + 14.6, 15.6, 16.6, 17.6, + 17., 18., 19., 20., + 19.4, 20.4, 21.4, 22.4, + 21., 22., 23., 24. + }); + //input = 1.f; + input.linspace(1); + + nd4j::ops::resize_bilinear op; + auto results = op.execute({&input}, {}, {4, 5}, {false, true}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + NDArray* result = results->at(0); + +// result->printIndexedBuffer("Resized to 4x5 bilinear with half pixels"); + //expected.printIndexedBuffer("Expect for 10x10"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + delete results; +} + +TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test1_2) { + + NDArray input = NDArrayFactory::create('c', {1, 2, 3, 4}); + //NDArray paddings('c', {3,2}, {0,0, 0,1, 0,0}); + //NDArray expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.}); + NDArray expected = NDArrayFactory::create('c', {1, 4, 5, 4}, { + 1.f, 2.f, 3.f, 4.f, + 2.6f, 3.6f, 4.6f, 5.6f, + 5.f, 6.f, 7.f, 8.f, + 7.4f, 8.4f, 9.4f, 10.4f, + 9.f, 10.f, 11.f, 12.f, + + 4.f, 5.f, 6.f, 7.f, + 5.6f, 6.6f, 7.6f, 8.6f, + 8.f, 9.f, 10.f, 11.f, + 10.4f, 11.4f, 12.4f, 13.4f, + 12.f, 13.f, 14.f, 15.f, + + 10.f, 11.f, 12.f, 13.f, + 11.6f, 12.6f, 13.6f, 14.6f, + 14.f, 15.f, 16.f, 17.f, + 16.4f, 17.4f, 18.4f, 19.4f, + 18.f, 19.f, 20.f, 21.f, + + 13.f, 14.f, 15.f, 16.f, + 14.6f, 15.6f, 16.6f, 17.6f, + 17.f, 18.f, 19.f, 20.f, + 19.4f, 20.4f, 21.4f, 22.4f, + 21.f, 22.f, 23.f, 24.f + }); + //input = 1.f; + input.linspace(1); + + nd4j::ops::resize_bilinear op; + auto results = op.execute({&input}, {}, {4, 5}, {false, true}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + NDArray* result = results->at(0); + +// result->printBuffer("Resized to 4x5"); +// expected.printBuffer("Expect for 4x5"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + delete results; +} + TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test01) { NDArray input = NDArrayFactory::create('c', {2,3,4}); @@ -1857,7 +1990,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test3) { input.linspace(1); nd4j::ops::resize_bilinear op; - auto results = op.execute({&input}, {}, {10, 10, 1}); + auto results = op.execute({&input}, {}, {10, 10}, {true}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1986,7 +2119,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test4) { input.linspace(1); nd4j::ops::resize_bilinear op; - auto results = op.execute({&input, &size}, {}, {1}); + auto results = op.execute({&input, &size}, {}, {}, {true}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2023,7 +2156,8 @@ TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test1) { NDArray input = NDArrayFactory::create('c', {1, 2, 3, 4}); //NDArray paddings('c', {3,2}, {0,0, 0,1, 0,0}); //NDArray expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.}); - NDArray expected = NDArrayFactory::create('c', {1, 4, 5, 4}, { 1, 2, 3, 4, + NDArray expected = NDArrayFactory::create('c', {1, 4, 5, 4}, { + 1, 2, 3, 4, 1, 2, 3, 4, 5, 6, 7, 8, 5, 6, 7, 8, @@ -2051,7 +2185,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test1) { input.linspace(1); nd4j::ops::resize_nearest_neighbor op; - auto results = op.execute({&input}, {}, {4, 5}); + auto results = op.execute({&input}, {}, {4, 5}, {false, false}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2070,7 +2204,8 @@ TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test1_1) { NDArray input = NDArrayFactory::create('c', {1, 2, 3, 4}); //NDArray paddings('c', {3,2}, {0,0, 0,1, 0,0}); //NDArray expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.}); - NDArray expected = NDArrayFactory::create('c', {1, 4, 5, 4}, { 1, 2, 3, 4, + NDArray expected = NDArrayFactory::create('c', {1, 4, 5, 4}, { + 1, 2, 3, 4, 1, 2, 3, 4, 5, 6, 7, 8, 5, 6, 7, 8, @@ -2112,6 +2247,54 @@ TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test1_1) { delete results; } +TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test1_1_1) { + + NDArray input = NDArrayFactory::create('c', {1, 2, 3, 4}); + //NDArray paddings('c', {3,2}, {0,0, 0,1, 0,0}); + //NDArray expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.}); + NDArray expected = NDArrayFactory::create('c', {1, 4, 5, 4}, { + 1.f, 2.f, 3.f, 4.f, + 1.f, 2.f, 3.f, 4.f, + 5.f, 6.f, 7.f, 8.f, + 9.f, 10.f, 11.f, 12.f, + 9.f, 10.f, 11.f, 12.f, + + 1.f, 2.f, 3.f, 4.f, + 1.f, 2.f, 3.f, 4.f, + 5.f, 6.f, 7.f, 8.f, + 9.f, 10.f, 11.f, 12.f, + 9.f, 10.f, 11.f, 12.f, + + 13.f, 14.f, 15.f, 16.f, + 13.f, 14.f, 15.f, 16.f, + 17.f, 18.f, 19.f, 20.f, + 21.f, 22.f, 23.f, 24.f, + 21.f, 22.f, 23.f, 24.f, + + 13.f, 14.f, 15.f, 16.f, + 13.f, 14.f, 15.f, 16.f, + 17.f, 18.f, 19.f, 20.f, + 21.f, 22.f, 23.f, 24.f, + 21.f, 22.f, 23.f, 24.f + }); + //input = 1.f; + input.linspace(1); + + nd4j::ops::resize_nearest_neighbor op; + auto results = op.execute({&input}, {}, {4,5}, {false, true}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + NDArray* result = results->at(0); + +// result->printIndexedBuffer("Resized to 4x5"); +// expected.printBuffer("Expect for 4x5"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + + delete results; +} + TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test01) { NDArray input = NDArrayFactory::create('c', {2, 3, 4}); @@ -2533,7 +2716,7 @@ TEST_F(DeclarableOpsTests10, Image_CropAndResize_3) { NDArray cropSize = NDArrayFactory::create({3, 3}); //NDArray ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); - NDArray expected('c', {1,3,3,1}, {1, 1.5f, 2., 2.f, 2.5f, 3.f, 3.f, 3.5f, 4.f}, nd4j::DataType::FLOAT32); + NDArray expected('c', {1,3,3,1}, {1.f, 1.5f, 2., 2.f, 2.5f, 3.f, 3.f, 3.5f, 4.f}, nd4j::DataType::FLOAT32); nd4j::ops::crop_and_resize op; auto results = op.execute({&images, &boxes, &boxI, &cropSize}, {}, {0}); @@ -2557,7 +2740,7 @@ TEST_F(DeclarableOpsTests10, Image_CropAndResize_4) { NDArray cropSize = NDArrayFactory::create({3, 3}); //NDArray ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); - NDArray expected('c', {1,3,3,1}, {1, 2.f, 2.f, 3.f, 4, 4.f, 3.f, 4.f, 4.f}, nd4j::DataType::FLOAT32); + NDArray expected('c', {1,3,3,1}, {1.f, 2.f, 2.f, 3.f, 4, 4.f, 3.f, 4.f, 4.f}, nd4j::DataType::FLOAT32); nd4j::ops::crop_and_resize op; auto results = op.execute({&images, &boxes, &boxI, &cropSize}, {}, {1}); @@ -2726,7 +2909,7 @@ TEST_F(DeclarableOpsTests10, Image_DrawBoundingBoxes_3) { TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_1) { NDArray x('c', {2,3}, {-63.80f, -63.75f, -63.70f, -63.5f, 0.0f, 0.1f}, nd4j::DataType::FLOAT32); - NDArray exp('c', {2,3}, {-63.75, -63.75, -63.75, -63.5, 0., 0.}, nd4j::DataType::FLOAT32); + NDArray exp('c', {2,3}, {-63.75f, -63.75f, -63.75f, -63.5f, 0.f, 0.f}, nd4j::DataType::FLOAT32); NDArray min('c', {}, {-63.65f}, nd4j::DataType::FLOAT32); NDArray max('c', {}, {0.1f}, nd4j::DataType::FLOAT32); @@ -2971,22 +3154,6 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_5) { delete results; } -/* public void testFakeQuantAgainstTF_1() { - INDArray x = Nd4j.createFromArray(new float[]{ 0.7788f,0.8012f, 0.7244f, 0.2309f,0.7271f, - 0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f, - 0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f}).reshape(3,5); - INDArray min = Nd4j.createFromArray(new float[]{-0.2283f, -0.0719f, -0.0154f, -0.5162f, -0.3567f}).reshape(1,5); - INDArray max = Nd4j.createFromArray(new float[]{0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f}).reshape(1,5); - - INDArray out = Nd4j.createUninitialized(x.shape()); - val op = new FakeQuantWithMinMaxVarsPerChannel(x,min,max,out); - - INDArray expected = Nd4j.createFromArray(new float[]{0.7801f, 0.5966f, 0.7260f, 0.2320f, 0.5084f, - 0.1800f, 0.5046f, 0.8684f, 0.3513f, 0.5084f, - 0.0877f, 0.5966f, 0.6600f, 0.3513f, 0.1604f}).reshape(3,5); - - assertEquals(expected, out); - }*/ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_6) { NDArray x = NDArrayFactory::create('c', {3, 5}, {0.7788f,0.8012f, 0.7244f, 0.2309f,0.7271f, 0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f, @@ -3094,12 +3261,12 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_8) { TEST_F(DeclarableOpsTests10, batchnorm_test1) { NDArray input ('c', {2,4}, nd4j::DataType::FLOAT32); - NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32); - NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32); - NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32); - NDArray beta ('c', {4}, {10, 20, -10, -20}, nd4j::DataType::FLOAT32); + NDArray mean ('c', {4}, {1.05f, 1.15f, 1.2f, 1.3f}, nd4j::DataType::FLOAT32); + NDArray variance('c', {4}, {0.5f, 0.7f, 0.9f, 1.1f}, nd4j::DataType::FLOAT32); + NDArray gamma ('c', {4}, {-1.2f, 1.3f, -1.4f, 1.5f}, nd4j::DataType::FLOAT32); + NDArray beta ('c', {4}, {10.f, 20.f, -10.f, -20.f}, nd4j::DataType::FLOAT32); - NDArray expected('c', {2,4}, {11.61218734, 18.52390321, -8.67185076, -21.28716864, 10.93337162, 19.14541765, -9.26213931, -20.71509369}, nd4j::DataType::FLOAT32); + NDArray expected('c', {2,4}, {11.61218734f, 18.52390321f, -8.67185076f, -21.28716864f, 10.93337162f, 19.14541765f, -9.26213931f, -20.71509369f}, nd4j::DataType::FLOAT32); input.linspace(0.1, 0.1); @@ -3211,19 +3378,19 @@ TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_test4) { TEST_F(DeclarableOpsTests10, batchnorm_test5) { NDArray input ('c', {2,4,2,2}, nd4j::DataType::FLOAT32); - NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32); - NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32); - NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32); - NDArray beta ('c', {4}, {10, 20, -10, -20}, nd4j::DataType::FLOAT32); + NDArray mean ('c', {4}, {1.05f, 1.15f, 1.2f, 1.3f}, nd4j::DataType::FLOAT32); + NDArray variance('c', {4}, {0.5f, 0.7f, 0.9f, 1.1f}, nd4j::DataType::FLOAT32); + NDArray gamma ('c', {4}, {-1.2f, 1.3f, -1.4f, 1.5f}, nd4j::DataType::FLOAT32); + NDArray beta ('c', {4}, {10.f, 20.f, -10.f, -20.f}, nd4j::DataType::FLOAT32); - NDArray expected('c', {2,4,2,2}, {11.612187, 11.442483, 11.272779, 11.103076, 18.990039, 19.145418, 19.300796, 19.456175, -9.557284, -9.704856, -9.852428, -10., -20., - -19.856981, -19.713963, -19.570944, 8.896924, 8.727221, 8.557517, 8.387813, 21.476097, 21.631475, 21.786854, 21.942233, -11.918438, - -12.06601 , -12.213582, -12.361154, -17.7117, -17.568681, -17.425663, -17.282644}, nd4j::DataType::FLOAT32); + NDArray expected('c', {2,4,2,2}, { 11.612187f, 11.442483f, 11.272779f, 11.103076f, 18.990039f, 19.145418f, 19.300796f, 19.456175f, -9.557284f, -9.704856f, -9.852428f, -10.f, -20.f, + -19.856981f, -19.713963f, -19.570944f, 8.896924f, 8.727221f, 8.557517f, 8.387813f, 21.476097f, 21.631475f, 21.786854f, 21.942233f, -11.918438f, + -12.06601f, -12.213582f, -12.361154f, -17.7117f, -17.568681f, -17.425663f, -17.282644f}, nd4j::DataType::FLOAT32); input.linspace(0.1, 0.1); nd4j::ops::batchnorm op; - auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1,1}); + auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1, 1, 1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3240,14 +3407,14 @@ TEST_F(DeclarableOpsTests10, batchnorm_test5) { TEST_F(DeclarableOpsTests10, batchnorm_test6) { NDArray input ('c', {2,2,2,4}, nd4j::DataType::FLOAT32); - NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32); - NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32); - NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32); - NDArray beta ('c', {4}, {10, 20, -10, -20}, nd4j::DataType::FLOAT32); + NDArray mean ('c', {4}, {1.05f, 1.15f, 1.2f, 1.3f}, nd4j::DataType::FLOAT32); + NDArray variance('c', {4}, {0.5f, 0.7f, 0.9, 1.1f}, nd4j::DataType::FLOAT32); + NDArray gamma ('c', {4}, {-1.2f, 1.3f, -1.4f, 1.5f}, nd4j::DataType::FLOAT32); + NDArray beta ('c', {4}, {10.f, 20.f, -10.f, -20.f}, nd4j::DataType::FLOAT32); - NDArray expected('c', {2,2,2,4}, {11.612187, 18.523903, -8.671851, -21.287169, 10.933372, 19.145418, -9.262139, -20.715094, 10.254556, 19.766932, -9.852428, -20.143019, 9.57574 , - 20.388447, -10.442716, -19.570944,8.896924, 21.009961, -11.033005, -18.998869, 8.218109, 21.631475, -11.623294, -18.426794, 7.539293, 22.25299 , - -12.213582, -17.854719, 6.860477, 22.874504, -12.803871, -17.282644}, nd4j::DataType::FLOAT32); + NDArray expected('c', {2,2,2,4}, {11.612187f, 18.523903f, -8.671851f, -21.287169f, 10.933372f, 19.145418f, -9.262139f, -20.715094f, 10.254556f, 19.766932f, -9.852428f, -20.143019f, 9.57574f, + 20.388447f, -10.442716f, -19.570944f, 8.896924f, 21.009961f, -11.033005f, -18.998869f, 8.218109f, 21.631475f, -11.623294f, -18.426794f, 7.539293f, 22.25299f, + -12.213582f, -17.854719f, 6.860477f, 22.874504f, -12.803871f, -17.282644f}, nd4j::DataType::FLOAT32); input.linspace(0.1, 0.1); nd4j::ops::batchnorm op; @@ -3270,7 +3437,7 @@ TEST_F(DeclarableOpsTests10, bool_broadcast_test_1) { NDArray arr1('c', {2,2,1}, {1, 2, 3, 4}, nd4j::DataType::INT32); NDArray arr2('c', { 2,2}, {0, 1, 0, 4}, nd4j::DataType::INT32); - NDArray expd('c', {2,2,2}, {0,1,0,0, 0,0,0,1}, nd4j::DataType::BOOL); + NDArray expd('c', {2,2,2}, {false, true, false, false, false, false, false, true}, nd4j::DataType::BOOL); NDArray result('c', {2,2,2}, nd4j::DataType::BOOL); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp index 67ecf5576..5ca22c95e 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp @@ -1257,7 +1257,7 @@ TEST_F(DeclarableOpsTests12, inTopK_2) { auto input = NDArrayFactory::create('c', {4, 5}); auto idx = NDArrayFactory::create('c', {4}); - auto exp = NDArrayFactory::create({0, 0, 0, 1}); + auto exp = NDArrayFactory::create({false, false, false, true}); int exclusive, reverse; input.linspace(1); @@ -1318,7 +1318,7 @@ TEST_F(DeclarableOpsTests12, inTopK_4) { TEST_F(DeclarableOpsTests12, inTopK_5) { auto x = NDArrayFactory::create('f', {6, 4}, {11.0, 3.0, 14.0, 5.0, 6.0, 9.0, 3.5, 7.0, 21.0, 3.0, 14.0, 15.0, 6.0, 9.0, 3.5, 7.0, 11.0, 13.0, 14.0, 5.0, 16.0, 9.0, 13.5, 7.0} ); auto y = NDArrayFactory::create('f', {6}, {0, 0, 0, 0, 0, 0}); - auto expV = NDArrayFactory::create('f', {6}, {1, 0, 0, 0, 0, 0 }); + auto expV = NDArrayFactory::create('f', {6}, {true, false, false, false, false, false }); nd4j::ops::in_top_k op; auto result = op.execute({&x, &y}, {}, {2}); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp index 76a44be0b..91ff89d46 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp @@ -1167,12 +1167,12 @@ TEST_F(DeclarableOpsTests13, lstmLayer_3) { std::initializer_list iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; - NDArray expH('c', {sL, bS, nOut}, {0.493883, 0.493883, 0.493883, 0.510990, 0.510990, 0.510990, 0.534701, 0.534701, 0.534701, 0.549139, - 0.549139, 0.549139, 0.571900, 0.571900, 0.571900, 0.583561, 0.583561, 0.583561, 0.605106, 0.605106, - 0.605106, 0.614114, 0.614114, 0.614114, 0.635354, 0.635354, 0.635354, 0.642045, 0.642045, 0.642045}, nd4j::DataType::FLOAT32); + NDArray expH('c', {sL, bS, nOut}, {0.493883f, 0.493883f, 0.493883f, 0.510990f, 0.510990f, 0.510990f, 0.534701f, 0.534701f, 0.534701f, 0.549139f, + 0.549139f, 0.549139f, 0.571900f, 0.571900f, 0.571900f, 0.583561f, 0.583561f, 0.583561f, 0.605106f, 0.605106f, + 0.605106f, 0.614114f, 0.614114f, 0.614114f, 0.635354f, 0.635354f, 0.635354f, 0.642045f, 0.642045f, 0.642045f}, nd4j::DataType::FLOAT32); - NDArray expHL('c', {bS, nOut}, {0.493883, 0.493883, 0.493883, 0.510990, 0.510990, 0.510990}, nd4j::DataType::FLOAT32); - NDArray expCL('c', {bS, nOut}, {1.061274, 1.061274, 1.061274, 1.115888, 1.115888, 1.115888}, nd4j::DataType::FLOAT32); + NDArray expHL('c', {bS, nOut}, {0.493883f, 0.493883f, 0.493883f, 0.510990f, 0.510990f, 0.510990f}, nd4j::DataType::FLOAT32); + NDArray expCL('c', {bS, nOut}, {1.061274f, 1.061274f, 1.061274f, 1.115888f, 1.115888f, 1.115888f}, nd4j::DataType::FLOAT32); nd4j::ops::lstmLayer op; auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs); @@ -1230,12 +1230,12 @@ TEST_F(DeclarableOpsTests13, lstmLayer_4) { NDArray cI('c', {2,bS, nOut}, nd4j::DataType::FLOAT32); x.linspace(0.5, 0.5); - Wx({0,1, 0,0, 0,0}) = 0.003; - Wx({1,2, 0,0, 0,0}) = -0.003; - Wr({0,1, 0,0, 0,0}) = 0.006; - Wr({1,2, 0,0, 0,0}) = -0.006; - b({0,1, 0,0}) = 0.5; - b({1,2, 0,0}) = -0.5; + Wx({0,1, 0,0, 0,0}) = 0.003f; + Wx({1,2, 0,0, 0,0}) = -0.003f; + Wr({0,1, 0,0, 0,0}) = 0.006f; + Wr({1,2, 0,0, 0,0}) = -0.006f; + b({0,1, 0,0}) = 0.5f; + b({1,2, 0,0}) = -0.5f; hI({0,1, 0,0, 0,0}) = 1; hI({1,2, 0,0, 0,0}) = -1; cI({0,1, 0,0, 0,0}) = 2; @@ -1245,18 +1245,19 @@ TEST_F(DeclarableOpsTests13, lstmLayer_4) { std::initializer_list iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; - NDArray expH('c', {sL, bS, 2*nOut}, {0.577661, 0.577661, 0.577661, -0.107642, -0.107642, -0.107642, 0.585289, 0.585289, 0.585289, - -0.106937, -0.106937, -0.106937, 0.556517, 0.556517, 0.556517, -0.111647, -0.111647, -0.111647, - 0.567274, 0.567274, 0.567274, -0.110214, -0.110214, -0.110214, 0.547395, 0.547395, 0.547395, - -0.123305, -0.123305, -0.123305, 0.560640, 0.560640, 0.560640, -0.120862, -0.120862, -0.120862, - 0.550714, 0.550714, 0.550714, -0.156223, -0.156223, -0.156223, 0.565308, 0.565308, 0.565308, - -0.152313, -0.152313, -0.152313, 0.563741, 0.563741, 0.563741, -0.234128, -0.234128, -0.234128, - 0.578676, 0.578676, 0.578676, -0.228917, -0.228917, -0.228917}, nd4j::DataType::FLOAT32); + NDArray expH('c', {sL, bS, 2 * nOut}, { + 0.577661f, 0.577661f, 0.577661f, -0.107642f, -0.107642f, -0.107642f, 0.585289f, 0.585289f, 0.585289f, + -0.106937f, -0.106937f, -0.106937f, 0.556517f, 0.556517f, 0.556517f, -0.111647f, -0.111647f, -0.111647f, + 0.567274f, 0.567274f, 0.567274f, -0.110214f, -0.110214f, -0.110214f, 0.547395f, 0.547395f, 0.547395f, + -0.123305f, -0.123305f, -0.123305f, 0.560640f, 0.560640f, 0.560640f, -0.120862f, -0.120862f, -0.120862f, + 0.550714f, 0.550714f, 0.550714f, -0.156223f, -0.156223f, -0.156223f, 0.565308f, 0.565308f, 0.565308f, + -0.152313f, -0.152313f, -0.152313f, 0.563741f, 0.563741f, 0.563741f, -0.234128f, -0.234128f, -0.234128f, + 0.578676f, 0.578676f, 0.578676f, -0.228917f, -0.228917f, -0.228917f}, nd4j::DataType::FLOAT32); - NDArray expHL('c', {2,bS, nOut}, {0.563741, 0.563741, 0.563741, 0.578676, 0.578676, 0.578676, -0.107642, - -0.107642, -0.107642, -0.106937, -0.106937, -0.106937}, nd4j::DataType::FLOAT32); - NDArray expCL('c', {2,bS, nOut}, {1.217757, 1.217757, 1.217757, 1.272398, 1.272398, 1.272398, -0.295768, - -0.295768, -0.295768, -0.298453, -0.298453, -0.298453}, nd4j::DataType::FLOAT32); + NDArray expHL('c', {2,bS, nOut}, {0.563741f, 0.563741f, 0.563741f, 0.578676f, 0.578676f, 0.578676f, -0.107642f, + -0.107642f, -0.107642f, -0.106937f, -0.106937f, -0.106937f}, nd4j::DataType::FLOAT32); + NDArray expCL('c', {2,bS, nOut}, {1.217757f, 1.217757f, 1.217757f, 1.272398f, 1.272398f, 1.272398f, -0.295768f, + -0.295768f, -0.295768f, -0.298453f, -0.298453f, -0.298453f}, nd4j::DataType::FLOAT32); nd4j::ops::lstmLayer op; auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs); @@ -1328,16 +1329,17 @@ TEST_F(DeclarableOpsTests13, lstmLayer_5) { std::initializer_list iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; - NDArray expH('c', {bS, sL, 2*nOut}, {0.577661, 0.577661, 0.577661, -0.107659, -0.107659, -0.107659, 0.548099, 0.548099, 0.548099, -0.113406, -0.113406, -0.113406, - 0.526881, 0.526881, 0.526881, -0.12883 , -0.12883 , -0.12883 , 0.515882, 0.515882, 0.515882, -0.16868 , -0.16868 , -0.16868 , - 0.51409 , 0.51409 , 0.51409 , -0.255185, -0.255185, -0.255185, 0.614599, 0.614599, 0.614599, -0.102739, -0.102739, -0.102739, - 0.599572, 0.599572, 0.599572, -0.105802, -0.105802, -0.105802,0.591089, 0.591089, 0.591089, -0.116681, -0.116681, -0.116681, - 0.588694, 0.588694, 0.588694, -0.149201, -0.149201, -0.149201,0.591492, 0.591492, 0.591492, -0.228917, -0.228917, -0.228917}, nd4j::DataType::FLOAT32); + NDArray expH('c', {bS, sL, 2*nOut}, { + 0.577661f, 0.577661f, 0.577661f, -0.107659f, -0.107659f, -0.107659f, 0.548099f, 0.548099f, 0.548099f, -0.113406f, -0.113406f, -0.113406f, + 0.526881f, 0.526881f, 0.526881f, -0.12883f, -0.12883f, -0.12883f, 0.515882f, 0.515882f, 0.515882f, -0.16868f, -0.16868f, -0.16868f, + 0.51409f, 0.51409f, 0.51409f, -0.255185f, -0.255185f, -0.255185f, 0.614599f, 0.614599f, 0.614599f, -0.102739f, -0.102739f, -0.102739f, + 0.599572f, 0.599572f, 0.599572f, -0.105802f, -0.105802f, -0.105802f, 0.591089f, 0.591089f, 0.591089f, -0.116681f, -0.116681f, -0.116681f, + 0.588694f, 0.588694f, 0.588694f, -0.149201f, -0.149201f, -0.149201f, 0.591492f, 0.591492f, 0.591492f, -0.228917f, -0.228917f, -0.228917f}, nd4j::DataType::FLOAT32); - NDArray expHL('c', {2,bS, nOut}, {0.51409 , 0.51409 , 0.51409 , 0.591492, 0.591492, 0.591492, - -0.107659, -0.107659, -0.107659, -0.102739, -0.102739, -0.102739}, nd4j::DataType::FLOAT32); - NDArray expCL('c', {2,bS, nOut}, {1.07293 , 1.07293 , 1.07293,1.346609, 1.346609, 1.346609, - -0.295811, -0.295811, -0.295811,-0.305394, -0.305394, -0.305394}, nd4j::DataType::FLOAT32); + NDArray expHL('c', {2,bS, nOut}, {0.51409f, 0.51409f, 0.51409f, 0.591492f, 0.591492f, 0.591492f, + -0.107659f, -0.107659f, -0.107659f, -0.102739f, -0.102739f, -0.102739f}, nd4j::DataType::FLOAT32); + NDArray expCL('c', {2,bS, nOut}, {1.07293f , 1.07293f , 1.07293f, 1.346609f, 1.346609f, 1.346609f, + -0.295811f, -0.295811f, -0.295811f, -0.305394f, -0.305394f, -0.305394f}, nd4j::DataType::FLOAT32); nd4j::ops::lstmLayer op; auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs); @@ -1398,12 +1400,12 @@ TEST_F(DeclarableOpsTests13, lstmLayer_6) { NDArray cI('c', {2,bS, nOut}, nd4j::DataType::FLOAT32); x.linspace(0.5, 0.5); - Wx({0,1, 0,0, 0,0}) = 0.003; - Wx({1,2, 0,0, 0,0}) = -0.003; - Wr({0,1, 0,0, 0,0}) = 0.006; - Wr({1,2, 0,0, 0,0}) = -0.006; - b({0,1, 0,0}) = 0.5; - b({1,2, 0,0}) = -0.5; + Wx({0,1, 0,0, 0,0}) = 0.003f; + Wx({1,2, 0,0, 0,0}) = -0.003f; + Wr({0,1, 0,0, 0,0}) = 0.006f; + Wr({1,2, 0,0, 0,0}) = -0.006f; + b({0,1, 0,0}) = 0.5f; + b({1,2, 0,0}) = -0.5f; hI({0,1, 0,0, 0,0}) = 1; hI({1,2, 0,0, 0,0}) = -1; cI({0,1, 0,0, 0,0}) = 2; @@ -1413,14 +1415,17 @@ TEST_F(DeclarableOpsTests13, lstmLayer_6) { std::initializer_list iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; - NDArray expH('c', {sL, bS, nOut}, {0.470019, 0.470019, 0.470019, 0.478352, 0.478352, 0.478352, 0.444871, 0.444871, 0.444871, 0.457060, - 0.457060, 0.457060, 0.424090, 0.424090, 0.424090, 0.439778, 0.439778, 0.439778, 0.394491, 0.394491, - 0.394491, 0.412995, 0.412995, 0.412995, 0.329613, 0.329613, 0.329613, 0.349760, 0.349760, 0.349760}, nd4j::DataType::FLOAT32); + NDArray expH('c', {sL, bS, nOut}, { + 0.470019f, 0.470019f, 0.470019f, 0.478352f, 0.478352f, 0.478352f, 0.444871f, 0.444871f, 0.444871f, 0.457060f, + 0.457060f, 0.457060f, 0.424090f, 0.424090f, 0.424090f, 0.439778f, 0.439778f, 0.439778f, 0.394491f, 0.394491f, + 0.394491f, 0.412995f, 0.412995f, 0.412995f, 0.329613f, 0.329613f, 0.329613f, 0.349760f, 0.349760f, 0.349760f}, nd4j::DataType::FLOAT32); - NDArray expHL('c', {2,bS, nOut}, {0.563741, 0.563741, 0.563741, 0.578676, 0.578676, 0.578676, -0.107642, - -0.107642, -0.107642, -0.106937, -0.106937, -0.106937}, nd4j::DataType::FLOAT32); - NDArray expCL('c', {2,bS, nOut}, {1.217757, 1.217757, 1.217757, 1.272398, 1.272398, 1.272398, -0.295768, - -0.295768, -0.295768, -0.298453, -0.298453, -0.298453}, nd4j::DataType::FLOAT32); + NDArray expHL('c', {2,bS, nOut}, {0.563741f, 0.563741f, 0.563741f, 0.578676f, 0.578676f, 0.578676f, + -0.107642f, -0.107642f, -0.107642f, -0.106937f, -0.106937f, -0.106937f}, + nd4j::DataType::FLOAT32); + NDArray expCL('c', {2,bS, nOut}, {1.217757f, 1.217757f, 1.217757f, 1.272398f, 1.272398f, 1.272398f, + -0.295768f, -0.295768f, -0.295768f, -0.298453f, -0.298453f, -0.298453f}, + nd4j::DataType::FLOAT32); nd4j::ops::lstmLayer op; auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs); @@ -1568,12 +1573,13 @@ TEST_F(DeclarableOpsTests13, lstmLayer_8) { std::initializer_list iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; - NDArray expH('c', {sL, bS, nOut}, {0.436221, 0.436221, 0.436221,0.450573, 0.450573, 0.450573,0.463602, 0.463602, 0.463602, 0.474674, 0.474674, 0.474674, - 0.484039, 0.484039, 0.484039,0.490679, 0.490679, 0.490679, 0.494871, 0.494871, 0.494871, 0.499028, 0.499028, 0.499028, - 0.504649, 0.504649, 0.504649, 0.508719, 0.508719, 0.508719}, nd4j::DataType::FLOAT32); + NDArray expH('c', {sL, bS, nOut}, { + 0.436221f, 0.436221f, 0.436221f, 0.450573f, 0.450573f, 0.450573f, 0.463602f, 0.463602f, 0.463602f, 0.474674f, 0.474674f, 0.474674f, + 0.484039f, 0.484039f, 0.484039f, 0.490679f, 0.490679f, 0.490679f, 0.494871f, 0.494871f, 0.494871f, 0.499028f, 0.499028f, 0.499028f, + 0.504649f, 0.504649f, 0.504649f, 0.508719f, 0.508719f, 0.508719f}, nd4j::DataType::FLOAT32); - NDArray expHL('c', {bS, nOut}, {0.436221, 0.436221, 0.436221, 0.450573, 0.450573, 0.450573}, nd4j::DataType::FLOAT32); - NDArray expCL('c', {bS, nOut}, {0.879804, 0.879804, 0.879804,0.914666, 0.914666, 0.914666}, nd4j::DataType::FLOAT32); + NDArray expHL('c', {bS, nOut}, {0.436221f, 0.436221f, 0.436221f, 0.450573f, 0.450573f, 0.450573f}, nd4j::DataType::FLOAT32); + NDArray expCL('c', {bS, nOut}, {0.879804f, 0.879804f, 0.879804f, 0.914666f, 0.914666f, 0.914666f}, nd4j::DataType::FLOAT32); nd4j::ops::lstmLayer op; auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); @@ -1650,16 +1656,17 @@ TEST_F(DeclarableOpsTests13, lstmLayer_9) { std::initializer_list iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; - NDArray expH('c', {sL, bS, 2*nOut}, { 0.55533 , 0.55533 , 0.55533 , -0.104502, -0.104502, -0.104502, 0.562925, 0.562925, 0.562925, -0.103843, -0.103843, -0.103843, - 0.531795, 0.531795, 0.531795, -0.107456, -0.107456, -0.107456,0.542556, 0.542556, 0.542556, -0.106139, -0.106139, -0.106139, - 0.521466, 0.521466, 0.521466, -0.11681 , -0.11681 , -0.11681 , 0.534638, 0.534638, 0.534638, -0.11458 , -0.11458 , -0.11458 , - 0.524805, 0.524805, 0.524805, -0.145177, -0.145177, -0.145177,0.539187, 0.539187, 0.539187, -0.14157 , -0.14157 , -0.14157 , - 0.538309, 0.538309, 0.538309, -0.218056, -0.218056, -0.218056,0.552923, 0.552923, 0.552923, -0.213068, -0.213068, -0.213068}, nd4j::DataType::FLOAT32); + NDArray expH('c', {sL, bS, 2*nOut}, { + 0.55533f, 0.55533f, 0.55533f, -0.104502f, -0.104502f, -0.104502f, 0.562925f, 0.562925f, 0.562925f, -0.103843f, -0.103843f, -0.103843f, + 0.531795f, 0.531795f, 0.531795f, -0.107456f, -0.107456f, -0.107456f, 0.542556f, 0.542556f, 0.542556f, -0.106139f, -0.106139f, -0.106139f, + 0.521466f, 0.521466f, 0.521466f, -0.11681f, -0.11681f, -0.11681f, 0.534638f, 0.534638f, 0.534638f, -0.11458f, -0.11458f, -0.11458f, + 0.524805f, 0.524805f, 0.524805f, -0.145177f, -0.145177f, -0.145177f, 0.539187f, 0.539187f, 0.539187f, -0.14157f, -0.14157f, -0.14157f, + 0.538309f, 0.538309f, 0.538309f, -0.218056f, -0.218056f, -0.218056f, 0.552923f, 0.552923f, 0.552923f, -0.213068f, -0.213068f, -0.213068f}, nd4j::DataType::FLOAT32); - NDArray expHL('c', {2,bS, nOut}, {0.538309, 0.538309, 0.538309, 0.552923, 0.552923, 0.552923, -0.104502, -0.104502, -0.104502, - -0.103843, -0.103843, -0.103843}, nd4j::DataType::FLOAT32); - NDArray expCL('c', {2,bS, nOut}, {1.147089, 1.147089, 1.147089, 1.197228, 1.197228, 1.197228, -0.289425, -0.289425, -0.289425, - -0.292174, -0.292174, -0.292174}, nd4j::DataType::FLOAT32); + NDArray expHL('c', {2,bS, nOut}, {0.538309f, 0.538309f, 0.538309f, 0.552923f, 0.552923f, 0.552923f, -0.104502f, -0.104502f, -0.104502f, + -0.103843f, -0.103843f, -0.103843f}, nd4j::DataType::FLOAT32); + NDArray expCL('c', {2,bS, nOut}, {1.147089f, 1.147089f, 1.147089f, 1.197228f, 1.197228f, 1.197228f, -0.289425f, -0.289425f, -0.289425f, + -0.292174f, -0.292174f, -0.292174f}, nd4j::DataType::FLOAT32); nd4j::ops::lstmLayer op; auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); @@ -1731,14 +1738,20 @@ TEST_F(DeclarableOpsTests13, lstmLayer_10) { std::initializer_list iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; - NDArray expH('c', {sL, bS, nOut}, {0., 0., 0., 0.562925, 0.562925, 0.562925, 0.570404, 0.570404, 0.570404, 0.57777 , 0.57777 , 0.57777 , 0.585023, 0.585023, 0.585023, - 0., 0., 0., 0., 0., 0., 0.576568, 0.576568, 0.576568, 0.586163, 0.586163, 0.586163, 0.595462, 0.595462, 0.595462, 0., 0., 0., 0., 0., - 0., 0., 0., 0., 0.611224, 0.611224, 0.611224, 0.621298, 0.621298, 0.621298, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., - 0.655858, 0.655858, 0.655858, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.692315, 0.692315, 0.692315, 0., 0., 0., 0., 0., 0., - 0., 0., 0., 0., 0., 0., 0., 0., 0.}, nd4j::DataType::FLOAT32); + NDArray expH('c', {sL, bS, nOut}, { + 0.f, 0.f, 0.f, 0.562925f, 0.562925f, 0.562925f, 0.570404f, 0.570404f, 0.570404f, 0.57777f, + 0.57777f, 0.57777f, 0.585023f, 0.585023f, 0.585023f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.576568f, 0.576568f, 0.576568f, 0.586163f, 0.586163f, 0.586163f, 0.595462f, 0.595462f, 0.595462f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.611224f, + 0.611224f, 0.611224f, 0.621298f, 0.621298f, 0.621298f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.655858f, 0.655858f, 0.655858f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.692315f, 0.692315f, 0.692315f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}, + nd4j::DataType::FLOAT32); - NDArray expHL('c', {bS, nOut}, {0., 0., 0., 0.562925, 0.562925, 0.562925, 0.576568, 0.576568, 0.576568, 0.611224, 0.611224, 0.611224, 0.692315, 0.692315, 0.692315}, nd4j::DataType::FLOAT32); - NDArray expCL('c', {bS, nOut}, {0., 0., 0., 1.534275, 1.534275, 1.534275, 1.40183, 1.40183, 1.40183, 1.449675, 1.449675, 1.449675, 1.767702, 1.767702, 1.767702}, nd4j::DataType::FLOAT32); + NDArray expHL('c', {bS, nOut}, {0.f, 0.f, 0.f, 0.562925f, 0.562925f, 0.562925f, 0.576568f, 0.576568f, 0.576568f, 0.611224f, 0.611224f, 0.611224f, 0.692315f, 0.692315f, 0.692315f}, nd4j::DataType::FLOAT32); + NDArray expCL('c', {bS, nOut}, {0.f, 0.f, 0.f, 1.534275f, 1.534275f, 1.534275f, 1.40183f, 1.40183f, 1.40183f, 1.449675f, 1.449675f, 1.449675f, 1.767702f, 1.767702f, 1.767702f}, nd4j::DataType::FLOAT32); nd4j::ops::lstmLayer op; auto results = op.execute({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); @@ -1799,25 +1812,26 @@ TEST_F(DeclarableOpsTests13, lstmLayer_11) { NDArray Wp('c', {3*nOut}, nd4j::DataType::FLOAT32); x.linspace(0.5, 0.5); - Wx = 0.003; - Wr = 0.006; - b = 0.5; - hI = 1.; - cI = 2.; - Wp = -0.05; + Wx = 0.003f; + Wr = 0.006f; + b = 0.5f; + hI = 1.f; + cI = 2.f; + Wp = -0.05f; std::initializer_list tArgs = {cellClip}; std::initializer_list iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; - NDArray expH('c', {sL, bS, nOut}, {0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.61209, - 0.61209, 0.61209,0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.652042, 0.652042, 0.652042, 0., 0., 0., 0., 0., - 0., 0., 0., 0., 0.677708, 0.677708, 0.677708, 0.684177, 0.684177, 0.684177, 0., 0., 0.,0., 0., 0.,0.699627, 0.699627, - 0.699627,0.705371, 0.705371, 0.705371,0.710989, 0.710989, 0.710989, 0., 0., 0., 0.719014, 0.719014, 0.719014, 0.724087, - 0.724087, 0.724087, 0.729084, 0.729084, 0.729084, 0.734004, 0.734004, 0.734004 }, nd4j::DataType::FLOAT32); + NDArray expH('c', {sL, bS, nOut}, { + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.61209f, + 0.61209f, 0.61209f,0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.652042f, 0.652042f, 0.652042f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.677708f, 0.677708f, 0.677708f, 0.684177f, 0.684177f, 0.684177f, 0.f, 0.f, 0.f,0.f, 0.f, 0.f, 0.699627f, 0.699627f, + 0.699627f, 0.705371f, 0.705371f, 0.705371f, 0.710989f, 0.710989f, 0.710989f, 0., 0., 0., 0.719014, 0.719014, 0.719014, 0.724087, + 0.724087f, 0.724087f, 0.729084f, 0.729084f, 0.729084f, 0.734004f, 0.734004f, 0.734004f }, nd4j::DataType::FLOAT32); - NDArray expHL('c', {bS, nOut}, {0., 0., 0., 0.719014, 0.719014, 0.719014, 0.699627, 0.699627, 0.699627, 0.677708, 0.677708, 0.677708, 0.61209, 0.61209, 0.61209}, nd4j::DataType::FLOAT32); - NDArray expCL('c', {bS, nOut}, {0., 0., 0., 2.092814, 2.092814, 2.092814, 2.08832, 2.08832, 2.08832, 2.009851, 2.009851, 2.009851, 1.646034, 1.646034, 1.646034}, nd4j::DataType::FLOAT32); + NDArray expHL('c', {bS, nOut}, {0.f, 0.f, 0.f, 0.719014f, 0.719014f, 0.719014f, 0.699627f, 0.699627f, 0.699627f, 0.677708f, 0.677708f, 0.677708f, 0.61209f, 0.61209f, 0.61209f}, nd4j::DataType::FLOAT32); + NDArray expCL('c', {bS, nOut}, {0.f, 0.f, 0.f, 2.092814f, 2.092814f, 2.092814f, 2.08832f, 2.08832f, 2.08832f, 2.009851f, 2.009851f, 2.009851f, 1.646034f, 1.646034f, 1.646034f}, nd4j::DataType::FLOAT32); nd4j::ops::lstmLayer op; auto results = op.execute({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); @@ -1878,18 +1892,18 @@ TEST_F(DeclarableOpsTests13, lstmLayer_12) { NDArray Wp('c', {2,3*nOut}, nd4j::DataType::FLOAT32); x.linspace(0.5, 0.5); - Wx({0,1, 0,0, 0,0}) = 0.003; - Wx({1,2, 0,0, 0,0}) = -0.003; - Wr({0,1, 0,0, 0,0}) = 0.006; - Wr({1,2, 0,0, 0,0}) = -0.006; - b({0,1, 0,0}) = 0.5; - b({1,2, 0,0}) = -0.5; + Wx({0,1, 0,0, 0,0}) = 0.003f; + Wx({1,2, 0,0, 0,0}) = -0.003f; + Wr({0,1, 0,0, 0,0}) = 0.006f; + Wr({1,2, 0,0, 0,0}) = -0.006f; + b({0,1, 0,0}) = 0.5f; + b({1,2, 0,0}) = -0.5f; hI({0,1, 0,0, 0,0}) = 1; hI({1,2, 0,0, 0,0}) = -1; cI({0,1, 0,0, 0,0}) = 2; cI({1,2, 0,0, 0,0}) = -2; - Wp({0,1, 0,0}) = -0.05; - Wp({1,2, 0,0}) = 0.05; + Wp({0,1, 0,0}) = -0.05f; + Wp({1,2, 0,0}) = 0.05f; std::initializer_list tArgs = {cellClip}; std::initializer_list iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; @@ -1905,10 +1919,10 @@ TEST_F(DeclarableOpsTests13, lstmLayer_12) { 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.692315, 0.692315, 0.692315, -0.143704, -0.143704, -0.143704, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.}, nd4j::DataType::FLOAT32); - NDArray expHL('c', {2,bS, nOut}, {0., 0., 0., 0.562925, 0.562925, 0.562925, 0.576568, 0.576568, 0.576568, 0.611224, 0.611224, 0.611224, 0.692315, 0.692315, 0.692315, - 0., 0., 0., -0.25361 , -0.25361 , -0.25361 , -0.157103, -0.157103, -0.157103,-0.116502, -0.116502, -0.116502, -0.100025, -0.100025, -0.100025}, nd4j::DataType::FLOAT32); - NDArray expCL('c', {2,bS, nOut}, {0., 0., 0.,1.534275, 1.534275, 1.534275,1.40183 , 1.40183 , 1.40183 ,1.449675, 1.449675, 1.449675,1.767702, 1.767702, 1.767702, - 0., 0., 0.,-0.86636 , -0.86636 , -0.86636 ,-0.470245, -0.470245, -0.470245,-0.341856, -0.341856, -0.341856,-0.294986, -0.294986, -0.294986}, nd4j::DataType::FLOAT32); + NDArray expHL('c', {2,bS, nOut}, {0.f, 0.f, 0.f, 0.562925f, 0.562925f, 0.562925f, 0.576568f, 0.576568f, 0.576568f, 0.611224f, 0.611224f, 0.611224f, 0.692315f, 0.692315f, 0.692315f, + 0.f, 0.f, 0.f, -0.25361f, -0.25361f, -0.25361f, -0.157103f, -0.157103f, -0.157103f, -0.116502f, -0.116502f, -0.116502f, -0.100025f, -0.100025f, -0.100025f}, nd4j::DataType::FLOAT32); + NDArray expCL('c', {2,bS, nOut}, {0.f, 0.f, 0.f, 1.534275f, 1.534275f, 1.534275f, 1.40183f, 1.40183f, 1.40183f, 1.449675f, 1.449675f, 1.449675f, 1.767702f, 1.767702f, 1.767702f, + 0.f, 0.f, 0.f, -0.86636f, -0.86636f, -0.86636f, -0.470245f, -0.470245f, -0.470245f, -0.341856f, -0.341856f, -0.341856f, -0.294986f, -0.294986f, -0.294986f}, nd4j::DataType::FLOAT32); nd4j::ops::lstmLayer op; auto results = op.execute({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp index b2ccad86f..488adad0c 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp @@ -148,8 +148,8 @@ TEST_F(DeclarableOpsTests15, Test_standarize_1) { } TEST_F(DeclarableOpsTests15, Test_standarize_bp_1) { - auto x = NDArrayFactory::create('c', {5}, {1., 1., 1., 1., 1.}); - auto eps = NDArrayFactory::create('c', {5}, {0., 0., 0., 0., 0.}); + auto x = NDArrayFactory::create('c', {5}, {1.f, 1.f, 1.f, 1.f, 1.f}); + auto eps = NDArrayFactory::create('c', {5}, {0.f, 0.f, 0.f, 0.f, 0.f}); nd4j::ops::standardize_bp op; auto result = op.execute({&x, &eps}, {}, {0}, {}); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp index 38d88b469..f8bf47e53 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp @@ -196,4 +196,45 @@ TEST_F(DeclarableOpsTests16, test_range_2) { ASSERT_TRUE(shape::shapeEquals(z.shapeInfo(), shapes->at(0))); delete shapes; +} + +TEST_F(DeclarableOpsTests16, test_reverse_1) { + std::vector rows = {3, 5, 7, 8, 9, 10, 119, 211}; + std::vector columns = {6, 5, 10, 100, 153, 171, 635}; + + for (auto r : rows) { + for (auto c : columns) { + //nd4j_printf("Trying [%i, %i]\n", r, c); + auto array = NDArrayFactory::create('c', {r, c}); + auto exp = NDArrayFactory::create('c', {r, c}); + auto reversed = NDArrayFactory::create('c', {r, c}); + + auto rowOriginal = NDArrayFactory::create('c', {c}); + auto rowReversed = NDArrayFactory::create('c', {c}); + + for (int e = 0; e < c; e++) { + rowOriginal.p(e, (float) e); + rowReversed.p(c - e - 1, (float) e); + } + + + auto listI = array.allTensorsAlongDimension({1}); + auto listE = exp.allTensorsAlongDimension({1}); + + for (int e = 0; e < r; e++) { + listI->at(e)->assign(rowOriginal); + listE->at(e)->assign(rowReversed); + } + + delete listI; + delete listE; + + nd4j::ops::reverse op; + Nd4jLong axis = 1; + auto status = op.execute({&array}, {&reversed}, {}, {axis}, {}); + ASSERT_EQ(Status::OK(), status); + + ASSERT_EQ(exp, reversed); + } + } } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests2.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests2.cpp index 9f9c39156..a8377b429 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests2.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests2.cpp @@ -1591,7 +1591,7 @@ TEST_F(DeclarableOpsTests2, cosine_distance_loss_test6) { auto *result = results->at(0); ASSERT_TRUE(result->isScalar()); - ASSERT_TRUE(result->e(0) == -71.); + ASSERT_TRUE(result->e(0) == -71.f); delete results; @@ -1616,7 +1616,7 @@ TEST_F(DeclarableOpsTests2, cosine_distance_loss_test7) { auto *result = results->at(0); ASSERT_TRUE(result->isScalar()); - ASSERT_TRUE(result->e(0) == -69.); + ASSERT_TRUE(result->e(0) == -69.f); delete results; @@ -1630,8 +1630,8 @@ TEST_F(DeclarableOpsTests2, cosine_distance_loss_test8) { auto weights = NDArrayFactory::create('c', {2,3,1}); labels.linspace(1); - weights.assign(0.5); - predictions.assign(0.5); + weights.assign(0.5f); + predictions.assign(0.5f); nd4j::ops::cosine_distance_loss op; auto results = op.execute({&predictions, &weights, &labels}, {}, {2,2}); @@ -1641,7 +1641,7 @@ TEST_F(DeclarableOpsTests2, cosine_distance_loss_test8) { auto *result = results->at(0); ASSERT_TRUE(result->isScalar()); - ASSERT_TRUE(result->e(0) == -24.); + ASSERT_TRUE(result->e(0) == -24.f); delete results; @@ -1655,8 +1655,8 @@ TEST_F(DeclarableOpsTests2, cosine_distance_loss_test9) { auto weights = NDArrayFactory::create('c', {1,1}); labels.linspace(1); - weights.assign(0.5); - predictions.assign(0.5); + weights.assign(0.5f); + predictions.assign(0.5f); nd4j::ops::cosine_distance_loss op; auto results = op.execute({&predictions, &weights, &labels}, {}, {2,2}); @@ -1680,10 +1680,10 @@ TEST_F(DeclarableOpsTests2, cosine_distance_loss_test10) { auto weights = NDArrayFactory::create('c', {2,3,1}); labels.linspace(1); - weights.assign(0.5); - predictions.assign(0.5); - weights.p(0, 0.); - weights.p(1, 0.); + weights.assign(0.5f); + predictions.assign(0.5f); + weights.p(0, 0.f); + weights.p(1, 0.f); nd4j::ops::cosine_distance_loss op; auto results = op.execute({&predictions, &weights, &labels}, {}, {2,2}); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp index 6d224b323..5322a0a6d 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp @@ -1707,7 +1707,7 @@ TEST_F(DeclarableOpsTests3, betainc_test8) { b.linspace(10.); x.assign(1.); - auto expected= NDArrayFactory::create('c', {3,3}, {1.f, 1.f, 1.,1.,1.,1.,1.,1.,1.}); + auto expected= NDArrayFactory::create('c', {3,3}, {1.f, 1.f, 1.f,1.f,1.f,1.f,1.f,1.f,1.f}); nd4j::ops::betainc op; auto results = op.execute({&a, &b, &x}, {}, {}); @@ -2292,9 +2292,9 @@ TEST_F(DeclarableOpsTests3, svd_test3) { } else { for(uint i = 0; i < expU.lengthOf(); ++i) - ASSERT_NEAR(nd4j::math::nd4j_abs(expU.e(i)), nd4j::math::nd4j_abs(u->e(i)), 1e-5); + ASSERT_NEAR(nd4j::math::nd4j_abs(expU.e(i)), nd4j::math::nd4j_abs(u->e(i)), 1e-5f); for(uint i = 0; i < expV.lengthOf(); ++i) - ASSERT_NEAR(nd4j::math::nd4j_abs(expV.e(i)), nd4j::math::nd4j_abs(v->e(i)), 1e-5); + ASSERT_NEAR(nd4j::math::nd4j_abs(expV.e(i)), nd4j::math::nd4j_abs(v->e(i)), 1e-5f); } delete results; @@ -2329,9 +2329,9 @@ TEST_F(DeclarableOpsTests3, svd_test4) { } else { for(uint i = 0; i < expU.lengthOf(); ++i) - ASSERT_NEAR(nd4j::math::nd4j_abs(expU.e(i)), nd4j::math::nd4j_abs(u->e(i)), 1e-5); + ASSERT_NEAR(nd4j::math::nd4j_abs(expU.e(i)), nd4j::math::nd4j_abs(u->e(i)), 1e-5f); for(uint i = 0; i < expV.lengthOf(); ++i) - ASSERT_NEAR(nd4j::math::nd4j_abs(expV.e(i)), nd4j::math::nd4j_abs(v->e(i)), 1e-5); + ASSERT_NEAR(nd4j::math::nd4j_abs(expV.e(i)), nd4j::math::nd4j_abs(v->e(i)), 1e-5f); } delete results; @@ -2366,9 +2366,9 @@ TEST_F(DeclarableOpsTests3, svd_test5) { } else { for(uint i = 0; i < expU.lengthOf(); ++i) - ASSERT_NEAR(nd4j::math::nd4j_abs(expU.e(i)), nd4j::math::nd4j_abs(u->e(i)), 1e-5); + ASSERT_NEAR(nd4j::math::nd4j_abs(expU.e(i)), nd4j::math::nd4j_abs(u->e(i)), 1e-5f); for(uint i = 0; i < expV.lengthOf(); ++i) - ASSERT_NEAR(nd4j::math::nd4j_abs(expV.e(i)), nd4j::math::nd4j_abs(v->e(i)), 1e-5); + ASSERT_NEAR(nd4j::math::nd4j_abs(expV.e(i)), nd4j::math::nd4j_abs(v->e(i)), 1e-5f); } delete results; @@ -2421,9 +2421,9 @@ TEST_F(DeclarableOpsTests3, svd_test6) { } else { for(uint i = 0; i < expU.lengthOf(); ++i) - ASSERT_NEAR(nd4j::math::nd4j_abs(expU.e(i)), nd4j::math::nd4j_abs(u->e(i)), 1e-5); + ASSERT_NEAR(nd4j::math::nd4j_abs(expU.e(i)), nd4j::math::nd4j_abs(u->e(i)), 1e-5f); for(uint i = 0; i < expV.lengthOf(); ++i) - ASSERT_NEAR(nd4j::math::nd4j_abs(expV.e(i)), nd4j::math::nd4j_abs(v->e(i)), 1e-5); + ASSERT_NEAR(nd4j::math::nd4j_abs(expV.e(i)), nd4j::math::nd4j_abs(v->e(i)), 1e-5f); } delete results; diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp index 23351f7af..220191011 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp @@ -4084,7 +4084,7 @@ TEST_F(DeclarableOpsTests7, Softsign_1) { TEST_F(DeclarableOpsTests7, Softsign_BP_1) { NDArray x = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); -// NDArray e = NDArrayFactory::create('c', {5, 2}, {1.3132616, 2.126928, 3.0485873, 4.01815, 5.0067153, 7.0009117, 9.000123, 10.000046, 10.000046, 11.000016}); +// NDArray e = NDArrayFactory::create('c', {5, 2}, {1.3132616f, 2.126928f, 3.0485873f, 4.01815f, 5.0067153f, 7.0009117f, 9.000123f, 10.000046f, 10.000046f, 11.000016f}); NDArray eps = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,6,7,8, 9, 10}); nd4j::ops::softsign ffOP; nd4j::ops::softsign_bp bpOp; diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTestsCuda1.cu b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTestsCuda1.cu index 161b96918..f88cddde5 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTestsCuda1.cu +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTestsCuda1.cu @@ -24,6 +24,7 @@ #include #include #include +#include using namespace nd4j; @@ -58,5 +59,20 @@ TEST_F(DeclarableOpsTestsCuda1, Test_CHOOSE_SCALAR_LARGE) { //ASSERT_TRUE(exp.isSameShape(z)); delete result; +} -} \ No newline at end of file +/* +TEST_F(DeclarableOpsTestsCuda1, Test_Reverse_TAD_1) { + auto x = NDArrayFactory::create('c', {1, 3, 608, 608}); + auto z = x.like(); + x.linspace(1.0f); + + nd4j::ops::reverse op; + auto timeStart = std::chrono::system_clock::now(); + auto status = op.execute({&x}, {&z}, {}, {1}, {}); + auto timeEnd = std::chrono::system_clock::now(); + auto outerTime = std::chrono::duration_cast (timeEnd - timeStart).count(); + nd4j_printf("exec time: %lld us\n", outerTime); + ASSERT_EQ(Status::OK(), status); +} +*/ \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp b/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp index c89a989a9..e7f7f7e68 100644 --- a/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp @@ -661,9 +661,9 @@ TEST_F(JavaInteropTests, Test_Greater_1) { auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 1, 2}); auto y = NDArrayFactory::create('c', {2, 2}, {1, 2, 0, 0}); // auto o = NDArrayFactory::create('c', {2, 2}, {3, 3, 3, 3}); - auto o = NDArrayFactory::create('c', {2, 2}, {1, 1, 1, 1}); + auto o = NDArrayFactory::create('c', {2, 2}, {true, true, true, true}); - auto exp = NDArrayFactory::create('c', {2, 2}, {0, 0, 1, 1}); + auto exp = NDArrayFactory::create('c', {2, 2}, {false, false, true, true}); NDArray::prepareSpecialUse({&o}, {&x, &y}); @@ -685,9 +685,9 @@ TEST_F(JavaInteropTests, Test_Greater_1) { TEST_F(JavaInteropTests, Test_Greater_2) { auto x = NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, 1.f, 2.f}); auto y = NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, 0.f, 0.f}); - auto o = NDArrayFactory::create('c', {2, 2}, {1, 1, 1, 1}); + auto o = NDArrayFactory::create('c', {2, 2}, {true, true, true, true}); - auto exp = NDArrayFactory::create('c', {2, 2}, {0, 0, 1, 1}); + auto exp = NDArrayFactory::create('c', {2, 2}, {false, false, true, true}); nd4j::ops::greater op; diff --git a/libnd4j/tests_cpu/layers_tests/NDArrayCudaBasicsTests.cu b/libnd4j/tests_cpu/layers_tests/NDArrayCudaBasicsTests.cu index 71ad6929b..7740cd1ac 100644 --- a/libnd4j/tests_cpu/layers_tests/NDArrayCudaBasicsTests.cu +++ b/libnd4j/tests_cpu/layers_tests/NDArrayCudaBasicsTests.cu @@ -1163,10 +1163,10 @@ TEST_F(NDArrayCudaBasicsTests, applyReduce3_1) { NDArray k('c', {2,3}, {-2,3,-4,5,-2,3}, nd4j::DataType::INT32); NDArray k2('c', {3,2}, {-2,3,-4,5,-2,3}, nd4j::DataType::INT32); - NDArray exp1('c', {3}, {4., 20., 36.}, nd4j::DataType::FLOAT32); - NDArray exp2('c', {2,3}, {-10., -2., 6.,14., 22., 30.}, nd4j::DataType::FLOAT32); - NDArray exp3('c', {4}, {38., 41., 44., 47.}, nd4j::DataType::FLOAT32); - NDArray exp4('c', {4}, {114., 117., 120., 123.}, nd4j::DataType::FLOAT32); + NDArray exp1('c', {3}, {4.f, 20.f, 36.f}, nd4j::DataType::FLOAT32); + NDArray exp2('c', {2,3}, {-10.f, -2.f, 6.f,14.f, 22.f, 30.f}, nd4j::DataType::FLOAT32); + NDArray exp3('c', {4}, {38.f, 41.f, 44.f, 47.f}, nd4j::DataType::FLOAT32); + NDArray exp4('c', {4}, {114.f, 117.f, 120.f, 123.f}, nd4j::DataType::FLOAT32); NDArray* z = x.applyReduce3(nd4j::reduce3::Dot, &y, {0,2}); @@ -1271,8 +1271,10 @@ TEST_F(NDArrayCudaBasicsTests, applyAllReduce3_1) { NDArray x3('c', {3,2}, {1.5,1.5,1.5,1.5,1.5,1.5}, nd4j::DataType::DOUBLE); NDArray x4('c', {3,2}, {1,2,3,4,5,6}, nd4j::DataType::DOUBLE); - NDArray exp1('c', {3,2}, {-88., -124., 6., -2., 22., 14.}, nd4j::DataType::FLOAT32); - NDArray exp2('c', {6,4}, {-36., -44., -52., -60.,-42., -52., -62., -72.,2., 0., -2., -4.,6., 4., 2., 0.,10., 8., 6., 4.,14., 12., 10., 8.}, nd4j::DataType::FLOAT32); + NDArray exp1('c', {3,2}, {-88.f, -124.f, 6.f, -2.f, 22.f, 14.f}, nd4j::DataType::FLOAT32); + NDArray exp2('c', {6,4}, {-36.f, -44.f, -52.f, -60.f,-42.f, -52.f, -62.f, -72.f, 2.f, 0.f, -2.f, + -4.f, 6.f, 4.f, 2.f, 0.f, 10.f, 8.f, 6.f, 4.f, 14.f, 12.f, 10.f, 8.f}, + nd4j::DataType::FLOAT32); NDArray exp3('c', {1,1}, {31.5}, nd4j::DataType::DOUBLE); NDArray exp4('c', {3,3}, {4.5, 10.5, 16.5,4.5, 10.5, 16.5,4.5, 10.5, 16.5}, nd4j::DataType::DOUBLE); @@ -1400,10 +1402,10 @@ TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_float_test1) { NDArray z5('c', {2}, {100,100}, nd4j::DataType::FLOAT32); NDArray exp1('c', {}, {2.166667}, nd4j::DataType::DOUBLE); - NDArray exp2('c', {2,2}, {3,4,1,0.666667}, nd4j::DataType::FLOAT32); + NDArray exp2('c', {2,2}, {3.f,4.f,1.f,0.666667f}, nd4j::DataType::FLOAT32); NDArray exp3('c', {3}, {4.5,1,1}, nd4j::DataType::DOUBLE); NDArray exp4('c', {3,2}, {4,5,1,1,1,1}, nd4j::DataType::FLOAT32); - NDArray exp5('c', {2}, {3.5,0.833333}, nd4j::DataType::FLOAT32); + NDArray exp5('c', {2}, {3.5f,0.833333f}, nd4j::DataType::FLOAT32); x.reduceAlongDimension(nd4j::reduce::Mean, &z1, {0,1,2}); ASSERT_TRUE(z1.equalsTo(&exp1)); @@ -1503,7 +1505,7 @@ TEST_F(NDArrayCudaBasicsTests, EqualityTest1) { //////////////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_same_test1) { - NDArray x('c', {2,3,2}, {1.5,2,3,4,5,6,7.5,8,-1,-2,-3.5,-4,}, nd4j::DataType::FLOAT32); + NDArray x('c', {2,3,2}, {1.5f,2.f,3.f,4.f,5.f,6.f,7.5f,8.f,-1.f,-2.f,-3.5f,-4.f}, nd4j::DataType::FLOAT32); NDArray z1('c', {}, {100}, nd4j::DataType::FLOAT32); NDArray z2('c', {2,2}, {100,100,100,100}, nd4j::DataType::FLOAT32); @@ -1511,11 +1513,11 @@ TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_same_test1) { NDArray z4('c', {3,2}, {100,100,100,100,100,100}, nd4j::DataType::FLOAT32); NDArray z5('c', {2}, {100,100}, nd4j::DataType::FLOAT32); - NDArray exp1('c', {}, {26.5}, nd4j::DataType::FLOAT32); - NDArray exp2('c', {2,2}, {9.5,12,3,2}, nd4j::DataType::FLOAT32); - NDArray exp3('c', {3}, {19,4,3.5}, nd4j::DataType::FLOAT32); - NDArray exp4('c', {3,2}, {9,10,2,2,1.5,2}, nd4j::DataType::FLOAT32); - NDArray exp5('c', {2}, {21.5,5}, nd4j::DataType::FLOAT32); + NDArray exp1('c', {}, {26.5f}, nd4j::DataType::FLOAT32); + NDArray exp2('c', {2,2}, {9.5f,12.f,3.f,2.f}, nd4j::DataType::FLOAT32); + NDArray exp3('c', {3}, {19.f,4.f,3.5f}, nd4j::DataType::FLOAT32); + NDArray exp4('c', {3,2}, {9.f,10.f,2.f,2.f,1.5f,2.f}, nd4j::DataType::FLOAT32); + NDArray exp5('c', {2}, {21.5f,5.f}, nd4j::DataType::FLOAT32); x.reduceAlongDimension(nd4j::reduce::Sum, &z1, {0,1,2}); ASSERT_TRUE(z1.equalsTo(&exp1)); @@ -1575,17 +1577,17 @@ TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_bool_test1) { NDArray x('c', {2,3,2}, {0.5,2,3,-4,5,6,-7.5,8,-1,-0.5,-3.5,4}, nd4j::DataType::DOUBLE); - NDArray z1('c', {}, {100}, nd4j::DataType::BOOL); - NDArray z2('c', {2,2}, {100,100,100,100}, nd4j::DataType::BOOL); - NDArray z3('c', {3}, {100,100,100}, nd4j::DataType::BOOL); - NDArray z4('c', {3,2}, {100,100,100,100,100,100}, nd4j::DataType::BOOL); - NDArray z5('c', {2}, {100,100}, nd4j::DataType::BOOL); + NDArray z1('c', {}, {true}, nd4j::DataType::BOOL); + NDArray z2('c', {2,2}, {true,true,true,true}, nd4j::DataType::BOOL); + NDArray z3('c', {3}, {true,true,true}, nd4j::DataType::BOOL); + NDArray z4('c', {3,2}, {true,true,true,true,true,true}, nd4j::DataType::BOOL); + NDArray z5('c', {2}, {true,true}, nd4j::DataType::BOOL); - NDArray exp1('c', {}, {1}, nd4j::DataType::BOOL); - NDArray exp2('c', {2,2}, {1,1,0,1}, nd4j::DataType::BOOL); - NDArray exp3('c', {3}, {1,1,1}, nd4j::DataType::BOOL); - NDArray exp4('c', {3,2}, {1,1,1,0,1,1}, nd4j::DataType::BOOL); - NDArray exp5('c', {2}, {1,1}, nd4j::DataType::BOOL); + NDArray exp1('c', {}, {true}, nd4j::DataType::BOOL); + NDArray exp2('c', {2,2}, {true,true,false,true}, nd4j::DataType::BOOL); + NDArray exp3('c', {3}, {true,true,true}, nd4j::DataType::BOOL); + NDArray exp4('c', {3,2}, {true,true,true,false,true,true}, nd4j::DataType::BOOL); + NDArray exp5('c', {2}, {true,true}, nd4j::DataType::BOOL); x.reduceAlongDimension(nd4j::reduce::IsPositive, &z1, {0,1,2}); ASSERT_TRUE(z1.equalsTo(&exp1)); @@ -1643,7 +1645,7 @@ TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_bool_test2) { //////////////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_long_test1) { - NDArray x('c', {2,3,2}, {0.5,2,3,-0,5,6,-7.5,0,-1,-0.5,-3.5,4}, nd4j::DataType::FLOAT32); + NDArray x('c', {2,3,2}, {0.5f,2.f,3.f,-0.f,5.f,6.f,-7.5f,0.f,-1.f,-0.5f,-3.5f,4.f}, nd4j::DataType::FLOAT32); NDArray z1('c', {}, {100}, nd4j::DataType::INT64); NDArray z2('c', {2,2}, {100,100,100,100}, nd4j::DataType::INT64); @@ -1912,7 +1914,7 @@ TEST_F(NDArrayCudaBasicsTests, Tile_Test_2_3) TEST_F(NDArrayCudaBasicsTests, Operator_Plus_Test_2) { double expBuff[] = {2., 3, 3., 4., 4., 5, 5., 6., 6., 7, 7., 8.}; - NDArray a('c', {4,4}, {1.,2,3,4,5,6,7,8,9,2,3,2,1,0,4,7.}, nd4j::DataType::FLOAT32); + NDArray a('c', {4,4}, {1,2,3,4,5,6,7,8,9,2,3,2,1,0,4,7}, nd4j::DataType::FLOAT32); auto x = NDArrayFactory::create('c', {3, 2, 1}); auto y = NDArrayFactory::create('c', {1, 2}); auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 2}); @@ -1928,7 +1930,7 @@ TEST_F(NDArrayCudaBasicsTests, Operator_Plus_Test_2) ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, assign_2) { - NDArray x('c', {4}, {1.5,2.5,3.5,4.5}, nd4j::DataType::FLOAT32); + NDArray x('c', {4}, {1.5f,2.5f,3.5f,4.5f}, nd4j::DataType::FLOAT32); NDArray y('c', {4}, nd4j::DataType::INT32); NDArray expected('c', {4}, {1,2,3,4}, nd4j::DataType::INT32); @@ -1945,30 +1947,30 @@ TEST_F(NDArrayCudaBasicsTests, subarray_1) NDArray y('f', {2,3,4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24}, nd4j::DataType::FLOAT32); Nd4jLong shapeExpX0[] = {1, 2, 12, 8192, 1, 99}; - float buffExpX0[] = {1.000000, 13.000000}; + float buffExpX0[] = {1.f, 13.f}; Nd4jLong shapeExpX1[] = {1, 2, 12, 8192, 1, 99}; - float buffExpX1[] = {2.000000, 14.000000}; + float buffExpX1[] = {2.f, 14.f}; Nd4jLong shapeExpX2[] = {3, 2, 1, 1, 12, 4, 1, 8192, 1, 99}; - float buffExpX2[] = {1.000000, 13.000000}; + float buffExpX2[] = {1.f, 13.f}; Nd4jLong shapeExpX3[] = {2, 2, 4, 12, 1, 8192, 1, 99}; - float buffExpX3[] = {9.000000, 10.000000, 11.000000, 12.000000, 21.000000, 22.000000, 23.000000, 24.000000}; + float buffExpX3[] = {9.f, 10.f, 11.f, 12.f, 21.f, 22.f, 23.f, 24.f}; Nd4jLong shapeExpX4[] = {3, 2, 1, 4, 12, 4, 1, 8192, 1, 99}; - float buffExpX4[] = {9.000000, 10.000000, 11.000000, 12.000000, 21.000000, 22.000000, 23.000000, 24.000000}; + float buffExpX4[] = {9.f, 10.f, 11.f, 12.f, 21.f, 22.f, 23.f, 24.f}; Nd4jLong shapeExpX5[] = {2, 2, 3, 12, 4, 8192, 1, 99}; - float buffExpX5[] = {4.000000, 8.000000, 12.000000, 16.000000, 20.000000, 24.000000}; + float buffExpX5[] = {4.f, 8.f, 12.f, 16.f, 20.f, 24.f}; Nd4jLong shapeExpY0[] = {1, 2, 1, 8192, 1, 99}; - float buffExpY0[] = {1.000000, 2.000000}; + float buffExpY0[] = {1.f, 2.f}; Nd4jLong shapeExpY1[] = {1, 2, 1, 8192, 1, 99}; - float buffExpY1[] = {7.000000, 8.000000}; + float buffExpY1[] = {7.f, 8.f}; Nd4jLong shapeExpY2[] = {3, 2, 1, 1, 1, 2, 6, 8192, 1, 102}; - float buffExpY2[] = {1.000000, 2.000000}; + float buffExpY2[] = {1.f, 2.f}; Nd4jLong shapeExpY3[] = {2, 2, 4, 1, 6, 8192, 1, 99}; - float buffExpY3[] = {5.000000, 11.000000, 17.000000, 23.000000, 6.000000, 12.000000, 18.000000, 24.000000}; + float buffExpY3[] = {5.f, 11.f, 17.f, 23.f, 6.f, 12.f, 18.f, 24.f}; Nd4jLong shapeExpY4[] = {3, 2, 1, 4, 1, 2, 6, 8192, 1, 102}; - float buffExpY4[] = {5.000000, 11.000000, 17.000000, 23.000000, 6.000000, 12.000000, 18.000000, 24.000000}; + float buffExpY4[] = {5.f, 11.f, 17.f, 23.f, 6.f, 12.f, 18.f, 24.f}; Nd4jLong shapeExpY5[] = {2, 2, 3, 1, 2, 8192, 1, 99}; - float buffExpY5[] = {19.000000, 21.000000, 23.000000, 20.000000, 22.000000, 24.000000}; + float buffExpY5[] = {19.f, 21.f, 23.f, 20.f, 22.f, 24.f}; NDArray x0 = x(0, {1,2}); @@ -2121,7 +2123,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_diagonal_1) { TEST_F(NDArrayCudaBasicsTests, Test_PermuteEquality_02) { auto x = NDArrayFactory::linspace(1.f, 60.f, 60); //('c', {1, 60}); //x.linspace(1); - auto exp = NDArrayFactory::create('c', {3, 4, 5}, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 48.0, 49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, 60.0}); + auto exp = NDArrayFactory::create('c', {3, 4, 5}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f, 37.0f, 38.0f, 39.0f, 40.0f, 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 48.0f, 49.0f, 50.0f, 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, 56.0f, 57.0f, 58.0f, 59.0f, 60.0}); x->reshapei('c', {3, 4, 5}); x->permutei({0, 1, 2}); @@ -2138,7 +2140,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_PermuteEquality_02) { TEST_F(NDArrayCudaBasicsTests, Test_PermuteEquality_0) { auto x = NDArrayFactory::create('c', {1, 60}); x.linspace(1); - auto exp = NDArrayFactory::create('c', {3, 4, 5}, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 48.0, 49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, 60.0}); + auto exp = NDArrayFactory::create('c', {3, 4, 5}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f, 37.0f, 38.0f, 39.0f, 40.0f, 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 48.0f, 49.0f, 50.0f, 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, 56.0f, 57.0f, 58.0f, 59.0f, 60.0}); x.reshapei('c', {3, 4, 5}); x.permutei({0, 1, 2}); @@ -2153,7 +2155,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_PermuteEquality_0) { TEST_F(NDArrayCudaBasicsTests, Test_PermuteEquality_1) { auto x = NDArrayFactory::create('c', {1, 60}); x.linspace(1); - auto exp = NDArrayFactory::create('c', {3, 4, 5}, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 48.0, 49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, 60.0}); + auto exp = NDArrayFactory::create('c', {3, 4, 5}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f, 37.0f, 38.0f, 39.0f, 40.0f, 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 48.0f, 49.0f, 50.0f, 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, 56.0f, 57.0f, 58.0f, 59.0f, 60.0}); x.reshapei('c', {3, 4, 5}); x.permutei({0, 1, 2}); @@ -2170,7 +2172,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_PermuteEquality_2) { auto xx = NDArrayFactory::linspace(1.f, 60.f, 60); //('c', {1, 60}); // auto x = *xx; //x.linspace(1); -// auto exp = NDArrayFactory::create('c', {3, 4, 5}, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 48.0, 49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, 60.0}); +// auto exp = NDArrayFactory::create('c', {3, 4, 5}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f, 37.0f, 38.0f, 39.0f, 40.0f, 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 48.0f, 49.0f, 50.0f, 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, 56.0f, 57.0f, 58.0f, 59.0f, 60.0}); // x.reshapei('c', {3, 4, 5}); // x.permutei({0, 1, 2}); @@ -2188,7 +2190,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_PermuteEquality_3) { //x.linspace(1); for (int l = 0; l < x.lengthOf(); l++) x.p(l, float(l + 1.f)); - auto exp = NDArrayFactory::create('c', {3, 4, 5}, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 48.0, 49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, 60.0}); + auto exp = NDArrayFactory::create('c', {3, 4, 5}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f, 37.0f, 38.0f, 39.0f, 40.0f, 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 48.0f, 49.0f, 50.0f, 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, 56.0f, 57.0f, 58.0f, 59.0f, 60.0}); x.reshapei('c', {3, 4, 5}); x.permutei({0, 1, 2}); diff --git a/libnd4j/tests_cpu/layers_tests/NDArrayTests.cpp b/libnd4j/tests_cpu/layers_tests/NDArrayTests.cpp index 0f3cab509..d0fb4bf37 100644 --- a/libnd4j/tests_cpu/layers_tests/NDArrayTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/NDArrayTests.cpp @@ -774,7 +774,7 @@ TEST_F(NDArrayTest, TestTile3) { TEST_F(NDArrayTest, TestTile4) { float xBuff[] = {1,2,3,4,5,6}; - float expBuff[] = {1.,2., 1.,2., 3.,4., 3.,4., 5.,6., 5.,6.}; + float expBuff[] = {1.f,2.f, 1.f,2.f, 3.f,4.f, 3.f,4.f, 5.f,6.f, 5.f,6.f}; auto x = NDArrayFactory::create(xBuff, 'c', {3,1,2}); auto exp = NDArrayFactory::create(expBuff, 'c', {3,2,2}); @@ -789,7 +789,7 @@ TEST_F(NDArrayTest, TestTile4) { TEST_F(NDArrayTest, TestTile5) { float xBuff[] = {1,2,3,4,5,6,7,8,9,10,11,12}; - float expBuff[] = {1., 2., 3., 4., 1., 2., 3., 4., 5., 6., 7., 8., 5., 6., 7., 8., 9.,10., 11.,12., 9.,10., 11.,12.}; + float expBuff[] = {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 5.f, 6.f, 7.f, 8.f, 9.f,10.f, 11.f,12.f, 9.f,10.f, 11.f,12.f}; auto x = NDArrayFactory::create(xBuff, 'c', {3,2,2}); auto exp = NDArrayFactory::create(expBuff, 'c', {3,4,2}); @@ -847,7 +847,7 @@ TEST_F(NDArrayTest, TestPermuteReshapeMmul1) { auto y = NDArrayFactory::create('c', {3, 6}); Nd4jLong _expS[] = {2, 3, 3, 1, 3, 8192, 1, 102}; - float _expB[] = {231.0, 252.0, 273.0, 537.0, 594.0, 651.0, 843.0, 936.0, 1029.0}; + float _expB[] = {231.0f, 252.0f, 273.0f, 537.0f, 594.0f, 651.0f, 843.0f, 936.0f, 1029.0f}; NDArray exp(_expB, _expS); for (int e = 0; e < x.lengthOf(); e++) @@ -872,7 +872,7 @@ TEST_F(NDArrayTest, TestPermuteReshapeMmul2) { auto y = NDArrayFactory::create('c', {3, 6}); Nd4jLong _expS[] = {2, 3, 3, 1, 3, 8192, 1, 102}; - float _expB[] = {231.0, 252.0, 273.0, 537.0, 594.0, 651.0, 843.0, 936.0, 1029.0}; + float _expB[] = {231.0f, 252.0f, 273.0f, 537.0f, 594.0f, 651.0f, 843.0f, 936.0f, 1029.0f}; NDArray exp(_expB, _expS); for (int e = 0; e < x.lengthOf(); e++) @@ -903,7 +903,7 @@ TEST_F(NDArrayTest, TestPermuteReshapeMmul3) { auto y = NDArrayFactory::create('c', {2, 3, 2 ,2}); Nd4jLong _expS[] = {2, 8, 2, 1, 8, 8192, 1, 102}; - float _expB[] = {1624.0, 1858.0, 2092.0, 2326.0, 5368.0, 5602.0, 5836.0, 6070.0, 4504.0, 5170.0, 5836.0, 6502.0, 15160.0, 15826.0, 16492.0, 17158.0}; + float _expB[] = {1624.0f, 1858.0f, 2092.0f, 2326.0f, 5368.0f, 5602.0f, 5836.0f, 6070.0f, 4504.0f, 5170.0f, 5836.0f, 6502.0f, 15160.0f, 15826.0f, 16492.0f, 17158.0f}; NDArray exp(_expB, _expS); for (int e = 0; e < x.lengthOf(); e++) @@ -931,7 +931,7 @@ TEST_F(NDArrayTest, TestPermuteReshapeMmul4) { auto y = NDArrayFactory::create('c', {2, 3, 2 ,2}); Nd4jLong _expS[] = {2, 8, 2, 1, 8, 8192, 1, 102}; - float _expB[] = {1624.0, 1858.0, 2092.0, 2326.0, 5368.0, 5602.0, 5836.0, 6070.0, 4504.0, 5170.0, 5836.0, 6502.0, 15160.0, 15826.0, 16492.0, 17158.0}; + float _expB[] = {1624.0f, 1858.0f, 2092.0f, 2326.0f, 5368.0f, 5602.0f, 5836.0f, 6070.0f, 4504.0f, 5170.0f, 5836.0f, 6502.0f, 15160.0f, 15826.0f, 16492.0f, 17158.0f}; NDArray exp(_expB, _expS); for (int e = 0; e < x.lengthOf(); e++) @@ -971,7 +971,7 @@ TEST_F(NDArrayTest, TestMmulHelper2) { auto z = NDArrayFactory::create_('f', {5, 1}); - auto expBuffer = new float[5]{28.00, 64.00, 100.00, 136.00, 172.00}; + auto expBuffer = new float[5]{28.00f, 64.00f, 100.00f, 136.00f, 172.00f}; auto exp = new NDArray(expBuffer, z->getShapeInfo(), nd4j::LaunchContext ::defaultContext(), true); //nd4j::blas::GEMV::op('f', x->rows(), x->columns(), 1.0f, x->getBuffer(), y->rows(), y->getBuffer(), 1, 0.0, z->getBuffer(), 1); @@ -1000,7 +1000,7 @@ TEST_F(NDArrayTest, TestMmulHelper3) { auto z = NDArrayFactory::create_('f', {5, 1}); - auto expBuffer = new float[5]{92.00, 104.00, 116.00, 128.00, 140.00}; + auto expBuffer = new float[5]{92.00f, 104.00f, 116.00f, 128.00f, 140.00f}; auto exp = new NDArray(expBuffer, z->getShapeInfo()); //nd4j::blas::GEMV::op('f', x->rows(), x->columns(), 1.0f, x->getBuffer(), y->rows(), y->getBuffer(), 1, 0.0, z->getBuffer(), 1); @@ -1035,7 +1035,7 @@ TEST_F(NDArrayTest, TestMmulHelper4) { auto z = NDArrayFactory::create_('f', {3, 3}); - auto expBuffer = new float[9]{7.0, 21.0, 35.0, 10.0, 28.0, 46.0, 13.0, 35.0, 57.0}; + auto expBuffer = new float[9]{7.0f, 21.0f, 35.0f, 10.0f, 28.0f, 46.0f, 13.0f, 35.0f, 57.0f}; auto exp = new NDArray(expBuffer, z->getShapeInfo()); MmulHelper::mmul(x, y, z); @@ -1065,7 +1065,7 @@ TEST_F(NDArrayTest, TestMmulHelper5) { auto z = NDArrayFactory::create_('f', {3, 3}); - auto expBuffer = new float[9]{7.0, 14.0, 21.0, 12.0, 21.0, 30.0, 17.0, 28.0, 39.0}; + auto expBuffer = new float[9]{7.0f, 14.0f, 21.0f, 12.0f, 21.0f, 30.0f, 17.0f, 28.0f, 39.0f}; auto exp = new NDArray(expBuffer, z->getShapeInfo()); MmulHelper::mmul(x, y, z); @@ -1095,7 +1095,7 @@ TEST_F(NDArrayTest, TestMmulHelper6) { auto z = NDArrayFactory::create_('f', {3, 3}); - auto expBuffer = new float[9]{39.0, 54.0, 69.0, 9.0, 18.0, 27.0, 9.0, 12.0, 15.0}; + auto expBuffer = new float[9]{39.0f, 54.0f, 69.0f, 9.0f, 18.0f, 27.0f, 9.0f, 12.0f, 15.0f}; auto exp = new NDArray(expBuffer, z->getShapeInfo()); MmulHelper::mmul(x, y, z); @@ -1126,7 +1126,7 @@ TEST_F(NDArrayTest, TestMmulHelper7) { auto z = NDArrayFactory::create_('f', {1, 3}); - auto expBuffer = new float[9]{110.00, 260.00, 410.00}; + auto expBuffer = new float[9]{110.00f, 260.00f, 410.00f}; auto exp = new NDArray(expBuffer, z->getShapeInfo()); MmulHelper::mmul(y, x, z); @@ -1171,7 +1171,59 @@ TEST_F(NDArrayTest, TestMmulHelper_ND_1) { TEST_F(NDArrayTest, TestMmulHelper_ND_2) { Nd4jLong _expS[] = {3, 2, 72, 2, 144, 2, 1, 8192, 1, 99}; - float _expB[] = {1.07250000e+04, 1.10500000e+04, 2.63500000e+04, 2.73000000e+04, 4.19750000e+04, 4.35500000e+04, 5.76000000e+04, 5.98000000e+04, 7.32250000e+04, 7.60500000e+04, 8.88500000e+04, 9.23000000e+04, 1.04475000e+05, 1.08550000e+05, 1.20100000e+05, 1.24800000e+05, 1.35725000e+05, 1.41050000e+05, 1.51350000e+05, 1.57300000e+05, 1.66975000e+05, 1.73550000e+05, 1.82600000e+05, 1.89800000e+05, 1.98225000e+05, 2.06050000e+05, 2.13850000e+05, 2.22300000e+05, 2.29475000e+05, 2.38550000e+05, 2.45100000e+05, 2.54800000e+05, 2.60725000e+05, 2.71050000e+05, 2.76350000e+05, 2.87300000e+05, 2.91975000e+05, 3.03550000e+05, 3.07600000e+05, 3.19800000e+05, 3.23225000e+05, 3.36050000e+05, 3.38850000e+05, 3.52300000e+05, 3.54475000e+05, 3.68550000e+05, 3.70100000e+05, 3.84800000e+05, 3.85725000e+05, 4.01050000e+05, 4.01350000e+05, 4.17300000e+05, 4.16975000e+05, 4.33550000e+05, 4.32600000e+05, 4.49800000e+05, 4.48225000e+05, 4.66050000e+05, 4.63850000e+05, 4.82300000e+05, 4.79475000e+05, 4.98550000e+05, 4.95100000e+05, 5.14800000e+05, 5.10725000e+05, 5.31050000e+05, 5.26350000e+05, 5.47300000e+05, 5.41975000e+05, 5.63550000e+05, 5.57600000e+05, 5.79800000e+05, 5.73225000e+05, 5.96050000e+05, 5.88850000e+05, 6.12300000e+05, 6.04475000e+05, 6.28550000e+05, 6.20100000e+05, 6.44800000e+05, 6.35725000e+05, 6.61050000e+05, 6.51350000e+05, 6.77300000e+05, 6.66975000e+05, 6.93550000e+05, 6.82600000e+05, 7.09800000e+05, 6.98225000e+05, 7.26050000e+05, 7.13850000e+05, 7.42300000e+05, 7.29475000e+05, 7.58550000e+05, 7.45100000e+05, 7.74800000e+05, 7.60725000e+05, 7.91050000e+05, 7.76350000e+05, 8.07300000e+05, 7.91975000e+05, 8.23550000e+05, 8.07600000e+05, 8.39800000e+05, 8.23225000e+05, 8.56050000e+05, 8.38850000e+05, 8.72300000e+05, 8.54475000e+05, 8.88550000e+05, 8.70100000e+05, 9.04800000e+05, 8.85725000e+05, 9.21050000e+05, 9.01350000e+05, 9.37300000e+05, 9.16975000e+05, 9.53550000e+05, 9.32600000e+05, 9.69800000e+05, 9.48225000e+05, 9.86050000e+05, 9.63850000e+05, 1.00230000e+06, 9.79475000e+05, 1.01855000e+06, 9.95100000e+05, 1.03480000e+06, 1.01072500e+06, 1.05105000e+06, 1.02635000e+06, 1.06730000e+06, 1.04197500e+06, 1.08355000e+06, 1.05760000e+06, 1.09980000e+06, 1.07322500e+06, 1.11605000e+06, 1.08885000e+06, 1.13230000e+06, 1.10447500e+06, 1.14855000e+06, 1.12010000e+06, 1.16480000e+06, 1.13572500e+06, 1.18105000e+06, 1.15135000e+06, 1.19730000e+06, 1.16697500e+06, 1.21355000e+06, 3.54260000e+06, 3.58980000e+06, 3.58947500e+06, 3.63730000e+06, 3.63635000e+06, 3.68480000e+06, 3.68322500e+06, 3.73230000e+06, 3.73010000e+06, 3.77980000e+06, 3.77697500e+06, 3.82730000e+06, 3.82385000e+06, 3.87480000e+06, 3.87072500e+06, 3.92230000e+06, 3.91760000e+06, 3.96980000e+06, 3.96447500e+06, 4.01730000e+06, 4.01135000e+06, 4.06480000e+06, 4.05822500e+06, 4.11230000e+06, 4.10510000e+06, 4.15980000e+06, 4.15197500e+06, 4.20730000e+06, 4.19885000e+06, 4.25480000e+06, 4.24572500e+06, 4.30230000e+06, 4.29260000e+06, 4.34980000e+06, 4.33947500e+06, 4.39730000e+06, 4.38635000e+06, 4.44480000e+06, 4.43322500e+06, 4.49230000e+06, 4.48010000e+06, 4.53980000e+06, 4.52697500e+06, 4.58730000e+06, 4.57385000e+06, 4.63480000e+06, 4.62072500e+06, 4.68230000e+06, 4.66760000e+06, 4.72980000e+06, 4.71447500e+06, 4.77730000e+06, 4.76135000e+06, 4.82480000e+06, 4.80822500e+06, 4.87230000e+06, 4.85510000e+06, 4.91980000e+06, 4.90197500e+06, 4.96730000e+06, 4.94885000e+06, 5.01480000e+06, 4.99572500e+06, 5.06230000e+06, 5.04260000e+06, 5.10980000e+06, 5.08947500e+06, 5.15730000e+06, 5.13635000e+06, 5.20480000e+06, 5.18322500e+06, 5.25230000e+06, 5.23010000e+06, 5.29980000e+06, 5.27697500e+06, 5.34730000e+06, 5.32385000e+06, 5.39480000e+06, 5.37072500e+06, 5.44230000e+06, 5.41760000e+06, 5.48980000e+06, 5.46447500e+06, 5.53730000e+06, 5.51135000e+06, 5.58480000e+06, 5.55822500e+06, 5.63230000e+06, 5.60510000e+06, 5.67980000e+06, 5.65197500e+06, 5.72730000e+06, 5.69885000e+06, 5.77480000e+06, 5.74572500e+06, 5.82230000e+06, 5.79260000e+06, 5.86980000e+06, 5.83947500e+06, 5.91730000e+06, 5.88635000e+06, 5.96480000e+06, 5.93322500e+06, 6.01230000e+06, 5.98010000e+06, 6.05980000e+06, 6.02697500e+06, 6.10730000e+06, 6.07385000e+06, 6.15480000e+06, 6.12072500e+06, 6.20230000e+06, 6.16760000e+06, 6.24980000e+06, 6.21447500e+06, 6.29730000e+06, 6.26135000e+06, 6.34480000e+06, 6.30822500e+06, 6.39230000e+06, 6.35510000e+06, 6.43980000e+06, 6.40197500e+06, 6.48730000e+06, 6.44885000e+06, 6.53480000e+06, 6.49572500e+06, 6.58230000e+06, 6.54260000e+06, 6.62980000e+06, 6.58947500e+06, 6.67730000e+06, 6.63635000e+06, 6.72480000e+06, 6.68322500e+06, 6.77230000e+06, 6.73010000e+06, 6.81980000e+06, 6.77697500e+06, 6.86730000e+06, 6.82385000e+06, 6.91480000e+06, 6.87072500e+06, 6.96230000e+06, 6.91760000e+06, 7.00980000e+06, 6.96447500e+06, 7.05730000e+06, 7.01135000e+06, 7.10480000e+06, 1.17619750e+07, 1.18560500e+07, 1.18401000e+07, 1.19348000e+07, 1.19182250e+07, 1.20135500e+07, 1.19963500e+07, 1.20923000e+07, 1.20744750e+07, 1.21710500e+07, 1.21526000e+07, 1.22498000e+07, 1.22307250e+07, 1.23285500e+07, 1.23088500e+07, 1.24073000e+07, 1.23869750e+07, 1.24860500e+07, 1.24651000e+07, 1.25648000e+07, 1.25432250e+07, 1.26435500e+07, 1.26213500e+07, 1.27223000e+07, 1.26994750e+07, 1.28010500e+07, 1.27776000e+07, 1.28798000e+07, 1.28557250e+07, 1.29585500e+07, 1.29338500e+07, 1.30373000e+07, 1.30119750e+07, 1.31160500e+07, 1.30901000e+07, 1.31948000e+07, 1.31682250e+07, 1.32735500e+07, 1.32463500e+07, 1.33523000e+07, 1.33244750e+07, 1.34310500e+07, 1.34026000e+07, 1.35098000e+07, 1.34807250e+07, 1.35885500e+07, 1.35588500e+07, 1.36673000e+07, 1.36369750e+07, 1.37460500e+07, 1.37151000e+07, 1.38248000e+07, 1.37932250e+07, 1.39035500e+07, 1.38713500e+07, 1.39823000e+07, 1.39494750e+07, 1.40610500e+07, 1.40276000e+07, 1.41398000e+07, 1.41057250e+07, 1.42185500e+07, 1.41838500e+07, 1.42973000e+07, 1.42619750e+07, 1.43760500e+07, 1.43401000e+07, 1.44548000e+07, 1.44182250e+07, 1.45335500e+07, 1.44963500e+07, 1.46123000e+07, 1.45744750e+07, 1.46910500e+07, 1.46526000e+07, 1.47698000e+07, 1.47307250e+07, 1.48485500e+07, 1.48088500e+07, 1.49273000e+07, 1.48869750e+07, 1.50060500e+07, 1.49651000e+07, 1.50848000e+07, 1.50432250e+07, 1.51635500e+07, 1.51213500e+07, 1.52423000e+07, 1.51994750e+07, 1.53210500e+07, 1.52776000e+07, 1.53998000e+07, 1.53557250e+07, 1.54785500e+07, 1.54338500e+07, 1.55573000e+07, 1.55119750e+07, 1.56360500e+07, 1.55901000e+07, 1.57148000e+07, 1.56682250e+07, 1.57935500e+07, 1.57463500e+07, 1.58723000e+07, 1.58244750e+07, 1.59510500e+07, 1.59026000e+07, 1.60298000e+07, 1.59807250e+07, 1.61085500e+07, 1.60588500e+07, 1.61873000e+07, 1.61369750e+07, 1.62660500e+07, 1.62151000e+07, 1.63448000e+07, 1.62932250e+07, 1.64235500e+07, 1.63713500e+07, 1.65023000e+07, 1.64494750e+07, 1.65810500e+07, 1.65276000e+07, 1.66598000e+07, 1.66057250e+07, 1.67385500e+07, 1.66838500e+07, 1.68173000e+07, 1.67619750e+07, 1.68960500e+07, 1.68401000e+07, 1.69748000e+07, 1.69182250e+07, 1.70535500e+07, 1.69963500e+07, 1.71323000e+07, 1.70744750e+07, 1.72110500e+07, 1.71526000e+07, 1.72898000e+07, 1.72307250e+07, 1.73685500e+07, 1.73088500e+07, 1.74473000e+07, 1.73869750e+07, 1.75260500e+07, 1.74651000e+07, 1.76048000e+07, 1.75432250e+07, 1.76835500e+07, 2.46688500e+07, 2.48098000e+07, 2.47782250e+07, 2.49198000e+07, 2.48876000e+07, 2.50298000e+07, 2.49969750e+07, 2.51398000e+07, 2.51063500e+07, 2.52498000e+07, 2.52157250e+07, 2.53598000e+07, 2.53251000e+07, 2.54698000e+07, 2.54344750e+07, 2.55798000e+07, 2.55438500e+07, 2.56898000e+07, 2.56532250e+07, 2.57998000e+07, 2.57626000e+07, 2.59098000e+07, 2.58719750e+07, 2.60198000e+07, 2.59813500e+07, 2.61298000e+07, 2.60907250e+07, 2.62398000e+07, 2.62001000e+07, 2.63498000e+07, 2.63094750e+07, 2.64598000e+07, 2.64188500e+07, 2.65698000e+07, 2.65282250e+07, 2.66798000e+07, 2.66376000e+07, 2.67898000e+07, 2.67469750e+07, 2.68998000e+07, 2.68563500e+07, 2.70098000e+07, 2.69657250e+07, 2.71198000e+07, 2.70751000e+07, 2.72298000e+07, 2.71844750e+07, 2.73398000e+07, 2.72938500e+07, 2.74498000e+07, 2.74032250e+07, 2.75598000e+07, 2.75126000e+07, 2.76698000e+07, 2.76219750e+07, 2.77798000e+07, 2.77313500e+07, 2.78898000e+07, 2.78407250e+07, 2.79998000e+07, 2.79501000e+07, 2.81098000e+07, 2.80594750e+07, 2.82198000e+07, 2.81688500e+07, 2.83298000e+07, 2.82782250e+07, 2.84398000e+07, 2.83876000e+07, 2.85498000e+07, 2.84969750e+07, 2.86598000e+07, 2.86063500e+07, 2.87698000e+07, 2.87157250e+07, 2.88798000e+07, 2.88251000e+07, 2.89898000e+07, 2.89344750e+07, 2.90998000e+07, 2.90438500e+07, 2.92098000e+07, 2.91532250e+07, 2.93198000e+07, 2.92626000e+07, 2.94298000e+07, 2.93719750e+07, 2.95398000e+07, 2.94813500e+07, 2.96498000e+07, 2.95907250e+07, 2.97598000e+07, 2.97001000e+07, 2.98698000e+07, 2.98094750e+07, 2.99798000e+07, 2.99188500e+07, 3.00898000e+07, 3.00282250e+07, 3.01998000e+07, 3.01376000e+07, 3.03098000e+07, 3.02469750e+07, 3.04198000e+07, 3.03563500e+07, 3.05298000e+07, 3.04657250e+07, 3.06398000e+07, 3.05751000e+07, 3.07498000e+07, 3.06844750e+07, 3.08598000e+07, 3.07938500e+07, 3.09698000e+07, 3.09032250e+07, 3.10798000e+07, 3.10126000e+07, 3.11898000e+07, 3.11219750e+07, 3.12998000e+07, 3.12313500e+07, 3.14098000e+07, 3.13407250e+07, 3.15198000e+07, 3.14501000e+07, 3.16298000e+07, 3.15594750e+07, 3.17398000e+07, 3.16688500e+07, 3.18498000e+07, 3.17782250e+07, 3.19598000e+07, 3.18876000e+07, 3.20698000e+07, 3.19969750e+07, 3.21798000e+07, 3.21063500e+07, 3.22898000e+07, 3.22157250e+07, 3.23998000e+07, 3.23251000e+07, 3.25098000e+07, 3.24344750e+07, 3.26198000e+07, 3.25438500e+07, 3.27298000e+07, 3.26532250e+07, 3.28398000e+07, 3.27626000e+07, 3.29498000e+07}; + float _expB[] = { + 1.07250000e+04f, 1.10500000e+04f, 2.63500000e+04f, 2.73000000e+04f, 4.19750000e+04f, 4.35500000e+04f, + 5.76000000e+04f, 5.98000000e+04f, 7.32250000e+04f, 7.60500000e+04f, 8.88500000e+04f, 9.23000000e+04f, + 1.04475000e+05f, 1.08550000e+05f, 1.20100000e+05f, 1.24800000e+05f, 1.35725000e+05f, 1.41050000e+05f, + 1.51350000e+05f, 1.57300000e+05f, 1.66975000e+05f, 1.73550000e+05f, 1.82600000e+05f, 1.89800000e+05f, + 1.98225000e+05f, 2.06050000e+05f, 2.13850000e+05f, 2.22300000e+05f, 2.29475000e+05f, 2.38550000e+05f, + 2.45100000e+05f, 2.54800000e+05f, 2.60725000e+05f, 2.71050000e+05f, 2.76350000e+05f, 2.87300000e+05f, + 2.91975000e+05f, 3.03550000e+05f, 3.07600000e+05f, 3.19800000e+05f, 3.23225000e+05f, 3.36050000e+05f, + 3.38850000e+05f, 3.52300000e+05f, 3.54475000e+05f, 3.68550000e+05f, 3.70100000e+05f, 3.84800000e+05f, + 3.85725000e+05f, 4.01050000e+05f, 4.01350000e+05f, 4.17300000e+05f, 4.16975000e+05f, 4.33550000e+05f, + 4.32600000e+05f, 4.49800000e+05f, 4.48225000e+05f, 4.66050000e+05f, 4.63850000e+05f, 4.82300000e+05f, + 4.79475000e+05f, 4.98550000e+05f, 4.95100000e+05f, 5.14800000e+05f, 5.10725000e+05f, 5.31050000e+05f, + 5.26350000e+05f, 5.47300000e+05f, 5.41975000e+05f, 5.63550000e+05f, 5.57600000e+05f, 5.79800000e+05f, + 5.73225000e+05f, 5.96050000e+05f, 5.88850000e+05f, 6.12300000e+05f, 6.04475000e+05f, 6.28550000e+05f, + 6.20100000e+05f, 6.44800000e+05f, 6.35725000e+05f, 6.61050000e+05f, 6.51350000e+05f, 6.77300000e+05f, + 6.66975000e+05f, 6.93550000e+05f, 6.82600000e+05f, 7.09800000e+05f, 6.98225000e+05f, 7.26050000e+05f, + 7.13850000e+05f, 7.42300000e+05f, 7.29475000e+05f, 7.58550000e+05f, 7.45100000e+05f, 7.74800000e+05f, + 7.60725000e+05f, 7.91050000e+05f, 7.76350000e+05f, 8.07300000e+05f, 7.91975000e+05f, 8.23550000e+05f, + 8.07600000e+05f, 8.39800000e+05f, 8.23225000e+05f, 8.56050000e+05f, 8.38850000e+05f, 8.72300000e+05f, + 8.54475000e+05f, 8.88550000e+05f, 8.70100000e+05f, 9.04800000e+05f, 8.85725000e+05f, 9.21050000e+05f, + 9.01350000e+05f, 9.37300000e+05f, 9.16975000e+05f, 9.53550000e+05f, 9.32600000e+05f, 9.69800000e+05f, + 9.48225000e+05f, 9.86050000e+05f, 9.63850000e+05f, 1.00230000e+06f, 9.79475000e+05f, 1.01855000e+06f, + 9.95100000e+05f, 1.03480000e+06f, 1.01072500e+06f, 1.05105000e+06f, 1.02635000e+06f, 1.06730000e+06f, + 1.04197500e+06f, 1.08355000e+06f, 1.05760000e+06f, 1.09980000e+06f, 1.07322500e+06f, 1.11605000e+06f, + 1.08885000e+06f, 1.13230000e+06f, 1.10447500e+06f, 1.14855000e+06f, 1.12010000e+06f, 1.16480000e+06f, + 1.13572500e+06f, 1.18105000e+06f, 1.15135000e+06f, 1.19730000e+06f, 1.16697500e+06f, 1.21355000e+06f, + 3.54260000e+06f, 3.58980000e+06f, 3.58947500e+06f, 3.63730000e+06f, 3.63635000e+06f, 3.68480000e+06f, + 3.68322500e+06f, 3.73230000e+06f, 3.73010000e+06f, 3.77980000e+06f, 3.77697500e+06f, 3.82730000e+06f, + 3.82385000e+06f, 3.87480000e+06f, 3.87072500e+06f, 3.92230000e+06f, 3.91760000e+06f, 3.96980000e+06f, + 3.96447500e+06f, 4.01730000e+06f, 4.01135000e+06f, 4.06480000e+06f, 4.05822500e+06f, 4.11230000e+06f, + 4.10510000e+06f, 4.15980000e+06f, 4.15197500e+06f, 4.20730000e+06f, 4.19885000e+06f, 4.25480000e+06f, + 4.24572500e+06f, 4.30230000e+06f, 4.29260000e+06f, 4.34980000e+06f, 4.33947500e+06f, 4.39730000e+06f, + 4.38635000e+06f, 4.44480000e+06f, 4.43322500e+06f, 4.49230000e+06f, 4.48010000e+06f, 4.53980000e+06f, + 4.52697500e+06f, 4.58730000e+06f, 4.57385000e+06f, 4.63480000e+06f, 4.62072500e+06f, 4.68230000e+06f, + 4.66760000e+06f, 4.72980000e+06f, 4.71447500e+06f, 4.77730000e+06f, 4.76135000e+06f, 4.82480000e+06f, + 4.80822500e+06f, 4.87230000e+06f, 4.85510000e+06f, 4.91980000e+06f, 4.90197500e+06f, 4.96730000e+06f, + 4.94885000e+06f, 5.01480000e+06f, 4.99572500e+06f, 5.06230000e+06f, 5.04260000e+06f, 5.10980000e+06f, + 5.08947500e+06f, 5.15730000e+06f, 5.13635000e+06f, 5.20480000e+06f, 5.18322500e+06f, 5.25230000e+06f, + 5.23010000e+06f, 5.29980000e+06f, 5.27697500e+06f, 5.34730000e+06f, 5.32385000e+06f, 5.39480000e+06f, + 5.37072500e+06f, 5.44230000e+06f, 5.41760000e+06f, 5.48980000e+06f, 5.46447500e+06f, 5.53730000e+06f, + 5.51135000e+06f, 5.58480000e+06f, 5.55822500e+06f, 5.63230000e+06f, 5.60510000e+06f, 5.67980000e+06f, + 5.65197500e+06f, 5.72730000e+06f, 5.69885000e+06f, 5.77480000e+06f, 5.74572500e+06f, 5.82230000e+06f, + 5.79260000e+06f, 5.86980000e+06f, 5.83947500e+06f, 5.91730000e+06f, 5.88635000e+06f, 5.96480000e+06f, + 5.93322500e+06f, 6.01230000e+06f, 5.98010000e+06f, 6.05980000e+06f, 6.02697500e+06f, 6.10730000e+06f, + 6.07385000e+06f, 6.15480000e+06f, 6.12072500e+06f, 6.20230000e+06f, 6.16760000e+06f, 6.24980000e+06f, + 6.21447500e+06f, 6.29730000e+06f, 6.26135000e+06f, 6.34480000e+06f, 6.30822500e+06f, 6.39230000e+06f, + 6.35510000e+06f, 6.43980000e+06f, 6.40197500e+06f, 6.48730000e+06f, 6.44885000e+06f, 6.53480000e+06f, + 6.49572500e+06f, 6.58230000e+06f, 6.54260000e+06f, 6.62980000e+06f, 6.58947500e+06f, 6.67730000e+06f, + 6.63635000e+06f, 6.72480000e+06f, 6.68322500e+06f, 6.77230000e+06f, 6.73010000e+06f, 6.81980000e+06f, + 6.77697500e+06f, 6.86730000e+06f, 6.82385000e+06f, 6.91480000e+06f, 6.87072500e+06f, 6.96230000e+06f, + 6.91760000e+06f, 7.00980000e+06f, 6.96447500e+06f, 7.05730000e+06f, 7.01135000e+06f, 7.10480000e+06f, + 1.17619750e+07f, 1.18560500e+07f, 1.18401000e+07f, 1.19348000e+07f, 1.19182250e+07f, 1.20135500e+07f, + 1.19963500e+07f, 1.20923000e+07f, 1.20744750e+07f, 1.21710500e+07f, 1.21526000e+07f, 1.22498000e+07f, 1.22307250e+07f, 1.23285500e+07f, 1.23088500e+07f, 1.24073000e+07f, 1.23869750e+07f, 1.24860500e+07f, 1.24651000e+07f, 1.25648000e+07f, 1.25432250e+07f, 1.26435500e+07f, 1.26213500e+07f, 1.27223000e+07f, 1.26994750e+07f, 1.28010500e+07f, 1.27776000e+07f, 1.28798000e+07f, 1.28557250e+07f, 1.29585500e+07f, 1.29338500e+07f, 1.30373000e+07f, 1.30119750e+07f, 1.31160500e+07f, 1.30901000e+07f, 1.31948000e+07f, 1.31682250e+07f, 1.32735500e+07f, 1.32463500e+07f, 1.33523000e+07f, 1.33244750e+07f, 1.34310500e+07f, 1.34026000e+07f, 1.35098000e+07f, 1.34807250e+07f, 1.35885500e+07f, 1.35588500e+07f, 1.36673000e+07f, 1.36369750e+07f, 1.37460500e+07f, 1.37151000e+07f, 1.38248000e+07f, 1.37932250e+07f, 1.39035500e+07f, 1.38713500e+07f, 1.39823000e+07f, 1.39494750e+07f, 1.40610500e+07f, 1.40276000e+07f, 1.41398000e+07f, 1.41057250e+07f, 1.42185500e+07f, 1.41838500e+07f, 1.42973000e+07f, 1.42619750e+07f, 1.43760500e+07f, 1.43401000e+07f, 1.44548000e+07f, 1.44182250e+07f, 1.45335500e+07f, 1.44963500e+07f, 1.46123000e+07f, 1.45744750e+07f, 1.46910500e+07f, 1.46526000e+07f, 1.47698000e+07f, 1.47307250e+07f, 1.48485500e+07f, 1.48088500e+07f, 1.49273000e+07f, 1.48869750e+07f, 1.50060500e+07f, 1.49651000e+07f, 1.50848000e+07f, 1.50432250e+07f, 1.51635500e+07f, 1.51213500e+07f, 1.52423000e+07f, 1.51994750e+07f, 1.53210500e+07f, 1.52776000e+07f, 1.53998000e+07f, 1.53557250e+07f, 1.54785500e+07f, 1.54338500e+07f, 1.55573000e+07f, 1.55119750e+07f, 1.56360500e+07f, 1.55901000e+07f, 1.57148000e+07f, 1.56682250e+07f, 1.57935500e+07f, 1.57463500e+07f, 1.58723000e+07f, 1.58244750e+07f, 1.59510500e+07f, 1.59026000e+07f, 1.60298000e+07f, 1.59807250e+07f, 1.61085500e+07f, 1.60588500e+07f, 1.61873000e+07f, 1.61369750e+07f, 1.62660500e+07f, 1.62151000e+07f, 1.63448000e+07f, 1.62932250e+07f, 1.64235500e+07f, 1.63713500e+07f, 1.65023000e+07f, 1.64494750e+07f, 1.65810500e+07f, 1.65276000e+07f, 1.66598000e+07f, 1.66057250e+07f, 1.67385500e+07f, 1.66838500e+07f, 1.68173000e+07f, 1.67619750e+07f, 1.68960500e+07f, 1.68401000e+07f, 1.69748000e+07f, 1.69182250e+07f, 1.70535500e+07f, 1.69963500e+07f, 1.71323000e+07f, 1.70744750e+07f, 1.72110500e+07f, 1.71526000e+07f, 1.72898000e+07f, 1.72307250e+07f, 1.73685500e+07f, 1.73088500e+07f, 1.74473000e+07f, 1.73869750e+07f, 1.75260500e+07f, 1.74651000e+07f, 1.76048000e+07f, 1.75432250e+07f, 1.76835500e+07f, 2.46688500e+07f, 2.48098000e+07f, 2.47782250e+07f, 2.49198000e+07f, 2.48876000e+07f, 2.50298000e+07f, 2.49969750e+07f, 2.51398000e+07f, 2.51063500e+07f, 2.52498000e+07f, 2.52157250e+07f, 2.53598000e+07f, 2.53251000e+07f, 2.54698000e+07f, 2.54344750e+07f, 2.55798000e+07f, 2.55438500e+07f, 2.56898000e+07f, 2.56532250e+07f, 2.57998000e+07f, 2.57626000e+07f, 2.59098000e+07f, 2.58719750e+07f, 2.60198000e+07f, 2.59813500e+07f, 2.61298000e+07f, 2.60907250e+07f, 2.62398000e+07f, 2.62001000e+07f, 2.63498000e+07f, 2.63094750e+07f, 2.64598000e+07f, 2.64188500e+07f, 2.65698000e+07f, 2.65282250e+07f, 2.66798000e+07f, 2.66376000e+07f, 2.67898000e+07f, 2.67469750e+07f, 2.68998000e+07f, 2.68563500e+07f, 2.70098000e+07f, 2.69657250e+07f, 2.71198000e+07f, 2.70751000e+07f, 2.72298000e+07f, 2.71844750e+07f, 2.73398000e+07f, 2.72938500e+07f, 2.74498000e+07f, 2.74032250e+07f, 2.75598000e+07f, 2.75126000e+07f, 2.76698000e+07f, 2.76219750e+07f, 2.77798000e+07f, 2.77313500e+07f, 2.78898000e+07f, 2.78407250e+07f, 2.79998000e+07f, 2.79501000e+07f, 2.81098000e+07f, 2.80594750e+07f, 2.82198000e+07f, 2.81688500e+07f, 2.83298000e+07f, 2.82782250e+07f, 2.84398000e+07f, 2.83876000e+07f, 2.85498000e+07f, 2.84969750e+07f, 2.86598000e+07f, 2.86063500e+07f, 2.87698000e+07f, 2.87157250e+07f, 2.88798000e+07f, 2.88251000e+07f, 2.89898000e+07f, 2.89344750e+07f, 2.90998000e+07f, 2.90438500e+07f, 2.92098000e+07f, 2.91532250e+07f, 2.93198000e+07f, 2.92626000e+07f, 2.94298000e+07f, 2.93719750e+07f, 2.95398000e+07f, 2.94813500e+07f, 2.96498000e+07f, 2.95907250e+07f, 2.97598000e+07f, 2.97001000e+07f, 2.98698000e+07f, 2.98094750e+07f, 2.99798000e+07f, 2.99188500e+07f, 3.00898000e+07f, 3.00282250e+07f, 3.01998000e+07f, 3.01376000e+07f, 3.03098000e+07f, 3.02469750e+07f, 3.04198000e+07f, 3.03563500e+07f, 3.05298000e+07f, 3.04657250e+07f, 3.06398000e+07f, 3.05751000e+07f, 3.07498000e+07f, 3.06844750e+07f, 3.08598000e+07f, 3.07938500e+07f, 3.09698000e+07f, 3.09032250e+07f, 3.10798000e+07f, 3.10126000e+07f, 3.11898000e+07f, 3.11219750e+07f, 3.12998000e+07f, 3.12313500e+07f, 3.14098000e+07f, 3.13407250e+07f, 3.15198000e+07f, 3.14501000e+07f, 3.16298000e+07f, 3.15594750e+07f, 3.17398000e+07f, 3.16688500e+07f, 3.18498000e+07f, 3.17782250e+07f, 3.19598000e+07f, 3.18876000e+07f, 3.20698000e+07f, 3.19969750e+07f, 3.21798000e+07f, 3.21063500e+07f, 3.22898000e+07f, 3.22157250e+07f, 3.23998000e+07f, 3.23251000e+07f, 3.25098000e+07f, 3.24344750e+07f, 3.26198000e+07f, 3.25438500e+07f, 3.27298000e+07f, 3.26532250e+07f, 3.28398000e+07f, 3.27626000e+07f, 3.29498000e+07}; auto a = NDArrayFactory::create('c', {2, 72, 25}); for (int e = 0; e < a.lengthOf(); e++) @@ -1626,7 +1678,7 @@ TEST_F(NDArrayTest, applyReduce3Dot) { TEST_F(NDArrayTest, applyAllReduce3EuclideanDistance) { float xBuff[] = {1, 2, 3, 4, 5, 6}; float yBuff[] = {2, 2, 2, 2, 2, 2}; - float expBuff[] = {1.414214, 1.414214, 5.385165, 5.385165}; + float expBuff[] = {1.414214f, 1.414214f, 5.385165f, 5.385165f}; Nd4jLong expShapeInfo[] = {2, 2, 2, 2, 1, 8192, 1, 99}; Nd4jLong xShapeInfo[] = {2, 2, 3, 3, 1, 8192, 1, 99}; @@ -1649,7 +1701,7 @@ TEST_F(NDArrayTest, applyAllReduce3EuclideanDistance) { TEST_F(NDArrayTest, applyReduce3EuclideanDistance) { float xBuff[] = {1, 2, 3, 4, 5, 6}; float yBuff[] = {2, 2, 2, 2, 2, 2}; - float expBuff[] = {1.414214, 1.414214, 5.385165, 5.385165}; + float expBuff[] = {1.414214f, 1.414214f, 5.385165f, 5.385165f}; Nd4jLong expShapeInfo[] = {2, 2, 2, 2, 1, 8192, 1, 99}; Nd4jLong xShapeInfo[] = {2, 2, 3, 3, 1, 8192, 1, 99}; @@ -1670,7 +1722,7 @@ TEST_F(NDArrayTest, applyReduce3EuclideanDistance) { TEST_F(NDArrayTest, TestVarianceAlongDimension1) { float xBuff[] = {1, 2, 3, 4, 5, 6}; - float expBuff[] = {0.816497, 0.816497}; + float expBuff[] = {0.816497f, 0.816497f}; Nd4jLong xShapeInfo[] = {2, 2, 3, 3, 1, 8192, 1, 99}; Nd4jLong expShapeInfo[] = {1, 2, 1, 8192, 1, 99}; @@ -1688,7 +1740,7 @@ TEST_F(NDArrayTest, TestVarianceAlongDimension1) { ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestVarianceAlongDimension2) { float xBuff[] = {1, 2, 3, 4, 5, 6}; - float expBuff[] = {0.666667, 0.666667}; + float expBuff[] = {0.666667f, 0.666667f}; Nd4jLong xShapeInfo[] = {2, 2, 3, 3, 1, 8192, 1, 99}; Nd4jLong expShapeInfo[] = {1, 2, 1, 8192, 1, 99}; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java index 5b60ac0b4..3c6b969b8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java @@ -103,6 +103,7 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNormDerivative.class, org.nd4j.linalg.api.ops.impl.layers.convolution.Col2Im.class, org.nd4j.linalg.api.ops.impl.layers.convolution.Conv1D.class, + org.nd4j.linalg.api.ops.impl.layers.convolution.Conv1DDerivative.class, org.nd4j.linalg.api.ops.impl.layers.convolution.Conv2D.class, org.nd4j.linalg.api.ops.impl.layers.convolution.Conv2DDerivative.class, org.nd4j.linalg.api.ops.impl.layers.convolution.Conv3D.class, diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/NonMaxSuppression.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/NonMaxSuppression.java index 75b82dc29..f8763c41a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/NonMaxSuppression.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/NonMaxSuppression.java @@ -60,7 +60,7 @@ public class NonMaxSuppression extends DynamicCustomOp { @Override public String[] tensorflowNames() { - return new String[]{"NonMaxSuppression", "NonMaxSuppressionV2","NonMaxSuppressionV3","NonMaxSuppressionV4"}; + return new String[]{"NonMaxSuppression", "NonMaxSuppressionV2"}; } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ResizeBilinear.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ResizeBilinear.java index be6eb3730..b6a96699c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ResizeBilinear.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ResizeBilinear.java @@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.image; import lombok.NoArgsConstructor; import lombok.NonNull; import lombok.NoArgsConstructor; +import lombok.val; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; @@ -43,20 +44,25 @@ import java.util.Map; @NoArgsConstructor public class ResizeBilinear extends DynamicCustomOp { protected boolean alignCorners = false; + protected boolean halfPixelCenters = false; protected Integer height = null; protected Integer width = null; - public ResizeBilinear(@NonNull SameDiff sd, @NonNull SDVariable input, int height, int width, boolean alignCorners){ + public ResizeBilinear(@NonNull SameDiff sd, @NonNull SDVariable input, int height, int width, + boolean alignCorners, boolean halfPixelCenters){ super(sd, input); this.alignCorners = alignCorners; this.height = height; this.width = width; + this.halfPixelCenters = halfPixelCenters; addArgs(); } - public ResizeBilinear(@NonNull INDArray x, INDArray z, int height, int width, boolean alignCorners){ + public ResizeBilinear(@NonNull INDArray x, INDArray z, int height, int width, + boolean alignCorners, boolean halfPixelCenters) { super(new INDArray[]{x}, new INDArray[]{z}); this.alignCorners = alignCorners; + this.halfPixelCenters = halfPixelCenters; this.height = height; this.width = width; addArgs(); @@ -76,7 +82,12 @@ public class ResizeBilinear extends DynamicCustomOp { public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); - this.alignCorners = attributesForNode.get("align_corners").getB(); + val attrC = attributesForNode.get("align_corners"); + val attrH = attributesForNode.get("half_pixel_centers"); + + this.alignCorners = attrC != null ? attrC.getB() : false; + this.halfPixelCenters = attrH != null ? attrH.getB() : false; + addArgs(); } @@ -87,8 +98,7 @@ public class ResizeBilinear extends DynamicCustomOp { iArguments.add(Long.valueOf(height)); iArguments.add(Long.valueOf(width)); } - iArguments.add(alignCorners ? 1L : 0L); - + addBArgument(alignCorners, halfPixelCenters); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPoolWithArgmax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPoolWithArgmax.java index 58602d85e..b966d4389 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPoolWithArgmax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPoolWithArgmax.java @@ -204,7 +204,7 @@ public class MaxPoolWithArgmax extends DynamicCustomOp { if(attributesForNode.containsKey("argmax")) { outputType = TFGraphMapper.convertType(attributesForNode.get("argmax").getType()); } else { - outputType = DataType.UINT32; + outputType = DataType.LONG; } } @@ -278,7 +278,7 @@ public class MaxPoolWithArgmax extends DynamicCustomOp { Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 1, "Expected 1 input data type for %s, got %s", getClass(), inputDataTypes); List result = new ArrayList<>(); result.add(inputDataTypes.get(0)); - result.add(outputType == null ? DataType.UINT32 : outputType); + result.add(outputType == null ? DataType.INT : outputType); return result; } } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java index 45a20bfbc..94c5601c1 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java @@ -4584,6 +4584,7 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); * returns reference on array element with given index */ + /** * returns array element with given index * i - element index in array @@ -5171,6 +5172,8 @@ NDArray& NDArray::operator()(const Nd4jLong* idx) { + + //////////////////////////////////////////////////////////////////////// @@ -5179,6 +5182,8 @@ NDArray& NDArray::operator()(const Nd4jLong* idx) { + + // #ifndef __JAVACPP_HACK__ // #endif diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java index e9a36d49f..0ba5d1293 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java @@ -4587,6 +4587,7 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); * returns reference on array element with given index */ + /** * returns array element with given index * i - element index in array @@ -5174,6 +5175,8 @@ NDArray& NDArray::operator()(const Nd4jLong* idx) { + + //////////////////////////////////////////////////////////////////////// @@ -5182,6 +5185,8 @@ NDArray& NDArray::operator()(const Nd4jLong* idx) { + + // #ifndef __JAVACPP_HACK__ // #endif @@ -18280,7 +18285,6 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); /** * This op calculates polygamma function psi^(n)(x). Implementation is based on serial representation written in * terms of the Hurwitz zeta function: polygamma = (-1)^{n+1} * n! * zeta(n+1, x). - * Currently the case n = 0 is not supported. * * Input arrays: * 0: n - define derivative order (n+1), type integer (however currently is implemented as float casted to integer) @@ -18309,6 +18313,34 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); } // #endif + /** + * This op calculates digamma function psi(x) = derivative of log(Gamma(x)) + * + * Input arrays: + * 0: x - abscissa points where to evaluate the digamma function, type float + * + * Output array: + * 0: values of digamma function at corresponding x, type float + * + */ +// #if NOT_EXCLUDED(OP_digamma) + @Namespace("nd4j::ops") public static class digamma extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public digamma(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public digamma(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public digamma position(long position) { + return (digamma)super.position(position); + } + + public digamma() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } +// #endif + /** * This operation takes shape as first argument, and returns new NDArray filled with specific scalar value. * Input arrays: @@ -18398,9 +18430,10 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); * This operation adjusts image hue by delta * Input arrays: * 0 - input array with rank >= 3, must have at least one dimension equal 3, that is dimension containing channels. + * 1 - optional argument, input scalar-array containing delta * * T arguments: - * 0 - delta value + * 0 - optional argument, delta value * * Int arguments: * 0 - optional argument, corresponds to dimension with 3 channels @@ -18427,9 +18460,10 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); * This operation adjusts image saturation by delta * Input arrays: * 0 - input array with rank >= 3, must have at least one dimension equal 3, that is dimension containing channels. + * 1 - optional argument, input scalar-array containing saturation factor * * T arguments: - * 0 - saturation factor + * 0 - optional argument, saturation factor * * Int arguments: * 0 - optional argument, corresponds to dimension with 3 channels @@ -18456,9 +18490,10 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); * This operation adjusts image contrast by given factor ( z = (x - mean) * factor + mean ) * Input arrays: * 0 - input array with rank >= 3, must have last one dimension equal 3, that is dimension containing channels. + * 1 - optional argument, input scalar-array containing saturation contrast factor * * T arguments: - * 0 - contrast factor + * 0 - optional argument, contrast factor * */ // #if NOT_EXCLUDED(OP_adjust_contrast) @@ -21053,7 +21088,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); // #endif /** - * compare_and_bitpack - compare with greater and pack result with uint8 + * compare_and_bitpack - compare with greater and pack result with uint8 * * input params: * 0 - NDArray (input) diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java index 9dd529399..6a32d9ea9 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java @@ -760,7 +760,7 @@ public class LayerOpValidation extends BaseOpValidation { .isSameMode(true) .build(); - SDVariable[] results = sd.nn().maxPoolWithArgmax(new String[]{"",""}, in, pooling2DConfig); + SDVariable[] results = sd.nn().maxPoolWithArgmax(new String[]{"out","idx"}, in, pooling2DConfig); assertArrayEquals(inArr.shape(), results[0].eval().shape()); assertArrayEquals(inArr.shape(), results[1].eval().shape()); } @@ -1050,7 +1050,7 @@ public class LayerOpValidation extends BaseOpValidation { SDVariable in = sd.var("in", inArr); SDVariable w = sd.var("w", wArr); - SDVariable res = sd.cnn.conv1d(in, w, Conv1DConfig.builder().k(kernel).build()); + SDVariable res = sd.cnn.conv1d(in, w, Conv1DConfig.builder().k(kernel).paddingMode(PaddingMode.VALID).build()); INDArray expected = Nd4j.createFromArray( new double[][][]{ diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/ConvConfigTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/ConvConfigTests.java index 996ccff7f..a6f7b6bea 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/ConvConfigTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/ConvConfigTests.java @@ -23,13 +23,7 @@ import static org.junit.Assert.fail; import org.junit.Assert; import org.junit.Test; import org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv2D; -import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig; -import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig; -import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig; -import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv2DConfig; -import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv3DConfig; -import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig; -import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling3DConfig; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.*; public class ConvConfigTests { @@ -489,24 +483,24 @@ public class ConvConfigTests { @Test public void testConv1D(){ - Conv1DConfig.builder().k(2).build(); + Conv1DConfig.builder().k(2).paddingMode(PaddingMode.SAME).build(); try{ - Conv1DConfig.builder().k(0).build(); + Conv1DConfig.builder().k(0).paddingMode(PaddingMode.SAME).build(); fail(); } catch (IllegalArgumentException e){ assertTrue(e.getMessage().contains("Kernel")); } try{ - Conv1DConfig.builder().k(4).s(-2).build(); + Conv1DConfig.builder().k(4).s(-2).paddingMode(PaddingMode.SAME).build(); fail(); } catch (IllegalArgumentException e){ assertTrue(e.getMessage().contains("Stride")); } try{ - Conv1DConfig.builder().k(3).p(-2).build(); + Conv1DConfig.builder().k(3).p(-2).paddingMode(PaddingMode.SAME).build(); fail(); } catch (IllegalArgumentException e){ assertTrue(e.getMessage().contains("Padding")); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java index ec65d71df..bc9f03e2f 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java @@ -117,9 +117,6 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a // 2019/11/15 - failure https://github.com/eclipse/deeplearning4j/issues/8402 "fake_quant/min_max_args_per_channel.*", - // 2019/11/15 - failure https://github.com/eclipse/deeplearning4j/issues/8403 - "resize_bilinear/int32.*", - // Suggesting TF 1.15 bug "non_max_suppression_v2/float16.*", diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java index fbb1ddb85..742ffae66 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java @@ -972,7 +972,7 @@ public class CustomOpsTests extends BaseNd4jTest { INDArray x = Nd4j.rand(1, 2,3,4); INDArray z = Nd4j.createUninitialized(x.shape()); boolean align = false; - val op = new ResizeBilinear(x, z, 10, 10, align); + val op = new ResizeBilinear(x, z, 10, 10, align, false); Nd4j.exec(op); } @@ -1174,6 +1174,7 @@ public class CustomOpsTests extends BaseNd4jTest { assertEquals(expected, x); } + @Ignore("AS failed 2019/12/04") @Test public void testPolygamma() { INDArray n = Nd4j.linspace(DataType.FLOAT, 1.0, 1.0, 9).reshape(3,3);