Update master (#8511)

* cleaned up bert iterator tests (#110)

Signed-off-by: eraly <susan.eraly@gmail.com>

* Various pre-release fixes (#111)

* Various fixes

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fix default dtypes for MaxPoolWithArgmax

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Small pre-release tweak (#112)

* Log UI address on launch as in previous Play-based UI

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Logging level tweak for UI

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* http not https

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* datavec python ensure host (#113)

* ensure host

* one more host ensure

* info->debug

* [WIP] reverse improvements (#115)

* initial commit

Signed-off-by: raver119 <raver119@gmail.com>

* reverse draft

Signed-off-by: raver119 <raver119@gmail.com>

* reverse kernel

Signed-off-by: raver119 <raver119@gmail.com>

* reverse kernel

Signed-off-by: raver119 <raver119@gmail.com>

* 2 micro fixes

Signed-off-by: raver119 <raver119@gmail.com>

* Shugeo resize fix5 (#102)

* Refactored resize images ops to use TF-like bool args as input.

* Refactored helpers for cpu implementation of resize_bilinear and resize_nearest_neighbor ops.

* Refactored cuda implementation for image.resize_bilinear and image.resize_nearest_neighbor ops helpers.

* Refactored nearest_neighbor resize op.

* Added a pair of tests for special case of resize_bilinear algorithm.

* Fixed issue with resize_bilinear op.

* Refactored cpu implementation for helpers with resize_nearest_neighbor op.

* Final fixed for resize ops to conform TF v.1.5

* Refactored cuda helpers for resize_neares_neighbor op.

* Fixed resize_bilinear to accept proper data.

* Fixed issue with non-float input for resize_bilinear op.

* Refactored cuda helper for resize_bilinear to proper process non-float inputs.

* Added tests for resize_bilinear to int inputs.

* Fixed ResizeBilinear wrapper

* Tests fixed

* Fixed float and bool constant to avoid overflow for some kind of compilers.

* Corrected float constants with float data type.

* Added f suffix for float constants.

* Corrected float constant to avoid overflow with initializing lists.

* Corrected float initializing list with float input.

* Corrected bool constant with initalizing list.

* Corrected float and bool values with initializing lists.

* Fixed wrong constant.

* Fixed issue with 1x1 input picture for resize.

* ResizeBilinear default values on import fix

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-12-06 11:10:44 +03:00 committed by GitHub
parent a6223d307b
commit 972fae60dc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
39 changed files with 1420 additions and 883 deletions

View File

@ -21,6 +21,7 @@ import lombok.Getter;
import lombok.NoArgsConstructor;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
@ -60,6 +61,7 @@ public class NumpyArray {
setND4JArray();
if (copy){
nd4jArray = nd4jArray.dup();
Nd4j.getAffinityManager().ensureLocation(nd4jArray, AffinityManager.Location.HOST);
this.address = nd4jArray.data().address();
}
@ -85,6 +87,7 @@ public class NumpyArray {
setND4JArray();
if (copy){
nd4jArray = nd4jArray.dup();
Nd4j.getAffinityManager().ensureLocation(nd4jArray, AffinityManager.Location.HOST);
this.address = nd4jArray.data().address();
}
}
@ -104,11 +107,12 @@ public class NumpyArray {
nd4jStrides[i] = strides[i] / elemSize;
}
this.nd4jArray = Nd4j.create(buff, shape, nd4jStrides, 0, Shape.getOrder(shape,nd4jStrides,1), dtype);
nd4jArray = Nd4j.create(buff, shape, nd4jStrides, 0, Shape.getOrder(shape,nd4jStrides,1), dtype);
Nd4j.getAffinityManager().ensureLocation(nd4jArray, AffinityManager.Location.HOST);
}
public NumpyArray(INDArray nd4jArray){
Nd4j.getAffinityManager().ensureLocation(nd4jArray, AffinityManager.Location.HOST);
DataBuffer buff = nd4jArray.data();
address = buff.pointer().address();
shape = nd4jArray.shape();

View File

@ -605,7 +605,7 @@ public class PythonExecutioner {
private static synchronized void _exec(String code) {
log.info(code);
log.debug(code);
log.info("CPython: PyRun_SimpleStringFlag()");
int result = PyRun_SimpleStringFlags(code, null);

View File

@ -17,11 +17,13 @@
package org.deeplearning4j.iterator;
import lombok.Getter;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.iterator.bert.BertMaskedLMMasker;
import org.deeplearning4j.iterator.provider.CollectionLabeledPairSentenceProvider;
import org.deeplearning4j.iterator.provider.CollectionLabeledSentenceProvider;
import org.deeplearning4j.text.tokenization.tokenizerfactory.BertWordPieceTokenizerFactory;
import org.junit.Test;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet;
@ -42,8 +44,12 @@ import static org.junit.Assert.*;
public class TestBertIterator extends BaseDL4JTest {
private File pathToVocab = Resources.asFile("other/vocab.txt");
private static File pathToVocab = Resources.asFile("other/vocab.txt");
private static Charset c = StandardCharsets.UTF_8;
private static String shortSentence = "I saw a girl with a telescope.";
private static String longSentence = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum";
private static String sentenceA = "Goodnight noises everywhere";
private static String sentenceB = "Goodnight moon";
public TestBertIterator() throws IOException {
}
@ -51,20 +57,15 @@ public class TestBertIterator extends BaseDL4JTest {
@Test(timeout = 20000L)
public void testBertSequenceClassification() throws Exception {
String toTokenize1 = "I saw a girl with a telescope.";
String toTokenize2 = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum";
List<String> forInference = new ArrayList<>();
forInference.add(toTokenize1);
forInference.add(toTokenize2);
BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
int minibatchSize = 2;
TestSentenceHelper testHelper = new TestSentenceHelper();
BertIterator b = BertIterator.builder()
.tokenizer(t)
.tokenizer(testHelper.getTokenizer())
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, 16)
.minibatchSize(2)
.sentenceProvider(new TestSentenceProvider())
.minibatchSize(minibatchSize)
.sentenceProvider(testHelper.getSentenceProvider())
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK)
.vocabMap(t.getVocab())
.vocabMap(testHelper.getTokenizer().getVocab())
.task(BertIterator.Task.SEQ_CLASSIFICATION)
.build();
@ -73,82 +74,77 @@ public class TestBertIterator extends BaseDL4JTest {
System.out.println(mds.getFeatures(0));
System.out.println(mds.getFeaturesMaskArray(0));
INDArray expEx0 = Nd4j.create(DataType.INT, 1, 16);
INDArray expM0 = Nd4j.create(DataType.INT, 1, 16);
List<String> tokens = t.create(toTokenize1).getTokens();
Map<String, Integer> m = t.getVocab();
for (int i = 0; i < tokens.size(); i++) {
int idx = m.get(tokens.get(i));
expEx0.putScalar(0, i, idx);
expM0.putScalar(0, i, 1);
}
INDArray expEx1 = Nd4j.create(DataType.INT, 1, 16);
INDArray expM1 = Nd4j.create(DataType.INT, 1, 16);
List<String> tokens2 = t.create(toTokenize2).getTokens();
for (int i = 0; i < tokens2.size(); i++) {
String token = tokens2.get(i);
INDArray expF = Nd4j.create(DataType.INT, 1, 16);
INDArray expM = Nd4j.create(DataType.INT, 1, 16);
Map<String, Integer> m = testHelper.getTokenizer().getVocab();
for (int i = 0; i < minibatchSize; i++) {
INDArray expFTemp = Nd4j.create(DataType.INT, 1, 16);
INDArray expMTemp = Nd4j.create(DataType.INT, 1, 16);
List<String> tokens = testHelper.getTokenizedSentences().get(i);
System.out.println(tokens);
for (int j = 0; j < tokens.size(); j++) {
String token = tokens.get(j);
if (!m.containsKey(token)) {
throw new IllegalStateException("Unknown token: \"" + token + "\"");
}
int idx = m.get(token);
expEx1.putScalar(0, i, idx);
expM1.putScalar(0, i, 1);
expFTemp.putScalar(0, j, idx);
expMTemp.putScalar(0, j, 1);
}
if (i == 0) {
expF = expFTemp.dup();
expM = expMTemp.dup();
} else {
expF = Nd4j.vstack(expF, expFTemp);
expM = Nd4j.vstack(expM, expMTemp);
}
}
INDArray expF = Nd4j.vstack(expEx0, expEx1);
INDArray expM = Nd4j.vstack(expM0, expM1);
assertEquals(expF, mds.getFeatures(0));
assertEquals(expM, mds.getFeaturesMaskArray(0));
assertEquals(expF, b.featurizeSentences(forInference).getFirst()[0]);
assertEquals(expM, b.featurizeSentences(forInference).getSecond()[0]);
assertEquals(expF, b.featurizeSentences(testHelper.getSentences()).getFirst()[0]);
assertEquals(expM, b.featurizeSentences(testHelper.getSentences()).getSecond()[0]);
b.next(); //pop the third element
assertFalse(b.hasNext());
b.reset();
assertTrue(b.hasNext());
forInference.set(0, toTokenize2);
//Same thing, but with segment ID also
b = BertIterator.builder()
.tokenizer(t)
.tokenizer(testHelper.getTokenizer())
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, 16)
.minibatchSize(2)
.sentenceProvider(new TestSentenceProvider())
.minibatchSize(minibatchSize)
.sentenceProvider(testHelper.getSentenceProvider())
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID)
.vocabMap(t.getVocab())
.vocabMap(testHelper.getTokenizer().getVocab())
.task(BertIterator.Task.SEQ_CLASSIFICATION)
.build();
mds = b.next();
assertEquals(2, mds.getFeatures().length);
//assertEquals(2, mds.getFeaturesMaskArrays().length); second element is null...
assertEquals(2, b.featurizeSentences(forInference).getFirst().length);
//Segment ID should be all 0s for single segment task
INDArray segmentId = expM.like();
assertEquals(segmentId, mds.getFeatures(1));
assertEquals(segmentId, b.featurizeSentences(forInference).getFirst()[1]);
assertEquals(segmentId, b.featurizeSentences(testHelper.getSentences()).getFirst()[1]);
}
@Test(timeout = 20000L)
public void testBertUnsupervised() throws Exception {
int minibatchSize = 2;
TestSentenceHelper testHelper = new TestSentenceHelper();
//Task 1: Unsupervised
BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
BertIterator b = BertIterator.builder()
.tokenizer(t)
.tokenizer(testHelper.getTokenizer())
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, 16)
.minibatchSize(2)
.sentenceProvider(new TestSentenceProvider())
.minibatchSize(minibatchSize)
.sentenceProvider(testHelper.getSentenceProvider())
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK)
.vocabMap(t.getVocab())
.vocabMap(testHelper.getTokenizer().getVocab())
.task(BertIterator.Task.UNSUPERVISED)
.masker(new BertMaskedLMMasker(new Random(12345), 0.2, 0.5, 0.5))
.unsupervisedLabelFormat(BertIterator.UnsupervisedLabelFormat.RANK2_IDX)
.maskToken("[MASK]")
.build();
System.out.println("Mask token index: " + t.getVocab().get("[MASK]"));
System.out.println("Mask token index: " + testHelper.getTokenizer().getVocab().get("[MASK]"));
MultiDataSet mds = b.next();
System.out.println(mds.getFeatures(0));
@ -156,7 +152,6 @@ public class TestBertIterator extends BaseDL4JTest {
System.out.println(mds.getLabels(0));
System.out.println(mds.getLabelsMaskArray(0));
b.next(); //pop the third element
assertFalse(b.hasNext());
b.reset();
assertTrue(b.hasNext());
@ -164,39 +159,33 @@ public class TestBertIterator extends BaseDL4JTest {
@Test(timeout = 20000L)
public void testLengthHandling() throws Exception {
String toTokenize1 = "I saw a girl with a telescope.";
String toTokenize2 = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum";
List<String> forInference = new ArrayList<>();
forInference.add(toTokenize1);
forInference.add(toTokenize2);
BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
INDArray expEx0 = Nd4j.create(DataType.INT, 1, 16);
INDArray expM0 = Nd4j.create(DataType.INT, 1, 16);
List<String> tokens = t.create(toTokenize1).getTokens();
int minibatchSize = 2;
TestSentenceHelper testHelper = new TestSentenceHelper();
INDArray expF = Nd4j.create(DataType.INT, 1, 16);
INDArray expM = Nd4j.create(DataType.INT, 1, 16);
Map<String, Integer> m = testHelper.getTokenizer().getVocab();
for (int i = 0; i < minibatchSize; i++) {
List<String> tokens = testHelper.getTokenizedSentences().get(i);
INDArray expFTemp = Nd4j.create(DataType.INT, 1, 16);
INDArray expMTemp = Nd4j.create(DataType.INT, 1, 16);
System.out.println(tokens);
Map<String, Integer> m = t.getVocab();
for (int i = 0; i < tokens.size(); i++) {
int idx = m.get(tokens.get(i));
expEx0.putScalar(0, i, idx);
expM0.putScalar(0, i, 1);
}
INDArray expEx1 = Nd4j.create(DataType.INT, 1, 16);
INDArray expM1 = Nd4j.create(DataType.INT, 1, 16);
List<String> tokens2 = t.create(toTokenize2).getTokens();
System.out.println(tokens2);
for (int i = 0; i < tokens2.size(); i++) {
String token = tokens2.get(i);
for (int j = 0; j < tokens.size(); j++) {
String token = tokens.get(j);
if (!m.containsKey(token)) {
throw new IllegalStateException("Unknown token: \"" + token + "\"");
}
int idx = m.get(token);
expEx1.putScalar(0, i, idx);
expM1.putScalar(0, i, 1);
expFTemp.putScalar(0, j, idx);
expMTemp.putScalar(0, j, 1);
}
if (i == 0) {
expF = expFTemp.dup();
expM = expMTemp.dup();
} else {
expF = Nd4j.vstack(expF, expFTemp);
expM = Nd4j.vstack(expM, expMTemp);
}
}
INDArray expF = Nd4j.vstack(expEx0, expEx1);
INDArray expM = Nd4j.vstack(expM0, expM1);
//--------------------------------------------------------------
@ -205,12 +194,12 @@ public class TestBertIterator extends BaseDL4JTest {
//Any length: as long as we need to fit longest sequence
BertIterator b = BertIterator.builder()
.tokenizer(t)
.tokenizer(testHelper.getTokenizer())
.lengthHandling(BertIterator.LengthHandling.ANY_LENGTH, -1)
.minibatchSize(2)
.sentenceProvider(new TestSentenceProvider())
.minibatchSize(minibatchSize)
.sentenceProvider(testHelper.getSentenceProvider())
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK)
.vocabMap(t.getVocab())
.vocabMap(testHelper.getTokenizer().getVocab())
.task(BertIterator.Task.SEQ_CLASSIFICATION)
.build();
MultiDataSet mds = b.next();
@ -219,20 +208,19 @@ public class TestBertIterator extends BaseDL4JTest {
assertArrayEquals(expShape, mds.getFeaturesMaskArray(0).shape());
assertEquals(expF.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 14)), mds.getFeatures(0));
assertEquals(expM.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 14)), mds.getFeaturesMaskArray(0));
assertEquals(mds.getFeatures(0), b.featurizeSentences(forInference).getFirst()[0]);
assertEquals(mds.getFeaturesMaskArray(0), b.featurizeSentences(forInference).getSecond()[0]);
assertEquals(mds.getFeatures(0), b.featurizeSentences(testHelper.getSentences()).getFirst()[0]);
assertEquals(mds.getFeaturesMaskArray(0), b.featurizeSentences(testHelper.getSentences()).getSecond()[0]);
//Clip only: clip to maximum, but don't pad if less
b = BertIterator.builder()
.tokenizer(t)
.tokenizer(testHelper.getTokenizer())
.lengthHandling(BertIterator.LengthHandling.CLIP_ONLY, 20)
.minibatchSize(2)
.sentenceProvider(new TestSentenceProvider())
.minibatchSize(minibatchSize)
.sentenceProvider(testHelper.getSentenceProvider())
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK)
.vocabMap(t.getVocab())
.vocabMap(testHelper.getTokenizer().getVocab())
.task(BertIterator.Task.SEQ_CLASSIFICATION)
.build();
mds = b.next();
expShape = new long[]{2, 14};
assertArrayEquals(expShape, mds.getFeatures(0).shape());
assertArrayEquals(expShape, mds.getFeaturesMaskArray(0).shape());
@ -241,54 +229,38 @@ public class TestBertIterator extends BaseDL4JTest {
@Test(timeout = 20000L)
public void testMinibatchPadding() throws Exception {
Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT);
String toTokenize1 = "I saw a girl with a telescope.";
String toTokenize2 = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum";
String toTokenize3 = "Goodnight noises everywhere";
List<String> forInference = new ArrayList<>();
forInference.add(toTokenize1);
forInference.add(toTokenize2);
forInference.add(toTokenize3);
BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
INDArray expEx0 = Nd4j.create(DataType.INT, 1, 16);
INDArray expM0 = Nd4j.create(DataType.INT, 1, 16);
List<String> tokens = t.create(toTokenize1).getTokens();
Map<String, Integer> m = t.getVocab();
for (int i = 0; i < tokens.size(); i++) {
int idx = m.get(tokens.get(i));
expEx0.putScalar(0, i, idx);
expM0.putScalar(0, i, 1);
}
INDArray expEx1 = Nd4j.create(DataType.INT, 1, 16);
INDArray expM1 = Nd4j.create(DataType.INT, 1, 16);
List<String> tokens2 = t.create(toTokenize2).getTokens();
for (int i = 0; i < tokens2.size(); i++) {
String token = tokens2.get(i);
if (!m.containsKey(token)) {
throw new IllegalStateException("Unknown token: \"" + token + "\"");
}
int idx = m.get(token);
expEx1.putScalar(0, i, idx);
expM1.putScalar(0, i, 1);
}
INDArray expEx3 = Nd4j.create(DataType.INT, 1, 16);
INDArray expM3 = Nd4j.create(DataType.INT, 1, 16);
List<String> tokens3 = t.create(toTokenize3).getTokens();
for (int i = 0; i < tokens3.size(); i++) {
String token = tokens3.get(i);
if (!m.containsKey(token)) {
throw new IllegalStateException("Unknown token: \"" + token + "\"");
}
int idx = m.get(token);
expEx3.putScalar(0, i, idx);
expM3.putScalar(0, i, 1);
}
int minibatchSize = 3;
TestSentenceHelper testHelper = new TestSentenceHelper(minibatchSize);
INDArray zeros = Nd4j.create(DataType.INT, 1, 16);
INDArray expF = Nd4j.vstack(expEx0, expEx1, expEx3, zeros);
INDArray expM = Nd4j.vstack(expM0, expM1, expM3, zeros);
INDArray expL = Nd4j.createFromArray(new float[][]{{1, 0}, {0, 1}, {1, 0}, {0, 0}});
INDArray expF = Nd4j.create(DataType.INT, 1, 16);
INDArray expM = Nd4j.create(DataType.INT, 1, 16);
Map<String, Integer> m = testHelper.getTokenizer().getVocab();
for (int i = 0; i < minibatchSize; i++) {
List<String> tokens = testHelper.getTokenizedSentences().get(i);
INDArray expFTemp = Nd4j.create(DataType.INT, 1, 16);
INDArray expMTemp = Nd4j.create(DataType.INT, 1, 16);
System.out.println(tokens);
for (int j = 0; j < tokens.size(); j++) {
String token = tokens.get(j);
if (!m.containsKey(token)) {
throw new IllegalStateException("Unknown token: \"" + token + "\"");
}
int idx = m.get(token);
expFTemp.putScalar(0, j, idx);
expMTemp.putScalar(0, j, 1);
}
if (i == 0) {
expF = expFTemp.dup();
expM = expMTemp.dup();
} else {
expF = Nd4j.vstack(expF.dup(), expFTemp);
expM = Nd4j.vstack(expM.dup(), expMTemp);
}
}
expF = Nd4j.vstack(expF, zeros);
expM = Nd4j.vstack(expM, zeros);
INDArray expL = Nd4j.createFromArray(new float[][]{{0, 1}, {1, 0}, {0, 1}, {0, 0}});
INDArray expLM = Nd4j.create(DataType.FLOAT, 4, 1);
expLM.putScalar(0, 0, 1);
expLM.putScalar(1, 0, 1);
@ -297,13 +269,13 @@ public class TestBertIterator extends BaseDL4JTest {
//--------------------------------------------------------------
BertIterator b = BertIterator.builder()
.tokenizer(t)
.tokenizer(testHelper.getTokenizer())
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, 16)
.minibatchSize(4)
.minibatchSize(minibatchSize + 1)
.padMinibatches(true)
.sentenceProvider(new TestSentenceProvider())
.sentenceProvider(testHelper.getSentenceProvider())
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID)
.vocabMap(t.getVocab())
.vocabMap(testHelper.getTokenizer().getVocab())
.task(BertIterator.Task.SEQ_CLASSIFICATION)
.build();
@ -323,170 +295,175 @@ public class TestBertIterator extends BaseDL4JTest {
assertEquals(expL, mds.getLabels(0));
assertEquals(expLM, mds.getLabelsMaskArray(0));
assertEquals(expF, b.featurizeSentences(forInference).getFirst()[0]);
assertEquals(expM, b.featurizeSentences(forInference).getSecond()[0]);
assertEquals(expF, b.featurizeSentences(testHelper.getSentences()).getFirst()[0]);
assertEquals(expM, b.featurizeSentences(testHelper.getSentences()).getSecond()[0]);
}
/*
Checks that a mds from a pair sentence is equal to hstack'd mds from the left side and right side of the pair
Checks different lengths for max length to check popping and padding
*/
@Test
public void testSentencePairsSingle() throws IOException {
String shortSent = "I saw a girl with a telescope.";
String longSent = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum";
boolean prependAppend;
BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
int shortL = t.create(shortSent).countTokens();
int longL = t.create(longSent).countTokens();
int numOfSentences;
TestSentenceHelper testHelper = new TestSentenceHelper();
int shortL = testHelper.getShortestL();
int longL = testHelper.getLongestL();
Triple<MultiDataSet, MultiDataSet, MultiDataSet> multiDataSetTriple;
MultiDataSet shortLongPair, shortSentence, longSentence;
MultiDataSet fromPair, leftSide, rightSide;
// check for pair max length exactly equal to sum of lengths - pop neither no padding
// should be the same as hstack with segment ids 1 for second sentence
prependAppend = true;
multiDataSetTriple = generateMultiDataSets(new Triple<>(shortL + longL, shortL, longL), prependAppend);
shortLongPair = multiDataSetTriple.getFirst();
shortSentence = multiDataSetTriple.getSecond();
longSentence = multiDataSetTriple.getThird();
assertEquals(shortLongPair.getFeatures(0), Nd4j.hstack(shortSentence.getFeatures(0), longSentence.getFeatures(0)));
longSentence.getFeatures(1).addi(1);
assertEquals(shortLongPair.getFeatures(1), Nd4j.hstack(shortSentence.getFeatures(1), longSentence.getFeatures(1)));
assertEquals(shortLongPair.getFeaturesMaskArray(0), Nd4j.hstack(shortSentence.getFeaturesMaskArray(0), longSentence.getFeaturesMaskArray(0)));
numOfSentences = 1;
multiDataSetTriple = generateMultiDataSets(new Triple<>(shortL + longL, shortL, longL), prependAppend, numOfSentences);
fromPair = multiDataSetTriple.getFirst();
leftSide = multiDataSetTriple.getSecond();
rightSide = multiDataSetTriple.getThird();
assertEquals(fromPair.getFeatures(0), Nd4j.hstack(leftSide.getFeatures(0), rightSide.getFeatures(0)));
rightSide.getFeatures(1).addi(1); //add 1 for right side segment ids
assertEquals(fromPair.getFeatures(1), Nd4j.hstack(leftSide.getFeatures(1), rightSide.getFeatures(1)));
assertEquals(fromPair.getFeaturesMaskArray(0), Nd4j.hstack(leftSide.getFeaturesMaskArray(0), rightSide.getFeaturesMaskArray(0)));
//check for pair max length greater than sum of lengths - pop neither with padding
// features should be the same as hstack of shorter and longer padded with prepend/append
// segment id should 1 only in the longer for part of the length of the sentence
prependAppend = true;
multiDataSetTriple = generateMultiDataSets(new Triple<>(shortL + longL + 5, shortL, longL + 5), prependAppend);
shortLongPair = multiDataSetTriple.getFirst();
shortSentence = multiDataSetTriple.getSecond();
longSentence = multiDataSetTriple.getThird();
assertEquals(shortLongPair.getFeatures(0), Nd4j.hstack(shortSentence.getFeatures(0), longSentence.getFeatures(0)));
longSentence.getFeatures(1).get(NDArrayIndex.all(), NDArrayIndex.interval(0, longL + 1)).addi(1); //segmentId stays 0 for the padded part
assertEquals(shortLongPair.getFeatures(1), Nd4j.hstack(shortSentence.getFeatures(1), longSentence.getFeatures(1)));
assertEquals(shortLongPair.getFeaturesMaskArray(0), Nd4j.hstack(shortSentence.getFeaturesMaskArray(0), longSentence.getFeaturesMaskArray(0)));
numOfSentences = 1;
multiDataSetTriple = generateMultiDataSets(new Triple<>(shortL + longL + 5, shortL, longL + 5), prependAppend, numOfSentences);
fromPair = multiDataSetTriple.getFirst();
leftSide = multiDataSetTriple.getSecond();
rightSide = multiDataSetTriple.getThird();
assertEquals(fromPair.getFeatures(0), Nd4j.hstack(leftSide.getFeatures(0), rightSide.getFeatures(0)));
rightSide.getFeatures(1).get(NDArrayIndex.all(), NDArrayIndex.interval(0, longL + 1)).addi(1); //segmentId stays 0 for the padded part
assertEquals(fromPair.getFeatures(1), Nd4j.hstack(leftSide.getFeatures(1), rightSide.getFeatures(1)));
assertEquals(fromPair.getFeaturesMaskArray(0), Nd4j.hstack(leftSide.getFeaturesMaskArray(0), rightSide.getFeaturesMaskArray(0)));
//check for pair max length less than shorter sentence - pop both
//should be the same as hstack with segment ids 1 for second sentence if no prepend/append
int maxL = shortL - 2;
int maxL = 5;//checking odd
numOfSentences = 3;
prependAppend = false;
multiDataSetTriple = generateMultiDataSets(new Triple<>(maxL, maxL / 2, maxL - maxL / 2), prependAppend);
shortLongPair = multiDataSetTriple.getFirst();
shortSentence = multiDataSetTriple.getSecond();
longSentence = multiDataSetTriple.getThird();
assertEquals(shortLongPair.getFeatures(0), Nd4j.hstack(shortSentence.getFeatures(0), longSentence.getFeatures(0)));
longSentence.getFeatures(1).addi(1);
assertEquals(shortLongPair.getFeatures(1), Nd4j.hstack(shortSentence.getFeatures(1), longSentence.getFeatures(1)));
assertEquals(shortLongPair.getFeaturesMaskArray(0), Nd4j.hstack(shortSentence.getFeaturesMaskArray(0), longSentence.getFeaturesMaskArray(0)));
multiDataSetTriple = generateMultiDataSets(new Triple<>(maxL, maxL / 2, maxL - maxL / 2), prependAppend, numOfSentences);
fromPair = multiDataSetTriple.getFirst();
leftSide = multiDataSetTriple.getSecond();
rightSide = multiDataSetTriple.getThird();
assertEquals(fromPair.getFeatures(0), Nd4j.hstack(leftSide.getFeatures(0), rightSide.getFeatures(0)));
rightSide.getFeatures(1).addi(1);
assertEquals(fromPair.getFeatures(1), Nd4j.hstack(leftSide.getFeatures(1), rightSide.getFeatures(1)));
assertEquals(fromPair.getFeaturesMaskArray(0), Nd4j.hstack(leftSide.getFeaturesMaskArray(0), rightSide.getFeaturesMaskArray(0)));
}
/*
Same idea as previous test - construct mds from bert iterator with sep sentences and check against one with pairs
Checks various max lengths
Has sentences of varying lengths
*/
@Test
public void testSentencePairsUnequalLengths() throws IOException {
//check for pop only longer (i.e between longer and longer + shorter), first row pop from second sentence, next row pop from first sentence, nothing to pop in the third row
//should be identical to hstack if there is no append, prepend
//batch size is 2
int mbS = 4;
String shortSent = "I saw a girl with a telescope.";
String longSent = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum";
String sent1 = "Goodnight noises everywhere"; //shorter than shortSent - no popping
String sent2 = "Goodnight moon"; //shorter than shortSent - no popping
BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
int shortL = t.create(shortSent).countTokens();
int longL = t.create(longSent).countTokens();
int sent1L = t.create(sent1).countTokens();
int sent2L = t.create(sent2).countTokens();
//won't check 2*shortL + 1 because this will always pop on the left
for (int maxL = longL + shortL - 1; maxL > 2 * shortL; maxL--) {
int minibatchSize = 4;
int numOfSentencesinIter = 3;
TestSentencePairsHelper testPairHelper = new TestSentencePairsHelper(numOfSentencesinIter);
int shortL = testPairHelper.getShortL();
int longL = testPairHelper.getLongL();
int sent1L = testPairHelper.getSentenceALen();
int sent2L = testPairHelper.getSentenceBLen();
System.out.println("Sentence Pairs, Left");
System.out.println(testPairHelper.getSentencesLeft());
System.out.println("Sentence Pairs, Right");
System.out.println(testPairHelper.getSentencesRight());
//anything outside this range more will need to check padding,truncation
for (int maxL = longL + shortL; maxL > 2 * shortL + 1; maxL--) {
System.out.println("Running for max length = " + maxL);
MultiDataSet leftMDS = BertIterator.builder()
.tokenizer(t)
.minibatchSize(mbS)
.tokenizer(testPairHelper.getTokenizer())
.minibatchSize(minibatchSize)
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID)
.vocabMap(t.getVocab())
.vocabMap(testPairHelper.getTokenizer().getVocab())
.task(BertIterator.Task.SEQ_CLASSIFICATION)
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, longL + 10) //random big num guaranteed to be longer than either
.sentenceProvider(new TestSentenceProvider())
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, longL * 10) //random big num guaranteed to be longer than either
.sentenceProvider(new TestSentenceHelper(numOfSentencesinIter).getSentenceProvider())
.padMinibatches(true)
.build().next();
MultiDataSet rightMDS = BertIterator.builder()
.tokenizer(t)
.minibatchSize(mbS)
.tokenizer(testPairHelper.getTokenizer())
.minibatchSize(minibatchSize)
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID)
.vocabMap(t.getVocab())
.vocabMap(testPairHelper.getTokenizer().getVocab())
.task(BertIterator.Task.SEQ_CLASSIFICATION)
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, longL + 10) //random big num guaranteed to be longer than either
.sentenceProvider(new TestSentenceProvider(true))
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, longL * 10) //random big num guaranteed to be longer than either
.sentenceProvider(new TestSentenceHelper(true, numOfSentencesinIter).getSentenceProvider())
.padMinibatches(true)
.build().next();
MultiDataSet pairMDS = BertIterator.builder()
.tokenizer(t)
.minibatchSize(mbS)
.tokenizer(testPairHelper.getTokenizer())
.minibatchSize(minibatchSize)
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID)
.vocabMap(t.getVocab())
.vocabMap(testPairHelper.getTokenizer().getVocab())
.task(BertIterator.Task.SEQ_CLASSIFICATION)
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, maxL) //random big num guaranteed to be longer than either
.sentencePairProvider(new TestSentencePairProvider())
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, maxL)
.sentencePairProvider(testPairHelper.getPairSentenceProvider())
.padMinibatches(true)
.build().next();
//Left sentences here are {{shortSent},
// {longSent},
// {Sent1}}
//Right sentences here are {{longSent},
// {shortSent},
// {Sent2}}
//The sentence pairs here are {{shortSent,longSent},
// {longSent,shortSent}
// {Sent1, Sent2}}
//CHECK FEATURES
INDArray combinedFeat = Nd4j.create(DataType.INT,mbS,maxL);
INDArray combinedFeat = Nd4j.create(DataType.INT, minibatchSize, maxL);
//left side
INDArray leftFeatures = leftMDS.getFeatures(0);
INDArray topLSentFeat = leftFeatures.getRow(0).get(NDArrayIndex.interval(0, shortL));
INDArray midLSentFeat = leftFeatures.getRow(1).get(NDArrayIndex.interval(0, maxL - shortL));
INDArray bottomLSentFeat = leftFeatures.getRow(2).get(NDArrayIndex.interval(0,sent1L));
INDArray bottomLSentFeat = leftFeatures.getRow(2).get(NDArrayIndex.interval(0, sent1L));
//right side
INDArray rightFeatures = rightMDS.getFeatures(0);
INDArray topRSentFeat = rightFeatures.getRow(0).get(NDArrayIndex.interval(0, maxL - shortL));
INDArray midRSentFeat = rightFeatures.getRow(1).get(NDArrayIndex.interval(0, shortL));
INDArray bottomRSentFeat = rightFeatures.getRow(2).get(NDArrayIndex.interval(0,sent2L));
INDArray bottomRSentFeat = rightFeatures.getRow(2).get(NDArrayIndex.interval(0, sent2L));
//expected pair
combinedFeat.getRow(0).addi(Nd4j.hstack(topLSentFeat,topRSentFeat));
combinedFeat.getRow(1).addi(Nd4j.hstack(midLSentFeat,midRSentFeat));
combinedFeat.getRow(2).get(NDArrayIndex.interval(0,sent1L+sent2L)).addi(Nd4j.hstack(bottomLSentFeat,bottomRSentFeat));
combinedFeat.getRow(0).addi(Nd4j.hstack(topLSentFeat, topRSentFeat));
combinedFeat.getRow(1).addi(Nd4j.hstack(midLSentFeat, midRSentFeat));
combinedFeat.getRow(2).get(NDArrayIndex.interval(0, sent1L + sent2L)).addi(Nd4j.hstack(bottomLSentFeat, bottomRSentFeat));
assertEquals(maxL, pairMDS.getFeatures(0).shape()[1]);
assertArrayEquals(combinedFeat.shape(), pairMDS.getFeatures(0).shape());
assertEquals(combinedFeat, pairMDS.getFeatures(0));
//CHECK SEGMENT ID
INDArray combinedFetSeg = Nd4j.create(DataType.INT, mbS, maxL);
INDArray combinedFetSeg = Nd4j.create(DataType.INT, minibatchSize, maxL);
combinedFetSeg.get(NDArrayIndex.point(0), NDArrayIndex.interval(shortL, maxL)).addi(1);
combinedFetSeg.get(NDArrayIndex.point(1), NDArrayIndex.interval(maxL - shortL, maxL)).addi(1);
combinedFetSeg.get(NDArrayIndex.point(2), NDArrayIndex.interval(sent1L, sent1L+sent2L)).addi(1);
combinedFetSeg.get(NDArrayIndex.point(2), NDArrayIndex.interval(sent1L, sent1L + sent2L)).addi(1);
assertArrayEquals(combinedFetSeg.shape(), pairMDS.getFeatures(1).shape());
assertEquals(maxL, combinedFetSeg.shape()[1]);
assertEquals(combinedFetSeg, pairMDS.getFeatures(1));
testPairHelper.getPairSentenceProvider().reset();
}
}
@Test
public void testSentencePairFeaturizer() throws IOException {
String shortSent = "I saw a girl with a telescope.";
String longSent = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum";
List<Pair<String, String>> listSentencePair = new ArrayList<>();
listSentencePair.add(new Pair<>(shortSent, longSent));
listSentencePair.add(new Pair<>(longSent, shortSent));
BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
int minibatchSize = 2;
TestSentencePairsHelper testPairHelper = new TestSentencePairsHelper(minibatchSize);
BertIterator b = BertIterator.builder()
.tokenizer(t)
.minibatchSize(2)
.tokenizer(testPairHelper.getTokenizer())
.minibatchSize(minibatchSize)
.padMinibatches(true)
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID)
.vocabMap(t.getVocab())
.vocabMap(testPairHelper.getTokenizer().getVocab())
.task(BertIterator.Task.SEQ_CLASSIFICATION)
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, 128)
.sentencePairProvider(new TestSentencePairProvider())
.sentencePairProvider(testPairHelper.getPairSentenceProvider())
.prependToken("[CLS]")
.appendToken("[SEP]")
.build();
@ -494,23 +471,19 @@ public class TestBertIterator extends BaseDL4JTest {
INDArray[] featuresArr = mds.getFeatures();
INDArray[] featuresMaskArr = mds.getFeaturesMaskArrays();
Pair<INDArray[], INDArray[]> p = b.featurizeSentencePairs(listSentencePair);
Pair<INDArray[], INDArray[]> p = b.featurizeSentencePairs(testPairHelper.getSentencePairs());
assertEquals(p.getFirst().length, 2);
assertEquals(featuresArr[0], p.getFirst()[0]);
assertEquals(featuresArr[1], p.getFirst()[1]);
//assertEquals(p.getSecond().length, 2);
assertEquals(featuresMaskArr[0], p.getSecond()[0]);
//assertEquals(featuresMaskArr[1], p.getSecond()[1]);
}
/**
* Returns three multidatasets from bert iterator based on given max lengths and whether to prepend/append
* Returns three multidatasets (one from pair of sentences and the other two from single sentence lists) from bert iterator
* with given max lengths and whether to prepend/append
* Idea is the sentence pair dataset can be constructed from the single sentence datasets
* First one is constructed from a sentence pair "I saw a girl with a telescope." & "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum"
* Second one is constructed from the left of the sentence pair i.e "I saw a girl with a telescope."
* Third one is constructed from the right of the sentence pair i.e "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum"
*/
private Triple<MultiDataSet, MultiDataSet, MultiDataSet> generateMultiDataSets(Triple<Integer, Integer, Integer> maxLengths, boolean prependAppend) throws IOException {
private Triple<MultiDataSet, MultiDataSet, MultiDataSet> generateMultiDataSets(Triple<Integer, Integer, Integer> maxLengths, boolean prependAppend, int numSentences) throws IOException {
BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
int maxforPair = maxLengths.getFirst();
int maxPartOne = maxLengths.getSecond();
@ -518,133 +491,155 @@ public class TestBertIterator extends BaseDL4JTest {
BertIterator.Builder commonBuilder;
commonBuilder = BertIterator.builder()
.tokenizer(t)
.minibatchSize(1)
.minibatchSize(4)
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID)
.vocabMap(t.getVocab())
.task(BertIterator.Task.SEQ_CLASSIFICATION);
BertIterator shortLongPairFirstIter = commonBuilder
BertIterator pairIter = commonBuilder
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, prependAppend ? maxforPair + 3 : maxforPair)
.sentencePairProvider(new TestSentencePairProvider())
.sentencePairProvider(new TestSentencePairsHelper(numSentences).getPairSentenceProvider())
.prependToken(prependAppend ? "[CLS]" : null)
.appendToken(prependAppend ? "[SEP]" : null)
.build();
BertIterator shortFirstIter = commonBuilder
BertIterator leftIter = commonBuilder
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, prependAppend ? maxPartOne + 2 : maxPartOne)
.sentenceProvider(new TestSentenceProvider())
.sentenceProvider(new TestSentenceHelper(numSentences).getSentenceProvider())
.prependToken(prependAppend ? "[CLS]" : null)
.appendToken(prependAppend ? "[SEP]" : null)
.build();
BertIterator longFirstIter = commonBuilder
BertIterator rightIter = commonBuilder
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, prependAppend ? maxPartTwo + 1 : maxPartTwo)
.sentenceProvider(new TestSentenceProvider(true))
.sentenceProvider(new TestSentenceHelper(true, numSentences).getSentenceProvider())
.prependToken(null)
.appendToken(prependAppend ? "[SEP]" : null)
.build();
return new Triple<>(shortLongPairFirstIter.next(), shortFirstIter.next(), longFirstIter.next());
return new Triple<>(pairIter.next(), leftIter.next(), rightIter.next());
}
private static class TestSentenceProvider implements LabeledSentenceProvider {
@Getter
private static class TestSentencePairsHelper {
private int pos = 0;
private boolean invert;
private List<String> sentencesLeft;
private List<String> sentencesRight;
private List<Pair<String, String>> sentencePairs;
private List<List<String>> tokenizedSentencesLeft;
private List<List<String>> tokenizedSentencesRight;
private List<String> labels;
private int shortL;
private int longL;
private int sentenceALen;
private int sentenceBLen;
private BertWordPieceTokenizerFactory tokenizer;
private CollectionLabeledPairSentenceProvider pairSentenceProvider;
private TestSentenceProvider() {
this.invert = false;
private TestSentencePairsHelper() throws IOException {
this(3);
}
private TestSentenceProvider(boolean invert) {
this.invert = invert;
private TestSentencePairsHelper(int minibatchSize) throws IOException {
sentencesLeft = new ArrayList<>();
sentencesRight = new ArrayList<>();
sentencePairs = new ArrayList<>();
labels = new ArrayList<>();
tokenizedSentencesLeft = new ArrayList<>();
tokenizedSentencesRight = new ArrayList<>();
tokenizer = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
sentencesLeft.add(shortSentence);
sentencesRight.add(longSentence);
sentencePairs.add(new Pair<>(shortSentence, longSentence));
labels.add("positive");
if (minibatchSize > 1) {
sentencesLeft.add(longSentence);
sentencesRight.add(shortSentence);
sentencePairs.add(new Pair<>(longSentence, shortSentence));
labels.add("negative");
if (minibatchSize > 2) {
sentencesLeft.add(sentenceA);
sentencesRight.add(sentenceB);
sentencePairs.add(new Pair<>(sentenceA, sentenceB));
labels.add("positive");
}
}
for (int i = 0; i < minibatchSize; i++) {
List<String> tokensL = tokenizer.create(sentencesLeft.get(i)).getTokens();
List<String> tokensR = tokenizer.create(sentencesRight.get(i)).getTokens();
if (i == 0) {
shortL = tokensL.size();
longL = tokensR.size();
}
if (i == 2) {
sentenceALen = tokensL.size();
sentenceBLen = tokensR.size();
}
tokenizedSentencesLeft.add(tokensL);
tokenizedSentencesRight.add(tokensR);
}
pairSentenceProvider = new CollectionLabeledPairSentenceProvider(sentencesLeft, sentencesRight, labels, null);
}
}
@Override
public boolean hasNext() {
return pos < totalNumSentences();
@Getter
private static class TestSentenceHelper {
private List<String> sentences;
private List<List<String>> tokenizedSentences;
private List<String> labels;
private int shortestL = 0;
private int longestL = 0;
private BertWordPieceTokenizerFactory tokenizer;
private CollectionLabeledSentenceProvider sentenceProvider;
private TestSentenceHelper() throws IOException {
this(false, 2);
}
@Override
public Pair<String, String> nextSentence() {
Preconditions.checkState(hasNext());
if (pos == 0) {
pos++;
if (!invert) return new Pair<>("I saw a girl with a telescope.", "positive");
return new Pair<>("Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum", "negative");
private TestSentenceHelper(int minibatchSize) throws IOException {
this(false, minibatchSize);
}
private TestSentenceHelper(boolean alternateOrder) throws IOException {
this(false, 3);
}
private TestSentenceHelper(boolean alternateOrder, int minibatchSize) throws IOException {
sentences = new ArrayList<>();
labels = new ArrayList<>();
tokenizedSentences = new ArrayList<>();
tokenizer = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
if (!alternateOrder) {
sentences.add(shortSentence);
labels.add("positive");
if (minibatchSize > 1) {
sentences.add(longSentence);
labels.add("negative");
if (minibatchSize > 2) {
sentences.add(sentenceA);
labels.add("positive");
}
}
} else {
if (pos == 1) {
pos++;
if (!invert) return new Pair<>("Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum", "negative");
return new Pair<>("I saw a girl with a telescope.", "positive");
}
pos++;
if (!invert)
return new Pair<>("Goodnight noises everywhere", "positive");
return new Pair<>("Goodnight moon", "positive");
sentences.add(longSentence);
labels.add("negative");
if (minibatchSize > 1) {
sentences.add(shortSentence);
labels.add("positive");
if (minibatchSize > 2) {
sentences.add(sentenceB);
labels.add("positive");
}
}
@Override
public void reset() {
pos = 0;
}
@Override
public int totalNumSentences() {
return 3;
for (int i = 0; i < sentences.size(); i++) {
List<String> tokenizedSentence = tokenizer.create(sentences.get(i)).getTokens();
if (i == 0)
shortestL = tokenizedSentence.size();
if (tokenizedSentence.size() > longestL)
longestL = tokenizedSentence.size();
if (tokenizedSentence.size() < shortestL)
shortestL = tokenizedSentence.size();
tokenizedSentences.add(tokenizedSentence);
}
@Override
public List<String> allLabels() {
return Arrays.asList("positive", "negative");
}
@Override
public int numLabelClasses() {
return 2;
}
}
private static class TestSentencePairProvider implements LabeledPairSentenceProvider {
private int pos = 0;
@Override
public boolean hasNext() {
return pos < totalNumSentences();
}
@Override
public Triple<String, String, String> nextSentencePair() {
Preconditions.checkState(hasNext());
if (pos == 0) {
pos++;
return new Triple<>("I saw a girl with a telescope.", "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum", "positive");
} else {
if (pos == 1) {
pos++;
return new Triple<>("Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum", "I saw a girl with a telescope.", "negative");
}
pos++;
return new Triple<>("Goodnight noises everywhere", "Goodnight moon", "positive");
}
}
@Override
public void reset() {
pos = 0;
}
@Override
public int totalNumSentences() {
return 3;
}
@Override
public List<String> allLabels() {
return Arrays.asList("positive", "negative");
}
@Override
public int numLabelClasses() {
return 2;
sentenceProvider = new CollectionLabeledSentenceProvider(sentences, labels, null);
}
}

View File

@ -254,6 +254,9 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
uiEventRoutingThread = new Thread(new StatsEventRouterRunnable());
uiEventRoutingThread.setDaemon(true);
uiEventRoutingThread.start();
String address = UIServer.getInstance().getAddress();
log.info("Deeplearning4j UI server started at: {}", address);
}
private List<String> extractArgsFromRoute(String path, RoutingContext rc) {
@ -317,7 +320,7 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
@Override
public String getAddress() {
return "https://localhost:" + server.actualPort();
return "http://localhost:" + server.actualPort();
}
@Override
@ -421,7 +424,7 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
}
private void runHelper() throws Exception {
log.info("VertxUIServer.StatsEventRouterRunnable started");
log.trace("VertxUIServer.StatsEventRouterRunnable started");
//Idea: collect all event stats, and route them to the appropriate modules
while (!shutdown.get()) {

View File

@ -1256,6 +1256,9 @@ namespace nd4j {
FORCEINLINE T& t(const Nd4jLong i, const Nd4jLong j);
template<typename T>
FORCEINLINE T& t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k);
template<typename T>
FORCEINLINE T& t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong w);
/**
* returns array element with given index
@ -1268,6 +1271,8 @@ namespace nd4j {
FORCEINLINE T t(const Nd4jLong i, const Nd4jLong j) const;
template<typename T>
FORCEINLINE T t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) const;
template<typename T>
FORCEINLINE T t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong w) const;
/**
@ -1711,7 +1716,7 @@ namespace nd4j {
if (isEmpty())
return false;
return shape::isMatrix(this->_shapeInfo);
return 0 != shape::isMatrix(this->_shapeInfo);
}
//////////////////////////////////////////////////////////////////////////
@ -1751,7 +1756,7 @@ namespace nd4j {
//////////////////////////////////////////////////////////////////////////
bool NDArray::isScalar() const {
return shape::isScalar(this->_shapeInfo);
return 0 != shape::isScalar(this->_shapeInfo);
}
//////////////////////////////////////////////////////////////////////////
@ -2082,7 +2087,7 @@ template <typename T>
T& NDArray::t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) {
if (rankOf() != 3 || i >= sizeAt(0) || j >= sizeAt(1) || k >= sizeAt(2))
throw std::invalid_argument("NDArray::t(i,j,k): one of input indexes is out of array length or rank!=2 !");
throw std::invalid_argument("NDArray::t(i,j,k): one of input indexes is out of array length or rank!=3!");
if (DataTypeUtils::fromT<T>() != _dataType)
throw std::invalid_argument("NDArray::t(i,j,k): type of array is not equal to template type T!");
@ -2095,6 +2100,23 @@ T& NDArray::t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) {
return *(reinterpret_cast<T*>(bufferWithOffset(offset)));
}
template <typename T>
T& NDArray::t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong w) {
if (rankOf() != 4 || i >= sizeAt(0) || j >= sizeAt(1) || k >= sizeAt(2), w >= sizeAt(3))
throw std::invalid_argument("NDArray::t(i,j,k,w): one of input indexes is out of array length or rank!=4 !");
if (DataTypeUtils::fromT<T>() != _dataType)
throw std::invalid_argument("NDArray::t(i,j,k,w): type of array is not equal to template type T!");
if(!isActualOnHostSide())
syncToHost();
Nd4jLong coords[4] = {i, j, k, w};
auto offset = shape::getOffset(getShapeInfo(), coords);
tickWriteHost();
return *(reinterpret_cast<T*>(bufferWithOffset(offset)));
}
////////////////////////////////////////////////////////////////////////
template <typename T>
T NDArray::t(const Nd4jLong i) const {
@ -2133,7 +2155,7 @@ T NDArray::t(const Nd4jLong i, const Nd4jLong j) const {
T NDArray::t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) const {
if (rankOf() != 3 || i >= sizeAt(0) || j >= sizeAt(1) || k >= sizeAt(2))
throw std::invalid_argument("NDArray::t(i,j,k): one of input indexes is out of array length or rank!=2 !");
throw std::invalid_argument("NDArray::t(i,j,k): one of input indexes is out of array length or rank!=3!");
if (DataTypeUtils::fromT<T>() != _dataType)
throw std::invalid_argument("NDArray::t(i,j,k): type of array is not equal to template type T!");
@ -2146,6 +2168,23 @@ T NDArray::t(const Nd4jLong i, const Nd4jLong j) const {
return *(reinterpret_cast<T*>(bufferWithOffset(offset)));
}
template <typename T>
T NDArray::t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong w) const {
if (rankOf() != 4 || i >= sizeAt(0) || j >= sizeAt(1) || k >= sizeAt(2) || w >= sizeAt(3))
throw std::invalid_argument("NDArray::t(i,j,k,w): one of input indexes is out of array length or rank!=4!");
if (DataTypeUtils::fromT<T>() != _dataType)
throw std::invalid_argument("NDArray::t(i,j,k,w): type of array is not equal to template type T!");
if(!isActualOnHostSide())
syncToHost();
Nd4jLong coords[4] = {i, j, k, w};
auto offset = shape::getOffset(getShapeInfo(), coords);
tickReadHost();
return *(reinterpret_cast<T*>(bufferWithOffset(offset)));
}
#ifndef __JAVACPP_HACK__
////////////////////////////////////////////////////////////////////////
std::shared_ptr<DataBuffer> NDArray::getDataBuffer() const {

View File

@ -2348,7 +2348,7 @@ NDArray NDArray::operator-(const NDArray& other) const {
NDArray result(getShapeInfo(), DataTypeUtils::pickPairwiseResultType(getShapeInfo(), other.getShapeInfo()), false, getContext());
NDArray::prepareSpecialUse({&result}, {this, &other});
NativeOpExecutioner::execPairwiseTransform(getContext(), nd4j::pairwise::Subtract, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), getSpecialShapeInfo(), nullptr);
NativeOpExecutioner::execPairwiseTransform(getContext(), nd4j::pairwise::Subtract, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), result.getSpecialShapeInfo(), nullptr);
NDArray::registerSpecialUse({&result}, {this, &other});
return result;
@ -2394,7 +2394,7 @@ NDArray NDArray::operator/(const NDArray& other) const {
NDArray result(getShapeInfo(), DataTypeUtils::pickPairwiseResultType(getShapeInfo(), other.getShapeInfo()), false, getContext());
NDArray::prepareSpecialUse({&result}, {this, &other});
NativeOpExecutioner::execPairwiseTransform(getContext(), nd4j::pairwise::Divide, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), getSpecialShapeInfo(), nullptr);
NativeOpExecutioner::execPairwiseTransform(getContext(), nd4j::pairwise::Divide, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), result.getSpecialShapeInfo(), nullptr);
NDArray::registerSpecialUse({&result}, {this, &other});
return result;

View File

@ -46,7 +46,7 @@ namespace nd4j {
getOpDescriptor()
->setAllowedInputTypes(nd4j::DataType::ANY)
->setAllowedOutputTypes(0, DataType::INHERIT)
->setAllowedOutputTypes(1, DataType::INT64);
->setAllowedOutputTypes(1, {ALL_INTS});
}

View File

@ -35,6 +35,8 @@ namespace nd4j {
int width;
int height;
auto inRank = image->rankOf();
if (output->isEmpty()) return Status::OK();
REQUIRE_TRUE(inRank == 3 || inRank == 4, 0, "resize_bicubic: Source tensor should have rank 4, but %i given.", inRank);
REQUIRE_TRUE(output->rankOf() == inRank, 0, "resize_bicubic: Source tensor and output should have the same rank, but %i and %i given.", inRank, output->rankOf());
REQUIRE_TRUE(size->rankOf() == 1, size->lengthOf() == 2, 0, "resize_bicubic: Resize params is a pair of values, not %i.", size->lengthOf());
@ -57,7 +59,7 @@ namespace nd4j {
if (block.numB()> 1)
halfPixelAlign = block.getBArguments()->at(1);
}
REQUIRE_TRUE(halfPixelAlign == false || halfPixelAlign == true && alignCorners == false, 0, "resize_bicubic: half pixel align can be used only with non-aligned corners");
REQUIRE_TRUE(!halfPixelAlign || (halfPixelAlign && !alignCorners), 0, "resize_bicubic: `half_pixel_centers' should be false or true only when `align_corners' is false");
auto source = inRank == 4?image->reshape(image->ordering(), {image->sizeAt(0), image->sizeAt(1), image->sizeAt(2), image->sizeAt(3)}):image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)});
auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}):output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)});

View File

@ -32,8 +32,10 @@ namespace nd4j {
NDArray* output = OUTPUT_VARIABLE(0);
int width;
int height;
bool center = false; // - default value
bool alignCorners = false; // - default value
auto inRank = image->rankOf();
if (output->isEmpty()) return Status::OK();
REQUIRE_TRUE( inRank == 4 || inRank == 3, 0, "resize_bilinear: input image should be 4D "
"tensor, but input has rank %i",
image->rankOf());
@ -46,21 +48,25 @@ namespace nd4j {
auto newImageSize = INPUT_VARIABLE(1);
REQUIRE_TRUE(newImageSize->lengthOf() == 2, 0, "resize_bilinear: Resize params is a pair of values, not %i.", newImageSize->lengthOf());
REQUIRE_TRUE(block.numI() <= 1, 0, "resize_bilinear: Resize params already given by the second param. Int params are expensive.");
width = newImageSize->e<int>(0);
height = newImageSize->e<int>(1);
if (block.numI() == 1) {
center = 0 != INT_ARG(0);
}
height = newImageSize->e<int>(0);
width = newImageSize->e<int>(1);
}
else {
REQUIRE_TRUE(block.numI() <= 3, 0, "resize_bilinear: Neither resize width nor height are provided.");
width = INT_ARG(0);
height = INT_ARG(1);
if (block.numI() == 3)
center = 0 != INT_ARG(2);
REQUIRE_TRUE(block.numI() > 1, 0, "resize_bilinear: Neither resize width nor height are provided.");
height = INT_ARG(0);
width = INT_ARG(1);
}
return helpers::resizeBilinearFunctor(block.launchContext(), inRank==4?image:&source, width, height, center, inRank == 4?output:&target);
if (block.numB() > 0)
alignCorners = B_ARG(0);
bool halfPixelCenter = false;
if (block.numB() > 1)
halfPixelCenter = B_ARG(1);
REQUIRE_TRUE(!halfPixelCenter || (halfPixelCenter && !alignCorners), 0, "resize_bilinear: `half_pixel_centers' should be false or true only when `align_corners' is false");
return helpers::resizeBilinearFunctor(block.launchContext(), inRank==4?image:&source, width, height, alignCorners, halfPixelCenter, inRank == 4 ? output : &target);
}
DECLARE_SHAPE_FN(resize_bilinear) {
@ -83,7 +89,7 @@ namespace nd4j {
height = newImageSize->e<int>(1);
}
else {
REQUIRE_TRUE(block.numI() <= 3, 0, "resize_bilinear: Neither resize width nor height are provided.");
REQUIRE_TRUE(block.numI() == 2, 0, "resize_bilinear: Neither resize width nor height are provided.");
width = INT_ARG(0);
height = INT_ARG(1);
}
@ -101,7 +107,12 @@ namespace nd4j {
outputShape[2] = height;
outputShape[3] = in[3];
}
if (DataTypeUtils::isR(ArrayOptions::dataType(in))) {
ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in));
}
else {
ShapeUtils::updateStridesAndType(outputShape, DataType::FLOAT32, shape::order(in));
}
shapeList->push_back(CONSTANT(outputShape));
return shapeList;

View File

@ -31,35 +31,40 @@ namespace nd4j {
auto image = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
auto inRank = image->rankOf();
int width;
int height;
bool center = false; // - default value
bool alignCorners = false; // - default value
if (output->isEmpty()) return Status::OK();
if (block.width() > 1) {
auto newImageSize = INPUT_VARIABLE(1);
REQUIRE_TRUE(newImageSize->lengthOf() == 2, 0, "resize_nearest_neighbor: Resize params is a pair of values, not %i.", newImageSize->lengthOf());
REQUIRE_TRUE(block.numI() <= 1, 0, "resize_nearest_neighbor: Resize params already given by the second param. Int params are expensive.");
width = newImageSize->e<int>(0);
height = newImageSize->e<int>(1);
if (block.numI() == 1) {
center = 0 != INT_ARG(0);
}
height = newImageSize->e<int>(0);
width = newImageSize->e<int>(1);
}
else {
REQUIRE_TRUE(block.numI() <= 3, 0, "resize_nearest_neighbor: Neither resize width nor height are provided.");
width = INT_ARG(0);
height = INT_ARG(1);
if (block.numI() == 3)
center = 0 != INT_ARG(2);
REQUIRE_TRUE(block.numI() == 2, 0, "resize_nearest_neighbor: Neither resize width nor height are provided.");
height = INT_ARG(0);
width = INT_ARG(1);
}
auto inRank = image->rankOf();
if (block.numB() > 0)
alignCorners = B_ARG(0);
bool halfPixelCenter = false;
if (block.numB() > 1)
halfPixelCenter = B_ARG(1);
REQUIRE_TRUE(width <= (1 << 24) || height <= (1 << 24), 0, "resize_nearest_neighbour: the image resize should be limited to 2^24 pixels both for height and width, but %d and %d were given.", height, width);
REQUIRE_TRUE(inRank == 4 || inRank == 3, 0, "resize_nearest_neighbor: Input should be 4D tensor, but rank %i occured");
REQUIRE_TRUE(inRank == output->rankOf(), 0, "resize_nearest_neighbor: Input and output ranks should be equals, but %i and %i occured.", inRank, output->rankOf());
REQUIRE_TRUE(image->dataType() == output->dataType(), 0, "resize_nearest_neighbor: Input and output types should be the same, but `%s' occured instead.", DataTypeUtils::asString(output->dataType()).c_str());
auto source = inRank == 4?*image:image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)});
REQUIRE_TRUE(!halfPixelCenter || (halfPixelCenter && !alignCorners), 0, "resize_nearest_neighbor: `half_pixel_centers' should be false or true only when `align_corners' is false");
REQUIRE_TRUE(((alignCorners && height > 2) || (height > 0)) && ((alignCorners && width > 1) || (width > 0)), 0, "resize_nearest_neighbor: Wrong input or output size to resize (width = %d, height = %d)", width, height);
auto source = inRank == 4?*image:image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)});
auto target = inRank == 4?*output:output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)});
return helpers::resizeNeighborFunctor(block.launchContext(), inRank==4?image:&source, width, height, center, inRank == 4?output:&target);
return helpers::resizeNeighborFunctor(block.launchContext(), inRank==4?image:&source, width, height, alignCorners, halfPixelCenter, inRank == 4 ? output : &target);
}
DECLARE_SHAPE_FN(resize_nearest_neighbor) {

View File

@ -120,6 +120,27 @@ namespace helpers {
}
};
// Half pixel scaler scales assuming that the pixel centers are at 0.5, i.e. the
// floating point coordinates of the top,left pixel is 0.5,0.5.
struct HalfPixelScalerNN {
HalfPixelScalerNN(){};
inline float operator()(const int x, const float scale) const {
// Note that we subtract 0.5 from the return value, as the existing bilinear
// sampling code etc assumes pixels are in the old coordinate system.
return (static_cast<float>(x) + 0.5f) * scale;
}
};
// Older incorrect scaling method that causes all resizes to have a slight
// translation leading to inconsistent results. For example, a flip then a
// resize gives different results then a resize then a flip.
struct LegacyScaler {
LegacyScaler(){};
inline float operator()(const int x, const float scale) const {
return static_cast<float>(x) * scale;
}
};
struct WeightsAndIndices {
float _weight0;
float _weight1;
@ -133,7 +154,8 @@ namespace helpers {
int _advance; // advance value.
};
inline void computeInterpolationWeights(Nd4jLong outSize,
template <class Scaler>
inline void computeInterpolationWeights(const Scaler scaler, Nd4jLong outSize,
Nd4jLong inSize,
double scale,
BilinearInterpolationData *interpolationData) {
@ -143,10 +165,12 @@ namespace helpers {
auto func = PRAGMA_THREADS_FOR {
for (auto k = start; k < stop; k++) {
auto i = (outSize - k - 1);
double in = i * scale;
interpolationData[i]._bottomIndex = static_cast<Nd4jLong>(in);
interpolationData[i]._topIndex = nd4j::math::nd4j_min(interpolationData[i]._bottomIndex + 1, inSize - 1);
interpolationData[i]._interpolarValue = in - interpolationData[i]._bottomIndex;
double const in = scaler(i, scale);
double const in_f = nd4j::math::nd4j_floor<double, double>(in);
double const in_c = nd4j::math::nd4j_ceil<double, double>(in);
interpolationData[i]._bottomIndex = nd4j::math::nd4j_max(static_cast<Nd4jLong>(in_f), (Nd4jLong)0LL);//static_cast<Nd4jLong>(in);
interpolationData[i]._topIndex = nd4j::math::nd4j_min(static_cast<Nd4jLong>(in_c), inSize - 1);
interpolationData[i]._interpolarValue = in - in_f;
}
};
samediff::Threads::parallel_for(func, 0, outSize);
@ -156,29 +180,29 @@ namespace helpers {
* Computes the bilinear interpolation from the appropriate 4 float points
* and the linear interpolation weights.
*/
static void
resizeImage(NDArray const *images, Nd4jLong batchSize, Nd4jLong inHeight, Nd4jLong inWidth, Nd4jLong outHeight,
Nd4jLong outWidth, Nd4jLong channels,
std::vector<BilinearInterpolationData> const& xs,
std::vector<BilinearInterpolationData> const& ys,
NDArray *output);
// static void
// resizeImage(NDArray const *images, Nd4jLong batchSize, Nd4jLong inHeight, Nd4jLong inWidth, Nd4jLong outHeight,
// Nd4jLong outWidth, Nd4jLong channels,
// std::vector<BilinearInterpolationData> const& xs,
// std::vector<BilinearInterpolationData> const& ys,
// NDArray *output);
template<typename T>
template<typename T, typename Z>
static void
resizeImage_(NDArray const *images, Nd4jLong batchSize, Nd4jLong inHeight, Nd4jLong inWidth, Nd4jLong outHeight,
resizeImage_(T const* pInputBuf, Nd4jLong batchSize, Nd4jLong inHeight, Nd4jLong inWidth, Nd4jLong outHeight,
Nd4jLong outWidth, Nd4jLong channels,
std::vector<BilinearInterpolationData> const &xs,
std::vector<BilinearInterpolationData> const &ys,
NDArray *output) {
Z* pOutputBuf) {
Nd4jLong inRowSize = inWidth * channels;
Nd4jLong inBatchNumValues = inHeight * inRowSize;
Nd4jLong outRowSize = outWidth * channels;
T const *pInputBuf = images->getDataBuffer()->primaryAsT<T>(); // this works only with 'c' direction
// T const *pInputBuf = images->getDataBuffer()->primaryAsT<T>(); // this works only with 'c' direction
BilinearInterpolationData const* xsPtr = xs.data();
T* pOutputBuf = output->dataBuffer()->primaryAsT<T>();
// T* pOutputBuf = output->dataBuffer()->primaryAsT<T>();
auto computeBilinear = [](double topLeft, double topRight,
double bottomLeft, double bottomRight,
double xVal, double yVal) {
@ -214,8 +238,12 @@ namespace helpers {
samediff::Threads::parallel_tad(func, 0, batchSize);
}
template<typename T>
static int resizeBilinearFunctor_(NDArray const *images, int width, int height, bool center, NDArray *output) {
template<typename X, typename Z>
static int resizeBilinearFunctor_(NDArray const *images, int const width, int const height, bool const alignCorners,
bool const halfPixelCenter, NDArray *output) {
ImageResizerState st(alignCorners, halfPixelCenter);
st.validateAndCalculateOutputSize(images, width, height);
const Nd4jLong batchSize = images->sizeAt(0);
const Nd4jLong inHeight = images->sizeAt(1);
const Nd4jLong inWidth = images->sizeAt(2);
@ -230,28 +258,20 @@ namespace helpers {
return ND4J_STATUS_OK;
}
// Special case for TF compatibility
if((center && inHeight < 2) || (center && inWidth < 2)){
center = false;
}
if ((center && inHeight < 2) || (inHeight < 1) || (outHeight < 1) || (center && outHeight < 2) ||
(center && inWidth < 2) || (inWidth < 1) || (outWidth < 1) || (center && outWidth < 2)) {
// wrong input data
nd4j_printf("image.resize_bilinear: Wrong input or output size to resize\n", "");
return ND4J_STATUS_BAD_ARGUMENTS;
}
float heightScale = center ? (inHeight - 1.f) / double(outHeight - 1.f) : (inHeight / float(outHeight));
float widthScale = center ? (inWidth - 1.f) / double(outWidth - 1.f) : (inWidth / float(outWidth));
std::vector<BilinearInterpolationData> ys(outHeight + 1);
std::vector<BilinearInterpolationData> xs(outWidth + 1);
// Compute the cached interpolation weights on the x and y dimensions.
computeInterpolationWeights(outHeight, inHeight, heightScale,
if (halfPixelCenter) {
computeInterpolationWeights(HalfPixelScaler(), outHeight, inHeight, st.heightScale,
ys.data());
computeInterpolationWeights(outWidth, inWidth, widthScale, xs.data());
computeInterpolationWeights(HalfPixelScaler(), outWidth, inWidth, st.widthScale, xs.data());
}
else {
// Compute the cached interpolation weights on the x and y dimensions.
computeInterpolationWeights(LegacyScaler(), outHeight, inHeight, st.heightScale,
ys.data());
computeInterpolationWeights(LegacyScaler(), outWidth, inWidth, st.widthScale, xs.data());
}
int xsSize = xs.size();
// Scale x interpolation weights to avoid a multiplication during iteration.
auto func = PRAGMA_THREADS_FOR {
@ -262,71 +282,84 @@ namespace helpers {
};
samediff::Threads::parallel_for(func, 0, xsSize);
resizeImage(images, batchSize, inHeight, inWidth, outHeight, outWidth, channels, xs, ys, output);
resizeImage_<X,Z>(images->getDataBuffer()->primaryAsT<X>(), batchSize, inHeight, inWidth, outHeight, outWidth, channels, xs, ys, output->dataBuffer()->primaryAsT<Z>());
return ND4J_STATUS_OK;
}
template<typename T>
int resizeNeighborFunctor_(NDArray const *images, int width, int height, bool center, NDArray *output) {
const Nd4jLong batchSize = images->sizeAt(0);
const Nd4jLong inHeight = images->sizeAt(1);
const Nd4jLong inWidth = images->sizeAt(2);
const Nd4jLong channels = images->sizeAt(3);
template <class Scaler, typename T>
void resizeNeighbor(ImageResizerState const& st, NDArray const *images, bool const alignCorners, bool const halfPixelCenter, NDArray *output) {
const Nd4jLong batchSize = st.batchSize;
const Nd4jLong inHeight = st.inHeight;
const Nd4jLong inWidth = st.inWidth;
const Nd4jLong channels = st.channels;
const Nd4jLong outHeight = output->sizeAt(1);
const Nd4jLong outWidth = output->sizeAt(2);
// Handle no-op resizes efficiently.
if (outHeight == inHeight && outWidth == inWidth) {
output->assign(images);
return Status::OK();
}
if ((center && inHeight < 2) || (inHeight < 1) || (outHeight < 1) || (center && outHeight < 2) ||
(center && inWidth < 2) || (inWidth < 1) || (outWidth < 1) || (center && outWidth < 2)) {
// wrong input data
nd4j_printf("image.resize_nearest_neighbor: Wrong input or output size to resize\n", "");
return ND4J_STATUS_BAD_ARGUMENTS;
}
double heightScale = center ? (inHeight - 1.) / double(outHeight - 1.0) : (inHeight / double(outHeight));
double widthScale = center ? (inWidth - 1.) / double(outWidth - 1.0) : (inWidth / double(outWidth));
const Nd4jLong outHeight = st.outHeight;
const Nd4jLong outWidth = st.outWidth;
Scaler scaler;
auto func = PRAGMA_THREADS_FOR_2D {
for (auto b = start_x; b < stop_x; b += inc_x) {
for (auto y = start_y; y < stop_y; y += inc_y) {
Nd4jLong inY = nd4j::math::nd4j_min((center) ? static_cast<Nd4jLong>(nd4j::math::p_round<float>(y * heightScale)) : static_cast<Nd4jLong>(nd4j::math::p_floor<float>(y * heightScale)), inHeight - 1);
auto posY = alignCorners ? static_cast<Nd4jLong>(nd4j::math::p_round<float>(scaler(y, st.heightScale))) : static_cast<Nd4jLong>(nd4j::math::p_floor<float>(scaler(y, st.heightScale)));
Nd4jLong inY = nd4j::math::nd4j_min(posY, inHeight - 1);
if (halfPixelCenter) {
inY = nd4j::math::nd4j_max(0LL, inY);
}
for (auto x = 0; x < outWidth; ++x) {
Nd4jLong inX = nd4j::math::nd4j_min((center) ? static_cast<Nd4jLong>(nd4j::math::p_round<float>(x * widthScale)) : static_cast<Nd4jLong>(nd4j::math::p_floor<float>(x * widthScale)),inWidth - 1);
auto posX = alignCorners ? static_cast<Nd4jLong>(nd4j::math::p_round<float>(scaler(x, st.widthScale))) : static_cast<Nd4jLong>(nd4j::math::p_floor<float>(scaler(x, st.widthScale)));
Nd4jLong inX = nd4j::math::nd4j_min(posX,inWidth - 1);
if (halfPixelCenter) {
inX = nd4j::math::nd4j_max(0LL, inX);
}
// copy pixel over all channels
for (auto e = 0; e < channels; e++)
output->p(b, y, x, e, images->e<T>(b, inY, inX, e));
output->t<T>(b, y, x, e) = images->t<T>(b, inY, inX, e);
}
}
}
};
samediff::Threads::parallel_for(func, 0, batchSize, 1, 0, outHeight, 1);
}
template<typename T>
int resizeNeighborFunctor_(NDArray const *images, int const width, int const height, bool const alignCorners, bool const halfPixelCenter, NDArray *output) {
ImageResizerState st(alignCorners, halfPixelCenter);
st.validateAndCalculateOutputSize(images, width, height);
// Handle no-op resizes efficiently.
if (output->sizeAt(1) == images->sizeAt(1) && output->sizeAt(2) == images->sizeAt(2)) {
output->assign(images);
return Status::OK();
}
if (halfPixelCenter)
resizeNeighbor<HalfPixelScalerNN, T>(st, images, alignCorners, true, output);
else
resizeNeighbor<LegacyScaler, T>(st, images, alignCorners, false, output);
return Status::OK();
}
void resizeImage(NDArray const *images, Nd4jLong batchSize, Nd4jLong inHeight, Nd4jLong inWidth, Nd4jLong outHeight,
Nd4jLong outWidth, Nd4jLong channels,
std::vector<BilinearInterpolationData> const &xs,
std::vector<BilinearInterpolationData> const &ys,
NDArray *output) {
BUILD_SINGLE_SELECTOR(images->dataType(), resizeImage_,
(images, batchSize, inHeight, inWidth, outHeight, outWidth, channels, xs, ys, output),
LIBND4J_TYPES);
// void resizeImage(NDArray const *images, Nd4jLong batchSize, Nd4jLong inHeight, Nd4jLong inWidth, Nd4jLong outHeight,
// Nd4jLong outWidth, Nd4jLong channels,
// std::vector<BilinearInterpolationData> const &xs,
// std::vector<BilinearInterpolationData> const &ys,
// NDArray *output) {
// BUILD_DOUBLE_SELECTOR(images->dataType(), output->dataType(), resizeImage_,
// (images, batchSize, inHeight, inWidth, outHeight, outWidth, channels, xs, ys, output),
// NUMERIC_TYPES, FLOAT_TYPES);
// }
int resizeBilinearFunctor(nd4j::LaunchContext * context, NDArray const *images, int const width, int const height,
bool const alignCorners, bool const halfPixelCenter, NDArray *output) {
BUILD_DOUBLE_SELECTOR(images->dataType(), output->dataType(), return resizeBilinearFunctor_,
(images, width, height, alignCorners, halfPixelCenter, output), NUMERIC_TYPES, FLOAT_TYPES);
}
int resizeBilinearFunctor(nd4j::LaunchContext * context, NDArray const *images, int width, int height, bool center, NDArray *output) {
BUILD_SINGLE_SELECTOR(images->dataType(), return resizeBilinearFunctor_,
(images, width, height, center, output), LIBND4J_TYPES);
}
int resizeNeighborFunctor(nd4j::LaunchContext * context, NDArray const *images, int width, int height, bool center, NDArray *output) {
int resizeNeighborFunctor(nd4j::LaunchContext * context, NDArray const *images, int const width, int const height,
bool const alignCorners, bool const halfPixelCenter, NDArray *output) {
BUILD_SINGLE_SELECTOR(images->dataType(), return resizeNeighborFunctor_,
(images, width, height, center, output), LIBND4J_TYPES);
(images, width, height, alignCorners, halfPixelCenter, output), LIBND4J_TYPES);
}
@ -586,16 +619,6 @@ namespace helpers {
}
}
// Older incorrect scaling method that causes all resizes to have a slight
// translation leading to inconsistent results. For example, a flip then a
// resize gives different results then a resize then a flip.
struct LegacyScaler {
LegacyScaler(){};
inline float operator()(const int x, const float scale) const {
return static_cast<float>(x) * scale;
}
};
static void computeXWeightsAndIndices(const ImageResizerState& resizer_state,
const bool half_pixel_centers,
std::vector<WeightsAndIndices>* x_wais) {
@ -847,7 +870,7 @@ namespace helpers {
// simplified bicubic resize without antialiasing
//
template <typename T>
int resizeBicubicFunctorA_(nd4j::LaunchContext * context, NDArray const* image, int width, int height,
int resizeBicubicFunctorA_(nd4j::LaunchContext * context, NDArray const* image, int const width, int const height,
bool const alignCorners, bool const halfPixelAlign, NDArray* output) {
ImageResizerState st(alignCorners, halfPixelAlign); // align_corners, half_pixel_align
int res = st.validateAndCreateOutput(image, width, height);
@ -856,17 +879,17 @@ namespace helpers {
return res;
}
int resizeBicubicFunctorA(nd4j::LaunchContext * context, NDArray const* image, int width, int height,
int resizeBicubicFunctorA(nd4j::LaunchContext * context, NDArray const* image, int const width, int const height,
bool const alignCorners, bool const halfPixelAlign, NDArray* output) {
BUILD_SINGLE_SELECTOR(image->dataType(), return resizeBicubicFunctorA_, (context,
image, width, height, alignCorners, halfPixelAlign, output), NUMERIC_TYPES);
}
// ------------------------------------------------------------------------------------------------------------------ //
int resizeFunctor(nd4j::LaunchContext * context, NDArray const* image, int width, int height,
int resizeFunctor(nd4j::LaunchContext * context, NDArray const* image, int const width, int const height,
ImageResizeMethods method, bool preserveAspectRatio, bool antialias, NDArray* output) {
switch (method) {
case kResizeBilinear: return resizeBilinearFunctor(context, image, width, height, false, output); break;
case kResizeNearest: return resizeNeighborFunctor(context, image, width, height, false, output); break;
case kResizeBilinear: return resizeBilinearFunctor(context, image, width, height, false, false, output); break;
case kResizeNearest: return resizeNeighborFunctor(context, image, width, height, false, false, output); break;
case kResizeBicubic: return resizeBicubicFunctor(context, image, width, height, preserveAspectRatio, antialias, output); break;
case kResizeLanczos5:
case kResizeGaussian:

View File

@ -13,6 +13,20 @@
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
//
// @author sgazeos@gmail.com
@ -32,6 +46,38 @@ namespace helpers {
// https://en.wikipedia.org/wiki/Bilinear_interpolation)
double interpolarValue;
};
// Older incorrect scaling method that causes all resizes to have a slight
// translation leading to inconsistent results. For example, a flip then a
// resize gives different results then a resize then a flip.
struct LegacyScaler {
_CUDA_HD LegacyScaler(){};
inline _CUDA_HD float operator()(const int x, const float scale) const {
return static_cast<float>(x) * scale;
}
};
// Half pixel scaler scales assuming that the pixel centers are at 0.5, i.e. the
// floating point coordinates of the top,left pixel is 0.5,0.5.
struct HalfPixelScaler {
_CUDA_HD HalfPixelScaler(){};
inline _CUDA_HD float operator()(const int x, const float scale) const {
// Note that we subtract 0.5 from the return value, as the existing bilinear
// sampling code etc assumes pixels are in the old coordinate system.
return (static_cast<float>(x) + 0.5f) * scale - 0.5f;
}
};
// Utility functions
// calculateResizeScale determines the float scaling factor.
inline float calculateResizeScale(Nd4jLong inSize, Nd4jLong outSize,
bool alignCorners) {
return (alignCorners && outSize > 1)
? (inSize - 1) / static_cast<float>(outSize - 1)
: inSize / static_cast<float>(outSize);
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// computeInterpolationWeights kernel
// outSize - output length
@ -39,6 +85,7 @@ namespace helpers {
// scale - input scale
// interporationData - result
//
template <class Scaler>
static __global__ void computeInterpolationWeights(Nd4jLong outSize,
Nd4jLong inSize,
double scale,
@ -48,12 +95,18 @@ namespace helpers {
interpolationData[outSize].topIndex = 0;
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
auto step = blockDim.x * gridDim.x;
Scaler scaler;
for (Nd4jLong i = outSize - tid; i >= 0; i -= step) {
double in = i * scale;
interpolationData[i].bottomIndex = static_cast<Nd4jLong>(in);
interpolationData[i].topIndex = nd4j::math::nd4j_min(interpolationData[i].bottomIndex + 1, inSize - 1);
interpolationData[i].interpolarValue = in - interpolationData[i].bottomIndex;
double in = scaler(i, scale);
// interpolationData[i].bottomIndex = static_cast<Nd4jLong>(in);
// interpolationData[i].topIndex = nd4j::math::nd4j_min(interpolationData[i].bottomIndex + 1, inSize - 1);
// interpolationData[i].interpolarValue = in - interpolationData[i].bottomIndex;
double const in_f = nd4j::math::p_floor<double>(in);
double const in_c = nd4j::math::p_ceil<double>(in);
interpolationData[i].bottomIndex = nd4j::math::nd4j_max(static_cast<Nd4jLong>(in_f), (Nd4jLong)0LL);//static_cast<Nd4jLong>(in);
interpolationData[i].topIndex = nd4j::math::nd4j_min(static_cast<Nd4jLong>(in_c), inSize - 1);
interpolationData[i].interpolarValue = in - in_f;
if (channels) {
math::atomics::nd4j_atomicMul(&interpolationData[i].bottomIndex, channels);
math::atomics::nd4j_atomicMul(&interpolationData[i].topIndex, channels);
@ -72,31 +125,33 @@ namespace helpers {
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// resize image with bilinear interpolation algorithm kernel
//
template <typename T>
static __global__ void resizeImageKernel(T const* input, Nd4jLong const* inputShape, T* outputYptr, Nd4jLong* outputShape, Nd4jLong batchSize,
Nd4jLong outWidth, Nd4jLong outHeight, Nd4jLong channels, Nd4jLong inRowSize, Nd4jLong outRowSize, Nd4jLong inBatchNumValues,
template <typename T, typename Z>
static __global__ void resizeImageKernel(T const* input, Nd4jLong const* inputShape, Z* outputYptr,
Nd4jLong* outputShape, Nd4jLong batchSize, Nd4jLong outWidth, Nd4jLong outHeight, Nd4jLong channels,
Nd4jLong inRowSize, Nd4jLong outRowSize, Nd4jLong inBatchNumValues,
BilinearInterpolationData* xs_, BilinearInterpolationData* ys_) {
for (auto batch = blockIdx.x; batch < batchSize; batch += gridDim.x ) { // blockIdx.x as batch index
auto pX = input + batch * inBatchNumValues;
for (Nd4jLong y = threadIdx.x; y < outHeight; y += blockDim.x) {
const T *ys_input_lower_ptr = pX + ys_[y].bottomIndex * inRowSize;
const T *ys_input_upper_ptr = pX + ys_[y].topIndex * inRowSize;
const T* ys_input_lower_ptr = pX + ys_[y].bottomIndex * inRowSize;
const T* ys_input_upper_ptr = pX + ys_[y].topIndex * inRowSize;
double yVal = ys_[y].interpolarValue;
auto pZ = outputYptr + (batch * outHeight + y) * outRowSize;
for (Nd4jLong x = threadIdx.y; x < outWidth; x += blockDim.y) {
for (Nd4jLong x = 0; x < outWidth; x++) {
auto xsBottom = xs_[x].bottomIndex;
auto xsTop = xs_[x].topIndex;
auto xVal = xs_[x].interpolarValue;
// process interpolation for all channels
for (int c = threadIdx.z; c < channels; c += blockDim.z) {
double topLeft(ys_input_lower_ptr[xsBottom + c]);
double topRight(ys_input_lower_ptr[xsTop + c]);
double bottomLeft(ys_input_upper_ptr[xsBottom + c]);
double bottomRight(ys_input_upper_ptr[xsTop + c]);
double top = topLeft + (topRight - topLeft) * xVal;
double bottom = bottomLeft + (bottomRight - bottomLeft) * xVal;
pZ[x * channels + c] = T(top + (bottom - top) * yVal);
for (int c = 0; c < channels; c++) {
Z topLeft(ys_input_lower_ptr[xsBottom + c]);
Z topRight(ys_input_lower_ptr[xsTop + c]);
Z bottomLeft(ys_input_upper_ptr[xsBottom + c]);
Z bottomRight(ys_input_upper_ptr[xsTop + c]);
Z top = topLeft + (topRight - topLeft) * xVal;
Z bottom = bottomLeft + (bottomRight - bottomLeft) * xVal;
Z resVal = Z(top + (bottom - top) * yVal);
pZ[x * channels + c] = resVal;
}
}
}
@ -105,7 +160,7 @@ namespace helpers {
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// resize image with
template <typename T>
template <typename T, typename F>
static void resizeImage_(nd4j::LaunchContext* context, NDArray const* images, Nd4jLong batchSize, Nd4jLong inHeight, Nd4jLong inWidth, Nd4jLong outHeight,
Nd4jLong outWidth, Nd4jLong channels,
BilinearInterpolationData* xs_,
@ -115,12 +170,13 @@ namespace helpers {
Nd4jLong inBatchNumValues = inHeight * inRowSize;
Nd4jLong outRowSize = outWidth * channels;
auto stream = context->getCudaStream();
T const *input_b_ptr = reinterpret_cast<T const *>(images->getSpecialBuffer()); // this works only with 'c' direction
T *output_y_ptr = reinterpret_cast<T *>(output->specialBuffer());
T const* pInput = images->getDataBuffer()->specialAsT<T>(); //reinterpret_cast<T const *>(images->getSpecialBuffer()); // this works only with 'c' direction
F* pOutput = output->dataBuffer()->specialAsT<F>();//reinterpret_cast<F *>(output->specialBuffer());
dim3 batchSizeBlock(batchSize, 1, 1);
dim3 pictureBlock(outHeight, outWidth, channels);
resizeImageKernel<T><<<256, pictureBlock, 256, *stream>>>(input_b_ptr, images->getSpecialShapeInfo(), output_y_ptr, output->specialShapeInfo(), batchSize,
outWidth, outHeight, channels, inRowSize, outRowSize, inBatchNumValues, xs_, ys_);
resizeImageKernel<T,F><<<256, 256, 256, *stream>>>(pInput, images->getSpecialShapeInfo(), pOutput,
output->specialShapeInfo(), batchSize, outWidth, outHeight, channels, inRowSize, outRowSize,
inBatchNumValues, xs_, ys_);
auto err = cudaStreamSynchronize(*stream);
if (err != 0) {
@ -129,8 +185,9 @@ namespace helpers {
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T>
static int resizeBilinearFunctor_(nd4j::LaunchContext* context, NDArray const* images, int width, int height, bool center, NDArray* output) {
template <typename T, typename F>
static int resizeBilinearFunctor_(nd4j::LaunchContext* context, NDArray const* images, int const width,
int const height, bool const alignCorners, bool const halfPixelCenter, NDArray* output) {
const Nd4jLong batchSize = images->sizeAt(0);
const Nd4jLong inHeight = images->sizeAt(1);
const Nd4jLong inWidth = images->sizeAt(2);
@ -145,19 +202,8 @@ namespace helpers {
return ND4J_STATUS_OK;
}
// Special case for TF compatibility
if((center && inHeight < 2) || (center && inWidth < 2)){
center = false;
}
if ((center && inHeight < 2) || (inHeight < 1) || (outHeight < 1) || (center && outHeight < 2) ||
(center && inWidth < 2) || (inWidth < 1) || (outWidth < 1) || (center && outWidth < 2)) {
// wrong input data
nd4j_printf("image.resize_bilinear: Wrong input or output size to resize\n", "");
return ND4J_STATUS_BAD_ARGUMENTS;
}
float heightScale = center ? (inHeight - 1.f) / double(outHeight - 1.f) : (inHeight / float(outHeight));
float widthScale = center ? (inWidth - 1.f) / double(outWidth - 1.f) : (inWidth / float(outWidth));
float heightScale = calculateResizeScale(inHeight, outHeight, alignCorners);
float widthScale = calculateResizeScale(inWidth, outWidth, alignCorners);
BilinearInterpolationData* xs_;// = xs.data();
BilinearInterpolationData* ys_;// = xs.data();
@ -173,12 +219,24 @@ namespace helpers {
}
auto stream = context->getCudaStream();
// Compute the cached interpolation weights on the x and y dimensions.
computeInterpolationWeights<<<256, 512, 512, *stream>>>(outHeight, inHeight, heightScale, 0, ys_);
computeInterpolationWeights<<<256, 512, 512, *stream>>>(outWidth, inWidth, widthScale, channels, xs_);
if (halfPixelCenter) {
computeInterpolationWeights <
HalfPixelScaler ><<<256, 512, 512, *stream>>>(outHeight, inHeight, heightScale, 0, ys_);
computeInterpolationWeights <
HalfPixelScaler ><<<256, 512, 512, *stream>>>(outWidth, inWidth, widthScale, channels, xs_);
}
else {
computeInterpolationWeights <
LegacyScaler ><<<256, 512, 512, *stream>>>(outHeight, inHeight, heightScale, 0, ys_);
computeInterpolationWeights <
LegacyScaler ><<<256, 512, 512, *stream>>>(outWidth, inWidth, widthScale, channels, xs_);
}
printf("Input is %dx%d, Output is %dx%d\n", inHeight, inWidth, outHeight, outWidth);
NDArray::prepareSpecialUse({output}, {images});
resizeImage(context, images, batchSize, inHeight, inWidth, outHeight, outWidth, channels, xs_, ys_, output);
resizeImage_<T,F>(context, images, batchSize, inHeight, inWidth, outHeight, outWidth, channels, xs_, ys_, output);
err = cudaStreamSynchronize(*stream);
NDArray::registerSpecialUse({output}, {images});
err = cudaFree(xs_);
if (err != 0) {
throw cuda_exception::build("helpers::resize_image: Cannot deallocate memory for vertical parts rectangulars", err);
@ -197,20 +255,28 @@ namespace helpers {
//
template <typename T>
static __global__ void resizeNeighborKernel(T const* input, Nd4jLong* inputShape, T* output, Nd4jLong* outputShape,
Nd4jLong batchSize, Nd4jLong inWidth, Nd4jLong inHeight, Nd4jLong outWidth, Nd4jLong outHeight, Nd4jLong channels, double widthScale, double heightScale, bool center) {
Nd4jLong batchSize, Nd4jLong inWidth, Nd4jLong inHeight, Nd4jLong outWidth, Nd4jLong outHeight, Nd4jLong channels, double widthScale, double heightScale, bool alignCorners, bool halfPixelCenters) {
//for (int b = blockIdx.x; b < batchSize; b += gridDim.x)
if (blockIdx.x < batchSize)
{
auto b = blockIdx.x;
for (int y = threadIdx.x; y < outHeight; y += blockDim.x) {
Nd4jLong inY = nd4j::math::nd4j_min(
(center) ? static_cast<Nd4jLong>(nd4j::math::p_round<float>(y * heightScale)) : static_cast<Nd4jLong>(nd4j::math::p_floor<float>(
y * heightScale)), inHeight - 1);
auto posY = alignCorners ? static_cast<Nd4jLong>(nd4j::math::p_round<float>(halfPixelCenters?((float)y + 0.5f) * heightScale:(float)y * heightScale)) : static_cast<Nd4jLong>(nd4j::math::p_floor<float>(
halfPixelCenters?((float)y + 0.5f) * heightScale:(float)y * heightScale));
Nd4jLong inY = nd4j::math::nd4j_min(posY, inHeight - 1);
if (halfPixelCenters) {
inY = nd4j::math::nd4j_max(0LL, inY);
}
for (int x = threadIdx.y; x < outWidth; x += blockDim.y) {
Nd4jLong inX = nd4j::math::nd4j_min(
(center) ? static_cast<Nd4jLong>(nd4j::math::p_round<float>(x * widthScale)) : static_cast<Nd4jLong>(nd4j::math::p_floor<float>(
x * widthScale)), inWidth - 1);
auto posX = alignCorners ? static_cast<Nd4jLong>(nd4j::math::p_round<float>(halfPixelCenters?((float)x + 0.5f) * widthScale:(float)x * widthScale)) : static_cast<Nd4jLong>(nd4j::math::p_floor<float>(
halfPixelCenters?((float)x + 0.5f) * widthScale:(float)x * widthScale));
Nd4jLong inX = nd4j::math::nd4j_min(posX, inWidth - 1);
if (halfPixelCenters) {
inX = nd4j::math::nd4j_max(0LL, inX);
}
auto start = blockIdx.z * blockDim.z + threadIdx.z;
auto step = blockDim.z * gridDim.z;
@ -231,7 +297,8 @@ namespace helpers {
// resizeNeighborFunctor - main algorithm by nearest neighbor
//
template <typename T>
int resizeNeighborFunctor_(nd4j::LaunchContext* context, NDArray const* images, int width, int height, bool center, NDArray* output) {
int resizeNeighborFunctor_(nd4j::LaunchContext* context, NDArray const* images, int const width, int const height,
bool const alignCorners, bool const halfPixelCenters, NDArray* output) {
const Nd4jLong batchSize = images->sizeAt(0);
const Nd4jLong inHeight = images->sizeAt(1);
const Nd4jLong inWidth = images->sizeAt(2);
@ -246,25 +313,24 @@ namespace helpers {
return ND4J_STATUS_OK;
}
if ((center && inHeight < 2) || (inHeight < 1) || (outHeight < 1) || (center && outHeight < 2) ||
(center && inWidth < 2) || (inWidth < 1) || (outWidth < 1) || (center && outWidth < 2)) {
// wrong input data
nd4j_printf("image.resize_nearest_neighbor: Wrong input or output size to resize\n", "");
return ND4J_STATUS_BAD_ARGUMENTS;
}
double heightScale = center ? (inHeight - 1.) / double(outHeight - 1.0) : (inHeight / double(outHeight));
double widthScale = center ? (inWidth - 1.) / double(outWidth - 1.0) : (inWidth / double(outWidth));
auto imagesBuffer = reinterpret_cast<T const*>(images->getSpecialBuffer());
auto outputBuffer = reinterpret_cast<T*>(output->specialBuffer());
// if ((alignCorners && inHeight < 2) || (inHeight < 1) || (outHeight < 1) || (alignCorners && outHeight < 2) ||
// (alignCorners && inWidth < 2) || (inWidth < 1) || (outWidth < 1) || (center && outWidth < 2)) {
// // wrong input data
// nd4j_printf("image.resize_nearest_neighbor: Wrong input or output size to resize\n", "");
// return ND4J_STATUS_BAD_ARGUMENTS;
// }
// float heightScale = alignCorners ? (inHeight - 1.f) / float(outHeight - 1.f) : (inHeight / float(outHeight));
// float widthScale = alignCorners ? (inWidth - 1.f) / float(outWidth - 1.f) : (inWidth / float(outWidth));
float heightScale = calculateResizeScale(inHeight, outHeight, alignCorners);
float widthScale = calculateResizeScale(inWidth, outWidth, alignCorners);
auto imagesBuffer = images->getDataBuffer()->specialAsT<T>();//reinterpret_cast<T const*>(images->getSpecialBuffer());
auto outputBuffer = output->dataBuffer()->specialAsT<T>();//reinterpret_cast<T*>(output->specialBuffer());
auto stream = context->getCudaStream();
//T const* input, Nd4jLong const* inputShape, T* output, Nd4jLong* outputShape,
// Nd4jLong batchSize, Nd4jLong inWidth, Nd4jLong inHeight, Nd4jLong outWidth, Nd4jLong outHeight, Nd4jLong channels, double widthScale, double heightScale, bool center
//input, inputShape, output, outputShape,
// batchSize, inWidth, inHeight, outWidth, outHeight, channels, widthScale, heightScale, center
NDArray::prepareSpecialUse({output}, {images});
resizeNeighborKernel<T><<<batchSize, outHeight * outWidth, 512, *stream>>>(imagesBuffer, images->getSpecialShapeInfo(), outputBuffer, output->specialShapeInfo(),
batchSize, inWidth, inHeight, outWidth, outHeight, channels, widthScale, heightScale, center);
batchSize, inWidth, inHeight, outWidth, outHeight, channels, widthScale, heightScale, alignCorners, halfPixelCenters);
NDArray::registerSpecialUse({output}, {images});
return Status::OK();
@ -275,39 +341,38 @@ namespace helpers {
void resizeImage(nd4j::LaunchContext* context, NDArray const* images, Nd4jLong batchSize, Nd4jLong inHeight,
Nd4jLong inWidth, Nd4jLong outHeight, Nd4jLong outWidth, Nd4jLong channels, BilinearInterpolationData* xs_,
BilinearInterpolationData* ys_, NDArray* output) {
BUILD_SINGLE_SELECTOR(images->dataType(), resizeImage_, (context, images, batchSize, inHeight, inWidth, outHeight, outWidth, channels, xs_, ys_, output), LIBND4J_TYPES);
BUILD_DOUBLE_SELECTOR(images->dataType(), output->dataType(),
resizeImage_, (context, images, batchSize, inHeight, inWidth, outHeight, outWidth, channels,
xs_, ys_, output), NUMERIC_TYPES, FLOAT_TYPES);
}
BUILD_SINGLE_TEMPLATE(template void resizeImage_,(nd4j::LaunchContext* context, NDArray const* images,
BUILD_DOUBLE_TEMPLATE(template void resizeImage_,(nd4j::LaunchContext* context, NDArray const* images,
Nd4jLong batchSize, Nd4jLong inHeight, Nd4jLong inWidth, Nd4jLong outHeight, Nd4jLong outWidth,
Nd4jLong channels, BilinearInterpolationData* xs_, BilinearInterpolationData* ys_, NDArray* output), LIBND4J_TYPES);
Nd4jLong channels, BilinearInterpolationData* xs_, BilinearInterpolationData* ys_, NDArray* output),
NUMERIC_TYPES, FLOAT_TYPES);
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
int resizeBilinearFunctor(nd4j::LaunchContext* context, NDArray const* images, int width, int height, bool center, NDArray* output) {
BUILD_SINGLE_SELECTOR(images->dataType(), return resizeBilinearFunctor_, (context, images, width, height, center, output), LIBND4J_TYPES);
int resizeBilinearFunctor(nd4j::LaunchContext* context, NDArray const* images, int width, int height,
bool const alignCorners, bool const halfPixelCenter, NDArray* output) {
BUILD_DOUBLE_SELECTOR(images->dataType(), output->dataType(), return resizeBilinearFunctor_, (context, images,
width, height, alignCorners, halfPixelCenter, output), NUMERIC_TYPES, FLOAT_TYPES);
}
BUILD_SINGLE_TEMPLATE(template int resizeBilinearFunctor_, (nd4j::LaunchContext* context, NDArray const* images, int width, int height, bool center, NDArray* output), LIBND4J_TYPES);
// BUILD_SINGLE_TEMPLATE(template int resizeBilinearFunctor_, (nd4j::LaunchContext* context,
// NDArray const* images, int const width, int const height, bool const alignCorners,
// bool const halfPixelCenter, NDArray* output), LIBND4J_TYPES);
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
int resizeNeighborFunctor(nd4j::LaunchContext* context, NDArray const* images, int width, int height, bool center, NDArray* output) {
BUILD_SINGLE_SELECTOR(images->dataType(), return resizeNeighborFunctor_, (context, images, width, height, center, output), LIBND4J_TYPES);
int resizeNeighborFunctor(nd4j::LaunchContext* context, NDArray const* images, int const width, int const height,
bool const alignCorners, bool const halfPixelCenter, NDArray* output) {
BUILD_SINGLE_SELECTOR(images->dataType(), return resizeNeighborFunctor_,
(context, images, width, height, alignCorners, halfPixelCenter, output), LIBND4J_TYPES);
}
BUILD_SINGLE_TEMPLATE(template int resizeNeighborFunctor_, (nd4j::LaunchContext* context, NDArray const* images,
int width, int height, bool center, NDArray* output), LIBND4J_TYPES);
// BUILD_SINGLE_TEMPLATE(template int resizeNeighborFunctor_, (nd4j::LaunchContext* context, NDArray const* images,
// int width, int height, bool const alignCorners, bool const halfPixelCenter, NDArray* output), LIBND4J_TYPES);
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// Bicubic interpolation
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// Utility functions and classes
// calculateResizeScale determines the float scaling factor.
inline float calculateResizeScale(Nd4jLong inSize, Nd4jLong outSize,
bool alignCorners) {
return (alignCorners && outSize > 1)
? (inSize - 1) / static_cast<float>(outSize - 1)
: inSize / static_cast<float>(outSize);
}
struct ImageResizerState {
explicit ImageResizerState(bool alignCorners, bool halfPixelCenters)
: _alignCorners(alignCorners),
@ -362,17 +427,6 @@ namespace helpers {
bool _halfPixelCenters;
};
// Half pixel scaler scales assuming that the pixel centers are at 0.5, i.e. the
// floating point coordinates of the top,left pixel is 0.5,0.5.
struct HalfPixelScaler {
_CUDA_HD HalfPixelScaler(){};
inline _CUDA_HD float operator()(const int x, const float scale) const {
// Note that we subtract 0.5 from the return value, as the existing bilinear
// sampling code etc assumes pixels are in the old coordinate system.
return (static_cast<float>(x) + 0.5f) * scale - 0.5f;
}
};
struct WeightsAndIndices {
float _weight0;
float _weight1;
@ -547,16 +601,6 @@ namespace helpers {
}
}
// Older incorrect scaling method that causes all resizes to have a slight
// translation leading to inconsistent results. For example, a flip then a
// resize gives different results then a resize then a flip.
struct LegacyScaler {
_CUDA_HD LegacyScaler(){};
inline _CUDA_HD float operator()(const int x, const float scale) const {
return static_cast<float>(x) * scale;
}
};
static __global__ void accumulateChannelsKernel(WeightsAndIndices* pXWais, Nd4jLong outWidth, Nd4jLong channels) {
auto start = blockIdx.x * blockDim.x + threadIdx.x;
auto step = blockDim.x * gridDim.x;
@ -906,8 +950,8 @@ namespace helpers {
int resizeFunctor(nd4j::LaunchContext * context, NDArray const* image, int width, int height,
ImageResizeMethods method, bool preserveAspectRatio, bool antialias, NDArray* output) {
switch (method) {
case kResizeBilinear: return resizeBilinearFunctor(context, image, width, height, false, output); break;
case kResizeNearest: return resizeNeighborFunctor(context, image, width, height, true, output); break;
case kResizeBilinear: return resizeBilinearFunctor(context, image, width, height, false, false, output); break;
case kResizeNearest: return resizeNeighborFunctor(context, image, width, height, false, false, output); break;
case kResizeBicubic: return resizeBicubicFunctor(context, image, width, height, preserveAspectRatio, antialias, output); break;
case kResizeLanczos5:
case kResizeGaussian:

View File

@ -30,6 +30,67 @@ namespace nd4j {
namespace ops {
namespace helpers {
template <typename T>
static __global__ void reverseTadKernel(void* vinput, Nd4jLong *inputShape, void* voutput, Nd4jLong *outputShape, Nd4jLong *inputTadShape, Nd4jLong *inputTadOffsets, Nd4jLong *outputTadShape, Nd4jLong *outputTadOffsets, uint64_t limit, uint64_t numOfElemsToReverse, uint64_t numTads) {
auto input = reinterpret_cast<T*>(vinput);
auto output = reinterpret_cast<T*>(voutput);
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
const auto step = gridDim.x * blockDim.x;
// this means that we'll have additional cycle, to move middle element
auto div = numOfElemsToReverse / 2;
auto odd = numOfElemsToReverse % 2 != 0;
auto rlimit = odd ? limit / 2 + 1 : limit / 2;
// all threads operate in the same input/output space
for (uint64_t e = tid; e < rlimit; e += step) {
// finding out the TAD we're going to process
auto tadId = e / div;
if (tadId >= numTads)
continue;
// now finding out element within tad
auto idx = e % div;
//printf("TID: %i; numTads: %lld; tadLength: %lld; tadId: %i, idx: %lld\n", tid, numTads, numOfElemsToReverse, tadId, idx);
auto tadInput = input + inputTadOffsets[tadId];
auto tadOutput = output + outputTadOffsets[tadId];
// we're calculating offsets within input TAD
auto fOffset = shape::getIndexOffset(idx, inputTadShape);
auto lOffset = shape::getIndexOffset(numOfElemsToReverse - idx - 1, inputTadShape);
// now we're storing input values
auto v1 = tadInput[fOffset];
auto v2 = tadInput[lOffset];
// now we're calculating offsets within output TAD
auto zfOffset = shape::getIndexOffset(idx, outputTadShape);
auto zlOffset = shape::getIndexOffset(numOfElemsToReverse - idx - 1, outputTadShape);
// and saving values to output arrays
tadOutput[zfOffset] = v2;
tadOutput[zlOffset] = v1;
}
// moving odd element in blocks
if (odd && threadIdx.x == 0) {
for (uint64_t e = blockIdx.x; e < numTads; e += gridDim.x) {
auto tadInput = input + inputTadOffsets[e];
auto tadOutput = output + outputTadOffsets[e];
auto xOffset = shape::getIndexOffset(numOfElemsToReverse / 2, inputTadShape);
auto zOffset = shape::getIndexOffset(numOfElemsToReverse / 2, outputTadShape);
tadOutput[zOffset] = tadInput[xOffset];
}
}
}
template <typename T>
static __global__ void reverseArrayKernel(void* input, Nd4jLong *inputShape, void* output, Nd4jLong *outputShape, Nd4jLong numOfElemsToReverse) {
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
@ -52,7 +113,7 @@ namespace helpers {
auto odd = numOfElemsToReverse % 2 != 0;
auto limit = numOfElemsToReverse / 2;
for (Nd4jLong e = tid; e < limit; e += step) {
for (uint64_t e = tid; e < limit; e += step) {
// we're calculating offsets within input array
auto fOffset = shape::getIndexOffset(e, inputShape);
auto lOffset = shape::getIndexOffset(numOfElemsToReverse - e - 1, inputShape);
@ -80,13 +141,19 @@ namespace helpers {
}
template<typename T>
static void reverseArray(nd4j::LaunchContext * context, NDArray* input, NDArray* output, Nd4jLong numOfElemsToReverse) {
static void reverseTad(nd4j::LaunchContext * context, const NDArray* input, NDArray* output, Nd4jLong *inputTadShape, Nd4jLong *inputTadOffsets, Nd4jLong *outputTadShape, Nd4jLong *outputTadOffsets, uint64_t tadLength) {
auto stream = context->getCudaStream();
reverseTadKernel<T><<<256, 512, 8192, *stream>>>(input->getSpecialBuffer(), input->getSpecialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), inputTadShape, inputTadOffsets, outputTadShape, outputTadOffsets, input->lengthOf(), tadLength, input->lengthOf() / tadLength);
}
template<typename T>
static void reverseArray(nd4j::LaunchContext * context, const NDArray* input, NDArray* output, Nd4jLong numOfElemsToReverse) {
auto stream = context->getCudaStream();
Nd4jLong numOfReverse = numOfElemsToReverse;
if (numOfElemsToReverse == 0)
numOfReverse = input->lengthOf();
reverseArrayKernel<T><<<256, 512, 8192, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), numOfReverse);
reverseArrayKernel<T><<<256, 512, 8192, *stream>>>(input->getSpecialBuffer(), input->getSpecialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), numOfReverse);
}
@ -153,27 +220,23 @@ namespace helpers {
// we need to reverse axis only if that's new op
std::vector<int> dimensions = isBackProp ? ShapeUtils::evalDimsToExclude(input->rankOf(), *intArgs) : *intArgs;
std::vector<int> axis = ShapeUtils::evalDimsToExclude(input->rankOf(), dimensions);
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), axis);
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), axis);
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
auto listOut = output->allTensorsAlongDimension(dimensions);
auto listIn = input->allTensorsAlongDimension(dimensions);
NDArray *subArrIn, *subArrOut;
NDArray::prepareSpecialUse({output}, {input});
for(int i = 0; i < listIn->size(); ++i) { // listIn->size() = listOut->size()
subArrIn = listIn->at(i);
subArrOut = listOut->at(i);
BUILD_SINGLE_SELECTOR(input->dataType(), reverseArray, (context, subArrIn, subArrOut, 0), LIBND4J_TYPES);
if (packX.numberOfTads() == 1) {
BUILD_SINGLE_SELECTOR(input->dataType(), reverseArray, (context, input, output, 0), LIBND4J_TYPES);
} else {
BUILD_SINGLE_SELECTOR(input->dataType(), reverseTad, (context, input, output, packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets(), (uint64_t) (input->lengthOf() / packX.numberOfTads())), LIBND4J_TYPES);
}
//BUILD_SINGLE_SELECTOR(input->dataType(), reverseArray, (context, const_cast<NDArray*>(input), output, (int)0), LIBND4J_TYPES);
NDArray::registerSpecialUse({output}, {input});
delete listOut;
delete listIn;
}
BUILD_SINGLE_TEMPLATE(template void reverseArray, (nd4j::LaunchContext * context, NDArray *inArr, NDArray *outArr, Nd4jLong numOfElemsToReverse), LIBND4J_TYPES);
BUILD_SINGLE_TEMPLATE(template void reverseArray, (nd4j::LaunchContext * context, const NDArray *inArr, NDArray *outArr, Nd4jLong numOfElemsToReverse), LIBND4J_TYPES);
}
}

View File

@ -37,15 +37,15 @@ namespace helpers {
kResizeArea
};
int resizeBilinearFunctor(nd4j::LaunchContext * context, NDArray const* image, int width, int height, bool center,
NDArray* output);
int resizeNeighborFunctor(nd4j::LaunchContext * context, NDArray const* image, int width, int height, bool center,
NDArray* output);
int resizeBicubicFunctor(nd4j::LaunchContext * context, NDArray const* image, int width, int height,
int resizeBilinearFunctor(nd4j::LaunchContext * context, NDArray const* image, int const width, int const height,
bool const alignCorners, bool const halfPixelCenter, NDArray* output);
int resizeNeighborFunctor(nd4j::LaunchContext * context, NDArray const* image, int const width, int const height,
bool const alignCorners, bool const halfPixelCenter, NDArray* output);
int resizeBicubicFunctor(nd4j::LaunchContext * context, NDArray const* image, int const width, int const height,
bool preserveAspectRatio, bool antialias, NDArray* output);
int resizeBicubicFunctorA(nd4j::LaunchContext * context, NDArray const* image, int width, int height,
int resizeBicubicFunctorA(nd4j::LaunchContext * context, NDArray const* image, int const width, int const height,
bool const alignCorners, bool const halfPixelAlign, NDArray* output);
int resizeFunctor(nd4j::LaunchContext * context, NDArray const* image, int width, int height,
int resizeFunctor(nd4j::LaunchContext * context, NDArray const* image, int const width, int const height,
ImageResizeMethods method, bool preserveAspectRatio, bool antialias, NDArray* output);
void cropAndResizeFunctor(nd4j::LaunchContext * context, NDArray const* images, NDArray const* boxes,

File diff suppressed because one or more lines are too long

View File

@ -212,12 +212,12 @@ TEST_F(ConvolutionTests2, Test_Dilation2D_Again_2) {
}
TYPED_TEST(TypedConvolutionTests2, sconv2d_bp_1) {
TypeParam _expGradWpB[] = {1603.7102981f, 10645.6278024f, 5975.4227995f, 17697.0903052f, 12133.6353024f, 26535.0528052f, 1779.221097f, 11795.5686029f, 6721.9835994f, 19904.0811062f, 13775.2461029f, 30123.0936062f, 1954.7318976f, 12945.5094033f, 7468.5443993f, 22111.071907f, 15416.8569033f, 33711.134407f, 2130.2426974f, 14095.4502038f, 8215.1051992f, 24318.0627081f, 17058.4677038f, 37299.1752081f, 2305.7534972f, 15245.3910042f, 8961.6659991f, 26525.0535091f, 18700.0785042f, 40887.2160091f, 2481.2642970f, 16395.3318047f, 9708.2267991f, 28732.0443100f, 20341.6893047f, 44475.2568100f, 2656.7750968f, 17545.2726051f, 10454.7875990f, 30939.0351110f, 21983.3001051f, 48063.2976110f, 2832.2858966f, 18695.2134056f, 11201.3483989f, 33146.0259119f, 23624.9109056f, 51651.3384119f, 3007.7966964f, 19845.1542060f, 11947.9091988f, 35353.0167129f, 25266.5217060f, 55239.3792129f, 3183.3074962f, 20995.095006f, 12694.4699987f, 37560.007513f, 26908.132506f, 58827.4200139};
TypeParam _expGradWpB[] = {1603.7102981f, 10645.6278024f, 5975.4227995f, 17697.0903052f, 12133.6353024f, 26535.0528052f, 1779.221097f, 11795.5686029f, 6721.9835994f, 19904.0811062f, 13775.2461029f, 30123.0936062f, 1954.7318976f, 12945.5094033f, 7468.5443993f, 22111.071907f, 15416.8569033f, 33711.134407f, 2130.2426974f, 14095.4502038f, 8215.1051992f, 24318.0627081f, 17058.4677038f, 37299.1752081f, 2305.7534972f, 15245.3910042f, 8961.6659991f, 26525.0535091f, 18700.0785042f, 40887.2160091f, 2481.2642970f, 16395.3318047f, 9708.2267991f, 28732.0443100f, 20341.6893047f, 44475.2568100f, 2656.7750968f, 17545.2726051f, 10454.7875990f, 30939.0351110f, 21983.3001051f, 48063.2976110f, 2832.2858966f, 18695.2134056f, 11201.3483989f, 33146.0259119f, 23624.9109056f, 51651.3384119f, 3007.7966964f, 19845.1542060f, 11947.9091988f, 35353.0167129f, 25266.5217060f, 55239.3792129f, 3183.3074962f, 20995.095006f, 12694.4699987f, 37560.007513f, 26908.132506f, 58827.4200139f};
Nd4jLong _expGradWpS[] {4, 10, 6, 1, 1, 6, 1, 1, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, 1, 99};
NDArray expGWP(_expGradWpB, _expGradWpS);
expGWP.permutei({2,3,1,0});
TypeParam _expGradWdB[] = {2074.21032f, 2082.76104f, 2091.31176f, 2099.86248f, 2108.4132f, 2159.71752f, 2168.26824f, 2176.81896f, 2185.36968f, 2193.9204f, 2245.22472f, 2253.77544f, 2262.32616f, 2270.87688f, 2279.4276f, 2330.73192f, 2339.28264f, 2347.83336f, 2356.38408f, 2364.9348f, 2416.23912f, 2424.78984f, 2433.34056f, 2441.89128f, 2450.442f, 3112.99344f, 3122.06328f, 3131.13312f, 3140.20296f, 3149.2728f, 3203.69184f, 3212.76168f, 3221.83152f, 3230.90136f, 3239.9712f, 3294.39024f, 3303.46008f, 3312.52992f, 3321.59976f, 3330.6696f, 3385.08864f, 3394.15848f, 3403.22832f, 3412.29816f, 3421.368f, 3475.78704f, 3484.85688f, 3493.92672f, 3502.99656f, 3512.0664f, 4255.60056f, 4265.18952f, 4274.77848f, 4284.36744f, 4293.9564f, 4351.49016f, 4361.07912f, 4370.66808f, 4380.25704f, 4389.846f, 4447.37976f, 4456.96872f, 4466.55768f, 4476.14664f, 4485.7356f, 4543.26936f, 4552.85832f, 4562.44728f, 4572.03624f, 4581.6252f, 4639.15896f, 4648.74792f, 4658.33688f, 4667.92584f, 4677.5148f, 2140.10988f, 2148.92016f, 2157.73044f, 2166.54072f, 2175.351f, 2228.21268f, 2237.02296f, 2245.83324f, 2254.64352f, 2263.4538f, 2316.31548f, 2325.12576f, 2333.93604f, 2342.74632f, 2351.5566f, 2404.41828f, 2413.22856f, 2422.03884f, 2430.84912f, 2439.6594f, 2492.52108f, 2501.33136f, 2510.14164f, 2518.95192f, 2527.7622f, 3204.849f, 3214.1784f, 3223.5078f, 3232.8372f, 3242.1666f, 3298.143f, 3307.4724f, 3316.8018f, 3326.1312f, 3335.4606f, 3391.437f, 3400.7664f, 3410.0958f, 3419.4252f, 3428.7546f, 3484.731f, 3494.0604f, 3503.3898f, 3512.7192f, 3522.0486f, 3578.025f, 3587.3544f, 3596.6838f, 3606.0132f, 3615.3426f, 4373.41212f, 4383.26064f, 4393.10916f, 4402.95768f, 4412.8062f, 4471.89732f, 4481.74584f, 4491.59436f, 4501.44288f, 4511.2914f, 4570.38252f, 4580.23104f, 4590.07956f, 4599.92808f, 4609.7766f, 4668.86772f, 4678.71624f, 4688.56476f, 4698.41328f, 4708.2618f, 4767.35292f, 4777.20144f, 4787.04996f, 4796.89848f, 4806.747};
TypeParam _expGradWdB[] = {2074.21032f, 2082.76104f, 2091.31176f, 2099.86248f, 2108.4132f, 2159.71752f, 2168.26824f, 2176.81896f, 2185.36968f, 2193.9204f, 2245.22472f, 2253.77544f, 2262.32616f, 2270.87688f, 2279.4276f, 2330.73192f, 2339.28264f, 2347.83336f, 2356.38408f, 2364.9348f, 2416.23912f, 2424.78984f, 2433.34056f, 2441.89128f, 2450.442f, 3112.99344f, 3122.06328f, 3131.13312f, 3140.20296f, 3149.2728f, 3203.69184f, 3212.76168f, 3221.83152f, 3230.90136f, 3239.9712f, 3294.39024f, 3303.46008f, 3312.52992f, 3321.59976f, 3330.6696f, 3385.08864f, 3394.15848f, 3403.22832f, 3412.29816f, 3421.368f, 3475.78704f, 3484.85688f, 3493.92672f, 3502.99656f, 3512.0664f, 4255.60056f, 4265.18952f, 4274.77848f, 4284.36744f, 4293.9564f, 4351.49016f, 4361.07912f, 4370.66808f, 4380.25704f, 4389.846f, 4447.37976f, 4456.96872f, 4466.55768f, 4476.14664f, 4485.7356f, 4543.26936f, 4552.85832f, 4562.44728f, 4572.03624f, 4581.6252f, 4639.15896f, 4648.74792f, 4658.33688f, 4667.92584f, 4677.5148f, 2140.10988f, 2148.92016f, 2157.73044f, 2166.54072f, 2175.351f, 2228.21268f, 2237.02296f, 2245.83324f, 2254.64352f, 2263.4538f, 2316.31548f, 2325.12576f, 2333.93604f, 2342.74632f, 2351.5566f, 2404.41828f, 2413.22856f, 2422.03884f, 2430.84912f, 2439.6594f, 2492.52108f, 2501.33136f, 2510.14164f, 2518.95192f, 2527.7622f, 3204.849f, 3214.1784f, 3223.5078f, 3232.8372f, 3242.1666f, 3298.143f, 3307.4724f, 3316.8018f, 3326.1312f, 3335.4606f, 3391.437f, 3400.7664f, 3410.0958f, 3419.4252f, 3428.7546f, 3484.731f, 3494.0604f, 3503.3898f, 3512.7192f, 3522.0486f, 3578.025f, 3587.3544f, 3596.6838f, 3606.0132f, 3615.3426f, 4373.41212f, 4383.26064f, 4393.10916f, 4402.95768f, 4412.8062f, 4471.89732f, 4481.74584f, 4491.59436f, 4501.44288f, 4511.2914f, 4570.38252f, 4580.23104f, 4590.07956f, 4599.92808f, 4609.7766f, 4668.86772f, 4678.71624f, 4688.56476f, 4698.41328f, 4708.2618f, 4767.35292f, 4777.20144f, 4787.04996f, 4796.89848f, 4806.747f};
Nd4jLong _expGradWdS[] = {4, 2, 3, 5, 5, 75, 25, 5, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, 1, 99};
NDArray expGWD(_expGradWdB, _expGradWdS);
expGWD.permutei({2,3,1,0});

View File

@ -1594,7 +1594,7 @@ TEST_F(DeclarableOpsTests1, TestGemv1) {
auto z = NDArrayFactory::create_<float>('f', {5, 1});
auto expBuffer = new float[5]{28.00,64.00,100.00,136.00,172.00};
auto expBuffer = new float[5]{28.00f,64.00f,100.00f,136.00f,172.00f};
auto exp = new NDArray(expBuffer, z->getShapeInfo());
nd4j::blas::GEMV<float, float, float>::op('f', x->rows(), x->columns(), 1.0f, x->getBuffer(), y->rows(), y->getBuffer(), 1, 0.0, z->getBuffer(), 1);
@ -3523,7 +3523,8 @@ TEST_F(DeclarableOpsTests1, Reverse_7 ) {
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto result = results->at(0);
// result->printBuffer();
//expected.printIndexedBuffer("E");
//result->printIndexedBuffer("R");
ASSERT_TRUE(expected.isSameShapeStrict(result));
ASSERT_TRUE(expected.equalsTo(result));
@ -3605,7 +3606,9 @@ TEST_F(DeclarableOpsTests1, Reverse_11 ) {
auto input = NDArrayFactory::create<float>('c', {2,3,4});
auto expected = NDArrayFactory::create<float>('c', {2,3,4}, {24., 23., 22., 21., 20., 19., 18., 17., 16., 15., 14., 13., 12., 11., 10., 9., 8., 7., 6., 5., 4., 3., 2., 1.});
auto expected = NDArrayFactory::create<float>('c', {2,3,4}, {24.f, 23.f, 22.f, 21.f, 20.f, 19.f, 18.f, 17.f, 16.f,
15.f, 14.f, 13.f, 12.f, 11.f, 10.f, 9.f, 8.f, 7.f,
6.f, 5.f, 4.f, 3.f, 2.f, 1.f});
input.linspace(1);
nd4j::ops::reverse op;

View File

@ -121,10 +121,10 @@ TEST_F(DeclarableOpsTests10, Test_Or_1) {
}
TEST_F(DeclarableOpsTests10, Test_Not_1) {
auto x = NDArrayFactory::create<bool>('c', {4}, {1, 1, 0, 1});
auto y = NDArrayFactory::create<bool>('c', {4}, {0, 0, 0, 1});
auto x = NDArrayFactory::create<bool>('c', {4}, {true, true, false, true});
auto y = NDArrayFactory::create<bool>('c', {4}, {false, false, false, true});
// auto e = NDArrayFactory::create<bool>('c', {4}, {1, 1, 1, 0});
auto e = NDArrayFactory::create<bool>('c', {4}, {0, 0, 1, 0});
auto e = NDArrayFactory::create<bool>('c', {4}, {false, false, true, false});
nd4j::ops::boolean_not op;
auto result = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::BOOL);
@ -245,7 +245,8 @@ TEST_F(DeclarableOpsTests10, WhereNP_SGO_Test_1) {
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests10, WhereNP_SGO_Test_2) {
auto cond2d = NDArrayFactory::create<bool>('c', {3, 5}, {1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1});
auto cond2d = NDArrayFactory::create<bool>('c', {3, 5}, {true, true, false, false, true, true, true,
true, true, true, false, true, true, true, true});
// auto expIdx({0, 1, 0, 2, 0, 3, 4, 1, 4, 1});
auto exp1 = NDArrayFactory::create<Nd4jLong>({0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2});
auto exp2 = NDArrayFactory::create<Nd4jLong>({0, 1, 4, 0, 1, 2, 3, 4, 1, 2, 3, 4});
@ -623,7 +624,7 @@ TEST_F(DeclarableOpsTests10, range_test11) {
//////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests10, range_test12) {
auto exp = NDArrayFactory::create<float>('c', {9}, {0.5, 1. , 1.5, 2. , 2.5, 3. , 3.5, 4. , 4.5});
auto exp = NDArrayFactory::create<float>('c', {9}, {0.5f, 1.f , 1.5f, 2.f , 2.5f, 3.f , 3.5f, 4.f , 4.5f});
nd4j::ops::range op;
auto result = op.execute({}, {0.5, 5, 0.5}, {}, {});
@ -1416,7 +1417,7 @@ TEST_F(DeclarableOpsTests10, broadcast_to_test10) {
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test1) {
NDArray input = NDArrayFactory::create<double>('c', {1, 2,3,4});
NDArray input = NDArrayFactory::create<double>('c', {1, 2, 3, 4});
//NDArray<float> paddings('c', {3,2}, {0,0, 0,1, 0,0});
//NDArray<float> expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.});
NDArray expected = NDArrayFactory::create<double>('c', {1, 10, 10, 4}, {1., 2., 3., 4., 2.2, 3.2, 4.2, 5.2, 3.4, 4.4, 5.4, 6.4,
@ -1470,6 +1471,138 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test1) {
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test_11) {
NDArray input = NDArrayFactory::create<float>('c', {1, 1, 1, 256});
input.assign(0.8f); //linspace(1);
auto size = NDArrayFactory::create<int>({65,65});
auto ex = NDArrayFactory::create<float>('c', {1,65,65,256});
nd4j::ops::resize_bilinear op;
auto results = op.execute({&input, &size}, {}, {}, {false});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
NDArray* result = results->at(0);
ASSERT_NE(*result, ex);
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test_12) {
NDArray input = NDArrayFactory::create<float>('c', {1, 1, 1, 256});
input.assign(0.8f); //linspace(1);
auto size = NDArrayFactory::create<int>({65,65});
auto ex = NDArrayFactory::create<float>('c', {1,65,65,256});
nd4j::ops::resize_bilinear op;
auto results = op.execute({&input, &size}, {}, {}, {true});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
NDArray* result = results->at(0);
ASSERT_NE(*result, ex);
delete results;
}
TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test1_1) {
NDArray input = NDArrayFactory::create<double>('c', {1, 2, 3, 4});
//NDArray<float> paddings('c', {3,2}, {0,0, 0,1, 0,0});
//NDArray<float> expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.});
NDArray expected = NDArrayFactory::create<double>('c', {1, 4, 5, 4}, {
1., 2., 3., 4.,
2.6, 3.6, 4.6, 5.6,
5., 6., 7., 8.,
7.4, 8.4, 9.4, 10.4,
9., 10., 11., 12.,
4., 5., 6., 7.,
5.6, 6.6, 7.6, 8.6,
8., 9., 10., 11.,
10.4, 11.4, 12.4, 13.4,
12., 13., 14., 15.,
10., 11., 12., 13.,
11.6, 12.6, 13.6, 14.6,
14., 15., 16., 17.,
16.4, 17.4, 18.4, 19.4,
18., 19., 20., 21.,
13., 14., 15., 16.,
14.6, 15.6, 16.6, 17.6,
17., 18., 19., 20.,
19.4, 20.4, 21.4, 22.4,
21., 22., 23., 24.
});
//input = 1.f;
input.linspace(1);
nd4j::ops::resize_bilinear op;
auto results = op.execute({&input}, {}, {4, 5}, {false, true});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
NDArray* result = results->at(0);
// result->printIndexedBuffer("Resized to 4x5 bilinear with half pixels");
//expected.printIndexedBuffer("Expect for 10x10");
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test1_2) {
NDArray input = NDArrayFactory::create<int>('c', {1, 2, 3, 4});
//NDArray<float> paddings('c', {3,2}, {0,0, 0,1, 0,0});
//NDArray<float> expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.});
NDArray expected = NDArrayFactory::create<float>('c', {1, 4, 5, 4}, {
1.f, 2.f, 3.f, 4.f,
2.6f, 3.6f, 4.6f, 5.6f,
5.f, 6.f, 7.f, 8.f,
7.4f, 8.4f, 9.4f, 10.4f,
9.f, 10.f, 11.f, 12.f,
4.f, 5.f, 6.f, 7.f,
5.6f, 6.6f, 7.6f, 8.6f,
8.f, 9.f, 10.f, 11.f,
10.4f, 11.4f, 12.4f, 13.4f,
12.f, 13.f, 14.f, 15.f,
10.f, 11.f, 12.f, 13.f,
11.6f, 12.6f, 13.6f, 14.6f,
14.f, 15.f, 16.f, 17.f,
16.4f, 17.4f, 18.4f, 19.4f,
18.f, 19.f, 20.f, 21.f,
13.f, 14.f, 15.f, 16.f,
14.6f, 15.6f, 16.6f, 17.6f,
17.f, 18.f, 19.f, 20.f,
19.4f, 20.4f, 21.4f, 22.4f,
21.f, 22.f, 23.f, 24.f
});
//input = 1.f;
input.linspace(1);
nd4j::ops::resize_bilinear op;
auto results = op.execute({&input}, {}, {4, 5}, {false, true});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
NDArray* result = results->at(0);
// result->printBuffer("Resized to 4x5");
// expected.printBuffer("Expect for 4x5");
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test01) {
NDArray input = NDArrayFactory::create<double>('c', {2,3,4});
@ -1857,7 +1990,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test3) {
input.linspace(1);
nd4j::ops::resize_bilinear op;
auto results = op.execute({&input}, {}, {10, 10, 1});
auto results = op.execute({&input}, {}, {10, 10}, {true});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -1986,7 +2119,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test4) {
input.linspace(1);
nd4j::ops::resize_bilinear op;
auto results = op.execute({&input, &size}, {}, {1});
auto results = op.execute({&input, &size}, {}, {}, {true});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -2023,7 +2156,56 @@ TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test1) {
NDArray input = NDArrayFactory::create<double>('c', {1, 2, 3, 4});
//NDArray<float> paddings('c', {3,2}, {0,0, 0,1, 0,0});
//NDArray<float> expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.});
NDArray expected = NDArrayFactory::create<double>('c', {1, 4, 5, 4}, { 1, 2, 3, 4,
NDArray expected = NDArrayFactory::create<double>('c', {1, 4, 5, 4}, {
1, 2, 3, 4,
1, 2, 3, 4,
5, 6, 7, 8,
5, 6, 7, 8,
9, 10, 11, 12,
1, 2, 3, 4,
1, 2, 3, 4,
5, 6, 7, 8,
5, 6, 7, 8,
9, 10, 11, 12,
13, 14, 15, 16,
13, 14, 15, 16,
17, 18, 19, 20,
17, 18, 19, 20,
21, 22, 23, 24,
13, 14, 15, 16,
13, 14, 15, 16,
17, 18, 19, 20,
17, 18, 19, 20,
21, 22, 23, 24
});
//input = 1.f;
input.linspace(1);
nd4j::ops::resize_nearest_neighbor op;
auto results = op.execute({&input}, {}, {4, 5}, {false, false});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
NDArray* result = results->at(0);
// result->printIndexedBuffer("Resized to 4x5");
// expected.printIndexedBuffer("Expect for 4x5");
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test1_1) {
NDArray input = NDArrayFactory::create<int>('c', {1, 2, 3, 4});
//NDArray<float> paddings('c', {3,2}, {0,0, 0,1, 0,0});
//NDArray<float> expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.});
NDArray expected = NDArrayFactory::create<int>('c', {1, 4, 5, 4}, {
1, 2, 3, 4,
1, 2, 3, 4,
5, 6, 7, 8,
5, 6, 7, 8,
@ -2065,47 +2247,48 @@ TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test1) {
delete results;
}
TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test1_1) {
TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test1_1_1) {
NDArray input = NDArrayFactory::create<int>('c', {1, 2, 3, 4});
NDArray input = NDArrayFactory::create<float>('c', {1, 2, 3, 4});
//NDArray<float> paddings('c', {3,2}, {0,0, 0,1, 0,0});
//NDArray<float> expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.});
NDArray expected = NDArrayFactory::create<int>('c', {1, 4, 5, 4}, { 1, 2, 3, 4,
1, 2, 3, 4,
5, 6, 7, 8,
5, 6, 7, 8,
9, 10, 11, 12,
NDArray expected = NDArrayFactory::create<float>('c', {1, 4, 5, 4}, {
1.f, 2.f, 3.f, 4.f,
1.f, 2.f, 3.f, 4.f,
5.f, 6.f, 7.f, 8.f,
9.f, 10.f, 11.f, 12.f,
9.f, 10.f, 11.f, 12.f,
1, 2, 3, 4,
1, 2, 3, 4,
5, 6, 7, 8,
5, 6, 7, 8,
9, 10, 11, 12,
1.f, 2.f, 3.f, 4.f,
1.f, 2.f, 3.f, 4.f,
5.f, 6.f, 7.f, 8.f,
9.f, 10.f, 11.f, 12.f,
9.f, 10.f, 11.f, 12.f,
13, 14, 15, 16,
13, 14, 15, 16,
17, 18, 19, 20,
17, 18, 19, 20,
21, 22, 23, 24,
13.f, 14.f, 15.f, 16.f,
13.f, 14.f, 15.f, 16.f,
17.f, 18.f, 19.f, 20.f,
21.f, 22.f, 23.f, 24.f,
21.f, 22.f, 23.f, 24.f,
13, 14, 15, 16,
13, 14, 15, 16,
17, 18, 19, 20,
17, 18, 19, 20,
21, 22, 23, 24
13.f, 14.f, 15.f, 16.f,
13.f, 14.f, 15.f, 16.f,
17.f, 18.f, 19.f, 20.f,
21.f, 22.f, 23.f, 24.f,
21.f, 22.f, 23.f, 24.f
});
//input = 1.f;
input.linspace(1);
nd4j::ops::resize_nearest_neighbor op;
auto results = op.execute({&input}, {}, {4, 5});
auto results = op.execute({&input}, {}, {4,5}, {false, true});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
NDArray* result = results->at(0);
// result->printIndexedBuffer("Resized to 4x5");
// expected.printIndexedBuffer("Expect for 4x5");
// expected.printBuffer("Expect for 4x5");
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
@ -2533,7 +2716,7 @@ TEST_F(DeclarableOpsTests10, Image_CropAndResize_3) {
NDArray cropSize = NDArrayFactory::create<Nd4jLong>({3, 3});
//NDArray<float> ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f});
NDArray expected('c', {1,3,3,1}, {1, 1.5f, 2., 2.f, 2.5f, 3.f, 3.f, 3.5f, 4.f}, nd4j::DataType::FLOAT32);
NDArray expected('c', {1,3,3,1}, {1.f, 1.5f, 2., 2.f, 2.5f, 3.f, 3.f, 3.5f, 4.f}, nd4j::DataType::FLOAT32);
nd4j::ops::crop_and_resize op;
auto results = op.execute({&images, &boxes, &boxI, &cropSize}, {}, {0});
@ -2557,7 +2740,7 @@ TEST_F(DeclarableOpsTests10, Image_CropAndResize_4) {
NDArray cropSize = NDArrayFactory::create<int>({3, 3});
//NDArray<float> ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f});
NDArray expected('c', {1,3,3,1}, {1, 2.f, 2.f, 3.f, 4, 4.f, 3.f, 4.f, 4.f}, nd4j::DataType::FLOAT32);
NDArray expected('c', {1,3,3,1}, {1.f, 2.f, 2.f, 3.f, 4, 4.f, 3.f, 4.f, 4.f}, nd4j::DataType::FLOAT32);
nd4j::ops::crop_and_resize op;
auto results = op.execute({&images, &boxes, &boxI, &cropSize}, {}, {1});
@ -2726,7 +2909,7 @@ TEST_F(DeclarableOpsTests10, Image_DrawBoundingBoxes_3) {
TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_1) {
NDArray x('c', {2,3}, {-63.80f, -63.75f, -63.70f, -63.5f, 0.0f, 0.1f}, nd4j::DataType::FLOAT32);
NDArray exp('c', {2,3}, {-63.75, -63.75, -63.75, -63.5, 0., 0.}, nd4j::DataType::FLOAT32);
NDArray exp('c', {2,3}, {-63.75f, -63.75f, -63.75f, -63.5f, 0.f, 0.f}, nd4j::DataType::FLOAT32);
NDArray min('c', {}, {-63.65f}, nd4j::DataType::FLOAT32);
NDArray max('c', {}, {0.1f}, nd4j::DataType::FLOAT32);
@ -2971,22 +3154,6 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_5) {
delete results;
}
/* public void testFakeQuantAgainstTF_1() {
INDArray x = Nd4j.createFromArray(new float[]{ 0.7788f,0.8012f, 0.7244f, 0.2309f,0.7271f,
0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f,
0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f}).reshape(3,5);
INDArray min = Nd4j.createFromArray(new float[]{-0.2283f, -0.0719f, -0.0154f, -0.5162f, -0.3567f}).reshape(1,5);
INDArray max = Nd4j.createFromArray(new float[]{0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f}).reshape(1,5);
INDArray out = Nd4j.createUninitialized(x.shape());
val op = new FakeQuantWithMinMaxVarsPerChannel(x,min,max,out);
INDArray expected = Nd4j.createFromArray(new float[]{0.7801f, 0.5966f, 0.7260f, 0.2320f, 0.5084f,
0.1800f, 0.5046f, 0.8684f, 0.3513f, 0.5084f,
0.0877f, 0.5966f, 0.6600f, 0.3513f, 0.1604f}).reshape(3,5);
assertEquals(expected, out);
}*/
TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_6) {
NDArray x = NDArrayFactory::create<float>('c', {3, 5}, {0.7788f,0.8012f, 0.7244f, 0.2309f,0.7271f,
0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f,
@ -3094,12 +3261,12 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_8) {
TEST_F(DeclarableOpsTests10, batchnorm_test1) {
NDArray input ('c', {2,4}, nd4j::DataType::FLOAT32);
NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32);
NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32);
NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32);
NDArray beta ('c', {4}, {10, 20, -10, -20}, nd4j::DataType::FLOAT32);
NDArray mean ('c', {4}, {1.05f, 1.15f, 1.2f, 1.3f}, nd4j::DataType::FLOAT32);
NDArray variance('c', {4}, {0.5f, 0.7f, 0.9f, 1.1f}, nd4j::DataType::FLOAT32);
NDArray gamma ('c', {4}, {-1.2f, 1.3f, -1.4f, 1.5f}, nd4j::DataType::FLOAT32);
NDArray beta ('c', {4}, {10.f, 20.f, -10.f, -20.f}, nd4j::DataType::FLOAT32);
NDArray expected('c', {2,4}, {11.61218734, 18.52390321, -8.67185076, -21.28716864, 10.93337162, 19.14541765, -9.26213931, -20.71509369}, nd4j::DataType::FLOAT32);
NDArray expected('c', {2,4}, {11.61218734f, 18.52390321f, -8.67185076f, -21.28716864f, 10.93337162f, 19.14541765f, -9.26213931f, -20.71509369f}, nd4j::DataType::FLOAT32);
input.linspace(0.1, 0.1);
@ -3211,19 +3378,19 @@ TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_test4) {
TEST_F(DeclarableOpsTests10, batchnorm_test5) {
NDArray input ('c', {2,4,2,2}, nd4j::DataType::FLOAT32);
NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32);
NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32);
NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32);
NDArray beta ('c', {4}, {10, 20, -10, -20}, nd4j::DataType::FLOAT32);
NDArray mean ('c', {4}, {1.05f, 1.15f, 1.2f, 1.3f}, nd4j::DataType::FLOAT32);
NDArray variance('c', {4}, {0.5f, 0.7f, 0.9f, 1.1f}, nd4j::DataType::FLOAT32);
NDArray gamma ('c', {4}, {-1.2f, 1.3f, -1.4f, 1.5f}, nd4j::DataType::FLOAT32);
NDArray beta ('c', {4}, {10.f, 20.f, -10.f, -20.f}, nd4j::DataType::FLOAT32);
NDArray expected('c', {2,4,2,2}, {11.612187, 11.442483, 11.272779, 11.103076, 18.990039, 19.145418, 19.300796, 19.456175, -9.557284, -9.704856, -9.852428, -10., -20.,
-19.856981, -19.713963, -19.570944, 8.896924, 8.727221, 8.557517, 8.387813, 21.476097, 21.631475, 21.786854, 21.942233, -11.918438,
-12.06601 , -12.213582, -12.361154, -17.7117, -17.568681, -17.425663, -17.282644}, nd4j::DataType::FLOAT32);
NDArray expected('c', {2,4,2,2}, { 11.612187f, 11.442483f, 11.272779f, 11.103076f, 18.990039f, 19.145418f, 19.300796f, 19.456175f, -9.557284f, -9.704856f, -9.852428f, -10.f, -20.f,
-19.856981f, -19.713963f, -19.570944f, 8.896924f, 8.727221f, 8.557517f, 8.387813f, 21.476097f, 21.631475f, 21.786854f, 21.942233f, -11.918438f,
-12.06601f, -12.213582f, -12.361154f, -17.7117f, -17.568681f, -17.425663f, -17.282644f}, nd4j::DataType::FLOAT32);
input.linspace(0.1, 0.1);
nd4j::ops::batchnorm op;
auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1,1});
auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1, 1, 1});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -3240,14 +3407,14 @@ TEST_F(DeclarableOpsTests10, batchnorm_test5) {
TEST_F(DeclarableOpsTests10, batchnorm_test6) {
NDArray input ('c', {2,2,2,4}, nd4j::DataType::FLOAT32);
NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32);
NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32);
NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32);
NDArray beta ('c', {4}, {10, 20, -10, -20}, nd4j::DataType::FLOAT32);
NDArray mean ('c', {4}, {1.05f, 1.15f, 1.2f, 1.3f}, nd4j::DataType::FLOAT32);
NDArray variance('c', {4}, {0.5f, 0.7f, 0.9, 1.1f}, nd4j::DataType::FLOAT32);
NDArray gamma ('c', {4}, {-1.2f, 1.3f, -1.4f, 1.5f}, nd4j::DataType::FLOAT32);
NDArray beta ('c', {4}, {10.f, 20.f, -10.f, -20.f}, nd4j::DataType::FLOAT32);
NDArray expected('c', {2,2,2,4}, {11.612187, 18.523903, -8.671851, -21.287169, 10.933372, 19.145418, -9.262139, -20.715094, 10.254556, 19.766932, -9.852428, -20.143019, 9.57574 ,
20.388447, -10.442716, -19.570944,8.896924, 21.009961, -11.033005, -18.998869, 8.218109, 21.631475, -11.623294, -18.426794, 7.539293, 22.25299 ,
-12.213582, -17.854719, 6.860477, 22.874504, -12.803871, -17.282644}, nd4j::DataType::FLOAT32);
NDArray expected('c', {2,2,2,4}, {11.612187f, 18.523903f, -8.671851f, -21.287169f, 10.933372f, 19.145418f, -9.262139f, -20.715094f, 10.254556f, 19.766932f, -9.852428f, -20.143019f, 9.57574f,
20.388447f, -10.442716f, -19.570944f, 8.896924f, 21.009961f, -11.033005f, -18.998869f, 8.218109f, 21.631475f, -11.623294f, -18.426794f, 7.539293f, 22.25299f,
-12.213582f, -17.854719f, 6.860477f, 22.874504f, -12.803871f, -17.282644f}, nd4j::DataType::FLOAT32);
input.linspace(0.1, 0.1);
nd4j::ops::batchnorm op;
@ -3270,7 +3437,7 @@ TEST_F(DeclarableOpsTests10, bool_broadcast_test_1) {
NDArray arr1('c', {2,2,1}, {1, 2, 3, 4}, nd4j::DataType::INT32);
NDArray arr2('c', { 2,2}, {0, 1, 0, 4}, nd4j::DataType::INT32);
NDArray expd('c', {2,2,2}, {0,1,0,0, 0,0,0,1}, nd4j::DataType::BOOL);
NDArray expd('c', {2,2,2}, {false, true, false, false, false, false, false, true}, nd4j::DataType::BOOL);
NDArray result('c', {2,2,2}, nd4j::DataType::BOOL);

View File

@ -1257,7 +1257,7 @@ TEST_F(DeclarableOpsTests12, inTopK_2) {
auto input = NDArrayFactory::create<double>('c', {4, 5});
auto idx = NDArrayFactory::create<Nd4jLong>('c', {4});
auto exp = NDArrayFactory::create<bool>({0, 0, 0, 1});
auto exp = NDArrayFactory::create<bool>({false, false, false, true});
int exclusive, reverse;
input.linspace(1);
@ -1318,7 +1318,7 @@ TEST_F(DeclarableOpsTests12, inTopK_4) {
TEST_F(DeclarableOpsTests12, inTopK_5) {
auto x = NDArrayFactory::create<double>('f', {6, 4}, {11.0, 3.0, 14.0, 5.0, 6.0, 9.0, 3.5, 7.0, 21.0, 3.0, 14.0, 15.0, 6.0, 9.0, 3.5, 7.0, 11.0, 13.0, 14.0, 5.0, 16.0, 9.0, 13.5, 7.0} );
auto y = NDArrayFactory::create<Nd4jLong>('f', {6}, {0, 0, 0, 0, 0, 0});
auto expV = NDArrayFactory::create<bool>('f', {6}, {1, 0, 0, 0, 0, 0 });
auto expV = NDArrayFactory::create<bool>('f', {6}, {true, false, false, false, false, false });
nd4j::ops::in_top_k op;
auto result = op.execute({&x, &y}, {}, {2});

View File

@ -1167,12 +1167,12 @@ TEST_F(DeclarableOpsTests13, lstmLayer_3) {
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
NDArray expH('c', {sL, bS, nOut}, {0.493883, 0.493883, 0.493883, 0.510990, 0.510990, 0.510990, 0.534701, 0.534701, 0.534701, 0.549139,
0.549139, 0.549139, 0.571900, 0.571900, 0.571900, 0.583561, 0.583561, 0.583561, 0.605106, 0.605106,
0.605106, 0.614114, 0.614114, 0.614114, 0.635354, 0.635354, 0.635354, 0.642045, 0.642045, 0.642045}, nd4j::DataType::FLOAT32);
NDArray expH('c', {sL, bS, nOut}, {0.493883f, 0.493883f, 0.493883f, 0.510990f, 0.510990f, 0.510990f, 0.534701f, 0.534701f, 0.534701f, 0.549139f,
0.549139f, 0.549139f, 0.571900f, 0.571900f, 0.571900f, 0.583561f, 0.583561f, 0.583561f, 0.605106f, 0.605106f,
0.605106f, 0.614114f, 0.614114f, 0.614114f, 0.635354f, 0.635354f, 0.635354f, 0.642045f, 0.642045f, 0.642045f}, nd4j::DataType::FLOAT32);
NDArray expHL('c', {bS, nOut}, {0.493883, 0.493883, 0.493883, 0.510990, 0.510990, 0.510990}, nd4j::DataType::FLOAT32);
NDArray expCL('c', {bS, nOut}, {1.061274, 1.061274, 1.061274, 1.115888, 1.115888, 1.115888}, nd4j::DataType::FLOAT32);
NDArray expHL('c', {bS, nOut}, {0.493883f, 0.493883f, 0.493883f, 0.510990f, 0.510990f, 0.510990f}, nd4j::DataType::FLOAT32);
NDArray expCL('c', {bS, nOut}, {1.061274f, 1.061274f, 1.061274f, 1.115888f, 1.115888f, 1.115888f}, nd4j::DataType::FLOAT32);
nd4j::ops::lstmLayer op;
auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs);
@ -1230,12 +1230,12 @@ TEST_F(DeclarableOpsTests13, lstmLayer_4) {
NDArray cI('c', {2,bS, nOut}, nd4j::DataType::FLOAT32);
x.linspace(0.5, 0.5);
Wx({0,1, 0,0, 0,0}) = 0.003;
Wx({1,2, 0,0, 0,0}) = -0.003;
Wr({0,1, 0,0, 0,0}) = 0.006;
Wr({1,2, 0,0, 0,0}) = -0.006;
b({0,1, 0,0}) = 0.5;
b({1,2, 0,0}) = -0.5;
Wx({0,1, 0,0, 0,0}) = 0.003f;
Wx({1,2, 0,0, 0,0}) = -0.003f;
Wr({0,1, 0,0, 0,0}) = 0.006f;
Wr({1,2, 0,0, 0,0}) = -0.006f;
b({0,1, 0,0}) = 0.5f;
b({1,2, 0,0}) = -0.5f;
hI({0,1, 0,0, 0,0}) = 1;
hI({1,2, 0,0, 0,0}) = -1;
cI({0,1, 0,0, 0,0}) = 2;
@ -1245,18 +1245,19 @@ TEST_F(DeclarableOpsTests13, lstmLayer_4) {
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
NDArray expH('c', {sL, bS, 2*nOut}, {0.577661, 0.577661, 0.577661, -0.107642, -0.107642, -0.107642, 0.585289, 0.585289, 0.585289,
-0.106937, -0.106937, -0.106937, 0.556517, 0.556517, 0.556517, -0.111647, -0.111647, -0.111647,
0.567274, 0.567274, 0.567274, -0.110214, -0.110214, -0.110214, 0.547395, 0.547395, 0.547395,
-0.123305, -0.123305, -0.123305, 0.560640, 0.560640, 0.560640, -0.120862, -0.120862, -0.120862,
0.550714, 0.550714, 0.550714, -0.156223, -0.156223, -0.156223, 0.565308, 0.565308, 0.565308,
-0.152313, -0.152313, -0.152313, 0.563741, 0.563741, 0.563741, -0.234128, -0.234128, -0.234128,
0.578676, 0.578676, 0.578676, -0.228917, -0.228917, -0.228917}, nd4j::DataType::FLOAT32);
NDArray expH('c', {sL, bS, 2 * nOut}, {
0.577661f, 0.577661f, 0.577661f, -0.107642f, -0.107642f, -0.107642f, 0.585289f, 0.585289f, 0.585289f,
-0.106937f, -0.106937f, -0.106937f, 0.556517f, 0.556517f, 0.556517f, -0.111647f, -0.111647f, -0.111647f,
0.567274f, 0.567274f, 0.567274f, -0.110214f, -0.110214f, -0.110214f, 0.547395f, 0.547395f, 0.547395f,
-0.123305f, -0.123305f, -0.123305f, 0.560640f, 0.560640f, 0.560640f, -0.120862f, -0.120862f, -0.120862f,
0.550714f, 0.550714f, 0.550714f, -0.156223f, -0.156223f, -0.156223f, 0.565308f, 0.565308f, 0.565308f,
-0.152313f, -0.152313f, -0.152313f, 0.563741f, 0.563741f, 0.563741f, -0.234128f, -0.234128f, -0.234128f,
0.578676f, 0.578676f, 0.578676f, -0.228917f, -0.228917f, -0.228917f}, nd4j::DataType::FLOAT32);
NDArray expHL('c', {2,bS, nOut}, {0.563741, 0.563741, 0.563741, 0.578676, 0.578676, 0.578676, -0.107642,
-0.107642, -0.107642, -0.106937, -0.106937, -0.106937}, nd4j::DataType::FLOAT32);
NDArray expCL('c', {2,bS, nOut}, {1.217757, 1.217757, 1.217757, 1.272398, 1.272398, 1.272398, -0.295768,
-0.295768, -0.295768, -0.298453, -0.298453, -0.298453}, nd4j::DataType::FLOAT32);
NDArray expHL('c', {2,bS, nOut}, {0.563741f, 0.563741f, 0.563741f, 0.578676f, 0.578676f, 0.578676f, -0.107642f,
-0.107642f, -0.107642f, -0.106937f, -0.106937f, -0.106937f}, nd4j::DataType::FLOAT32);
NDArray expCL('c', {2,bS, nOut}, {1.217757f, 1.217757f, 1.217757f, 1.272398f, 1.272398f, 1.272398f, -0.295768f,
-0.295768f, -0.295768f, -0.298453f, -0.298453f, -0.298453f}, nd4j::DataType::FLOAT32);
nd4j::ops::lstmLayer op;
auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs);
@ -1328,16 +1329,17 @@ TEST_F(DeclarableOpsTests13, lstmLayer_5) {
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
NDArray expH('c', {bS, sL, 2*nOut}, {0.577661, 0.577661, 0.577661, -0.107659, -0.107659, -0.107659, 0.548099, 0.548099, 0.548099, -0.113406, -0.113406, -0.113406,
0.526881, 0.526881, 0.526881, -0.12883 , -0.12883 , -0.12883 , 0.515882, 0.515882, 0.515882, -0.16868 , -0.16868 , -0.16868 ,
0.51409 , 0.51409 , 0.51409 , -0.255185, -0.255185, -0.255185, 0.614599, 0.614599, 0.614599, -0.102739, -0.102739, -0.102739,
0.599572, 0.599572, 0.599572, -0.105802, -0.105802, -0.105802,0.591089, 0.591089, 0.591089, -0.116681, -0.116681, -0.116681,
0.588694, 0.588694, 0.588694, -0.149201, -0.149201, -0.149201,0.591492, 0.591492, 0.591492, -0.228917, -0.228917, -0.228917}, nd4j::DataType::FLOAT32);
NDArray expH('c', {bS, sL, 2*nOut}, {
0.577661f, 0.577661f, 0.577661f, -0.107659f, -0.107659f, -0.107659f, 0.548099f, 0.548099f, 0.548099f, -0.113406f, -0.113406f, -0.113406f,
0.526881f, 0.526881f, 0.526881f, -0.12883f, -0.12883f, -0.12883f, 0.515882f, 0.515882f, 0.515882f, -0.16868f, -0.16868f, -0.16868f,
0.51409f, 0.51409f, 0.51409f, -0.255185f, -0.255185f, -0.255185f, 0.614599f, 0.614599f, 0.614599f, -0.102739f, -0.102739f, -0.102739f,
0.599572f, 0.599572f, 0.599572f, -0.105802f, -0.105802f, -0.105802f, 0.591089f, 0.591089f, 0.591089f, -0.116681f, -0.116681f, -0.116681f,
0.588694f, 0.588694f, 0.588694f, -0.149201f, -0.149201f, -0.149201f, 0.591492f, 0.591492f, 0.591492f, -0.228917f, -0.228917f, -0.228917f}, nd4j::DataType::FLOAT32);
NDArray expHL('c', {2,bS, nOut}, {0.51409 , 0.51409 , 0.51409 , 0.591492, 0.591492, 0.591492,
-0.107659, -0.107659, -0.107659, -0.102739, -0.102739, -0.102739}, nd4j::DataType::FLOAT32);
NDArray expCL('c', {2,bS, nOut}, {1.07293 , 1.07293 , 1.07293,1.346609, 1.346609, 1.346609,
-0.295811, -0.295811, -0.295811,-0.305394, -0.305394, -0.305394}, nd4j::DataType::FLOAT32);
NDArray expHL('c', {2,bS, nOut}, {0.51409f, 0.51409f, 0.51409f, 0.591492f, 0.591492f, 0.591492f,
-0.107659f, -0.107659f, -0.107659f, -0.102739f, -0.102739f, -0.102739f}, nd4j::DataType::FLOAT32);
NDArray expCL('c', {2,bS, nOut}, {1.07293f , 1.07293f , 1.07293f, 1.346609f, 1.346609f, 1.346609f,
-0.295811f, -0.295811f, -0.295811f, -0.305394f, -0.305394f, -0.305394f}, nd4j::DataType::FLOAT32);
nd4j::ops::lstmLayer op;
auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs);
@ -1398,12 +1400,12 @@ TEST_F(DeclarableOpsTests13, lstmLayer_6) {
NDArray cI('c', {2,bS, nOut}, nd4j::DataType::FLOAT32);
x.linspace(0.5, 0.5);
Wx({0,1, 0,0, 0,0}) = 0.003;
Wx({1,2, 0,0, 0,0}) = -0.003;
Wr({0,1, 0,0, 0,0}) = 0.006;
Wr({1,2, 0,0, 0,0}) = -0.006;
b({0,1, 0,0}) = 0.5;
b({1,2, 0,0}) = -0.5;
Wx({0,1, 0,0, 0,0}) = 0.003f;
Wx({1,2, 0,0, 0,0}) = -0.003f;
Wr({0,1, 0,0, 0,0}) = 0.006f;
Wr({1,2, 0,0, 0,0}) = -0.006f;
b({0,1, 0,0}) = 0.5f;
b({1,2, 0,0}) = -0.5f;
hI({0,1, 0,0, 0,0}) = 1;
hI({1,2, 0,0, 0,0}) = -1;
cI({0,1, 0,0, 0,0}) = 2;
@ -1413,14 +1415,17 @@ TEST_F(DeclarableOpsTests13, lstmLayer_6) {
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
NDArray expH('c', {sL, bS, nOut}, {0.470019, 0.470019, 0.470019, 0.478352, 0.478352, 0.478352, 0.444871, 0.444871, 0.444871, 0.457060,
0.457060, 0.457060, 0.424090, 0.424090, 0.424090, 0.439778, 0.439778, 0.439778, 0.394491, 0.394491,
0.394491, 0.412995, 0.412995, 0.412995, 0.329613, 0.329613, 0.329613, 0.349760, 0.349760, 0.349760}, nd4j::DataType::FLOAT32);
NDArray expH('c', {sL, bS, nOut}, {
0.470019f, 0.470019f, 0.470019f, 0.478352f, 0.478352f, 0.478352f, 0.444871f, 0.444871f, 0.444871f, 0.457060f,
0.457060f, 0.457060f, 0.424090f, 0.424090f, 0.424090f, 0.439778f, 0.439778f, 0.439778f, 0.394491f, 0.394491f,
0.394491f, 0.412995f, 0.412995f, 0.412995f, 0.329613f, 0.329613f, 0.329613f, 0.349760f, 0.349760f, 0.349760f}, nd4j::DataType::FLOAT32);
NDArray expHL('c', {2,bS, nOut}, {0.563741, 0.563741, 0.563741, 0.578676, 0.578676, 0.578676, -0.107642,
-0.107642, -0.107642, -0.106937, -0.106937, -0.106937}, nd4j::DataType::FLOAT32);
NDArray expCL('c', {2,bS, nOut}, {1.217757, 1.217757, 1.217757, 1.272398, 1.272398, 1.272398, -0.295768,
-0.295768, -0.295768, -0.298453, -0.298453, -0.298453}, nd4j::DataType::FLOAT32);
NDArray expHL('c', {2,bS, nOut}, {0.563741f, 0.563741f, 0.563741f, 0.578676f, 0.578676f, 0.578676f,
-0.107642f, -0.107642f, -0.107642f, -0.106937f, -0.106937f, -0.106937f},
nd4j::DataType::FLOAT32);
NDArray expCL('c', {2,bS, nOut}, {1.217757f, 1.217757f, 1.217757f, 1.272398f, 1.272398f, 1.272398f,
-0.295768f, -0.295768f, -0.295768f, -0.298453f, -0.298453f, -0.298453f},
nd4j::DataType::FLOAT32);
nd4j::ops::lstmLayer op;
auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs);
@ -1568,12 +1573,13 @@ TEST_F(DeclarableOpsTests13, lstmLayer_8) {
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
NDArray expH('c', {sL, bS, nOut}, {0.436221, 0.436221, 0.436221,0.450573, 0.450573, 0.450573,0.463602, 0.463602, 0.463602, 0.474674, 0.474674, 0.474674,
0.484039, 0.484039, 0.484039,0.490679, 0.490679, 0.490679, 0.494871, 0.494871, 0.494871, 0.499028, 0.499028, 0.499028,
0.504649, 0.504649, 0.504649, 0.508719, 0.508719, 0.508719}, nd4j::DataType::FLOAT32);
NDArray expH('c', {sL, bS, nOut}, {
0.436221f, 0.436221f, 0.436221f, 0.450573f, 0.450573f, 0.450573f, 0.463602f, 0.463602f, 0.463602f, 0.474674f, 0.474674f, 0.474674f,
0.484039f, 0.484039f, 0.484039f, 0.490679f, 0.490679f, 0.490679f, 0.494871f, 0.494871f, 0.494871f, 0.499028f, 0.499028f, 0.499028f,
0.504649f, 0.504649f, 0.504649f, 0.508719f, 0.508719f, 0.508719f}, nd4j::DataType::FLOAT32);
NDArray expHL('c', {bS, nOut}, {0.436221, 0.436221, 0.436221, 0.450573, 0.450573, 0.450573}, nd4j::DataType::FLOAT32);
NDArray expCL('c', {bS, nOut}, {0.879804, 0.879804, 0.879804,0.914666, 0.914666, 0.914666}, nd4j::DataType::FLOAT32);
NDArray expHL('c', {bS, nOut}, {0.436221f, 0.436221f, 0.436221f, 0.450573f, 0.450573f, 0.450573f}, nd4j::DataType::FLOAT32);
NDArray expCL('c', {bS, nOut}, {0.879804f, 0.879804f, 0.879804f, 0.914666f, 0.914666f, 0.914666f}, nd4j::DataType::FLOAT32);
nd4j::ops::lstmLayer op;
auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
@ -1650,16 +1656,17 @@ TEST_F(DeclarableOpsTests13, lstmLayer_9) {
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
NDArray expH('c', {sL, bS, 2*nOut}, { 0.55533 , 0.55533 , 0.55533 , -0.104502, -0.104502, -0.104502, 0.562925, 0.562925, 0.562925, -0.103843, -0.103843, -0.103843,
0.531795, 0.531795, 0.531795, -0.107456, -0.107456, -0.107456,0.542556, 0.542556, 0.542556, -0.106139, -0.106139, -0.106139,
0.521466, 0.521466, 0.521466, -0.11681 , -0.11681 , -0.11681 , 0.534638, 0.534638, 0.534638, -0.11458 , -0.11458 , -0.11458 ,
0.524805, 0.524805, 0.524805, -0.145177, -0.145177, -0.145177,0.539187, 0.539187, 0.539187, -0.14157 , -0.14157 , -0.14157 ,
0.538309, 0.538309, 0.538309, -0.218056, -0.218056, -0.218056,0.552923, 0.552923, 0.552923, -0.213068, -0.213068, -0.213068}, nd4j::DataType::FLOAT32);
NDArray expH('c', {sL, bS, 2*nOut}, {
0.55533f, 0.55533f, 0.55533f, -0.104502f, -0.104502f, -0.104502f, 0.562925f, 0.562925f, 0.562925f, -0.103843f, -0.103843f, -0.103843f,
0.531795f, 0.531795f, 0.531795f, -0.107456f, -0.107456f, -0.107456f, 0.542556f, 0.542556f, 0.542556f, -0.106139f, -0.106139f, -0.106139f,
0.521466f, 0.521466f, 0.521466f, -0.11681f, -0.11681f, -0.11681f, 0.534638f, 0.534638f, 0.534638f, -0.11458f, -0.11458f, -0.11458f,
0.524805f, 0.524805f, 0.524805f, -0.145177f, -0.145177f, -0.145177f, 0.539187f, 0.539187f, 0.539187f, -0.14157f, -0.14157f, -0.14157f,
0.538309f, 0.538309f, 0.538309f, -0.218056f, -0.218056f, -0.218056f, 0.552923f, 0.552923f, 0.552923f, -0.213068f, -0.213068f, -0.213068f}, nd4j::DataType::FLOAT32);
NDArray expHL('c', {2,bS, nOut}, {0.538309, 0.538309, 0.538309, 0.552923, 0.552923, 0.552923, -0.104502, -0.104502, -0.104502,
-0.103843, -0.103843, -0.103843}, nd4j::DataType::FLOAT32);
NDArray expCL('c', {2,bS, nOut}, {1.147089, 1.147089, 1.147089, 1.197228, 1.197228, 1.197228, -0.289425, -0.289425, -0.289425,
-0.292174, -0.292174, -0.292174}, nd4j::DataType::FLOAT32);
NDArray expHL('c', {2,bS, nOut}, {0.538309f, 0.538309f, 0.538309f, 0.552923f, 0.552923f, 0.552923f, -0.104502f, -0.104502f, -0.104502f,
-0.103843f, -0.103843f, -0.103843f}, nd4j::DataType::FLOAT32);
NDArray expCL('c', {2,bS, nOut}, {1.147089f, 1.147089f, 1.147089f, 1.197228f, 1.197228f, 1.197228f, -0.289425f, -0.289425f, -0.289425f,
-0.292174f, -0.292174f, -0.292174f}, nd4j::DataType::FLOAT32);
nd4j::ops::lstmLayer op;
auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
@ -1731,14 +1738,20 @@ TEST_F(DeclarableOpsTests13, lstmLayer_10) {
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
NDArray expH('c', {sL, bS, nOut}, {0., 0., 0., 0.562925, 0.562925, 0.562925, 0.570404, 0.570404, 0.570404, 0.57777 , 0.57777 , 0.57777 , 0.585023, 0.585023, 0.585023,
0., 0., 0., 0., 0., 0., 0.576568, 0.576568, 0.576568, 0.586163, 0.586163, 0.586163, 0.595462, 0.595462, 0.595462, 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0.611224, 0.611224, 0.611224, 0.621298, 0.621298, 0.621298, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0.655858, 0.655858, 0.655858, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.692315, 0.692315, 0.692315, 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0.}, nd4j::DataType::FLOAT32);
NDArray expH('c', {sL, bS, nOut}, {
0.f, 0.f, 0.f, 0.562925f, 0.562925f, 0.562925f, 0.570404f, 0.570404f, 0.570404f, 0.57777f,
0.57777f, 0.57777f, 0.585023f, 0.585023f, 0.585023f, 0.f, 0.f, 0.f, 0.f, 0.f,
0.f, 0.576568f, 0.576568f, 0.576568f, 0.586163f, 0.586163f, 0.586163f, 0.595462f, 0.595462f, 0.595462f,
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.611224f,
0.611224f, 0.611224f, 0.621298f, 0.621298f, 0.621298f, 0.f, 0.f, 0.f, 0.f, 0.f,
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.655858f, 0.655858f, 0.655858f,
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,
0.f, 0.f, 0.692315f, 0.692315f, 0.692315f, 0.f, 0.f, 0.f, 0.f, 0.f,
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f},
nd4j::DataType::FLOAT32);
NDArray expHL('c', {bS, nOut}, {0., 0., 0., 0.562925, 0.562925, 0.562925, 0.576568, 0.576568, 0.576568, 0.611224, 0.611224, 0.611224, 0.692315, 0.692315, 0.692315}, nd4j::DataType::FLOAT32);
NDArray expCL('c', {bS, nOut}, {0., 0., 0., 1.534275, 1.534275, 1.534275, 1.40183, 1.40183, 1.40183, 1.449675, 1.449675, 1.449675, 1.767702, 1.767702, 1.767702}, nd4j::DataType::FLOAT32);
NDArray expHL('c', {bS, nOut}, {0.f, 0.f, 0.f, 0.562925f, 0.562925f, 0.562925f, 0.576568f, 0.576568f, 0.576568f, 0.611224f, 0.611224f, 0.611224f, 0.692315f, 0.692315f, 0.692315f}, nd4j::DataType::FLOAT32);
NDArray expCL('c', {bS, nOut}, {0.f, 0.f, 0.f, 1.534275f, 1.534275f, 1.534275f, 1.40183f, 1.40183f, 1.40183f, 1.449675f, 1.449675f, 1.449675f, 1.767702f, 1.767702f, 1.767702f}, nd4j::DataType::FLOAT32);
nd4j::ops::lstmLayer op;
auto results = op.execute({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
@ -1799,25 +1812,26 @@ TEST_F(DeclarableOpsTests13, lstmLayer_11) {
NDArray Wp('c', {3*nOut}, nd4j::DataType::FLOAT32);
x.linspace(0.5, 0.5);
Wx = 0.003;
Wr = 0.006;
b = 0.5;
hI = 1.;
cI = 2.;
Wp = -0.05;
Wx = 0.003f;
Wr = 0.006f;
b = 0.5f;
hI = 1.f;
cI = 2.f;
Wp = -0.05f;
std::initializer_list<double> tArgs = {cellClip};
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
NDArray expH('c', {sL, bS, nOut}, {0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.61209,
0.61209, 0.61209,0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.652042, 0.652042, 0.652042, 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0.677708, 0.677708, 0.677708, 0.684177, 0.684177, 0.684177, 0., 0., 0.,0., 0., 0.,0.699627, 0.699627,
0.699627,0.705371, 0.705371, 0.705371,0.710989, 0.710989, 0.710989, 0., 0., 0., 0.719014, 0.719014, 0.719014, 0.724087,
0.724087, 0.724087, 0.729084, 0.729084, 0.729084, 0.734004, 0.734004, 0.734004 }, nd4j::DataType::FLOAT32);
NDArray expH('c', {sL, bS, nOut}, {
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.61209f,
0.61209f, 0.61209f,0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.652042f, 0.652042f, 0.652042f, 0.f, 0.f, 0.f, 0.f, 0.f,
0.f, 0.f, 0.f, 0.f, 0.677708f, 0.677708f, 0.677708f, 0.684177f, 0.684177f, 0.684177f, 0.f, 0.f, 0.f,0.f, 0.f, 0.f, 0.699627f, 0.699627f,
0.699627f, 0.705371f, 0.705371f, 0.705371f, 0.710989f, 0.710989f, 0.710989f, 0., 0., 0., 0.719014, 0.719014, 0.719014, 0.724087,
0.724087f, 0.724087f, 0.729084f, 0.729084f, 0.729084f, 0.734004f, 0.734004f, 0.734004f }, nd4j::DataType::FLOAT32);
NDArray expHL('c', {bS, nOut}, {0., 0., 0., 0.719014, 0.719014, 0.719014, 0.699627, 0.699627, 0.699627, 0.677708, 0.677708, 0.677708, 0.61209, 0.61209, 0.61209}, nd4j::DataType::FLOAT32);
NDArray expCL('c', {bS, nOut}, {0., 0., 0., 2.092814, 2.092814, 2.092814, 2.08832, 2.08832, 2.08832, 2.009851, 2.009851, 2.009851, 1.646034, 1.646034, 1.646034}, nd4j::DataType::FLOAT32);
NDArray expHL('c', {bS, nOut}, {0.f, 0.f, 0.f, 0.719014f, 0.719014f, 0.719014f, 0.699627f, 0.699627f, 0.699627f, 0.677708f, 0.677708f, 0.677708f, 0.61209f, 0.61209f, 0.61209f}, nd4j::DataType::FLOAT32);
NDArray expCL('c', {bS, nOut}, {0.f, 0.f, 0.f, 2.092814f, 2.092814f, 2.092814f, 2.08832f, 2.08832f, 2.08832f, 2.009851f, 2.009851f, 2.009851f, 1.646034f, 1.646034f, 1.646034f}, nd4j::DataType::FLOAT32);
nd4j::ops::lstmLayer op;
auto results = op.execute({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
@ -1878,18 +1892,18 @@ TEST_F(DeclarableOpsTests13, lstmLayer_12) {
NDArray Wp('c', {2,3*nOut}, nd4j::DataType::FLOAT32);
x.linspace(0.5, 0.5);
Wx({0,1, 0,0, 0,0}) = 0.003;
Wx({1,2, 0,0, 0,0}) = -0.003;
Wr({0,1, 0,0, 0,0}) = 0.006;
Wr({1,2, 0,0, 0,0}) = -0.006;
b({0,1, 0,0}) = 0.5;
b({1,2, 0,0}) = -0.5;
Wx({0,1, 0,0, 0,0}) = 0.003f;
Wx({1,2, 0,0, 0,0}) = -0.003f;
Wr({0,1, 0,0, 0,0}) = 0.006f;
Wr({1,2, 0,0, 0,0}) = -0.006f;
b({0,1, 0,0}) = 0.5f;
b({1,2, 0,0}) = -0.5f;
hI({0,1, 0,0, 0,0}) = 1;
hI({1,2, 0,0, 0,0}) = -1;
cI({0,1, 0,0, 0,0}) = 2;
cI({1,2, 0,0, 0,0}) = -2;
Wp({0,1, 0,0}) = -0.05;
Wp({1,2, 0,0}) = 0.05;
Wp({0,1, 0,0}) = -0.05f;
Wp({1,2, 0,0}) = 0.05f;
std::initializer_list<double> tArgs = {cellClip};
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
@ -1905,10 +1919,10 @@ TEST_F(DeclarableOpsTests13, lstmLayer_12) {
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.692315, 0.692315, 0.692315, -0.143704, -0.143704, -0.143704, 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.}, nd4j::DataType::FLOAT32);
NDArray expHL('c', {2,bS, nOut}, {0., 0., 0., 0.562925, 0.562925, 0.562925, 0.576568, 0.576568, 0.576568, 0.611224, 0.611224, 0.611224, 0.692315, 0.692315, 0.692315,
0., 0., 0., -0.25361 , -0.25361 , -0.25361 , -0.157103, -0.157103, -0.157103,-0.116502, -0.116502, -0.116502, -0.100025, -0.100025, -0.100025}, nd4j::DataType::FLOAT32);
NDArray expCL('c', {2,bS, nOut}, {0., 0., 0.,1.534275, 1.534275, 1.534275,1.40183 , 1.40183 , 1.40183 ,1.449675, 1.449675, 1.449675,1.767702, 1.767702, 1.767702,
0., 0., 0.,-0.86636 , -0.86636 , -0.86636 ,-0.470245, -0.470245, -0.470245,-0.341856, -0.341856, -0.341856,-0.294986, -0.294986, -0.294986}, nd4j::DataType::FLOAT32);
NDArray expHL('c', {2,bS, nOut}, {0.f, 0.f, 0.f, 0.562925f, 0.562925f, 0.562925f, 0.576568f, 0.576568f, 0.576568f, 0.611224f, 0.611224f, 0.611224f, 0.692315f, 0.692315f, 0.692315f,
0.f, 0.f, 0.f, -0.25361f, -0.25361f, -0.25361f, -0.157103f, -0.157103f, -0.157103f, -0.116502f, -0.116502f, -0.116502f, -0.100025f, -0.100025f, -0.100025f}, nd4j::DataType::FLOAT32);
NDArray expCL('c', {2,bS, nOut}, {0.f, 0.f, 0.f, 1.534275f, 1.534275f, 1.534275f, 1.40183f, 1.40183f, 1.40183f, 1.449675f, 1.449675f, 1.449675f, 1.767702f, 1.767702f, 1.767702f,
0.f, 0.f, 0.f, -0.86636f, -0.86636f, -0.86636f, -0.470245f, -0.470245f, -0.470245f, -0.341856f, -0.341856f, -0.341856f, -0.294986f, -0.294986f, -0.294986f}, nd4j::DataType::FLOAT32);
nd4j::ops::lstmLayer op;
auto results = op.execute({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);

View File

@ -148,8 +148,8 @@ TEST_F(DeclarableOpsTests15, Test_standarize_1) {
}
TEST_F(DeclarableOpsTests15, Test_standarize_bp_1) {
auto x = NDArrayFactory::create<float>('c', {5}, {1., 1., 1., 1., 1.});
auto eps = NDArrayFactory::create<float>('c', {5}, {0., 0., 0., 0., 0.});
auto x = NDArrayFactory::create<float>('c', {5}, {1.f, 1.f, 1.f, 1.f, 1.f});
auto eps = NDArrayFactory::create<float>('c', {5}, {0.f, 0.f, 0.f, 0.f, 0.f});
nd4j::ops::standardize_bp op;
auto result = op.execute({&x, &eps}, {}, {0}, {});

View File

@ -197,3 +197,44 @@ TEST_F(DeclarableOpsTests16, test_range_2) {
delete shapes;
}
TEST_F(DeclarableOpsTests16, test_reverse_1) {
std::vector<Nd4jLong> rows = {3, 5, 7, 8, 9, 10, 119, 211};
std::vector<Nd4jLong> columns = {6, 5, 10, 100, 153, 171, 635};
for (auto r : rows) {
for (auto c : columns) {
//nd4j_printf("Trying [%i, %i]\n", r, c);
auto array = NDArrayFactory::create<float>('c', {r, c});
auto exp = NDArrayFactory::create<float>('c', {r, c});
auto reversed = NDArrayFactory::create<float>('c', {r, c});
auto rowOriginal = NDArrayFactory::create<float>('c', {c});
auto rowReversed = NDArrayFactory::create<float>('c', {c});
for (int e = 0; e < c; e++) {
rowOriginal.p(e, (float) e);
rowReversed.p(c - e - 1, (float) e);
}
auto listI = array.allTensorsAlongDimension({1});
auto listE = exp.allTensorsAlongDimension({1});
for (int e = 0; e < r; e++) {
listI->at(e)->assign(rowOriginal);
listE->at(e)->assign(rowReversed);
}
delete listI;
delete listE;
nd4j::ops::reverse op;
Nd4jLong axis = 1;
auto status = op.execute({&array}, {&reversed}, {}, {axis}, {});
ASSERT_EQ(Status::OK(), status);
ASSERT_EQ(exp, reversed);
}
}
}

View File

@ -1591,7 +1591,7 @@ TEST_F(DeclarableOpsTests2, cosine_distance_loss_test6) {
auto *result = results->at(0);
ASSERT_TRUE(result->isScalar());
ASSERT_TRUE(result->e<float>(0) == -71.);
ASSERT_TRUE(result->e<float>(0) == -71.f);
delete results;
@ -1616,7 +1616,7 @@ TEST_F(DeclarableOpsTests2, cosine_distance_loss_test7) {
auto *result = results->at(0);
ASSERT_TRUE(result->isScalar());
ASSERT_TRUE(result->e<float>(0) == -69.);
ASSERT_TRUE(result->e<float>(0) == -69.f);
delete results;
@ -1630,8 +1630,8 @@ TEST_F(DeclarableOpsTests2, cosine_distance_loss_test8) {
auto weights = NDArrayFactory::create<float>('c', {2,3,1});
labels.linspace(1);
weights.assign(0.5);
predictions.assign(0.5);
weights.assign(0.5f);
predictions.assign(0.5f);
nd4j::ops::cosine_distance_loss op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {2,2});
@ -1641,7 +1641,7 @@ TEST_F(DeclarableOpsTests2, cosine_distance_loss_test8) {
auto *result = results->at(0);
ASSERT_TRUE(result->isScalar());
ASSERT_TRUE(result->e<float>(0) == -24.);
ASSERT_TRUE(result->e<float>(0) == -24.f);
delete results;
@ -1655,8 +1655,8 @@ TEST_F(DeclarableOpsTests2, cosine_distance_loss_test9) {
auto weights = NDArrayFactory::create<float>('c', {1,1});
labels.linspace(1);
weights.assign(0.5);
predictions.assign(0.5);
weights.assign(0.5f);
predictions.assign(0.5f);
nd4j::ops::cosine_distance_loss op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {2,2});
@ -1680,10 +1680,10 @@ TEST_F(DeclarableOpsTests2, cosine_distance_loss_test10) {
auto weights = NDArrayFactory::create<float>('c', {2,3,1});
labels.linspace(1);
weights.assign(0.5);
predictions.assign(0.5);
weights.p(0, 0.);
weights.p(1, 0.);
weights.assign(0.5f);
predictions.assign(0.5f);
weights.p(0, 0.f);
weights.p(1, 0.f);
nd4j::ops::cosine_distance_loss op;
auto results = op.execute({&predictions, &weights, &labels}, {}, {2,2});

View File

@ -1707,7 +1707,7 @@ TEST_F(DeclarableOpsTests3, betainc_test8) {
b.linspace(10.);
x.assign(1.);
auto expected= NDArrayFactory::create<float>('c', {3,3}, {1.f, 1.f, 1.,1.,1.,1.,1.,1.,1.});
auto expected= NDArrayFactory::create<float>('c', {3,3}, {1.f, 1.f, 1.f,1.f,1.f,1.f,1.f,1.f,1.f});
nd4j::ops::betainc op;
auto results = op.execute({&a, &b, &x}, {}, {});
@ -2292,9 +2292,9 @@ TEST_F(DeclarableOpsTests3, svd_test3) {
}
else {
for(uint i = 0; i < expU.lengthOf(); ++i)
ASSERT_NEAR(nd4j::math::nd4j_abs(expU.e<float>(i)), nd4j::math::nd4j_abs(u->e<float>(i)), 1e-5);
ASSERT_NEAR(nd4j::math::nd4j_abs(expU.e<float>(i)), nd4j::math::nd4j_abs(u->e<float>(i)), 1e-5f);
for(uint i = 0; i < expV.lengthOf(); ++i)
ASSERT_NEAR(nd4j::math::nd4j_abs(expV.e<float>(i)), nd4j::math::nd4j_abs(v->e<float>(i)), 1e-5);
ASSERT_NEAR(nd4j::math::nd4j_abs(expV.e<float>(i)), nd4j::math::nd4j_abs(v->e<float>(i)), 1e-5f);
}
delete results;
@ -2329,9 +2329,9 @@ TEST_F(DeclarableOpsTests3, svd_test4) {
}
else {
for(uint i = 0; i < expU.lengthOf(); ++i)
ASSERT_NEAR(nd4j::math::nd4j_abs(expU.e<float>(i)), nd4j::math::nd4j_abs(u->e<float>(i)), 1e-5);
ASSERT_NEAR(nd4j::math::nd4j_abs(expU.e<float>(i)), nd4j::math::nd4j_abs(u->e<float>(i)), 1e-5f);
for(uint i = 0; i < expV.lengthOf(); ++i)
ASSERT_NEAR(nd4j::math::nd4j_abs(expV.e<float>(i)), nd4j::math::nd4j_abs(v->e<float>(i)), 1e-5);
ASSERT_NEAR(nd4j::math::nd4j_abs(expV.e<float>(i)), nd4j::math::nd4j_abs(v->e<float>(i)), 1e-5f);
}
delete results;
@ -2366,9 +2366,9 @@ TEST_F(DeclarableOpsTests3, svd_test5) {
}
else {
for(uint i = 0; i < expU.lengthOf(); ++i)
ASSERT_NEAR(nd4j::math::nd4j_abs(expU.e<float>(i)), nd4j::math::nd4j_abs(u->e<float>(i)), 1e-5);
ASSERT_NEAR(nd4j::math::nd4j_abs(expU.e<float>(i)), nd4j::math::nd4j_abs(u->e<float>(i)), 1e-5f);
for(uint i = 0; i < expV.lengthOf(); ++i)
ASSERT_NEAR(nd4j::math::nd4j_abs(expV.e<float>(i)), nd4j::math::nd4j_abs(v->e<float>(i)), 1e-5);
ASSERT_NEAR(nd4j::math::nd4j_abs(expV.e<float>(i)), nd4j::math::nd4j_abs(v->e<float>(i)), 1e-5f);
}
delete results;
@ -2421,9 +2421,9 @@ TEST_F(DeclarableOpsTests3, svd_test6) {
}
else {
for(uint i = 0; i < expU.lengthOf(); ++i)
ASSERT_NEAR(nd4j::math::nd4j_abs(expU.e<float>(i)), nd4j::math::nd4j_abs(u->e<float>(i)), 1e-5);
ASSERT_NEAR(nd4j::math::nd4j_abs(expU.e<float>(i)), nd4j::math::nd4j_abs(u->e<float>(i)), 1e-5f);
for(uint i = 0; i < expV.lengthOf(); ++i)
ASSERT_NEAR(nd4j::math::nd4j_abs(expV.e<float>(i)), nd4j::math::nd4j_abs(v->e<float>(i)), 1e-5);
ASSERT_NEAR(nd4j::math::nd4j_abs(expV.e<float>(i)), nd4j::math::nd4j_abs(v->e<float>(i)), 1e-5f);
}
delete results;

View File

@ -4084,7 +4084,7 @@ TEST_F(DeclarableOpsTests7, Softsign_1) {
TEST_F(DeclarableOpsTests7, Softsign_BP_1) {
NDArray x = NDArrayFactory::create<double >('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11});
// NDArray e = NDArrayFactory::create<float>('c', {5, 2}, {1.3132616, 2.126928, 3.0485873, 4.01815, 5.0067153, 7.0009117, 9.000123, 10.000046, 10.000046, 11.000016});
// NDArray e = NDArrayFactory::create<float>('c', {5, 2}, {1.3132616f, 2.126928f, 3.0485873f, 4.01815f, 5.0067153f, 7.0009117f, 9.000123f, 10.000046f, 10.000046f, 11.000016f});
NDArray eps = NDArrayFactory::create<double>('c', {5, 2}, {1,2,3,4,5,6,7,8, 9, 10});
nd4j::ops::softsign ffOP;
nd4j::ops::softsign_bp bpOp;

View File

@ -24,6 +24,7 @@
#include <NDArray.h>
#include <ops/ops.h>
#include <GradCheck.h>
#include <chrono>
using namespace nd4j;
@ -58,5 +59,20 @@ TEST_F(DeclarableOpsTestsCuda1, Test_CHOOSE_SCALAR_LARGE) {
//ASSERT_TRUE(exp.isSameShape(z));
delete result;
}
/*
TEST_F(DeclarableOpsTestsCuda1, Test_Reverse_TAD_1) {
auto x = NDArrayFactory::create<float>('c', {1, 3, 608, 608});
auto z = x.like();
x.linspace(1.0f);
nd4j::ops::reverse op;
auto timeStart = std::chrono::system_clock::now();
auto status = op.execute({&x}, {&z}, {}, {1}, {});
auto timeEnd = std::chrono::system_clock::now();
auto outerTime = std::chrono::duration_cast<std::chrono::microseconds> (timeEnd - timeStart).count();
nd4j_printf("exec time: %lld us\n", outerTime);
ASSERT_EQ(Status::OK(), status);
}
*/

View File

@ -661,9 +661,9 @@ TEST_F(JavaInteropTests, Test_Greater_1) {
auto x = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 1, 2});
auto y = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 0, 0});
// auto o = NDArrayFactory::create<float>('c', {2, 2}, {3, 3, 3, 3});
auto o = NDArrayFactory::create<bool>('c', {2, 2}, {1, 1, 1, 1});
auto o = NDArrayFactory::create<bool>('c', {2, 2}, {true, true, true, true});
auto exp = NDArrayFactory::create<bool>('c', {2, 2}, {0, 0, 1, 1});
auto exp = NDArrayFactory::create<bool>('c', {2, 2}, {false, false, true, true});
NDArray::prepareSpecialUse({&o}, {&x, &y});
@ -685,9 +685,9 @@ TEST_F(JavaInteropTests, Test_Greater_1) {
TEST_F(JavaInteropTests, Test_Greater_2) {
auto x = NDArrayFactory::create<float>('c', {2, 2}, {1.f, 2.f, 1.f, 2.f});
auto y = NDArrayFactory::create<float>('c', {2, 2}, {1.f, 2.f, 0.f, 0.f});
auto o = NDArrayFactory::create<bool>('c', {2, 2}, {1, 1, 1, 1});
auto o = NDArrayFactory::create<bool>('c', {2, 2}, {true, true, true, true});
auto exp = NDArrayFactory::create<bool>('c', {2, 2}, {0, 0, 1, 1});
auto exp = NDArrayFactory::create<bool>('c', {2, 2}, {false, false, true, true});
nd4j::ops::greater op;

View File

@ -1163,10 +1163,10 @@ TEST_F(NDArrayCudaBasicsTests, applyReduce3_1) {
NDArray k('c', {2,3}, {-2,3,-4,5,-2,3}, nd4j::DataType::INT32);
NDArray k2('c', {3,2}, {-2,3,-4,5,-2,3}, nd4j::DataType::INT32);
NDArray exp1('c', {3}, {4., 20., 36.}, nd4j::DataType::FLOAT32);
NDArray exp2('c', {2,3}, {-10., -2., 6.,14., 22., 30.}, nd4j::DataType::FLOAT32);
NDArray exp3('c', {4}, {38., 41., 44., 47.}, nd4j::DataType::FLOAT32);
NDArray exp4('c', {4}, {114., 117., 120., 123.}, nd4j::DataType::FLOAT32);
NDArray exp1('c', {3}, {4.f, 20.f, 36.f}, nd4j::DataType::FLOAT32);
NDArray exp2('c', {2,3}, {-10.f, -2.f, 6.f,14.f, 22.f, 30.f}, nd4j::DataType::FLOAT32);
NDArray exp3('c', {4}, {38.f, 41.f, 44.f, 47.f}, nd4j::DataType::FLOAT32);
NDArray exp4('c', {4}, {114.f, 117.f, 120.f, 123.f}, nd4j::DataType::FLOAT32);
NDArray* z = x.applyReduce3(nd4j::reduce3::Dot, &y, {0,2});
@ -1271,8 +1271,10 @@ TEST_F(NDArrayCudaBasicsTests, applyAllReduce3_1) {
NDArray x3('c', {3,2}, {1.5,1.5,1.5,1.5,1.5,1.5}, nd4j::DataType::DOUBLE);
NDArray x4('c', {3,2}, {1,2,3,4,5,6}, nd4j::DataType::DOUBLE);
NDArray exp1('c', {3,2}, {-88., -124., 6., -2., 22., 14.}, nd4j::DataType::FLOAT32);
NDArray exp2('c', {6,4}, {-36., -44., -52., -60.,-42., -52., -62., -72.,2., 0., -2., -4.,6., 4., 2., 0.,10., 8., 6., 4.,14., 12., 10., 8.}, nd4j::DataType::FLOAT32);
NDArray exp1('c', {3,2}, {-88.f, -124.f, 6.f, -2.f, 22.f, 14.f}, nd4j::DataType::FLOAT32);
NDArray exp2('c', {6,4}, {-36.f, -44.f, -52.f, -60.f,-42.f, -52.f, -62.f, -72.f, 2.f, 0.f, -2.f,
-4.f, 6.f, 4.f, 2.f, 0.f, 10.f, 8.f, 6.f, 4.f, 14.f, 12.f, 10.f, 8.f},
nd4j::DataType::FLOAT32);
NDArray exp3('c', {1,1}, {31.5}, nd4j::DataType::DOUBLE);
NDArray exp4('c', {3,3}, {4.5, 10.5, 16.5,4.5, 10.5, 16.5,4.5, 10.5, 16.5}, nd4j::DataType::DOUBLE);
@ -1400,10 +1402,10 @@ TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_float_test1) {
NDArray z5('c', {2}, {100,100}, nd4j::DataType::FLOAT32);
NDArray exp1('c', {}, {2.166667}, nd4j::DataType::DOUBLE);
NDArray exp2('c', {2,2}, {3,4,1,0.666667}, nd4j::DataType::FLOAT32);
NDArray exp2('c', {2,2}, {3.f,4.f,1.f,0.666667f}, nd4j::DataType::FLOAT32);
NDArray exp3('c', {3}, {4.5,1,1}, nd4j::DataType::DOUBLE);
NDArray exp4('c', {3,2}, {4,5,1,1,1,1}, nd4j::DataType::FLOAT32);
NDArray exp5('c', {2}, {3.5,0.833333}, nd4j::DataType::FLOAT32);
NDArray exp5('c', {2}, {3.5f,0.833333f}, nd4j::DataType::FLOAT32);
x.reduceAlongDimension(nd4j::reduce::Mean, &z1, {0,1,2});
ASSERT_TRUE(z1.equalsTo(&exp1));
@ -1503,7 +1505,7 @@ TEST_F(NDArrayCudaBasicsTests, EqualityTest1) {
////////////////////////////////////////////////////////////////////////////////
TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_same_test1) {
NDArray x('c', {2,3,2}, {1.5,2,3,4,5,6,7.5,8,-1,-2,-3.5,-4,}, nd4j::DataType::FLOAT32);
NDArray x('c', {2,3,2}, {1.5f,2.f,3.f,4.f,5.f,6.f,7.5f,8.f,-1.f,-2.f,-3.5f,-4.f}, nd4j::DataType::FLOAT32);
NDArray z1('c', {}, {100}, nd4j::DataType::FLOAT32);
NDArray z2('c', {2,2}, {100,100,100,100}, nd4j::DataType::FLOAT32);
@ -1511,11 +1513,11 @@ TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_same_test1) {
NDArray z4('c', {3,2}, {100,100,100,100,100,100}, nd4j::DataType::FLOAT32);
NDArray z5('c', {2}, {100,100}, nd4j::DataType::FLOAT32);
NDArray exp1('c', {}, {26.5}, nd4j::DataType::FLOAT32);
NDArray exp2('c', {2,2}, {9.5,12,3,2}, nd4j::DataType::FLOAT32);
NDArray exp3('c', {3}, {19,4,3.5}, nd4j::DataType::FLOAT32);
NDArray exp4('c', {3,2}, {9,10,2,2,1.5,2}, nd4j::DataType::FLOAT32);
NDArray exp5('c', {2}, {21.5,5}, nd4j::DataType::FLOAT32);
NDArray exp1('c', {}, {26.5f}, nd4j::DataType::FLOAT32);
NDArray exp2('c', {2,2}, {9.5f,12.f,3.f,2.f}, nd4j::DataType::FLOAT32);
NDArray exp3('c', {3}, {19.f,4.f,3.5f}, nd4j::DataType::FLOAT32);
NDArray exp4('c', {3,2}, {9.f,10.f,2.f,2.f,1.5f,2.f}, nd4j::DataType::FLOAT32);
NDArray exp5('c', {2}, {21.5f,5.f}, nd4j::DataType::FLOAT32);
x.reduceAlongDimension(nd4j::reduce::Sum, &z1, {0,1,2});
ASSERT_TRUE(z1.equalsTo(&exp1));
@ -1575,17 +1577,17 @@ TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_bool_test1) {
NDArray x('c', {2,3,2}, {0.5,2,3,-4,5,6,-7.5,8,-1,-0.5,-3.5,4}, nd4j::DataType::DOUBLE);
NDArray z1('c', {}, {100}, nd4j::DataType::BOOL);
NDArray z2('c', {2,2}, {100,100,100,100}, nd4j::DataType::BOOL);
NDArray z3('c', {3}, {100,100,100}, nd4j::DataType::BOOL);
NDArray z4('c', {3,2}, {100,100,100,100,100,100}, nd4j::DataType::BOOL);
NDArray z5('c', {2}, {100,100}, nd4j::DataType::BOOL);
NDArray z1('c', {}, {true}, nd4j::DataType::BOOL);
NDArray z2('c', {2,2}, {true,true,true,true}, nd4j::DataType::BOOL);
NDArray z3('c', {3}, {true,true,true}, nd4j::DataType::BOOL);
NDArray z4('c', {3,2}, {true,true,true,true,true,true}, nd4j::DataType::BOOL);
NDArray z5('c', {2}, {true,true}, nd4j::DataType::BOOL);
NDArray exp1('c', {}, {1}, nd4j::DataType::BOOL);
NDArray exp2('c', {2,2}, {1,1,0,1}, nd4j::DataType::BOOL);
NDArray exp3('c', {3}, {1,1,1}, nd4j::DataType::BOOL);
NDArray exp4('c', {3,2}, {1,1,1,0,1,1}, nd4j::DataType::BOOL);
NDArray exp5('c', {2}, {1,1}, nd4j::DataType::BOOL);
NDArray exp1('c', {}, {true}, nd4j::DataType::BOOL);
NDArray exp2('c', {2,2}, {true,true,false,true}, nd4j::DataType::BOOL);
NDArray exp3('c', {3}, {true,true,true}, nd4j::DataType::BOOL);
NDArray exp4('c', {3,2}, {true,true,true,false,true,true}, nd4j::DataType::BOOL);
NDArray exp5('c', {2}, {true,true}, nd4j::DataType::BOOL);
x.reduceAlongDimension(nd4j::reduce::IsPositive, &z1, {0,1,2});
ASSERT_TRUE(z1.equalsTo(&exp1));
@ -1643,7 +1645,7 @@ TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_bool_test2) {
////////////////////////////////////////////////////////////////////////////////
TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_long_test1) {
NDArray x('c', {2,3,2}, {0.5,2,3,-0,5,6,-7.5,0,-1,-0.5,-3.5,4}, nd4j::DataType::FLOAT32);
NDArray x('c', {2,3,2}, {0.5f,2.f,3.f,-0.f,5.f,6.f,-7.5f,0.f,-1.f,-0.5f,-3.5f,4.f}, nd4j::DataType::FLOAT32);
NDArray z1('c', {}, {100}, nd4j::DataType::INT64);
NDArray z2('c', {2,2}, {100,100,100,100}, nd4j::DataType::INT64);
@ -1912,7 +1914,7 @@ TEST_F(NDArrayCudaBasicsTests, Tile_Test_2_3)
TEST_F(NDArrayCudaBasicsTests, Operator_Plus_Test_2)
{
double expBuff[] = {2., 3, 3., 4., 4., 5, 5., 6., 6., 7, 7., 8.};
NDArray a('c', {4,4}, {1.,2,3,4,5,6,7,8,9,2,3,2,1,0,4,7.}, nd4j::DataType::FLOAT32);
NDArray a('c', {4,4}, {1,2,3,4,5,6,7,8,9,2,3,2,1,0,4,7}, nd4j::DataType::FLOAT32);
auto x = NDArrayFactory::create<double>('c', {3, 2, 1});
auto y = NDArrayFactory::create<double>('c', {1, 2});
auto expected = NDArrayFactory::create<double>(expBuff, 'c', {3, 2, 2});
@ -1928,7 +1930,7 @@ TEST_F(NDArrayCudaBasicsTests, Operator_Plus_Test_2)
//////////////////////////////////////////////////////////////////////
TEST_F(NDArrayCudaBasicsTests, assign_2)
{
NDArray x('c', {4}, {1.5,2.5,3.5,4.5}, nd4j::DataType::FLOAT32);
NDArray x('c', {4}, {1.5f,2.5f,3.5f,4.5f}, nd4j::DataType::FLOAT32);
NDArray y('c', {4}, nd4j::DataType::INT32);
NDArray expected('c', {4}, {1,2,3,4}, nd4j::DataType::INT32);
@ -1945,30 +1947,30 @@ TEST_F(NDArrayCudaBasicsTests, subarray_1)
NDArray y('f', {2,3,4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24}, nd4j::DataType::FLOAT32);
Nd4jLong shapeExpX0[] = {1, 2, 12, 8192, 1, 99};
float buffExpX0[] = {1.000000, 13.000000};
float buffExpX0[] = {1.f, 13.f};
Nd4jLong shapeExpX1[] = {1, 2, 12, 8192, 1, 99};
float buffExpX1[] = {2.000000, 14.000000};
float buffExpX1[] = {2.f, 14.f};
Nd4jLong shapeExpX2[] = {3, 2, 1, 1, 12, 4, 1, 8192, 1, 99};
float buffExpX2[] = {1.000000, 13.000000};
float buffExpX2[] = {1.f, 13.f};
Nd4jLong shapeExpX3[] = {2, 2, 4, 12, 1, 8192, 1, 99};
float buffExpX3[] = {9.000000, 10.000000, 11.000000, 12.000000, 21.000000, 22.000000, 23.000000, 24.000000};
float buffExpX3[] = {9.f, 10.f, 11.f, 12.f, 21.f, 22.f, 23.f, 24.f};
Nd4jLong shapeExpX4[] = {3, 2, 1, 4, 12, 4, 1, 8192, 1, 99};
float buffExpX4[] = {9.000000, 10.000000, 11.000000, 12.000000, 21.000000, 22.000000, 23.000000, 24.000000};
float buffExpX4[] = {9.f, 10.f, 11.f, 12.f, 21.f, 22.f, 23.f, 24.f};
Nd4jLong shapeExpX5[] = {2, 2, 3, 12, 4, 8192, 1, 99};
float buffExpX5[] = {4.000000, 8.000000, 12.000000, 16.000000, 20.000000, 24.000000};
float buffExpX5[] = {4.f, 8.f, 12.f, 16.f, 20.f, 24.f};
Nd4jLong shapeExpY0[] = {1, 2, 1, 8192, 1, 99};
float buffExpY0[] = {1.000000, 2.000000};
float buffExpY0[] = {1.f, 2.f};
Nd4jLong shapeExpY1[] = {1, 2, 1, 8192, 1, 99};
float buffExpY1[] = {7.000000, 8.000000};
float buffExpY1[] = {7.f, 8.f};
Nd4jLong shapeExpY2[] = {3, 2, 1, 1, 1, 2, 6, 8192, 1, 102};
float buffExpY2[] = {1.000000, 2.000000};
float buffExpY2[] = {1.f, 2.f};
Nd4jLong shapeExpY3[] = {2, 2, 4, 1, 6, 8192, 1, 99};
float buffExpY3[] = {5.000000, 11.000000, 17.000000, 23.000000, 6.000000, 12.000000, 18.000000, 24.000000};
float buffExpY3[] = {5.f, 11.f, 17.f, 23.f, 6.f, 12.f, 18.f, 24.f};
Nd4jLong shapeExpY4[] = {3, 2, 1, 4, 1, 2, 6, 8192, 1, 102};
float buffExpY4[] = {5.000000, 11.000000, 17.000000, 23.000000, 6.000000, 12.000000, 18.000000, 24.000000};
float buffExpY4[] = {5.f, 11.f, 17.f, 23.f, 6.f, 12.f, 18.f, 24.f};
Nd4jLong shapeExpY5[] = {2, 2, 3, 1, 2, 8192, 1, 99};
float buffExpY5[] = {19.000000, 21.000000, 23.000000, 20.000000, 22.000000, 24.000000};
float buffExpY5[] = {19.f, 21.f, 23.f, 20.f, 22.f, 24.f};
NDArray x0 = x(0, {1,2});
@ -2121,7 +2123,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_diagonal_1) {
TEST_F(NDArrayCudaBasicsTests, Test_PermuteEquality_02) {
auto x = NDArrayFactory::linspace<float>(1.f, 60.f, 60); //('c', {1, 60});
//x.linspace(1);
auto exp = NDArrayFactory::create<float>('c', {3, 4, 5}, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 48.0, 49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, 60.0});
auto exp = NDArrayFactory::create<float>('c', {3, 4, 5}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f, 37.0f, 38.0f, 39.0f, 40.0f, 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 48.0f, 49.0f, 50.0f, 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, 56.0f, 57.0f, 58.0f, 59.0f, 60.0});
x->reshapei('c', {3, 4, 5});
x->permutei({0, 1, 2});
@ -2138,7 +2140,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_PermuteEquality_02) {
TEST_F(NDArrayCudaBasicsTests, Test_PermuteEquality_0) {
auto x = NDArrayFactory::create<float>('c', {1, 60});
x.linspace(1);
auto exp = NDArrayFactory::create<float>('c', {3, 4, 5}, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 48.0, 49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, 60.0});
auto exp = NDArrayFactory::create<float>('c', {3, 4, 5}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f, 37.0f, 38.0f, 39.0f, 40.0f, 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 48.0f, 49.0f, 50.0f, 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, 56.0f, 57.0f, 58.0f, 59.0f, 60.0});
x.reshapei('c', {3, 4, 5});
x.permutei({0, 1, 2});
@ -2153,7 +2155,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_PermuteEquality_0) {
TEST_F(NDArrayCudaBasicsTests, Test_PermuteEquality_1) {
auto x = NDArrayFactory::create<float>('c', {1, 60});
x.linspace(1);
auto exp = NDArrayFactory::create<float>('c', {3, 4, 5}, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 48.0, 49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, 60.0});
auto exp = NDArrayFactory::create<float>('c', {3, 4, 5}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f, 37.0f, 38.0f, 39.0f, 40.0f, 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 48.0f, 49.0f, 50.0f, 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, 56.0f, 57.0f, 58.0f, 59.0f, 60.0});
x.reshapei('c', {3, 4, 5});
x.permutei({0, 1, 2});
@ -2170,7 +2172,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_PermuteEquality_2) {
auto xx = NDArrayFactory::linspace<float>(1.f, 60.f, 60); //('c', {1, 60});
// auto x = *xx;
//x.linspace(1);
// auto exp = NDArrayFactory::create<float>('c', {3, 4, 5}, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 48.0, 49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, 60.0});
// auto exp = NDArrayFactory::create<float>('c', {3, 4, 5}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f, 37.0f, 38.0f, 39.0f, 40.0f, 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 48.0f, 49.0f, 50.0f, 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, 56.0f, 57.0f, 58.0f, 59.0f, 60.0});
// x.reshapei('c', {3, 4, 5});
// x.permutei({0, 1, 2});
@ -2188,7 +2190,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_PermuteEquality_3) {
//x.linspace(1);
for (int l = 0; l < x.lengthOf(); l++)
x.p(l, float(l + 1.f));
auto exp = NDArrayFactory::create<float>('c', {3, 4, 5}, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 48.0, 49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, 60.0});
auto exp = NDArrayFactory::create<float>('c', {3, 4, 5}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f, 37.0f, 38.0f, 39.0f, 40.0f, 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 48.0f, 49.0f, 50.0f, 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, 56.0f, 57.0f, 58.0f, 59.0f, 60.0});
x.reshapei('c', {3, 4, 5});
x.permutei({0, 1, 2});

File diff suppressed because one or more lines are too long

View File

@ -103,6 +103,7 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNormDerivative.class,
org.nd4j.linalg.api.ops.impl.layers.convolution.Col2Im.class,
org.nd4j.linalg.api.ops.impl.layers.convolution.Conv1D.class,
org.nd4j.linalg.api.ops.impl.layers.convolution.Conv1DDerivative.class,
org.nd4j.linalg.api.ops.impl.layers.convolution.Conv2D.class,
org.nd4j.linalg.api.ops.impl.layers.convolution.Conv2DDerivative.class,
org.nd4j.linalg.api.ops.impl.layers.convolution.Conv3D.class,

View File

@ -60,7 +60,7 @@ public class NonMaxSuppression extends DynamicCustomOp {
@Override
public String[] tensorflowNames() {
return new String[]{"NonMaxSuppression", "NonMaxSuppressionV2","NonMaxSuppressionV3","NonMaxSuppressionV4"};
return new String[]{"NonMaxSuppression", "NonMaxSuppressionV2"};
}
@Override

View File

@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.image;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import lombok.NoArgsConstructor;
import lombok.val;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
@ -43,20 +44,25 @@ import java.util.Map;
@NoArgsConstructor
public class ResizeBilinear extends DynamicCustomOp {
protected boolean alignCorners = false;
protected boolean halfPixelCenters = false;
protected Integer height = null;
protected Integer width = null;
public ResizeBilinear(@NonNull SameDiff sd, @NonNull SDVariable input, int height, int width, boolean alignCorners){
public ResizeBilinear(@NonNull SameDiff sd, @NonNull SDVariable input, int height, int width,
boolean alignCorners, boolean halfPixelCenters){
super(sd, input);
this.alignCorners = alignCorners;
this.height = height;
this.width = width;
this.halfPixelCenters = halfPixelCenters;
addArgs();
}
public ResizeBilinear(@NonNull INDArray x, INDArray z, int height, int width, boolean alignCorners){
public ResizeBilinear(@NonNull INDArray x, INDArray z, int height, int width,
boolean alignCorners, boolean halfPixelCenters) {
super(new INDArray[]{x}, new INDArray[]{z});
this.alignCorners = alignCorners;
this.halfPixelCenters = halfPixelCenters;
this.height = height;
this.width = width;
addArgs();
@ -76,7 +82,12 @@ public class ResizeBilinear extends DynamicCustomOp {
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
this.alignCorners = attributesForNode.get("align_corners").getB();
val attrC = attributesForNode.get("align_corners");
val attrH = attributesForNode.get("half_pixel_centers");
this.alignCorners = attrC != null ? attrC.getB() : false;
this.halfPixelCenters = attrH != null ? attrH.getB() : false;
addArgs();
}
@ -87,8 +98,7 @@ public class ResizeBilinear extends DynamicCustomOp {
iArguments.add(Long.valueOf(height));
iArguments.add(Long.valueOf(width));
}
iArguments.add(alignCorners ? 1L : 0L);
addBArgument(alignCorners, halfPixelCenters);
}
@Override

View File

@ -204,7 +204,7 @@ public class MaxPoolWithArgmax extends DynamicCustomOp {
if(attributesForNode.containsKey("argmax")) {
outputType = TFGraphMapper.convertType(attributesForNode.get("argmax").getType());
} else {
outputType = DataType.UINT32;
outputType = DataType.LONG;
}
}
@ -278,7 +278,7 @@ public class MaxPoolWithArgmax extends DynamicCustomOp {
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 1, "Expected 1 input data type for %s, got %s", getClass(), inputDataTypes);
List<DataType> result = new ArrayList<>();
result.add(inputDataTypes.get(0));
result.add(outputType == null ? DataType.UINT32 : outputType);
result.add(outputType == null ? DataType.INT : outputType);
return result;
}
}

View File

@ -4584,6 +4584,7 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
* returns reference on array element with given index
*/
/**
* returns array element with given index
* i - element index in array
@ -5171,6 +5172,8 @@ NDArray& NDArray::operator()(const Nd4jLong* idx) {
////////////////////////////////////////////////////////////////////////
@ -5179,6 +5182,8 @@ NDArray& NDArray::operator()(const Nd4jLong* idx) {
// #ifndef __JAVACPP_HACK__
// #endif

View File

@ -4587,6 +4587,7 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
* returns reference on array element with given index
*/
/**
* returns array element with given index
* i - element index in array
@ -5174,6 +5175,8 @@ NDArray& NDArray::operator()(const Nd4jLong* idx) {
////////////////////////////////////////////////////////////////////////
@ -5182,6 +5185,8 @@ NDArray& NDArray::operator()(const Nd4jLong* idx) {
// #ifndef __JAVACPP_HACK__
// #endif
@ -18280,7 +18285,6 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
/**
* This op calculates polygamma function psi^(n)(x). Implementation is based on serial representation written in
* terms of the Hurwitz zeta function: polygamma = (-1)^{n+1} * n! * zeta(n+1, x).
* Currently the case n = 0 is not supported.
*
* Input arrays:
* 0: n - define derivative order (n+1), type integer (however currently is implemented as float casted to integer)
@ -18309,6 +18313,34 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
}
// #endif
/**
* This op calculates digamma function psi(x) = derivative of log(Gamma(x))
*
* Input arrays:
* 0: x - abscissa points where to evaluate the digamma function, type float
*
* Output array:
* 0: values of digamma function at corresponding x, type float
*
*/
// #if NOT_EXCLUDED(OP_digamma)
@Namespace("nd4j::ops") public static class digamma extends DeclarableOp {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public digamma(Pointer p) { super(p); }
/** Native array allocator. Access with {@link Pointer#position(long)}. */
public digamma(long size) { super((Pointer)null); allocateArray(size); }
private native void allocateArray(long size);
@Override public digamma position(long position) {
return (digamma)super.position(position);
}
public digamma() { super((Pointer)null); allocate(); }
private native void allocate();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
}
// #endif
/**
* This operation takes shape as first argument, and returns new NDArray filled with specific scalar value.
* Input arrays:
@ -18398,9 +18430,10 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
* This operation adjusts image hue by delta
* Input arrays:
* 0 - input array with rank >= 3, must have at least one dimension equal 3, that is dimension containing channels.
* 1 - optional argument, input scalar-array containing delta
*
* T arguments:
* 0 - delta value
* 0 - optional argument, delta value
*
* Int arguments:
* 0 - optional argument, corresponds to dimension with 3 channels
@ -18427,9 +18460,10 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
* This operation adjusts image saturation by delta
* Input arrays:
* 0 - input array with rank >= 3, must have at least one dimension equal 3, that is dimension containing channels.
* 1 - optional argument, input scalar-array containing saturation factor
*
* T arguments:
* 0 - saturation factor
* 0 - optional argument, saturation factor
*
* Int arguments:
* 0 - optional argument, corresponds to dimension with 3 channels
@ -18456,9 +18490,10 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
* This operation adjusts image contrast by given factor ( z = (x - mean) * factor + mean )
* Input arrays:
* 0 - input array with rank >= 3, must have last one dimension equal 3, that is dimension containing channels.
* 1 - optional argument, input scalar-array containing saturation contrast factor
*
* T arguments:
* 0 - contrast factor
* 0 - optional argument, contrast factor
*
*/
// #if NOT_EXCLUDED(OP_adjust_contrast)

View File

@ -760,7 +760,7 @@ public class LayerOpValidation extends BaseOpValidation {
.isSameMode(true)
.build();
SDVariable[] results = sd.nn().maxPoolWithArgmax(new String[]{"",""}, in, pooling2DConfig);
SDVariable[] results = sd.nn().maxPoolWithArgmax(new String[]{"out","idx"}, in, pooling2DConfig);
assertArrayEquals(inArr.shape(), results[0].eval().shape());
assertArrayEquals(inArr.shape(), results[1].eval().shape());
}
@ -1050,7 +1050,7 @@ public class LayerOpValidation extends BaseOpValidation {
SDVariable in = sd.var("in", inArr);
SDVariable w = sd.var("w", wArr);
SDVariable res = sd.cnn.conv1d(in, w, Conv1DConfig.builder().k(kernel).build());
SDVariable res = sd.cnn.conv1d(in, w, Conv1DConfig.builder().k(kernel).paddingMode(PaddingMode.VALID).build());
INDArray expected = Nd4j.createFromArray(
new double[][][]{

View File

@ -23,13 +23,7 @@ import static org.junit.Assert.fail;
import org.junit.Assert;
import org.junit.Test;
import org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv2D;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv2DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv3DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling3DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.*;
public class ConvConfigTests {
@ -489,24 +483,24 @@ public class ConvConfigTests {
@Test
public void testConv1D(){
Conv1DConfig.builder().k(2).build();
Conv1DConfig.builder().k(2).paddingMode(PaddingMode.SAME).build();
try{
Conv1DConfig.builder().k(0).build();
Conv1DConfig.builder().k(0).paddingMode(PaddingMode.SAME).build();
fail();
} catch (IllegalArgumentException e){
assertTrue(e.getMessage().contains("Kernel"));
}
try{
Conv1DConfig.builder().k(4).s(-2).build();
Conv1DConfig.builder().k(4).s(-2).paddingMode(PaddingMode.SAME).build();
fail();
} catch (IllegalArgumentException e){
assertTrue(e.getMessage().contains("Stride"));
}
try{
Conv1DConfig.builder().k(3).p(-2).build();
Conv1DConfig.builder().k(3).p(-2).paddingMode(PaddingMode.SAME).build();
fail();
} catch (IllegalArgumentException e){
assertTrue(e.getMessage().contains("Padding"));

View File

@ -117,9 +117,6 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
// 2019/11/15 - failure https://github.com/eclipse/deeplearning4j/issues/8402
"fake_quant/min_max_args_per_channel.*",
// 2019/11/15 - failure https://github.com/eclipse/deeplearning4j/issues/8403
"resize_bilinear/int32.*",
// Suggesting TF 1.15 bug
"non_max_suppression_v2/float16.*",

View File

@ -972,7 +972,7 @@ public class CustomOpsTests extends BaseNd4jTest {
INDArray x = Nd4j.rand(1, 2,3,4);
INDArray z = Nd4j.createUninitialized(x.shape());
boolean align = false;
val op = new ResizeBilinear(x, z, 10, 10, align);
val op = new ResizeBilinear(x, z, 10, 10, align, false);
Nd4j.exec(op);
}
@ -1174,6 +1174,7 @@ public class CustomOpsTests extends BaseNd4jTest {
assertEquals(expected, x);
}
@Ignore("AS failed 2019/12/04")
@Test
public void testPolygamma() {
INDArray n = Nd4j.linspace(DataType.FLOAT, 1.0, 1.0, 9).reshape(3,3);