Update master (#8511)
* cleaned up bert iterator tests (#110) Signed-off-by: eraly <susan.eraly@gmail.com> * Various pre-release fixes (#111) * Various fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix default dtypes for MaxPoolWithArgmax Signed-off-by: AlexDBlack <blacka101@gmail.com> * Small pre-release tweak (#112) * Log UI address on launch as in previous Play-based UI Signed-off-by: AlexDBlack <blacka101@gmail.com> * Logging level tweak for UI Signed-off-by: AlexDBlack <blacka101@gmail.com> * http not https Signed-off-by: AlexDBlack <blacka101@gmail.com> * datavec python ensure host (#113) * ensure host * one more host ensure * info->debug * [WIP] reverse improvements (#115) * initial commit Signed-off-by: raver119 <raver119@gmail.com> * reverse draft Signed-off-by: raver119 <raver119@gmail.com> * reverse kernel Signed-off-by: raver119 <raver119@gmail.com> * reverse kernel Signed-off-by: raver119 <raver119@gmail.com> * 2 micro fixes Signed-off-by: raver119 <raver119@gmail.com> * Shugeo resize fix5 (#102) * Refactored resize images ops to use TF-like bool args as input. * Refactored helpers for cpu implementation of resize_bilinear and resize_nearest_neighbor ops. * Refactored cuda implementation for image.resize_bilinear and image.resize_nearest_neighbor ops helpers. * Refactored nearest_neighbor resize op. * Added a pair of tests for special case of resize_bilinear algorithm. * Fixed issue with resize_bilinear op. * Refactored cpu implementation for helpers with resize_nearest_neighbor op. * Final fixed for resize ops to conform TF v.1.5 * Refactored cuda helpers for resize_neares_neighbor op. * Fixed resize_bilinear to accept proper data. * Fixed issue with non-float input for resize_bilinear op. * Refactored cuda helper for resize_bilinear to proper process non-float inputs. * Added tests for resize_bilinear to int inputs. * Fixed ResizeBilinear wrapper * Tests fixed * Fixed float and bool constant to avoid overflow for some kind of compilers. * Corrected float constants with float data type. * Added f suffix for float constants. * Corrected float constant to avoid overflow with initializing lists. * Corrected float initializing list with float input. * Corrected bool constant with initalizing list. * Corrected float and bool values with initializing lists. * Fixed wrong constant. * Fixed issue with 1x1 input picture for resize. * ResizeBilinear default values on import fix Signed-off-by: raver119 <raver119@gmail.com>master
parent
a6223d307b
commit
972fae60dc
|
@ -21,6 +21,7 @@ import lombok.Getter;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
import org.bytedeco.javacpp.Pointer;
|
import org.bytedeco.javacpp.Pointer;
|
||||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
import org.nd4j.linalg.api.buffer.DataBuffer;
|
||||||
|
import org.nd4j.linalg.api.concurrency.AffinityManager;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.shape.Shape;
|
import org.nd4j.linalg.api.shape.Shape;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
@ -60,6 +61,7 @@ public class NumpyArray {
|
||||||
setND4JArray();
|
setND4JArray();
|
||||||
if (copy){
|
if (copy){
|
||||||
nd4jArray = nd4jArray.dup();
|
nd4jArray = nd4jArray.dup();
|
||||||
|
Nd4j.getAffinityManager().ensureLocation(nd4jArray, AffinityManager.Location.HOST);
|
||||||
this.address = nd4jArray.data().address();
|
this.address = nd4jArray.data().address();
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -85,6 +87,7 @@ public class NumpyArray {
|
||||||
setND4JArray();
|
setND4JArray();
|
||||||
if (copy){
|
if (copy){
|
||||||
nd4jArray = nd4jArray.dup();
|
nd4jArray = nd4jArray.dup();
|
||||||
|
Nd4j.getAffinityManager().ensureLocation(nd4jArray, AffinityManager.Location.HOST);
|
||||||
this.address = nd4jArray.data().address();
|
this.address = nd4jArray.data().address();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -104,11 +107,12 @@ public class NumpyArray {
|
||||||
nd4jStrides[i] = strides[i] / elemSize;
|
nd4jStrides[i] = strides[i] / elemSize;
|
||||||
}
|
}
|
||||||
|
|
||||||
this.nd4jArray = Nd4j.create(buff, shape, nd4jStrides, 0, Shape.getOrder(shape,nd4jStrides,1), dtype);
|
nd4jArray = Nd4j.create(buff, shape, nd4jStrides, 0, Shape.getOrder(shape,nd4jStrides,1), dtype);
|
||||||
|
Nd4j.getAffinityManager().ensureLocation(nd4jArray, AffinityManager.Location.HOST);
|
||||||
}
|
}
|
||||||
|
|
||||||
public NumpyArray(INDArray nd4jArray){
|
public NumpyArray(INDArray nd4jArray){
|
||||||
|
Nd4j.getAffinityManager().ensureLocation(nd4jArray, AffinityManager.Location.HOST);
|
||||||
DataBuffer buff = nd4jArray.data();
|
DataBuffer buff = nd4jArray.data();
|
||||||
address = buff.pointer().address();
|
address = buff.pointer().address();
|
||||||
shape = nd4jArray.shape();
|
shape = nd4jArray.shape();
|
||||||
|
|
|
@ -605,7 +605,7 @@ public class PythonExecutioner {
|
||||||
|
|
||||||
|
|
||||||
private static synchronized void _exec(String code) {
|
private static synchronized void _exec(String code) {
|
||||||
log.info(code);
|
log.debug(code);
|
||||||
log.info("CPython: PyRun_SimpleStringFlag()");
|
log.info("CPython: PyRun_SimpleStringFlag()");
|
||||||
|
|
||||||
int result = PyRun_SimpleStringFlags(code, null);
|
int result = PyRun_SimpleStringFlags(code, null);
|
||||||
|
|
|
@ -17,11 +17,13 @@
|
||||||
|
|
||||||
package org.deeplearning4j.iterator;
|
package org.deeplearning4j.iterator;
|
||||||
|
|
||||||
|
import lombok.Getter;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.iterator.bert.BertMaskedLMMasker;
|
import org.deeplearning4j.iterator.bert.BertMaskedLMMasker;
|
||||||
|
import org.deeplearning4j.iterator.provider.CollectionLabeledPairSentenceProvider;
|
||||||
|
import org.deeplearning4j.iterator.provider.CollectionLabeledSentenceProvider;
|
||||||
import org.deeplearning4j.text.tokenization.tokenizerfactory.BertWordPieceTokenizerFactory;
|
import org.deeplearning4j.text.tokenization.tokenizerfactory.BertWordPieceTokenizerFactory;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.base.Preconditions;
|
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||||
|
@ -42,8 +44,12 @@ import static org.junit.Assert.*;
|
||||||
|
|
||||||
public class TestBertIterator extends BaseDL4JTest {
|
public class TestBertIterator extends BaseDL4JTest {
|
||||||
|
|
||||||
private File pathToVocab = Resources.asFile("other/vocab.txt");
|
private static File pathToVocab = Resources.asFile("other/vocab.txt");
|
||||||
private static Charset c = StandardCharsets.UTF_8;
|
private static Charset c = StandardCharsets.UTF_8;
|
||||||
|
private static String shortSentence = "I saw a girl with a telescope.";
|
||||||
|
private static String longSentence = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum";
|
||||||
|
private static String sentenceA = "Goodnight noises everywhere";
|
||||||
|
private static String sentenceB = "Goodnight moon";
|
||||||
|
|
||||||
public TestBertIterator() throws IOException {
|
public TestBertIterator() throws IOException {
|
||||||
}
|
}
|
||||||
|
@ -51,20 +57,15 @@ public class TestBertIterator extends BaseDL4JTest {
|
||||||
@Test(timeout = 20000L)
|
@Test(timeout = 20000L)
|
||||||
public void testBertSequenceClassification() throws Exception {
|
public void testBertSequenceClassification() throws Exception {
|
||||||
|
|
||||||
String toTokenize1 = "I saw a girl with a telescope.";
|
int minibatchSize = 2;
|
||||||
String toTokenize2 = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum";
|
TestSentenceHelper testHelper = new TestSentenceHelper();
|
||||||
List<String> forInference = new ArrayList<>();
|
|
||||||
forInference.add(toTokenize1);
|
|
||||||
forInference.add(toTokenize2);
|
|
||||||
BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
|
|
||||||
|
|
||||||
BertIterator b = BertIterator.builder()
|
BertIterator b = BertIterator.builder()
|
||||||
.tokenizer(t)
|
.tokenizer(testHelper.getTokenizer())
|
||||||
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, 16)
|
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, 16)
|
||||||
.minibatchSize(2)
|
.minibatchSize(minibatchSize)
|
||||||
.sentenceProvider(new TestSentenceProvider())
|
.sentenceProvider(testHelper.getSentenceProvider())
|
||||||
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK)
|
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK)
|
||||||
.vocabMap(t.getVocab())
|
.vocabMap(testHelper.getTokenizer().getVocab())
|
||||||
.task(BertIterator.Task.SEQ_CLASSIFICATION)
|
.task(BertIterator.Task.SEQ_CLASSIFICATION)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
|
@ -73,82 +74,77 @@ public class TestBertIterator extends BaseDL4JTest {
|
||||||
System.out.println(mds.getFeatures(0));
|
System.out.println(mds.getFeatures(0));
|
||||||
System.out.println(mds.getFeaturesMaskArray(0));
|
System.out.println(mds.getFeaturesMaskArray(0));
|
||||||
|
|
||||||
|
INDArray expF = Nd4j.create(DataType.INT, 1, 16);
|
||||||
INDArray expEx0 = Nd4j.create(DataType.INT, 1, 16);
|
INDArray expM = Nd4j.create(DataType.INT, 1, 16);
|
||||||
INDArray expM0 = Nd4j.create(DataType.INT, 1, 16);
|
Map<String, Integer> m = testHelper.getTokenizer().getVocab();
|
||||||
List<String> tokens = t.create(toTokenize1).getTokens();
|
for (int i = 0; i < minibatchSize; i++) {
|
||||||
Map<String, Integer> m = t.getVocab();
|
INDArray expFTemp = Nd4j.create(DataType.INT, 1, 16);
|
||||||
for (int i = 0; i < tokens.size(); i++) {
|
INDArray expMTemp = Nd4j.create(DataType.INT, 1, 16);
|
||||||
int idx = m.get(tokens.get(i));
|
List<String> tokens = testHelper.getTokenizedSentences().get(i);
|
||||||
expEx0.putScalar(0, i, idx);
|
System.out.println(tokens);
|
||||||
expM0.putScalar(0, i, 1);
|
for (int j = 0; j < tokens.size(); j++) {
|
||||||
}
|
String token = tokens.get(j);
|
||||||
|
|
||||||
INDArray expEx1 = Nd4j.create(DataType.INT, 1, 16);
|
|
||||||
INDArray expM1 = Nd4j.create(DataType.INT, 1, 16);
|
|
||||||
List<String> tokens2 = t.create(toTokenize2).getTokens();
|
|
||||||
for (int i = 0; i < tokens2.size(); i++) {
|
|
||||||
String token = tokens2.get(i);
|
|
||||||
if (!m.containsKey(token)) {
|
if (!m.containsKey(token)) {
|
||||||
throw new IllegalStateException("Unknown token: \"" + token + "\"");
|
throw new IllegalStateException("Unknown token: \"" + token + "\"");
|
||||||
}
|
}
|
||||||
int idx = m.get(token);
|
int idx = m.get(token);
|
||||||
expEx1.putScalar(0, i, idx);
|
expFTemp.putScalar(0, j, idx);
|
||||||
expM1.putScalar(0, i, 1);
|
expMTemp.putScalar(0, j, 1);
|
||||||
|
}
|
||||||
|
if (i == 0) {
|
||||||
|
expF = expFTemp.dup();
|
||||||
|
expM = expMTemp.dup();
|
||||||
|
} else {
|
||||||
|
expF = Nd4j.vstack(expF, expFTemp);
|
||||||
|
expM = Nd4j.vstack(expM, expMTemp);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
INDArray expF = Nd4j.vstack(expEx0, expEx1);
|
|
||||||
INDArray expM = Nd4j.vstack(expM0, expM1);
|
|
||||||
|
|
||||||
assertEquals(expF, mds.getFeatures(0));
|
assertEquals(expF, mds.getFeatures(0));
|
||||||
assertEquals(expM, mds.getFeaturesMaskArray(0));
|
assertEquals(expM, mds.getFeaturesMaskArray(0));
|
||||||
assertEquals(expF, b.featurizeSentences(forInference).getFirst()[0]);
|
assertEquals(expF, b.featurizeSentences(testHelper.getSentences()).getFirst()[0]);
|
||||||
assertEquals(expM, b.featurizeSentences(forInference).getSecond()[0]);
|
assertEquals(expM, b.featurizeSentences(testHelper.getSentences()).getSecond()[0]);
|
||||||
|
|
||||||
b.next(); //pop the third element
|
|
||||||
assertFalse(b.hasNext());
|
assertFalse(b.hasNext());
|
||||||
b.reset();
|
b.reset();
|
||||||
assertTrue(b.hasNext());
|
assertTrue(b.hasNext());
|
||||||
|
|
||||||
forInference.set(0, toTokenize2);
|
|
||||||
//Same thing, but with segment ID also
|
//Same thing, but with segment ID also
|
||||||
b = BertIterator.builder()
|
b = BertIterator.builder()
|
||||||
.tokenizer(t)
|
.tokenizer(testHelper.getTokenizer())
|
||||||
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, 16)
|
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, 16)
|
||||||
.minibatchSize(2)
|
.minibatchSize(minibatchSize)
|
||||||
.sentenceProvider(new TestSentenceProvider())
|
.sentenceProvider(testHelper.getSentenceProvider())
|
||||||
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID)
|
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID)
|
||||||
.vocabMap(t.getVocab())
|
.vocabMap(testHelper.getTokenizer().getVocab())
|
||||||
.task(BertIterator.Task.SEQ_CLASSIFICATION)
|
.task(BertIterator.Task.SEQ_CLASSIFICATION)
|
||||||
.build();
|
.build();
|
||||||
mds = b.next();
|
mds = b.next();
|
||||||
assertEquals(2, mds.getFeatures().length);
|
assertEquals(2, mds.getFeatures().length);
|
||||||
//assertEquals(2, mds.getFeaturesMaskArrays().length); second element is null...
|
|
||||||
assertEquals(2, b.featurizeSentences(forInference).getFirst().length);
|
|
||||||
//Segment ID should be all 0s for single segment task
|
//Segment ID should be all 0s for single segment task
|
||||||
INDArray segmentId = expM.like();
|
INDArray segmentId = expM.like();
|
||||||
assertEquals(segmentId, mds.getFeatures(1));
|
assertEquals(segmentId, mds.getFeatures(1));
|
||||||
assertEquals(segmentId, b.featurizeSentences(forInference).getFirst()[1]);
|
assertEquals(segmentId, b.featurizeSentences(testHelper.getSentences()).getFirst()[1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(timeout = 20000L)
|
@Test(timeout = 20000L)
|
||||||
public void testBertUnsupervised() throws Exception {
|
public void testBertUnsupervised() throws Exception {
|
||||||
|
int minibatchSize = 2;
|
||||||
|
TestSentenceHelper testHelper = new TestSentenceHelper();
|
||||||
//Task 1: Unsupervised
|
//Task 1: Unsupervised
|
||||||
BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
|
|
||||||
BertIterator b = BertIterator.builder()
|
BertIterator b = BertIterator.builder()
|
||||||
.tokenizer(t)
|
.tokenizer(testHelper.getTokenizer())
|
||||||
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, 16)
|
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, 16)
|
||||||
.minibatchSize(2)
|
.minibatchSize(minibatchSize)
|
||||||
.sentenceProvider(new TestSentenceProvider())
|
.sentenceProvider(testHelper.getSentenceProvider())
|
||||||
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK)
|
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK)
|
||||||
.vocabMap(t.getVocab())
|
.vocabMap(testHelper.getTokenizer().getVocab())
|
||||||
.task(BertIterator.Task.UNSUPERVISED)
|
.task(BertIterator.Task.UNSUPERVISED)
|
||||||
.masker(new BertMaskedLMMasker(new Random(12345), 0.2, 0.5, 0.5))
|
.masker(new BertMaskedLMMasker(new Random(12345), 0.2, 0.5, 0.5))
|
||||||
.unsupervisedLabelFormat(BertIterator.UnsupervisedLabelFormat.RANK2_IDX)
|
.unsupervisedLabelFormat(BertIterator.UnsupervisedLabelFormat.RANK2_IDX)
|
||||||
.maskToken("[MASK]")
|
.maskToken("[MASK]")
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
System.out.println("Mask token index: " + t.getVocab().get("[MASK]"));
|
System.out.println("Mask token index: " + testHelper.getTokenizer().getVocab().get("[MASK]"));
|
||||||
|
|
||||||
MultiDataSet mds = b.next();
|
MultiDataSet mds = b.next();
|
||||||
System.out.println(mds.getFeatures(0));
|
System.out.println(mds.getFeatures(0));
|
||||||
|
@ -156,7 +152,6 @@ public class TestBertIterator extends BaseDL4JTest {
|
||||||
System.out.println(mds.getLabels(0));
|
System.out.println(mds.getLabels(0));
|
||||||
System.out.println(mds.getLabelsMaskArray(0));
|
System.out.println(mds.getLabelsMaskArray(0));
|
||||||
|
|
||||||
b.next(); //pop the third element
|
|
||||||
assertFalse(b.hasNext());
|
assertFalse(b.hasNext());
|
||||||
b.reset();
|
b.reset();
|
||||||
assertTrue(b.hasNext());
|
assertTrue(b.hasNext());
|
||||||
|
@ -164,39 +159,33 @@ public class TestBertIterator extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test(timeout = 20000L)
|
@Test(timeout = 20000L)
|
||||||
public void testLengthHandling() throws Exception {
|
public void testLengthHandling() throws Exception {
|
||||||
String toTokenize1 = "I saw a girl with a telescope.";
|
int minibatchSize = 2;
|
||||||
String toTokenize2 = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum";
|
TestSentenceHelper testHelper = new TestSentenceHelper();
|
||||||
List<String> forInference = new ArrayList<>();
|
INDArray expF = Nd4j.create(DataType.INT, 1, 16);
|
||||||
forInference.add(toTokenize1);
|
INDArray expM = Nd4j.create(DataType.INT, 1, 16);
|
||||||
forInference.add(toTokenize2);
|
Map<String, Integer> m = testHelper.getTokenizer().getVocab();
|
||||||
BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
|
for (int i = 0; i < minibatchSize; i++) {
|
||||||
INDArray expEx0 = Nd4j.create(DataType.INT, 1, 16);
|
List<String> tokens = testHelper.getTokenizedSentences().get(i);
|
||||||
INDArray expM0 = Nd4j.create(DataType.INT, 1, 16);
|
INDArray expFTemp = Nd4j.create(DataType.INT, 1, 16);
|
||||||
List<String> tokens = t.create(toTokenize1).getTokens();
|
INDArray expMTemp = Nd4j.create(DataType.INT, 1, 16);
|
||||||
System.out.println(tokens);
|
System.out.println(tokens);
|
||||||
Map<String, Integer> m = t.getVocab();
|
for (int j = 0; j < tokens.size(); j++) {
|
||||||
for (int i = 0; i < tokens.size(); i++) {
|
String token = tokens.get(j);
|
||||||
int idx = m.get(tokens.get(i));
|
|
||||||
expEx0.putScalar(0, i, idx);
|
|
||||||
expM0.putScalar(0, i, 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
INDArray expEx1 = Nd4j.create(DataType.INT, 1, 16);
|
|
||||||
INDArray expM1 = Nd4j.create(DataType.INT, 1, 16);
|
|
||||||
List<String> tokens2 = t.create(toTokenize2).getTokens();
|
|
||||||
System.out.println(tokens2);
|
|
||||||
for (int i = 0; i < tokens2.size(); i++) {
|
|
||||||
String token = tokens2.get(i);
|
|
||||||
if (!m.containsKey(token)) {
|
if (!m.containsKey(token)) {
|
||||||
throw new IllegalStateException("Unknown token: \"" + token + "\"");
|
throw new IllegalStateException("Unknown token: \"" + token + "\"");
|
||||||
}
|
}
|
||||||
int idx = m.get(token);
|
int idx = m.get(token);
|
||||||
expEx1.putScalar(0, i, idx);
|
expFTemp.putScalar(0, j, idx);
|
||||||
expM1.putScalar(0, i, 1);
|
expMTemp.putScalar(0, j, 1);
|
||||||
|
}
|
||||||
|
if (i == 0) {
|
||||||
|
expF = expFTemp.dup();
|
||||||
|
expM = expMTemp.dup();
|
||||||
|
} else {
|
||||||
|
expF = Nd4j.vstack(expF, expFTemp);
|
||||||
|
expM = Nd4j.vstack(expM, expMTemp);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
INDArray expF = Nd4j.vstack(expEx0, expEx1);
|
|
||||||
INDArray expM = Nd4j.vstack(expM0, expM1);
|
|
||||||
|
|
||||||
//--------------------------------------------------------------
|
//--------------------------------------------------------------
|
||||||
|
|
||||||
|
@ -205,12 +194,12 @@ public class TestBertIterator extends BaseDL4JTest {
|
||||||
//Any length: as long as we need to fit longest sequence
|
//Any length: as long as we need to fit longest sequence
|
||||||
|
|
||||||
BertIterator b = BertIterator.builder()
|
BertIterator b = BertIterator.builder()
|
||||||
.tokenizer(t)
|
.tokenizer(testHelper.getTokenizer())
|
||||||
.lengthHandling(BertIterator.LengthHandling.ANY_LENGTH, -1)
|
.lengthHandling(BertIterator.LengthHandling.ANY_LENGTH, -1)
|
||||||
.minibatchSize(2)
|
.minibatchSize(minibatchSize)
|
||||||
.sentenceProvider(new TestSentenceProvider())
|
.sentenceProvider(testHelper.getSentenceProvider())
|
||||||
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK)
|
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK)
|
||||||
.vocabMap(t.getVocab())
|
.vocabMap(testHelper.getTokenizer().getVocab())
|
||||||
.task(BertIterator.Task.SEQ_CLASSIFICATION)
|
.task(BertIterator.Task.SEQ_CLASSIFICATION)
|
||||||
.build();
|
.build();
|
||||||
MultiDataSet mds = b.next();
|
MultiDataSet mds = b.next();
|
||||||
|
@ -219,20 +208,19 @@ public class TestBertIterator extends BaseDL4JTest {
|
||||||
assertArrayEquals(expShape, mds.getFeaturesMaskArray(0).shape());
|
assertArrayEquals(expShape, mds.getFeaturesMaskArray(0).shape());
|
||||||
assertEquals(expF.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 14)), mds.getFeatures(0));
|
assertEquals(expF.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 14)), mds.getFeatures(0));
|
||||||
assertEquals(expM.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 14)), mds.getFeaturesMaskArray(0));
|
assertEquals(expM.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 14)), mds.getFeaturesMaskArray(0));
|
||||||
assertEquals(mds.getFeatures(0), b.featurizeSentences(forInference).getFirst()[0]);
|
assertEquals(mds.getFeatures(0), b.featurizeSentences(testHelper.getSentences()).getFirst()[0]);
|
||||||
assertEquals(mds.getFeaturesMaskArray(0), b.featurizeSentences(forInference).getSecond()[0]);
|
assertEquals(mds.getFeaturesMaskArray(0), b.featurizeSentences(testHelper.getSentences()).getSecond()[0]);
|
||||||
|
|
||||||
//Clip only: clip to maximum, but don't pad if less
|
//Clip only: clip to maximum, but don't pad if less
|
||||||
b = BertIterator.builder()
|
b = BertIterator.builder()
|
||||||
.tokenizer(t)
|
.tokenizer(testHelper.getTokenizer())
|
||||||
.lengthHandling(BertIterator.LengthHandling.CLIP_ONLY, 20)
|
.lengthHandling(BertIterator.LengthHandling.CLIP_ONLY, 20)
|
||||||
.minibatchSize(2)
|
.minibatchSize(minibatchSize)
|
||||||
.sentenceProvider(new TestSentenceProvider())
|
.sentenceProvider(testHelper.getSentenceProvider())
|
||||||
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK)
|
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK)
|
||||||
.vocabMap(t.getVocab())
|
.vocabMap(testHelper.getTokenizer().getVocab())
|
||||||
.task(BertIterator.Task.SEQ_CLASSIFICATION)
|
.task(BertIterator.Task.SEQ_CLASSIFICATION)
|
||||||
.build();
|
.build();
|
||||||
mds = b.next();
|
|
||||||
expShape = new long[]{2, 14};
|
expShape = new long[]{2, 14};
|
||||||
assertArrayEquals(expShape, mds.getFeatures(0).shape());
|
assertArrayEquals(expShape, mds.getFeatures(0).shape());
|
||||||
assertArrayEquals(expShape, mds.getFeaturesMaskArray(0).shape());
|
assertArrayEquals(expShape, mds.getFeaturesMaskArray(0).shape());
|
||||||
|
@ -241,54 +229,38 @@ public class TestBertIterator extends BaseDL4JTest {
|
||||||
@Test(timeout = 20000L)
|
@Test(timeout = 20000L)
|
||||||
public void testMinibatchPadding() throws Exception {
|
public void testMinibatchPadding() throws Exception {
|
||||||
Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT);
|
Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT);
|
||||||
String toTokenize1 = "I saw a girl with a telescope.";
|
int minibatchSize = 3;
|
||||||
String toTokenize2 = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum";
|
TestSentenceHelper testHelper = new TestSentenceHelper(minibatchSize);
|
||||||
String toTokenize3 = "Goodnight noises everywhere";
|
|
||||||
List<String> forInference = new ArrayList<>();
|
|
||||||
forInference.add(toTokenize1);
|
|
||||||
forInference.add(toTokenize2);
|
|
||||||
forInference.add(toTokenize3);
|
|
||||||
BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
|
|
||||||
INDArray expEx0 = Nd4j.create(DataType.INT, 1, 16);
|
|
||||||
INDArray expM0 = Nd4j.create(DataType.INT, 1, 16);
|
|
||||||
List<String> tokens = t.create(toTokenize1).getTokens();
|
|
||||||
Map<String, Integer> m = t.getVocab();
|
|
||||||
for (int i = 0; i < tokens.size(); i++) {
|
|
||||||
int idx = m.get(tokens.get(i));
|
|
||||||
expEx0.putScalar(0, i, idx);
|
|
||||||
expM0.putScalar(0, i, 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
INDArray expEx1 = Nd4j.create(DataType.INT, 1, 16);
|
|
||||||
INDArray expM1 = Nd4j.create(DataType.INT, 1, 16);
|
|
||||||
List<String> tokens2 = t.create(toTokenize2).getTokens();
|
|
||||||
for (int i = 0; i < tokens2.size(); i++) {
|
|
||||||
String token = tokens2.get(i);
|
|
||||||
if (!m.containsKey(token)) {
|
|
||||||
throw new IllegalStateException("Unknown token: \"" + token + "\"");
|
|
||||||
}
|
|
||||||
int idx = m.get(token);
|
|
||||||
expEx1.putScalar(0, i, idx);
|
|
||||||
expM1.putScalar(0, i, 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
INDArray expEx3 = Nd4j.create(DataType.INT, 1, 16);
|
|
||||||
INDArray expM3 = Nd4j.create(DataType.INT, 1, 16);
|
|
||||||
List<String> tokens3 = t.create(toTokenize3).getTokens();
|
|
||||||
for (int i = 0; i < tokens3.size(); i++) {
|
|
||||||
String token = tokens3.get(i);
|
|
||||||
if (!m.containsKey(token)) {
|
|
||||||
throw new IllegalStateException("Unknown token: \"" + token + "\"");
|
|
||||||
}
|
|
||||||
int idx = m.get(token);
|
|
||||||
expEx3.putScalar(0, i, idx);
|
|
||||||
expM3.putScalar(0, i, 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
INDArray zeros = Nd4j.create(DataType.INT, 1, 16);
|
INDArray zeros = Nd4j.create(DataType.INT, 1, 16);
|
||||||
INDArray expF = Nd4j.vstack(expEx0, expEx1, expEx3, zeros);
|
INDArray expF = Nd4j.create(DataType.INT, 1, 16);
|
||||||
INDArray expM = Nd4j.vstack(expM0, expM1, expM3, zeros);
|
INDArray expM = Nd4j.create(DataType.INT, 1, 16);
|
||||||
INDArray expL = Nd4j.createFromArray(new float[][]{{1, 0}, {0, 1}, {1, 0}, {0, 0}});
|
Map<String, Integer> m = testHelper.getTokenizer().getVocab();
|
||||||
|
for (int i = 0; i < minibatchSize; i++) {
|
||||||
|
List<String> tokens = testHelper.getTokenizedSentences().get(i);
|
||||||
|
INDArray expFTemp = Nd4j.create(DataType.INT, 1, 16);
|
||||||
|
INDArray expMTemp = Nd4j.create(DataType.INT, 1, 16);
|
||||||
|
System.out.println(tokens);
|
||||||
|
for (int j = 0; j < tokens.size(); j++) {
|
||||||
|
String token = tokens.get(j);
|
||||||
|
if (!m.containsKey(token)) {
|
||||||
|
throw new IllegalStateException("Unknown token: \"" + token + "\"");
|
||||||
|
}
|
||||||
|
int idx = m.get(token);
|
||||||
|
expFTemp.putScalar(0, j, idx);
|
||||||
|
expMTemp.putScalar(0, j, 1);
|
||||||
|
}
|
||||||
|
if (i == 0) {
|
||||||
|
expF = expFTemp.dup();
|
||||||
|
expM = expMTemp.dup();
|
||||||
|
} else {
|
||||||
|
expF = Nd4j.vstack(expF.dup(), expFTemp);
|
||||||
|
expM = Nd4j.vstack(expM.dup(), expMTemp);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
expF = Nd4j.vstack(expF, zeros);
|
||||||
|
expM = Nd4j.vstack(expM, zeros);
|
||||||
|
INDArray expL = Nd4j.createFromArray(new float[][]{{0, 1}, {1, 0}, {0, 1}, {0, 0}});
|
||||||
INDArray expLM = Nd4j.create(DataType.FLOAT, 4, 1);
|
INDArray expLM = Nd4j.create(DataType.FLOAT, 4, 1);
|
||||||
expLM.putScalar(0, 0, 1);
|
expLM.putScalar(0, 0, 1);
|
||||||
expLM.putScalar(1, 0, 1);
|
expLM.putScalar(1, 0, 1);
|
||||||
|
@ -297,13 +269,13 @@ public class TestBertIterator extends BaseDL4JTest {
|
||||||
//--------------------------------------------------------------
|
//--------------------------------------------------------------
|
||||||
|
|
||||||
BertIterator b = BertIterator.builder()
|
BertIterator b = BertIterator.builder()
|
||||||
.tokenizer(t)
|
.tokenizer(testHelper.getTokenizer())
|
||||||
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, 16)
|
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, 16)
|
||||||
.minibatchSize(4)
|
.minibatchSize(minibatchSize + 1)
|
||||||
.padMinibatches(true)
|
.padMinibatches(true)
|
||||||
.sentenceProvider(new TestSentenceProvider())
|
.sentenceProvider(testHelper.getSentenceProvider())
|
||||||
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID)
|
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID)
|
||||||
.vocabMap(t.getVocab())
|
.vocabMap(testHelper.getTokenizer().getVocab())
|
||||||
.task(BertIterator.Task.SEQ_CLASSIFICATION)
|
.task(BertIterator.Task.SEQ_CLASSIFICATION)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
|
@ -323,170 +295,175 @@ public class TestBertIterator extends BaseDL4JTest {
|
||||||
assertEquals(expL, mds.getLabels(0));
|
assertEquals(expL, mds.getLabels(0));
|
||||||
assertEquals(expLM, mds.getLabelsMaskArray(0));
|
assertEquals(expLM, mds.getLabelsMaskArray(0));
|
||||||
|
|
||||||
assertEquals(expF, b.featurizeSentences(forInference).getFirst()[0]);
|
assertEquals(expF, b.featurizeSentences(testHelper.getSentences()).getFirst()[0]);
|
||||||
assertEquals(expM, b.featurizeSentences(forInference).getSecond()[0]);
|
assertEquals(expM, b.featurizeSentences(testHelper.getSentences()).getSecond()[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
Checks that a mds from a pair sentence is equal to hstack'd mds from the left side and right side of the pair
|
||||||
|
Checks different lengths for max length to check popping and padding
|
||||||
|
*/
|
||||||
@Test
|
@Test
|
||||||
public void testSentencePairsSingle() throws IOException {
|
public void testSentencePairsSingle() throws IOException {
|
||||||
String shortSent = "I saw a girl with a telescope.";
|
|
||||||
String longSent = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum";
|
|
||||||
boolean prependAppend;
|
boolean prependAppend;
|
||||||
BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
|
int numOfSentences;
|
||||||
int shortL = t.create(shortSent).countTokens();
|
|
||||||
int longL = t.create(longSent).countTokens();
|
TestSentenceHelper testHelper = new TestSentenceHelper();
|
||||||
|
int shortL = testHelper.getShortestL();
|
||||||
|
int longL = testHelper.getLongestL();
|
||||||
|
|
||||||
Triple<MultiDataSet, MultiDataSet, MultiDataSet> multiDataSetTriple;
|
Triple<MultiDataSet, MultiDataSet, MultiDataSet> multiDataSetTriple;
|
||||||
MultiDataSet shortLongPair, shortSentence, longSentence;
|
MultiDataSet fromPair, leftSide, rightSide;
|
||||||
|
|
||||||
// check for pair max length exactly equal to sum of lengths - pop neither no padding
|
// check for pair max length exactly equal to sum of lengths - pop neither no padding
|
||||||
// should be the same as hstack with segment ids 1 for second sentence
|
// should be the same as hstack with segment ids 1 for second sentence
|
||||||
prependAppend = true;
|
prependAppend = true;
|
||||||
multiDataSetTriple = generateMultiDataSets(new Triple<>(shortL + longL, shortL, longL), prependAppend);
|
numOfSentences = 1;
|
||||||
shortLongPair = multiDataSetTriple.getFirst();
|
multiDataSetTriple = generateMultiDataSets(new Triple<>(shortL + longL, shortL, longL), prependAppend, numOfSentences);
|
||||||
shortSentence = multiDataSetTriple.getSecond();
|
fromPair = multiDataSetTriple.getFirst();
|
||||||
longSentence = multiDataSetTriple.getThird();
|
leftSide = multiDataSetTriple.getSecond();
|
||||||
assertEquals(shortLongPair.getFeatures(0), Nd4j.hstack(shortSentence.getFeatures(0), longSentence.getFeatures(0)));
|
rightSide = multiDataSetTriple.getThird();
|
||||||
longSentence.getFeatures(1).addi(1);
|
assertEquals(fromPair.getFeatures(0), Nd4j.hstack(leftSide.getFeatures(0), rightSide.getFeatures(0)));
|
||||||
assertEquals(shortLongPair.getFeatures(1), Nd4j.hstack(shortSentence.getFeatures(1), longSentence.getFeatures(1)));
|
rightSide.getFeatures(1).addi(1); //add 1 for right side segment ids
|
||||||
assertEquals(shortLongPair.getFeaturesMaskArray(0), Nd4j.hstack(shortSentence.getFeaturesMaskArray(0), longSentence.getFeaturesMaskArray(0)));
|
assertEquals(fromPair.getFeatures(1), Nd4j.hstack(leftSide.getFeatures(1), rightSide.getFeatures(1)));
|
||||||
|
assertEquals(fromPair.getFeaturesMaskArray(0), Nd4j.hstack(leftSide.getFeaturesMaskArray(0), rightSide.getFeaturesMaskArray(0)));
|
||||||
|
|
||||||
//check for pair max length greater than sum of lengths - pop neither with padding
|
//check for pair max length greater than sum of lengths - pop neither with padding
|
||||||
// features should be the same as hstack of shorter and longer padded with prepend/append
|
// features should be the same as hstack of shorter and longer padded with prepend/append
|
||||||
// segment id should 1 only in the longer for part of the length of the sentence
|
// segment id should 1 only in the longer for part of the length of the sentence
|
||||||
prependAppend = true;
|
prependAppend = true;
|
||||||
multiDataSetTriple = generateMultiDataSets(new Triple<>(shortL + longL + 5, shortL, longL + 5), prependAppend);
|
numOfSentences = 1;
|
||||||
shortLongPair = multiDataSetTriple.getFirst();
|
multiDataSetTriple = generateMultiDataSets(new Triple<>(shortL + longL + 5, shortL, longL + 5), prependAppend, numOfSentences);
|
||||||
shortSentence = multiDataSetTriple.getSecond();
|
fromPair = multiDataSetTriple.getFirst();
|
||||||
longSentence = multiDataSetTriple.getThird();
|
leftSide = multiDataSetTriple.getSecond();
|
||||||
assertEquals(shortLongPair.getFeatures(0), Nd4j.hstack(shortSentence.getFeatures(0), longSentence.getFeatures(0)));
|
rightSide = multiDataSetTriple.getThird();
|
||||||
longSentence.getFeatures(1).get(NDArrayIndex.all(), NDArrayIndex.interval(0, longL + 1)).addi(1); //segmentId stays 0 for the padded part
|
assertEquals(fromPair.getFeatures(0), Nd4j.hstack(leftSide.getFeatures(0), rightSide.getFeatures(0)));
|
||||||
assertEquals(shortLongPair.getFeatures(1), Nd4j.hstack(shortSentence.getFeatures(1), longSentence.getFeatures(1)));
|
rightSide.getFeatures(1).get(NDArrayIndex.all(), NDArrayIndex.interval(0, longL + 1)).addi(1); //segmentId stays 0 for the padded part
|
||||||
assertEquals(shortLongPair.getFeaturesMaskArray(0), Nd4j.hstack(shortSentence.getFeaturesMaskArray(0), longSentence.getFeaturesMaskArray(0)));
|
assertEquals(fromPair.getFeatures(1), Nd4j.hstack(leftSide.getFeatures(1), rightSide.getFeatures(1)));
|
||||||
|
assertEquals(fromPair.getFeaturesMaskArray(0), Nd4j.hstack(leftSide.getFeaturesMaskArray(0), rightSide.getFeaturesMaskArray(0)));
|
||||||
|
|
||||||
//check for pair max length less than shorter sentence - pop both
|
//check for pair max length less than shorter sentence - pop both
|
||||||
//should be the same as hstack with segment ids 1 for second sentence if no prepend/append
|
//should be the same as hstack with segment ids 1 for second sentence if no prepend/append
|
||||||
int maxL = shortL - 2;
|
int maxL = 5;//checking odd
|
||||||
|
numOfSentences = 3;
|
||||||
prependAppend = false;
|
prependAppend = false;
|
||||||
multiDataSetTriple = generateMultiDataSets(new Triple<>(maxL, maxL / 2, maxL - maxL / 2), prependAppend);
|
multiDataSetTriple = generateMultiDataSets(new Triple<>(maxL, maxL / 2, maxL - maxL / 2), prependAppend, numOfSentences);
|
||||||
shortLongPair = multiDataSetTriple.getFirst();
|
fromPair = multiDataSetTriple.getFirst();
|
||||||
shortSentence = multiDataSetTriple.getSecond();
|
leftSide = multiDataSetTriple.getSecond();
|
||||||
longSentence = multiDataSetTriple.getThird();
|
rightSide = multiDataSetTriple.getThird();
|
||||||
assertEquals(shortLongPair.getFeatures(0), Nd4j.hstack(shortSentence.getFeatures(0), longSentence.getFeatures(0)));
|
assertEquals(fromPair.getFeatures(0), Nd4j.hstack(leftSide.getFeatures(0), rightSide.getFeatures(0)));
|
||||||
longSentence.getFeatures(1).addi(1);
|
rightSide.getFeatures(1).addi(1);
|
||||||
assertEquals(shortLongPair.getFeatures(1), Nd4j.hstack(shortSentence.getFeatures(1), longSentence.getFeatures(1)));
|
assertEquals(fromPair.getFeatures(1), Nd4j.hstack(leftSide.getFeatures(1), rightSide.getFeatures(1)));
|
||||||
assertEquals(shortLongPair.getFeaturesMaskArray(0), Nd4j.hstack(shortSentence.getFeaturesMaskArray(0), longSentence.getFeaturesMaskArray(0)));
|
assertEquals(fromPair.getFeaturesMaskArray(0), Nd4j.hstack(leftSide.getFeaturesMaskArray(0), rightSide.getFeaturesMaskArray(0)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
Same idea as previous test - construct mds from bert iterator with sep sentences and check against one with pairs
|
||||||
|
Checks various max lengths
|
||||||
|
Has sentences of varying lengths
|
||||||
|
*/
|
||||||
@Test
|
@Test
|
||||||
public void testSentencePairsUnequalLengths() throws IOException {
|
public void testSentencePairsUnequalLengths() throws IOException {
|
||||||
//check for pop only longer (i.e between longer and longer + shorter), first row pop from second sentence, next row pop from first sentence, nothing to pop in the third row
|
|
||||||
//should be identical to hstack if there is no append, prepend
|
int minibatchSize = 4;
|
||||||
//batch size is 2
|
int numOfSentencesinIter = 3;
|
||||||
int mbS = 4;
|
|
||||||
String shortSent = "I saw a girl with a telescope.";
|
TestSentencePairsHelper testPairHelper = new TestSentencePairsHelper(numOfSentencesinIter);
|
||||||
String longSent = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum";
|
int shortL = testPairHelper.getShortL();
|
||||||
String sent1 = "Goodnight noises everywhere"; //shorter than shortSent - no popping
|
int longL = testPairHelper.getLongL();
|
||||||
String sent2 = "Goodnight moon"; //shorter than shortSent - no popping
|
int sent1L = testPairHelper.getSentenceALen();
|
||||||
BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
|
int sent2L = testPairHelper.getSentenceBLen();
|
||||||
int shortL = t.create(shortSent).countTokens();
|
|
||||||
int longL = t.create(longSent).countTokens();
|
System.out.println("Sentence Pairs, Left");
|
||||||
int sent1L = t.create(sent1).countTokens();
|
System.out.println(testPairHelper.getSentencesLeft());
|
||||||
int sent2L = t.create(sent2).countTokens();
|
System.out.println("Sentence Pairs, Right");
|
||||||
//won't check 2*shortL + 1 because this will always pop on the left
|
System.out.println(testPairHelper.getSentencesRight());
|
||||||
for (int maxL = longL + shortL - 1; maxL > 2 * shortL; maxL--) {
|
|
||||||
|
//anything outside this range more will need to check padding,truncation
|
||||||
|
for (int maxL = longL + shortL; maxL > 2 * shortL + 1; maxL--) {
|
||||||
|
|
||||||
|
System.out.println("Running for max length = " + maxL);
|
||||||
|
|
||||||
MultiDataSet leftMDS = BertIterator.builder()
|
MultiDataSet leftMDS = BertIterator.builder()
|
||||||
.tokenizer(t)
|
.tokenizer(testPairHelper.getTokenizer())
|
||||||
.minibatchSize(mbS)
|
.minibatchSize(minibatchSize)
|
||||||
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID)
|
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID)
|
||||||
.vocabMap(t.getVocab())
|
.vocabMap(testPairHelper.getTokenizer().getVocab())
|
||||||
.task(BertIterator.Task.SEQ_CLASSIFICATION)
|
.task(BertIterator.Task.SEQ_CLASSIFICATION)
|
||||||
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, longL + 10) //random big num guaranteed to be longer than either
|
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, longL * 10) //random big num guaranteed to be longer than either
|
||||||
.sentenceProvider(new TestSentenceProvider())
|
.sentenceProvider(new TestSentenceHelper(numOfSentencesinIter).getSentenceProvider())
|
||||||
.padMinibatches(true)
|
.padMinibatches(true)
|
||||||
.build().next();
|
.build().next();
|
||||||
|
|
||||||
MultiDataSet rightMDS = BertIterator.builder()
|
MultiDataSet rightMDS = BertIterator.builder()
|
||||||
.tokenizer(t)
|
.tokenizer(testPairHelper.getTokenizer())
|
||||||
.minibatchSize(mbS)
|
.minibatchSize(minibatchSize)
|
||||||
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID)
|
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID)
|
||||||
.vocabMap(t.getVocab())
|
.vocabMap(testPairHelper.getTokenizer().getVocab())
|
||||||
.task(BertIterator.Task.SEQ_CLASSIFICATION)
|
.task(BertIterator.Task.SEQ_CLASSIFICATION)
|
||||||
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, longL + 10) //random big num guaranteed to be longer than either
|
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, longL * 10) //random big num guaranteed to be longer than either
|
||||||
.sentenceProvider(new TestSentenceProvider(true))
|
.sentenceProvider(new TestSentenceHelper(true, numOfSentencesinIter).getSentenceProvider())
|
||||||
.padMinibatches(true)
|
.padMinibatches(true)
|
||||||
.build().next();
|
.build().next();
|
||||||
|
|
||||||
MultiDataSet pairMDS = BertIterator.builder()
|
MultiDataSet pairMDS = BertIterator.builder()
|
||||||
.tokenizer(t)
|
.tokenizer(testPairHelper.getTokenizer())
|
||||||
.minibatchSize(mbS)
|
.minibatchSize(minibatchSize)
|
||||||
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID)
|
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID)
|
||||||
.vocabMap(t.getVocab())
|
.vocabMap(testPairHelper.getTokenizer().getVocab())
|
||||||
.task(BertIterator.Task.SEQ_CLASSIFICATION)
|
.task(BertIterator.Task.SEQ_CLASSIFICATION)
|
||||||
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, maxL) //random big num guaranteed to be longer than either
|
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, maxL)
|
||||||
.sentencePairProvider(new TestSentencePairProvider())
|
.sentencePairProvider(testPairHelper.getPairSentenceProvider())
|
||||||
.padMinibatches(true)
|
.padMinibatches(true)
|
||||||
.build().next();
|
.build().next();
|
||||||
|
|
||||||
//Left sentences here are {{shortSent},
|
|
||||||
// {longSent},
|
|
||||||
// {Sent1}}
|
|
||||||
//Right sentences here are {{longSent},
|
|
||||||
// {shortSent},
|
|
||||||
// {Sent2}}
|
|
||||||
//The sentence pairs here are {{shortSent,longSent},
|
|
||||||
// {longSent,shortSent}
|
|
||||||
// {Sent1, Sent2}}
|
|
||||||
|
|
||||||
//CHECK FEATURES
|
//CHECK FEATURES
|
||||||
INDArray combinedFeat = Nd4j.create(DataType.INT,mbS,maxL);
|
INDArray combinedFeat = Nd4j.create(DataType.INT, minibatchSize, maxL);
|
||||||
//left side
|
//left side
|
||||||
INDArray leftFeatures = leftMDS.getFeatures(0);
|
INDArray leftFeatures = leftMDS.getFeatures(0);
|
||||||
INDArray topLSentFeat = leftFeatures.getRow(0).get(NDArrayIndex.interval(0, shortL));
|
INDArray topLSentFeat = leftFeatures.getRow(0).get(NDArrayIndex.interval(0, shortL));
|
||||||
INDArray midLSentFeat = leftFeatures.getRow(1).get(NDArrayIndex.interval(0, maxL - shortL));
|
INDArray midLSentFeat = leftFeatures.getRow(1).get(NDArrayIndex.interval(0, maxL - shortL));
|
||||||
INDArray bottomLSentFeat = leftFeatures.getRow(2).get(NDArrayIndex.interval(0,sent1L));
|
INDArray bottomLSentFeat = leftFeatures.getRow(2).get(NDArrayIndex.interval(0, sent1L));
|
||||||
//right side
|
//right side
|
||||||
INDArray rightFeatures = rightMDS.getFeatures(0);
|
INDArray rightFeatures = rightMDS.getFeatures(0);
|
||||||
INDArray topRSentFeat = rightFeatures.getRow(0).get(NDArrayIndex.interval(0, maxL - shortL));
|
INDArray topRSentFeat = rightFeatures.getRow(0).get(NDArrayIndex.interval(0, maxL - shortL));
|
||||||
INDArray midRSentFeat = rightFeatures.getRow(1).get(NDArrayIndex.interval(0, shortL));
|
INDArray midRSentFeat = rightFeatures.getRow(1).get(NDArrayIndex.interval(0, shortL));
|
||||||
INDArray bottomRSentFeat = rightFeatures.getRow(2).get(NDArrayIndex.interval(0,sent2L));
|
INDArray bottomRSentFeat = rightFeatures.getRow(2).get(NDArrayIndex.interval(0, sent2L));
|
||||||
//expected pair
|
//expected pair
|
||||||
combinedFeat.getRow(0).addi(Nd4j.hstack(topLSentFeat,topRSentFeat));
|
combinedFeat.getRow(0).addi(Nd4j.hstack(topLSentFeat, topRSentFeat));
|
||||||
combinedFeat.getRow(1).addi(Nd4j.hstack(midLSentFeat,midRSentFeat));
|
combinedFeat.getRow(1).addi(Nd4j.hstack(midLSentFeat, midRSentFeat));
|
||||||
combinedFeat.getRow(2).get(NDArrayIndex.interval(0,sent1L+sent2L)).addi(Nd4j.hstack(bottomLSentFeat,bottomRSentFeat));
|
combinedFeat.getRow(2).get(NDArrayIndex.interval(0, sent1L + sent2L)).addi(Nd4j.hstack(bottomLSentFeat, bottomRSentFeat));
|
||||||
|
|
||||||
assertEquals(maxL, pairMDS.getFeatures(0).shape()[1]);
|
assertEquals(maxL, pairMDS.getFeatures(0).shape()[1]);
|
||||||
assertArrayEquals(combinedFeat.shape(), pairMDS.getFeatures(0).shape());
|
assertArrayEquals(combinedFeat.shape(), pairMDS.getFeatures(0).shape());
|
||||||
assertEquals(combinedFeat, pairMDS.getFeatures(0));
|
assertEquals(combinedFeat, pairMDS.getFeatures(0));
|
||||||
|
|
||||||
//CHECK SEGMENT ID
|
//CHECK SEGMENT ID
|
||||||
INDArray combinedFetSeg = Nd4j.create(DataType.INT, mbS, maxL);
|
INDArray combinedFetSeg = Nd4j.create(DataType.INT, minibatchSize, maxL);
|
||||||
combinedFetSeg.get(NDArrayIndex.point(0), NDArrayIndex.interval(shortL, maxL)).addi(1);
|
combinedFetSeg.get(NDArrayIndex.point(0), NDArrayIndex.interval(shortL, maxL)).addi(1);
|
||||||
combinedFetSeg.get(NDArrayIndex.point(1), NDArrayIndex.interval(maxL - shortL, maxL)).addi(1);
|
combinedFetSeg.get(NDArrayIndex.point(1), NDArrayIndex.interval(maxL - shortL, maxL)).addi(1);
|
||||||
combinedFetSeg.get(NDArrayIndex.point(2), NDArrayIndex.interval(sent1L, sent1L+sent2L)).addi(1);
|
combinedFetSeg.get(NDArrayIndex.point(2), NDArrayIndex.interval(sent1L, sent1L + sent2L)).addi(1);
|
||||||
assertArrayEquals(combinedFetSeg.shape(), pairMDS.getFeatures(1).shape());
|
assertArrayEquals(combinedFetSeg.shape(), pairMDS.getFeatures(1).shape());
|
||||||
assertEquals(maxL, combinedFetSeg.shape()[1]);
|
assertEquals(maxL, combinedFetSeg.shape()[1]);
|
||||||
assertEquals(combinedFetSeg, pairMDS.getFeatures(1));
|
assertEquals(combinedFetSeg, pairMDS.getFeatures(1));
|
||||||
|
|
||||||
|
testPairHelper.getPairSentenceProvider().reset();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testSentencePairFeaturizer() throws IOException {
|
public void testSentencePairFeaturizer() throws IOException {
|
||||||
String shortSent = "I saw a girl with a telescope.";
|
int minibatchSize = 2;
|
||||||
String longSent = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum";
|
TestSentencePairsHelper testPairHelper = new TestSentencePairsHelper(minibatchSize);
|
||||||
List<Pair<String, String>> listSentencePair = new ArrayList<>();
|
|
||||||
listSentencePair.add(new Pair<>(shortSent, longSent));
|
|
||||||
listSentencePair.add(new Pair<>(longSent, shortSent));
|
|
||||||
BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
|
|
||||||
BertIterator b = BertIterator.builder()
|
BertIterator b = BertIterator.builder()
|
||||||
.tokenizer(t)
|
.tokenizer(testPairHelper.getTokenizer())
|
||||||
.minibatchSize(2)
|
.minibatchSize(minibatchSize)
|
||||||
.padMinibatches(true)
|
.padMinibatches(true)
|
||||||
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID)
|
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID)
|
||||||
.vocabMap(t.getVocab())
|
.vocabMap(testPairHelper.getTokenizer().getVocab())
|
||||||
.task(BertIterator.Task.SEQ_CLASSIFICATION)
|
.task(BertIterator.Task.SEQ_CLASSIFICATION)
|
||||||
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, 128)
|
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, 128)
|
||||||
.sentencePairProvider(new TestSentencePairProvider())
|
.sentencePairProvider(testPairHelper.getPairSentenceProvider())
|
||||||
.prependToken("[CLS]")
|
.prependToken("[CLS]")
|
||||||
.appendToken("[SEP]")
|
.appendToken("[SEP]")
|
||||||
.build();
|
.build();
|
||||||
|
@ -494,23 +471,19 @@ public class TestBertIterator extends BaseDL4JTest {
|
||||||
INDArray[] featuresArr = mds.getFeatures();
|
INDArray[] featuresArr = mds.getFeatures();
|
||||||
INDArray[] featuresMaskArr = mds.getFeaturesMaskArrays();
|
INDArray[] featuresMaskArr = mds.getFeaturesMaskArrays();
|
||||||
|
|
||||||
Pair<INDArray[], INDArray[]> p = b.featurizeSentencePairs(listSentencePair);
|
Pair<INDArray[], INDArray[]> p = b.featurizeSentencePairs(testPairHelper.getSentencePairs());
|
||||||
assertEquals(p.getFirst().length, 2);
|
assertEquals(p.getFirst().length, 2);
|
||||||
assertEquals(featuresArr[0], p.getFirst()[0]);
|
assertEquals(featuresArr[0], p.getFirst()[0]);
|
||||||
assertEquals(featuresArr[1], p.getFirst()[1]);
|
assertEquals(featuresArr[1], p.getFirst()[1]);
|
||||||
//assertEquals(p.getSecond().length, 2);
|
|
||||||
assertEquals(featuresMaskArr[0], p.getSecond()[0]);
|
assertEquals(featuresMaskArr[0], p.getSecond()[0]);
|
||||||
//assertEquals(featuresMaskArr[1], p.getSecond()[1]);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns three multidatasets from bert iterator based on given max lengths and whether to prepend/append
|
* Returns three multidatasets (one from pair of sentences and the other two from single sentence lists) from bert iterator
|
||||||
|
* with given max lengths and whether to prepend/append
|
||||||
* Idea is the sentence pair dataset can be constructed from the single sentence datasets
|
* Idea is the sentence pair dataset can be constructed from the single sentence datasets
|
||||||
* First one is constructed from a sentence pair "I saw a girl with a telescope." & "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum"
|
|
||||||
* Second one is constructed from the left of the sentence pair i.e "I saw a girl with a telescope."
|
|
||||||
* Third one is constructed from the right of the sentence pair i.e "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum"
|
|
||||||
*/
|
*/
|
||||||
private Triple<MultiDataSet, MultiDataSet, MultiDataSet> generateMultiDataSets(Triple<Integer, Integer, Integer> maxLengths, boolean prependAppend) throws IOException {
|
private Triple<MultiDataSet, MultiDataSet, MultiDataSet> generateMultiDataSets(Triple<Integer, Integer, Integer> maxLengths, boolean prependAppend, int numSentences) throws IOException {
|
||||||
BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
|
BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
|
||||||
int maxforPair = maxLengths.getFirst();
|
int maxforPair = maxLengths.getFirst();
|
||||||
int maxPartOne = maxLengths.getSecond();
|
int maxPartOne = maxLengths.getSecond();
|
||||||
|
@ -518,133 +491,155 @@ public class TestBertIterator extends BaseDL4JTest {
|
||||||
BertIterator.Builder commonBuilder;
|
BertIterator.Builder commonBuilder;
|
||||||
commonBuilder = BertIterator.builder()
|
commonBuilder = BertIterator.builder()
|
||||||
.tokenizer(t)
|
.tokenizer(t)
|
||||||
.minibatchSize(1)
|
.minibatchSize(4)
|
||||||
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID)
|
.featureArrays(BertIterator.FeatureArrays.INDICES_MASK_SEGMENTID)
|
||||||
.vocabMap(t.getVocab())
|
.vocabMap(t.getVocab())
|
||||||
.task(BertIterator.Task.SEQ_CLASSIFICATION);
|
.task(BertIterator.Task.SEQ_CLASSIFICATION);
|
||||||
BertIterator shortLongPairFirstIter = commonBuilder
|
BertIterator pairIter = commonBuilder
|
||||||
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, prependAppend ? maxforPair + 3 : maxforPair)
|
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, prependAppend ? maxforPair + 3 : maxforPair)
|
||||||
.sentencePairProvider(new TestSentencePairProvider())
|
.sentencePairProvider(new TestSentencePairsHelper(numSentences).getPairSentenceProvider())
|
||||||
.prependToken(prependAppend ? "[CLS]" : null)
|
.prependToken(prependAppend ? "[CLS]" : null)
|
||||||
.appendToken(prependAppend ? "[SEP]" : null)
|
.appendToken(prependAppend ? "[SEP]" : null)
|
||||||
.build();
|
.build();
|
||||||
BertIterator shortFirstIter = commonBuilder
|
BertIterator leftIter = commonBuilder
|
||||||
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, prependAppend ? maxPartOne + 2 : maxPartOne)
|
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, prependAppend ? maxPartOne + 2 : maxPartOne)
|
||||||
.sentenceProvider(new TestSentenceProvider())
|
.sentenceProvider(new TestSentenceHelper(numSentences).getSentenceProvider())
|
||||||
.prependToken(prependAppend ? "[CLS]" : null)
|
.prependToken(prependAppend ? "[CLS]" : null)
|
||||||
.appendToken(prependAppend ? "[SEP]" : null)
|
.appendToken(prependAppend ? "[SEP]" : null)
|
||||||
.build();
|
.build();
|
||||||
BertIterator longFirstIter = commonBuilder
|
BertIterator rightIter = commonBuilder
|
||||||
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, prependAppend ? maxPartTwo + 1 : maxPartTwo)
|
.lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, prependAppend ? maxPartTwo + 1 : maxPartTwo)
|
||||||
.sentenceProvider(new TestSentenceProvider(true))
|
.sentenceProvider(new TestSentenceHelper(true, numSentences).getSentenceProvider())
|
||||||
.prependToken(null)
|
.prependToken(null)
|
||||||
.appendToken(prependAppend ? "[SEP]" : null)
|
.appendToken(prependAppend ? "[SEP]" : null)
|
||||||
.build();
|
.build();
|
||||||
return new Triple<>(shortLongPairFirstIter.next(), shortFirstIter.next(), longFirstIter.next());
|
return new Triple<>(pairIter.next(), leftIter.next(), rightIter.next());
|
||||||
}
|
}
|
||||||
|
|
||||||
private static class TestSentenceProvider implements LabeledSentenceProvider {
|
@Getter
|
||||||
|
private static class TestSentencePairsHelper {
|
||||||
|
|
||||||
private int pos = 0;
|
private List<String> sentencesLeft;
|
||||||
private boolean invert;
|
private List<String> sentencesRight;
|
||||||
|
private List<Pair<String, String>> sentencePairs;
|
||||||
|
private List<List<String>> tokenizedSentencesLeft;
|
||||||
|
private List<List<String>> tokenizedSentencesRight;
|
||||||
|
private List<String> labels;
|
||||||
|
private int shortL;
|
||||||
|
private int longL;
|
||||||
|
private int sentenceALen;
|
||||||
|
private int sentenceBLen;
|
||||||
|
private BertWordPieceTokenizerFactory tokenizer;
|
||||||
|
private CollectionLabeledPairSentenceProvider pairSentenceProvider;
|
||||||
|
|
||||||
private TestSentenceProvider() {
|
private TestSentencePairsHelper() throws IOException {
|
||||||
this.invert = false;
|
this(3);
|
||||||
}
|
}
|
||||||
|
|
||||||
private TestSentenceProvider(boolean invert) {
|
private TestSentencePairsHelper(int minibatchSize) throws IOException {
|
||||||
this.invert = invert;
|
sentencesLeft = new ArrayList<>();
|
||||||
|
sentencesRight = new ArrayList<>();
|
||||||
|
sentencePairs = new ArrayList<>();
|
||||||
|
labels = new ArrayList<>();
|
||||||
|
tokenizedSentencesLeft = new ArrayList<>();
|
||||||
|
tokenizedSentencesRight = new ArrayList<>();
|
||||||
|
tokenizer = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
|
||||||
|
sentencesLeft.add(shortSentence);
|
||||||
|
sentencesRight.add(longSentence);
|
||||||
|
sentencePairs.add(new Pair<>(shortSentence, longSentence));
|
||||||
|
labels.add("positive");
|
||||||
|
if (minibatchSize > 1) {
|
||||||
|
sentencesLeft.add(longSentence);
|
||||||
|
sentencesRight.add(shortSentence);
|
||||||
|
sentencePairs.add(new Pair<>(longSentence, shortSentence));
|
||||||
|
labels.add("negative");
|
||||||
|
if (minibatchSize > 2) {
|
||||||
|
sentencesLeft.add(sentenceA);
|
||||||
|
sentencesRight.add(sentenceB);
|
||||||
|
sentencePairs.add(new Pair<>(sentenceA, sentenceB));
|
||||||
|
labels.add("positive");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int i = 0; i < minibatchSize; i++) {
|
||||||
|
List<String> tokensL = tokenizer.create(sentencesLeft.get(i)).getTokens();
|
||||||
|
List<String> tokensR = tokenizer.create(sentencesRight.get(i)).getTokens();
|
||||||
|
if (i == 0) {
|
||||||
|
shortL = tokensL.size();
|
||||||
|
longL = tokensR.size();
|
||||||
|
}
|
||||||
|
if (i == 2) {
|
||||||
|
sentenceALen = tokensL.size();
|
||||||
|
sentenceBLen = tokensR.size();
|
||||||
|
}
|
||||||
|
tokenizedSentencesLeft.add(tokensL);
|
||||||
|
tokenizedSentencesRight.add(tokensR);
|
||||||
|
}
|
||||||
|
pairSentenceProvider = new CollectionLabeledPairSentenceProvider(sentencesLeft, sentencesRight, labels, null);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Getter
|
||||||
public boolean hasNext() {
|
private static class TestSentenceHelper {
|
||||||
return pos < totalNumSentences();
|
|
||||||
|
private List<String> sentences;
|
||||||
|
private List<List<String>> tokenizedSentences;
|
||||||
|
private List<String> labels;
|
||||||
|
private int shortestL = 0;
|
||||||
|
private int longestL = 0;
|
||||||
|
private BertWordPieceTokenizerFactory tokenizer;
|
||||||
|
private CollectionLabeledSentenceProvider sentenceProvider;
|
||||||
|
|
||||||
|
private TestSentenceHelper() throws IOException {
|
||||||
|
this(false, 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
private TestSentenceHelper(int minibatchSize) throws IOException {
|
||||||
public Pair<String, String> nextSentence() {
|
this(false, minibatchSize);
|
||||||
Preconditions.checkState(hasNext());
|
}
|
||||||
if (pos == 0) {
|
|
||||||
pos++;
|
private TestSentenceHelper(boolean alternateOrder) throws IOException {
|
||||||
if (!invert) return new Pair<>("I saw a girl with a telescope.", "positive");
|
this(false, 3);
|
||||||
return new Pair<>("Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum", "negative");
|
}
|
||||||
|
|
||||||
|
private TestSentenceHelper(boolean alternateOrder, int minibatchSize) throws IOException {
|
||||||
|
sentences = new ArrayList<>();
|
||||||
|
labels = new ArrayList<>();
|
||||||
|
tokenizedSentences = new ArrayList<>();
|
||||||
|
tokenizer = new BertWordPieceTokenizerFactory(pathToVocab, false, false, c);
|
||||||
|
if (!alternateOrder) {
|
||||||
|
sentences.add(shortSentence);
|
||||||
|
labels.add("positive");
|
||||||
|
if (minibatchSize > 1) {
|
||||||
|
sentences.add(longSentence);
|
||||||
|
labels.add("negative");
|
||||||
|
if (minibatchSize > 2) {
|
||||||
|
sentences.add(sentenceA);
|
||||||
|
labels.add("positive");
|
||||||
|
}
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
if (pos == 1) {
|
sentences.add(longSentence);
|
||||||
pos++;
|
labels.add("negative");
|
||||||
if (!invert) return new Pair<>("Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum", "negative");
|
if (minibatchSize > 1) {
|
||||||
return new Pair<>("I saw a girl with a telescope.", "positive");
|
sentences.add(shortSentence);
|
||||||
}
|
labels.add("positive");
|
||||||
pos++;
|
if (minibatchSize > 2) {
|
||||||
if (!invert)
|
sentences.add(sentenceB);
|
||||||
return new Pair<>("Goodnight noises everywhere", "positive");
|
labels.add("positive");
|
||||||
return new Pair<>("Goodnight moon", "positive");
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public void reset() {
|
|
||||||
pos = 0;
|
|
||||||
}
|
}
|
||||||
|
for (int i = 0; i < sentences.size(); i++) {
|
||||||
@Override
|
List<String> tokenizedSentence = tokenizer.create(sentences.get(i)).getTokens();
|
||||||
public int totalNumSentences() {
|
if (i == 0)
|
||||||
return 3;
|
shortestL = tokenizedSentence.size();
|
||||||
|
if (tokenizedSentence.size() > longestL)
|
||||||
|
longestL = tokenizedSentence.size();
|
||||||
|
if (tokenizedSentence.size() < shortestL)
|
||||||
|
shortestL = tokenizedSentence.size();
|
||||||
|
tokenizedSentences.add(tokenizedSentence);
|
||||||
}
|
}
|
||||||
|
sentenceProvider = new CollectionLabeledSentenceProvider(sentences, labels, null);
|
||||||
@Override
|
|
||||||
public List<String> allLabels() {
|
|
||||||
return Arrays.asList("positive", "negative");
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int numLabelClasses() {
|
|
||||||
return 2;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private static class TestSentencePairProvider implements LabeledPairSentenceProvider {
|
|
||||||
|
|
||||||
private int pos = 0;
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean hasNext() {
|
|
||||||
return pos < totalNumSentences();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Triple<String, String, String> nextSentencePair() {
|
|
||||||
Preconditions.checkState(hasNext());
|
|
||||||
if (pos == 0) {
|
|
||||||
pos++;
|
|
||||||
return new Triple<>("I saw a girl with a telescope.", "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum", "positive");
|
|
||||||
} else {
|
|
||||||
if (pos == 1) {
|
|
||||||
pos++;
|
|
||||||
return new Triple<>("Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum", "I saw a girl with a telescope.", "negative");
|
|
||||||
}
|
|
||||||
pos++;
|
|
||||||
return new Triple<>("Goodnight noises everywhere", "Goodnight moon", "positive");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void reset() {
|
|
||||||
pos = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int totalNumSentences() {
|
|
||||||
return 3;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<String> allLabels() {
|
|
||||||
return Arrays.asList("positive", "negative");
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int numLabelClasses() {
|
|
||||||
return 2;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -254,6 +254,9 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
|
||||||
uiEventRoutingThread = new Thread(new StatsEventRouterRunnable());
|
uiEventRoutingThread = new Thread(new StatsEventRouterRunnable());
|
||||||
uiEventRoutingThread.setDaemon(true);
|
uiEventRoutingThread.setDaemon(true);
|
||||||
uiEventRoutingThread.start();
|
uiEventRoutingThread.start();
|
||||||
|
|
||||||
|
String address = UIServer.getInstance().getAddress();
|
||||||
|
log.info("Deeplearning4j UI server started at: {}", address);
|
||||||
}
|
}
|
||||||
|
|
||||||
private List<String> extractArgsFromRoute(String path, RoutingContext rc) {
|
private List<String> extractArgsFromRoute(String path, RoutingContext rc) {
|
||||||
|
@ -317,7 +320,7 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String getAddress() {
|
public String getAddress() {
|
||||||
return "https://localhost:" + server.actualPort();
|
return "http://localhost:" + server.actualPort();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -421,7 +424,7 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
|
||||||
}
|
}
|
||||||
|
|
||||||
private void runHelper() throws Exception {
|
private void runHelper() throws Exception {
|
||||||
log.info("VertxUIServer.StatsEventRouterRunnable started");
|
log.trace("VertxUIServer.StatsEventRouterRunnable started");
|
||||||
//Idea: collect all event stats, and route them to the appropriate modules
|
//Idea: collect all event stats, and route them to the appropriate modules
|
||||||
while (!shutdown.get()) {
|
while (!shutdown.get()) {
|
||||||
|
|
||||||
|
|
|
@ -1256,6 +1256,9 @@ namespace nd4j {
|
||||||
FORCEINLINE T& t(const Nd4jLong i, const Nd4jLong j);
|
FORCEINLINE T& t(const Nd4jLong i, const Nd4jLong j);
|
||||||
template<typename T>
|
template<typename T>
|
||||||
FORCEINLINE T& t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k);
|
FORCEINLINE T& t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k);
|
||||||
|
template<typename T>
|
||||||
|
FORCEINLINE T& t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong w);
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* returns array element with given index
|
* returns array element with given index
|
||||||
|
@ -1268,6 +1271,8 @@ namespace nd4j {
|
||||||
FORCEINLINE T t(const Nd4jLong i, const Nd4jLong j) const;
|
FORCEINLINE T t(const Nd4jLong i, const Nd4jLong j) const;
|
||||||
template<typename T>
|
template<typename T>
|
||||||
FORCEINLINE T t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) const;
|
FORCEINLINE T t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) const;
|
||||||
|
template<typename T>
|
||||||
|
FORCEINLINE T t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong w) const;
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -1711,7 +1716,7 @@ namespace nd4j {
|
||||||
if (isEmpty())
|
if (isEmpty())
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
return shape::isMatrix(this->_shapeInfo);
|
return 0 != shape::isMatrix(this->_shapeInfo);
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1751,7 +1756,7 @@ namespace nd4j {
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
bool NDArray::isScalar() const {
|
bool NDArray::isScalar() const {
|
||||||
return shape::isScalar(this->_shapeInfo);
|
return 0 != shape::isScalar(this->_shapeInfo);
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -2082,7 +2087,7 @@ template <typename T>
|
||||||
T& NDArray::t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) {
|
T& NDArray::t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) {
|
||||||
|
|
||||||
if (rankOf() != 3 || i >= sizeAt(0) || j >= sizeAt(1) || k >= sizeAt(2))
|
if (rankOf() != 3 || i >= sizeAt(0) || j >= sizeAt(1) || k >= sizeAt(2))
|
||||||
throw std::invalid_argument("NDArray::t(i,j,k): one of input indexes is out of array length or rank!=2 !");
|
throw std::invalid_argument("NDArray::t(i,j,k): one of input indexes is out of array length or rank!=3!");
|
||||||
if (DataTypeUtils::fromT<T>() != _dataType)
|
if (DataTypeUtils::fromT<T>() != _dataType)
|
||||||
throw std::invalid_argument("NDArray::t(i,j,k): type of array is not equal to template type T!");
|
throw std::invalid_argument("NDArray::t(i,j,k): type of array is not equal to template type T!");
|
||||||
|
|
||||||
|
@ -2095,6 +2100,23 @@ T& NDArray::t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) {
|
||||||
return *(reinterpret_cast<T*>(bufferWithOffset(offset)));
|
return *(reinterpret_cast<T*>(bufferWithOffset(offset)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
T& NDArray::t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong w) {
|
||||||
|
|
||||||
|
if (rankOf() != 4 || i >= sizeAt(0) || j >= sizeAt(1) || k >= sizeAt(2), w >= sizeAt(3))
|
||||||
|
throw std::invalid_argument("NDArray::t(i,j,k,w): one of input indexes is out of array length or rank!=4 !");
|
||||||
|
if (DataTypeUtils::fromT<T>() != _dataType)
|
||||||
|
throw std::invalid_argument("NDArray::t(i,j,k,w): type of array is not equal to template type T!");
|
||||||
|
|
||||||
|
if(!isActualOnHostSide())
|
||||||
|
syncToHost();
|
||||||
|
|
||||||
|
Nd4jLong coords[4] = {i, j, k, w};
|
||||||
|
auto offset = shape::getOffset(getShapeInfo(), coords);
|
||||||
|
tickWriteHost();
|
||||||
|
return *(reinterpret_cast<T*>(bufferWithOffset(offset)));
|
||||||
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
template <typename T>
|
template <typename T>
|
||||||
T NDArray::t(const Nd4jLong i) const {
|
T NDArray::t(const Nd4jLong i) const {
|
||||||
|
@ -2133,7 +2155,7 @@ T NDArray::t(const Nd4jLong i, const Nd4jLong j) const {
|
||||||
T NDArray::t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) const {
|
T NDArray::t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) const {
|
||||||
|
|
||||||
if (rankOf() != 3 || i >= sizeAt(0) || j >= sizeAt(1) || k >= sizeAt(2))
|
if (rankOf() != 3 || i >= sizeAt(0) || j >= sizeAt(1) || k >= sizeAt(2))
|
||||||
throw std::invalid_argument("NDArray::t(i,j,k): one of input indexes is out of array length or rank!=2 !");
|
throw std::invalid_argument("NDArray::t(i,j,k): one of input indexes is out of array length or rank!=3!");
|
||||||
if (DataTypeUtils::fromT<T>() != _dataType)
|
if (DataTypeUtils::fromT<T>() != _dataType)
|
||||||
throw std::invalid_argument("NDArray::t(i,j,k): type of array is not equal to template type T!");
|
throw std::invalid_argument("NDArray::t(i,j,k): type of array is not equal to template type T!");
|
||||||
|
|
||||||
|
@ -2146,6 +2168,23 @@ T NDArray::t(const Nd4jLong i, const Nd4jLong j) const {
|
||||||
return *(reinterpret_cast<T*>(bufferWithOffset(offset)));
|
return *(reinterpret_cast<T*>(bufferWithOffset(offset)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
T NDArray::t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong w) const {
|
||||||
|
|
||||||
|
if (rankOf() != 4 || i >= sizeAt(0) || j >= sizeAt(1) || k >= sizeAt(2) || w >= sizeAt(3))
|
||||||
|
throw std::invalid_argument("NDArray::t(i,j,k,w): one of input indexes is out of array length or rank!=4!");
|
||||||
|
if (DataTypeUtils::fromT<T>() != _dataType)
|
||||||
|
throw std::invalid_argument("NDArray::t(i,j,k,w): type of array is not equal to template type T!");
|
||||||
|
|
||||||
|
if(!isActualOnHostSide())
|
||||||
|
syncToHost();
|
||||||
|
|
||||||
|
Nd4jLong coords[4] = {i, j, k, w};
|
||||||
|
auto offset = shape::getOffset(getShapeInfo(), coords);
|
||||||
|
tickReadHost();
|
||||||
|
return *(reinterpret_cast<T*>(bufferWithOffset(offset)));
|
||||||
|
}
|
||||||
|
|
||||||
#ifndef __JAVACPP_HACK__
|
#ifndef __JAVACPP_HACK__
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
std::shared_ptr<DataBuffer> NDArray::getDataBuffer() const {
|
std::shared_ptr<DataBuffer> NDArray::getDataBuffer() const {
|
||||||
|
|
|
@ -2348,7 +2348,7 @@ NDArray NDArray::operator-(const NDArray& other) const {
|
||||||
NDArray result(getShapeInfo(), DataTypeUtils::pickPairwiseResultType(getShapeInfo(), other.getShapeInfo()), false, getContext());
|
NDArray result(getShapeInfo(), DataTypeUtils::pickPairwiseResultType(getShapeInfo(), other.getShapeInfo()), false, getContext());
|
||||||
|
|
||||||
NDArray::prepareSpecialUse({&result}, {this, &other});
|
NDArray::prepareSpecialUse({&result}, {this, &other});
|
||||||
NativeOpExecutioner::execPairwiseTransform(getContext(), nd4j::pairwise::Subtract, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), getSpecialShapeInfo(), nullptr);
|
NativeOpExecutioner::execPairwiseTransform(getContext(), nd4j::pairwise::Subtract, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), result.getSpecialShapeInfo(), nullptr);
|
||||||
NDArray::registerSpecialUse({&result}, {this, &other});
|
NDArray::registerSpecialUse({&result}, {this, &other});
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
|
@ -2394,7 +2394,7 @@ NDArray NDArray::operator/(const NDArray& other) const {
|
||||||
NDArray result(getShapeInfo(), DataTypeUtils::pickPairwiseResultType(getShapeInfo(), other.getShapeInfo()), false, getContext());
|
NDArray result(getShapeInfo(), DataTypeUtils::pickPairwiseResultType(getShapeInfo(), other.getShapeInfo()), false, getContext());
|
||||||
|
|
||||||
NDArray::prepareSpecialUse({&result}, {this, &other});
|
NDArray::prepareSpecialUse({&result}, {this, &other});
|
||||||
NativeOpExecutioner::execPairwiseTransform(getContext(), nd4j::pairwise::Divide, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), getSpecialShapeInfo(), nullptr);
|
NativeOpExecutioner::execPairwiseTransform(getContext(), nd4j::pairwise::Divide, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), result.buffer(), result.getShapeInfo(), result.specialBuffer(), result.getSpecialShapeInfo(), nullptr);
|
||||||
NDArray::registerSpecialUse({&result}, {this, &other});
|
NDArray::registerSpecialUse({&result}, {this, &other});
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
|
|
|
@ -46,7 +46,7 @@ namespace nd4j {
|
||||||
getOpDescriptor()
|
getOpDescriptor()
|
||||||
->setAllowedInputTypes(nd4j::DataType::ANY)
|
->setAllowedInputTypes(nd4j::DataType::ANY)
|
||||||
->setAllowedOutputTypes(0, DataType::INHERIT)
|
->setAllowedOutputTypes(0, DataType::INHERIT)
|
||||||
->setAllowedOutputTypes(1, DataType::INT64);
|
->setAllowedOutputTypes(1, {ALL_INTS});
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -35,6 +35,8 @@ namespace nd4j {
|
||||||
int width;
|
int width;
|
||||||
int height;
|
int height;
|
||||||
auto inRank = image->rankOf();
|
auto inRank = image->rankOf();
|
||||||
|
if (output->isEmpty()) return Status::OK();
|
||||||
|
|
||||||
REQUIRE_TRUE(inRank == 3 || inRank == 4, 0, "resize_bicubic: Source tensor should have rank 4, but %i given.", inRank);
|
REQUIRE_TRUE(inRank == 3 || inRank == 4, 0, "resize_bicubic: Source tensor should have rank 4, but %i given.", inRank);
|
||||||
REQUIRE_TRUE(output->rankOf() == inRank, 0, "resize_bicubic: Source tensor and output should have the same rank, but %i and %i given.", inRank, output->rankOf());
|
REQUIRE_TRUE(output->rankOf() == inRank, 0, "resize_bicubic: Source tensor and output should have the same rank, but %i and %i given.", inRank, output->rankOf());
|
||||||
REQUIRE_TRUE(size->rankOf() == 1, size->lengthOf() == 2, 0, "resize_bicubic: Resize params is a pair of values, not %i.", size->lengthOf());
|
REQUIRE_TRUE(size->rankOf() == 1, size->lengthOf() == 2, 0, "resize_bicubic: Resize params is a pair of values, not %i.", size->lengthOf());
|
||||||
|
@ -57,7 +59,7 @@ namespace nd4j {
|
||||||
if (block.numB()> 1)
|
if (block.numB()> 1)
|
||||||
halfPixelAlign = block.getBArguments()->at(1);
|
halfPixelAlign = block.getBArguments()->at(1);
|
||||||
}
|
}
|
||||||
REQUIRE_TRUE(halfPixelAlign == false || halfPixelAlign == true && alignCorners == false, 0, "resize_bicubic: half pixel align can be used only with non-aligned corners");
|
REQUIRE_TRUE(!halfPixelAlign || (halfPixelAlign && !alignCorners), 0, "resize_bicubic: `half_pixel_centers' should be false or true only when `align_corners' is false");
|
||||||
|
|
||||||
auto source = inRank == 4?image->reshape(image->ordering(), {image->sizeAt(0), image->sizeAt(1), image->sizeAt(2), image->sizeAt(3)}):image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)});
|
auto source = inRank == 4?image->reshape(image->ordering(), {image->sizeAt(0), image->sizeAt(1), image->sizeAt(2), image->sizeAt(3)}):image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)});
|
||||||
auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}):output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)});
|
auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}):output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)});
|
||||||
|
|
|
@ -32,8 +32,10 @@ namespace nd4j {
|
||||||
NDArray* output = OUTPUT_VARIABLE(0);
|
NDArray* output = OUTPUT_VARIABLE(0);
|
||||||
int width;
|
int width;
|
||||||
int height;
|
int height;
|
||||||
bool center = false; // - default value
|
bool alignCorners = false; // - default value
|
||||||
auto inRank = image->rankOf();
|
auto inRank = image->rankOf();
|
||||||
|
if (output->isEmpty()) return Status::OK();
|
||||||
|
|
||||||
REQUIRE_TRUE( inRank == 4 || inRank == 3, 0, "resize_bilinear: input image should be 4D "
|
REQUIRE_TRUE( inRank == 4 || inRank == 3, 0, "resize_bilinear: input image should be 4D "
|
||||||
"tensor, but input has rank %i",
|
"tensor, but input has rank %i",
|
||||||
image->rankOf());
|
image->rankOf());
|
||||||
|
@ -46,21 +48,25 @@ namespace nd4j {
|
||||||
auto newImageSize = INPUT_VARIABLE(1);
|
auto newImageSize = INPUT_VARIABLE(1);
|
||||||
REQUIRE_TRUE(newImageSize->lengthOf() == 2, 0, "resize_bilinear: Resize params is a pair of values, not %i.", newImageSize->lengthOf());
|
REQUIRE_TRUE(newImageSize->lengthOf() == 2, 0, "resize_bilinear: Resize params is a pair of values, not %i.", newImageSize->lengthOf());
|
||||||
REQUIRE_TRUE(block.numI() <= 1, 0, "resize_bilinear: Resize params already given by the second param. Int params are expensive.");
|
REQUIRE_TRUE(block.numI() <= 1, 0, "resize_bilinear: Resize params already given by the second param. Int params are expensive.");
|
||||||
width = newImageSize->e<int>(0);
|
height = newImageSize->e<int>(0);
|
||||||
height = newImageSize->e<int>(1);
|
width = newImageSize->e<int>(1);
|
||||||
if (block.numI() == 1) {
|
|
||||||
center = 0 != INT_ARG(0);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
REQUIRE_TRUE(block.numI() <= 3, 0, "resize_bilinear: Neither resize width nor height are provided.");
|
REQUIRE_TRUE(block.numI() > 1, 0, "resize_bilinear: Neither resize width nor height are provided.");
|
||||||
width = INT_ARG(0);
|
height = INT_ARG(0);
|
||||||
height = INT_ARG(1);
|
width = INT_ARG(1);
|
||||||
if (block.numI() == 3)
|
|
||||||
center = 0 != INT_ARG(2);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return helpers::resizeBilinearFunctor(block.launchContext(), inRank==4?image:&source, width, height, center, inRank == 4?output:&target);
|
if (block.numB() > 0)
|
||||||
|
alignCorners = B_ARG(0);
|
||||||
|
bool halfPixelCenter = false;
|
||||||
|
|
||||||
|
if (block.numB() > 1)
|
||||||
|
halfPixelCenter = B_ARG(1);
|
||||||
|
|
||||||
|
REQUIRE_TRUE(!halfPixelCenter || (halfPixelCenter && !alignCorners), 0, "resize_bilinear: `half_pixel_centers' should be false or true only when `align_corners' is false");
|
||||||
|
|
||||||
|
return helpers::resizeBilinearFunctor(block.launchContext(), inRank==4?image:&source, width, height, alignCorners, halfPixelCenter, inRank == 4 ? output : &target);
|
||||||
}
|
}
|
||||||
|
|
||||||
DECLARE_SHAPE_FN(resize_bilinear) {
|
DECLARE_SHAPE_FN(resize_bilinear) {
|
||||||
|
@ -83,7 +89,7 @@ namespace nd4j {
|
||||||
height = newImageSize->e<int>(1);
|
height = newImageSize->e<int>(1);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
REQUIRE_TRUE(block.numI() <= 3, 0, "resize_bilinear: Neither resize width nor height are provided.");
|
REQUIRE_TRUE(block.numI() == 2, 0, "resize_bilinear: Neither resize width nor height are provided.");
|
||||||
width = INT_ARG(0);
|
width = INT_ARG(0);
|
||||||
height = INT_ARG(1);
|
height = INT_ARG(1);
|
||||||
}
|
}
|
||||||
|
@ -101,7 +107,12 @@ namespace nd4j {
|
||||||
outputShape[2] = height;
|
outputShape[2] = height;
|
||||||
outputShape[3] = in[3];
|
outputShape[3] = in[3];
|
||||||
}
|
}
|
||||||
|
if (DataTypeUtils::isR(ArrayOptions::dataType(in))) {
|
||||||
ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in));
|
ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in));
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
ShapeUtils::updateStridesAndType(outputShape, DataType::FLOAT32, shape::order(in));
|
||||||
|
}
|
||||||
|
|
||||||
shapeList->push_back(CONSTANT(outputShape));
|
shapeList->push_back(CONSTANT(outputShape));
|
||||||
return shapeList;
|
return shapeList;
|
||||||
|
|
|
@ -31,35 +31,40 @@ namespace nd4j {
|
||||||
|
|
||||||
auto image = INPUT_VARIABLE(0);
|
auto image = INPUT_VARIABLE(0);
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
auto inRank = image->rankOf();
|
||||||
int width;
|
int width;
|
||||||
int height;
|
int height;
|
||||||
bool center = false; // - default value
|
bool alignCorners = false; // - default value
|
||||||
|
if (output->isEmpty()) return Status::OK();
|
||||||
if (block.width() > 1) {
|
if (block.width() > 1) {
|
||||||
auto newImageSize = INPUT_VARIABLE(1);
|
auto newImageSize = INPUT_VARIABLE(1);
|
||||||
REQUIRE_TRUE(newImageSize->lengthOf() == 2, 0, "resize_nearest_neighbor: Resize params is a pair of values, not %i.", newImageSize->lengthOf());
|
REQUIRE_TRUE(newImageSize->lengthOf() == 2, 0, "resize_nearest_neighbor: Resize params is a pair of values, not %i.", newImageSize->lengthOf());
|
||||||
REQUIRE_TRUE(block.numI() <= 1, 0, "resize_nearest_neighbor: Resize params already given by the second param. Int params are expensive.");
|
REQUIRE_TRUE(block.numI() <= 1, 0, "resize_nearest_neighbor: Resize params already given by the second param. Int params are expensive.");
|
||||||
width = newImageSize->e<int>(0);
|
height = newImageSize->e<int>(0);
|
||||||
height = newImageSize->e<int>(1);
|
width = newImageSize->e<int>(1);
|
||||||
if (block.numI() == 1) {
|
|
||||||
center = 0 != INT_ARG(0);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
REQUIRE_TRUE(block.numI() <= 3, 0, "resize_nearest_neighbor: Neither resize width nor height are provided.");
|
REQUIRE_TRUE(block.numI() == 2, 0, "resize_nearest_neighbor: Neither resize width nor height are provided.");
|
||||||
width = INT_ARG(0);
|
height = INT_ARG(0);
|
||||||
height = INT_ARG(1);
|
width = INT_ARG(1);
|
||||||
if (block.numI() == 3)
|
|
||||||
center = 0 != INT_ARG(2);
|
|
||||||
}
|
}
|
||||||
auto inRank = image->rankOf();
|
if (block.numB() > 0)
|
||||||
|
alignCorners = B_ARG(0);
|
||||||
|
bool halfPixelCenter = false;
|
||||||
|
|
||||||
|
if (block.numB() > 1)
|
||||||
|
halfPixelCenter = B_ARG(1);
|
||||||
|
REQUIRE_TRUE(width <= (1 << 24) || height <= (1 << 24), 0, "resize_nearest_neighbour: the image resize should be limited to 2^24 pixels both for height and width, but %d and %d were given.", height, width);
|
||||||
REQUIRE_TRUE(inRank == 4 || inRank == 3, 0, "resize_nearest_neighbor: Input should be 4D tensor, but rank %i occured");
|
REQUIRE_TRUE(inRank == 4 || inRank == 3, 0, "resize_nearest_neighbor: Input should be 4D tensor, but rank %i occured");
|
||||||
REQUIRE_TRUE(inRank == output->rankOf(), 0, "resize_nearest_neighbor: Input and output ranks should be equals, but %i and %i occured.", inRank, output->rankOf());
|
REQUIRE_TRUE(inRank == output->rankOf(), 0, "resize_nearest_neighbor: Input and output ranks should be equals, but %i and %i occured.", inRank, output->rankOf());
|
||||||
REQUIRE_TRUE(image->dataType() == output->dataType(), 0, "resize_nearest_neighbor: Input and output types should be the same, but `%s' occured instead.", DataTypeUtils::asString(output->dataType()).c_str());
|
REQUIRE_TRUE(image->dataType() == output->dataType(), 0, "resize_nearest_neighbor: Input and output types should be the same, but `%s' occured instead.", DataTypeUtils::asString(output->dataType()).c_str());
|
||||||
auto source = inRank == 4?*image:image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)});
|
REQUIRE_TRUE(!halfPixelCenter || (halfPixelCenter && !alignCorners), 0, "resize_nearest_neighbor: `half_pixel_centers' should be false or true only when `align_corners' is false");
|
||||||
|
REQUIRE_TRUE(((alignCorners && height > 2) || (height > 0)) && ((alignCorners && width > 1) || (width > 0)), 0, "resize_nearest_neighbor: Wrong input or output size to resize (width = %d, height = %d)", width, height);
|
||||||
|
|
||||||
|
auto source = inRank == 4?*image:image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)});
|
||||||
auto target = inRank == 4?*output:output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)});
|
auto target = inRank == 4?*output:output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)});
|
||||||
|
|
||||||
return helpers::resizeNeighborFunctor(block.launchContext(), inRank==4?image:&source, width, height, center, inRank == 4?output:&target);
|
return helpers::resizeNeighborFunctor(block.launchContext(), inRank==4?image:&source, width, height, alignCorners, halfPixelCenter, inRank == 4 ? output : &target);
|
||||||
}
|
}
|
||||||
|
|
||||||
DECLARE_SHAPE_FN(resize_nearest_neighbor) {
|
DECLARE_SHAPE_FN(resize_nearest_neighbor) {
|
||||||
|
|
|
@ -120,6 +120,27 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Half pixel scaler scales assuming that the pixel centers are at 0.5, i.e. the
|
||||||
|
// floating point coordinates of the top,left pixel is 0.5,0.5.
|
||||||
|
struct HalfPixelScalerNN {
|
||||||
|
HalfPixelScalerNN(){};
|
||||||
|
inline float operator()(const int x, const float scale) const {
|
||||||
|
// Note that we subtract 0.5 from the return value, as the existing bilinear
|
||||||
|
// sampling code etc assumes pixels are in the old coordinate system.
|
||||||
|
return (static_cast<float>(x) + 0.5f) * scale;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Older incorrect scaling method that causes all resizes to have a slight
|
||||||
|
// translation leading to inconsistent results. For example, a flip then a
|
||||||
|
// resize gives different results then a resize then a flip.
|
||||||
|
struct LegacyScaler {
|
||||||
|
LegacyScaler(){};
|
||||||
|
inline float operator()(const int x, const float scale) const {
|
||||||
|
return static_cast<float>(x) * scale;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
struct WeightsAndIndices {
|
struct WeightsAndIndices {
|
||||||
float _weight0;
|
float _weight0;
|
||||||
float _weight1;
|
float _weight1;
|
||||||
|
@ -133,7 +154,8 @@ namespace helpers {
|
||||||
int _advance; // advance value.
|
int _advance; // advance value.
|
||||||
};
|
};
|
||||||
|
|
||||||
inline void computeInterpolationWeights(Nd4jLong outSize,
|
template <class Scaler>
|
||||||
|
inline void computeInterpolationWeights(const Scaler scaler, Nd4jLong outSize,
|
||||||
Nd4jLong inSize,
|
Nd4jLong inSize,
|
||||||
double scale,
|
double scale,
|
||||||
BilinearInterpolationData *interpolationData) {
|
BilinearInterpolationData *interpolationData) {
|
||||||
|
@ -143,10 +165,12 @@ namespace helpers {
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
for (auto k = start; k < stop; k++) {
|
for (auto k = start; k < stop; k++) {
|
||||||
auto i = (outSize - k - 1);
|
auto i = (outSize - k - 1);
|
||||||
double in = i * scale;
|
double const in = scaler(i, scale);
|
||||||
interpolationData[i]._bottomIndex = static_cast<Nd4jLong>(in);
|
double const in_f = nd4j::math::nd4j_floor<double, double>(in);
|
||||||
interpolationData[i]._topIndex = nd4j::math::nd4j_min(interpolationData[i]._bottomIndex + 1, inSize - 1);
|
double const in_c = nd4j::math::nd4j_ceil<double, double>(in);
|
||||||
interpolationData[i]._interpolarValue = in - interpolationData[i]._bottomIndex;
|
interpolationData[i]._bottomIndex = nd4j::math::nd4j_max(static_cast<Nd4jLong>(in_f), (Nd4jLong)0LL);//static_cast<Nd4jLong>(in);
|
||||||
|
interpolationData[i]._topIndex = nd4j::math::nd4j_min(static_cast<Nd4jLong>(in_c), inSize - 1);
|
||||||
|
interpolationData[i]._interpolarValue = in - in_f;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
samediff::Threads::parallel_for(func, 0, outSize);
|
samediff::Threads::parallel_for(func, 0, outSize);
|
||||||
|
@ -156,29 +180,29 @@ namespace helpers {
|
||||||
* Computes the bilinear interpolation from the appropriate 4 float points
|
* Computes the bilinear interpolation from the appropriate 4 float points
|
||||||
* and the linear interpolation weights.
|
* and the linear interpolation weights.
|
||||||
*/
|
*/
|
||||||
static void
|
// static void
|
||||||
resizeImage(NDArray const *images, Nd4jLong batchSize, Nd4jLong inHeight, Nd4jLong inWidth, Nd4jLong outHeight,
|
// resizeImage(NDArray const *images, Nd4jLong batchSize, Nd4jLong inHeight, Nd4jLong inWidth, Nd4jLong outHeight,
|
||||||
Nd4jLong outWidth, Nd4jLong channels,
|
// Nd4jLong outWidth, Nd4jLong channels,
|
||||||
std::vector<BilinearInterpolationData> const& xs,
|
// std::vector<BilinearInterpolationData> const& xs,
|
||||||
std::vector<BilinearInterpolationData> const& ys,
|
// std::vector<BilinearInterpolationData> const& ys,
|
||||||
NDArray *output);
|
// NDArray *output);
|
||||||
|
|
||||||
template<typename T>
|
template<typename T, typename Z>
|
||||||
static void
|
static void
|
||||||
resizeImage_(NDArray const *images, Nd4jLong batchSize, Nd4jLong inHeight, Nd4jLong inWidth, Nd4jLong outHeight,
|
resizeImage_(T const* pInputBuf, Nd4jLong batchSize, Nd4jLong inHeight, Nd4jLong inWidth, Nd4jLong outHeight,
|
||||||
Nd4jLong outWidth, Nd4jLong channels,
|
Nd4jLong outWidth, Nd4jLong channels,
|
||||||
std::vector<BilinearInterpolationData> const &xs,
|
std::vector<BilinearInterpolationData> const &xs,
|
||||||
std::vector<BilinearInterpolationData> const &ys,
|
std::vector<BilinearInterpolationData> const &ys,
|
||||||
NDArray *output) {
|
Z* pOutputBuf) {
|
||||||
|
|
||||||
Nd4jLong inRowSize = inWidth * channels;
|
Nd4jLong inRowSize = inWidth * channels;
|
||||||
Nd4jLong inBatchNumValues = inHeight * inRowSize;
|
Nd4jLong inBatchNumValues = inHeight * inRowSize;
|
||||||
Nd4jLong outRowSize = outWidth * channels;
|
Nd4jLong outRowSize = outWidth * channels;
|
||||||
|
|
||||||
T const *pInputBuf = images->getDataBuffer()->primaryAsT<T>(); // this works only with 'c' direction
|
// T const *pInputBuf = images->getDataBuffer()->primaryAsT<T>(); // this works only with 'c' direction
|
||||||
BilinearInterpolationData const* xsPtr = xs.data();
|
BilinearInterpolationData const* xsPtr = xs.data();
|
||||||
|
|
||||||
T* pOutputBuf = output->dataBuffer()->primaryAsT<T>();
|
// T* pOutputBuf = output->dataBuffer()->primaryAsT<T>();
|
||||||
auto computeBilinear = [](double topLeft, double topRight,
|
auto computeBilinear = [](double topLeft, double topRight,
|
||||||
double bottomLeft, double bottomRight,
|
double bottomLeft, double bottomRight,
|
||||||
double xVal, double yVal) {
|
double xVal, double yVal) {
|
||||||
|
@ -214,8 +238,12 @@ namespace helpers {
|
||||||
samediff::Threads::parallel_tad(func, 0, batchSize);
|
samediff::Threads::parallel_tad(func, 0, batchSize);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename T>
|
template<typename X, typename Z>
|
||||||
static int resizeBilinearFunctor_(NDArray const *images, int width, int height, bool center, NDArray *output) {
|
static int resizeBilinearFunctor_(NDArray const *images, int const width, int const height, bool const alignCorners,
|
||||||
|
bool const halfPixelCenter, NDArray *output) {
|
||||||
|
ImageResizerState st(alignCorners, halfPixelCenter);
|
||||||
|
st.validateAndCalculateOutputSize(images, width, height);
|
||||||
|
|
||||||
const Nd4jLong batchSize = images->sizeAt(0);
|
const Nd4jLong batchSize = images->sizeAt(0);
|
||||||
const Nd4jLong inHeight = images->sizeAt(1);
|
const Nd4jLong inHeight = images->sizeAt(1);
|
||||||
const Nd4jLong inWidth = images->sizeAt(2);
|
const Nd4jLong inWidth = images->sizeAt(2);
|
||||||
|
@ -230,28 +258,20 @@ namespace helpers {
|
||||||
return ND4J_STATUS_OK;
|
return ND4J_STATUS_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Special case for TF compatibility
|
|
||||||
if((center && inHeight < 2) || (center && inWidth < 2)){
|
|
||||||
center = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if ((center && inHeight < 2) || (inHeight < 1) || (outHeight < 1) || (center && outHeight < 2) ||
|
|
||||||
(center && inWidth < 2) || (inWidth < 1) || (outWidth < 1) || (center && outWidth < 2)) {
|
|
||||||
// wrong input data
|
|
||||||
nd4j_printf("image.resize_bilinear: Wrong input or output size to resize\n", "");
|
|
||||||
return ND4J_STATUS_BAD_ARGUMENTS;
|
|
||||||
}
|
|
||||||
float heightScale = center ? (inHeight - 1.f) / double(outHeight - 1.f) : (inHeight / float(outHeight));
|
|
||||||
float widthScale = center ? (inWidth - 1.f) / double(outWidth - 1.f) : (inWidth / float(outWidth));
|
|
||||||
|
|
||||||
std::vector<BilinearInterpolationData> ys(outHeight + 1);
|
std::vector<BilinearInterpolationData> ys(outHeight + 1);
|
||||||
std::vector<BilinearInterpolationData> xs(outWidth + 1);
|
std::vector<BilinearInterpolationData> xs(outWidth + 1);
|
||||||
|
if (halfPixelCenter) {
|
||||||
// Compute the cached interpolation weights on the x and y dimensions.
|
computeInterpolationWeights(HalfPixelScaler(), outHeight, inHeight, st.heightScale,
|
||||||
computeInterpolationWeights(outHeight, inHeight, heightScale,
|
|
||||||
ys.data());
|
ys.data());
|
||||||
computeInterpolationWeights(outWidth, inWidth, widthScale, xs.data());
|
computeInterpolationWeights(HalfPixelScaler(), outWidth, inWidth, st.widthScale, xs.data());
|
||||||
|
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
// Compute the cached interpolation weights on the x and y dimensions.
|
||||||
|
computeInterpolationWeights(LegacyScaler(), outHeight, inHeight, st.heightScale,
|
||||||
|
ys.data());
|
||||||
|
computeInterpolationWeights(LegacyScaler(), outWidth, inWidth, st.widthScale, xs.data());
|
||||||
|
}
|
||||||
int xsSize = xs.size();
|
int xsSize = xs.size();
|
||||||
// Scale x interpolation weights to avoid a multiplication during iteration.
|
// Scale x interpolation weights to avoid a multiplication during iteration.
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
|
@ -262,71 +282,84 @@ namespace helpers {
|
||||||
};
|
};
|
||||||
samediff::Threads::parallel_for(func, 0, xsSize);
|
samediff::Threads::parallel_for(func, 0, xsSize);
|
||||||
|
|
||||||
resizeImage(images, batchSize, inHeight, inWidth, outHeight, outWidth, channels, xs, ys, output);
|
resizeImage_<X,Z>(images->getDataBuffer()->primaryAsT<X>(), batchSize, inHeight, inWidth, outHeight, outWidth, channels, xs, ys, output->dataBuffer()->primaryAsT<Z>());
|
||||||
return ND4J_STATUS_OK;
|
return ND4J_STATUS_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename T>
|
template <class Scaler, typename T>
|
||||||
int resizeNeighborFunctor_(NDArray const *images, int width, int height, bool center, NDArray *output) {
|
void resizeNeighbor(ImageResizerState const& st, NDArray const *images, bool const alignCorners, bool const halfPixelCenter, NDArray *output) {
|
||||||
const Nd4jLong batchSize = images->sizeAt(0);
|
const Nd4jLong batchSize = st.batchSize;
|
||||||
const Nd4jLong inHeight = images->sizeAt(1);
|
const Nd4jLong inHeight = st.inHeight;
|
||||||
const Nd4jLong inWidth = images->sizeAt(2);
|
const Nd4jLong inWidth = st.inWidth;
|
||||||
const Nd4jLong channels = images->sizeAt(3);
|
const Nd4jLong channels = st.channels;
|
||||||
|
|
||||||
const Nd4jLong outHeight = output->sizeAt(1);
|
const Nd4jLong outHeight = st.outHeight;
|
||||||
const Nd4jLong outWidth = output->sizeAt(2);
|
const Nd4jLong outWidth = st.outWidth;
|
||||||
|
Scaler scaler;
|
||||||
// Handle no-op resizes efficiently.
|
|
||||||
if (outHeight == inHeight && outWidth == inWidth) {
|
|
||||||
output->assign(images);
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
if ((center && inHeight < 2) || (inHeight < 1) || (outHeight < 1) || (center && outHeight < 2) ||
|
|
||||||
(center && inWidth < 2) || (inWidth < 1) || (outWidth < 1) || (center && outWidth < 2)) {
|
|
||||||
// wrong input data
|
|
||||||
nd4j_printf("image.resize_nearest_neighbor: Wrong input or output size to resize\n", "");
|
|
||||||
return ND4J_STATUS_BAD_ARGUMENTS;
|
|
||||||
}
|
|
||||||
double heightScale = center ? (inHeight - 1.) / double(outHeight - 1.0) : (inHeight / double(outHeight));
|
|
||||||
double widthScale = center ? (inWidth - 1.) / double(outWidth - 1.0) : (inWidth / double(outWidth));
|
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR_2D {
|
auto func = PRAGMA_THREADS_FOR_2D {
|
||||||
for (auto b = start_x; b < stop_x; b += inc_x) {
|
for (auto b = start_x; b < stop_x; b += inc_x) {
|
||||||
for (auto y = start_y; y < stop_y; y += inc_y) {
|
for (auto y = start_y; y < stop_y; y += inc_y) {
|
||||||
Nd4jLong inY = nd4j::math::nd4j_min((center) ? static_cast<Nd4jLong>(nd4j::math::p_round<float>(y * heightScale)) : static_cast<Nd4jLong>(nd4j::math::p_floor<float>(y * heightScale)), inHeight - 1);
|
auto posY = alignCorners ? static_cast<Nd4jLong>(nd4j::math::p_round<float>(scaler(y, st.heightScale))) : static_cast<Nd4jLong>(nd4j::math::p_floor<float>(scaler(y, st.heightScale)));
|
||||||
|
Nd4jLong inY = nd4j::math::nd4j_min(posY, inHeight - 1);
|
||||||
|
if (halfPixelCenter) {
|
||||||
|
inY = nd4j::math::nd4j_max(0LL, inY);
|
||||||
|
}
|
||||||
for (auto x = 0; x < outWidth; ++x) {
|
for (auto x = 0; x < outWidth; ++x) {
|
||||||
Nd4jLong inX = nd4j::math::nd4j_min((center) ? static_cast<Nd4jLong>(nd4j::math::p_round<float>(x * widthScale)) : static_cast<Nd4jLong>(nd4j::math::p_floor<float>(x * widthScale)),inWidth - 1);
|
auto posX = alignCorners ? static_cast<Nd4jLong>(nd4j::math::p_round<float>(scaler(x, st.widthScale))) : static_cast<Nd4jLong>(nd4j::math::p_floor<float>(scaler(x, st.widthScale)));
|
||||||
|
Nd4jLong inX = nd4j::math::nd4j_min(posX,inWidth - 1);
|
||||||
|
if (halfPixelCenter) {
|
||||||
|
inX = nd4j::math::nd4j_max(0LL, inX);
|
||||||
|
}
|
||||||
|
// copy pixel over all channels
|
||||||
for (auto e = 0; e < channels; e++)
|
for (auto e = 0; e < channels; e++)
|
||||||
output->p(b, y, x, e, images->e<T>(b, inY, inX, e));
|
output->t<T>(b, y, x, e) = images->t<T>(b, inY, inX, e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
samediff::Threads::parallel_for(func, 0, batchSize, 1, 0, outHeight, 1);
|
samediff::Threads::parallel_for(func, 0, batchSize, 1, 0, outHeight, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
int resizeNeighborFunctor_(NDArray const *images, int const width, int const height, bool const alignCorners, bool const halfPixelCenter, NDArray *output) {
|
||||||
|
ImageResizerState st(alignCorners, halfPixelCenter);
|
||||||
|
st.validateAndCalculateOutputSize(images, width, height);
|
||||||
|
|
||||||
|
// Handle no-op resizes efficiently.
|
||||||
|
if (output->sizeAt(1) == images->sizeAt(1) && output->sizeAt(2) == images->sizeAt(2)) {
|
||||||
|
output->assign(images);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (halfPixelCenter)
|
||||||
|
resizeNeighbor<HalfPixelScalerNN, T>(st, images, alignCorners, true, output);
|
||||||
|
else
|
||||||
|
resizeNeighbor<LegacyScaler, T>(st, images, alignCorners, false, output);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
void resizeImage(NDArray const *images, Nd4jLong batchSize, Nd4jLong inHeight, Nd4jLong inWidth, Nd4jLong outHeight,
|
// void resizeImage(NDArray const *images, Nd4jLong batchSize, Nd4jLong inHeight, Nd4jLong inWidth, Nd4jLong outHeight,
|
||||||
Nd4jLong outWidth, Nd4jLong channels,
|
// Nd4jLong outWidth, Nd4jLong channels,
|
||||||
std::vector<BilinearInterpolationData> const &xs,
|
// std::vector<BilinearInterpolationData> const &xs,
|
||||||
std::vector<BilinearInterpolationData> const &ys,
|
// std::vector<BilinearInterpolationData> const &ys,
|
||||||
NDArray *output) {
|
// NDArray *output) {
|
||||||
BUILD_SINGLE_SELECTOR(images->dataType(), resizeImage_,
|
// BUILD_DOUBLE_SELECTOR(images->dataType(), output->dataType(), resizeImage_,
|
||||||
(images, batchSize, inHeight, inWidth, outHeight, outWidth, channels, xs, ys, output),
|
// (images, batchSize, inHeight, inWidth, outHeight, outWidth, channels, xs, ys, output),
|
||||||
LIBND4J_TYPES);
|
// NUMERIC_TYPES, FLOAT_TYPES);
|
||||||
|
// }
|
||||||
|
|
||||||
|
int resizeBilinearFunctor(nd4j::LaunchContext * context, NDArray const *images, int const width, int const height,
|
||||||
|
bool const alignCorners, bool const halfPixelCenter, NDArray *output) {
|
||||||
|
BUILD_DOUBLE_SELECTOR(images->dataType(), output->dataType(), return resizeBilinearFunctor_,
|
||||||
|
(images, width, height, alignCorners, halfPixelCenter, output), NUMERIC_TYPES, FLOAT_TYPES);
|
||||||
}
|
}
|
||||||
|
|
||||||
int resizeBilinearFunctor(nd4j::LaunchContext * context, NDArray const *images, int width, int height, bool center, NDArray *output) {
|
int resizeNeighborFunctor(nd4j::LaunchContext * context, NDArray const *images, int const width, int const height,
|
||||||
BUILD_SINGLE_SELECTOR(images->dataType(), return resizeBilinearFunctor_,
|
bool const alignCorners, bool const halfPixelCenter, NDArray *output) {
|
||||||
(images, width, height, center, output), LIBND4J_TYPES);
|
|
||||||
}
|
|
||||||
|
|
||||||
int resizeNeighborFunctor(nd4j::LaunchContext * context, NDArray const *images, int width, int height, bool center, NDArray *output) {
|
|
||||||
BUILD_SINGLE_SELECTOR(images->dataType(), return resizeNeighborFunctor_,
|
BUILD_SINGLE_SELECTOR(images->dataType(), return resizeNeighborFunctor_,
|
||||||
(images, width, height, center, output), LIBND4J_TYPES);
|
(images, width, height, alignCorners, halfPixelCenter, output), LIBND4J_TYPES);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -586,16 +619,6 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Older incorrect scaling method that causes all resizes to have a slight
|
|
||||||
// translation leading to inconsistent results. For example, a flip then a
|
|
||||||
// resize gives different results then a resize then a flip.
|
|
||||||
struct LegacyScaler {
|
|
||||||
LegacyScaler(){};
|
|
||||||
inline float operator()(const int x, const float scale) const {
|
|
||||||
return static_cast<float>(x) * scale;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
static void computeXWeightsAndIndices(const ImageResizerState& resizer_state,
|
static void computeXWeightsAndIndices(const ImageResizerState& resizer_state,
|
||||||
const bool half_pixel_centers,
|
const bool half_pixel_centers,
|
||||||
std::vector<WeightsAndIndices>* x_wais) {
|
std::vector<WeightsAndIndices>* x_wais) {
|
||||||
|
@ -847,7 +870,7 @@ namespace helpers {
|
||||||
// simplified bicubic resize without antialiasing
|
// simplified bicubic resize without antialiasing
|
||||||
//
|
//
|
||||||
template <typename T>
|
template <typename T>
|
||||||
int resizeBicubicFunctorA_(nd4j::LaunchContext * context, NDArray const* image, int width, int height,
|
int resizeBicubicFunctorA_(nd4j::LaunchContext * context, NDArray const* image, int const width, int const height,
|
||||||
bool const alignCorners, bool const halfPixelAlign, NDArray* output) {
|
bool const alignCorners, bool const halfPixelAlign, NDArray* output) {
|
||||||
ImageResizerState st(alignCorners, halfPixelAlign); // align_corners, half_pixel_align
|
ImageResizerState st(alignCorners, halfPixelAlign); // align_corners, half_pixel_align
|
||||||
int res = st.validateAndCreateOutput(image, width, height);
|
int res = st.validateAndCreateOutput(image, width, height);
|
||||||
|
@ -856,17 +879,17 @@ namespace helpers {
|
||||||
|
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
int resizeBicubicFunctorA(nd4j::LaunchContext * context, NDArray const* image, int width, int height,
|
int resizeBicubicFunctorA(nd4j::LaunchContext * context, NDArray const* image, int const width, int const height,
|
||||||
bool const alignCorners, bool const halfPixelAlign, NDArray* output) {
|
bool const alignCorners, bool const halfPixelAlign, NDArray* output) {
|
||||||
BUILD_SINGLE_SELECTOR(image->dataType(), return resizeBicubicFunctorA_, (context,
|
BUILD_SINGLE_SELECTOR(image->dataType(), return resizeBicubicFunctorA_, (context,
|
||||||
image, width, height, alignCorners, halfPixelAlign, output), NUMERIC_TYPES);
|
image, width, height, alignCorners, halfPixelAlign, output), NUMERIC_TYPES);
|
||||||
}
|
}
|
||||||
// ------------------------------------------------------------------------------------------------------------------ //
|
// ------------------------------------------------------------------------------------------------------------------ //
|
||||||
int resizeFunctor(nd4j::LaunchContext * context, NDArray const* image, int width, int height,
|
int resizeFunctor(nd4j::LaunchContext * context, NDArray const* image, int const width, int const height,
|
||||||
ImageResizeMethods method, bool preserveAspectRatio, bool antialias, NDArray* output) {
|
ImageResizeMethods method, bool preserveAspectRatio, bool antialias, NDArray* output) {
|
||||||
switch (method) {
|
switch (method) {
|
||||||
case kResizeBilinear: return resizeBilinearFunctor(context, image, width, height, false, output); break;
|
case kResizeBilinear: return resizeBilinearFunctor(context, image, width, height, false, false, output); break;
|
||||||
case kResizeNearest: return resizeNeighborFunctor(context, image, width, height, false, output); break;
|
case kResizeNearest: return resizeNeighborFunctor(context, image, width, height, false, false, output); break;
|
||||||
case kResizeBicubic: return resizeBicubicFunctor(context, image, width, height, preserveAspectRatio, antialias, output); break;
|
case kResizeBicubic: return resizeBicubicFunctor(context, image, width, height, preserveAspectRatio, antialias, output); break;
|
||||||
case kResizeLanczos5:
|
case kResizeLanczos5:
|
||||||
case kResizeGaussian:
|
case kResizeGaussian:
|
||||||
|
|
|
@ -13,6 +13,20 @@
|
||||||
*
|
*
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
//
|
//
|
||||||
// @author sgazeos@gmail.com
|
// @author sgazeos@gmail.com
|
||||||
|
@ -32,6 +46,38 @@ namespace helpers {
|
||||||
// https://en.wikipedia.org/wiki/Bilinear_interpolation)
|
// https://en.wikipedia.org/wiki/Bilinear_interpolation)
|
||||||
double interpolarValue;
|
double interpolarValue;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Older incorrect scaling method that causes all resizes to have a slight
|
||||||
|
// translation leading to inconsistent results. For example, a flip then a
|
||||||
|
// resize gives different results then a resize then a flip.
|
||||||
|
struct LegacyScaler {
|
||||||
|
_CUDA_HD LegacyScaler(){};
|
||||||
|
inline _CUDA_HD float operator()(const int x, const float scale) const {
|
||||||
|
return static_cast<float>(x) * scale;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Half pixel scaler scales assuming that the pixel centers are at 0.5, i.e. the
|
||||||
|
// floating point coordinates of the top,left pixel is 0.5,0.5.
|
||||||
|
struct HalfPixelScaler {
|
||||||
|
_CUDA_HD HalfPixelScaler(){};
|
||||||
|
inline _CUDA_HD float operator()(const int x, const float scale) const {
|
||||||
|
// Note that we subtract 0.5 from the return value, as the existing bilinear
|
||||||
|
// sampling code etc assumes pixels are in the old coordinate system.
|
||||||
|
return (static_cast<float>(x) + 0.5f) * scale - 0.5f;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
// Utility functions
|
||||||
|
// calculateResizeScale determines the float scaling factor.
|
||||||
|
inline float calculateResizeScale(Nd4jLong inSize, Nd4jLong outSize,
|
||||||
|
bool alignCorners) {
|
||||||
|
return (alignCorners && outSize > 1)
|
||||||
|
? (inSize - 1) / static_cast<float>(outSize - 1)
|
||||||
|
: inSize / static_cast<float>(outSize);
|
||||||
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
// computeInterpolationWeights kernel
|
// computeInterpolationWeights kernel
|
||||||
// outSize - output length
|
// outSize - output length
|
||||||
|
@ -39,6 +85,7 @@ namespace helpers {
|
||||||
// scale - input scale
|
// scale - input scale
|
||||||
// interporationData - result
|
// interporationData - result
|
||||||
//
|
//
|
||||||
|
template <class Scaler>
|
||||||
static __global__ void computeInterpolationWeights(Nd4jLong outSize,
|
static __global__ void computeInterpolationWeights(Nd4jLong outSize,
|
||||||
Nd4jLong inSize,
|
Nd4jLong inSize,
|
||||||
double scale,
|
double scale,
|
||||||
|
@ -48,12 +95,18 @@ namespace helpers {
|
||||||
interpolationData[outSize].topIndex = 0;
|
interpolationData[outSize].topIndex = 0;
|
||||||
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
auto step = blockDim.x * gridDim.x;
|
auto step = blockDim.x * gridDim.x;
|
||||||
|
Scaler scaler;
|
||||||
for (Nd4jLong i = outSize - tid; i >= 0; i -= step) {
|
for (Nd4jLong i = outSize - tid; i >= 0; i -= step) {
|
||||||
double in = i * scale;
|
double in = scaler(i, scale);
|
||||||
interpolationData[i].bottomIndex = static_cast<Nd4jLong>(in);
|
// interpolationData[i].bottomIndex = static_cast<Nd4jLong>(in);
|
||||||
interpolationData[i].topIndex = nd4j::math::nd4j_min(interpolationData[i].bottomIndex + 1, inSize - 1);
|
// interpolationData[i].topIndex = nd4j::math::nd4j_min(interpolationData[i].bottomIndex + 1, inSize - 1);
|
||||||
interpolationData[i].interpolarValue = in - interpolationData[i].bottomIndex;
|
// interpolationData[i].interpolarValue = in - interpolationData[i].bottomIndex;
|
||||||
|
double const in_f = nd4j::math::p_floor<double>(in);
|
||||||
|
double const in_c = nd4j::math::p_ceil<double>(in);
|
||||||
|
interpolationData[i].bottomIndex = nd4j::math::nd4j_max(static_cast<Nd4jLong>(in_f), (Nd4jLong)0LL);//static_cast<Nd4jLong>(in);
|
||||||
|
interpolationData[i].topIndex = nd4j::math::nd4j_min(static_cast<Nd4jLong>(in_c), inSize - 1);
|
||||||
|
interpolationData[i].interpolarValue = in - in_f;
|
||||||
|
|
||||||
if (channels) {
|
if (channels) {
|
||||||
math::atomics::nd4j_atomicMul(&interpolationData[i].bottomIndex, channels);
|
math::atomics::nd4j_atomicMul(&interpolationData[i].bottomIndex, channels);
|
||||||
math::atomics::nd4j_atomicMul(&interpolationData[i].topIndex, channels);
|
math::atomics::nd4j_atomicMul(&interpolationData[i].topIndex, channels);
|
||||||
|
@ -72,31 +125,33 @@ namespace helpers {
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
// resize image with bilinear interpolation algorithm kernel
|
// resize image with bilinear interpolation algorithm kernel
|
||||||
//
|
//
|
||||||
template <typename T>
|
template <typename T, typename Z>
|
||||||
static __global__ void resizeImageKernel(T const* input, Nd4jLong const* inputShape, T* outputYptr, Nd4jLong* outputShape, Nd4jLong batchSize,
|
static __global__ void resizeImageKernel(T const* input, Nd4jLong const* inputShape, Z* outputYptr,
|
||||||
Nd4jLong outWidth, Nd4jLong outHeight, Nd4jLong channels, Nd4jLong inRowSize, Nd4jLong outRowSize, Nd4jLong inBatchNumValues,
|
Nd4jLong* outputShape, Nd4jLong batchSize, Nd4jLong outWidth, Nd4jLong outHeight, Nd4jLong channels,
|
||||||
|
Nd4jLong inRowSize, Nd4jLong outRowSize, Nd4jLong inBatchNumValues,
|
||||||
BilinearInterpolationData* xs_, BilinearInterpolationData* ys_) {
|
BilinearInterpolationData* xs_, BilinearInterpolationData* ys_) {
|
||||||
|
|
||||||
for (auto batch = blockIdx.x; batch < batchSize; batch += gridDim.x ) { // blockIdx.x as batch index
|
for (auto batch = blockIdx.x; batch < batchSize; batch += gridDim.x ) { // blockIdx.x as batch index
|
||||||
auto pX = input + batch * inBatchNumValues;
|
auto pX = input + batch * inBatchNumValues;
|
||||||
for (Nd4jLong y = threadIdx.x; y < outHeight; y += blockDim.x) {
|
for (Nd4jLong y = threadIdx.x; y < outHeight; y += blockDim.x) {
|
||||||
const T *ys_input_lower_ptr = pX + ys_[y].bottomIndex * inRowSize;
|
const T* ys_input_lower_ptr = pX + ys_[y].bottomIndex * inRowSize;
|
||||||
const T *ys_input_upper_ptr = pX + ys_[y].topIndex * inRowSize;
|
const T* ys_input_upper_ptr = pX + ys_[y].topIndex * inRowSize;
|
||||||
double yVal = ys_[y].interpolarValue;
|
double yVal = ys_[y].interpolarValue;
|
||||||
auto pZ = outputYptr + (batch * outHeight + y) * outRowSize;
|
auto pZ = outputYptr + (batch * outHeight + y) * outRowSize;
|
||||||
for (Nd4jLong x = threadIdx.y; x < outWidth; x += blockDim.y) {
|
for (Nd4jLong x = 0; x < outWidth; x++) {
|
||||||
auto xsBottom = xs_[x].bottomIndex;
|
auto xsBottom = xs_[x].bottomIndex;
|
||||||
auto xsTop = xs_[x].topIndex;
|
auto xsTop = xs_[x].topIndex;
|
||||||
auto xVal = xs_[x].interpolarValue;
|
auto xVal = xs_[x].interpolarValue;
|
||||||
// process interpolation for all channels
|
// process interpolation for all channels
|
||||||
for (int c = threadIdx.z; c < channels; c += blockDim.z) {
|
for (int c = 0; c < channels; c++) {
|
||||||
double topLeft(ys_input_lower_ptr[xsBottom + c]);
|
Z topLeft(ys_input_lower_ptr[xsBottom + c]);
|
||||||
double topRight(ys_input_lower_ptr[xsTop + c]);
|
Z topRight(ys_input_lower_ptr[xsTop + c]);
|
||||||
double bottomLeft(ys_input_upper_ptr[xsBottom + c]);
|
Z bottomLeft(ys_input_upper_ptr[xsBottom + c]);
|
||||||
double bottomRight(ys_input_upper_ptr[xsTop + c]);
|
Z bottomRight(ys_input_upper_ptr[xsTop + c]);
|
||||||
double top = topLeft + (topRight - topLeft) * xVal;
|
Z top = topLeft + (topRight - topLeft) * xVal;
|
||||||
double bottom = bottomLeft + (bottomRight - bottomLeft) * xVal;
|
Z bottom = bottomLeft + (bottomRight - bottomLeft) * xVal;
|
||||||
pZ[x * channels + c] = T(top + (bottom - top) * yVal);
|
Z resVal = Z(top + (bottom - top) * yVal);
|
||||||
|
pZ[x * channels + c] = resVal;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -105,7 +160,7 @@ namespace helpers {
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
// resize image with
|
// resize image with
|
||||||
template <typename T>
|
template <typename T, typename F>
|
||||||
static void resizeImage_(nd4j::LaunchContext* context, NDArray const* images, Nd4jLong batchSize, Nd4jLong inHeight, Nd4jLong inWidth, Nd4jLong outHeight,
|
static void resizeImage_(nd4j::LaunchContext* context, NDArray const* images, Nd4jLong batchSize, Nd4jLong inHeight, Nd4jLong inWidth, Nd4jLong outHeight,
|
||||||
Nd4jLong outWidth, Nd4jLong channels,
|
Nd4jLong outWidth, Nd4jLong channels,
|
||||||
BilinearInterpolationData* xs_,
|
BilinearInterpolationData* xs_,
|
||||||
|
@ -115,12 +170,13 @@ namespace helpers {
|
||||||
Nd4jLong inBatchNumValues = inHeight * inRowSize;
|
Nd4jLong inBatchNumValues = inHeight * inRowSize;
|
||||||
Nd4jLong outRowSize = outWidth * channels;
|
Nd4jLong outRowSize = outWidth * channels;
|
||||||
auto stream = context->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
T const *input_b_ptr = reinterpret_cast<T const *>(images->getSpecialBuffer()); // this works only with 'c' direction
|
T const* pInput = images->getDataBuffer()->specialAsT<T>(); //reinterpret_cast<T const *>(images->getSpecialBuffer()); // this works only with 'c' direction
|
||||||
T *output_y_ptr = reinterpret_cast<T *>(output->specialBuffer());
|
F* pOutput = output->dataBuffer()->specialAsT<F>();//reinterpret_cast<F *>(output->specialBuffer());
|
||||||
dim3 batchSizeBlock(batchSize, 1, 1);
|
dim3 batchSizeBlock(batchSize, 1, 1);
|
||||||
dim3 pictureBlock(outHeight, outWidth, channels);
|
dim3 pictureBlock(outHeight, outWidth, channels);
|
||||||
resizeImageKernel<T><<<256, pictureBlock, 256, *stream>>>(input_b_ptr, images->getSpecialShapeInfo(), output_y_ptr, output->specialShapeInfo(), batchSize,
|
resizeImageKernel<T,F><<<256, 256, 256, *stream>>>(pInput, images->getSpecialShapeInfo(), pOutput,
|
||||||
outWidth, outHeight, channels, inRowSize, outRowSize, inBatchNumValues, xs_, ys_);
|
output->specialShapeInfo(), batchSize, outWidth, outHeight, channels, inRowSize, outRowSize,
|
||||||
|
inBatchNumValues, xs_, ys_);
|
||||||
|
|
||||||
auto err = cudaStreamSynchronize(*stream);
|
auto err = cudaStreamSynchronize(*stream);
|
||||||
if (err != 0) {
|
if (err != 0) {
|
||||||
|
@ -129,8 +185,9 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
template <typename T>
|
template <typename T, typename F>
|
||||||
static int resizeBilinearFunctor_(nd4j::LaunchContext* context, NDArray const* images, int width, int height, bool center, NDArray* output) {
|
static int resizeBilinearFunctor_(nd4j::LaunchContext* context, NDArray const* images, int const width,
|
||||||
|
int const height, bool const alignCorners, bool const halfPixelCenter, NDArray* output) {
|
||||||
const Nd4jLong batchSize = images->sizeAt(0);
|
const Nd4jLong batchSize = images->sizeAt(0);
|
||||||
const Nd4jLong inHeight = images->sizeAt(1);
|
const Nd4jLong inHeight = images->sizeAt(1);
|
||||||
const Nd4jLong inWidth = images->sizeAt(2);
|
const Nd4jLong inWidth = images->sizeAt(2);
|
||||||
|
@ -145,19 +202,8 @@ namespace helpers {
|
||||||
return ND4J_STATUS_OK;
|
return ND4J_STATUS_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Special case for TF compatibility
|
float heightScale = calculateResizeScale(inHeight, outHeight, alignCorners);
|
||||||
if((center && inHeight < 2) || (center && inWidth < 2)){
|
float widthScale = calculateResizeScale(inWidth, outWidth, alignCorners);
|
||||||
center = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if ((center && inHeight < 2) || (inHeight < 1) || (outHeight < 1) || (center && outHeight < 2) ||
|
|
||||||
(center && inWidth < 2) || (inWidth < 1) || (outWidth < 1) || (center && outWidth < 2)) {
|
|
||||||
// wrong input data
|
|
||||||
nd4j_printf("image.resize_bilinear: Wrong input or output size to resize\n", "");
|
|
||||||
return ND4J_STATUS_BAD_ARGUMENTS;
|
|
||||||
}
|
|
||||||
float heightScale = center ? (inHeight - 1.f) / double(outHeight - 1.f) : (inHeight / float(outHeight));
|
|
||||||
float widthScale = center ? (inWidth - 1.f) / double(outWidth - 1.f) : (inWidth / float(outWidth));
|
|
||||||
|
|
||||||
BilinearInterpolationData* xs_;// = xs.data();
|
BilinearInterpolationData* xs_;// = xs.data();
|
||||||
BilinearInterpolationData* ys_;// = xs.data();
|
BilinearInterpolationData* ys_;// = xs.data();
|
||||||
|
@ -173,12 +219,24 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
auto stream = context->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
// Compute the cached interpolation weights on the x and y dimensions.
|
// Compute the cached interpolation weights on the x and y dimensions.
|
||||||
computeInterpolationWeights<<<256, 512, 512, *stream>>>(outHeight, inHeight, heightScale, 0, ys_);
|
if (halfPixelCenter) {
|
||||||
computeInterpolationWeights<<<256, 512, 512, *stream>>>(outWidth, inWidth, widthScale, channels, xs_);
|
computeInterpolationWeights <
|
||||||
|
HalfPixelScaler ><<<256, 512, 512, *stream>>>(outHeight, inHeight, heightScale, 0, ys_);
|
||||||
|
computeInterpolationWeights <
|
||||||
|
HalfPixelScaler ><<<256, 512, 512, *stream>>>(outWidth, inWidth, widthScale, channels, xs_);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
computeInterpolationWeights <
|
||||||
|
LegacyScaler ><<<256, 512, 512, *stream>>>(outHeight, inHeight, heightScale, 0, ys_);
|
||||||
|
computeInterpolationWeights <
|
||||||
|
LegacyScaler ><<<256, 512, 512, *stream>>>(outWidth, inWidth, widthScale, channels, xs_);
|
||||||
|
}
|
||||||
|
printf("Input is %dx%d, Output is %dx%d\n", inHeight, inWidth, outHeight, outWidth);
|
||||||
NDArray::prepareSpecialUse({output}, {images});
|
NDArray::prepareSpecialUse({output}, {images});
|
||||||
resizeImage(context, images, batchSize, inHeight, inWidth, outHeight, outWidth, channels, xs_, ys_, output);
|
resizeImage_<T,F>(context, images, batchSize, inHeight, inWidth, outHeight, outWidth, channels, xs_, ys_, output);
|
||||||
|
err = cudaStreamSynchronize(*stream);
|
||||||
NDArray::registerSpecialUse({output}, {images});
|
NDArray::registerSpecialUse({output}, {images});
|
||||||
|
|
||||||
err = cudaFree(xs_);
|
err = cudaFree(xs_);
|
||||||
if (err != 0) {
|
if (err != 0) {
|
||||||
throw cuda_exception::build("helpers::resize_image: Cannot deallocate memory for vertical parts rectangulars", err);
|
throw cuda_exception::build("helpers::resize_image: Cannot deallocate memory for vertical parts rectangulars", err);
|
||||||
|
@ -197,20 +255,28 @@ namespace helpers {
|
||||||
//
|
//
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static __global__ void resizeNeighborKernel(T const* input, Nd4jLong* inputShape, T* output, Nd4jLong* outputShape,
|
static __global__ void resizeNeighborKernel(T const* input, Nd4jLong* inputShape, T* output, Nd4jLong* outputShape,
|
||||||
Nd4jLong batchSize, Nd4jLong inWidth, Nd4jLong inHeight, Nd4jLong outWidth, Nd4jLong outHeight, Nd4jLong channels, double widthScale, double heightScale, bool center) {
|
Nd4jLong batchSize, Nd4jLong inWidth, Nd4jLong inHeight, Nd4jLong outWidth, Nd4jLong outHeight, Nd4jLong channels, double widthScale, double heightScale, bool alignCorners, bool halfPixelCenters) {
|
||||||
|
|
||||||
//for (int b = blockIdx.x; b < batchSize; b += gridDim.x)
|
//for (int b = blockIdx.x; b < batchSize; b += gridDim.x)
|
||||||
if (blockIdx.x < batchSize)
|
if (blockIdx.x < batchSize)
|
||||||
{
|
{
|
||||||
auto b = blockIdx.x;
|
auto b = blockIdx.x;
|
||||||
for (int y = threadIdx.x; y < outHeight; y += blockDim.x) {
|
for (int y = threadIdx.x; y < outHeight; y += blockDim.x) {
|
||||||
Nd4jLong inY = nd4j::math::nd4j_min(
|
auto posY = alignCorners ? static_cast<Nd4jLong>(nd4j::math::p_round<float>(halfPixelCenters?((float)y + 0.5f) * heightScale:(float)y * heightScale)) : static_cast<Nd4jLong>(nd4j::math::p_floor<float>(
|
||||||
(center) ? static_cast<Nd4jLong>(nd4j::math::p_round<float>(y * heightScale)) : static_cast<Nd4jLong>(nd4j::math::p_floor<float>(
|
halfPixelCenters?((float)y + 0.5f) * heightScale:(float)y * heightScale));
|
||||||
y * heightScale)), inHeight - 1);
|
Nd4jLong inY = nd4j::math::nd4j_min(posY, inHeight - 1);
|
||||||
|
if (halfPixelCenters) {
|
||||||
|
inY = nd4j::math::nd4j_max(0LL, inY);
|
||||||
|
}
|
||||||
|
|
||||||
for (int x = threadIdx.y; x < outWidth; x += blockDim.y) {
|
for (int x = threadIdx.y; x < outWidth; x += blockDim.y) {
|
||||||
Nd4jLong inX = nd4j::math::nd4j_min(
|
auto posX = alignCorners ? static_cast<Nd4jLong>(nd4j::math::p_round<float>(halfPixelCenters?((float)x + 0.5f) * widthScale:(float)x * widthScale)) : static_cast<Nd4jLong>(nd4j::math::p_floor<float>(
|
||||||
(center) ? static_cast<Nd4jLong>(nd4j::math::p_round<float>(x * widthScale)) : static_cast<Nd4jLong>(nd4j::math::p_floor<float>(
|
halfPixelCenters?((float)x + 0.5f) * widthScale:(float)x * widthScale));
|
||||||
x * widthScale)), inWidth - 1);
|
Nd4jLong inX = nd4j::math::nd4j_min(posX, inWidth - 1);
|
||||||
|
if (halfPixelCenters) {
|
||||||
|
inX = nd4j::math::nd4j_max(0LL, inX);
|
||||||
|
}
|
||||||
|
|
||||||
auto start = blockIdx.z * blockDim.z + threadIdx.z;
|
auto start = blockIdx.z * blockDim.z + threadIdx.z;
|
||||||
auto step = blockDim.z * gridDim.z;
|
auto step = blockDim.z * gridDim.z;
|
||||||
|
|
||||||
|
@ -231,7 +297,8 @@ namespace helpers {
|
||||||
// resizeNeighborFunctor - main algorithm by nearest neighbor
|
// resizeNeighborFunctor - main algorithm by nearest neighbor
|
||||||
//
|
//
|
||||||
template <typename T>
|
template <typename T>
|
||||||
int resizeNeighborFunctor_(nd4j::LaunchContext* context, NDArray const* images, int width, int height, bool center, NDArray* output) {
|
int resizeNeighborFunctor_(nd4j::LaunchContext* context, NDArray const* images, int const width, int const height,
|
||||||
|
bool const alignCorners, bool const halfPixelCenters, NDArray* output) {
|
||||||
const Nd4jLong batchSize = images->sizeAt(0);
|
const Nd4jLong batchSize = images->sizeAt(0);
|
||||||
const Nd4jLong inHeight = images->sizeAt(1);
|
const Nd4jLong inHeight = images->sizeAt(1);
|
||||||
const Nd4jLong inWidth = images->sizeAt(2);
|
const Nd4jLong inWidth = images->sizeAt(2);
|
||||||
|
@ -246,25 +313,24 @@ namespace helpers {
|
||||||
return ND4J_STATUS_OK;
|
return ND4J_STATUS_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
if ((center && inHeight < 2) || (inHeight < 1) || (outHeight < 1) || (center && outHeight < 2) ||
|
// if ((alignCorners && inHeight < 2) || (inHeight < 1) || (outHeight < 1) || (alignCorners && outHeight < 2) ||
|
||||||
(center && inWidth < 2) || (inWidth < 1) || (outWidth < 1) || (center && outWidth < 2)) {
|
// (alignCorners && inWidth < 2) || (inWidth < 1) || (outWidth < 1) || (center && outWidth < 2)) {
|
||||||
// wrong input data
|
// // wrong input data
|
||||||
nd4j_printf("image.resize_nearest_neighbor: Wrong input or output size to resize\n", "");
|
// nd4j_printf("image.resize_nearest_neighbor: Wrong input or output size to resize\n", "");
|
||||||
return ND4J_STATUS_BAD_ARGUMENTS;
|
// return ND4J_STATUS_BAD_ARGUMENTS;
|
||||||
}
|
// }
|
||||||
double heightScale = center ? (inHeight - 1.) / double(outHeight - 1.0) : (inHeight / double(outHeight));
|
// float heightScale = alignCorners ? (inHeight - 1.f) / float(outHeight - 1.f) : (inHeight / float(outHeight));
|
||||||
double widthScale = center ? (inWidth - 1.) / double(outWidth - 1.0) : (inWidth / double(outWidth));
|
// float widthScale = alignCorners ? (inWidth - 1.f) / float(outWidth - 1.f) : (inWidth / float(outWidth));
|
||||||
auto imagesBuffer = reinterpret_cast<T const*>(images->getSpecialBuffer());
|
float heightScale = calculateResizeScale(inHeight, outHeight, alignCorners);
|
||||||
auto outputBuffer = reinterpret_cast<T*>(output->specialBuffer());
|
float widthScale = calculateResizeScale(inWidth, outWidth, alignCorners);
|
||||||
|
|
||||||
|
auto imagesBuffer = images->getDataBuffer()->specialAsT<T>();//reinterpret_cast<T const*>(images->getSpecialBuffer());
|
||||||
|
auto outputBuffer = output->dataBuffer()->specialAsT<T>();//reinterpret_cast<T*>(output->specialBuffer());
|
||||||
auto stream = context->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
|
|
||||||
//T const* input, Nd4jLong const* inputShape, T* output, Nd4jLong* outputShape,
|
|
||||||
// Nd4jLong batchSize, Nd4jLong inWidth, Nd4jLong inHeight, Nd4jLong outWidth, Nd4jLong outHeight, Nd4jLong channels, double widthScale, double heightScale, bool center
|
|
||||||
//input, inputShape, output, outputShape,
|
|
||||||
// batchSize, inWidth, inHeight, outWidth, outHeight, channels, widthScale, heightScale, center
|
|
||||||
NDArray::prepareSpecialUse({output}, {images});
|
NDArray::prepareSpecialUse({output}, {images});
|
||||||
resizeNeighborKernel<T><<<batchSize, outHeight * outWidth, 512, *stream>>>(imagesBuffer, images->getSpecialShapeInfo(), outputBuffer, output->specialShapeInfo(),
|
resizeNeighborKernel<T><<<batchSize, outHeight * outWidth, 512, *stream>>>(imagesBuffer, images->getSpecialShapeInfo(), outputBuffer, output->specialShapeInfo(),
|
||||||
batchSize, inWidth, inHeight, outWidth, outHeight, channels, widthScale, heightScale, center);
|
batchSize, inWidth, inHeight, outWidth, outHeight, channels, widthScale, heightScale, alignCorners, halfPixelCenters);
|
||||||
NDArray::registerSpecialUse({output}, {images});
|
NDArray::registerSpecialUse({output}, {images});
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
@ -275,39 +341,38 @@ namespace helpers {
|
||||||
void resizeImage(nd4j::LaunchContext* context, NDArray const* images, Nd4jLong batchSize, Nd4jLong inHeight,
|
void resizeImage(nd4j::LaunchContext* context, NDArray const* images, Nd4jLong batchSize, Nd4jLong inHeight,
|
||||||
Nd4jLong inWidth, Nd4jLong outHeight, Nd4jLong outWidth, Nd4jLong channels, BilinearInterpolationData* xs_,
|
Nd4jLong inWidth, Nd4jLong outHeight, Nd4jLong outWidth, Nd4jLong channels, BilinearInterpolationData* xs_,
|
||||||
BilinearInterpolationData* ys_, NDArray* output) {
|
BilinearInterpolationData* ys_, NDArray* output) {
|
||||||
BUILD_SINGLE_SELECTOR(images->dataType(), resizeImage_, (context, images, batchSize, inHeight, inWidth, outHeight, outWidth, channels, xs_, ys_, output), LIBND4J_TYPES);
|
BUILD_DOUBLE_SELECTOR(images->dataType(), output->dataType(),
|
||||||
|
resizeImage_, (context, images, batchSize, inHeight, inWidth, outHeight, outWidth, channels,
|
||||||
|
xs_, ys_, output), NUMERIC_TYPES, FLOAT_TYPES);
|
||||||
}
|
}
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void resizeImage_,(nd4j::LaunchContext* context, NDArray const* images,
|
BUILD_DOUBLE_TEMPLATE(template void resizeImage_,(nd4j::LaunchContext* context, NDArray const* images,
|
||||||
Nd4jLong batchSize, Nd4jLong inHeight, Nd4jLong inWidth, Nd4jLong outHeight, Nd4jLong outWidth,
|
Nd4jLong batchSize, Nd4jLong inHeight, Nd4jLong inWidth, Nd4jLong outHeight, Nd4jLong outWidth,
|
||||||
Nd4jLong channels, BilinearInterpolationData* xs_, BilinearInterpolationData* ys_, NDArray* output), LIBND4J_TYPES);
|
Nd4jLong channels, BilinearInterpolationData* xs_, BilinearInterpolationData* ys_, NDArray* output),
|
||||||
|
NUMERIC_TYPES, FLOAT_TYPES);
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
int resizeBilinearFunctor(nd4j::LaunchContext* context, NDArray const* images, int width, int height, bool center, NDArray* output) {
|
int resizeBilinearFunctor(nd4j::LaunchContext* context, NDArray const* images, int width, int height,
|
||||||
BUILD_SINGLE_SELECTOR(images->dataType(), return resizeBilinearFunctor_, (context, images, width, height, center, output), LIBND4J_TYPES);
|
bool const alignCorners, bool const halfPixelCenter, NDArray* output) {
|
||||||
|
BUILD_DOUBLE_SELECTOR(images->dataType(), output->dataType(), return resizeBilinearFunctor_, (context, images,
|
||||||
|
width, height, alignCorners, halfPixelCenter, output), NUMERIC_TYPES, FLOAT_TYPES);
|
||||||
}
|
}
|
||||||
BUILD_SINGLE_TEMPLATE(template int resizeBilinearFunctor_, (nd4j::LaunchContext* context, NDArray const* images, int width, int height, bool center, NDArray* output), LIBND4J_TYPES);
|
// BUILD_SINGLE_TEMPLATE(template int resizeBilinearFunctor_, (nd4j::LaunchContext* context,
|
||||||
|
// NDArray const* images, int const width, int const height, bool const alignCorners,
|
||||||
|
// bool const halfPixelCenter, NDArray* output), LIBND4J_TYPES);
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
int resizeNeighborFunctor(nd4j::LaunchContext* context, NDArray const* images, int width, int height, bool center, NDArray* output) {
|
int resizeNeighborFunctor(nd4j::LaunchContext* context, NDArray const* images, int const width, int const height,
|
||||||
BUILD_SINGLE_SELECTOR(images->dataType(), return resizeNeighborFunctor_, (context, images, width, height, center, output), LIBND4J_TYPES);
|
bool const alignCorners, bool const halfPixelCenter, NDArray* output) {
|
||||||
|
BUILD_SINGLE_SELECTOR(images->dataType(), return resizeNeighborFunctor_,
|
||||||
|
(context, images, width, height, alignCorners, halfPixelCenter, output), LIBND4J_TYPES);
|
||||||
}
|
}
|
||||||
BUILD_SINGLE_TEMPLATE(template int resizeNeighborFunctor_, (nd4j::LaunchContext* context, NDArray const* images,
|
// BUILD_SINGLE_TEMPLATE(template int resizeNeighborFunctor_, (nd4j::LaunchContext* context, NDArray const* images,
|
||||||
int width, int height, bool center, NDArray* output), LIBND4J_TYPES);
|
// int width, int height, bool const alignCorners, bool const halfPixelCenter, NDArray* output), LIBND4J_TYPES);
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
// Bicubic interpolation
|
// Bicubic interpolation
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
// Utility functions and classes
|
|
||||||
|
|
||||||
// calculateResizeScale determines the float scaling factor.
|
|
||||||
inline float calculateResizeScale(Nd4jLong inSize, Nd4jLong outSize,
|
|
||||||
bool alignCorners) {
|
|
||||||
return (alignCorners && outSize > 1)
|
|
||||||
? (inSize - 1) / static_cast<float>(outSize - 1)
|
|
||||||
: inSize / static_cast<float>(outSize);
|
|
||||||
}
|
|
||||||
|
|
||||||
struct ImageResizerState {
|
struct ImageResizerState {
|
||||||
explicit ImageResizerState(bool alignCorners, bool halfPixelCenters)
|
explicit ImageResizerState(bool alignCorners, bool halfPixelCenters)
|
||||||
: _alignCorners(alignCorners),
|
: _alignCorners(alignCorners),
|
||||||
|
@ -362,17 +427,6 @@ namespace helpers {
|
||||||
bool _halfPixelCenters;
|
bool _halfPixelCenters;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Half pixel scaler scales assuming that the pixel centers are at 0.5, i.e. the
|
|
||||||
// floating point coordinates of the top,left pixel is 0.5,0.5.
|
|
||||||
struct HalfPixelScaler {
|
|
||||||
_CUDA_HD HalfPixelScaler(){};
|
|
||||||
inline _CUDA_HD float operator()(const int x, const float scale) const {
|
|
||||||
// Note that we subtract 0.5 from the return value, as the existing bilinear
|
|
||||||
// sampling code etc assumes pixels are in the old coordinate system.
|
|
||||||
return (static_cast<float>(x) + 0.5f) * scale - 0.5f;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct WeightsAndIndices {
|
struct WeightsAndIndices {
|
||||||
float _weight0;
|
float _weight0;
|
||||||
float _weight1;
|
float _weight1;
|
||||||
|
@ -547,16 +601,6 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Older incorrect scaling method that causes all resizes to have a slight
|
|
||||||
// translation leading to inconsistent results. For example, a flip then a
|
|
||||||
// resize gives different results then a resize then a flip.
|
|
||||||
struct LegacyScaler {
|
|
||||||
_CUDA_HD LegacyScaler(){};
|
|
||||||
inline _CUDA_HD float operator()(const int x, const float scale) const {
|
|
||||||
return static_cast<float>(x) * scale;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
static __global__ void accumulateChannelsKernel(WeightsAndIndices* pXWais, Nd4jLong outWidth, Nd4jLong channels) {
|
static __global__ void accumulateChannelsKernel(WeightsAndIndices* pXWais, Nd4jLong outWidth, Nd4jLong channels) {
|
||||||
auto start = blockIdx.x * blockDim.x + threadIdx.x;
|
auto start = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
auto step = blockDim.x * gridDim.x;
|
auto step = blockDim.x * gridDim.x;
|
||||||
|
@ -906,8 +950,8 @@ namespace helpers {
|
||||||
int resizeFunctor(nd4j::LaunchContext * context, NDArray const* image, int width, int height,
|
int resizeFunctor(nd4j::LaunchContext * context, NDArray const* image, int width, int height,
|
||||||
ImageResizeMethods method, bool preserveAspectRatio, bool antialias, NDArray* output) {
|
ImageResizeMethods method, bool preserveAspectRatio, bool antialias, NDArray* output) {
|
||||||
switch (method) {
|
switch (method) {
|
||||||
case kResizeBilinear: return resizeBilinearFunctor(context, image, width, height, false, output); break;
|
case kResizeBilinear: return resizeBilinearFunctor(context, image, width, height, false, false, output); break;
|
||||||
case kResizeNearest: return resizeNeighborFunctor(context, image, width, height, true, output); break;
|
case kResizeNearest: return resizeNeighborFunctor(context, image, width, height, false, false, output); break;
|
||||||
case kResizeBicubic: return resizeBicubicFunctor(context, image, width, height, preserveAspectRatio, antialias, output); break;
|
case kResizeBicubic: return resizeBicubicFunctor(context, image, width, height, preserveAspectRatio, antialias, output); break;
|
||||||
case kResizeLanczos5:
|
case kResizeLanczos5:
|
||||||
case kResizeGaussian:
|
case kResizeGaussian:
|
||||||
|
|
|
@ -30,6 +30,67 @@ namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
namespace helpers {
|
namespace helpers {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static __global__ void reverseTadKernel(void* vinput, Nd4jLong *inputShape, void* voutput, Nd4jLong *outputShape, Nd4jLong *inputTadShape, Nd4jLong *inputTadOffsets, Nd4jLong *outputTadShape, Nd4jLong *outputTadOffsets, uint64_t limit, uint64_t numOfElemsToReverse, uint64_t numTads) {
|
||||||
|
auto input = reinterpret_cast<T*>(vinput);
|
||||||
|
auto output = reinterpret_cast<T*>(voutput);
|
||||||
|
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
const auto step = gridDim.x * blockDim.x;
|
||||||
|
|
||||||
|
// this means that we'll have additional cycle, to move middle element
|
||||||
|
auto div = numOfElemsToReverse / 2;
|
||||||
|
auto odd = numOfElemsToReverse % 2 != 0;
|
||||||
|
auto rlimit = odd ? limit / 2 + 1 : limit / 2;
|
||||||
|
|
||||||
|
// all threads operate in the same input/output space
|
||||||
|
for (uint64_t e = tid; e < rlimit; e += step) {
|
||||||
|
// finding out the TAD we're going to process
|
||||||
|
auto tadId = e / div;
|
||||||
|
|
||||||
|
if (tadId >= numTads)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
// now finding out element within tad
|
||||||
|
auto idx = e % div;
|
||||||
|
|
||||||
|
//printf("TID: %i; numTads: %lld; tadLength: %lld; tadId: %i, idx: %lld\n", tid, numTads, numOfElemsToReverse, tadId, idx);
|
||||||
|
|
||||||
|
auto tadInput = input + inputTadOffsets[tadId];
|
||||||
|
auto tadOutput = output + outputTadOffsets[tadId];
|
||||||
|
|
||||||
|
// we're calculating offsets within input TAD
|
||||||
|
auto fOffset = shape::getIndexOffset(idx, inputTadShape);
|
||||||
|
auto lOffset = shape::getIndexOffset(numOfElemsToReverse - idx - 1, inputTadShape);
|
||||||
|
|
||||||
|
// now we're storing input values
|
||||||
|
auto v1 = tadInput[fOffset];
|
||||||
|
auto v2 = tadInput[lOffset];
|
||||||
|
|
||||||
|
// now we're calculating offsets within output TAD
|
||||||
|
auto zfOffset = shape::getIndexOffset(idx, outputTadShape);
|
||||||
|
auto zlOffset = shape::getIndexOffset(numOfElemsToReverse - idx - 1, outputTadShape);
|
||||||
|
|
||||||
|
// and saving values to output arrays
|
||||||
|
tadOutput[zfOffset] = v2;
|
||||||
|
tadOutput[zlOffset] = v1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// moving odd element in blocks
|
||||||
|
if (odd && threadIdx.x == 0) {
|
||||||
|
for (uint64_t e = blockIdx.x; e < numTads; e += gridDim.x) {
|
||||||
|
auto tadInput = input + inputTadOffsets[e];
|
||||||
|
auto tadOutput = output + outputTadOffsets[e];
|
||||||
|
|
||||||
|
auto xOffset = shape::getIndexOffset(numOfElemsToReverse / 2, inputTadShape);
|
||||||
|
auto zOffset = shape::getIndexOffset(numOfElemsToReverse / 2, outputTadShape);
|
||||||
|
|
||||||
|
tadOutput[zOffset] = tadInput[xOffset];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static __global__ void reverseArrayKernel(void* input, Nd4jLong *inputShape, void* output, Nd4jLong *outputShape, Nd4jLong numOfElemsToReverse) {
|
static __global__ void reverseArrayKernel(void* input, Nd4jLong *inputShape, void* output, Nd4jLong *outputShape, Nd4jLong numOfElemsToReverse) {
|
||||||
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
@ -52,7 +113,7 @@ namespace helpers {
|
||||||
auto odd = numOfElemsToReverse % 2 != 0;
|
auto odd = numOfElemsToReverse % 2 != 0;
|
||||||
auto limit = numOfElemsToReverse / 2;
|
auto limit = numOfElemsToReverse / 2;
|
||||||
|
|
||||||
for (Nd4jLong e = tid; e < limit; e += step) {
|
for (uint64_t e = tid; e < limit; e += step) {
|
||||||
// we're calculating offsets within input array
|
// we're calculating offsets within input array
|
||||||
auto fOffset = shape::getIndexOffset(e, inputShape);
|
auto fOffset = shape::getIndexOffset(e, inputShape);
|
||||||
auto lOffset = shape::getIndexOffset(numOfElemsToReverse - e - 1, inputShape);
|
auto lOffset = shape::getIndexOffset(numOfElemsToReverse - e - 1, inputShape);
|
||||||
|
@ -80,13 +141,19 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
static void reverseArray(nd4j::LaunchContext * context, NDArray* input, NDArray* output, Nd4jLong numOfElemsToReverse) {
|
static void reverseTad(nd4j::LaunchContext * context, const NDArray* input, NDArray* output, Nd4jLong *inputTadShape, Nd4jLong *inputTadOffsets, Nd4jLong *outputTadShape, Nd4jLong *outputTadOffsets, uint64_t tadLength) {
|
||||||
|
auto stream = context->getCudaStream();
|
||||||
|
reverseTadKernel<T><<<256, 512, 8192, *stream>>>(input->getSpecialBuffer(), input->getSpecialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), inputTadShape, inputTadOffsets, outputTadShape, outputTadOffsets, input->lengthOf(), tadLength, input->lengthOf() / tadLength);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
static void reverseArray(nd4j::LaunchContext * context, const NDArray* input, NDArray* output, Nd4jLong numOfElemsToReverse) {
|
||||||
auto stream = context->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
Nd4jLong numOfReverse = numOfElemsToReverse;
|
Nd4jLong numOfReverse = numOfElemsToReverse;
|
||||||
if (numOfElemsToReverse == 0)
|
if (numOfElemsToReverse == 0)
|
||||||
numOfReverse = input->lengthOf();
|
numOfReverse = input->lengthOf();
|
||||||
|
|
||||||
reverseArrayKernel<T><<<256, 512, 8192, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), numOfReverse);
|
reverseArrayKernel<T><<<256, 512, 8192, *stream>>>(input->getSpecialBuffer(), input->getSpecialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), numOfReverse);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -153,27 +220,23 @@ namespace helpers {
|
||||||
// we need to reverse axis only if that's new op
|
// we need to reverse axis only if that's new op
|
||||||
std::vector<int> dimensions = isBackProp ? ShapeUtils::evalDimsToExclude(input->rankOf(), *intArgs) : *intArgs;
|
std::vector<int> dimensions = isBackProp ? ShapeUtils::evalDimsToExclude(input->rankOf(), *intArgs) : *intArgs;
|
||||||
std::vector<int> axis = ShapeUtils::evalDimsToExclude(input->rankOf(), dimensions);
|
std::vector<int> axis = ShapeUtils::evalDimsToExclude(input->rankOf(), dimensions);
|
||||||
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), axis);
|
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), dimensions);
|
||||||
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), axis);
|
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), dimensions);
|
||||||
|
|
||||||
auto listOut = output->allTensorsAlongDimension(dimensions);
|
|
||||||
auto listIn = input->allTensorsAlongDimension(dimensions);
|
|
||||||
|
|
||||||
NDArray *subArrIn, *subArrOut;
|
|
||||||
|
|
||||||
NDArray::prepareSpecialUse({output}, {input});
|
NDArray::prepareSpecialUse({output}, {input});
|
||||||
for(int i = 0; i < listIn->size(); ++i) { // listIn->size() = listOut->size()
|
|
||||||
subArrIn = listIn->at(i);
|
if (packX.numberOfTads() == 1) {
|
||||||
subArrOut = listOut->at(i);
|
BUILD_SINGLE_SELECTOR(input->dataType(), reverseArray, (context, input, output, 0), LIBND4J_TYPES);
|
||||||
BUILD_SINGLE_SELECTOR(input->dataType(), reverseArray, (context, subArrIn, subArrOut, 0), LIBND4J_TYPES);
|
} else {
|
||||||
|
BUILD_SINGLE_SELECTOR(input->dataType(), reverseTad, (context, input, output, packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets(), (uint64_t) (input->lengthOf() / packX.numberOfTads())), LIBND4J_TYPES);
|
||||||
}
|
}
|
||||||
//BUILD_SINGLE_SELECTOR(input->dataType(), reverseArray, (context, const_cast<NDArray*>(input), output, (int)0), LIBND4J_TYPES);
|
|
||||||
NDArray::registerSpecialUse({output}, {input});
|
NDArray::registerSpecialUse({output}, {input});
|
||||||
delete listOut;
|
|
||||||
delete listIn;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void reverseArray, (nd4j::LaunchContext * context, NDArray *inArr, NDArray *outArr, Nd4jLong numOfElemsToReverse), LIBND4J_TYPES);
|
BUILD_SINGLE_TEMPLATE(template void reverseArray, (nd4j::LaunchContext * context, const NDArray *inArr, NDArray *outArr, Nd4jLong numOfElemsToReverse), LIBND4J_TYPES);
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -37,15 +37,15 @@ namespace helpers {
|
||||||
kResizeArea
|
kResizeArea
|
||||||
};
|
};
|
||||||
|
|
||||||
int resizeBilinearFunctor(nd4j::LaunchContext * context, NDArray const* image, int width, int height, bool center,
|
int resizeBilinearFunctor(nd4j::LaunchContext * context, NDArray const* image, int const width, int const height,
|
||||||
NDArray* output);
|
bool const alignCorners, bool const halfPixelCenter, NDArray* output);
|
||||||
int resizeNeighborFunctor(nd4j::LaunchContext * context, NDArray const* image, int width, int height, bool center,
|
int resizeNeighborFunctor(nd4j::LaunchContext * context, NDArray const* image, int const width, int const height,
|
||||||
NDArray* output);
|
bool const alignCorners, bool const halfPixelCenter, NDArray* output);
|
||||||
int resizeBicubicFunctor(nd4j::LaunchContext * context, NDArray const* image, int width, int height,
|
int resizeBicubicFunctor(nd4j::LaunchContext * context, NDArray const* image, int const width, int const height,
|
||||||
bool preserveAspectRatio, bool antialias, NDArray* output);
|
bool preserveAspectRatio, bool antialias, NDArray* output);
|
||||||
int resizeBicubicFunctorA(nd4j::LaunchContext * context, NDArray const* image, int width, int height,
|
int resizeBicubicFunctorA(nd4j::LaunchContext * context, NDArray const* image, int const width, int const height,
|
||||||
bool const alignCorners, bool const halfPixelAlign, NDArray* output);
|
bool const alignCorners, bool const halfPixelAlign, NDArray* output);
|
||||||
int resizeFunctor(nd4j::LaunchContext * context, NDArray const* image, int width, int height,
|
int resizeFunctor(nd4j::LaunchContext * context, NDArray const* image, int const width, int const height,
|
||||||
ImageResizeMethods method, bool preserveAspectRatio, bool antialias, NDArray* output);
|
ImageResizeMethods method, bool preserveAspectRatio, bool antialias, NDArray* output);
|
||||||
|
|
||||||
void cropAndResizeFunctor(nd4j::LaunchContext * context, NDArray const* images, NDArray const* boxes,
|
void cropAndResizeFunctor(nd4j::LaunchContext * context, NDArray const* images, NDArray const* boxes,
|
||||||
|
|
File diff suppressed because one or more lines are too long
|
@ -212,12 +212,12 @@ TEST_F(ConvolutionTests2, Test_Dilation2D_Again_2) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TYPED_TEST(TypedConvolutionTests2, sconv2d_bp_1) {
|
TYPED_TEST(TypedConvolutionTests2, sconv2d_bp_1) {
|
||||||
TypeParam _expGradWpB[] = {1603.7102981f, 10645.6278024f, 5975.4227995f, 17697.0903052f, 12133.6353024f, 26535.0528052f, 1779.221097f, 11795.5686029f, 6721.9835994f, 19904.0811062f, 13775.2461029f, 30123.0936062f, 1954.7318976f, 12945.5094033f, 7468.5443993f, 22111.071907f, 15416.8569033f, 33711.134407f, 2130.2426974f, 14095.4502038f, 8215.1051992f, 24318.0627081f, 17058.4677038f, 37299.1752081f, 2305.7534972f, 15245.3910042f, 8961.6659991f, 26525.0535091f, 18700.0785042f, 40887.2160091f, 2481.2642970f, 16395.3318047f, 9708.2267991f, 28732.0443100f, 20341.6893047f, 44475.2568100f, 2656.7750968f, 17545.2726051f, 10454.7875990f, 30939.0351110f, 21983.3001051f, 48063.2976110f, 2832.2858966f, 18695.2134056f, 11201.3483989f, 33146.0259119f, 23624.9109056f, 51651.3384119f, 3007.7966964f, 19845.1542060f, 11947.9091988f, 35353.0167129f, 25266.5217060f, 55239.3792129f, 3183.3074962f, 20995.095006f, 12694.4699987f, 37560.007513f, 26908.132506f, 58827.4200139};
|
TypeParam _expGradWpB[] = {1603.7102981f, 10645.6278024f, 5975.4227995f, 17697.0903052f, 12133.6353024f, 26535.0528052f, 1779.221097f, 11795.5686029f, 6721.9835994f, 19904.0811062f, 13775.2461029f, 30123.0936062f, 1954.7318976f, 12945.5094033f, 7468.5443993f, 22111.071907f, 15416.8569033f, 33711.134407f, 2130.2426974f, 14095.4502038f, 8215.1051992f, 24318.0627081f, 17058.4677038f, 37299.1752081f, 2305.7534972f, 15245.3910042f, 8961.6659991f, 26525.0535091f, 18700.0785042f, 40887.2160091f, 2481.2642970f, 16395.3318047f, 9708.2267991f, 28732.0443100f, 20341.6893047f, 44475.2568100f, 2656.7750968f, 17545.2726051f, 10454.7875990f, 30939.0351110f, 21983.3001051f, 48063.2976110f, 2832.2858966f, 18695.2134056f, 11201.3483989f, 33146.0259119f, 23624.9109056f, 51651.3384119f, 3007.7966964f, 19845.1542060f, 11947.9091988f, 35353.0167129f, 25266.5217060f, 55239.3792129f, 3183.3074962f, 20995.095006f, 12694.4699987f, 37560.007513f, 26908.132506f, 58827.4200139f};
|
||||||
Nd4jLong _expGradWpS[] {4, 10, 6, 1, 1, 6, 1, 1, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, 1, 99};
|
Nd4jLong _expGradWpS[] {4, 10, 6, 1, 1, 6, 1, 1, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, 1, 99};
|
||||||
NDArray expGWP(_expGradWpB, _expGradWpS);
|
NDArray expGWP(_expGradWpB, _expGradWpS);
|
||||||
expGWP.permutei({2,3,1,0});
|
expGWP.permutei({2,3,1,0});
|
||||||
|
|
||||||
TypeParam _expGradWdB[] = {2074.21032f, 2082.76104f, 2091.31176f, 2099.86248f, 2108.4132f, 2159.71752f, 2168.26824f, 2176.81896f, 2185.36968f, 2193.9204f, 2245.22472f, 2253.77544f, 2262.32616f, 2270.87688f, 2279.4276f, 2330.73192f, 2339.28264f, 2347.83336f, 2356.38408f, 2364.9348f, 2416.23912f, 2424.78984f, 2433.34056f, 2441.89128f, 2450.442f, 3112.99344f, 3122.06328f, 3131.13312f, 3140.20296f, 3149.2728f, 3203.69184f, 3212.76168f, 3221.83152f, 3230.90136f, 3239.9712f, 3294.39024f, 3303.46008f, 3312.52992f, 3321.59976f, 3330.6696f, 3385.08864f, 3394.15848f, 3403.22832f, 3412.29816f, 3421.368f, 3475.78704f, 3484.85688f, 3493.92672f, 3502.99656f, 3512.0664f, 4255.60056f, 4265.18952f, 4274.77848f, 4284.36744f, 4293.9564f, 4351.49016f, 4361.07912f, 4370.66808f, 4380.25704f, 4389.846f, 4447.37976f, 4456.96872f, 4466.55768f, 4476.14664f, 4485.7356f, 4543.26936f, 4552.85832f, 4562.44728f, 4572.03624f, 4581.6252f, 4639.15896f, 4648.74792f, 4658.33688f, 4667.92584f, 4677.5148f, 2140.10988f, 2148.92016f, 2157.73044f, 2166.54072f, 2175.351f, 2228.21268f, 2237.02296f, 2245.83324f, 2254.64352f, 2263.4538f, 2316.31548f, 2325.12576f, 2333.93604f, 2342.74632f, 2351.5566f, 2404.41828f, 2413.22856f, 2422.03884f, 2430.84912f, 2439.6594f, 2492.52108f, 2501.33136f, 2510.14164f, 2518.95192f, 2527.7622f, 3204.849f, 3214.1784f, 3223.5078f, 3232.8372f, 3242.1666f, 3298.143f, 3307.4724f, 3316.8018f, 3326.1312f, 3335.4606f, 3391.437f, 3400.7664f, 3410.0958f, 3419.4252f, 3428.7546f, 3484.731f, 3494.0604f, 3503.3898f, 3512.7192f, 3522.0486f, 3578.025f, 3587.3544f, 3596.6838f, 3606.0132f, 3615.3426f, 4373.41212f, 4383.26064f, 4393.10916f, 4402.95768f, 4412.8062f, 4471.89732f, 4481.74584f, 4491.59436f, 4501.44288f, 4511.2914f, 4570.38252f, 4580.23104f, 4590.07956f, 4599.92808f, 4609.7766f, 4668.86772f, 4678.71624f, 4688.56476f, 4698.41328f, 4708.2618f, 4767.35292f, 4777.20144f, 4787.04996f, 4796.89848f, 4806.747};
|
TypeParam _expGradWdB[] = {2074.21032f, 2082.76104f, 2091.31176f, 2099.86248f, 2108.4132f, 2159.71752f, 2168.26824f, 2176.81896f, 2185.36968f, 2193.9204f, 2245.22472f, 2253.77544f, 2262.32616f, 2270.87688f, 2279.4276f, 2330.73192f, 2339.28264f, 2347.83336f, 2356.38408f, 2364.9348f, 2416.23912f, 2424.78984f, 2433.34056f, 2441.89128f, 2450.442f, 3112.99344f, 3122.06328f, 3131.13312f, 3140.20296f, 3149.2728f, 3203.69184f, 3212.76168f, 3221.83152f, 3230.90136f, 3239.9712f, 3294.39024f, 3303.46008f, 3312.52992f, 3321.59976f, 3330.6696f, 3385.08864f, 3394.15848f, 3403.22832f, 3412.29816f, 3421.368f, 3475.78704f, 3484.85688f, 3493.92672f, 3502.99656f, 3512.0664f, 4255.60056f, 4265.18952f, 4274.77848f, 4284.36744f, 4293.9564f, 4351.49016f, 4361.07912f, 4370.66808f, 4380.25704f, 4389.846f, 4447.37976f, 4456.96872f, 4466.55768f, 4476.14664f, 4485.7356f, 4543.26936f, 4552.85832f, 4562.44728f, 4572.03624f, 4581.6252f, 4639.15896f, 4648.74792f, 4658.33688f, 4667.92584f, 4677.5148f, 2140.10988f, 2148.92016f, 2157.73044f, 2166.54072f, 2175.351f, 2228.21268f, 2237.02296f, 2245.83324f, 2254.64352f, 2263.4538f, 2316.31548f, 2325.12576f, 2333.93604f, 2342.74632f, 2351.5566f, 2404.41828f, 2413.22856f, 2422.03884f, 2430.84912f, 2439.6594f, 2492.52108f, 2501.33136f, 2510.14164f, 2518.95192f, 2527.7622f, 3204.849f, 3214.1784f, 3223.5078f, 3232.8372f, 3242.1666f, 3298.143f, 3307.4724f, 3316.8018f, 3326.1312f, 3335.4606f, 3391.437f, 3400.7664f, 3410.0958f, 3419.4252f, 3428.7546f, 3484.731f, 3494.0604f, 3503.3898f, 3512.7192f, 3522.0486f, 3578.025f, 3587.3544f, 3596.6838f, 3606.0132f, 3615.3426f, 4373.41212f, 4383.26064f, 4393.10916f, 4402.95768f, 4412.8062f, 4471.89732f, 4481.74584f, 4491.59436f, 4501.44288f, 4511.2914f, 4570.38252f, 4580.23104f, 4590.07956f, 4599.92808f, 4609.7766f, 4668.86772f, 4678.71624f, 4688.56476f, 4698.41328f, 4708.2618f, 4767.35292f, 4777.20144f, 4787.04996f, 4796.89848f, 4806.747f};
|
||||||
Nd4jLong _expGradWdS[] = {4, 2, 3, 5, 5, 75, 25, 5, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, 1, 99};
|
Nd4jLong _expGradWdS[] = {4, 2, 3, 5, 5, 75, 25, 5, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, 1, 99};
|
||||||
NDArray expGWD(_expGradWdB, _expGradWdS);
|
NDArray expGWD(_expGradWdB, _expGradWdS);
|
||||||
expGWD.permutei({2,3,1,0});
|
expGWD.permutei({2,3,1,0});
|
||||||
|
|
|
@ -1594,7 +1594,7 @@ TEST_F(DeclarableOpsTests1, TestGemv1) {
|
||||||
|
|
||||||
auto z = NDArrayFactory::create_<float>('f', {5, 1});
|
auto z = NDArrayFactory::create_<float>('f', {5, 1});
|
||||||
|
|
||||||
auto expBuffer = new float[5]{28.00,64.00,100.00,136.00,172.00};
|
auto expBuffer = new float[5]{28.00f,64.00f,100.00f,136.00f,172.00f};
|
||||||
auto exp = new NDArray(expBuffer, z->getShapeInfo());
|
auto exp = new NDArray(expBuffer, z->getShapeInfo());
|
||||||
|
|
||||||
nd4j::blas::GEMV<float, float, float>::op('f', x->rows(), x->columns(), 1.0f, x->getBuffer(), y->rows(), y->getBuffer(), 1, 0.0, z->getBuffer(), 1);
|
nd4j::blas::GEMV<float, float, float>::op('f', x->rows(), x->columns(), 1.0f, x->getBuffer(), y->rows(), y->getBuffer(), 1, 0.0, z->getBuffer(), 1);
|
||||||
|
@ -3523,7 +3523,8 @@ TEST_F(DeclarableOpsTests1, Reverse_7 ) {
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
auto result = results->at(0);
|
auto result = results->at(0);
|
||||||
// result->printBuffer();
|
//expected.printIndexedBuffer("E");
|
||||||
|
//result->printIndexedBuffer("R");
|
||||||
|
|
||||||
ASSERT_TRUE(expected.isSameShapeStrict(result));
|
ASSERT_TRUE(expected.isSameShapeStrict(result));
|
||||||
ASSERT_TRUE(expected.equalsTo(result));
|
ASSERT_TRUE(expected.equalsTo(result));
|
||||||
|
@ -3605,7 +3606,9 @@ TEST_F(DeclarableOpsTests1, Reverse_11 ) {
|
||||||
|
|
||||||
|
|
||||||
auto input = NDArrayFactory::create<float>('c', {2,3,4});
|
auto input = NDArrayFactory::create<float>('c', {2,3,4});
|
||||||
auto expected = NDArrayFactory::create<float>('c', {2,3,4}, {24., 23., 22., 21., 20., 19., 18., 17., 16., 15., 14., 13., 12., 11., 10., 9., 8., 7., 6., 5., 4., 3., 2., 1.});
|
auto expected = NDArrayFactory::create<float>('c', {2,3,4}, {24.f, 23.f, 22.f, 21.f, 20.f, 19.f, 18.f, 17.f, 16.f,
|
||||||
|
15.f, 14.f, 13.f, 12.f, 11.f, 10.f, 9.f, 8.f, 7.f,
|
||||||
|
6.f, 5.f, 4.f, 3.f, 2.f, 1.f});
|
||||||
|
|
||||||
input.linspace(1);
|
input.linspace(1);
|
||||||
nd4j::ops::reverse op;
|
nd4j::ops::reverse op;
|
||||||
|
|
|
@ -121,10 +121,10 @@ TEST_F(DeclarableOpsTests10, Test_Or_1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests10, Test_Not_1) {
|
TEST_F(DeclarableOpsTests10, Test_Not_1) {
|
||||||
auto x = NDArrayFactory::create<bool>('c', {4}, {1, 1, 0, 1});
|
auto x = NDArrayFactory::create<bool>('c', {4}, {true, true, false, true});
|
||||||
auto y = NDArrayFactory::create<bool>('c', {4}, {0, 0, 0, 1});
|
auto y = NDArrayFactory::create<bool>('c', {4}, {false, false, false, true});
|
||||||
// auto e = NDArrayFactory::create<bool>('c', {4}, {1, 1, 1, 0});
|
// auto e = NDArrayFactory::create<bool>('c', {4}, {1, 1, 1, 0});
|
||||||
auto e = NDArrayFactory::create<bool>('c', {4}, {0, 0, 1, 0});
|
auto e = NDArrayFactory::create<bool>('c', {4}, {false, false, true, false});
|
||||||
|
|
||||||
nd4j::ops::boolean_not op;
|
nd4j::ops::boolean_not op;
|
||||||
auto result = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::BOOL);
|
auto result = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::BOOL);
|
||||||
|
@ -245,7 +245,8 @@ TEST_F(DeclarableOpsTests10, WhereNP_SGO_Test_1) {
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests10, WhereNP_SGO_Test_2) {
|
TEST_F(DeclarableOpsTests10, WhereNP_SGO_Test_2) {
|
||||||
auto cond2d = NDArrayFactory::create<bool>('c', {3, 5}, {1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1});
|
auto cond2d = NDArrayFactory::create<bool>('c', {3, 5}, {true, true, false, false, true, true, true,
|
||||||
|
true, true, true, false, true, true, true, true});
|
||||||
// auto expIdx({0, 1, 0, 2, 0, 3, 4, 1, 4, 1});
|
// auto expIdx({0, 1, 0, 2, 0, 3, 4, 1, 4, 1});
|
||||||
auto exp1 = NDArrayFactory::create<Nd4jLong>({0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2});
|
auto exp1 = NDArrayFactory::create<Nd4jLong>({0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2});
|
||||||
auto exp2 = NDArrayFactory::create<Nd4jLong>({0, 1, 4, 0, 1, 2, 3, 4, 1, 2, 3, 4});
|
auto exp2 = NDArrayFactory::create<Nd4jLong>({0, 1, 4, 0, 1, 2, 3, 4, 1, 2, 3, 4});
|
||||||
|
@ -623,7 +624,7 @@ TEST_F(DeclarableOpsTests10, range_test11) {
|
||||||
//////////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests10, range_test12) {
|
TEST_F(DeclarableOpsTests10, range_test12) {
|
||||||
|
|
||||||
auto exp = NDArrayFactory::create<float>('c', {9}, {0.5, 1. , 1.5, 2. , 2.5, 3. , 3.5, 4. , 4.5});
|
auto exp = NDArrayFactory::create<float>('c', {9}, {0.5f, 1.f , 1.5f, 2.f , 2.5f, 3.f , 3.5f, 4.f , 4.5f});
|
||||||
|
|
||||||
nd4j::ops::range op;
|
nd4j::ops::range op;
|
||||||
auto result = op.execute({}, {0.5, 5, 0.5}, {}, {});
|
auto result = op.execute({}, {0.5, 5, 0.5}, {}, {});
|
||||||
|
@ -1416,7 +1417,7 @@ TEST_F(DeclarableOpsTests10, broadcast_to_test10) {
|
||||||
////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test1) {
|
TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test1) {
|
||||||
|
|
||||||
NDArray input = NDArrayFactory::create<double>('c', {1, 2,3,4});
|
NDArray input = NDArrayFactory::create<double>('c', {1, 2, 3, 4});
|
||||||
//NDArray<float> paddings('c', {3,2}, {0,0, 0,1, 0,0});
|
//NDArray<float> paddings('c', {3,2}, {0,0, 0,1, 0,0});
|
||||||
//NDArray<float> expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.});
|
//NDArray<float> expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.});
|
||||||
NDArray expected = NDArrayFactory::create<double>('c', {1, 10, 10, 4}, {1., 2., 3., 4., 2.2, 3.2, 4.2, 5.2, 3.4, 4.4, 5.4, 6.4,
|
NDArray expected = NDArrayFactory::create<double>('c', {1, 10, 10, 4}, {1., 2., 3., 4., 2.2, 3.2, 4.2, 5.2, 3.4, 4.4, 5.4, 6.4,
|
||||||
|
@ -1470,6 +1471,138 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test1) {
|
||||||
delete results;
|
delete results;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test_11) {
|
||||||
|
|
||||||
|
NDArray input = NDArrayFactory::create<float>('c', {1, 1, 1, 256});
|
||||||
|
|
||||||
|
input.assign(0.8f); //linspace(1);
|
||||||
|
auto size = NDArrayFactory::create<int>({65,65});
|
||||||
|
auto ex = NDArrayFactory::create<float>('c', {1,65,65,256});
|
||||||
|
nd4j::ops::resize_bilinear op;
|
||||||
|
auto results = op.execute({&input, &size}, {}, {}, {false});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
NDArray* result = results->at(0);
|
||||||
|
ASSERT_NE(*result, ex);
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test_12) {
|
||||||
|
|
||||||
|
NDArray input = NDArrayFactory::create<float>('c', {1, 1, 1, 256});
|
||||||
|
|
||||||
|
input.assign(0.8f); //linspace(1);
|
||||||
|
auto size = NDArrayFactory::create<int>({65,65});
|
||||||
|
auto ex = NDArrayFactory::create<float>('c', {1,65,65,256});
|
||||||
|
nd4j::ops::resize_bilinear op;
|
||||||
|
auto results = op.execute({&input, &size}, {}, {}, {true});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
NDArray* result = results->at(0);
|
||||||
|
ASSERT_NE(*result, ex);
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test1_1) {
|
||||||
|
|
||||||
|
NDArray input = NDArrayFactory::create<double>('c', {1, 2, 3, 4});
|
||||||
|
//NDArray<float> paddings('c', {3,2}, {0,0, 0,1, 0,0});
|
||||||
|
//NDArray<float> expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.});
|
||||||
|
NDArray expected = NDArrayFactory::create<double>('c', {1, 4, 5, 4}, {
|
||||||
|
1., 2., 3., 4.,
|
||||||
|
2.6, 3.6, 4.6, 5.6,
|
||||||
|
5., 6., 7., 8.,
|
||||||
|
7.4, 8.4, 9.4, 10.4,
|
||||||
|
9., 10., 11., 12.,
|
||||||
|
|
||||||
|
4., 5., 6., 7.,
|
||||||
|
5.6, 6.6, 7.6, 8.6,
|
||||||
|
8., 9., 10., 11.,
|
||||||
|
10.4, 11.4, 12.4, 13.4,
|
||||||
|
12., 13., 14., 15.,
|
||||||
|
|
||||||
|
10., 11., 12., 13.,
|
||||||
|
11.6, 12.6, 13.6, 14.6,
|
||||||
|
14., 15., 16., 17.,
|
||||||
|
16.4, 17.4, 18.4, 19.4,
|
||||||
|
18., 19., 20., 21.,
|
||||||
|
|
||||||
|
13., 14., 15., 16.,
|
||||||
|
14.6, 15.6, 16.6, 17.6,
|
||||||
|
17., 18., 19., 20.,
|
||||||
|
19.4, 20.4, 21.4, 22.4,
|
||||||
|
21., 22., 23., 24.
|
||||||
|
});
|
||||||
|
//input = 1.f;
|
||||||
|
input.linspace(1);
|
||||||
|
|
||||||
|
nd4j::ops::resize_bilinear op;
|
||||||
|
auto results = op.execute({&input}, {}, {4, 5}, {false, true});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
NDArray* result = results->at(0);
|
||||||
|
|
||||||
|
// result->printIndexedBuffer("Resized to 4x5 bilinear with half pixels");
|
||||||
|
//expected.printIndexedBuffer("Expect for 10x10");
|
||||||
|
ASSERT_TRUE(expected.isSameShape(result));
|
||||||
|
ASSERT_TRUE(expected.equalsTo(result));
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test1_2) {
|
||||||
|
|
||||||
|
NDArray input = NDArrayFactory::create<int>('c', {1, 2, 3, 4});
|
||||||
|
//NDArray<float> paddings('c', {3,2}, {0,0, 0,1, 0,0});
|
||||||
|
//NDArray<float> expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.});
|
||||||
|
NDArray expected = NDArrayFactory::create<float>('c', {1, 4, 5, 4}, {
|
||||||
|
1.f, 2.f, 3.f, 4.f,
|
||||||
|
2.6f, 3.6f, 4.6f, 5.6f,
|
||||||
|
5.f, 6.f, 7.f, 8.f,
|
||||||
|
7.4f, 8.4f, 9.4f, 10.4f,
|
||||||
|
9.f, 10.f, 11.f, 12.f,
|
||||||
|
|
||||||
|
4.f, 5.f, 6.f, 7.f,
|
||||||
|
5.6f, 6.6f, 7.6f, 8.6f,
|
||||||
|
8.f, 9.f, 10.f, 11.f,
|
||||||
|
10.4f, 11.4f, 12.4f, 13.4f,
|
||||||
|
12.f, 13.f, 14.f, 15.f,
|
||||||
|
|
||||||
|
10.f, 11.f, 12.f, 13.f,
|
||||||
|
11.6f, 12.6f, 13.6f, 14.6f,
|
||||||
|
14.f, 15.f, 16.f, 17.f,
|
||||||
|
16.4f, 17.4f, 18.4f, 19.4f,
|
||||||
|
18.f, 19.f, 20.f, 21.f,
|
||||||
|
|
||||||
|
13.f, 14.f, 15.f, 16.f,
|
||||||
|
14.6f, 15.6f, 16.6f, 17.6f,
|
||||||
|
17.f, 18.f, 19.f, 20.f,
|
||||||
|
19.4f, 20.4f, 21.4f, 22.4f,
|
||||||
|
21.f, 22.f, 23.f, 24.f
|
||||||
|
});
|
||||||
|
//input = 1.f;
|
||||||
|
input.linspace(1);
|
||||||
|
|
||||||
|
nd4j::ops::resize_bilinear op;
|
||||||
|
auto results = op.execute({&input}, {}, {4, 5}, {false, true});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
NDArray* result = results->at(0);
|
||||||
|
|
||||||
|
// result->printBuffer("Resized to 4x5");
|
||||||
|
// expected.printBuffer("Expect for 4x5");
|
||||||
|
ASSERT_TRUE(expected.isSameShape(result));
|
||||||
|
ASSERT_TRUE(expected.equalsTo(result));
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test01) {
|
TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test01) {
|
||||||
|
|
||||||
NDArray input = NDArrayFactory::create<double>('c', {2,3,4});
|
NDArray input = NDArrayFactory::create<double>('c', {2,3,4});
|
||||||
|
@ -1857,7 +1990,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test3) {
|
||||||
input.linspace(1);
|
input.linspace(1);
|
||||||
|
|
||||||
nd4j::ops::resize_bilinear op;
|
nd4j::ops::resize_bilinear op;
|
||||||
auto results = op.execute({&input}, {}, {10, 10, 1});
|
auto results = op.execute({&input}, {}, {10, 10}, {true});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -1986,7 +2119,7 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test4) {
|
||||||
input.linspace(1);
|
input.linspace(1);
|
||||||
|
|
||||||
nd4j::ops::resize_bilinear op;
|
nd4j::ops::resize_bilinear op;
|
||||||
auto results = op.execute({&input, &size}, {}, {1});
|
auto results = op.execute({&input, &size}, {}, {}, {true});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -2023,7 +2156,56 @@ TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test1) {
|
||||||
NDArray input = NDArrayFactory::create<double>('c', {1, 2, 3, 4});
|
NDArray input = NDArrayFactory::create<double>('c', {1, 2, 3, 4});
|
||||||
//NDArray<float> paddings('c', {3,2}, {0,0, 0,1, 0,0});
|
//NDArray<float> paddings('c', {3,2}, {0,0, 0,1, 0,0});
|
||||||
//NDArray<float> expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.});
|
//NDArray<float> expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.});
|
||||||
NDArray expected = NDArrayFactory::create<double>('c', {1, 4, 5, 4}, { 1, 2, 3, 4,
|
NDArray expected = NDArrayFactory::create<double>('c', {1, 4, 5, 4}, {
|
||||||
|
1, 2, 3, 4,
|
||||||
|
1, 2, 3, 4,
|
||||||
|
5, 6, 7, 8,
|
||||||
|
5, 6, 7, 8,
|
||||||
|
9, 10, 11, 12,
|
||||||
|
|
||||||
|
1, 2, 3, 4,
|
||||||
|
1, 2, 3, 4,
|
||||||
|
5, 6, 7, 8,
|
||||||
|
5, 6, 7, 8,
|
||||||
|
9, 10, 11, 12,
|
||||||
|
|
||||||
|
13, 14, 15, 16,
|
||||||
|
13, 14, 15, 16,
|
||||||
|
17, 18, 19, 20,
|
||||||
|
17, 18, 19, 20,
|
||||||
|
21, 22, 23, 24,
|
||||||
|
|
||||||
|
13, 14, 15, 16,
|
||||||
|
13, 14, 15, 16,
|
||||||
|
17, 18, 19, 20,
|
||||||
|
17, 18, 19, 20,
|
||||||
|
21, 22, 23, 24
|
||||||
|
});
|
||||||
|
//input = 1.f;
|
||||||
|
input.linspace(1);
|
||||||
|
|
||||||
|
nd4j::ops::resize_nearest_neighbor op;
|
||||||
|
auto results = op.execute({&input}, {}, {4, 5}, {false, false});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
NDArray* result = results->at(0);
|
||||||
|
|
||||||
|
// result->printIndexedBuffer("Resized to 4x5");
|
||||||
|
// expected.printIndexedBuffer("Expect for 4x5");
|
||||||
|
ASSERT_TRUE(expected.isSameShape(result));
|
||||||
|
ASSERT_TRUE(expected.equalsTo(result));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test1_1) {
|
||||||
|
|
||||||
|
NDArray input = NDArrayFactory::create<int>('c', {1, 2, 3, 4});
|
||||||
|
//NDArray<float> paddings('c', {3,2}, {0,0, 0,1, 0,0});
|
||||||
|
//NDArray<float> expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.});
|
||||||
|
NDArray expected = NDArrayFactory::create<int>('c', {1, 4, 5, 4}, {
|
||||||
|
1, 2, 3, 4,
|
||||||
1, 2, 3, 4,
|
1, 2, 3, 4,
|
||||||
5, 6, 7, 8,
|
5, 6, 7, 8,
|
||||||
5, 6, 7, 8,
|
5, 6, 7, 8,
|
||||||
|
@ -2065,47 +2247,48 @@ TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test1) {
|
||||||
delete results;
|
delete results;
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test1_1) {
|
TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test1_1_1) {
|
||||||
|
|
||||||
NDArray input = NDArrayFactory::create<int>('c', {1, 2, 3, 4});
|
NDArray input = NDArrayFactory::create<float>('c', {1, 2, 3, 4});
|
||||||
//NDArray<float> paddings('c', {3,2}, {0,0, 0,1, 0,0});
|
//NDArray<float> paddings('c', {3,2}, {0,0, 0,1, 0,0});
|
||||||
//NDArray<float> expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.});
|
//NDArray<float> expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.});
|
||||||
NDArray expected = NDArrayFactory::create<int>('c', {1, 4, 5, 4}, { 1, 2, 3, 4,
|
NDArray expected = NDArrayFactory::create<float>('c', {1, 4, 5, 4}, {
|
||||||
1, 2, 3, 4,
|
1.f, 2.f, 3.f, 4.f,
|
||||||
5, 6, 7, 8,
|
1.f, 2.f, 3.f, 4.f,
|
||||||
5, 6, 7, 8,
|
5.f, 6.f, 7.f, 8.f,
|
||||||
9, 10, 11, 12,
|
9.f, 10.f, 11.f, 12.f,
|
||||||
|
9.f, 10.f, 11.f, 12.f,
|
||||||
|
|
||||||
1, 2, 3, 4,
|
1.f, 2.f, 3.f, 4.f,
|
||||||
1, 2, 3, 4,
|
1.f, 2.f, 3.f, 4.f,
|
||||||
5, 6, 7, 8,
|
5.f, 6.f, 7.f, 8.f,
|
||||||
5, 6, 7, 8,
|
9.f, 10.f, 11.f, 12.f,
|
||||||
9, 10, 11, 12,
|
9.f, 10.f, 11.f, 12.f,
|
||||||
|
|
||||||
13, 14, 15, 16,
|
13.f, 14.f, 15.f, 16.f,
|
||||||
13, 14, 15, 16,
|
13.f, 14.f, 15.f, 16.f,
|
||||||
17, 18, 19, 20,
|
17.f, 18.f, 19.f, 20.f,
|
||||||
17, 18, 19, 20,
|
21.f, 22.f, 23.f, 24.f,
|
||||||
21, 22, 23, 24,
|
21.f, 22.f, 23.f, 24.f,
|
||||||
|
|
||||||
13, 14, 15, 16,
|
13.f, 14.f, 15.f, 16.f,
|
||||||
13, 14, 15, 16,
|
13.f, 14.f, 15.f, 16.f,
|
||||||
17, 18, 19, 20,
|
17.f, 18.f, 19.f, 20.f,
|
||||||
17, 18, 19, 20,
|
21.f, 22.f, 23.f, 24.f,
|
||||||
21, 22, 23, 24
|
21.f, 22.f, 23.f, 24.f
|
||||||
});
|
});
|
||||||
//input = 1.f;
|
//input = 1.f;
|
||||||
input.linspace(1);
|
input.linspace(1);
|
||||||
|
|
||||||
nd4j::ops::resize_nearest_neighbor op;
|
nd4j::ops::resize_nearest_neighbor op;
|
||||||
auto results = op.execute({&input}, {}, {4, 5});
|
auto results = op.execute({&input}, {}, {4,5}, {false, true});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
NDArray* result = results->at(0);
|
NDArray* result = results->at(0);
|
||||||
|
|
||||||
// result->printIndexedBuffer("Resized to 4x5");
|
// result->printIndexedBuffer("Resized to 4x5");
|
||||||
// expected.printIndexedBuffer("Expect for 4x5");
|
// expected.printBuffer("Expect for 4x5");
|
||||||
ASSERT_TRUE(expected.isSameShape(result));
|
ASSERT_TRUE(expected.isSameShape(result));
|
||||||
ASSERT_TRUE(expected.equalsTo(result));
|
ASSERT_TRUE(expected.equalsTo(result));
|
||||||
|
|
||||||
|
@ -2533,7 +2716,7 @@ TEST_F(DeclarableOpsTests10, Image_CropAndResize_3) {
|
||||||
NDArray cropSize = NDArrayFactory::create<Nd4jLong>({3, 3});
|
NDArray cropSize = NDArrayFactory::create<Nd4jLong>({3, 3});
|
||||||
|
|
||||||
//NDArray<float> ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f});
|
//NDArray<float> ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f});
|
||||||
NDArray expected('c', {1,3,3,1}, {1, 1.5f, 2., 2.f, 2.5f, 3.f, 3.f, 3.5f, 4.f}, nd4j::DataType::FLOAT32);
|
NDArray expected('c', {1,3,3,1}, {1.f, 1.5f, 2., 2.f, 2.5f, 3.f, 3.f, 3.5f, 4.f}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::crop_and_resize op;
|
nd4j::ops::crop_and_resize op;
|
||||||
auto results = op.execute({&images, &boxes, &boxI, &cropSize}, {}, {0});
|
auto results = op.execute({&images, &boxes, &boxI, &cropSize}, {}, {0});
|
||||||
|
@ -2557,7 +2740,7 @@ TEST_F(DeclarableOpsTests10, Image_CropAndResize_4) {
|
||||||
NDArray cropSize = NDArrayFactory::create<int>({3, 3});
|
NDArray cropSize = NDArrayFactory::create<int>({3, 3});
|
||||||
|
|
||||||
//NDArray<float> ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f});
|
//NDArray<float> ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f});
|
||||||
NDArray expected('c', {1,3,3,1}, {1, 2.f, 2.f, 3.f, 4, 4.f, 3.f, 4.f, 4.f}, nd4j::DataType::FLOAT32);
|
NDArray expected('c', {1,3,3,1}, {1.f, 2.f, 2.f, 3.f, 4, 4.f, 3.f, 4.f, 4.f}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::crop_and_resize op;
|
nd4j::ops::crop_and_resize op;
|
||||||
auto results = op.execute({&images, &boxes, &boxI, &cropSize}, {}, {1});
|
auto results = op.execute({&images, &boxes, &boxI, &cropSize}, {}, {1});
|
||||||
|
@ -2726,7 +2909,7 @@ TEST_F(DeclarableOpsTests10, Image_DrawBoundingBoxes_3) {
|
||||||
TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_1) {
|
TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_1) {
|
||||||
|
|
||||||
NDArray x('c', {2,3}, {-63.80f, -63.75f, -63.70f, -63.5f, 0.0f, 0.1f}, nd4j::DataType::FLOAT32);
|
NDArray x('c', {2,3}, {-63.80f, -63.75f, -63.70f, -63.5f, 0.0f, 0.1f}, nd4j::DataType::FLOAT32);
|
||||||
NDArray exp('c', {2,3}, {-63.75, -63.75, -63.75, -63.5, 0., 0.}, nd4j::DataType::FLOAT32);
|
NDArray exp('c', {2,3}, {-63.75f, -63.75f, -63.75f, -63.5f, 0.f, 0.f}, nd4j::DataType::FLOAT32);
|
||||||
NDArray min('c', {}, {-63.65f}, nd4j::DataType::FLOAT32);
|
NDArray min('c', {}, {-63.65f}, nd4j::DataType::FLOAT32);
|
||||||
NDArray max('c', {}, {0.1f}, nd4j::DataType::FLOAT32);
|
NDArray max('c', {}, {0.1f}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
@ -2971,22 +3154,6 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_5) {
|
||||||
delete results;
|
delete results;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* public void testFakeQuantAgainstTF_1() {
|
|
||||||
INDArray x = Nd4j.createFromArray(new float[]{ 0.7788f,0.8012f, 0.7244f, 0.2309f,0.7271f,
|
|
||||||
0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f,
|
|
||||||
0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f}).reshape(3,5);
|
|
||||||
INDArray min = Nd4j.createFromArray(new float[]{-0.2283f, -0.0719f, -0.0154f, -0.5162f, -0.3567f}).reshape(1,5);
|
|
||||||
INDArray max = Nd4j.createFromArray(new float[]{0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f}).reshape(1,5);
|
|
||||||
|
|
||||||
INDArray out = Nd4j.createUninitialized(x.shape());
|
|
||||||
val op = new FakeQuantWithMinMaxVarsPerChannel(x,min,max,out);
|
|
||||||
|
|
||||||
INDArray expected = Nd4j.createFromArray(new float[]{0.7801f, 0.5966f, 0.7260f, 0.2320f, 0.5084f,
|
|
||||||
0.1800f, 0.5046f, 0.8684f, 0.3513f, 0.5084f,
|
|
||||||
0.0877f, 0.5966f, 0.6600f, 0.3513f, 0.1604f}).reshape(3,5);
|
|
||||||
|
|
||||||
assertEquals(expected, out);
|
|
||||||
}*/
|
|
||||||
TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_6) {
|
TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_6) {
|
||||||
NDArray x = NDArrayFactory::create<float>('c', {3, 5}, {0.7788f,0.8012f, 0.7244f, 0.2309f,0.7271f,
|
NDArray x = NDArrayFactory::create<float>('c', {3, 5}, {0.7788f,0.8012f, 0.7244f, 0.2309f,0.7271f,
|
||||||
0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f,
|
0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f,
|
||||||
|
@ -3094,12 +3261,12 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_8) {
|
||||||
TEST_F(DeclarableOpsTests10, batchnorm_test1) {
|
TEST_F(DeclarableOpsTests10, batchnorm_test1) {
|
||||||
|
|
||||||
NDArray input ('c', {2,4}, nd4j::DataType::FLOAT32);
|
NDArray input ('c', {2,4}, nd4j::DataType::FLOAT32);
|
||||||
NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32);
|
NDArray mean ('c', {4}, {1.05f, 1.15f, 1.2f, 1.3f}, nd4j::DataType::FLOAT32);
|
||||||
NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32);
|
NDArray variance('c', {4}, {0.5f, 0.7f, 0.9f, 1.1f}, nd4j::DataType::FLOAT32);
|
||||||
NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32);
|
NDArray gamma ('c', {4}, {-1.2f, 1.3f, -1.4f, 1.5f}, nd4j::DataType::FLOAT32);
|
||||||
NDArray beta ('c', {4}, {10, 20, -10, -20}, nd4j::DataType::FLOAT32);
|
NDArray beta ('c', {4}, {10.f, 20.f, -10.f, -20.f}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
NDArray expected('c', {2,4}, {11.61218734, 18.52390321, -8.67185076, -21.28716864, 10.93337162, 19.14541765, -9.26213931, -20.71509369}, nd4j::DataType::FLOAT32);
|
NDArray expected('c', {2,4}, {11.61218734f, 18.52390321f, -8.67185076f, -21.28716864f, 10.93337162f, 19.14541765f, -9.26213931f, -20.71509369f}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
input.linspace(0.1, 0.1);
|
input.linspace(0.1, 0.1);
|
||||||
|
|
||||||
|
@ -3211,19 +3378,19 @@ TYPED_TEST(TypedDeclarableOpsTests10, batchnorm_test4) {
|
||||||
TEST_F(DeclarableOpsTests10, batchnorm_test5) {
|
TEST_F(DeclarableOpsTests10, batchnorm_test5) {
|
||||||
|
|
||||||
NDArray input ('c', {2,4,2,2}, nd4j::DataType::FLOAT32);
|
NDArray input ('c', {2,4,2,2}, nd4j::DataType::FLOAT32);
|
||||||
NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32);
|
NDArray mean ('c', {4}, {1.05f, 1.15f, 1.2f, 1.3f}, nd4j::DataType::FLOAT32);
|
||||||
NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32);
|
NDArray variance('c', {4}, {0.5f, 0.7f, 0.9f, 1.1f}, nd4j::DataType::FLOAT32);
|
||||||
NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32);
|
NDArray gamma ('c', {4}, {-1.2f, 1.3f, -1.4f, 1.5f}, nd4j::DataType::FLOAT32);
|
||||||
NDArray beta ('c', {4}, {10, 20, -10, -20}, nd4j::DataType::FLOAT32);
|
NDArray beta ('c', {4}, {10.f, 20.f, -10.f, -20.f}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
NDArray expected('c', {2,4,2,2}, {11.612187, 11.442483, 11.272779, 11.103076, 18.990039, 19.145418, 19.300796, 19.456175, -9.557284, -9.704856, -9.852428, -10., -20.,
|
NDArray expected('c', {2,4,2,2}, { 11.612187f, 11.442483f, 11.272779f, 11.103076f, 18.990039f, 19.145418f, 19.300796f, 19.456175f, -9.557284f, -9.704856f, -9.852428f, -10.f, -20.f,
|
||||||
-19.856981, -19.713963, -19.570944, 8.896924, 8.727221, 8.557517, 8.387813, 21.476097, 21.631475, 21.786854, 21.942233, -11.918438,
|
-19.856981f, -19.713963f, -19.570944f, 8.896924f, 8.727221f, 8.557517f, 8.387813f, 21.476097f, 21.631475f, 21.786854f, 21.942233f, -11.918438f,
|
||||||
-12.06601 , -12.213582, -12.361154, -17.7117, -17.568681, -17.425663, -17.282644}, nd4j::DataType::FLOAT32);
|
-12.06601f, -12.213582f, -12.361154f, -17.7117f, -17.568681f, -17.425663f, -17.282644f}, nd4j::DataType::FLOAT32);
|
||||||
input.linspace(0.1, 0.1);
|
input.linspace(0.1, 0.1);
|
||||||
|
|
||||||
nd4j::ops::batchnorm op;
|
nd4j::ops::batchnorm op;
|
||||||
|
|
||||||
auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1,1});
|
auto results = op.execute({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1, 1, 1});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -3240,14 +3407,14 @@ TEST_F(DeclarableOpsTests10, batchnorm_test5) {
|
||||||
TEST_F(DeclarableOpsTests10, batchnorm_test6) {
|
TEST_F(DeclarableOpsTests10, batchnorm_test6) {
|
||||||
|
|
||||||
NDArray input ('c', {2,2,2,4}, nd4j::DataType::FLOAT32);
|
NDArray input ('c', {2,2,2,4}, nd4j::DataType::FLOAT32);
|
||||||
NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, nd4j::DataType::FLOAT32);
|
NDArray mean ('c', {4}, {1.05f, 1.15f, 1.2f, 1.3f}, nd4j::DataType::FLOAT32);
|
||||||
NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, nd4j::DataType::FLOAT32);
|
NDArray variance('c', {4}, {0.5f, 0.7f, 0.9, 1.1f}, nd4j::DataType::FLOAT32);
|
||||||
NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, nd4j::DataType::FLOAT32);
|
NDArray gamma ('c', {4}, {-1.2f, 1.3f, -1.4f, 1.5f}, nd4j::DataType::FLOAT32);
|
||||||
NDArray beta ('c', {4}, {10, 20, -10, -20}, nd4j::DataType::FLOAT32);
|
NDArray beta ('c', {4}, {10.f, 20.f, -10.f, -20.f}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
NDArray expected('c', {2,2,2,4}, {11.612187, 18.523903, -8.671851, -21.287169, 10.933372, 19.145418, -9.262139, -20.715094, 10.254556, 19.766932, -9.852428, -20.143019, 9.57574 ,
|
NDArray expected('c', {2,2,2,4}, {11.612187f, 18.523903f, -8.671851f, -21.287169f, 10.933372f, 19.145418f, -9.262139f, -20.715094f, 10.254556f, 19.766932f, -9.852428f, -20.143019f, 9.57574f,
|
||||||
20.388447, -10.442716, -19.570944,8.896924, 21.009961, -11.033005, -18.998869, 8.218109, 21.631475, -11.623294, -18.426794, 7.539293, 22.25299 ,
|
20.388447f, -10.442716f, -19.570944f, 8.896924f, 21.009961f, -11.033005f, -18.998869f, 8.218109f, 21.631475f, -11.623294f, -18.426794f, 7.539293f, 22.25299f,
|
||||||
-12.213582, -17.854719, 6.860477, 22.874504, -12.803871, -17.282644}, nd4j::DataType::FLOAT32);
|
-12.213582f, -17.854719f, 6.860477f, 22.874504f, -12.803871f, -17.282644f}, nd4j::DataType::FLOAT32);
|
||||||
input.linspace(0.1, 0.1);
|
input.linspace(0.1, 0.1);
|
||||||
|
|
||||||
nd4j::ops::batchnorm op;
|
nd4j::ops::batchnorm op;
|
||||||
|
@ -3270,7 +3437,7 @@ TEST_F(DeclarableOpsTests10, bool_broadcast_test_1) {
|
||||||
NDArray arr1('c', {2,2,1}, {1, 2, 3, 4}, nd4j::DataType::INT32);
|
NDArray arr1('c', {2,2,1}, {1, 2, 3, 4}, nd4j::DataType::INT32);
|
||||||
NDArray arr2('c', { 2,2}, {0, 1, 0, 4}, nd4j::DataType::INT32);
|
NDArray arr2('c', { 2,2}, {0, 1, 0, 4}, nd4j::DataType::INT32);
|
||||||
|
|
||||||
NDArray expd('c', {2,2,2}, {0,1,0,0, 0,0,0,1}, nd4j::DataType::BOOL);
|
NDArray expd('c', {2,2,2}, {false, true, false, false, false, false, false, true}, nd4j::DataType::BOOL);
|
||||||
|
|
||||||
NDArray result('c', {2,2,2}, nd4j::DataType::BOOL);
|
NDArray result('c', {2,2,2}, nd4j::DataType::BOOL);
|
||||||
|
|
||||||
|
|
|
@ -1257,7 +1257,7 @@ TEST_F(DeclarableOpsTests12, inTopK_2) {
|
||||||
auto input = NDArrayFactory::create<double>('c', {4, 5});
|
auto input = NDArrayFactory::create<double>('c', {4, 5});
|
||||||
auto idx = NDArrayFactory::create<Nd4jLong>('c', {4});
|
auto idx = NDArrayFactory::create<Nd4jLong>('c', {4});
|
||||||
|
|
||||||
auto exp = NDArrayFactory::create<bool>({0, 0, 0, 1});
|
auto exp = NDArrayFactory::create<bool>({false, false, false, true});
|
||||||
|
|
||||||
int exclusive, reverse;
|
int exclusive, reverse;
|
||||||
input.linspace(1);
|
input.linspace(1);
|
||||||
|
@ -1318,7 +1318,7 @@ TEST_F(DeclarableOpsTests12, inTopK_4) {
|
||||||
TEST_F(DeclarableOpsTests12, inTopK_5) {
|
TEST_F(DeclarableOpsTests12, inTopK_5) {
|
||||||
auto x = NDArrayFactory::create<double>('f', {6, 4}, {11.0, 3.0, 14.0, 5.0, 6.0, 9.0, 3.5, 7.0, 21.0, 3.0, 14.0, 15.0, 6.0, 9.0, 3.5, 7.0, 11.0, 13.0, 14.0, 5.0, 16.0, 9.0, 13.5, 7.0} );
|
auto x = NDArrayFactory::create<double>('f', {6, 4}, {11.0, 3.0, 14.0, 5.0, 6.0, 9.0, 3.5, 7.0, 21.0, 3.0, 14.0, 15.0, 6.0, 9.0, 3.5, 7.0, 11.0, 13.0, 14.0, 5.0, 16.0, 9.0, 13.5, 7.0} );
|
||||||
auto y = NDArrayFactory::create<Nd4jLong>('f', {6}, {0, 0, 0, 0, 0, 0});
|
auto y = NDArrayFactory::create<Nd4jLong>('f', {6}, {0, 0, 0, 0, 0, 0});
|
||||||
auto expV = NDArrayFactory::create<bool>('f', {6}, {1, 0, 0, 0, 0, 0 });
|
auto expV = NDArrayFactory::create<bool>('f', {6}, {true, false, false, false, false, false });
|
||||||
|
|
||||||
nd4j::ops::in_top_k op;
|
nd4j::ops::in_top_k op;
|
||||||
auto result = op.execute({&x, &y}, {}, {2});
|
auto result = op.execute({&x, &y}, {}, {2});
|
||||||
|
|
|
@ -1167,12 +1167,12 @@ TEST_F(DeclarableOpsTests13, lstmLayer_3) {
|
||||||
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
||||||
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
||||||
|
|
||||||
NDArray expH('c', {sL, bS, nOut}, {0.493883, 0.493883, 0.493883, 0.510990, 0.510990, 0.510990, 0.534701, 0.534701, 0.534701, 0.549139,
|
NDArray expH('c', {sL, bS, nOut}, {0.493883f, 0.493883f, 0.493883f, 0.510990f, 0.510990f, 0.510990f, 0.534701f, 0.534701f, 0.534701f, 0.549139f,
|
||||||
0.549139, 0.549139, 0.571900, 0.571900, 0.571900, 0.583561, 0.583561, 0.583561, 0.605106, 0.605106,
|
0.549139f, 0.549139f, 0.571900f, 0.571900f, 0.571900f, 0.583561f, 0.583561f, 0.583561f, 0.605106f, 0.605106f,
|
||||||
0.605106, 0.614114, 0.614114, 0.614114, 0.635354, 0.635354, 0.635354, 0.642045, 0.642045, 0.642045}, nd4j::DataType::FLOAT32);
|
0.605106f, 0.614114f, 0.614114f, 0.614114f, 0.635354f, 0.635354f, 0.635354f, 0.642045f, 0.642045f, 0.642045f}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
NDArray expHL('c', {bS, nOut}, {0.493883, 0.493883, 0.493883, 0.510990, 0.510990, 0.510990}, nd4j::DataType::FLOAT32);
|
NDArray expHL('c', {bS, nOut}, {0.493883f, 0.493883f, 0.493883f, 0.510990f, 0.510990f, 0.510990f}, nd4j::DataType::FLOAT32);
|
||||||
NDArray expCL('c', {bS, nOut}, {1.061274, 1.061274, 1.061274, 1.115888, 1.115888, 1.115888}, nd4j::DataType::FLOAT32);
|
NDArray expCL('c', {bS, nOut}, {1.061274f, 1.061274f, 1.061274f, 1.115888f, 1.115888f, 1.115888f}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::lstmLayer op;
|
nd4j::ops::lstmLayer op;
|
||||||
auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs);
|
auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs);
|
||||||
|
@ -1230,12 +1230,12 @@ TEST_F(DeclarableOpsTests13, lstmLayer_4) {
|
||||||
NDArray cI('c', {2,bS, nOut}, nd4j::DataType::FLOAT32);
|
NDArray cI('c', {2,bS, nOut}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
x.linspace(0.5, 0.5);
|
x.linspace(0.5, 0.5);
|
||||||
Wx({0,1, 0,0, 0,0}) = 0.003;
|
Wx({0,1, 0,0, 0,0}) = 0.003f;
|
||||||
Wx({1,2, 0,0, 0,0}) = -0.003;
|
Wx({1,2, 0,0, 0,0}) = -0.003f;
|
||||||
Wr({0,1, 0,0, 0,0}) = 0.006;
|
Wr({0,1, 0,0, 0,0}) = 0.006f;
|
||||||
Wr({1,2, 0,0, 0,0}) = -0.006;
|
Wr({1,2, 0,0, 0,0}) = -0.006f;
|
||||||
b({0,1, 0,0}) = 0.5;
|
b({0,1, 0,0}) = 0.5f;
|
||||||
b({1,2, 0,0}) = -0.5;
|
b({1,2, 0,0}) = -0.5f;
|
||||||
hI({0,1, 0,0, 0,0}) = 1;
|
hI({0,1, 0,0, 0,0}) = 1;
|
||||||
hI({1,2, 0,0, 0,0}) = -1;
|
hI({1,2, 0,0, 0,0}) = -1;
|
||||||
cI({0,1, 0,0, 0,0}) = 2;
|
cI({0,1, 0,0, 0,0}) = 2;
|
||||||
|
@ -1245,18 +1245,19 @@ TEST_F(DeclarableOpsTests13, lstmLayer_4) {
|
||||||
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
||||||
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
||||||
|
|
||||||
NDArray expH('c', {sL, bS, 2*nOut}, {0.577661, 0.577661, 0.577661, -0.107642, -0.107642, -0.107642, 0.585289, 0.585289, 0.585289,
|
NDArray expH('c', {sL, bS, 2 * nOut}, {
|
||||||
-0.106937, -0.106937, -0.106937, 0.556517, 0.556517, 0.556517, -0.111647, -0.111647, -0.111647,
|
0.577661f, 0.577661f, 0.577661f, -0.107642f, -0.107642f, -0.107642f, 0.585289f, 0.585289f, 0.585289f,
|
||||||
0.567274, 0.567274, 0.567274, -0.110214, -0.110214, -0.110214, 0.547395, 0.547395, 0.547395,
|
-0.106937f, -0.106937f, -0.106937f, 0.556517f, 0.556517f, 0.556517f, -0.111647f, -0.111647f, -0.111647f,
|
||||||
-0.123305, -0.123305, -0.123305, 0.560640, 0.560640, 0.560640, -0.120862, -0.120862, -0.120862,
|
0.567274f, 0.567274f, 0.567274f, -0.110214f, -0.110214f, -0.110214f, 0.547395f, 0.547395f, 0.547395f,
|
||||||
0.550714, 0.550714, 0.550714, -0.156223, -0.156223, -0.156223, 0.565308, 0.565308, 0.565308,
|
-0.123305f, -0.123305f, -0.123305f, 0.560640f, 0.560640f, 0.560640f, -0.120862f, -0.120862f, -0.120862f,
|
||||||
-0.152313, -0.152313, -0.152313, 0.563741, 0.563741, 0.563741, -0.234128, -0.234128, -0.234128,
|
0.550714f, 0.550714f, 0.550714f, -0.156223f, -0.156223f, -0.156223f, 0.565308f, 0.565308f, 0.565308f,
|
||||||
0.578676, 0.578676, 0.578676, -0.228917, -0.228917, -0.228917}, nd4j::DataType::FLOAT32);
|
-0.152313f, -0.152313f, -0.152313f, 0.563741f, 0.563741f, 0.563741f, -0.234128f, -0.234128f, -0.234128f,
|
||||||
|
0.578676f, 0.578676f, 0.578676f, -0.228917f, -0.228917f, -0.228917f}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
NDArray expHL('c', {2,bS, nOut}, {0.563741, 0.563741, 0.563741, 0.578676, 0.578676, 0.578676, -0.107642,
|
NDArray expHL('c', {2,bS, nOut}, {0.563741f, 0.563741f, 0.563741f, 0.578676f, 0.578676f, 0.578676f, -0.107642f,
|
||||||
-0.107642, -0.107642, -0.106937, -0.106937, -0.106937}, nd4j::DataType::FLOAT32);
|
-0.107642f, -0.107642f, -0.106937f, -0.106937f, -0.106937f}, nd4j::DataType::FLOAT32);
|
||||||
NDArray expCL('c', {2,bS, nOut}, {1.217757, 1.217757, 1.217757, 1.272398, 1.272398, 1.272398, -0.295768,
|
NDArray expCL('c', {2,bS, nOut}, {1.217757f, 1.217757f, 1.217757f, 1.272398f, 1.272398f, 1.272398f, -0.295768f,
|
||||||
-0.295768, -0.295768, -0.298453, -0.298453, -0.298453}, nd4j::DataType::FLOAT32);
|
-0.295768f, -0.295768f, -0.298453f, -0.298453f, -0.298453f}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::lstmLayer op;
|
nd4j::ops::lstmLayer op;
|
||||||
auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs);
|
auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs);
|
||||||
|
@ -1328,16 +1329,17 @@ TEST_F(DeclarableOpsTests13, lstmLayer_5) {
|
||||||
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
||||||
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
||||||
|
|
||||||
NDArray expH('c', {bS, sL, 2*nOut}, {0.577661, 0.577661, 0.577661, -0.107659, -0.107659, -0.107659, 0.548099, 0.548099, 0.548099, -0.113406, -0.113406, -0.113406,
|
NDArray expH('c', {bS, sL, 2*nOut}, {
|
||||||
0.526881, 0.526881, 0.526881, -0.12883 , -0.12883 , -0.12883 , 0.515882, 0.515882, 0.515882, -0.16868 , -0.16868 , -0.16868 ,
|
0.577661f, 0.577661f, 0.577661f, -0.107659f, -0.107659f, -0.107659f, 0.548099f, 0.548099f, 0.548099f, -0.113406f, -0.113406f, -0.113406f,
|
||||||
0.51409 , 0.51409 , 0.51409 , -0.255185, -0.255185, -0.255185, 0.614599, 0.614599, 0.614599, -0.102739, -0.102739, -0.102739,
|
0.526881f, 0.526881f, 0.526881f, -0.12883f, -0.12883f, -0.12883f, 0.515882f, 0.515882f, 0.515882f, -0.16868f, -0.16868f, -0.16868f,
|
||||||
0.599572, 0.599572, 0.599572, -0.105802, -0.105802, -0.105802,0.591089, 0.591089, 0.591089, -0.116681, -0.116681, -0.116681,
|
0.51409f, 0.51409f, 0.51409f, -0.255185f, -0.255185f, -0.255185f, 0.614599f, 0.614599f, 0.614599f, -0.102739f, -0.102739f, -0.102739f,
|
||||||
0.588694, 0.588694, 0.588694, -0.149201, -0.149201, -0.149201,0.591492, 0.591492, 0.591492, -0.228917, -0.228917, -0.228917}, nd4j::DataType::FLOAT32);
|
0.599572f, 0.599572f, 0.599572f, -0.105802f, -0.105802f, -0.105802f, 0.591089f, 0.591089f, 0.591089f, -0.116681f, -0.116681f, -0.116681f,
|
||||||
|
0.588694f, 0.588694f, 0.588694f, -0.149201f, -0.149201f, -0.149201f, 0.591492f, 0.591492f, 0.591492f, -0.228917f, -0.228917f, -0.228917f}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
NDArray expHL('c', {2,bS, nOut}, {0.51409 , 0.51409 , 0.51409 , 0.591492, 0.591492, 0.591492,
|
NDArray expHL('c', {2,bS, nOut}, {0.51409f, 0.51409f, 0.51409f, 0.591492f, 0.591492f, 0.591492f,
|
||||||
-0.107659, -0.107659, -0.107659, -0.102739, -0.102739, -0.102739}, nd4j::DataType::FLOAT32);
|
-0.107659f, -0.107659f, -0.107659f, -0.102739f, -0.102739f, -0.102739f}, nd4j::DataType::FLOAT32);
|
||||||
NDArray expCL('c', {2,bS, nOut}, {1.07293 , 1.07293 , 1.07293,1.346609, 1.346609, 1.346609,
|
NDArray expCL('c', {2,bS, nOut}, {1.07293f , 1.07293f , 1.07293f, 1.346609f, 1.346609f, 1.346609f,
|
||||||
-0.295811, -0.295811, -0.295811,-0.305394, -0.305394, -0.305394}, nd4j::DataType::FLOAT32);
|
-0.295811f, -0.295811f, -0.295811f, -0.305394f, -0.305394f, -0.305394f}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::lstmLayer op;
|
nd4j::ops::lstmLayer op;
|
||||||
auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs);
|
auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs);
|
||||||
|
@ -1398,12 +1400,12 @@ TEST_F(DeclarableOpsTests13, lstmLayer_6) {
|
||||||
NDArray cI('c', {2,bS, nOut}, nd4j::DataType::FLOAT32);
|
NDArray cI('c', {2,bS, nOut}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
x.linspace(0.5, 0.5);
|
x.linspace(0.5, 0.5);
|
||||||
Wx({0,1, 0,0, 0,0}) = 0.003;
|
Wx({0,1, 0,0, 0,0}) = 0.003f;
|
||||||
Wx({1,2, 0,0, 0,0}) = -0.003;
|
Wx({1,2, 0,0, 0,0}) = -0.003f;
|
||||||
Wr({0,1, 0,0, 0,0}) = 0.006;
|
Wr({0,1, 0,0, 0,0}) = 0.006f;
|
||||||
Wr({1,2, 0,0, 0,0}) = -0.006;
|
Wr({1,2, 0,0, 0,0}) = -0.006f;
|
||||||
b({0,1, 0,0}) = 0.5;
|
b({0,1, 0,0}) = 0.5f;
|
||||||
b({1,2, 0,0}) = -0.5;
|
b({1,2, 0,0}) = -0.5f;
|
||||||
hI({0,1, 0,0, 0,0}) = 1;
|
hI({0,1, 0,0, 0,0}) = 1;
|
||||||
hI({1,2, 0,0, 0,0}) = -1;
|
hI({1,2, 0,0, 0,0}) = -1;
|
||||||
cI({0,1, 0,0, 0,0}) = 2;
|
cI({0,1, 0,0, 0,0}) = 2;
|
||||||
|
@ -1413,14 +1415,17 @@ TEST_F(DeclarableOpsTests13, lstmLayer_6) {
|
||||||
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
||||||
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
||||||
|
|
||||||
NDArray expH('c', {sL, bS, nOut}, {0.470019, 0.470019, 0.470019, 0.478352, 0.478352, 0.478352, 0.444871, 0.444871, 0.444871, 0.457060,
|
NDArray expH('c', {sL, bS, nOut}, {
|
||||||
0.457060, 0.457060, 0.424090, 0.424090, 0.424090, 0.439778, 0.439778, 0.439778, 0.394491, 0.394491,
|
0.470019f, 0.470019f, 0.470019f, 0.478352f, 0.478352f, 0.478352f, 0.444871f, 0.444871f, 0.444871f, 0.457060f,
|
||||||
0.394491, 0.412995, 0.412995, 0.412995, 0.329613, 0.329613, 0.329613, 0.349760, 0.349760, 0.349760}, nd4j::DataType::FLOAT32);
|
0.457060f, 0.457060f, 0.424090f, 0.424090f, 0.424090f, 0.439778f, 0.439778f, 0.439778f, 0.394491f, 0.394491f,
|
||||||
|
0.394491f, 0.412995f, 0.412995f, 0.412995f, 0.329613f, 0.329613f, 0.329613f, 0.349760f, 0.349760f, 0.349760f}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
NDArray expHL('c', {2,bS, nOut}, {0.563741, 0.563741, 0.563741, 0.578676, 0.578676, 0.578676, -0.107642,
|
NDArray expHL('c', {2,bS, nOut}, {0.563741f, 0.563741f, 0.563741f, 0.578676f, 0.578676f, 0.578676f,
|
||||||
-0.107642, -0.107642, -0.106937, -0.106937, -0.106937}, nd4j::DataType::FLOAT32);
|
-0.107642f, -0.107642f, -0.107642f, -0.106937f, -0.106937f, -0.106937f},
|
||||||
NDArray expCL('c', {2,bS, nOut}, {1.217757, 1.217757, 1.217757, 1.272398, 1.272398, 1.272398, -0.295768,
|
nd4j::DataType::FLOAT32);
|
||||||
-0.295768, -0.295768, -0.298453, -0.298453, -0.298453}, nd4j::DataType::FLOAT32);
|
NDArray expCL('c', {2,bS, nOut}, {1.217757f, 1.217757f, 1.217757f, 1.272398f, 1.272398f, 1.272398f,
|
||||||
|
-0.295768f, -0.295768f, -0.295768f, -0.298453f, -0.298453f, -0.298453f},
|
||||||
|
nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::lstmLayer op;
|
nd4j::ops::lstmLayer op;
|
||||||
auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs);
|
auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs);
|
||||||
|
@ -1568,12 +1573,13 @@ TEST_F(DeclarableOpsTests13, lstmLayer_8) {
|
||||||
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
||||||
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
||||||
|
|
||||||
NDArray expH('c', {sL, bS, nOut}, {0.436221, 0.436221, 0.436221,0.450573, 0.450573, 0.450573,0.463602, 0.463602, 0.463602, 0.474674, 0.474674, 0.474674,
|
NDArray expH('c', {sL, bS, nOut}, {
|
||||||
0.484039, 0.484039, 0.484039,0.490679, 0.490679, 0.490679, 0.494871, 0.494871, 0.494871, 0.499028, 0.499028, 0.499028,
|
0.436221f, 0.436221f, 0.436221f, 0.450573f, 0.450573f, 0.450573f, 0.463602f, 0.463602f, 0.463602f, 0.474674f, 0.474674f, 0.474674f,
|
||||||
0.504649, 0.504649, 0.504649, 0.508719, 0.508719, 0.508719}, nd4j::DataType::FLOAT32);
|
0.484039f, 0.484039f, 0.484039f, 0.490679f, 0.490679f, 0.490679f, 0.494871f, 0.494871f, 0.494871f, 0.499028f, 0.499028f, 0.499028f,
|
||||||
|
0.504649f, 0.504649f, 0.504649f, 0.508719f, 0.508719f, 0.508719f}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
NDArray expHL('c', {bS, nOut}, {0.436221, 0.436221, 0.436221, 0.450573, 0.450573, 0.450573}, nd4j::DataType::FLOAT32);
|
NDArray expHL('c', {bS, nOut}, {0.436221f, 0.436221f, 0.436221f, 0.450573f, 0.450573f, 0.450573f}, nd4j::DataType::FLOAT32);
|
||||||
NDArray expCL('c', {bS, nOut}, {0.879804, 0.879804, 0.879804,0.914666, 0.914666, 0.914666}, nd4j::DataType::FLOAT32);
|
NDArray expCL('c', {bS, nOut}, {0.879804f, 0.879804f, 0.879804f, 0.914666f, 0.914666f, 0.914666f}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::lstmLayer op;
|
nd4j::ops::lstmLayer op;
|
||||||
auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
|
auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
|
||||||
|
@ -1650,16 +1656,17 @@ TEST_F(DeclarableOpsTests13, lstmLayer_9) {
|
||||||
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
||||||
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
||||||
|
|
||||||
NDArray expH('c', {sL, bS, 2*nOut}, { 0.55533 , 0.55533 , 0.55533 , -0.104502, -0.104502, -0.104502, 0.562925, 0.562925, 0.562925, -0.103843, -0.103843, -0.103843,
|
NDArray expH('c', {sL, bS, 2*nOut}, {
|
||||||
0.531795, 0.531795, 0.531795, -0.107456, -0.107456, -0.107456,0.542556, 0.542556, 0.542556, -0.106139, -0.106139, -0.106139,
|
0.55533f, 0.55533f, 0.55533f, -0.104502f, -0.104502f, -0.104502f, 0.562925f, 0.562925f, 0.562925f, -0.103843f, -0.103843f, -0.103843f,
|
||||||
0.521466, 0.521466, 0.521466, -0.11681 , -0.11681 , -0.11681 , 0.534638, 0.534638, 0.534638, -0.11458 , -0.11458 , -0.11458 ,
|
0.531795f, 0.531795f, 0.531795f, -0.107456f, -0.107456f, -0.107456f, 0.542556f, 0.542556f, 0.542556f, -0.106139f, -0.106139f, -0.106139f,
|
||||||
0.524805, 0.524805, 0.524805, -0.145177, -0.145177, -0.145177,0.539187, 0.539187, 0.539187, -0.14157 , -0.14157 , -0.14157 ,
|
0.521466f, 0.521466f, 0.521466f, -0.11681f, -0.11681f, -0.11681f, 0.534638f, 0.534638f, 0.534638f, -0.11458f, -0.11458f, -0.11458f,
|
||||||
0.538309, 0.538309, 0.538309, -0.218056, -0.218056, -0.218056,0.552923, 0.552923, 0.552923, -0.213068, -0.213068, -0.213068}, nd4j::DataType::FLOAT32);
|
0.524805f, 0.524805f, 0.524805f, -0.145177f, -0.145177f, -0.145177f, 0.539187f, 0.539187f, 0.539187f, -0.14157f, -0.14157f, -0.14157f,
|
||||||
|
0.538309f, 0.538309f, 0.538309f, -0.218056f, -0.218056f, -0.218056f, 0.552923f, 0.552923f, 0.552923f, -0.213068f, -0.213068f, -0.213068f}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
NDArray expHL('c', {2,bS, nOut}, {0.538309, 0.538309, 0.538309, 0.552923, 0.552923, 0.552923, -0.104502, -0.104502, -0.104502,
|
NDArray expHL('c', {2,bS, nOut}, {0.538309f, 0.538309f, 0.538309f, 0.552923f, 0.552923f, 0.552923f, -0.104502f, -0.104502f, -0.104502f,
|
||||||
-0.103843, -0.103843, -0.103843}, nd4j::DataType::FLOAT32);
|
-0.103843f, -0.103843f, -0.103843f}, nd4j::DataType::FLOAT32);
|
||||||
NDArray expCL('c', {2,bS, nOut}, {1.147089, 1.147089, 1.147089, 1.197228, 1.197228, 1.197228, -0.289425, -0.289425, -0.289425,
|
NDArray expCL('c', {2,bS, nOut}, {1.147089f, 1.147089f, 1.147089f, 1.197228f, 1.197228f, 1.197228f, -0.289425f, -0.289425f, -0.289425f,
|
||||||
-0.292174, -0.292174, -0.292174}, nd4j::DataType::FLOAT32);
|
-0.292174f, -0.292174f, -0.292174f}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::lstmLayer op;
|
nd4j::ops::lstmLayer op;
|
||||||
auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
|
auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
|
||||||
|
@ -1731,14 +1738,20 @@ TEST_F(DeclarableOpsTests13, lstmLayer_10) {
|
||||||
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
||||||
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
||||||
|
|
||||||
NDArray expH('c', {sL, bS, nOut}, {0., 0., 0., 0.562925, 0.562925, 0.562925, 0.570404, 0.570404, 0.570404, 0.57777 , 0.57777 , 0.57777 , 0.585023, 0.585023, 0.585023,
|
NDArray expH('c', {sL, bS, nOut}, {
|
||||||
0., 0., 0., 0., 0., 0., 0.576568, 0.576568, 0.576568, 0.586163, 0.586163, 0.586163, 0.595462, 0.595462, 0.595462, 0., 0., 0., 0., 0.,
|
0.f, 0.f, 0.f, 0.562925f, 0.562925f, 0.562925f, 0.570404f, 0.570404f, 0.570404f, 0.57777f,
|
||||||
0., 0., 0., 0., 0.611224, 0.611224, 0.611224, 0.621298, 0.621298, 0.621298, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
|
0.57777f, 0.57777f, 0.585023f, 0.585023f, 0.585023f, 0.f, 0.f, 0.f, 0.f, 0.f,
|
||||||
0.655858, 0.655858, 0.655858, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.692315, 0.692315, 0.692315, 0., 0., 0., 0., 0., 0.,
|
0.f, 0.576568f, 0.576568f, 0.576568f, 0.586163f, 0.586163f, 0.586163f, 0.595462f, 0.595462f, 0.595462f,
|
||||||
0., 0., 0., 0., 0., 0., 0., 0., 0.}, nd4j::DataType::FLOAT32);
|
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.611224f,
|
||||||
|
0.611224f, 0.611224f, 0.621298f, 0.621298f, 0.621298f, 0.f, 0.f, 0.f, 0.f, 0.f,
|
||||||
|
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.655858f, 0.655858f, 0.655858f,
|
||||||
|
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,
|
||||||
|
0.f, 0.f, 0.692315f, 0.692315f, 0.692315f, 0.f, 0.f, 0.f, 0.f, 0.f,
|
||||||
|
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f},
|
||||||
|
nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
NDArray expHL('c', {bS, nOut}, {0., 0., 0., 0.562925, 0.562925, 0.562925, 0.576568, 0.576568, 0.576568, 0.611224, 0.611224, 0.611224, 0.692315, 0.692315, 0.692315}, nd4j::DataType::FLOAT32);
|
NDArray expHL('c', {bS, nOut}, {0.f, 0.f, 0.f, 0.562925f, 0.562925f, 0.562925f, 0.576568f, 0.576568f, 0.576568f, 0.611224f, 0.611224f, 0.611224f, 0.692315f, 0.692315f, 0.692315f}, nd4j::DataType::FLOAT32);
|
||||||
NDArray expCL('c', {bS, nOut}, {0., 0., 0., 1.534275, 1.534275, 1.534275, 1.40183, 1.40183, 1.40183, 1.449675, 1.449675, 1.449675, 1.767702, 1.767702, 1.767702}, nd4j::DataType::FLOAT32);
|
NDArray expCL('c', {bS, nOut}, {0.f, 0.f, 0.f, 1.534275f, 1.534275f, 1.534275f, 1.40183f, 1.40183f, 1.40183f, 1.449675f, 1.449675f, 1.449675f, 1.767702f, 1.767702f, 1.767702f}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::lstmLayer op;
|
nd4j::ops::lstmLayer op;
|
||||||
auto results = op.execute({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
|
auto results = op.execute({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
|
||||||
|
@ -1799,25 +1812,26 @@ TEST_F(DeclarableOpsTests13, lstmLayer_11) {
|
||||||
NDArray Wp('c', {3*nOut}, nd4j::DataType::FLOAT32);
|
NDArray Wp('c', {3*nOut}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
x.linspace(0.5, 0.5);
|
x.linspace(0.5, 0.5);
|
||||||
Wx = 0.003;
|
Wx = 0.003f;
|
||||||
Wr = 0.006;
|
Wr = 0.006f;
|
||||||
b = 0.5;
|
b = 0.5f;
|
||||||
hI = 1.;
|
hI = 1.f;
|
||||||
cI = 2.;
|
cI = 2.f;
|
||||||
Wp = -0.05;
|
Wp = -0.05f;
|
||||||
|
|
||||||
std::initializer_list<double> tArgs = {cellClip};
|
std::initializer_list<double> tArgs = {cellClip};
|
||||||
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
||||||
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
||||||
|
|
||||||
NDArray expH('c', {sL, bS, nOut}, {0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.61209,
|
NDArray expH('c', {sL, bS, nOut}, {
|
||||||
0.61209, 0.61209,0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.652042, 0.652042, 0.652042, 0., 0., 0., 0., 0.,
|
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.61209f,
|
||||||
0., 0., 0., 0., 0.677708, 0.677708, 0.677708, 0.684177, 0.684177, 0.684177, 0., 0., 0.,0., 0., 0.,0.699627, 0.699627,
|
0.61209f, 0.61209f,0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.652042f, 0.652042f, 0.652042f, 0.f, 0.f, 0.f, 0.f, 0.f,
|
||||||
0.699627,0.705371, 0.705371, 0.705371,0.710989, 0.710989, 0.710989, 0., 0., 0., 0.719014, 0.719014, 0.719014, 0.724087,
|
0.f, 0.f, 0.f, 0.f, 0.677708f, 0.677708f, 0.677708f, 0.684177f, 0.684177f, 0.684177f, 0.f, 0.f, 0.f,0.f, 0.f, 0.f, 0.699627f, 0.699627f,
|
||||||
0.724087, 0.724087, 0.729084, 0.729084, 0.729084, 0.734004, 0.734004, 0.734004 }, nd4j::DataType::FLOAT32);
|
0.699627f, 0.705371f, 0.705371f, 0.705371f, 0.710989f, 0.710989f, 0.710989f, 0., 0., 0., 0.719014, 0.719014, 0.719014, 0.724087,
|
||||||
|
0.724087f, 0.724087f, 0.729084f, 0.729084f, 0.729084f, 0.734004f, 0.734004f, 0.734004f }, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
NDArray expHL('c', {bS, nOut}, {0., 0., 0., 0.719014, 0.719014, 0.719014, 0.699627, 0.699627, 0.699627, 0.677708, 0.677708, 0.677708, 0.61209, 0.61209, 0.61209}, nd4j::DataType::FLOAT32);
|
NDArray expHL('c', {bS, nOut}, {0.f, 0.f, 0.f, 0.719014f, 0.719014f, 0.719014f, 0.699627f, 0.699627f, 0.699627f, 0.677708f, 0.677708f, 0.677708f, 0.61209f, 0.61209f, 0.61209f}, nd4j::DataType::FLOAT32);
|
||||||
NDArray expCL('c', {bS, nOut}, {0., 0., 0., 2.092814, 2.092814, 2.092814, 2.08832, 2.08832, 2.08832, 2.009851, 2.009851, 2.009851, 1.646034, 1.646034, 1.646034}, nd4j::DataType::FLOAT32);
|
NDArray expCL('c', {bS, nOut}, {0.f, 0.f, 0.f, 2.092814f, 2.092814f, 2.092814f, 2.08832f, 2.08832f, 2.08832f, 2.009851f, 2.009851f, 2.009851f, 1.646034f, 1.646034f, 1.646034f}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::lstmLayer op;
|
nd4j::ops::lstmLayer op;
|
||||||
auto results = op.execute({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
|
auto results = op.execute({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
|
||||||
|
@ -1878,18 +1892,18 @@ TEST_F(DeclarableOpsTests13, lstmLayer_12) {
|
||||||
NDArray Wp('c', {2,3*nOut}, nd4j::DataType::FLOAT32);
|
NDArray Wp('c', {2,3*nOut}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
x.linspace(0.5, 0.5);
|
x.linspace(0.5, 0.5);
|
||||||
Wx({0,1, 0,0, 0,0}) = 0.003;
|
Wx({0,1, 0,0, 0,0}) = 0.003f;
|
||||||
Wx({1,2, 0,0, 0,0}) = -0.003;
|
Wx({1,2, 0,0, 0,0}) = -0.003f;
|
||||||
Wr({0,1, 0,0, 0,0}) = 0.006;
|
Wr({0,1, 0,0, 0,0}) = 0.006f;
|
||||||
Wr({1,2, 0,0, 0,0}) = -0.006;
|
Wr({1,2, 0,0, 0,0}) = -0.006f;
|
||||||
b({0,1, 0,0}) = 0.5;
|
b({0,1, 0,0}) = 0.5f;
|
||||||
b({1,2, 0,0}) = -0.5;
|
b({1,2, 0,0}) = -0.5f;
|
||||||
hI({0,1, 0,0, 0,0}) = 1;
|
hI({0,1, 0,0, 0,0}) = 1;
|
||||||
hI({1,2, 0,0, 0,0}) = -1;
|
hI({1,2, 0,0, 0,0}) = -1;
|
||||||
cI({0,1, 0,0, 0,0}) = 2;
|
cI({0,1, 0,0, 0,0}) = 2;
|
||||||
cI({1,2, 0,0, 0,0}) = -2;
|
cI({1,2, 0,0, 0,0}) = -2;
|
||||||
Wp({0,1, 0,0}) = -0.05;
|
Wp({0,1, 0,0}) = -0.05f;
|
||||||
Wp({1,2, 0,0}) = 0.05;
|
Wp({1,2, 0,0}) = 0.05f;
|
||||||
|
|
||||||
std::initializer_list<double> tArgs = {cellClip};
|
std::initializer_list<double> tArgs = {cellClip};
|
||||||
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
||||||
|
@ -1905,10 +1919,10 @@ TEST_F(DeclarableOpsTests13, lstmLayer_12) {
|
||||||
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.692315, 0.692315, 0.692315, -0.143704, -0.143704, -0.143704, 0., 0., 0., 0., 0., 0.,
|
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.692315, 0.692315, 0.692315, -0.143704, -0.143704, -0.143704, 0., 0., 0., 0., 0., 0.,
|
||||||
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.}, nd4j::DataType::FLOAT32);
|
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
NDArray expHL('c', {2,bS, nOut}, {0., 0., 0., 0.562925, 0.562925, 0.562925, 0.576568, 0.576568, 0.576568, 0.611224, 0.611224, 0.611224, 0.692315, 0.692315, 0.692315,
|
NDArray expHL('c', {2,bS, nOut}, {0.f, 0.f, 0.f, 0.562925f, 0.562925f, 0.562925f, 0.576568f, 0.576568f, 0.576568f, 0.611224f, 0.611224f, 0.611224f, 0.692315f, 0.692315f, 0.692315f,
|
||||||
0., 0., 0., -0.25361 , -0.25361 , -0.25361 , -0.157103, -0.157103, -0.157103,-0.116502, -0.116502, -0.116502, -0.100025, -0.100025, -0.100025}, nd4j::DataType::FLOAT32);
|
0.f, 0.f, 0.f, -0.25361f, -0.25361f, -0.25361f, -0.157103f, -0.157103f, -0.157103f, -0.116502f, -0.116502f, -0.116502f, -0.100025f, -0.100025f, -0.100025f}, nd4j::DataType::FLOAT32);
|
||||||
NDArray expCL('c', {2,bS, nOut}, {0., 0., 0.,1.534275, 1.534275, 1.534275,1.40183 , 1.40183 , 1.40183 ,1.449675, 1.449675, 1.449675,1.767702, 1.767702, 1.767702,
|
NDArray expCL('c', {2,bS, nOut}, {0.f, 0.f, 0.f, 1.534275f, 1.534275f, 1.534275f, 1.40183f, 1.40183f, 1.40183f, 1.449675f, 1.449675f, 1.449675f, 1.767702f, 1.767702f, 1.767702f,
|
||||||
0., 0., 0.,-0.86636 , -0.86636 , -0.86636 ,-0.470245, -0.470245, -0.470245,-0.341856, -0.341856, -0.341856,-0.294986, -0.294986, -0.294986}, nd4j::DataType::FLOAT32);
|
0.f, 0.f, 0.f, -0.86636f, -0.86636f, -0.86636f, -0.470245f, -0.470245f, -0.470245f, -0.341856f, -0.341856f, -0.341856f, -0.294986f, -0.294986f, -0.294986f}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::lstmLayer op;
|
nd4j::ops::lstmLayer op;
|
||||||
auto results = op.execute({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
|
auto results = op.execute({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
|
||||||
|
|
|
@ -148,8 +148,8 @@ TEST_F(DeclarableOpsTests15, Test_standarize_1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests15, Test_standarize_bp_1) {
|
TEST_F(DeclarableOpsTests15, Test_standarize_bp_1) {
|
||||||
auto x = NDArrayFactory::create<float>('c', {5}, {1., 1., 1., 1., 1.});
|
auto x = NDArrayFactory::create<float>('c', {5}, {1.f, 1.f, 1.f, 1.f, 1.f});
|
||||||
auto eps = NDArrayFactory::create<float>('c', {5}, {0., 0., 0., 0., 0.});
|
auto eps = NDArrayFactory::create<float>('c', {5}, {0.f, 0.f, 0.f, 0.f, 0.f});
|
||||||
|
|
||||||
nd4j::ops::standardize_bp op;
|
nd4j::ops::standardize_bp op;
|
||||||
auto result = op.execute({&x, &eps}, {}, {0}, {});
|
auto result = op.execute({&x, &eps}, {}, {0}, {});
|
||||||
|
|
|
@ -197,3 +197,44 @@ TEST_F(DeclarableOpsTests16, test_range_2) {
|
||||||
|
|
||||||
delete shapes;
|
delete shapes;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests16, test_reverse_1) {
|
||||||
|
std::vector<Nd4jLong> rows = {3, 5, 7, 8, 9, 10, 119, 211};
|
||||||
|
std::vector<Nd4jLong> columns = {6, 5, 10, 100, 153, 171, 635};
|
||||||
|
|
||||||
|
for (auto r : rows) {
|
||||||
|
for (auto c : columns) {
|
||||||
|
//nd4j_printf("Trying [%i, %i]\n", r, c);
|
||||||
|
auto array = NDArrayFactory::create<float>('c', {r, c});
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {r, c});
|
||||||
|
auto reversed = NDArrayFactory::create<float>('c', {r, c});
|
||||||
|
|
||||||
|
auto rowOriginal = NDArrayFactory::create<float>('c', {c});
|
||||||
|
auto rowReversed = NDArrayFactory::create<float>('c', {c});
|
||||||
|
|
||||||
|
for (int e = 0; e < c; e++) {
|
||||||
|
rowOriginal.p(e, (float) e);
|
||||||
|
rowReversed.p(c - e - 1, (float) e);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
auto listI = array.allTensorsAlongDimension({1});
|
||||||
|
auto listE = exp.allTensorsAlongDimension({1});
|
||||||
|
|
||||||
|
for (int e = 0; e < r; e++) {
|
||||||
|
listI->at(e)->assign(rowOriginal);
|
||||||
|
listE->at(e)->assign(rowReversed);
|
||||||
|
}
|
||||||
|
|
||||||
|
delete listI;
|
||||||
|
delete listE;
|
||||||
|
|
||||||
|
nd4j::ops::reverse op;
|
||||||
|
Nd4jLong axis = 1;
|
||||||
|
auto status = op.execute({&array}, {&reversed}, {}, {axis}, {});
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
|
ASSERT_EQ(exp, reversed);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -1591,7 +1591,7 @@ TEST_F(DeclarableOpsTests2, cosine_distance_loss_test6) {
|
||||||
auto *result = results->at(0);
|
auto *result = results->at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(result->isScalar());
|
ASSERT_TRUE(result->isScalar());
|
||||||
ASSERT_TRUE(result->e<float>(0) == -71.);
|
ASSERT_TRUE(result->e<float>(0) == -71.f);
|
||||||
|
|
||||||
delete results;
|
delete results;
|
||||||
|
|
||||||
|
@ -1616,7 +1616,7 @@ TEST_F(DeclarableOpsTests2, cosine_distance_loss_test7) {
|
||||||
auto *result = results->at(0);
|
auto *result = results->at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(result->isScalar());
|
ASSERT_TRUE(result->isScalar());
|
||||||
ASSERT_TRUE(result->e<float>(0) == -69.);
|
ASSERT_TRUE(result->e<float>(0) == -69.f);
|
||||||
|
|
||||||
delete results;
|
delete results;
|
||||||
|
|
||||||
|
@ -1630,8 +1630,8 @@ TEST_F(DeclarableOpsTests2, cosine_distance_loss_test8) {
|
||||||
auto weights = NDArrayFactory::create<float>('c', {2,3,1});
|
auto weights = NDArrayFactory::create<float>('c', {2,3,1});
|
||||||
|
|
||||||
labels.linspace(1);
|
labels.linspace(1);
|
||||||
weights.assign(0.5);
|
weights.assign(0.5f);
|
||||||
predictions.assign(0.5);
|
predictions.assign(0.5f);
|
||||||
|
|
||||||
nd4j::ops::cosine_distance_loss op;
|
nd4j::ops::cosine_distance_loss op;
|
||||||
auto results = op.execute({&predictions, &weights, &labels}, {}, {2,2});
|
auto results = op.execute({&predictions, &weights, &labels}, {}, {2,2});
|
||||||
|
@ -1641,7 +1641,7 @@ TEST_F(DeclarableOpsTests2, cosine_distance_loss_test8) {
|
||||||
auto *result = results->at(0);
|
auto *result = results->at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(result->isScalar());
|
ASSERT_TRUE(result->isScalar());
|
||||||
ASSERT_TRUE(result->e<float>(0) == -24.);
|
ASSERT_TRUE(result->e<float>(0) == -24.f);
|
||||||
|
|
||||||
delete results;
|
delete results;
|
||||||
|
|
||||||
|
@ -1655,8 +1655,8 @@ TEST_F(DeclarableOpsTests2, cosine_distance_loss_test9) {
|
||||||
auto weights = NDArrayFactory::create<float>('c', {1,1});
|
auto weights = NDArrayFactory::create<float>('c', {1,1});
|
||||||
|
|
||||||
labels.linspace(1);
|
labels.linspace(1);
|
||||||
weights.assign(0.5);
|
weights.assign(0.5f);
|
||||||
predictions.assign(0.5);
|
predictions.assign(0.5f);
|
||||||
|
|
||||||
nd4j::ops::cosine_distance_loss op;
|
nd4j::ops::cosine_distance_loss op;
|
||||||
auto results = op.execute({&predictions, &weights, &labels}, {}, {2,2});
|
auto results = op.execute({&predictions, &weights, &labels}, {}, {2,2});
|
||||||
|
@ -1680,10 +1680,10 @@ TEST_F(DeclarableOpsTests2, cosine_distance_loss_test10) {
|
||||||
auto weights = NDArrayFactory::create<float>('c', {2,3,1});
|
auto weights = NDArrayFactory::create<float>('c', {2,3,1});
|
||||||
|
|
||||||
labels.linspace(1);
|
labels.linspace(1);
|
||||||
weights.assign(0.5);
|
weights.assign(0.5f);
|
||||||
predictions.assign(0.5);
|
predictions.assign(0.5f);
|
||||||
weights.p(0, 0.);
|
weights.p(0, 0.f);
|
||||||
weights.p(1, 0.);
|
weights.p(1, 0.f);
|
||||||
|
|
||||||
nd4j::ops::cosine_distance_loss op;
|
nd4j::ops::cosine_distance_loss op;
|
||||||
auto results = op.execute({&predictions, &weights, &labels}, {}, {2,2});
|
auto results = op.execute({&predictions, &weights, &labels}, {}, {2,2});
|
||||||
|
|
|
@ -1707,7 +1707,7 @@ TEST_F(DeclarableOpsTests3, betainc_test8) {
|
||||||
b.linspace(10.);
|
b.linspace(10.);
|
||||||
x.assign(1.);
|
x.assign(1.);
|
||||||
|
|
||||||
auto expected= NDArrayFactory::create<float>('c', {3,3}, {1.f, 1.f, 1.,1.,1.,1.,1.,1.,1.});
|
auto expected= NDArrayFactory::create<float>('c', {3,3}, {1.f, 1.f, 1.f,1.f,1.f,1.f,1.f,1.f,1.f});
|
||||||
|
|
||||||
nd4j::ops::betainc op;
|
nd4j::ops::betainc op;
|
||||||
auto results = op.execute({&a, &b, &x}, {}, {});
|
auto results = op.execute({&a, &b, &x}, {}, {});
|
||||||
|
@ -2292,9 +2292,9 @@ TEST_F(DeclarableOpsTests3, svd_test3) {
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
for(uint i = 0; i < expU.lengthOf(); ++i)
|
for(uint i = 0; i < expU.lengthOf(); ++i)
|
||||||
ASSERT_NEAR(nd4j::math::nd4j_abs(expU.e<float>(i)), nd4j::math::nd4j_abs(u->e<float>(i)), 1e-5);
|
ASSERT_NEAR(nd4j::math::nd4j_abs(expU.e<float>(i)), nd4j::math::nd4j_abs(u->e<float>(i)), 1e-5f);
|
||||||
for(uint i = 0; i < expV.lengthOf(); ++i)
|
for(uint i = 0; i < expV.lengthOf(); ++i)
|
||||||
ASSERT_NEAR(nd4j::math::nd4j_abs(expV.e<float>(i)), nd4j::math::nd4j_abs(v->e<float>(i)), 1e-5);
|
ASSERT_NEAR(nd4j::math::nd4j_abs(expV.e<float>(i)), nd4j::math::nd4j_abs(v->e<float>(i)), 1e-5f);
|
||||||
}
|
}
|
||||||
|
|
||||||
delete results;
|
delete results;
|
||||||
|
@ -2329,9 +2329,9 @@ TEST_F(DeclarableOpsTests3, svd_test4) {
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
for(uint i = 0; i < expU.lengthOf(); ++i)
|
for(uint i = 0; i < expU.lengthOf(); ++i)
|
||||||
ASSERT_NEAR(nd4j::math::nd4j_abs(expU.e<float>(i)), nd4j::math::nd4j_abs(u->e<float>(i)), 1e-5);
|
ASSERT_NEAR(nd4j::math::nd4j_abs(expU.e<float>(i)), nd4j::math::nd4j_abs(u->e<float>(i)), 1e-5f);
|
||||||
for(uint i = 0; i < expV.lengthOf(); ++i)
|
for(uint i = 0; i < expV.lengthOf(); ++i)
|
||||||
ASSERT_NEAR(nd4j::math::nd4j_abs(expV.e<float>(i)), nd4j::math::nd4j_abs(v->e<float>(i)), 1e-5);
|
ASSERT_NEAR(nd4j::math::nd4j_abs(expV.e<float>(i)), nd4j::math::nd4j_abs(v->e<float>(i)), 1e-5f);
|
||||||
}
|
}
|
||||||
|
|
||||||
delete results;
|
delete results;
|
||||||
|
@ -2366,9 +2366,9 @@ TEST_F(DeclarableOpsTests3, svd_test5) {
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
for(uint i = 0; i < expU.lengthOf(); ++i)
|
for(uint i = 0; i < expU.lengthOf(); ++i)
|
||||||
ASSERT_NEAR(nd4j::math::nd4j_abs(expU.e<float>(i)), nd4j::math::nd4j_abs(u->e<float>(i)), 1e-5);
|
ASSERT_NEAR(nd4j::math::nd4j_abs(expU.e<float>(i)), nd4j::math::nd4j_abs(u->e<float>(i)), 1e-5f);
|
||||||
for(uint i = 0; i < expV.lengthOf(); ++i)
|
for(uint i = 0; i < expV.lengthOf(); ++i)
|
||||||
ASSERT_NEAR(nd4j::math::nd4j_abs(expV.e<float>(i)), nd4j::math::nd4j_abs(v->e<float>(i)), 1e-5);
|
ASSERT_NEAR(nd4j::math::nd4j_abs(expV.e<float>(i)), nd4j::math::nd4j_abs(v->e<float>(i)), 1e-5f);
|
||||||
}
|
}
|
||||||
|
|
||||||
delete results;
|
delete results;
|
||||||
|
@ -2421,9 +2421,9 @@ TEST_F(DeclarableOpsTests3, svd_test6) {
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
for(uint i = 0; i < expU.lengthOf(); ++i)
|
for(uint i = 0; i < expU.lengthOf(); ++i)
|
||||||
ASSERT_NEAR(nd4j::math::nd4j_abs(expU.e<float>(i)), nd4j::math::nd4j_abs(u->e<float>(i)), 1e-5);
|
ASSERT_NEAR(nd4j::math::nd4j_abs(expU.e<float>(i)), nd4j::math::nd4j_abs(u->e<float>(i)), 1e-5f);
|
||||||
for(uint i = 0; i < expV.lengthOf(); ++i)
|
for(uint i = 0; i < expV.lengthOf(); ++i)
|
||||||
ASSERT_NEAR(nd4j::math::nd4j_abs(expV.e<float>(i)), nd4j::math::nd4j_abs(v->e<float>(i)), 1e-5);
|
ASSERT_NEAR(nd4j::math::nd4j_abs(expV.e<float>(i)), nd4j::math::nd4j_abs(v->e<float>(i)), 1e-5f);
|
||||||
}
|
}
|
||||||
|
|
||||||
delete results;
|
delete results;
|
||||||
|
|
|
@ -4084,7 +4084,7 @@ TEST_F(DeclarableOpsTests7, Softsign_1) {
|
||||||
TEST_F(DeclarableOpsTests7, Softsign_BP_1) {
|
TEST_F(DeclarableOpsTests7, Softsign_BP_1) {
|
||||||
|
|
||||||
NDArray x = NDArrayFactory::create<double >('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11});
|
NDArray x = NDArrayFactory::create<double >('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11});
|
||||||
// NDArray e = NDArrayFactory::create<float>('c', {5, 2}, {1.3132616, 2.126928, 3.0485873, 4.01815, 5.0067153, 7.0009117, 9.000123, 10.000046, 10.000046, 11.000016});
|
// NDArray e = NDArrayFactory::create<float>('c', {5, 2}, {1.3132616f, 2.126928f, 3.0485873f, 4.01815f, 5.0067153f, 7.0009117f, 9.000123f, 10.000046f, 10.000046f, 11.000016f});
|
||||||
NDArray eps = NDArrayFactory::create<double>('c', {5, 2}, {1,2,3,4,5,6,7,8, 9, 10});
|
NDArray eps = NDArrayFactory::create<double>('c', {5, 2}, {1,2,3,4,5,6,7,8, 9, 10});
|
||||||
nd4j::ops::softsign ffOP;
|
nd4j::ops::softsign ffOP;
|
||||||
nd4j::ops::softsign_bp bpOp;
|
nd4j::ops::softsign_bp bpOp;
|
||||||
|
|
|
@ -24,6 +24,7 @@
|
||||||
#include <NDArray.h>
|
#include <NDArray.h>
|
||||||
#include <ops/ops.h>
|
#include <ops/ops.h>
|
||||||
#include <GradCheck.h>
|
#include <GradCheck.h>
|
||||||
|
#include <chrono>
|
||||||
|
|
||||||
|
|
||||||
using namespace nd4j;
|
using namespace nd4j;
|
||||||
|
@ -58,5 +59,20 @@ TEST_F(DeclarableOpsTestsCuda1, Test_CHOOSE_SCALAR_LARGE) {
|
||||||
//ASSERT_TRUE(exp.isSameShape(z));
|
//ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
|
||||||
delete result;
|
delete result;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
TEST_F(DeclarableOpsTestsCuda1, Test_Reverse_TAD_1) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {1, 3, 608, 608});
|
||||||
|
auto z = x.like();
|
||||||
|
x.linspace(1.0f);
|
||||||
|
|
||||||
|
nd4j::ops::reverse op;
|
||||||
|
auto timeStart = std::chrono::system_clock::now();
|
||||||
|
auto status = op.execute({&x}, {&z}, {}, {1}, {});
|
||||||
|
auto timeEnd = std::chrono::system_clock::now();
|
||||||
|
auto outerTime = std::chrono::duration_cast<std::chrono::microseconds> (timeEnd - timeStart).count();
|
||||||
|
nd4j_printf("exec time: %lld us\n", outerTime);
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
}
|
||||||
|
*/
|
|
@ -661,9 +661,9 @@ TEST_F(JavaInteropTests, Test_Greater_1) {
|
||||||
auto x = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 1, 2});
|
auto x = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 1, 2});
|
||||||
auto y = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 0, 0});
|
auto y = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 0, 0});
|
||||||
// auto o = NDArrayFactory::create<float>('c', {2, 2}, {3, 3, 3, 3});
|
// auto o = NDArrayFactory::create<float>('c', {2, 2}, {3, 3, 3, 3});
|
||||||
auto o = NDArrayFactory::create<bool>('c', {2, 2}, {1, 1, 1, 1});
|
auto o = NDArrayFactory::create<bool>('c', {2, 2}, {true, true, true, true});
|
||||||
|
|
||||||
auto exp = NDArrayFactory::create<bool>('c', {2, 2}, {0, 0, 1, 1});
|
auto exp = NDArrayFactory::create<bool>('c', {2, 2}, {false, false, true, true});
|
||||||
|
|
||||||
NDArray::prepareSpecialUse({&o}, {&x, &y});
|
NDArray::prepareSpecialUse({&o}, {&x, &y});
|
||||||
|
|
||||||
|
@ -685,9 +685,9 @@ TEST_F(JavaInteropTests, Test_Greater_1) {
|
||||||
TEST_F(JavaInteropTests, Test_Greater_2) {
|
TEST_F(JavaInteropTests, Test_Greater_2) {
|
||||||
auto x = NDArrayFactory::create<float>('c', {2, 2}, {1.f, 2.f, 1.f, 2.f});
|
auto x = NDArrayFactory::create<float>('c', {2, 2}, {1.f, 2.f, 1.f, 2.f});
|
||||||
auto y = NDArrayFactory::create<float>('c', {2, 2}, {1.f, 2.f, 0.f, 0.f});
|
auto y = NDArrayFactory::create<float>('c', {2, 2}, {1.f, 2.f, 0.f, 0.f});
|
||||||
auto o = NDArrayFactory::create<bool>('c', {2, 2}, {1, 1, 1, 1});
|
auto o = NDArrayFactory::create<bool>('c', {2, 2}, {true, true, true, true});
|
||||||
|
|
||||||
auto exp = NDArrayFactory::create<bool>('c', {2, 2}, {0, 0, 1, 1});
|
auto exp = NDArrayFactory::create<bool>('c', {2, 2}, {false, false, true, true});
|
||||||
|
|
||||||
nd4j::ops::greater op;
|
nd4j::ops::greater op;
|
||||||
|
|
||||||
|
|
|
@ -1163,10 +1163,10 @@ TEST_F(NDArrayCudaBasicsTests, applyReduce3_1) {
|
||||||
NDArray k('c', {2,3}, {-2,3,-4,5,-2,3}, nd4j::DataType::INT32);
|
NDArray k('c', {2,3}, {-2,3,-4,5,-2,3}, nd4j::DataType::INT32);
|
||||||
NDArray k2('c', {3,2}, {-2,3,-4,5,-2,3}, nd4j::DataType::INT32);
|
NDArray k2('c', {3,2}, {-2,3,-4,5,-2,3}, nd4j::DataType::INT32);
|
||||||
|
|
||||||
NDArray exp1('c', {3}, {4., 20., 36.}, nd4j::DataType::FLOAT32);
|
NDArray exp1('c', {3}, {4.f, 20.f, 36.f}, nd4j::DataType::FLOAT32);
|
||||||
NDArray exp2('c', {2,3}, {-10., -2., 6.,14., 22., 30.}, nd4j::DataType::FLOAT32);
|
NDArray exp2('c', {2,3}, {-10.f, -2.f, 6.f,14.f, 22.f, 30.f}, nd4j::DataType::FLOAT32);
|
||||||
NDArray exp3('c', {4}, {38., 41., 44., 47.}, nd4j::DataType::FLOAT32);
|
NDArray exp3('c', {4}, {38.f, 41.f, 44.f, 47.f}, nd4j::DataType::FLOAT32);
|
||||||
NDArray exp4('c', {4}, {114., 117., 120., 123.}, nd4j::DataType::FLOAT32);
|
NDArray exp4('c', {4}, {114.f, 117.f, 120.f, 123.f}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
|
||||||
NDArray* z = x.applyReduce3(nd4j::reduce3::Dot, &y, {0,2});
|
NDArray* z = x.applyReduce3(nd4j::reduce3::Dot, &y, {0,2});
|
||||||
|
@ -1271,8 +1271,10 @@ TEST_F(NDArrayCudaBasicsTests, applyAllReduce3_1) {
|
||||||
NDArray x3('c', {3,2}, {1.5,1.5,1.5,1.5,1.5,1.5}, nd4j::DataType::DOUBLE);
|
NDArray x3('c', {3,2}, {1.5,1.5,1.5,1.5,1.5,1.5}, nd4j::DataType::DOUBLE);
|
||||||
NDArray x4('c', {3,2}, {1,2,3,4,5,6}, nd4j::DataType::DOUBLE);
|
NDArray x4('c', {3,2}, {1,2,3,4,5,6}, nd4j::DataType::DOUBLE);
|
||||||
|
|
||||||
NDArray exp1('c', {3,2}, {-88., -124., 6., -2., 22., 14.}, nd4j::DataType::FLOAT32);
|
NDArray exp1('c', {3,2}, {-88.f, -124.f, 6.f, -2.f, 22.f, 14.f}, nd4j::DataType::FLOAT32);
|
||||||
NDArray exp2('c', {6,4}, {-36., -44., -52., -60.,-42., -52., -62., -72.,2., 0., -2., -4.,6., 4., 2., 0.,10., 8., 6., 4.,14., 12., 10., 8.}, nd4j::DataType::FLOAT32);
|
NDArray exp2('c', {6,4}, {-36.f, -44.f, -52.f, -60.f,-42.f, -52.f, -62.f, -72.f, 2.f, 0.f, -2.f,
|
||||||
|
-4.f, 6.f, 4.f, 2.f, 0.f, 10.f, 8.f, 6.f, 4.f, 14.f, 12.f, 10.f, 8.f},
|
||||||
|
nd4j::DataType::FLOAT32);
|
||||||
NDArray exp3('c', {1,1}, {31.5}, nd4j::DataType::DOUBLE);
|
NDArray exp3('c', {1,1}, {31.5}, nd4j::DataType::DOUBLE);
|
||||||
NDArray exp4('c', {3,3}, {4.5, 10.5, 16.5,4.5, 10.5, 16.5,4.5, 10.5, 16.5}, nd4j::DataType::DOUBLE);
|
NDArray exp4('c', {3,3}, {4.5, 10.5, 16.5,4.5, 10.5, 16.5,4.5, 10.5, 16.5}, nd4j::DataType::DOUBLE);
|
||||||
|
|
||||||
|
@ -1400,10 +1402,10 @@ TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_float_test1) {
|
||||||
NDArray z5('c', {2}, {100,100}, nd4j::DataType::FLOAT32);
|
NDArray z5('c', {2}, {100,100}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
NDArray exp1('c', {}, {2.166667}, nd4j::DataType::DOUBLE);
|
NDArray exp1('c', {}, {2.166667}, nd4j::DataType::DOUBLE);
|
||||||
NDArray exp2('c', {2,2}, {3,4,1,0.666667}, nd4j::DataType::FLOAT32);
|
NDArray exp2('c', {2,2}, {3.f,4.f,1.f,0.666667f}, nd4j::DataType::FLOAT32);
|
||||||
NDArray exp3('c', {3}, {4.5,1,1}, nd4j::DataType::DOUBLE);
|
NDArray exp3('c', {3}, {4.5,1,1}, nd4j::DataType::DOUBLE);
|
||||||
NDArray exp4('c', {3,2}, {4,5,1,1,1,1}, nd4j::DataType::FLOAT32);
|
NDArray exp4('c', {3,2}, {4,5,1,1,1,1}, nd4j::DataType::FLOAT32);
|
||||||
NDArray exp5('c', {2}, {3.5,0.833333}, nd4j::DataType::FLOAT32);
|
NDArray exp5('c', {2}, {3.5f,0.833333f}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
x.reduceAlongDimension(nd4j::reduce::Mean, &z1, {0,1,2});
|
x.reduceAlongDimension(nd4j::reduce::Mean, &z1, {0,1,2});
|
||||||
ASSERT_TRUE(z1.equalsTo(&exp1));
|
ASSERT_TRUE(z1.equalsTo(&exp1));
|
||||||
|
@ -1503,7 +1505,7 @@ TEST_F(NDArrayCudaBasicsTests, EqualityTest1) {
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_same_test1) {
|
TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_same_test1) {
|
||||||
|
|
||||||
NDArray x('c', {2,3,2}, {1.5,2,3,4,5,6,7.5,8,-1,-2,-3.5,-4,}, nd4j::DataType::FLOAT32);
|
NDArray x('c', {2,3,2}, {1.5f,2.f,3.f,4.f,5.f,6.f,7.5f,8.f,-1.f,-2.f,-3.5f,-4.f}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
NDArray z1('c', {}, {100}, nd4j::DataType::FLOAT32);
|
NDArray z1('c', {}, {100}, nd4j::DataType::FLOAT32);
|
||||||
NDArray z2('c', {2,2}, {100,100,100,100}, nd4j::DataType::FLOAT32);
|
NDArray z2('c', {2,2}, {100,100,100,100}, nd4j::DataType::FLOAT32);
|
||||||
|
@ -1511,11 +1513,11 @@ TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_same_test1) {
|
||||||
NDArray z4('c', {3,2}, {100,100,100,100,100,100}, nd4j::DataType::FLOAT32);
|
NDArray z4('c', {3,2}, {100,100,100,100,100,100}, nd4j::DataType::FLOAT32);
|
||||||
NDArray z5('c', {2}, {100,100}, nd4j::DataType::FLOAT32);
|
NDArray z5('c', {2}, {100,100}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
NDArray exp1('c', {}, {26.5}, nd4j::DataType::FLOAT32);
|
NDArray exp1('c', {}, {26.5f}, nd4j::DataType::FLOAT32);
|
||||||
NDArray exp2('c', {2,2}, {9.5,12,3,2}, nd4j::DataType::FLOAT32);
|
NDArray exp2('c', {2,2}, {9.5f,12.f,3.f,2.f}, nd4j::DataType::FLOAT32);
|
||||||
NDArray exp3('c', {3}, {19,4,3.5}, nd4j::DataType::FLOAT32);
|
NDArray exp3('c', {3}, {19.f,4.f,3.5f}, nd4j::DataType::FLOAT32);
|
||||||
NDArray exp4('c', {3,2}, {9,10,2,2,1.5,2}, nd4j::DataType::FLOAT32);
|
NDArray exp4('c', {3,2}, {9.f,10.f,2.f,2.f,1.5f,2.f}, nd4j::DataType::FLOAT32);
|
||||||
NDArray exp5('c', {2}, {21.5,5}, nd4j::DataType::FLOAT32);
|
NDArray exp5('c', {2}, {21.5f,5.f}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
x.reduceAlongDimension(nd4j::reduce::Sum, &z1, {0,1,2});
|
x.reduceAlongDimension(nd4j::reduce::Sum, &z1, {0,1,2});
|
||||||
ASSERT_TRUE(z1.equalsTo(&exp1));
|
ASSERT_TRUE(z1.equalsTo(&exp1));
|
||||||
|
@ -1575,17 +1577,17 @@ TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_bool_test1) {
|
||||||
|
|
||||||
NDArray x('c', {2,3,2}, {0.5,2,3,-4,5,6,-7.5,8,-1,-0.5,-3.5,4}, nd4j::DataType::DOUBLE);
|
NDArray x('c', {2,3,2}, {0.5,2,3,-4,5,6,-7.5,8,-1,-0.5,-3.5,4}, nd4j::DataType::DOUBLE);
|
||||||
|
|
||||||
NDArray z1('c', {}, {100}, nd4j::DataType::BOOL);
|
NDArray z1('c', {}, {true}, nd4j::DataType::BOOL);
|
||||||
NDArray z2('c', {2,2}, {100,100,100,100}, nd4j::DataType::BOOL);
|
NDArray z2('c', {2,2}, {true,true,true,true}, nd4j::DataType::BOOL);
|
||||||
NDArray z3('c', {3}, {100,100,100}, nd4j::DataType::BOOL);
|
NDArray z3('c', {3}, {true,true,true}, nd4j::DataType::BOOL);
|
||||||
NDArray z4('c', {3,2}, {100,100,100,100,100,100}, nd4j::DataType::BOOL);
|
NDArray z4('c', {3,2}, {true,true,true,true,true,true}, nd4j::DataType::BOOL);
|
||||||
NDArray z5('c', {2}, {100,100}, nd4j::DataType::BOOL);
|
NDArray z5('c', {2}, {true,true}, nd4j::DataType::BOOL);
|
||||||
|
|
||||||
NDArray exp1('c', {}, {1}, nd4j::DataType::BOOL);
|
NDArray exp1('c', {}, {true}, nd4j::DataType::BOOL);
|
||||||
NDArray exp2('c', {2,2}, {1,1,0,1}, nd4j::DataType::BOOL);
|
NDArray exp2('c', {2,2}, {true,true,false,true}, nd4j::DataType::BOOL);
|
||||||
NDArray exp3('c', {3}, {1,1,1}, nd4j::DataType::BOOL);
|
NDArray exp3('c', {3}, {true,true,true}, nd4j::DataType::BOOL);
|
||||||
NDArray exp4('c', {3,2}, {1,1,1,0,1,1}, nd4j::DataType::BOOL);
|
NDArray exp4('c', {3,2}, {true,true,true,false,true,true}, nd4j::DataType::BOOL);
|
||||||
NDArray exp5('c', {2}, {1,1}, nd4j::DataType::BOOL);
|
NDArray exp5('c', {2}, {true,true}, nd4j::DataType::BOOL);
|
||||||
|
|
||||||
x.reduceAlongDimension(nd4j::reduce::IsPositive, &z1, {0,1,2});
|
x.reduceAlongDimension(nd4j::reduce::IsPositive, &z1, {0,1,2});
|
||||||
ASSERT_TRUE(z1.equalsTo(&exp1));
|
ASSERT_TRUE(z1.equalsTo(&exp1));
|
||||||
|
@ -1643,7 +1645,7 @@ TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_bool_test2) {
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_long_test1) {
|
TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_long_test1) {
|
||||||
|
|
||||||
NDArray x('c', {2,3,2}, {0.5,2,3,-0,5,6,-7.5,0,-1,-0.5,-3.5,4}, nd4j::DataType::FLOAT32);
|
NDArray x('c', {2,3,2}, {0.5f,2.f,3.f,-0.f,5.f,6.f,-7.5f,0.f,-1.f,-0.5f,-3.5f,4.f}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
NDArray z1('c', {}, {100}, nd4j::DataType::INT64);
|
NDArray z1('c', {}, {100}, nd4j::DataType::INT64);
|
||||||
NDArray z2('c', {2,2}, {100,100,100,100}, nd4j::DataType::INT64);
|
NDArray z2('c', {2,2}, {100,100,100,100}, nd4j::DataType::INT64);
|
||||||
|
@ -1912,7 +1914,7 @@ TEST_F(NDArrayCudaBasicsTests, Tile_Test_2_3)
|
||||||
TEST_F(NDArrayCudaBasicsTests, Operator_Plus_Test_2)
|
TEST_F(NDArrayCudaBasicsTests, Operator_Plus_Test_2)
|
||||||
{
|
{
|
||||||
double expBuff[] = {2., 3, 3., 4., 4., 5, 5., 6., 6., 7, 7., 8.};
|
double expBuff[] = {2., 3, 3., 4., 4., 5, 5., 6., 6., 7, 7., 8.};
|
||||||
NDArray a('c', {4,4}, {1.,2,3,4,5,6,7,8,9,2,3,2,1,0,4,7.}, nd4j::DataType::FLOAT32);
|
NDArray a('c', {4,4}, {1,2,3,4,5,6,7,8,9,2,3,2,1,0,4,7}, nd4j::DataType::FLOAT32);
|
||||||
auto x = NDArrayFactory::create<double>('c', {3, 2, 1});
|
auto x = NDArrayFactory::create<double>('c', {3, 2, 1});
|
||||||
auto y = NDArrayFactory::create<double>('c', {1, 2});
|
auto y = NDArrayFactory::create<double>('c', {1, 2});
|
||||||
auto expected = NDArrayFactory::create<double>(expBuff, 'c', {3, 2, 2});
|
auto expected = NDArrayFactory::create<double>(expBuff, 'c', {3, 2, 2});
|
||||||
|
@ -1928,7 +1930,7 @@ TEST_F(NDArrayCudaBasicsTests, Operator_Plus_Test_2)
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(NDArrayCudaBasicsTests, assign_2)
|
TEST_F(NDArrayCudaBasicsTests, assign_2)
|
||||||
{
|
{
|
||||||
NDArray x('c', {4}, {1.5,2.5,3.5,4.5}, nd4j::DataType::FLOAT32);
|
NDArray x('c', {4}, {1.5f,2.5f,3.5f,4.5f}, nd4j::DataType::FLOAT32);
|
||||||
NDArray y('c', {4}, nd4j::DataType::INT32);
|
NDArray y('c', {4}, nd4j::DataType::INT32);
|
||||||
NDArray expected('c', {4}, {1,2,3,4}, nd4j::DataType::INT32);
|
NDArray expected('c', {4}, {1,2,3,4}, nd4j::DataType::INT32);
|
||||||
|
|
||||||
|
@ -1945,30 +1947,30 @@ TEST_F(NDArrayCudaBasicsTests, subarray_1)
|
||||||
NDArray y('f', {2,3,4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24}, nd4j::DataType::FLOAT32);
|
NDArray y('f', {2,3,4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
Nd4jLong shapeExpX0[] = {1, 2, 12, 8192, 1, 99};
|
Nd4jLong shapeExpX0[] = {1, 2, 12, 8192, 1, 99};
|
||||||
float buffExpX0[] = {1.000000, 13.000000};
|
float buffExpX0[] = {1.f, 13.f};
|
||||||
Nd4jLong shapeExpX1[] = {1, 2, 12, 8192, 1, 99};
|
Nd4jLong shapeExpX1[] = {1, 2, 12, 8192, 1, 99};
|
||||||
float buffExpX1[] = {2.000000, 14.000000};
|
float buffExpX1[] = {2.f, 14.f};
|
||||||
Nd4jLong shapeExpX2[] = {3, 2, 1, 1, 12, 4, 1, 8192, 1, 99};
|
Nd4jLong shapeExpX2[] = {3, 2, 1, 1, 12, 4, 1, 8192, 1, 99};
|
||||||
float buffExpX2[] = {1.000000, 13.000000};
|
float buffExpX2[] = {1.f, 13.f};
|
||||||
Nd4jLong shapeExpX3[] = {2, 2, 4, 12, 1, 8192, 1, 99};
|
Nd4jLong shapeExpX3[] = {2, 2, 4, 12, 1, 8192, 1, 99};
|
||||||
float buffExpX3[] = {9.000000, 10.000000, 11.000000, 12.000000, 21.000000, 22.000000, 23.000000, 24.000000};
|
float buffExpX3[] = {9.f, 10.f, 11.f, 12.f, 21.f, 22.f, 23.f, 24.f};
|
||||||
Nd4jLong shapeExpX4[] = {3, 2, 1, 4, 12, 4, 1, 8192, 1, 99};
|
Nd4jLong shapeExpX4[] = {3, 2, 1, 4, 12, 4, 1, 8192, 1, 99};
|
||||||
float buffExpX4[] = {9.000000, 10.000000, 11.000000, 12.000000, 21.000000, 22.000000, 23.000000, 24.000000};
|
float buffExpX4[] = {9.f, 10.f, 11.f, 12.f, 21.f, 22.f, 23.f, 24.f};
|
||||||
Nd4jLong shapeExpX5[] = {2, 2, 3, 12, 4, 8192, 1, 99};
|
Nd4jLong shapeExpX5[] = {2, 2, 3, 12, 4, 8192, 1, 99};
|
||||||
float buffExpX5[] = {4.000000, 8.000000, 12.000000, 16.000000, 20.000000, 24.000000};
|
float buffExpX5[] = {4.f, 8.f, 12.f, 16.f, 20.f, 24.f};
|
||||||
|
|
||||||
Nd4jLong shapeExpY0[] = {1, 2, 1, 8192, 1, 99};
|
Nd4jLong shapeExpY0[] = {1, 2, 1, 8192, 1, 99};
|
||||||
float buffExpY0[] = {1.000000, 2.000000};
|
float buffExpY0[] = {1.f, 2.f};
|
||||||
Nd4jLong shapeExpY1[] = {1, 2, 1, 8192, 1, 99};
|
Nd4jLong shapeExpY1[] = {1, 2, 1, 8192, 1, 99};
|
||||||
float buffExpY1[] = {7.000000, 8.000000};
|
float buffExpY1[] = {7.f, 8.f};
|
||||||
Nd4jLong shapeExpY2[] = {3, 2, 1, 1, 1, 2, 6, 8192, 1, 102};
|
Nd4jLong shapeExpY2[] = {3, 2, 1, 1, 1, 2, 6, 8192, 1, 102};
|
||||||
float buffExpY2[] = {1.000000, 2.000000};
|
float buffExpY2[] = {1.f, 2.f};
|
||||||
Nd4jLong shapeExpY3[] = {2, 2, 4, 1, 6, 8192, 1, 99};
|
Nd4jLong shapeExpY3[] = {2, 2, 4, 1, 6, 8192, 1, 99};
|
||||||
float buffExpY3[] = {5.000000, 11.000000, 17.000000, 23.000000, 6.000000, 12.000000, 18.000000, 24.000000};
|
float buffExpY3[] = {5.f, 11.f, 17.f, 23.f, 6.f, 12.f, 18.f, 24.f};
|
||||||
Nd4jLong shapeExpY4[] = {3, 2, 1, 4, 1, 2, 6, 8192, 1, 102};
|
Nd4jLong shapeExpY4[] = {3, 2, 1, 4, 1, 2, 6, 8192, 1, 102};
|
||||||
float buffExpY4[] = {5.000000, 11.000000, 17.000000, 23.000000, 6.000000, 12.000000, 18.000000, 24.000000};
|
float buffExpY4[] = {5.f, 11.f, 17.f, 23.f, 6.f, 12.f, 18.f, 24.f};
|
||||||
Nd4jLong shapeExpY5[] = {2, 2, 3, 1, 2, 8192, 1, 99};
|
Nd4jLong shapeExpY5[] = {2, 2, 3, 1, 2, 8192, 1, 99};
|
||||||
float buffExpY5[] = {19.000000, 21.000000, 23.000000, 20.000000, 22.000000, 24.000000};
|
float buffExpY5[] = {19.f, 21.f, 23.f, 20.f, 22.f, 24.f};
|
||||||
|
|
||||||
|
|
||||||
NDArray x0 = x(0, {1,2});
|
NDArray x0 = x(0, {1,2});
|
||||||
|
@ -2121,7 +2123,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_diagonal_1) {
|
||||||
TEST_F(NDArrayCudaBasicsTests, Test_PermuteEquality_02) {
|
TEST_F(NDArrayCudaBasicsTests, Test_PermuteEquality_02) {
|
||||||
auto x = NDArrayFactory::linspace<float>(1.f, 60.f, 60); //('c', {1, 60});
|
auto x = NDArrayFactory::linspace<float>(1.f, 60.f, 60); //('c', {1, 60});
|
||||||
//x.linspace(1);
|
//x.linspace(1);
|
||||||
auto exp = NDArrayFactory::create<float>('c', {3, 4, 5}, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 48.0, 49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, 60.0});
|
auto exp = NDArrayFactory::create<float>('c', {3, 4, 5}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f, 37.0f, 38.0f, 39.0f, 40.0f, 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 48.0f, 49.0f, 50.0f, 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, 56.0f, 57.0f, 58.0f, 59.0f, 60.0});
|
||||||
x->reshapei('c', {3, 4, 5});
|
x->reshapei('c', {3, 4, 5});
|
||||||
|
|
||||||
x->permutei({0, 1, 2});
|
x->permutei({0, 1, 2});
|
||||||
|
@ -2138,7 +2140,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_PermuteEquality_02) {
|
||||||
TEST_F(NDArrayCudaBasicsTests, Test_PermuteEquality_0) {
|
TEST_F(NDArrayCudaBasicsTests, Test_PermuteEquality_0) {
|
||||||
auto x = NDArrayFactory::create<float>('c', {1, 60});
|
auto x = NDArrayFactory::create<float>('c', {1, 60});
|
||||||
x.linspace(1);
|
x.linspace(1);
|
||||||
auto exp = NDArrayFactory::create<float>('c', {3, 4, 5}, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 48.0, 49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, 60.0});
|
auto exp = NDArrayFactory::create<float>('c', {3, 4, 5}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f, 37.0f, 38.0f, 39.0f, 40.0f, 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 48.0f, 49.0f, 50.0f, 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, 56.0f, 57.0f, 58.0f, 59.0f, 60.0});
|
||||||
x.reshapei('c', {3, 4, 5});
|
x.reshapei('c', {3, 4, 5});
|
||||||
|
|
||||||
x.permutei({0, 1, 2});
|
x.permutei({0, 1, 2});
|
||||||
|
@ -2153,7 +2155,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_PermuteEquality_0) {
|
||||||
TEST_F(NDArrayCudaBasicsTests, Test_PermuteEquality_1) {
|
TEST_F(NDArrayCudaBasicsTests, Test_PermuteEquality_1) {
|
||||||
auto x = NDArrayFactory::create<float>('c', {1, 60});
|
auto x = NDArrayFactory::create<float>('c', {1, 60});
|
||||||
x.linspace(1);
|
x.linspace(1);
|
||||||
auto exp = NDArrayFactory::create<float>('c', {3, 4, 5}, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 48.0, 49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, 60.0});
|
auto exp = NDArrayFactory::create<float>('c', {3, 4, 5}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f, 37.0f, 38.0f, 39.0f, 40.0f, 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 48.0f, 49.0f, 50.0f, 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, 56.0f, 57.0f, 58.0f, 59.0f, 60.0});
|
||||||
x.reshapei('c', {3, 4, 5});
|
x.reshapei('c', {3, 4, 5});
|
||||||
|
|
||||||
x.permutei({0, 1, 2});
|
x.permutei({0, 1, 2});
|
||||||
|
@ -2170,7 +2172,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_PermuteEquality_2) {
|
||||||
auto xx = NDArrayFactory::linspace<float>(1.f, 60.f, 60); //('c', {1, 60});
|
auto xx = NDArrayFactory::linspace<float>(1.f, 60.f, 60); //('c', {1, 60});
|
||||||
// auto x = *xx;
|
// auto x = *xx;
|
||||||
//x.linspace(1);
|
//x.linspace(1);
|
||||||
// auto exp = NDArrayFactory::create<float>('c', {3, 4, 5}, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 48.0, 49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, 60.0});
|
// auto exp = NDArrayFactory::create<float>('c', {3, 4, 5}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f, 37.0f, 38.0f, 39.0f, 40.0f, 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 48.0f, 49.0f, 50.0f, 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, 56.0f, 57.0f, 58.0f, 59.0f, 60.0});
|
||||||
// x.reshapei('c', {3, 4, 5});
|
// x.reshapei('c', {3, 4, 5});
|
||||||
|
|
||||||
// x.permutei({0, 1, 2});
|
// x.permutei({0, 1, 2});
|
||||||
|
@ -2188,7 +2190,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_PermuteEquality_3) {
|
||||||
//x.linspace(1);
|
//x.linspace(1);
|
||||||
for (int l = 0; l < x.lengthOf(); l++)
|
for (int l = 0; l < x.lengthOf(); l++)
|
||||||
x.p(l, float(l + 1.f));
|
x.p(l, float(l + 1.f));
|
||||||
auto exp = NDArrayFactory::create<float>('c', {3, 4, 5}, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 48.0, 49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, 60.0});
|
auto exp = NDArrayFactory::create<float>('c', {3, 4, 5}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f, 37.0f, 38.0f, 39.0f, 40.0f, 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 48.0f, 49.0f, 50.0f, 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, 56.0f, 57.0f, 58.0f, 59.0f, 60.0});
|
||||||
x.reshapei('c', {3, 4, 5});
|
x.reshapei('c', {3, 4, 5});
|
||||||
|
|
||||||
x.permutei({0, 1, 2});
|
x.permutei({0, 1, 2});
|
||||||
|
|
File diff suppressed because one or more lines are too long
|
@ -103,6 +103,7 @@ public class ImportClassMapping {
|
||||||
org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNormDerivative.class,
|
org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNormDerivative.class,
|
||||||
org.nd4j.linalg.api.ops.impl.layers.convolution.Col2Im.class,
|
org.nd4j.linalg.api.ops.impl.layers.convolution.Col2Im.class,
|
||||||
org.nd4j.linalg.api.ops.impl.layers.convolution.Conv1D.class,
|
org.nd4j.linalg.api.ops.impl.layers.convolution.Conv1D.class,
|
||||||
|
org.nd4j.linalg.api.ops.impl.layers.convolution.Conv1DDerivative.class,
|
||||||
org.nd4j.linalg.api.ops.impl.layers.convolution.Conv2D.class,
|
org.nd4j.linalg.api.ops.impl.layers.convolution.Conv2D.class,
|
||||||
org.nd4j.linalg.api.ops.impl.layers.convolution.Conv2DDerivative.class,
|
org.nd4j.linalg.api.ops.impl.layers.convolution.Conv2DDerivative.class,
|
||||||
org.nd4j.linalg.api.ops.impl.layers.convolution.Conv3D.class,
|
org.nd4j.linalg.api.ops.impl.layers.convolution.Conv3D.class,
|
||||||
|
|
|
@ -60,7 +60,7 @@ public class NonMaxSuppression extends DynamicCustomOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String[] tensorflowNames() {
|
public String[] tensorflowNames() {
|
||||||
return new String[]{"NonMaxSuppression", "NonMaxSuppressionV2","NonMaxSuppressionV3","NonMaxSuppressionV4"};
|
return new String[]{"NonMaxSuppression", "NonMaxSuppressionV2"};
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.image;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
|
import lombok.val;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
|
@ -43,20 +44,25 @@ import java.util.Map;
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
public class ResizeBilinear extends DynamicCustomOp {
|
public class ResizeBilinear extends DynamicCustomOp {
|
||||||
protected boolean alignCorners = false;
|
protected boolean alignCorners = false;
|
||||||
|
protected boolean halfPixelCenters = false;
|
||||||
protected Integer height = null;
|
protected Integer height = null;
|
||||||
protected Integer width = null;
|
protected Integer width = null;
|
||||||
|
|
||||||
public ResizeBilinear(@NonNull SameDiff sd, @NonNull SDVariable input, int height, int width, boolean alignCorners){
|
public ResizeBilinear(@NonNull SameDiff sd, @NonNull SDVariable input, int height, int width,
|
||||||
|
boolean alignCorners, boolean halfPixelCenters){
|
||||||
super(sd, input);
|
super(sd, input);
|
||||||
this.alignCorners = alignCorners;
|
this.alignCorners = alignCorners;
|
||||||
this.height = height;
|
this.height = height;
|
||||||
this.width = width;
|
this.width = width;
|
||||||
|
this.halfPixelCenters = halfPixelCenters;
|
||||||
addArgs();
|
addArgs();
|
||||||
}
|
}
|
||||||
|
|
||||||
public ResizeBilinear(@NonNull INDArray x, INDArray z, int height, int width, boolean alignCorners){
|
public ResizeBilinear(@NonNull INDArray x, INDArray z, int height, int width,
|
||||||
|
boolean alignCorners, boolean halfPixelCenters) {
|
||||||
super(new INDArray[]{x}, new INDArray[]{z});
|
super(new INDArray[]{x}, new INDArray[]{z});
|
||||||
this.alignCorners = alignCorners;
|
this.alignCorners = alignCorners;
|
||||||
|
this.halfPixelCenters = halfPixelCenters;
|
||||||
this.height = height;
|
this.height = height;
|
||||||
this.width = width;
|
this.width = width;
|
||||||
addArgs();
|
addArgs();
|
||||||
|
@ -76,7 +82,12 @@ public class ResizeBilinear extends DynamicCustomOp {
|
||||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||||
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||||
|
|
||||||
this.alignCorners = attributesForNode.get("align_corners").getB();
|
val attrC = attributesForNode.get("align_corners");
|
||||||
|
val attrH = attributesForNode.get("half_pixel_centers");
|
||||||
|
|
||||||
|
this.alignCorners = attrC != null ? attrC.getB() : false;
|
||||||
|
this.halfPixelCenters = attrH != null ? attrH.getB() : false;
|
||||||
|
|
||||||
addArgs();
|
addArgs();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -87,8 +98,7 @@ public class ResizeBilinear extends DynamicCustomOp {
|
||||||
iArguments.add(Long.valueOf(height));
|
iArguments.add(Long.valueOf(height));
|
||||||
iArguments.add(Long.valueOf(width));
|
iArguments.add(Long.valueOf(width));
|
||||||
}
|
}
|
||||||
iArguments.add(alignCorners ? 1L : 0L);
|
addBArgument(alignCorners, halfPixelCenters);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -204,7 +204,7 @@ public class MaxPoolWithArgmax extends DynamicCustomOp {
|
||||||
if(attributesForNode.containsKey("argmax")) {
|
if(attributesForNode.containsKey("argmax")) {
|
||||||
outputType = TFGraphMapper.convertType(attributesForNode.get("argmax").getType());
|
outputType = TFGraphMapper.convertType(attributesForNode.get("argmax").getType());
|
||||||
} else {
|
} else {
|
||||||
outputType = DataType.UINT32;
|
outputType = DataType.LONG;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -278,7 +278,7 @@ public class MaxPoolWithArgmax extends DynamicCustomOp {
|
||||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 1, "Expected 1 input data type for %s, got %s", getClass(), inputDataTypes);
|
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 1, "Expected 1 input data type for %s, got %s", getClass(), inputDataTypes);
|
||||||
List<DataType> result = new ArrayList<>();
|
List<DataType> result = new ArrayList<>();
|
||||||
result.add(inputDataTypes.get(0));
|
result.add(inputDataTypes.get(0));
|
||||||
result.add(outputType == null ? DataType.UINT32 : outputType);
|
result.add(outputType == null ? DataType.INT : outputType);
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -4584,6 +4584,7 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
|
||||||
* returns reference on array element with given index
|
* returns reference on array element with given index
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* returns array element with given index
|
* returns array element with given index
|
||||||
* i - element index in array
|
* i - element index in array
|
||||||
|
@ -5171,6 +5172,8 @@ NDArray& NDArray::operator()(const Nd4jLong* idx) {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
|
||||||
|
@ -5179,6 +5182,8 @@ NDArray& NDArray::operator()(const Nd4jLong* idx) {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
// #ifndef __JAVACPP_HACK__
|
// #ifndef __JAVACPP_HACK__
|
||||||
// #endif
|
// #endif
|
||||||
|
|
||||||
|
|
|
@ -4587,6 +4587,7 @@ public native @Cast("bool") boolean isOptimalRequirementsMet();
|
||||||
* returns reference on array element with given index
|
* returns reference on array element with given index
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* returns array element with given index
|
* returns array element with given index
|
||||||
* i - element index in array
|
* i - element index in array
|
||||||
|
@ -5174,6 +5175,8 @@ NDArray& NDArray::operator()(const Nd4jLong* idx) {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
|
||||||
|
@ -5182,6 +5185,8 @@ NDArray& NDArray::operator()(const Nd4jLong* idx) {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
// #ifndef __JAVACPP_HACK__
|
// #ifndef __JAVACPP_HACK__
|
||||||
// #endif
|
// #endif
|
||||||
|
|
||||||
|
@ -18280,7 +18285,6 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
||||||
/**
|
/**
|
||||||
* This op calculates polygamma function psi^(n)(x). Implementation is based on serial representation written in
|
* This op calculates polygamma function psi^(n)(x). Implementation is based on serial representation written in
|
||||||
* terms of the Hurwitz zeta function: polygamma = (-1)^{n+1} * n! * zeta(n+1, x).
|
* terms of the Hurwitz zeta function: polygamma = (-1)^{n+1} * n! * zeta(n+1, x).
|
||||||
* Currently the case n = 0 is not supported.
|
|
||||||
*
|
*
|
||||||
* Input arrays:
|
* Input arrays:
|
||||||
* 0: n - define derivative order (n+1), type integer (however currently is implemented as float casted to integer)
|
* 0: n - define derivative order (n+1), type integer (however currently is implemented as float casted to integer)
|
||||||
|
@ -18309,6 +18313,34 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
||||||
}
|
}
|
||||||
// #endif
|
// #endif
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This op calculates digamma function psi(x) = derivative of log(Gamma(x))
|
||||||
|
*
|
||||||
|
* Input arrays:
|
||||||
|
* 0: x - abscissa points where to evaluate the digamma function, type float
|
||||||
|
*
|
||||||
|
* Output array:
|
||||||
|
* 0: values of digamma function at corresponding x, type float
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
// #if NOT_EXCLUDED(OP_digamma)
|
||||||
|
@Namespace("nd4j::ops") public static class digamma extends DeclarableOp {
|
||||||
|
static { Loader.load(); }
|
||||||
|
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||||
|
public digamma(Pointer p) { super(p); }
|
||||||
|
/** Native array allocator. Access with {@link Pointer#position(long)}. */
|
||||||
|
public digamma(long size) { super((Pointer)null); allocateArray(size); }
|
||||||
|
private native void allocateArray(long size);
|
||||||
|
@Override public digamma position(long position) {
|
||||||
|
return (digamma)super.position(position);
|
||||||
|
}
|
||||||
|
|
||||||
|
public digamma() { super((Pointer)null); allocate(); }
|
||||||
|
private native void allocate();
|
||||||
|
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
||||||
|
}
|
||||||
|
// #endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This operation takes shape as first argument, and returns new NDArray filled with specific scalar value.
|
* This operation takes shape as first argument, and returns new NDArray filled with specific scalar value.
|
||||||
* Input arrays:
|
* Input arrays:
|
||||||
|
@ -18398,9 +18430,10 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
||||||
* This operation adjusts image hue by delta
|
* This operation adjusts image hue by delta
|
||||||
* Input arrays:
|
* Input arrays:
|
||||||
* 0 - input array with rank >= 3, must have at least one dimension equal 3, that is dimension containing channels.
|
* 0 - input array with rank >= 3, must have at least one dimension equal 3, that is dimension containing channels.
|
||||||
|
* 1 - optional argument, input scalar-array containing delta
|
||||||
*
|
*
|
||||||
* T arguments:
|
* T arguments:
|
||||||
* 0 - delta value
|
* 0 - optional argument, delta value
|
||||||
*
|
*
|
||||||
* Int arguments:
|
* Int arguments:
|
||||||
* 0 - optional argument, corresponds to dimension with 3 channels
|
* 0 - optional argument, corresponds to dimension with 3 channels
|
||||||
|
@ -18427,9 +18460,10 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
||||||
* This operation adjusts image saturation by delta
|
* This operation adjusts image saturation by delta
|
||||||
* Input arrays:
|
* Input arrays:
|
||||||
* 0 - input array with rank >= 3, must have at least one dimension equal 3, that is dimension containing channels.
|
* 0 - input array with rank >= 3, must have at least one dimension equal 3, that is dimension containing channels.
|
||||||
|
* 1 - optional argument, input scalar-array containing saturation factor
|
||||||
*
|
*
|
||||||
* T arguments:
|
* T arguments:
|
||||||
* 0 - saturation factor
|
* 0 - optional argument, saturation factor
|
||||||
*
|
*
|
||||||
* Int arguments:
|
* Int arguments:
|
||||||
* 0 - optional argument, corresponds to dimension with 3 channels
|
* 0 - optional argument, corresponds to dimension with 3 channels
|
||||||
|
@ -18456,9 +18490,10 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
||||||
* This operation adjusts image contrast by given factor ( z = (x - mean) * factor + mean )
|
* This operation adjusts image contrast by given factor ( z = (x - mean) * factor + mean )
|
||||||
* Input arrays:
|
* Input arrays:
|
||||||
* 0 - input array with rank >= 3, must have last one dimension equal 3, that is dimension containing channels.
|
* 0 - input array with rank >= 3, must have last one dimension equal 3, that is dimension containing channels.
|
||||||
|
* 1 - optional argument, input scalar-array containing saturation contrast factor
|
||||||
*
|
*
|
||||||
* T arguments:
|
* T arguments:
|
||||||
* 0 - contrast factor
|
* 0 - optional argument, contrast factor
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
// #if NOT_EXCLUDED(OP_adjust_contrast)
|
// #if NOT_EXCLUDED(OP_adjust_contrast)
|
||||||
|
|
|
@ -760,7 +760,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
.isSameMode(true)
|
.isSameMode(true)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
SDVariable[] results = sd.nn().maxPoolWithArgmax(new String[]{"",""}, in, pooling2DConfig);
|
SDVariable[] results = sd.nn().maxPoolWithArgmax(new String[]{"out","idx"}, in, pooling2DConfig);
|
||||||
assertArrayEquals(inArr.shape(), results[0].eval().shape());
|
assertArrayEquals(inArr.shape(), results[0].eval().shape());
|
||||||
assertArrayEquals(inArr.shape(), results[1].eval().shape());
|
assertArrayEquals(inArr.shape(), results[1].eval().shape());
|
||||||
}
|
}
|
||||||
|
@ -1050,7 +1050,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
||||||
SDVariable in = sd.var("in", inArr);
|
SDVariable in = sd.var("in", inArr);
|
||||||
SDVariable w = sd.var("w", wArr);
|
SDVariable w = sd.var("w", wArr);
|
||||||
|
|
||||||
SDVariable res = sd.cnn.conv1d(in, w, Conv1DConfig.builder().k(kernel).build());
|
SDVariable res = sd.cnn.conv1d(in, w, Conv1DConfig.builder().k(kernel).paddingMode(PaddingMode.VALID).build());
|
||||||
|
|
||||||
INDArray expected = Nd4j.createFromArray(
|
INDArray expected = Nd4j.createFromArray(
|
||||||
new double[][][]{
|
new double[][][]{
|
||||||
|
|
|
@ -23,13 +23,7 @@ import static org.junit.Assert.fail;
|
||||||
import org.junit.Assert;
|
import org.junit.Assert;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv2D;
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv2D;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig;
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.*;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv2DConfig;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv3DConfig;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling3DConfig;
|
|
||||||
|
|
||||||
public class ConvConfigTests {
|
public class ConvConfigTests {
|
||||||
|
|
||||||
|
@ -489,24 +483,24 @@ public class ConvConfigTests {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testConv1D(){
|
public void testConv1D(){
|
||||||
Conv1DConfig.builder().k(2).build();
|
Conv1DConfig.builder().k(2).paddingMode(PaddingMode.SAME).build();
|
||||||
|
|
||||||
try{
|
try{
|
||||||
Conv1DConfig.builder().k(0).build();
|
Conv1DConfig.builder().k(0).paddingMode(PaddingMode.SAME).build();
|
||||||
fail();
|
fail();
|
||||||
} catch (IllegalArgumentException e){
|
} catch (IllegalArgumentException e){
|
||||||
assertTrue(e.getMessage().contains("Kernel"));
|
assertTrue(e.getMessage().contains("Kernel"));
|
||||||
}
|
}
|
||||||
|
|
||||||
try{
|
try{
|
||||||
Conv1DConfig.builder().k(4).s(-2).build();
|
Conv1DConfig.builder().k(4).s(-2).paddingMode(PaddingMode.SAME).build();
|
||||||
fail();
|
fail();
|
||||||
} catch (IllegalArgumentException e){
|
} catch (IllegalArgumentException e){
|
||||||
assertTrue(e.getMessage().contains("Stride"));
|
assertTrue(e.getMessage().contains("Stride"));
|
||||||
}
|
}
|
||||||
|
|
||||||
try{
|
try{
|
||||||
Conv1DConfig.builder().k(3).p(-2).build();
|
Conv1DConfig.builder().k(3).p(-2).paddingMode(PaddingMode.SAME).build();
|
||||||
fail();
|
fail();
|
||||||
} catch (IllegalArgumentException e){
|
} catch (IllegalArgumentException e){
|
||||||
assertTrue(e.getMessage().contains("Padding"));
|
assertTrue(e.getMessage().contains("Padding"));
|
||||||
|
|
|
@ -117,9 +117,6 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
|
||||||
// 2019/11/15 - failure https://github.com/eclipse/deeplearning4j/issues/8402
|
// 2019/11/15 - failure https://github.com/eclipse/deeplearning4j/issues/8402
|
||||||
"fake_quant/min_max_args_per_channel.*",
|
"fake_quant/min_max_args_per_channel.*",
|
||||||
|
|
||||||
// 2019/11/15 - failure https://github.com/eclipse/deeplearning4j/issues/8403
|
|
||||||
"resize_bilinear/int32.*",
|
|
||||||
|
|
||||||
// Suggesting TF 1.15 bug
|
// Suggesting TF 1.15 bug
|
||||||
"non_max_suppression_v2/float16.*",
|
"non_max_suppression_v2/float16.*",
|
||||||
|
|
||||||
|
|
|
@ -972,7 +972,7 @@ public class CustomOpsTests extends BaseNd4jTest {
|
||||||
INDArray x = Nd4j.rand(1, 2,3,4);
|
INDArray x = Nd4j.rand(1, 2,3,4);
|
||||||
INDArray z = Nd4j.createUninitialized(x.shape());
|
INDArray z = Nd4j.createUninitialized(x.shape());
|
||||||
boolean align = false;
|
boolean align = false;
|
||||||
val op = new ResizeBilinear(x, z, 10, 10, align);
|
val op = new ResizeBilinear(x, z, 10, 10, align, false);
|
||||||
Nd4j.exec(op);
|
Nd4j.exec(op);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1174,6 +1174,7 @@ public class CustomOpsTests extends BaseNd4jTest {
|
||||||
assertEquals(expected, x);
|
assertEquals(expected, x);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Ignore("AS failed 2019/12/04")
|
||||||
@Test
|
@Test
|
||||||
public void testPolygamma() {
|
public void testPolygamma() {
|
||||||
INDArray n = Nd4j.linspace(DataType.FLOAT, 1.0, 1.0, 9).reshape(3,3);
|
INDArray n = Nd4j.linspace(DataType.FLOAT, 1.0, 1.0, 9).reshape(3,3);
|
||||||
|
|
Loading…
Reference in New Issue