From 63ed202057d301755184dbb93fc777d4e590748c Mon Sep 17 00:00:00 2001 From: Susan Eraly Date: Wed, 4 Dec 2019 18:24:37 -0800 Subject: [PATCH 01/18] cleaned up bert iterator tests (#110) Signed-off-by: eraly --- .../iterator/TestBertIterator.java | 681 +++++++++--------- 1 file changed, 338 insertions(+), 343 deletions(-) diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/iterator/TestBertIterator.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/iterator/TestBertIterator.java index a6716ba40..52644c360 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/iterator/TestBertIterator.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/iterator/TestBertIterator.java @@ -17,11 +17,13 @@ package org.deeplearning4j.iterator; +import lombok.Getter; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.iterator.bert.BertMaskedLMMasker; +import org.deeplearning4j.iterator.provider.CollectionLabeledPairSentenceProvider; +import org.deeplearning4j.iterator.provider.CollectionLabeledSentenceProvider; import org.deeplearning4j.text.tokenization.tokenizerfactory.BertWordPieceTokenizerFactory; import org.junit.Test; -import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.MultiDataSet; @@ -42,8 +44,12 @@ import static org.junit.Assert.*; public class TestBertIterator extends BaseDL4JTest { - private File pathToVocab = Resources.asFile("other/vocab.txt"); + private static File pathToVocab = Resources.asFile("other/vocab.txt"); private static Charset c = StandardCharsets.UTF_8; + private static String shortSentence = "I saw a girl with a telescope."; + private static String longSentence = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum"; + private static String sentenceA = "Goodnight noises everywhere"; + private static String sentenceB = "Goodnight moon"; public TestBertIterator() throws IOException { } @@ -51,20 +57,15 @@ public class TestBertIterator extends BaseDL4JTest { @Test(timeout = 20000L) public void testBertSequenceClassification() throws Exception { - String toTokenize1 = "I saw a girl with a telescope."; - String toTokenize2 = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum"; - List forInference = new ArrayList<>(); - forInference.add(toTokenize1); - forInference.add(toTokenize2); - BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c); - + int minibatchSize = 2; + TestSentenceHelper testHelper = new TestSentenceHelper(); BertIterator b = BertIterator.builder() - .tokenizer(t) + .tokenizer(testHelper.getTokenizer()) .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, 16) - .minibatchSize(2) - .sentenceProvider(new TestSentenceProvider()) + .minibatchSize(minibatchSize) + .sentenceProvider(testHelper.getSentenceProvider()) .featureArrays(BertIterator.FeatureArrays.INDICES_MASK) - .vocabMap(t.getVocab()) + .vocabMap(testHelper.getTokenizer().getVocab()) .task(BertIterator.Task.SEQ_CLASSIFICATION) .build(); @@ -73,82 +74,77 @@ public class TestBertIterator extends BaseDL4JTest { System.out.println(mds.getFeatures(0)); System.out.println(mds.getFeaturesMaskArray(0)); - - INDArray expEx0 = Nd4j.create(DataType.INT, 1, 16); - INDArray expM0 = Nd4j.create(DataType.INT, 1, 16); - List tokens = t.create(toTokenize1).getTokens(); - Map m = t.getVocab(); - for (int i = 0; i < tokens.size(); i++) { - int idx = m.get(tokens.get(i)); - expEx0.putScalar(0, i, idx); - expM0.putScalar(0, i, 1); - } - - INDArray expEx1 = Nd4j.create(DataType.INT, 1, 16); - INDArray expM1 = Nd4j.create(DataType.INT, 1, 16); - List tokens2 = t.create(toTokenize2).getTokens(); - for (int i = 0; i < tokens2.size(); i++) { - String token = tokens2.get(i); - if (!m.containsKey(token)) { - throw new IllegalStateException("Unknown token: \"" + token + "\""); + INDArray expF = Nd4j.create(DataType.INT, 1, 16); + INDArray expM = Nd4j.create(DataType.INT, 1, 16); + Map m = testHelper.getTokenizer().getVocab(); + for (int i = 0; i < minibatchSize; i++) { + INDArray expFTemp = Nd4j.create(DataType.INT, 1, 16); + INDArray expMTemp = Nd4j.create(DataType.INT, 1, 16); + List tokens = testHelper.getTokenizedSentences().get(i); + System.out.println(tokens); + for (int j = 0; j < tokens.size(); j++) { + String token = tokens.get(j); + if (!m.containsKey(token)) { + throw new IllegalStateException("Unknown token: \"" + token + "\""); + } + int idx = m.get(token); + expFTemp.putScalar(0, j, idx); + expMTemp.putScalar(0, j, 1); + } + if (i == 0) { + expF = expFTemp.dup(); + expM = expMTemp.dup(); + } else { + expF = Nd4j.vstack(expF, expFTemp); + expM = Nd4j.vstack(expM, expMTemp); } - int idx = m.get(token); - expEx1.putScalar(0, i, idx); - expM1.putScalar(0, i, 1); } - - INDArray expF = Nd4j.vstack(expEx0, expEx1); - INDArray expM = Nd4j.vstack(expM0, expM1); - assertEquals(expF, mds.getFeatures(0)); assertEquals(expM, mds.getFeaturesMaskArray(0)); - assertEquals(expF, b.featurizeSentences(forInference).getFirst()[0]); - assertEquals(expM, b.featurizeSentences(forInference).getSecond()[0]); + assertEquals(expF, b.featurizeSentences(testHelper.getSentences()).getFirst()[0]); + assertEquals(expM, b.featurizeSentences(testHelper.getSentences()).getSecond()[0]); - b.next(); //pop the third element assertFalse(b.hasNext()); b.reset(); assertTrue(b.hasNext()); - forInference.set(0, toTokenize2); //Same thing, but with segment ID also b = BertIterator.builder() - .tokenizer(t) + .tokenizer(testHelper.getTokenizer()) .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, 16) - .minibatchSize(2) - .sentenceProvider(new TestSentenceProvider()) + .minibatchSize(minibatchSize) + .sentenceProvider(testHelper.getSentenceProvider()) .featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID) - .vocabMap(t.getVocab()) + .vocabMap(testHelper.getTokenizer().getVocab()) .task(BertIterator.Task.SEQ_CLASSIFICATION) .build(); mds = b.next(); assertEquals(2, mds.getFeatures().length); - //assertEquals(2, mds.getFeaturesMaskArrays().length); second element is null... - assertEquals(2, b.featurizeSentences(forInference).getFirst().length); //Segment ID should be all 0s for single segment task INDArray segmentId = expM.like(); assertEquals(segmentId, mds.getFeatures(1)); - assertEquals(segmentId, b.featurizeSentences(forInference).getFirst()[1]); + assertEquals(segmentId, b.featurizeSentences(testHelper.getSentences()).getFirst()[1]); } @Test(timeout = 20000L) public void testBertUnsupervised() throws Exception { + int minibatchSize = 2; + TestSentenceHelper testHelper = new TestSentenceHelper(); //Task 1: Unsupervised - BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c); BertIterator b = BertIterator.builder() - .tokenizer(t) + .tokenizer(testHelper.getTokenizer()) .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, 16) - .minibatchSize(2) - .sentenceProvider(new TestSentenceProvider()) + .minibatchSize(minibatchSize) + .sentenceProvider(testHelper.getSentenceProvider()) .featureArrays(BertIterator.FeatureArrays.INDICES_MASK) - .vocabMap(t.getVocab()) + .vocabMap(testHelper.getTokenizer().getVocab()) .task(BertIterator.Task.UNSUPERVISED) .masker(new BertMaskedLMMasker(new Random(12345), 0.2, 0.5, 0.5)) .unsupervisedLabelFormat(BertIterator.UnsupervisedLabelFormat.RANK2_IDX) .maskToken("[MASK]") .build(); - System.out.println("Mask token index: " + t.getVocab().get("[MASK]")); + System.out.println("Mask token index: " + testHelper.getTokenizer().getVocab().get("[MASK]")); MultiDataSet mds = b.next(); System.out.println(mds.getFeatures(0)); @@ -156,7 +152,6 @@ public class TestBertIterator extends BaseDL4JTest { System.out.println(mds.getLabels(0)); System.out.println(mds.getLabelsMaskArray(0)); - b.next(); //pop the third element assertFalse(b.hasNext()); b.reset(); assertTrue(b.hasNext()); @@ -164,40 +159,34 @@ public class TestBertIterator extends BaseDL4JTest { @Test(timeout = 20000L) public void testLengthHandling() throws Exception { - String toTokenize1 = "I saw a girl with a telescope."; - String toTokenize2 = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum"; - List forInference = new ArrayList<>(); - forInference.add(toTokenize1); - forInference.add(toTokenize2); - BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c); - INDArray expEx0 = Nd4j.create(DataType.INT, 1, 16); - INDArray expM0 = Nd4j.create(DataType.INT, 1, 16); - List tokens = t.create(toTokenize1).getTokens(); - System.out.println(tokens); - Map m = t.getVocab(); - for (int i = 0; i < tokens.size(); i++) { - int idx = m.get(tokens.get(i)); - expEx0.putScalar(0, i, idx); - expM0.putScalar(0, i, 1); - } - - INDArray expEx1 = Nd4j.create(DataType.INT, 1, 16); - INDArray expM1 = Nd4j.create(DataType.INT, 1, 16); - List tokens2 = t.create(toTokenize2).getTokens(); - System.out.println(tokens2); - for (int i = 0; i < tokens2.size(); i++) { - String token = tokens2.get(i); - if (!m.containsKey(token)) { - throw new IllegalStateException("Unknown token: \"" + token + "\""); + int minibatchSize = 2; + TestSentenceHelper testHelper = new TestSentenceHelper(); + INDArray expF = Nd4j.create(DataType.INT, 1, 16); + INDArray expM = Nd4j.create(DataType.INT, 1, 16); + Map m = testHelper.getTokenizer().getVocab(); + for (int i = 0; i < minibatchSize; i++) { + List tokens = testHelper.getTokenizedSentences().get(i); + INDArray expFTemp = Nd4j.create(DataType.INT, 1, 16); + INDArray expMTemp = Nd4j.create(DataType.INT, 1, 16); + System.out.println(tokens); + for (int j = 0; j < tokens.size(); j++) { + String token = tokens.get(j); + if (!m.containsKey(token)) { + throw new IllegalStateException("Unknown token: \"" + token + "\""); + } + int idx = m.get(token); + expFTemp.putScalar(0, j, idx); + expMTemp.putScalar(0, j, 1); + } + if (i == 0) { + expF = expFTemp.dup(); + expM = expMTemp.dup(); + } else { + expF = Nd4j.vstack(expF, expFTemp); + expM = Nd4j.vstack(expM, expMTemp); } - int idx = m.get(token); - expEx1.putScalar(0, i, idx); - expM1.putScalar(0, i, 1); } - INDArray expF = Nd4j.vstack(expEx0, expEx1); - INDArray expM = Nd4j.vstack(expM0, expM1); - //-------------------------------------------------------------- //Fixed length: clip or pad - already tested in other tests @@ -205,12 +194,12 @@ public class TestBertIterator extends BaseDL4JTest { //Any length: as long as we need to fit longest sequence BertIterator b = BertIterator.builder() - .tokenizer(t) + .tokenizer(testHelper.getTokenizer()) .lengthHandling(BertIterator.LengthHandling.ANY_LENGTH, -1) - .minibatchSize(2) - .sentenceProvider(new TestSentenceProvider()) + .minibatchSize(minibatchSize) + .sentenceProvider(testHelper.getSentenceProvider()) .featureArrays(BertIterator.FeatureArrays.INDICES_MASK) - .vocabMap(t.getVocab()) + .vocabMap(testHelper.getTokenizer().getVocab()) .task(BertIterator.Task.SEQ_CLASSIFICATION) .build(); MultiDataSet mds = b.next(); @@ -219,20 +208,19 @@ public class TestBertIterator extends BaseDL4JTest { assertArrayEquals(expShape, mds.getFeaturesMaskArray(0).shape()); assertEquals(expF.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 14)), mds.getFeatures(0)); assertEquals(expM.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 14)), mds.getFeaturesMaskArray(0)); - assertEquals(mds.getFeatures(0), b.featurizeSentences(forInference).getFirst()[0]); - assertEquals(mds.getFeaturesMaskArray(0), b.featurizeSentences(forInference).getSecond()[0]); + assertEquals(mds.getFeatures(0), b.featurizeSentences(testHelper.getSentences()).getFirst()[0]); + assertEquals(mds.getFeaturesMaskArray(0), b.featurizeSentences(testHelper.getSentences()).getSecond()[0]); //Clip only: clip to maximum, but don't pad if less b = BertIterator.builder() - .tokenizer(t) + .tokenizer(testHelper.getTokenizer()) .lengthHandling(BertIterator.LengthHandling.CLIP_ONLY, 20) - .minibatchSize(2) - .sentenceProvider(new TestSentenceProvider()) + .minibatchSize(minibatchSize) + .sentenceProvider(testHelper.getSentenceProvider()) .featureArrays(BertIterator.FeatureArrays.INDICES_MASK) - .vocabMap(t.getVocab()) + .vocabMap(testHelper.getTokenizer().getVocab()) .task(BertIterator.Task.SEQ_CLASSIFICATION) .build(); - mds = b.next(); expShape = new long[]{2, 14}; assertArrayEquals(expShape, mds.getFeatures(0).shape()); assertArrayEquals(expShape, mds.getFeaturesMaskArray(0).shape()); @@ -241,54 +229,38 @@ public class TestBertIterator extends BaseDL4JTest { @Test(timeout = 20000L) public void testMinibatchPadding() throws Exception { Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT); - String toTokenize1 = "I saw a girl with a telescope."; - String toTokenize2 = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum"; - String toTokenize3 = "Goodnight noises everywhere"; - List forInference = new ArrayList<>(); - forInference.add(toTokenize1); - forInference.add(toTokenize2); - forInference.add(toTokenize3); - BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c); - INDArray expEx0 = Nd4j.create(DataType.INT, 1, 16); - INDArray expM0 = Nd4j.create(DataType.INT, 1, 16); - List tokens = t.create(toTokenize1).getTokens(); - Map m = t.getVocab(); - for (int i = 0; i < tokens.size(); i++) { - int idx = m.get(tokens.get(i)); - expEx0.putScalar(0, i, idx); - expM0.putScalar(0, i, 1); - } - - INDArray expEx1 = Nd4j.create(DataType.INT, 1, 16); - INDArray expM1 = Nd4j.create(DataType.INT, 1, 16); - List tokens2 = t.create(toTokenize2).getTokens(); - for (int i = 0; i < tokens2.size(); i++) { - String token = tokens2.get(i); - if (!m.containsKey(token)) { - throw new IllegalStateException("Unknown token: \"" + token + "\""); - } - int idx = m.get(token); - expEx1.putScalar(0, i, idx); - expM1.putScalar(0, i, 1); - } - - INDArray expEx3 = Nd4j.create(DataType.INT, 1, 16); - INDArray expM3 = Nd4j.create(DataType.INT, 1, 16); - List tokens3 = t.create(toTokenize3).getTokens(); - for (int i = 0; i < tokens3.size(); i++) { - String token = tokens3.get(i); - if (!m.containsKey(token)) { - throw new IllegalStateException("Unknown token: \"" + token + "\""); - } - int idx = m.get(token); - expEx3.putScalar(0, i, idx); - expM3.putScalar(0, i, 1); - } - + int minibatchSize = 3; + TestSentenceHelper testHelper = new TestSentenceHelper(minibatchSize); INDArray zeros = Nd4j.create(DataType.INT, 1, 16); - INDArray expF = Nd4j.vstack(expEx0, expEx1, expEx3, zeros); - INDArray expM = Nd4j.vstack(expM0, expM1, expM3, zeros); - INDArray expL = Nd4j.createFromArray(new float[][]{{1, 0}, {0, 1}, {1, 0}, {0, 0}}); + INDArray expF = Nd4j.create(DataType.INT, 1, 16); + INDArray expM = Nd4j.create(DataType.INT, 1, 16); + Map m = testHelper.getTokenizer().getVocab(); + for (int i = 0; i < minibatchSize; i++) { + List tokens = testHelper.getTokenizedSentences().get(i); + INDArray expFTemp = Nd4j.create(DataType.INT, 1, 16); + INDArray expMTemp = Nd4j.create(DataType.INT, 1, 16); + System.out.println(tokens); + for (int j = 0; j < tokens.size(); j++) { + String token = tokens.get(j); + if (!m.containsKey(token)) { + throw new IllegalStateException("Unknown token: \"" + token + "\""); + } + int idx = m.get(token); + expFTemp.putScalar(0, j, idx); + expMTemp.putScalar(0, j, 1); + } + if (i == 0) { + expF = expFTemp.dup(); + expM = expMTemp.dup(); + } else { + expF = Nd4j.vstack(expF.dup(), expFTemp); + expM = Nd4j.vstack(expM.dup(), expMTemp); + } + } + + expF = Nd4j.vstack(expF, zeros); + expM = Nd4j.vstack(expM, zeros); + INDArray expL = Nd4j.createFromArray(new float[][]{{0, 1}, {1, 0}, {0, 1}, {0, 0}}); INDArray expLM = Nd4j.create(DataType.FLOAT, 4, 1); expLM.putScalar(0, 0, 1); expLM.putScalar(1, 0, 1); @@ -297,13 +269,13 @@ public class TestBertIterator extends BaseDL4JTest { //-------------------------------------------------------------- BertIterator b = BertIterator.builder() - .tokenizer(t) + .tokenizer(testHelper.getTokenizer()) .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, 16) - .minibatchSize(4) + .minibatchSize(minibatchSize + 1) .padMinibatches(true) - .sentenceProvider(new TestSentenceProvider()) + .sentenceProvider(testHelper.getSentenceProvider()) .featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID) - .vocabMap(t.getVocab()) + .vocabMap(testHelper.getTokenizer().getVocab()) .task(BertIterator.Task.SEQ_CLASSIFICATION) .build(); @@ -323,170 +295,175 @@ public class TestBertIterator extends BaseDL4JTest { assertEquals(expL, mds.getLabels(0)); assertEquals(expLM, mds.getLabelsMaskArray(0)); - assertEquals(expF, b.featurizeSentences(forInference).getFirst()[0]); - assertEquals(expM, b.featurizeSentences(forInference).getSecond()[0]); + assertEquals(expF, b.featurizeSentences(testHelper.getSentences()).getFirst()[0]); + assertEquals(expM, b.featurizeSentences(testHelper.getSentences()).getSecond()[0]); } + /* + Checks that a mds from a pair sentence is equal to hstack'd mds from the left side and right side of the pair + Checks different lengths for max length to check popping and padding + */ @Test public void testSentencePairsSingle() throws IOException { - String shortSent = "I saw a girl with a telescope."; - String longSent = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum"; boolean prependAppend; - BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c); - int shortL = t.create(shortSent).countTokens(); - int longL = t.create(longSent).countTokens(); + int numOfSentences; + + TestSentenceHelper testHelper = new TestSentenceHelper(); + int shortL = testHelper.getShortestL(); + int longL = testHelper.getLongestL(); Triple multiDataSetTriple; - MultiDataSet shortLongPair, shortSentence, longSentence; + MultiDataSet fromPair, leftSide, rightSide; // check for pair max length exactly equal to sum of lengths - pop neither no padding // should be the same as hstack with segment ids 1 for second sentence prependAppend = true; - multiDataSetTriple = generateMultiDataSets(new Triple<>(shortL + longL, shortL, longL), prependAppend); - shortLongPair = multiDataSetTriple.getFirst(); - shortSentence = multiDataSetTriple.getSecond(); - longSentence = multiDataSetTriple.getThird(); - assertEquals(shortLongPair.getFeatures(0), Nd4j.hstack(shortSentence.getFeatures(0), longSentence.getFeatures(0))); - longSentence.getFeatures(1).addi(1); - assertEquals(shortLongPair.getFeatures(1), Nd4j.hstack(shortSentence.getFeatures(1), longSentence.getFeatures(1))); - assertEquals(shortLongPair.getFeaturesMaskArray(0), Nd4j.hstack(shortSentence.getFeaturesMaskArray(0), longSentence.getFeaturesMaskArray(0))); + numOfSentences = 1; + multiDataSetTriple = generateMultiDataSets(new Triple<>(shortL + longL, shortL, longL), prependAppend, numOfSentences); + fromPair = multiDataSetTriple.getFirst(); + leftSide = multiDataSetTriple.getSecond(); + rightSide = multiDataSetTriple.getThird(); + assertEquals(fromPair.getFeatures(0), Nd4j.hstack(leftSide.getFeatures(0), rightSide.getFeatures(0))); + rightSide.getFeatures(1).addi(1); //add 1 for right side segment ids + assertEquals(fromPair.getFeatures(1), Nd4j.hstack(leftSide.getFeatures(1), rightSide.getFeatures(1))); + assertEquals(fromPair.getFeaturesMaskArray(0), Nd4j.hstack(leftSide.getFeaturesMaskArray(0), rightSide.getFeaturesMaskArray(0))); //check for pair max length greater than sum of lengths - pop neither with padding // features should be the same as hstack of shorter and longer padded with prepend/append // segment id should 1 only in the longer for part of the length of the sentence prependAppend = true; - multiDataSetTriple = generateMultiDataSets(new Triple<>(shortL + longL + 5, shortL, longL + 5), prependAppend); - shortLongPair = multiDataSetTriple.getFirst(); - shortSentence = multiDataSetTriple.getSecond(); - longSentence = multiDataSetTriple.getThird(); - assertEquals(shortLongPair.getFeatures(0), Nd4j.hstack(shortSentence.getFeatures(0), longSentence.getFeatures(0))); - longSentence.getFeatures(1).get(NDArrayIndex.all(), NDArrayIndex.interval(0, longL + 1)).addi(1); //segmentId stays 0 for the padded part - assertEquals(shortLongPair.getFeatures(1), Nd4j.hstack(shortSentence.getFeatures(1), longSentence.getFeatures(1))); - assertEquals(shortLongPair.getFeaturesMaskArray(0), Nd4j.hstack(shortSentence.getFeaturesMaskArray(0), longSentence.getFeaturesMaskArray(0))); + numOfSentences = 1; + multiDataSetTriple = generateMultiDataSets(new Triple<>(shortL + longL + 5, shortL, longL + 5), prependAppend, numOfSentences); + fromPair = multiDataSetTriple.getFirst(); + leftSide = multiDataSetTriple.getSecond(); + rightSide = multiDataSetTriple.getThird(); + assertEquals(fromPair.getFeatures(0), Nd4j.hstack(leftSide.getFeatures(0), rightSide.getFeatures(0))); + rightSide.getFeatures(1).get(NDArrayIndex.all(), NDArrayIndex.interval(0, longL + 1)).addi(1); //segmentId stays 0 for the padded part + assertEquals(fromPair.getFeatures(1), Nd4j.hstack(leftSide.getFeatures(1), rightSide.getFeatures(1))); + assertEquals(fromPair.getFeaturesMaskArray(0), Nd4j.hstack(leftSide.getFeaturesMaskArray(0), rightSide.getFeaturesMaskArray(0))); //check for pair max length less than shorter sentence - pop both //should be the same as hstack with segment ids 1 for second sentence if no prepend/append - int maxL = shortL - 2; + int maxL = 5;//checking odd + numOfSentences = 3; prependAppend = false; - multiDataSetTriple = generateMultiDataSets(new Triple<>(maxL, maxL / 2, maxL - maxL / 2), prependAppend); - shortLongPair = multiDataSetTriple.getFirst(); - shortSentence = multiDataSetTriple.getSecond(); - longSentence = multiDataSetTriple.getThird(); - assertEquals(shortLongPair.getFeatures(0), Nd4j.hstack(shortSentence.getFeatures(0), longSentence.getFeatures(0))); - longSentence.getFeatures(1).addi(1); - assertEquals(shortLongPair.getFeatures(1), Nd4j.hstack(shortSentence.getFeatures(1), longSentence.getFeatures(1))); - assertEquals(shortLongPair.getFeaturesMaskArray(0), Nd4j.hstack(shortSentence.getFeaturesMaskArray(0), longSentence.getFeaturesMaskArray(0))); + multiDataSetTriple = generateMultiDataSets(new Triple<>(maxL, maxL / 2, maxL - maxL / 2), prependAppend, numOfSentences); + fromPair = multiDataSetTriple.getFirst(); + leftSide = multiDataSetTriple.getSecond(); + rightSide = multiDataSetTriple.getThird(); + assertEquals(fromPair.getFeatures(0), Nd4j.hstack(leftSide.getFeatures(0), rightSide.getFeatures(0))); + rightSide.getFeatures(1).addi(1); + assertEquals(fromPair.getFeatures(1), Nd4j.hstack(leftSide.getFeatures(1), rightSide.getFeatures(1))); + assertEquals(fromPair.getFeaturesMaskArray(0), Nd4j.hstack(leftSide.getFeaturesMaskArray(0), rightSide.getFeaturesMaskArray(0))); } + /* + Same idea as previous test - construct mds from bert iterator with sep sentences and check against one with pairs + Checks various max lengths + Has sentences of varying lengths + */ @Test public void testSentencePairsUnequalLengths() throws IOException { - //check for pop only longer (i.e between longer and longer + shorter), first row pop from second sentence, next row pop from first sentence, nothing to pop in the third row - //should be identical to hstack if there is no append, prepend - //batch size is 2 - int mbS = 4; - String shortSent = "I saw a girl with a telescope."; - String longSent = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum"; - String sent1 = "Goodnight noises everywhere"; //shorter than shortSent - no popping - String sent2 = "Goodnight moon"; //shorter than shortSent - no popping - BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c); - int shortL = t.create(shortSent).countTokens(); - int longL = t.create(longSent).countTokens(); - int sent1L = t.create(sent1).countTokens(); - int sent2L = t.create(sent2).countTokens(); - //won't check 2*shortL + 1 because this will always pop on the left - for (int maxL = longL + shortL - 1; maxL > 2 * shortL; maxL--) { + + int minibatchSize = 4; + int numOfSentencesinIter = 3; + + TestSentencePairsHelper testPairHelper = new TestSentencePairsHelper(numOfSentencesinIter); + int shortL = testPairHelper.getShortL(); + int longL = testPairHelper.getLongL(); + int sent1L = testPairHelper.getSentenceALen(); + int sent2L = testPairHelper.getSentenceBLen(); + + System.out.println("Sentence Pairs, Left"); + System.out.println(testPairHelper.getSentencesLeft()); + System.out.println("Sentence Pairs, Right"); + System.out.println(testPairHelper.getSentencesRight()); + + //anything outside this range more will need to check padding,truncation + for (int maxL = longL + shortL; maxL > 2 * shortL + 1; maxL--) { + + System.out.println("Running for max length = " + maxL); + MultiDataSet leftMDS = BertIterator.builder() - .tokenizer(t) - .minibatchSize(mbS) + .tokenizer(testPairHelper.getTokenizer()) + .minibatchSize(minibatchSize) .featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID) - .vocabMap(t.getVocab()) + .vocabMap(testPairHelper.getTokenizer().getVocab()) .task(BertIterator.Task.SEQ_CLASSIFICATION) - .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, longL + 10) //random big num guaranteed to be longer than either - .sentenceProvider(new TestSentenceProvider()) + .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, longL * 10) //random big num guaranteed to be longer than either + .sentenceProvider(new TestSentenceHelper(numOfSentencesinIter).getSentenceProvider()) .padMinibatches(true) .build().next(); MultiDataSet rightMDS = BertIterator.builder() - .tokenizer(t) - .minibatchSize(mbS) + .tokenizer(testPairHelper.getTokenizer()) + .minibatchSize(minibatchSize) .featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID) - .vocabMap(t.getVocab()) + .vocabMap(testPairHelper.getTokenizer().getVocab()) .task(BertIterator.Task.SEQ_CLASSIFICATION) - .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, longL + 10) //random big num guaranteed to be longer than either - .sentenceProvider(new TestSentenceProvider(true)) + .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, longL * 10) //random big num guaranteed to be longer than either + .sentenceProvider(new TestSentenceHelper(true, numOfSentencesinIter).getSentenceProvider()) .padMinibatches(true) .build().next(); MultiDataSet pairMDS = BertIterator.builder() - .tokenizer(t) - .minibatchSize(mbS) + .tokenizer(testPairHelper.getTokenizer()) + .minibatchSize(minibatchSize) .featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID) - .vocabMap(t.getVocab()) + .vocabMap(testPairHelper.getTokenizer().getVocab()) .task(BertIterator.Task.SEQ_CLASSIFICATION) - .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, maxL) //random big num guaranteed to be longer than either - .sentencePairProvider(new TestSentencePairProvider()) + .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, maxL) + .sentencePairProvider(testPairHelper.getPairSentenceProvider()) .padMinibatches(true) .build().next(); - //Left sentences here are {{shortSent}, - // {longSent}, - // {Sent1}} - //Right sentences here are {{longSent}, - // {shortSent}, - // {Sent2}} - //The sentence pairs here are {{shortSent,longSent}, - // {longSent,shortSent} - // {Sent1, Sent2}} - //CHECK FEATURES - INDArray combinedFeat = Nd4j.create(DataType.INT,mbS,maxL); + INDArray combinedFeat = Nd4j.create(DataType.INT, minibatchSize, maxL); //left side INDArray leftFeatures = leftMDS.getFeatures(0); INDArray topLSentFeat = leftFeatures.getRow(0).get(NDArrayIndex.interval(0, shortL)); INDArray midLSentFeat = leftFeatures.getRow(1).get(NDArrayIndex.interval(0, maxL - shortL)); - INDArray bottomLSentFeat = leftFeatures.getRow(2).get(NDArrayIndex.interval(0,sent1L)); + INDArray bottomLSentFeat = leftFeatures.getRow(2).get(NDArrayIndex.interval(0, sent1L)); //right side INDArray rightFeatures = rightMDS.getFeatures(0); INDArray topRSentFeat = rightFeatures.getRow(0).get(NDArrayIndex.interval(0, maxL - shortL)); INDArray midRSentFeat = rightFeatures.getRow(1).get(NDArrayIndex.interval(0, shortL)); - INDArray bottomRSentFeat = rightFeatures.getRow(2).get(NDArrayIndex.interval(0,sent2L)); + INDArray bottomRSentFeat = rightFeatures.getRow(2).get(NDArrayIndex.interval(0, sent2L)); //expected pair - combinedFeat.getRow(0).addi(Nd4j.hstack(topLSentFeat,topRSentFeat)); - combinedFeat.getRow(1).addi(Nd4j.hstack(midLSentFeat,midRSentFeat)); - combinedFeat.getRow(2).get(NDArrayIndex.interval(0,sent1L+sent2L)).addi(Nd4j.hstack(bottomLSentFeat,bottomRSentFeat)); + combinedFeat.getRow(0).addi(Nd4j.hstack(topLSentFeat, topRSentFeat)); + combinedFeat.getRow(1).addi(Nd4j.hstack(midLSentFeat, midRSentFeat)); + combinedFeat.getRow(2).get(NDArrayIndex.interval(0, sent1L + sent2L)).addi(Nd4j.hstack(bottomLSentFeat, bottomRSentFeat)); assertEquals(maxL, pairMDS.getFeatures(0).shape()[1]); assertArrayEquals(combinedFeat.shape(), pairMDS.getFeatures(0).shape()); assertEquals(combinedFeat, pairMDS.getFeatures(0)); //CHECK SEGMENT ID - INDArray combinedFetSeg = Nd4j.create(DataType.INT, mbS, maxL); + INDArray combinedFetSeg = Nd4j.create(DataType.INT, minibatchSize, maxL); combinedFetSeg.get(NDArrayIndex.point(0), NDArrayIndex.interval(shortL, maxL)).addi(1); combinedFetSeg.get(NDArrayIndex.point(1), NDArrayIndex.interval(maxL - shortL, maxL)).addi(1); - combinedFetSeg.get(NDArrayIndex.point(2), NDArrayIndex.interval(sent1L, sent1L+sent2L)).addi(1); + combinedFetSeg.get(NDArrayIndex.point(2), NDArrayIndex.interval(sent1L, sent1L + sent2L)).addi(1); assertArrayEquals(combinedFetSeg.shape(), pairMDS.getFeatures(1).shape()); assertEquals(maxL, combinedFetSeg.shape()[1]); assertEquals(combinedFetSeg, pairMDS.getFeatures(1)); + + testPairHelper.getPairSentenceProvider().reset(); } } @Test public void testSentencePairFeaturizer() throws IOException { - String shortSent = "I saw a girl with a telescope."; - String longSent = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum"; - List> listSentencePair = new ArrayList<>(); - listSentencePair.add(new Pair<>(shortSent, longSent)); - listSentencePair.add(new Pair<>(longSent, shortSent)); - BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c); + int minibatchSize = 2; + TestSentencePairsHelper testPairHelper = new TestSentencePairsHelper(minibatchSize); BertIterator b = BertIterator.builder() - .tokenizer(t) - .minibatchSize(2) + .tokenizer(testPairHelper.getTokenizer()) + .minibatchSize(minibatchSize) .padMinibatches(true) .featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID) - .vocabMap(t.getVocab()) + .vocabMap(testPairHelper.getTokenizer().getVocab()) .task(BertIterator.Task.SEQ_CLASSIFICATION) .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, 128) - .sentencePairProvider(new TestSentencePairProvider()) + .sentencePairProvider(testPairHelper.getPairSentenceProvider()) .prependToken("[CLS]") .appendToken("[SEP]") .build(); @@ -494,23 +471,19 @@ public class TestBertIterator extends BaseDL4JTest { INDArray[] featuresArr = mds.getFeatures(); INDArray[] featuresMaskArr = mds.getFeaturesMaskArrays(); - Pair p = b.featurizeSentencePairs(listSentencePair); + Pair p = b.featurizeSentencePairs(testPairHelper.getSentencePairs()); assertEquals(p.getFirst().length, 2); assertEquals(featuresArr[0], p.getFirst()[0]); assertEquals(featuresArr[1], p.getFirst()[1]); - //assertEquals(p.getSecond().length, 2); assertEquals(featuresMaskArr[0], p.getSecond()[0]); - //assertEquals(featuresMaskArr[1], p.getSecond()[1]); } /** - * Returns three multidatasets from bert iterator based on given max lengths and whether to prepend/append + * Returns three multidatasets (one from pair of sentences and the other two from single sentence lists) from bert iterator + * with given max lengths and whether to prepend/append * Idea is the sentence pair dataset can be constructed from the single sentence datasets - * First one is constructed from a sentence pair "I saw a girl with a telescope." & "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum" - * Second one is constructed from the left of the sentence pair i.e "I saw a girl with a telescope." - * Third one is constructed from the right of the sentence pair i.e "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum" */ - private Triple generateMultiDataSets(Triple maxLengths, boolean prependAppend) throws IOException { + private Triple generateMultiDataSets(Triple maxLengths, boolean prependAppend, int numSentences) throws IOException { BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c); int maxforPair = maxLengths.getFirst(); int maxPartOne = maxLengths.getSecond(); @@ -518,133 +491,155 @@ public class TestBertIterator extends BaseDL4JTest { BertIterator.Builder commonBuilder; commonBuilder = BertIterator.builder() .tokenizer(t) - .minibatchSize(1) + .minibatchSize(4) .featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID) .vocabMap(t.getVocab()) .task(BertIterator.Task.SEQ_CLASSIFICATION); - BertIterator shortLongPairFirstIter = commonBuilder + BertIterator pairIter = commonBuilder .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, prependAppend ? maxforPair + 3 : maxforPair) - .sentencePairProvider(new TestSentencePairProvider()) + .sentencePairProvider(new TestSentencePairsHelper(numSentences).getPairSentenceProvider()) .prependToken(prependAppend ? "[CLS]" : null) .appendToken(prependAppend ? "[SEP]" : null) .build(); - BertIterator shortFirstIter = commonBuilder + BertIterator leftIter = commonBuilder .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, prependAppend ? maxPartOne + 2 : maxPartOne) - .sentenceProvider(new TestSentenceProvider()) + .sentenceProvider(new TestSentenceHelper(numSentences).getSentenceProvider()) .prependToken(prependAppend ? "[CLS]" : null) .appendToken(prependAppend ? "[SEP]" : null) .build(); - BertIterator longFirstIter = commonBuilder + BertIterator rightIter = commonBuilder .lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, prependAppend ? maxPartTwo + 1 : maxPartTwo) - .sentenceProvider(new TestSentenceProvider(true)) + .sentenceProvider(new TestSentenceHelper(true, numSentences).getSentenceProvider()) .prependToken(null) .appendToken(prependAppend ? "[SEP]" : null) .build(); - return new Triple<>(shortLongPairFirstIter.next(), shortFirstIter.next(), longFirstIter.next()); + return new Triple<>(pairIter.next(), leftIter.next(), rightIter.next()); } - private static class TestSentenceProvider implements LabeledSentenceProvider { + @Getter + private static class TestSentencePairsHelper { - private int pos = 0; - private boolean invert; + private List sentencesLeft; + private List sentencesRight; + private List> sentencePairs; + private List> tokenizedSentencesLeft; + private List> tokenizedSentencesRight; + private List labels; + private int shortL; + private int longL; + private int sentenceALen; + private int sentenceBLen; + private BertWordPieceTokenizerFactory tokenizer; + private CollectionLabeledPairSentenceProvider pairSentenceProvider; - private TestSentenceProvider() { - this.invert = false; + private TestSentencePairsHelper() throws IOException { + this(3); } - private TestSentenceProvider(boolean invert) { - this.invert = invert; - } - - @Override - public boolean hasNext() { - return pos < totalNumSentences(); - } - - @Override - public Pair nextSentence() { - Preconditions.checkState(hasNext()); - if (pos == 0) { - pos++; - if (!invert) return new Pair<>("I saw a girl with a telescope.", "positive"); - return new Pair<>("Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum", "negative"); - } else { - if (pos == 1) { - pos++; - if (!invert) return new Pair<>("Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum", "negative"); - return new Pair<>("I saw a girl with a telescope.", "positive"); + private TestSentencePairsHelper(int minibatchSize) throws IOException { + sentencesLeft = new ArrayList<>(); + sentencesRight = new ArrayList<>(); + sentencePairs = new ArrayList<>(); + labels = new ArrayList<>(); + tokenizedSentencesLeft = new ArrayList<>(); + tokenizedSentencesRight = new ArrayList<>(); + tokenizer = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c); + sentencesLeft.add(shortSentence); + sentencesRight.add(longSentence); + sentencePairs.add(new Pair<>(shortSentence, longSentence)); + labels.add("positive"); + if (minibatchSize > 1) { + sentencesLeft.add(longSentence); + sentencesRight.add(shortSentence); + sentencePairs.add(new Pair<>(longSentence, shortSentence)); + labels.add("negative"); + if (minibatchSize > 2) { + sentencesLeft.add(sentenceA); + sentencesRight.add(sentenceB); + sentencePairs.add(new Pair<>(sentenceA, sentenceB)); + labels.add("positive"); } - pos++; - if (!invert) - return new Pair<>("Goodnight noises everywhere", "positive"); - return new Pair<>("Goodnight moon", "positive"); } - } - - @Override - public void reset() { - pos = 0; - } - - @Override - public int totalNumSentences() { - return 3; - } - - @Override - public List allLabels() { - return Arrays.asList("positive", "negative"); - } - - @Override - public int numLabelClasses() { - return 2; + for (int i = 0; i < minibatchSize; i++) { + List tokensL = tokenizer.create(sentencesLeft.get(i)).getTokens(); + List tokensR = tokenizer.create(sentencesRight.get(i)).getTokens(); + if (i == 0) { + shortL = tokensL.size(); + longL = tokensR.size(); + } + if (i == 2) { + sentenceALen = tokensL.size(); + sentenceBLen = tokensR.size(); + } + tokenizedSentencesLeft.add(tokensL); + tokenizedSentencesRight.add(tokensR); + } + pairSentenceProvider = new CollectionLabeledPairSentenceProvider(sentencesLeft, sentencesRight, labels, null); } } - private static class TestSentencePairProvider implements LabeledPairSentenceProvider { + @Getter + private static class TestSentenceHelper { - private int pos = 0; + private List sentences; + private List> tokenizedSentences; + private List labels; + private int shortestL = 0; + private int longestL = 0; + private BertWordPieceTokenizerFactory tokenizer; + private CollectionLabeledSentenceProvider sentenceProvider; - @Override - public boolean hasNext() { - return pos < totalNumSentences(); + private TestSentenceHelper() throws IOException { + this(false, 2); } - @Override - public Triple nextSentencePair() { - Preconditions.checkState(hasNext()); - if (pos == 0) { - pos++; - return new Triple<>("I saw a girl with a telescope.", "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum", "positive"); - } else { - if (pos == 1) { - pos++; - return new Triple<>("Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum", "I saw a girl with a telescope.", "negative"); + private TestSentenceHelper(int minibatchSize) throws IOException { + this(false, minibatchSize); + } + + private TestSentenceHelper(boolean alternateOrder) throws IOException { + this(false, 3); + } + + private TestSentenceHelper(boolean alternateOrder, int minibatchSize) throws IOException { + sentences = new ArrayList<>(); + labels = new ArrayList<>(); + tokenizedSentences = new ArrayList<>(); + tokenizer = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c); + if (!alternateOrder) { + sentences.add(shortSentence); + labels.add("positive"); + if (minibatchSize > 1) { + sentences.add(longSentence); + labels.add("negative"); + if (minibatchSize > 2) { + sentences.add(sentenceA); + labels.add("positive"); + } + } + } else { + sentences.add(longSentence); + labels.add("negative"); + if (minibatchSize > 1) { + sentences.add(shortSentence); + labels.add("positive"); + if (minibatchSize > 2) { + sentences.add(sentenceB); + labels.add("positive"); + } } - pos++; - return new Triple<>("Goodnight noises everywhere", "Goodnight moon", "positive"); } - } - - @Override - public void reset() { - pos = 0; - } - - @Override - public int totalNumSentences() { - return 3; - } - - @Override - public List allLabels() { - return Arrays.asList("positive", "negative"); - } - - @Override - public int numLabelClasses() { - return 2; + for (int i = 0; i < sentences.size(); i++) { + List tokenizedSentence = tokenizer.create(sentences.get(i)).getTokens(); + if (i == 0) + shortestL = tokenizedSentence.size(); + if (tokenizedSentence.size() > longestL) + longestL = tokenizedSentence.size(); + if (tokenizedSentence.size() < shortestL) + shortestL = tokenizedSentence.size(); + tokenizedSentences.add(tokenizedSentence); + } + sentenceProvider = new CollectionLabeledSentenceProvider(sentences, labels, null); } } From 2052ce7026d5d5e2303d3d74533ee8765e7fefb5 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Thu, 5 Dec 2019 14:20:03 +1100 Subject: [PATCH 02/18] Various pre-release fixes (#111) * Various fixes Signed-off-by: AlexDBlack * Fix default dtypes for MaxPoolWithArgmax Signed-off-by: AlexDBlack --- .../generic/nn/pooling/maxpool_with_argmax.cpp | 2 +- .../imports/converters/ImportClassMapping.java | 1 + .../api/ops/impl/image/NonMaxSuppression.java | 2 +- .../layers/convolution/MaxPoolWithArgmax.java | 4 ++-- .../autodiff/opvalidation/LayerOpValidation.java | 4 ++-- .../nd4j/autodiff/samediff/ConvConfigTests.java | 16 +++++----------- 6 files changed, 12 insertions(+), 17 deletions(-) diff --git a/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool_with_argmax.cpp b/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool_with_argmax.cpp index bf5a3eb6e..5fe7455fc 100644 --- a/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool_with_argmax.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool_with_argmax.cpp @@ -46,7 +46,7 @@ namespace nd4j { getOpDescriptor() ->setAllowedInputTypes(nd4j::DataType::ANY) ->setAllowedOutputTypes(0, DataType::INHERIT) - ->setAllowedOutputTypes(1, DataType::INT64); + ->setAllowedOutputTypes(1, {ALL_INTS}); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java index 5b60ac0b4..3c6b969b8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java @@ -103,6 +103,7 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNormDerivative.class, org.nd4j.linalg.api.ops.impl.layers.convolution.Col2Im.class, org.nd4j.linalg.api.ops.impl.layers.convolution.Conv1D.class, + org.nd4j.linalg.api.ops.impl.layers.convolution.Conv1DDerivative.class, org.nd4j.linalg.api.ops.impl.layers.convolution.Conv2D.class, org.nd4j.linalg.api.ops.impl.layers.convolution.Conv2DDerivative.class, org.nd4j.linalg.api.ops.impl.layers.convolution.Conv3D.class, diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/NonMaxSuppression.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/NonMaxSuppression.java index 75b82dc29..f8763c41a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/NonMaxSuppression.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/NonMaxSuppression.java @@ -60,7 +60,7 @@ public class NonMaxSuppression extends DynamicCustomOp { @Override public String[] tensorflowNames() { - return new String[]{"NonMaxSuppression", "NonMaxSuppressionV2","NonMaxSuppressionV3","NonMaxSuppressionV4"}; + return new String[]{"NonMaxSuppression", "NonMaxSuppressionV2"}; } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPoolWithArgmax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPoolWithArgmax.java index 58602d85e..b966d4389 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPoolWithArgmax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/MaxPoolWithArgmax.java @@ -204,7 +204,7 @@ public class MaxPoolWithArgmax extends DynamicCustomOp { if(attributesForNode.containsKey("argmax")) { outputType = TFGraphMapper.convertType(attributesForNode.get("argmax").getType()); } else { - outputType = DataType.UINT32; + outputType = DataType.LONG; } } @@ -278,7 +278,7 @@ public class MaxPoolWithArgmax extends DynamicCustomOp { Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 1, "Expected 1 input data type for %s, got %s", getClass(), inputDataTypes); List result = new ArrayList<>(); result.add(inputDataTypes.get(0)); - result.add(outputType == null ? DataType.UINT32 : outputType); + result.add(outputType == null ? DataType.INT : outputType); return result; } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java index 9dd529399..6a32d9ea9 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java @@ -760,7 +760,7 @@ public class LayerOpValidation extends BaseOpValidation { .isSameMode(true) .build(); - SDVariable[] results = sd.nn().maxPoolWithArgmax(new String[]{"",""}, in, pooling2DConfig); + SDVariable[] results = sd.nn().maxPoolWithArgmax(new String[]{"out","idx"}, in, pooling2DConfig); assertArrayEquals(inArr.shape(), results[0].eval().shape()); assertArrayEquals(inArr.shape(), results[1].eval().shape()); } @@ -1050,7 +1050,7 @@ public class LayerOpValidation extends BaseOpValidation { SDVariable in = sd.var("in", inArr); SDVariable w = sd.var("w", wArr); - SDVariable res = sd.cnn.conv1d(in, w, Conv1DConfig.builder().k(kernel).build()); + SDVariable res = sd.cnn.conv1d(in, w, Conv1DConfig.builder().k(kernel).paddingMode(PaddingMode.VALID).build()); INDArray expected = Nd4j.createFromArray( new double[][][]{ diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/ConvConfigTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/ConvConfigTests.java index 996ccff7f..a6f7b6bea 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/ConvConfigTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/ConvConfigTests.java @@ -23,13 +23,7 @@ import static org.junit.Assert.fail; import org.junit.Assert; import org.junit.Test; import org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv2D; -import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig; -import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig; -import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig; -import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv2DConfig; -import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv3DConfig; -import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig; -import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling3DConfig; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.*; public class ConvConfigTests { @@ -489,24 +483,24 @@ public class ConvConfigTests { @Test public void testConv1D(){ - Conv1DConfig.builder().k(2).build(); + Conv1DConfig.builder().k(2).paddingMode(PaddingMode.SAME).build(); try{ - Conv1DConfig.builder().k(0).build(); + Conv1DConfig.builder().k(0).paddingMode(PaddingMode.SAME).build(); fail(); } catch (IllegalArgumentException e){ assertTrue(e.getMessage().contains("Kernel")); } try{ - Conv1DConfig.builder().k(4).s(-2).build(); + Conv1DConfig.builder().k(4).s(-2).paddingMode(PaddingMode.SAME).build(); fail(); } catch (IllegalArgumentException e){ assertTrue(e.getMessage().contains("Stride")); } try{ - Conv1DConfig.builder().k(3).p(-2).build(); + Conv1DConfig.builder().k(3).p(-2).paddingMode(PaddingMode.SAME).build(); fail(); } catch (IllegalArgumentException e){ assertTrue(e.getMessage().contains("Padding")); From ef4d3ffee823b50a77fbad81658401ed5c72fec0 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Thu, 5 Dec 2019 20:59:46 +1100 Subject: [PATCH 03/18] Small pre-release tweak (#112) * Log UI address on launch as in previous Play-based UI Signed-off-by: AlexDBlack * Logging level tweak for UI Signed-off-by: AlexDBlack * http not https Signed-off-by: AlexDBlack --- .../src/main/java/org/deeplearning4j/ui/VertxUIServer.java | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/java/org/deeplearning4j/ui/VertxUIServer.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/java/org/deeplearning4j/ui/VertxUIServer.java index 2aec66a77..64b033133 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/java/org/deeplearning4j/ui/VertxUIServer.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/java/org/deeplearning4j/ui/VertxUIServer.java @@ -254,6 +254,9 @@ public class VertxUIServer extends AbstractVerticle implements UIServer { uiEventRoutingThread = new Thread(new StatsEventRouterRunnable()); uiEventRoutingThread.setDaemon(true); uiEventRoutingThread.start(); + + String address = UIServer.getInstance().getAddress(); + log.info("Deeplearning4j UI server started at: {}", address); } private List extractArgsFromRoute(String path, RoutingContext rc) { @@ -317,7 +320,7 @@ public class VertxUIServer extends AbstractVerticle implements UIServer { @Override public String getAddress() { - return "https://localhost:" + server.actualPort(); + return "http://localhost:" + server.actualPort(); } @Override @@ -421,7 +424,7 @@ public class VertxUIServer extends AbstractVerticle implements UIServer { } private void runHelper() throws Exception { - log.info("VertxUIServer.StatsEventRouterRunnable started"); + log.trace("VertxUIServer.StatsEventRouterRunnable started"); //Idea: collect all event stats, and route them to the appropriate modules while (!shutdown.get()) { From 0e8a4f77bc0365afb07f99482b1314536370bf85 Mon Sep 17 00:00:00 2001 From: Fariz Rahman Date: Thu, 5 Dec 2019 17:57:32 +0530 Subject: [PATCH 04/18] datavec python ensure host (#113) * ensure host * one more host ensure * info->debug --- .../src/main/java/org/datavec/python/NumpyArray.java | 8 ++++++-- .../main/java/org/datavec/python/PythonExecutioner.java | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/NumpyArray.java b/datavec/datavec-python/src/main/java/org/datavec/python/NumpyArray.java index ab49cf5ea..24a2c2e09 100644 --- a/datavec/datavec-python/src/main/java/org/datavec/python/NumpyArray.java +++ b/datavec/datavec-python/src/main/java/org/datavec/python/NumpyArray.java @@ -21,6 +21,7 @@ import lombok.Getter; import lombok.NoArgsConstructor; import org.bytedeco.javacpp.Pointer; import org.nd4j.linalg.api.buffer.DataBuffer; +import org.nd4j.linalg.api.concurrency.AffinityManager; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.factory.Nd4j; @@ -60,6 +61,7 @@ public class NumpyArray { setND4JArray(); if (copy){ nd4jArray = nd4jArray.dup(); + Nd4j.getAffinityManager().ensureLocation(nd4jArray, AffinityManager.Location.HOST); this.address = nd4jArray.data().address(); } @@ -85,6 +87,7 @@ public class NumpyArray { setND4JArray(); if (copy){ nd4jArray = nd4jArray.dup(); + Nd4j.getAffinityManager().ensureLocation(nd4jArray, AffinityManager.Location.HOST); this.address = nd4jArray.data().address(); } } @@ -104,11 +107,12 @@ public class NumpyArray { nd4jStrides[i] = strides[i] / elemSize; } - this.nd4jArray = Nd4j.create(buff, shape, nd4jStrides, 0, Shape.getOrder(shape,nd4jStrides,1), dtype); - + nd4jArray = Nd4j.create(buff, shape, nd4jStrides, 0, Shape.getOrder(shape,nd4jStrides,1), dtype); + Nd4j.getAffinityManager().ensureLocation(nd4jArray, AffinityManager.Location.HOST); } public NumpyArray(INDArray nd4jArray){ + Nd4j.getAffinityManager().ensureLocation(nd4jArray, AffinityManager.Location.HOST); DataBuffer buff = nd4jArray.data(); address = buff.pointer().address(); shape = nd4jArray.shape(); diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/PythonExecutioner.java b/datavec/datavec-python/src/main/java/org/datavec/python/PythonExecutioner.java index c6272e7ad..0f926b9ad 100644 --- a/datavec/datavec-python/src/main/java/org/datavec/python/PythonExecutioner.java +++ b/datavec/datavec-python/src/main/java/org/datavec/python/PythonExecutioner.java @@ -605,7 +605,7 @@ public class PythonExecutioner { private static synchronized void _exec(String code) { - log.info(code); + log.debug(code); log.info("CPython: PyRun_SimpleStringFlag()"); int result = PyRun_SimpleStringFlags(code, null); From 355c6b60961ee854db88fd4f24f3f96c62c355f4 Mon Sep 17 00:00:00 2001 From: raver119 Date: Thu, 5 Dec 2019 20:03:10 +0300 Subject: [PATCH 05/18] [WIP] reverse improvements (#115) * initial commit Signed-off-by: raver119 * reverse draft Signed-off-by: raver119 * reverse kernel Signed-off-by: raver119 * reverse kernel Signed-off-by: raver119 --- .../ops/declarable/helpers/cuda/reverse.cu | 95 +++++++++++++++---- .../layers_tests/DeclarableOpsTests1.cpp | 3 +- .../layers_tests/DeclarableOpsTests16.cpp | 41 ++++++++ .../layers_tests/DeclarableOpsTestsCuda1.cu | 18 +++- 4 files changed, 139 insertions(+), 18 deletions(-) diff --git a/libnd4j/include/ops/declarable/helpers/cuda/reverse.cu b/libnd4j/include/ops/declarable/helpers/cuda/reverse.cu index aceebf7a0..90e15b21f 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/reverse.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/reverse.cu @@ -30,6 +30,67 @@ namespace nd4j { namespace ops { namespace helpers { + template + static __global__ void reverseTadKernel(void* vinput, Nd4jLong *inputShape, void* voutput, Nd4jLong *outputShape, Nd4jLong *inputTadShape, Nd4jLong *inputTadOffsets, Nd4jLong *outputTadShape, Nd4jLong *outputTadOffsets, uint64_t limit, uint64_t numOfElemsToReverse, uint64_t numTads) { + auto input = reinterpret_cast(vinput); + auto output = reinterpret_cast(voutput); + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + const auto step = gridDim.x * blockDim.x; + + // this means that we'll have additional cycle, to move middle element + auto div = numOfElemsToReverse / 2; + auto odd = numOfElemsToReverse % 2 != 0; + auto rlimit = odd ? limit / 2 + 1 : limit / 2; + + // all threads operate in the same input/output space + for (uint64_t e = tid; e < rlimit; e += step) { + // finding out the TAD we're going to process + auto tadId = e / div; + + if (tadId >= numTads) + continue; + + // now finding out element within tad + auto idx = e % div; + + //printf("TID: %i; numTads: %lld; tadLength: %lld; tadId: %i, idx: %lld\n", tid, numTads, numOfElemsToReverse, tadId, idx); + + auto tadInput = input + inputTadOffsets[tadId]; + auto tadOutput = output + outputTadOffsets[tadId]; + + // we're calculating offsets within input TAD + auto fOffset = shape::getIndexOffset(idx, inputTadShape); + auto lOffset = shape::getIndexOffset(numOfElemsToReverse - idx - 1, inputTadShape); + + // now we're storing input values + auto v1 = tadInput[fOffset]; + auto v2 = tadInput[lOffset]; + + // now we're calculating offsets within output TAD + auto zfOffset = shape::getIndexOffset(idx, outputTadShape); + auto zlOffset = shape::getIndexOffset(numOfElemsToReverse - idx - 1, outputTadShape); + + // and saving values to output arrays + tadOutput[zfOffset] = v2; + tadOutput[zlOffset] = v1; + } + + // moving odd element in blocks + if (odd && threadIdx.x == 0) { + for (uint64_t e = blockIdx.x; e < numTads; e += gridDim.x) { + auto tadInput = input + inputTadOffsets[e]; + auto tadOutput = output + outputTadOffsets[e]; + + auto xOffset = shape::getIndexOffset(numOfElemsToReverse / 2, inputTadShape); + auto zOffset = shape::getIndexOffset(numOfElemsToReverse / 2, outputTadShape); + + tadOutput[zOffset] = tadInput[xOffset]; + } + } + + } + + template static __global__ void reverseArrayKernel(void* input, Nd4jLong *inputShape, void* output, Nd4jLong *outputShape, Nd4jLong numOfElemsToReverse) { const auto tid = blockIdx.x * blockDim.x + threadIdx.x; @@ -52,7 +113,7 @@ namespace helpers { auto odd = numOfElemsToReverse % 2 != 0; auto limit = numOfElemsToReverse / 2; - for (Nd4jLong e = tid; e < limit; e += step) { + for (uint64_t e = tid; e < limit; e += step) { // we're calculating offsets within input array auto fOffset = shape::getIndexOffset(e, inputShape); auto lOffset = shape::getIndexOffset(numOfElemsToReverse - e - 1, inputShape); @@ -80,13 +141,19 @@ namespace helpers { } template - static void reverseArray(nd4j::LaunchContext * context, NDArray* input, NDArray* output, Nd4jLong numOfElemsToReverse) { + static void reverseTad(nd4j::LaunchContext * context, const NDArray* input, NDArray* output, Nd4jLong *inputTadShape, Nd4jLong *inputTadOffsets, Nd4jLong *outputTadShape, Nd4jLong *outputTadOffsets, uint64_t tadLength) { + auto stream = context->getCudaStream(); + reverseTadKernel<<<256, 512, 8192, *stream>>>(input->getSpecialBuffer(), input->getSpecialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), inputTadShape, inputTadOffsets, outputTadShape, outputTadOffsets, input->lengthOf(), tadLength, input->lengthOf() / tadLength); + } + + template + static void reverseArray(nd4j::LaunchContext * context, const NDArray* input, NDArray* output, Nd4jLong numOfElemsToReverse) { auto stream = context->getCudaStream(); Nd4jLong numOfReverse = numOfElemsToReverse; if (numOfElemsToReverse == 0) numOfReverse = input->lengthOf(); - reverseArrayKernel<<<256, 512, 8192, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), numOfReverse); + reverseArrayKernel<<<256, 512, 8192, *stream>>>(input->getSpecialBuffer(), input->getSpecialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), numOfReverse); } @@ -153,27 +220,23 @@ namespace helpers { // we need to reverse axis only if that's new op std::vector dimensions = isBackProp ? ShapeUtils::evalDimsToExclude(input->rankOf(), *intArgs) : *intArgs; std::vector axis = ShapeUtils::evalDimsToExclude(input->rankOf(), dimensions); - auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), axis); - auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), axis); + auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions); + auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions); - auto listOut = output->allTensorsAlongDimension(dimensions); - auto listIn = input->allTensorsAlongDimension(dimensions); - NDArray *subArrIn, *subArrOut; NDArray::prepareSpecialUse({output}, {input}); - for(int i = 0; i < listIn->size(); ++i) { // listIn->size() = listOut->size() - subArrIn = listIn->at(i); - subArrOut = listOut->at(i); - BUILD_SINGLE_SELECTOR(input->dataType(), reverseArray, (context, subArrIn, subArrOut, 0), LIBND4J_TYPES); + + if (packX.numberOfTads() == 1) { + BUILD_SINGLE_SELECTOR(input->dataType(), reverseArray, (context, input, output, 0), LIBND4J_TYPES); + } else { + BUILD_SINGLE_SELECTOR(input->dataType(), reverseTad, (context, input, output, packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets(), (uint64_t) (input->lengthOf() / packX.numberOfTads())), LIBND4J_TYPES); } - //BUILD_SINGLE_SELECTOR(input->dataType(), reverseArray, (context, const_cast(input), output, (int)0), LIBND4J_TYPES); + NDArray::registerSpecialUse({output}, {input}); - delete listOut; - delete listIn; } -BUILD_SINGLE_TEMPLATE(template void reverseArray, (nd4j::LaunchContext * context, NDArray *inArr, NDArray *outArr, Nd4jLong numOfElemsToReverse), LIBND4J_TYPES); +BUILD_SINGLE_TEMPLATE(template void reverseArray, (nd4j::LaunchContext * context, const NDArray *inArr, NDArray *outArr, Nd4jLong numOfElemsToReverse), LIBND4J_TYPES); } } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp index 7036ef77f..60351cc52 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp @@ -3523,7 +3523,8 @@ TEST_F(DeclarableOpsTests1, Reverse_7 ) { ASSERT_EQ(ND4J_STATUS_OK, results->status()); auto result = results->at(0); - // result->printBuffer(); + //expected.printIndexedBuffer("E"); + //result->printIndexedBuffer("R"); ASSERT_TRUE(expected.isSameShapeStrict(result)); ASSERT_TRUE(expected.equalsTo(result)); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp index 38d88b469..f8bf47e53 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp @@ -196,4 +196,45 @@ TEST_F(DeclarableOpsTests16, test_range_2) { ASSERT_TRUE(shape::shapeEquals(z.shapeInfo(), shapes->at(0))); delete shapes; +} + +TEST_F(DeclarableOpsTests16, test_reverse_1) { + std::vector rows = {3, 5, 7, 8, 9, 10, 119, 211}; + std::vector columns = {6, 5, 10, 100, 153, 171, 635}; + + for (auto r : rows) { + for (auto c : columns) { + //nd4j_printf("Trying [%i, %i]\n", r, c); + auto array = NDArrayFactory::create('c', {r, c}); + auto exp = NDArrayFactory::create('c', {r, c}); + auto reversed = NDArrayFactory::create('c', {r, c}); + + auto rowOriginal = NDArrayFactory::create('c', {c}); + auto rowReversed = NDArrayFactory::create('c', {c}); + + for (int e = 0; e < c; e++) { + rowOriginal.p(e, (float) e); + rowReversed.p(c - e - 1, (float) e); + } + + + auto listI = array.allTensorsAlongDimension({1}); + auto listE = exp.allTensorsAlongDimension({1}); + + for (int e = 0; e < r; e++) { + listI->at(e)->assign(rowOriginal); + listE->at(e)->assign(rowReversed); + } + + delete listI; + delete listE; + + nd4j::ops::reverse op; + Nd4jLong axis = 1; + auto status = op.execute({&array}, {&reversed}, {}, {axis}, {}); + ASSERT_EQ(Status::OK(), status); + + ASSERT_EQ(exp, reversed); + } + } } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTestsCuda1.cu b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTestsCuda1.cu index 161b96918..f88cddde5 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTestsCuda1.cu +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTestsCuda1.cu @@ -24,6 +24,7 @@ #include #include #include +#include using namespace nd4j; @@ -58,5 +59,20 @@ TEST_F(DeclarableOpsTestsCuda1, Test_CHOOSE_SCALAR_LARGE) { //ASSERT_TRUE(exp.isSameShape(z)); delete result; +} -} \ No newline at end of file +/* +TEST_F(DeclarableOpsTestsCuda1, Test_Reverse_TAD_1) { + auto x = NDArrayFactory::create('c', {1, 3, 608, 608}); + auto z = x.like(); + x.linspace(1.0f); + + nd4j::ops::reverse op; + auto timeStart = std::chrono::system_clock::now(); + auto status = op.execute({&x}, {&z}, {}, {1}, {}); + auto timeEnd = std::chrono::system_clock::now(); + auto outerTime = std::chrono::duration_cast (timeEnd - timeStart).count(); + nd4j_printf("exec time: %lld us\n", outerTime); + ASSERT_EQ(Status::OK(), status); +} +*/ \ No newline at end of file From 6a3c046ffd1a016102a93e3fba1323333023e19b Mon Sep 17 00:00:00 2001 From: raver119 Date: Thu, 5 Dec 2019 20:44:11 +0300 Subject: [PATCH 06/18] 2 micro fixes Signed-off-by: raver119 --- libnd4j/blas/NDArray.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/libnd4j/blas/NDArray.hpp b/libnd4j/blas/NDArray.hpp index df358b64f..5adff5853 100644 --- a/libnd4j/blas/NDArray.hpp +++ b/libnd4j/blas/NDArray.hpp @@ -2348,7 +2348,7 @@ NDArray NDArray::operator-(const NDArray& other) const { NDArray result(getShapeInfo(), DataTypeUtils::pickPairwiseResultType(getShapeInfo(), other.getShapeInfo()), false, getContext()); NDArray::prepareSpecialUse({&result}, {this, &other}); - NativeOpExecutioner::execPairwiseTransform(getContext(), nd4j::pairwise::Subtract, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), getSpecialShapeInfo(), nullptr); + NativeOpExecutioner::execPairwiseTransform(getContext(), nd4j::pairwise::Subtract, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), result.getSpecialShapeInfo(), nullptr); NDArray::registerSpecialUse({&result}, {this, &other}); return result; @@ -2394,7 +2394,7 @@ NDArray NDArray::operator/(const NDArray& other) const { NDArray result(getShapeInfo(), DataTypeUtils::pickPairwiseResultType(getShapeInfo(), other.getShapeInfo()), false, getContext()); NDArray::prepareSpecialUse({&result}, {this, &other}); - NativeOpExecutioner::execPairwiseTransform(getContext(), nd4j::pairwise::Divide, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), getSpecialShapeInfo(), nullptr); + NativeOpExecutioner::execPairwiseTransform(getContext(), nd4j::pairwise::Divide, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), result.getSpecialShapeInfo(), nullptr); NDArray::registerSpecialUse({&result}, {this, &other}); return result; From e09a78523277e8c37c052415091ad9266c2ffee8 Mon Sep 17 00:00:00 2001 From: shugeo Date: Thu, 5 Dec 2019 21:05:33 +0200 Subject: [PATCH 07/18] Shugeo resize fix5 (#102) * Refactored resize images ops to use TF-like bool args as input. * Refactored helpers for cpu implementation of resize_bilinear and resize_nearest_neighbor ops. * Refactored cuda implementation for image.resize_bilinear and image.resize_nearest_neighbor ops helpers. * Refactored nearest_neighbor resize op. * Added a pair of tests for special case of resize_bilinear algorithm. * Fixed issue with resize_bilinear op. * Refactored cpu implementation for helpers with resize_nearest_neighbor op. * Final fixed for resize ops to conform TF v.1.5 * Refactored cuda helpers for resize_neares_neighbor op. * Fixed resize_bilinear to accept proper data. * Fixed issue with non-float input for resize_bilinear op. * Refactored cuda helper for resize_bilinear to proper process non-float inputs. * Added tests for resize_bilinear to int inputs. * Fixed ResizeBilinear wrapper * Tests fixed * Fixed float and bool constant to avoid overflow for some kind of compilers. * Corrected float constants with float data type. * Added f suffix for float constants. * Corrected float constant to avoid overflow with initializing lists. * Corrected float initializing list with float input. * Corrected bool constant with initalizing list. * Corrected float and bool values with initializing lists. * Fixed wrong constant. * Fixed issue with 1x1 input picture for resize. * ResizeBilinear default values on import fix Signed-off-by: raver119 --- libnd4j/blas/NDArray.h | 47 ++- .../generic/parity_ops/resize_bicubic.cpp | 4 +- .../generic/parity_ops/resize_linear.cpp | 39 ++- .../generic/parity_ops/resize_neighbor.cpp | 33 ++- .../declarable/helpers/cpu/image_resize.cpp | 213 +++++++------- .../declarable/helpers/cuda/image_resize.cu | 264 ++++++++++------- .../ops/declarable/helpers/image_resize.h | 14 +- .../layers_tests/ConvolutionTests1.cpp | 18 +- .../layers_tests/ConvolutionTests2.cpp | 4 +- .../layers_tests/DeclarableOpsTests1.cpp | 6 +- .../layers_tests/DeclarableOpsTests10.cpp | 269 ++++++++++++++---- .../layers_tests/DeclarableOpsTests12.cpp | 4 +- .../layers_tests/DeclarableOpsTests13.cpp | 194 +++++++------ .../layers_tests/DeclarableOpsTests15.cpp | 4 +- .../layers_tests/DeclarableOpsTests2.cpp | 22 +- .../layers_tests/DeclarableOpsTests3.cpp | 18 +- .../layers_tests/DeclarableOpsTests7.cpp | 2 +- .../layers_tests/JavaInteropTests.cpp | 8 +- .../layers_tests/NDArrayCudaBasicsTests.cu | 90 +++--- .../tests_cpu/layers_tests/NDArrayTests.cpp | 86 ++++-- .../api/ops/impl/image/ResizeBilinear.java | 20 +- .../java/org/nd4j/nativeblas/Nd4jCuda.java | 5 + .../java/org/nd4j/nativeblas/Nd4jCpu.java | 45 ++- .../TFGraphs/TFGraphTestAllSameDiff.java | 3 - .../nd4j/linalg/custom/CustomOpsTests.java | 3 +- 25 files changed, 917 insertions(+), 498 deletions(-) diff --git a/libnd4j/blas/NDArray.h b/libnd4j/blas/NDArray.h index cfad05b49..d89ef8c72 100644 --- a/libnd4j/blas/NDArray.h +++ b/libnd4j/blas/NDArray.h @@ -1256,6 +1256,9 @@ namespace nd4j { FORCEINLINE T& t(const Nd4jLong i, const Nd4jLong j); template FORCEINLINE T& t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k); + template + FORCEINLINE T& t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong w); + /** * returns array element with given index @@ -1268,6 +1271,8 @@ namespace nd4j { FORCEINLINE T t(const Nd4jLong i, const Nd4jLong j) const; template FORCEINLINE T t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) const; + template + FORCEINLINE T t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong w) const; /** @@ -1711,7 +1716,7 @@ namespace nd4j { if (isEmpty()) return false; - return shape::isMatrix(this->_shapeInfo); + return 0 != shape::isMatrix(this->_shapeInfo); } ////////////////////////////////////////////////////////////////////////// @@ -1751,7 +1756,7 @@ namespace nd4j { ////////////////////////////////////////////////////////////////////////// bool NDArray::isScalar() const { - return shape::isScalar(this->_shapeInfo); + return 0 != shape::isScalar(this->_shapeInfo); } ////////////////////////////////////////////////////////////////////////// @@ -2082,7 +2087,7 @@ template T& NDArray::t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) { if (rankOf() != 3 || i >= sizeAt(0) || j >= sizeAt(1) || k >= sizeAt(2)) - throw std::invalid_argument("NDArray::t(i,j,k): one of input indexes is out of array length or rank!=2 !"); + throw std::invalid_argument("NDArray::t(i,j,k): one of input indexes is out of array length or rank!=3!"); if (DataTypeUtils::fromT() != _dataType) throw std::invalid_argument("NDArray::t(i,j,k): type of array is not equal to template type T!"); @@ -2095,6 +2100,23 @@ T& NDArray::t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) { return *(reinterpret_cast(bufferWithOffset(offset))); } +template +T& NDArray::t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong w) { + + if (rankOf() != 4 || i >= sizeAt(0) || j >= sizeAt(1) || k >= sizeAt(2), w >= sizeAt(3)) + throw std::invalid_argument("NDArray::t(i,j,k,w): one of input indexes is out of array length or rank!=4 !"); + if (DataTypeUtils::fromT() != _dataType) + throw std::invalid_argument("NDArray::t(i,j,k,w): type of array is not equal to template type T!"); + + if(!isActualOnHostSide()) + syncToHost(); + + Nd4jLong coords[4] = {i, j, k, w}; + auto offset = shape::getOffset(getShapeInfo(), coords); + tickWriteHost(); + return *(reinterpret_cast(bufferWithOffset(offset))); +} + //////////////////////////////////////////////////////////////////////// template T NDArray::t(const Nd4jLong i) const { @@ -2133,7 +2155,7 @@ T NDArray::t(const Nd4jLong i, const Nd4jLong j) const { T NDArray::t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) const { if (rankOf() != 3 || i >= sizeAt(0) || j >= sizeAt(1) || k >= sizeAt(2)) - throw std::invalid_argument("NDArray::t(i,j,k): one of input indexes is out of array length or rank!=2 !"); + throw std::invalid_argument("NDArray::t(i,j,k): one of input indexes is out of array length or rank!=3!"); if (DataTypeUtils::fromT() != _dataType) throw std::invalid_argument("NDArray::t(i,j,k): type of array is not equal to template type T!"); @@ -2146,6 +2168,23 @@ T NDArray::t(const Nd4jLong i, const Nd4jLong j) const { return *(reinterpret_cast(bufferWithOffset(offset))); } + template + T NDArray::t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong w) const { + + if (rankOf() != 4 || i >= sizeAt(0) || j >= sizeAt(1) || k >= sizeAt(2) || w >= sizeAt(3)) + throw std::invalid_argument("NDArray::t(i,j,k,w): one of input indexes is out of array length or rank!=4!"); + if (DataTypeUtils::fromT() != _dataType) + throw std::invalid_argument("NDArray::t(i,j,k,w): type of array is not equal to template type T!"); + + if(!isActualOnHostSide()) + syncToHost(); + + Nd4jLong coords[4] = {i, j, k, w}; + auto offset = shape::getOffset(getShapeInfo(), coords); + tickReadHost(); + return *(reinterpret_cast(bufferWithOffset(offset))); + } + #ifndef __JAVACPP_HACK__ //////////////////////////////////////////////////////////////////////// std::shared_ptr NDArray::getDataBuffer() const { diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/resize_bicubic.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/resize_bicubic.cpp index 0c1aeba61..99053561c 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/resize_bicubic.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/resize_bicubic.cpp @@ -35,6 +35,8 @@ namespace nd4j { int width; int height; auto inRank = image->rankOf(); + if (output->isEmpty()) return Status::OK(); + REQUIRE_TRUE(inRank == 3 || inRank == 4, 0, "resize_bicubic: Source tensor should have rank 4, but %i given.", inRank); REQUIRE_TRUE(output->rankOf() == inRank, 0, "resize_bicubic: Source tensor and output should have the same rank, but %i and %i given.", inRank, output->rankOf()); REQUIRE_TRUE(size->rankOf() == 1, size->lengthOf() == 2, 0, "resize_bicubic: Resize params is a pair of values, not %i.", size->lengthOf()); @@ -57,7 +59,7 @@ namespace nd4j { if (block.numB()> 1) halfPixelAlign = block.getBArguments()->at(1); } - REQUIRE_TRUE(halfPixelAlign == false || halfPixelAlign == true && alignCorners == false, 0, "resize_bicubic: half pixel align can be used only with non-aligned corners"); + REQUIRE_TRUE(!halfPixelAlign || (halfPixelAlign && !alignCorners), 0, "resize_bicubic: `half_pixel_centers' should be false or true only when `align_corners' is false"); auto source = inRank == 4?image->reshape(image->ordering(), {image->sizeAt(0), image->sizeAt(1), image->sizeAt(2), image->sizeAt(3)}):image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)}); auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}):output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)}); diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/resize_linear.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/resize_linear.cpp index f60f14fdc..f1f79b08f 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/resize_linear.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/resize_linear.cpp @@ -32,8 +32,10 @@ namespace nd4j { NDArray* output = OUTPUT_VARIABLE(0); int width; int height; - bool center = false; // - default value + bool alignCorners = false; // - default value auto inRank = image->rankOf(); + if (output->isEmpty()) return Status::OK(); + REQUIRE_TRUE( inRank == 4 || inRank == 3, 0, "resize_bilinear: input image should be 4D " "tensor, but input has rank %i", image->rankOf()); @@ -46,21 +48,25 @@ namespace nd4j { auto newImageSize = INPUT_VARIABLE(1); REQUIRE_TRUE(newImageSize->lengthOf() == 2, 0, "resize_bilinear: Resize params is a pair of values, not %i.", newImageSize->lengthOf()); REQUIRE_TRUE(block.numI() <= 1, 0, "resize_bilinear: Resize params already given by the second param. Int params are expensive."); - width = newImageSize->e(0); - height = newImageSize->e(1); - if (block.numI() == 1) { - center = 0 != INT_ARG(0); - } + height = newImageSize->e(0); + width = newImageSize->e(1); } else { - REQUIRE_TRUE(block.numI() <= 3, 0, "resize_bilinear: Neither resize width nor height are provided."); - width = INT_ARG(0); - height = INT_ARG(1); - if (block.numI() == 3) - center = 0 != INT_ARG(2); + REQUIRE_TRUE(block.numI() > 1, 0, "resize_bilinear: Neither resize width nor height are provided."); + height = INT_ARG(0); + width = INT_ARG(1); } - return helpers::resizeBilinearFunctor(block.launchContext(), inRank==4?image:&source, width, height, center, inRank == 4?output:&target); + if (block.numB() > 0) + alignCorners = B_ARG(0); + bool halfPixelCenter = false; + + if (block.numB() > 1) + halfPixelCenter = B_ARG(1); + + REQUIRE_TRUE(!halfPixelCenter || (halfPixelCenter && !alignCorners), 0, "resize_bilinear: `half_pixel_centers' should be false or true only when `align_corners' is false"); + + return helpers::resizeBilinearFunctor(block.launchContext(), inRank==4?image:&source, width, height, alignCorners, halfPixelCenter, inRank == 4 ? output : &target); } DECLARE_SHAPE_FN(resize_bilinear) { @@ -83,7 +89,7 @@ namespace nd4j { height = newImageSize->e(1); } else { - REQUIRE_TRUE(block.numI() <= 3, 0, "resize_bilinear: Neither resize width nor height are provided."); + REQUIRE_TRUE(block.numI() == 2, 0, "resize_bilinear: Neither resize width nor height are provided."); width = INT_ARG(0); height = INT_ARG(1); } @@ -101,7 +107,12 @@ namespace nd4j { outputShape[2] = height; outputShape[3] = in[3]; } - ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); + if (DataTypeUtils::isR(ArrayOptions::dataType(in))) { + ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); + } + else { + ShapeUtils::updateStridesAndType(outputShape, DataType::FLOAT32, shape::order(in)); + } shapeList->push_back(CONSTANT(outputShape)); return shapeList; diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/resize_neighbor.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/resize_neighbor.cpp index 8733cb9d5..6c18e61e1 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/resize_neighbor.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/resize_neighbor.cpp @@ -31,35 +31,40 @@ namespace nd4j { auto image = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); + auto inRank = image->rankOf(); int width; int height; - bool center = false; // - default value + bool alignCorners = false; // - default value + if (output->isEmpty()) return Status::OK(); if (block.width() > 1) { auto newImageSize = INPUT_VARIABLE(1); REQUIRE_TRUE(newImageSize->lengthOf() == 2, 0, "resize_nearest_neighbor: Resize params is a pair of values, not %i.", newImageSize->lengthOf()); REQUIRE_TRUE(block.numI() <= 1, 0, "resize_nearest_neighbor: Resize params already given by the second param. Int params are expensive."); - width = newImageSize->e(0); - height = newImageSize->e(1); - if (block.numI() == 1) { - center = 0 != INT_ARG(0); - } + height = newImageSize->e(0); + width = newImageSize->e(1); } else { - REQUIRE_TRUE(block.numI() <= 3, 0, "resize_nearest_neighbor: Neither resize width nor height are provided."); - width = INT_ARG(0); - height = INT_ARG(1); - if (block.numI() == 3) - center = 0 != INT_ARG(2); + REQUIRE_TRUE(block.numI() == 2, 0, "resize_nearest_neighbor: Neither resize width nor height are provided."); + height = INT_ARG(0); + width = INT_ARG(1); } - auto inRank = image->rankOf(); + if (block.numB() > 0) + alignCorners = B_ARG(0); + bool halfPixelCenter = false; + + if (block.numB() > 1) + halfPixelCenter = B_ARG(1); + REQUIRE_TRUE(width <= (1 << 24) || height <= (1 << 24), 0, "resize_nearest_neighbour: the image resize should be limited to 2^24 pixels both for height and width, but %d and %d were given.", height, width); REQUIRE_TRUE(inRank == 4 || inRank == 3, 0, "resize_nearest_neighbor: Input should be 4D tensor, but rank %i occured"); REQUIRE_TRUE(inRank == output->rankOf(), 0, "resize_nearest_neighbor: Input and output ranks should be equals, but %i and %i occured.", inRank, output->rankOf()); REQUIRE_TRUE(image->dataType() == output->dataType(), 0, "resize_nearest_neighbor: Input and output types should be the same, but `%s' occured instead.", DataTypeUtils::asString(output->dataType()).c_str()); - auto source = inRank == 4?*image:image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)}); + REQUIRE_TRUE(!halfPixelCenter || (halfPixelCenter && !alignCorners), 0, "resize_nearest_neighbor: `half_pixel_centers' should be false or true only when `align_corners' is false"); + REQUIRE_TRUE(((alignCorners && height > 2) || (height > 0)) && ((alignCorners && width > 1) || (width > 0)), 0, "resize_nearest_neighbor: Wrong input or output size to resize (width = %d, height = %d)", width, height); + auto source = inRank == 4?*image:image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)}); auto target = inRank == 4?*output:output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)}); - return helpers::resizeNeighborFunctor(block.launchContext(), inRank==4?image:&source, width, height, center, inRank == 4?output:&target); + return helpers::resizeNeighborFunctor(block.launchContext(), inRank==4?image:&source, width, height, alignCorners, halfPixelCenter, inRank == 4 ? output : &target); } DECLARE_SHAPE_FN(resize_nearest_neighbor) { diff --git a/libnd4j/include/ops/declarable/helpers/cpu/image_resize.cpp b/libnd4j/include/ops/declarable/helpers/cpu/image_resize.cpp index d334caed2..16ddd17da 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/image_resize.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/image_resize.cpp @@ -120,6 +120,27 @@ namespace helpers { } }; + // Half pixel scaler scales assuming that the pixel centers are at 0.5, i.e. the +// floating point coordinates of the top,left pixel is 0.5,0.5. + struct HalfPixelScalerNN { + HalfPixelScalerNN(){}; + inline float operator()(const int x, const float scale) const { + // Note that we subtract 0.5 from the return value, as the existing bilinear + // sampling code etc assumes pixels are in the old coordinate system. + return (static_cast(x) + 0.5f) * scale; + } + }; + +// Older incorrect scaling method that causes all resizes to have a slight +// translation leading to inconsistent results. For example, a flip then a +// resize gives different results then a resize then a flip. + struct LegacyScaler { + LegacyScaler(){}; + inline float operator()(const int x, const float scale) const { + return static_cast(x) * scale; + } + }; + struct WeightsAndIndices { float _weight0; float _weight1; @@ -133,7 +154,8 @@ namespace helpers { int _advance; // advance value. }; - inline void computeInterpolationWeights(Nd4jLong outSize, + template + inline void computeInterpolationWeights(const Scaler scaler, Nd4jLong outSize, Nd4jLong inSize, double scale, BilinearInterpolationData *interpolationData) { @@ -143,10 +165,12 @@ namespace helpers { auto func = PRAGMA_THREADS_FOR { for (auto k = start; k < stop; k++) { auto i = (outSize - k - 1); - double in = i * scale; - interpolationData[i]._bottomIndex = static_cast(in); - interpolationData[i]._topIndex = nd4j::math::nd4j_min(interpolationData[i]._bottomIndex + 1, inSize - 1); - interpolationData[i]._interpolarValue = in - interpolationData[i]._bottomIndex; + double const in = scaler(i, scale); + double const in_f = nd4j::math::nd4j_floor(in); + double const in_c = nd4j::math::nd4j_ceil(in); + interpolationData[i]._bottomIndex = nd4j::math::nd4j_max(static_cast(in_f), (Nd4jLong)0LL);//static_cast(in); + interpolationData[i]._topIndex = nd4j::math::nd4j_min(static_cast(in_c), inSize - 1); + interpolationData[i]._interpolarValue = in - in_f; } }; samediff::Threads::parallel_for(func, 0, outSize); @@ -156,29 +180,29 @@ namespace helpers { * Computes the bilinear interpolation from the appropriate 4 float points * and the linear interpolation weights. */ - static void - resizeImage(NDArray const *images, Nd4jLong batchSize, Nd4jLong inHeight, Nd4jLong inWidth, Nd4jLong outHeight, - Nd4jLong outWidth, Nd4jLong channels, - std::vector const& xs, - std::vector const& ys, - NDArray *output); +// static void +// resizeImage(NDArray const *images, Nd4jLong batchSize, Nd4jLong inHeight, Nd4jLong inWidth, Nd4jLong outHeight, +// Nd4jLong outWidth, Nd4jLong channels, +// std::vector const& xs, +// std::vector const& ys, +// NDArray *output); - template + template static void - resizeImage_(NDArray const *images, Nd4jLong batchSize, Nd4jLong inHeight, Nd4jLong inWidth, Nd4jLong outHeight, + resizeImage_(T const* pInputBuf, Nd4jLong batchSize, Nd4jLong inHeight, Nd4jLong inWidth, Nd4jLong outHeight, Nd4jLong outWidth, Nd4jLong channels, std::vector const &xs, std::vector const &ys, - NDArray *output) { + Z* pOutputBuf) { Nd4jLong inRowSize = inWidth * channels; Nd4jLong inBatchNumValues = inHeight * inRowSize; Nd4jLong outRowSize = outWidth * channels; - T const *pInputBuf = images->getDataBuffer()->primaryAsT(); // this works only with 'c' direction +// T const *pInputBuf = images->getDataBuffer()->primaryAsT(); // this works only with 'c' direction BilinearInterpolationData const* xsPtr = xs.data(); - T* pOutputBuf = output->dataBuffer()->primaryAsT(); +// T* pOutputBuf = output->dataBuffer()->primaryAsT(); auto computeBilinear = [](double topLeft, double topRight, double bottomLeft, double bottomRight, double xVal, double yVal) { @@ -214,8 +238,12 @@ namespace helpers { samediff::Threads::parallel_tad(func, 0, batchSize); } - template - static int resizeBilinearFunctor_(NDArray const *images, int width, int height, bool center, NDArray *output) { + template + static int resizeBilinearFunctor_(NDArray const *images, int const width, int const height, bool const alignCorners, + bool const halfPixelCenter, NDArray *output) { + ImageResizerState st(alignCorners, halfPixelCenter); + st.validateAndCalculateOutputSize(images, width, height); + const Nd4jLong batchSize = images->sizeAt(0); const Nd4jLong inHeight = images->sizeAt(1); const Nd4jLong inWidth = images->sizeAt(2); @@ -230,28 +258,20 @@ namespace helpers { return ND4J_STATUS_OK; } - // Special case for TF compatibility - if((center && inHeight < 2) || (center && inWidth < 2)){ - center = false; - } - - if ((center && inHeight < 2) || (inHeight < 1) || (outHeight < 1) || (center && outHeight < 2) || - (center && inWidth < 2) || (inWidth < 1) || (outWidth < 1) || (center && outWidth < 2)) { - // wrong input data - nd4j_printf("image.resize_bilinear: Wrong input or output size to resize\n", ""); - return ND4J_STATUS_BAD_ARGUMENTS; - } - float heightScale = center ? (inHeight - 1.f) / double(outHeight - 1.f) : (inHeight / float(outHeight)); - float widthScale = center ? (inWidth - 1.f) / double(outWidth - 1.f) : (inWidth / float(outWidth)); - std::vector ys(outHeight + 1); std::vector xs(outWidth + 1); + if (halfPixelCenter) { + computeInterpolationWeights(HalfPixelScaler(), outHeight, inHeight, st.heightScale, + ys.data()); + computeInterpolationWeights(HalfPixelScaler(), outWidth, inWidth, st.widthScale, xs.data()); - // Compute the cached interpolation weights on the x and y dimensions. - computeInterpolationWeights(outHeight, inHeight, heightScale, - ys.data()); - computeInterpolationWeights(outWidth, inWidth, widthScale, xs.data()); - + } + else { + // Compute the cached interpolation weights on the x and y dimensions. + computeInterpolationWeights(LegacyScaler(), outHeight, inHeight, st.heightScale, + ys.data()); + computeInterpolationWeights(LegacyScaler(), outWidth, inWidth, st.widthScale, xs.data()); + } int xsSize = xs.size(); // Scale x interpolation weights to avoid a multiplication during iteration. auto func = PRAGMA_THREADS_FOR { @@ -262,71 +282,84 @@ namespace helpers { }; samediff::Threads::parallel_for(func, 0, xsSize); - resizeImage(images, batchSize, inHeight, inWidth, outHeight, outWidth, channels, xs, ys, output); + resizeImage_(images->getDataBuffer()->primaryAsT(), batchSize, inHeight, inWidth, outHeight, outWidth, channels, xs, ys, output->dataBuffer()->primaryAsT()); return ND4J_STATUS_OK; } - template - int resizeNeighborFunctor_(NDArray const *images, int width, int height, bool center, NDArray *output) { - const Nd4jLong batchSize = images->sizeAt(0); - const Nd4jLong inHeight = images->sizeAt(1); - const Nd4jLong inWidth = images->sizeAt(2); - const Nd4jLong channels = images->sizeAt(3); + template + void resizeNeighbor(ImageResizerState const& st, NDArray const *images, bool const alignCorners, bool const halfPixelCenter, NDArray *output) { + const Nd4jLong batchSize = st.batchSize; + const Nd4jLong inHeight = st.inHeight; + const Nd4jLong inWidth = st.inWidth; + const Nd4jLong channels = st.channels; - const Nd4jLong outHeight = output->sizeAt(1); - const Nd4jLong outWidth = output->sizeAt(2); - - // Handle no-op resizes efficiently. - if (outHeight == inHeight && outWidth == inWidth) { - output->assign(images); - return Status::OK(); - } - - if ((center && inHeight < 2) || (inHeight < 1) || (outHeight < 1) || (center && outHeight < 2) || - (center && inWidth < 2) || (inWidth < 1) || (outWidth < 1) || (center && outWidth < 2)) { - // wrong input data - nd4j_printf("image.resize_nearest_neighbor: Wrong input or output size to resize\n", ""); - return ND4J_STATUS_BAD_ARGUMENTS; - } - double heightScale = center ? (inHeight - 1.) / double(outHeight - 1.0) : (inHeight / double(outHeight)); - double widthScale = center ? (inWidth - 1.) / double(outWidth - 1.0) : (inWidth / double(outWidth)); + const Nd4jLong outHeight = st.outHeight; + const Nd4jLong outWidth = st.outWidth; + Scaler scaler; auto func = PRAGMA_THREADS_FOR_2D { for (auto b = start_x; b < stop_x; b += inc_x) { for (auto y = start_y; y < stop_y; y += inc_y) { - Nd4jLong inY = nd4j::math::nd4j_min((center) ? static_cast(nd4j::math::p_round(y * heightScale)) : static_cast(nd4j::math::p_floor(y * heightScale)), inHeight - 1); - + auto posY = alignCorners ? static_cast(nd4j::math::p_round(scaler(y, st.heightScale))) : static_cast(nd4j::math::p_floor(scaler(y, st.heightScale))); + Nd4jLong inY = nd4j::math::nd4j_min(posY, inHeight - 1); + if (halfPixelCenter) { + inY = nd4j::math::nd4j_max(0LL, inY); + } for (auto x = 0; x < outWidth; ++x) { - Nd4jLong inX = nd4j::math::nd4j_min((center) ? static_cast(nd4j::math::p_round(x * widthScale)) : static_cast(nd4j::math::p_floor(x * widthScale)),inWidth - 1); + auto posX = alignCorners ? static_cast(nd4j::math::p_round(scaler(x, st.widthScale))) : static_cast(nd4j::math::p_floor(scaler(x, st.widthScale))); + Nd4jLong inX = nd4j::math::nd4j_min(posX,inWidth - 1); + if (halfPixelCenter) { + inX = nd4j::math::nd4j_max(0LL, inX); + } + // copy pixel over all channels for (auto e = 0; e < channels; e++) - output->p(b, y, x, e, images->e(b, inY, inX, e)); + output->t(b, y, x, e) = images->t(b, inY, inX, e); } } } }; samediff::Threads::parallel_for(func, 0, batchSize, 1, 0, outHeight, 1); + } + + template + int resizeNeighborFunctor_(NDArray const *images, int const width, int const height, bool const alignCorners, bool const halfPixelCenter, NDArray *output) { + ImageResizerState st(alignCorners, halfPixelCenter); + st.validateAndCalculateOutputSize(images, width, height); + + // Handle no-op resizes efficiently. + if (output->sizeAt(1) == images->sizeAt(1) && output->sizeAt(2) == images->sizeAt(2)) { + output->assign(images); + return Status::OK(); + } + + if (halfPixelCenter) + resizeNeighbor(st, images, alignCorners, true, output); + else + resizeNeighbor(st, images, alignCorners, false, output); return Status::OK(); } - void resizeImage(NDArray const *images, Nd4jLong batchSize, Nd4jLong inHeight, Nd4jLong inWidth, Nd4jLong outHeight, - Nd4jLong outWidth, Nd4jLong channels, - std::vector const &xs, - std::vector const &ys, - NDArray *output) { - BUILD_SINGLE_SELECTOR(images->dataType(), resizeImage_, - (images, batchSize, inHeight, inWidth, outHeight, outWidth, channels, xs, ys, output), - LIBND4J_TYPES); +// void resizeImage(NDArray const *images, Nd4jLong batchSize, Nd4jLong inHeight, Nd4jLong inWidth, Nd4jLong outHeight, +// Nd4jLong outWidth, Nd4jLong channels, +// std::vector const &xs, +// std::vector const &ys, +// NDArray *output) { +// BUILD_DOUBLE_SELECTOR(images->dataType(), output->dataType(), resizeImage_, +// (images, batchSize, inHeight, inWidth, outHeight, outWidth, channels, xs, ys, output), +// NUMERIC_TYPES, FLOAT_TYPES); +// } + + int resizeBilinearFunctor(nd4j::LaunchContext * context, NDArray const *images, int const width, int const height, + bool const alignCorners, bool const halfPixelCenter, NDArray *output) { + BUILD_DOUBLE_SELECTOR(images->dataType(), output->dataType(), return resizeBilinearFunctor_, + (images, width, height, alignCorners, halfPixelCenter, output), NUMERIC_TYPES, FLOAT_TYPES); } - int resizeBilinearFunctor(nd4j::LaunchContext * context, NDArray const *images, int width, int height, bool center, NDArray *output) { - BUILD_SINGLE_SELECTOR(images->dataType(), return resizeBilinearFunctor_, - (images, width, height, center, output), LIBND4J_TYPES); - } - - int resizeNeighborFunctor(nd4j::LaunchContext * context, NDArray const *images, int width, int height, bool center, NDArray *output) { + int resizeNeighborFunctor(nd4j::LaunchContext * context, NDArray const *images, int const width, int const height, + bool const alignCorners, bool const halfPixelCenter, NDArray *output) { BUILD_SINGLE_SELECTOR(images->dataType(), return resizeNeighborFunctor_, - (images, width, height, center, output), LIBND4J_TYPES); + (images, width, height, alignCorners, halfPixelCenter, output), LIBND4J_TYPES); } @@ -586,16 +619,6 @@ namespace helpers { } } -// Older incorrect scaling method that causes all resizes to have a slight -// translation leading to inconsistent results. For example, a flip then a -// resize gives different results then a resize then a flip. - struct LegacyScaler { - LegacyScaler(){}; - inline float operator()(const int x, const float scale) const { - return static_cast(x) * scale; - } - }; - static void computeXWeightsAndIndices(const ImageResizerState& resizer_state, const bool half_pixel_centers, std::vector* x_wais) { @@ -847,7 +870,7 @@ namespace helpers { // simplified bicubic resize without antialiasing // template - int resizeBicubicFunctorA_(nd4j::LaunchContext * context, NDArray const* image, int width, int height, + int resizeBicubicFunctorA_(nd4j::LaunchContext * context, NDArray const* image, int const width, int const height, bool const alignCorners, bool const halfPixelAlign, NDArray* output) { ImageResizerState st(alignCorners, halfPixelAlign); // align_corners, half_pixel_align int res = st.validateAndCreateOutput(image, width, height); @@ -856,17 +879,17 @@ namespace helpers { return res; } - int resizeBicubicFunctorA(nd4j::LaunchContext * context, NDArray const* image, int width, int height, + int resizeBicubicFunctorA(nd4j::LaunchContext * context, NDArray const* image, int const width, int const height, bool const alignCorners, bool const halfPixelAlign, NDArray* output) { BUILD_SINGLE_SELECTOR(image->dataType(), return resizeBicubicFunctorA_, (context, image, width, height, alignCorners, halfPixelAlign, output), NUMERIC_TYPES); } // ------------------------------------------------------------------------------------------------------------------ // - int resizeFunctor(nd4j::LaunchContext * context, NDArray const* image, int width, int height, + int resizeFunctor(nd4j::LaunchContext * context, NDArray const* image, int const width, int const height, ImageResizeMethods method, bool preserveAspectRatio, bool antialias, NDArray* output) { switch (method) { - case kResizeBilinear: return resizeBilinearFunctor(context, image, width, height, false, output); break; - case kResizeNearest: return resizeNeighborFunctor(context, image, width, height, false, output); break; + case kResizeBilinear: return resizeBilinearFunctor(context, image, width, height, false, false, output); break; + case kResizeNearest: return resizeNeighborFunctor(context, image, width, height, false, false, output); break; case kResizeBicubic: return resizeBicubicFunctor(context, image, width, height, preserveAspectRatio, antialias, output); break; case kResizeLanczos5: case kResizeGaussian: diff --git a/libnd4j/include/ops/declarable/helpers/cuda/image_resize.cu b/libnd4j/include/ops/declarable/helpers/cuda/image_resize.cu index 0541742ca..4f025d851 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/image_resize.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/image_resize.cu @@ -13,6 +13,20 @@ * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ // // @author sgazeos@gmail.com @@ -32,6 +46,38 @@ namespace helpers { // https://en.wikipedia.org/wiki/Bilinear_interpolation) double interpolarValue; }; + +// Older incorrect scaling method that causes all resizes to have a slight +// translation leading to inconsistent results. For example, a flip then a +// resize gives different results then a resize then a flip. + struct LegacyScaler { + _CUDA_HD LegacyScaler(){}; + inline _CUDA_HD float operator()(const int x, const float scale) const { + return static_cast(x) * scale; + } + }; + +// Half pixel scaler scales assuming that the pixel centers are at 0.5, i.e. the +// floating point coordinates of the top,left pixel is 0.5,0.5. + struct HalfPixelScaler { + _CUDA_HD HalfPixelScaler(){}; + inline _CUDA_HD float operator()(const int x, const float scale) const { + // Note that we subtract 0.5 from the return value, as the existing bilinear + // sampling code etc assumes pixels are in the old coordinate system. + return (static_cast(x) + 0.5f) * scale - 0.5f; + } + }; + + + // Utility functions + // calculateResizeScale determines the float scaling factor. + inline float calculateResizeScale(Nd4jLong inSize, Nd4jLong outSize, + bool alignCorners) { + return (alignCorners && outSize > 1) + ? (inSize - 1) / static_cast(outSize - 1) + : inSize / static_cast(outSize); + } + //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // computeInterpolationWeights kernel // outSize - output length @@ -39,6 +85,7 @@ namespace helpers { // scale - input scale // interporationData - result // + template static __global__ void computeInterpolationWeights(Nd4jLong outSize, Nd4jLong inSize, double scale, @@ -48,12 +95,18 @@ namespace helpers { interpolationData[outSize].topIndex = 0; auto tid = blockIdx.x * blockDim.x + threadIdx.x; auto step = blockDim.x * gridDim.x; - + Scaler scaler; for (Nd4jLong i = outSize - tid; i >= 0; i -= step) { - double in = i * scale; - interpolationData[i].bottomIndex = static_cast(in); - interpolationData[i].topIndex = nd4j::math::nd4j_min(interpolationData[i].bottomIndex + 1, inSize - 1); - interpolationData[i].interpolarValue = in - interpolationData[i].bottomIndex; + double in = scaler(i, scale); +// interpolationData[i].bottomIndex = static_cast(in); +// interpolationData[i].topIndex = nd4j::math::nd4j_min(interpolationData[i].bottomIndex + 1, inSize - 1); +// interpolationData[i].interpolarValue = in - interpolationData[i].bottomIndex; + double const in_f = nd4j::math::p_floor(in); + double const in_c = nd4j::math::p_ceil(in); + interpolationData[i].bottomIndex = nd4j::math::nd4j_max(static_cast(in_f), (Nd4jLong)0LL);//static_cast(in); + interpolationData[i].topIndex = nd4j::math::nd4j_min(static_cast(in_c), inSize - 1); + interpolationData[i].interpolarValue = in - in_f; + if (channels) { math::atomics::nd4j_atomicMul(&interpolationData[i].bottomIndex, channels); math::atomics::nd4j_atomicMul(&interpolationData[i].topIndex, channels); @@ -72,31 +125,33 @@ namespace helpers { //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // resize image with bilinear interpolation algorithm kernel // - template - static __global__ void resizeImageKernel(T const* input, Nd4jLong const* inputShape, T* outputYptr, Nd4jLong* outputShape, Nd4jLong batchSize, - Nd4jLong outWidth, Nd4jLong outHeight, Nd4jLong channels, Nd4jLong inRowSize, Nd4jLong outRowSize, Nd4jLong inBatchNumValues, - BilinearInterpolationData* xs_, BilinearInterpolationData* ys_) { + template + static __global__ void resizeImageKernel(T const* input, Nd4jLong const* inputShape, Z* outputYptr, + Nd4jLong* outputShape, Nd4jLong batchSize, Nd4jLong outWidth, Nd4jLong outHeight, Nd4jLong channels, + Nd4jLong inRowSize, Nd4jLong outRowSize, Nd4jLong inBatchNumValues, + BilinearInterpolationData* xs_, BilinearInterpolationData* ys_) { for (auto batch = blockIdx.x; batch < batchSize; batch += gridDim.x ) { // blockIdx.x as batch index auto pX = input + batch * inBatchNumValues; for (Nd4jLong y = threadIdx.x; y < outHeight; y += blockDim.x) { - const T *ys_input_lower_ptr = pX + ys_[y].bottomIndex * inRowSize; - const T *ys_input_upper_ptr = pX + ys_[y].topIndex * inRowSize; + const T* ys_input_lower_ptr = pX + ys_[y].bottomIndex * inRowSize; + const T* ys_input_upper_ptr = pX + ys_[y].topIndex * inRowSize; double yVal = ys_[y].interpolarValue; auto pZ = outputYptr + (batch * outHeight + y) * outRowSize; - for (Nd4jLong x = threadIdx.y; x < outWidth; x += blockDim.y) { + for (Nd4jLong x = 0; x < outWidth; x++) { auto xsBottom = xs_[x].bottomIndex; auto xsTop = xs_[x].topIndex; auto xVal = xs_[x].interpolarValue; // process interpolation for all channels - for (int c = threadIdx.z; c < channels; c += blockDim.z) { - double topLeft(ys_input_lower_ptr[xsBottom + c]); - double topRight(ys_input_lower_ptr[xsTop + c]); - double bottomLeft(ys_input_upper_ptr[xsBottom + c]); - double bottomRight(ys_input_upper_ptr[xsTop + c]); - double top = topLeft + (topRight - topLeft) * xVal; - double bottom = bottomLeft + (bottomRight - bottomLeft) * xVal; - pZ[x * channels + c] = T(top + (bottom - top) * yVal); + for (int c = 0; c < channels; c++) { + Z topLeft(ys_input_lower_ptr[xsBottom + c]); + Z topRight(ys_input_lower_ptr[xsTop + c]); + Z bottomLeft(ys_input_upper_ptr[xsBottom + c]); + Z bottomRight(ys_input_upper_ptr[xsTop + c]); + Z top = topLeft + (topRight - topLeft) * xVal; + Z bottom = bottomLeft + (bottomRight - bottomLeft) * xVal; + Z resVal = Z(top + (bottom - top) * yVal); + pZ[x * channels + c] = resVal; } } } @@ -105,7 +160,7 @@ namespace helpers { //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // resize image with - template + template static void resizeImage_(nd4j::LaunchContext* context, NDArray const* images, Nd4jLong batchSize, Nd4jLong inHeight, Nd4jLong inWidth, Nd4jLong outHeight, Nd4jLong outWidth, Nd4jLong channels, BilinearInterpolationData* xs_, @@ -115,12 +170,13 @@ namespace helpers { Nd4jLong inBatchNumValues = inHeight * inRowSize; Nd4jLong outRowSize = outWidth * channels; auto stream = context->getCudaStream(); - T const *input_b_ptr = reinterpret_cast(images->getSpecialBuffer()); // this works only with 'c' direction - T *output_y_ptr = reinterpret_cast(output->specialBuffer()); + T const* pInput = images->getDataBuffer()->specialAsT(); //reinterpret_cast(images->getSpecialBuffer()); // this works only with 'c' direction + F* pOutput = output->dataBuffer()->specialAsT();//reinterpret_cast(output->specialBuffer()); dim3 batchSizeBlock(batchSize, 1, 1); dim3 pictureBlock(outHeight, outWidth, channels); - resizeImageKernel<<<256, pictureBlock, 256, *stream>>>(input_b_ptr, images->getSpecialShapeInfo(), output_y_ptr, output->specialShapeInfo(), batchSize, - outWidth, outHeight, channels, inRowSize, outRowSize, inBatchNumValues, xs_, ys_); + resizeImageKernel<<<256, 256, 256, *stream>>>(pInput, images->getSpecialShapeInfo(), pOutput, + output->specialShapeInfo(), batchSize, outWidth, outHeight, channels, inRowSize, outRowSize, + inBatchNumValues, xs_, ys_); auto err = cudaStreamSynchronize(*stream); if (err != 0) { @@ -129,8 +185,9 @@ namespace helpers { } //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - template - static int resizeBilinearFunctor_(nd4j::LaunchContext* context, NDArray const* images, int width, int height, bool center, NDArray* output) { + template + static int resizeBilinearFunctor_(nd4j::LaunchContext* context, NDArray const* images, int const width, + int const height, bool const alignCorners, bool const halfPixelCenter, NDArray* output) { const Nd4jLong batchSize = images->sizeAt(0); const Nd4jLong inHeight = images->sizeAt(1); const Nd4jLong inWidth = images->sizeAt(2); @@ -145,19 +202,8 @@ namespace helpers { return ND4J_STATUS_OK; } - // Special case for TF compatibility - if((center && inHeight < 2) || (center && inWidth < 2)){ - center = false; - } - - if ((center && inHeight < 2) || (inHeight < 1) || (outHeight < 1) || (center && outHeight < 2) || - (center && inWidth < 2) || (inWidth < 1) || (outWidth < 1) || (center && outWidth < 2)) { - // wrong input data - nd4j_printf("image.resize_bilinear: Wrong input or output size to resize\n", ""); - return ND4J_STATUS_BAD_ARGUMENTS; - } - float heightScale = center ? (inHeight - 1.f) / double(outHeight - 1.f) : (inHeight / float(outHeight)); - float widthScale = center ? (inWidth - 1.f) / double(outWidth - 1.f) : (inWidth / float(outWidth)); + float heightScale = calculateResizeScale(inHeight, outHeight, alignCorners); + float widthScale = calculateResizeScale(inWidth, outWidth, alignCorners); BilinearInterpolationData* xs_;// = xs.data(); BilinearInterpolationData* ys_;// = xs.data(); @@ -173,12 +219,24 @@ namespace helpers { } auto stream = context->getCudaStream(); // Compute the cached interpolation weights on the x and y dimensions. - computeInterpolationWeights<<<256, 512, 512, *stream>>>(outHeight, inHeight, heightScale, 0, ys_); - computeInterpolationWeights<<<256, 512, 512, *stream>>>(outWidth, inWidth, widthScale, channels, xs_); - + if (halfPixelCenter) { + computeInterpolationWeights < + HalfPixelScaler ><<<256, 512, 512, *stream>>>(outHeight, inHeight, heightScale, 0, ys_); + computeInterpolationWeights < + HalfPixelScaler ><<<256, 512, 512, *stream>>>(outWidth, inWidth, widthScale, channels, xs_); + } + else { + computeInterpolationWeights < + LegacyScaler ><<<256, 512, 512, *stream>>>(outHeight, inHeight, heightScale, 0, ys_); + computeInterpolationWeights < + LegacyScaler ><<<256, 512, 512, *stream>>>(outWidth, inWidth, widthScale, channels, xs_); + } + printf("Input is %dx%d, Output is %dx%d\n", inHeight, inWidth, outHeight, outWidth); NDArray::prepareSpecialUse({output}, {images}); - resizeImage(context, images, batchSize, inHeight, inWidth, outHeight, outWidth, channels, xs_, ys_, output); + resizeImage_(context, images, batchSize, inHeight, inWidth, outHeight, outWidth, channels, xs_, ys_, output); + err = cudaStreamSynchronize(*stream); NDArray::registerSpecialUse({output}, {images}); + err = cudaFree(xs_); if (err != 0) { throw cuda_exception::build("helpers::resize_image: Cannot deallocate memory for vertical parts rectangulars", err); @@ -197,20 +255,28 @@ namespace helpers { // template static __global__ void resizeNeighborKernel(T const* input, Nd4jLong* inputShape, T* output, Nd4jLong* outputShape, - Nd4jLong batchSize, Nd4jLong inWidth, Nd4jLong inHeight, Nd4jLong outWidth, Nd4jLong outHeight, Nd4jLong channels, double widthScale, double heightScale, bool center) { + Nd4jLong batchSize, Nd4jLong inWidth, Nd4jLong inHeight, Nd4jLong outWidth, Nd4jLong outHeight, Nd4jLong channels, double widthScale, double heightScale, bool alignCorners, bool halfPixelCenters) { //for (int b = blockIdx.x; b < batchSize; b += gridDim.x) if (blockIdx.x < batchSize) { auto b = blockIdx.x; for (int y = threadIdx.x; y < outHeight; y += blockDim.x) { - Nd4jLong inY = nd4j::math::nd4j_min( - (center) ? static_cast(nd4j::math::p_round(y * heightScale)) : static_cast(nd4j::math::p_floor( - y * heightScale)), inHeight - 1); + auto posY = alignCorners ? static_cast(nd4j::math::p_round(halfPixelCenters?((float)y + 0.5f) * heightScale:(float)y * heightScale)) : static_cast(nd4j::math::p_floor( + halfPixelCenters?((float)y + 0.5f) * heightScale:(float)y * heightScale)); + Nd4jLong inY = nd4j::math::nd4j_min(posY, inHeight - 1); + if (halfPixelCenters) { + inY = nd4j::math::nd4j_max(0LL, inY); + } + for (int x = threadIdx.y; x < outWidth; x += blockDim.y) { - Nd4jLong inX = nd4j::math::nd4j_min( - (center) ? static_cast(nd4j::math::p_round(x * widthScale)) : static_cast(nd4j::math::p_floor( - x * widthScale)), inWidth - 1); + auto posX = alignCorners ? static_cast(nd4j::math::p_round(halfPixelCenters?((float)x + 0.5f) * widthScale:(float)x * widthScale)) : static_cast(nd4j::math::p_floor( + halfPixelCenters?((float)x + 0.5f) * widthScale:(float)x * widthScale)); + Nd4jLong inX = nd4j::math::nd4j_min(posX, inWidth - 1); + if (halfPixelCenters) { + inX = nd4j::math::nd4j_max(0LL, inX); + } + auto start = blockIdx.z * blockDim.z + threadIdx.z; auto step = blockDim.z * gridDim.z; @@ -231,7 +297,8 @@ namespace helpers { // resizeNeighborFunctor - main algorithm by nearest neighbor // template - int resizeNeighborFunctor_(nd4j::LaunchContext* context, NDArray const* images, int width, int height, bool center, NDArray* output) { + int resizeNeighborFunctor_(nd4j::LaunchContext* context, NDArray const* images, int const width, int const height, + bool const alignCorners, bool const halfPixelCenters, NDArray* output) { const Nd4jLong batchSize = images->sizeAt(0); const Nd4jLong inHeight = images->sizeAt(1); const Nd4jLong inWidth = images->sizeAt(2); @@ -246,25 +313,24 @@ namespace helpers { return ND4J_STATUS_OK; } - if ((center && inHeight < 2) || (inHeight < 1) || (outHeight < 1) || (center && outHeight < 2) || - (center && inWidth < 2) || (inWidth < 1) || (outWidth < 1) || (center && outWidth < 2)) { - // wrong input data - nd4j_printf("image.resize_nearest_neighbor: Wrong input or output size to resize\n", ""); - return ND4J_STATUS_BAD_ARGUMENTS; - } - double heightScale = center ? (inHeight - 1.) / double(outHeight - 1.0) : (inHeight / double(outHeight)); - double widthScale = center ? (inWidth - 1.) / double(outWidth - 1.0) : (inWidth / double(outWidth)); - auto imagesBuffer = reinterpret_cast(images->getSpecialBuffer()); - auto outputBuffer = reinterpret_cast(output->specialBuffer()); +// if ((alignCorners && inHeight < 2) || (inHeight < 1) || (outHeight < 1) || (alignCorners && outHeight < 2) || +// (alignCorners && inWidth < 2) || (inWidth < 1) || (outWidth < 1) || (center && outWidth < 2)) { +// // wrong input data +// nd4j_printf("image.resize_nearest_neighbor: Wrong input or output size to resize\n", ""); +// return ND4J_STATUS_BAD_ARGUMENTS; +// } +// float heightScale = alignCorners ? (inHeight - 1.f) / float(outHeight - 1.f) : (inHeight / float(outHeight)); +// float widthScale = alignCorners ? (inWidth - 1.f) / float(outWidth - 1.f) : (inWidth / float(outWidth)); + float heightScale = calculateResizeScale(inHeight, outHeight, alignCorners); + float widthScale = calculateResizeScale(inWidth, outWidth, alignCorners); + + auto imagesBuffer = images->getDataBuffer()->specialAsT();//reinterpret_cast(images->getSpecialBuffer()); + auto outputBuffer = output->dataBuffer()->specialAsT();//reinterpret_cast(output->specialBuffer()); auto stream = context->getCudaStream(); - //T const* input, Nd4jLong const* inputShape, T* output, Nd4jLong* outputShape, - // Nd4jLong batchSize, Nd4jLong inWidth, Nd4jLong inHeight, Nd4jLong outWidth, Nd4jLong outHeight, Nd4jLong channels, double widthScale, double heightScale, bool center - //input, inputShape, output, outputShape, - // batchSize, inWidth, inHeight, outWidth, outHeight, channels, widthScale, heightScale, center NDArray::prepareSpecialUse({output}, {images}); resizeNeighborKernel<<>>(imagesBuffer, images->getSpecialShapeInfo(), outputBuffer, output->specialShapeInfo(), - batchSize, inWidth, inHeight, outWidth, outHeight, channels, widthScale, heightScale, center); + batchSize, inWidth, inHeight, outWidth, outHeight, channels, widthScale, heightScale, alignCorners, halfPixelCenters); NDArray::registerSpecialUse({output}, {images}); return Status::OK(); @@ -275,39 +341,38 @@ namespace helpers { void resizeImage(nd4j::LaunchContext* context, NDArray const* images, Nd4jLong batchSize, Nd4jLong inHeight, Nd4jLong inWidth, Nd4jLong outHeight, Nd4jLong outWidth, Nd4jLong channels, BilinearInterpolationData* xs_, BilinearInterpolationData* ys_, NDArray* output) { - BUILD_SINGLE_SELECTOR(images->dataType(), resizeImage_, (context, images, batchSize, inHeight, inWidth, outHeight, outWidth, channels, xs_, ys_, output), LIBND4J_TYPES); + BUILD_DOUBLE_SELECTOR(images->dataType(), output->dataType(), + resizeImage_, (context, images, batchSize, inHeight, inWidth, outHeight, outWidth, channels, + xs_, ys_, output), NUMERIC_TYPES, FLOAT_TYPES); } - BUILD_SINGLE_TEMPLATE(template void resizeImage_,(nd4j::LaunchContext* context, NDArray const* images, + BUILD_DOUBLE_TEMPLATE(template void resizeImage_,(nd4j::LaunchContext* context, NDArray const* images, Nd4jLong batchSize, Nd4jLong inHeight, Nd4jLong inWidth, Nd4jLong outHeight, Nd4jLong outWidth, - Nd4jLong channels, BilinearInterpolationData* xs_, BilinearInterpolationData* ys_, NDArray* output), LIBND4J_TYPES); + Nd4jLong channels, BilinearInterpolationData* xs_, BilinearInterpolationData* ys_, NDArray* output), + NUMERIC_TYPES, FLOAT_TYPES); //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - int resizeBilinearFunctor(nd4j::LaunchContext* context, NDArray const* images, int width, int height, bool center, NDArray* output) { - BUILD_SINGLE_SELECTOR(images->dataType(), return resizeBilinearFunctor_, (context, images, width, height, center, output), LIBND4J_TYPES); + int resizeBilinearFunctor(nd4j::LaunchContext* context, NDArray const* images, int width, int height, + bool const alignCorners, bool const halfPixelCenter, NDArray* output) { + BUILD_DOUBLE_SELECTOR(images->dataType(), output->dataType(), return resizeBilinearFunctor_, (context, images, + width, height, alignCorners, halfPixelCenter, output), NUMERIC_TYPES, FLOAT_TYPES); } - BUILD_SINGLE_TEMPLATE(template int resizeBilinearFunctor_, (nd4j::LaunchContext* context, NDArray const* images, int width, int height, bool center, NDArray* output), LIBND4J_TYPES); +// BUILD_SINGLE_TEMPLATE(template int resizeBilinearFunctor_, (nd4j::LaunchContext* context, +// NDArray const* images, int const width, int const height, bool const alignCorners, +// bool const halfPixelCenter, NDArray* output), LIBND4J_TYPES); //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - int resizeNeighborFunctor(nd4j::LaunchContext* context, NDArray const* images, int width, int height, bool center, NDArray* output) { - BUILD_SINGLE_SELECTOR(images->dataType(), return resizeNeighborFunctor_, (context, images, width, height, center, output), LIBND4J_TYPES); + int resizeNeighborFunctor(nd4j::LaunchContext* context, NDArray const* images, int const width, int const height, + bool const alignCorners, bool const halfPixelCenter, NDArray* output) { + BUILD_SINGLE_SELECTOR(images->dataType(), return resizeNeighborFunctor_, + (context, images, width, height, alignCorners, halfPixelCenter, output), LIBND4J_TYPES); } - BUILD_SINGLE_TEMPLATE(template int resizeNeighborFunctor_, (nd4j::LaunchContext* context, NDArray const* images, - int width, int height, bool center, NDArray* output), LIBND4J_TYPES); +// BUILD_SINGLE_TEMPLATE(template int resizeNeighborFunctor_, (nd4j::LaunchContext* context, NDArray const* images, +// int width, int height, bool const alignCorners, bool const halfPixelCenter, NDArray* output), LIBND4J_TYPES); //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // Bicubic interpolation //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -// Utility functions and classes - - // calculateResizeScale determines the float scaling factor. - inline float calculateResizeScale(Nd4jLong inSize, Nd4jLong outSize, - bool alignCorners) { - return (alignCorners && outSize > 1) - ? (inSize - 1) / static_cast(outSize - 1) - : inSize / static_cast(outSize); - } - struct ImageResizerState { explicit ImageResizerState(bool alignCorners, bool halfPixelCenters) : _alignCorners(alignCorners), @@ -362,17 +427,6 @@ namespace helpers { bool _halfPixelCenters; }; - // Half pixel scaler scales assuming that the pixel centers are at 0.5, i.e. the -// floating point coordinates of the top,left pixel is 0.5,0.5. - struct HalfPixelScaler { - _CUDA_HD HalfPixelScaler(){}; - inline _CUDA_HD float operator()(const int x, const float scale) const { - // Note that we subtract 0.5 from the return value, as the existing bilinear - // sampling code etc assumes pixels are in the old coordinate system. - return (static_cast(x) + 0.5f) * scale - 0.5f; - } - }; - struct WeightsAndIndices { float _weight0; float _weight1; @@ -547,16 +601,6 @@ namespace helpers { } } -// Older incorrect scaling method that causes all resizes to have a slight -// translation leading to inconsistent results. For example, a flip then a -// resize gives different results then a resize then a flip. - struct LegacyScaler { - _CUDA_HD LegacyScaler(){}; - inline _CUDA_HD float operator()(const int x, const float scale) const { - return static_cast(x) * scale; - } - }; - static __global__ void accumulateChannelsKernel(WeightsAndIndices* pXWais, Nd4jLong outWidth, Nd4jLong channels) { auto start = blockIdx.x * blockDim.x + threadIdx.x; auto step = blockDim.x * gridDim.x; @@ -906,8 +950,8 @@ namespace helpers { int resizeFunctor(nd4j::LaunchContext * context, NDArray const* image, int width, int height, ImageResizeMethods method, bool preserveAspectRatio, bool antialias, NDArray* output) { switch (method) { - case kResizeBilinear: return resizeBilinearFunctor(context, image, width, height, false, output); break; - case kResizeNearest: return resizeNeighborFunctor(context, image, width, height, true, output); break; + case kResizeBilinear: return resizeBilinearFunctor(context, image, width, height, false, false, output); break; + case kResizeNearest: return resizeNeighborFunctor(context, image, width, height, false, false, output); break; case kResizeBicubic: return resizeBicubicFunctor(context, image, width, height, preserveAspectRatio, antialias, output); break; case kResizeLanczos5: case kResizeGaussian: diff --git a/libnd4j/include/ops/declarable/helpers/image_resize.h b/libnd4j/include/ops/declarable/helpers/image_resize.h index 22c41833b..d52fd74f7 100644 --- a/libnd4j/include/ops/declarable/helpers/image_resize.h +++ b/libnd4j/include/ops/declarable/helpers/image_resize.h @@ -37,15 +37,15 @@ namespace helpers { kResizeArea }; - int resizeBilinearFunctor(nd4j::LaunchContext * context, NDArray const* image, int width, int height, bool center, - NDArray* output); - int resizeNeighborFunctor(nd4j::LaunchContext * context, NDArray const* image, int width, int height, bool center, - NDArray* output); - int resizeBicubicFunctor(nd4j::LaunchContext * context, NDArray const* image, int width, int height, + int resizeBilinearFunctor(nd4j::LaunchContext * context, NDArray const* image, int const width, int const height, + bool const alignCorners, bool const halfPixelCenter, NDArray* output); + int resizeNeighborFunctor(nd4j::LaunchContext * context, NDArray const* image, int const width, int const height, + bool const alignCorners, bool const halfPixelCenter, NDArray* output); + int resizeBicubicFunctor(nd4j::LaunchContext * context, NDArray const* image, int const width, int const height, bool preserveAspectRatio, bool antialias, NDArray* output); - int resizeBicubicFunctorA(nd4j::LaunchContext * context, NDArray const* image, int width, int height, + int resizeBicubicFunctorA(nd4j::LaunchContext * context, NDArray const* image, int const width, int const height, bool const alignCorners, bool const halfPixelAlign, NDArray* output); - int resizeFunctor(nd4j::LaunchContext * context, NDArray const* image, int width, int height, + int resizeFunctor(nd4j::LaunchContext * context, NDArray const* image, int const width, int const height, ImageResizeMethods method, bool preserveAspectRatio, bool antialias, NDArray* output); void cropAndResizeFunctor(nd4j::LaunchContext * context, NDArray const* images, NDArray const* boxes, diff --git a/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp b/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp index 99cc98af9..eccb73c6c 100644 --- a/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp +++ b/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp @@ -419,7 +419,17 @@ TEST_F(ConvolutionTests1, sconv2d_1) { ////////////////////////////////////////////////////////////////////// TYPED_TEST(TypedConvolutionTests1, sconv2d_2) { - TypeParam _expBFF[] = {108.9405008, 109.5920008, 110.2435008, 110.8950008, 111.5465008, 112.1980008, 115.4555008, 116.1070008, 116.7585008, 117.410000, 118.061500, 118.7130009, 121.9705009, 122.6220009, 123.2735009, 123.9250009, 124.5765009, 125.2280009, 128.4855009, 129.1370009, 129.7885009, 130.4400009, 131.09150, 131.74300, 135.0005010, 135.6520010, 136.3035010, 136.9550010, 137.6065010, 138.2580010, 141.5155010, 142.1670010, 142.8185010, 143.4700010, 144.1215010, 144.7730010, 248.9617514, 250.670751, 252.3797515, 254.0887515, 255.7977515, 257.5067515, 266.0517515, 267.7607515, 269.469751, 271.1787516, 272.8877516, 274.5967516, 283.1417516, 284.8507516, 286.5597516, 288.268751, 289.9777517, 291.6867517, 300.2317517, 301.9407517, 303.6497517, 305.3587517, 307.067751, 308.7767518, 317.3217518, 319.0307518, 320.7397518, 322.4487518, 324.157751, 325.866751, 334.4117519, 336.1207519, 337.8297519, 339.5387519, 341.2477519, 342.95675, 388.9829964, 391.7494964, 394.5159964, 397.2824964, 400.048996, 402.8154963, 416.647996, 419.4144962, 422.1809962, 424.9474962, 427.7139962, 430.4804962, 444.3129961, 447.0794961, 449.8459961, 452.6124960, 455.3789960, 458.1454960, 471.9779959, 474.7444959, 477.5109959, 480.2774959, 483.0439959, 485.8104958, 499.6429958, 502.4094957, 505.1759957, 507.9424957, 510.7089957, 513.4754957, 527.3079956, 530.0744956, 532.8409956, 535.607495, 538.3739955, 541.1404955, 529.0042487, 532.8282487, 536.6522487, 540.4762487, 544.3002487, 548.1242487, 567.2442487, 571.068248, 574.892248, 578.716248, 582.540248, 586.3642486, 605.4842486, 609.3082486, 613.1322486, 616.9562486, 620.7802486, 624.6042486, 643.7242486, 647.5482486, 651.3722486, 655.1962486, 659.0202486, 662.8442486, 681.9642486, 685.7882486, 689.6122486, 693.4362486, 697.2602486, 701.0842486, 720.2042486, 724.0282486, 727.852248, 731.676248, 735.500248, 739.324248, 669.0255044, 673.9070044, 678.7885044, 683.6700044, 688.5515044, 693.4330044, 717.8405044, 722.7220044, 727.6035044, 732.4850044, 737.3665044, 742.2480044, 766.6555043, 771.5370043, 776.4185043, 781.3000043, 786.1815043, 791.0630043, 815.4705043, 820.3520043, 825.2335043, 830.1150043, 834.9965043, 839.8780043, 864.2855042, 869.1670042, 874.0485042, 878.9300042, 883.8115042, 888.6930042, 913.1005042, 917.9820042, 922.8635042, 927.7450042, 932.6265042, 937.5080042, 809.0467424, 814.9857424, 820.9247424, 826.8637423, 832.8027423, 838.7417423, 868.4367421, 874.3757421, 880.3147420, 886.2537420, 892.1927420, 898.13174, 927.8267418, 933.7657418, 939.7047417, 945.6437417, 951.5827417, 957.5217416, 987.2167415, 993.155741, 999.0947414, 1005.0337414, 1010.972741, 1016.9117413, 1046.6067412, 1052.5457411, 1058.4847411, 1064.4237411, 1070.3627410, 1076.3017410, 1105.996740, 1111.9357408, 1117.8747408, 1123.8137408, 1129.7527407, 1135.6917407, 949.0679815, 956.0644814, 963.060981, 970.0574813, 977.0539812, 984.0504811, 1019.0329807, 1026.0294807, 1033.0259806, 1040.0224805, 1047.0189804, 1054.0154804, 1088.9979800, 1095.9944799, 1102.9909798, 1109.987479, 1116.9839797, 1123.9804796, 1158.9629792, 1165.9594791, 1172.9559791, 1179.9524790, 1186.9489789, 1193.9454788, 1228.9279785, 1235.9244784, 1242.9209783, 1249.9174782, 1256.913978, 1263.9104781, 1298.8929777, 1305.8894776, 1312.8859775, 1319.8824775, 1326.8789774, 1333.8754773, 1089.0892560, 1097.1432561, 1105.1972562, 1113.251256, 1121.3052563, 1129.3592564, 1169.6292568, 1177.6832568, 1185.7372569, 1193.7912570, 1201.845257, 1209.8992571, 1250.1692575, 1258.2232576, 1266.2772576, 1274.3312577, 1282.3852578, 1290.4392579, 1330.7092582, 1338.7632583, 1346.8172584, 1354.8712584, 1362.9252585, 1370.9792586, 1411.24925, 1419.3032590, 1427.3572591, 1435.4112592, 1443.465259, 1451.5192593, 1491.7892597, 1499.8432598, 1507.8972598, 1515.9512599, 1524.0052600, 1532.059260, 1229.1105073, 1238.2220073, 1247.3335073, 1256.4450073, 1265.5565073, 1274.668007, 1320.2255074, 1329.3370074, 1338.4485074, 1347.5600075, 1356.6715075, 1365.7830075, 1411.340507, 1420.4520076, 1429.5635076, 1438.6750076, 1447.7865076, 1456.8980076, 1502.4555077, 1511.5670077, 1520.6785077, 1529.7900077, 1538.9015077, 1548.013007, 1593.5705078, 1602.6820078, 1611.793507, 1620.9050079, 1630.0165079, 1639.1280079, 1684.6855080, 1693.7970080, 1702.9085080, 1712.0200080, 1721.1315080, 1730.2430080, 1369.1317613, 1379.3007614, 1389.4697614, 1399.6387615, 1409.8077615, 1419.976761, 1470.8217618, 1480.9907618, 1491.159761, 1501.3287619, 1511.4977619, 1521.6667620, 1572.5117622, 1582.6807622, 1592.8497623, 1603.0187623, 1613.1877624, 1623.3567624, 1674.2017626, 1684.3707627, 1694.5397627, 1704.7087628, 1714.8777628, 1725.046762, 1775.8917631, 1786.0607631, 1796.229763, 1806.3987632, 1816.5677632, 1826.7367633, 1877.5817635, 1887.7507635, 1897.9197636, 1908.0887636, 1918.2577637, 1928.4267637, 304.3905022, 305.0420022, 305.6935022, 306.3450022, 306.9965022, 307.6480022, 310.9055022, 311.5570022, 312.208502, 312.860002, 313.5115023, 314.1630023, 317.4205023, 318.0720023, 318.7235023, 319.3750023, 320.0265023, 320.6780023, 323.9355023, 324.5870023, 325.2385023, 325.8900023, 326.541502, 327.193002, 330.4505024, 331.1020024, 331.7535024, 332.4050024, 333.0565024, 333.7080024, 336.9655024, 337.6170024, 338.2685024, 338.9200024, 339.5715024, 340.223002, 761.6617542, 763.3707542, 765.0797542, 766.7887542, 768.4977542, 770.206754, 778.7517543, 780.4607543, 782.1697543, 783.8787543, 785.5877543, 787.2967543, 795.8417544, 797.5507544, 799.2597544, 800.9687544, 802.6777544, 804.3867544, 812.9317545, 814.6407545, 816.3497545, 818.0587545, 819.7677545, 821.4767545, 830.0217546, 831.7307546, 833.4397546, 835.1487546, 836.8577546, 838.5667546, 847.1117547, 848.8207547, 850.5297547, 852.2387547, 853.9477547, 855.6567547, 1218.9329915, 1221.6994915, 1224.4659915, 1227.232491, 1229.9989914, 1232.7654914, 1246.5979913, 1249.3644913, 1252.1309913, 1254.8974913, 1257.6639913, 1260.430491, 1274.2629912, 1277.029491, 1279.7959911, 1282.5624911, 1285.3289911, 1288.0954911, 1301.9279910, 1304.6944910, 1307.4609910, 1310.22749, 1312.9939909, 1315.7604909, 1329.5929908, 1332.3594908, 1335.1259908, 1337.8924908, 1340.6589908, 1343.4254908, 1357.2579907, 1360.0244907, 1362.7909906, 1365.5574906, 1368.3239906, 1371.0904906, 1676.2042479, 1680.0282479, 1683.8522479, 1687.6762479, 1691.5002479, 1695.3242479, 1714.4442479, 1718.2682479, 1722.0922479, 1725.9162479, 1729.7402479, 1733.5642479, 1752.6842479, 1756.5082479, 1760.3322479, 1764.1562479, 1767.9802479, 1771.8042479, 1790.9242479, 1794.7482479, 1798.5722479, 1802.3962479, 1806.2202479, 1810.044247, 1829.1642478, 1832.9882478, 1836.8122478, 1840.6362478, 1844.4602478, 1848.2842478, 1867.4042478, 1871.2282478, 1875.0522478, 1878.8762478, 1882.7002478, 1886.5242478, 2133.4755029, 2138.3570029, 2143.2385029, 2148.1200029, 2153.0015029, 2157.8830029, 2182.2905028, 2187.1720028, 2192.0535028, 2196.9350028, 2201.8165028, 2206.6980028, 2231.1055028, 2235.9870028, 2240.8685028, 2245.7500028, 2250.6315028, 2255.5130028, 2279.9205027, 2284.8020027, 2289.6835027, 2294.5650027, 2299.4465027, 2304.3280027, 2328.7355027, 2333.6170027, 2338.4985027, 2343.3800027, 2348.2615027, 2353.1430027, 2377.5505026, 2382.4320026, 2387.3135026, 2392.1950026, 2397.0765026, 2401.9580026, 2590.7467330, 2596.6857330, 2602.6247329, 2608.5637329, 2614.5027329, 2620.441732, 2650.1367327, 2656.0757327, 2662.0147326, 2667.9537326, 2673.8927326, 2679.8317325, 2709.5267324, 2715.465732, 2721.4047323, 2727.3437323, 2733.282732, 2739.2217322, 2768.9167321, 2774.8557320, 2780.7947320, 2786.7337320, 2792.6727319, 2798.6117319, 2828.306731, 2834.2457317, 2840.1847317, 2846.1237317, 2852.0627316, 2858.0017316, 2887.6967314, 2893.6357314, 2899.5747314, 2905.5137313, 2911.4527313, 2917.3917313, 3048.0179587, 3055.0144586, 3062.0109585, 3069.0074584, 3076.0039584, 3083.0004583, 3117.9829579, 3124.9794578, 3131.9759578, 3138.9724577, 3145.9689576, 3152.9654575, 3187.947957, 3194.9444571, 3201.9409570, 3208.9374569, 3215.933956, 3222.9304568, 3257.9129564, 3264.9094563, 3271.9059562, 3278.9024562, 3285.8989561, 3292.8954560, 3327.8779556, 3334.874455, 3341.8709555, 3348.8674554, 3355.8639553, 3362.860455, 3397.8429549, 3404.8394548, 3411.8359547, 3418.8324546, 3425.8289546, 3432.8254545, 3505.28927, 3513.3432780, 3521.3972781, 3529.4512782, 3537.5052782, 3545.5592783, 3585.8292787, 3593.8832788, 3601.9372788, 3609.9912789, 3618.0452790, 3626.099279, 3666.3692794, 3674.4232795, 3682.4772796, 3690.5312796, 3698.5852797, 3706.6392798, 3746.9092801, 3754.9632802, 3763.0172803, 3771.0712804, 3779.1252804, 3787.1792805, 3827.4492809, 3835.50328, 3843.5572810, 3851.6112811, 3859.6652812, 3867.7192812, 3907.9892816, 3916.0432817, 3924.097281, 3932.1512818, 3940.2052819, 3948.2592820, 3962.5605113, 3971.6720113, 3980.783511, 3989.8950114, 3999.0065114, 4008.1180114, 4053.6755115, 4062.7870115, 4071.8985115, 4081.0100115, 4090.1215115, 4099.2330115, 4144.7905116, 4153.9020116, 4163.0135116, 4172.1250116, 4181.236511, 4190.3480117, 4235.9055117, 4245.0170117, 4254.128511, 4263.2400118, 4272.3515118, 4281.4630118, 4327.0205119, 4336.1320119, 4345.2435119, 4354.3550119, 4363.4665119, 4372.5780119, 4418.1355120, 4427.2470120, 4436.3585120, 4445.4700120, 4454.581512, 4463.6930121, 4419.8317743, 4430.0007744, 4440.1697744, 4450.338774, 4460.5077745, 4470.6767745, 4521.521774, 4531.6907748, 4541.8597748, 4552.0287749, 4562.1977749, 4572.3667750, 4623.2117752, 4633.3807752, 4643.5497753, 4653.7187753, 4663.8877754, 4674.0567754, 4724.9017756, 4735.0707757, 4745.2397757, 4755.4087757, 4765.5777758, 4775.7467758, 4826.591776, 4836.7607761, 4846.9297761, 4857.0987762, 4867.2677762, 4877.4367763, 4928.2817765, 4938.4507765, 4948.6197766, 4958.7887766, 4968.957776, 4979.12677675}; + TypeParam _expBFF[] = {108.9405008f, 109.5920008f, 110.2435008f, 110.8950008f, 111.5465008f, 112.1980008f, 115.4555008f, 116.1070008f, 116.7585008f, 117.410000f, 118.061500f, 118.7130009f, 121.9705009f, 122.6220009f, 123.2735009f, 123.9250009f, 124.5765009f, 125.2280009f, 128.4855009f, 129.1370009f, 129.7885009f, 130.4400009f, 131.09150f, 131.74300f, 135.0005010f, 135.6520010f, 136.3035010f, 136.9550010f, 137.6065010f, 138.2580010f, 141.5155010f, 142.1670010f, 142.8185010f, 143.4700010f, 144.1215010f, 144.7730010f, 248.9617514f, 250.670751f, 252.3797515f, 254.0887515f, 255.7977515f, 257.5067515f, 266.0517515f, 267.7607515f, 269.469751f, 271.1787516f, 272.8877516f, 274.5967516f, 283.1417516f, 284.8507516f, + 286.5597516f, 288.268751f, 289.9777517f, 291.6867517f, 300.2317517f, 301.9407517f, 303.6497517f, 305.3587517f, 307.067751f, 308.7767518f, 317.3217518f, 319.0307518f, 320.7397518f, 322.4487518f, 324.157751f, 325.866751f, 334.4117519f, 336.1207519f, 337.8297519f, 339.5387519f, 341.2477519f, 342.95675f, 388.9829964f, 391.7494964f, 394.5159964f, 397.2824964f, 400.048996f, 402.8154963f, 416.647996f, 419.4144962f, 422.1809962f, 424.9474962f, 427.7139962f, 430.4804962f, 444.3129961f, 447.0794961f, 449.8459961f, 452.6124960f, 455.3789960f, 458.1454960f, 471.9779959f, 474.7444959f, 477.5109959f, 480.2774959f, 483.0439959f, 485.8104958f, 499.6429958f, 502.4094957f, 505.1759957f, 507.9424957f, + 510.7089957f, 513.4754957f, 527.3079956f, 530.0744956f, 532.8409956f, 535.607495f, 538.3739955f, 541.1404955f, 529.0042487f, 532.8282487f, 536.6522487f, 540.4762487f, 544.3002487f, 548.1242487f, 567.2442487f, 571.068248f, 574.892248f, 578.716248f, 582.540248f, 586.3642486f, 605.4842486f, 609.3082486f, 613.1322486f, 616.9562486f, 620.7802486f, 624.6042486f, 643.7242486f, 647.5482486f, 651.3722486f, 655.1962486f, 659.0202486f, 662.8442486f, 681.9642486f, 685.7882486f, 689.6122486f, 693.4362486f, 697.2602486f, 701.0842486f, 720.2042486f, 724.0282486f, 727.852248f, 731.676248f, 735.500248f, 739.324248f, 669.0255044f, 673.9070044f, 678.7885044f, 683.6700044f, 688.5515044f, 693.4330044f, + 717.8405044f, 722.7220044f, 727.6035044f, 732.4850044f, 737.3665044f, 742.2480044f, 766.6555043f, 771.5370043f, 776.4185043f, 781.3000043f, 786.1815043f, 791.0630043f, 815.4705043f, 820.3520043f, 825.2335043f, 830.1150043f, 834.9965043f, 839.8780043f, 864.2855042f, 869.1670042f, 874.0485042f, 878.9300042f, 883.8115042f, 888.6930042f, 913.1005042f, 917.9820042f, 922.8635042f, 927.7450042f, 932.6265042f, 937.5080042f, 809.0467424f, 814.9857424f, 820.9247424f, 826.8637423f, 832.8027423f, 838.7417423f, 868.4367421f, 874.3757421f, 880.3147420f, 886.2537420f, 892.1927420f, 898.13174f, 927.8267418f, 933.7657418f, 939.7047417f, 945.6437417f, 951.5827417f, 957.5217416f, 987.2167415f, 993.155741f, + 999.0947414f, 1005.0337414f, 1010.972741f, 1016.9117413f, 1046.6067412f, 1052.5457411f, 1058.4847411f, 1064.4237411f, 1070.3627410f, 1076.3017410f, 1105.996740f, 1111.9357408f, 1117.8747408f, 1123.8137408f, 1129.7527407f, 1135.6917407f, 949.0679815f, 956.0644814f, 963.060981f, 970.0574813f, 977.0539812f, 984.0504811f, 1019.0329807f, 1026.0294807f, 1033.0259806f, 1040.0224805f, 1047.0189804f, 1054.0154804f, 1088.9979800f, 1095.9944799f, 1102.9909798f, 1109.987479f, 1116.9839797f, 1123.9804796f, 1158.9629792f, 1165.9594791f, 1172.9559791f, 1179.9524790f, 1186.9489789f, 1193.9454788f, 1228.9279785f, 1235.9244784f, 1242.9209783f, 1249.9174782f, 1256.913978f, 1263.9104781f, 1298.8929777f, 1305.8894776f, 1312.8859775f, 1319.8824775f, 1326.8789774f, 1333.8754773f, 1089.0892560f, 1097.1432561f, 1105.1972562f, 1113.251256f, 1121.3052563f, 1129.3592564f, 1169.6292568f, 1177.6832568f, 1185.7372569f, 1193.7912570f, 1201.845257f, 1209.8992571f, 1250.1692575f, 1258.2232576f, 1266.2772576f, 1274.3312577f, 1282.3852578f, 1290.4392579f, 1330.7092582f, 1338.7632583f, 1346.8172584f, 1354.8712584f, 1362.9252585f, 1370.9792586f, 1411.24925f, 1419.3032590f, 1427.3572591f, 1435.4112592f, 1443.465259f, 1451.5192593f, 1491.7892597f, 1499.8432598f, 1507.8972598f, 1515.9512599f, 1524.0052600f, 1532.059260f, 1229.1105073f, 1238.2220073f, 1247.3335073f, 1256.4450073f, 1265.5565073f, 1274.668007f, 1320.2255074f, 1329.3370074f, 1338.4485074f, 1347.5600075f, 1356.6715075f, 1365.7830075f, 1411.340507f, 1420.4520076f, 1429.5635076f, 1438.6750076f, 1447.7865076f, 1456.8980076f, 1502.4555077f, 1511.5670077f, 1520.6785077f, 1529.7900077f, 1538.9015077f, 1548.013007f, 1593.5705078f, 1602.6820078f, 1611.793507f, 1620.9050079f, 1630.0165079f, 1639.1280079f, 1684.6855080f, 1693.7970080f, 1702.9085080f, 1712.0200080f, 1721.1315080f, 1730.2430080f, 1369.1317613f, 1379.3007614f, 1389.4697614f, 1399.6387615f, 1409.8077615f, 1419.976761f, 1470.8217618f, 1480.9907618f, 1491.159761f, 1501.3287619f, 1511.4977619f, 1521.6667620f, 1572.5117622f, 1582.6807622f, 1592.8497623f, 1603.0187623f, 1613.1877624f, 1623.3567624f, 1674.2017626f, 1684.3707627f, 1694.5397627f, 1704.7087628f, 1714.8777628f, 1725.046762f, 1775.8917631f, 1786.0607631f, 1796.229763f, 1806.3987632f, 1816.5677632f, 1826.7367633f, 1877.5817635f, 1887.7507635f, 1897.9197636f, 1908.0887636f, 1918.2577637f, 1928.4267637f, 304.3905022f, 305.0420022f, 305.6935022f, 306.3450022f, 306.9965022f, 307.6480022f, 310.9055022f, 311.5570022f, 312.208502f, 312.860002f, 313.5115023f, 314.1630023f, 317.4205023f, 318.0720023f, 318.7235023f, 319.3750023f, 320.0265023f, 320.6780023f, 323.9355023f, 324.5870023f, 325.2385023f, 325.8900023f, 326.541502f, 327.193002f, 330.4505024f, 331.1020024f, 331.7535024f, 332.4050024f, 333.0565024f, 333.7080024f, 336.9655024f, 337.6170024f, 338.2685024f, 338.9200024f, 339.5715024f, 340.223002f, 761.6617542f, 763.3707542f, 765.0797542f, 766.7887542f, 768.4977542f, 770.206754f, 778.7517543f, 780.4607543f, 782.1697543f, 783.8787543f, 785.5877543f, 787.2967543f, 795.8417544f, 797.5507544f, 799.2597544f, 800.9687544f, 802.6777544f, 804.3867544f, 812.9317545f, 814.6407545f, 816.3497545f, 818.0587545f, 819.7677545f, 821.4767545f, 830.0217546f, 831.7307546f, 833.4397546f, 835.1487546f, 836.8577546f, 838.5667546f, 847.1117547f, 848.8207547f, 850.5297547f, 852.2387547f, 853.9477547f, 855.6567547f, 1218.9329915f, 1221.6994915f, 1224.4659915f, 1227.232491f, 1229.9989914f, 1232.7654914f, 1246.5979913f, 1249.3644913f, 1252.1309913f, 1254.8974913f, 1257.6639913f, 1260.430491f, 1274.2629912f, 1277.029491f, 1279.7959911f, 1282.5624911f, 1285.3289911f, 1288.0954911f, 1301.9279910f, 1304.6944910f, 1307.4609910f, 1310.22749f, 1312.9939909f, 1315.7604909f, 1329.5929908f, 1332.3594908f, 1335.1259908f, 1337.8924908f, 1340.6589908f, 1343.4254908f, 1357.2579907f, + 1360.0244907f, 1362.7909906f, 1365.5574906f, 1368.3239906f, 1371.0904906f, 1676.2042479f, 1680.0282479f, 1683.8522479f, 1687.6762479f, 1691.5002479f, 1695.3242479f, 1714.4442479f, 1718.2682479f, 1722.0922479f, 1725.9162479f, 1729.7402479f, 1733.5642479f, 1752.6842479f, 1756.5082479f, 1760.3322479f, 1764.1562479f, 1767.9802479f, 1771.8042479f, 1790.9242479f, 1794.7482479f, 1798.5722479f, 1802.3962479f, 1806.2202479f, 1810.044247f, 1829.1642478f, 1832.9882478f, 1836.8122478f, 1840.6362478f, 1844.4602478f, 1848.2842478f, 1867.4042478f, 1871.2282478f, 1875.0522478f, 1878.8762478f, 1882.7002478f, 1886.5242478f, 2133.4755029f, 2138.3570029f, 2143.2385029f, 2148.1200029f, 2153.0015029f, 2157.8830029f, 2182.2905028f, 2187.1720028f, 2192.0535028f, 2196.9350028f, 2201.8165028f, 2206.6980028f, 2231.1055028f, 2235.9870028f, 2240.8685028f, 2245.7500028f, 2250.6315028f, 2255.5130028f, 2279.9205027f, 2284.8020027f, 2289.6835027f, 2294.5650027f, 2299.4465027f, 2304.3280027f, 2328.7355027f, 2333.6170027f, 2338.4985027f, 2343.3800027f, 2348.2615027f, 2353.1430027f, 2377.5505026f, 2382.4320026f, 2387.3135026f, 2392.1950026f, 2397.0765026f, 2401.9580026f, 2590.7467330f, 2596.6857330f, 2602.6247329f, 2608.5637329f, 2614.5027329f, 2620.441732f, 2650.1367327f, 2656.0757327f, 2662.0147326f, 2667.9537326f, 2673.8927326f, 2679.8317325f, 2709.5267324f, 2715.465732f, 2721.4047323f, 2727.3437323f, 2733.282732f, 2739.2217322f, 2768.9167321f, 2774.8557320f, 2780.7947320f, 2786.7337320f, 2792.6727319f, 2798.6117319f, 2828.306731f, 2834.2457317f, 2840.1847317f, 2846.1237317f, 2852.0627316f, 2858.0017316f, 2887.6967314f, 2893.6357314f, 2899.5747314f, 2905.5137313f, 2911.4527313f, 2917.3917313f, 3048.0179587f, 3055.0144586f, 3062.0109585f, 3069.0074584f, 3076.0039584f, 3083.0004583f, 3117.9829579f, 3124.9794578f, 3131.9759578f, 3138.9724577f, 3145.9689576f, 3152.9654575f, 3187.947957f, 3194.9444571f, 3201.9409570f, 3208.9374569f, 3215.933956f, 3222.9304568f, 3257.9129564f, 3264.9094563f, 3271.9059562f, 3278.9024562f, 3285.8989561f, + 3292.8954560f, 3327.8779556f, 3334.874455f, 3341.8709555f, 3348.8674554f, 3355.8639553f, 3362.860455f, 3397.8429549f, 3404.8394548f, 3411.8359547f, 3418.8324546f, 3425.8289546f, 3432.8254545f, 3505.28927f, 3513.3432780f, 3521.3972781f, 3529.4512782f, 3537.5052782f, 3545.5592783f, 3585.8292787f, 3593.8832788f, 3601.9372788f, 3609.9912789f, 3618.0452790f, 3626.099279f, + 3666.3692794f, 3674.4232795f, 3682.4772796f, 3690.5312796f, 3698.5852797f, 3706.6392798f, 3746.9092801f, 3754.9632802f, 3763.0172803f, 3771.0712804f, 3779.1252804f, 3787.1792805f, 3827.4492809f, 3835.50328f, 3843.5572810f, 3851.6112811f, 3859.6652812f, 3867.7192812f, 3907.9892816f, 3916.0432817f, 3924.097281f, + 3932.1512818f, 3940.2052819f, 3948.2592820f, 3962.5605113f, 3971.6720113f, 3980.783511f, 3989.8950114f, 3999.0065114f, 4008.1180114f, 4053.6755115f, 4062.7870115f, 4071.8985115f, 4081.0100115f, 4090.1215115f, 4099.2330115f, 4144.7905116f, 4153.9020116f, 4163.0135116f, 4172.1250116f, + 4181.236511f, 4190.3480117f, 4235.9055117f, 4245.0170117f, 4254.128511f, 4263.2400118f, 4272.3515118f, 4281.4630118f, 4327.0205119f, 4336.1320119f, 4345.2435119f, 4354.3550119f, 4363.4665119f, 4372.5780119f, 4418.1355120f, 4427.2470120f, 4436.3585120f, 4445.4700120f, 4454.581512f, 4463.6930121f, 4419.8317743f, 4430.0007744f, 4440.1697744f, 4450.338774f, 4460.5077745f, 4470.6767745f, 4521.521774f, 4531.6907748f, + 4541.8597748f, 4552.0287749f, 4562.1977749f, 4572.3667750f, 4623.2117752f, 4633.3807752f, 4643.5497753f, 4653.7187753f, 4663.8877754f, 4674.0567754f, 4724.9017756f, 4735.0707757f, 4745.2397757f, 4755.4087757f, 4765.5777758f, 4775.7467758f, 4826.591776f, 4836.7607761f, 4846.9297761f, 4857.0987762f, 4867.2677762f, 4877.4367763f, 4928.2817765f, 4938.4507765f, 4948.6197766f, 4958.7887766f, 4968.957776f, 4979.12677675f}; Nd4jLong _expSFF[] = {4, 2, 10, 6, 6, 360, 36, 6, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, 1, 99,}; NDArray expFF(_expBFF, _expSFF); @@ -625,11 +635,11 @@ TYPED_TEST(TypedConvolutionTests1, conv2D_BP_NoBias_1) { } TYPED_TEST(TypedConvolutionTests1, sconv2d_conv2d_1) { - TypeParam _expBFF[] = {10025.0f, 10350.0f, 10675.0f, 11000.0f, 11325.0f, 11650.0f, 13275.0f, 13600.0f, 13925.0f, 14250.0f, 14575.0f, 14900.0f, 16525.0f, 16850.0f, 17175.0f, 17500.0f, 17825.0f, 18150.0f, 19775.0f, 20100.0f, 20425.0f, 20750.0f, 21075.0f, 21400.0f, 23025.0f, 23350.0f, 23675.0f, 24000.0f, 24325.0f, 24650.0f, 26275.0f, 26600.0f, 26925.0f, 27250.0f, 27575.0f, 27900.0f, 53150.0f, 55350.0f, 57550.0f, 59750.0f, 61950.0f, 64150.0f, 75150.0f, 77350.0f, 79550.0f, 81750.0f, 83950.0f, 86150.0f, 97150.0f, 99350.0f, 101550.0f, 103750.0f, 105950.0f, 108150.0f, 119150.0f, 121350.0f, 123550.0f, 125750.0f, 127950.0f, 130150.0f, 141150.0f, 143350.0f, 145550.0f, 147750.0f, 149950.0f, 152150.0f, 163150.0f, 165350.0f, 167550.0f, 169750.0f, 171950.0f, 174150.0f, 119400.0f, 120350.0f, 121300.0f, 122250.0f, 123200.0f, 124150.0f, 128900.0f, 129850.0f, 130800.0f, 131750.0f, 132700.0f, 133650.0f, 138400.0f, 139350.0f, 140300.0f, 141250.0f, 142200.0f, 143150.0f, 147900.0f, 148850.0f, 149800.0f, 150750.0f, 151700.0f, 152650.0f, 157400.0f, 158350.0f, 159300.0f, 160250.0f, 161200.0f, 162150.0f, 166900.0f, 167850.0f, 168800.0f, 169750.0f, 170700.0f, 171650.0f, 350025.0f, 352850.0f, 355675.0f, 358500.0f, 361325.0f, 364150.0f, 378275.0f, 381100.0f, 383925.0f, 386750.0f, 389575.0f, 392400.0f, 406525.0f, 409350.0f, 412175.0f, 415000.0f, 417825.0f, 420650.0f, 434775.0f, 437600.0f, 440425.0f, 443250.0f, 446075.0f, 448900.0f, 463025.0f, 465850.0f, 468675.0f, 471500.0f, 474325.0f, 477150.0f, 491275.0f, 494100.0f, 496925.0f, 499750.0f, 502575.0f, 505400.0f, 353775.0f, 355350.0f, 356925.0f, 358500.0f, 360075.0f, 361650.0f, 369525.0f, 371100.0f, 372675.0f, 374250.0f, 375825.0f, 377400.0f, 385275.0f, 386850.0f, 388425.0f, 390000.0f, 391575.0f, 393150.0f, 401025.0f, 402600.0f, 404175.0f, 405750.0f, 407325.0f, 408900.0f, 416775.0f, 418350.0f, 419925.0f, 421500.0f, 423075.0f, 424650.0f, 432525.0f, 434100.0f, 435675.0f, 437250.0f, 438825.0f, 440400.0f, 771900.0f, 775350.0f, 778800.0f, 782250.0f, 785700.0f, 789150.0f, 806400.0f, 809850.0f, 813300.0f, 816750.0f, 820200.0f, 823650.0f, 840900.0f, 844350.0f, 847800.0f, 851250.0f, 854700.0f, 858150.0f, 875400.0f, 878850.0f, 882300.0f, 885750.0f, 889200.0f, 892650.0f, 909900.0f, 913350.0f, 916800.0f, 920250.0f, 923700.0f, 927150.0f, 944400.0f, 947850.0f, 951300.0f, 954750.0f, 958200.0f, 961650.0f, 107525.0f, 107850.0f, 108175.0f, 108500.0f, 108825.0f, 109150.0f, 110775.0f, 111100.0f, 111425.0f, 111750.0f, 112075.0f, 112400.0f, 114025.0f, 114350.0f, 114675.0f, 115000.0f, 115325.0f, 115650.0f, 117275.0f, 117600.0f, 117925.0f, 118250.0f, 118575.0f, 118900.0f, 120525.0f, 120850.0f, 121175.0f, 121500.0f, 121825.0f, 122150.0f, 123775.0f, 124100.0f, 124425.0f, 124750.0f, 125075.0f, 125400.0f, 713150.0f, 715350.0f, 717550.0f, 719750.0f, 721950.0f, 724150.0f, 735150.0f, 737350.0f, 739550.0f, 741750.0f, 743950.0f, 746150.0f, 757150.0f, 759350.0f, 761550.0f, 763750.0f, 765950.0f, 768150.0f, 779150.0f, 781350.0f, 783550.0f, 785750.0f, 787950.0f, 790150.0f, 801150.0f, 803350.0f, 805550.0f, 807750.0f, 809950.0f, 812150.0f, 823150.0f, 825350.0f, 827550.0f, 829750.0f, 831950.0f, 834150.0f, 404400.0f, 405350.0f, 406300.0f, 407250.0f, 408200.0f, 409150.0f, 413900.0f, 414850.0f, 415800.0f, 416750.0f, 417700.0f, 418650.0f, 423400.0f, 424350.0f, 425300.0f, 426250.0f, 427200.0f, 428150.0f, 432900.0f, 433850.0f, 434800.0f, 435750.0f, 436700.0f, 437650.0f, 442400.0f, 443350.0f, 444300.0f, 445250.0f, 446200.0f, 447150.0f, 451900.0f, 452850.0f, 453800.0f, 454750.0f, 455700.0f, 456650.0f, 1197525.0f, 1200350.0f, 1203175.0f, 1206000.0f, 1208825.0f, 1211650.0f, 1225775.0f, 1228600.0f, 1231425.0f, 1234250.0f, 1237075.0f, 1239900.0f, 1254025.0f, 1256850.0f, 1259675.0f, 1262500.0f, 1265325.0f, 1268150.0f, 1282275.0f, 1285100.0f, 1287925.0f, 1290750.0f, 1293575.0f, 1296400.0f, 1310525.0f, 1313350.0f, 1316175.0f, 1319000.0f, 1321825.0f, 1324650.0f, 1338775.0f, 1341600.0f, 1344425.0f, 1347250.0f, 1350075.0f, 1352900.0f, 826275.0f, 827850.0f, 829425.0f, 831000.0f, 832575.0f, 834150.0f, 842025.0f, 843600.0f, 845175.0f, 846750.0f, 848325.0f, 849900.0f, 857775.0f, 859350.0f, 860925.0f, 862500.0f, 864075.0f, 865650.0f, 873525.0f, 875100.0f, 876675.0f, 878250.0f, 879825.0f, 881400.0f, 889275.0f, 890850.0f, 892425.0f, 894000.0f, 895575.0f, 897150.0f, 905025.0f, 906600.0f, 908175.0f, 909750.0f, 911325.0f, 912900.0f, 1806900.0f, 1810350.0f, 1813800.0f, 1817250.0f, 1820700.0f, 1824150.0f, 1841400.0f, 1844850.0f, 1848300.0f, 1851750.0f, 1855200.0f, 1858650.0f, 1875900.0f, 1879350.0f, 1882800.0f, 1886250.0f, 1889700.0f, 1893150.0f, 1910400.0f, 1913850.0f, 1917300.0f, 1920750.0f, 1924200.0f, 1927650.0f, 1944900.0f, 1948350.0f, 1951800.0f, 1955250.0f, 1958700.0f, 1962150.0f, 1979400.0f, 1982850.0f, 1986300.0f, 1989750.0f, 1993200.0f, 1996650.}; + TypeParam _expBFF[] = {10025.0f, 10350.0f, 10675.0f, 11000.0f, 11325.0f, 11650.0f, 13275.0f, 13600.0f, 13925.0f, 14250.0f, 14575.0f, 14900.0f, 16525.0f, 16850.0f, 17175.0f, 17500.0f, 17825.0f, 18150.0f, 19775.0f, 20100.0f, 20425.0f, 20750.0f, 21075.0f, 21400.0f, 23025.0f, 23350.0f, 23675.0f, 24000.0f, 24325.0f, 24650.0f, 26275.0f, 26600.0f, 26925.0f, 27250.0f, 27575.0f, 27900.0f, 53150.0f, 55350.0f, 57550.0f, 59750.0f, 61950.0f, 64150.0f, 75150.0f, 77350.0f, 79550.0f, 81750.0f, 83950.0f, 86150.0f, 97150.0f, 99350.0f, 101550.0f, 103750.0f, 105950.0f, 108150.0f, 119150.0f, 121350.0f, 123550.0f, 125750.0f, 127950.0f, 130150.0f, 141150.0f, 143350.0f, 145550.0f, 147750.0f, 149950.0f, 152150.0f, 163150.0f, 165350.0f, 167550.0f, 169750.0f, 171950.0f, 174150.0f, 119400.0f, 120350.0f, 121300.0f, 122250.0f, 123200.0f, 124150.0f, 128900.0f, 129850.0f, 130800.0f, 131750.0f, 132700.0f, 133650.0f, 138400.0f, 139350.0f, 140300.0f, 141250.0f, 142200.0f, 143150.0f, 147900.0f, 148850.0f, 149800.0f, 150750.0f, 151700.0f, 152650.0f, 157400.0f, 158350.0f, 159300.0f, 160250.0f, 161200.0f, 162150.0f, 166900.0f, 167850.0f, 168800.0f, 169750.0f, 170700.0f, 171650.0f, 350025.0f, 352850.0f, 355675.0f, 358500.0f, 361325.0f, 364150.0f, 378275.0f, 381100.0f, 383925.0f, 386750.0f, 389575.0f, 392400.0f, 406525.0f, 409350.0f, 412175.0f, 415000.0f, 417825.0f, 420650.0f, 434775.0f, 437600.0f, 440425.0f, 443250.0f, 446075.0f, 448900.0f, 463025.0f, 465850.0f, 468675.0f, 471500.0f, 474325.0f, 477150.0f, 491275.0f, 494100.0f, 496925.0f, 499750.0f, 502575.0f, 505400.0f, 353775.0f, 355350.0f, 356925.0f, 358500.0f, 360075.0f, 361650.0f, 369525.0f, 371100.0f, 372675.0f, 374250.0f, 375825.0f, 377400.0f, 385275.0f, 386850.0f, 388425.0f, 390000.0f, 391575.0f, 393150.0f, 401025.0f, 402600.0f, 404175.0f, 405750.0f, 407325.0f, 408900.0f, 416775.0f, 418350.0f, 419925.0f, 421500.0f, 423075.0f, 424650.0f, 432525.0f, 434100.0f, 435675.0f, 437250.0f, 438825.0f, 440400.0f, 771900.0f, 775350.0f, 778800.0f, 782250.0f, 785700.0f, 789150.0f, 806400.0f, 809850.0f, 813300.0f, 816750.0f, 820200.0f, 823650.0f, 840900.0f, 844350.0f, 847800.0f, 851250.0f, 854700.0f, 858150.0f, 875400.0f, 878850.0f, 882300.0f, 885750.0f, 889200.0f, 892650.0f, 909900.0f, 913350.0f, 916800.0f, 920250.0f, 923700.0f, 927150.0f, 944400.0f, 947850.0f, 951300.0f, 954750.0f, 958200.0f, 961650.0f, 107525.0f, 107850.0f, 108175.0f, 108500.0f, 108825.0f, 109150.0f, 110775.0f, 111100.0f, 111425.0f, 111750.0f, 112075.0f, 112400.0f, 114025.0f, 114350.0f, 114675.0f, 115000.0f, 115325.0f, 115650.0f, 117275.0f, 117600.0f, 117925.0f, 118250.0f, 118575.0f, 118900.0f, 120525.0f, 120850.0f, 121175.0f, 121500.0f, 121825.0f, 122150.0f, 123775.0f, 124100.0f, 124425.0f, 124750.0f, 125075.0f, 125400.0f, 713150.0f, 715350.0f, 717550.0f, 719750.0f, 721950.0f, 724150.0f, 735150.0f, 737350.0f, 739550.0f, 741750.0f, 743950.0f, 746150.0f, 757150.0f, 759350.0f, 761550.0f, 763750.0f, 765950.0f, 768150.0f, 779150.0f, 781350.0f, 783550.0f, 785750.0f, 787950.0f, 790150.0f, 801150.0f, 803350.0f, 805550.0f, 807750.0f, 809950.0f, 812150.0f, 823150.0f, 825350.0f, 827550.0f, 829750.0f, 831950.0f, 834150.0f, 404400.0f, 405350.0f, 406300.0f, 407250.0f, 408200.0f, 409150.0f, 413900.0f, 414850.0f, 415800.0f, 416750.0f, 417700.0f, 418650.0f, 423400.0f, 424350.0f, 425300.0f, 426250.0f, 427200.0f, 428150.0f, 432900.0f, 433850.0f, 434800.0f, 435750.0f, 436700.0f, 437650.0f, 442400.0f, 443350.0f, 444300.0f, 445250.0f, 446200.0f, 447150.0f, 451900.0f, 452850.0f, 453800.0f, 454750.0f, 455700.0f, 456650.0f, 1197525.0f, 1200350.0f, 1203175.0f, 1206000.0f, 1208825.0f, 1211650.0f, 1225775.0f, 1228600.0f, 1231425.0f, 1234250.0f, 1237075.0f, 1239900.0f, 1254025.0f, 1256850.0f, 1259675.0f, 1262500.0f, 1265325.0f, 1268150.0f, 1282275.0f, 1285100.0f, 1287925.0f, 1290750.0f, 1293575.0f, 1296400.0f, 1310525.0f, 1313350.0f, 1316175.0f, 1319000.0f, 1321825.0f, 1324650.0f, 1338775.0f, 1341600.0f, 1344425.0f, 1347250.0f, 1350075.0f, 1352900.0f, 826275.0f, 827850.0f, 829425.0f, 831000.0f, 832575.0f, 834150.0f, 842025.0f, 843600.0f, 845175.0f, 846750.0f, 848325.0f, 849900.0f, 857775.0f, 859350.0f, 860925.0f, 862500.0f, 864075.0f, 865650.0f, 873525.0f, 875100.0f, 876675.0f, 878250.0f, 879825.0f, 881400.0f, 889275.0f, 890850.0f, 892425.0f, 894000.0f, 895575.0f, 897150.0f, 905025.0f, 906600.0f, 908175.0f, 909750.0f, 911325.0f, 912900.0f, 1806900.0f, 1810350.0f, 1813800.0f, 1817250.0f, 1820700.0f, 1824150.0f, 1841400.0f, 1844850.0f, 1848300.0f, 1851750.0f, 1855200.0f, 1858650.0f, 1875900.0f, 1879350.0f, 1882800.0f, 1886250.0f, 1889700.0f, 1893150.0f, 1910400.0f, 1913850.0f, 1917300.0f, 1920750.0f, 1924200.0f, 1927650.0f, 1944900.0f, 1948350.0f, 1951800.0f, 1955250.0f, 1958700.0f, 1962150.0f, 1979400.0f, 1982850.0f, 1986300.0f, 1989750.0f, 1993200.0f, 1996650.f}; Nd4jLong _expSFF[] = {4, 2, 6, 6, 6, 216, 36, 6, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, 1, 99,}; NDArray expFF(_expBFF, _expSFF); - TypeParam _exp2BFF[] = {827.4900282f, 832.2350283f, 836.9800284f, 841.725028f, 846.4700287f, 851.2150288f, 874.9400293f, 879.6850294f, 884.4300295f, 889.1750296f, 893.9200297f, 898.665029f, 922.3900304f, 927.1350305f, 931.8800306f, 936.6250307f, 941.3700308f, 946.1150309f, 969.8400315f, 974.5850316f, 979.3300317f, 984.0750318f, 988.8200319f, 993.5650320f, 1017.2900326f, 1022.0350327f, 1026.7800328f, 1031.5250329f, 1036.2700330f, 1041.0150331f, 1064.7400337f, 1069.4850338f, 1074.2300339f, 1078.9750340f, 1083.7200341f, 1088.4650342f, 1822.4550553f, 1833.995055f, 1845.5350558f, 1857.075056f, 1868.6150563f, 1880.1550566f, 1937.8550578f, 1949.3950581f, 1960.9350583f, 1972.4750586f, 1984.015058f, 1995.5550591f, 2053.2550604f, 2064.7950606f, 2076.3350609f, 2087.8750611f, 2099.4150614f, 2110.955061f, 2168.6550629f, 2180.1950632f, 2191.7350634f, 2203.2750637f, 2214.8150639f, 2226.3550642f, 2284.0550655f, 2295.5950657f, 2307.1350660f, 2318.6750662f, 2330.2150665f, 2341.7550667f, 2399.4550680f, 2410.9950683f, 2422.5350685f, 2434.0750688f, 2445.6150690f, 2457.1550693f, 2817.419968f, 2835.7549686f, 2854.0899683f, 2872.4249680f, 2890.7599677f, 2909.0949674f, 3000.7699660f, 3019.104965f, 3037.4399655f, 3055.7749652f, 3074.1099649f, 3092.4449646f, 3184.1199632f, 3202.4549629f, 3220.789962f, 3239.1249624f, 3257.4599621f, 3275.7949618f, 3367.4699604f, 3385.8049601f, 3404.1399598f, 3422.474959f, 3440.8099593f, 3459.1449590f, 3550.8199576f, 3569.1549573f, 3587.4899570f, 3605.8249567f, 3624.1599565f, 3642.4949562f, 3734.1699548f, 3752.5049545f, 3770.8399542f, 3789.1749539f, 3807.5099536f, 3825.8449534f, 3812.385098f, 3837.5150988f, 3862.6450994f, 3887.7751000f, 3912.9051006f, 3938.0351012f, 4063.6851041f, 4088.8151047f, 4113.9451053f, 4139.0751059f, 4164.2051065f, 4189.3351071f, 4314.9851100f, 4340.1151106f, 4365.2451112f, 4390.3751118f, 4415.5051124f, 4440.6351130f, 4566.2851159f, 4591.4151165f, 4616.5451171f, 4641.6751177f, 4666.805118f, 4691.9351188f, 4817.5851218f, 4842.7151224f, 4867.8451230f, 4892.975123f, 4918.1051241f, 4943.2351247f, 5068.8851277f, 5094.0151283f, 5119.1451288f, 5144.2751294f, 5169.4051300f, 5194.5351306f, 4807.3499803f, 4839.2749801f, 4871.1999799f, 4903.1249797f, 4935.0499795f, 4966.9749793f, 5126.5999784f, 5158.5249782f, 5190.4499780f, 5222.3749778f, 5254.2999777f, 5286.2249775f, 5445.8499765f, 5477.774976f, 5509.6999762f, 5541.6249760f, 5573.5499758f, 5605.4749756f, 5765.0999747f, 5797.0249745f, 5828.9499743f, 5860.8749741f, 5892.7999739f, 5924.724973f, 6084.3499728f, 6116.2749726f, 6148.1999724f, 6180.1249723f, 6212.0499721f, 6243.9749719f, 6403.59997f, 6435.5249708f, 6467.4499706f, 6499.3749704f, 6531.2999702f, 6563.2249700f, 5802.3150007f, 5841.0350006f, 5879.7550005f, 5918.4750004f, 5957.195000f, 5995.9150003f, 6189.5149999f, 6228.2349998f, 6266.9549997f, 6305.6749996f, 6344.3949995f, 6383.114999f, 6576.7149990f, 6615.4349990f, 6654.1549989f, 6692.8749988f, 6731.5949987f, 6770.3149986f, 6963.9149982f, 7002.6349981f, 7041.3549981f, 7080.0749980f, 7118.7949979f, 7157.5149978f, 7351.1149974f, 7389.8349973f, 7428.5549972f, 7467.2749972f, 7505.9949971f, 7544.7149970f, 7738.3149966f, 7777.0349965f, 7815.7549964f, 7854.4749963f, 7893.1949963f, 7931.9149962f, 6797.2799488f, 6842.794948f, 6888.3099489f, 6933.8249490f, 6979.3399491f, 7024.8549492f, 7252.4299497f, 7297.9449498f, 7343.4599499f, 7388.9749500f, 7434.489950f, 7480.0049501f, 7707.5799506f, 7753.0949507f, 7798.6099508f, 7844.1249509f, 7889.6399510f, 7935.1549511f, 8162.7299515f, 8208.2449516f, 8253.7599517f, 8299.2749518f, 8344.7899519f, 8390.3049520f, 8617.8799525f, 8663.394952f, 8708.9099526f, 8754.4249527f, 8799.9399528f, 8845.4549529f, 9073.0299534f, 9118.5449535f, 9164.0599536f, 9209.5749537f, 9255.089953f, 9300.604953f, 7792.2451647f, 7844.5551655f, 7896.8651663f, 7949.1751671f, 8001.4851679f, 8053.7951686f, 8315.3451725f, 8367.6551733f, 8419.9651741f, 8472.2751749f, 8524.585175f, 8576.8951764f, 8838.4451803f, 8890.7551811f, 8943.0651819f, 8995.3751827f, 9047.6851834f, 9099.9951842f, 9361.5451881f, 9413.8551889f, 9466.1651897f, 9518.475190f, 9570.7851912f, 9623.0951920f, 9884.6451959f, 9936.9551967f, 9989.2651975f, 10041.5751982f, 10093.8851990f, 10146.1951998f, 10407.7452037f, 10460.0552045f, 10512.3652053f, 10564.6752060f, 10616.9852068f, 10669.2952076f, 8787.210074f, 8846.3150748f, 8905.4200750f, 8964.5250752f, 9023.6300755f, 9082.7350757f, 9378.2600768f, 9437.3650770f, 9496.4700773f, 9555.5750775f, 9614.6800777f, 9673.7850779f, 9969.3100791f, 10028.4150793f, 10087.5200795f, 10146.625079f, 10205.7300800f, 10264.8350802f, 10560.3600813f, 10619.465081f, 10678.5700818f, 10737.6750820f, 10796.7800822f, 10855.8850825f, 11151.4100836f, 11210.5150838f, 11269.6200840f, 11328.7250843f, 11387.8300845f, 11446.9350847f, 11742.4600858f, 11801.5650861f, 11860.6700863f, 11919.7750865f, 11978.880086f, 12037.9850870f, 9782.1750935f, 9848.0750935f, 9913.9750934f, 9979.8750934f, 10045.7750934f, 10111.6750933f, 10441.1750931f, 10507.0750931f, 10572.9750931f, 10638.8750930f, 10704.7750930f, 10770.6750930f, 11100.1750928f, 11166.0750927f, 11231.9750927f, 11297.8750927f, 11363.7750926f, 11429.6750926f, 11759.1750924f, 11825.0750924f, 11890.9750923f, 11956.8750923f, 12022.7750923f, 12088.6750922f, 12418.175092f, 12484.0750920f, 12549.9750920f, 12615.8750919f, 12681.7750919f, 12747.6750919f, 13077.1750917f, 13143.0750916f, 13208.9750916f, 13274.8750916f, 13340.7750915f, 13406.6750915f, 2250.990060f, 2255.7350610f, 2260.4800611f, 2265.2250612f, 2269.9700613f, 2274.7150614f, 2298.4400619f, 2303.185062f, 2307.9300622f, 2312.6750623f, 2317.4200624f, 2322.1650625f, 2345.8900630f, 2350.6350631f, 2355.380063f, 2360.1250634f, 2364.8700635f, 2369.6150636f, 2393.3400641f, 2398.0850642f, 2402.8300643f, 2407.5750644f, 2412.320064f, 2417.0650647f, 2440.7900652f, 2445.5350653f, 2450.2800654f, 2455.0250655f, 2459.7700656f, 2464.515065f, 2488.2400663f, 2492.9850664f, 2497.7300665f, 2502.4750666f, 2507.2200667f, 2511.9650668f, 5284.4551315f, 5295.9951318f, 5307.535132f, 5319.0751323f, 5330.6151326f, 5342.1551328f, 5399.8551341f, 5411.3951343f, 5422.9351346f, 5434.475134f, 5446.0151351f, 5457.5551354f, 5515.2551366f, 5526.7951369f, 5538.3351371f, 5549.8751374f, 5561.4151376f, 5572.9551379f, 5630.6551392f, 5642.1951394f, 5653.7351397f, 5665.2751399f, 5676.8151402f, 5688.3551404f, 5746.0551417f, 5757.5951420f, 5769.1351422f, 5780.6751425f, 5792.2151427f, 5803.7551430f, 5861.455144f, 5872.9951445f, 5884.5351448f, 5896.0751450f, 5907.6151453f, 5919.1551455f, 8317.919884f, 8336.2548841f, 8354.5898838f, 8372.9248835f, 8391.2598832f, 8409.59488f, 8501.2698815f, 8519.6048813f, 8537.9398810f, 8556.2748807f, 8574.6098804f, 8592.9448801f, 8684.6198787f, 8702.9548784f, 8721.2898782f, 8739.6248779f, 8757.9598776f, 8776.2948773f, 8867.9698759f, 8886.3048756f, 8904.6398753f, 8922.9748751f, 8941.3098748f, 8959.6448745f, 9051.3198731f, 9069.6548728f, 9087.9898725f, 9106.3248722f, 9124.6598720f, 9142.9948717f, 9234.6698703f, 9253.0048700f, 9271.3398697f, 9289.6748694f, 9308.0098691f, 9326.3448689f, 11351.3852747f, 11376.5152753f, 11401.6452759f, 11426.7752765f, 11451.9052771f, 11477.0352777f, 11602.6852806f, 11627.8152812f, 11652.9452818f, 11678.0752824f, 11703.2052830f, 11728.335283f, 11853.9852865f, 11879.1152871f, 11904.2452877f, 11929.3752883f, 11954.505288f, 11979.6352894f, 12105.2852924f, 12130.4152930f, 12155.545293f, 12180.6752941f, 12205.8052947f, 12230.9352953f, 12356.5852983f, 12381.715298f, 12406.8452994f, 12431.9753000f, 12457.1053006f, 12482.2353012f, 12607.8853041f, 12633.0153047f, 12658.1453053f, 12683.2753059f, 12708.4053065f, 12733.5353071f, 14384.8499244f, 14416.7749242f, 14448.6999240f, 14480.6249238f, 14512.549923f, 14544.4749235f, 14704.0999225f, 14736.024922f, 14767.9499222f, 14799.8749220f, 14831.7999218f, 14863.7249216f, 15023.3499207f, 15055.2749205f, 15087.1999203f, 15119.1249201f, 15151.0499199f, 15182.9749197f, 15342.5999188f, 15374.5249186f, 15406.4499184f, 15438.374918f, 15470.2999181f, 15502.2249179f, 15661.84991f, 15693.7749168f, 15725.6999166f, 15757.6249164f, 15789.5499162f, 15821.4749160f, 15981.0999151f, 16013.0249149f, 16044.9499147f, 16076.8749145f, 16108.7999143f, 16140.7249142f, 17418.314976f, 17457.0349761f, 17495.7549760f, 17534.4749759f, 17573.1949758f, 17611.9149757f, 17805.5149753f, 17844.234975f, 17882.9549752f, 17921.6749751f, 17960.3949750f, 17999.1149749f, 18192.7149745f, 18231.4349744f, 18270.154974f, 18308.8749743f, 18347.5949742f, 18386.3149741f, 18579.9149737f, 18618.6349736f, 18657.3549735f, 18696.074973f, 18734.7949734f, 18773.5149733f, 18967.1149729f, 19005.8349728f, 19044.5549727f, 19083.2749726f, 19121.994972f, 19160.7149725f, 19354.3149721f, 19393.0349720f, 19431.7549719f, 19470.4749718f, 19509.1949717f, 19547.914971f, 20451.7799765f, 20497.2949766f, 20542.8099767f, 20588.3249768f, 20633.8399769f, 20679.3549770f, 20906.929977f, 20952.4449775f, 20997.9599776f, 21043.4749777f, 21088.9899778f, 21134.5049779f, 21362.0799784f, 21407.5949785f, 21453.1099786f, 21498.624978f, 21544.139978f, 21589.6549788f, 21817.2299793f, 21862.7449794f, 21908.2599795f, 21953.7749796f, 21999.2899797f, 22044.8049798f, 22272.3799802f, 22317.8949803f, 22363.4099804f, 22408.9249805f, 22454.4399806f, 22499.9549807f, 22727.529981f, 22773.044981f, 22818.5599813f, 22864.0749814f, 22909.5899815f, 22955.1049816f, 23485.2453985f, 23537.555399f, 23589.8654000f, 23642.1754008f, 23694.4854016f, 23746.7954024f, 24008.3454063f, 24060.655407f, 24112.9654078f, 24165.2754086f, 24217.5854094f, 24269.8954102f, 24531.4454141f, 24583.7554148f, 24636.0654156f, 24688.3754164f, 24740.6854172f, 24792.99541f, 25054.545421f, 25106.8554226f, 25159.1654234f, 25211.4754242f, 25263.7854250f, 25316.0954257f, 25577.6454296f, 25629.9554304f, 25682.2654312f, 25734.5754320f, 25786.8854328f, 25839.1954335f, 26100.7454374f, 26153.0554382f, 26205.3654390f, 26257.6754398f, 26309.985440f, 26362.2954413f, 26518.7101423f, 26577.8151425f, 26636.920142f, 26696.0251430f, 26755.1301432f, 26814.2351434f, 27109.7601446f, 27168.8651448f, 27227.9701450f, 27287.0751452f, 27346.1801455f, 27405.2851457f, 27700.8101468f, 27759.9151470f, 27819.0201473f, 27878.1251475f, 27937.2301477f, 27996.33514f, 28291.8601491f, 28350.9651493f, 28410.0701495f, 28469.175149f, 28528.2801500f, 28587.3851502f, 28882.9101513f, 28942.0151516f, 29001.1201518f, 29060.2251520f, 29119.3301522f, 29178.4351525f, 29473.9601536f, 29533.0651538f, 29592.1701540f, 29651.2751543f, 29710.3801545f, 29769.4851547f, 29552.1750826f, 29618.0750825f, 29683.9750825f, 29749.8750825f, 29815.7750824f, 29881.6750824f, 30211.1750822f, 30277.0750822f, 30342.9750821f, 30408.8750821f, 30474.7750821f, 30540.6750820f, 30870.175081f, 30936.0750818f, 31001.9750818f, 31067.8750817f, 31133.7750817f, 31199.6750817f, 31529.1750815f, 31595.075081f, 31660.9750814f, 31726.8750814f, 31792.7750813f, 31858.6750813f, 32188.1750811f, 32254.0750811f, 32319.975081f, 32385.8750810f, 32451.7750810f, 32517.6750809f, 32847.1750808f, 32913.0750807f, 32978.9750807f, 33044.875080f, 33110.7750806f, 33176.67508062}; - Nd4jLong _exp2SFF[] = {4, 2, 10, 6, 6, 360, 36, 6, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, 1, 99,}; + TypeParam _exp2BFF[] = {827.4900282f, 832.2350283f, 836.9800284f, 841.725028f, 846.4700287f, 851.2150288f, 874.9400293f, 879.6850294f, 884.4300295f, 889.1750296f, 893.9200297f, 898.665029f, 922.3900304f, 927.1350305f, 931.8800306f, 936.6250307f, 941.3700308f, 946.1150309f, 969.8400315f, 974.5850316f, 979.3300317f, 984.0750318f, 988.8200319f, 993.5650320f, 1017.2900326f, 1022.0350327f, 1026.7800328f, 1031.5250329f, 1036.2700330f, 1041.0150331f, 1064.7400337f, 1069.4850338f, 1074.2300339f, 1078.9750340f, 1083.7200341f, 1088.4650342f, 1822.4550553f, 1833.995055f, 1845.5350558f, 1857.075056f, 1868.6150563f, 1880.1550566f, 1937.8550578f, 1949.3950581f, 1960.9350583f, 1972.4750586f, 1984.015058f, 1995.5550591f, 2053.2550604f, 2064.7950606f, 2076.3350609f, 2087.8750611f, 2099.4150614f, 2110.955061f, 2168.6550629f, 2180.1950632f, 2191.7350634f, 2203.2750637f, 2214.8150639f, 2226.3550642f, 2284.0550655f, 2295.5950657f, 2307.1350660f, 2318.6750662f, 2330.2150665f, 2341.7550667f, 2399.4550680f, 2410.9950683f, 2422.5350685f, 2434.0750688f, 2445.6150690f, 2457.1550693f, 2817.419968f, 2835.7549686f, 2854.0899683f, 2872.4249680f, 2890.7599677f, 2909.0949674f, 3000.7699660f, 3019.104965f, 3037.4399655f, 3055.7749652f, 3074.1099649f, 3092.4449646f, 3184.1199632f, 3202.4549629f, 3220.789962f, 3239.1249624f, 3257.4599621f, 3275.7949618f, 3367.4699604f, 3385.8049601f, 3404.1399598f, 3422.474959f, 3440.8099593f, 3459.1449590f, 3550.8199576f, 3569.1549573f, 3587.4899570f, 3605.8249567f, 3624.1599565f, 3642.4949562f, 3734.1699548f, 3752.5049545f, 3770.8399542f, 3789.1749539f, 3807.5099536f, 3825.8449534f, 3812.385098f, 3837.5150988f, 3862.6450994f, 3887.7751000f, 3912.9051006f, 3938.0351012f, 4063.6851041f, 4088.8151047f, 4113.9451053f, 4139.0751059f, 4164.2051065f, 4189.3351071f, 4314.9851100f, 4340.1151106f, 4365.2451112f, 4390.3751118f, 4415.5051124f, 4440.6351130f, 4566.2851159f, 4591.4151165f, 4616.5451171f, 4641.6751177f, 4666.805118f, 4691.9351188f, 4817.5851218f, 4842.7151224f, 4867.8451230f, 4892.975123f, 4918.1051241f, 4943.2351247f, 5068.8851277f, 5094.0151283f, 5119.1451288f, 5144.2751294f, 5169.4051300f, 5194.5351306f, 4807.3499803f, 4839.2749801f, 4871.1999799f, 4903.1249797f, 4935.0499795f, 4966.9749793f, 5126.5999784f, 5158.5249782f, 5190.4499780f, 5222.3749778f, 5254.2999777f, 5286.2249775f, 5445.8499765f, 5477.774976f, 5509.6999762f, 5541.6249760f, 5573.5499758f, 5605.4749756f, 5765.0999747f, 5797.0249745f, 5828.9499743f, 5860.8749741f, 5892.7999739f, 5924.724973f, 6084.3499728f, 6116.2749726f, 6148.1999724f, 6180.1249723f, 6212.0499721f, 6243.9749719f, 6403.59997f, 6435.5249708f, 6467.4499706f, 6499.3749704f, 6531.2999702f, 6563.2249700f, 5802.3150007f, 5841.0350006f, 5879.7550005f, 5918.4750004f, 5957.195000f, 5995.9150003f, 6189.5149999f, 6228.2349998f, 6266.9549997f, 6305.6749996f, 6344.3949995f, 6383.114999f, 6576.7149990f, 6615.4349990f, 6654.1549989f, 6692.8749988f, 6731.5949987f, 6770.3149986f, 6963.9149982f, 7002.6349981f, 7041.3549981f, 7080.0749980f, 7118.7949979f, 7157.5149978f, 7351.1149974f, 7389.8349973f, 7428.5549972f, 7467.2749972f, 7505.9949971f, 7544.7149970f, 7738.3149966f, 7777.0349965f, 7815.7549964f, 7854.4749963f, 7893.1949963f, 7931.9149962f, 6797.2799488f, 6842.794948f, 6888.3099489f, 6933.8249490f, 6979.3399491f, 7024.8549492f, 7252.4299497f, 7297.9449498f, 7343.4599499f, 7388.9749500f, 7434.489950f, 7480.0049501f, 7707.5799506f, 7753.0949507f, 7798.6099508f, 7844.1249509f, 7889.6399510f, 7935.1549511f, 8162.7299515f, 8208.2449516f, 8253.7599517f, 8299.2749518f, 8344.7899519f, 8390.3049520f, 8617.8799525f, 8663.394952f, 8708.9099526f, 8754.4249527f, 8799.9399528f, 8845.4549529f, 9073.0299534f, 9118.5449535f, 9164.0599536f, 9209.5749537f, 9255.089953f, 9300.604953f, 7792.2451647f, 7844.5551655f, 7896.8651663f, 7949.1751671f, 8001.4851679f, 8053.7951686f, 8315.3451725f, 8367.6551733f, 8419.9651741f, 8472.2751749f, 8524.585175f, 8576.8951764f, 8838.4451803f, 8890.7551811f, 8943.0651819f, 8995.3751827f, 9047.6851834f, 9099.9951842f, 9361.5451881f, 9413.8551889f, 9466.1651897f, 9518.475190f, 9570.7851912f, 9623.0951920f, 9884.6451959f, 9936.9551967f, 9989.2651975f, 10041.5751982f, 10093.8851990f, 10146.1951998f, 10407.7452037f, 10460.0552045f, 10512.3652053f, 10564.6752060f, 10616.9852068f, 10669.2952076f, 8787.210074f, 8846.3150748f, 8905.4200750f, 8964.5250752f, 9023.6300755f, 9082.7350757f, 9378.2600768f, 9437.3650770f, 9496.4700773f, 9555.5750775f, 9614.6800777f, 9673.7850779f, 9969.3100791f, 10028.4150793f, 10087.5200795f, 10146.625079f, 10205.7300800f, 10264.8350802f, 10560.3600813f, 10619.465081f, 10678.5700818f, 10737.6750820f, 10796.7800822f, 10855.8850825f, 11151.4100836f, 11210.5150838f, 11269.6200840f, 11328.7250843f, 11387.8300845f, 11446.9350847f, 11742.4600858f, 11801.5650861f, 11860.6700863f, 11919.7750865f, 11978.880086f, 12037.9850870f, 9782.1750935f, 9848.0750935f, 9913.9750934f, 9979.8750934f, 10045.7750934f, 10111.6750933f, 10441.1750931f, 10507.0750931f, 10572.9750931f, 10638.8750930f, 10704.7750930f, 10770.6750930f, 11100.1750928f, 11166.0750927f, 11231.9750927f, 11297.8750927f, 11363.7750926f, 11429.6750926f, 11759.1750924f, 11825.0750924f, 11890.9750923f, 11956.8750923f, 12022.7750923f, 12088.6750922f, 12418.175092f, 12484.0750920f, 12549.9750920f, 12615.8750919f, 12681.7750919f, 12747.6750919f, 13077.1750917f, 13143.0750916f, 13208.9750916f, 13274.8750916f, 13340.7750915f, 13406.6750915f, 2250.990060f, 2255.7350610f, 2260.4800611f, 2265.2250612f, 2269.9700613f, 2274.7150614f, 2298.4400619f, 2303.185062f, 2307.9300622f, 2312.6750623f, 2317.4200624f, 2322.1650625f, 2345.8900630f, 2350.6350631f, 2355.380063f, 2360.1250634f, 2364.8700635f, 2369.6150636f, 2393.3400641f, 2398.0850642f, 2402.8300643f, 2407.5750644f, 2412.320064f, 2417.0650647f, 2440.7900652f, 2445.5350653f, 2450.2800654f, 2455.0250655f, 2459.7700656f, 2464.515065f, 2488.2400663f, 2492.9850664f, 2497.7300665f, 2502.4750666f, 2507.2200667f, 2511.9650668f, 5284.4551315f, 5295.9951318f, 5307.535132f, 5319.0751323f, 5330.6151326f, 5342.1551328f, 5399.8551341f, 5411.3951343f, 5422.9351346f, 5434.475134f, 5446.0151351f, 5457.5551354f, 5515.2551366f, 5526.7951369f, 5538.3351371f, 5549.8751374f, 5561.4151376f, 5572.9551379f, 5630.6551392f, 5642.1951394f, 5653.7351397f, 5665.2751399f, 5676.8151402f, 5688.3551404f, 5746.0551417f, 5757.5951420f, 5769.1351422f, 5780.6751425f, 5792.2151427f, 5803.7551430f, 5861.455144f, 5872.9951445f, 5884.5351448f, 5896.0751450f, 5907.6151453f, 5919.1551455f, 8317.919884f, 8336.2548841f, 8354.5898838f, 8372.9248835f, 8391.2598832f, 8409.59488f, 8501.2698815f, 8519.6048813f, 8537.9398810f, 8556.2748807f, 8574.6098804f, 8592.9448801f, 8684.6198787f, 8702.9548784f, 8721.2898782f, 8739.6248779f, 8757.9598776f, 8776.2948773f, 8867.9698759f, 8886.3048756f, 8904.6398753f, 8922.9748751f, 8941.3098748f, 8959.6448745f, 9051.3198731f, 9069.6548728f, 9087.9898725f, 9106.3248722f, 9124.6598720f, 9142.9948717f, 9234.6698703f, 9253.0048700f, 9271.3398697f, 9289.6748694f, 9308.0098691f, 9326.3448689f, 11351.3852747f, 11376.5152753f, 11401.6452759f, 11426.7752765f, 11451.9052771f, 11477.0352777f, 11602.6852806f, 11627.8152812f, 11652.9452818f, 11678.0752824f, 11703.2052830f, 11728.335283f, 11853.9852865f, 11879.1152871f, 11904.2452877f, 11929.3752883f, 11954.505288f, 11979.6352894f, 12105.2852924f, 12130.4152930f, 12155.545293f, 12180.6752941f, 12205.8052947f, 12230.9352953f, 12356.5852983f, 12381.715298f, 12406.8452994f, 12431.9753000f, 12457.1053006f, 12482.2353012f, 12607.8853041f, 12633.0153047f, 12658.1453053f, 12683.2753059f, 12708.4053065f, 12733.5353071f, 14384.8499244f, 14416.7749242f, 14448.6999240f, 14480.6249238f, 14512.549923f, 14544.4749235f, 14704.0999225f, 14736.024922f, 14767.9499222f, 14799.8749220f, 14831.7999218f, 14863.7249216f, 15023.3499207f, 15055.2749205f, 15087.1999203f, 15119.1249201f, 15151.0499199f, 15182.9749197f, 15342.5999188f, 15374.5249186f, 15406.4499184f, 15438.374918f, 15470.2999181f, 15502.2249179f, 15661.84991f, 15693.7749168f, 15725.6999166f, 15757.6249164f, 15789.5499162f, 15821.4749160f, 15981.0999151f, 16013.0249149f, 16044.9499147f, 16076.8749145f, 16108.7999143f, 16140.7249142f, 17418.314976f, 17457.0349761f, 17495.7549760f, 17534.4749759f, 17573.1949758f, 17611.9149757f, 17805.5149753f, 17844.234975f, 17882.9549752f, 17921.6749751f, 17960.3949750f, 17999.1149749f, 18192.7149745f, 18231.4349744f, 18270.154974f, 18308.8749743f, 18347.5949742f, 18386.3149741f, 18579.9149737f, 18618.6349736f, 18657.3549735f, 18696.074973f, 18734.7949734f, 18773.5149733f, 18967.1149729f, 19005.8349728f, 19044.5549727f, 19083.2749726f, 19121.994972f, 19160.7149725f, 19354.3149721f, 19393.0349720f, 19431.7549719f, 19470.4749718f, 19509.1949717f, 19547.914971f, 20451.7799765f, 20497.2949766f, 20542.8099767f, 20588.3249768f, 20633.8399769f, 20679.3549770f, 20906.929977f, 20952.4449775f, 20997.9599776f, 21043.4749777f, 21088.9899778f, 21134.5049779f, 21362.0799784f, 21407.5949785f, 21453.1099786f, 21498.624978f, 21544.139978f, 21589.6549788f, 21817.2299793f, 21862.7449794f, 21908.2599795f, 21953.7749796f, 21999.2899797f, 22044.8049798f, 22272.3799802f, 22317.8949803f, 22363.4099804f, 22408.9249805f, 22454.4399806f, 22499.9549807f, 22727.529981f, 22773.044981f, 22818.5599813f, 22864.0749814f, 22909.5899815f, 22955.1049816f, 23485.2453985f, 23537.555399f, 23589.8654000f, 23642.1754008f, 23694.4854016f, 23746.7954024f, 24008.3454063f, 24060.655407f, 24112.9654078f, 24165.2754086f, 24217.5854094f, 24269.8954102f, 24531.4454141f, 24583.7554148f, 24636.0654156f, 24688.3754164f, 24740.6854172f, 24792.99541f, 25054.545421f, 25106.8554226f, 25159.1654234f, 25211.4754242f, 25263.7854250f, 25316.0954257f, 25577.6454296f, 25629.9554304f, 25682.2654312f, 25734.5754320f, 25786.8854328f, 25839.1954335f, 26100.7454374f, 26153.0554382f, 26205.3654390f, 26257.6754398f, 26309.985440f, 26362.2954413f, 26518.7101423f, 26577.8151425f, 26636.920142f, 26696.0251430f, 26755.1301432f, 26814.2351434f, 27109.7601446f, 27168.8651448f, 27227.9701450f, 27287.0751452f, 27346.1801455f, 27405.2851457f, 27700.8101468f, 27759.9151470f, 27819.0201473f, 27878.1251475f, 27937.2301477f, 27996.33514f, 28291.8601491f, 28350.9651493f, 28410.0701495f, 28469.175149f, 28528.2801500f, 28587.3851502f, 28882.9101513f, 28942.0151516f, 29001.1201518f, 29060.2251520f, 29119.3301522f, 29178.4351525f, 29473.9601536f, 29533.0651538f, 29592.1701540f, 29651.2751543f, 29710.3801545f, 29769.4851547f, 29552.1750826f, 29618.0750825f, 29683.9750825f, 29749.8750825f, 29815.7750824f, 29881.6750824f, 30211.1750822f, 30277.0750822f, 30342.9750821f, 30408.8750821f, 30474.7750821f, 30540.6750820f, 30870.175081f, 30936.0750818f, 31001.9750818f, 31067.8750817f, 31133.7750817f, 31199.6750817f, 31529.1750815f, 31595.075081f, 31660.9750814f, 31726.8750814f, 31792.7750813f, 31858.6750813f, 32188.1750811f, 32254.0750811f, 32319.975081f, 32385.8750810f, 32451.7750810f, 32517.6750809f, 32847.1750808f, 32913.0750807f, 32978.9750807f, 33044.875080f, 33110.7750806f, 33176.67508062f}; + Nd4jLong _exp2SFF[] = {4, 2, 10, 6, 6, 360, 36, 6, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, 1, 99}; NDArray exp2FF(_exp2BFF, _exp2SFF); auto input = NDArrayFactory::create('c', {2, 3, 10, 10}); diff --git a/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp b/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp index 4cbf6b6dd..de3cdcdba 100644 --- a/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp +++ b/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp @@ -212,12 +212,12 @@ TEST_F(ConvolutionTests2, Test_Dilation2D_Again_2) { } TYPED_TEST(TypedConvolutionTests2, sconv2d_bp_1) { - TypeParam _expGradWpB[] = {1603.7102981f, 10645.6278024f, 5975.4227995f, 17697.0903052f, 12133.6353024f, 26535.0528052f, 1779.221097f, 11795.5686029f, 6721.9835994f, 19904.0811062f, 13775.2461029f, 30123.0936062f, 1954.7318976f, 12945.5094033f, 7468.5443993f, 22111.071907f, 15416.8569033f, 33711.134407f, 2130.2426974f, 14095.4502038f, 8215.1051992f, 24318.0627081f, 17058.4677038f, 37299.1752081f, 2305.7534972f, 15245.3910042f, 8961.6659991f, 26525.0535091f, 18700.0785042f, 40887.2160091f, 2481.2642970f, 16395.3318047f, 9708.2267991f, 28732.0443100f, 20341.6893047f, 44475.2568100f, 2656.7750968f, 17545.2726051f, 10454.7875990f, 30939.0351110f, 21983.3001051f, 48063.2976110f, 2832.2858966f, 18695.2134056f, 11201.3483989f, 33146.0259119f, 23624.9109056f, 51651.3384119f, 3007.7966964f, 19845.1542060f, 11947.9091988f, 35353.0167129f, 25266.5217060f, 55239.3792129f, 3183.3074962f, 20995.095006f, 12694.4699987f, 37560.007513f, 26908.132506f, 58827.4200139}; + TypeParam _expGradWpB[] = {1603.7102981f, 10645.6278024f, 5975.4227995f, 17697.0903052f, 12133.6353024f, 26535.0528052f, 1779.221097f, 11795.5686029f, 6721.9835994f, 19904.0811062f, 13775.2461029f, 30123.0936062f, 1954.7318976f, 12945.5094033f, 7468.5443993f, 22111.071907f, 15416.8569033f, 33711.134407f, 2130.2426974f, 14095.4502038f, 8215.1051992f, 24318.0627081f, 17058.4677038f, 37299.1752081f, 2305.7534972f, 15245.3910042f, 8961.6659991f, 26525.0535091f, 18700.0785042f, 40887.2160091f, 2481.2642970f, 16395.3318047f, 9708.2267991f, 28732.0443100f, 20341.6893047f, 44475.2568100f, 2656.7750968f, 17545.2726051f, 10454.7875990f, 30939.0351110f, 21983.3001051f, 48063.2976110f, 2832.2858966f, 18695.2134056f, 11201.3483989f, 33146.0259119f, 23624.9109056f, 51651.3384119f, 3007.7966964f, 19845.1542060f, 11947.9091988f, 35353.0167129f, 25266.5217060f, 55239.3792129f, 3183.3074962f, 20995.095006f, 12694.4699987f, 37560.007513f, 26908.132506f, 58827.4200139f}; Nd4jLong _expGradWpS[] {4, 10, 6, 1, 1, 6, 1, 1, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, 1, 99}; NDArray expGWP(_expGradWpB, _expGradWpS); expGWP.permutei({2,3,1,0}); - TypeParam _expGradWdB[] = {2074.21032f, 2082.76104f, 2091.31176f, 2099.86248f, 2108.4132f, 2159.71752f, 2168.26824f, 2176.81896f, 2185.36968f, 2193.9204f, 2245.22472f, 2253.77544f, 2262.32616f, 2270.87688f, 2279.4276f, 2330.73192f, 2339.28264f, 2347.83336f, 2356.38408f, 2364.9348f, 2416.23912f, 2424.78984f, 2433.34056f, 2441.89128f, 2450.442f, 3112.99344f, 3122.06328f, 3131.13312f, 3140.20296f, 3149.2728f, 3203.69184f, 3212.76168f, 3221.83152f, 3230.90136f, 3239.9712f, 3294.39024f, 3303.46008f, 3312.52992f, 3321.59976f, 3330.6696f, 3385.08864f, 3394.15848f, 3403.22832f, 3412.29816f, 3421.368f, 3475.78704f, 3484.85688f, 3493.92672f, 3502.99656f, 3512.0664f, 4255.60056f, 4265.18952f, 4274.77848f, 4284.36744f, 4293.9564f, 4351.49016f, 4361.07912f, 4370.66808f, 4380.25704f, 4389.846f, 4447.37976f, 4456.96872f, 4466.55768f, 4476.14664f, 4485.7356f, 4543.26936f, 4552.85832f, 4562.44728f, 4572.03624f, 4581.6252f, 4639.15896f, 4648.74792f, 4658.33688f, 4667.92584f, 4677.5148f, 2140.10988f, 2148.92016f, 2157.73044f, 2166.54072f, 2175.351f, 2228.21268f, 2237.02296f, 2245.83324f, 2254.64352f, 2263.4538f, 2316.31548f, 2325.12576f, 2333.93604f, 2342.74632f, 2351.5566f, 2404.41828f, 2413.22856f, 2422.03884f, 2430.84912f, 2439.6594f, 2492.52108f, 2501.33136f, 2510.14164f, 2518.95192f, 2527.7622f, 3204.849f, 3214.1784f, 3223.5078f, 3232.8372f, 3242.1666f, 3298.143f, 3307.4724f, 3316.8018f, 3326.1312f, 3335.4606f, 3391.437f, 3400.7664f, 3410.0958f, 3419.4252f, 3428.7546f, 3484.731f, 3494.0604f, 3503.3898f, 3512.7192f, 3522.0486f, 3578.025f, 3587.3544f, 3596.6838f, 3606.0132f, 3615.3426f, 4373.41212f, 4383.26064f, 4393.10916f, 4402.95768f, 4412.8062f, 4471.89732f, 4481.74584f, 4491.59436f, 4501.44288f, 4511.2914f, 4570.38252f, 4580.23104f, 4590.07956f, 4599.92808f, 4609.7766f, 4668.86772f, 4678.71624f, 4688.56476f, 4698.41328f, 4708.2618f, 4767.35292f, 4777.20144f, 4787.04996f, 4796.89848f, 4806.747}; + TypeParam _expGradWdB[] = {2074.21032f, 2082.76104f, 2091.31176f, 2099.86248f, 2108.4132f, 2159.71752f, 2168.26824f, 2176.81896f, 2185.36968f, 2193.9204f, 2245.22472f, 2253.77544f, 2262.32616f, 2270.87688f, 2279.4276f, 2330.73192f, 2339.28264f, 2347.83336f, 2356.38408f, 2364.9348f, 2416.23912f, 2424.78984f, 2433.34056f, 2441.89128f, 2450.442f, 3112.99344f, 3122.06328f, 3131.13312f, 3140.20296f, 3149.2728f, 3203.69184f, 3212.76168f, 3221.83152f, 3230.90136f, 3239.9712f, 3294.39024f, 3303.46008f, 3312.52992f, 3321.59976f, 3330.6696f, 3385.08864f, 3394.15848f, 3403.22832f, 3412.29816f, 3421.368f, 3475.78704f, 3484.85688f, 3493.92672f, 3502.99656f, 3512.0664f, 4255.60056f, 4265.18952f, 4274.77848f, 4284.36744f, 4293.9564f, 4351.49016f, 4361.07912f, 4370.66808f, 4380.25704f, 4389.846f, 4447.37976f, 4456.96872f, 4466.55768f, 4476.14664f, 4485.7356f, 4543.26936f, 4552.85832f, 4562.44728f, 4572.03624f, 4581.6252f, 4639.15896f, 4648.74792f, 4658.33688f, 4667.92584f, 4677.5148f, 2140.10988f, 2148.92016f, 2157.73044f, 2166.54072f, 2175.351f, 2228.21268f, 2237.02296f, 2245.83324f, 2254.64352f, 2263.4538f, 2316.31548f, 2325.12576f, 2333.93604f, 2342.74632f, 2351.5566f, 2404.41828f, 2413.22856f, 2422.03884f, 2430.84912f, 2439.6594f, 2492.52108f, 2501.33136f, 2510.14164f, 2518.95192f, 2527.7622f, 3204.849f, 3214.1784f, 3223.5078f, 3232.8372f, 3242.1666f, 3298.143f, 3307.4724f, 3316.8018f, 3326.1312f, 3335.4606f, 3391.437f, 3400.7664f, 3410.0958f, 3419.4252f, 3428.7546f, 3484.731f, 3494.0604f, 3503.3898f, 3512.7192f, 3522.0486f, 3578.025f, 3587.3544f, 3596.6838f, 3606.0132f, 3615.3426f, 4373.41212f, 4383.26064f, 4393.10916f, 4402.95768f, 4412.8062f, 4471.89732f, 4481.74584f, 4491.59436f, 4501.44288f, 4511.2914f, 4570.38252f, 4580.23104f, 4590.07956f, 4599.92808f, 4609.7766f, 4668.86772f, 4678.71624f, 4688.56476f, 4698.41328f, 4708.2618f, 4767.35292f, 4777.20144f, 4787.04996f, 4796.89848f, 4806.747f}; Nd4jLong _expGradWdS[] = {4, 2, 3, 5, 5, 75, 25, 5, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, 1, 99}; NDArray expGWD(_expGradWdB, _expGradWdS); expGWD.permutei({2,3,1,0}); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp index 60351cc52..591746804 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp @@ -1594,7 +1594,7 @@ TEST_F(DeclarableOpsTests1, TestGemv1) { auto z = NDArrayFactory::create_('f', {5, 1}); - auto expBuffer = new float[5]{28.00,64.00,100.00,136.00,172.00}; + auto expBuffer = new float[5]{28.00f,64.00f,100.00f,136.00f,172.00f}; auto exp = new NDArray(expBuffer, z->getShapeInfo()); nd4j::blas::GEMV::op('f', x->rows(), x->columns(), 1.0f, x->getBuffer(), y->rows(), y->getBuffer(), 1, 0.0, z->getBuffer(), 1); @@ -3606,7 +3606,9 @@ TEST_F(DeclarableOpsTests1, Reverse_11 ) { auto input = NDArrayFactory::create('c', {2,3,4}); - auto expected = NDArrayFactory::create('c', {2,3,4}, {24., 23., 22., 21., 20., 19., 18., 17., 16., 15., 14., 13., 12., 11., 10., 9., 8., 7., 6., 5., 4., 3., 2., 1.}); + auto expected = NDArrayFactory::create('c', {2,3,4}, {24.f, 23.f, 22.f, 21.f, 20.f, 19.f, 18.f, 17.f, 16.f, + 15.f, 14.f, 13.f, 12.f, 11.f, 10.f, 9.f, 8.f, 7.f, + 6.f, 5.f, 4.f, 3.f, 2.f, 1.f}); input.linspace(1); nd4j::ops::reverse op; diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp index 6375d935c..21c18299e 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp @@ -121,10 +121,10 @@ TEST_F(DeclarableOpsTests10, Test_Or_1) { } TEST_F(DeclarableOpsTests10, Test_Not_1) { - auto x = NDArrayFactory::create('c', {4}, {1, 1, 0, 1}); - auto y = NDArrayFactory::create('c', {4}, {0, 0, 0, 1}); + auto x = NDArrayFactory::create('c', {4}, {true, true, false, true}); + auto y = NDArrayFactory::create('c', {4}, {false, false, false, true}); // auto e = NDArrayFactory::create('c', {4}, {1, 1, 1, 0}); - auto e = NDArrayFactory::create('c', {4}, {0, 0, 1, 0}); + auto e = NDArrayFactory::create('c', {4}, {false, false, true, false}); nd4j::ops::boolean_not op; auto result = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::BOOL); @@ -245,7 +245,8 @@ TEST_F(DeclarableOpsTests10, WhereNP_SGO_Test_1) { //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, WhereNP_SGO_Test_2) { - auto cond2d = NDArrayFactory::create('c', {3, 5}, {1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1}); + auto cond2d = NDArrayFactory::create('c', {3, 5}, {true, true, false, false, true, true, true, + true, true, true, false, true, true, true, true}); // auto expIdx({0, 1, 0, 2, 0, 3, 4, 1, 4, 1}); auto exp1 = NDArrayFactory::create({0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2}); auto exp2 = NDArrayFactory::create({0, 1, 4, 0, 1, 2, 3, 4, 1, 2, 3, 4}); @@ -623,7 +624,7 @@ TEST_F(DeclarableOpsTests10, range_test11) { ////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, range_test12) { - auto exp = NDArrayFactory::create('c', {9}, {0.5, 1. , 1.5, 2. , 2.5, 3. , 3.5, 4. , 4.5}); + auto exp = NDArrayFactory::create('c', {9}, {0.5f, 1.f , 1.5f, 2.f , 2.5f, 3.f , 3.5f, 4.f , 4.5f}); nd4j::ops::range op; auto result = op.execute({}, {0.5, 5, 0.5}, {}, {}); @@ -1416,7 +1417,7 @@ TEST_F(DeclarableOpsTests10, broadcast_to_test10) { //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test1) { - NDArray input = NDArrayFactory::create('c', {1, 2,3,4}); + NDArray input = NDArrayFactory::create('c', {1, 2, 3, 4}); //NDArray paddings('c', {3,2}, {0,0, 0,1, 0,0}); //NDArray expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.}); NDArray expected = NDArrayFactory::create('c', {1, 10, 10, 4}, {1., 2., 3., 4., 2.2, 3.2, 4.2, 5.2, 3.4, 4.4, 5.4, 6.4, @@ -1470,6 +1471,138 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test1) { delete results; } +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test_11) { + + NDArray input = NDArrayFactory::create('c', {1, 1, 1, 256}); + + input.assign(0.8f); //linspace(1); + auto size = NDArrayFactory::create({65,65}); + auto ex = NDArrayFactory::create('c', {1,65,65,256}); + nd4j::ops::resize_bilinear op; + auto results = op.execute({&input, &size}, {}, {}, {false}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + NDArray* result = results->at(0); + ASSERT_NE(*result, ex); + + delete results; +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test_12) { + + NDArray input = NDArrayFactory::create('c', {1, 1, 1, 256}); + + input.assign(0.8f); //linspace(1); + auto size = NDArrayFactory::create({65,65}); + auto ex = NDArrayFactory::create('c', {1,65,65,256}); + nd4j::ops::resize_bilinear op; + auto results = op.execute({&input, &size}, {}, {}, {true}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + NDArray* result = results->at(0); + ASSERT_NE(*result, ex); + + delete results; +} + +TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test1_1) { + + NDArray input = NDArrayFactory::create('c', {1, 2, 3, 4}); + //NDArray paddings('c', {3,2}, {0,0, 0,1, 0,0}); + //NDArray expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.}); + NDArray expected = NDArrayFactory::create('c', {1, 4, 5, 4}, { + 1., 2., 3., 4., + 2.6, 3.6, 4.6, 5.6, + 5., 6., 7., 8., + 7.4, 8.4, 9.4, 10.4, + 9., 10., 11., 12., + + 4., 5., 6., 7., + 5.6, 6.6, 7.6, 8.6, + 8., 9., 10., 11., + 10.4, 11.4, 12.4, 13.4, + 12., 13., 14., 15., + + 10., 11., 12., 13., + 11.6, 12.6, 13.6, 14.6, + 14., 15., 16., 17., + 16.4, 17.4, 18.4, 19.4, + 18., 19., 20., 21., + + 13., 14., 15., 16., + 14.6, 15.6, 16.6, 17.6, + 17., 18., 19., 20., + 19.4, 20.4, 21.4, 22.4, + 21., 22., 23., 24. + }); + //input = 1.f; + input.linspace(1); + + nd4j::ops::resize_bilinear op; + auto results = op.execute({&input}, {}, {4, 5}, {false, true}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + NDArray* result = results->at(0); + +// result->printIndexedBuffer("Resized to 4x5 bilinear with half pixels"); + //expected.printIndexedBuffer("Expect for 10x10"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + delete results; +} + +TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test1_2) { + + NDArray input = NDArrayFactory::create('c', {1, 2, 3, 4}); + //NDArray paddings('c', {3,2}, {0,0, 0,1, 0,0}); + //NDArray expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.}); + NDArray expected = NDArrayFactory::create('c', {1, 4, 5, 4}, { + 1.f, 2.f, 3.f, 4.f, + 2.6f, 3.6f, 4.6f, 5.6f, + 5.f, 6.f, 7.f, 8.f, + 7.4f, 8.4f, 9.4f, 10.4f, + 9.f, 10.f, 11.f, 12.f, + + 4.f, 5.f, 6.f, 7.f, + 5.6f, 6.6f, 7.6f, 8.6f, + 8.f, 9.f, 10.f, 11.f, + 10.4f, 11.4f, 12.4f, 13.4f, + 12.f, 13.f, 14.f, 15.f, + + 10.f, 11.f, 12.f, 13.f, + 11.6f, 12.6f, 13.6f, 14.6f, + 14.f, 15.f, 16.f, 17.f, + 16.4f, 17.4f, 18.4f, 19.4f, + 18.f, 19.f, 20.f, 21.f, + + 13.f, 14.f, 15.f, 16.f, + 14.6f, 15.6f, 16.6f, 17.6f, + 17.f, 18.f, 19.f, 20.f, + 19.4f, 20.4f, 21.4f, 22.4f, + 21.f, 22.f, 23.f, 24.f + }); + //input = 1.f; + input.linspace(1); + + nd4j::ops::resize_bilinear op; + auto results = op.execute({&input}, {}, {4, 5}, {false, true}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + NDArray* result = results->at(0); + +// result->printBuffer("Resized to 4x5"); +// expected.printBuffer("Expect for 4x5"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + delete results; +} + TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test01) { NDArray input = NDArrayFactory::create('c', {2,3,4}); @@ -1857,7 +1990,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test3) { input.linspace(1); nd4j::ops::resize_bilinear op; - auto results = op.execute({&input}, {}, {10, 10, 1}); + auto results = op.execute({&input}, {}, {10, 10}, {true}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -1986,7 +2119,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test4) { input.linspace(1); nd4j::ops::resize_bilinear op; - auto results = op.execute({&input, &size}, {}, {1}); + auto results = op.execute({&input, &size}, {}, {}, {true}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2023,7 +2156,8 @@ TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test1) { NDArray input = NDArrayFactory::create('c', {1, 2, 3, 4}); //NDArray paddings('c', {3,2}, {0,0, 0,1, 0,0}); //NDArray expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.}); - NDArray expected = NDArrayFactory::create('c', {1, 4, 5, 4}, { 1, 2, 3, 4, + NDArray expected = NDArrayFactory::create('c', {1, 4, 5, 4}, { + 1, 2, 3, 4, 1, 2, 3, 4, 5, 6, 7, 8, 5, 6, 7, 8, @@ -2051,7 +2185,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test1) { input.linspace(1); nd4j::ops::resize_nearest_neighbor op; - auto results = op.execute({&input}, {}, {4, 5}); + auto results = op.execute({&input}, {}, {4, 5}, {false, false}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -2070,7 +2204,8 @@ TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test1_1) { NDArray input = NDArrayFactory::create('c', {1, 2, 3, 4}); //NDArray paddings('c', {3,2}, {0,0, 0,1, 0,0}); //NDArray expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.}); - NDArray expected = NDArrayFactory::create('c', {1, 4, 5, 4}, { 1, 2, 3, 4, + NDArray expected = NDArrayFactory::create('c', {1, 4, 5, 4}, { + 1, 2, 3, 4, 1, 2, 3, 4, 5, 6, 7, 8, 5, 6, 7, 8, @@ -2112,6 +2247,54 @@ TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test1_1) { delete results; } +TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test1_1_1) { + + NDArray input = NDArrayFactory::create('c', {1, 2, 3, 4}); + //NDArray paddings('c', {3,2}, {0,0, 0,1, 0,0}); + //NDArray expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.}); + NDArray expected = NDArrayFactory::create('c', {1, 4, 5, 4}, { + 1.f, 2.f, 3.f, 4.f, + 1.f, 2.f, 3.f, 4.f, + 5.f, 6.f, 7.f, 8.f, + 9.f, 10.f, 11.f, 12.f, + 9.f, 10.f, 11.f, 12.f, + + 1.f, 2.f, 3.f, 4.f, + 1.f, 2.f, 3.f, 4.f, + 5.f, 6.f, 7.f, 8.f, + 9.f, 10.f, 11.f, 12.f, + 9.f, 10.f, 11.f, 12.f, + + 13.f, 14.f, 15.f, 16.f, + 13.f, 14.f, 15.f, 16.f, + 17.f, 18.f, 19.f, 20.f, + 21.f, 22.f, 23.f, 24.f, + 21.f, 22.f, 23.f, 24.f, + + 13.f, 14.f, 15.f, 16.f, + 13.f, 14.f, 15.f, 16.f, + 17.f, 18.f, 19.f, 20.f, + 21.f, 22.f, 23.f, 24.f, + 21.f, 22.f, 23.f, 24.f + }); + //input = 1.f; + input.linspace(1); + + nd4j::ops::resize_nearest_neighbor op; + auto results = op.execute({&input}, {}, {4,5}, {false, true}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + NDArray* result = results->at(0); + +// result->printIndexedBuffer("Resized to 4x5"); +// expected.printBuffer("Expect for 4x5"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + + delete results; +} + TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test01) { NDArray input = NDArrayFactory::create('c', {2, 3, 4}); @@ -2533,7 +2716,7 @@ TEST_F(DeclarableOpsTests10, Image_CropAndResize_3) { NDArray cropSize = NDArrayFactory::create({3, 3}); //NDArray ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); - NDArray expected('c', {1,3,3,1}, {1, 1.5f, 2., 2.f, 2.5f, 3.f, 3.f, 3.5f, 4.f}, nd4j::DataType::FLOAT32); + NDArray expected('c', {1,3,3,1}, {1.f, 1.5f, 2., 2.f, 2.5f, 3.f, 3.f, 3.5f, 4.f}, nd4j::DataType::FLOAT32); nd4j::ops::crop_and_resize op; auto results = op.execute({&images, &boxes, &boxI, &cropSize}, {}, {0}); @@ -2557,7 +2740,7 @@ TEST_F(DeclarableOpsTests10, Image_CropAndResize_4) { NDArray cropSize = NDArrayFactory::create({3, 3}); //NDArray ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); - NDArray expected('c', {1,3,3,1}, {1, 2.f, 2.f, 3.f, 4, 4.f, 3.f, 4.f, 4.f}, nd4j::DataType::FLOAT32); + NDArray expected('c', {1,3,3,1}, {1.f, 2.f, 2.f, 3.f, 4, 4.f, 3.f, 4.f, 4.f}, nd4j::DataType::FLOAT32); nd4j::ops::crop_and_resize op; auto results = op.execute({&images, &boxes, &boxI, &cropSize}, {}, {1}); @@ -2726,7 +2909,7 @@ TEST_F(DeclarableOpsTests10, Image_DrawBoundingBoxes_3) { TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_1) { NDArray x('c', {2,3}, {-63.80f, -63.75f, -63.70f, -63.5f, 0.0f, 0.1f}, nd4j::DataType::FLOAT32); - NDArray exp('c', {2,3}, {-63.75, -63.75, -63.75, -63.5, 0., 0.}, nd4j::DataType::FLOAT32); + NDArray exp('c', {2,3}, {-63.75f, -63.75f, -63.75f, -63.5f, 0.f, 0.f}, nd4j::DataType::FLOAT32); NDArray min('c', {}, {-63.65f}, nd4j::DataType::FLOAT32); NDArray max('c', {}, {0.1f}, nd4j::DataType::FLOAT32); @@ -2971,22 +3154,6 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_5) { delete results; } -/* public void testFakeQuantAgainstTF_1() { - INDArray x = Nd4j.createFromArray(new float[]{ 0.7788f,0.8012f, 0.7244f, 0.2309f,0.7271f, - 0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f, - 0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f}).reshape(3,5); - INDArray min = Nd4j.createFromArray(new float[]{-0.2283f, -0.0719f, -0.0154f, -0.5162f, -0.3567f}).reshape(1,5); - INDArray max = Nd4j.createFromArray(new float[]{0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f}).reshape(1,5); - - INDArray out = Nd4j.createUninitialized(x.shape()); - val op = new FakeQuantWithMinMaxVarsPerChannel(x,min,max,out); - - INDArray expected = Nd4j.createFromArray(new float[]{0.7801f, 0.5966f, 0.7260f, 0.2320f, 0.5084f, - 0.1800f, 0.5046f, 0.8684f, 0.3513f, 0.5084f, - 0.0877f, 0.5966f, 0.6600f, 0.3513f, 0.1604f}).reshape(3,5); - - assertEquals(expected, out); - }*/ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_6) { NDArray x = NDArrayFactory::create('c', {3, 5}, {0.7788f,0.8012f, 0.7244f, 0.2309f,0.7271f, 0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f, @@ -3094,12 +3261,12 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_8) { TEST_F(DeclarableOpsTests10, batchnorm_test1) { NDArray input ('c', {2,4}, nd4j::DataType::FLOAT32); - NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32); - NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32); - NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32); - NDArray beta ('c', {4}, {10, 20, -10, -20}, nd4j::DataType::FLOAT32); + NDArray mean ('c', {4}, {1.05f, 1.15f, 1.2f, 1.3f}, nd4j::DataType::FLOAT32); + NDArray variance('c', {4}, {0.5f, 0.7f, 0.9f, 1.1f}, nd4j::DataType::FLOAT32); + NDArray gamma ('c', {4}, {-1.2f, 1.3f, -1.4f, 1.5f}, nd4j::DataType::FLOAT32); + NDArray beta ('c', {4}, {10.f, 20.f, -10.f, -20.f}, nd4j::DataType::FLOAT32); - NDArray expected('c', {2,4}, {11.61218734, 18.52390321, -8.67185076, -21.28716864, 10.93337162, 19.14541765, -9.26213931, -20.71509369}, nd4j::DataType::FLOAT32); + NDArray expected('c', {2,4}, {11.61218734f, 18.52390321f, -8.67185076f, -21.28716864f, 10.93337162f, 19.14541765f, -9.26213931f, -20.71509369f}, nd4j::DataType::FLOAT32); input.linspace(0.1, 0.1); @@ -3211,19 +3378,19 @@ TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_test4) { TEST_F(DeclarableOpsTests10, batchnorm_test5) { NDArray input ('c', {2,4,2,2}, nd4j::DataType::FLOAT32); - NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32); - NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32); - NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32); - NDArray beta ('c', {4}, {10, 20, -10, -20}, nd4j::DataType::FLOAT32); + NDArray mean ('c', {4}, {1.05f, 1.15f, 1.2f, 1.3f}, nd4j::DataType::FLOAT32); + NDArray variance('c', {4}, {0.5f, 0.7f, 0.9f, 1.1f}, nd4j::DataType::FLOAT32); + NDArray gamma ('c', {4}, {-1.2f, 1.3f, -1.4f, 1.5f}, nd4j::DataType::FLOAT32); + NDArray beta ('c', {4}, {10.f, 20.f, -10.f, -20.f}, nd4j::DataType::FLOAT32); - NDArray expected('c', {2,4,2,2}, {11.612187, 11.442483, 11.272779, 11.103076, 18.990039, 19.145418, 19.300796, 19.456175, -9.557284, -9.704856, -9.852428, -10., -20., - -19.856981, -19.713963, -19.570944, 8.896924, 8.727221, 8.557517, 8.387813, 21.476097, 21.631475, 21.786854, 21.942233, -11.918438, - -12.06601 , -12.213582, -12.361154, -17.7117, -17.568681, -17.425663, -17.282644}, nd4j::DataType::FLOAT32); + NDArray expected('c', {2,4,2,2}, { 11.612187f, 11.442483f, 11.272779f, 11.103076f, 18.990039f, 19.145418f, 19.300796f, 19.456175f, -9.557284f, -9.704856f, -9.852428f, -10.f, -20.f, + -19.856981f, -19.713963f, -19.570944f, 8.896924f, 8.727221f, 8.557517f, 8.387813f, 21.476097f, 21.631475f, 21.786854f, 21.942233f, -11.918438f, + -12.06601f, -12.213582f, -12.361154f, -17.7117f, -17.568681f, -17.425663f, -17.282644f}, nd4j::DataType::FLOAT32); input.linspace(0.1, 0.1); nd4j::ops::batchnorm op; - auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1,1}); + auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1, 1, 1}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); @@ -3240,14 +3407,14 @@ TEST_F(DeclarableOpsTests10, batchnorm_test5) { TEST_F(DeclarableOpsTests10, batchnorm_test6) { NDArray input ('c', {2,2,2,4}, nd4j::DataType::FLOAT32); - NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32); - NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32); - NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32); - NDArray beta ('c', {4}, {10, 20, -10, -20}, nd4j::DataType::FLOAT32); + NDArray mean ('c', {4}, {1.05f, 1.15f, 1.2f, 1.3f}, nd4j::DataType::FLOAT32); + NDArray variance('c', {4}, {0.5f, 0.7f, 0.9, 1.1f}, nd4j::DataType::FLOAT32); + NDArray gamma ('c', {4}, {-1.2f, 1.3f, -1.4f, 1.5f}, nd4j::DataType::FLOAT32); + NDArray beta ('c', {4}, {10.f, 20.f, -10.f, -20.f}, nd4j::DataType::FLOAT32); - NDArray expected('c', {2,2,2,4}, {11.612187, 18.523903, -8.671851, -21.287169, 10.933372, 19.145418, -9.262139, -20.715094, 10.254556, 19.766932, -9.852428, -20.143019, 9.57574 , - 20.388447, -10.442716, -19.570944,8.896924, 21.009961, -11.033005, -18.998869, 8.218109, 21.631475, -11.623294, -18.426794, 7.539293, 22.25299 , - -12.213582, -17.854719, 6.860477, 22.874504, -12.803871, -17.282644}, nd4j::DataType::FLOAT32); + NDArray expected('c', {2,2,2,4}, {11.612187f, 18.523903f, -8.671851f, -21.287169f, 10.933372f, 19.145418f, -9.262139f, -20.715094f, 10.254556f, 19.766932f, -9.852428f, -20.143019f, 9.57574f, + 20.388447f, -10.442716f, -19.570944f, 8.896924f, 21.009961f, -11.033005f, -18.998869f, 8.218109f, 21.631475f, -11.623294f, -18.426794f, 7.539293f, 22.25299f, + -12.213582f, -17.854719f, 6.860477f, 22.874504f, -12.803871f, -17.282644f}, nd4j::DataType::FLOAT32); input.linspace(0.1, 0.1); nd4j::ops::batchnorm op; @@ -3270,7 +3437,7 @@ TEST_F(DeclarableOpsTests10, bool_broadcast_test_1) { NDArray arr1('c', {2,2,1}, {1, 2, 3, 4}, nd4j::DataType::INT32); NDArray arr2('c', { 2,2}, {0, 1, 0, 4}, nd4j::DataType::INT32); - NDArray expd('c', {2,2,2}, {0,1,0,0, 0,0,0,1}, nd4j::DataType::BOOL); + NDArray expd('c', {2,2,2}, {false, true, false, false, false, false, false, true}, nd4j::DataType::BOOL); NDArray result('c', {2,2,2}, nd4j::DataType::BOOL); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp index 67ecf5576..5ca22c95e 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp @@ -1257,7 +1257,7 @@ TEST_F(DeclarableOpsTests12, inTopK_2) { auto input = NDArrayFactory::create('c', {4, 5}); auto idx = NDArrayFactory::create('c', {4}); - auto exp = NDArrayFactory::create({0, 0, 0, 1}); + auto exp = NDArrayFactory::create({false, false, false, true}); int exclusive, reverse; input.linspace(1); @@ -1318,7 +1318,7 @@ TEST_F(DeclarableOpsTests12, inTopK_4) { TEST_F(DeclarableOpsTests12, inTopK_5) { auto x = NDArrayFactory::create('f', {6, 4}, {11.0, 3.0, 14.0, 5.0, 6.0, 9.0, 3.5, 7.0, 21.0, 3.0, 14.0, 15.0, 6.0, 9.0, 3.5, 7.0, 11.0, 13.0, 14.0, 5.0, 16.0, 9.0, 13.5, 7.0} ); auto y = NDArrayFactory::create('f', {6}, {0, 0, 0, 0, 0, 0}); - auto expV = NDArrayFactory::create('f', {6}, {1, 0, 0, 0, 0, 0 }); + auto expV = NDArrayFactory::create('f', {6}, {true, false, false, false, false, false }); nd4j::ops::in_top_k op; auto result = op.execute({&x, &y}, {}, {2}); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp index 76a44be0b..91ff89d46 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp @@ -1167,12 +1167,12 @@ TEST_F(DeclarableOpsTests13, lstmLayer_3) { std::initializer_list iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; - NDArray expH('c', {sL, bS, nOut}, {0.493883, 0.493883, 0.493883, 0.510990, 0.510990, 0.510990, 0.534701, 0.534701, 0.534701, 0.549139, - 0.549139, 0.549139, 0.571900, 0.571900, 0.571900, 0.583561, 0.583561, 0.583561, 0.605106, 0.605106, - 0.605106, 0.614114, 0.614114, 0.614114, 0.635354, 0.635354, 0.635354, 0.642045, 0.642045, 0.642045}, nd4j::DataType::FLOAT32); + NDArray expH('c', {sL, bS, nOut}, {0.493883f, 0.493883f, 0.493883f, 0.510990f, 0.510990f, 0.510990f, 0.534701f, 0.534701f, 0.534701f, 0.549139f, + 0.549139f, 0.549139f, 0.571900f, 0.571900f, 0.571900f, 0.583561f, 0.583561f, 0.583561f, 0.605106f, 0.605106f, + 0.605106f, 0.614114f, 0.614114f, 0.614114f, 0.635354f, 0.635354f, 0.635354f, 0.642045f, 0.642045f, 0.642045f}, nd4j::DataType::FLOAT32); - NDArray expHL('c', {bS, nOut}, {0.493883, 0.493883, 0.493883, 0.510990, 0.510990, 0.510990}, nd4j::DataType::FLOAT32); - NDArray expCL('c', {bS, nOut}, {1.061274, 1.061274, 1.061274, 1.115888, 1.115888, 1.115888}, nd4j::DataType::FLOAT32); + NDArray expHL('c', {bS, nOut}, {0.493883f, 0.493883f, 0.493883f, 0.510990f, 0.510990f, 0.510990f}, nd4j::DataType::FLOAT32); + NDArray expCL('c', {bS, nOut}, {1.061274f, 1.061274f, 1.061274f, 1.115888f, 1.115888f, 1.115888f}, nd4j::DataType::FLOAT32); nd4j::ops::lstmLayer op; auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs); @@ -1230,12 +1230,12 @@ TEST_F(DeclarableOpsTests13, lstmLayer_4) { NDArray cI('c', {2,bS, nOut}, nd4j::DataType::FLOAT32); x.linspace(0.5, 0.5); - Wx({0,1, 0,0, 0,0}) = 0.003; - Wx({1,2, 0,0, 0,0}) = -0.003; - Wr({0,1, 0,0, 0,0}) = 0.006; - Wr({1,2, 0,0, 0,0}) = -0.006; - b({0,1, 0,0}) = 0.5; - b({1,2, 0,0}) = -0.5; + Wx({0,1, 0,0, 0,0}) = 0.003f; + Wx({1,2, 0,0, 0,0}) = -0.003f; + Wr({0,1, 0,0, 0,0}) = 0.006f; + Wr({1,2, 0,0, 0,0}) = -0.006f; + b({0,1, 0,0}) = 0.5f; + b({1,2, 0,0}) = -0.5f; hI({0,1, 0,0, 0,0}) = 1; hI({1,2, 0,0, 0,0}) = -1; cI({0,1, 0,0, 0,0}) = 2; @@ -1245,18 +1245,19 @@ TEST_F(DeclarableOpsTests13, lstmLayer_4) { std::initializer_list iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; - NDArray expH('c', {sL, bS, 2*nOut}, {0.577661, 0.577661, 0.577661, -0.107642, -0.107642, -0.107642, 0.585289, 0.585289, 0.585289, - -0.106937, -0.106937, -0.106937, 0.556517, 0.556517, 0.556517, -0.111647, -0.111647, -0.111647, - 0.567274, 0.567274, 0.567274, -0.110214, -0.110214, -0.110214, 0.547395, 0.547395, 0.547395, - -0.123305, -0.123305, -0.123305, 0.560640, 0.560640, 0.560640, -0.120862, -0.120862, -0.120862, - 0.550714, 0.550714, 0.550714, -0.156223, -0.156223, -0.156223, 0.565308, 0.565308, 0.565308, - -0.152313, -0.152313, -0.152313, 0.563741, 0.563741, 0.563741, -0.234128, -0.234128, -0.234128, - 0.578676, 0.578676, 0.578676, -0.228917, -0.228917, -0.228917}, nd4j::DataType::FLOAT32); + NDArray expH('c', {sL, bS, 2 * nOut}, { + 0.577661f, 0.577661f, 0.577661f, -0.107642f, -0.107642f, -0.107642f, 0.585289f, 0.585289f, 0.585289f, + -0.106937f, -0.106937f, -0.106937f, 0.556517f, 0.556517f, 0.556517f, -0.111647f, -0.111647f, -0.111647f, + 0.567274f, 0.567274f, 0.567274f, -0.110214f, -0.110214f, -0.110214f, 0.547395f, 0.547395f, 0.547395f, + -0.123305f, -0.123305f, -0.123305f, 0.560640f, 0.560640f, 0.560640f, -0.120862f, -0.120862f, -0.120862f, + 0.550714f, 0.550714f, 0.550714f, -0.156223f, -0.156223f, -0.156223f, 0.565308f, 0.565308f, 0.565308f, + -0.152313f, -0.152313f, -0.152313f, 0.563741f, 0.563741f, 0.563741f, -0.234128f, -0.234128f, -0.234128f, + 0.578676f, 0.578676f, 0.578676f, -0.228917f, -0.228917f, -0.228917f}, nd4j::DataType::FLOAT32); - NDArray expHL('c', {2,bS, nOut}, {0.563741, 0.563741, 0.563741, 0.578676, 0.578676, 0.578676, -0.107642, - -0.107642, -0.107642, -0.106937, -0.106937, -0.106937}, nd4j::DataType::FLOAT32); - NDArray expCL('c', {2,bS, nOut}, {1.217757, 1.217757, 1.217757, 1.272398, 1.272398, 1.272398, -0.295768, - -0.295768, -0.295768, -0.298453, -0.298453, -0.298453}, nd4j::DataType::FLOAT32); + NDArray expHL('c', {2,bS, nOut}, {0.563741f, 0.563741f, 0.563741f, 0.578676f, 0.578676f, 0.578676f, -0.107642f, + -0.107642f, -0.107642f, -0.106937f, -0.106937f, -0.106937f}, nd4j::DataType::FLOAT32); + NDArray expCL('c', {2,bS, nOut}, {1.217757f, 1.217757f, 1.217757f, 1.272398f, 1.272398f, 1.272398f, -0.295768f, + -0.295768f, -0.295768f, -0.298453f, -0.298453f, -0.298453f}, nd4j::DataType::FLOAT32); nd4j::ops::lstmLayer op; auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs); @@ -1328,16 +1329,17 @@ TEST_F(DeclarableOpsTests13, lstmLayer_5) { std::initializer_list iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; - NDArray expH('c', {bS, sL, 2*nOut}, {0.577661, 0.577661, 0.577661, -0.107659, -0.107659, -0.107659, 0.548099, 0.548099, 0.548099, -0.113406, -0.113406, -0.113406, - 0.526881, 0.526881, 0.526881, -0.12883 , -0.12883 , -0.12883 , 0.515882, 0.515882, 0.515882, -0.16868 , -0.16868 , -0.16868 , - 0.51409 , 0.51409 , 0.51409 , -0.255185, -0.255185, -0.255185, 0.614599, 0.614599, 0.614599, -0.102739, -0.102739, -0.102739, - 0.599572, 0.599572, 0.599572, -0.105802, -0.105802, -0.105802,0.591089, 0.591089, 0.591089, -0.116681, -0.116681, -0.116681, - 0.588694, 0.588694, 0.588694, -0.149201, -0.149201, -0.149201,0.591492, 0.591492, 0.591492, -0.228917, -0.228917, -0.228917}, nd4j::DataType::FLOAT32); + NDArray expH('c', {bS, sL, 2*nOut}, { + 0.577661f, 0.577661f, 0.577661f, -0.107659f, -0.107659f, -0.107659f, 0.548099f, 0.548099f, 0.548099f, -0.113406f, -0.113406f, -0.113406f, + 0.526881f, 0.526881f, 0.526881f, -0.12883f, -0.12883f, -0.12883f, 0.515882f, 0.515882f, 0.515882f, -0.16868f, -0.16868f, -0.16868f, + 0.51409f, 0.51409f, 0.51409f, -0.255185f, -0.255185f, -0.255185f, 0.614599f, 0.614599f, 0.614599f, -0.102739f, -0.102739f, -0.102739f, + 0.599572f, 0.599572f, 0.599572f, -0.105802f, -0.105802f, -0.105802f, 0.591089f, 0.591089f, 0.591089f, -0.116681f, -0.116681f, -0.116681f, + 0.588694f, 0.588694f, 0.588694f, -0.149201f, -0.149201f, -0.149201f, 0.591492f, 0.591492f, 0.591492f, -0.228917f, -0.228917f, -0.228917f}, nd4j::DataType::FLOAT32); - NDArray expHL('c', {2,bS, nOut}, {0.51409 , 0.51409 , 0.51409 , 0.591492, 0.591492, 0.591492, - -0.107659, -0.107659, -0.107659, -0.102739, -0.102739, -0.102739}, nd4j::DataType::FLOAT32); - NDArray expCL('c', {2,bS, nOut}, {1.07293 , 1.07293 , 1.07293,1.346609, 1.346609, 1.346609, - -0.295811, -0.295811, -0.295811,-0.305394, -0.305394, -0.305394}, nd4j::DataType::FLOAT32); + NDArray expHL('c', {2,bS, nOut}, {0.51409f, 0.51409f, 0.51409f, 0.591492f, 0.591492f, 0.591492f, + -0.107659f, -0.107659f, -0.107659f, -0.102739f, -0.102739f, -0.102739f}, nd4j::DataType::FLOAT32); + NDArray expCL('c', {2,bS, nOut}, {1.07293f , 1.07293f , 1.07293f, 1.346609f, 1.346609f, 1.346609f, + -0.295811f, -0.295811f, -0.295811f, -0.305394f, -0.305394f, -0.305394f}, nd4j::DataType::FLOAT32); nd4j::ops::lstmLayer op; auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs); @@ -1398,12 +1400,12 @@ TEST_F(DeclarableOpsTests13, lstmLayer_6) { NDArray cI('c', {2,bS, nOut}, nd4j::DataType::FLOAT32); x.linspace(0.5, 0.5); - Wx({0,1, 0,0, 0,0}) = 0.003; - Wx({1,2, 0,0, 0,0}) = -0.003; - Wr({0,1, 0,0, 0,0}) = 0.006; - Wr({1,2, 0,0, 0,0}) = -0.006; - b({0,1, 0,0}) = 0.5; - b({1,2, 0,0}) = -0.5; + Wx({0,1, 0,0, 0,0}) = 0.003f; + Wx({1,2, 0,0, 0,0}) = -0.003f; + Wr({0,1, 0,0, 0,0}) = 0.006f; + Wr({1,2, 0,0, 0,0}) = -0.006f; + b({0,1, 0,0}) = 0.5f; + b({1,2, 0,0}) = -0.5f; hI({0,1, 0,0, 0,0}) = 1; hI({1,2, 0,0, 0,0}) = -1; cI({0,1, 0,0, 0,0}) = 2; @@ -1413,14 +1415,17 @@ TEST_F(DeclarableOpsTests13, lstmLayer_6) { std::initializer_list iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; - NDArray expH('c', {sL, bS, nOut}, {0.470019, 0.470019, 0.470019, 0.478352, 0.478352, 0.478352, 0.444871, 0.444871, 0.444871, 0.457060, - 0.457060, 0.457060, 0.424090, 0.424090, 0.424090, 0.439778, 0.439778, 0.439778, 0.394491, 0.394491, - 0.394491, 0.412995, 0.412995, 0.412995, 0.329613, 0.329613, 0.329613, 0.349760, 0.349760, 0.349760}, nd4j::DataType::FLOAT32); + NDArray expH('c', {sL, bS, nOut}, { + 0.470019f, 0.470019f, 0.470019f, 0.478352f, 0.478352f, 0.478352f, 0.444871f, 0.444871f, 0.444871f, 0.457060f, + 0.457060f, 0.457060f, 0.424090f, 0.424090f, 0.424090f, 0.439778f, 0.439778f, 0.439778f, 0.394491f, 0.394491f, + 0.394491f, 0.412995f, 0.412995f, 0.412995f, 0.329613f, 0.329613f, 0.329613f, 0.349760f, 0.349760f, 0.349760f}, nd4j::DataType::FLOAT32); - NDArray expHL('c', {2,bS, nOut}, {0.563741, 0.563741, 0.563741, 0.578676, 0.578676, 0.578676, -0.107642, - -0.107642, -0.107642, -0.106937, -0.106937, -0.106937}, nd4j::DataType::FLOAT32); - NDArray expCL('c', {2,bS, nOut}, {1.217757, 1.217757, 1.217757, 1.272398, 1.272398, 1.272398, -0.295768, - -0.295768, -0.295768, -0.298453, -0.298453, -0.298453}, nd4j::DataType::FLOAT32); + NDArray expHL('c', {2,bS, nOut}, {0.563741f, 0.563741f, 0.563741f, 0.578676f, 0.578676f, 0.578676f, + -0.107642f, -0.107642f, -0.107642f, -0.106937f, -0.106937f, -0.106937f}, + nd4j::DataType::FLOAT32); + NDArray expCL('c', {2,bS, nOut}, {1.217757f, 1.217757f, 1.217757f, 1.272398f, 1.272398f, 1.272398f, + -0.295768f, -0.295768f, -0.295768f, -0.298453f, -0.298453f, -0.298453f}, + nd4j::DataType::FLOAT32); nd4j::ops::lstmLayer op; auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs); @@ -1568,12 +1573,13 @@ TEST_F(DeclarableOpsTests13, lstmLayer_8) { std::initializer_list iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; - NDArray expH('c', {sL, bS, nOut}, {0.436221, 0.436221, 0.436221,0.450573, 0.450573, 0.450573,0.463602, 0.463602, 0.463602, 0.474674, 0.474674, 0.474674, - 0.484039, 0.484039, 0.484039,0.490679, 0.490679, 0.490679, 0.494871, 0.494871, 0.494871, 0.499028, 0.499028, 0.499028, - 0.504649, 0.504649, 0.504649, 0.508719, 0.508719, 0.508719}, nd4j::DataType::FLOAT32); + NDArray expH('c', {sL, bS, nOut}, { + 0.436221f, 0.436221f, 0.436221f, 0.450573f, 0.450573f, 0.450573f, 0.463602f, 0.463602f, 0.463602f, 0.474674f, 0.474674f, 0.474674f, + 0.484039f, 0.484039f, 0.484039f, 0.490679f, 0.490679f, 0.490679f, 0.494871f, 0.494871f, 0.494871f, 0.499028f, 0.499028f, 0.499028f, + 0.504649f, 0.504649f, 0.504649f, 0.508719f, 0.508719f, 0.508719f}, nd4j::DataType::FLOAT32); - NDArray expHL('c', {bS, nOut}, {0.436221, 0.436221, 0.436221, 0.450573, 0.450573, 0.450573}, nd4j::DataType::FLOAT32); - NDArray expCL('c', {bS, nOut}, {0.879804, 0.879804, 0.879804,0.914666, 0.914666, 0.914666}, nd4j::DataType::FLOAT32); + NDArray expHL('c', {bS, nOut}, {0.436221f, 0.436221f, 0.436221f, 0.450573f, 0.450573f, 0.450573f}, nd4j::DataType::FLOAT32); + NDArray expCL('c', {bS, nOut}, {0.879804f, 0.879804f, 0.879804f, 0.914666f, 0.914666f, 0.914666f}, nd4j::DataType::FLOAT32); nd4j::ops::lstmLayer op; auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); @@ -1650,16 +1656,17 @@ TEST_F(DeclarableOpsTests13, lstmLayer_9) { std::initializer_list iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; - NDArray expH('c', {sL, bS, 2*nOut}, { 0.55533 , 0.55533 , 0.55533 , -0.104502, -0.104502, -0.104502, 0.562925, 0.562925, 0.562925, -0.103843, -0.103843, -0.103843, - 0.531795, 0.531795, 0.531795, -0.107456, -0.107456, -0.107456,0.542556, 0.542556, 0.542556, -0.106139, -0.106139, -0.106139, - 0.521466, 0.521466, 0.521466, -0.11681 , -0.11681 , -0.11681 , 0.534638, 0.534638, 0.534638, -0.11458 , -0.11458 , -0.11458 , - 0.524805, 0.524805, 0.524805, -0.145177, -0.145177, -0.145177,0.539187, 0.539187, 0.539187, -0.14157 , -0.14157 , -0.14157 , - 0.538309, 0.538309, 0.538309, -0.218056, -0.218056, -0.218056,0.552923, 0.552923, 0.552923, -0.213068, -0.213068, -0.213068}, nd4j::DataType::FLOAT32); + NDArray expH('c', {sL, bS, 2*nOut}, { + 0.55533f, 0.55533f, 0.55533f, -0.104502f, -0.104502f, -0.104502f, 0.562925f, 0.562925f, 0.562925f, -0.103843f, -0.103843f, -0.103843f, + 0.531795f, 0.531795f, 0.531795f, -0.107456f, -0.107456f, -0.107456f, 0.542556f, 0.542556f, 0.542556f, -0.106139f, -0.106139f, -0.106139f, + 0.521466f, 0.521466f, 0.521466f, -0.11681f, -0.11681f, -0.11681f, 0.534638f, 0.534638f, 0.534638f, -0.11458f, -0.11458f, -0.11458f, + 0.524805f, 0.524805f, 0.524805f, -0.145177f, -0.145177f, -0.145177f, 0.539187f, 0.539187f, 0.539187f, -0.14157f, -0.14157f, -0.14157f, + 0.538309f, 0.538309f, 0.538309f, -0.218056f, -0.218056f, -0.218056f, 0.552923f, 0.552923f, 0.552923f, -0.213068f, -0.213068f, -0.213068f}, nd4j::DataType::FLOAT32); - NDArray expHL('c', {2,bS, nOut}, {0.538309, 0.538309, 0.538309, 0.552923, 0.552923, 0.552923, -0.104502, -0.104502, -0.104502, - -0.103843, -0.103843, -0.103843}, nd4j::DataType::FLOAT32); - NDArray expCL('c', {2,bS, nOut}, {1.147089, 1.147089, 1.147089, 1.197228, 1.197228, 1.197228, -0.289425, -0.289425, -0.289425, - -0.292174, -0.292174, -0.292174}, nd4j::DataType::FLOAT32); + NDArray expHL('c', {2,bS, nOut}, {0.538309f, 0.538309f, 0.538309f, 0.552923f, 0.552923f, 0.552923f, -0.104502f, -0.104502f, -0.104502f, + -0.103843f, -0.103843f, -0.103843f}, nd4j::DataType::FLOAT32); + NDArray expCL('c', {2,bS, nOut}, {1.147089f, 1.147089f, 1.147089f, 1.197228f, 1.197228f, 1.197228f, -0.289425f, -0.289425f, -0.289425f, + -0.292174f, -0.292174f, -0.292174f}, nd4j::DataType::FLOAT32); nd4j::ops::lstmLayer op; auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); @@ -1731,14 +1738,20 @@ TEST_F(DeclarableOpsTests13, lstmLayer_10) { std::initializer_list iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; - NDArray expH('c', {sL, bS, nOut}, {0., 0., 0., 0.562925, 0.562925, 0.562925, 0.570404, 0.570404, 0.570404, 0.57777 , 0.57777 , 0.57777 , 0.585023, 0.585023, 0.585023, - 0., 0., 0., 0., 0., 0., 0.576568, 0.576568, 0.576568, 0.586163, 0.586163, 0.586163, 0.595462, 0.595462, 0.595462, 0., 0., 0., 0., 0., - 0., 0., 0., 0., 0.611224, 0.611224, 0.611224, 0.621298, 0.621298, 0.621298, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., - 0.655858, 0.655858, 0.655858, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.692315, 0.692315, 0.692315, 0., 0., 0., 0., 0., 0., - 0., 0., 0., 0., 0., 0., 0., 0., 0.}, nd4j::DataType::FLOAT32); + NDArray expH('c', {sL, bS, nOut}, { + 0.f, 0.f, 0.f, 0.562925f, 0.562925f, 0.562925f, 0.570404f, 0.570404f, 0.570404f, 0.57777f, + 0.57777f, 0.57777f, 0.585023f, 0.585023f, 0.585023f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.576568f, 0.576568f, 0.576568f, 0.586163f, 0.586163f, 0.586163f, 0.595462f, 0.595462f, 0.595462f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.611224f, + 0.611224f, 0.611224f, 0.621298f, 0.621298f, 0.621298f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.655858f, 0.655858f, 0.655858f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.692315f, 0.692315f, 0.692315f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}, + nd4j::DataType::FLOAT32); - NDArray expHL('c', {bS, nOut}, {0., 0., 0., 0.562925, 0.562925, 0.562925, 0.576568, 0.576568, 0.576568, 0.611224, 0.611224, 0.611224, 0.692315, 0.692315, 0.692315}, nd4j::DataType::FLOAT32); - NDArray expCL('c', {bS, nOut}, {0., 0., 0., 1.534275, 1.534275, 1.534275, 1.40183, 1.40183, 1.40183, 1.449675, 1.449675, 1.449675, 1.767702, 1.767702, 1.767702}, nd4j::DataType::FLOAT32); + NDArray expHL('c', {bS, nOut}, {0.f, 0.f, 0.f, 0.562925f, 0.562925f, 0.562925f, 0.576568f, 0.576568f, 0.576568f, 0.611224f, 0.611224f, 0.611224f, 0.692315f, 0.692315f, 0.692315f}, nd4j::DataType::FLOAT32); + NDArray expCL('c', {bS, nOut}, {0.f, 0.f, 0.f, 1.534275f, 1.534275f, 1.534275f, 1.40183f, 1.40183f, 1.40183f, 1.449675f, 1.449675f, 1.449675f, 1.767702f, 1.767702f, 1.767702f}, nd4j::DataType::FLOAT32); nd4j::ops::lstmLayer op; auto results = op.execute({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); @@ -1799,25 +1812,26 @@ TEST_F(DeclarableOpsTests13, lstmLayer_11) { NDArray Wp('c', {3*nOut}, nd4j::DataType::FLOAT32); x.linspace(0.5, 0.5); - Wx = 0.003; - Wr = 0.006; - b = 0.5; - hI = 1.; - cI = 2.; - Wp = -0.05; + Wx = 0.003f; + Wr = 0.006f; + b = 0.5f; + hI = 1.f; + cI = 2.f; + Wp = -0.05f; std::initializer_list tArgs = {cellClip}; std::initializer_list iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; - NDArray expH('c', {sL, bS, nOut}, {0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.61209, - 0.61209, 0.61209,0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.652042, 0.652042, 0.652042, 0., 0., 0., 0., 0., - 0., 0., 0., 0., 0.677708, 0.677708, 0.677708, 0.684177, 0.684177, 0.684177, 0., 0., 0.,0., 0., 0.,0.699627, 0.699627, - 0.699627,0.705371, 0.705371, 0.705371,0.710989, 0.710989, 0.710989, 0., 0., 0., 0.719014, 0.719014, 0.719014, 0.724087, - 0.724087, 0.724087, 0.729084, 0.729084, 0.729084, 0.734004, 0.734004, 0.734004 }, nd4j::DataType::FLOAT32); + NDArray expH('c', {sL, bS, nOut}, { + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.61209f, + 0.61209f, 0.61209f,0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.652042f, 0.652042f, 0.652042f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.677708f, 0.677708f, 0.677708f, 0.684177f, 0.684177f, 0.684177f, 0.f, 0.f, 0.f,0.f, 0.f, 0.f, 0.699627f, 0.699627f, + 0.699627f, 0.705371f, 0.705371f, 0.705371f, 0.710989f, 0.710989f, 0.710989f, 0., 0., 0., 0.719014, 0.719014, 0.719014, 0.724087, + 0.724087f, 0.724087f, 0.729084f, 0.729084f, 0.729084f, 0.734004f, 0.734004f, 0.734004f }, nd4j::DataType::FLOAT32); - NDArray expHL('c', {bS, nOut}, {0., 0., 0., 0.719014, 0.719014, 0.719014, 0.699627, 0.699627, 0.699627, 0.677708, 0.677708, 0.677708, 0.61209, 0.61209, 0.61209}, nd4j::DataType::FLOAT32); - NDArray expCL('c', {bS, nOut}, {0., 0., 0., 2.092814, 2.092814, 2.092814, 2.08832, 2.08832, 2.08832, 2.009851, 2.009851, 2.009851, 1.646034, 1.646034, 1.646034}, nd4j::DataType::FLOAT32); + NDArray expHL('c', {bS, nOut}, {0.f, 0.f, 0.f, 0.719014f, 0.719014f, 0.719014f, 0.699627f, 0.699627f, 0.699627f, 0.677708f, 0.677708f, 0.677708f, 0.61209f, 0.61209f, 0.61209f}, nd4j::DataType::FLOAT32); + NDArray expCL('c', {bS, nOut}, {0.f, 0.f, 0.f, 2.092814f, 2.092814f, 2.092814f, 2.08832f, 2.08832f, 2.08832f, 2.009851f, 2.009851f, 2.009851f, 1.646034f, 1.646034f, 1.646034f}, nd4j::DataType::FLOAT32); nd4j::ops::lstmLayer op; auto results = op.execute({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); @@ -1878,18 +1892,18 @@ TEST_F(DeclarableOpsTests13, lstmLayer_12) { NDArray Wp('c', {2,3*nOut}, nd4j::DataType::FLOAT32); x.linspace(0.5, 0.5); - Wx({0,1, 0,0, 0,0}) = 0.003; - Wx({1,2, 0,0, 0,0}) = -0.003; - Wr({0,1, 0,0, 0,0}) = 0.006; - Wr({1,2, 0,0, 0,0}) = -0.006; - b({0,1, 0,0}) = 0.5; - b({1,2, 0,0}) = -0.5; + Wx({0,1, 0,0, 0,0}) = 0.003f; + Wx({1,2, 0,0, 0,0}) = -0.003f; + Wr({0,1, 0,0, 0,0}) = 0.006f; + Wr({1,2, 0,0, 0,0}) = -0.006f; + b({0,1, 0,0}) = 0.5f; + b({1,2, 0,0}) = -0.5f; hI({0,1, 0,0, 0,0}) = 1; hI({1,2, 0,0, 0,0}) = -1; cI({0,1, 0,0, 0,0}) = 2; cI({1,2, 0,0, 0,0}) = -2; - Wp({0,1, 0,0}) = -0.05; - Wp({1,2, 0,0}) = 0.05; + Wp({0,1, 0,0}) = -0.05f; + Wp({1,2, 0,0}) = 0.05f; std::initializer_list tArgs = {cellClip}; std::initializer_list iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; @@ -1905,10 +1919,10 @@ TEST_F(DeclarableOpsTests13, lstmLayer_12) { 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.692315, 0.692315, 0.692315, -0.143704, -0.143704, -0.143704, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.}, nd4j::DataType::FLOAT32); - NDArray expHL('c', {2,bS, nOut}, {0., 0., 0., 0.562925, 0.562925, 0.562925, 0.576568, 0.576568, 0.576568, 0.611224, 0.611224, 0.611224, 0.692315, 0.692315, 0.692315, - 0., 0., 0., -0.25361 , -0.25361 , -0.25361 , -0.157103, -0.157103, -0.157103,-0.116502, -0.116502, -0.116502, -0.100025, -0.100025, -0.100025}, nd4j::DataType::FLOAT32); - NDArray expCL('c', {2,bS, nOut}, {0., 0., 0.,1.534275, 1.534275, 1.534275,1.40183 , 1.40183 , 1.40183 ,1.449675, 1.449675, 1.449675,1.767702, 1.767702, 1.767702, - 0., 0., 0.,-0.86636 , -0.86636 , -0.86636 ,-0.470245, -0.470245, -0.470245,-0.341856, -0.341856, -0.341856,-0.294986, -0.294986, -0.294986}, nd4j::DataType::FLOAT32); + NDArray expHL('c', {2,bS, nOut}, {0.f, 0.f, 0.f, 0.562925f, 0.562925f, 0.562925f, 0.576568f, 0.576568f, 0.576568f, 0.611224f, 0.611224f, 0.611224f, 0.692315f, 0.692315f, 0.692315f, + 0.f, 0.f, 0.f, -0.25361f, -0.25361f, -0.25361f, -0.157103f, -0.157103f, -0.157103f, -0.116502f, -0.116502f, -0.116502f, -0.100025f, -0.100025f, -0.100025f}, nd4j::DataType::FLOAT32); + NDArray expCL('c', {2,bS, nOut}, {0.f, 0.f, 0.f, 1.534275f, 1.534275f, 1.534275f, 1.40183f, 1.40183f, 1.40183f, 1.449675f, 1.449675f, 1.449675f, 1.767702f, 1.767702f, 1.767702f, + 0.f, 0.f, 0.f, -0.86636f, -0.86636f, -0.86636f, -0.470245f, -0.470245f, -0.470245f, -0.341856f, -0.341856f, -0.341856f, -0.294986f, -0.294986f, -0.294986f}, nd4j::DataType::FLOAT32); nd4j::ops::lstmLayer op; auto results = op.execute({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp index b2ccad86f..488adad0c 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp @@ -148,8 +148,8 @@ TEST_F(DeclarableOpsTests15, Test_standarize_1) { } TEST_F(DeclarableOpsTests15, Test_standarize_bp_1) { - auto x = NDArrayFactory::create('c', {5}, {1., 1., 1., 1., 1.}); - auto eps = NDArrayFactory::create('c', {5}, {0., 0., 0., 0., 0.}); + auto x = NDArrayFactory::create('c', {5}, {1.f, 1.f, 1.f, 1.f, 1.f}); + auto eps = NDArrayFactory::create('c', {5}, {0.f, 0.f, 0.f, 0.f, 0.f}); nd4j::ops::standardize_bp op; auto result = op.execute({&x, &eps}, {}, {0}, {}); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests2.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests2.cpp index 9f9c39156..a8377b429 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests2.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests2.cpp @@ -1591,7 +1591,7 @@ TEST_F(DeclarableOpsTests2, cosine_distance_loss_test6) { auto *result = results->at(0); ASSERT_TRUE(result->isScalar()); - ASSERT_TRUE(result->e(0) == -71.); + ASSERT_TRUE(result->e(0) == -71.f); delete results; @@ -1616,7 +1616,7 @@ TEST_F(DeclarableOpsTests2, cosine_distance_loss_test7) { auto *result = results->at(0); ASSERT_TRUE(result->isScalar()); - ASSERT_TRUE(result->e(0) == -69.); + ASSERT_TRUE(result->e(0) == -69.f); delete results; @@ -1630,8 +1630,8 @@ TEST_F(DeclarableOpsTests2, cosine_distance_loss_test8) { auto weights = NDArrayFactory::create('c', {2,3,1}); labels.linspace(1); - weights.assign(0.5); - predictions.assign(0.5); + weights.assign(0.5f); + predictions.assign(0.5f); nd4j::ops::cosine_distance_loss op; auto results = op.execute({&predictions, &weights, &labels}, {}, {2,2}); @@ -1641,7 +1641,7 @@ TEST_F(DeclarableOpsTests2, cosine_distance_loss_test8) { auto *result = results->at(0); ASSERT_TRUE(result->isScalar()); - ASSERT_TRUE(result->e(0) == -24.); + ASSERT_TRUE(result->e(0) == -24.f); delete results; @@ -1655,8 +1655,8 @@ TEST_F(DeclarableOpsTests2, cosine_distance_loss_test9) { auto weights = NDArrayFactory::create('c', {1,1}); labels.linspace(1); - weights.assign(0.5); - predictions.assign(0.5); + weights.assign(0.5f); + predictions.assign(0.5f); nd4j::ops::cosine_distance_loss op; auto results = op.execute({&predictions, &weights, &labels}, {}, {2,2}); @@ -1680,10 +1680,10 @@ TEST_F(DeclarableOpsTests2, cosine_distance_loss_test10) { auto weights = NDArrayFactory::create('c', {2,3,1}); labels.linspace(1); - weights.assign(0.5); - predictions.assign(0.5); - weights.p(0, 0.); - weights.p(1, 0.); + weights.assign(0.5f); + predictions.assign(0.5f); + weights.p(0, 0.f); + weights.p(1, 0.f); nd4j::ops::cosine_distance_loss op; auto results = op.execute({&predictions, &weights, &labels}, {}, {2,2}); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp index 6d224b323..5322a0a6d 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp @@ -1707,7 +1707,7 @@ TEST_F(DeclarableOpsTests3, betainc_test8) { b.linspace(10.); x.assign(1.); - auto expected= NDArrayFactory::create('c', {3,3}, {1.f, 1.f, 1.,1.,1.,1.,1.,1.,1.}); + auto expected= NDArrayFactory::create('c', {3,3}, {1.f, 1.f, 1.f,1.f,1.f,1.f,1.f,1.f,1.f}); nd4j::ops::betainc op; auto results = op.execute({&a, &b, &x}, {}, {}); @@ -2292,9 +2292,9 @@ TEST_F(DeclarableOpsTests3, svd_test3) { } else { for(uint i = 0; i < expU.lengthOf(); ++i) - ASSERT_NEAR(nd4j::math::nd4j_abs(expU.e(i)), nd4j::math::nd4j_abs(u->e(i)), 1e-5); + ASSERT_NEAR(nd4j::math::nd4j_abs(expU.e(i)), nd4j::math::nd4j_abs(u->e(i)), 1e-5f); for(uint i = 0; i < expV.lengthOf(); ++i) - ASSERT_NEAR(nd4j::math::nd4j_abs(expV.e(i)), nd4j::math::nd4j_abs(v->e(i)), 1e-5); + ASSERT_NEAR(nd4j::math::nd4j_abs(expV.e(i)), nd4j::math::nd4j_abs(v->e(i)), 1e-5f); } delete results; @@ -2329,9 +2329,9 @@ TEST_F(DeclarableOpsTests3, svd_test4) { } else { for(uint i = 0; i < expU.lengthOf(); ++i) - ASSERT_NEAR(nd4j::math::nd4j_abs(expU.e(i)), nd4j::math::nd4j_abs(u->e(i)), 1e-5); + ASSERT_NEAR(nd4j::math::nd4j_abs(expU.e(i)), nd4j::math::nd4j_abs(u->e(i)), 1e-5f); for(uint i = 0; i < expV.lengthOf(); ++i) - ASSERT_NEAR(nd4j::math::nd4j_abs(expV.e(i)), nd4j::math::nd4j_abs(v->e(i)), 1e-5); + ASSERT_NEAR(nd4j::math::nd4j_abs(expV.e(i)), nd4j::math::nd4j_abs(v->e(i)), 1e-5f); } delete results; @@ -2366,9 +2366,9 @@ TEST_F(DeclarableOpsTests3, svd_test5) { } else { for(uint i = 0; i < expU.lengthOf(); ++i) - ASSERT_NEAR(nd4j::math::nd4j_abs(expU.e(i)), nd4j::math::nd4j_abs(u->e(i)), 1e-5); + ASSERT_NEAR(nd4j::math::nd4j_abs(expU.e(i)), nd4j::math::nd4j_abs(u->e(i)), 1e-5f); for(uint i = 0; i < expV.lengthOf(); ++i) - ASSERT_NEAR(nd4j::math::nd4j_abs(expV.e(i)), nd4j::math::nd4j_abs(v->e(i)), 1e-5); + ASSERT_NEAR(nd4j::math::nd4j_abs(expV.e(i)), nd4j::math::nd4j_abs(v->e(i)), 1e-5f); } delete results; @@ -2421,9 +2421,9 @@ TEST_F(DeclarableOpsTests3, svd_test6) { } else { for(uint i = 0; i < expU.lengthOf(); ++i) - ASSERT_NEAR(nd4j::math::nd4j_abs(expU.e(i)), nd4j::math::nd4j_abs(u->e(i)), 1e-5); + ASSERT_NEAR(nd4j::math::nd4j_abs(expU.e(i)), nd4j::math::nd4j_abs(u->e(i)), 1e-5f); for(uint i = 0; i < expV.lengthOf(); ++i) - ASSERT_NEAR(nd4j::math::nd4j_abs(expV.e(i)), nd4j::math::nd4j_abs(v->e(i)), 1e-5); + ASSERT_NEAR(nd4j::math::nd4j_abs(expV.e(i)), nd4j::math::nd4j_abs(v->e(i)), 1e-5f); } delete results; diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp index 23351f7af..220191011 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp @@ -4084,7 +4084,7 @@ TEST_F(DeclarableOpsTests7, Softsign_1) { TEST_F(DeclarableOpsTests7, Softsign_BP_1) { NDArray x = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); -// NDArray e = NDArrayFactory::create('c', {5, 2}, {1.3132616, 2.126928, 3.0485873, 4.01815, 5.0067153, 7.0009117, 9.000123, 10.000046, 10.000046, 11.000016}); +// NDArray e = NDArrayFactory::create('c', {5, 2}, {1.3132616f, 2.126928f, 3.0485873f, 4.01815f, 5.0067153f, 7.0009117f, 9.000123f, 10.000046f, 10.000046f, 11.000016f}); NDArray eps = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,6,7,8, 9, 10}); nd4j::ops::softsign ffOP; nd4j::ops::softsign_bp bpOp; diff --git a/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp b/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp index c89a989a9..e7f7f7e68 100644 --- a/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp @@ -661,9 +661,9 @@ TEST_F(JavaInteropTests, Test_Greater_1) { auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 1, 2}); auto y = NDArrayFactory::create('c', {2, 2}, {1, 2, 0, 0}); // auto o = NDArrayFactory::create('c', {2, 2}, {3, 3, 3, 3}); - auto o = NDArrayFactory::create('c', {2, 2}, {1, 1, 1, 1}); + auto o = NDArrayFactory::create('c', {2, 2}, {true, true, true, true}); - auto exp = NDArrayFactory::create('c', {2, 2}, {0, 0, 1, 1}); + auto exp = NDArrayFactory::create('c', {2, 2}, {false, false, true, true}); NDArray::prepareSpecialUse({&o}, {&x, &y}); @@ -685,9 +685,9 @@ TEST_F(JavaInteropTests, Test_Greater_1) { TEST_F(JavaInteropTests, Test_Greater_2) { auto x = NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, 1.f, 2.f}); auto y = NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, 0.f, 0.f}); - auto o = NDArrayFactory::create('c', {2, 2}, {1, 1, 1, 1}); + auto o = NDArrayFactory::create('c', {2, 2}, {true, true, true, true}); - auto exp = NDArrayFactory::create('c', {2, 2}, {0, 0, 1, 1}); + auto exp = NDArrayFactory::create('c', {2, 2}, {false, false, true, true}); nd4j::ops::greater op; diff --git a/libnd4j/tests_cpu/layers_tests/NDArrayCudaBasicsTests.cu b/libnd4j/tests_cpu/layers_tests/NDArrayCudaBasicsTests.cu index 71ad6929b..7740cd1ac 100644 --- a/libnd4j/tests_cpu/layers_tests/NDArrayCudaBasicsTests.cu +++ b/libnd4j/tests_cpu/layers_tests/NDArrayCudaBasicsTests.cu @@ -1163,10 +1163,10 @@ TEST_F(NDArrayCudaBasicsTests, applyReduce3_1) { NDArray k('c', {2,3}, {-2,3,-4,5,-2,3}, nd4j::DataType::INT32); NDArray k2('c', {3,2}, {-2,3,-4,5,-2,3}, nd4j::DataType::INT32); - NDArray exp1('c', {3}, {4., 20., 36.}, nd4j::DataType::FLOAT32); - NDArray exp2('c', {2,3}, {-10., -2., 6.,14., 22., 30.}, nd4j::DataType::FLOAT32); - NDArray exp3('c', {4}, {38., 41., 44., 47.}, nd4j::DataType::FLOAT32); - NDArray exp4('c', {4}, {114., 117., 120., 123.}, nd4j::DataType::FLOAT32); + NDArray exp1('c', {3}, {4.f, 20.f, 36.f}, nd4j::DataType::FLOAT32); + NDArray exp2('c', {2,3}, {-10.f, -2.f, 6.f,14.f, 22.f, 30.f}, nd4j::DataType::FLOAT32); + NDArray exp3('c', {4}, {38.f, 41.f, 44.f, 47.f}, nd4j::DataType::FLOAT32); + NDArray exp4('c', {4}, {114.f, 117.f, 120.f, 123.f}, nd4j::DataType::FLOAT32); NDArray* z = x.applyReduce3(nd4j::reduce3::Dot, &y, {0,2}); @@ -1271,8 +1271,10 @@ TEST_F(NDArrayCudaBasicsTests, applyAllReduce3_1) { NDArray x3('c', {3,2}, {1.5,1.5,1.5,1.5,1.5,1.5}, nd4j::DataType::DOUBLE); NDArray x4('c', {3,2}, {1,2,3,4,5,6}, nd4j::DataType::DOUBLE); - NDArray exp1('c', {3,2}, {-88., -124., 6., -2., 22., 14.}, nd4j::DataType::FLOAT32); - NDArray exp2('c', {6,4}, {-36., -44., -52., -60.,-42., -52., -62., -72.,2., 0., -2., -4.,6., 4., 2., 0.,10., 8., 6., 4.,14., 12., 10., 8.}, nd4j::DataType::FLOAT32); + NDArray exp1('c', {3,2}, {-88.f, -124.f, 6.f, -2.f, 22.f, 14.f}, nd4j::DataType::FLOAT32); + NDArray exp2('c', {6,4}, {-36.f, -44.f, -52.f, -60.f,-42.f, -52.f, -62.f, -72.f, 2.f, 0.f, -2.f, + -4.f, 6.f, 4.f, 2.f, 0.f, 10.f, 8.f, 6.f, 4.f, 14.f, 12.f, 10.f, 8.f}, + nd4j::DataType::FLOAT32); NDArray exp3('c', {1,1}, {31.5}, nd4j::DataType::DOUBLE); NDArray exp4('c', {3,3}, {4.5, 10.5, 16.5,4.5, 10.5, 16.5,4.5, 10.5, 16.5}, nd4j::DataType::DOUBLE); @@ -1400,10 +1402,10 @@ TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_float_test1) { NDArray z5('c', {2}, {100,100}, nd4j::DataType::FLOAT32); NDArray exp1('c', {}, {2.166667}, nd4j::DataType::DOUBLE); - NDArray exp2('c', {2,2}, {3,4,1,0.666667}, nd4j::DataType::FLOAT32); + NDArray exp2('c', {2,2}, {3.f,4.f,1.f,0.666667f}, nd4j::DataType::FLOAT32); NDArray exp3('c', {3}, {4.5,1,1}, nd4j::DataType::DOUBLE); NDArray exp4('c', {3,2}, {4,5,1,1,1,1}, nd4j::DataType::FLOAT32); - NDArray exp5('c', {2}, {3.5,0.833333}, nd4j::DataType::FLOAT32); + NDArray exp5('c', {2}, {3.5f,0.833333f}, nd4j::DataType::FLOAT32); x.reduceAlongDimension(nd4j::reduce::Mean, &z1, {0,1,2}); ASSERT_TRUE(z1.equalsTo(&exp1)); @@ -1503,7 +1505,7 @@ TEST_F(NDArrayCudaBasicsTests, EqualityTest1) { //////////////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_same_test1) { - NDArray x('c', {2,3,2}, {1.5,2,3,4,5,6,7.5,8,-1,-2,-3.5,-4,}, nd4j::DataType::FLOAT32); + NDArray x('c', {2,3,2}, {1.5f,2.f,3.f,4.f,5.f,6.f,7.5f,8.f,-1.f,-2.f,-3.5f,-4.f}, nd4j::DataType::FLOAT32); NDArray z1('c', {}, {100}, nd4j::DataType::FLOAT32); NDArray z2('c', {2,2}, {100,100,100,100}, nd4j::DataType::FLOAT32); @@ -1511,11 +1513,11 @@ TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_same_test1) { NDArray z4('c', {3,2}, {100,100,100,100,100,100}, nd4j::DataType::FLOAT32); NDArray z5('c', {2}, {100,100}, nd4j::DataType::FLOAT32); - NDArray exp1('c', {}, {26.5}, nd4j::DataType::FLOAT32); - NDArray exp2('c', {2,2}, {9.5,12,3,2}, nd4j::DataType::FLOAT32); - NDArray exp3('c', {3}, {19,4,3.5}, nd4j::DataType::FLOAT32); - NDArray exp4('c', {3,2}, {9,10,2,2,1.5,2}, nd4j::DataType::FLOAT32); - NDArray exp5('c', {2}, {21.5,5}, nd4j::DataType::FLOAT32); + NDArray exp1('c', {}, {26.5f}, nd4j::DataType::FLOAT32); + NDArray exp2('c', {2,2}, {9.5f,12.f,3.f,2.f}, nd4j::DataType::FLOAT32); + NDArray exp3('c', {3}, {19.f,4.f,3.5f}, nd4j::DataType::FLOAT32); + NDArray exp4('c', {3,2}, {9.f,10.f,2.f,2.f,1.5f,2.f}, nd4j::DataType::FLOAT32); + NDArray exp5('c', {2}, {21.5f,5.f}, nd4j::DataType::FLOAT32); x.reduceAlongDimension(nd4j::reduce::Sum, &z1, {0,1,2}); ASSERT_TRUE(z1.equalsTo(&exp1)); @@ -1575,17 +1577,17 @@ TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_bool_test1) { NDArray x('c', {2,3,2}, {0.5,2,3,-4,5,6,-7.5,8,-1,-0.5,-3.5,4}, nd4j::DataType::DOUBLE); - NDArray z1('c', {}, {100}, nd4j::DataType::BOOL); - NDArray z2('c', {2,2}, {100,100,100,100}, nd4j::DataType::BOOL); - NDArray z3('c', {3}, {100,100,100}, nd4j::DataType::BOOL); - NDArray z4('c', {3,2}, {100,100,100,100,100,100}, nd4j::DataType::BOOL); - NDArray z5('c', {2}, {100,100}, nd4j::DataType::BOOL); + NDArray z1('c', {}, {true}, nd4j::DataType::BOOL); + NDArray z2('c', {2,2}, {true,true,true,true}, nd4j::DataType::BOOL); + NDArray z3('c', {3}, {true,true,true}, nd4j::DataType::BOOL); + NDArray z4('c', {3,2}, {true,true,true,true,true,true}, nd4j::DataType::BOOL); + NDArray z5('c', {2}, {true,true}, nd4j::DataType::BOOL); - NDArray exp1('c', {}, {1}, nd4j::DataType::BOOL); - NDArray exp2('c', {2,2}, {1,1,0,1}, nd4j::DataType::BOOL); - NDArray exp3('c', {3}, {1,1,1}, nd4j::DataType::BOOL); - NDArray exp4('c', {3,2}, {1,1,1,0,1,1}, nd4j::DataType::BOOL); - NDArray exp5('c', {2}, {1,1}, nd4j::DataType::BOOL); + NDArray exp1('c', {}, {true}, nd4j::DataType::BOOL); + NDArray exp2('c', {2,2}, {true,true,false,true}, nd4j::DataType::BOOL); + NDArray exp3('c', {3}, {true,true,true}, nd4j::DataType::BOOL); + NDArray exp4('c', {3,2}, {true,true,true,false,true,true}, nd4j::DataType::BOOL); + NDArray exp5('c', {2}, {true,true}, nd4j::DataType::BOOL); x.reduceAlongDimension(nd4j::reduce::IsPositive, &z1, {0,1,2}); ASSERT_TRUE(z1.equalsTo(&exp1)); @@ -1643,7 +1645,7 @@ TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_bool_test2) { //////////////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_long_test1) { - NDArray x('c', {2,3,2}, {0.5,2,3,-0,5,6,-7.5,0,-1,-0.5,-3.5,4}, nd4j::DataType::FLOAT32); + NDArray x('c', {2,3,2}, {0.5f,2.f,3.f,-0.f,5.f,6.f,-7.5f,0.f,-1.f,-0.5f,-3.5f,4.f}, nd4j::DataType::FLOAT32); NDArray z1('c', {}, {100}, nd4j::DataType::INT64); NDArray z2('c', {2,2}, {100,100,100,100}, nd4j::DataType::INT64); @@ -1912,7 +1914,7 @@ TEST_F(NDArrayCudaBasicsTests, Tile_Test_2_3) TEST_F(NDArrayCudaBasicsTests, Operator_Plus_Test_2) { double expBuff[] = {2., 3, 3., 4., 4., 5, 5., 6., 6., 7, 7., 8.}; - NDArray a('c', {4,4}, {1.,2,3,4,5,6,7,8,9,2,3,2,1,0,4,7.}, nd4j::DataType::FLOAT32); + NDArray a('c', {4,4}, {1,2,3,4,5,6,7,8,9,2,3,2,1,0,4,7}, nd4j::DataType::FLOAT32); auto x = NDArrayFactory::create('c', {3, 2, 1}); auto y = NDArrayFactory::create('c', {1, 2}); auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 2}); @@ -1928,7 +1930,7 @@ TEST_F(NDArrayCudaBasicsTests, Operator_Plus_Test_2) ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayCudaBasicsTests, assign_2) { - NDArray x('c', {4}, {1.5,2.5,3.5,4.5}, nd4j::DataType::FLOAT32); + NDArray x('c', {4}, {1.5f,2.5f,3.5f,4.5f}, nd4j::DataType::FLOAT32); NDArray y('c', {4}, nd4j::DataType::INT32); NDArray expected('c', {4}, {1,2,3,4}, nd4j::DataType::INT32); @@ -1945,30 +1947,30 @@ TEST_F(NDArrayCudaBasicsTests, subarray_1) NDArray y('f', {2,3,4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24}, nd4j::DataType::FLOAT32); Nd4jLong shapeExpX0[] = {1, 2, 12, 8192, 1, 99}; - float buffExpX0[] = {1.000000, 13.000000}; + float buffExpX0[] = {1.f, 13.f}; Nd4jLong shapeExpX1[] = {1, 2, 12, 8192, 1, 99}; - float buffExpX1[] = {2.000000, 14.000000}; + float buffExpX1[] = {2.f, 14.f}; Nd4jLong shapeExpX2[] = {3, 2, 1, 1, 12, 4, 1, 8192, 1, 99}; - float buffExpX2[] = {1.000000, 13.000000}; + float buffExpX2[] = {1.f, 13.f}; Nd4jLong shapeExpX3[] = {2, 2, 4, 12, 1, 8192, 1, 99}; - float buffExpX3[] = {9.000000, 10.000000, 11.000000, 12.000000, 21.000000, 22.000000, 23.000000, 24.000000}; + float buffExpX3[] = {9.f, 10.f, 11.f, 12.f, 21.f, 22.f, 23.f, 24.f}; Nd4jLong shapeExpX4[] = {3, 2, 1, 4, 12, 4, 1, 8192, 1, 99}; - float buffExpX4[] = {9.000000, 10.000000, 11.000000, 12.000000, 21.000000, 22.000000, 23.000000, 24.000000}; + float buffExpX4[] = {9.f, 10.f, 11.f, 12.f, 21.f, 22.f, 23.f, 24.f}; Nd4jLong shapeExpX5[] = {2, 2, 3, 12, 4, 8192, 1, 99}; - float buffExpX5[] = {4.000000, 8.000000, 12.000000, 16.000000, 20.000000, 24.000000}; + float buffExpX5[] = {4.f, 8.f, 12.f, 16.f, 20.f, 24.f}; Nd4jLong shapeExpY0[] = {1, 2, 1, 8192, 1, 99}; - float buffExpY0[] = {1.000000, 2.000000}; + float buffExpY0[] = {1.f, 2.f}; Nd4jLong shapeExpY1[] = {1, 2, 1, 8192, 1, 99}; - float buffExpY1[] = {7.000000, 8.000000}; + float buffExpY1[] = {7.f, 8.f}; Nd4jLong shapeExpY2[] = {3, 2, 1, 1, 1, 2, 6, 8192, 1, 102}; - float buffExpY2[] = {1.000000, 2.000000}; + float buffExpY2[] = {1.f, 2.f}; Nd4jLong shapeExpY3[] = {2, 2, 4, 1, 6, 8192, 1, 99}; - float buffExpY3[] = {5.000000, 11.000000, 17.000000, 23.000000, 6.000000, 12.000000, 18.000000, 24.000000}; + float buffExpY3[] = {5.f, 11.f, 17.f, 23.f, 6.f, 12.f, 18.f, 24.f}; Nd4jLong shapeExpY4[] = {3, 2, 1, 4, 1, 2, 6, 8192, 1, 102}; - float buffExpY4[] = {5.000000, 11.000000, 17.000000, 23.000000, 6.000000, 12.000000, 18.000000, 24.000000}; + float buffExpY4[] = {5.f, 11.f, 17.f, 23.f, 6.f, 12.f, 18.f, 24.f}; Nd4jLong shapeExpY5[] = {2, 2, 3, 1, 2, 8192, 1, 99}; - float buffExpY5[] = {19.000000, 21.000000, 23.000000, 20.000000, 22.000000, 24.000000}; + float buffExpY5[] = {19.f, 21.f, 23.f, 20.f, 22.f, 24.f}; NDArray x0 = x(0, {1,2}); @@ -2121,7 +2123,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_diagonal_1) { TEST_F(NDArrayCudaBasicsTests, Test_PermuteEquality_02) { auto x = NDArrayFactory::linspace(1.f, 60.f, 60); //('c', {1, 60}); //x.linspace(1); - auto exp = NDArrayFactory::create('c', {3, 4, 5}, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 48.0, 49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, 60.0}); + auto exp = NDArrayFactory::create('c', {3, 4, 5}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f, 37.0f, 38.0f, 39.0f, 40.0f, 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 48.0f, 49.0f, 50.0f, 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, 56.0f, 57.0f, 58.0f, 59.0f, 60.0}); x->reshapei('c', {3, 4, 5}); x->permutei({0, 1, 2}); @@ -2138,7 +2140,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_PermuteEquality_02) { TEST_F(NDArrayCudaBasicsTests, Test_PermuteEquality_0) { auto x = NDArrayFactory::create('c', {1, 60}); x.linspace(1); - auto exp = NDArrayFactory::create('c', {3, 4, 5}, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 48.0, 49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, 60.0}); + auto exp = NDArrayFactory::create('c', {3, 4, 5}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f, 37.0f, 38.0f, 39.0f, 40.0f, 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 48.0f, 49.0f, 50.0f, 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, 56.0f, 57.0f, 58.0f, 59.0f, 60.0}); x.reshapei('c', {3, 4, 5}); x.permutei({0, 1, 2}); @@ -2153,7 +2155,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_PermuteEquality_0) { TEST_F(NDArrayCudaBasicsTests, Test_PermuteEquality_1) { auto x = NDArrayFactory::create('c', {1, 60}); x.linspace(1); - auto exp = NDArrayFactory::create('c', {3, 4, 5}, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 48.0, 49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, 60.0}); + auto exp = NDArrayFactory::create('c', {3, 4, 5}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f, 37.0f, 38.0f, 39.0f, 40.0f, 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 48.0f, 49.0f, 50.0f, 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, 56.0f, 57.0f, 58.0f, 59.0f, 60.0}); x.reshapei('c', {3, 4, 5}); x.permutei({0, 1, 2}); @@ -2170,7 +2172,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_PermuteEquality_2) { auto xx = NDArrayFactory::linspace(1.f, 60.f, 60); //('c', {1, 60}); // auto x = *xx; //x.linspace(1); -// auto exp = NDArrayFactory::create('c', {3, 4, 5}, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 48.0, 49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, 60.0}); +// auto exp = NDArrayFactory::create('c', {3, 4, 5}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f, 37.0f, 38.0f, 39.0f, 40.0f, 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 48.0f, 49.0f, 50.0f, 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, 56.0f, 57.0f, 58.0f, 59.0f, 60.0}); // x.reshapei('c', {3, 4, 5}); // x.permutei({0, 1, 2}); @@ -2188,7 +2190,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_PermuteEquality_3) { //x.linspace(1); for (int l = 0; l < x.lengthOf(); l++) x.p(l, float(l + 1.f)); - auto exp = NDArrayFactory::create('c', {3, 4, 5}, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 48.0, 49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, 60.0}); + auto exp = NDArrayFactory::create('c', {3, 4, 5}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f, 37.0f, 38.0f, 39.0f, 40.0f, 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 48.0f, 49.0f, 50.0f, 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, 56.0f, 57.0f, 58.0f, 59.0f, 60.0}); x.reshapei('c', {3, 4, 5}); x.permutei({0, 1, 2}); diff --git a/libnd4j/tests_cpu/layers_tests/NDArrayTests.cpp b/libnd4j/tests_cpu/layers_tests/NDArrayTests.cpp index 0f3cab509..d0fb4bf37 100644 --- a/libnd4j/tests_cpu/layers_tests/NDArrayTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/NDArrayTests.cpp @@ -774,7 +774,7 @@ TEST_F(NDArrayTest, TestTile3) { TEST_F(NDArrayTest, TestTile4) { float xBuff[] = {1,2,3,4,5,6}; - float expBuff[] = {1.,2., 1.,2., 3.,4., 3.,4., 5.,6., 5.,6.}; + float expBuff[] = {1.f,2.f, 1.f,2.f, 3.f,4.f, 3.f,4.f, 5.f,6.f, 5.f,6.f}; auto x = NDArrayFactory::create(xBuff, 'c', {3,1,2}); auto exp = NDArrayFactory::create(expBuff, 'c', {3,2,2}); @@ -789,7 +789,7 @@ TEST_F(NDArrayTest, TestTile4) { TEST_F(NDArrayTest, TestTile5) { float xBuff[] = {1,2,3,4,5,6,7,8,9,10,11,12}; - float expBuff[] = {1., 2., 3., 4., 1., 2., 3., 4., 5., 6., 7., 8., 5., 6., 7., 8., 9.,10., 11.,12., 9.,10., 11.,12.}; + float expBuff[] = {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 5.f, 6.f, 7.f, 8.f, 9.f,10.f, 11.f,12.f, 9.f,10.f, 11.f,12.f}; auto x = NDArrayFactory::create(xBuff, 'c', {3,2,2}); auto exp = NDArrayFactory::create(expBuff, 'c', {3,4,2}); @@ -847,7 +847,7 @@ TEST_F(NDArrayTest, TestPermuteReshapeMmul1) { auto y = NDArrayFactory::create('c', {3, 6}); Nd4jLong _expS[] = {2, 3, 3, 1, 3, 8192, 1, 102}; - float _expB[] = {231.0, 252.0, 273.0, 537.0, 594.0, 651.0, 843.0, 936.0, 1029.0}; + float _expB[] = {231.0f, 252.0f, 273.0f, 537.0f, 594.0f, 651.0f, 843.0f, 936.0f, 1029.0f}; NDArray exp(_expB, _expS); for (int e = 0; e < x.lengthOf(); e++) @@ -872,7 +872,7 @@ TEST_F(NDArrayTest, TestPermuteReshapeMmul2) { auto y = NDArrayFactory::create('c', {3, 6}); Nd4jLong _expS[] = {2, 3, 3, 1, 3, 8192, 1, 102}; - float _expB[] = {231.0, 252.0, 273.0, 537.0, 594.0, 651.0, 843.0, 936.0, 1029.0}; + float _expB[] = {231.0f, 252.0f, 273.0f, 537.0f, 594.0f, 651.0f, 843.0f, 936.0f, 1029.0f}; NDArray exp(_expB, _expS); for (int e = 0; e < x.lengthOf(); e++) @@ -903,7 +903,7 @@ TEST_F(NDArrayTest, TestPermuteReshapeMmul3) { auto y = NDArrayFactory::create('c', {2, 3, 2 ,2}); Nd4jLong _expS[] = {2, 8, 2, 1, 8, 8192, 1, 102}; - float _expB[] = {1624.0, 1858.0, 2092.0, 2326.0, 5368.0, 5602.0, 5836.0, 6070.0, 4504.0, 5170.0, 5836.0, 6502.0, 15160.0, 15826.0, 16492.0, 17158.0}; + float _expB[] = {1624.0f, 1858.0f, 2092.0f, 2326.0f, 5368.0f, 5602.0f, 5836.0f, 6070.0f, 4504.0f, 5170.0f, 5836.0f, 6502.0f, 15160.0f, 15826.0f, 16492.0f, 17158.0f}; NDArray exp(_expB, _expS); for (int e = 0; e < x.lengthOf(); e++) @@ -931,7 +931,7 @@ TEST_F(NDArrayTest, TestPermuteReshapeMmul4) { auto y = NDArrayFactory::create('c', {2, 3, 2 ,2}); Nd4jLong _expS[] = {2, 8, 2, 1, 8, 8192, 1, 102}; - float _expB[] = {1624.0, 1858.0, 2092.0, 2326.0, 5368.0, 5602.0, 5836.0, 6070.0, 4504.0, 5170.0, 5836.0, 6502.0, 15160.0, 15826.0, 16492.0, 17158.0}; + float _expB[] = {1624.0f, 1858.0f, 2092.0f, 2326.0f, 5368.0f, 5602.0f, 5836.0f, 6070.0f, 4504.0f, 5170.0f, 5836.0f, 6502.0f, 15160.0f, 15826.0f, 16492.0f, 17158.0f}; NDArray exp(_expB, _expS); for (int e = 0; e < x.lengthOf(); e++) @@ -971,7 +971,7 @@ TEST_F(NDArrayTest, TestMmulHelper2) { auto z = NDArrayFactory::create_('f', {5, 1}); - auto expBuffer = new float[5]{28.00, 64.00, 100.00, 136.00, 172.00}; + auto expBuffer = new float[5]{28.00f, 64.00f, 100.00f, 136.00f, 172.00f}; auto exp = new NDArray(expBuffer, z->getShapeInfo(), nd4j::LaunchContext ::defaultContext(), true); //nd4j::blas::GEMV::op('f', x->rows(), x->columns(), 1.0f, x->getBuffer(), y->rows(), y->getBuffer(), 1, 0.0, z->getBuffer(), 1); @@ -1000,7 +1000,7 @@ TEST_F(NDArrayTest, TestMmulHelper3) { auto z = NDArrayFactory::create_('f', {5, 1}); - auto expBuffer = new float[5]{92.00, 104.00, 116.00, 128.00, 140.00}; + auto expBuffer = new float[5]{92.00f, 104.00f, 116.00f, 128.00f, 140.00f}; auto exp = new NDArray(expBuffer, z->getShapeInfo()); //nd4j::blas::GEMV::op('f', x->rows(), x->columns(), 1.0f, x->getBuffer(), y->rows(), y->getBuffer(), 1, 0.0, z->getBuffer(), 1); @@ -1035,7 +1035,7 @@ TEST_F(NDArrayTest, TestMmulHelper4) { auto z = NDArrayFactory::create_('f', {3, 3}); - auto expBuffer = new float[9]{7.0, 21.0, 35.0, 10.0, 28.0, 46.0, 13.0, 35.0, 57.0}; + auto expBuffer = new float[9]{7.0f, 21.0f, 35.0f, 10.0f, 28.0f, 46.0f, 13.0f, 35.0f, 57.0f}; auto exp = new NDArray(expBuffer, z->getShapeInfo()); MmulHelper::mmul(x, y, z); @@ -1065,7 +1065,7 @@ TEST_F(NDArrayTest, TestMmulHelper5) { auto z = NDArrayFactory::create_('f', {3, 3}); - auto expBuffer = new float[9]{7.0, 14.0, 21.0, 12.0, 21.0, 30.0, 17.0, 28.0, 39.0}; + auto expBuffer = new float[9]{7.0f, 14.0f, 21.0f, 12.0f, 21.0f, 30.0f, 17.0f, 28.0f, 39.0f}; auto exp = new NDArray(expBuffer, z->getShapeInfo()); MmulHelper::mmul(x, y, z); @@ -1095,7 +1095,7 @@ TEST_F(NDArrayTest, TestMmulHelper6) { auto z = NDArrayFactory::create_('f', {3, 3}); - auto expBuffer = new float[9]{39.0, 54.0, 69.0, 9.0, 18.0, 27.0, 9.0, 12.0, 15.0}; + auto expBuffer = new float[9]{39.0f, 54.0f, 69.0f, 9.0f, 18.0f, 27.0f, 9.0f, 12.0f, 15.0f}; auto exp = new NDArray(expBuffer, z->getShapeInfo()); MmulHelper::mmul(x, y, z); @@ -1126,7 +1126,7 @@ TEST_F(NDArrayTest, TestMmulHelper7) { auto z = NDArrayFactory::create_('f', {1, 3}); - auto expBuffer = new float[9]{110.00, 260.00, 410.00}; + auto expBuffer = new float[9]{110.00f, 260.00f, 410.00f}; auto exp = new NDArray(expBuffer, z->getShapeInfo()); MmulHelper::mmul(y, x, z); @@ -1171,7 +1171,59 @@ TEST_F(NDArrayTest, TestMmulHelper_ND_1) { TEST_F(NDArrayTest, TestMmulHelper_ND_2) { Nd4jLong _expS[] = {3, 2, 72, 2, 144, 2, 1, 8192, 1, 99}; - float _expB[] = {1.07250000e+04, 1.10500000e+04, 2.63500000e+04, 2.73000000e+04, 4.19750000e+04, 4.35500000e+04, 5.76000000e+04, 5.98000000e+04, 7.32250000e+04, 7.60500000e+04, 8.88500000e+04, 9.23000000e+04, 1.04475000e+05, 1.08550000e+05, 1.20100000e+05, 1.24800000e+05, 1.35725000e+05, 1.41050000e+05, 1.51350000e+05, 1.57300000e+05, 1.66975000e+05, 1.73550000e+05, 1.82600000e+05, 1.89800000e+05, 1.98225000e+05, 2.06050000e+05, 2.13850000e+05, 2.22300000e+05, 2.29475000e+05, 2.38550000e+05, 2.45100000e+05, 2.54800000e+05, 2.60725000e+05, 2.71050000e+05, 2.76350000e+05, 2.87300000e+05, 2.91975000e+05, 3.03550000e+05, 3.07600000e+05, 3.19800000e+05, 3.23225000e+05, 3.36050000e+05, 3.38850000e+05, 3.52300000e+05, 3.54475000e+05, 3.68550000e+05, 3.70100000e+05, 3.84800000e+05, 3.85725000e+05, 4.01050000e+05, 4.01350000e+05, 4.17300000e+05, 4.16975000e+05, 4.33550000e+05, 4.32600000e+05, 4.49800000e+05, 4.48225000e+05, 4.66050000e+05, 4.63850000e+05, 4.82300000e+05, 4.79475000e+05, 4.98550000e+05, 4.95100000e+05, 5.14800000e+05, 5.10725000e+05, 5.31050000e+05, 5.26350000e+05, 5.47300000e+05, 5.41975000e+05, 5.63550000e+05, 5.57600000e+05, 5.79800000e+05, 5.73225000e+05, 5.96050000e+05, 5.88850000e+05, 6.12300000e+05, 6.04475000e+05, 6.28550000e+05, 6.20100000e+05, 6.44800000e+05, 6.35725000e+05, 6.61050000e+05, 6.51350000e+05, 6.77300000e+05, 6.66975000e+05, 6.93550000e+05, 6.82600000e+05, 7.09800000e+05, 6.98225000e+05, 7.26050000e+05, 7.13850000e+05, 7.42300000e+05, 7.29475000e+05, 7.58550000e+05, 7.45100000e+05, 7.74800000e+05, 7.60725000e+05, 7.91050000e+05, 7.76350000e+05, 8.07300000e+05, 7.91975000e+05, 8.23550000e+05, 8.07600000e+05, 8.39800000e+05, 8.23225000e+05, 8.56050000e+05, 8.38850000e+05, 8.72300000e+05, 8.54475000e+05, 8.88550000e+05, 8.70100000e+05, 9.04800000e+05, 8.85725000e+05, 9.21050000e+05, 9.01350000e+05, 9.37300000e+05, 9.16975000e+05, 9.53550000e+05, 9.32600000e+05, 9.69800000e+05, 9.48225000e+05, 9.86050000e+05, 9.63850000e+05, 1.00230000e+06, 9.79475000e+05, 1.01855000e+06, 9.95100000e+05, 1.03480000e+06, 1.01072500e+06, 1.05105000e+06, 1.02635000e+06, 1.06730000e+06, 1.04197500e+06, 1.08355000e+06, 1.05760000e+06, 1.09980000e+06, 1.07322500e+06, 1.11605000e+06, 1.08885000e+06, 1.13230000e+06, 1.10447500e+06, 1.14855000e+06, 1.12010000e+06, 1.16480000e+06, 1.13572500e+06, 1.18105000e+06, 1.15135000e+06, 1.19730000e+06, 1.16697500e+06, 1.21355000e+06, 3.54260000e+06, 3.58980000e+06, 3.58947500e+06, 3.63730000e+06, 3.63635000e+06, 3.68480000e+06, 3.68322500e+06, 3.73230000e+06, 3.73010000e+06, 3.77980000e+06, 3.77697500e+06, 3.82730000e+06, 3.82385000e+06, 3.87480000e+06, 3.87072500e+06, 3.92230000e+06, 3.91760000e+06, 3.96980000e+06, 3.96447500e+06, 4.01730000e+06, 4.01135000e+06, 4.06480000e+06, 4.05822500e+06, 4.11230000e+06, 4.10510000e+06, 4.15980000e+06, 4.15197500e+06, 4.20730000e+06, 4.19885000e+06, 4.25480000e+06, 4.24572500e+06, 4.30230000e+06, 4.29260000e+06, 4.34980000e+06, 4.33947500e+06, 4.39730000e+06, 4.38635000e+06, 4.44480000e+06, 4.43322500e+06, 4.49230000e+06, 4.48010000e+06, 4.53980000e+06, 4.52697500e+06, 4.58730000e+06, 4.57385000e+06, 4.63480000e+06, 4.62072500e+06, 4.68230000e+06, 4.66760000e+06, 4.72980000e+06, 4.71447500e+06, 4.77730000e+06, 4.76135000e+06, 4.82480000e+06, 4.80822500e+06, 4.87230000e+06, 4.85510000e+06, 4.91980000e+06, 4.90197500e+06, 4.96730000e+06, 4.94885000e+06, 5.01480000e+06, 4.99572500e+06, 5.06230000e+06, 5.04260000e+06, 5.10980000e+06, 5.08947500e+06, 5.15730000e+06, 5.13635000e+06, 5.20480000e+06, 5.18322500e+06, 5.25230000e+06, 5.23010000e+06, 5.29980000e+06, 5.27697500e+06, 5.34730000e+06, 5.32385000e+06, 5.39480000e+06, 5.37072500e+06, 5.44230000e+06, 5.41760000e+06, 5.48980000e+06, 5.46447500e+06, 5.53730000e+06, 5.51135000e+06, 5.58480000e+06, 5.55822500e+06, 5.63230000e+06, 5.60510000e+06, 5.67980000e+06, 5.65197500e+06, 5.72730000e+06, 5.69885000e+06, 5.77480000e+06, 5.74572500e+06, 5.82230000e+06, 5.79260000e+06, 5.86980000e+06, 5.83947500e+06, 5.91730000e+06, 5.88635000e+06, 5.96480000e+06, 5.93322500e+06, 6.01230000e+06, 5.98010000e+06, 6.05980000e+06, 6.02697500e+06, 6.10730000e+06, 6.07385000e+06, 6.15480000e+06, 6.12072500e+06, 6.20230000e+06, 6.16760000e+06, 6.24980000e+06, 6.21447500e+06, 6.29730000e+06, 6.26135000e+06, 6.34480000e+06, 6.30822500e+06, 6.39230000e+06, 6.35510000e+06, 6.43980000e+06, 6.40197500e+06, 6.48730000e+06, 6.44885000e+06, 6.53480000e+06, 6.49572500e+06, 6.58230000e+06, 6.54260000e+06, 6.62980000e+06, 6.58947500e+06, 6.67730000e+06, 6.63635000e+06, 6.72480000e+06, 6.68322500e+06, 6.77230000e+06, 6.73010000e+06, 6.81980000e+06, 6.77697500e+06, 6.86730000e+06, 6.82385000e+06, 6.91480000e+06, 6.87072500e+06, 6.96230000e+06, 6.91760000e+06, 7.00980000e+06, 6.96447500e+06, 7.05730000e+06, 7.01135000e+06, 7.10480000e+06, 1.17619750e+07, 1.18560500e+07, 1.18401000e+07, 1.19348000e+07, 1.19182250e+07, 1.20135500e+07, 1.19963500e+07, 1.20923000e+07, 1.20744750e+07, 1.21710500e+07, 1.21526000e+07, 1.22498000e+07, 1.22307250e+07, 1.23285500e+07, 1.23088500e+07, 1.24073000e+07, 1.23869750e+07, 1.24860500e+07, 1.24651000e+07, 1.25648000e+07, 1.25432250e+07, 1.26435500e+07, 1.26213500e+07, 1.27223000e+07, 1.26994750e+07, 1.28010500e+07, 1.27776000e+07, 1.28798000e+07, 1.28557250e+07, 1.29585500e+07, 1.29338500e+07, 1.30373000e+07, 1.30119750e+07, 1.31160500e+07, 1.30901000e+07, 1.31948000e+07, 1.31682250e+07, 1.32735500e+07, 1.32463500e+07, 1.33523000e+07, 1.33244750e+07, 1.34310500e+07, 1.34026000e+07, 1.35098000e+07, 1.34807250e+07, 1.35885500e+07, 1.35588500e+07, 1.36673000e+07, 1.36369750e+07, 1.37460500e+07, 1.37151000e+07, 1.38248000e+07, 1.37932250e+07, 1.39035500e+07, 1.38713500e+07, 1.39823000e+07, 1.39494750e+07, 1.40610500e+07, 1.40276000e+07, 1.41398000e+07, 1.41057250e+07, 1.42185500e+07, 1.41838500e+07, 1.42973000e+07, 1.42619750e+07, 1.43760500e+07, 1.43401000e+07, 1.44548000e+07, 1.44182250e+07, 1.45335500e+07, 1.44963500e+07, 1.46123000e+07, 1.45744750e+07, 1.46910500e+07, 1.46526000e+07, 1.47698000e+07, 1.47307250e+07, 1.48485500e+07, 1.48088500e+07, 1.49273000e+07, 1.48869750e+07, 1.50060500e+07, 1.49651000e+07, 1.50848000e+07, 1.50432250e+07, 1.51635500e+07, 1.51213500e+07, 1.52423000e+07, 1.51994750e+07, 1.53210500e+07, 1.52776000e+07, 1.53998000e+07, 1.53557250e+07, 1.54785500e+07, 1.54338500e+07, 1.55573000e+07, 1.55119750e+07, 1.56360500e+07, 1.55901000e+07, 1.57148000e+07, 1.56682250e+07, 1.57935500e+07, 1.57463500e+07, 1.58723000e+07, 1.58244750e+07, 1.59510500e+07, 1.59026000e+07, 1.60298000e+07, 1.59807250e+07, 1.61085500e+07, 1.60588500e+07, 1.61873000e+07, 1.61369750e+07, 1.62660500e+07, 1.62151000e+07, 1.63448000e+07, 1.62932250e+07, 1.64235500e+07, 1.63713500e+07, 1.65023000e+07, 1.64494750e+07, 1.65810500e+07, 1.65276000e+07, 1.66598000e+07, 1.66057250e+07, 1.67385500e+07, 1.66838500e+07, 1.68173000e+07, 1.67619750e+07, 1.68960500e+07, 1.68401000e+07, 1.69748000e+07, 1.69182250e+07, 1.70535500e+07, 1.69963500e+07, 1.71323000e+07, 1.70744750e+07, 1.72110500e+07, 1.71526000e+07, 1.72898000e+07, 1.72307250e+07, 1.73685500e+07, 1.73088500e+07, 1.74473000e+07, 1.73869750e+07, 1.75260500e+07, 1.74651000e+07, 1.76048000e+07, 1.75432250e+07, 1.76835500e+07, 2.46688500e+07, 2.48098000e+07, 2.47782250e+07, 2.49198000e+07, 2.48876000e+07, 2.50298000e+07, 2.49969750e+07, 2.51398000e+07, 2.51063500e+07, 2.52498000e+07, 2.52157250e+07, 2.53598000e+07, 2.53251000e+07, 2.54698000e+07, 2.54344750e+07, 2.55798000e+07, 2.55438500e+07, 2.56898000e+07, 2.56532250e+07, 2.57998000e+07, 2.57626000e+07, 2.59098000e+07, 2.58719750e+07, 2.60198000e+07, 2.59813500e+07, 2.61298000e+07, 2.60907250e+07, 2.62398000e+07, 2.62001000e+07, 2.63498000e+07, 2.63094750e+07, 2.64598000e+07, 2.64188500e+07, 2.65698000e+07, 2.65282250e+07, 2.66798000e+07, 2.66376000e+07, 2.67898000e+07, 2.67469750e+07, 2.68998000e+07, 2.68563500e+07, 2.70098000e+07, 2.69657250e+07, 2.71198000e+07, 2.70751000e+07, 2.72298000e+07, 2.71844750e+07, 2.73398000e+07, 2.72938500e+07, 2.74498000e+07, 2.74032250e+07, 2.75598000e+07, 2.75126000e+07, 2.76698000e+07, 2.76219750e+07, 2.77798000e+07, 2.77313500e+07, 2.78898000e+07, 2.78407250e+07, 2.79998000e+07, 2.79501000e+07, 2.81098000e+07, 2.80594750e+07, 2.82198000e+07, 2.81688500e+07, 2.83298000e+07, 2.82782250e+07, 2.84398000e+07, 2.83876000e+07, 2.85498000e+07, 2.84969750e+07, 2.86598000e+07, 2.86063500e+07, 2.87698000e+07, 2.87157250e+07, 2.88798000e+07, 2.88251000e+07, 2.89898000e+07, 2.89344750e+07, 2.90998000e+07, 2.90438500e+07, 2.92098000e+07, 2.91532250e+07, 2.93198000e+07, 2.92626000e+07, 2.94298000e+07, 2.93719750e+07, 2.95398000e+07, 2.94813500e+07, 2.96498000e+07, 2.95907250e+07, 2.97598000e+07, 2.97001000e+07, 2.98698000e+07, 2.98094750e+07, 2.99798000e+07, 2.99188500e+07, 3.00898000e+07, 3.00282250e+07, 3.01998000e+07, 3.01376000e+07, 3.03098000e+07, 3.02469750e+07, 3.04198000e+07, 3.03563500e+07, 3.05298000e+07, 3.04657250e+07, 3.06398000e+07, 3.05751000e+07, 3.07498000e+07, 3.06844750e+07, 3.08598000e+07, 3.07938500e+07, 3.09698000e+07, 3.09032250e+07, 3.10798000e+07, 3.10126000e+07, 3.11898000e+07, 3.11219750e+07, 3.12998000e+07, 3.12313500e+07, 3.14098000e+07, 3.13407250e+07, 3.15198000e+07, 3.14501000e+07, 3.16298000e+07, 3.15594750e+07, 3.17398000e+07, 3.16688500e+07, 3.18498000e+07, 3.17782250e+07, 3.19598000e+07, 3.18876000e+07, 3.20698000e+07, 3.19969750e+07, 3.21798000e+07, 3.21063500e+07, 3.22898000e+07, 3.22157250e+07, 3.23998000e+07, 3.23251000e+07, 3.25098000e+07, 3.24344750e+07, 3.26198000e+07, 3.25438500e+07, 3.27298000e+07, 3.26532250e+07, 3.28398000e+07, 3.27626000e+07, 3.29498000e+07}; + float _expB[] = { + 1.07250000e+04f, 1.10500000e+04f, 2.63500000e+04f, 2.73000000e+04f, 4.19750000e+04f, 4.35500000e+04f, + 5.76000000e+04f, 5.98000000e+04f, 7.32250000e+04f, 7.60500000e+04f, 8.88500000e+04f, 9.23000000e+04f, + 1.04475000e+05f, 1.08550000e+05f, 1.20100000e+05f, 1.24800000e+05f, 1.35725000e+05f, 1.41050000e+05f, + 1.51350000e+05f, 1.57300000e+05f, 1.66975000e+05f, 1.73550000e+05f, 1.82600000e+05f, 1.89800000e+05f, + 1.98225000e+05f, 2.06050000e+05f, 2.13850000e+05f, 2.22300000e+05f, 2.29475000e+05f, 2.38550000e+05f, + 2.45100000e+05f, 2.54800000e+05f, 2.60725000e+05f, 2.71050000e+05f, 2.76350000e+05f, 2.87300000e+05f, + 2.91975000e+05f, 3.03550000e+05f, 3.07600000e+05f, 3.19800000e+05f, 3.23225000e+05f, 3.36050000e+05f, + 3.38850000e+05f, 3.52300000e+05f, 3.54475000e+05f, 3.68550000e+05f, 3.70100000e+05f, 3.84800000e+05f, + 3.85725000e+05f, 4.01050000e+05f, 4.01350000e+05f, 4.17300000e+05f, 4.16975000e+05f, 4.33550000e+05f, + 4.32600000e+05f, 4.49800000e+05f, 4.48225000e+05f, 4.66050000e+05f, 4.63850000e+05f, 4.82300000e+05f, + 4.79475000e+05f, 4.98550000e+05f, 4.95100000e+05f, 5.14800000e+05f, 5.10725000e+05f, 5.31050000e+05f, + 5.26350000e+05f, 5.47300000e+05f, 5.41975000e+05f, 5.63550000e+05f, 5.57600000e+05f, 5.79800000e+05f, + 5.73225000e+05f, 5.96050000e+05f, 5.88850000e+05f, 6.12300000e+05f, 6.04475000e+05f, 6.28550000e+05f, + 6.20100000e+05f, 6.44800000e+05f, 6.35725000e+05f, 6.61050000e+05f, 6.51350000e+05f, 6.77300000e+05f, + 6.66975000e+05f, 6.93550000e+05f, 6.82600000e+05f, 7.09800000e+05f, 6.98225000e+05f, 7.26050000e+05f, + 7.13850000e+05f, 7.42300000e+05f, 7.29475000e+05f, 7.58550000e+05f, 7.45100000e+05f, 7.74800000e+05f, + 7.60725000e+05f, 7.91050000e+05f, 7.76350000e+05f, 8.07300000e+05f, 7.91975000e+05f, 8.23550000e+05f, + 8.07600000e+05f, 8.39800000e+05f, 8.23225000e+05f, 8.56050000e+05f, 8.38850000e+05f, 8.72300000e+05f, + 8.54475000e+05f, 8.88550000e+05f, 8.70100000e+05f, 9.04800000e+05f, 8.85725000e+05f, 9.21050000e+05f, + 9.01350000e+05f, 9.37300000e+05f, 9.16975000e+05f, 9.53550000e+05f, 9.32600000e+05f, 9.69800000e+05f, + 9.48225000e+05f, 9.86050000e+05f, 9.63850000e+05f, 1.00230000e+06f, 9.79475000e+05f, 1.01855000e+06f, + 9.95100000e+05f, 1.03480000e+06f, 1.01072500e+06f, 1.05105000e+06f, 1.02635000e+06f, 1.06730000e+06f, + 1.04197500e+06f, 1.08355000e+06f, 1.05760000e+06f, 1.09980000e+06f, 1.07322500e+06f, 1.11605000e+06f, + 1.08885000e+06f, 1.13230000e+06f, 1.10447500e+06f, 1.14855000e+06f, 1.12010000e+06f, 1.16480000e+06f, + 1.13572500e+06f, 1.18105000e+06f, 1.15135000e+06f, 1.19730000e+06f, 1.16697500e+06f, 1.21355000e+06f, + 3.54260000e+06f, 3.58980000e+06f, 3.58947500e+06f, 3.63730000e+06f, 3.63635000e+06f, 3.68480000e+06f, + 3.68322500e+06f, 3.73230000e+06f, 3.73010000e+06f, 3.77980000e+06f, 3.77697500e+06f, 3.82730000e+06f, + 3.82385000e+06f, 3.87480000e+06f, 3.87072500e+06f, 3.92230000e+06f, 3.91760000e+06f, 3.96980000e+06f, + 3.96447500e+06f, 4.01730000e+06f, 4.01135000e+06f, 4.06480000e+06f, 4.05822500e+06f, 4.11230000e+06f, + 4.10510000e+06f, 4.15980000e+06f, 4.15197500e+06f, 4.20730000e+06f, 4.19885000e+06f, 4.25480000e+06f, + 4.24572500e+06f, 4.30230000e+06f, 4.29260000e+06f, 4.34980000e+06f, 4.33947500e+06f, 4.39730000e+06f, + 4.38635000e+06f, 4.44480000e+06f, 4.43322500e+06f, 4.49230000e+06f, 4.48010000e+06f, 4.53980000e+06f, + 4.52697500e+06f, 4.58730000e+06f, 4.57385000e+06f, 4.63480000e+06f, 4.62072500e+06f, 4.68230000e+06f, + 4.66760000e+06f, 4.72980000e+06f, 4.71447500e+06f, 4.77730000e+06f, 4.76135000e+06f, 4.82480000e+06f, + 4.80822500e+06f, 4.87230000e+06f, 4.85510000e+06f, 4.91980000e+06f, 4.90197500e+06f, 4.96730000e+06f, + 4.94885000e+06f, 5.01480000e+06f, 4.99572500e+06f, 5.06230000e+06f, 5.04260000e+06f, 5.10980000e+06f, + 5.08947500e+06f, 5.15730000e+06f, 5.13635000e+06f, 5.20480000e+06f, 5.18322500e+06f, 5.25230000e+06f, + 5.23010000e+06f, 5.29980000e+06f, 5.27697500e+06f, 5.34730000e+06f, 5.32385000e+06f, 5.39480000e+06f, + 5.37072500e+06f, 5.44230000e+06f, 5.41760000e+06f, 5.48980000e+06f, 5.46447500e+06f, 5.53730000e+06f, + 5.51135000e+06f, 5.58480000e+06f, 5.55822500e+06f, 5.63230000e+06f, 5.60510000e+06f, 5.67980000e+06f, + 5.65197500e+06f, 5.72730000e+06f, 5.69885000e+06f, 5.77480000e+06f, 5.74572500e+06f, 5.82230000e+06f, + 5.79260000e+06f, 5.86980000e+06f, 5.83947500e+06f, 5.91730000e+06f, 5.88635000e+06f, 5.96480000e+06f, + 5.93322500e+06f, 6.01230000e+06f, 5.98010000e+06f, 6.05980000e+06f, 6.02697500e+06f, 6.10730000e+06f, + 6.07385000e+06f, 6.15480000e+06f, 6.12072500e+06f, 6.20230000e+06f, 6.16760000e+06f, 6.24980000e+06f, + 6.21447500e+06f, 6.29730000e+06f, 6.26135000e+06f, 6.34480000e+06f, 6.30822500e+06f, 6.39230000e+06f, + 6.35510000e+06f, 6.43980000e+06f, 6.40197500e+06f, 6.48730000e+06f, 6.44885000e+06f, 6.53480000e+06f, + 6.49572500e+06f, 6.58230000e+06f, 6.54260000e+06f, 6.62980000e+06f, 6.58947500e+06f, 6.67730000e+06f, + 6.63635000e+06f, 6.72480000e+06f, 6.68322500e+06f, 6.77230000e+06f, 6.73010000e+06f, 6.81980000e+06f, + 6.77697500e+06f, 6.86730000e+06f, 6.82385000e+06f, 6.91480000e+06f, 6.87072500e+06f, 6.96230000e+06f, + 6.91760000e+06f, 7.00980000e+06f, 6.96447500e+06f, 7.05730000e+06f, 7.01135000e+06f, 7.10480000e+06f, + 1.17619750e+07f, 1.18560500e+07f, 1.18401000e+07f, 1.19348000e+07f, 1.19182250e+07f, 1.20135500e+07f, + 1.19963500e+07f, 1.20923000e+07f, 1.20744750e+07f, 1.21710500e+07f, 1.21526000e+07f, 1.22498000e+07f, 1.22307250e+07f, 1.23285500e+07f, 1.23088500e+07f, 1.24073000e+07f, 1.23869750e+07f, 1.24860500e+07f, 1.24651000e+07f, 1.25648000e+07f, 1.25432250e+07f, 1.26435500e+07f, 1.26213500e+07f, 1.27223000e+07f, 1.26994750e+07f, 1.28010500e+07f, 1.27776000e+07f, 1.28798000e+07f, 1.28557250e+07f, 1.29585500e+07f, 1.29338500e+07f, 1.30373000e+07f, 1.30119750e+07f, 1.31160500e+07f, 1.30901000e+07f, 1.31948000e+07f, 1.31682250e+07f, 1.32735500e+07f, 1.32463500e+07f, 1.33523000e+07f, 1.33244750e+07f, 1.34310500e+07f, 1.34026000e+07f, 1.35098000e+07f, 1.34807250e+07f, 1.35885500e+07f, 1.35588500e+07f, 1.36673000e+07f, 1.36369750e+07f, 1.37460500e+07f, 1.37151000e+07f, 1.38248000e+07f, 1.37932250e+07f, 1.39035500e+07f, 1.38713500e+07f, 1.39823000e+07f, 1.39494750e+07f, 1.40610500e+07f, 1.40276000e+07f, 1.41398000e+07f, 1.41057250e+07f, 1.42185500e+07f, 1.41838500e+07f, 1.42973000e+07f, 1.42619750e+07f, 1.43760500e+07f, 1.43401000e+07f, 1.44548000e+07f, 1.44182250e+07f, 1.45335500e+07f, 1.44963500e+07f, 1.46123000e+07f, 1.45744750e+07f, 1.46910500e+07f, 1.46526000e+07f, 1.47698000e+07f, 1.47307250e+07f, 1.48485500e+07f, 1.48088500e+07f, 1.49273000e+07f, 1.48869750e+07f, 1.50060500e+07f, 1.49651000e+07f, 1.50848000e+07f, 1.50432250e+07f, 1.51635500e+07f, 1.51213500e+07f, 1.52423000e+07f, 1.51994750e+07f, 1.53210500e+07f, 1.52776000e+07f, 1.53998000e+07f, 1.53557250e+07f, 1.54785500e+07f, 1.54338500e+07f, 1.55573000e+07f, 1.55119750e+07f, 1.56360500e+07f, 1.55901000e+07f, 1.57148000e+07f, 1.56682250e+07f, 1.57935500e+07f, 1.57463500e+07f, 1.58723000e+07f, 1.58244750e+07f, 1.59510500e+07f, 1.59026000e+07f, 1.60298000e+07f, 1.59807250e+07f, 1.61085500e+07f, 1.60588500e+07f, 1.61873000e+07f, 1.61369750e+07f, 1.62660500e+07f, 1.62151000e+07f, 1.63448000e+07f, 1.62932250e+07f, 1.64235500e+07f, 1.63713500e+07f, 1.65023000e+07f, 1.64494750e+07f, 1.65810500e+07f, 1.65276000e+07f, 1.66598000e+07f, 1.66057250e+07f, 1.67385500e+07f, 1.66838500e+07f, 1.68173000e+07f, 1.67619750e+07f, 1.68960500e+07f, 1.68401000e+07f, 1.69748000e+07f, 1.69182250e+07f, 1.70535500e+07f, 1.69963500e+07f, 1.71323000e+07f, 1.70744750e+07f, 1.72110500e+07f, 1.71526000e+07f, 1.72898000e+07f, 1.72307250e+07f, 1.73685500e+07f, 1.73088500e+07f, 1.74473000e+07f, 1.73869750e+07f, 1.75260500e+07f, 1.74651000e+07f, 1.76048000e+07f, 1.75432250e+07f, 1.76835500e+07f, 2.46688500e+07f, 2.48098000e+07f, 2.47782250e+07f, 2.49198000e+07f, 2.48876000e+07f, 2.50298000e+07f, 2.49969750e+07f, 2.51398000e+07f, 2.51063500e+07f, 2.52498000e+07f, 2.52157250e+07f, 2.53598000e+07f, 2.53251000e+07f, 2.54698000e+07f, 2.54344750e+07f, 2.55798000e+07f, 2.55438500e+07f, 2.56898000e+07f, 2.56532250e+07f, 2.57998000e+07f, 2.57626000e+07f, 2.59098000e+07f, 2.58719750e+07f, 2.60198000e+07f, 2.59813500e+07f, 2.61298000e+07f, 2.60907250e+07f, 2.62398000e+07f, 2.62001000e+07f, 2.63498000e+07f, 2.63094750e+07f, 2.64598000e+07f, 2.64188500e+07f, 2.65698000e+07f, 2.65282250e+07f, 2.66798000e+07f, 2.66376000e+07f, 2.67898000e+07f, 2.67469750e+07f, 2.68998000e+07f, 2.68563500e+07f, 2.70098000e+07f, 2.69657250e+07f, 2.71198000e+07f, 2.70751000e+07f, 2.72298000e+07f, 2.71844750e+07f, 2.73398000e+07f, 2.72938500e+07f, 2.74498000e+07f, 2.74032250e+07f, 2.75598000e+07f, 2.75126000e+07f, 2.76698000e+07f, 2.76219750e+07f, 2.77798000e+07f, 2.77313500e+07f, 2.78898000e+07f, 2.78407250e+07f, 2.79998000e+07f, 2.79501000e+07f, 2.81098000e+07f, 2.80594750e+07f, 2.82198000e+07f, 2.81688500e+07f, 2.83298000e+07f, 2.82782250e+07f, 2.84398000e+07f, 2.83876000e+07f, 2.85498000e+07f, 2.84969750e+07f, 2.86598000e+07f, 2.86063500e+07f, 2.87698000e+07f, 2.87157250e+07f, 2.88798000e+07f, 2.88251000e+07f, 2.89898000e+07f, 2.89344750e+07f, 2.90998000e+07f, 2.90438500e+07f, 2.92098000e+07f, 2.91532250e+07f, 2.93198000e+07f, 2.92626000e+07f, 2.94298000e+07f, 2.93719750e+07f, 2.95398000e+07f, 2.94813500e+07f, 2.96498000e+07f, 2.95907250e+07f, 2.97598000e+07f, 2.97001000e+07f, 2.98698000e+07f, 2.98094750e+07f, 2.99798000e+07f, 2.99188500e+07f, 3.00898000e+07f, 3.00282250e+07f, 3.01998000e+07f, 3.01376000e+07f, 3.03098000e+07f, 3.02469750e+07f, 3.04198000e+07f, 3.03563500e+07f, 3.05298000e+07f, 3.04657250e+07f, 3.06398000e+07f, 3.05751000e+07f, 3.07498000e+07f, 3.06844750e+07f, 3.08598000e+07f, 3.07938500e+07f, 3.09698000e+07f, 3.09032250e+07f, 3.10798000e+07f, 3.10126000e+07f, 3.11898000e+07f, 3.11219750e+07f, 3.12998000e+07f, 3.12313500e+07f, 3.14098000e+07f, 3.13407250e+07f, 3.15198000e+07f, 3.14501000e+07f, 3.16298000e+07f, 3.15594750e+07f, 3.17398000e+07f, 3.16688500e+07f, 3.18498000e+07f, 3.17782250e+07f, 3.19598000e+07f, 3.18876000e+07f, 3.20698000e+07f, 3.19969750e+07f, 3.21798000e+07f, 3.21063500e+07f, 3.22898000e+07f, 3.22157250e+07f, 3.23998000e+07f, 3.23251000e+07f, 3.25098000e+07f, 3.24344750e+07f, 3.26198000e+07f, 3.25438500e+07f, 3.27298000e+07f, 3.26532250e+07f, 3.28398000e+07f, 3.27626000e+07f, 3.29498000e+07}; auto a = NDArrayFactory::create('c', {2, 72, 25}); for (int e = 0; e < a.lengthOf(); e++) @@ -1626,7 +1678,7 @@ TEST_F(NDArrayTest, applyReduce3Dot) { TEST_F(NDArrayTest, applyAllReduce3EuclideanDistance) { float xBuff[] = {1, 2, 3, 4, 5, 6}; float yBuff[] = {2, 2, 2, 2, 2, 2}; - float expBuff[] = {1.414214, 1.414214, 5.385165, 5.385165}; + float expBuff[] = {1.414214f, 1.414214f, 5.385165f, 5.385165f}; Nd4jLong expShapeInfo[] = {2, 2, 2, 2, 1, 8192, 1, 99}; Nd4jLong xShapeInfo[] = {2, 2, 3, 3, 1, 8192, 1, 99}; @@ -1649,7 +1701,7 @@ TEST_F(NDArrayTest, applyAllReduce3EuclideanDistance) { TEST_F(NDArrayTest, applyReduce3EuclideanDistance) { float xBuff[] = {1, 2, 3, 4, 5, 6}; float yBuff[] = {2, 2, 2, 2, 2, 2}; - float expBuff[] = {1.414214, 1.414214, 5.385165, 5.385165}; + float expBuff[] = {1.414214f, 1.414214f, 5.385165f, 5.385165f}; Nd4jLong expShapeInfo[] = {2, 2, 2, 2, 1, 8192, 1, 99}; Nd4jLong xShapeInfo[] = {2, 2, 3, 3, 1, 8192, 1, 99}; @@ -1670,7 +1722,7 @@ TEST_F(NDArrayTest, applyReduce3EuclideanDistance) { TEST_F(NDArrayTest, TestVarianceAlongDimension1) { float xBuff[] = {1, 2, 3, 4, 5, 6}; - float expBuff[] = {0.816497, 0.816497}; + float expBuff[] = {0.816497f, 0.816497f}; Nd4jLong xShapeInfo[] = {2, 2, 3, 3, 1, 8192, 1, 99}; Nd4jLong expShapeInfo[] = {1, 2, 1, 8192, 1, 99}; @@ -1688,7 +1740,7 @@ TEST_F(NDArrayTest, TestVarianceAlongDimension1) { ////////////////////////////////////////////////////////////////////// TEST_F(NDArrayTest, TestVarianceAlongDimension2) { float xBuff[] = {1, 2, 3, 4, 5, 6}; - float expBuff[] = {0.666667, 0.666667}; + float expBuff[] = {0.666667f, 0.666667f}; Nd4jLong xShapeInfo[] = {2, 2, 3, 3, 1, 8192, 1, 99}; Nd4jLong expShapeInfo[] = {1, 2, 1, 8192, 1, 99}; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ResizeBilinear.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ResizeBilinear.java index be6eb3730..b6a96699c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ResizeBilinear.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ResizeBilinear.java @@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.image; import lombok.NoArgsConstructor; import lombok.NonNull; import lombok.NoArgsConstructor; +import lombok.val; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; @@ -43,20 +44,25 @@ import java.util.Map; @NoArgsConstructor public class ResizeBilinear extends DynamicCustomOp { protected boolean alignCorners = false; + protected boolean halfPixelCenters = false; protected Integer height = null; protected Integer width = null; - public ResizeBilinear(@NonNull SameDiff sd, @NonNull SDVariable input, int height, int width, boolean alignCorners){ + public ResizeBilinear(@NonNull SameDiff sd, @NonNull SDVariable input, int height, int width, + boolean alignCorners, boolean halfPixelCenters){ super(sd, input); this.alignCorners = alignCorners; this.height = height; this.width = width; + this.halfPixelCenters = halfPixelCenters; addArgs(); } - public ResizeBilinear(@NonNull INDArray x, INDArray z, int height, int width, boolean alignCorners){ + public ResizeBilinear(@NonNull INDArray x, INDArray z, int height, int width, + boolean alignCorners, boolean halfPixelCenters) { super(new INDArray[]{x}, new INDArray[]{z}); this.alignCorners = alignCorners; + this.halfPixelCenters = halfPixelCenters; this.height = height; this.width = width; addArgs(); @@ -76,7 +82,12 @@ public class ResizeBilinear extends DynamicCustomOp { public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); - this.alignCorners = attributesForNode.get("align_corners").getB(); + val attrC = attributesForNode.get("align_corners"); + val attrH = attributesForNode.get("half_pixel_centers"); + + this.alignCorners = attrC != null ? attrC.getB() : false; + this.halfPixelCenters = attrH != null ? attrH.getB() : false; + addArgs(); } @@ -87,8 +98,7 @@ public class ResizeBilinear extends DynamicCustomOp { iArguments.add(Long.valueOf(height)); iArguments.add(Long.valueOf(width)); } - iArguments.add(alignCorners ? 1L : 0L); - + addBArgument(alignCorners, halfPixelCenters); } @Override diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java index 45a20bfbc..94c5601c1 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java @@ -4584,6 +4584,7 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); * returns reference on array element with given index */ + /** * returns array element with given index * i - element index in array @@ -5171,6 +5172,8 @@ NDArray& NDArray::operator()(const Nd4jLong* idx) { + + //////////////////////////////////////////////////////////////////////// @@ -5179,6 +5182,8 @@ NDArray& NDArray::operator()(const Nd4jLong* idx) { + + // #ifndef __JAVACPP_HACK__ // #endif diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java index e9a36d49f..0ba5d1293 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java @@ -4587,6 +4587,7 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); * returns reference on array element with given index */ + /** * returns array element with given index * i - element index in array @@ -5174,6 +5175,8 @@ NDArray& NDArray::operator()(const Nd4jLong* idx) { + + //////////////////////////////////////////////////////////////////////// @@ -5182,6 +5185,8 @@ NDArray& NDArray::operator()(const Nd4jLong* idx) { + + // #ifndef __JAVACPP_HACK__ // #endif @@ -18280,7 +18285,6 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); /** * This op calculates polygamma function psi^(n)(x). Implementation is based on serial representation written in * terms of the Hurwitz zeta function: polygamma = (-1)^{n+1} * n! * zeta(n+1, x). - * Currently the case n = 0 is not supported. * * Input arrays: * 0: n - define derivative order (n+1), type integer (however currently is implemented as float casted to integer) @@ -18309,6 +18313,34 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); } // #endif + /** + * This op calculates digamma function psi(x) = derivative of log(Gamma(x)) + * + * Input arrays: + * 0: x - abscissa points where to evaluate the digamma function, type float + * + * Output array: + * 0: values of digamma function at corresponding x, type float + * + */ +// #if NOT_EXCLUDED(OP_digamma) + @Namespace("nd4j::ops") public static class digamma extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public digamma(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public digamma(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public digamma position(long position) { + return (digamma)super.position(position); + } + + public digamma() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } +// #endif + /** * This operation takes shape as first argument, and returns new NDArray filled with specific scalar value. * Input arrays: @@ -18398,9 +18430,10 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); * This operation adjusts image hue by delta * Input arrays: * 0 - input array with rank >= 3, must have at least one dimension equal 3, that is dimension containing channels. + * 1 - optional argument, input scalar-array containing delta * * T arguments: - * 0 - delta value + * 0 - optional argument, delta value * * Int arguments: * 0 - optional argument, corresponds to dimension with 3 channels @@ -18427,9 +18460,10 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); * This operation adjusts image saturation by delta * Input arrays: * 0 - input array with rank >= 3, must have at least one dimension equal 3, that is dimension containing channels. + * 1 - optional argument, input scalar-array containing saturation factor * * T arguments: - * 0 - saturation factor + * 0 - optional argument, saturation factor * * Int arguments: * 0 - optional argument, corresponds to dimension with 3 channels @@ -18456,9 +18490,10 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); * This operation adjusts image contrast by given factor ( z = (x - mean) * factor + mean ) * Input arrays: * 0 - input array with rank >= 3, must have last one dimension equal 3, that is dimension containing channels. + * 1 - optional argument, input scalar-array containing saturation contrast factor * * T arguments: - * 0 - contrast factor + * 0 - optional argument, contrast factor * */ // #if NOT_EXCLUDED(OP_adjust_contrast) @@ -21053,7 +21088,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); // #endif /** - * compare_and_bitpack - compare with greater and pack result with uint8 + * compare_and_bitpack - compare with greater and pack result with uint8 * * input params: * 0 - NDArray (input) diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java index ec65d71df..bc9f03e2f 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java @@ -117,9 +117,6 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a // 2019/11/15 - failure https://github.com/eclipse/deeplearning4j/issues/8402 "fake_quant/min_max_args_per_channel.*", - // 2019/11/15 - failure https://github.com/eclipse/deeplearning4j/issues/8403 - "resize_bilinear/int32.*", - // Suggesting TF 1.15 bug "non_max_suppression_v2/float16.*", diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java index fbb1ddb85..742ffae66 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java @@ -972,7 +972,7 @@ public class CustomOpsTests extends BaseNd4jTest { INDArray x = Nd4j.rand(1, 2,3,4); INDArray z = Nd4j.createUninitialized(x.shape()); boolean align = false; - val op = new ResizeBilinear(x, z, 10, 10, align); + val op = new ResizeBilinear(x, z, 10, 10, align, false); Nd4j.exec(op); } @@ -1174,6 +1174,7 @@ public class CustomOpsTests extends BaseNd4jTest { assertEquals(expected, x); } + @Ignore("AS failed 2019/12/04") @Test public void testPolygamma() { INDArray n = Nd4j.linspace(DataType.FLOAT, 1.0, 1.0, 9).reshape(3,3); From b32dd1bf92ca2aa5160cb9f64f351510d6baaa36 Mon Sep 17 00:00:00 2001 From: raver119 Date: Fri, 6 Dec 2019 18:58:37 +0300 Subject: [PATCH 08/18] [WIP] resize_bicubic types (#116) * resize_bicubic: allow more dtypes Signed-off-by: raver119 * resize_bicubic: allow less dtypes Signed-off-by: raver119 * Refactored resize_bicubic op to full conform with TF1.5 and tests. * Corrected test to proper data type output. Signed-off-by: shugeo * Corrected double input test to float constant outputs. Signed-off-by: shugeo * Finished with correction of tests for bicubic interpolated resizes expected. Signed-off-by: shugeo * Fixed adjust_contrast ops to allow non-RGB inputs. Signed-off-by: shugeo * Refactored adjust_contrast_v2 to conform with TF one. Signed-off-by: shugeo * AdjustContrast tests activated * two typos fixed Signed-off-by: raver119 --- .../generic/broadcastable/realdiv.cpp | 2 +- .../broadcastable/squared_subtract.cpp | 2 +- .../generic/parity_ops/adjust_contrast.cpp | 26 +- .../generic/parity_ops/resize_bicubic.cpp | 8 +- .../declarable/helpers/cpu/image_resize.cpp | 11 +- .../declarable/helpers/cuda/image_resize.cu | 4 +- .../layers_tests/DeclarableOpsTests11.cpp | 883 ++++++++---------- .../layers_tests/DeclarableOpsTests15.cpp | 634 +++++++++++++ .../TFGraphs/TFGraphTestAllSameDiff.java | 3 - 9 files changed, 1045 insertions(+), 528 deletions(-) diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/realdiv.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/realdiv.cpp index 4260f6ffa..7b4e374d5 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/realdiv.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/realdiv.cpp @@ -76,7 +76,7 @@ namespace nd4j { // Y gradient //epsNext->applyTriplewiseLambda(x, y, lambdaY, gradY); - gradY->assign(epsNext * -(*x) / ((*y) * (*y))); + gradY->assign((*epsNext) * -(*x) / ((*y) * (*y))); } else if (y->isScalar()) { // scalar case diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/squared_subtract.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/squared_subtract.cpp index 655a26429..280a09857 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/squared_subtract.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/squared_subtract.cpp @@ -89,7 +89,7 @@ namespace nd4j { gradY->assign(tmpX); //epsNext->applyPairwiseLambda(x, lambdaS, gradX); - gradX->assign(epsNext * ts * ((*x) - (*y))); + gradX->assign((*epsNext) * ts * ((*x) - (*y))); } else { // broadcast case diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/adjust_contrast.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/adjust_contrast.cpp index 1aa0c5249..27b6a4302 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/adjust_contrast.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/adjust_contrast.cpp @@ -39,7 +39,7 @@ CONFIGURABLE_OP_IMPL(adjust_contrast, 1, 1, true, 0, 0) { REQUIRE_TRUE(block.numT() > 0 || block.width() > 1, 0, "ADJUST_CONTRAST: Scale factor required"); REQUIRE_TRUE(input->rankOf() > 2, 0, "ADJUST_CONTRAST: op expects rank of input array to be >= 3, but got %i instead", input->rankOf()); - REQUIRE_TRUE(input->sizeAt(-1) == 3, 0, "ADJUST_CONTRAST: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(-1)); +// REQUIRE_TRUE(input->sizeAt(-1) == 3, 0, "ADJUST_CONTRAST: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(-1)); NDArray* factor = nullptr; @@ -84,10 +84,15 @@ CONFIGURABLE_OP_IMPL(adjust_contrast_v2, 1, 1, true, 0, 0) { return Status::OK(); REQUIRE_TRUE(input->rankOf() > 2, 0, "ADJUST_CONTRAST_V2: op expects rank of input array to be >= 3, but got %i instead", input->rankOf()); - REQUIRE_TRUE(input->sizeAt(-1) == 3, 0, "ADJUST_CONTRAST_V2: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(-1)); +// REQUIRE_TRUE(input->sizeAt(-1) == 3, 0, "ADJUST_CONTRAST_V2: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(-1)); REQUIRE_TRUE(block.numT() > 0 || block.width() > 1, 0, "ADJUST_CONTRAST_V2: Scale factor required"); NDArray* factor = nullptr; + auto size = input->sizeAt(-2) * input->sizeAt(-3); + auto channels = input->sizeAt(-1); + auto batch = input->lengthOf() / (size * channels); + auto input3D = input->reshape(input->ordering(), {batch, size, channels}); + auto output3D = input->reshape(input->ordering(), {batch, size, channels}); if(block.width() > 1) factor = INPUT_VARIABLE(1); @@ -96,20 +101,17 @@ CONFIGURABLE_OP_IMPL(adjust_contrast_v2, 1, 1, true, 0, 0) { factor->p(0, T_ARG(0)); } - // compute mean before - std::vector axes(input->rankOf() - 1); - for (auto i = 0; i < axes.size(); ++i) - axes[i] = i; + std::vector axes({1}); // dim 1 of pseudoresult - // mean as reduction for last dimension set - auto mean = input->reduceAlongDims(reduce::Mean, axes); +// mean as reduction for last dimension set over size (dim 1) of result3D + auto mean = input3D.reduceAlongDims(reduce::Mean, axes); // result as (x - mean) * factor + mean - auto temp = input->ulike(); - input->applyTrueBroadcast(BroadcastOpsTuple::Subtract(), &mean, &temp); + auto temp = input3D.ulike(); + input3D.applyBroadcast(broadcast::Subtract, {0, 2}, &mean, &temp, nullptr); temp.applyScalarArr(scalar::Multiply, factor); - temp.applyTrueBroadcast(BroadcastOpsTuple::Add(), &mean, output); - + temp.applyBroadcast(broadcast::Add, {0, 2}, &mean, &output3D); + output->assign(output3D); if(block.width() == 1) delete factor; diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/resize_bicubic.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/resize_bicubic.cpp index 99053561c..da98c1702 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/resize_bicubic.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/resize_bicubic.cpp @@ -96,16 +96,16 @@ namespace nd4j { outputShape[2] = height; outputShape[3] = in[3]; } - ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); + ShapeUtils::updateStridesAndType(outputShape, DataType::FLOAT32, shape::order(in)); shapeList->push_back(CONSTANT(outputShape)); return shapeList; } DECLARE_TYPES(resize_bicubic) { getOpDescriptor() - ->setAllowedInputTypes(0, {ALL_FLOATS}) - ->setAllowedInputTypes(1, {DataType::INT32}) - ->setAllowedOutputTypes({ALL_FLOATS}); + ->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS}) + ->setAllowedInputTypes(1, DataType::INT32) + ->setAllowedOutputTypes({DataType::FLOAT32}); } } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/image_resize.cpp b/libnd4j/include/ops/declarable/helpers/cpu/image_resize.cpp index 16ddd17da..87bc9c9f2 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/image_resize.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/image_resize.cpp @@ -352,14 +352,12 @@ namespace helpers { int resizeBilinearFunctor(nd4j::LaunchContext * context, NDArray const *images, int const width, int const height, bool const alignCorners, bool const halfPixelCenter, NDArray *output) { - BUILD_DOUBLE_SELECTOR(images->dataType(), output->dataType(), return resizeBilinearFunctor_, - (images, width, height, alignCorners, halfPixelCenter, output), NUMERIC_TYPES, FLOAT_TYPES); + BUILD_DOUBLE_SELECTOR(images->dataType(), output->dataType(), return resizeBilinearFunctor_, (images, width, height, alignCorners, halfPixelCenter, output), NUMERIC_TYPES, FLOAT_TYPES); } int resizeNeighborFunctor(nd4j::LaunchContext * context, NDArray const *images, int const width, int const height, bool const alignCorners, bool const halfPixelCenter, NDArray *output) { - BUILD_SINGLE_SELECTOR(images->dataType(), return resizeNeighborFunctor_, - (images, width, height, alignCorners, halfPixelCenter, output), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(images->dataType(), return resizeNeighborFunctor_, (images, width, height, alignCorners, halfPixelCenter, output), LIBND4J_TYPES); } @@ -696,7 +694,7 @@ namespace helpers { const Nd4jLong inBatchWidth = resizerState.inHeight * inRowWidth; const T* inputPtr = image->getDataBuffer()->primaryAsT(); - T* pOutputY = output->dataBuffer()->primaryAsT(); //_data.data(); + float* pOutputY = output->dataBuffer()->primaryAsT(); // output is float anyway std::vector cachedValue(numChannels == 3 ? 0 : 4 * numChannels, 0); auto func = PRAGMA_THREADS_FOR { @@ -881,8 +879,7 @@ namespace helpers { } int resizeBicubicFunctorA(nd4j::LaunchContext * context, NDArray const* image, int const width, int const height, bool const alignCorners, bool const halfPixelAlign, NDArray* output) { - BUILD_SINGLE_SELECTOR(image->dataType(), return resizeBicubicFunctorA_, (context, - image, width, height, alignCorners, halfPixelAlign, output), NUMERIC_TYPES); + BUILD_SINGLE_SELECTOR(image->dataType(), return resizeBicubicFunctorA_, (context, image, width, height, alignCorners, halfPixelAlign, output), NUMERIC_TYPES); } // ------------------------------------------------------------------------------------------------------------------ // int resizeFunctor(nd4j::LaunchContext * context, NDArray const* image, int const width, int const height, diff --git a/libnd4j/include/ops/declarable/helpers/cuda/image_resize.cu b/libnd4j/include/ops/declarable/helpers/cuda/image_resize.cu index 4f025d851..ab3a96801 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/image_resize.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/image_resize.cu @@ -689,7 +689,7 @@ namespace helpers { } template - static __global__ void bicubicInterpolateWithCachingKernel(float const* cachedTable, float* cachedValue, T const* inputPtr, ImageResizerState* pResizerState, WeightsAndIndices* xWais, bool halfPixelCenters, Nd4jLong inBatchWidth, Nd4jLong inRowWidth, T* outputPtr) { + static __global__ void bicubicInterpolateWithCachingKernel(float const* cachedTable, float* cachedValue, T const* inputPtr, ImageResizerState* pResizerState, WeightsAndIndices* xWais, bool halfPixelCenters, Nd4jLong inBatchWidth, Nd4jLong inRowWidth, float* outputPtr) { // auto numChannels = pResizerState->channels; for (Nd4jLong b = blockIdx.x; b < pResizerState->batchSize; b += gridDim.x) { auto pInput = inputPtr + b * inBatchWidth; @@ -877,7 +877,7 @@ namespace helpers { throw cuda_exception::build("helpers::bicubicInterpolateWithCaching: computeXWeigtsAndInidces finished with error", err); } const T* pInput = image->getDataBuffer()->specialAsT(); - T* pOutput = output->dataBuffer()->specialAsT(); //_data.data(); + float* pOutput = output->dataBuffer()->specialAsT(); //_data.data(); bicubicInterpolateWithCachingKernel<<<128, 1, 512, *stream>>>(coeffsTable, cachedValue, pInput, resizerStateD, xWais, halfPixelCenters, inBatchWidth, inRowWidth, pOutput); err = cudaStreamSynchronize(*stream); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp index ac56c496f..647f37271 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp @@ -479,166 +479,166 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test13) { TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test1) { - NDArray input = NDArrayFactory::create('c', {1, 7, 7, 1}, { - 1, 2.1, 3.15, 4.2, 5.15, 6.1, 7, - 8, 9.1, 10., 11, 12.9, 13.1, 14, - 15, 16., 17., 18, 19, 20., 21, - 22, 23., 24., 25, 26, 27, 28, - 30, 31, 32, 33, 34., 35, 36, - 37, 38, 39, 40, 41., 42, 43, - 44, 45, 46, 47, 48., 49, 50 + NDArray input = NDArrayFactory::create('c', {1, 7, 7, 1}, { + 1.f, 2.1f, 3.15f, 4.2f, 5.15f, 6.1f, 7.f, + 8.f, 9.1f, 10.f, 11.f, 12.9f, 13.1f, 14.f, + 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, + 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, + 30.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, + 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 43.f, + 44.f, 45.f, 46.f, 47.f, 48.f, 49.f, 50.f }); - NDArray expected = NDArrayFactory::create('c', {1, 30, 30, 1}, { - 1. ,1.1976162 ,1.4174359 ,1.6775769 ,1.9961575 ,2.3283265 , - 2.550918 ,2.7360606 ,2.9655411 ,3.2929654 ,3.5441515 ,3.7380352 , - 3.948995 ,4.248106 ,4.5073795 ,4.6843743 ,4.8572845 ,5.104302 , - 5.3869915 ,5.581401 ,5.7539616 ,5.974285 ,6.272836 ,6.5204263 , - 6.718899 ,6.8871036 ,7.039068 ,7.099216 ,7.0784245 ,7.0281887 , - 2.247592 ,2.446947 ,2.6694887 ,2.9312382 ,3.248216 ,3.5745337 , - 3.78931 ,3.9656973 ,4.186417 ,4.5046535 ,4.740569 ,4.9217057 , - 5.133866 ,5.459533 ,5.7744613 ,6.0197873 ,6.254011 ,6.535633 , - 6.8097296 ,6.9607787 ,7.0749416 ,7.241601 ,7.5094895 ,7.7499495 , - 7.954571 ,8.131972 ,8.286526 ,8.346463 ,8.325745 ,8.275683 , - 3.6286845 ,3.830573 ,4.0569587 ,4.3211575 ,4.6364856 ,4.9556503 , - 5.160583 ,5.3258467 ,5.535462 ,5.84216 ,6.058749 ,6.223753 , - 6.437597 ,6.797369 ,7.1836042 ,7.5164022 ,7.8290343 ,8.154773 , - 8.417635 ,8.512958 ,8.5521 ,8.649708 ,8.87788 ,9.108794 , - 9.320926 ,9.509781 ,9.667375 ,9.72694 ,9.706349 ,9.656599 , - 5.276778 ,5.480438 ,5.709702 ,5.9754477 ,6.288551 ,6.6005697 , - 6.796207 ,6.9511423 ,7.1503997 ,7.4461427 ,7.644651 ,7.794562 , - 8.009684 ,8.400473 ,8.851847 ,9.26469 ,9.649218, 10.015648 , - 10.268647 ,10.313368 ,10.2843275 ,10.319379 ,10.512033 ,10.734956 , - 10.954604 ,11.154507 ,11.315369 ,11.374779 ,11.354242 ,11.304622 , - 7.325373 ,7.5284843 ,7.757575 ,8.022221 ,8.331997 ,8.638187 , - 8.827649 ,8.976217 ,9.168955 ,9.45726 ,9.6442375 ,9.784517 , - 9.999621, 10.407702 ,10.896234, 11.355122, 11.781423, 12.172186 , - 12.420712 ,12.4374485 ,12.370511 ,12.371386 ,12.545973 ,12.766424 , - 12.992249 ,13.20012 ,13.364252 ,13.424109 ,13.40342 ,13.353425 , - 9.493208 ,9.692467 ,9.9169445, 10.176801, 10.482199, 10.78547 , - 10.974367 ,11.123442 ,11.31637 ,11.603645 ,11.790616 ,11.930889 , - 12.144082 ,12.546447 ,13.024898 ,13.4723 ,13.889232 ,14.276275 , - 14.528972 ,14.555555 ,14.50145 ,14.515459 ,14.700572 ,14.927055 , - 15.156046 ,15.366046 ,15.532901 ,15.594008 ,15.5728855 ,15.521847 , - 10.970133 ,11.163599 ,11.380694 ,11.633735 ,11.935032 ,12.238887 , - 12.43254 ,12.588294 ,12.787534 ,13.079956 ,13.27752 ,13.426631 , - 13.636713 ,14.013844 ,14.441672 ,14.827978 ,15.191209 ,15.549808 , - 15.81343 ,15.881828 ,15.883522 ,15.950411 ,16.16933 ,16.40794 , - 16.636436 ,16.842583 ,17.010887 ,17.07363 ,17.05194 ,16.999537 , - 12.219155 ,12.406129 ,12.614796 ,12.860335 ,13.157928 ,13.464224 , - 13.665207 ,13.830567 ,14.039036 ,14.339629 ,14.552863 ,14.715049 , - 14.921564 ,15.264454 ,15.622843 ,15.924977 ,16.213829 ,16.532364 , - 16.8099 ,16.934835 ,17.012146 ,17.150164 ,17.413412 ,17.666712 , - 17.892765 ,18.09207 ,18.261044 ,18.325531 ,18.303238 ,18.249378 , - 13.7663965 ,13.947391 ,14.148263 ,14.386917 ,14.681246 ,14.990087 , - 15.198166 ,15.372728 ,15.590062 ,15.898583 ,16.126892 ,16.301655 , - 16.50487 ,16.815214 ,17.107498 ,17.329458 ,17.547403 ,17.827654 , - 18.118288 ,18.296928 ,18.4461 ,18.651634 ,18.956806 ,19.22382 , - 19.447308 ,19.639887 ,19.809319 ,19.875397 ,19.852556 ,19.797365 , - 15.9419365 ,16.118704 ,16.314133 ,16.547867 ,16.839561 ,17.14954 , - 17.361883 ,17.542162 ,17.764957 ,18.078188 ,18.315733 ,18.498205 , - 18.699116 ,18.988684 ,19.238989 ,19.410137 ,19.583265 ,19.839512 , - 20.13878 ,20.35177 ,20.546844 ,20.795671 ,21.128067 ,21.404358 , - 21.626736 ,21.8155 ,21.98561 ,22.052843 ,22.029604 ,21.973448 , - 17.53522 ,17.71077 ,17.904636 ,18.13695 ,18.42784 ,18.738056 , - 18.951529 ,19.133352 ,19.357613 ,19.672083 ,19.912102 ,20.096638 , - 20.296894 ,20.580765 ,20.819603 ,20.976887 ,21.137802 ,21.387535 , - 21.689209 ,21.911621 ,22.119276 ,22.37999 ,22.71991 ,22.998823 , - 23.22097 ,23.40876 ,23.57911 ,23.646685 ,23.623325 ,23.566887 , - 18.746353 ,18.922657 ,19.117487 ,19.350685 ,19.64207 ,19.952137 , - 20.164913 ,20.345781 ,20.569134 ,20.88284 ,21.12133 ,21.30459 , - 21.505253 ,21.792645 ,22.038572 ,22.204426 ,22.37289 ,22.626648 , - 22.926834 ,23.143423 ,23.343302 ,23.596668 ,23.931936 ,24.209232 , - 24.431519 ,24.619913 ,24.79011 ,24.857473 ,24.83419 ,24.777927 , - 20.16656 ,20.344206 ,20.540766 ,20.775532 ,21.067804 ,21.377607 , - 21.589132 ,21.768297 ,21.99003 ,22.302366 ,22.538124 ,22.719105 , - 22.920494 ,23.214176 ,23.472767 ,23.653934 ,23.83589 ,24.096842 , - 24.394371 ,24.600555 ,24.786541 ,25.026773 ,25.353731 ,25.62813 , - 25.850672 ,26.04014 ,26.210072 ,26.277063 ,26.253906 ,26.197956 , - 22.363024 ,22.54125 ,22.738552 ,22.973991 ,23.266647 ,23.57634 , - 23.787327 ,23.96576 ,24.186796 ,24.498543 ,24.733124 ,24.913122 , - 25.114826 ,25.411213 ,25.675262 ,25.863028 ,26.050789 ,26.314838 , - 26.611223 ,26.812925 ,26.992926 ,27.227505 ,27.550882 ,27.824034 , - 28.046684 ,28.236614 ,28.406433 ,28.473265 ,28.450163 ,28.394344 , - 24.429443 ,24.60767 ,24.80497 ,25.04041 ,25.333065 ,25.642756 , - 25.853743 ,26.032173 ,26.25321 ,26.564959 ,26.79954 ,26.97954 , - 27.181242 ,27.47763 ,27.74168 ,27.929441 ,28.117207 ,28.381254 , - 28.677637 ,28.879343 ,29.059345 ,29.293922 ,29.617298 ,29.890451 , - 30.113104 ,30.303034 ,30.472853 ,30.539684 ,30.516582 ,30.460762 , - 26. ,26.178228 ,26.375526 ,26.61097 ,26.903624 ,27.213314 , - 27.424305 ,27.602734 ,27.823772 ,28.135519 ,28.3701 ,28.550098 , - 28.7518 ,29.04819 ,29.312237 ,29.5 ,29.687763 ,29.951813 , - 30.2482 ,30.449903 ,30.629902 ,30.864483 ,31.187859 ,31.461012 , - 31.683659 ,31.873592 ,32.043407 ,32.11024 ,32.087135 ,32.03132 , - 27.570559 ,27.748787 ,27.946087 ,28.181528 ,28.474184 ,28.783876 , - 28.994865 ,29.173294 ,29.39433 ,29.70608 ,29.940659 ,30.120655 , - 30.32236 ,30.618746 ,30.882797 ,31.070557 ,31.25832 ,31.522371 , - 31.818754 ,32.02046 ,32.20046 ,32.43504 ,32.758415 ,33.031567 , - 33.25422 ,33.44415 ,33.613964 ,33.680794 ,33.657696 ,33.60188 , - 29.636976 ,29.815207 ,30.0125 ,30.247944 ,30.5406 ,30.85029 , - 31.061283 ,31.239712 ,31.46075 ,31.7725 ,32.00708 ,32.187077 , - 32.38878 ,32.685165 ,32.949215 ,33.13698 ,33.32474 ,33.58879 , - 33.885178 ,34.086884 ,34.26688 ,34.501457 ,34.824837 ,35.09799 , - 35.320637 ,35.510574 ,35.68039 ,35.747215 ,35.724117 ,35.6683 , - 31.83344 ,32.011665 ,32.20897 ,32.444412 ,32.73707 ,33.046757 , - 33.257744 ,33.436176 ,33.657207 ,33.96896 ,34.203537 ,34.383537 , - 34.58524 ,34.88163 ,35.145676 ,35.33344 ,35.521206 ,35.785255 , - 36.081642 ,36.28334 ,36.46334 ,36.69792 ,37.021297 ,37.294453 , - 37.517097 ,37.707027 ,37.876846 ,37.94368 ,37.920578 ,37.864758 , - 33.253647 ,33.431873 ,33.62917 ,33.864613 ,34.15727 ,34.466957 , - 34.677948 ,34.856377 ,35.077415 ,35.38916 ,35.623745 ,35.803745 , - 36.005447 ,36.301834 ,36.565884 ,36.753647 ,36.941406 ,37.205456 , - 37.50184 ,37.703545 ,37.883545 ,38.118122 ,38.4415 ,38.714653 , - 38.9373 ,39.127235 ,39.297054 ,39.363884 ,39.340782 ,39.28496 , - 34.464783 ,34.64301 ,34.840305 ,35.075752 ,35.368404 ,35.6781 , - 35.889088 ,36.067516 ,36.28855 ,36.6003 ,36.834885 ,37.014877 , - 37.216583 ,37.51297 ,37.77702 ,37.964783 ,38.152546 ,38.416595 , - 38.71298 ,38.914684 ,39.094685 ,39.32926 ,39.652645 ,39.925793 , - 40.14844 ,40.338375 ,40.508194 ,40.575024 ,40.55192 ,40.496105 , - 36.058067 ,36.23629 ,36.43359 ,36.669033 ,36.961685 ,37.271378 , - 37.48237 ,37.6608 ,37.881836 ,38.19359 ,38.42817 ,38.608162 , - 38.809868 ,39.10625 ,39.3703 ,39.558064 ,39.74583 ,40.00988 , - 40.306267 ,40.50797 ,40.68797 ,40.92255 ,41.245926 ,41.519077 , - 41.741722 ,41.931652 ,42.101475 ,42.168304 ,42.145203 ,42.089386 , - 38.315002 ,38.493233 ,38.690533 ,38.925976 ,39.218628 ,39.52832 , - 39.739307 ,39.917736 ,40.138775 ,40.45052 ,40.685104 ,40.865097 , - 41.066803 ,41.36319 ,41.627243 ,41.815002 ,42.002766 ,42.26682 , - 42.5632 ,42.764908 ,42.944904 ,43.179485 ,43.50286 ,43.776016 , - 43.998665 ,44.188595 ,44.358418 ,44.425247 ,44.402145 ,44.34633 , - 40.22708 ,40.40531 ,40.602608 ,40.83805 ,41.130707 ,41.440395 , - 41.651382 ,41.82982 ,42.050854 ,42.3626 ,42.597183 ,42.77718 , - 42.97888 ,43.27527 ,43.53932 ,43.72708 ,43.914845 ,44.178894 , - 44.47528 ,44.676983 ,44.856983 ,45.09156 ,45.41494 ,45.68809 , - 45.91074 ,46.100674 ,46.270493 ,46.337322 ,46.31422 ,46.2584 , - 41.785618 ,41.963844 ,42.161144 ,42.396584 ,42.68924 ,42.998936 , - 43.209923 ,43.388355 ,43.609394 ,43.921143 ,44.15572 ,44.335716 , - 44.53742 ,44.833805 ,45.09786 ,45.285614 ,45.473377 ,45.737427 , - 46.033817 ,46.235523 ,46.415524 ,46.650105 ,46.973476 ,47.24663 , - 47.469276 ,47.65921 ,47.82903 ,47.895855 ,47.872753 ,47.81694 , - 43.11514 ,43.293365 ,43.490665 ,43.726105 ,44.018764 ,44.328457 , - 44.539444 ,44.717873 ,44.93891 ,45.25066 ,45.48524 ,45.665237 , - 45.86694 ,46.163326 ,46.427376 ,46.615143 ,46.802902 ,47.066956 , - 47.363342 ,47.56505 ,47.74505 ,47.979626 ,48.302998 ,48.576153 , - 48.798798 ,48.98873 ,49.158546 ,49.225376 ,49.202282 ,49.146458 , - 44.303867 ,44.482094 ,44.679394 ,44.914833 ,45.207493 ,45.51718 , - 45.72817 ,45.9066 ,46.12764 ,46.439384 ,46.673965 ,46.853966 , - 47.055668 ,47.352055 ,47.6161 ,47.803867 ,47.99163 ,48.25568 , - 48.552063 ,48.75377 ,48.933773 ,49.16835 ,49.491726 ,49.764877 , - 49.987526 ,50.17746 ,50.347275 ,50.4141 ,50.391006 ,50.335186 , - 44.771675 ,44.949905 ,45.1472 ,45.382645 ,45.6753 ,45.98499 , - 46.195976 ,46.374413 ,46.595448 ,46.907196 ,47.141773 ,47.321774 , - 47.523476 ,47.819862 ,48.08391 ,48.27168 ,48.459446 ,48.72349 , - 49.019882 ,49.22158 ,49.401585 ,49.63616 ,49.959538 ,50.232693 , - 50.455338 ,50.64527 ,50.81509 ,50.88192 ,50.858818 ,50.803 , - 44.609966 ,44.788193 ,44.985493 ,45.220936 ,45.51359 ,45.82328 , - 46.03427 ,46.2127 ,46.433743 ,46.74549 ,46.98007 ,47.160065 , - 47.36177 ,47.658157 ,47.922207 ,48.10997 ,48.297733 ,48.561783 , - 48.858166 ,49.059875 ,49.239872 ,49.47445 ,49.79783 ,50.07098 , - 50.293625 ,50.48356 ,50.653378 ,50.720203 ,50.6971 ,50.64128 , - 44.219246 ,44.397472 ,44.594772 ,44.83021 ,45.122868 ,45.43256 , - 45.643543 ,45.82198 ,46.04302 ,46.354763 ,46.589344 ,46.76934 , - 46.971046 ,47.267433 ,47.531483 ,47.719242 ,47.907005 ,48.17105 , - 48.467438 ,48.66914 ,48.849144 ,49.08372 ,49.4071 ,49.680256 , - 49.902905 ,50.092834 ,50.262653 ,50.329483 ,50.30638 ,50.25057}); + NDArray expected = NDArrayFactory::create('c', {1, 30, 30, 1}, { + 1.f, 1.1976162f, 1.4174359f, 1.6775769f, 1.9961575f, 2.3283265f, + 2.550918f, 2.7360606f, 2.9655411f, 3.2929654f, 3.5441515f, 3.7380352f, + 3.948995f, 4.248106f, 4.5073795f, 4.6843743f, 4.8572845f, 5.104302f, + 5.3869915f, 5.581401f, 5.7539616f, 5.974285f, 6.272836f, 6.5204263f, + 6.718899f, 6.8871036f, 7.039068f, 7.099216f, 7.0784245f, 7.0281887f, + 2.247592f, 2.446947f, 2.6694887f, 2.9312382f, 3.248216f, 3.5745337f, + 3.78931f, 3.9656973f, 4.186417f, 4.5046535f, 4.740569f, 4.9217057f, + 5.133866f, 5.459533f, 5.7744613f, 6.0197873f, 6.254011f, 6.535633f, + 6.8097296f, 6.9607787f, 7.0749416f, 7.241601f, 7.5094895f, 7.7499495f, + 7.954571f, 8.131972f, 8.286526f, 8.346463f, 8.325745f, 8.275683f, + 3.6286845f, 3.830573f, 4.0569587f, 4.3211575f, 4.6364856f, 4.9556503f, + 5.160583f, 5.3258467f, 5.535462f, 5.84216f, 6.058749f, 6.223753f, + 6.437597f, 6.797369f, 7.1836042f, 7.5164022f, 7.8290343f, 8.154773f, + 8.417635f, 8.512958f, 8.5521f, 8.649708f, 8.87788f, 9.108794f, + 9.320926f, 9.509781f, 9.667375f, 9.72694f, 9.706349f, 9.656599f, + 5.276778f, 5.480438f, 5.709702f, 5.9754477f, 6.288551f, 6.6005697f, + 6.796207f, 6.9511423f, 7.1503997f, 7.4461427f, 7.644651f, 7.794562f, + 8.009684f, 8.400473f, 8.851847f, 9.26469f, 9.649218f, 10.015648f, + 10.268647f, 10.313368f, 10.2843275f, 10.319379f, 10.512033f, 10.734956f, + 10.954604f, 11.154507f, 11.315369f, 11.374779f, 11.354242f, 11.304622f, + 7.325373f, 7.5284843f, 7.757575f, 8.022221f, 8.331997f, 8.638187f, + 8.827649f, 8.976217f, 9.168955f, 9.45726f, 9.6442375f, 9.784517f, + 9.999621f, 10.407702f, 10.896234f, 11.355122f, 11.781423f, 12.172186f, + 12.420712f, 12.4374485f, 12.370511f, 12.371386f, 12.545973f, 12.766424f, + 12.992249f, 13.20012f, 13.364252f, 13.424109f, 13.40342f, 13.353425f, + 9.493208f, 9.692467f, 9.9169445f, 10.176801f, 10.482199f, 10.78547f, + 10.974367f, 11.123442f, 11.31637f, 11.603645f, 11.790616f, 11.930889f, + 12.144082f, 12.546447f, 13.024898f, 13.4723f, 13.889232f, 14.276275f, + 14.528972f, 14.555555f, 14.50145f, 14.515459f, 14.700572f, 14.927055f, + 15.156046f, 15.366046f, 15.532901f, 15.594008f, 15.5728855f, 15.521847f, + 10.970133f, 11.163599f, 11.380694f, 11.633735f, 11.935032f, 12.238887f, + 12.43254f, 12.588294f, 12.787534f, 13.079956f, 13.27752f, 13.426631f, + 13.636713f, 14.013844f, 14.441672f, 14.827978f, 15.191209f, 15.549808f, + 15.81343f, 15.881828f, 15.883522f, 15.950411f, 16.16933f, 16.40794f, + 16.636436f, 16.842583f, 17.010887f, 17.07363f, 17.05194f, 16.999537f, + 12.219155f, 12.406129f, 12.614796f, 12.860335f, 13.157928f, 13.464224f, + 13.665207f, 13.830567f, 14.039036f, 14.339629f, 14.552863f, 14.715049f, + 14.921564f, 15.264454f, 15.622843f, 15.924977f, 16.213829f, 16.532364f, + 16.8099f, 16.934835f, 17.012146f, 17.150164f, 17.413412f, 17.666712f, + 17.892765f, 18.09207f, 18.261044f, 18.325531f, 18.303238f, 18.249378f, + 13.7663965f, 13.947391f, 14.148263f, 14.386917f, 14.681246f, 14.990087f, + 15.198166f, 15.372728f, 15.590062f, 15.898583f, 16.126892f, 16.301655f, + 16.50487f, 16.815214f, 17.107498f, 17.329458f, 17.547403f, 17.827654f, + 18.118288f, 18.296928f, 18.4461f, 18.651634f, 18.956806f, 19.22382f, + 19.447308f, 19.639887f, 19.809319f, 19.875397f, 19.852556f, 19.797365f, + 15.9419365f, 16.118704f, 16.314133f, 16.547867f, 16.839561f, 17.14954f, + 17.361883f, 17.542162f, 17.764957f, 18.078188f, 18.315733f, 18.498205f, + 18.699116f, 18.988684f, 19.238989f, 19.410137f, 19.583265f, 19.839512f, + 20.13878f, 20.35177f, 20.546844f, 20.795671f, 21.128067f, 21.404358f, + 21.626736f, 21.8155f, 21.98561f, 22.052843f, 22.029604f, 21.973448f, + 17.53522f, 17.71077f, 17.904636f, 18.13695f, 18.42784f, 18.738056f, + 18.951529f, 19.133352f, 19.357613f, 19.672083f, 19.912102f, 20.096638f, + 20.296894f, 20.580765f, 20.819603f, 20.976887f, 21.137802f, 21.387535f, + 21.689209f, 21.911621f, 22.119276f, 22.37999f, 22.71991f, 22.998823f, + 23.22097f, 23.40876f, 23.57911f, 23.646685f, 23.623325f, 23.566887f, + 18.746353f, 18.922657f, 19.117487f, 19.350685f, 19.64207f, 19.952137f, + 20.164913f, 20.345781f, 20.569134f, 20.88284f, 21.12133f, 21.30459f, + 21.505253f, 21.792645f, 22.038572f, 22.204426f, 22.37289f, 22.626648f, + 22.926834f, 23.143423f, 23.343302f, 23.596668f, 23.931936f, 24.209232f, + 24.431519f, 24.619913f, 24.79011f, 24.857473f, 24.83419f, 24.777927f, + 20.16656f, 20.344206f, 20.540766f, 20.775532f, 21.067804f, 21.377607f, + 21.589132f, 21.768297f, 21.99003f, 22.302366f, 22.538124f, 22.719105f, + 22.920494f, 23.214176f, 23.472767f, 23.653934f, 23.83589f, 24.096842f, + 24.394371f, 24.600555f, 24.786541f, 25.026773f, 25.353731f, 25.62813f, + 25.850672f, 26.04014f, 26.210072f, 26.277063f, 26.253906f, 26.197956f, + 22.363024f, 22.54125f, 22.738552f, 22.973991f, 23.266647f, 23.57634f, + 23.787327f, 23.96576f, 24.186796f, 24.498543f, 24.733124f, 24.913122f, + 25.114826f, 25.411213f, 25.675262f, 25.863028f, 26.050789f, 26.314838f, + 26.611223f, 26.812925f, 26.992926f, 27.227505f, 27.550882f, 27.824034f, + 28.046684f, 28.236614f, 28.406433f, 28.473265f, 28.450163f, 28.394344f, + 24.429443f, 24.60767f, 24.80497f, 25.04041f, 25.333065f, 25.642756f, + 25.853743f, 26.032173f, 26.25321f, 26.564959f, 26.79954f, 26.97954f, + 27.181242f, 27.47763f, 27.74168f, 27.929441f, 28.117207f, 28.381254f, + 28.677637f, 28.879343f, 29.059345f, 29.293922f, 29.617298f, 29.890451f, + 30.113104f, 30.303034f, 30.472853f, 30.539684f, 30.516582f, 30.460762f, + 26.f, 26.178228f, 26.375526f, 26.61097f, 26.903624f, 27.213314f, + 27.424305f, 27.602734f, 27.823772f, 28.135519f, 28.3701f, 28.550098f, + 28.7518f, 29.04819f, 29.312237f, 29.5f, 29.687763f, 29.951813f, + 30.2482f, 30.449903f, 30.629902f, 30.864483f, 31.187859f, 31.461012f, + 31.683659f, 31.873592f, 32.043407f, 32.11024f, 32.087135f, 32.03132f, + 27.570559f, 27.748787f, 27.946087f, 28.181528f, 28.474184f, 28.783876f, + 28.994865f, 29.173294f, 29.39433f, 29.70608f, 29.940659f, 30.120655f, + 30.32236f, 30.618746f, 30.882797f, 31.070557f, 31.25832f, 31.522371f, + 31.818754f, 32.02046f, 32.20046f, 32.43504f, 32.758415f, 33.031567f, + 33.25422f, 33.44415f, 33.613964f, 33.680794f, 33.657696f, 33.60188f, + 29.636976f, 29.815207f, 30.0125f, 30.247944f, 30.5406f, 30.85029f, + 31.061283f, 31.239712f, 31.46075f, 31.7725f, 32.00708f, 32.187077f, + 32.38878f, 32.685165f, 32.949215f, 33.13698f, 33.32474f, 33.58879f, + 33.885178f, 34.086884f, 34.26688f, 34.501457f, 34.824837f, 35.09799f, + 35.320637f, 35.510574f, 35.68039f, 35.747215f, 35.724117f, 35.6683f, + 31.83344f, 32.011665f, 32.20897f, 32.444412f, 32.73707f, 33.046757f, + 33.257744f, 33.436176f, 33.657207f, 33.96896f, 34.203537f, 34.383537f, + 34.58524f, 34.88163f, 35.145676f, 35.33344f, 35.521206f, 35.785255f, + 36.081642f, 36.28334f, 36.46334f, 36.69792f, 37.021297f, 37.294453f, + 37.517097f, 37.707027f, 37.876846f, 37.94368f, 37.920578f, 37.864758f, + 33.253647f, 33.431873f, 33.62917f, 33.864613f, 34.15727f, 34.466957f, + 34.677948f, 34.856377f, 35.077415f, 35.38916f, 35.623745f, 35.803745f, + 36.005447f, 36.301834f, 36.565884f, 36.753647f, 36.941406f, 37.205456f, + 37.50184f, 37.703545f, 37.883545f, 38.118122f, 38.4415f, 38.714653f, + 38.9373f, 39.127235f, 39.297054f, 39.363884f, 39.340782f, 39.28496f, + 34.464783f, 34.64301f, 34.840305f, 35.075752f, 35.368404f, 35.6781f, + 35.889088f, 36.067516f, 36.28855f, 36.6003f, 36.834885f, 37.014877f, + 37.216583f, 37.51297f, 37.77702f, 37.964783f, 38.152546f, 38.416595f, + 38.71298f, 38.914684f, 39.094685f, 39.32926f, 39.652645f, 39.925793f, + 40.14844f, 40.338375f, 40.508194f, 40.575024f, 40.55192f, 40.496105f, + 36.058067f, 36.23629f, 36.43359f, 36.669033f, 36.961685f, 37.271378f, + 37.48237f, 37.6608f, 37.881836f, 38.19359f, 38.42817f, 38.608162f, + 38.809868f, 39.10625f, 39.3703f, 39.558064f, 39.74583f, 40.00988f, + 40.306267f, 40.50797f, 40.68797f, 40.92255f, 41.245926f, 41.519077f, + 41.741722f, 41.931652f, 42.101475f, 42.168304f, 42.145203f, 42.089386f, + 38.315002f, 38.493233f, 38.690533f, 38.925976f, 39.218628f, 39.52832f, + 39.739307f, 39.917736f, 40.138775f, 40.45052f, 40.685104f, 40.865097f, + 41.066803f, 41.36319f, 41.627243f, 41.815002f, 42.002766f, 42.26682f, + 42.5632f, 42.764908f, 42.944904f, 43.179485f, 43.50286f, 43.776016f, + 43.998665f, 44.188595f, 44.358418f, 44.425247f, 44.402145f, 44.34633f, + 40.22708f, 40.40531f, 40.602608f, 40.83805f, 41.130707f, 41.440395f, + 41.651382f, 41.82982f, 42.050854f, 42.3626f, 42.597183f, 42.77718f, + 42.97888f, 43.27527f, 43.53932f, 43.72708f, 43.914845f, 44.178894f, + 44.47528f, 44.676983f, 44.856983f, 45.09156f, 45.41494f, 45.68809f, + 45.91074f, 46.100674f, 46.270493f, 46.337322f, 46.31422f, 46.2584f, + 41.785618f, 41.963844f, 42.161144f, 42.396584f, 42.68924f, 42.998936f, + 43.209923f, 43.388355f, 43.609394f, 43.921143f, 44.15572f, 44.335716f, + 44.53742f, 44.833805f, 45.09786f, 45.285614f, 45.473377f, 45.737427f, + 46.033817f, 46.235523f, 46.415524f, 46.650105f, 46.973476f, 47.24663f, + 47.469276f, 47.65921f, 47.82903f, 47.895855f, 47.872753f, 47.81694f, + 43.11514f, 43.293365f, 43.490665f, 43.726105f, 44.018764f, 44.328457f, + 44.539444f, 44.717873f, 44.93891f, 45.25066f, 45.48524f, 45.665237f, + 45.86694f, 46.163326f, 46.427376f, 46.615143f, 46.802902f, 47.066956f, + 47.363342f, 47.56505f, 47.74505f, 47.979626f, 48.302998f, 48.576153f, + 48.798798f, 48.98873f, 49.158546f, 49.225376f, 49.202282f, 49.146458f, + 44.303867f, 44.482094f, 44.679394f, 44.914833f, 45.207493f, 45.51718f, + 45.72817f, 45.9066f, 46.12764f, 46.439384f, 46.673965f, 46.853966f, + 47.055668f, 47.352055f, 47.6161f, 47.803867f, 47.99163f, 48.25568f, + 48.552063f, 48.75377f, 48.933773f, 49.16835f, 49.491726f, 49.764877f, + 49.987526f, 50.17746f, 50.347275f, 50.4141f, 50.391006f, 50.335186f, + 44.771675f, 44.949905f, 45.1472f, 45.382645f, 45.6753f, 45.98499f, + 46.195976f, 46.374413f, 46.595448f, 46.907196f, 47.141773f, 47.321774f, + 47.523476f, 47.819862f, 48.08391f, 48.27168f, 48.459446f, 48.72349f, + 49.019882f, 49.22158f, 49.401585f, 49.63616f, 49.959538f, 50.232693f, + 50.455338f, 50.64527f, 50.81509f, 50.88192f, 50.858818f, 50.803f, + 44.609966f, 44.788193f, 44.985493f, 45.220936f, 45.51359f, 45.82328f, + 46.03427f, 46.2127f, 46.433743f, 46.74549f, 46.98007f, 47.160065f, + 47.36177f, 47.658157f, 47.922207f, 48.10997f, 48.297733f, 48.561783f, + 48.858166f, 49.059875f, 49.239872f, 49.47445f, 49.79783f, 50.07098f, + 50.293625f, 50.48356f, 50.653378f, 50.720203f, 50.6971f, 50.64128f, + 44.219246f, 44.397472f, 44.594772f, 44.83021f, 45.122868f, 45.43256f, + 45.643543f, 45.82198f, 46.04302f, 46.354763f, 46.589344f, 46.76934f, + 46.971046f, 47.267433f, 47.531483f, 47.719242f, 47.907005f, 48.17105f, + 48.467438f, 48.66914f, 48.849144f, 49.08372f, 49.4071f, 49.680256f, + 49.902905f, 50.092834f, 50.262653f, 50.329483f, 50.30638f, 50.25057f}); auto size = NDArrayFactory::create({30, 30}); nd4j::ops::resize_bicubic op; @@ -656,64 +656,63 @@ TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test1) { TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test2) { NDArray input = NDArrayFactory::create('c', {2, 5, 4, 3}); - NDArray expected = NDArrayFactory::create('c', {2, 10, 8, 3}, { - 1. , 2. ,3. ,2.21875, 3.21875, 4.21875, 4. , 5. , 6. ,5.5, - 6.5, 7.5, 7., 8., 9. ,8.78125, 9.78125, 10.78125, 10., 11. , - 12., 10.28125 , 11.28125 ,12.28125, 5.875, 6.875, 7.875, 7.09375, 8.09375 ,9.09375, - 8.875, 9.875, 10.875, 10.375, 11.375, 12.375 ,11.875 ,12.875 , 13.875, 13.65625, - 14.65625, 15.65625, 14.875 ,15.875 ,16.875 , 15.15625, 16.15625, 17.15625, 13., 14., - 15. ,14.21875, 15.21875, 16.21875, 16., 17., 18. ,17.5 ,18.5 , 19.5, - 19., 20., 21., 20.78125 ,21.78125 ,22.78125, 22., 23. , 24. , 22.28125, - 23.28125 ,24.28125 ,19. , 20., 21., 20.21875, 21.21875, 22.21875 ,22. ,23., - 24. , 23.5, 24.5, 25.5, 25. ,26. ,27., 26.78125 , 27.78125, 28.78125, - 28., 29. ,30. ,28.28125, 29.28125, 30.28125, 25., 26., 27. ,26.21875, - 27.21875, 28.21875, 28., 29., 30., 29.5 ,30.5 ,31.5 , 31., 32., - 33., 32.78125, 33.78125 ,34.78125 ,34., 35., 36., 34.28125, 35.28125, 36.28125, - 31. ,32., 33. , 32.21875, 33.21875, 34.21875, 34. ,35. ,36., 35.5, - 36.5 , 37.5 , 37., 38. ,39. ,38.78125, 39.78125, 40.78125, 40., 41., - 42. ,40.28125 ,41.28125, 42.28125, 37. , 38., 39., 38.21875 ,39.21875 ,40.21875, - 40. , 41. , 42. , 41.5, 42.5, 43.5 ,43., 44., 45., 44.78125, - 45.78125, 46.78125 ,46. ,47. , 48. , 46.28125 , 47.28125, 48.28125, 44.125 ,45.125, - 46.125, 45.34375, 46.34375, 47.34375, 47.125, 48.125 ,49.125 ,48.625, 49.625 , 50.625, - 50.125 , 51.125, 52.125 ,51.90625 ,52.90625, 53.90625, 53.125, 54.125, 55.125, 53.40625, - 54.40625 ,55.40625, 49. ,50. , 51. ,50.21875, 51.21875 ,52.21875 ,52. ,53., - 54. ,53.5 , 54.5, 55.5 ,55. ,56. ,57. ,56.78125 ,57.78125, 58.78125, - 58. ,59. ,60. ,58.28125 ,59.28125 ,60.28125, 50.125, 51.125 ,52.125 ,51.34375, - 52.34375 ,53.34375 ,53.125, 54.125, 55.125 ,54.625 ,55.625 ,56.625 ,56.125 ,57.125, - 58.125, 57.90625 ,58.90625 ,59.90625 ,59.125 ,60.125 ,61.125, 59.40625, 60.40625 ,61.40625, - 61. ,62. ,63. ,62.21875, 63.21875, 64.21875 ,64. ,65. ,66. ,65.5 , - 66.5, 67.5, 67. ,68. ,69. ,68.78125 ,69.78125 ,70.78125 ,70., 71. , - 72. ,70.28125 ,71.28125 ,72.28125 ,65.875 ,66.875, 67.875 ,67.09375 ,68.09375 ,69.09375, - 68.875 ,69.875 ,70.875, 70.375 ,71.375 ,72.375 ,71.875 ,72.875 ,73.875 ,73.65625, - 74.65625 ,75.65625 ,74.875 ,75.875, 76.875 ,75.15625 ,76.15625, - 77.15625 ,73. ,74. ,75., 74.21875 ,75.21875 ,76.21875, - 76. ,77. ,78. ,77.5 ,78.5 ,79.5 ,79., - 80. ,81. ,80.78125 ,81.78125, 82.78125 ,82. ,83., - 84. ,82.28125 ,83.28125 ,84.28125, 79. ,80. ,81., - 80.21875 ,81.21875 ,82.21875 ,82., 83. ,84. ,83.5, - 84.5 ,85.5 ,85. ,86., 87. ,86.78125 ,87.78125, - 88.78125 ,88. ,89. ,90., 88.28125 ,89.28125 ,90.28125, - 85. ,86. ,87. ,86.21875, 87.21875 ,88.21875 ,88., - 89. ,90. ,89.5 ,90.5, 91.5 ,91. ,92., - 93. ,92.78125 ,93.78125 ,94.78125, 94. ,95. ,96., - 94.28125 ,95.28125 ,96.28125 ,91., 92. ,93. ,92.21875, - 93.21875 ,94.21875 ,94. ,95., 96. ,95.5 ,96.5, - 97.5 ,97. ,98. ,99., 98.78125 ,99.78125 ,100.78125, - 100. ,101. ,102. ,100.28125, 101.28125 ,102.28125, 97., - 98. ,99. ,98.21875 ,99.21875, 100.21875 ,100., 101., - 102. ,101.5 ,102.5 ,103.5, 103. ,104., 105., - 104.78125 ,105.78125 ,106.78125 ,106., 107. ,108., 106.28125, - 107.28125 ,108.28125 ,104.125 ,105.125, 106.125 ,105.34375, 106.34375, - 107.34375 ,107.125 ,108.125 ,109.125, 108.625 ,109.625, 110.625, - 110.125 ,111.125 ,112.125 ,111.90625, 112.90625 ,113.90625, 113.125, - 114.125 ,115.125 ,113.40625 ,114.40625, 115.40625 ,109., 110., - 111. ,110.21875 ,111.21875 ,112.21875, 112., 113., 114., - 113.5 ,114.5 ,115.5 ,115., 116., 117., 116.78125, - 117.78125 ,118.78125 ,118. ,119., 120., 118.28125, 119.28125, - 120.28125 ,110.125 ,111.125 ,112.125, 111.34375, 112.34375, 113.34375, - 113.125 ,114.125 ,115.125 ,114.625, 115.625, 116.625, 116.125, - 117.125 ,118.125 ,117.90625, 118.90625, 119.90625, 119.125, 120.125, - 121.125 ,119.40625 ,120.40625, 121.40625}); //input = 1.f; + NDArray expected = NDArrayFactory::create('c', {2, 10, 8, 3}, { + 1.000000f, 2.000000f, 3.000000f, 2.218750f, 3.218750f, 4.218750f, 4.000000f, 5.000000f, 6.000000f, + 5.500000f, 6.500000f, 7.500000f, 7.000000f, 8.000000f, 9.000000f, 8.781250f, 9.781250f, 10.781250f, + 10.000000f, 11.000000f, 12.000000f, 10.281250f, 11.281250f, 12.281250f, 5.875000f, 6.875000f, 7.875000f, + 7.093750f, 8.093750f, 9.093750f, 8.875000f, 9.875000f, 10.875000f, 10.375000f, 11.375000f, 12.375000f, + 11.875000f, 12.875000f, 13.875000f, 13.656250f, 14.656250f, 15.656250f, 14.875000f, 15.875000f, 16.875000f, + 15.156250f, 16.156250f, 17.156250f, 13.000000f, 14.000000f, 15.000000f, 14.218750f, 15.218750f, 16.218750f, + 16.000000f, 17.000000f, 18.000000f, 17.500000f, 18.500000f, 19.500000f, 19.000000f, 20.000000f, 21.000000f, + 20.781250f, 21.781250f, 22.781250f, 22.000000f, 23.000000f, 24.000000f, 22.281250f, 23.281250f, 24.281250f, + 19.000000f, 20.000000f, 21.000000f, 20.218750f, 21.218750f, 22.218750f, 22.000000f, 23.000000f, 24.000000f, + 23.500000f, 24.500000f, 25.500000f, 25.000000f, 26.000000f, 27.000000f, 26.781250f, 27.781250f, 28.781250f, + 28.000000f, 29.000000f, 30.000000f, 28.281250f, 29.281250f, 30.281250f, 25.000000f, 26.000000f, 27.000000f, + 26.218750f, 27.218750f, 28.218750f, 28.000000f, 29.000000f, 30.000000f, 29.500000f, 30.500000f, 31.500000f, + 31.000000f, 32.000000f, 33.000000f, 32.781250f, 33.781250f, 34.781250f, 34.000000f, 35.000000f, 36.000000f, + 34.281250f, 35.281250f, 36.281250f, 31.000000f, 32.000000f, 33.000000f, 32.218750f, 33.218750f, 34.218750f, + 34.000000f, 35.000000f, 36.000000f, 35.500000f, 36.500000f, 37.500000f, 37.000000f, 38.000000f, 39.000000f, + 38.781250f, 39.781250f, 40.781250f, 40.000000f, 41.000000f, 42.000000f, 40.281250f, 41.281250f, 42.281250f, + 37.000000f, 38.000000f, 39.000000f, 38.218750f, 39.218750f, 40.218750f, 40.000000f, 41.000000f, 42.000000f, + 41.500000f, 42.500000f, 43.500000f, 43.000000f, 44.000000f, 45.000000f, 44.781250f, 45.781250f, 46.781250f, + 46.000000f, 47.000000f, 48.000000f, 46.281250f, 47.281250f, 48.281250f, 44.125000f, 45.125000f, 46.125000f, + 45.343750f, 46.343750f, 47.343750f, 47.125000f, 48.125000f, 49.125000f, 48.625000f, 49.625000f, 50.625000f, + 50.125000f, 51.125000f, 52.125000f, 51.906250f, 52.906250f, 53.906250f, 53.125000f, 54.125000f, 55.125000f, + 53.406250f, 54.406250f, 55.406250f, 49.000000f, 50.000000f, 51.000000f, 50.218750f, 51.218750f, 52.218750f, + 52.000000f, 53.000000f, 54.000000f, 53.500000f, 54.500000f, 55.500000f, 55.000000f, 56.000000f, 57.000000f, + 56.781250f, 57.781250f, 58.781250f, 58.000000f, 59.000000f, 60.000000f, 58.281250f, 59.281250f, 60.281250f, + 50.125000f, 51.125000f, 52.125000f, 51.343750f, 52.343750f, 53.343750f, 53.125000f, 54.125000f, 55.125000f, + 54.625000f, 55.625000f, 56.625000f, 56.125000f, 57.125000f, 58.125000f, 57.906250f, 58.906250f, 59.906250f, + 59.125000f, 60.125000f, 61.125000f, 59.406250f, 60.406250f, 61.406250f, 61.000000f, 62.000000f, 63.000000f, + 62.218750f, 63.218750f, 64.218750f, 64.000000f, 65.000000f, 66.000000f, 65.500000f, 66.500000f, 67.500000f, + 67.000000f, 68.000000f, 69.000000f, 68.781250f, 69.781250f, 70.781250f, 70.000000f, 71.000000f, 72.000000f, + 70.281250f, 71.281250f, 72.281250f, 65.875000f, 66.875000f, 67.875000f, 67.093750f, 68.093750f, 69.093750f, + 68.875000f, 69.875000f, 70.875000f, 70.375000f, 71.375000f, 72.375000f, 71.875000f, 72.875000f, 73.875000f, + 73.656250f, 74.656250f, 75.656250f, 74.875000f, 75.875000f, 76.875000f, 75.156250f, 76.156250f, 77.156250f, + 73.000000f, 74.000000f, 75.000000f, 74.218750f, 75.218750f, 76.218750f, 76.000000f, 77.000000f, 78.000000f, + 77.500000f, 78.500000f, 79.500000f, 79.000000f, 80.000000f, 81.000000f, 80.781250f, 81.781250f, 82.781250f, + 82.000000f, 83.000000f, 84.000000f, 82.281250f, 83.281250f, 84.281250f, 79.000000f, 80.000000f, 81.000000f, + 80.218750f, 81.218750f, 82.218750f, 82.000000f, 83.000000f, 84.000000f, 83.500000f, 84.500000f, 85.500000f, + 85.000000f, 86.000000f, 87.000000f, 86.781250f, 87.781250f, 88.781250f, 88.000000f, 89.000000f, 90.000000f, + 88.281250f, 89.281250f, 90.281250f, 85.000000f, 86.000000f, 87.000000f, 86.218750f, 87.218750f, 88.218750f, + 88.000000f, 89.000000f, 90.000000f, 89.500000f, 90.500000f, 91.500000f, 91.000000f, 92.000000f, 93.000000f, + 92.781250f, 93.781250f, 94.781250f, 94.000000f, 95.000000f, 96.000000f, 94.281250f, 95.281250f, 96.281250f, + 91.000000f, 92.000000f, 93.000000f, 92.218750f, 93.218750f, 94.218750f, 94.000000f, 95.000000f, 96.000000f, + 95.500000f, 96.500000f, 97.500000f, 97.000000f, 98.000000f, 99.000000f, 98.781250f, 99.781250f, 100.781250f, + 100.000000f, 101.000000f, 102.000000f, 100.281250f, 101.281250f, 102.281250f, 97.000000f, 98.000000f, + 99.000000f, 98.218750f, 99.218750f, 100.218750f, 100.000000f, 101.000000f, 102.000000f, 101.500000f, + 102.500000f, 103.500000f, 103.000000f, 104.000000f, 105.000000f, 104.781250f, 105.781250f, 106.781250f, + 106.000000f, 107.000000f, 108.000000f, 106.281250f, 107.281250f, 108.281250f, 104.125000f, 105.125000f, + 106.125000f, 105.343750f, 106.343750f, 107.343750f, 107.125000f, 108.125000f, 109.125000f, 108.625000f, + 109.625000f, 110.625000f, 110.125000f, 111.125000f, 112.125000f, 111.906250f, 112.906250f, 113.906250f, + 113.125000f, 114.125000f, 115.125000f, 113.406250f, 114.406250f, 115.406250f, 109.000000f, 110.000000f, + 111.000000f, 110.218750f, 111.218750f, 112.218750f, 112.000000f, 113.000000f, 114.000000f, 113.500000f, + 114.500000f, 115.500000f, 115.000000f, 116.000000f, 117.000000f, 116.781250f, 117.781250f, 118.781250f, + 118.000000f, 119.000000f, 120.000000f, 118.281250f, 119.281250f, 120.281250f, 110.125000f, 111.125000f, + 112.125000f, 111.343750f, 112.343750f, 113.343750f, 113.125000f, 114.125000f, 115.125000f, 114.625000f, + 115.625000f, 116.625000f, 116.125000f, 117.125000f, 118.125000f, 117.906250f, 118.906250f, 119.906250f, + 119.125000f, 120.125000f, 121.125000f, 119.406250f, 120.406250f, 121.406250f + }); //input = 1.f; input.linspace(1); auto size = NDArrayFactory::create({10, 8}); nd4j::ops::resize_bicubic op; @@ -733,48 +732,23 @@ TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test2) { TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test3) { NDArray input = NDArrayFactory::create('c', {1, 3, 3, 4}); - NDArray expected = NDArrayFactory::create('c', {1, 6, 6, 4}, { - 1. ,2. ,3. ,4., - 2.625 ,3.625 ,4.625 ,5.625, - 5. ,6. ,7. ,8., - 7.375 ,8.375 ,9.375, 10.375, - 9. ,10. ,11. ,12., - 9.375 ,10.375 ,11.375 ,12.375, - - 5.875 ,6.875 ,7.875 , 8.875 , - 7.5 ,8.5 ,9.5 , 10.5 , - 9.875 ,10.875 ,11.875, 12.875, - 12.25 ,13.25 ,14.25 , 15.25 , - 13.875 ,14.875 ,15.875, 16.875, - 14.25 ,15.25 ,16.25 , 17.25 , - - 13. ,14. ,15. ,16., - 14.625 ,15.625 ,16.625 ,17.625, - 17. ,18. ,19. ,20., - 19.375 ,20.375 ,21.375 ,22.375, - 21. ,22. ,23. ,24., - 21.375 ,22.375 ,23.375 ,24.375, - - 20.125 ,21.125 ,22.125 ,23.125, - 21.75 ,22.75 ,23.75 ,24.75, - 24.125 ,25.125 ,26.125 ,27.125, - 26.5 ,27.5 ,28.5 ,29.5, - 28.125 ,29.125 ,30.125 ,31.125, - 28.5 ,29.5 ,30.5 ,31.5, - - 25. , 26. , 27. , 28., - 26.625 ,27.625 ,28.625 ,29.625, - 29. ,30. ,31. ,32., - 31.375 ,32.375 ,33.375 ,34.375, - 33. ,34. ,35. ,36., - 33.375 ,34.375 ,35.375 ,36.375, - - 26.125, 27.125, 28.125, 29.125, - 27.75 ,28.75 ,29.75 ,30.75, - 30.125 ,31.125 ,32.125 ,33.125, - 32.5 ,33.5 ,34.5 ,35.5, - 34.125 ,35.125 ,36.125 ,37.125, - 34.5 ,35.5 ,36.5 ,37.5 + NDArray expected = NDArrayFactory::create('c', {1, 6, 6, 4}, { + 1.000000f, 2.000000f, 3.000000f, 4.000000f, 2.625000f, 3.625000f, 4.625000f, 5.625000f, 5.000000f, + 6.000000f, 7.000000f, 8.000000f, 7.375000f, 8.375000f, 9.375000f, 10.375000f, 9.000000f, 10.000000f, + 11.000000f, 12.000000f, 9.375000f, 10.375000f, 11.375000f, 12.375000f, 5.875000f, 6.875000f, 7.875000f, + 8.875000f, 7.500000f, 8.500000f, 9.500000f, 10.500000f, 9.875000f, 10.875000f, 11.875000f, 12.875000f, + 12.250000f, 13.250000f, 14.250000f, 15.250000f, 13.875000f, 14.875000f, 15.875000f, 16.875000f, 14.250000f, + 15.250000f, 16.250000f, 17.250000f, 13.000000f, 14.000000f, 15.000000f, 16.000000f, 14.625000f, 15.625000f, + 16.625000f, 17.625000f, 17.000000f, 18.000000f, 19.000000f, 20.000000f, 19.375000f, 20.375000f, 21.375000f, + 22.375000f, 21.000000f, 22.000000f, 23.000000f, 24.000000f, 21.375000f, 22.375000f, 23.375000f, 24.375000f, + 20.125000f, 21.125000f, 22.125000f, 23.125000f, 21.750000f, 22.750000f, 23.750000f, 24.750000f, 24.125000f, + 25.125000f, 26.125000f, 27.125000f, 26.500000f, 27.500000f, 28.500000f, 29.500000f, 28.125000f, 29.125000f, + 30.125000f, 31.125000f, 28.500000f, 29.500000f, 30.500000f, 31.500000f, 25.000000f, 26.000000f, 27.000000f, + 28.000000f, 26.625000f, 27.625000f, 28.625000f, 29.625000f, 29.000000f, 30.000000f, 31.000000f, 32.000000f, + 31.375000f, 32.375000f, 33.375000f, 34.375000f, 33.000000f, 34.000000f, 35.000000f, 36.000000f, 33.375000f, + 34.375000f, 35.375000f, 36.375000f, 26.125000f, 27.125000f, 28.125000f, 29.125000f, 27.750000f, 28.750000f, + 29.750000f, 30.750000f, 30.125000f, 31.125000f, 32.125000f, 33.125000f, 32.500000f, 33.500000f, 34.500000f, + 35.500000f, 34.125000f, 35.125000f, 36.125000f, 37.125000f, 34.500000f, 35.500000f, 36.500000f, 37.500000f }); input.linspace(1); auto size = NDArrayFactory::create({6, 6}); @@ -795,60 +769,24 @@ TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test3) { TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test4) { NDArray input = NDArrayFactory::create('c', {1, 3, 4, 3}); - NDArray expected = NDArrayFactory::create('c', {1, 6, 8, 3}, { - 1. , 2. , 3. , - 2.21875 ,3.21875 ,4.21875, - 4. ,5. ,6. , - 5.5 ,6.5 ,7.5 , - 7. ,8. ,9. , - 8.78125 ,9.78125, 10.78125, - 10. ,11., 12. , - 10.28125 ,11.28125, 12.28125, - - 5.875 , 6.875 , 7.875 , - 7.09375 , 8.09375 , 9.09375, - 8.875 , 9.875 ,10.875 , - 10.375 ,11.375 ,12.375 , - 11.875 ,12.875 ,13.875 , - 13.65625 ,14.65625 ,15.65625, - 14.875 ,15.875 ,16.875 , - 15.15625 ,16.15625 ,17.15625, - - 13., 14., 15., - 14.21875 ,15.21875 ,16.21875, - 16. ,17. ,18. , - 17.5 ,18.5 ,19.5 , - 19. ,20. ,21. , - 20.78125 ,21.78125 ,22.78125, - 22. ,23. ,24. , - 22.28125 ,23.28125 ,24.28125, - - 20.125 , 21.125 , 22.125, - 21.34375 ,22.34375 ,23.34375, - 23.125 ,24.125 ,25.125 , - 24.625 ,25.625 ,26.625 , - 26.125 ,27.125 ,28.125 , - 27.90625 ,28.90625 ,29.90625, - 29.125 ,30.125 ,31.125 , - 29.40625 ,30.40625 ,31.40625, - - 25. ,26. ,27. , - 26.21875 ,27.21875 ,28.21875, - 28. ,29. ,30. , - 29.5 ,30.5 ,31.5 , - 31. ,32. ,33. , - 32.78125 ,33.78125 ,34.78125, - 34. ,35. ,36. , - 34.28125 ,35.28125 ,36.28125, - - 26.125 ,27.125 , 28.125 , - 27.34375 ,28.34375 ,29.34375, - 29.125 ,30.125 ,31.125 , - 30.625 ,31.625 ,32.625 , - 32.125 ,33.125 ,34.125 , - 33.90625 ,34.90625 ,35.90625, - 35.125 ,36.125 ,37.125 , - 35.40625 ,36.40625 ,37.40625 }); + NDArray expected = NDArrayFactory::create('c', {1, 6, 8, 3}, { + 1.000000f, 2.000000f, 3.000000f, 2.218750f, 3.218750f, 4.218750f, 4.000000f, 5.000000f, 6.000000f, + 5.500000f, 6.500000f, 7.500000f, 7.000000f, 8.000000f, 9.000000f, 8.781250f, 9.781250f, 10.781250f, + 10.000000f, 11.000000f, 12.000000f, 10.281250f, 11.281250f, 12.281250f, 5.875000f, 6.875000f, 7.875000f, + 7.093750f, 8.093750f, 9.093750f, 8.875000f, 9.875000f, 10.875000f, 10.375000f, 11.375000f, 12.375000f, + 11.875000f, 12.875000f, 13.875000f, 13.656250f, 14.656250f, 15.656250f, 14.875000f, 15.875000f, 16.875000f, + 15.156250f, 16.156250f, 17.156250f, 13.000000f, 14.000000f, 15.000000f, 14.218750f, 15.218750f, 16.218750f, + 16.000000f, 17.000000f, 18.000000f, 17.500000f, 18.500000f, 19.500000f, 19.000000f, 20.000000f, 21.000000f, + 20.781250f, 21.781250f, 22.781250f, 22.000000f, 23.000000f, 24.000000f, 22.281250f, 23.281250f, 24.281250f, + 20.125000f, 21.125000f, 22.125000f, 21.343750f, 22.343750f, 23.343750f, 23.125000f, 24.125000f, 25.125000f, + 24.625000f, 25.625000f, 26.625000f, 26.125000f, 27.125000f, 28.125000f, 27.906250f, 28.906250f, 29.906250f, + 29.125000f, 30.125000f, 31.125000f, 29.406250f, 30.406250f, 31.406250f, 25.000000f, 26.000000f, 27.000000f, + 26.218750f, 27.218750f, 28.218750f, 28.000000f, 29.000000f, 30.000000f, 29.500000f, 30.500000f, 31.500000f, + 31.000000f, 32.000000f, 33.000000f, 32.781250f, 33.781250f, 34.781250f, 34.000000f, 35.000000f, 36.000000f, + 34.281250f, 35.281250f, 36.281250f, 26.125000f, 27.125000f, 28.125000f, 27.343750f, 28.343750f, 29.343750f, + 29.125000f, 30.125000f, 31.125000f, 30.625000f, 31.625000f, 32.625000f, 32.125000f, 33.125000f, 34.125000f, + 33.906250f, 34.906250f, 35.906250f, 35.125000f, 36.125000f, 37.125000f, 35.406250f, 36.406250f, 37.406250f + }); input.linspace(1); auto size = NDArrayFactory::create({6, 8}); nd4j::ops::resize_bicubic op; @@ -868,32 +806,30 @@ TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test4) { TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test5) { NDArray input = NDArrayFactory::create('c', {1, 4, 4, 3}); - NDArray expected = NDArrayFactory::create('c', {1, 8, 8, 3}, { - 1. ,2. , 3. , 2.21875 , 3.21875 , 4.21875 , 4. , 5. , - 6. ,5.5 , 6.5 , 7.5 , 7. , 8. , 9. , 8.78125 , - 9.78125 ,10.78125 ,10. ,11. ,12. ,10.28125 ,11.28125 ,12.28125 , - 5.875 ,6.875 , 7.875 , 7.09375 , 8.09375 , 9.09375 , 8.875 , 9.875 , - 10.875 ,10.375 , 11.375 , 12.375 , 11.875 , 12.875 , 13.875 , 13.65625, - 14.65625 ,15.65625, 14.875 , 15.875 , 16.875 , 15.15625, 16.15625, 17.15625, - 13. ,14. , 15. , 14.21875, 15.21875, 16.21875, 16. , 17. , - 18. ,17.5 , 18.5 , 19.5 , 19. , 20. , 21. , 20.78125, - 21.78125 ,22.78125, 22. , 23. , 24. , 22.28125, 23.28125, 24.28125, - 19. ,20. , 21. , 20.21875, 21.21875, 22.21875, 22. , 23. , - 24. ,23.5 , 24.5 , 25.5 , 25. , 26. , 27. , 26.78125, - 27.78125 ,28.78125, 28. , 29. , 30. , 28.28125, 29.28125, 30.28125, - 25. ,26. , 27. , 26.21875, 27.21875, 28.21875, 28. , 29. , - 30. ,29.5 , 30.5 , 31.5 , 31. , 32. , 33. , 32.78125, - 33.78125 ,34.78125, 34. , 35. , 36. , 34.28125, 35.28125, 36.28125, - 32.125 ,33.125 , 34.125 , 33.34375, 34.34375, 35.34375, 35.125 , 36.125 , - 37.125 ,36.625 , 37.625 , 38.625 , 38.125 , 39.125 , 40.125 , 39.90625, - 40.90625 ,41.90625, 41.125 , 42.125 , 43.125 , 41.40625, 42.40625, 43.40625, - 37. ,38. , 39. , 38.21875, 39.21875, 40.21875, 40. , 41. , - 42. ,41.5 , 42.5 , 43.5 , 43. , 44. , 45. , 44.78125, - 45.78125 ,46.78125, 46. , 47. , 48. , 46.28125, 47.28125, 48.28125, - 38.125 ,39.125 , 40.125 , 39.34375, 40.34375, 41.34375, 41.125 , 42.125 , - 43.125 ,42.625 , 43.625 , 44.625 , 44.125 , 45.125 , 46.125 , 45.90625, - 46.90625 ,47.90625, 47.125 , 48.125 , 49.125 , 47.40625, 48.40625, 49.40625, - }); + NDArray expected = NDArrayFactory::create('c', {1, 8, 8, 3}, { + 1.000000f, 2.000000f, 3.000000f, 2.218750f, 3.218750f, 4.218750f, 4.000000f, 5.000000f, 6.000000f, + 5.500000f, 6.500000f, 7.500000f, 7.000000f, 8.000000f, 9.000000f, 8.781250f, 9.781250f, 10.781250f, + 10.000000f, 11.000000f, 12.000000f, 10.281250f, 11.281250f, 12.281250f, 5.875000f, 6.875000f, 7.875000f, + 7.093750f, 8.093750f, 9.093750f, 8.875000f, 9.875000f, 10.875000f, 10.375000f, 11.375000f, 12.375000f, + 11.875000f, 12.875000f, 13.875000f, 13.656250f, 14.656250f, 15.656250f, 14.875000f, 15.875000f, 16.875000f, + 15.156250f, 16.156250f, 17.156250f, 13.000000f, 14.000000f, 15.000000f, 14.218750f, 15.218750f, 16.218750f, + 16.000000f, 17.000000f, 18.000000f, 17.500000f, 18.500000f, 19.500000f, 19.000000f, 20.000000f, 21.000000f, + 20.781250f, 21.781250f, 22.781250f, 22.000000f, 23.000000f, 24.000000f, 22.281250f, 23.281250f, 24.281250f, + 19.000000f, 20.000000f, 21.000000f, 20.218750f, 21.218750f, 22.218750f, 22.000000f, 23.000000f, 24.000000f, + 23.500000f, 24.500000f, 25.500000f, 25.000000f, 26.000000f, 27.000000f, 26.781250f, 27.781250f, 28.781250f, + 28.000000f, 29.000000f, 30.000000f, 28.281250f, 29.281250f, 30.281250f, 25.000000f, 26.000000f, 27.000000f, + 26.218750f, 27.218750f, 28.218750f, 28.000000f, 29.000000f, 30.000000f, 29.500000f, 30.500000f, 31.500000f, + 31.000000f, 32.000000f, 33.000000f, 32.781250f, 33.781250f, 34.781250f, 34.000000f, 35.000000f, 36.000000f, + 34.281250f, 35.281250f, 36.281250f, 32.125000f, 33.125000f, 34.125000f, 33.343750f, 34.343750f, 35.343750f, + 35.125000f, 36.125000f, 37.125000f, 36.625000f, 37.625000f, 38.625000f, 38.125000f, 39.125000f, 40.125000f, + 39.906250f, 40.906250f, 41.906250f, 41.125000f, 42.125000f, 43.125000f, 41.406250f, 42.406250f, 43.406250f, + 37.000000f, 38.000000f, 39.000000f, 38.218750f, 39.218750f, 40.218750f, 40.000000f, 41.000000f, 42.000000f, + 41.500000f, 42.500000f, 43.500000f, 43.000000f, 44.000000f, 45.000000f, 44.781250f, 45.781250f, 46.781250f, + 46.000000f, 47.000000f, 48.000000f, 46.281250f, 47.281250f, 48.281250f, 38.125000f, 39.125000f, 40.125000f, + 39.343750f, 40.343750f, 41.343750f, 41.125000f, 42.125000f, 43.125000f, 42.625000f, 43.625000f, 44.625000f, + 44.125000f, 45.125000f, 46.125000f, 45.906250f, 46.906250f, 47.906250f, 47.125000f, 48.125000f, 49.125000f, + 47.406250f, 48.406250f, 49.406250f, + }); input.linspace(1); auto size = NDArrayFactory::create({8, 8}); nd4j::ops::resize_bicubic op; @@ -912,167 +848,118 @@ TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test5) { TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test6) { - NDArray input = NDArrayFactory::create('c', {7, 7, 1}, { - 1, 2.1, 3.15, 4.2, 5.15, 6.1, 7, - 8, 9.1, 10., 11, 12.9, 13.1, 14, - 15, 16., 17., 18, 19, 20., 21, - 22, 23., 24., 25, 26, 27, 28, - 30, 31, 32, 33, 34., 35, 36, - 37, 38, 39, 40, 41., 42, 43, - 44, 45, 46, 47, 48., 49, 50 + NDArray input = NDArrayFactory::create('c', {7, 7, 1}, { + 1.f, 2.1f, 3.15f, 4.2f, 5.15f, 6.1f, 7.f, + 8.f, 9.1f, 10.f, 11.f, 12.9f, 13.1f, 14.f, + 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, + 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, + 30.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, + 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 43.f, + 44.f, 45.f, 46.f, 47.f, 48.f, 49.f, 50.f }); - NDArray expected = NDArrayFactory::create('c', {30, 30, 1}, { - 1. ,1.1976162 ,1.4174359 ,1.6775769 ,1.9961575 ,2.3283265 , - 2.550918 ,2.7360606 ,2.9655411 ,3.2929654 ,3.5441515 ,3.7380352 , - 3.948995 ,4.248106 ,4.5073795 ,4.6843743 ,4.8572845 ,5.104302 , - 5.3869915 ,5.581401 ,5.7539616 ,5.974285 ,6.272836 ,6.5204263 , - 6.718899 ,6.8871036 ,7.039068 ,7.099216 ,7.0784245 ,7.0281887 , - 2.247592 ,2.446947 ,2.6694887 ,2.9312382 ,3.248216 ,3.5745337 , - 3.78931 ,3.9656973 ,4.186417 ,4.5046535 ,4.740569 ,4.9217057 , - 5.133866 ,5.459533 ,5.7744613 ,6.0197873 ,6.254011 ,6.535633 , - 6.8097296 ,6.9607787 ,7.0749416 ,7.241601 ,7.5094895 ,7.7499495 , - 7.954571 ,8.131972 ,8.286526 ,8.346463 ,8.325745 ,8.275683 , - 3.6286845 ,3.830573 ,4.0569587 ,4.3211575 ,4.6364856 ,4.9556503 , - 5.160583 ,5.3258467 ,5.535462 ,5.84216 ,6.058749 ,6.223753 , - 6.437597 ,6.797369 ,7.1836042 ,7.5164022 ,7.8290343 ,8.154773 , - 8.417635 ,8.512958 ,8.5521 ,8.649708 ,8.87788 ,9.108794 , - 9.320926 ,9.509781 ,9.667375 ,9.72694 ,9.706349 ,9.656599 , - 5.276778 ,5.480438 ,5.709702 ,5.9754477 ,6.288551 ,6.6005697 , - 6.796207 ,6.9511423 ,7.1503997 ,7.4461427 ,7.644651 ,7.794562 , - 8.009684 ,8.400473 ,8.851847 ,9.26469 ,9.649218, 10.015648 , - 10.268647 ,10.313368 ,10.2843275 ,10.319379 ,10.512033 ,10.734956 , - 10.954604 ,11.154507 ,11.315369 ,11.374779 ,11.354242 ,11.304622 , - 7.325373 ,7.5284843 ,7.757575 ,8.022221 ,8.331997 ,8.638187 , - 8.827649 ,8.976217 ,9.168955 ,9.45726 ,9.6442375 ,9.784517 , - 9.999621, 10.407702 ,10.896234, 11.355122, 11.781423, 12.172186 , - 12.420712 ,12.4374485 ,12.370511 ,12.371386 ,12.545973 ,12.766424 , - 12.992249 ,13.20012 ,13.364252 ,13.424109 ,13.40342 ,13.353425 , - 9.493208 ,9.692467 ,9.9169445, 10.176801, 10.482199, 10.78547 , - 10.974367 ,11.123442 ,11.31637 ,11.603645 ,11.790616 ,11.930889 , - 12.144082 ,12.546447 ,13.024898 ,13.4723 ,13.889232 ,14.276275 , - 14.528972 ,14.555555 ,14.50145 ,14.515459 ,14.700572 ,14.927055 , - 15.156046 ,15.366046 ,15.532901 ,15.594008 ,15.5728855 ,15.521847 , - 10.970133 ,11.163599 ,11.380694 ,11.633735 ,11.935032 ,12.238887 , - 12.43254 ,12.588294 ,12.787534 ,13.079956 ,13.27752 ,13.426631 , - 13.636713 ,14.013844 ,14.441672 ,14.827978 ,15.191209 ,15.549808 , - 15.81343 ,15.881828 ,15.883522 ,15.950411 ,16.16933 ,16.40794 , - 16.636436 ,16.842583 ,17.010887 ,17.07363 ,17.05194 ,16.999537 , - 12.219155 ,12.406129 ,12.614796 ,12.860335 ,13.157928 ,13.464224 , - 13.665207 ,13.830567 ,14.039036 ,14.339629 ,14.552863 ,14.715049 , - 14.921564 ,15.264454 ,15.622843 ,15.924977 ,16.213829 ,16.532364 , - 16.8099 ,16.934835 ,17.012146 ,17.150164 ,17.413412 ,17.666712 , - 17.892765 ,18.09207 ,18.261044 ,18.325531 ,18.303238 ,18.249378 , - 13.7663965 ,13.947391 ,14.148263 ,14.386917 ,14.681246 ,14.990087 , - 15.198166 ,15.372728 ,15.590062 ,15.898583 ,16.126892 ,16.301655 , - 16.50487 ,16.815214 ,17.107498 ,17.329458 ,17.547403 ,17.827654 , - 18.118288 ,18.296928 ,18.4461 ,18.651634 ,18.956806 ,19.22382 , - 19.447308 ,19.639887 ,19.809319 ,19.875397 ,19.852556 ,19.797365 , - 15.9419365 ,16.118704 ,16.314133 ,16.547867 ,16.839561 ,17.14954 , - 17.361883 ,17.542162 ,17.764957 ,18.078188 ,18.315733 ,18.498205 , - 18.699116 ,18.988684 ,19.238989 ,19.410137 ,19.583265 ,19.839512 , - 20.13878 ,20.35177 ,20.546844 ,20.795671 ,21.128067 ,21.404358 , - 21.626736 ,21.8155 ,21.98561 ,22.052843 ,22.029604 ,21.973448 , - 17.53522 ,17.71077 ,17.904636 ,18.13695 ,18.42784 ,18.738056 , - 18.951529 ,19.133352 ,19.357613 ,19.672083 ,19.912102 ,20.096638 , - 20.296894 ,20.580765 ,20.819603 ,20.976887 ,21.137802 ,21.387535 , - 21.689209 ,21.911621 ,22.119276 ,22.37999 ,22.71991 ,22.998823 , - 23.22097 ,23.40876 ,23.57911 ,23.646685 ,23.623325 ,23.566887 , - 18.746353 ,18.922657 ,19.117487 ,19.350685 ,19.64207 ,19.952137 , - 20.164913 ,20.345781 ,20.569134 ,20.88284 ,21.12133 ,21.30459 , - 21.505253 ,21.792645 ,22.038572 ,22.204426 ,22.37289 ,22.626648 , - 22.926834 ,23.143423 ,23.343302 ,23.596668 ,23.931936 ,24.209232 , - 24.431519 ,24.619913 ,24.79011 ,24.857473 ,24.83419 ,24.777927 , - 20.16656 ,20.344206 ,20.540766 ,20.775532 ,21.067804 ,21.377607 , - 21.589132 ,21.768297 ,21.99003 ,22.302366 ,22.538124 ,22.719105 , - 22.920494 ,23.214176 ,23.472767 ,23.653934 ,23.83589 ,24.096842 , - 24.394371 ,24.600555 ,24.786541 ,25.026773 ,25.353731 ,25.62813 , - 25.850672 ,26.04014 ,26.210072 ,26.277063 ,26.253906 ,26.197956 , - 22.363024 ,22.54125 ,22.738552 ,22.973991 ,23.266647 ,23.57634 , - 23.787327 ,23.96576 ,24.186796 ,24.498543 ,24.733124 ,24.913122 , - 25.114826 ,25.411213 ,25.675262 ,25.863028 ,26.050789 ,26.314838 , - 26.611223 ,26.812925 ,26.992926 ,27.227505 ,27.550882 ,27.824034 , - 28.046684 ,28.236614 ,28.406433 ,28.473265 ,28.450163 ,28.394344 , - 24.429443 ,24.60767 ,24.80497 ,25.04041 ,25.333065 ,25.642756 , - 25.853743 ,26.032173 ,26.25321 ,26.564959 ,26.79954 ,26.97954 , - 27.181242 ,27.47763 ,27.74168 ,27.929441 ,28.117207 ,28.381254 , - 28.677637 ,28.879343 ,29.059345 ,29.293922 ,29.617298 ,29.890451 , - 30.113104 ,30.303034 ,30.472853 ,30.539684 ,30.516582 ,30.460762 , - 26. ,26.178228 ,26.375526 ,26.61097 ,26.903624 ,27.213314 , - 27.424305 ,27.602734 ,27.823772 ,28.135519 ,28.3701 ,28.550098 , - 28.7518 ,29.04819 ,29.312237 ,29.5 ,29.687763 ,29.951813 , - 30.2482 ,30.449903 ,30.629902 ,30.864483 ,31.187859 ,31.461012 , - 31.683659 ,31.873592 ,32.043407 ,32.11024 ,32.087135 ,32.03132 , - 27.570559 ,27.748787 ,27.946087 ,28.181528 ,28.474184 ,28.783876 , - 28.994865 ,29.173294 ,29.39433 ,29.70608 ,29.940659 ,30.120655 , - 30.32236 ,30.618746 ,30.882797 ,31.070557 ,31.25832 ,31.522371 , - 31.818754 ,32.02046 ,32.20046 ,32.43504 ,32.758415 ,33.031567 , - 33.25422 ,33.44415 ,33.613964 ,33.680794 ,33.657696 ,33.60188 , - 29.636976 ,29.815207 ,30.0125 ,30.247944 ,30.5406 ,30.85029 , - 31.061283 ,31.239712 ,31.46075 ,31.7725 ,32.00708 ,32.187077 , - 32.38878 ,32.685165 ,32.949215 ,33.13698 ,33.32474 ,33.58879 , - 33.885178 ,34.086884 ,34.26688 ,34.501457 ,34.824837 ,35.09799 , - 35.320637 ,35.510574 ,35.68039 ,35.747215 ,35.724117 ,35.6683 , - 31.83344 ,32.011665 ,32.20897 ,32.444412 ,32.73707 ,33.046757 , - 33.257744 ,33.436176 ,33.657207 ,33.96896 ,34.203537 ,34.383537 , - 34.58524 ,34.88163 ,35.145676 ,35.33344 ,35.521206 ,35.785255 , - 36.081642 ,36.28334 ,36.46334 ,36.69792 ,37.021297 ,37.294453 , - 37.517097 ,37.707027 ,37.876846 ,37.94368 ,37.920578 ,37.864758 , - 33.253647 ,33.431873 ,33.62917 ,33.864613 ,34.15727 ,34.466957 , - 34.677948 ,34.856377 ,35.077415 ,35.38916 ,35.623745 ,35.803745 , - 36.005447 ,36.301834 ,36.565884 ,36.753647 ,36.941406 ,37.205456 , - 37.50184 ,37.703545 ,37.883545 ,38.118122 ,38.4415 ,38.714653 , - 38.9373 ,39.127235 ,39.297054 ,39.363884 ,39.340782 ,39.28496 , - 34.464783 ,34.64301 ,34.840305 ,35.075752 ,35.368404 ,35.6781 , - 35.889088 ,36.067516 ,36.28855 ,36.6003 ,36.834885 ,37.014877 , - 37.216583 ,37.51297 ,37.77702 ,37.964783 ,38.152546 ,38.416595 , - 38.71298 ,38.914684 ,39.094685 ,39.32926 ,39.652645 ,39.925793 , - 40.14844 ,40.338375 ,40.508194 ,40.575024 ,40.55192 ,40.496105 , - 36.058067 ,36.23629 ,36.43359 ,36.669033 ,36.961685 ,37.271378 , - 37.48237 ,37.6608 ,37.881836 ,38.19359 ,38.42817 ,38.608162 , - 38.809868 ,39.10625 ,39.3703 ,39.558064 ,39.74583 ,40.00988 , - 40.306267 ,40.50797 ,40.68797 ,40.92255 ,41.245926 ,41.519077 , - 41.741722 ,41.931652 ,42.101475 ,42.168304 ,42.145203 ,42.089386 , - 38.315002 ,38.493233 ,38.690533 ,38.925976 ,39.218628 ,39.52832 , - 39.739307 ,39.917736 ,40.138775 ,40.45052 ,40.685104 ,40.865097 , - 41.066803 ,41.36319 ,41.627243 ,41.815002 ,42.002766 ,42.26682 , - 42.5632 ,42.764908 ,42.944904 ,43.179485 ,43.50286 ,43.776016 , - 43.998665 ,44.188595 ,44.358418 ,44.425247 ,44.402145 ,44.34633 , - 40.22708 ,40.40531 ,40.602608 ,40.83805 ,41.130707 ,41.440395 , - 41.651382 ,41.82982 ,42.050854 ,42.3626 ,42.597183 ,42.77718 , - 42.97888 ,43.27527 ,43.53932 ,43.72708 ,43.914845 ,44.178894 , - 44.47528 ,44.676983 ,44.856983 ,45.09156 ,45.41494 ,45.68809 , - 45.91074 ,46.100674 ,46.270493 ,46.337322 ,46.31422 ,46.2584 , - 41.785618 ,41.963844 ,42.161144 ,42.396584 ,42.68924 ,42.998936 , - 43.209923 ,43.388355 ,43.609394 ,43.921143 ,44.15572 ,44.335716 , - 44.53742 ,44.833805 ,45.09786 ,45.285614 ,45.473377 ,45.737427 , - 46.033817 ,46.235523 ,46.415524 ,46.650105 ,46.973476 ,47.24663 , - 47.469276 ,47.65921 ,47.82903 ,47.895855 ,47.872753 ,47.81694 , - 43.11514 ,43.293365 ,43.490665 ,43.726105 ,44.018764 ,44.328457 , - 44.539444 ,44.717873 ,44.93891 ,45.25066 ,45.48524 ,45.665237 , - 45.86694 ,46.163326 ,46.427376 ,46.615143 ,46.802902 ,47.066956 , - 47.363342 ,47.56505 ,47.74505 ,47.979626 ,48.302998 ,48.576153 , - 48.798798 ,48.98873 ,49.158546 ,49.225376 ,49.202282 ,49.146458 , - 44.303867 ,44.482094 ,44.679394 ,44.914833 ,45.207493 ,45.51718 , - 45.72817 ,45.9066 ,46.12764 ,46.439384 ,46.673965 ,46.853966 , - 47.055668 ,47.352055 ,47.6161 ,47.803867 ,47.99163 ,48.25568 , - 48.552063 ,48.75377 ,48.933773 ,49.16835 ,49.491726 ,49.764877 , - 49.987526 ,50.17746 ,50.347275 ,50.4141 ,50.391006 ,50.335186 , - 44.771675 ,44.949905 ,45.1472 ,45.382645 ,45.6753 ,45.98499 , - 46.195976 ,46.374413 ,46.595448 ,46.907196 ,47.141773 ,47.321774 , - 47.523476 ,47.819862 ,48.08391 ,48.27168 ,48.459446 ,48.72349 , - 49.019882 ,49.22158 ,49.401585 ,49.63616 ,49.959538 ,50.232693 , - 50.455338 ,50.64527 ,50.81509 ,50.88192 ,50.858818 ,50.803 , - 44.609966 ,44.788193 ,44.985493 ,45.220936 ,45.51359 ,45.82328 , - 46.03427 ,46.2127 ,46.433743 ,46.74549 ,46.98007 ,47.160065 , - 47.36177 ,47.658157 ,47.922207 ,48.10997 ,48.297733 ,48.561783 , - 48.858166 ,49.059875 ,49.239872 ,49.47445 ,49.79783 ,50.07098 , - 50.293625 ,50.48356 ,50.653378 ,50.720203 ,50.6971 ,50.64128 , - 44.219246 ,44.397472 ,44.594772 ,44.83021 ,45.122868 ,45.43256 , - 45.643543 ,45.82198 ,46.04302 ,46.354763 ,46.589344 ,46.76934 , - 46.971046 ,47.267433 ,47.531483 ,47.719242 ,47.907005 ,48.17105 , - 48.467438 ,48.66914 ,48.849144 ,49.08372 ,49.4071 ,49.680256 , - 49.902905 ,50.092834 ,50.262653 ,50.329483 ,50.30638 ,50.25057}); + NDArray expected = NDArrayFactory::create('c', {30, 30, 1}, { + 1.000000f, 1.197616f, 1.417436f, 1.677577f, 1.996158f, 2.328327f, 2.550918f, 2.736061f, 2.965541f, + 3.292965f, 3.544151f, 3.738035f, 3.948995f, 4.248106f, 4.507379f, 4.684374f, 4.857284f, 5.104302f, + 5.386991f, 5.581401f, 5.753962f, 5.974285f, 6.272836f, 6.520426f, 6.718899f, 6.887104f, 7.039068f, + 7.099216f, 7.078424f, 7.028189f, 2.247592f, 2.446947f, 2.669489f, 2.931238f, 3.248216f, 3.574534f, + 3.789310f, 3.965697f, 4.186417f, 4.504653f, 4.740569f, 4.921706f, 5.133866f, 5.459533f, 5.774461f, + 6.019787f, 6.254011f, 6.535633f, 6.809730f, 6.960779f, 7.074942f, 7.241601f, 7.509489f, 7.749949f, + 7.954571f, 8.131972f, 8.286526f, 8.346463f, 8.325745f, 8.275683f, 3.628684f, 3.830573f, 4.056959f, + 4.321157f, 4.636486f, 4.955650f, 5.160583f, 5.325847f, 5.535462f, 5.842160f, 6.058749f, 6.223753f, + 6.437597f, 6.797369f, 7.183604f, 7.516402f, 7.829034f, 8.154773f, 8.417635f, 8.512958f, 8.552100f, + 8.649708f, 8.877880f, 9.108794f, 9.320926f, 9.509781f, 9.667375f, 9.726940f, 9.706349f, 9.656599f, + 5.276778f, 5.480438f, 5.709702f, 5.975448f, 6.288551f, 6.600570f, 6.796207f, 6.951142f, 7.150400f, + 7.446143f, 7.644651f, 7.794562f, 8.009684f, 8.400473f, 8.851847f, 9.264690f, 9.649218f, 10.015648f, + 10.268647f, 10.313368f, 10.284327f, 10.319379f, 10.512033f, 10.734956f, 10.954604f, 11.154507f, 11.315369f, + 11.374779f, 11.354242f, 11.304622f, 7.325373f, 7.528484f, 7.757575f, 8.022221f, 8.331997f, 8.638187f, + 8.827649f, 8.976217f, 9.168955f, 9.457260f, 9.644237f, 9.784517f, 9.999621f, 10.407702f, 10.896234f, + 11.355122f, 11.781423f, 12.172186f, 12.420712f, 12.437449f, 12.370511f, 12.371386f, 12.545973f, 12.766424f, + 12.992249f, 13.200120f, 13.364252f, 13.424109f, 13.403420f, 13.353425f, 9.493208f, 9.692467f, 9.916944f, + 10.176801f, 10.482199f, 10.785470f, 10.974367f, 11.123442f, 11.316370f, 11.603645f, 11.790616f, 11.930889f, + 12.144082f, 12.546447f, 13.024898f, 13.472300f, 13.889232f, 14.276275f, 14.528972f, 14.555555f, 14.501450f, + 14.515459f, 14.700572f, 14.927055f, 15.156046f, 15.366046f, 15.532901f, 15.594008f, 15.572885f, 15.521847f, + 10.970133f, 11.163599f, 11.380694f, 11.633735f, 11.935032f, 12.238887f, 12.432540f, 12.588294f, 12.787534f, + 13.079956f, 13.277520f, 13.426631f, 13.636713f, 14.013844f, 14.441672f, 14.827978f, 15.191209f, 15.549808f, + 15.813430f, 15.881828f, 15.883522f, 15.950411f, 16.169330f, 16.407940f, 16.636436f, 16.842583f, 17.010887f, + 17.073630f, 17.051940f, 16.999537f, 12.219155f, 12.406129f, 12.614796f, 12.860335f, 13.157928f, 13.464224f, + 13.665207f, 13.830567f, 14.039036f, 14.339629f, 14.552863f, 14.715049f, 14.921564f, 15.264454f, 15.622843f, + 15.924977f, 16.213829f, 16.532364f, 16.809900f, 16.934835f, 17.012146f, 17.150164f, 17.413412f, 17.666712f, + 17.892765f, 18.092070f, 18.261044f, 18.325531f, 18.303238f, 18.249378f, 13.766397f, 13.947391f, 14.148263f, + 14.386917f, 14.681246f, 14.990087f, 15.198166f, 15.372728f, 15.590062f, 15.898583f, 16.126892f, 16.301655f, + 16.504870f, 16.815214f, 17.107498f, 17.329458f, 17.547403f, 17.827654f, 18.118288f, 18.296928f, 18.446100f, + 18.651634f, 18.956806f, 19.223820f, 19.447308f, 19.639887f, 19.809319f, 19.875397f, 19.852556f, 19.797365f, + 15.941937f, 16.118704f, 16.314133f, 16.547867f, 16.839561f, 17.149540f, 17.361883f, 17.542162f, 17.764957f, + 18.078188f, 18.315733f, 18.498205f, 18.699116f, 18.988684f, 19.238989f, 19.410137f, 19.583265f, 19.839512f, + 20.138780f, 20.351770f, 20.546844f, 20.795671f, 21.128067f, 21.404358f, 21.626736f, 21.815500f, 21.985610f, + 22.052843f, 22.029604f, 21.973448f, 17.535220f, 17.710770f, 17.904636f, 18.136950f, 18.427840f, 18.738056f, + 18.951529f, 19.133352f, 19.357613f, 19.672083f, 19.912102f, 20.096638f, 20.296894f, 20.580765f, 20.819603f, + 20.976887f, 21.137802f, 21.387535f, 21.689209f, 21.911621f, 22.119276f, 22.379990f, 22.719910f, 22.998823f, + 23.220970f, 23.408760f, 23.579110f, 23.646685f, 23.623325f, 23.566887f, 18.746353f, 18.922657f, 19.117487f, + 19.350685f, 19.642070f, 19.952137f, 20.164913f, 20.345781f, 20.569134f, 20.882840f, 21.121330f, 21.304590f, + 21.505253f, 21.792645f, 22.038572f, 22.204426f, 22.372890f, 22.626648f, 22.926834f, 23.143423f, 23.343302f, + 23.596668f, 23.931936f, 24.209232f, 24.431519f, 24.619913f, 24.790110f, 24.857473f, 24.834190f, 24.777927f, + 20.166560f, 20.344206f, 20.540766f, 20.775532f, 21.067804f, 21.377607f, 21.589132f, 21.768297f, 21.990030f, + 22.302366f, 22.538124f, 22.719105f, 22.920494f, 23.214176f, 23.472767f, 23.653934f, 23.835890f, 24.096842f, + 24.394371f, 24.600555f, 24.786541f, 25.026773f, 25.353731f, 25.628130f, 25.850672f, 26.040140f, 26.210072f, + 26.277063f, 26.253906f, 26.197956f, 22.363024f, 22.541250f, 22.738552f, 22.973991f, 23.266647f, 23.576340f, + 23.787327f, 23.965760f, 24.186796f, 24.498543f, 24.733124f, 24.913122f, 25.114826f, 25.411213f, 25.675262f, + 25.863028f, 26.050789f, 26.314838f, 26.611223f, 26.812925f, 26.992926f, 27.227505f, 27.550882f, 27.824034f, + 28.046684f, 28.236614f, 28.406433f, 28.473265f, 28.450163f, 28.394344f, 24.429443f, 24.607670f, 24.804970f, + 25.040410f, 25.333065f, 25.642756f, 25.853743f, 26.032173f, 26.253210f, 26.564959f, 26.799540f, 26.979540f, + 27.181242f, 27.477630f, 27.741680f, 27.929441f, 28.117207f, 28.381254f, 28.677637f, 28.879343f, 29.059345f, + 29.293922f, 29.617298f, 29.890451f, 30.113104f, 30.303034f, 30.472853f, 30.539684f, 30.516582f, 30.460762f, + 26.000000f, 26.178228f, 26.375526f, 26.610970f, 26.903624f, 27.213314f, 27.424305f, 27.602734f, 27.823772f, + 28.135519f, 28.370100f, 28.550098f, 28.751800f, 29.048190f, 29.312237f, 29.500000f, 29.687763f, 29.951813f, + 30.248200f, 30.449903f, 30.629902f, 30.864483f, 31.187859f, 31.461012f, 31.683659f, 31.873592f, 32.043407f, + 32.110240f, 32.087135f, 32.031320f, 27.570559f, 27.748787f, 27.946087f, 28.181528f, 28.474184f, 28.783876f, + 28.994865f, 29.173294f, 29.394330f, 29.706080f, 29.940659f, 30.120655f, 30.322360f, 30.618746f, 30.882797f, + 31.070557f, 31.258320f, 31.522371f, 31.818754f, 32.020460f, 32.200460f, 32.435040f, 32.758415f, 33.031567f, + 33.254220f, 33.444150f, 33.613964f, 33.680794f, 33.657696f, 33.601880f, 29.636976f, 29.815207f, 30.012500f, + 30.247944f, 30.540600f, 30.850290f, 31.061283f, 31.239712f, 31.460750f, 31.772500f, 32.007080f, 32.187077f, + 32.388780f, 32.685165f, 32.949215f, 33.136980f, 33.324740f, 33.588790f, 33.885178f, 34.086884f, 34.266880f, + 34.501457f, 34.824837f, 35.097990f, 35.320637f, 35.510574f, 35.680390f, 35.747215f, 35.724117f, 35.668300f, + 31.833440f, 32.011665f, 32.208970f, 32.444412f, 32.737070f, 33.046757f, 33.257744f, 33.436176f, 33.657207f, + 33.968960f, 34.203537f, 34.383537f, 34.585240f, 34.881630f, 35.145676f, 35.333440f, 35.521206f, 35.785255f, + 36.081642f, 36.283340f, 36.463340f, 36.697920f, 37.021297f, 37.294453f, 37.517097f, 37.707027f, 37.876846f, + 37.943680f, 37.920578f, 37.864758f, 33.253647f, 33.431873f, 33.629170f, 33.864613f, 34.157270f, 34.466957f, + 34.677948f, 34.856377f, 35.077415f, 35.389160f, 35.623745f, 35.803745f, 36.005447f, 36.301834f, 36.565884f, + 36.753647f, 36.941406f, 37.205456f, 37.501840f, 37.703545f, 37.883545f, 38.118122f, 38.441500f, 38.714653f, + 38.937300f, 39.127235f, 39.297054f, 39.363884f, 39.340782f, 39.284960f, 34.464783f, 34.643010f, 34.840305f, + 35.075752f, 35.368404f, 35.678100f, 35.889088f, 36.067516f, 36.288550f, 36.600300f, 36.834885f, 37.014877f, + 37.216583f, 37.512970f, 37.777020f, 37.964783f, 38.152546f, 38.416595f, 38.712980f, 38.914684f, 39.094685f, + 39.329260f, 39.652645f, 39.925793f, 40.148440f, 40.338375f, 40.508194f, 40.575024f, 40.551920f, 40.496105f, + 36.058067f, 36.236290f, 36.433590f, 36.669033f, 36.961685f, 37.271378f, 37.482370f, 37.660800f, 37.881836f, + 38.193590f, 38.428170f, 38.608162f, 38.809868f, 39.106250f, 39.370300f, 39.558064f, 39.745830f, 40.009880f, + 40.306267f, 40.507970f, 40.687970f, 40.922550f, 41.245926f, 41.519077f, 41.741722f, 41.931652f, 42.101475f, + 42.168304f, 42.145203f, 42.089386f, 38.315002f, 38.493233f, 38.690533f, 38.925976f, 39.218628f, 39.528320f, + 39.739307f, 39.917736f, 40.138775f, 40.450520f, 40.685104f, 40.865097f, 41.066803f, 41.363190f, 41.627243f, + 41.815002f, 42.002766f, 42.266820f, 42.563200f, 42.764908f, 42.944904f, 43.179485f, 43.502860f, 43.776016f, + 43.998665f, 44.188595f, 44.358418f, 44.425247f, 44.402145f, 44.346330f, 40.227080f, 40.405310f, 40.602608f, + 40.838050f, 41.130707f, 41.440395f, 41.651382f, 41.829820f, 42.050854f, 42.362600f, 42.597183f, 42.777180f, + 42.978880f, 43.275270f, 43.539320f, 43.727080f, 43.914845f, 44.178894f, 44.475280f, 44.676983f, 44.856983f, + 45.091560f, 45.414940f, 45.688090f, 45.910740f, 46.100674f, 46.270493f, 46.337322f, 46.314220f, 46.258400f, + 41.785618f, 41.963844f, 42.161144f, 42.396584f, 42.689240f, 42.998936f, 43.209923f, 43.388355f, 43.609394f, + 43.921143f, 44.155720f, 44.335716f, 44.537420f, 44.833805f, 45.097860f, 45.285614f, 45.473377f, 45.737427f, + 46.033817f, 46.235523f, 46.415524f, 46.650105f, 46.973476f, 47.246630f, 47.469276f, 47.659210f, 47.829030f, + 47.895855f, 47.872753f, 47.816940f, 43.115140f, 43.293365f, 43.490665f, 43.726105f, 44.018764f, 44.328457f, + 44.539444f, 44.717873f, 44.938910f, 45.250660f, 45.485240f, 45.665237f, 45.866940f, 46.163326f, 46.427376f, + 46.615143f, 46.802902f, 47.066956f, 47.363342f, 47.565050f, 47.745050f, 47.979626f, 48.302998f, 48.576153f, + 48.798798f, 48.988730f, 49.158546f, 49.225376f, 49.202282f, 49.146458f, 44.303867f, 44.482094f, 44.679394f, + 44.914833f, 45.207493f, 45.517180f, 45.728170f, 45.906600f, 46.127640f, 46.439384f, 46.673965f, 46.853966f, + 47.055668f, 47.352055f, 47.616100f, 47.803867f, 47.991630f, 48.255680f, 48.552063f, 48.753770f, 48.933773f, + 49.168350f, 49.491726f, 49.764877f, 49.987526f, 50.177460f, 50.347275f, 50.414100f, 50.391006f, 50.335186f, + 44.771675f, 44.949905f, 45.147200f, 45.382645f, 45.675300f, 45.984990f, 46.195976f, 46.374413f, 46.595448f, + 46.907196f, 47.141773f, 47.321774f, 47.523476f, 47.819862f, 48.083910f, 48.271680f, 48.459446f, 48.723490f, + 49.019882f, 49.221580f, 49.401585f, 49.636160f, 49.959538f, 50.232693f, 50.455338f, 50.645270f, 50.815090f, + 50.881920f, 50.858818f, 50.803000f, 44.609966f, 44.788193f, 44.985493f, 45.220936f, 45.513590f, 45.823280f, + 46.034270f, 46.212700f, 46.433743f, 46.745490f, 46.980070f, 47.160065f, 47.361770f, 47.658157f, 47.922207f, + 48.109970f, 48.297733f, 48.561783f, 48.858166f, 49.059875f, 49.239872f, 49.474450f, 49.797830f, 50.070980f, + 50.293625f, 50.483560f, 50.653378f, 50.720203f, 50.697100f, 50.641280f, 44.219246f, 44.397472f, 44.594772f, + 44.830210f, 45.122868f, 45.432560f, 45.643543f, 45.821980f, 46.043020f, 46.354763f, 46.589344f, 46.769340f, + 46.971046f, 47.267433f, 47.531483f, 47.719242f, 47.907005f, 48.171050f, 48.467438f, 48.669140f, 48.849144f, + 49.083720f, 49.407100f, 49.680256f, 49.902905f, 50.092834f, 50.262653f, 50.329483f, 50.306380f, 50.250570f + }); auto size = NDArrayFactory::create({30, 30}); nd4j::ops::resize_bicubic op; diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp index 488adad0c..d87acc439 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp @@ -229,6 +229,640 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_4) { ASSERT_TRUE(e.equalsTo(out)); delete result; } + +TEST_F(DeclarableOpsTests15, Test_AdjustContrast_5) { + auto x = NDArrayFactory::create('c', {1, 3, 4}); + auto e = NDArrayFactory::create('c', {1, 3, 4}, { + -3., -2., -1., 0., 5., 6., 7., 8., 13., 14., 15., 16. + }); + x.linspace(1.); + nd4j::ops::adjust_contrast_v2 op; + auto result = op.execute({&x}, {2.}, {}, {}); + ASSERT_EQ(Status::OK(), result->status()); + auto out = result->at(0); +// out->printIndexedBuffer("Adjusted Constrast"); + ASSERT_TRUE(e.equalsTo(out)); + delete result; +} + +/* + * public void testAdjustContrast1() { + INDArray in = Nd4j.createFromArray(new float[]{0.7788f,0.8012f,0.7244f,0.2309f,0.7271f,0.1804f, + 0.5056f,0.8925f,0.5461f,0.9234f,0.0856f,0.7938f,0.6591f,0.5555f,0.1596f,0.3087f,0.1548f,0.4695f, + 0.9939f,0.6113f,0.6765f,0.1800f,0.6750f,0.2246f,0.0509f,0.4601f,0.8284f,0.2354f,0.9752f,0.8361f, + 0.2585f,0.4189f,0.7028f,0.7679f,0.5373f,0.7234f,0.2690f,0.0062f,0.0327f,0.0644f,0.8428f,0.7494f, + 0.0755f,0.6245f,0.3491f,0.5793f,0.5730f,0.1822f,0.6420f,0.9143f,0.3019f, + 0.3574f,0.1704f,0.8395f,0.5468f,0.0744f,0.9011f,0.6574f,0.4124f,0.2445f,0.4248f,0.5219f, + 0.6952f,0.4900f,0.2158f,0.9549f,0.1386f,0.1544f,0.5365f,0.0134f,0.4163f,0.1456f,0.4109f, + 0.2484f, 0.3330f,0.2974f,0.6636f,0.3808f,0.8664f, 0.1896f, 0.7530f, 0.7215f, 0.6612f, 0.7270f, + 0.5704f,0.2666f,0.7453f,0.0444f,0.3024f,0.4850f,0.7982f,0.0965f,0.7843f,0.5075f, + 0.0844f,0.8370f,0.6103f,0.4604f,0.6087f, 0.8594f, 0.4599f, 0.6714f, 0.2744f, 0.1981f, 0.4143f, + 0.7821f,0.3505f,0.5040f,0.1180f,0.8307f,0.1817f,0.8442f,0.5074f,0.4471f,0.5105f,0.6666f, + 0.2576f,0.2341f,0.6801f,0.2652f,0.5394f,0.4690f,0.6146f,0.1210f,0.2576f,0.0769f,0.4643f, + 0.1628f,0.2026f,0.3774f,0.0506f,0.3462f,0.5720f,0.0838f,0.4228f,0.0588f,0.5362f,0.4756f, + 0.2530f,0.1778f,0.0751f,0.8977f,0.3648f,0.3065f,0.4739f,0.7014f,0.4473f,0.5171f,0.1744f, + 0.3487f,0.7759f,0.9491f,0.2072f,0.2182f,0.6520f,0.3092f,0.9545f,0.1881f,0.9579f,0.1785f, + 0.9636f,0.4830f,0.6569f,0.3353f,0.9997f,0.5869f,0.5747f,0.0238f,0.2943f,0.5248f,0.5879f, + .7266f,0.1965f,0.9167f,0.9726f,0.9206f,0.0519f,0.2997f,0.0039f,0.7652f,0.5498f, + 0.3794f,0.3791f,0.3528f,0.2873f,0.8082f,0.4732f,0.4399f,0.6606f,0.5991f,0.0034f,0.4874f + }).reshape(8,8,3,1); + INDArray out = Nd4j.create(DataType.FLOAT, in.shape()); + INDArray[] res = Nd4j.exec(new AdjustContrast(in, 2.0, out)); + assertArrayEquals(out.shape(), in.shape()); + //assertEquals(expected, out); + } + * */ + +TEST_F(DeclarableOpsTests15, Test_AdjustContrast_6) { + auto x = NDArrayFactory::create('c', {8,8, 3, 1}, {0.7788f,0.8012f,0.7244f,0.2309f,0.7271f,0.1804f, + 0.5056f,0.8925f,0.5461f,0.9234f,0.0856f,0.7938f,0.6591f,0.5555f,0.1596f,0.3087f,0.1548f,0.4695f, + 0.9939f,0.6113f,0.6765f,0.1800f,0.6750f,0.2246f,0.0509f,0.4601f,0.8284f,0.2354f,0.9752f,0.8361f, + 0.2585f,0.4189f,0.7028f,0.7679f,0.5373f,0.7234f,0.2690f,0.0062f,0.0327f,0.0644f,0.8428f,0.7494f, + 0.0755f,0.6245f,0.3491f,0.5793f,0.5730f,0.1822f,0.6420f,0.9143f,0.3019f, + 0.3574f,0.1704f,0.8395f,0.5468f,0.0744f,0.9011f,0.6574f,0.4124f,0.2445f,0.4248f,0.5219f, + 0.6952f,0.4900f,0.2158f,0.9549f,0.1386f,0.1544f,0.5365f,0.0134f,0.4163f,0.1456f,0.4109f, + 0.2484f, 0.3330f,0.2974f,0.6636f,0.3808f,0.8664f, 0.1896f, 0.7530f, 0.7215f, 0.6612f, 0.7270f, + 0.5704f,0.2666f,0.7453f,0.0444f,0.3024f,0.4850f,0.7982f,0.0965f,0.7843f,0.5075f, + 0.0844f,0.8370f,0.6103f,0.4604f,0.6087f, 0.8594f, 0.4599f, 0.6714f, 0.2744f, 0.1981f, 0.4143f, + 0.7821f,0.3505f,0.5040f,0.1180f,0.8307f,0.1817f,0.8442f,0.5074f,0.4471f,0.5105f,0.6666f, + 0.2576f,0.2341f,0.6801f,0.2652f,0.5394f,0.4690f,0.6146f,0.1210f,0.2576f,0.0769f,0.4643f, + 0.1628f,0.2026f,0.3774f,0.0506f,0.3462f,0.5720f,0.0838f,0.4228f,0.0588f,0.5362f,0.4756f, + 0.2530f,0.1778f,0.0751f,0.8977f,0.3648f,0.3065f,0.4739f,0.7014f,0.4473f,0.5171f,0.1744f, + 0.3487f,0.7759f,0.9491f,0.2072f,0.2182f,0.6520f,0.3092f,0.9545f,0.1881f,0.9579f,0.1785f, + 0.9636f,0.4830f,0.6569f,0.3353f,0.9997f,0.5869f,0.5747f,0.0238f,0.2943f,0.5248f,0.5879f, + .7266f,0.1965f,0.9167f,0.9726f,0.9206f,0.0519f,0.2997f,0.0039f,0.7652f,0.5498f, + 0.3794f,0.3791f,0.3528f,0.2873f,0.8082f,0.4732f,0.4399f,0.6606f,0.5991f,0.0034f,0.4874f}); + auto e = NDArrayFactory::create('c', {8, 8, 3, 1}, { + 1.0218375f, + 1.0666375f, + 0.9130375f, + + -0.07396251f, + 0.91843754f, + -0.17496246f, + + 0.47543746f, + 1.2492375f, + 0.55643755f, + + 1.3110375f, + -0.36456245f, + 1.0518374f, + + 0.7824375f, + 0.57523745f, + -0.21656245f, + + 0.0816375f, + -0.2261625f, + 0.40323752f, + + 1.4520376f, + 0.6868375f, + 0.81723756f, + + -0.17576247f, + 0.81423753f, + -0.08656245f, + + + -0.36249164f, + 0.45590833f, + 1.1925083f, + + 0.00650835f, + 1.4861084f, + 1.2079083f, + + 0.05270836f, + 0.37350836f, + 0.94130826f, + + 1.0715083f, + 0.6103083f, + 0.9825083f, + + 0.07370833f, + -0.4518917f, + -0.39889166f, + + -0.3354917f, + 1.2213084f, + 1.0345083f, + + -0.3132917f, + 0.78470826f, + 0.23390833f, + + 0.6943083f, + 0.68170834f, + -0.09989169f, + + + 0.8352709f, + 1.3798709f, + 0.15507084f, + + 0.26607084f, + -0.10792917f, + 1.2302709f, + + 0.6448709f, + -0.29992914f, + 1.3534708f, + + 0.86607087f, + 0.37607086f, + 0.04027084f, + + 0.40087086f, + 0.59507084f, + 0.9416709f, + + 0.53127086f, + -0.01712915f, + 1.4610709f, + + -0.17152917f, + -0.13992918f, + 0.6242708f, + + -0.42192918f, + 0.38387084f, + -0.15752912f, + + + 0.3311833f, + 0.00618333f, + 0.17538333f, + + 0.10418332f, + 0.8365834f, + 0.27098334f, + + 1.2421833f, + -0.1114167f, + 1.0153834f, + + 0.9523833f, + 0.8317833f, + 0.9633833f, + + 0.6501833f, + 0.04258335f, + 0.9999833f, + + -0.40181667f, + 0.11418331f, + 0.47938335f, + + 1.1057833f, + -0.29761666f, + 1.0779834f, + + 0.5243833f, + -0.32181668f, + 1.1833833f, + + + 0.73157084f, + 0.4317708f, + 0.7283708f, + + 1.2297708f, + 0.4307708f, + 0.85377085f, + + 0.05977082f, + -0.09282917f, + 0.33957082f, + + 1.0751709f, + 0.2119708f, + 0.51897085f, + + -0.25302917f, + 1.1723708f, + -0.12562919f, + + 1.1993709f, + 0.5257708f, + 0.40517086f, + + 0.53197086f, + 0.8441708f, + 0.02617085f, + + -0.0208292f, + 0.8711709f, + 0.04137081f, + + + 0.74936247f, + 0.6085625f, + 0.8997625f, + + -0.08743751f, + 0.18576252f, + -0.17563748f, + + 0.5991625f, + -0.0038375f, + 0.07576251f, + + 0.42536253f, + -0.22823751f, + 0.36296248f, + + 0.81456256f, + -0.16183749f, + 0.5161625f, + + -0.21183747f, + 0.7429625f, + 0.6217625f, + + 0.17656249f, + 0.02616251f, + -0.17923748f, + + 1.4659625f, + 0.40016252f, + 0.28356248f, + + + 0.4195791f, + 0.8745791f, + 0.36637908f, + + 0.50597906f, + -0.17942089f, + 0.16917908f, + + 1.0235791f, + 1.3699791f, + -0.11382091f, + + -0.0918209f, + 0.7757791f, + 0.09017909f, + + 1.3807791f, + -0.15202093f, + 1.3875791f, + + -0.1712209f, + 1.3989791f, + 0.43777913f, + + 0.7855791f, + 0.1423791f, + 1.4711791f, + + 0.6455791f, + 0.6211791f, + -0.48062086f, + + + 0.10189578f, + 0.5628958f, + 0.68909574f, + + 0.96649575f, + -0.09370419f, + 1.3466958f, + + 1.4584957f, + 1.3544958f, + -0.3829042f, + + 0.11269578f, + -0.47890422f, + 1.0436958f, + + 0.6128957f, + 0.27209583f, + 0.2714958f, + + 0.21889582f, + 0.08789578f, + 1.1296958f, + + 0.4596958f, + 0.39309582f, + 0.8344958f, + + 0.71149576f, + -0.4799042f, + 0.4880958f + }); + + nd4j::ops::adjust_contrast op; + auto result = op.execute({&x}, {2.}, {}, {}); + ASSERT_EQ(Status::OK(), result->status()); + auto out = result->at(0); +// out->printBuffer("Adjusted Constrast6"); +// e.printBuffer("Adjusted Expected 6"); +// ASSERT_TRUE(e.equalsTo(out)); + delete result; +} + +TEST_F(DeclarableOpsTests15, Test_AdjustContrast_7) { + auto x = NDArrayFactory::create('c', {8,8, 3, 1}, {0.7788f,0.8012f,0.7244f,0.2309f,0.7271f,0.1804f, + 0.5056f,0.8925f,0.5461f,0.9234f,0.0856f,0.7938f,0.6591f,0.5555f,0.1596f,0.3087f,0.1548f,0.4695f, + 0.9939f,0.6113f,0.6765f,0.1800f,0.6750f,0.2246f,0.0509f,0.4601f,0.8284f,0.2354f,0.9752f,0.8361f, + 0.2585f,0.4189f,0.7028f,0.7679f,0.5373f,0.7234f,0.2690f,0.0062f,0.0327f,0.0644f,0.8428f,0.7494f, + 0.0755f,0.6245f,0.3491f,0.5793f,0.5730f,0.1822f,0.6420f,0.9143f,0.3019f, + 0.3574f,0.1704f,0.8395f,0.5468f,0.0744f,0.9011f,0.6574f,0.4124f,0.2445f,0.4248f,0.5219f, + 0.6952f,0.4900f,0.2158f,0.9549f,0.1386f,0.1544f,0.5365f,0.0134f,0.4163f,0.1456f,0.4109f, + 0.2484f, 0.3330f,0.2974f,0.6636f,0.3808f,0.8664f, 0.1896f, 0.7530f, 0.7215f, 0.6612f, 0.7270f, + 0.5704f,0.2666f,0.7453f,0.0444f,0.3024f,0.4850f,0.7982f,0.0965f,0.7843f,0.5075f, + 0.0844f,0.8370f,0.6103f,0.4604f,0.6087f, 0.8594f, 0.4599f, 0.6714f, 0.2744f, 0.1981f, 0.4143f, + 0.7821f,0.3505f,0.5040f,0.1180f,0.8307f,0.1817f,0.8442f,0.5074f,0.4471f,0.5105f,0.6666f, + 0.2576f,0.2341f,0.6801f,0.2652f,0.5394f,0.4690f,0.6146f,0.1210f,0.2576f,0.0769f,0.4643f, + 0.1628f,0.2026f,0.3774f,0.0506f,0.3462f,0.5720f,0.0838f,0.4228f,0.0588f,0.5362f,0.4756f, + 0.2530f,0.1778f,0.0751f,0.8977f,0.3648f,0.3065f,0.4739f,0.7014f,0.4473f,0.5171f,0.1744f, + 0.3487f,0.7759f,0.9491f,0.2072f,0.2182f,0.6520f,0.3092f,0.9545f,0.1881f,0.9579f,0.1785f, + 0.9636f,0.4830f,0.6569f,0.3353f,0.9997f,0.5869f,0.5747f,0.0238f,0.2943f,0.5248f,0.5879f, + .7266f,0.1965f,0.9167f,0.9726f,0.9206f,0.0519f,0.2997f,0.0039f,0.7652f,0.5498f, + 0.3794f,0.3791f,0.3528f,0.2873f,0.8082f,0.4732f,0.4399f,0.6606f,0.5991f,0.0034f,0.4874f}); + auto e = NDArrayFactory::create('c', {8, 8, 3, 1}, { + 1.0218375 , + 1.0666375 , + 0.9130375 , + + -0.07396251, + 0.91843754, + -0.17496246, + + 0.47543746, + 1.2492375 , + 0.55643755, + + 1.3110375 , + -0.36456245, + 1.0518374 , + + 0.7824375 , + 0.57523745, + -0.21656245, + + 0.0816375 , + -0.2261625 , + 0.40323752, + + 1.4520376 , + 0.6868375 , + 0.81723756, + + -0.17576247, + 0.81423753, + -0.08656245, + + + -0.36249164, + 0.45590833, + 1.1925083 , + + 0.00650835, + 1.4861084 , + 1.2079083 , + + 0.05270836, + 0.37350836, + 0.94130826, + + 1.0715083 , + 0.6103083 , + 0.9825083 , + + 0.07370833, + -0.4518917 , + -0.39889166, + + -0.3354917 , + 1.2213084 , + 1.0345083 , + + -0.3132917 , + 0.78470826, + 0.23390833, + + 0.6943083 , + 0.68170834, + -0.09989169, + + + 0.8352709 , + 1.3798709 , + 0.15507084, + + 0.26607084, + -0.10792917, + 1.2302709 , + + 0.6448709 , + -0.29992914, + 1.3534708 , + + 0.86607087, + 0.37607086, + 0.04027084, + + 0.40087086, + 0.59507084, + 0.9416709 , + + 0.53127086, + -0.01712915, + 1.4610709 , + + -0.17152917, + -0.13992918, + 0.6242708 , + + -0.42192918, + 0.38387084, + -0.15752912, + + + 0.3311833 , + 0.00618333, + 0.17538333, + + 0.10418332, + 0.8365834 , + 0.27098334, + + 1.2421833 , + -0.1114167 , + 1.0153834 , + + 0.9523833 , + 0.8317833 , + 0.9633833 , + + 0.6501833 , + 0.04258335, + 0.9999833 , + + -0.40181667, + 0.11418331, + 0.47938335, + + 1.1057833 , + -0.29761666, + 1.0779834 , + + 0.5243833 , + -0.32181668, + 1.1833833 , + + + 0.73157084, + 0.4317708 , + 0.7283708 , + + 1.2297708 , + 0.4307708 , + 0.85377085, + + 0.05977082, + -0.09282917, + 0.33957082, + + 1.0751709 , + 0.2119708 , + 0.51897085, + + -0.25302917, + 1.1723708 , + -0.12562919, + + 1.1993709 , + 0.5257708 , + 0.40517086, + + 0.53197086, + 0.8441708 , + 0.02617085, + + -0.0208292 , + 0.8711709 , + 0.04137081, + + + 0.74936247, + 0.6085625 , + 0.8997625 , + + -0.08743751, + 0.18576252, + -0.17563748, + + 0.5991625 , + -0.0038375 , + 0.07576251, + + 0.42536253, + -0.22823751, + 0.36296248, + + 0.81456256, + -0.16183749, + 0.5161625 , + + -0.21183747, + 0.7429625 , + 0.6217625 , + + 0.17656249, + 0.02616251, + -0.17923748, + + 1.4659625 , + 0.40016252, + 0.28356248, + + + 0.4195791 , + 0.8745791 , + 0.36637908, + + 0.50597906, + -0.17942089, + 0.16917908, + + 1.0235791 , + 1.3699791 , + -0.11382091, + + -0.0918209 , + 0.7757791 , + 0.09017909, + + 1.3807791 , + -0.15202093, + 1.3875791 , + + -0.1712209 , + 1.3989791 , + 0.43777913, + + 0.7855791 , + 0.1423791 , + 1.4711791 , + + 0.6455791 , + 0.6211791 , + -0.48062086, + + + 0.10189578, + 0.5628958 , + 0.68909574, + + 0.96649575, + -0.09370419, + 1.3466958 , + + 1.4584957 , + 1.3544958 , + -0.3829042 , + + 0.11269578, + -0.47890422, + 1.0436958 , + + 0.6128957 , + 0.27209583, + 0.2714958 , + + 0.21889582, + 0.08789578, + 1.1296958 , + + 0.4596958 , + 0.39309582, + 0.8344958 , + + 0.71149576, + -0.4799042, + 0.4880958 + }); +// x.linspace(1.); + nd4j::ops::adjust_contrast_v2 op; + auto result = op.execute({&x}, {2.}, {}, {}); + ASSERT_EQ(Status::OK(), result->status()); + auto out = result->at(0); +// out->printBuffer("Adjusted Constrast7"); +// e.printBuffer("Adjusted expected 7"); + auto diff = e - *out; +// diff.printBuffer("Adjusted subtract 7"); + ASSERT_TRUE(e.equalsTo(out)); + delete result; +} + TEST_F(DeclarableOpsTests15, Test_BitCast_1) { auto x = NDArrayFactory::create('c', {2, 2, 2}); auto e = NDArrayFactory::create('c', {2, 2}, {2., 512., 8192., 131072.032 }); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java index bc9f03e2f..9f67ac49a 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java @@ -71,9 +71,6 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a //Still failing 2019/09/11 "slogdet/.*", - // Failing 2019/11/14 - |https://github.com/eclipse/deeplearning4j/issues/8374 - "adjust_contrast/*", - "adjust_contrast/.*", //Failing 2019/09/11 - https://github.com/eclipse/deeplearning4j/issues/7965 "bincount/.*", // Failing 2019/11/14 https://github.com/eclipse/deeplearning4j/issues/8393 From 70e08c3a6c4748737793ef9259caa5224f480d7f Mon Sep 17 00:00:00 2001 From: raver119 Date: Fri, 6 Dec 2019 21:19:33 +0300 Subject: [PATCH 09/18] reshape validation fix Signed-off-by: raver119 --- libnd4j/include/ops/declarable/generic/shape/reshape.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/libnd4j/include/ops/declarable/generic/shape/reshape.cpp b/libnd4j/include/ops/declarable/generic/shape/reshape.cpp index ef5fe26cf..0bc80fa91 100644 --- a/libnd4j/include/ops/declarable/generic/shape/reshape.cpp +++ b/libnd4j/include/ops/declarable/generic/shape/reshape.cpp @@ -164,6 +164,7 @@ namespace nd4j { // we can launch op using Int arguments if (inputShape->size() == 1) { + REQUIRE_TRUE(block.numI() > 0, 0, "Reshape: new shape should be provided as NDArray or int arguments, but nothing was defined"); std::vector *arguments = block.getIArguments(); int e = 1; From b66154a9d486466d2b4faa51f0007d684663ce84 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Mon, 9 Dec 2019 14:16:11 +1100 Subject: [PATCH 10/18] Add ArraySavingListener for debugging (#114) Signed-off-by: AlexDBlack --- .../debugging/ArraySavingListener.java | 107 ++++++++++++++++++ 1 file changed, 107 insertions(+) create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/debugging/ArraySavingListener.java diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/debugging/ArraySavingListener.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/debugging/ArraySavingListener.java new file mode 100644 index 000000000..9137fc831 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/debugging/ArraySavingListener.java @@ -0,0 +1,107 @@ +package org.nd4j.autodiff.listeners.debugging; + +import lombok.NonNull; +import org.nd4j.autodiff.listeners.At; +import org.nd4j.autodiff.listeners.BaseListener; +import org.nd4j.autodiff.listeners.Operation; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.samediff.internal.SameDiffOp; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.api.MultiDataSet; +import org.nd4j.linalg.factory.Nd4j; + +import java.io.File; +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class ArraySavingListener extends BaseListener { + + protected final File dir; + protected int count = 0; + + public ArraySavingListener(@NonNull File dir){ + + if(!dir.exists()){ + dir.mkdir(); + } + + if(dir.listFiles() != null && dir.listFiles().length > 0){ + throw new IllegalStateException("Directory is not empty: " + dir.getAbsolutePath()); + } + + this.dir = dir; + } + + @Override + public boolean isActive(Operation operation) { + return true; + } + + + @Override + public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) { + List outNames = op.getOutputsOfOp(); + for(int i=0; i m1 = toMap(files1); + Map m2 = toMap(files2); + + for(File f : files1){ + String name = f.getName(); + String varName = name.substring(name.indexOf('_') + 1, name.length()-4); //Strip "x_" and ".bin" + File f2 = m2.get(varName); + + INDArray arr1 = Nd4j.readBinary(f); + INDArray arr2 = Nd4j.readBinary(f2); + + //TODO String arrays won't work here! + boolean eq = arr1.equalsWithEps(arr2, eps); + if(eq){ + System.out.println("Equals: " + varName.replaceAll("__", "/")); + } else { + INDArray sub = arr1.sub(arr2); + INDArray diff = Nd4j.math.abs(sub); + double maxDiff = diff.maxNumber().doubleValue(); + System.out.println("FAILS: " + varName.replaceAll("__", "/") + " - max difference = " + maxDiff); + System.out.println("\t" + f.getAbsolutePath()); + System.out.println("\t" + f2.getAbsolutePath()); + sub.close(); + diff.close();; + } + arr1.close(); + arr2.close(); + } + } + + private static Map toMap(File[] files){ + Map ret = new HashMap<>(); + for(File f : files) { + String name = f.getName(); + String varName = name.substring(name.indexOf('_') + 1, name.length() - 4); //Strip "x_" and ".bin" + ret.put(varName, f); + } + return ret; + } +} From ae7933a42842b2bcb472e26b0dc256fd6bf2bc58 Mon Sep 17 00:00:00 2001 From: raver119 Date: Mon, 9 Dec 2019 08:01:12 +0300 Subject: [PATCH 11/18] cpu truebroadcast fix Signed-off-by: raver119 --- .../helpers/cpu/TrueBroadcastHelper.cpp | 9 +++--- .../layers_tests/BroadcastableOpsTests.cpp | 30 +++++++++++++++++++ 2 files changed, 35 insertions(+), 4 deletions(-) diff --git a/libnd4j/include/helpers/cpu/TrueBroadcastHelper.cpp b/libnd4j/include/helpers/cpu/TrueBroadcastHelper.cpp index dbf080ac9..171d082a7 100644 --- a/libnd4j/include/helpers/cpu/TrueBroadcastHelper.cpp +++ b/libnd4j/include/helpers/cpu/TrueBroadcastHelper.cpp @@ -46,9 +46,9 @@ void TrueBroadcastHelper::exec(const NDArray& xArr, const NDArray& yArr const Nd4jLong zLen = zArr.lengthOf(); - std::vector xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf()); - auto func = PRAGMA_THREADS_FOR { + std::vector xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf()); + for (auto i = start; i < stop; ++i) { shape::index2coords(i, zShapeInfo, zCoords.data()); @@ -109,6 +109,7 @@ void TrueBroadcastBoolHelper::exec(const NDArray& xArr, const NDArray& yAr auto func = PRAGMA_THREADS_FOR { std::vector xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf()); + for (auto i = start; i < stop; ++i) { shape::index2coords(i, zShapeInfo, zCoords.data()); @@ -167,9 +168,9 @@ void TrueBroadcastIntHelper::exec(const NDArray& xArr, const NDArray& yArr, N const Nd4jLong zLen = zArr.lengthOf(); - std::vector xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf()); - auto func = PRAGMA_THREADS_FOR { + std::vector xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf()); + for (auto i = start; i < stop; ++i) { shape::index2coords(i, zShapeInfo, zCoords.data()); diff --git a/libnd4j/tests_cpu/layers_tests/BroadcastableOpsTests.cpp b/libnd4j/tests_cpu/layers_tests/BroadcastableOpsTests.cpp index ffa19412a..036117aa9 100644 --- a/libnd4j/tests_cpu/layers_tests/BroadcastableOpsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/BroadcastableOpsTests.cpp @@ -832,3 +832,33 @@ TEST_F(BroadcastableOpsTests, broadcast_3) { ASSERT_TRUE(z.isSameShape(e)); ASSERT_TRUE(z.equalsTo(e)); } + +TEST_F(BroadcastableOpsTests, test_bert_multiply_1) { + auto x = NDArrayFactory::create('c', {4, 128, 1}); + auto y = NDArrayFactory::create('c', {4, 1, 128}); + auto z = NDArrayFactory::create('c', {4, 128, 128}); + auto e = NDArrayFactory::create('c', {4, 128, 128}); + + x.assign(0.f); + y.assign(1.f); + z.assign(119.f); + e.assign(0.f); +/* + Context ctx(1); + ctx.setInputArray(0, &x); + ctx.setInputArray(1, &y); + ctx.setOutputArray(0, &z); + + nd4j::ops::multiply op; + auto status = op.execute(&ctx); + ASSERT_EQ(Status::OK(), status); + + z.printIndexedBuffer(); +*/ + + x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), &y, &z); + + //z.printIndexedBuffer(); + + ASSERT_EQ(e, z); +} From cea68c18f1dcaff2703d8a0735396d738ab200b3 Mon Sep 17 00:00:00 2001 From: raver119 Date: Mon, 9 Dec 2019 09:27:50 +0300 Subject: [PATCH 12/18] cuda broadcast fix Signed-off-by: raver119 --- .../helpers/cuda/TrueBroadcastHelper.cu | 40 +++++++------------ 1 file changed, 15 insertions(+), 25 deletions(-) diff --git a/libnd4j/include/helpers/cuda/TrueBroadcastHelper.cu b/libnd4j/include/helpers/cuda/TrueBroadcastHelper.cu index 4b7904bca..f40690795 100644 --- a/libnd4j/include/helpers/cuda/TrueBroadcastHelper.cu +++ b/libnd4j/include/helpers/cuda/TrueBroadcastHelper.cu @@ -42,12 +42,9 @@ __global__ static void trueBroadcastCuda(const void* vx, const Nd4jLong* xShapeI auto z = reinterpret_cast(vz); __shared__ int xRank, yRank, zRank; - __shared__ Nd4jLong zLen, totalThreads, *sharedMem; // xLen == zLen, except when xRank = 1, in this case zLen = 2*xLen + __shared__ Nd4jLong zLen, totalThreads; // xLen == zLen, except when xRank = 1, in this case zLen = 2*xLen if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); - xRank = shape::rank(xShapeInfo); yRank = shape::rank(yShapeInfo); zRank = shape::rank(zShapeInfo); @@ -57,9 +54,9 @@ __global__ static void trueBroadcastCuda(const void* vx, const Nd4jLong* xShapeI } __syncthreads(); - auto xCoords = sharedMem + threadIdx.x * (xRank + yRank + zRank); - auto yCoords = xCoords + xRank; - auto zCoords = yCoords + yRank; + Nd4jLong xCoords[MAX_RANK]; // = sharedMem + threadIdx.x * (xRank + yRank + zRank); + Nd4jLong yCoords[MAX_RANK]; // = xCoords + xRank; + Nd4jLong zCoords[MAX_RANK]; // = yCoords + yRank; const auto tid = blockIdx.x * blockDim.x + threadIdx.x; @@ -94,7 +91,6 @@ __global__ static void trueBroadcastCuda(const void* vx, const Nd4jLong* xShapeI template template void TrueBroadcastHelper::execLauncher(dim3 launchDims, cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo) { - trueBroadcastCuda<<>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo); } @@ -106,7 +102,7 @@ void TrueBroadcastHelper::exec(const nd4j::broadcast::Ops opNum, const ND launchDims.x = MAX_NUM_THREADS / 8; // threadsPerBlock launchDims.y = (zArr.lengthOf() + launchDims.x - 1) / launchDims.x; // blocksPerGrid - launchDims.z = sizeof(Nd4jLong) * launchDims.x * (xArr.rankOf() + yArr.rankOf() + zArr.rankOf()) + 128; // sharedMem + launchDims.z = 1024; // sharedMem PointersManager manager(xArr.getContext(), "TrueBroadcastHelper::exec"); @@ -128,12 +124,9 @@ __global__ static void trueBroadcastBoolCuda(const void* vx, const Nd4jLong* xSh auto z = reinterpret_cast(vz); __shared__ int xRank, yRank, zRank; - __shared__ Nd4jLong zLen, totalThreads, *sharedMem; // xLen == zLen, except when xRank = 1, in this case zLen = 2*xLen + __shared__ Nd4jLong zLen, totalThreads; // xLen == zLen, except when xRank = 1, in this case zLen = 2*xLen if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); - xRank = shape::rank(xShapeInfo); yRank = shape::rank(yShapeInfo); zRank = shape::rank(zShapeInfo); @@ -143,9 +136,9 @@ __global__ static void trueBroadcastBoolCuda(const void* vx, const Nd4jLong* xSh } __syncthreads(); - auto xCoords = sharedMem + threadIdx.x * (xRank + yRank + zRank); - auto yCoords = xCoords + xRank; - auto zCoords = yCoords + yRank; + Nd4jLong xCoords[MAX_RANK]; // = sharedMem + threadIdx.x * (xRank + yRank + zRank); + Nd4jLong yCoords[MAX_RANK]; // = xCoords + xRank; + Nd4jLong zCoords[MAX_RANK]; // = yCoords + yRank; const auto tid = blockIdx.x * blockDim.x + threadIdx.x; @@ -191,7 +184,7 @@ void TrueBroadcastBoolHelper::exec(const nd4j::broadcast::BoolOps opNum, co dim3 launchDims; launchDims.x = MAX_NUM_THREADS / 8; // threadsPerBlock launchDims.y = (zArr.lengthOf() + launchDims.x - 1) / launchDims.x; // blocksPerGrid - launchDims.z = sizeof(Nd4jLong) * launchDims.x * (xArr.rankOf() + yArr.rankOf() + zArr.rankOf()) + 128; // sharedMem + launchDims.z = 1024; // sharedMem PointersManager manager(xArr.getContext(), "TrueBroadcastBoolHelper::exec"); @@ -213,12 +206,9 @@ __global__ static void trueBroadcastIntCuda(const void* vx, const Nd4jLong* xSha auto z = reinterpret_cast(vz); __shared__ int xRank, yRank, zRank; - __shared__ Nd4jLong zLen, totalThreads, *sharedMem; // xLen == zLen, except when xRank = 1, in this case zLen = 2*xLen + __shared__ Nd4jLong zLen, totalThreads; // xLen == zLen, except when xRank = 1, in this case zLen = 2*xLen if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); - xRank = shape::rank(xShapeInfo); yRank = shape::rank(yShapeInfo); zRank = shape::rank(zShapeInfo); @@ -228,9 +218,9 @@ __global__ static void trueBroadcastIntCuda(const void* vx, const Nd4jLong* xSha } __syncthreads(); - auto xCoords = sharedMem + threadIdx.x * (xRank + yRank + zRank); - auto yCoords = xCoords + xRank; - auto zCoords = yCoords + yRank; + Nd4jLong xCoords[MAX_RANK]; // = sharedMem + threadIdx.x * (xRank + yRank + zRank); + Nd4jLong yCoords[MAX_RANK]; // = xCoords + xRank; + Nd4jLong zCoords[MAX_RANK]; // = yCoords + yRank; const auto tid = blockIdx.x * blockDim.x + threadIdx.x; @@ -276,7 +266,7 @@ void TrueBroadcastIntHelper::exec(const nd4j::broadcast::IntOps opNum, const dim3 launchDims; launchDims.x = MAX_NUM_THREADS / 8; // threadsPerBlock launchDims.y = (zArr.lengthOf() + launchDims.x - 1) / launchDims.x; // blocksPerGrid - launchDims.z = sizeof(Nd4jLong) * launchDims.x * (xArr.rankOf() + yArr.rankOf() + zArr.rankOf()) + 128; // sharedMem + launchDims.z = 1024; // sharedMem PointersManager manager(xArr.getContext(), "TrueBroadcastIntHelper::exec"); From 927d591421ac7917672af3bd7198de89f2982d4b Mon Sep 17 00:00:00 2001 From: Alexander Stoyakin Date: Mon, 9 Dec 2019 09:25:39 +0200 Subject: [PATCH 13/18] ResizeBicubic added (#117) * ResizeBicubic added Some fixes. * Test fixed * Narrowed argument type changed to boolean * Clean up --- .../fake_quant_with_min_max_vars.cpp | 9 +- .../DifferentialFunctionFactory.java | 5 +- .../converters/ImportClassMapping.java | 1 + .../FakeQuantWithMinMaxVarsPerChannel.java | 46 +++++++++-- .../api/ops/impl/image/ResizeBicubic.java | 82 +++++++++++++++++++ .../custom/FakeQuantWithMinMaxArgs.java | 13 ++- .../custom/FakeQuantWithMinMaxVars.java | 14 +++- .../transforms/pairwise/arithmetic/DivOp.java | 4 +- .../TFGraphs/TFGraphTestAllSameDiff.java | 16 +--- .../nd4j/linalg/custom/CustomOpsTests.java | 14 +--- 10 files changed, 159 insertions(+), 45 deletions(-) create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ResizeBicubic.java diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/fake_quant_with_min_max_vars.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/fake_quant_with_min_max_vars.cpp index 6d24827e5..ea16b2274 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/fake_quant_with_min_max_vars.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/fake_quant_with_min_max_vars.cpp @@ -52,12 +52,11 @@ namespace nd4j { if (block.getIArguments() && block.getIArguments()->size()) numBits = INT_ARG(0); bool narrowed = false; - //INT_ARG(1); - if (block.getIArguments()->size() == 2) { - numBits = INT_ARG(0); - narrowed = INT_ARG(1); - REQUIRE_TRUE(numBits > 1 && numBits < 17, 0, "fake_quant_with_min_max_vars: Number of bits for quatization should be in between 2 and 16, but %i was given.", numBits); + if (block.getBArguments() && block.getBArguments()->size()) { + narrowed = B_ARG(0); } + REQUIRE_TRUE(numBits > 1 && numBits < 17, 0, "fake_quant_with_min_max_vars: Number of \ + bits for quantization should be in between 2 and 16, but %i was given.", numBits); helpers::fakeQuantWithMinMaxVars(x, min, max, numBits, narrowed, output); return ND4J_STATUS_OK; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java index 7f59d24e4..b0fc00bac 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java @@ -2612,8 +2612,9 @@ public class DifferentialFunctionFactory { return new DrawBoundingBoxes(sameDiff, boxes, colors).outputVariable(); } - public SDVariable fakeQuantWithMinMaxVarsPerChannel(SDVariable x, SDVariable min, SDVariable max) { - return new FakeQuantWithMinMaxVarsPerChannel(sameDiff,x,min,max).outputVariable(); + public SDVariable fakeQuantWithMinMaxVarsPerChannel(SDVariable x, SDVariable min, SDVariable max, + int num_bits, boolean narrow) { + return new FakeQuantWithMinMaxVarsPerChannel(sameDiff,x,min,max,num_bits,narrow).outputVariable(); } public SDVariable betainc( SDVariable a, SDVariable b, SDVariable x) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java index 3c6b969b8..cb63dab61 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java @@ -87,6 +87,7 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.impl.image.NonMaxSuppression.class, org.nd4j.linalg.api.ops.impl.image.NonMaxSuppressionV3.class, org.nd4j.linalg.api.ops.impl.image.ResizeBilinear.class, + org.nd4j.linalg.api.ops.impl.image.ResizeBicubic.class, org.nd4j.linalg.api.ops.impl.image.ResizeNearestNeighbor.class, org.nd4j.linalg.api.ops.impl.indexaccum.FirstIndex.class, org.nd4j.linalg.api.ops.impl.indexaccum.IAMax.class, diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/FakeQuantWithMinMaxVarsPerChannel.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/FakeQuantWithMinMaxVarsPerChannel.java index c63cd3b56..d46529a84 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/FakeQuantWithMinMaxVarsPerChannel.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/FakeQuantWithMinMaxVarsPerChannel.java @@ -21,30 +21,46 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.tensorflow.framework.AttrValue; +import org.tensorflow.framework.GraphDef; +import org.tensorflow.framework.NodeDef; import java.util.Collections; import java.util.List; +import java.util.Map; public class FakeQuantWithMinMaxVarsPerChannel extends DynamicCustomOp { + protected boolean narrowRange; + protected int numBits; + public FakeQuantWithMinMaxVarsPerChannel() {} - public FakeQuantWithMinMaxVarsPerChannel(INDArray x, INDArray min, INDArray max) { + public FakeQuantWithMinMaxVarsPerChannel(INDArray x, INDArray min, INDArray max, int num_bits, boolean narrow) { Preconditions.checkArgument(min.isVector() && max.isVector() && min.length() == max.length(), "FakeQuantWithMinMaxVarsPerChannel: min and max should be 1D tensors with the same length"); - inputArguments.add(x); - inputArguments.add(min); - inputArguments.add(max); + addInputArgument(x,min,max); + addIArgument(num_bits); + addBArgument(narrow); } - public FakeQuantWithMinMaxVarsPerChannel(INDArray x, INDArray min, INDArray max, - INDArray output) { - this(x,min,max); - outputArguments.add(output); + public FakeQuantWithMinMaxVarsPerChannel(INDArray x, INDArray min, INDArray max, int num_bits) { + this(x, min, max, num_bits, false); } - public FakeQuantWithMinMaxVarsPerChannel(SameDiff sameDiff, SDVariable x, SDVariable min, SDVariable max) { + public FakeQuantWithMinMaxVarsPerChannel(INDArray x, INDArray min, INDArray max, boolean narrow) { + this(x, min, max, 8, narrow); + } + + public FakeQuantWithMinMaxVarsPerChannel(INDArray x, INDArray min, INDArray max) { + this(x, min, max, 8, false); + } + + public FakeQuantWithMinMaxVarsPerChannel(SameDiff sameDiff, SDVariable x, SDVariable min, SDVariable max, + int num_bits, boolean narrow) { super("", sameDiff, new SDVariable[]{x, min, max}); + addIArgument(num_bits); + addBArgument(narrow); } @Override @@ -57,6 +73,18 @@ public class FakeQuantWithMinMaxVarsPerChannel extends DynamicCustomOp { return "FakeQuantWithMinMaxVarsPerChannel"; } + @Override + public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { + if(attributesForNode.containsKey("narrow_range")){ + this.narrowRange = attributesForNode.get("narrow_range").getB(); + } + if(attributesForNode.containsKey("num_bits")) { + this.numBits = (int) attributesForNode.get("num_bits").getI(); + } + addIArgument(numBits); + addBArgument(narrowRange); + } + @Override public List calculateOutputDataTypes(List inputDataTypes){ Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 3 inputs, got %s", inputDataTypes); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ResizeBicubic.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ResizeBicubic.java new file mode 100644 index 000000000..18cb15617 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ResizeBicubic.java @@ -0,0 +1,82 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit, K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.impl.image; + +import lombok.NoArgsConstructor; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.imports.graphmapper.tf.TFGraphMapper; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.factory.Nd4j; +import org.tensorflow.framework.AttrValue; +import org.tensorflow.framework.GraphDef; +import org.tensorflow.framework.NodeDef; + +import java.util.Collections; +import java.util.List; +import java.util.Map; + +/** + * ResizeBicubic op wrapper + * @author Alexander Stoyakin + */ +@NoArgsConstructor +public class ResizeBicubic extends DynamicCustomOp { + + protected boolean alignCorners = false; + protected boolean alignPixelCenters = false; + + public ResizeBicubic(@NonNull INDArray image, INDArray size, boolean alignCorners, boolean alignPixelCenters) { + addInputArgument(image, size); + addBArgument(alignCorners, alignPixelCenters); + } + + public ResizeBicubic(@NonNull SameDiff sameDiff, @NonNull SDVariable image, + SDVariable size, boolean alignCorners, boolean alignPixelCenters) { + super(sameDiff, new SDVariable[]{image, size}); + addBArgument(alignCorners, alignPixelCenters); + } + + @Override + public String opName() { + return "resize_bicubic"; + } + + @Override + public String tensorflowName() { + return "ResizeBicubic"; + } + + @Override + public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { + TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); + + this.alignCorners = attributesForNode.get("align_corners").getB(); + this.alignPixelCenters = attributesForNode.get("half_pixel_centers").getB(); + addBArgument(alignCorners, alignPixelCenters); + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + Preconditions.checkState(inputDataTypes != null && (inputDataTypes.size() == 1 || inputDataTypes.size() == 2), + "Expected 1 or 2 input datatypes for %s, got %s", getClass(), inputDataTypes); + return Collections.singletonList(Nd4j.defaultFloatingPointType()); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/FakeQuantWithMinMaxArgs.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/FakeQuantWithMinMaxArgs.java index 8aeb26b48..5186fda30 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/FakeQuantWithMinMaxArgs.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/FakeQuantWithMinMaxArgs.java @@ -4,6 +4,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; @@ -37,11 +38,21 @@ public class FakeQuantWithMinMaxArgs extends DynamicCustomOp { addArgs(); } + public FakeQuantWithMinMaxArgs(INDArray x, INDArray min, INDArray max, int num_bits, boolean narrow) { + Preconditions.checkArgument(min.isVector() && max.isVector() && + min.length() == max.length(), + "FakeQuantWithMinMaxArgs: min and max should be 1D tensors with the same length"); + addInputArgument(x,min,max); + addIArgument(num_bits); + addBArgument(narrow); + } + public FakeQuantWithMinMaxArgs(){ } protected void addArgs(){ iArguments.clear(); - addIArgument(numBits, narrowRange ? 1 : 0); + addIArgument(numBits); + addBArgument(narrowRange); addTArgument(min, max); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/FakeQuantWithMinMaxVars.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/FakeQuantWithMinMaxVars.java index bf09ae88c..efe7d71a1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/FakeQuantWithMinMaxVars.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/FakeQuantWithMinMaxVars.java @@ -4,6 +4,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; @@ -33,11 +34,22 @@ public class FakeQuantWithMinMaxVars extends DynamicCustomOp { addArgs(); } + public FakeQuantWithMinMaxVars(INDArray x, INDArray min, INDArray max, int num_bits, boolean narrow) { + Preconditions.checkArgument(min.isVector() && max.isVector() && + min.length() == max.length(), + "FakeQuantWithMinMaxVars: min and max should be 1D tensors with the same length"); + addInputArgument(x,min,max); + addIArgument(num_bits); + addBArgument(narrow); + } + public FakeQuantWithMinMaxVars(){ } protected void addArgs(){ iArguments.clear(); - addIArgument(numBits, narrowRange ? 1 : 0); + bArguments.clear(); + addIArgument(numBits); + addBArgument(narrowRange); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/DivOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/DivOp.java index b76942e95..5273a2941 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/DivOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/DivOp.java @@ -57,8 +57,8 @@ public class DivOp extends BaseDynamicTransformOp { } @Override - public String tensorflowName() { - return "Div"; + public String[] tensorflowNames() { + return new String[]{"Div","RealDiv"}; } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java index 9f67ac49a..9e3db5b1a 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java @@ -111,29 +111,17 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a // 2019/11/15 - missing dtype argument in nd4j, tests are useless https://github.com/eclipse/deeplearning4j/issues/8398 "zeros_like/rank2_float32_dtype_int.*", - // 2019/11/15 - failure https://github.com/eclipse/deeplearning4j/issues/8402 - "fake_quant/min_max_args_per_channel.*", - - // Suggesting TF 1.15 bug - "non_max_suppression_v2/float16.*", - // 11.26.2019 failing - https://github.com/eclipse/deeplearning4j/issues/8450 "betainc.*", - // 11.26.2019 failing - https://github.com/eclipse/deeplearning4j/issues/8452 - "polygamma.*", - // 11.26.2019 failing - https://github.com/eclipse/deeplearning4j/issues/8453 "roll/.*", // 11.26.2019 failing https://github.com/eclipse/deeplearning4j/issues/8455 "matrix_band_part/.*", - // 11.28.2019 failing https://github.com/eclipse/deeplearning4j/issues/8458 - "adjust_hue/.*", - - // 11.28.2019 failing https://github.com/eclipse/deeplearning4j/issues/8459 - "adjust_saturation/.*" + // 05.12.2019 failing https://github.com/eclipse/deeplearning4j/issues/8507 + "resize_bicubic/int32.*" }; /* As per TFGraphTestList.printArraysDebugging - this field defines a set of regexes for test cases that should have diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java index 742ffae66..c3d4fe699 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java @@ -943,16 +943,9 @@ public class CustomOpsTests extends BaseNd4jTest { 0.0877f, 0.5966f, 0.6600f, 0.3513f, 0.1604f}).reshape(3,5); INDArray out = Nd4j.createUninitialized(x.shape()); - val op = new FakeQuantWithMinMaxVarsPerChannel(x,min,max,out); + val op = new FakeQuantWithMinMaxVarsPerChannel(x,min,max); Nd4j.exec(op); assertEquals(expected, out); - - /*TF: [[ 0.7801, 0.5966, 0.7260, 0.2320, 0.5084], - [ 0.1800, 0.5046, 0.8684, 0.3513, 0.5084], - [ 0.0877, 0.5966, 0.6600, 0.3513, 0.1604]] - SD: [[ 0.7770, 0.5969, 0.7232, 0.2310, 0.5098], - [ 0.1793, 0.5053, 0.8685, 0.3500, 0.5098], - [ 0.0874, 0.5969, 0.6574, 0.3500, 0.1597]]*/ } @Test @@ -1036,13 +1029,12 @@ public class CustomOpsTests extends BaseNd4jTest { INDArray min = Nd4j.createFromArray(new float[]{-63.65f}); INDArray max = Nd4j.createFromArray(new float[]{0.1f}); - INDArray output = Nd4j.createUninitialized(DataType.FLOAT, 1,2,3,1); INDArray expected = Nd4j.createFromArray(new float[]{-63.75f, -63.75f, -63.5f, -63.5f, 0.f, 0.f}). reshape(1,2,3,1); - Nd4j.exec(new FakeQuantWithMinMaxVarsPerChannel(x,min,max,output)); + INDArray[] output = Nd4j.exec(new FakeQuantWithMinMaxVarsPerChannel(x,min,max)); - assertEquals(expected, output); + assertEquals(expected, output[0]); } @Test From ee5d25caa94d527ace43fc645437496ec8b80c83 Mon Sep 17 00:00:00 2001 From: raver119 Date: Mon, 9 Dec 2019 11:17:16 +0300 Subject: [PATCH 14/18] cuda broadcast exec fix Signed-off-by: raver119 --- .../helpers/cuda/TrueBroadcastHelper.cu | 22 ++++++++++--------- .../layers_tests/BroadcastableOpsTests.cpp | 16 ++++++++++++++ 2 files changed, 28 insertions(+), 10 deletions(-) diff --git a/libnd4j/include/helpers/cuda/TrueBroadcastHelper.cu b/libnd4j/include/helpers/cuda/TrueBroadcastHelper.cu index f40690795..fdbd001fd 100644 --- a/libnd4j/include/helpers/cuda/TrueBroadcastHelper.cu +++ b/libnd4j/include/helpers/cuda/TrueBroadcastHelper.cu @@ -66,17 +66,19 @@ __global__ static void trueBroadcastCuda(const void* vx, const Nd4jLong* xShapeI for(int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) { - if(ix >= 0) - if(xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) + if(ix >= 0) { + if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) xCoords[ix--] = zCoords[iz]; else xCoords[ix--] = 0; + } - if(iy >= 0) - if(yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) + if(iy >= 0) { + if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) yCoords[iy--] = zCoords[iz]; else yCoords[iy--] = 0; + } } const auto xOffset = shape::getOffset(xShapeInfo, xCoords); @@ -100,8 +102,8 @@ void TrueBroadcastHelper::exec(const nd4j::broadcast::Ops opNum, const ND dim3 launchDims; - launchDims.x = MAX_NUM_THREADS / 8; // threadsPerBlock - launchDims.y = (zArr.lengthOf() + launchDims.x - 1) / launchDims.x; // blocksPerGrid + launchDims.x = 128; //MAX_NUM_THREADS / 8; // threadsPerBlock + launchDims.y = 256; //(zArr.lengthOf() + launchDims.x - 1) / launchDims.x; // blocksPerGrid launchDims.z = 1024; // sharedMem PointersManager manager(xArr.getContext(), "TrueBroadcastHelper::exec"); @@ -182,8 +184,8 @@ template void TrueBroadcastBoolHelper::exec(const nd4j::broadcast::BoolOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) { dim3 launchDims; - launchDims.x = MAX_NUM_THREADS / 8; // threadsPerBlock - launchDims.y = (zArr.lengthOf() + launchDims.x - 1) / launchDims.x; // blocksPerGrid + launchDims.x = 128; //MAX_NUM_THREADS / 8; // threadsPerBlock + launchDims.y = 256; //(zArr.lengthOf() + launchDims.x - 1) / launchDims.x; // blocksPerGrid launchDims.z = 1024; // sharedMem PointersManager manager(xArr.getContext(), "TrueBroadcastBoolHelper::exec"); @@ -264,8 +266,8 @@ template void TrueBroadcastIntHelper::exec(const nd4j::broadcast::IntOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) { dim3 launchDims; - launchDims.x = MAX_NUM_THREADS / 8; // threadsPerBlock - launchDims.y = (zArr.lengthOf() + launchDims.x - 1) / launchDims.x; // blocksPerGrid + launchDims.x = 128; //MAX_NUM_THREADS / 8; // threadsPerBlock + launchDims.y = 256; //(zArr.lengthOf() + launchDims.x - 1) / launchDims.x; // blocksPerGrid launchDims.z = 1024; // sharedMem PointersManager manager(xArr.getContext(), "TrueBroadcastIntHelper::exec"); diff --git a/libnd4j/tests_cpu/layers_tests/BroadcastableOpsTests.cpp b/libnd4j/tests_cpu/layers_tests/BroadcastableOpsTests.cpp index 036117aa9..238c2f15d 100644 --- a/libnd4j/tests_cpu/layers_tests/BroadcastableOpsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/BroadcastableOpsTests.cpp @@ -862,3 +862,19 @@ TEST_F(BroadcastableOpsTests, test_bert_multiply_1) { ASSERT_EQ(e, z); } + +TEST_F(BroadcastableOpsTests, test_bert_multiply_2) { + auto x = NDArrayFactory::create('c', {4, 128, 1}); + auto y = NDArrayFactory::create('c', {768}); + auto z = NDArrayFactory::create('c', {4, 128, 768}); + auto e = NDArrayFactory::create('c', {4, 128, 768}); + + x.assign(1.f); + y.assign(2.f); + z.assign(119.f); + e.assign(2.f); + + x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), &y, &z); + + ASSERT_EQ(e, z); +} From 0175ace4c3936d65a81c46f9aa984e4b4198c8ba Mon Sep 17 00:00:00 2001 From: Alex Black Date: Mon, 9 Dec 2019 23:08:00 +1100 Subject: [PATCH 15/18] Small tweaks (#119) Signed-off-by: AlexDBlack --- .../debugging/ArraySavingListener.java | 27 +++++++++++++------ .../transforms/pairwise/arithmetic/DivOp.java | 4 +-- 2 files changed, 21 insertions(+), 10 deletions(-) diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/debugging/ArraySavingListener.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/debugging/ArraySavingListener.java index 9137fc831..6b64c69d8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/debugging/ArraySavingListener.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/debugging/ArraySavingListener.java @@ -7,7 +7,9 @@ import org.nd4j.autodiff.listeners.Operation; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.internal.SameDiffOp; import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Xor; import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.factory.Nd4j; @@ -81,14 +83,23 @@ public class ArraySavingListener extends BaseListener { if(eq){ System.out.println("Equals: " + varName.replaceAll("__", "/")); } else { - INDArray sub = arr1.sub(arr2); - INDArray diff = Nd4j.math.abs(sub); - double maxDiff = diff.maxNumber().doubleValue(); - System.out.println("FAILS: " + varName.replaceAll("__", "/") + " - max difference = " + maxDiff); - System.out.println("\t" + f.getAbsolutePath()); - System.out.println("\t" + f2.getAbsolutePath()); - sub.close(); - diff.close();; + if(arr1.dataType() == DataType.BOOL){ + INDArray xor = Nd4j.exec(new Xor(arr1, arr2)); + int count = xor.castTo(DataType.INT).sumNumber().intValue(); + System.out.println("FAILS: " + varName.replaceAll("__", "/") + " - boolean, # differences = " + count); + System.out.println("\t" + f.getAbsolutePath()); + System.out.println("\t" + f2.getAbsolutePath()); + xor.close(); + } else { + INDArray sub = arr1.sub(arr2); + INDArray diff = Nd4j.math.abs(sub); + double maxDiff = diff.maxNumber().doubleValue(); + System.out.println("FAILS: " + varName.replaceAll("__", "/") + " - max difference = " + maxDiff); + System.out.println("\t" + f.getAbsolutePath()); + System.out.println("\t" + f2.getAbsolutePath()); + sub.close(); + diff.close(); + } } arr1.close(); arr2.close(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/DivOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/DivOp.java index 5273a2941..b76942e95 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/DivOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/DivOp.java @@ -57,8 +57,8 @@ public class DivOp extends BaseDynamicTransformOp { } @Override - public String[] tensorflowNames() { - return new String[]{"Div","RealDiv"}; + public String tensorflowName() { + return "Div"; } From 425c747330c63152d84ffde25dda6aec98d2481f Mon Sep 17 00:00:00 2001 From: Yurii Shyrma Date: Mon, 9 Dec 2019 19:08:36 +0200 Subject: [PATCH 16/18] - permute threadsPerBlock and blocksPerGrid in signature of launching of cuda kernel for trueBroadcast op (#120) Signed-off-by: Yurii --- .../helpers/cuda/TrueBroadcastHelper.cu | 64 +++++++++++-------- 1 file changed, 37 insertions(+), 27 deletions(-) diff --git a/libnd4j/include/helpers/cuda/TrueBroadcastHelper.cu b/libnd4j/include/helpers/cuda/TrueBroadcastHelper.cu index fdbd001fd..12c3eb0c5 100644 --- a/libnd4j/include/helpers/cuda/TrueBroadcastHelper.cu +++ b/libnd4j/include/helpers/cuda/TrueBroadcastHelper.cu @@ -42,9 +42,12 @@ __global__ static void trueBroadcastCuda(const void* vx, const Nd4jLong* xShapeI auto z = reinterpret_cast(vz); __shared__ int xRank, yRank, zRank; - __shared__ Nd4jLong zLen, totalThreads; // xLen == zLen, except when xRank = 1, in this case zLen = 2*xLen + __shared__ Nd4jLong zLen, totalThreads, *sharedMem; // xLen == zLen, except when xRank = 1, in this case zLen = 2*xLen if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); + xRank = shape::rank(xShapeInfo); yRank = shape::rank(yShapeInfo); zRank = shape::rank(zShapeInfo); @@ -54,9 +57,9 @@ __global__ static void trueBroadcastCuda(const void* vx, const Nd4jLong* xShapeI } __syncthreads(); - Nd4jLong xCoords[MAX_RANK]; // = sharedMem + threadIdx.x * (xRank + yRank + zRank); - Nd4jLong yCoords[MAX_RANK]; // = xCoords + xRank; - Nd4jLong zCoords[MAX_RANK]; // = yCoords + yRank; + auto xCoords = sharedMem + threadIdx.x * (xRank + yRank + zRank); + auto yCoords = xCoords + xRank; + auto zCoords = yCoords + yRank; const auto tid = blockIdx.x * blockDim.x + threadIdx.x; @@ -66,19 +69,17 @@ __global__ static void trueBroadcastCuda(const void* vx, const Nd4jLong* xShapeI for(int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) { - if(ix >= 0) { - if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) + if(ix >= 0) + if(xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) xCoords[ix--] = zCoords[iz]; else xCoords[ix--] = 0; - } - if(iy >= 0) { - if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) + if(iy >= 0) + if(yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) yCoords[iy--] = zCoords[iz]; else yCoords[iy--] = 0; - } } const auto xOffset = shape::getOffset(xShapeInfo, xCoords); @@ -93,6 +94,7 @@ __global__ static void trueBroadcastCuda(const void* vx, const Nd4jLong* xShapeI template template void TrueBroadcastHelper::execLauncher(dim3 launchDims, cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo) { + trueBroadcastCuda<<>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo); } @@ -102,9 +104,9 @@ void TrueBroadcastHelper::exec(const nd4j::broadcast::Ops opNum, const ND dim3 launchDims; - launchDims.x = 128; //MAX_NUM_THREADS / 8; // threadsPerBlock - launchDims.y = 256; //(zArr.lengthOf() + launchDims.x - 1) / launchDims.x; // blocksPerGrid - launchDims.z = 1024; // sharedMem + launchDims.y = MAX_NUM_THREADS / 8; // threadsPerBlock + launchDims.x = (zArr.lengthOf() + launchDims.y - 1) / launchDims.y; // blocksPerGrid + launchDims.z = sizeof(Nd4jLong) * launchDims.y * (xArr.rankOf() + yArr.rankOf() + zArr.rankOf()) + 128; // sharedMe PointersManager manager(xArr.getContext(), "TrueBroadcastHelper::exec"); @@ -126,9 +128,12 @@ __global__ static void trueBroadcastBoolCuda(const void* vx, const Nd4jLong* xSh auto z = reinterpret_cast(vz); __shared__ int xRank, yRank, zRank; - __shared__ Nd4jLong zLen, totalThreads; // xLen == zLen, except when xRank = 1, in this case zLen = 2*xLen + __shared__ Nd4jLong zLen, totalThreads, *sharedMem; // xLen == zLen, except when xRank = 1, in this case zLen = 2*xLen if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); + xRank = shape::rank(xShapeInfo); yRank = shape::rank(yShapeInfo); zRank = shape::rank(zShapeInfo); @@ -138,9 +143,9 @@ __global__ static void trueBroadcastBoolCuda(const void* vx, const Nd4jLong* xSh } __syncthreads(); - Nd4jLong xCoords[MAX_RANK]; // = sharedMem + threadIdx.x * (xRank + yRank + zRank); - Nd4jLong yCoords[MAX_RANK]; // = xCoords + xRank; - Nd4jLong zCoords[MAX_RANK]; // = yCoords + yRank; + auto xCoords = sharedMem + threadIdx.x * (xRank + yRank + zRank); + auto yCoords = xCoords + xRank; + auto zCoords = yCoords + yRank; const auto tid = blockIdx.x * blockDim.x + threadIdx.x; @@ -184,9 +189,10 @@ template void TrueBroadcastBoolHelper::exec(const nd4j::broadcast::BoolOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) { dim3 launchDims; - launchDims.x = 128; //MAX_NUM_THREADS / 8; // threadsPerBlock - launchDims.y = 256; //(zArr.lengthOf() + launchDims.x - 1) / launchDims.x; // blocksPerGrid - launchDims.z = 1024; // sharedMem + + launchDims.y = MAX_NUM_THREADS / 8; // threadsPerBlock + launchDims.x = (zArr.lengthOf() + launchDims.y - 1) / launchDims.y; // blocksPerGrid + launchDims.z = sizeof(Nd4jLong) * launchDims.y * (xArr.rankOf() + yArr.rankOf() + zArr.rankOf()) + 128; // sharedMe PointersManager manager(xArr.getContext(), "TrueBroadcastBoolHelper::exec"); @@ -208,9 +214,12 @@ __global__ static void trueBroadcastIntCuda(const void* vx, const Nd4jLong* xSha auto z = reinterpret_cast(vz); __shared__ int xRank, yRank, zRank; - __shared__ Nd4jLong zLen, totalThreads; // xLen == zLen, except when xRank = 1, in this case zLen = 2*xLen + __shared__ Nd4jLong zLen, totalThreads, *sharedMem; // xLen == zLen, except when xRank = 1, in this case zLen = 2*xLen if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); + xRank = shape::rank(xShapeInfo); yRank = shape::rank(yShapeInfo); zRank = shape::rank(zShapeInfo); @@ -220,9 +229,9 @@ __global__ static void trueBroadcastIntCuda(const void* vx, const Nd4jLong* xSha } __syncthreads(); - Nd4jLong xCoords[MAX_RANK]; // = sharedMem + threadIdx.x * (xRank + yRank + zRank); - Nd4jLong yCoords[MAX_RANK]; // = xCoords + xRank; - Nd4jLong zCoords[MAX_RANK]; // = yCoords + yRank; + auto xCoords = sharedMem + threadIdx.x * (xRank + yRank + zRank); + auto yCoords = xCoords + xRank; + auto zCoords = yCoords + yRank; const auto tid = blockIdx.x * blockDim.x + threadIdx.x; @@ -266,9 +275,10 @@ template void TrueBroadcastIntHelper::exec(const nd4j::broadcast::IntOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) { dim3 launchDims; - launchDims.x = 128; //MAX_NUM_THREADS / 8; // threadsPerBlock - launchDims.y = 256; //(zArr.lengthOf() + launchDims.x - 1) / launchDims.x; // blocksPerGrid - launchDims.z = 1024; // sharedMem + + launchDims.y = MAX_NUM_THREADS / 8; // threadsPerBlock + launchDims.x = (zArr.lengthOf() + launchDims.y - 1) / launchDims.y; // blocksPerGrid + launchDims.z = sizeof(Nd4jLong) * launchDims.y * (xArr.rankOf() + yArr.rankOf() + zArr.rankOf()) + 128; // sharedMe PointersManager manager(xArr.getContext(), "TrueBroadcastIntHelper::exec"); From a5f5ac72b10c2a3bcc312a28cceed989c60a4f3e Mon Sep 17 00:00:00 2001 From: raver119 Date: Mon, 9 Dec 2019 20:08:59 +0300 Subject: [PATCH 17/18] reduce bool changes (#118) * reduce bool changes Signed-off-by: raver119 * reduce bool tweaks Signed-off-by: raver119 --- .../include/loops/cuda/reduce/reduce_bool.cu | 8 ++++- .../loops/cuda/reduce/reduce_float.chpp | 2 +- .../nd4j/linalg/api/ops/BaseReduceBoolOp.java | 3 ++ .../linalg/api/ops/impl/reduce/bool/All.java | 9 ++++++ .../linalg/api/ops/impl/reduce/bool/Any.java | 5 ++++ .../api/ops/impl/reduce/bool/IsInf.java | 4 +++ .../api/ops/impl/reduce/bool/IsNaN.java | 4 +++ .../ops/executioner/CudaExecutioner.java | 18 +++++++++-- .../nativecpu/ops/NativeOpExecutioner.java | 11 +++++++ .../test/java/org/nd4j/linalg/Nd4jTestsC.java | 30 +++++++++++++++++++ 10 files changed, 89 insertions(+), 5 deletions(-) diff --git a/libnd4j/include/loops/cuda/reduce/reduce_bool.cu b/libnd4j/include/loops/cuda/reduce/reduce_bool.cu index a785094f1..52ca3decc 100644 --- a/libnd4j/include/loops/cuda/reduce/reduce_bool.cu +++ b/libnd4j/include/loops/cuda/reduce/reduce_bool.cu @@ -237,6 +237,8 @@ template template __host__ void ReduceBoolFunction::intermediateXD(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, Nd4jLong *hXShapeInfo, void *extraParams, void *z, Nd4jLong *zShapeInfo, Nd4jLong *hZShapeInfo, int *dimension, int dimensionLength, void *reductionPointer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { + nd4j_printf("Step A%i\n", -1); + if(shape::isEmpty(hXShapeInfo)) { if(shape::isEmpty(hZShapeInfo)) @@ -251,7 +253,8 @@ __host__ void ReduceBoolFunction::intermediateXD(dim3 launchDims, cudaStrea auto ptr = nd4j::LaunchContext::defaultContext()->getScalarPointer(); // scalar assign - functions::scalar::ScalarTransform::executeCudaShaped(launchDims, stream, 14, z, zShapeInfo, hXShapeInfo, z, zShapeInfo, hZShapeInfo, ptr, nullptr); + functions::scalar::ScalarTransform::executeCudaShaped(launchDims, stream, 14, z, zShapeInfo, hZShapeInfo, z, zShapeInfo, hZShapeInfo, ptr, nullptr); + nd4j::DebugHelper::checkErrorCode(stream, "reduceBoolDim empty(...) failed"); } else { simpleReduce<<>>(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, reductionPointer, tadShapeInfo, tadOffsets); @@ -274,6 +277,9 @@ __host__ void ReduceBoolFunction::intermediateScalar(dim3 launchDims, cudaS auto res = cudaMemcpyAsync(z, &startingVal, sizeof(Z), cudaMemcpyHostToDevice, *stream); if (res != 0) throw nd4j::cuda_exception::build("ReduceBoolFunction::intermediateScalar: failed to copy resulting scalar", res); + + nd4j::DebugHelper::checkErrorCode(stream, "reduceBoolScalar empty(...) failed"); + } else { simpleScalar<<>>(x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength, reductionBuffer, tadOnlyShapeInfo); diff --git a/libnd4j/include/loops/cuda/reduce/reduce_float.chpp b/libnd4j/include/loops/cuda/reduce/reduce_float.chpp index ef366caf7..110cc0f68 100644 --- a/libnd4j/include/loops/cuda/reduce/reduce_float.chpp +++ b/libnd4j/include/loops/cuda/reduce/reduce_float.chpp @@ -249,7 +249,7 @@ __host__ void ReduceFloatFunction::intermediateXD(dim3 launchDims, cudaStre auto ptr = nd4j::LaunchContext::defaultContext()->getScalarPointer(); // scalar assign - functions::scalar::ScalarTransform::executeCudaShaped(launchDims, stream, 14, z, zShape, hXShapeInfo, z, zShape, hZShapeInfo, ptr, nullptr); + functions::scalar::ScalarTransform::executeCudaShaped(launchDims, stream, 14, z, zShape, hZShapeInfo, z, zShape, hZShapeInfo, ptr, nullptr); } else { simpleReduce<<>>(x, xShape, extraParams, z, zShape, dimension, dimensionLength, reductionPointer, tadShapeInfo, tadOffsets); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceBoolOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceBoolOp.java index 6e2801c67..dd2072758 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceBoolOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceBoolOp.java @@ -102,4 +102,7 @@ public abstract class BaseReduceBoolOp extends BaseReduceOp implements ReduceBoo "with 2 inputs, second input (axis) must be an integer datatype for %s, got %s", getClass(), dataTypes); return Collections.singletonList(DataType.BOOL); } + + + public abstract boolean emptyValue(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/All.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/All.java index 60b835135..a465728d1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/All.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/All.java @@ -41,6 +41,10 @@ public class All extends BaseReduceBoolOp { super(x); } + public All(INDArray x, int... axis) { + super(x, axis); + } + @Override public int opNum() { return 1; @@ -65,4 +69,9 @@ public class All extends BaseReduceBoolOp { public String tensorflowName() { return "All"; } + + @Override + public boolean emptyValue() { + return true; + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/Any.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/Any.java index 1cd31d19d..7daebd4cf 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/Any.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/Any.java @@ -65,4 +65,9 @@ public class Any extends BaseReduceBoolOp { public String tensorflowName() { return "Any"; } + + @Override + public boolean emptyValue() { + return false; + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/IsInf.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/IsInf.java index c74acf734..cb93a832e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/IsInf.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/IsInf.java @@ -71,4 +71,8 @@ public class IsInf extends BaseReduceBoolOp { return Collections.singletonList(f().zerosLike(arg())); } + @Override + public boolean emptyValue() { + return false; + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/IsNaN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/IsNaN.java index 611219d3e..c8cd72f2c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/IsNaN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/IsNaN.java @@ -71,4 +71,8 @@ public class IsNaN extends BaseReduceBoolOp { return Collections.singletonList(f().zerosLike(arg())); } + @Override + public boolean emptyValue() { + return false; + } } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java index 904e1305e..16568fbf4 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java @@ -935,6 +935,18 @@ public class CudaExecutioner extends DefaultOpExecutioner { } } + // FIXME: this should be moved down to C++ on per-op basis + // reduce to scalar case, ReduceBool ops require special treatment + if (op instanceof BaseReduceBoolOp && op.x().isEmpty() && (dimension == null || (dimension.length == 1 && dimension[0] == Integer.MAX_VALUE))) { + if (op.z() == null) { + op.setZ(Nd4j.scalar(((BaseReduceBoolOp) op).emptyValue())); + } else { + op.z().assign(((BaseReduceBoolOp) op).emptyValue()); + } + + return context; + } + long st = profilingConfigurableHookIn(op); checkForCompression(op); @@ -994,9 +1006,9 @@ public class CudaExecutioner extends DefaultOpExecutioner { } } - if (op.x().isVector() && op.x().length() == ArrayUtil.prod(retShape)) { - return null; - } + //if (op.x().isVector() && op.x().length() == ArrayUtil.prod(retShape)) { + // return null; + //} val dataType = op.resultType(); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java index b6af2e5f2..d12efba59 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java @@ -265,7 +265,18 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { } } + // FIXME: this should be moved down to C++ on per-op basis val dimension = Shape.normalizeAxis(op.x().rank(), op.dimensions().toIntVector()); + // reduce to scalar case, ReduceBool ops require special treatment + if (op instanceof BaseReduceBoolOp && op.x().isEmpty() && (dimension == null || (dimension.length == 1 && dimension[0] == Integer.MAX_VALUE))) { + if (op.z() == null) { + op.setZ(Nd4j.scalar(((BaseReduceBoolOp) op).emptyValue())); + } else { + op.z().assign(((BaseReduceBoolOp) op).emptyValue()); + } + + return op.z(); + } //validateDataType(Nd4j.dataType(), op); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java index bac06b981..68551e53d 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java @@ -8134,6 +8134,36 @@ public class Nd4jTestsC extends BaseNd4jTest { assertEquals(Nd4j.createFromArray(1.0,2,3,4,5,6), hStack); } + + @Test + public void testReduceAll_1() { + val x = Nd4j.empty(DataType.FLOAT); + val e = Nd4j.scalar(true); + val z = Nd4j.exec(new All(x)); + + assertEquals(e, z); + } + + @Test + public void testReduceAll_2() { + val x = Nd4j.ones(DataType.FLOAT, 0); + val e = Nd4j.scalar(true); + val z = Nd4j.exec(new All(x)); + + assertEquals(e, z); + } + + @Test + public void testReduceAll_3() { + val x = Nd4j.create(DataType.FLOAT, 0); + assertEquals(1, x.rank()); + + val e = Nd4j.scalar(true); + val z = Nd4j.exec(new All(x, 0)); + + assertEquals(e, z); + } + @Override public char ordering() { return 'c'; From 4920f22fffaf00aafa4373252ef1d6fe705f79f5 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Tue, 10 Dec 2019 12:11:05 +1100 Subject: [PATCH 18/18] Check for empty streams for NativeImageLoader + test (#121) Signed-off-by: AlexDBlack --- .../image/loader/NativeImageLoader.java | 8 ++++ .../image/loader/TestNativeImageLoader.java | 44 +++++++++++++++++++ 2 files changed, 52 insertions(+) diff --git a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/NativeImageLoader.java b/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/NativeImageLoader.java index 8f482846b..d2be87536 100644 --- a/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/NativeImageLoader.java +++ b/datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/loader/NativeImageLoader.java @@ -24,6 +24,7 @@ import org.bytedeco.javacv.OpenCVFrameConverter; import org.datavec.image.data.Image; import org.datavec.image.data.ImageWritable; import org.datavec.image.transform.ImageTransform; +import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.concurrency.AffinityManager; import org.nd4j.linalg.api.memory.pointers.PagedPointer; import org.nd4j.linalg.api.ndarray.INDArray; @@ -284,6 +285,9 @@ public class NativeImageLoader extends BaseImageLoader { private Mat streamToMat(InputStream is) throws IOException { if(buffer == null){ buffer = IOUtils.toByteArray(is); + if(buffer.length <= 0){ + throw new IOException("Could not decode image from input stream: input stream was empty (no data)"); + } bufferMat = new Mat(buffer); return bufferMat; } else { @@ -292,6 +296,10 @@ public class NativeImageLoader extends BaseImageLoader { //(a) if numRead < buffer.length - got everything //(b) if numRead >= buffer.length: we MIGHT have got everything (exact right size buffer) OR we need more data + if(numReadTotal <= 0){ + throw new IOException("Could not decode image from input stream: input stream was empty (no data)"); + } + if(numReadTotal < buffer.length){ bufferMat.data().put(buffer, 0, numReadTotal); bufferMat.cols(numReadTotal); diff --git a/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/loader/TestNativeImageLoader.java b/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/loader/TestNativeImageLoader.java index 544e46b77..5f634bab8 100644 --- a/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/loader/TestNativeImageLoader.java +++ b/datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/loader/TestNativeImageLoader.java @@ -24,7 +24,9 @@ import org.bytedeco.javacv.Frame; import org.bytedeco.javacv.Java2DFrameConverter; import org.bytedeco.javacv.OpenCVFrameConverter; import org.datavec.image.data.ImageWritable; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.TemporaryFolder; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -55,6 +57,9 @@ public class TestNativeImageLoader { static final long seed = 10; static final Random rng = new Random(seed); + @Rule + public TemporaryFolder testDir = new TemporaryFolder(); + @Test public void testConvertPix() throws Exception { PIX pix; @@ -554,4 +559,43 @@ public class TestNativeImageLoader { assertEquals(img1LargeBuffer, img1ExactBuffer); } + + @Test + public void testNativeImageLoaderEmptyStreams() throws Exception { + File dir = testDir.newFolder(); + File f = new File(dir, "myFile.jpg"); + f.createNewFile(); + + NativeImageLoader nil = new NativeImageLoader(32, 32, 3); + + try(InputStream is = new FileInputStream(f)){ + nil.asMatrix(is); + } catch (IOException e){ + String msg = e.getMessage(); + assertTrue(msg, msg.contains("decode image")); + } + + try(InputStream is = new FileInputStream(f)){ + nil.asImageMatrix(is); + } catch (IOException e){ + String msg = e.getMessage(); + assertTrue(msg, msg.contains("decode image")); + } + + try(InputStream is = new FileInputStream(f)){ + nil.asRowVector(is); + } catch (IOException e){ + String msg = e.getMessage(); + assertTrue(msg, msg.contains("decode image")); + } + + try(InputStream is = new FileInputStream(f)){ + INDArray arr = Nd4j.create(DataType.FLOAT, 1, 3, 32, 32); + nil.asMatrixView(is, arr); + } catch (IOException e){ + String msg = e.getMessage(); + assertTrue(msg, msg.contains("decode image")); + } + } + }