Dev branch merge: dev_20190606 (#7904)
* correct logsoftmax looss (#2) * Small SameDiff listener fix (#4) * Various fixes (#6) * #7839 Fix for asXMatrix and tests * #7866 EmbeddingSequenceLayer dtype fix + test * #7856 SameDiff save/load stream methods * #7859 RegressionEvaluation rank 4 fix + tests + axis configuration * EvaluationBinary 3d/4d * More evaluation 3d/4d tests * #7847 Evaluation empty checks * Small test ifx * #7848 Fix median edge case * Improve DL4J samediff layer tests * [WIP] FastText wrapper implemented (#8) * FastText implemented * Some fixes * Fix shapes for wordsNearest * Validation of input vectors * Fixes * Fixed test * Thread tagged * Some tweaks * setContextClassLoader for DeallocatorServiceThread * Numpy format tests (#1) * Various fixes (#11) * #7852 SameDiff gather fix * #7892 SameDiff placeholder to constant conversion * #7890 validate input rank for MLN/CG init methods * Fix broken permute shape calculation * Permute and gather fixes * Tests * #7850 LogSumExp fix + test * Handful of test fixes * Empty arrays with non-scalar shapes (#10) * minor rearrangements for lambdas * empty tensors with non-scalar shapes * numpy empty tensors with non-scalar shapes * few more empty tweaks * Small fixes * conv3d signature update * micro fix in batchnorm mkldnn * Import fixes * Fix * MKL-DNN update * Small fill fix * fill with empty input + test * Fixes * Small error improvement * Fix * one special test * couple of fixes for lstm * Rewrite TFGraphMapper.getNDArrayFromTensor to be maintainable and less error prone * Fixes * FP16 * Unsigned * BFloat16 * Fill op - empty tweaks * - couple of fixes for empty arrays construction - stack updated * strided slice fix * one transform test * provide method for reducing shapeInfo in case of input array is empty * Fixed reduceAlongDimensions to use empty input properly. * couple of broadcast tests * couple of tests broadcast tests + tweak to make them pass * add check of non-empty to methods producing sub-arrays * Fixed reshapeC with zeros in shape. * complete empty check in reduce_... legacy ops * Concat and cumsum/prod * Tweak to empty shape inference on import * add empty check to the rest of reduce legacy ops * one more test * correct typo in evalReduceShapeInfoEmpty * Added tests for reduce_* ops to tests with zero shapes. * few more tests for empty reductions * Fixed strided_slice op with empty case and tests. * one more empty reduction test * Fixed strided_slice test. * add empty check to NDArray::reshapei * infOrMax * empty min/max with infinity tests * made unstack working correctly with empty arrays * few IndexReduce tests + tweaks for empty shapes * add test for empty concat * few tests fixed * Validation fix for reductions on empty shapes * Reverse fix * Reduction shape calc fixes * SameDiff.generateOutputVariable: don't use shape function to determine number of outputs * Range fix * - NDArray constructor updated for scalars/empty arrays - few tests fixed * More fixes * Empty creator fixes * concat fix * concat fix * TF import tests: allow 'both all NaN' and 'both all inf' to pass * Slice, zero fraction, and reshape fixes * transpose, gather * Zero fraction * scalar cast fix * Empty reduction axis support * few more tests fixed * Fixed input checks conforming with TF for concat op and tests. * few tests fixed * matmul scalar shape fix * Fixed checkout for data type and scalarity with concat to allow non-empty scalars with vector concats. * broadcast bool fix * few more tests * few more tests * correct evalReduceShapeInfoEmpty * argmax/argmin + tests * one more empty edge case + one more test * argmax/argmin/realdiv_bp tweaks * empty reshape test + fix * Helper fixes * Small fixes * Gather test fix * Gather test fix * Small fixes * reduce scalar zero values * scalar mean workaround * Remove debug code * along dim mean workaround * one more test * - equalsTo() tweak for empty arrays - one more test * broadcast tweaksmaster
parent
32e5cc1945
commit
68ea5f3688
|
@ -23,6 +23,7 @@ import org.deeplearning4j.nn.api.Layer;
|
|||
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.*;
|
||||
import org.deeplearning4j.nn.conf.layers.EmbeddingLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.EmbeddingSequenceLayer;
|
||||
|
@ -283,7 +284,6 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
|
|||
net.computeGradientAndScore();
|
||||
net2.computeGradientAndScore();
|
||||
|
||||
System.out.println(net.score() + "\t" + net2.score());
|
||||
assertEquals(net2.score(), net.score(), 1e-6);
|
||||
|
||||
Map<String, INDArray> gradient = net.gradient().gradientForVariable();
|
||||
|
@ -441,6 +441,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
|
|||
int numInputClasses = 10;
|
||||
int timeSeriesLength = 5;
|
||||
|
||||
for (DataType maskDtype : new DataType[]{DataType.FLOAT, DataType.DOUBLE, DataType.INT}) {
|
||||
for (int nExamples : miniBatchSizes) {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
||||
|
@ -492,7 +493,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
|
|||
}
|
||||
}
|
||||
|
||||
INDArray inputMask = Nd4j.zeros(nExamples, timeSeriesLength);
|
||||
INDArray inputMask = Nd4j.zeros(maskDtype, nExamples, timeSeriesLength);
|
||||
for (int i = 0; i < nExamples; i++) {
|
||||
for (int j = 0; j < timeSeriesLength; j++) {
|
||||
inputMask.putScalar(new int[]{i, j}, (r.nextBoolean() ? 1.0 : 0.0));
|
||||
|
@ -523,6 +524,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
|
|||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
|
@ -583,6 +585,104 @@ public class EmbeddingLayerTest extends BaseDL4JTest {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEmbeddingSequenceLayerWithMasking() {
|
||||
//Idea: have masking on the input with an embedding and dense layers on input
|
||||
//Ensure that the parameter gradients for the inputs don't depend on the inputs when inputs are masked
|
||||
|
||||
int[] miniBatchSizes = {1, 3};
|
||||
int nIn = 2;
|
||||
Random r = new Random(12345);
|
||||
|
||||
int numInputClasses = 10;
|
||||
int timeSeriesLength = 5;
|
||||
|
||||
for (DataType maskDtype : new DataType[]{DataType.FLOAT, DataType.DOUBLE, DataType.INT}) {
|
||||
for (DataType inLabelDtype : new DataType[]{DataType.FLOAT, DataType.DOUBLE, DataType.INT}) {
|
||||
for(int inputRank : new int[]{2, 3}) {
|
||||
for (int nExamples : miniBatchSizes) {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
||||
.updater(new Sgd(0.1)).seed(12345).list()
|
||||
.layer(0, new EmbeddingSequenceLayer.Builder().hasBias(true).activation(Activation.TANH).nIn(numInputClasses)
|
||||
.nOut(5).build())
|
||||
.layer(1, new DenseLayer.Builder().activation(Activation.TANH).nIn(5).nOut(4).build())
|
||||
.layer(2, new LSTM.Builder().activation(Activation.TANH).nIn(4).nOut(3).build())
|
||||
.layer(3, new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(3)
|
||||
.nOut(4).build())
|
||||
.setInputType(InputType.recurrent(1)).build();
|
||||
|
||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||
net.init();
|
||||
|
||||
MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder()
|
||||
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
||||
.updater(new Sgd(0.1)).seed(12345).list()
|
||||
.layer(0, new DenseLayer.Builder().activation(Activation.TANH).nIn(numInputClasses).nOut(5)
|
||||
.build())
|
||||
.layer(1, new DenseLayer.Builder().activation(Activation.TANH).nIn(5).nOut(4).build())
|
||||
.layer(2, new LSTM.Builder().activation(Activation.TANH).nIn(4).nOut(3).build())
|
||||
.layer(3, new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(3)
|
||||
.nOut(4).build())
|
||||
.setInputType(InputType.recurrent(1)).build();
|
||||
|
||||
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
|
||||
net2.init();
|
||||
|
||||
net2.setParams(net.params().dup());
|
||||
|
||||
INDArray inEmbedding = Nd4j.zeros(inLabelDtype, inputRank == 2 ? new long[]{nExamples, timeSeriesLength} : new long[]{nExamples, 1, timeSeriesLength});
|
||||
INDArray inDense = Nd4j.zeros(inLabelDtype, nExamples, numInputClasses, timeSeriesLength);
|
||||
|
||||
INDArray labels = Nd4j.zeros(inLabelDtype, nExamples, 4, timeSeriesLength);
|
||||
|
||||
for (int i = 0; i < nExamples; i++) {
|
||||
for (int j = 0; j < timeSeriesLength; j++) {
|
||||
int inIdx = r.nextInt(numInputClasses);
|
||||
inEmbedding.putScalar(inputRank == 2 ? new int[]{i, j} : new int[]{i, 0, j}, inIdx);
|
||||
inDense.putScalar(new int[]{i, inIdx, j}, 1.0);
|
||||
|
||||
int outIdx = r.nextInt(4);
|
||||
labels.putScalar(new int[]{i, outIdx, j}, 1.0);
|
||||
}
|
||||
}
|
||||
|
||||
INDArray inputMask = Nd4j.zeros(maskDtype, nExamples, timeSeriesLength);
|
||||
for (int i = 0; i < nExamples; i++) {
|
||||
for (int j = 0; j < timeSeriesLength; j++) {
|
||||
inputMask.putScalar(new int[]{i, j}, (r.nextBoolean() ? 1.0 : 0.0));
|
||||
}
|
||||
}
|
||||
|
||||
net.setLayerMaskArrays(inputMask, null);
|
||||
net2.setLayerMaskArrays(inputMask, null);
|
||||
List<INDArray> actEmbedding = net.feedForward(inEmbedding, false);
|
||||
List<INDArray> actDense = net2.feedForward(inDense, false);
|
||||
for (int i = 2; i < actEmbedding.size(); i++) { //Start from layer 2: EmbeddingSequence is 3d, first dense is 2d (before reshape)
|
||||
assertEquals(actDense.get(i), actEmbedding.get(i));
|
||||
}
|
||||
|
||||
net.setLabels(labels);
|
||||
net2.setLabels(labels);
|
||||
net.computeGradientAndScore();
|
||||
net2.computeGradientAndScore();
|
||||
|
||||
assertEquals(net2.score(), net.score(), 1e-5);
|
||||
|
||||
Map<String, INDArray> gradients = net.gradient().gradientForVariable();
|
||||
Map<String, INDArray> gradients2 = net2.gradient().gradientForVariable();
|
||||
assertEquals(gradients.keySet(), gradients2.keySet());
|
||||
for (String s : gradients.keySet()) {
|
||||
assertEquals(gradients2.get(s), gradients.get(s));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@EqualsAndHashCode
|
||||
private static class WordVectorsMockup implements EmbeddingInitializer {
|
||||
|
||||
|
|
|
@ -213,6 +213,12 @@ public class TestSameDiffConv extends BaseDL4JTest {
|
|||
INDArray outLoaded = netLoaded.output(in);
|
||||
|
||||
assertEquals(msg, outExp, outLoaded);
|
||||
|
||||
//Sanity check on different minibatch sizes:
|
||||
INDArray newIn = Nd4j.vstack(in, in);
|
||||
INDArray outMbsd = net.output(newIn);
|
||||
INDArray outMb = net2.output(newIn);
|
||||
assertEquals(outMb, outMbsd);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -306,6 +312,10 @@ public class TestSameDiffConv extends BaseDL4JTest {
|
|||
assertTrue(msg, gradOK);
|
||||
|
||||
TestUtils.testModelSerialization(net);
|
||||
|
||||
//Sanity check on different minibatch sizes:
|
||||
INDArray newIn = Nd4j.vstack(f, f);
|
||||
net.output(newIn);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -137,6 +137,12 @@ public class TestSameDiffDense extends BaseDL4JTest {
|
|||
INDArray outLoaded = netLoaded.output(in);
|
||||
|
||||
assertEquals(outExp, outLoaded);
|
||||
|
||||
//Sanity check on different minibatch sizes:
|
||||
INDArray newIn = Nd4j.vstack(in, in);
|
||||
INDArray outMbsd = net.output(newIn);
|
||||
INDArray outMb = net2.output(newIn);
|
||||
assertEquals(outMb, outMbsd);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -314,6 +320,12 @@ public class TestSameDiffDense extends BaseDL4JTest {
|
|||
netSD.computeGradientAndScore();
|
||||
// netStandard.computeGradientAndScore();
|
||||
// assertEquals(netStandard.gradient().gradient(), netSD.gradient().gradient());
|
||||
|
||||
//Sanity check on different minibatch sizes:
|
||||
INDArray newIn = Nd4j.vstack(in, in);
|
||||
INDArray outMbsd = netSD.output(newIn);
|
||||
INDArray outMb = netStandard.output(newIn);
|
||||
assertEquals(outMb, outMbsd);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -377,6 +389,12 @@ public class TestSameDiffDense extends BaseDL4JTest {
|
|||
assertEquals(s, netStandard.params(), netSD.params());
|
||||
assertEquals(s, netStandard.getUpdater().getStateViewArray(), netSD.getUpdater().getStateViewArray());
|
||||
}
|
||||
|
||||
//Sanity check on different minibatch sizes:
|
||||
INDArray newIn = Nd4j.vstack(ds.getFeatures(), ds.getFeatures());
|
||||
INDArray outMbsd = netSD.output(newIn);
|
||||
INDArray outMb = netStandard.output(newIn);
|
||||
assertEquals(outMb, outMbsd);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -417,6 +435,10 @@ public class TestSameDiffDense extends BaseDL4JTest {
|
|||
assertTrue(msg, gradOK);
|
||||
|
||||
TestUtils.testModelSerialization(net);
|
||||
|
||||
//Sanity check on different minibatch sizes:
|
||||
INDArray newIn = Nd4j.vstack(f, f);
|
||||
net.output(newIn);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -166,6 +166,12 @@ public class TestSameDiffDenseVertex extends BaseDL4JTest {
|
|||
outSD = loaded.outputSingle(in);
|
||||
outStd = netStandard.outputSingle(in);
|
||||
assertEquals(outStd, outSD);
|
||||
|
||||
//Sanity check on different minibatch sizes:
|
||||
INDArray newIn = Nd4j.vstack(in, in);
|
||||
INDArray outMbsd = netSD.output(newIn)[0];
|
||||
INDArray outMb = netStandard.output(newIn)[0];
|
||||
assertEquals(outMb, outMbsd);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -115,6 +115,12 @@ public class TestSameDiffLambda extends BaseDL4JTest {
|
|||
outStd = std.outputSingle(in);
|
||||
|
||||
assertEquals(outStd, outLambda);
|
||||
|
||||
//Sanity check on different minibatch sizes:
|
||||
INDArray newIn = Nd4j.vstack(in, in);
|
||||
INDArray outMbsd = lambda.output(newIn)[0];
|
||||
INDArray outMb = std.output(newIn)[0];
|
||||
assertEquals(outMb, outMbsd);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -186,5 +192,12 @@ public class TestSameDiffLambda extends BaseDL4JTest {
|
|||
outStd = std.output(in1, in2)[0];
|
||||
|
||||
assertEquals(outStd, outLambda);
|
||||
|
||||
//Sanity check on different minibatch sizes:
|
||||
INDArray newIn1 = Nd4j.vstack(in1, in1);
|
||||
INDArray newIn2 = Nd4j.vstack(in2, in2);
|
||||
INDArray outMbsd = lambda.output(newIn1, newIn2)[0];
|
||||
INDArray outMb = std.output(newIn1, newIn2)[0];
|
||||
assertEquals(outMb, outMbsd);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -90,6 +90,12 @@ public class TestSameDiffOutput extends BaseDL4JTest {
|
|||
MultiLayerNetwork net = new MultiLayerNetwork(confSD.clone());
|
||||
net.init();
|
||||
net.fit(ds);
|
||||
|
||||
//Sanity check on different minibatch sizes:
|
||||
INDArray newIn = Nd4j.vstack(in, in);
|
||||
INDArray outMbsd = netSD.output(newIn);
|
||||
INDArray outMb = netStd.output(newIn);
|
||||
assertEquals(outMb, outMbsd);
|
||||
}
|
||||
|
||||
|
||||
|
@ -164,6 +170,12 @@ public class TestSameDiffOutput extends BaseDL4JTest {
|
|||
MultiLayerNetwork net = new MultiLayerNetwork(confSD.clone());
|
||||
net.init();
|
||||
net.fit(ds);
|
||||
|
||||
//Sanity check on different minibatch sizes:
|
||||
INDArray newIn = Nd4j.vstack(in, in);
|
||||
INDArray outMbsd = netSD.output(newIn);
|
||||
INDArray outMb = netStd.output(newIn);
|
||||
assertEquals(outMb, outMbsd);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -32,6 +32,7 @@ import org.deeplearning4j.models.embeddings.learning.impl.elements.SkipGram;
|
|||
import org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils;
|
||||
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
|
||||
import org.deeplearning4j.models.embeddings.wordvectors.WordVectorsImpl;
|
||||
import org.deeplearning4j.models.fasttext.FastText;
|
||||
import org.deeplearning4j.models.glove.Glove;
|
||||
import org.deeplearning4j.models.paragraphvectors.ParagraphVectors;
|
||||
import org.deeplearning4j.models.sequencevectors.SequenceVectors;
|
||||
|
@ -3090,6 +3091,42 @@ public class WordVectorSerializer {
|
|||
return word2Vec;
|
||||
}
|
||||
|
||||
public static void writeWordVectors(@NonNull FastText vectors, @NonNull File path) throws IOException {
|
||||
ObjectOutputStream outputStream = null;
|
||||
try {
|
||||
outputStream = new ObjectOutputStream(new FileOutputStream(path ));
|
||||
outputStream.writeObject(vectors);
|
||||
}
|
||||
finally {
|
||||
try {
|
||||
if (outputStream != null) {
|
||||
outputStream.flush();
|
||||
outputStream.close();
|
||||
}
|
||||
} catch (IOException ex) {
|
||||
ex.printStackTrace();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public static FastText readWordVectors(File path) {
|
||||
FastText result = null;
|
||||
try {
|
||||
FileInputStream fileIn = new FileInputStream(path);
|
||||
ObjectInputStream in = new ObjectInputStream(fileIn);
|
||||
try {
|
||||
result = (FastText) in.readObject();
|
||||
} catch (ClassNotFoundException ex) {
|
||||
|
||||
}
|
||||
} catch (FileNotFoundException ex) {
|
||||
ex.printStackTrace();
|
||||
} catch (IOException ex) {
|
||||
ex.printStackTrace();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
public static void printOutProjectedMemoryUse(long numWords, int vectorLength, int numTables) {
|
||||
double memSize = numWords * vectorLength * Nd4j.sizeOfDataType() * numTables;
|
||||
|
||||
|
|
|
@ -21,6 +21,7 @@ import lombok.AllArgsConstructor;
|
|||
import lombok.Data;
|
||||
import lombok.NonNull;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import org.deeplearning4j.models.embeddings.WeightLookupTable;
|
||||
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
|
||||
import org.deeplearning4j.models.embeddings.reader.ModelUtils;
|
||||
|
@ -207,7 +208,6 @@ public class BasicModelUtils<T extends SequenceElement> implements ModelUtils<T>
|
|||
}
|
||||
|
||||
INDArray mean = words.isMatrix() ? words.mean(0) : words;
|
||||
|
||||
Collection<String> tempRes = wordsNearest(mean, top + positive.size() + negative.size());
|
||||
List<String> realResults = new ArrayList<>();
|
||||
|
||||
|
@ -232,6 +232,22 @@ public class BasicModelUtils<T extends SequenceElement> implements ModelUtils<T>
|
|||
return wordsNearestSum(vec, n);
|
||||
}
|
||||
|
||||
protected INDArray adjustRank(INDArray words) {
|
||||
if (lookupTable instanceof InMemoryLookupTable) {
|
||||
InMemoryLookupTable l = (InMemoryLookupTable) lookupTable;
|
||||
|
||||
INDArray syn0 = l.getSyn0();
|
||||
if (!words.dataType().equals(syn0.dataType())) {
|
||||
return words.castTo(syn0.dataType());
|
||||
}
|
||||
if (words.rank() == 0 || words.rank() > 2) {
|
||||
throw new IllegalStateException("Invalid rank for wordsNearest method");
|
||||
} else if (words.rank() == 1) {
|
||||
return words.reshape(1, -1);
|
||||
}
|
||||
}
|
||||
return words;
|
||||
}
|
||||
/**
|
||||
* Words nearest based on positive and negative words
|
||||
* * @param top the top n words
|
||||
|
@ -239,6 +255,8 @@ public class BasicModelUtils<T extends SequenceElement> implements ModelUtils<T>
|
|||
*/
|
||||
@Override
|
||||
public Collection<String> wordsNearest(INDArray words, int top) {
|
||||
words = adjustRank(words);
|
||||
|
||||
if (lookupTable instanceof InMemoryLookupTable) {
|
||||
InMemoryLookupTable l = (InMemoryLookupTable) lookupTable;
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.deeplearning4j.models.embeddings.reader.impl;
|
||||
|
||||
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
|
||||
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.ops.transforms.Transforms;
|
||||
|
@ -64,6 +65,8 @@ public class FlatModelUtils<T extends SequenceElement> extends BasicModelUtils<T
|
|||
public Collection<String> wordsNearest(INDArray words, int top) {
|
||||
Counter<String> distances = new Counter<>();
|
||||
|
||||
words = adjustRank(words);
|
||||
|
||||
for (String s : vocabCache.words()) {
|
||||
INDArray otherVec = lookupTable.vector(s);
|
||||
double sim = Transforms.cosineSim(Transforms.unitVec(words.dup()), Transforms.unitVec(otherVec.dup()));
|
||||
|
|
|
@ -103,6 +103,7 @@ public class TreeModelUtils<T extends SequenceElement> extends BasicModelUtils<T
|
|||
@Override
|
||||
public Collection<String> wordsNearest(INDArray words, int top) {
|
||||
checkTree();
|
||||
words = adjustRank(words);
|
||||
|
||||
List<DataPoint> add = new ArrayList<>();
|
||||
List<Double> distances = new ArrayList<>();
|
||||
|
|
|
@ -172,4 +172,10 @@ public interface WordVectors extends Serializable, EmbeddingInitializer {
|
|||
*/
|
||||
void setModelUtils(ModelUtils utils);
|
||||
|
||||
/**
|
||||
* Does implementation vectorize words absent in vocabulary
|
||||
* @return boolean
|
||||
*/
|
||||
boolean outOfVocabularySupported();
|
||||
|
||||
}
|
||||
|
|
|
@ -20,6 +20,7 @@ import com.google.common.util.concurrent.AtomicDouble;
|
|||
import lombok.Getter;
|
||||
import lombok.NonNull;
|
||||
import lombok.Setter;
|
||||
import lombok.val;
|
||||
import org.apache.commons.lang.ArrayUtils;
|
||||
import org.deeplearning4j.models.embeddings.WeightLookupTable;
|
||||
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
|
||||
|
@ -357,4 +358,9 @@ public class WordVectorsImpl<T extends SequenceElement> implements WordVectors {
|
|||
public boolean jsonSerializable() {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean outOfVocabularySupported() {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,10 +6,13 @@ import lombok.extern.slf4j.Slf4j;
|
|||
import org.apache.commons.lang3.NotImplementedException;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.deeplearning4j.models.embeddings.WeightLookupTable;
|
||||
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
|
||||
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
|
||||
import org.deeplearning4j.models.embeddings.reader.ModelUtils;
|
||||
import org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils;
|
||||
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
|
||||
import org.deeplearning4j.models.word2vec.VocabWord;
|
||||
import org.deeplearning4j.models.word2vec.Word2Vec;
|
||||
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
|
||||
import org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache;
|
||||
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
|
||||
|
@ -17,41 +20,78 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
|||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
|
||||
import java.io.BufferedWriter;
|
||||
import java.io.File;
|
||||
import java.io.FileWriter;
|
||||
import java.io.IOException;
|
||||
import java.util.Collection;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.io.*;
|
||||
import java.util.*;
|
||||
|
||||
@Slf4j
|
||||
@AllArgsConstructor
|
||||
@lombok.Builder
|
||||
public class FastText implements WordVectors {
|
||||
public class FastText implements WordVectors, Serializable {
|
||||
|
||||
private boolean supervised;
|
||||
private boolean quantize;
|
||||
private boolean predict;
|
||||
private boolean predict_prob;
|
||||
// Mandatory
|
||||
@Getter private String inputFile;
|
||||
@Getter private String outputFile;
|
||||
|
||||
private boolean skipgram;
|
||||
@Builder.Default private int bucket = 100;
|
||||
@Builder.Default private int minCount = 1;
|
||||
// Optional for dictionary
|
||||
@Builder.Default private int bucket = -1;
|
||||
@Builder.Default private int minCount = -1;
|
||||
@Builder.Default private int minCountLabel = -1;
|
||||
@Builder.Default private int wordNgrams = -1;
|
||||
@Builder.Default private int minNgramLength = -1;
|
||||
@Builder.Default private int maxNgramLength = -1;
|
||||
@Builder.Default private int samplingThreshold = -1;
|
||||
private String labelPrefix;
|
||||
|
||||
private boolean cbow;
|
||||
private boolean nn;
|
||||
private boolean analogies;
|
||||
private String inputFile;
|
||||
private String outputFile;
|
||||
private SentenceIterator iterator;
|
||||
private String modelName;
|
||||
private String lossName;
|
||||
//TODO:
|
||||
private double[] pretrainedVectors;
|
||||
// Optional for training
|
||||
@Getter private boolean supervised;
|
||||
@Getter private boolean quantize;
|
||||
@Getter private boolean predict;
|
||||
@Getter private boolean predict_prob;
|
||||
@Getter private boolean skipgram;
|
||||
@Getter private boolean cbow;
|
||||
@Getter private boolean nn;
|
||||
@Getter private boolean analogies;
|
||||
@Getter private String pretrainedVectorsFile;
|
||||
@Getter
|
||||
@Builder.Default
|
||||
private double learningRate = -1.0;
|
||||
@Getter private double learningRateUpdate = -1.0;
|
||||
@Getter
|
||||
@Builder.Default
|
||||
private int dim = -1;
|
||||
@Getter
|
||||
@Builder.Default
|
||||
private int contextWindowSize = -1;
|
||||
@Getter
|
||||
@Builder.Default
|
||||
private int epochs = -1;
|
||||
@Getter private String modelName;
|
||||
@Getter private String lossName;
|
||||
@Getter
|
||||
@Builder.Default
|
||||
private int negativeSamples = -1;
|
||||
@Getter
|
||||
@Builder.Default
|
||||
private int numThreads = -1;
|
||||
@Getter private boolean saveOutput = false;
|
||||
|
||||
private JFastText fastTextImpl;
|
||||
private boolean modelLoaded;
|
||||
// Optional for quantization
|
||||
@Getter
|
||||
@Builder.Default
|
||||
private int cutOff = -1;
|
||||
@Getter private boolean retrain;
|
||||
@Getter private boolean qnorm;
|
||||
@Getter private boolean qout;
|
||||
@Getter
|
||||
@Builder.Default
|
||||
private int dsub = -1;
|
||||
|
||||
@Getter private SentenceIterator iterator;
|
||||
|
||||
@Builder.Default private transient JFastText fastTextImpl = new JFastText();
|
||||
private transient Word2Vec word2Vec;
|
||||
@Getter private boolean modelLoaded;
|
||||
@Getter private boolean modelVectorsLoaded;
|
||||
private VocabCache vocabCache;
|
||||
|
||||
public FastText(File modelPath) {
|
||||
|
@ -63,8 +103,97 @@ public class FastText implements WordVectors {
|
|||
fastTextImpl = new JFastText();
|
||||
}
|
||||
|
||||
public void init() {
|
||||
fastTextImpl = new JFastText();
|
||||
private static class ArgsFactory {
|
||||
|
||||
private List<String> args = new ArrayList<>();
|
||||
|
||||
private void add(String label, String value) {
|
||||
args.add(label);
|
||||
args.add(value);
|
||||
}
|
||||
|
||||
private void addOptional(String label, int value) {
|
||||
if (value >= 0) {
|
||||
args.add(label);
|
||||
args.add(Integer.toString(value));
|
||||
}
|
||||
}
|
||||
|
||||
private void addOptional(String label, double value) {
|
||||
if (value >= 0.0) {
|
||||
args.add(label);
|
||||
args.add(Double.toString(value));
|
||||
}
|
||||
}
|
||||
|
||||
private void addOptional(String label, String value) {
|
||||
if (StringUtils.isNotEmpty(value)) {
|
||||
args.add(label);
|
||||
args.add(value);
|
||||
}
|
||||
}
|
||||
|
||||
private void addOptional(String label, boolean value) {
|
||||
if (value) {
|
||||
args.add(label);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
public String[] args() {
|
||||
String[] asArray = new String[args.size()];
|
||||
return args.toArray(asArray);
|
||||
}
|
||||
}
|
||||
|
||||
private String[] makeArgs() {
|
||||
ArgsFactory argsFactory = new ArgsFactory();
|
||||
|
||||
argsFactory.addOptional("cbow", cbow);
|
||||
argsFactory.addOptional("skipgram", skipgram);
|
||||
argsFactory.addOptional("supervised", supervised);
|
||||
argsFactory.addOptional("quantize", quantize);
|
||||
argsFactory.addOptional("predict", predict);
|
||||
argsFactory.addOptional("predict_prob", predict_prob);
|
||||
|
||||
argsFactory.add("-input", inputFile);
|
||||
argsFactory.add("-output", outputFile );
|
||||
|
||||
argsFactory.addOptional("-pretrainedVectors", pretrainedVectorsFile);
|
||||
|
||||
argsFactory.addOptional("-bucket", bucket);
|
||||
argsFactory.addOptional("-minCount", minCount);
|
||||
argsFactory.addOptional("-minCountLabel", minCountLabel);
|
||||
argsFactory.addOptional("-wordNgrams", wordNgrams);
|
||||
argsFactory.addOptional("-minn", minNgramLength);
|
||||
argsFactory.addOptional("-maxn", maxNgramLength);
|
||||
argsFactory.addOptional("-t", samplingThreshold);
|
||||
argsFactory.addOptional("-label", labelPrefix);
|
||||
argsFactory.addOptional("analogies",analogies);
|
||||
argsFactory.addOptional("-lr", learningRate);
|
||||
argsFactory.addOptional("-lrUpdateRate", learningRateUpdate);
|
||||
argsFactory.addOptional("-dim", dim);
|
||||
argsFactory.addOptional("-ws", contextWindowSize);
|
||||
argsFactory.addOptional("-epoch", epochs);
|
||||
argsFactory.addOptional("-loss", lossName);
|
||||
argsFactory.addOptional("-neg", negativeSamples);
|
||||
argsFactory.addOptional("-thread", numThreads);
|
||||
argsFactory.addOptional("-saveOutput", saveOutput);
|
||||
argsFactory.addOptional("-cutoff", cutOff);
|
||||
argsFactory.addOptional("-retrain", retrain);
|
||||
argsFactory.addOptional("-qnorm", qnorm);
|
||||
argsFactory.addOptional("-qout", qout);
|
||||
argsFactory.addOptional("-dsub", dsub);
|
||||
|
||||
return argsFactory.args();
|
||||
}
|
||||
|
||||
public void fit() {
|
||||
String[] cmd = makeArgs();
|
||||
fastTextImpl.runCmd(cmd);
|
||||
}
|
||||
|
||||
public void loadIterator() {
|
||||
if (iterator != null) {
|
||||
try {
|
||||
File tempFile = File.createTempFile("FTX", ".txt");
|
||||
|
@ -81,24 +210,11 @@ public class FastText implements WordVectors {
|
|||
}
|
||||
}
|
||||
|
||||
public void fit() {
|
||||
|
||||
String[] cmd;
|
||||
if (skipgram) {
|
||||
cmd = new String[]{"skipgram", "-bucket", Integer.toString(bucket), "-minCount", Integer.toString(minCount),
|
||||
"-input", inputFile, "-output", outputFile};
|
||||
}
|
||||
else if (cbow) {
|
||||
cmd = new String[]{"cbow", "-bucket", Integer.toString(bucket), "-minCount", Integer.toString(minCount),
|
||||
"-input", inputFile, "-output", outputFile};
|
||||
}
|
||||
else if (supervised)
|
||||
cmd = new String[]{"supervised", "-input", inputFile,
|
||||
"-output", outputFile};
|
||||
else
|
||||
cmd = new String[]{"-input", inputFile,
|
||||
"-output", outputFile};
|
||||
fastTextImpl.runCmd(cmd);
|
||||
public void loadPretrainedVectors(File vectorsFile) {
|
||||
word2Vec = WordVectorSerializer.readWord2VecModel(vectorsFile);
|
||||
modelVectorsLoaded = true;
|
||||
log.info("Loaded vectorized representation from file %s. Functionality will be restricted.",
|
||||
vectorsFile.getAbsolutePath());
|
||||
}
|
||||
|
||||
public void loadBinaryModel(String modelPath) {
|
||||
|
@ -111,10 +227,18 @@ public class FastText implements WordVectors {
|
|||
modelLoaded = false;
|
||||
}
|
||||
|
||||
public void test(File testFile) {
|
||||
fastTextImpl.test(testFile.getAbsolutePath());
|
||||
}
|
||||
|
||||
private void assertModelLoaded() {
|
||||
if (!modelLoaded && !modelVectorsLoaded)
|
||||
throw new IllegalStateException("Model must be loaded before predict!");
|
||||
}
|
||||
|
||||
public String predict(String text) {
|
||||
|
||||
if (!modelLoaded)
|
||||
throw new IllegalStateException("Model must be loaded before predict!");
|
||||
assertModelLoaded();
|
||||
|
||||
String label = fastTextImpl.predict(text);
|
||||
return label;
|
||||
|
@ -122,8 +246,7 @@ public class FastText implements WordVectors {
|
|||
|
||||
public Pair<String, Float> predictProbability(String text) {
|
||||
|
||||
if (!modelLoaded)
|
||||
throw new IllegalStateException("Model must be loaded before predict!");
|
||||
assertModelLoaded();
|
||||
|
||||
JFastText.ProbLabel predictedProbLabel = fastTextImpl.predictProba(text);
|
||||
|
||||
|
@ -135,6 +258,10 @@ public class FastText implements WordVectors {
|
|||
|
||||
@Override
|
||||
public VocabCache vocab() {
|
||||
if (modelVectorsLoaded) {
|
||||
vocabCache = word2Vec.vocab();
|
||||
}
|
||||
else {
|
||||
if (!modelLoaded)
|
||||
throw new IllegalStateException("Load model before calling vocab()");
|
||||
|
||||
|
@ -148,14 +275,22 @@ public class FastText implements WordVectors {
|
|||
word.setWord(words.get(i));
|
||||
vocabCache.addToken(word);
|
||||
}
|
||||
}
|
||||
return vocabCache;
|
||||
}
|
||||
|
||||
@Override
|
||||
public long vocabSize() {
|
||||
long result = 0;
|
||||
if (modelVectorsLoaded) {
|
||||
result = word2Vec.vocabSize();
|
||||
}
|
||||
else {
|
||||
if (!modelLoaded)
|
||||
throw new IllegalStateException("Load model before calling vocab()");
|
||||
return fastTextImpl.getNWords();
|
||||
result = fastTextImpl.getNWords();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -170,6 +305,10 @@ public class FastText implements WordVectors {
|
|||
|
||||
@Override
|
||||
public double[] getWordVector(String word) {
|
||||
if (modelVectorsLoaded) {
|
||||
return word2Vec.getWordVector(word);
|
||||
}
|
||||
else {
|
||||
List<Float> vectors = fastTextImpl.getVector(word);
|
||||
double[] retVal = new double[vectors.size()];
|
||||
for (int i = 0; i < vectors.size(); ++i) {
|
||||
|
@ -177,92 +316,149 @@ public class FastText implements WordVectors {
|
|||
}
|
||||
return retVal;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray getWordVectorMatrixNormalized(String word) {
|
||||
if (modelVectorsLoaded) {
|
||||
return word2Vec.getWordVectorMatrixNormalized(word);
|
||||
}
|
||||
else {
|
||||
INDArray r = getWordVectorMatrix(word);
|
||||
return r.divi(Nd4j.getBlasWrapper().nrm2(r));
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray getWordVectorMatrix(String word) {
|
||||
if (modelVectorsLoaded) {
|
||||
return word2Vec.getWordVectorMatrix(word);
|
||||
}
|
||||
else {
|
||||
double[] values = getWordVector(word);
|
||||
return Nd4j.createFromArray(values);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray getWordVectors(Collection<String> labels) {
|
||||
if (modelVectorsLoaded) {
|
||||
return word2Vec.getWordVectors(labels);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray getWordVectorsMean(Collection<String> labels) {
|
||||
if (modelVectorsLoaded) {
|
||||
return word2Vec.getWordVectorsMean(labels);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
private List<String> words = new ArrayList<>();
|
||||
|
||||
@Override
|
||||
public boolean hasWord(String word) {
|
||||
return fastTextImpl.getWords().contains(word);
|
||||
if (modelVectorsLoaded) {
|
||||
return word2Vec.outOfVocabularySupported();
|
||||
}
|
||||
if (words.isEmpty())
|
||||
words = fastTextImpl.getWords();
|
||||
return words.contains(word);
|
||||
}
|
||||
|
||||
protected transient ModelUtils modelUtils = new BasicModelUtils<>();
|
||||
protected transient ModelUtils modelUtils;
|
||||
|
||||
@Override
|
||||
public Collection<String> wordsNearest(INDArray words, int top) {
|
||||
if (modelVectorsLoaded) {
|
||||
return word2Vec.wordsNearest(words, top);
|
||||
}
|
||||
return modelUtils.wordsNearest(words, top);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Collection<String> wordsNearestSum(INDArray words, int top) {
|
||||
if (modelVectorsLoaded) {
|
||||
return word2Vec.wordsNearestSum(words, top);
|
||||
}
|
||||
return modelUtils.wordsNearestSum(words, top);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Collection<String> wordsNearestSum(String word, int n) {
|
||||
if (modelVectorsLoaded) {
|
||||
return word2Vec.wordsNearestSum(word, n);
|
||||
}
|
||||
return modelUtils.wordsNearestSum(word, n);
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public Collection<String> wordsNearestSum(Collection<String> positive, Collection<String> negative, int top) {
|
||||
if (modelVectorsLoaded) {
|
||||
return word2Vec.wordsNearestSum(positive, negative, top);
|
||||
}
|
||||
return modelUtils.wordsNearestSum(positive, negative, top);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Map<String, Double> accuracy(List<String> questions) {
|
||||
if (modelVectorsLoaded) {
|
||||
return word2Vec.accuracy(questions);
|
||||
}
|
||||
return modelUtils.accuracy(questions);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int indexOf(String word) {
|
||||
if (modelVectorsLoaded) {
|
||||
return word2Vec.indexOf(word);
|
||||
}
|
||||
return vocab().indexOf(word);
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public List<String> similarWordsInVocabTo(String word, double accuracy) {
|
||||
if (modelVectorsLoaded) {
|
||||
return word2Vec.similarWordsInVocabTo(word, accuracy);
|
||||
}
|
||||
return modelUtils.similarWordsInVocabTo(word, accuracy);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Collection<String> wordsNearest(Collection<String> positive, Collection<String> negative, int top) {
|
||||
if (modelVectorsLoaded) {
|
||||
return word2Vec.wordsNearest(positive, negative, top);
|
||||
}
|
||||
return modelUtils.wordsNearest(positive, negative, top);
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public Collection<String> wordsNearest(String word, int n) {
|
||||
if (modelVectorsLoaded) {
|
||||
return word2Vec.wordsNearest(word,n);
|
||||
}
|
||||
return modelUtils.wordsNearestSum(word, n);
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public double similarity(String word, String word2) {
|
||||
if (modelVectorsLoaded) {
|
||||
return word2Vec.similarity(word, word2);
|
||||
}
|
||||
return modelUtils.similarity(word, word2);
|
||||
}
|
||||
|
||||
@Override
|
||||
public WeightLookupTable lookupTable() {
|
||||
if (modelVectorsLoaded) {
|
||||
return word2Vec.lookupTable();
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
|
@ -320,4 +516,9 @@ public class FastText implements WordVectors {
|
|||
return fastTextImpl.getLabelPrefix();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean outOfVocabularySupported() {
|
||||
return true;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -376,6 +376,11 @@ public class StaticWord2Vec implements WordVectors {
|
|||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean outOfVocabularySupported() {
|
||||
return false;
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
|
||||
private AbstractStorage<Integer> storage;
|
||||
|
|
|
@ -1,6 +1,10 @@
|
|||
package org.deeplearning4j.models.fasttext;
|
||||
|
||||
import com.github.jfasttext.JFastText;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
|
||||
import org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils;
|
||||
import org.deeplearning4j.models.word2vec.Word2Vec;
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.text.sentenceiterator.BasicLineIterator;
|
||||
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
|
||||
|
@ -13,6 +17,7 @@ import org.nd4j.resources.Resources;
|
|||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
|
||||
|
||||
import static org.junit.Assert.assertArrayEquals;
|
||||
|
@ -23,7 +28,9 @@ import static org.junit.Assert.assertEquals;
|
|||
public class FastTextTest extends BaseDL4JTest {
|
||||
|
||||
private File inputFile = Resources.asFile("models/fasttext/data/labeled_data.txt");
|
||||
private File modelFile = Resources.asFile("models/fasttext/supervised.model.bin");
|
||||
private File supModelFile = Resources.asFile("models/fasttext/supervised.model.bin");
|
||||
private File cbowModelFile = Resources.asFile("models/fasttext/cbow.model.bin");
|
||||
private File supervisedVectors = Resources.asFile("models/fasttext/supervised.model.vec");
|
||||
|
||||
|
||||
@Rule
|
||||
|
@ -39,7 +46,6 @@ public class FastTextTest extends BaseDL4JTest {
|
|||
inputFile(inputFile.getAbsolutePath()).
|
||||
outputFile(output.getAbsolutePath()).build();
|
||||
log.info("\nTraining supervised model ...\n");
|
||||
fastText.init();
|
||||
fastText.fit();
|
||||
}
|
||||
|
||||
|
@ -53,7 +59,6 @@ public class FastTextTest extends BaseDL4JTest {
|
|||
inputFile(inputFile.getAbsolutePath()).
|
||||
outputFile(output.getAbsolutePath()).build();
|
||||
log.info("\nTraining supervised model ...\n");
|
||||
fastText.init();
|
||||
fastText.fit();
|
||||
}
|
||||
|
||||
|
@ -68,7 +73,6 @@ public class FastTextTest extends BaseDL4JTest {
|
|||
inputFile(inputFile.getAbsolutePath()).
|
||||
outputFile(output.getAbsolutePath()).build();
|
||||
log.info("\nTraining supervised model ...\n");
|
||||
fastText.init();
|
||||
fastText.fit();
|
||||
}
|
||||
|
||||
|
@ -82,34 +86,42 @@ public class FastTextTest extends BaseDL4JTest {
|
|||
inputFile(inputFile.getAbsolutePath()).
|
||||
outputFile(output.getAbsolutePath()).build();
|
||||
log.info("\nTraining supervised model ...\n");
|
||||
fastText.init();
|
||||
fastText.fit();
|
||||
}
|
||||
|
||||
@Ignore
|
||||
@Test
|
||||
public void tesLoadCBOWModel() throws IOException {
|
||||
|
||||
FastText fastText = new FastText(cbowModelFile);
|
||||
fastText.test(cbowModelFile);
|
||||
|
||||
assertEquals(19, fastText.vocab().numWords());
|
||||
assertEquals("enjoy", fastText.vocab().wordAtIndex(fastText.vocab().numWords() - 1));
|
||||
|
||||
double[] expected = {5.040466203354299E-4, 0.001005030469968915, 2.8882650076411664E-4, -6.413314840756357E-4, -1.78931062691845E-4, -0.0023157168179750443, -0.002215880434960127, 0.00274421414360404, -1.5344757412094623E-4, 4.6274057240225375E-4, -1.4383681991603225E-4, 3.7832374800927937E-4, 2.523412986192852E-4, 0.0018913350068032742, -0.0024741862434893847, -4.976555937901139E-4, 0.0039220210164785385, -0.001781729981303215, -6.010578363202512E-4, -0.00244093406945467, -7.98621098510921E-4, -0.0010007203090935946, -0.001640203408896923, 7.897148607298732E-4, 9.131592814810574E-4, -0.0013367272913455963, -0.0014030139427632093, -7.755287806503475E-4, -4.2878396925516427E-4, 6.912827957421541E-4, -0.0011824817629531026, -0.0036014916840940714, 0.004353308118879795, -7.073904271237552E-5, -9.646290563978255E-4, -0.0031849315855652094, 2.3360115301329643E-4, -2.9103990527801216E-4, -0.0022990566212683916, -0.002393763978034258, -0.001034979010000825, -0.0010725988540798426, 0.0018285386031493545, -0.0013178540393710136, -1.6632364713586867E-4, -1.4665909475297667E-5, 5.445032729767263E-4, 2.999933494720608E-4, -0.0014367225812748075, -0.002345481887459755, 0.001117417006753385, -8.688368834555149E-4, -0.001830018823966384, 0.0013242220738902688, -8.880519890226424E-4, -6.888324278406799E-4, -0.0036394784692674875, 0.002179111586883664, -1.7201311129610986E-4, 0.002365073887631297, 0.002688770182430744, 0.0023955567739903927, 0.001469283364713192, 0.0011803617235273123, 5.871498142369092E-4, -7.099180947989225E-4, 7.518937345594168E-4, -8.599072461947799E-4, -6.600041524507105E-4, -0.002724145073443651, -8.365285466425121E-4, 0.0013173354091122746, 0.001083166105672717, 0.0014539906987920403, -3.1698777456767857E-4, -2.387022686889395E-4, 1.9560157670639455E-4, 0.0020277926232665777, -0.0012741144746541977, -0.0013026101514697075, -1.5212174912448972E-4, 0.0014194383984431624, 0.0012500399025157094, 0.0013362085446715355, 3.692879108712077E-4, 4.319801155361347E-5, 0.0011261265026405454, 0.0017244465416297317, 5.564604725805111E-5, 0.002170475199818611, 0.0014707016525790095, 0.001303741242736578, 0.005553730763494968, -0.0011097051901742816, -0.0013661726843565702, 0.0014100460102781653, 0.0011811562580987811, -6.622733199037611E-4, 7.860265322960913E-4, -9.811905911192298E-4};
|
||||
assertArrayEquals(expected, fastText.getWordVector("enjoy"), 1e-4);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testPredict() throws IOException {
|
||||
for (int i = 0; i < 100; ++i) {
|
||||
String text = "I like soccer";
|
||||
|
||||
FastText fastText = new FastText(modelFile);
|
||||
FastText fastText = new FastText(supModelFile);
|
||||
assertEquals(48, fastText.vocab().numWords());
|
||||
assertEquals("association", fastText.vocab().wordAtIndex(fastText.vocab().numWords() - 1));
|
||||
|
||||
double[] expected = {-0.006423053797334433, 0.007660661358386278, 0.006068876478821039, -0.004772625397890806, -0.007143457420170307, -0.007735592778772116, -0.005607823841273785, -0.00836215727031231, 0.0011235733982175589, 2.599214785732329E-4, 0.004131870809942484, 0.007203693501651287, 0.0016768622444942594, 0.008694255724549294, -0.0012487826170399785, -0.00393667770549655, -0.006292815785855055, 0.0049359360709786415, -3.356488887220621E-4, -0.009407570585608482, -0.0026168026961386204, -0.00978928804397583, 0.0032913016621023417, -0.0029464277904480696, -0.008649969473481178, 8.056449587456882E-4, 0.0043088337406516075, -0.008980576880276203, 0.008716211654245853, 0.0073893265798687935, -0.007388216909021139, 0.003814412746578455, -0.005518500227481127, 0.004668557550758123, 0.006603693123906851, 0.003820829326286912, 0.007174000144004822, -0.006393063813447952, -0.0019381389720365405, -0.0046371882781386375, -0.006193376146256924, -0.0036685809027403593, 7.58899434003979E-4, -0.003185075242072344, -0.008330358192324638, 3.3206873922608793E-4, -0.005389622412621975, 0.009706716984510422, 0.0037855932023376226, -0.008665262721478939, -0.0032511046156287193, 4.4134497875347733E-4, -0.008377416990697384, -0.009110655635595322, 0.0019723298028111458, 0.007486093323677778, 0.006400121841579676, 0.00902814231812954, 0.00975200068205595, 0.0060582347214221954, -0.0075621469877660275, 1.0270809434587136E-4, -0.00673140911385417, -0.007316927425563335, 0.009916870854794979, -0.0011407854035496712, -4.502215306274593E-4, -0.007612560410052538, 0.008726916275918484, -3.0280642022262327E-5, 0.005529289599508047, -0.007944817654788494, 0.005593308713287115, 0.003423960180953145, 4.1348213562741876E-4, 0.009524818509817123, -0.0025129399728029966, -0.0030074280221015215, -0.007503866218030453, -0.0028124507516622543, -0.006841592025011778, -2.9375351732596755E-4, 0.007195258513092995, -0.007775942329317331, 3.951996040996164E-4, -0.006887971889227629, 0.0032655203249305487, -0.007975360378623009, -4.840183464693837E-6, 0.004651934839785099, 0.0031739831902086735, 0.004644941072911024, -0.007461248897016048, 0.003057275665923953, 0.008903342299163342, 0.006857945583760738, 0.007567950990051031, 0.001506582135334611, 0.0063307867385447025, 0.005645462777465582};
|
||||
assertArrayEquals(expected, fastText.getWordVector("association"), 1e-5);
|
||||
assertArrayEquals(expected, fastText.getWordVector("association"), 1e-4);
|
||||
|
||||
String label = fastText.predict(text);
|
||||
assertEquals("__label__soccer", label);
|
||||
}
|
||||
}
|
||||
|
||||
@Ignore
|
||||
@Test
|
||||
public void testPredictProbability() throws IOException {
|
||||
String text = "I like soccer";
|
||||
|
||||
FastText fastText = new FastText(modelFile);
|
||||
FastText fastText = new FastText(supModelFile);
|
||||
|
||||
Pair<String,Float> result = fastText.predictProbability(text);
|
||||
assertEquals("__label__soccer", result.getFirst());
|
||||
|
@ -129,7 +141,7 @@ public class FastTextTest extends BaseDL4JTest {
|
|||
|
||||
@Test
|
||||
public void testVocabulary() throws IOException {
|
||||
FastText fastText = new FastText(modelFile);
|
||||
FastText fastText = new FastText(supModelFile);
|
||||
assertEquals(48, fastText.vocab().numWords());
|
||||
assertEquals(48, fastText.vocabSize());
|
||||
|
||||
|
@ -149,7 +161,7 @@ public class FastTextTest extends BaseDL4JTest {
|
|||
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
|
||||
FastText fastText =
|
||||
FastText.builder().supervised(true).iterator(iter).build();
|
||||
fastText.init();
|
||||
fastText.loadIterator();
|
||||
|
||||
} catch (IOException e) {
|
||||
log.error(e.toString());
|
||||
|
@ -162,4 +174,60 @@ public class FastTextTest extends BaseDL4JTest {
|
|||
String label = fastText.predict("something");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testPretrainedVectors() throws IOException {
|
||||
File output = testDir.newFile();
|
||||
|
||||
FastText fastText =
|
||||
FastText.builder().supervised(true).
|
||||
inputFile(inputFile.getAbsolutePath()).
|
||||
pretrainedVectorsFile(supervisedVectors.getAbsolutePath()).
|
||||
outputFile(output.getAbsolutePath()).build();
|
||||
log.info("\nTraining supervised model ...\n");
|
||||
fastText.fit();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testWordsStatistics() throws IOException {
|
||||
|
||||
File output = testDir.newFile();
|
||||
|
||||
FastText fastText =
|
||||
FastText.builder().supervised(true).
|
||||
inputFile(inputFile.getAbsolutePath()).
|
||||
outputFile(output.getAbsolutePath()).build();
|
||||
|
||||
log.info("\nTraining supervised model ...\n");
|
||||
fastText.fit();
|
||||
|
||||
Word2Vec word2Vec = WordVectorSerializer.readAsCsv(new File(output.getAbsolutePath() + ".vec"));
|
||||
|
||||
assertEquals(48, word2Vec.getVocab().numWords());
|
||||
|
||||
System.out.println(word2Vec.wordsNearest("association", 3));
|
||||
System.out.println(word2Vec.similarity("Football", "teams"));
|
||||
System.out.println(word2Vec.similarity("professional", "minutes"));
|
||||
System.out.println(word2Vec.similarity("java","cpp"));
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testWordsNativeStatistics() throws IOException {
|
||||
|
||||
File output = testDir.newFile();
|
||||
|
||||
FastText fastText = new FastText();
|
||||
fastText.loadPretrainedVectors(supervisedVectors);
|
||||
|
||||
log.info("\nTraining supervised model ...\n");
|
||||
|
||||
assertEquals(48, fastText.vocab().numWords());
|
||||
|
||||
String[] result = new String[3];
|
||||
fastText.wordsNearest("association", 3).toArray(result);
|
||||
assertArrayEquals(new String[]{"most","eleven","hours"}, result);
|
||||
assertEquals(0.1657, fastText.similarity("Football", "teams"), 1e-4);
|
||||
assertEquals(0.3661, fastText.similarity("professional", "minutes"), 1e-4);
|
||||
assertEquals(Double.NaN, fastText.similarity("java","cpp"), 1e-4);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -26,6 +26,7 @@ import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
|
|||
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
|
||||
import org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils;
|
||||
import org.deeplearning4j.models.embeddings.reader.impl.FlatModelUtils;
|
||||
import org.deeplearning4j.models.fasttext.FastText;
|
||||
import org.deeplearning4j.models.paragraphvectors.ParagraphVectors;
|
||||
import org.deeplearning4j.models.sequencevectors.SequenceVectors;
|
||||
import org.deeplearning4j.models.word2vec.VocabWord;
|
||||
|
@ -42,6 +43,7 @@ import org.nd4j.linalg.factory.Nd4j;
|
|||
import java.io.ByteArrayInputStream;
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
|
@ -289,4 +291,33 @@ public class WordVectorSerializerTest extends BaseDL4JTest {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void FastText_Correct_WhenDeserialized() throws IOException {
|
||||
|
||||
FastText fastText =
|
||||
FastText.builder().cbow(true).build();
|
||||
|
||||
WordVectorSerializer.writeWordVectors(fastText, new File("some.data"));
|
||||
|
||||
FastText deser = null;
|
||||
try {
|
||||
ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
||||
deser = WordVectorSerializer.readWordVectors(new File("some.data"));
|
||||
} catch (Exception e) {
|
||||
e.printStackTrace();
|
||||
fail();
|
||||
}
|
||||
|
||||
assertNotNull(deser);
|
||||
assertEquals(fastText.isCbow(), deser.isCbow());
|
||||
assertEquals(fastText.isModelLoaded(), deser.isModelLoaded());
|
||||
assertEquals(fastText.isAnalogies(), deser.isAnalogies());
|
||||
assertEquals(fastText.isNn(), deser.isNn());
|
||||
assertEquals(fastText.isPredict(), deser.isPredict());
|
||||
assertEquals(fastText.isPredict_prob(), deser.isPredict_prob());
|
||||
assertEquals(fastText.isQuantize(), deser.isQuantize());
|
||||
assertEquals(fastText.getInputFile(), deser.getInputFile());
|
||||
assertEquals(fastText.getOutputFile(), deser.getOutputFile());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -453,6 +453,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
|||
|
||||
DataType netDtype = getConfiguration().getDataType();
|
||||
if(parameters != null && parameters.dataType() != netDtype){
|
||||
Preconditions.checkState(parameters.rank() == 2 && parameters.size(0) == 1, "Invalid parameters array: should be rank 2 with shape [1,numParams]. Got %ndShape", parameters);
|
||||
if(cloneParametersArray){
|
||||
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
|
||||
parameters = parameters.castTo(netDtype);
|
||||
|
|
|
@ -178,8 +178,7 @@ public class EmbeddingSequenceLayer extends BaseLayer<org.deeplearning4j.nn.conf
|
|||
", mask shape: " + Arrays.toString(maskArray.shape()));
|
||||
}
|
||||
//Returned array: rank 3, shape [mb, vector, seqLength]. mask shape: [mb, seqLength]
|
||||
Broadcast.mul(ret, maskArray, ret, 0, 2);
|
||||
// ret.muliColumnVector(maskArray);
|
||||
Broadcast.mul(ret, maskArray.castTo(ret.dataType()), ret, 0, 2);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
|
|
@ -616,6 +616,7 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
|
|||
|
||||
DataType netDtype = getLayerWiseConfigurations().getDataType();
|
||||
if(parameters != null && parameters.dataType() != netDtype){
|
||||
Preconditions.checkState(parameters.rank() == 2 && parameters.size(0) == 1, "Invalid parameters array: should be rank 2 with shape [1,numParams]. Got %ndShape", parameters);
|
||||
if(cloneParametersArray){
|
||||
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
|
||||
parameters = parameters.castTo(netDtype);
|
||||
|
@ -627,6 +628,7 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
if (layerMap == null)
|
||||
layerMap = new LinkedHashMap<>();
|
||||
|
||||
|
|
|
@ -32,7 +32,7 @@ import java.util.Arrays;
|
|||
@Ignore
|
||||
public class TestSameDiffUI {
|
||||
|
||||
// @Ignore
|
||||
@Ignore
|
||||
@Test
|
||||
public void testSameDiff() throws Exception {
|
||||
|
||||
|
|
|
@ -1598,9 +1598,6 @@ namespace nd4j {
|
|||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
int NDArray::rankOf() const {
|
||||
if (isEmpty())
|
||||
return 0;
|
||||
|
||||
return shape::rank(_shapeInfo);
|
||||
}
|
||||
|
||||
|
|
|
@ -36,7 +36,7 @@ std::string NDArray::e(const Nd4jLong i) const;
|
|||
template <typename T>
|
||||
NDArray* NDArray::asT() const{
|
||||
|
||||
auto result = new NDArray(ordering(), isScalar() ? std::vector<Nd4jLong>({0}) : getShapeAsVector(), DataTypeUtils::fromT<T>());
|
||||
auto result = isScalar() ? new NDArray('c', {}, {0.}, DataTypeUtils::fromT<T>(), this->getContext()) : new NDArray(ordering(), getShapeAsVector(), DataTypeUtils::fromT<T>(), this->getContext());
|
||||
auto l = this->lengthOf();
|
||||
|
||||
prepareSpecialUse({result}, {this});
|
||||
|
@ -67,17 +67,18 @@ NDArray::NDArray(const NDArray& other) {
|
|||
////////////////////////////////////////////////////////////////////////
|
||||
NDArray::NDArray(const char order, const std::vector<Nd4jLong> &shape, nd4j::DataType dtype, nd4j::LaunchContext * context) {
|
||||
|
||||
if (shape.empty())
|
||||
throw std::runtime_error("NDArray constructor: input shape is empty !");
|
||||
|
||||
if ((int) shape.size() > MAX_RANK)
|
||||
throw std::invalid_argument("Rank of NDArray can't exceed 32");
|
||||
|
||||
_context = context;
|
||||
_isAttached = getContext()->getWorkspace() != nullptr;
|
||||
_isAttached = _context->getWorkspace() != nullptr;
|
||||
_offset = 0;
|
||||
|
||||
if (shape.empty())
|
||||
setShapeInfo(ShapeDescriptor::emptyDescriptor(dtype));
|
||||
else
|
||||
setShapeInfo(ShapeDescriptor(dtype, order, shape));
|
||||
|
||||
_buffer = std::make_shared<DataBuffer>(lengthOf() * DataTypeUtils::sizeOf(dtype), dtype, getContext()->getWorkspace());
|
||||
_buffer->setToZeroBuffers();
|
||||
}
|
||||
|
@ -85,16 +86,20 @@ NDArray::NDArray(const char order, const std::vector<Nd4jLong> &shape, nd4j::Dat
|
|||
////////////////////////////////////////////////////////////////////////
|
||||
NDArray::NDArray(const char order, const std::vector<Nd4jLong> &shape, const std::vector<double>& data, nd4j::DataType dtype, nd4j::LaunchContext * context) {
|
||||
|
||||
if (shape.empty())
|
||||
throw std::runtime_error("NDArray constructor: input shape is empty !");
|
||||
|
||||
if ((int) shape.size() > MAX_RANK)
|
||||
throw std::invalid_argument("Rank of NDArray can't exceed 32");
|
||||
|
||||
_context = context;
|
||||
_offset = 0;
|
||||
|
||||
if (shape.size() == 0) {
|
||||
if (data.size() == 0)
|
||||
setShapeInfo(ShapeDescriptor::emptyDescriptor(dtype));
|
||||
else
|
||||
setShapeInfo(ShapeDescriptor::scalarDescriptor(dtype));
|
||||
} else {
|
||||
setShapeInfo(ShapeDescriptor(dtype, order, shape));
|
||||
}
|
||||
|
||||
if (lengthOf() != data.size()) {
|
||||
nd4j_printf("NDArray constructor: data size [%i] doesn't match shape length [%i]\n", data.size(), lengthOf());
|
||||
|
@ -2441,6 +2446,9 @@ void NDArray::applyTrueBroadcast(nd4j::BroadcastOpsTuple op, const NDArray* othe
|
|||
if(((op.s == scalar::Divide || op.s == scalar::FloorDiv || op.s == scalar::FloorMod) && other->isB()) || (op.s == scalar::ReverseDivide && this->isB()))
|
||||
throw std::runtime_error("NDArray::applyTrueBroadcast method: you can't divide by bool array !");
|
||||
|
||||
if (isEmpty() || other->isEmpty())
|
||||
return;
|
||||
|
||||
NDArray::prepareSpecialUse({target}, {this, other});
|
||||
|
||||
if (isScalar()) {
|
||||
|
@ -2513,6 +2521,9 @@ void NDArray::applyTrueBroadcast(nd4j::BroadcastBoolOpsTuple op, const NDArray*
|
|||
if(target == nullptr || other == nullptr)
|
||||
throw std::runtime_error("NDArray::applyTrueBroadcast bool method: target or other = nullptr !");
|
||||
|
||||
if (isEmpty() || other->isEmpty())
|
||||
return;
|
||||
|
||||
NDArray::prepareSpecialUse({target}, {this, other});
|
||||
|
||||
if (isScalar()) {
|
||||
|
@ -2583,6 +2594,13 @@ void NDArray::applyTrueBroadcast(nd4j::BroadcastBoolOpsTuple op, const NDArray*
|
|||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
NDArray NDArray::applyTrueBroadcast(nd4j::BroadcastOpsTuple op, const NDArray& other, ExtraArguments *extraArgs) const {
|
||||
if (isEmpty() || other.isEmpty()) {
|
||||
if (isEmpty())
|
||||
return NDArray(*this);
|
||||
else
|
||||
return NDArray(other);
|
||||
}
|
||||
|
||||
Nd4jLong* newShapeInfo = nullptr;
|
||||
if(!ShapeUtils::evalBroadcastShapeInfo(*this, other, true, newShapeInfo, getContext()->getWorkspace())) // the rank of new array = max->rankOf)()
|
||||
throw std::runtime_error("NDArray::applyTrueBroadcast method: the shapes of this and other arrays are not suitable for broadcast operation !");
|
||||
|
@ -2812,6 +2830,19 @@ bool NDArray::reshapei(const char order, const std::vector<Nd4jLong>& cshape) {
|
|||
if(order == ordering() && shape::shapeEquals(rankOf(), shapeOf(), cshape.size(), cshape.data()))
|
||||
return true;
|
||||
|
||||
const bool isOutShapeEmpty = std::find(cshape.begin(), cshape.end(), 0) != cshape.end();
|
||||
|
||||
if(isEmpty() && !isOutShapeEmpty)
|
||||
throw std::invalid_argument("NDArray::reshapei: can't reshape empty array to non-empty !");
|
||||
if(!isEmpty() && isOutShapeEmpty)
|
||||
throw std::invalid_argument("NDArray::reshapei: can't reshape non-empty array to empty !");
|
||||
if(isEmpty() && isOutShapeEmpty) {
|
||||
Nd4jLong* shapeInfoNew = ShapeBuilders::emptyShapeInfo(dataType(), order, cshape, getContext()->getWorkspace());
|
||||
setShapeInfo(shapeInfoNew);
|
||||
RELEASE(shapeInfoNew, getContext()->getWorkspace());
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<Nd4jLong> shape(cshape);
|
||||
int rank = shape.size();
|
||||
|
||||
|
@ -2823,7 +2854,7 @@ bool NDArray::reshapei(const char order, const std::vector<Nd4jLong>& cshape) {
|
|||
for (int i = 0; i < (int) shape.size(); i++) {
|
||||
if (shape[i] < 0) {
|
||||
if (numberNegativesOnes >= 1)
|
||||
throw std::runtime_error("Only one dimension can be negative at once");
|
||||
throw std::runtime_error("NDArray::reshapei: only one dimension can be negative at once");
|
||||
|
||||
numberNegativesOnes++;
|
||||
|
||||
|
@ -3664,7 +3695,7 @@ void NDArray::reduceAlongDimension(nd4j::reduce::SameOps op, NDArray* target, co
|
|||
if(rankOf() == copy.size() || copy.empty()) {
|
||||
NativeOpExecutioner::execReduceSameScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), nullptr, target->getBuffer(), target->getShapeInfo(), target->getSpecialBuffer(), target->getSpecialShapeInfo());
|
||||
}
|
||||
else {
|
||||
else { //if (!isEmpty()) {
|
||||
auto pDims = nd4j::Environment::getInstance()->isCPU() ? copy.data() : nullptr;
|
||||
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(this->getShapeInfo(), copy);
|
||||
NativeOpExecutioner::execReduceSame(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), nullptr, target->getBuffer(), target->getShapeInfo(), target->getSpecialBuffer(), target->getSpecialShapeInfo(), pDims, copy.size(), packX.platformShapeInfo(), packX.platformOffsets());
|
||||
|
@ -4198,6 +4229,9 @@ NDArray* NDArray::tensorAlongDimension(Nd4jLong index, const std::vector<int>& d
|
|||
// operator returns sub-array with buffer pointing at this->_buffer + certain offset
|
||||
NDArray NDArray::operator()(const std::vector<Nd4jLong>& idx, const bool keepUnitiesInShape, const bool isStrided) const {
|
||||
|
||||
if(isEmpty())
|
||||
throw std::invalid_argument("NDArray::operator(sub-arrays): array is empty !");
|
||||
|
||||
const int rank = rankOf();
|
||||
Nd4jLong *newShapeInfo = ShapeBuilders::copyShapeInfo(getShapeInfo(), true, getContext()->getWorkspace());
|
||||
|
||||
|
@ -4260,6 +4294,9 @@ NDArray NDArray::operator()(const Nd4jLong subArrIdx, const std::vector<int>& di
|
|||
////////////////////////////////////////////////////////////////////////
|
||||
void NDArray::getSubArrShapeAndOffsets(const std::vector<int>& dimsToExclude, Nd4jLong* &subArrShapeInfo, Nd4jLong* &subArrOffsets, bool keepUnitiesInShape) const {
|
||||
|
||||
if(isEmpty())
|
||||
throw std::invalid_argument("NDArray::getSubArrShapeAndOffsets: array is empty !");
|
||||
|
||||
const int rank = rankOf();
|
||||
const int subArrRank = (rank == dimsToExclude.size() || keepUnitiesInShape) ? rank : rank - dimsToExclude.size();
|
||||
const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(_shapeInfo, dimsToExclude);
|
||||
|
|
|
@ -1334,18 +1334,7 @@ public:
|
|||
* @param npyArray
|
||||
* @return
|
||||
*/
|
||||
Nd4jPointer shapeBufferForNumpy(Nd4jPointer npyArray) {
|
||||
cnpy::NpyArray arr = cnpy::loadNpyFromPointer(reinterpret_cast<char *>(npyArray));
|
||||
unsigned int shapeSize = arr.shape.size();
|
||||
auto shape = new unsigned int[shapeSize];
|
||||
for(unsigned int i = 0; i < shapeSize; i++) {
|
||||
shape[i] = arr.shape[i];
|
||||
}
|
||||
|
||||
auto shapeBuffer = shape::shapeBufferOfNpy(arr.shape.size(), shape, arr.fortranOrder);
|
||||
delete[] shape;
|
||||
return reinterpret_cast<Nd4jPointer>(shapeBuffer);
|
||||
}
|
||||
Nd4jPointer shapeBufferForNumpy(Nd4jPointer npyArray);
|
||||
|
||||
|
||||
/**
|
||||
|
|
|
@ -64,8 +64,8 @@ void NDArray::tickWriteDevice() const { }
|
|||
void NDArray::tickReadHost() const { }
|
||||
void NDArray::tickReadDevice() const { }
|
||||
void NDArray::tickBothActual() const { }
|
||||
bool NDArray::isActualOnHostSide() const { }
|
||||
bool NDArray::isActualOnDeviceSide() const { }
|
||||
bool NDArray::isActualOnHostSide() const { return true; }
|
||||
bool NDArray::isActualOnDeviceSide() const { return true; }
|
||||
void NDArray::makeBothBuffersActual() const { }
|
||||
|
||||
|
||||
|
@ -419,328 +419,8 @@ void NDArray::repeat(int dimension, NDArray& target) const {
|
|||
//////////////////////////////////////////////////////////////////////////
|
||||
#ifndef __JAVACPP_HACK__
|
||||
|
||||
template<typename T>
|
||||
void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<T(T, T, T)>& func, NDArray* target) {
|
||||
if (target == nullptr)
|
||||
target = this;
|
||||
#include "NDArrayLambda.hpp"
|
||||
|
||||
if (second == nullptr) {
|
||||
nd4j_printf("applyTriplewiseLambda requires three operands to be valid NDArrays, but Second is NULL\n","");
|
||||
throw std::runtime_error("second is null");
|
||||
}
|
||||
|
||||
if (third == nullptr) {
|
||||
nd4j_printf("applyTriplewiseLambda requires three operands to be valid NDArrays, but Third is NULL\n","");
|
||||
throw std::runtime_error("third is null");
|
||||
}
|
||||
if(dataType() != DataTypeUtils::fromT<T>())
|
||||
throw std::runtime_error("NDArray::applyTriplewiseLambda<T> method: wrong template parameter T, its type should be the same as type of this array!");
|
||||
if(dataType() != second->dataType() || dataType() != third->dataType() || dataType() != target->dataType())
|
||||
throw std::runtime_error("NDArray::applyTriplewiseLambda<T> method: bother four arrays (this, second, third, target) should have the same type !");
|
||||
|
||||
if (this->lengthOf() != second->lengthOf() || this->lengthOf() != third->lengthOf() || !this->isSameShape(second) || !this->isSameShape(third)) {
|
||||
nd4j_printf("applyPairwiseLambda requires both operands to have the same shape\n","");
|
||||
throw std::runtime_error("Shapes mismach");
|
||||
}
|
||||
|
||||
auto f = this->bufferAsT<T>();
|
||||
auto s = second->bufferAsT<T>();
|
||||
auto t = third->bufferAsT<T>();
|
||||
auto z = target->bufferAsT<T>();
|
||||
|
||||
if (this->ordering() == second->ordering() && this->ordering() == third->ordering() && this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1) && this->ews() == second->ews() && this->ews() == third->ews()) {
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||||
for (Nd4jLong e = 0; e < _length; e++)
|
||||
z[e] = func(f[e], s[e], t[e]);
|
||||
} else {
|
||||
if (f == z) {
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||||
for (int e = 0; e < _length; e++) {
|
||||
|
||||
auto tOffset = this->getOffset(e);
|
||||
auto uOffset = second->getOffset(e);
|
||||
auto vOffset = third->getOffset(e);
|
||||
|
||||
f[tOffset] = func(f[tOffset], s[uOffset], t[vOffset]);
|
||||
}
|
||||
} else {
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||||
for (int e = 0; e < _length; e++) {
|
||||
|
||||
auto tOffset = this->getOffset(e);
|
||||
auto uOffset = second->getOffset(e);
|
||||
auto vOffset = third->getOffset(e);
|
||||
auto zOffset = target->getOffset(e);
|
||||
|
||||
z[zOffset] = func(f[tOffset], s[uOffset], t[vOffset]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<double (double, double, double)>& func, NDArray* target);
|
||||
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<float (float, float, float)>& func, NDArray* target);
|
||||
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<float16 (float16, float16, float16)>& func, NDArray* target);
|
||||
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<bfloat16 (bfloat16, bfloat16, bfloat16)>& func, NDArray* target);
|
||||
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<Nd4jLong (Nd4jLong, Nd4jLong, Nd4jLong)>& func, NDArray* target);
|
||||
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<int (int, int, int)>& func, NDArray* target);
|
||||
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<int16_t (int16_t, int16_t, int16_t)>& func, NDArray* target);
|
||||
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<uint8_t (uint8_t, uint8_t, uint8_t)>& func, NDArray* target);
|
||||
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<uint16_t (uint16_t, uint16_t, uint16_t)>& func, NDArray* target);
|
||||
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<uint32_t (uint32_t, uint32_t, uint32_t)>& func, NDArray* target);
|
||||
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<uint64_t (uint64_t, uint64_t, uint64_t)>& func, NDArray* target);
|
||||
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<int8_t (int8_t, int8_t, int8_t)>& func, NDArray* target);
|
||||
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<bool (bool, bool, bool)>& func, NDArray* target);
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template<typename T>
|
||||
void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<T(T, T)>& func, NDArray* target) {
|
||||
if (target == nullptr)
|
||||
target = this;
|
||||
|
||||
if (other == nullptr) {
|
||||
nd4j_printf("applyPairwiseLambda requires both operands to be valid NDArrays, but Y is NULL\n","");
|
||||
throw std::runtime_error("Other is null");
|
||||
}
|
||||
|
||||
if(dataType() != DataTypeUtils::fromT<T>())
|
||||
throw std::runtime_error("NDArray::applyPairwiseLambda<T> method: wrong template parameter T, its type should be the same as type of this array!");
|
||||
if(dataType() != other->dataType() || dataType() != target->dataType())
|
||||
throw std::runtime_error("NDArray::applyPairwiseLambda<T> method: all three arrays (this, other, target) must have the same type !");
|
||||
|
||||
if (this->lengthOf() != other->lengthOf()) {
|
||||
nd4j_printf("applyPairwiseLambda requires both operands to have the same shape\n","");
|
||||
throw std::runtime_error("Shapes mismach");
|
||||
}
|
||||
|
||||
auto f = this->bufferAsT<T>();
|
||||
auto s = other->bufferAsT<T>();
|
||||
auto z = target->bufferAsT<T>();
|
||||
|
||||
if (this->ordering() == other->ordering() && this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1) && this->ews() == other->ews()) {
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||||
for (int e = 0; e < _length; e++)
|
||||
z[e] = func(f[e], s[e]);
|
||||
} else {
|
||||
if (f == z) {
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||||
for (int e = 0; e < _length; e++) {
|
||||
|
||||
auto xOffset = this->getOffset(e);
|
||||
auto yOffset = other->getOffset(e);
|
||||
|
||||
f[xOffset] = func(f[xOffset], s[yOffset]);
|
||||
}
|
||||
} else {
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||||
for (int e = 0; e < _length; e++) {
|
||||
|
||||
auto xOffset = this->getOffset(e);
|
||||
auto yOffset = other->getOffset(e);
|
||||
auto zOffset = target->getOffset(e);
|
||||
|
||||
z[zOffset] = func(f[xOffset], s[yOffset]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<double (double, double)>& func, NDArray* target);
|
||||
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<float (float, float)>& func, NDArray* target);
|
||||
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<float16 (float16, float16)>& func, NDArray* target);
|
||||
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<bfloat16 (bfloat16, bfloat16)>& func, NDArray* target);
|
||||
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<Nd4jLong (Nd4jLong, Nd4jLong)>& func, NDArray* target);
|
||||
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<int (int, int)>& func, NDArray* target);
|
||||
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<int16_t (int16_t, int16_t)>& func, NDArray* target);
|
||||
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<uint8_t (uint8_t, uint8_t)>& func, NDArray* target);
|
||||
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<uint16_t (uint16_t, uint16_t)>& func, NDArray* target);
|
||||
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<uint32_t (uint32_t, uint32_t)>& func, NDArray* target);
|
||||
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<uint64_t (uint64_t, uint64_t)>& func, NDArray* target);
|
||||
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<int8_t (int8_t, int8_t)>& func, NDArray* target);
|
||||
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<bool (bool, bool)>& func, NDArray* target);
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template<typename T>
|
||||
void NDArray::applyLambda(const std::function<T(T)>& func, NDArray* target) {
|
||||
if (target == nullptr)
|
||||
target = this;
|
||||
|
||||
if(dataType() != DataTypeUtils::fromT<T>())
|
||||
throw std::runtime_error("NDArray::applyLambda<T> method: wrong template parameter T, its type should be the same as type of this array!");
|
||||
if(dataType() != target->dataType())
|
||||
throw std::runtime_error("NDArray::applyLambda<T> method: types of this and target array should match !");
|
||||
|
||||
auto f = this->bufferAsT<T>();
|
||||
auto z = target->bufferAsT<T>();
|
||||
|
||||
if (this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1)) {
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||||
for (int e = 0; e < _length; e++)
|
||||
z[e] = func(f[e]);
|
||||
} else {
|
||||
if (f == z) {
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||||
for (int e = 0; e < _length; e++) {
|
||||
|
||||
auto xOffset = this->getOffset(e);
|
||||
|
||||
f[xOffset] = func(f[xOffset]);
|
||||
}
|
||||
} else {
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||||
for (int e = 0; e < _length; e++) {
|
||||
|
||||
auto xOffset = this->getOffset(e);
|
||||
auto zOffset = target->getOffset(e);
|
||||
|
||||
z[zOffset] = func(f[xOffset]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
template void NDArray::applyLambda(const std::function<double(double)>& func, NDArray* target);
|
||||
template void NDArray::applyLambda(const std::function<float(float)>& func, NDArray* target);
|
||||
template void NDArray::applyLambda(const std::function<float16(float16)>& func, NDArray* target);
|
||||
template void NDArray::applyLambda(const std::function<bfloat16(bfloat16)>& func, NDArray* target);
|
||||
template void NDArray::applyLambda(const std::function<Nd4jLong(Nd4jLong)>& func, NDArray* target);
|
||||
template void NDArray::applyLambda(const std::function<int16_t(int16_t)>& func, NDArray* target);
|
||||
template void NDArray::applyLambda(const std::function<int32_t(int32_t)>& func, NDArray* target);
|
||||
template void NDArray::applyLambda(const std::function<uint8_t(uint8_t)>& func, NDArray* target);
|
||||
template void NDArray::applyLambda(const std::function<uint16_t(uint16_t)>& func, NDArray* target);
|
||||
template void NDArray::applyLambda(const std::function<uint32_t(uint32_t)>& func, NDArray* target);
|
||||
template void NDArray::applyLambda(const std::function<uint64_t(uint64_t)>& func, NDArray* target);
|
||||
template void NDArray::applyLambda(const std::function<int8_t(int8_t)>& func, NDArray* target);
|
||||
template void NDArray::applyLambda(const std::function<bool(bool)>& func, NDArray* target);
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template<typename T>
|
||||
void NDArray::applyIndexedLambda(const std::function<T(Nd4jLong, T)>& func, NDArray* target) {
|
||||
if (target == nullptr)
|
||||
target = this;
|
||||
|
||||
if(dataType() != DataTypeUtils::fromT<T>())
|
||||
throw std::runtime_error("NDArray::applyIndexedLambda<T> method: wrong template parameter T, its type should be the same as type of this array!");
|
||||
if(dataType() != target->dataType())
|
||||
throw std::runtime_error("NDArray::applyIndexedLambda<T> method: types of this and target array should match !");
|
||||
|
||||
auto f = this->bufferAsT<T>();
|
||||
auto z = target->bufferAsT<T>();
|
||||
|
||||
if (this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1)) {
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||||
for (Nd4jLong e = 0; e < _length; e++)
|
||||
z[e] = func(e, f[e]);
|
||||
} else {
|
||||
if (f == z) {
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||||
for (Nd4jLong e = 0; e < _length; e++) {
|
||||
|
||||
auto xOffset = this->getOffset(e);
|
||||
|
||||
f[xOffset] = func(e, f[xOffset]);
|
||||
}
|
||||
} else {
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||||
for (Nd4jLong e = 0; e < _length; e++) {
|
||||
|
||||
auto xOffset = this->getOffset(e);
|
||||
auto zOffset = target->getOffset(e);
|
||||
|
||||
z[zOffset] = func(e, f[xOffset]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
template void NDArray::applyIndexedLambda(const std::function<double(Nd4jLong, double)>& func, NDArray* target);
|
||||
template void NDArray::applyIndexedLambda(const std::function<float(Nd4jLong, float)>& func, NDArray* target);
|
||||
template void NDArray::applyIndexedLambda(const std::function<float16(Nd4jLong, float16)>& func, NDArray* target);
|
||||
template void NDArray::applyIndexedLambda(const std::function<bfloat16(Nd4jLong, bfloat16)>& func, NDArray* target);
|
||||
template void NDArray::applyIndexedLambda(const std::function<Nd4jLong(Nd4jLong, Nd4jLong)>& func, NDArray* target);
|
||||
template void NDArray::applyIndexedLambda(const std::function<int(Nd4jLong, int)>& func, NDArray* target);
|
||||
template void NDArray::applyIndexedLambda(const std::function<int16_t(Nd4jLong, int16_t)>& func, NDArray* target);
|
||||
template void NDArray::applyIndexedLambda(const std::function<uint8_t (Nd4jLong, uint8_t)>& func, NDArray* target);
|
||||
template void NDArray::applyIndexedLambda(const std::function<uint16_t (Nd4jLong, uint16_t)>& func, NDArray* target);
|
||||
template void NDArray::applyIndexedLambda(const std::function<uint32_t (Nd4jLong, uint32_t)>& func, NDArray* target);
|
||||
template void NDArray::applyIndexedLambda(const std::function<uint64_t (Nd4jLong, uint64_t)>& func, NDArray* target);
|
||||
template void NDArray::applyIndexedLambda(const std::function<int8_t(Nd4jLong, int8_t)>& func, NDArray* target);
|
||||
template void NDArray::applyIndexedLambda(const std::function<bool(Nd4jLong, bool)>& func, NDArray* target);
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template<typename T>
|
||||
void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<T(Nd4jLong, T, T)>& func, NDArray* target) {
|
||||
if (target == nullptr)
|
||||
target = this;
|
||||
|
||||
if (other == nullptr) {
|
||||
nd4j_printf("applyIndexedPairwiseLambda requires both operands to be valid NDArrays, but Y is NULL\n","");
|
||||
throw std::runtime_error("Other is null");
|
||||
}
|
||||
if(dataType() != DataTypeUtils::fromT<T>())
|
||||
throw std::runtime_error("NDArray::applyIndexedPairwiseLambda<T> method: wrong template parameter T, its type should be the same as type of this array!");
|
||||
if(dataType() != target->dataType())
|
||||
throw std::runtime_error("NDArray::applyIndexedPairwiseLambda<T> method: types of this and target array should match !");
|
||||
if (this->lengthOf() != other->lengthOf()) {
|
||||
nd4j_printf("applyIndexedPairwiseLambda requires both operands to have the same shape\n","");
|
||||
throw std::runtime_error("Shapes mismach");
|
||||
}
|
||||
|
||||
auto f = this->bufferAsT<T>();
|
||||
auto s = other->bufferAsT<T>();
|
||||
auto z = target->bufferAsT<T>();
|
||||
|
||||
if (this->ordering() == other->ordering() && this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1) && this->ews() == other->ews()) {
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||||
for (Nd4jLong e = 0; e < _length; e++)
|
||||
z[e] = func((Nd4jLong) e, f[e], s[e]);
|
||||
} else {
|
||||
if (f == z) {
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||||
for (int e = 0; e < _length; e++) {
|
||||
|
||||
auto xOffset = this->getOffset(e);
|
||||
auto yOffset = other->getOffset(e);
|
||||
|
||||
f[xOffset] = func((Nd4jLong) e, f[xOffset], s[yOffset]);
|
||||
}
|
||||
} else {
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||||
for (int e = 0; e < _length; e++) {
|
||||
|
||||
auto xOffset = this->getOffset(e);
|
||||
auto yOffset = other->getOffset(e);
|
||||
auto zOffset = target->getOffset(e);
|
||||
|
||||
z[zOffset] = func((Nd4jLong) e, f[xOffset], s[yOffset]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<double (Nd4jLong, double, double)>& func, NDArray* target);
|
||||
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<float (Nd4jLong, float, float)>& func, NDArray* target);
|
||||
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<float16 (Nd4jLong, float16, float16)>& func, NDArray* target);
|
||||
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<bfloat16 (Nd4jLong, bfloat16, bfloat16)>& func, NDArray* target);
|
||||
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<Nd4jLong (Nd4jLong, Nd4jLong, Nd4jLong)>& func, NDArray* target);
|
||||
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<int (Nd4jLong, int, int)>& func, NDArray* target);
|
||||
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<int16_t (Nd4jLong, int16_t, int16_t)>& func, NDArray* target);
|
||||
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<uint8_t (Nd4jLong, uint8_t, uint8_t)>& func, NDArray* target);
|
||||
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<uint16_t (Nd4jLong, uint16_t, uint16_t)>& func, NDArray* target);
|
||||
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<uint32_t (Nd4jLong, uint32_t, uint32_t)>& func, NDArray* target);
|
||||
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<uint64_t (Nd4jLong, uint64_t, uint64_t)>& func, NDArray* target);
|
||||
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<int8_t (Nd4jLong, int8_t, int8_t)>& func, NDArray* target);
|
||||
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<bool (Nd4jLong, bool, bool)>& func, NDArray* target);
|
||||
#endif
|
||||
|
||||
/*
|
||||
|
|
|
@ -0,0 +1,325 @@
|
|||
|
||||
|
||||
|
||||
template<typename T>
|
||||
void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<T(T, T, T)>& func, NDArray* target) {
|
||||
if (target == nullptr)
|
||||
target = this;
|
||||
|
||||
if (second == nullptr) {
|
||||
nd4j_printf("applyTriplewiseLambda requires three operands to be valid NDArrays, but Second is NULL\n","");
|
||||
throw std::runtime_error("second is null");
|
||||
}
|
||||
|
||||
if (third == nullptr) {
|
||||
nd4j_printf("applyTriplewiseLambda requires three operands to be valid NDArrays, but Third is NULL\n","");
|
||||
throw std::runtime_error("third is null");
|
||||
}
|
||||
if(dataType() != DataTypeUtils::fromT<T>())
|
||||
throw std::runtime_error("NDArray::applyTriplewiseLambda<T> method: wrong template parameter T, its type should be the same as type of this array!");
|
||||
if(dataType() != second->dataType() || dataType() != third->dataType() || dataType() != target->dataType())
|
||||
throw std::runtime_error("NDArray::applyTriplewiseLambda<T> method: bother four arrays (this, second, third, target) should have the same type !");
|
||||
|
||||
if (this->lengthOf() != second->lengthOf() || this->lengthOf() != third->lengthOf() || !this->isSameShape(second) || !this->isSameShape(third)) {
|
||||
nd4j_printf("applyPairwiseLambda requires both operands to have the same shape\n","");
|
||||
throw std::runtime_error("Shapes mismach");
|
||||
}
|
||||
|
||||
auto f = this->bufferAsT<T>();
|
||||
auto s = second->bufferAsT<T>();
|
||||
auto t = third->bufferAsT<T>();
|
||||
auto z = target->bufferAsT<T>();
|
||||
|
||||
if (this->ordering() == second->ordering() && this->ordering() == third->ordering() && this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1) && this->ews() == second->ews() && this->ews() == third->ews()) {
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||||
for (Nd4jLong e = 0; e < _length; e++)
|
||||
z[e] = func(f[e], s[e], t[e]);
|
||||
} else {
|
||||
if (f == z) {
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||||
for (int e = 0; e < _length; e++) {
|
||||
|
||||
auto tOffset = this->getOffset(e);
|
||||
auto uOffset = second->getOffset(e);
|
||||
auto vOffset = third->getOffset(e);
|
||||
|
||||
f[tOffset] = func(f[tOffset], s[uOffset], t[vOffset]);
|
||||
}
|
||||
} else {
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||||
for (int e = 0; e < _length; e++) {
|
||||
|
||||
auto tOffset = this->getOffset(e);
|
||||
auto uOffset = second->getOffset(e);
|
||||
auto vOffset = third->getOffset(e);
|
||||
auto zOffset = target->getOffset(e);
|
||||
|
||||
z[zOffset] = func(f[tOffset], s[uOffset], t[vOffset]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<double (double, double, double)>& func, NDArray* target);
|
||||
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<float (float, float, float)>& func, NDArray* target);
|
||||
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<float16 (float16, float16, float16)>& func, NDArray* target);
|
||||
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<bfloat16 (bfloat16, bfloat16, bfloat16)>& func, NDArray* target);
|
||||
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<Nd4jLong (Nd4jLong, Nd4jLong, Nd4jLong)>& func, NDArray* target);
|
||||
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<int (int, int, int)>& func, NDArray* target);
|
||||
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<int16_t (int16_t, int16_t, int16_t)>& func, NDArray* target);
|
||||
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<uint8_t (uint8_t, uint8_t, uint8_t)>& func, NDArray* target);
|
||||
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<uint16_t (uint16_t, uint16_t, uint16_t)>& func, NDArray* target);
|
||||
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<uint32_t (uint32_t, uint32_t, uint32_t)>& func, NDArray* target);
|
||||
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<uint64_t (uint64_t, uint64_t, uint64_t)>& func, NDArray* target);
|
||||
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<int8_t (int8_t, int8_t, int8_t)>& func, NDArray* target);
|
||||
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<bool (bool, bool, bool)>& func, NDArray* target);
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template<typename T>
|
||||
void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<T(T, T)>& func, NDArray* target) {
|
||||
if (target == nullptr)
|
||||
target = this;
|
||||
|
||||
if (other == nullptr) {
|
||||
nd4j_printf("applyPairwiseLambda requires both operands to be valid NDArrays, but Y is NULL\n","");
|
||||
throw std::runtime_error("Other is null");
|
||||
}
|
||||
|
||||
if(dataType() != DataTypeUtils::fromT<T>())
|
||||
throw std::runtime_error("NDArray::applyPairwiseLambda<T> method: wrong template parameter T, its type should be the same as type of this array!");
|
||||
if(dataType() != other->dataType() || dataType() != target->dataType())
|
||||
throw std::runtime_error("NDArray::applyPairwiseLambda<T> method: all three arrays (this, other, target) must have the same type !");
|
||||
|
||||
if (this->lengthOf() != other->lengthOf()) {
|
||||
nd4j_printf("applyPairwiseLambda requires both operands to have the same shape\n","");
|
||||
throw std::runtime_error("Shapes mismach");
|
||||
}
|
||||
|
||||
auto f = this->bufferAsT<T>();
|
||||
auto s = other->bufferAsT<T>();
|
||||
auto z = target->bufferAsT<T>();
|
||||
|
||||
if (this->ordering() == other->ordering() && this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1) && this->ews() == other->ews()) {
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||||
for (int e = 0; e < _length; e++)
|
||||
z[e] = func(f[e], s[e]);
|
||||
} else {
|
||||
if (f == z) {
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||||
for (int e = 0; e < _length; e++) {
|
||||
|
||||
auto xOffset = this->getOffset(e);
|
||||
auto yOffset = other->getOffset(e);
|
||||
|
||||
f[xOffset] = func(f[xOffset], s[yOffset]);
|
||||
}
|
||||
} else {
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||||
for (int e = 0; e < _length; e++) {
|
||||
|
||||
auto xOffset = this->getOffset(e);
|
||||
auto yOffset = other->getOffset(e);
|
||||
auto zOffset = target->getOffset(e);
|
||||
|
||||
z[zOffset] = func(f[xOffset], s[yOffset]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<double (double, double)>& func, NDArray* target);
|
||||
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<float (float, float)>& func, NDArray* target);
|
||||
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<float16 (float16, float16)>& func, NDArray* target);
|
||||
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<bfloat16 (bfloat16, bfloat16)>& func, NDArray* target);
|
||||
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<Nd4jLong (Nd4jLong, Nd4jLong)>& func, NDArray* target);
|
||||
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<int (int, int)>& func, NDArray* target);
|
||||
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<int16_t (int16_t, int16_t)>& func, NDArray* target);
|
||||
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<uint8_t (uint8_t, uint8_t)>& func, NDArray* target);
|
||||
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<uint16_t (uint16_t, uint16_t)>& func, NDArray* target);
|
||||
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<uint32_t (uint32_t, uint32_t)>& func, NDArray* target);
|
||||
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<uint64_t (uint64_t, uint64_t)>& func, NDArray* target);
|
||||
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<int8_t (int8_t, int8_t)>& func, NDArray* target);
|
||||
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<bool (bool, bool)>& func, NDArray* target);
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template<typename T>
|
||||
void NDArray::applyLambda(const std::function<T(T)>& func, NDArray* target) {
|
||||
if (target == nullptr)
|
||||
target = this;
|
||||
|
||||
if(dataType() != DataTypeUtils::fromT<T>())
|
||||
throw std::runtime_error("NDArray::applyLambda<T> method: wrong template parameter T, its type should be the same as type of this array!");
|
||||
if(dataType() != target->dataType())
|
||||
throw std::runtime_error("NDArray::applyLambda<T> method: types of this and target array should match !");
|
||||
|
||||
auto f = this->bufferAsT<T>();
|
||||
auto z = target->bufferAsT<T>();
|
||||
|
||||
if (this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1)) {
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||||
for (int e = 0; e < _length; e++)
|
||||
z[e] = func(f[e]);
|
||||
} else {
|
||||
if (f == z) {
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||||
for (int e = 0; e < _length; e++) {
|
||||
|
||||
auto xOffset = this->getOffset(e);
|
||||
|
||||
f[xOffset] = func(f[xOffset]);
|
||||
}
|
||||
} else {
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||||
for (int e = 0; e < _length; e++) {
|
||||
|
||||
auto xOffset = this->getOffset(e);
|
||||
auto zOffset = target->getOffset(e);
|
||||
|
||||
z[zOffset] = func(f[xOffset]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
template void NDArray::applyLambda(const std::function<double(double)>& func, NDArray* target);
|
||||
template void NDArray::applyLambda(const std::function<float(float)>& func, NDArray* target);
|
||||
template void NDArray::applyLambda(const std::function<float16(float16)>& func, NDArray* target);
|
||||
template void NDArray::applyLambda(const std::function<bfloat16(bfloat16)>& func, NDArray* target);
|
||||
template void NDArray::applyLambda(const std::function<Nd4jLong(Nd4jLong)>& func, NDArray* target);
|
||||
template void NDArray::applyLambda(const std::function<int16_t(int16_t)>& func, NDArray* target);
|
||||
template void NDArray::applyLambda(const std::function<int32_t(int32_t)>& func, NDArray* target);
|
||||
template void NDArray::applyLambda(const std::function<uint8_t(uint8_t)>& func, NDArray* target);
|
||||
template void NDArray::applyLambda(const std::function<uint16_t(uint16_t)>& func, NDArray* target);
|
||||
template void NDArray::applyLambda(const std::function<uint32_t(uint32_t)>& func, NDArray* target);
|
||||
template void NDArray::applyLambda(const std::function<uint64_t(uint64_t)>& func, NDArray* target);
|
||||
template void NDArray::applyLambda(const std::function<int8_t(int8_t)>& func, NDArray* target);
|
||||
template void NDArray::applyLambda(const std::function<bool(bool)>& func, NDArray* target);
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template<typename T>
|
||||
void NDArray::applyIndexedLambda(const std::function<T(Nd4jLong, T)>& func, NDArray* target) {
|
||||
if (target == nullptr)
|
||||
target = this;
|
||||
|
||||
if(dataType() != DataTypeUtils::fromT<T>())
|
||||
throw std::runtime_error("NDArray::applyIndexedLambda<T> method: wrong template parameter T, its type should be the same as type of this array!");
|
||||
if(dataType() != target->dataType())
|
||||
throw std::runtime_error("NDArray::applyIndexedLambda<T> method: types of this and target array should match !");
|
||||
|
||||
auto f = this->bufferAsT<T>();
|
||||
auto z = target->bufferAsT<T>();
|
||||
|
||||
if (this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1)) {
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||||
for (Nd4jLong e = 0; e < _length; e++)
|
||||
z[e] = func(e, f[e]);
|
||||
} else {
|
||||
if (f == z) {
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||||
for (Nd4jLong e = 0; e < _length; e++) {
|
||||
|
||||
auto xOffset = this->getOffset(e);
|
||||
|
||||
f[xOffset] = func(e, f[xOffset]);
|
||||
}
|
||||
} else {
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||||
for (Nd4jLong e = 0; e < _length; e++) {
|
||||
|
||||
auto xOffset = this->getOffset(e);
|
||||
auto zOffset = target->getOffset(e);
|
||||
|
||||
z[zOffset] = func(e, f[xOffset]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
template void NDArray::applyIndexedLambda(const std::function<double(Nd4jLong, double)>& func, NDArray* target);
|
||||
template void NDArray::applyIndexedLambda(const std::function<float(Nd4jLong, float)>& func, NDArray* target);
|
||||
template void NDArray::applyIndexedLambda(const std::function<float16(Nd4jLong, float16)>& func, NDArray* target);
|
||||
template void NDArray::applyIndexedLambda(const std::function<bfloat16(Nd4jLong, bfloat16)>& func, NDArray* target);
|
||||
template void NDArray::applyIndexedLambda(const std::function<Nd4jLong(Nd4jLong, Nd4jLong)>& func, NDArray* target);
|
||||
template void NDArray::applyIndexedLambda(const std::function<int(Nd4jLong, int)>& func, NDArray* target);
|
||||
template void NDArray::applyIndexedLambda(const std::function<int16_t(Nd4jLong, int16_t)>& func, NDArray* target);
|
||||
template void NDArray::applyIndexedLambda(const std::function<uint8_t (Nd4jLong, uint8_t)>& func, NDArray* target);
|
||||
template void NDArray::applyIndexedLambda(const std::function<uint16_t (Nd4jLong, uint16_t)>& func, NDArray* target);
|
||||
template void NDArray::applyIndexedLambda(const std::function<uint32_t (Nd4jLong, uint32_t)>& func, NDArray* target);
|
||||
template void NDArray::applyIndexedLambda(const std::function<uint64_t (Nd4jLong, uint64_t)>& func, NDArray* target);
|
||||
template void NDArray::applyIndexedLambda(const std::function<int8_t(Nd4jLong, int8_t)>& func, NDArray* target);
|
||||
template void NDArray::applyIndexedLambda(const std::function<bool(Nd4jLong, bool)>& func, NDArray* target);
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template<typename T>
|
||||
void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<T(Nd4jLong, T, T)>& func, NDArray* target) {
|
||||
if (target == nullptr)
|
||||
target = this;
|
||||
|
||||
if (other == nullptr) {
|
||||
nd4j_printf("applyIndexedPairwiseLambda requires both operands to be valid NDArrays, but Y is NULL\n","");
|
||||
throw std::runtime_error("Other is null");
|
||||
}
|
||||
if(dataType() != DataTypeUtils::fromT<T>())
|
||||
throw std::runtime_error("NDArray::applyIndexedPairwiseLambda<T> method: wrong template parameter T, its type should be the same as type of this array!");
|
||||
if(dataType() != target->dataType())
|
||||
throw std::runtime_error("NDArray::applyIndexedPairwiseLambda<T> method: types of this and target array should match !");
|
||||
if (this->lengthOf() != other->lengthOf()) {
|
||||
nd4j_printf("applyIndexedPairwiseLambda requires both operands to have the same shape\n","");
|
||||
throw std::runtime_error("Shapes mismach");
|
||||
}
|
||||
|
||||
auto f = this->bufferAsT<T>();
|
||||
auto s = other->bufferAsT<T>();
|
||||
auto z = target->bufferAsT<T>();
|
||||
|
||||
if (this->ordering() == other->ordering() && this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1) && this->ews() == other->ews()) {
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||||
for (Nd4jLong e = 0; e < _length; e++)
|
||||
z[e] = func((Nd4jLong) e, f[e], s[e]);
|
||||
} else {
|
||||
if (f == z) {
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||||
for (int e = 0; e < _length; e++) {
|
||||
|
||||
auto xOffset = this->getOffset(e);
|
||||
auto yOffset = other->getOffset(e);
|
||||
|
||||
f[xOffset] = func((Nd4jLong) e, f[xOffset], s[yOffset]);
|
||||
}
|
||||
} else {
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||||
for (int e = 0; e < _length; e++) {
|
||||
|
||||
auto xOffset = this->getOffset(e);
|
||||
auto yOffset = other->getOffset(e);
|
||||
auto zOffset = target->getOffset(e);
|
||||
|
||||
z[zOffset] = func((Nd4jLong) e, f[xOffset], s[yOffset]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<double (Nd4jLong, double, double)>& func, NDArray* target);
|
||||
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<float (Nd4jLong, float, float)>& func, NDArray* target);
|
||||
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<float16 (Nd4jLong, float16, float16)>& func, NDArray* target);
|
||||
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<bfloat16 (Nd4jLong, bfloat16, bfloat16)>& func, NDArray* target);
|
||||
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<Nd4jLong (Nd4jLong, Nd4jLong, Nd4jLong)>& func, NDArray* target);
|
||||
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<int (Nd4jLong, int, int)>& func, NDArray* target);
|
||||
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<int16_t (Nd4jLong, int16_t, int16_t)>& func, NDArray* target);
|
||||
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<uint8_t (Nd4jLong, uint8_t, uint8_t)>& func, NDArray* target);
|
||||
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<uint16_t (Nd4jLong, uint16_t, uint16_t)>& func, NDArray* target);
|
||||
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<uint32_t (Nd4jLong, uint32_t, uint32_t)>& func, NDArray* target);
|
||||
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<uint64_t (Nd4jLong, uint64_t, uint64_t)>& func, NDArray* target);
|
||||
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<int8_t (Nd4jLong, int8_t, int8_t)>& func, NDArray* target);
|
||||
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<bool (Nd4jLong, bool, bool)>& func, NDArray* target);
|
|
@ -2710,6 +2710,32 @@ int NativeOps::dataTypeFromNpyHeader(void *header) {
|
|||
return (int) cnpy::dataTypeFromHeader(reinterpret_cast<char *>(header));
|
||||
}
|
||||
|
||||
Nd4jPointer NativeOps::shapeBufferForNumpy(Nd4jPointer npyArray) {
|
||||
cnpy::NpyArray arr = cnpy::loadNpyFromPointer(reinterpret_cast<char *>(npyArray));
|
||||
unsigned int shapeSize = arr.shape.size();
|
||||
std::vector<Nd4jLong> shape(shapeSize);
|
||||
bool _empty = false;
|
||||
for(unsigned int i = 0; i < shapeSize; i++) {
|
||||
shape[i] = arr.shape[i];
|
||||
|
||||
if (arr.shape[i] == 0)
|
||||
_empty = true;
|
||||
}
|
||||
|
||||
auto dtype = cnpy::dataTypeFromHeader(reinterpret_cast<char *>(npyArray));
|
||||
|
||||
Nd4jLong *shapeBuffer;
|
||||
if (_empty) {
|
||||
if (shapeSize > 0)
|
||||
shapeBuffer = nd4j::ShapeBuilders::emptyShapeInfo(dtype, arr.fortranOrder ? 'f' : 'c', shape);
|
||||
else
|
||||
shapeBuffer = nd4j::ShapeBuilders::emptyShapeInfo(dtype);
|
||||
} else {
|
||||
shapeBuffer = nd4j::ShapeBuilders::createShapeInfo(dtype, arr.fortranOrder ? 'f' : 'c', shape);
|
||||
}
|
||||
return reinterpret_cast<Nd4jPointer>(shapeBuffer);
|
||||
}
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template void flattenGeneric,(Nd4jPointer*, int, char, void*, Nd4jLong*, void*, Nd4jLong*), LIBND4J_TYPES);
|
||||
BUILD_SINGLE_TEMPLATE(template void pullRowsGeneric, (void *, Nd4jLong*, void*, Nd4jLong*, const int, Nd4jLong*, Nd4jLong*, Nd4jLong*, Nd4jLong*, Nd4jLong*), LIBND4J_TYPES);
|
||||
BUILD_SINGLE_TEMPLATE(template void tearGeneric, (void *, Nd4jLong*, Nd4jPointer*, Nd4jLong*, Nd4jLong*, Nd4jLong*), LIBND4J_TYPES);
|
||||
|
|
|
@ -454,7 +454,7 @@ void NDArray::printCurrentBuffer(const bool host, const char* msg, const int pre
|
|||
|
||||
if (ews() != 1) {
|
||||
for (uint i = 0; i < _length; i++)
|
||||
cudaMemcpyAsync(pHost + i * sizeof(T), getSpecialBuffer() + getOffset(i) * sizeof(T), sizeof(T), cudaMemcpyDeviceToHost, *(getContext()->getCudaStream()));
|
||||
cudaMemcpyAsync(reinterpret_cast<T*>(pHost) + i, specialBufferWithOffset(i), sizeof(T), cudaMemcpyDeviceToHost, *(getContext()->getCudaStream()));
|
||||
}
|
||||
else
|
||||
cudaMemcpyAsync(pHost, getSpecialBuffer(), sizeOfT() * _length, cudaMemcpyDeviceToHost, *getContext()->getCudaStream());
|
||||
|
@ -475,6 +475,12 @@ template void NDArray::printCurrentBuffer<float>(const bool host, const char* ms
|
|||
template void NDArray::printCurrentBuffer<double>(const bool host, const char* msg, const int precision) const;
|
||||
|
||||
|
||||
#if defined(__CUDACC__) && !defined(BUILD_TESTS)
|
||||
|
||||
#include <cpu/NDArrayLambda.hpp>
|
||||
|
||||
#endif
|
||||
|
||||
} // end namespace nd4j
|
||||
#endif
|
||||
|
||||
|
|
|
@ -3105,3 +3105,29 @@ nd4j::ConstantDataBuffer* NativeOps::constantBuffer(nd4j::DataType dtype, double
|
|||
nd4j::ConstantDataBuffer* NativeOps::constantBuffer(nd4j::DataType dtype, nd4j::ConstantDescriptor *descriptor) {
|
||||
return nd4j::ConstantHelper::getInstance()->constantBuffer(*descriptor, dtype);
|
||||
}
|
||||
|
||||
Nd4jPointer NativeOps::shapeBufferForNumpy(Nd4jPointer npyArray) {
|
||||
cnpy::NpyArray arr = cnpy::loadNpyFromPointer(reinterpret_cast<char *>(npyArray));
|
||||
unsigned int shapeSize = arr.shape.size();
|
||||
std::vector<Nd4jLong> shape(shapeSize);
|
||||
bool _empty = false;
|
||||
for(unsigned int i = 0; i < shapeSize; i++) {
|
||||
shape[i] = arr.shape[i];
|
||||
|
||||
if (arr.shape[i] == 0)
|
||||
_empty = true;
|
||||
}
|
||||
|
||||
auto dtype = cnpy::dataTypeFromHeader(reinterpret_cast<char *>(npyArray));
|
||||
|
||||
Nd4jLong *shapeBuffer;
|
||||
if (_empty) {
|
||||
if (shapeSize > 0)
|
||||
shapeBuffer = nd4j::ShapeBuilders::emptyShapeInfo(dtype, arr.fortranOrder ? 'f' : 'c', shape);
|
||||
else
|
||||
shapeBuffer = nd4j::ShapeBuilders::emptyShapeInfo(dtype);
|
||||
} else {
|
||||
shapeBuffer = nd4j::ShapeBuilders::createShapeInfo(dtype, arr.fortranOrder ? 'f' : 'c', shape);
|
||||
}
|
||||
return reinterpret_cast<Nd4jPointer>(shapeBuffer);
|
||||
}
|
||||
|
|
|
@ -53,6 +53,15 @@ namespace nd4j {
|
|||
template <typename T>
|
||||
FORCEINLINE static _CUDA_HD T max();
|
||||
|
||||
/**
|
||||
* returns inf for float/double and max for everything else
|
||||
*/
|
||||
template <typename T>
|
||||
FORCEINLINE static _CUDA_HD T infOrMax();
|
||||
|
||||
template <typename T>
|
||||
FORCEINLINE static _CUDA_HD T nanOrZero();
|
||||
|
||||
// returns the difference between 1.0 and the next representable value of the given floating-point type
|
||||
template <typename T>
|
||||
FORCEINLINE static T eps();
|
||||
|
@ -290,6 +299,36 @@ FORCEINLINE _CUDA_HD bfloat16 DataTypeUtils::max<bfloat16>() {
|
|||
return bfloat16::max();
|
||||
}
|
||||
|
||||
template <>
|
||||
FORCEINLINE _CUDA_HD float DataTypeUtils::infOrMax<float>() {
|
||||
return std::numeric_limits<float>::infinity();
|
||||
}
|
||||
|
||||
template <>
|
||||
FORCEINLINE _CUDA_HD double DataTypeUtils::infOrMax<double>() {
|
||||
return std::numeric_limits<double>::infinity();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
FORCEINLINE _CUDA_HD T DataTypeUtils::infOrMax() {
|
||||
return DataTypeUtils::max<T>();
|
||||
}
|
||||
|
||||
template <>
|
||||
FORCEINLINE _CUDA_HD float DataTypeUtils::nanOrZero<float>() {
|
||||
return std::numeric_limits<float>::quiet_NaN();
|
||||
}
|
||||
|
||||
template <>
|
||||
FORCEINLINE _CUDA_HD double DataTypeUtils::nanOrZero<double>() {
|
||||
return std::numeric_limits<double>::quiet_NaN();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
FORCEINLINE _CUDA_HD T DataTypeUtils::nanOrZero() {
|
||||
return static_cast<T>(0);
|
||||
}
|
||||
|
||||
FORCEINLINE std::string DataTypeUtils::asString(DataType dataType) {
|
||||
switch(dataType) {
|
||||
case INT8:
|
||||
|
|
|
@ -55,8 +55,13 @@ bool ShapeDescriptor::operator<(const ShapeDescriptor& other) const {
|
|||
}
|
||||
|
||||
Nd4jLong* ShapeDescriptor::toShapeInfo() const {
|
||||
if (_empty)
|
||||
if (_empty) {
|
||||
if (_rank == 0)
|
||||
return ShapeBuilders::emptyShapeInfo(_dataType);
|
||||
else {
|
||||
return ShapeBuilders::emptyShapeInfo(_dataType, _order, _shape);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
switch (_rank) {
|
||||
|
@ -133,15 +138,11 @@ ShapeDescriptor::ShapeDescriptor(const DataType type, const char order, const Nd
|
|||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
ShapeDescriptor::ShapeDescriptor(const DataType type, const char order, const std::vector<Nd4jLong> &shape): _dataType(type), _order(order), _shape(shape) {
|
||||
_rank = ((shape.size() == 1 && shape[0] == 0)? 0: shape.size());
|
||||
_rank = shape.size();
|
||||
_ews = 1;
|
||||
|
||||
if (_rank > 0) {
|
||||
_strides.resize(_rank);
|
||||
if (order == 'c')
|
||||
shape::calcStrides(_shape.data(), shape.size(), _strides.data());
|
||||
else
|
||||
shape::calcStridesFortran(_shape.data(), shape.size(), _strides.data());
|
||||
|
||||
for (auto v:_shape) {
|
||||
if (v == 0) {
|
||||
|
@ -149,6 +150,17 @@ ShapeDescriptor::ShapeDescriptor(const DataType type, const char order, const st
|
|||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// no point calculating strides for empty arrays
|
||||
if (!_empty) {
|
||||
if (order == 'c')
|
||||
shape::calcStrides(_shape.data(), shape.size(), _strides.data());
|
||||
else
|
||||
shape::calcStridesFortran(_shape.data(), shape.size(), _strides.data());
|
||||
} else {
|
||||
// all strides set to 0
|
||||
memset(_strides.data(), 0, sizeof(Nd4jLong) * shape.size());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -191,8 +203,11 @@ ShapeDescriptor::ShapeDescriptor(const Nd4jLong *shapeInfo, bool inheritDtype) {
|
|||
|
||||
_empty = shape::isEmpty(shapeInfo);
|
||||
|
||||
for (int e = 0; e < _rank; e++)
|
||||
for (int e = 0; e < _rank; e++) {
|
||||
_shape.emplace_back(shapeInfo[e + 1]);
|
||||
if (shapeInfo[e + 1] == 0)
|
||||
_empty = true;
|
||||
}
|
||||
|
||||
for (int e = 0; e < _rank; e++)
|
||||
_strides.emplace_back(shapeInfo[e + 1 + _rank]);
|
||||
|
@ -304,7 +319,14 @@ ShapeDescriptor ShapeDescriptor::vectorDescriptor(const Nd4jLong length, const D
|
|||
ShapeDescriptor descriptor;
|
||||
descriptor._dataType = type;
|
||||
descriptor._shape.emplace_back(length);
|
||||
|
||||
if (length > 0)
|
||||
descriptor._strides.emplace_back(1);
|
||||
else {
|
||||
descriptor._strides.emplace_back(0);
|
||||
descriptor._empty = true;
|
||||
}
|
||||
|
||||
descriptor._order = 'c';
|
||||
descriptor._ews = 1;
|
||||
descriptor._rank = 1;
|
||||
|
|
|
@ -29,7 +29,7 @@
|
|||
#include <array/ArrayOptions.h>
|
||||
|
||||
namespace nd4j {
|
||||
class ShapeBuilders {
|
||||
class ND4J_EXPORT ShapeBuilders {
|
||||
public:
|
||||
static Nd4jLong* createScalarShapeInfo(nd4j::DataType dataType, nd4j::memory::Workspace* workspace = nullptr);
|
||||
|
||||
|
@ -53,6 +53,8 @@ namespace nd4j {
|
|||
|
||||
static Nd4jLong* emptyShapeInfo(const nd4j::DataType dataType, memory::Workspace* workspace = nullptr);
|
||||
|
||||
static Nd4jLong* emptyShapeInfo(const nd4j::DataType dataType, const char order, const std::vector<Nd4jLong> &shape, memory::Workspace* workspace = nullptr);
|
||||
|
||||
};
|
||||
}
|
||||
|
||||
|
|
|
@ -40,6 +40,12 @@ namespace nd4j {
|
|||
static Nd4jLong* evalReduceShapeInfo(const char order, std::vector<int>& dimensions, const NDArray& arr, const bool keepDims = false, const bool supportOldShapes = false, nd4j::memory::Workspace* workspace = nullptr);
|
||||
static Nd4jLong* evalReduceShapeInfo(const char order, std::vector<int>& dimensions, const Nd4jLong* shapeInfo, const bool keepDims = false, const bool supportOldShapes = false, nd4j::memory::Workspace* workspace = nullptr);
|
||||
|
||||
/**
|
||||
* evaluate output shape for reduce operation when input shape is empty
|
||||
* behavior is analogous to tf
|
||||
*/
|
||||
static Nd4jLong* evalReduceShapeInfoEmpty(const char order, std::vector<int>& dimensions, const Nd4jLong *shapeInfo, const nd4j::DataType dataType, const bool keepDims, nd4j::memory::Workspace* workspace);
|
||||
|
||||
// evaluate shape for array which is result of repeat operation applied to arr
|
||||
static std::vector<Nd4jLong> evalRepeatShape(int dimension, const std::vector<Nd4jLong>& repeats, const NDArray& arr);
|
||||
|
||||
|
|
|
@ -71,8 +71,10 @@ namespace nd4j {
|
|||
auto sPtr = new Nd4jLong[shape::shapeInfoLength(subArrRank)];
|
||||
auto oPtr = new Nd4jLong[numOfSubArrs];
|
||||
|
||||
if (numOfSubArrs > 0)
|
||||
shape::calcSubArrShapeAndOffsets(shapeInfo, numOfSubArrs, dimsToExclude.size(), dimsToExclude.data(), sPtr, oPtr, descriptor.areUnitiesinShape());
|
||||
|
||||
|
||||
ConstantDataBuffer shapesBuffer(sPtr, nullptr, shape::shapeInfoLength(subArrRank)*sizeof(Nd4jLong), DataType::INT64);
|
||||
ConstantDataBuffer offsetsBuffer(oPtr, nullptr, numOfSubArrs*sizeof(Nd4jLong), DataType::INT64);
|
||||
TadPack t(shapesBuffer, offsetsBuffer, numOfSubArrs);
|
||||
|
|
|
@ -75,6 +75,7 @@ namespace nd4j {
|
|||
auto sPtr = new Nd4jLong[shape::shapeInfoLength(subArrRank)];
|
||||
auto oPtr = new Nd4jLong[numOfSubArrs];
|
||||
|
||||
if (numOfSubArrs > 0)
|
||||
shape::calcSubArrShapeAndOffsets(shapeInfo, numOfSubArrs, dimsToExclude.size(), dimsToExclude.data(), sPtr, oPtr, descriptor.areUnitiesinShape());
|
||||
|
||||
Nd4jPointer soPtr;
|
||||
|
|
|
@ -54,11 +54,6 @@ namespace nd4j {
|
|||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
Nd4jLong* ShapeBuilders::createShapeInfo(const nd4j::DataType dataType, const char order, int rank, const Nd4jLong* shapeOnly, memory::Workspace* workspace) {
|
||||
|
||||
if (rank)
|
||||
if(shapeOnly[0] == 0) // scalar case
|
||||
rank = 0;
|
||||
|
||||
Nd4jLong* shapeInfo = nullptr;
|
||||
|
||||
if(rank == 0) { // scalar case
|
||||
|
@ -67,10 +62,23 @@ namespace nd4j {
|
|||
else {
|
||||
ALLOCATE(shapeInfo, workspace, shape::shapeInfoLength(rank), Nd4jLong);
|
||||
shapeInfo[0] = rank;
|
||||
for(int i = 0; i < rank; ++i)
|
||||
bool isEmpty = false;
|
||||
for(int i = 0; i < rank; ++i) {
|
||||
shapeInfo[i + 1] = shapeOnly[i];
|
||||
|
||||
if (shapeOnly[i] == 0)
|
||||
isEmpty = true;
|
||||
}
|
||||
|
||||
if (!isEmpty) {
|
||||
shape::updateStrides(shapeInfo, order);
|
||||
}
|
||||
else {
|
||||
shapeInfo[shape::shapeInfoLength(rank) - 1] = order;
|
||||
memset(shape::stride(shapeInfo), 0, rank * sizeof(Nd4jLong));
|
||||
ArrayOptions::setPropertyBit(shapeInfo, ARRAY_EMPTY);
|
||||
}
|
||||
|
||||
nd4j::ArrayOptions::setDataType(shapeInfo, dataType);
|
||||
}
|
||||
|
||||
|
@ -78,9 +86,16 @@ namespace nd4j {
|
|||
}
|
||||
|
||||
Nd4jLong* ShapeBuilders::emptyShapeInfo(const nd4j::DataType dataType, memory::Workspace* workspace) {
|
||||
auto shape = createScalarShapeInfo(dataType, workspace);
|
||||
ArrayOptions::setPropertyBit(shape, ARRAY_EMPTY);
|
||||
return shape;
|
||||
auto shapeInfo = createScalarShapeInfo(dataType, workspace);
|
||||
ArrayOptions::setPropertyBit(shapeInfo, ARRAY_EMPTY);
|
||||
return shapeInfo;
|
||||
}
|
||||
|
||||
Nd4jLong* ShapeBuilders::emptyShapeInfo(const nd4j::DataType dataType, const char order, const std::vector<Nd4jLong> &shape, memory::Workspace* workspace) {
|
||||
auto shapeInfo = createShapeInfo(dataType, order, shape, workspace);
|
||||
memset(shape::stride(shapeInfo), 0, shape.size() * sizeof(Nd4jLong));
|
||||
ArrayOptions::setPropertyBit(shapeInfo, ARRAY_EMPTY);
|
||||
return shapeInfo;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
|
|
@ -108,27 +108,81 @@ std::vector<Nd4jLong> ShapeUtils::evalShapeForTensorDot(const NDArray* a, cons
|
|||
return evalShapeForTensorDot(a->getShapeInfo(), b->getShapeInfo(), axesA, axesB, permutAt, permutBt, shapeAt, shapeBt);
|
||||
}
|
||||
|
||||
Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector<int>& dimensions, const NDArray& arr, const bool keepDims, const bool supportOldShapes, nd4j::memory::Workspace* workspace) {
|
||||
return evalReduceShapeInfo(order, dimensions, arr, arr.dataType(), keepDims, supportOldShapes, workspace);
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
// evaluate output shape for reduce operation when input shape is empty
|
||||
Nd4jLong* ShapeUtils::evalReduceShapeInfoEmpty(const char order, std::vector<int>& dimsToExclude, const Nd4jLong *shapeInfo, const nd4j::DataType dataType, const bool keepDims, nd4j::memory::Workspace* workspace) {
|
||||
|
||||
if (dimsToExclude.size() == 0) { // return copy of input shape
|
||||
Nd4jLong* outShapeInfo = ShapeBuilders::copyShapeInfoAndType(shapeInfo, dataType, true, workspace);
|
||||
ShapeDescriptor descriptor(outShapeInfo, dataType);
|
||||
RELEASE(outShapeInfo, workspace);
|
||||
return ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
|
||||
}
|
||||
|
||||
Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector<int>& dimensions, const Nd4jLong* shapeInfo, const bool keepDims, const bool supportOldShapes, nd4j::memory::Workspace* workspace) {
|
||||
return evalReduceShapeInfo(order, dimensions, shapeInfo, ArrayOptions::dataType(shapeInfo), keepDims, supportOldShapes, workspace);
|
||||
const int rank = shape::rank(shapeInfo);
|
||||
Nd4jLong* outShapeInfo = nullptr;
|
||||
|
||||
if (dimsToExclude.size() == rank) { // return scalar or shape filled with unities
|
||||
|
||||
if(!keepDims)
|
||||
outShapeInfo = ShapeBuilders::createScalarShapeInfo(dataType, workspace);
|
||||
else
|
||||
outShapeInfo = ShapeBuilders::createShapeInfo(dataType, order, std::vector<Nd4jLong>(rank, 1), workspace);
|
||||
}
|
||||
else {
|
||||
|
||||
shape::checkDimensions(rank, dimsToExclude);
|
||||
|
||||
std::vector<Nd4jLong> outShape;
|
||||
|
||||
if(keepDims) {
|
||||
outShape.assign(shapeInfo + 1, shapeInfo + 1 + rank);
|
||||
for(const auto& dim : dimsToExclude)
|
||||
outShape[dim] = 1;
|
||||
}
|
||||
else {
|
||||
for (uint i = 0, j = 0; i < rank; ++i) {
|
||||
if(j < dimsToExclude.size() && i == dimsToExclude[j])
|
||||
++j;
|
||||
else
|
||||
outShape.emplace_back(shapeInfo[i + 1]);
|
||||
}
|
||||
}
|
||||
|
||||
outShapeInfo = ShapeBuilders::createShapeInfo(dataType, order, outShape, workspace);
|
||||
}
|
||||
|
||||
ShapeDescriptor descriptor(outShapeInfo, dataType);
|
||||
RELEASE(outShapeInfo, workspace);
|
||||
return ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
|
||||
}
|
||||
|
||||
Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector<int>& dimsToExclude, const NDArray& arr, const bool keepDims, const bool supportOldShapes, nd4j::memory::Workspace* workspace) {
|
||||
return evalReduceShapeInfo(order, dimsToExclude, arr, arr.dataType(), keepDims, supportOldShapes, workspace);
|
||||
}
|
||||
|
||||
Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector<int>& dimsToExclude, const Nd4jLong* shapeInfo, const bool keepDims, const bool supportOldShapes, nd4j::memory::Workspace* workspace) {
|
||||
return evalReduceShapeInfo(order, dimsToExclude, shapeInfo, ArrayOptions::dataType(shapeInfo), keepDims, supportOldShapes, workspace);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector<int>& dimensions, const NDArray& arr, const nd4j::DataType dataType, const bool keepDims, const bool supportOldShapes, nd4j::memory::Workspace* workspace) {
|
||||
return evalReduceShapeInfo(order, dimensions, arr.getShapeInfo(), dataType, keepDims, supportOldShapes, workspace);
|
||||
Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector<int>& dimsToExclude, const NDArray& arr, const nd4j::DataType dataType, const bool keepDims, const bool supportOldShapes, nd4j::memory::Workspace* workspace) {
|
||||
return evalReduceShapeInfo(order, dimsToExclude, arr.getShapeInfo(), dataType, keepDims, supportOldShapes, workspace);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
// evaluate shape resulting from reduce operation
|
||||
Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector<int>& dimensions, const Nd4jLong *shapeInfo, const nd4j::DataType dataType, const bool keepDims, const bool supportOldShapes, nd4j::memory::Workspace* workspace) {
|
||||
Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector<int>& dimsToExclude, const Nd4jLong *shapeInfo, const nd4j::DataType dataType, const bool keepDims, const bool supportOldShapes, nd4j::memory::Workspace* workspace) {
|
||||
|
||||
if(ArrayOptions::arrayType(shapeInfo) == ArrayType::EMPTY)
|
||||
return ShapeUtils::evalReduceShapeInfoEmpty(order, dimsToExclude, shapeInfo, dataType, keepDims, workspace);
|
||||
|
||||
Nd4jLong* newShapeInfo = nullptr;
|
||||
|
||||
int rank = shape::rank(const_cast<Nd4jLong*>(shapeInfo));
|
||||
|
||||
if (dimensions.size() == 0) { // return scalar or array with len=1 in this case
|
||||
if (dimsToExclude.size() == 0) { // return scalar or array with len=1 in this case
|
||||
|
||||
if(keepDims && rank > 1) {
|
||||
ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(rank), Nd4jLong);
|
||||
|
@ -157,16 +211,16 @@ Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector<int>& di
|
|||
}
|
||||
}
|
||||
|
||||
shape::checkDimensions(rank, dimensions);
|
||||
shape::checkDimensions(rank, dimsToExclude);
|
||||
|
||||
int dimSize = dimensions.size();
|
||||
int dimSize = dimsToExclude.size();
|
||||
|
||||
if(keepDims) {
|
||||
|
||||
ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(rank), Nd4jLong);
|
||||
newShapeInfo[0] = rank;
|
||||
for(int i = 0; i < rank; ++i)
|
||||
if (std::binary_search(dimensions.begin(), dimensions.end(), i)) // dimensions is already sorted after shape::checkDimensions() has been applied
|
||||
if (std::binary_search(dimsToExclude.begin(), dimsToExclude.end(), i)) // dimsToExclude is already sorted after shape::checkDimensions() has been applied
|
||||
newShapeInfo[i+1] = 1;
|
||||
else
|
||||
newShapeInfo[i+1] = shapeInfo[i+1];
|
||||
|
@ -178,7 +232,7 @@ Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector<int>& di
|
|||
}
|
||||
|
||||
int newRank = rank - dimSize;
|
||||
if (newRank==0 || (dimSize==1 && dimensions[0]==INT_MAX)) { // check whether given dimension is meant for the whole dimension
|
||||
if (newRank==0 || (dimSize==1 && dimsToExclude[0]==INT_MAX)) { // check whether given dimension is meant for the whole dimension
|
||||
|
||||
if(supportOldShapes) {
|
||||
ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(2), Nd4jLong);
|
||||
|
@ -199,7 +253,7 @@ Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector<int>& di
|
|||
newShapeInfo[0] = newRank; // set rank
|
||||
int j=1;
|
||||
for(int i = 0; i < rank; ++i)
|
||||
if (!std::binary_search(dimensions.begin(), dimensions.end(), i)) // dimensions is already sorted after shape::checkDimensions() has been applied
|
||||
if (!std::binary_search(dimsToExclude.begin(), dimsToExclude.end(), i)) // dimsToExclude is already sorted after shape::checkDimensions() has been applied
|
||||
newShapeInfo[j++] = shapeInfo[i+1];
|
||||
|
||||
//ensure whether vector has proper shape for old shape type
|
||||
|
@ -208,7 +262,7 @@ Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector<int>& di
|
|||
RELEASE(newShapeInfo, workspace);
|
||||
ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(2), Nd4jLong); // set newRank = 2
|
||||
newShapeInfo[0] = 2;
|
||||
if (dimensions[0] == 0) {
|
||||
if (dimsToExclude[0] == 0) {
|
||||
newShapeInfo[1] = 1;
|
||||
newShapeInfo[2] = oldValue;
|
||||
}
|
||||
|
@ -422,8 +476,23 @@ bool ShapeUtils::evalBroadcastShapeInfo(Nd4jLong *max, Nd4jLong *min, const bool
|
|||
if(maxShapeInfo[maxRank-i] < minShapeInfo[minRank-i])
|
||||
tmpShapeInfo[maxRank - i] = minShapeInfo[minRank-i];
|
||||
|
||||
// nullify zero axis
|
||||
for (int e = 0; e < maxRank; e++)
|
||||
if (maxShapeInfo[e+1] == 0)
|
||||
tmpShapeInfo[e+1] = 0;
|
||||
|
||||
int delta = maxRank - minRank;
|
||||
for (int e = minRank - 1; e >= 0; e--)
|
||||
if (minShapeInfo[e + 1] == 0)
|
||||
tmpShapeInfo[e + 1 + delta] = 0;
|
||||
|
||||
ShapeUtils::updateStridesAndType(tmpShapeInfo, DataTypeUtils::pickPairwiseResultType(maxShapeInfo, minShapeInfo), shape::order(maxShapeInfo));
|
||||
|
||||
if (shape::isEmpty(max) || shape::isEmpty(min)) {
|
||||
ArrayOptions::setPropertyBit(tmpShapeInfo, ARRAY_EMPTY);
|
||||
memset(shape::stride(tmpShapeInfo), 0, shape::rank(tmpShapeInfo) * sizeof(Nd4jLong));
|
||||
}
|
||||
|
||||
ShapeDescriptor descriptor(tmpShapeInfo);
|
||||
RELEASE(tmpShapeInfo, workspace);
|
||||
resultShapeInfo = ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
|
||||
|
@ -805,7 +874,7 @@ std::vector<Nd4jLong> ShapeUtils::evalShapeForMatmul(const Nd4jLong* xShapeInfo,
|
|||
nd4j_printf("ShapeUtils::evalShapeForMatmul method: since input arrays are vectors they must have the same length, but got x length = %i, y length = %i !", xShapeInfo[1], yShapeInfo[1]);
|
||||
throw std::invalid_argument("");
|
||||
}
|
||||
return std::vector<Nd4jLong>({0});
|
||||
return std::vector<Nd4jLong>({});
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -1992,7 +1992,7 @@ template <typename T>
|
|||
len = shape::length(shapeInfo);
|
||||
|
||||
//check whether shape is like {1} or {1,1} or {1,1,1,1,...} - in this case we don't need permute
|
||||
if(len < 2)
|
||||
if(len == 1)
|
||||
return;
|
||||
|
||||
const int rank = shape::rank(shapeInfo);
|
||||
|
@ -3961,7 +3961,7 @@ INLINEDEF _CUDA_H bool reshapeC(const int oldRank, const Nd4jLong* oldShapeInfo,
|
|||
newDim = newShape[newStart];
|
||||
oldDim = oldShape[oldStart];
|
||||
|
||||
while (newDim != oldDim)
|
||||
while (newDim != oldDim && newDim > 0 && oldDim > 0)
|
||||
if (newDim < oldDim) newDim *= newShape[newStop++];
|
||||
else oldDim *= oldShape[oldStop++];
|
||||
|
||||
|
|
|
@ -116,13 +116,23 @@ void IndexReduce<X>::exec(void *vx, Nd4jLong *xShapeInfo,
|
|||
auto x = reinterpret_cast<X *>(vx);
|
||||
auto extraParams = reinterpret_cast<X *>(vextraParams);
|
||||
|
||||
const Nd4jLong zLen = shape::length(zShapeInfo);
|
||||
|
||||
if(nd4j::ArrayOptions::arrayType(xShapeInfo) == nd4j::ArrayType::EMPTY) {
|
||||
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
|
||||
return;
|
||||
const auto indexValue = OpType::startingIndexValue(x);
|
||||
PRAGMA_OMP_PARALLEL_FOR_IF(zLen > nd4j::Environment::getInstance()->elementwiseThreshold())
|
||||
for (uint i = 0; i < zLen; i++)
|
||||
z[i] = indexValue.index;;
|
||||
return;
|
||||
}
|
||||
|
||||
if(shape::isScalar(zShapeInfo)) {
|
||||
z[0] = execScalar<OpType>(x,xShapeInfo,extraParams);
|
||||
return;
|
||||
}
|
||||
|
||||
const Nd4jLong zLen = shape::length(zShapeInfo);
|
||||
|
||||
auto tadOnlyShapeInfo = tadShapeInfo;
|
||||
Nd4jLong *tadOffsets = tadOffset;
|
||||
|
||||
|
|
|
@ -46,6 +46,21 @@ namespace functions {
|
|||
const Nd4jLong length = shape::length(xShapeInfo);
|
||||
auto xEws = shape::elementWiseStride(xShapeInfo);
|
||||
|
||||
if (shape::isEmpty(xShapeInfo)) {
|
||||
z[0] = OpType::startingValue(x);
|
||||
return;
|
||||
}
|
||||
|
||||
if(nd4j::ArrayOptions::arrayType(xShapeInfo) == nd4j::ArrayType::EMPTY) {
|
||||
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
|
||||
return;
|
||||
const auto startingVal = OpType::startingValue(x);
|
||||
PRAGMA_OMP_PARALLEL_FOR_IF(length > nd4j::Environment::getInstance()->elementwiseThreshold())
|
||||
for (uint i = 0; i < length; i++)
|
||||
z[i] = startingVal;
|
||||
return;
|
||||
}
|
||||
|
||||
if (xEws >= 1) {
|
||||
z[0] = execScalar<OpType>(x, xEws, length, extraParams);
|
||||
}
|
||||
|
@ -157,6 +172,16 @@ namespace functions {
|
|||
|
||||
auto resultLength = shape::length(zShapeInfo);
|
||||
|
||||
if(nd4j::ArrayOptions::arrayType(xShapeInfo) == nd4j::ArrayType::EMPTY) {
|
||||
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
|
||||
return;
|
||||
const auto startingVal = OpType::startingValue(x);
|
||||
PRAGMA_OMP_PARALLEL_FOR_IF(resultLength > nd4j::Environment::getInstance()->elementwiseThreshold())
|
||||
for (uint i = 0; i < resultLength; i++)
|
||||
z[i] = startingVal;
|
||||
return;
|
||||
}
|
||||
|
||||
//pre squeezed: this is for keeping the pointer to the original
|
||||
//shape information for tad offset
|
||||
//the squeezed information doesn't render the right strides for
|
||||
|
|
|
@ -46,6 +46,25 @@ namespace functions {
|
|||
const Nd4jLong length = shape::length(xShapeInfo);
|
||||
auto xEws = shape::elementWiseStride(xShapeInfo);
|
||||
|
||||
if (shape::isEmpty(xShapeInfo)) {
|
||||
if (std::is_same<OpType, simdOps::Mean<X,Z>>::value) {
|
||||
z[0] = nd4j::DataTypeUtils::nanOrZero<Z>();
|
||||
} else {
|
||||
z[0] = OpType::startingValue(x);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if(nd4j::ArrayOptions::arrayType(xShapeInfo) == nd4j::ArrayType::EMPTY) {
|
||||
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
|
||||
return;
|
||||
const auto startingVal = OpType::startingValue(x);
|
||||
PRAGMA_OMP_PARALLEL_FOR_IF(length > nd4j::Environment::getInstance()->elementwiseThreshold())
|
||||
for (uint i = 0; i < length; i++)
|
||||
z[i] = startingVal;
|
||||
return;
|
||||
}
|
||||
|
||||
if (xEws > 0) {
|
||||
z[0] = execScalar<OpType>(x, xEws, length, extraParams);
|
||||
}
|
||||
|
@ -165,6 +184,16 @@ namespace functions {
|
|||
|
||||
auto resultLength = shape::length(zShapeInfo);
|
||||
|
||||
if(nd4j::ArrayOptions::arrayType(xShapeInfo) == nd4j::ArrayType::EMPTY) {
|
||||
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
|
||||
return;
|
||||
const auto startingVal = std::is_same<OpType, simdOps::Mean<X,Z>>::value ? nd4j::DataTypeUtils::nanOrZero<Z>() : static_cast<Z>(OpType::startingValue(x));
|
||||
PRAGMA_OMP_PARALLEL_FOR_IF(resultLength > nd4j::Environment::getInstance()->elementwiseThreshold())
|
||||
for (uint i = 0; i < resultLength; i++)
|
||||
z[i] = startingVal;
|
||||
return;
|
||||
}
|
||||
|
||||
//pre squeezed: this is for keeping the pointer to the original
|
||||
//shape information for tad offset
|
||||
//the squeezed information doesn't render the right strides for
|
||||
|
|
|
@ -46,6 +46,21 @@ namespace functions {
|
|||
const Nd4jLong length = shape::length(xShapeInfo);
|
||||
auto xEws = shape::elementWiseStride(xShapeInfo);
|
||||
|
||||
if (shape::isEmpty(xShapeInfo)) {
|
||||
z[0] = OpType::startingValue(x);
|
||||
return;
|
||||
}
|
||||
|
||||
if(nd4j::ArrayOptions::arrayType(xShapeInfo) == nd4j::ArrayType::EMPTY) {
|
||||
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
|
||||
return;
|
||||
const auto startingVal = OpType::startingValue(x);
|
||||
PRAGMA_OMP_PARALLEL_FOR_IF(length > nd4j::Environment::getInstance()->elementwiseThreshold())
|
||||
for (uint i = 0; i < length; i++)
|
||||
z[i] = startingVal;
|
||||
return;
|
||||
}
|
||||
|
||||
if (xEws >= 1) {
|
||||
z[0] = execScalar<OpType>(x, xEws, length, extraParams);
|
||||
}
|
||||
|
@ -159,6 +174,16 @@ namespace functions {
|
|||
|
||||
auto resultLength = shape::length(zShapeInfo);
|
||||
|
||||
if(nd4j::ArrayOptions::arrayType(xShapeInfo) == nd4j::ArrayType::EMPTY) {
|
||||
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
|
||||
return;
|
||||
const auto startingVal = OpType::startingValue(x);
|
||||
PRAGMA_OMP_PARALLEL_FOR_IF(resultLength > nd4j::Environment::getInstance()->elementwiseThreshold())
|
||||
for (uint i = 0; i < resultLength; i++)
|
||||
z[i] = startingVal;
|
||||
return;
|
||||
}
|
||||
|
||||
//pre squeezed: this is for keeping the pointer to the original
|
||||
//shape information for tad offset
|
||||
//the squeezed information doesn't render the right strides for
|
||||
|
|
|
@ -48,6 +48,20 @@ namespace functions {
|
|||
const auto xEws = shape::elementWiseStride(xShapeInfo);
|
||||
const int rank = shape::rank(xShapeInfo);
|
||||
|
||||
if (shape::isEmpty(xShapeInfo)) {
|
||||
z[0] = OpType::startingValue(x);
|
||||
return;
|
||||
}
|
||||
|
||||
if(nd4j::ArrayOptions::arrayType(xShapeInfo) == nd4j::ArrayType::EMPTY) {
|
||||
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
|
||||
return;
|
||||
const auto startingVal = OpType::startingValue(x);
|
||||
PRAGMA_OMP_PARALLEL_FOR_IF(length > nd4j::Environment::getInstance()->elementwiseThreshold())
|
||||
for (uint i = 0; i < length; i++)
|
||||
z[i] = startingVal;
|
||||
return;
|
||||
}
|
||||
|
||||
if (xEws >= 1) {
|
||||
z[0] = execScalar<OpType>(x, xEws, length, extraParams);
|
||||
|
@ -71,7 +85,7 @@ namespace functions {
|
|||
for (int e = 0; e < maxThreads; e++)
|
||||
start = OpType::update(start, intermediate[e], extraParams);
|
||||
|
||||
z[0] = OpType::postProcess(start, shape::length(xShapeInfo), extraParams);
|
||||
z[0] = OpType::postProcess(start, length, extraParams);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -171,6 +185,16 @@ namespace functions {
|
|||
|
||||
auto zLength = shape::length(zShapeInfo);
|
||||
|
||||
if(nd4j::ArrayOptions::arrayType(xShapeInfo) == nd4j::ArrayType::EMPTY) {
|
||||
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
|
||||
return;
|
||||
const auto startingVal = OpType::startingValue(x);
|
||||
PRAGMA_OMP_PARALLEL_FOR_IF(zLength > nd4j::Environment::getInstance()->elementwiseThreshold())
|
||||
for (uint i = 0; i < zLength; i++)
|
||||
z[i] = startingVal;
|
||||
return;
|
||||
}
|
||||
|
||||
//pre squeezed: this is for keeping the pointer to the original
|
||||
//shape information for tad offset
|
||||
//the squeezed information doesn't render the right strides for
|
||||
|
|
|
@ -47,6 +47,16 @@ void Reduce3<X,Z>::execScalar(void *vx, Nd4jLong *xShapeInfo,
|
|||
auto xEws = shape::elementWiseStride(xShapeInfo);
|
||||
auto yEws = shape::elementWiseStride(yShapeInfo);
|
||||
|
||||
if(nd4j::ArrayOptions::arrayType(xShapeInfo) == nd4j::ArrayType::EMPTY || nd4j::ArrayOptions::arrayType(yShapeInfo) == nd4j::ArrayType::EMPTY) {
|
||||
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
|
||||
return;
|
||||
const auto startingVal = OpType::startingValue(x);
|
||||
PRAGMA_OMP_PARALLEL_FOR_IF(length > nd4j::Environment::getInstance()->elementwiseThreshold())
|
||||
for (uint i = 0; i < length; i++)
|
||||
z[i] = startingVal;
|
||||
return;
|
||||
}
|
||||
|
||||
Z extraParamsVals[3] = {(Z) 0.0f, (Z) 0.0f, (Z) 0.0f};
|
||||
// it's possible case for EqualsWithEps op
|
||||
if (extraParams != nullptr)
|
||||
|
|
|
@ -47,8 +47,8 @@ namespace functions {
|
|||
Nd4jLong *xShapeInfo,
|
||||
void *extraParams,
|
||||
void *z,
|
||||
Nd4jLong *resultShapeInfoBuffer) {
|
||||
DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(biasCorrected, x, xShapeInfo, extraParams, z, resultShapeInfoBuffer), SUMMARY_STATS_OPS);
|
||||
Nd4jLong *zShapeInfo) {
|
||||
DISPATCH_BY_OPNUM_TT(execScalar, PARAMS(biasCorrected, x, xShapeInfo, extraParams, z, zShapeInfo), SUMMARY_STATS_OPS);
|
||||
}
|
||||
|
||||
template <typename X, typename Y>
|
||||
|
@ -58,10 +58,10 @@ namespace functions {
|
|||
Nd4jLong *xShapeInfo,
|
||||
void *extraParams,
|
||||
void *z,
|
||||
Nd4jLong *resultShapeInfoBuffer,
|
||||
Nd4jLong *zShapeInfo,
|
||||
int *dimension,
|
||||
int dimensionLength) {
|
||||
DISPATCH_BY_OPNUM_TT(exec, PARAMS(biasCorrected, x, xShapeInfo, extraParams, z, resultShapeInfoBuffer, dimension, dimensionLength), SUMMARY_STATS_OPS);
|
||||
DISPATCH_BY_OPNUM_TT(exec, PARAMS(biasCorrected, x, xShapeInfo, extraParams, z, zShapeInfo, dimension, dimensionLength), SUMMARY_STATS_OPS);
|
||||
}
|
||||
|
||||
template <typename X, typename Z>
|
||||
|
@ -71,7 +71,7 @@ namespace functions {
|
|||
Nd4jLong *xShapeInfo,
|
||||
void *vextraParams,
|
||||
void *vz,
|
||||
Nd4jLong *resultShapeInfoBuffer) {
|
||||
Nd4jLong *zShapeInfo) {
|
||||
auto z = reinterpret_cast<Z*>(vz);
|
||||
z[0] = execScalar<OpType>(biasCorrected, vx, xShapeInfo, vextraParams);
|
||||
}
|
||||
|
@ -108,20 +108,31 @@ namespace functions {
|
|||
void *vx,
|
||||
Nd4jLong *xShapeInfo,
|
||||
void *vextraParams,
|
||||
void *vresult,
|
||||
Nd4jLong *resultShapeInfoBuffer,
|
||||
void *vz,
|
||||
Nd4jLong *zShapeInfo,
|
||||
int *dimension,
|
||||
int dimensionLength) {
|
||||
auto x = reinterpret_cast<X *>(vx);
|
||||
auto z = reinterpret_cast<Z *>(vresult);
|
||||
auto extraParams = reinterpret_cast<Z *>(vextraParams);
|
||||
|
||||
if (shape::isScalar(resultShapeInfoBuffer)) {
|
||||
z[0] = execScalar<OpType>(biasCorrected, x, xShapeInfo, extraParams);
|
||||
auto x = reinterpret_cast<X *>(vx);
|
||||
auto z = reinterpret_cast<Z *>(vz);
|
||||
auto extraParams = reinterpret_cast<Z *>(vextraParams);
|
||||
int resultLength = shape::length(zShapeInfo);
|
||||
|
||||
if(nd4j::ArrayOptions::arrayType(xShapeInfo) == nd4j::ArrayType::EMPTY) {
|
||||
if(nd4j::ArrayOptions::arrayType(zShapeInfo) == nd4j::ArrayType::EMPTY)
|
||||
return;
|
||||
SummaryStatsData<X> comp;
|
||||
comp.initWithValue(x[0]);
|
||||
PRAGMA_OMP_PARALLEL_FOR_IF(resultLength > nd4j::Environment::getInstance()->elementwiseThreshold())
|
||||
for (uint i = 0; i < resultLength; i++)
|
||||
z[i] = OpType::getValue(biasCorrected, comp);
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
if (shape::isScalar(zShapeInfo)) {
|
||||
z[0] = execScalar<OpType>(biasCorrected, x, xShapeInfo, extraParams);
|
||||
return;
|
||||
}
|
||||
|
||||
//no-op
|
||||
if (dimensionLength < 1)
|
||||
|
@ -129,7 +140,6 @@ namespace functions {
|
|||
|
||||
auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(xShapeInfo, dimension, dimensionLength);
|
||||
|
||||
int resultLength = shape::length(resultShapeInfoBuffer);
|
||||
//pre squeezed: this is for keeping the pointer to the original
|
||||
//shape information for tad offset
|
||||
//the squeezed information doesn't render the right strides for
|
||||
|
|
|
@ -131,7 +131,7 @@ namespace nd4j {
|
|||
COPY_SHAPE(x, shapeE);
|
||||
COPY_SHAPE(y, shapeG);
|
||||
|
||||
auto shapeList = SHAPELIST(shapeE, shapeG);
|
||||
auto shapeList = SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG));
|
||||
|
||||
return shapeList;
|
||||
}
|
||||
|
|
|
@ -81,7 +81,7 @@ CUSTOM_OP_IMPL(conv1d, 2, 1, false, 0, 4) {
|
|||
auto outputReshaped = output ->reshape(output->ordering(), reshapeForOutput);
|
||||
auto weightsReshaped = weights->reshape(weights->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
|
||||
|
||||
ConvolutionUtils::conv2d(*block.launchContext(), inputReshaped, weightsReshaped, bias, outputReshaped, 1,kW, 1,sW, 0,pW, 1,1, isSameMode, isNCW);
|
||||
ConvolutionUtils::conv2d(block, inputReshaped, weightsReshaped, bias, outputReshaped, 1,kW, 1,sW, 0,pW, 1,1, isSameMode, isNCW);
|
||||
|
||||
delete inputReshaped;
|
||||
delete outputReshaped;
|
||||
|
@ -217,7 +217,7 @@ CUSTOM_OP_IMPL(conv1d_bp, 3, 2, false, 0, 4) {
|
|||
auto weightsReshaped = weights->reshape(weights->ordering(),{1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
|
||||
auto gradWReshaped = gradW ->reshape(gradW->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
|
||||
|
||||
ConvolutionUtils::conv2dBP(*block.launchContext(), inputReshaped, weightsReshaped, bias, gradOReshaped, gradIReshaped, gradWReshaped, gradB, 1,kW, 1,sW, 0,pW, 1,1, isSameMode, isNCW);
|
||||
ConvolutionUtils::conv2dBP(block, inputReshaped, weightsReshaped, bias, gradOReshaped, gradIReshaped, gradWReshaped, gradB, 1,kW, 1,sW, 0,pW, 1,1, isSameMode, isNCW);
|
||||
|
||||
delete inputReshaped;
|
||||
delete gradIReshaped;
|
||||
|
|
|
@ -63,7 +63,7 @@ CUSTOM_OP_IMPL(conv2d, 2, 1, false, 0, 9) {
|
|||
if (bias)
|
||||
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||
|
||||
ConvolutionUtils::conv2d(*block.launchContext(), input, weights, bias, output, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW);
|
||||
ConvolutionUtils::conv2d(block, input, weights, bias, output, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -194,7 +194,7 @@ CUSTOM_OP_IMPL(conv2d_bp, 3, 2, false, 0, 9) {
|
|||
if(bias)
|
||||
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV2D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||
|
||||
ConvolutionUtils::conv2dBP(*block.launchContext(), input, weights, bias, gradO, gradI, gradW, gradB, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW);
|
||||
ConvolutionUtils::conv2dBP(block, input, weights, bias, gradO, gradI, gradW, gradB, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -305,7 +305,7 @@ CUSTOM_OP_IMPL(conv2d_input_bp, 3, 1, false, 0, 9) {
|
|||
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "CUSTOM CONV2D_INPUT_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, "CUSTOM CONV2D_INPUT_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||
|
||||
ConvolutionUtils::conv2dBP(*block.launchContext(), &input, weights, nullptr, gradO, gradI, nullptr, nullptr, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW);
|
||||
ConvolutionUtils::conv2dBP(block, &input, weights, nullptr, gradO, gradI, nullptr, nullptr, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -157,7 +157,7 @@ CUSTOM_OP_IMPL(conv3dnew, 2, 1, false, 0, 13) {
|
|||
permutForOutput = {0,2,3,4,1}; // [bS, oC, oD, oH, oW] -> [bS, oD, oH, oW, oC]
|
||||
|
||||
NDArray columns(input->ordering(), {bS, iC, kD, kH, kW, oD, oH, oW}, input->dataType(), block.launchContext());
|
||||
ConvolutionUtils::vol2col(*block.launchContext(), *input, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, iC, iD, iH, iW] is convoluted to [bS, iC, kD, kH, kW, oD, oH, oW]
|
||||
ConvolutionUtils::vol2col(block, *input, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, iC, iD, iH, iW] is convoluted to [bS, iC, kD, kH, kW, oD, oH, oW]
|
||||
// [bS, iC, kD, kH, kW, oD, oH, oW] x [kD, kH, kW, iC, oC] = [bS, oD, oH, oW, oC]
|
||||
MmulHelper::tensorDot(&columns, weights, output, {1,2,3,4}, {3,0,1,2}, permutForOutput);
|
||||
|
||||
|
@ -456,7 +456,7 @@ CUSTOM_OP_IMPL(conv3dnew_bp, 3, 2, false, 0, 13) {
|
|||
|
||||
// ----- calculation of gradW and gradB ----- //
|
||||
NDArray columns(input->ordering(), {bS, iC, kD, kH, kW, oD, oH, oW}, input->dataType(), block.launchContext());
|
||||
ConvolutionUtils::vol2col(*block.launchContext(), *input, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, iC, iD, iH, iW] is convoluted to [bS, iC, kD, kH, kW, oD, oH, oW]
|
||||
ConvolutionUtils::vol2col(block, *input, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, iC, iD, iH, iW] is convoluted to [bS, iC, kD, kH, kW, oD, oH, oW]
|
||||
MmulHelper::tensorDot(&columns, gradO, gradW, {0,5,6,7}, gradOaxesForDot, {3,0,1,2,4}); // [bS, iC, kD, kH, kW, oD, oH, oW] x [bS, oD, oH, oW, oC]/[bS, oC, oD, oH, oW] = [iC, kD, kH, kW, oC]
|
||||
|
||||
if(gradB) {
|
||||
|
@ -469,7 +469,7 @@ CUSTOM_OP_IMPL(conv3dnew_bp, 3, 2, false, 0, 13) {
|
|||
|
||||
//----- calculation of gradI -----//
|
||||
MmulHelper::tensorDot(weights, gradO, &columns, {indWoC}, {indIOioC}, {2,3,4,1,0,5,6,7}); // [kD, kH, kW, iC, oC] x [bS, oD, oH, oW, oC]/[bS, oC, oD, oH, oW] = [kD, kH, kW, iC, bS, oD, oH, oW]
|
||||
ConvolutionUtils::col2vol(*block.launchContext(), columns, *gradI, sD, sH, sW, pD, pH, pW, dD, dH, dW); // columns [bS, iC, kD, kH, kW, oD, oH, oW] is de-convoluted to [bS, iC, iD, iH, iW]
|
||||
ConvolutionUtils::col2vol(block, columns, *gradI, sD, sH, sW, pD, pH, pW, dD, dH, dW); // columns [bS, iC, kD, kH, kW, oD, oH, oW] is de-convoluted to [bS, iC, iD, iH, iW]
|
||||
|
||||
if(!isNDHWC) {
|
||||
delete input;
|
||||
|
|
|
@ -70,7 +70,7 @@ CUSTOM_OP_IMPL(deconv2d_tf, 3, 1, false, 0, 9) {
|
|||
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "CUSTOM DECONV2D_TF OP: wrong shape of input array, basing on array with output shape expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, "CUSTOM DECONV2D_TF OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||
|
||||
ConvolutionUtils::conv2dBP(*block.launchContext(), &input, weights, nullptr, gradO, gradI, nullptr, nullptr, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW);
|
||||
ConvolutionUtils::conv2dBP(block, &input, weights, nullptr, gradO, gradI, nullptr, nullptr, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -75,7 +75,7 @@ CUSTOM_OP_IMPL(deconv3d, 2, 1, false, 0, 13) {
|
|||
// NDHWC: [kD, kH, kW, oC, iC] x [bS, iD, iH, iW, iC] = [kD, kH, kW, oC, bS, iD, iH, iW]
|
||||
// NCDHW: [iC, oC, kD, kH, kW] x [bS, iC, iD, iH, iW] = [oC, kD, kH, kW, bS, iD, iH, iW]
|
||||
nd4j::MmulHelper::tensorDot(weights, input, &columns, {indWiC}, {indIOioC}, {2, 3, 4, 1, 0, 5, 6, 7}); // [bS, oC, kD, kH, kW, iD, iH, iW] -> [kD, kH, kW, oC, bS, iD, iH, iW]
|
||||
ConvolutionUtils::col2vol(*block.launchContext(), columns, *output, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, oC, kD, kH, kW, iD, iH, iW] is de-convoluted to [bS, oC, oD, oH, oW]
|
||||
ConvolutionUtils::col2vol(block, columns, *output, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, oC, kD, kH, kW, iD, iH, iW] is de-convoluted to [bS, oC, oD, oH, oW]
|
||||
|
||||
//----- add biases if required -----//
|
||||
if(bias)
|
||||
|
@ -234,7 +234,7 @@ CUSTOM_OP_IMPL(deconv3d_bp, 3, 2, false, 0, 13) {
|
|||
|
||||
// ----- calculation of gradW ----- //
|
||||
auto columns = NDArrayFactory::create(input->ordering(), {bS, oC, kD, kH, kW, iD, iH, iW}, input->dataType(), block.launchContext());
|
||||
ConvolutionUtils::vol2col(*block.launchContext(), *gradO, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, oC, oD, oH, oW] is deconvoluted to [bS, oC, kD, kH, kW, iD, iH, iW]
|
||||
ConvolutionUtils::vol2col(block, *gradO, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); // [bS, oC, oD, oH, oW] is deconvoluted to [bS, oC, kD, kH, kW, iD, iH, iW]
|
||||
MmulHelper::tensorDot(input, &columns, gradW, inputAxesForDot, {0, 5, 6, 7}, {4, 3, 0, 1, 2}); // [bS, iC, iD, iH, iW]/[bS, iD, iH, iW, iC] x [bS, oC, kD, kH, kW, iD, iH, iW] = [iC, oC, kD, kH, kW]
|
||||
|
||||
// ----- calculation of gradB ----- //
|
||||
|
|
|
@ -62,7 +62,7 @@ CUSTOM_OP_IMPL(depthwise_conv2d, 2, 1, false, 0, 9) {
|
|||
if (bias)
|
||||
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DEPTHWISECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||
|
||||
ConvolutionUtils::depthwiseConv2d(*block.launchContext(), input, weights, bias, output, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW);
|
||||
ConvolutionUtils::depthwiseConv2d(block, input, weights, bias, output, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -185,7 +185,7 @@ CUSTOM_OP_IMPL(depthwise_conv2d_bp, 3, 2, false, 0, 9) {
|
|||
if(bias)
|
||||
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DEPTHWISECONV2D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||
|
||||
ConvolutionUtils::depthwiseConv2dBP(*block.launchContext(), input, weights, bias, gradO, gradI, gradW, gradB, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW);
|
||||
ConvolutionUtils::depthwiseConv2dBP(block, input, weights, bias, gradO, gradI, gradW, gradB, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -58,7 +58,7 @@ CUSTOM_OP_IMPL(pointwise_conv2d, 2, 1, false, 0, 0) {
|
|||
if (bias)
|
||||
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM POINTWISECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||
|
||||
ConvolutionUtils::conv2d(*block.launchContext(), input, weights, bias, output, kH,kW, sH,sW, pH,pW, dH,dW, 1/*isSameMode*/, isNCHW);
|
||||
ConvolutionUtils::conv2d(block, input, weights, bias, output, kH,kW, sH,sW, pH,pW, dH,dW, 1/*isSameMode*/, isNCHW);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -70,7 +70,7 @@ CUSTOM_OP_IMPL(avgpool2d, 1, 1, false, 0, 10) {
|
|||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
||||
|
||||
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor;
|
||||
ConvolutionUtils::pooling2d(*block.launchContext(), *input, *output, kH, kW, sH, sW, pH, pW, dH, dW, PoolingType::AVG_POOL, extraParam0);
|
||||
ConvolutionUtils::pooling2d(block, *input, *output, kH, kW, sH, sW, pH, pW, dH, dW, PoolingType::AVG_POOL, extraParam0);
|
||||
//output->printBuffer("output op");
|
||||
|
||||
if (!isNCHW) {
|
||||
|
@ -198,7 +198,7 @@ CUSTOM_OP_IMPL(avgpool2d_bp, 2, 1, false, 0, 10) {
|
|||
// *gradI /= kH*kW;
|
||||
|
||||
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor;
|
||||
ConvolutionUtils::pooling2dBP(*block.launchContext(), *input, *gradO, *gradI, kH, kW, sH, sW, pH, pW, dH, dW, 1, extraParam0);
|
||||
ConvolutionUtils::pooling2dBP(block, *input, *gradO, *gradI, kH, kW, sH, sW, pH, pW, dH, dW, 1, extraParam0);
|
||||
|
||||
if(!isNCHW) {
|
||||
delete input;
|
||||
|
|
|
@ -69,7 +69,7 @@ CUSTOM_OP_IMPL(avgpool3dnew, 1, 1, false, 0, 14) {
|
|||
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
||||
|
||||
//T extraParams[] = {};
|
||||
ConvolutionUtils::pooling3d(*block.launchContext(), *input, *output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, 1, extraParam0);
|
||||
ConvolutionUtils::pooling3d(block, *input, *output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, 1, extraParam0);
|
||||
|
||||
if(!isNCDHW) {
|
||||
delete input;
|
||||
|
@ -189,7 +189,7 @@ CUSTOM_OP_IMPL(avgpool3dnew_bp, 2, 1, false, 0, 14) {
|
|||
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
||||
|
||||
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor;
|
||||
ConvolutionUtils::pooling3dBP(*block.launchContext(), *input, *gradO, *gradI, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, 1, extraParam0);
|
||||
ConvolutionUtils::pooling3dBP(block, *input, *gradO, *gradI, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, 1, extraParam0);
|
||||
|
||||
if(!isNCDHW) {
|
||||
delete input;
|
||||
|
|
|
@ -70,7 +70,7 @@ CUSTOM_OP_IMPL(maxpool2d, 1, 1, false, 0, 9) {
|
|||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
||||
|
||||
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; poolingMode; 9 - divisor;
|
||||
ConvolutionUtils::pooling2d(*block.launchContext(), *input, *output, kH, kW, sH, sW, pH, pW, dH, dW, PoolingType::MAX_POOL, 1);
|
||||
ConvolutionUtils::pooling2d(block, *input, *output, kH, kW, sH, sW, pH, pW, dH, dW, PoolingType::MAX_POOL, 1);
|
||||
|
||||
if (!isNCHW) {
|
||||
delete input;
|
||||
|
@ -196,7 +196,7 @@ CUSTOM_OP_IMPL(maxpool2d_bp, 2, 1, false, 0, 10) {
|
|||
|
||||
// columns->template applyTransform<simdOps::Col2Im<T>>(gradI, std::vector<T>({(T)sH, (T)sW, (T)pH, (T)pW, (T)iH, (T)iW, (T)dH, (T)dW}).data());
|
||||
|
||||
ConvolutionUtils::pooling2dBP(*block.launchContext(), *input, *gradO, *gradI, kH, kW, sH, sW, pH, pW, dH, dW, 0., 1.);
|
||||
ConvolutionUtils::pooling2dBP(block, *input, *gradO, *gradI, kH, kW, sH, sW, pH, pW, dH, dW, 0., 1.);
|
||||
|
||||
if(!isNCHW) {
|
||||
delete input;
|
||||
|
|
|
@ -70,7 +70,7 @@ CUSTOM_OP_IMPL(maxpool3dnew, 1, 1, false, 0, 14) {
|
|||
if(isSameMode) // SAME
|
||||
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
||||
|
||||
ConvolutionUtils::pooling3d(*block.launchContext(), *input, *output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, 0, 1);
|
||||
ConvolutionUtils::pooling3d(block, *input, *output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, 0, 1);
|
||||
|
||||
if(!isNCDHW) {
|
||||
delete input;
|
||||
|
@ -204,7 +204,7 @@ CUSTOM_OP_IMPL(maxpool3dnew_bp, 2, 1, false, 0, 14) {
|
|||
// ConvolutionUtils<T>::col2vol(*columns, *gradI, sD, sH, sW, pD, pH, pW, dD, dH, dW); // columns [bS, iC, kD, kH, kW, oD, oH, oW] is de-convoluted to [bS, iC, iD, iH, iW]
|
||||
|
||||
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - unnecessary;
|
||||
ConvolutionUtils::pooling3dBP(*block.launchContext(), *input, *gradO, *gradI, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, 0, 1);
|
||||
ConvolutionUtils::pooling3dBP(block, *input, *gradO, *gradI, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, 0, 1);
|
||||
|
||||
if(!isNCDHW) {
|
||||
delete input;
|
||||
|
|
|
@ -68,7 +68,7 @@ namespace nd4j {
|
|||
ConvolutionUtils::calcPadding2D(pY, pX, oY, oX, inY, inX, kY, kX, sY, sX, dY, dX);
|
||||
|
||||
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor;
|
||||
ConvolutionUtils::pooling2d(*block.launchContext(), *input, *output, kY, kX, sY, sX, pY, pX, dY, dX, PoolingType::PNORM_POOL, extraParam0);
|
||||
ConvolutionUtils::pooling2d(block, *input, *output, kY, kX, sY, sX, pY, pX, dY, dX, PoolingType::PNORM_POOL, extraParam0);
|
||||
|
||||
if (!isNCHW) {
|
||||
delete input;
|
||||
|
@ -209,7 +209,7 @@ CUSTOM_OP_IMPL(pnormpool2d_bp, 2, 1, false, 1, 10) {
|
|||
|
||||
// columns->template applyTransform<simdOps::Col2Im<T>>(gradI, std::vector<T>({(T)sH, (T)sW, (T)pH, (T)pW, (T)iH, (T)iW, (T)dH, (T)dW}).data());
|
||||
|
||||
ConvolutionUtils::pooling2dBP(*block.launchContext(), *input, *gradO, *gradI, kH, kW, sH, sW, pH, pW, dH, dW, 2, pnorm);
|
||||
ConvolutionUtils::pooling2dBP(block, *input, *gradO, *gradI, kH, kW, sH, sW, pH, pW, dH, dW, 2, pnorm);
|
||||
|
||||
if(!isNCHW) {
|
||||
delete input;
|
||||
|
|
|
@ -84,11 +84,11 @@ CUSTOM_OP_IMPL(sconv2d, 2, 1, false, 0, 9) {
|
|||
|
||||
if (iC == 1) {
|
||||
nd4j_debug("SCONV2D OP: for input_channels = 1 this op is equivalent to standard conv2d\n","");
|
||||
ConvolutionUtils::conv2d(*block.launchContext(), input, weightsDepth, bias, output, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW);
|
||||
ConvolutionUtils::conv2d(block, input, weightsDepth, bias, output, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
ConvolutionUtils::sconv2d(*block.launchContext(), input, weightsDepth, weightsPoint, bias, output, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW);
|
||||
ConvolutionUtils::sconv2d(block, input, weightsDepth, weightsPoint, bias, output, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -274,12 +274,12 @@ CUSTOM_OP_IMPL(sconv2d_bp, 3, 2, false, 0, 9) {
|
|||
|
||||
auto resultFFShape = isNCHW ? std::vector<Nd4jLong>({bS, mC*iC, oH, oW}) : std::vector<Nd4jLong>({bS, oH, oW, mC*iC});
|
||||
auto resultFF = NDArrayFactory::create_(input->ordering(), resultFFShape, input->dataType(), block.launchContext());
|
||||
ConvolutionUtils::sconv2d(*block.launchContext(), input, weightsDepth, nullptr, nullptr, resultFF, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW);
|
||||
ConvolutionUtils::sconv2d(block, input, weightsDepth, nullptr, nullptr, resultFF, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW);
|
||||
|
||||
auto gradIDepthShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC*mC,oH,oW, 0,indIOioC,indIiH,indIiH+1});
|
||||
auto gradIDepth = NDArrayFactory::create_(resultFF->ordering(), gradIDepthShape, resultFF->dataType(), block.launchContext()); // [bS, oH, oW, iC*mC] (NHWC) or [bS, iC*mC, oH, oW] (NCHW)
|
||||
|
||||
ConvolutionUtils::conv2dBP(*block.launchContext(), resultFF, weightsPoint, bias, gradO, gradIDepth, gradWP, gradB, 1,1, 1,1, 0,0, 1,1, isSameMode, isNCHW); // in this case oH=iH and oW=iW
|
||||
ConvolutionUtils::conv2dBP(block, resultFF, weightsPoint, bias, gradO, gradIDepth, gradWP, gradB, 1,1, 1,1, 0,0, 1,1, isSameMode, isNCHW); // in this case oH=iH and oW=iW
|
||||
|
||||
gradO = gradIDepth;
|
||||
bias = gradB = nullptr; // if pointwise backprop was done then don't calculate gradB at depthwise_conv2d_bp step
|
||||
|
@ -288,7 +288,7 @@ CUSTOM_OP_IMPL(sconv2d_bp, 3, 2, false, 0, 9) {
|
|||
}
|
||||
|
||||
// ----- apply depthwise_conv2d_bp ----- //
|
||||
ConvolutionUtils::depthwiseConv2dBP(*block.launchContext(), input, weightsDepth, bias, gradO, gradI, gradWD, gradB, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW);
|
||||
ConvolutionUtils::depthwiseConv2dBP(block, input, weightsDepth, bias, gradO, gradI, gradWD, gradB, kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, isNCHW);
|
||||
|
||||
if(weightsPoint)
|
||||
delete gradO;
|
||||
|
|
|
@ -41,7 +41,7 @@ CUSTOM_OP_IMPL(upsampling2d, 1, 1, false, 0, 2) {
|
|||
REQUIRE_TRUE(input->rankOf() == 4, 0, "UPSAMPLING2D op: input should be 4D, but got %i instead!", input->rankOf());
|
||||
REQUIRE_TRUE(output->rankOf() == 4, 0, "UPSAMPLING2D op: output should be 4D, but got %i instead!", output->rankOf());
|
||||
|
||||
ConvolutionUtils::upsampling2d(*block.launchContext(), *input, *output, factorH, factorW, (bool)isNCHW);
|
||||
ConvolutionUtils::upsampling2d(block, *input, *output, factorH, factorW, (bool)isNCHW);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -105,7 +105,7 @@ CUSTOM_OP_IMPL(upsampling2d_bp, 2, 1, false, 0, 0) {
|
|||
REQUIRE_TRUE(gradO->rankOf() == 4, 0, "UPSAMPLING2D_BP op: output's gradient array must be 4D, but got %i instead!", gradO->rankOf());
|
||||
REQUIRE_TRUE(gradI->rankOf() == 4, 0, "UPSAMPLING2D_BP op: input's gradient array must be 4D, but got %i instead!", gradI->rankOf());
|
||||
|
||||
ConvolutionUtils::upsampling2dBP(*block.launchContext(), *gradO, *gradI, (bool)isNCHW);
|
||||
ConvolutionUtils::upsampling2dBP(block, *gradO, *gradI, (bool)isNCHW);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -41,7 +41,7 @@ CUSTOM_OP_IMPL(upsampling3d, 1, 1, false, 0, 3) {
|
|||
REQUIRE_TRUE(input->rankOf() == 5, 0, "UPSAMPLING3D op: input should be 5D, but got %i instead!", input->rankOf());
|
||||
REQUIRE_TRUE(output->rankOf() == 5, 0, "UPSAMPLING3D op: output should be 5D, but got %i instead!", output->rankOf());
|
||||
|
||||
ConvolutionUtils::upsampling3d(*block.launchContext(), *input, *output, factorD, factorH, factorW, (bool)isNCDHW);
|
||||
ConvolutionUtils::upsampling3d(block, *input, *output, factorD, factorH, factorW, (bool)isNCDHW);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -105,7 +105,7 @@ CUSTOM_OP_IMPL(upsampling3d_bp, 2, 1, false, 0, 0) {
|
|||
REQUIRE_TRUE(gradO->rankOf() == 5, 0, "UPSAMPLING3D_BP op: output's gradient array must be 4D, but got %i instead!", gradO->rankOf());
|
||||
REQUIRE_TRUE(gradI->rankOf() == 5, 0, "UPSAMPLING3D_BP op: input's gradient array must be 4D, but got %i instead!", gradI->rankOf());
|
||||
|
||||
ConvolutionUtils::upsampling3dBP(*block.launchContext(), *gradO, *gradI, (bool)isNCDHW);
|
||||
ConvolutionUtils::upsampling3dBP(block, *gradO, *gradI, (bool)isNCDHW);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -216,7 +216,7 @@ CUSTOM_OP_IMPL(batchnorm_new, 3, 1, false, 1, 2) {
|
|||
}
|
||||
|
||||
std::vector<Nd4jLong> shape({2, mean->lengthOf()});
|
||||
NDArray weights = NDArrayFactory::create<float>('c', shape, block.getWorkspace());
|
||||
NDArray weights = NDArrayFactory::create<float>('c', shape, block.launchContext());
|
||||
weights({0, 1, 0, 0}).assign(1.0f);
|
||||
weights({1, 2, 0, 0}).assign(0.0f);
|
||||
|
||||
|
|
|
@ -72,6 +72,11 @@ namespace nd4j {
|
|||
if (dims.size() > 1)
|
||||
std::sort(dims.begin(), dims.end());
|
||||
|
||||
|
||||
for (auto d:dims) {
|
||||
REQUIRE_TRUE(inputShape->at(0)[d+1] != 0, 0, "ArgMax: you can't reduce along axis with 0 in shape");
|
||||
}
|
||||
|
||||
// special case - output is scalar
|
||||
if (dims.size() == 0 || (dims.size() == 1 && dims.at(0) == MAX_INT)) {
|
||||
return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(nd4j::DataType::INT64));
|
||||
|
|
|
@ -72,6 +72,10 @@ namespace nd4j {
|
|||
if (dims.size() > 1)
|
||||
std::sort(dims.begin(), dims.end());
|
||||
|
||||
for (auto d:dims) {
|
||||
REQUIRE_TRUE(inputShape->at(0)[d+1] != 0, 0, "ArgMin: you can't reduce along axis with 0 in shape");
|
||||
}
|
||||
|
||||
// special case - output is scalar
|
||||
if (dims.size() == 0 || (dims.size() == 1 && dims.at(0) == MAX_INT)) {
|
||||
return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(DataType::INT64));
|
||||
|
|
|
@ -71,10 +71,8 @@ namespace nd4j {
|
|||
ALLOCATE(newShape, block.getWorkspace(), shape::shapeInfoLength(len), Nd4jLong);
|
||||
|
||||
newShape[0] = len;
|
||||
auto empty = false;
|
||||
for (int e = 0; e < shapeArray->lengthOf(); e++){
|
||||
newShape[e+1] = shapeArray->e<Nd4jLong>(e);
|
||||
empty |= (newShape[e+1] == 0); //Support "zeros in shape as empty" for TF import
|
||||
}
|
||||
|
||||
nd4j::DataType dataType;
|
||||
|
@ -90,10 +88,6 @@ namespace nd4j {
|
|||
} else
|
||||
throw std::runtime_error("Fill: missing value to fill output array with");
|
||||
|
||||
if(empty){
|
||||
return SHAPELIST(ShapeBuilders::emptyShapeInfo(dataType, block.getWorkspace()));
|
||||
}
|
||||
|
||||
ShapeUtils::updateStridesAndType(newShape, dataType, 'c');
|
||||
|
||||
return SHAPELIST(CONSTANT(newShape));
|
||||
|
|
|
@ -151,8 +151,10 @@ DECLARE_SHAPE_FN(range) {
|
|||
delta = INPUT_VARIABLE(2)->e<double>(0);
|
||||
}
|
||||
|
||||
if (limit == start)
|
||||
return SHAPELIST(ConstantShapeHelper::getInstance()->emptyShapeInfo(dtype));
|
||||
if (limit == start){
|
||||
//Return [0] to match TF
|
||||
return SHAPELIST(ConstantShapeHelper::getInstance()->vectorShapeInfo(0, dtype));
|
||||
}
|
||||
|
||||
REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !");
|
||||
|
||||
|
@ -177,8 +179,10 @@ DECLARE_SHAPE_FN(range) {
|
|||
|
||||
//nd4j_printf("Start: [%lld]; Limit: [%lld]; Delta: [%lld];\n", start, limit, delta)
|
||||
|
||||
if (limit == start)
|
||||
return SHAPELIST(ConstantShapeHelper::getInstance()->emptyShapeInfo(dtype));
|
||||
if (limit == start){
|
||||
//Return [0] to match TF
|
||||
return SHAPELIST(ConstantShapeHelper::getInstance()->vectorShapeInfo(0, dtype));
|
||||
}
|
||||
|
||||
REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !");
|
||||
|
||||
|
@ -203,8 +207,10 @@ DECLARE_SHAPE_FN(range) {
|
|||
delta = INT_ARG(2);
|
||||
}
|
||||
|
||||
if (limit == start)
|
||||
return SHAPELIST(ConstantShapeHelper::getInstance()->emptyShapeInfo(nd4j::DataType::INT32));
|
||||
if (limit == start){
|
||||
//Return [0] to match TF
|
||||
return SHAPELIST(ConstantShapeHelper::getInstance()->vectorShapeInfo(0, nd4j::DataType::INT32));
|
||||
}
|
||||
|
||||
REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !");
|
||||
|
||||
|
@ -233,9 +239,10 @@ DECLARE_SHAPE_FN(range) {
|
|||
delta = T_ARG(2);
|
||||
}
|
||||
|
||||
//REQUIRE_TRUE(limit != start, 0, "CUSTOM RANGE OP: limit and start values should be different, but got both equal to %f !", limit);
|
||||
if (limit == start)
|
||||
return SHAPELIST(ConstantShapeHelper::getInstance()->emptyShapeInfo(Environment::getInstance()->defaultFloatDataType()));
|
||||
if (limit == start){
|
||||
//Return [0] to match TF
|
||||
return SHAPELIST(ConstantShapeHelper::getInstance()->vectorShapeInfo(0, Environment::getInstance()->defaultFloatDataType()));
|
||||
}
|
||||
|
||||
|
||||
REQUIRE_TRUE(delta != 0, 0, "CUSTOM RANGE OP: delta should not be equal to zero !");
|
||||
|
|
|
@ -31,7 +31,8 @@ namespace nd4j {
|
|||
|
||||
REQUIRE_TRUE(output->isScalar(), 0, "Rank output should be scalar");
|
||||
|
||||
output->assign(static_cast<Nd4jLong>(input->rankOf()));
|
||||
// output->assign(static_cast<Nd4jLong>(input->rankOf()));
|
||||
output->assign(input->rankOf());
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -132,17 +132,13 @@ namespace nd4j {
|
|||
REQUIRE_TRUE(size == -1 || size >= 0, 0, "Invalid size[%i] value: must be positive (or -1 for 'all remaining'), got %i", e, size, inShape[e+1]);
|
||||
REQUIRE_TRUE(start >= 0 && start <= inShape[e+1], 0, "Invalid begin[%i] value: Begin must satisfy 0 <= begin <= size[i], got begin=%i for dimension size %i", e, start, inShape[e+1]);
|
||||
REQUIRE_TRUE(start + size <= inShape[e+1], 0, "Slice: interval [%i, %i] is out of bounds for dimension %i with size %i", start, start + size, e, inShape[e+1]);
|
||||
if(start == inShape[e+1] || size == 0 ){
|
||||
empty = true;
|
||||
if(start == inShape[e+1] ){
|
||||
size = 0;
|
||||
}
|
||||
|
||||
shape.emplace_back(size);
|
||||
}
|
||||
|
||||
if(empty){
|
||||
return SHAPELIST(ShapeBuilders::emptyShapeInfo(nd4j::DataType::INT32, block.getWorkspace()));
|
||||
}
|
||||
|
||||
auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShape), 'c', shape);
|
||||
return SHAPELIST(newShape);
|
||||
}
|
||||
|
|
|
@ -34,6 +34,10 @@ CUSTOM_OP_IMPL(stack, -1, 1, false, 0, 0) {
|
|||
if(dim < 0)
|
||||
dim += input->rankOf() + 1;
|
||||
|
||||
// no-op in case of empty output array
|
||||
if (output->isEmpty())
|
||||
return Status::OK();
|
||||
|
||||
// input validation
|
||||
// check whether shapes of all input array are the same
|
||||
for (int i = 0; i < (int) block.width() - 1; ++i)
|
||||
|
@ -48,16 +52,6 @@ CUSTOM_OP_IMPL(stack, -1, 1, false, 0, 0) {
|
|||
|
||||
helpers::stack(block.launchContext(), inArrs, output, dim);
|
||||
|
||||
// remove unity from output shape if input arrays are vectors
|
||||
// if(input->isVector()) {
|
||||
// std::vector<int> outShape(output->shapeOf(), output->shapeOf() + output->rankOf());
|
||||
// outShape.erase(find(outShape.begin(), outShape.end(), 1));
|
||||
// output->reshapei(output->ordering(), outShape);
|
||||
// if(dim != 0 && (int)block.width() == 1) // such is implemented by tensorFlow
|
||||
// output->permutei({1, 0});
|
||||
// output->getShapeInfo()[output->rankOf()*2 + 2] = 1;
|
||||
// }
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
DECLARE_SYN(pack, stack);
|
||||
|
@ -82,6 +76,20 @@ DECLARE_SHAPE_FN(stack) {
|
|||
|
||||
REQUIRE_TRUE(dim <= inShapeInfo[0], 0, "STACK op: the input dimension parameter must be <= rank of input arrays shapes (rank=%i), but got %i instead !", inShapeInfo[0], dim);
|
||||
|
||||
// empty input arrays require some special handling
|
||||
if (shape::isEmpty(inShapeInfo)) {
|
||||
switch (rank) {
|
||||
case 0: {
|
||||
// we're going to return rank 1 here
|
||||
if (block.width() == 1) {
|
||||
return SHAPELIST(ConstantShapeHelper::getInstance()->vectorShapeInfo(0, ArrayOptions::dataType(inShapeInfo)));
|
||||
} else {
|
||||
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShapeInfo), 'c', {(Nd4jLong) block.width(), 0}));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if(rank == 0) {
|
||||
return SHAPELIST(ConstantShapeHelper::getInstance()->vectorShapeInfo(block.width(), ArrayOptions::dataType(inShapeInfo)));
|
||||
}
|
||||
|
@ -94,10 +102,6 @@ DECLARE_SHAPE_FN(stack) {
|
|||
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(inShapeInfo), shape::order(inShapeInfo), outShape)));
|
||||
}
|
||||
|
||||
// 1) 1х4 + 1х4 = 2х1х4 (along dim=0) = 2x4
|
||||
// 2) 1х4 + 1х4 = 1х2х4 (along dim=1) = 2x4
|
||||
// 3) 4х1 + 4х1 = 2х4x1 (along dim=0) = 2x4
|
||||
// 4) 4х1 + 4х1 = 4х2x1 (along dim=1) = 4x2
|
||||
|
||||
}
|
||||
}
|
||||
|
|
|
@ -410,10 +410,13 @@ namespace nd4j {
|
|||
// z->assign(x->e<float>(indices[0]));
|
||||
// }
|
||||
// else {
|
||||
if (indices.size()) {
|
||||
auto sub = (*x)(indices, true, true);
|
||||
z->assign(sub);
|
||||
// }
|
||||
|
||||
}
|
||||
else if (!z->isEmpty()){
|
||||
z->assign(x->e(0));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
DECLARE_SYN(stridedslice, strided_slice);
|
||||
|
@ -496,28 +499,19 @@ namespace nd4j {
|
|||
bool is_simple_slice;
|
||||
bool is_dim0;
|
||||
|
||||
// FIXME: remove this, once we bring in 1D NDArrays
|
||||
//vectorize(input_shape);
|
||||
bool result = _preprocess_strided_slice(nullptr, &shape, input_shape, begin, end, strides, begin_mask, ellipsis_mask, end_mask, new_axis_mask, shrink_axis_mask, &is_identity, &is_simple_slice, &is_dim0);
|
||||
bool nonEmpty = shape.size() > 0;
|
||||
if (nonEmpty)
|
||||
for (auto x: shape) {
|
||||
if (x == 0) {
|
||||
nonEmpty = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (nonEmpty && inputLen > 1) {
|
||||
newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShape), 'c', shape);
|
||||
}
|
||||
else {
|
||||
if (shape::rank(inShape) == 0 || begin >= end) {
|
||||
newShape = ConstantShapeHelper::getInstance()->emptyShapeInfo(ArrayOptions::dataType(inShape));
|
||||
std::vector<Nd4jLong> indices;
|
||||
bool result = _preprocess_strided_slice(&indices, &shape, input_shape, begin, end, strides, begin_mask, ellipsis_mask, end_mask, new_axis_mask, shrink_axis_mask, &is_identity, &is_simple_slice, &is_dim0);
|
||||
if (indices.size()) {
|
||||
newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShape), 'c',
|
||||
shape);
|
||||
if (inputLen > 1) {
|
||||
newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShape), 'c',
|
||||
shape);
|
||||
} else {
|
||||
newShape = ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(inShape));
|
||||
|
||||
}
|
||||
}
|
||||
} else
|
||||
newShape = ConstantShapeHelper::getInstance()->emptyShapeInfo(ArrayOptions::dataType(inShape));
|
||||
|
||||
return SHAPELIST(newShape);
|
||||
}
|
||||
|
|
|
@ -37,6 +37,9 @@ namespace nd4j {
|
|||
REQUIRE_TRUE(dim < input->rankOf(), 0, "Unstack dimension should be lower then rank of input %i, but got dimension=%i !", input->rankOf(), dim);
|
||||
REQUIRE_TRUE(dim >= 0, 0, "Unstack dimension should be non-negative value, but got %i !", dim);
|
||||
|
||||
if(input->isEmpty())
|
||||
return Status::OK();
|
||||
|
||||
std::vector<int> dims;
|
||||
for (int e = 0; e < input->rankOf(); e++)
|
||||
if (e != dim)
|
||||
|
@ -76,6 +79,21 @@ namespace nd4j {
|
|||
REQUIRE_TRUE(dim < inShape[0], 0, "UNSTACK op: dimension should be lower then rank of input %i, but got dimension=%i !", inShape[0], dim);
|
||||
REQUIRE_TRUE(dim >= 0, 0, "UNSTACK op: dimension should be non-negative value, but got %i !", dim);
|
||||
|
||||
if(ArrayOptions::arrayType(inShape) == ArrayType::EMPTY) {
|
||||
if(shape::shapeOf(inShape)[dim] == 0)
|
||||
return SHAPELIST();
|
||||
const Nd4jLong numTads = shape::shapeOf(inShape)[dim];
|
||||
std::vector<Nd4jLong> outShape;
|
||||
for(uint i = 0; i < shape::rank(inShape); ++i)
|
||||
if(i != dim)
|
||||
outShape.push_back(shape::shapeOf(inShape)[i]);
|
||||
|
||||
auto result = SHAPELIST();
|
||||
for(uint i = 0; i < numTads; ++i)
|
||||
result->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShape), shape::order(inShape), outShape));
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<int> dims;
|
||||
for (int e = 0; e < shape::rank(inShape); e++)
|
||||
if (e != dim)
|
||||
|
|
|
@ -30,6 +30,12 @@ namespace nd4j {
|
|||
auto output = OUTPUT_VARIABLE(0);
|
||||
|
||||
REQUIRE_TRUE(output->isScalar(), 0, "Rank output should be scalar");
|
||||
|
||||
if(input->isEmpty()){
|
||||
output->p<double>(0, std::numeric_limits<double>::quiet_NaN());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
int numZeros = 0;
|
||||
// for (int e = 0; e < input->lengthOf(); e++)
|
||||
// if ((*input)(e) == T(0))
|
||||
|
|
|
@ -113,16 +113,10 @@ DECLARE_SHAPE_FN(lstmBlock) {
|
|||
}
|
||||
ShapeUtils::updateStridesAndType(s, x, 'c');
|
||||
|
||||
Nd4jLong *s1, *s2, *s3, *s4, *s5, *s6;
|
||||
COPY_SHAPE(s, s1);
|
||||
COPY_SHAPE(s, s2);
|
||||
COPY_SHAPE(s, s3);
|
||||
COPY_SHAPE(s, s4);
|
||||
COPY_SHAPE(s, s5);
|
||||
COPY_SHAPE(s, s6);
|
||||
Nd4jLong *s1 = CONSTANT(s);
|
||||
|
||||
//7 outputs, all same shape/type
|
||||
return SHAPELIST(s, s1, s2, s3, s4, s5, s6);
|
||||
return SHAPELIST(s1, s1, s1, s1, s1, s1, s1);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -115,16 +115,10 @@ DECLARE_SHAPE_FN(lstmBlockCell) {
|
|||
|
||||
ShapeUtils::updateStridesAndType(s, xt, 'c');
|
||||
|
||||
Nd4jLong *s1, *s2, *s3, *s4, *s5, *s6;
|
||||
COPY_SHAPE(s, s1);
|
||||
COPY_SHAPE(s, s2);
|
||||
COPY_SHAPE(s, s3);
|
||||
COPY_SHAPE(s, s4);
|
||||
COPY_SHAPE(s, s5);
|
||||
COPY_SHAPE(s, s6);
|
||||
Nd4jLong *s1 = CONSTANT(s);
|
||||
|
||||
//7 outputs, all same shape: z, i, f, o, h, c, y
|
||||
return SHAPELIST(s, s1, s2, s3, s4, s5, s6);
|
||||
return SHAPELIST(s1, s1, s1, s1, s1, s1, s1);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -78,7 +78,6 @@ DECLARE_SHAPE_FN(broadcast_to) {
|
|||
REQUIRE_TRUE(inputShapeInfo[inputRank+1-i] == outShape[shapeLen-i] || inputShapeInfo[inputRank+1-i] == 1, 0, "BROADCAST_TO op: shape of input array %s can't be broadcasted to the shape %s !", ShapeUtils::shapeAsString(inputShapeInfo).c_str(), ShapeUtils::shapeAsString(outShape).c_str());
|
||||
|
||||
auto outShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inputShapeInfo), shape::order(inputShapeInfo), outShape);
|
||||
|
||||
return SHAPELIST(outShapeInfo);
|
||||
}
|
||||
|
||||
|
|
|
@ -34,26 +34,26 @@ namespace nd4j {
|
|||
|
||||
bool replace = false;
|
||||
|
||||
auto arguments = block.getIArguments();
|
||||
if (block.width() == 2 && arguments->size() == 0) {
|
||||
auto axis = INPUT_VARIABLE(1);
|
||||
for (int e = 0; e < axis->lengthOf(); e++) {
|
||||
int ax = axis->e<int>(e);
|
||||
auto origArgs = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT<int>() : *block.getIArguments();
|
||||
std::vector<int> arguments({});
|
||||
if(origArgs.size() > 0){
|
||||
for (int e = 0; e < origArgs.size(); e++) {
|
||||
int ax = origArgs[e];
|
||||
if (ax < 0)
|
||||
ax += x->rankOf();
|
||||
|
||||
arguments->emplace_back(ax);
|
||||
arguments.emplace_back(ax);
|
||||
}
|
||||
|
||||
replace = true;
|
||||
} else if (arguments->size() == 0) {
|
||||
} else {
|
||||
for (int e = x->rankOf() - 1; e >= 0; e--)
|
||||
arguments->emplace_back(e);
|
||||
arguments.emplace_back(e);
|
||||
}
|
||||
|
||||
// 0D edge case
|
||||
if (x->rankOf() == 0) {
|
||||
REQUIRE_TRUE(arguments->size() == 1, 0, "Permute: only one axis is allowed for scalar");
|
||||
REQUIRE_TRUE(arguments.size() == 1, 0, "Permute: only one axis is allowed for scalar");
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
if (!block.isInplace())
|
||||
output->assign(x);
|
||||
|
@ -62,24 +62,16 @@ namespace nd4j {
|
|||
}
|
||||
|
||||
if(block.isInplace()) { // in-place
|
||||
x->permutei(*arguments);
|
||||
x->permutei(arguments);
|
||||
STORE_RESULT(x);
|
||||
} else {
|
||||
if (!replace) { // not-in-place
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
// nd4j_printv("permute shape", *arguments);
|
||||
auto result = x->permute(*arguments);
|
||||
auto result = x->permute(arguments);
|
||||
output->assign(result);
|
||||
STORE_RESULT(output);
|
||||
delete result;
|
||||
} else {
|
||||
auto output = OUTPUT_VARIABLE(0); //->dup();
|
||||
output->assign(x);
|
||||
output->permutei(*arguments);
|
||||
}
|
||||
|
||||
//OVERWRITE_RESULT(output);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -92,20 +84,21 @@ namespace nd4j {
|
|||
|
||||
DECLARE_SHAPE_FN(permute) {
|
||||
auto shapeList = SHAPELIST();
|
||||
auto arguments = block.getIArguments();
|
||||
auto arguments = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT<int>() : *block.getIArguments();
|
||||
|
||||
if (shape::rank(inputShape->at(0)) == 0) {
|
||||
shapeList->push_back(ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(inputShape->at(0))));
|
||||
} else if (inputShape->size() == 1 && !arguments->empty()) {
|
||||
shapeList->push_back(ShapeUtils::evalPermShapeInfo(arguments->data(), arguments->size(), *INPUT_VARIABLE(0), block.workspace()));
|
||||
} else if (inputShape->size() == 2) {
|
||||
// dead end
|
||||
shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(inputShape->at(0))));
|
||||
} else if (inputShape->size() == 1 && !arguments.empty()) {
|
||||
shapeList->push_back(ShapeUtils::evalPermShapeInfo(arguments.data(), arguments.size(), *INPUT_VARIABLE(0), block.workspace()));
|
||||
} else {
|
||||
if(arguments.size() == 0){
|
||||
//Reverse dimensions
|
||||
int rank = shape::rank(inputShape->at(0));
|
||||
for (int e = rank - 1; e >= 0; e--)
|
||||
arguments->emplace_back(e);
|
||||
arguments.emplace_back(e);
|
||||
}
|
||||
|
||||
shapeList->push_back(ShapeUtils::evalPermShapeInfo(arguments->data(), arguments->size(), *INPUT_VARIABLE(0), block.workspace()));
|
||||
shapeList->push_back(ShapeUtils::evalPermShapeInfo(arguments.data(), arguments.size(), *INPUT_VARIABLE(0), block.workspace()));
|
||||
}
|
||||
|
||||
return shapeList;
|
||||
|
|
|
@ -35,9 +35,8 @@ namespace nd4j {
|
|||
auto arguments = block.getIArguments();
|
||||
int argsSize = arguments->size();
|
||||
|
||||
//Special case: empty.reshape(-1) -> return empty
|
||||
//Special case: empty.reshape(<other empty shape>) -> return empty
|
||||
if (x->isEmpty()) {
|
||||
REQUIRE_TRUE((int) arguments->size() == 1 && arguments->at(0) == -1, 0, "Reshape: when input is empty, iargs must be [-1]");
|
||||
REQUIRE_TRUE(OUTPUT_VARIABLE(0)->isEmpty(), 0, "Reshape: when input is empty, output must also be empty");
|
||||
return ND4J_STATUS_OK; //No op
|
||||
}
|
||||
|
@ -96,9 +95,9 @@ namespace nd4j {
|
|||
|
||||
//Special case: empty.reshape(-1) -> return empty
|
||||
if (x->isEmpty()) {
|
||||
REQUIRE_TRUE(s->lengthOf() == 1 && s->e<Nd4jLong>(0) == -1, 0, "Reshape: when input is empty, shape must be [-1]");
|
||||
//REQUIRE_TRUE(s->lengthOf() == 1 && s->e<Nd4jLong>(0) == -1, 0, "Reshape: when input is empty, shape must be [-1]");
|
||||
REQUIRE_TRUE(OUTPUT_VARIABLE(0)->isEmpty(), 0, "Reshape: when input is empty, output must also be empty");
|
||||
return ND4J_STATUS_OK; //No op
|
||||
return Status::OK(); //No op
|
||||
}
|
||||
|
||||
char order = 'c';
|
||||
|
@ -116,7 +115,8 @@ namespace nd4j {
|
|||
}
|
||||
for(int e2 = e + 1; e2 < (int) s->lengthOf(); e2++){
|
||||
REQUIRE_TRUE(s->e<Nd4jLong>(e2) != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
|
||||
shapeLength *= s->e<Nd4jLong>(e2);
|
||||
shapeLength *=
|
||||
s->e<Nd4jLong>(e2);
|
||||
}
|
||||
long realShape = x->lengthOf() / shapeLength;
|
||||
shapeNew[e] = realShape;
|
||||
|
@ -175,12 +175,12 @@ namespace nd4j {
|
|||
e = 0;
|
||||
}
|
||||
|
||||
//Special case: empty.reshape(-1) -> return empty
|
||||
if (INPUT_VARIABLE(0)->isEmpty()) {
|
||||
REQUIRE_TRUE((int) arguments->size() == 1 && arguments->at(0) == -1, 0, "Reshape: when input is empty, iargs must be [-1]");
|
||||
auto newShape = ConstantShapeHelper::getInstance()->emptyShapeInfo(ArrayOptions::dataType(inp));
|
||||
return SHAPELIST(newShape);
|
||||
}
|
||||
// //Special case: empty.reshape(-1) -> return empty
|
||||
// if (INPUT_VARIABLE(0)->isEmpty()) {
|
||||
// //
|
||||
// auto newShape = ConstantShapeHelper::getInstance()->emptyShapeInfo(ArrayOptions::dataType(inp));
|
||||
// return SHAPELIST(newShape);
|
||||
// }
|
||||
|
||||
std::vector<Nd4jLong> shapeNew;
|
||||
|
||||
|
@ -197,9 +197,15 @@ namespace nd4j {
|
|||
shapeLength *= arguments->at(e2);
|
||||
}
|
||||
|
||||
if(shapeLength == 0){
|
||||
//Edge case for empty:
|
||||
shapeNew.push_back(0);
|
||||
} else {
|
||||
//Standard case
|
||||
long realShape = shape::length(inp) / shapeLength;
|
||||
shapeNew.push_back(realShape);
|
||||
}
|
||||
}
|
||||
else{
|
||||
shapeNew.push_back(arguments->at(e));
|
||||
}
|
||||
|
@ -218,9 +224,16 @@ namespace nd4j {
|
|||
}
|
||||
//Special case: empty.reshape(-1) -> return empty
|
||||
if (x->isEmpty()) {
|
||||
REQUIRE_TRUE(y->lengthOf() == 1 && y->e<Nd4jLong>(0) == -1, 0, "Reshape: when input is empty, shape must be [-1]");
|
||||
auto newShape = ConstantShapeHelper::getInstance()->emptyShapeInfo(ArrayOptions::dataType(inp));
|
||||
return SHAPELIST(newShape);
|
||||
//REQUIRE_TRUE(y->lengthOf() == 1 && y->e<Nd4jLong>(0) == -1, 0, "Reshape: when input is empty, shape must be [-1]");
|
||||
auto shapeOf = y->getBufferAsVector<Nd4jLong>();
|
||||
Nd4jLong prod = 1;
|
||||
for (auto v:shapeOf)
|
||||
prod *= v;
|
||||
|
||||
REQUIRE_TRUE(prod == 0, 0, "Reshape: in case of empty arrays reshape must return empty array as well");
|
||||
|
||||
auto newShape = ShapeBuilders::createShapeInfo(ArrayOptions::dataType(inp), shape::order(inp), y->lengthOf(), shapeOf.data());
|
||||
return SHAPELIST(CONSTANT(newShape));
|
||||
}
|
||||
|
||||
std::vector<Nd4jLong> shapeNew(y->lengthOf());
|
||||
|
@ -236,8 +249,14 @@ namespace nd4j {
|
|||
REQUIRE_TRUE(y->e<Nd4jLong>(e2) != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
|
||||
shapeLength *= y->e<Nd4jLong>(e2);
|
||||
}
|
||||
|
||||
if(shapeLength == 0){
|
||||
//Edge case for empty:
|
||||
shapeNew[e] = 0;
|
||||
} else {
|
||||
long realShape = shape::length(inp) / shapeLength;
|
||||
shapeNew[e] = realShape;
|
||||
}
|
||||
}else {
|
||||
shapeNew[e] = dim;
|
||||
}
|
||||
|
|
|
@ -38,21 +38,26 @@ CUSTOM_OP_IMPL(concat, -1, 1, false, 0, 1) {
|
|||
std::vector<int> arrsToDelete;
|
||||
int index = 0;
|
||||
bool allOfSameType = true;
|
||||
|
||||
auto theFirstRank = block.width() > 0?INPUT_VARIABLE(0)->rankOf():0;
|
||||
auto theFirstDatatype = block.width() > 0?INPUT_VARIABLE(0)->dataType():block.dataType();
|
||||
for(int i = 0; i < block.width(); ++i) {
|
||||
auto input = INPUT_VARIABLE(i);
|
||||
auto currentRank = input->rankOf();
|
||||
|
||||
if(!INPUT_VARIABLE(i)->isEmpty()) {
|
||||
// TODO: follow two lines are accordingly with current tf.concat spec. Commented for compatibility with legacy
|
||||
// REQUIRE_TRUE(currentRank > 0, 0, "Rank of input variable %i must be greater 0, but is %lld instead.", i, currentRank);
|
||||
// REQUIRE_TRUE(theFirstRank == currentRank, 0, "Number of dimensions in concat should be equals, but for %i input variable %lld != %lld appears.", i, currentRank, theFirstRank);
|
||||
if(!input->isEmpty()) {
|
||||
|
||||
allOfSameType &= (INPUT_VARIABLE(0)->dataType() == INPUT_VARIABLE(i)->dataType());
|
||||
if(INPUT_VARIABLE(i)->rankOf() == 0) {
|
||||
// FIXME, use this instead: block.dataType()
|
||||
auto vec = new NDArray('c', {1}, INPUT_VARIABLE(0)->dataType(), block.launchContext());
|
||||
vec->assign(INPUT_VARIABLE(i));
|
||||
allOfSameType &= (theFirstDatatype == input->dataType());
|
||||
if(input->rankOf() == 0) {
|
||||
auto vec = new NDArray('c', {1}, input->dataType(), block.launchContext());
|
||||
vec->assign(input);
|
||||
nonEmptyArrs.push_back(vec);
|
||||
arrsToDelete.push_back(index);
|
||||
}
|
||||
else{
|
||||
nonEmptyArrs.push_back(INPUT_VARIABLE(i));
|
||||
nonEmptyArrs.push_back(input);
|
||||
}
|
||||
++index;
|
||||
}
|
||||
|
@ -113,33 +118,24 @@ DECLARE_SHAPE_FN(concat) {
|
|||
|
||||
// first of all take into account possible presence of empty arrays
|
||||
// also if scalar is present -> use the shape of vector with length=1 instead
|
||||
std::vector<Nd4jLong*> nonEmptyArrShapes;
|
||||
std::vector<Nd4jLong*> arrShapes;
|
||||
std::vector<int> shapesToDelete;
|
||||
int index = 0;
|
||||
for(int i = 0; i < block.width(); ++i) {
|
||||
|
||||
if(!INPUT_VARIABLE(i)->isEmpty()) {
|
||||
|
||||
if(inputShape->at(i)[0] == 0) {
|
||||
// FIXME, use this instead: block.dataType()
|
||||
nonEmptyArrShapes.push_back(ConstantShapeHelper::getInstance()->vectorShapeInfo(1, INPUT_VARIABLE(0)->dataType()));
|
||||
arrShapes.push_back(ConstantShapeHelper::getInstance()->vectorShapeInfo(1, INPUT_VARIABLE(0)->dataType()));
|
||||
}
|
||||
else{
|
||||
nonEmptyArrShapes.push_back(inputShape->at(i));
|
||||
arrShapes.push_back(inputShape->at(i));
|
||||
}
|
||||
++index;
|
||||
}
|
||||
}
|
||||
|
||||
const int numOfArrs = nonEmptyArrShapes.size();
|
||||
const int numOfArrs = arrShapes.size();
|
||||
|
||||
if(numOfArrs == 0){
|
||||
//All inputs are empty arrays -> return empty, mainly for TF import compatibility
|
||||
auto empty = ConstantShapeHelper::getInstance()->emptyShapeInfo(INPUT_VARIABLE(0)->dataType());
|
||||
return SHAPELIST(empty);
|
||||
}
|
||||
|
||||
const int rank = nonEmptyArrShapes[0][0]; // look up to first non-empty array
|
||||
const int rank = arrShapes[0][0];
|
||||
|
||||
int axis = INT_ARG(0);
|
||||
if(axis < 0)
|
||||
|
@ -149,33 +145,33 @@ DECLARE_SHAPE_FN(concat) {
|
|||
REQUIRE_TRUE(0 <= axis && axis < rank, 0, "CONCAT op: input axis must be in range [0, %i], but got %i instead!", rank-1, axis);
|
||||
|
||||
for(int i = 1; i < numOfArrs; ++i)
|
||||
REQUIRE_TRUE(nonEmptyArrShapes[i][0] == rank, 0, "CONCAT op: all input arrays must have the same rank !");
|
||||
REQUIRE_TRUE(arrShapes[i][0] == rank, 0, "CONCAT op: all input arrays must have the same rank !");
|
||||
|
||||
for(int i = 1; i < numOfArrs; ++i) {
|
||||
for(int dim = 0; dim < rank; ++dim)
|
||||
if(dim != axis)
|
||||
REQUIRE_TRUE(nonEmptyArrShapes[i][dim+1] == nonEmptyArrShapes[0][dim+1], 0, "CONCAT op: all input arrays must have the same dimensions (except those on input axis) !");
|
||||
REQUIRE_TRUE(arrShapes[i][dim+1] == arrShapes[0][dim+1], 0, "CONCAT op: all input arrays must have the same dimensions (except those on input axis) !");
|
||||
}
|
||||
// ******** end of input validation ******** //
|
||||
|
||||
|
||||
Nd4jLong* outShapeInfo(nullptr);
|
||||
COPY_SHAPE(nonEmptyArrShapes[0], outShapeInfo);
|
||||
COPY_SHAPE(arrShapes[0], outShapeInfo);
|
||||
|
||||
// case when we have only one input array
|
||||
if(numOfArrs == 1) {
|
||||
ShapeUtils::updateStridesAndType(outShapeInfo, nonEmptyArrShapes[0], shape::order(nonEmptyArrShapes[0]));
|
||||
ShapeUtils::updateStridesAndType(outShapeInfo, arrShapes[0], shape::order(arrShapes[0]));
|
||||
return SHAPELIST(CONSTANT(outShapeInfo));
|
||||
}
|
||||
|
||||
for(int i = 1; i < numOfArrs; ++i)
|
||||
outShapeInfo[axis + 1] += nonEmptyArrShapes[i][axis + 1];
|
||||
outShapeInfo[axis + 1] += arrShapes[i][axis + 1];
|
||||
|
||||
ShapeUtils::updateStridesAndType(outShapeInfo, nonEmptyArrShapes[0], shape::order(nonEmptyArrShapes[0]));
|
||||
ShapeUtils::updateStridesAndType(outShapeInfo, arrShapes[0], shape::order(arrShapes[0]));
|
||||
|
||||
// delete dynamically allocated vectors shapes with length=1
|
||||
for(int index : shapesToDelete)
|
||||
RELEASE(nonEmptyArrShapes[index], block.getWorkspace());
|
||||
RELEASE(arrShapes[index], block.getWorkspace());
|
||||
|
||||
auto result = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(outShapeInfo));
|
||||
RELEASE(outShapeInfo, block.getWorkspace());
|
||||
|
|
|
@ -32,6 +32,11 @@ namespace nd4j {
|
|||
|
||||
REQUIRE_TRUE(input->dataType() == output->dataType(), 0, "CumSum: input and output data types must be equal");
|
||||
|
||||
if(input->isEmpty()){
|
||||
//No-op
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
const bool exclusive = INT_ARG(0) == 1;
|
||||
const bool reverse = INT_ARG(1) == 1;
|
||||
|
||||
|
|
|
@ -36,6 +36,11 @@ CONFIGURABLE_OP_IMPL(cumsum, 1, 1, true, 0, 2) {
|
|||
|
||||
REQUIRE_TRUE(input->dataType() == output->dataType(), 0, "CumSum: input and output data types must be equal");
|
||||
|
||||
if(input->isEmpty()){
|
||||
//No-op
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (block.getIArguments()->size() == 2 && block.width() == 1) {
|
||||
// all at once case
|
||||
nd4j::ops::helpers::_prefix(block.launchContext(), scalar::Add, input, output, exclusive, reverse);
|
||||
|
|
|
@ -102,12 +102,6 @@ DECLARE_SHAPE_FN(gather) {
|
|||
if(axis < 0)
|
||||
axis += inputRank;
|
||||
|
||||
//Edge case: empty indices, empty input -> empty output
|
||||
if(block.width() > 1 && INPUT_VARIABLE(0)->isEmpty() && INPUT_VARIABLE(1)->isEmpty()){
|
||||
auto empty = ConstantShapeHelper::getInstance()->emptyShapeInfo(INPUT_VARIABLE(0)->dataType());
|
||||
return SHAPELIST(empty);
|
||||
}
|
||||
|
||||
REQUIRE_TRUE(axis < inputRank, 0, "GATHER op: input axis must be smaller than input array rank, but got %i and %i correspondingly!", axis, inputRank);
|
||||
|
||||
bool isEmpty = false;
|
||||
|
@ -119,11 +113,6 @@ DECLARE_SHAPE_FN(gather) {
|
|||
|
||||
int outputRank = inputRank + indicesRank - 1;
|
||||
|
||||
if(INPUT_VARIABLE(1)->isEmpty()) { //Empty indices -> empty output
|
||||
outputRank = 0;
|
||||
isEmpty = true;
|
||||
}
|
||||
|
||||
ALLOCATE(outputShapeInfo, block.getWorkspace(), shape::shapeInfoLength(outputRank), Nd4jLong);
|
||||
|
||||
// fill output shapeInfo
|
||||
|
|
|
@ -33,6 +33,11 @@ namespace ops {
|
|||
auto input = INPUT_VARIABLE(0);
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
|
||||
if(output->isEmpty()){
|
||||
//No-op
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::vector<int> axis;
|
||||
|
||||
if (block.width() > 1)
|
||||
|
|
|
@ -204,6 +204,8 @@ namespace nd4j {
|
|||
const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst,
|
||||
mkldnn::memory::desc* conv_src_md, mkldnn::memory::desc* conv_diff_src_md, mkldnn::memory::desc* conv_weights_md,
|
||||
mkldnn::memory::desc* conv_diff_weights_md, mkldnn::memory::desc* conv_bias_md, mkldnn::memory::desc* conv_dst_md,
|
||||
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_weights_md,
|
||||
mkldnn::memory::desc* user_diff_weights_md, mkldnn::memory::desc* user_bias_md, mkldnn::memory::desc* user_dst_md,
|
||||
mkldnn::memory::dims& conv_strides, mkldnn::memory::dims& conv_padding, mkldnn::memory::dims& conv_padding_r);
|
||||
|
||||
static void getMKLDNNMemoryDescConv3d(
|
||||
|
@ -212,56 +214,60 @@ namespace nd4j {
|
|||
const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst,
|
||||
mkldnn::memory::desc* conv_src_md, mkldnn::memory::desc* conv_diff_src_md, mkldnn::memory::desc* conv_weights_md,
|
||||
mkldnn::memory::desc* conv_diff_weights_md, mkldnn::memory::desc* conv_bias_md, mkldnn::memory::desc* conv_dst_md,
|
||||
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_weights_md,
|
||||
mkldnn::memory::desc* user_diff_weights_md, mkldnn::memory::desc* user_bias_md, mkldnn::memory::desc* user_dst_md,
|
||||
mkldnn::memory::dims& conv_strides, mkldnn::memory::dims& conv_padding, mkldnn::memory::dims& conv_padding_r);
|
||||
|
||||
static void getMKLDNNMemoryDescPool2d(
|
||||
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, int poolingMode, int extraParam0, bool isNCHW,
|
||||
int bS, int iC, int iH, int iW, int oC, int oH, int oW,
|
||||
const NDArray* src, const NDArray* diff_src, const NDArray* dst,
|
||||
mkldnn::memory::desc* pool_src_md, mkldnn::memory::desc* conv_diff_src_md, mkldnn::memory::desc* pool_dst_md, mkldnn::algorithm& algorithm,
|
||||
const NDArray* src, const NDArray* diff_src, const NDArray* dst, mkldnn::algorithm& algorithm,
|
||||
mkldnn::memory::desc* pool_src_md, mkldnn::memory::desc* pool_diff_src_md, mkldnn::memory::desc* pool_dst_md,
|
||||
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md,
|
||||
mkldnn::memory::dims& pool_strides, mkldnn::memory::dims& pool_kernel, mkldnn::memory::dims& pool_padding, mkldnn::memory::dims& pool_padding_r);
|
||||
|
||||
static void getMKLDNNMemoryDescPool3d(
|
||||
int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, int poolingMode, int extraParam0, bool isNCDHW,
|
||||
int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW,
|
||||
const NDArray* src, const NDArray* diff_src, const NDArray* dst,
|
||||
mkldnn::memory::desc* pool_src_md, mkldnn::memory::desc* conv_diff_src_md, mkldnn::memory::desc* pool_dst_md, mkldnn::algorithm& algorithm,
|
||||
const NDArray* src, const NDArray* diff_src, const NDArray* dst, mkldnn::algorithm& algorithm,
|
||||
mkldnn::memory::desc* pool_src_md, mkldnn::memory::desc* pool_diff_src_md, mkldnn::memory::desc* pool_dst_md,
|
||||
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md,
|
||||
mkldnn::memory::dims& pool_strides, mkldnn::memory::dims& pool_kernel, mkldnn::memory::dims& pool_padding, mkldnn::memory::dims& pool_padding_r);
|
||||
#endif
|
||||
|
||||
static void conv2d(nd4j::LaunchContext &context, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW);
|
||||
static void conv2d(nd4j::graph::Context &context, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW);
|
||||
|
||||
static void conv2d(nd4j::LaunchContext & block, const std::vector<NDArray*>& inArrs, NDArray* output, const std::vector<int>& intArgs);
|
||||
static void conv2d(nd4j::graph::Context & block, const std::vector<NDArray*>& inArrs, NDArray* output, const std::vector<int>& intArgs);
|
||||
|
||||
static void conv2dBP(nd4j::LaunchContext & block, const std::vector<NDArray*>& inArrs, const std::vector<NDArray*>& outArrs, const std::vector<int>& intArgs);
|
||||
static void conv2dBP(nd4j::graph::Context & block, const std::vector<NDArray*>& inArrs, const std::vector<NDArray*>& outArrs, const std::vector<int>& intArgs);
|
||||
|
||||
static void conv2dBP(nd4j::LaunchContext & block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW);
|
||||
static void conv2dBP(nd4j::graph::Context & block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW);
|
||||
|
||||
static void depthwiseConv2d(nd4j::LaunchContext & block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW);
|
||||
static void depthwiseConv2d(nd4j::graph::Context & block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW);
|
||||
|
||||
static void depthwiseConv2dBP(nd4j::LaunchContext & block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW);
|
||||
static void depthwiseConv2dBP(nd4j::graph::Context & block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW);
|
||||
|
||||
static void sconv2d(nd4j::LaunchContext & block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW);
|
||||
static void sconv2d(nd4j::graph::Context & block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW);
|
||||
|
||||
static void vol2col(nd4j::LaunchContext & block, const NDArray& vol, NDArray& col, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW);
|
||||
static void vol2col(nd4j::graph::Context & block, const NDArray& vol, NDArray& col, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW);
|
||||
|
||||
static void col2vol(nd4j::LaunchContext & block, const NDArray& col, NDArray& vol, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW);
|
||||
static void col2vol(nd4j::graph::Context & block, const NDArray& col, NDArray& vol, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW);
|
||||
|
||||
static void upsampling2d(nd4j::LaunchContext & block, const NDArray& input, NDArray& output, const int factorH, const int factorW, const bool isNCHW);
|
||||
static void upsampling2d(nd4j::graph::Context & block, const NDArray& input, NDArray& output, const int factorH, const int factorW, const bool isNCHW);
|
||||
|
||||
static void upsampling3d(nd4j::LaunchContext & block, const NDArray& input, NDArray& output, const int factorD, const int factorH, const int factorW, const bool isNCDHW);
|
||||
static void upsampling3d(nd4j::graph::Context & block, const NDArray& input, NDArray& output, const int factorD, const int factorH, const int factorW, const bool isNCDHW);
|
||||
|
||||
static void upsampling2dBP(nd4j::LaunchContext & block, const NDArray& gradO, NDArray& gradI, const bool isNCHW);
|
||||
static void upsampling2dBP(nd4j::graph::Context & block, const NDArray& gradO, NDArray& gradI, const bool isNCHW);
|
||||
|
||||
static void upsampling3dBP(nd4j::LaunchContext & block, const NDArray& gradO, NDArray& gradI, const bool isNCDHW);
|
||||
static void upsampling3dBP(nd4j::graph::Context & block, const NDArray& gradO, NDArray& gradI, const bool isNCDHW);
|
||||
|
||||
static void pooling2d(nd4j::LaunchContext & block, const NDArray& input, NDArray& output, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const PoolingType poolingMode, const int extraParam0);
|
||||
static void pooling2d(nd4j::graph::Context & block, const NDArray& input, NDArray& output, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const PoolingType poolingMode, const int extraParam0);
|
||||
|
||||
static void pooling3d(nd4j::LaunchContext & block, const NDArray& input, NDArray& output, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0);
|
||||
static void pooling3d(nd4j::graph::Context & block, const NDArray& input, NDArray& output, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0);
|
||||
|
||||
static void pooling2dBP(nd4j::LaunchContext & block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0);
|
||||
static void pooling2dBP(nd4j::graph::Context & block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0);
|
||||
|
||||
static void pooling3dBP(nd4j::LaunchContext & block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0);
|
||||
static void pooling3dBP(nd4j::graph::Context & block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0);
|
||||
};
|
||||
|
||||
}
|
||||
|
|
|
@ -153,7 +153,7 @@ void softMaxForVector(nd4j::LaunchContext * context, const NDArray& input, NDArr
|
|||
if (inEWS == 1) {
|
||||
PRAGMA_OMP_SIMD_MAX(max)
|
||||
for (int i = 0; i < length; i++)
|
||||
max = nd4j::math::nd4j_max<T>(max, outBuff[i]);
|
||||
max = nd4j::math::nd4j_max<T>(max, inBuff[i]);
|
||||
|
||||
PRAGMA_OMP_SIMD_SUM(sum)
|
||||
for (int i = 0; i < length; i++) {
|
||||
|
@ -171,7 +171,7 @@ void softMaxForVector(nd4j::LaunchContext * context, const NDArray& input, NDArr
|
|||
|
||||
PRAGMA_OMP_SIMD_MAX(max)
|
||||
for (int i = 0; i < length; i++)
|
||||
max = nd4j::math::nd4j_max<T>(max, outBuff[i * inEWS]);
|
||||
max = nd4j::math::nd4j_max<T>(max, inBuff[i * inEWS]);
|
||||
|
||||
PRAGMA_OMP_SIMD_SUM(sum)
|
||||
for (int i = 0; i < length; i++) {
|
||||
|
|
|
@ -18,8 +18,8 @@
|
|||
// @author Yurii Shyrma (iuriish@yahoo.com), created on 18.09.2018
|
||||
//
|
||||
|
||||
#include<ops/declarable/helpers/addBias.h>
|
||||
#include <ops/declarable/helpers/convolutions.h>
|
||||
#include<ops/declarable/helpers/addBias.h>
|
||||
#include <ops/declarable/helpers/im2col.h>
|
||||
#include <ops/declarable/helpers/col2im.h>
|
||||
#include <NDArrayFactory.h>
|
||||
|
@ -28,6 +28,122 @@
|
|||
namespace nd4j {
|
||||
namespace ops {
|
||||
|
||||
#ifdef HAVE_MKLDNN
|
||||
using namespace mkldnn;
|
||||
|
||||
void ConvolutionUtils::getMKLDNNMemoryDescPool2d(
|
||||
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, int poolingMode, int extraParam0, bool isNCHW,
|
||||
int bS, int iC, int iH, int iW, int oC, int oH, int oW,
|
||||
const NDArray* src, const NDArray* diff_src, const NDArray* dst, mkldnn::algorithm& algorithm,
|
||||
mkldnn::memory::desc* pool_src_md, mkldnn::memory::desc* pool_diff_src_md, mkldnn::memory::desc* pool_dst_md,
|
||||
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md,
|
||||
mkldnn::memory::dims& pool_strides, mkldnn::memory::dims& pool_kernel, mkldnn::memory::dims& pool_padding, mkldnn::memory::dims& pool_padding_r) {
|
||||
mkldnn::memory::dims pool_src_tz = { bS, iC, iH, iW };
|
||||
mkldnn::memory::dims pool_dst_tz = { bS, oC, oH, oW };
|
||||
|
||||
pool_strides = { sH, sW };
|
||||
pool_kernel = { kH, kW };
|
||||
pool_padding = { pH, pW };
|
||||
pool_padding_r = { (oH - 1) * sH - iH + kH - pH,
|
||||
(oW - 1) * sW - iW + kW - pW };
|
||||
|
||||
algorithm = poolingMode == 0 ? pooling_max
|
||||
: extraParam0 == 0 ? pooling_avg_exclude_padding
|
||||
: pooling_avg_include_padding;
|
||||
auto type = mkldnn::memory::data_type::f32;
|
||||
auto format = isNCHW ? mkldnn::memory::format::nchw : mkldnn::memory::format::nhwc;
|
||||
auto supposed_to_be_any_format = mkldnn::memory::format::nChw8c; // doesn't work with "any"
|
||||
|
||||
if (src != nullptr && src->getBuffer() != nullptr && pool_src_md != nullptr) {
|
||||
*pool_src_md = mkldnn::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format);
|
||||
*user_src_md = mkldnn::memory::desc({ pool_src_tz }, type, format);
|
||||
user_src_md->data.format = mkldnn_blocked; // overrides "format = isNCHW ? nchw : nhwc"
|
||||
user_src_md->data.layout_desc.blocking.strides[0][0] = src->stridesOf()[isNCHW ? 0 : 0];
|
||||
user_src_md->data.layout_desc.blocking.strides[0][1] = src->stridesOf()[isNCHW ? 1 : 3];
|
||||
user_src_md->data.layout_desc.blocking.strides[0][2] = src->stridesOf()[isNCHW ? 2 : 1];
|
||||
user_src_md->data.layout_desc.blocking.strides[0][3] = src->stridesOf()[isNCHW ? 3 : 2];
|
||||
}
|
||||
|
||||
if (diff_src != nullptr && diff_src->getBuffer() != nullptr && pool_diff_src_md != nullptr) {
|
||||
*pool_diff_src_md = mkldnn::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format);
|
||||
*user_diff_src_md = mkldnn::memory::desc({ pool_src_tz }, type, format);
|
||||
user_diff_src_md->data.format = mkldnn_blocked; // overrides "format = isNCHW ? nchw : nhwc"
|
||||
user_diff_src_md->data.layout_desc.blocking.strides[0][0] = diff_src->stridesOf()[isNCHW ? 0 : 0];
|
||||
user_diff_src_md->data.layout_desc.blocking.strides[0][1] = diff_src->stridesOf()[isNCHW ? 1 : 3];
|
||||
user_diff_src_md->data.layout_desc.blocking.strides[0][2] = diff_src->stridesOf()[isNCHW ? 2 : 1];
|
||||
user_diff_src_md->data.layout_desc.blocking.strides[0][3] = diff_src->stridesOf()[isNCHW ? 3 : 2];
|
||||
}
|
||||
|
||||
if (dst != nullptr && dst->getBuffer() != nullptr && pool_dst_md != nullptr) {
|
||||
*pool_dst_md = mkldnn::memory::desc({ pool_dst_tz }, type, supposed_to_be_any_format);
|
||||
*user_dst_md = mkldnn::memory::desc({ pool_dst_tz }, type, format);
|
||||
user_dst_md->data.format = mkldnn_blocked; // overrides "format = isNCHW ? nchw : nhwc"
|
||||
user_dst_md->data.layout_desc.blocking.strides[0][0] = dst->stridesOf()[isNCHW ? 0 : 0];
|
||||
user_dst_md->data.layout_desc.blocking.strides[0][1] = dst->stridesOf()[isNCHW ? 1 : 3];
|
||||
user_dst_md->data.layout_desc.blocking.strides[0][2] = dst->stridesOf()[isNCHW ? 2 : 1];
|
||||
user_dst_md->data.layout_desc.blocking.strides[0][3] = dst->stridesOf()[isNCHW ? 3 : 2];
|
||||
}
|
||||
}
|
||||
|
||||
void ConvolutionUtils::getMKLDNNMemoryDescPool3d(
|
||||
int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, int poolingMode, int extraParam0, bool isNCDHW,
|
||||
int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW,
|
||||
const NDArray* src, const NDArray* diff_src, const NDArray* dst, mkldnn::algorithm& algorithm,
|
||||
mkldnn::memory::desc* pool_src_md, mkldnn::memory::desc* pool_diff_src_md, mkldnn::memory::desc* pool_dst_md,
|
||||
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md,
|
||||
mkldnn::memory::dims& pool_strides, mkldnn::memory::dims& pool_kernel, mkldnn::memory::dims& pool_padding, mkldnn::memory::dims& pool_padding_r) {
|
||||
mkldnn::memory::dims pool_src_tz = { bS, iC, iD, iH, iW };
|
||||
mkldnn::memory::dims pool_dst_tz = { bS, oC, oD, oH, oW };
|
||||
|
||||
pool_strides = { sD, sH, sW };
|
||||
pool_kernel = { kD, kH, kW };
|
||||
pool_padding = { pD, pH, pW };
|
||||
pool_padding_r = { (oD - 1) * sD - iD + kD - pD,
|
||||
(oH - 1) * sH - iH + kH - pH,
|
||||
(oW - 1) * sW - iW + kW - pW };
|
||||
|
||||
algorithm = poolingMode == 0 ? pooling_max
|
||||
: extraParam0 == 0 ? pooling_avg_exclude_padding
|
||||
: pooling_avg_include_padding;
|
||||
auto type = mkldnn::memory::data_type::f32;
|
||||
auto format = isNCDHW ? mkldnn::memory::format::ncdhw : mkldnn::memory::format::ndhwc;
|
||||
auto supposed_to_be_any_format = mkldnn::memory::format::nCdhw8c; // doesn't work with "any"
|
||||
|
||||
if (src != nullptr && src->getBuffer() != nullptr && pool_src_md != nullptr) {
|
||||
*pool_src_md = mkldnn::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format);
|
||||
*user_src_md = mkldnn::memory::desc({ pool_src_tz }, type, format);
|
||||
user_src_md->data.format = mkldnn_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
|
||||
user_src_md->data.layout_desc.blocking.strides[0][0] = src->stridesOf()[isNCDHW ? 0 : 0];
|
||||
user_src_md->data.layout_desc.blocking.strides[0][1] = src->stridesOf()[isNCDHW ? 1 : 4];
|
||||
user_src_md->data.layout_desc.blocking.strides[0][2] = src->stridesOf()[isNCDHW ? 2 : 1];
|
||||
user_src_md->data.layout_desc.blocking.strides[0][3] = src->stridesOf()[isNCDHW ? 3 : 2];
|
||||
user_src_md->data.layout_desc.blocking.strides[0][4] = src->stridesOf()[isNCDHW ? 4 : 3];
|
||||
}
|
||||
|
||||
if (diff_src != nullptr && diff_src->getBuffer() != nullptr && pool_diff_src_md != nullptr) {
|
||||
*pool_diff_src_md = mkldnn::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format);
|
||||
*user_diff_src_md = mkldnn::memory::desc({ pool_src_tz }, type, format);
|
||||
user_diff_src_md->data.format = mkldnn_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
|
||||
user_diff_src_md->data.layout_desc.blocking.strides[0][0] = diff_src->stridesOf()[isNCDHW ? 0 : 0];
|
||||
user_diff_src_md->data.layout_desc.blocking.strides[0][1] = diff_src->stridesOf()[isNCDHW ? 1 : 4];
|
||||
user_diff_src_md->data.layout_desc.blocking.strides[0][2] = diff_src->stridesOf()[isNCDHW ? 2 : 1];
|
||||
user_diff_src_md->data.layout_desc.blocking.strides[0][3] = diff_src->stridesOf()[isNCDHW ? 3 : 2];
|
||||
user_diff_src_md->data.layout_desc.blocking.strides[0][4] = diff_src->stridesOf()[isNCDHW ? 4 : 3];
|
||||
}
|
||||
|
||||
if (dst != nullptr && dst->getBuffer() != nullptr && pool_dst_md != nullptr) {
|
||||
*pool_dst_md = mkldnn::memory::desc({ pool_dst_tz }, type, supposed_to_be_any_format);
|
||||
*user_dst_md = mkldnn::memory::desc({ pool_dst_tz }, type, format);
|
||||
user_dst_md->data.format = mkldnn_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
|
||||
user_dst_md->data.layout_desc.blocking.strides[0][0] = dst->stridesOf()[isNCDHW ? 0 : 0];
|
||||
user_dst_md->data.layout_desc.blocking.strides[0][1] = dst->stridesOf()[isNCDHW ? 1 : 4];
|
||||
user_dst_md->data.layout_desc.blocking.strides[0][2] = dst->stridesOf()[isNCDHW ? 2 : 1];
|
||||
user_dst_md->data.layout_desc.blocking.strides[0][3] = dst->stridesOf()[isNCDHW ? 3 : 2];
|
||||
user_dst_md->data.layout_desc.blocking.strides[0][4] = dst->stridesOf()[isNCDHW ? 4 : 3];
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
// [bS, iC, iD, iH, iW] is convoluted to [bS, iC, kD, kH, kW, oD, oH, oW]
|
||||
template <typename T>
|
||||
|
@ -400,7 +516,7 @@ void ConvolutionUtils::getMKLDNNMemoryDescConv3d(
|
|||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename X, typename Y>
|
||||
static void conv2d_(nd4j::LaunchContext & block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) {
|
||||
static void conv2d_(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) {
|
||||
|
||||
// input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||
// weights [kH, kW, iC, oC] always
|
||||
|
@ -511,13 +627,13 @@ static void conv2d_(nd4j::LaunchContext & block, const NDArray* input, const NDA
|
|||
// permutForOutput = {0, indOoH, indOoH+1, indIOioC}; // [bS, oC, oH, oW] -> [bS, oH, oW, oC]
|
||||
permutForOutput = {0, 3, 1, 2}; // [bS, oH, oW, oC] -> [bS, oC, oH, oW]
|
||||
|
||||
NDArray col('c', {bS, oH, oW, kH, kW, iC}, input->dataType(), &block);
|
||||
NDArray col('c', {bS, oH, oW, kH, kW, iC}, input->dataType(), input->getContext());
|
||||
NDArray* colP = col.permute({0, 5, 3, 4, 1, 2}); // {bS, iC, kH, kW, oH, oW}
|
||||
NDArray mmulResult('f', {bS*oH*oW, oC}, output->dataType(), &block);
|
||||
NDArray mmulResult('f', {bS*oH*oW, oC}, output->dataType(), output->getContext());
|
||||
|
||||
//----- calculation of output -----//
|
||||
nd4j::LaunchContext ctx;
|
||||
helpers::im2col(ctx, *input, *colP, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, &block)); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW]
|
||||
auto ctx = block.launchContext();
|
||||
helpers::im2col(*ctx, *input, *colP, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW]
|
||||
MmulHelper::tensorDot(&col, weights, &mmulResult, {3,4,5}, {0,1,2}, {}); // [bS, oH, oW, kH, kW, iC] x [kH, kW, iC, oC] = [bS, oH, oW, oC]
|
||||
|
||||
//----- assign outTemp to output -----//
|
||||
|
@ -540,7 +656,7 @@ static void conv2d_(nd4j::LaunchContext & block, const NDArray* input, const NDA
|
|||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename X, typename Y>
|
||||
static void conv2dBP_(nd4j::LaunchContext & block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) {
|
||||
static void conv2dBP_(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) {
|
||||
|
||||
// input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||
// weights [kH, kW, iC, oC] always
|
||||
|
@ -730,7 +846,7 @@ static void conv2dBP_(nd4j::LaunchContext & block, const NDArray* input, const N
|
|||
|
||||
// ----- calculation of gradW ----- //
|
||||
if(gradW) {
|
||||
nd4j::LaunchContext * ctx = █
|
||||
auto ctx = block.launchContext();
|
||||
helpers::im2col(*ctx, *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW]
|
||||
nd4j::MmulHelper::tensorDot(&columns, gradO, gradW, {0,4,5}, gradOaxesForDot, {2, 0, 1, 3}); // [bS, iC, kH, kW, oH, oW] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [iC, kH, kW, oC]
|
||||
}
|
||||
|
@ -747,8 +863,8 @@ static void conv2dBP_(nd4j::LaunchContext & block, const NDArray* input, const N
|
|||
|
||||
//----- calculation of gradI -----//
|
||||
nd4j::MmulHelper::tensorDot(weights, gradO, &columns, {indWoC}, {indIOioC}, {2, 3, 1, 0, 4, 5}); // [kH, kW, iC, oC]/[oC, iC, kH, kW]] x [bS, oH, oW, oC]/[bS, oC, oH, oW] = [kH, kW, iC, bS, oH, oW]
|
||||
nd4j::LaunchContext * ctx = █
|
||||
helpers::col2im(*ctx, columns, *gradI, sH, sW, pH, pW, iH, iW, dH, dW); // [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW]
|
||||
|
||||
helpers::col2im(*block.launchContext(), columns, *gradI, sH, sW, pH, pW, iH, iW, dH, dW); // [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW]
|
||||
|
||||
if(!isNCHW) {
|
||||
delete input;
|
||||
|
@ -758,7 +874,7 @@ static void conv2dBP_(nd4j::LaunchContext & block, const NDArray* input, const N
|
|||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename X, typename Y>
|
||||
static void depthwiseConv2d_(nd4j::LaunchContext & block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) {
|
||||
static void depthwiseConv2d_(const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) {
|
||||
|
||||
// input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||
// weights [kH, kW, iC, mC] always
|
||||
|
@ -801,8 +917,7 @@ static void depthwiseConv2d_(nd4j::LaunchContext & block, const NDArray* input,
|
|||
NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, input->dataType(), input->getContext());
|
||||
NDArray* outputReshaped = output->reshape(output->ordering(), outReShape);
|
||||
|
||||
nd4j::LaunchContext * ctx = █
|
||||
helpers::im2col(*ctx, *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW]
|
||||
helpers::im2col(*output->getContext(), *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW]
|
||||
MmulHelper::tensorDot(&columns, weights, outputReshaped, modifColumns, {{2,0,1,3},{iC,kH*kW,mC}}, modifOutput); // [iC, bS*oH*oW, kW*kH] x [iC, kH*kW, mC] = [iC, bS*oH*oW, mC]
|
||||
|
||||
if(bias)
|
||||
|
@ -816,7 +931,7 @@ static void depthwiseConv2d_(nd4j::LaunchContext & block, const NDArray* input,
|
|||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename X, typename Y>
|
||||
static void depthwiseConv2dBP_(nd4j::LaunchContext & block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) {
|
||||
static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) {
|
||||
|
||||
// input [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW)
|
||||
// weights [kH, kW, iC, mC] always
|
||||
|
@ -867,8 +982,7 @@ static void depthwiseConv2dBP_(nd4j::LaunchContext & block, const NDArray* input
|
|||
|
||||
// ----- calculation of gradW and gradB ----- //
|
||||
|
||||
nd4j::LaunchContext * ctx = █
|
||||
helpers::im2col(*ctx, *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW]
|
||||
helpers::im2col(*input->getContext(), *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW]
|
||||
nd4j::MmulHelper::tensorDot(&columns, gradOreshaped, gradW, modifColumns, modifGradO1, {{2,0,1,3},{iC,kH*kW,mC}}); // [iC, kW*kH, bS*oH*oW] x [iC, bS*oH*oW, mC] = [iC, kH*kW, mC]
|
||||
|
||||
// ----- calculation of gradB ----- //
|
||||
|
@ -883,7 +997,7 @@ static void depthwiseConv2dBP_(nd4j::LaunchContext & block, const NDArray* input
|
|||
|
||||
//----- calculation of gradI -----//
|
||||
nd4j::MmulHelper::tensorDot(weights, gradO, &columns, {{2,0,1,3},{iC,kH*kW,mC}}, modifGradO2, modifColumns); // [iC, kH*kW, mC] x [iC, mC, bS*oH*oW] = [iC, kW*kH, bS*oH*oW]
|
||||
helpers::col2im(*ctx, columns, *gradI, sH, sW, pH, pW, iH, iW, dH, dW); // [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW]
|
||||
helpers::col2im(*input->getContext(), columns, *gradI, sH, sW, pH, pW, iH, iW, dH, dW); // [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW]
|
||||
|
||||
if(!isNCHW) {
|
||||
delete input;
|
||||
|
@ -895,7 +1009,7 @@ static void depthwiseConv2dBP_(nd4j::LaunchContext & block, const NDArray* input
|
|||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename X, typename Y>
|
||||
static void sconv2d_(nd4j::LaunchContext & block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) {
|
||||
static void sconv2d_(nd4j::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) {
|
||||
|
||||
// input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||
// weightsDepth [kH, kW, iC, mC] always
|
||||
|
@ -1100,125 +1214,9 @@ static void upsampling3dBP_(const NDArray& gradO, NDArray& gradI, const bool isN
|
|||
}
|
||||
|
||||
|
||||
#ifdef HAVE_MKLDNN
|
||||
using namespace mkldnn;
|
||||
|
||||
void ConvolutionUtils::getMKLDNNMemoryDescPool2d(
|
||||
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, int poolingMode, int extraParam0, bool isNCHW,
|
||||
int bS, int iC, int iH, int iW, int oC, int oH, int oW,
|
||||
const NDArray* src, const NDArray* diff_src, const NDArray* dst, mkldnn::algorithm& algorithm,
|
||||
mkldnn::memory::desc* pool_src_md, mkldnn::memory::desc* pool_diff_src_md, mkldnn::memory::desc* pool_dst_md,
|
||||
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md,
|
||||
mkldnn::memory::dims& pool_strides, mkldnn::memory::dims& pool_kernel, mkldnn::memory::dims& pool_padding, mkldnn::memory::dims& pool_padding_r) {
|
||||
mkldnn::memory::dims pool_src_tz = { bS, iC, iH, iW };
|
||||
mkldnn::memory::dims pool_dst_tz = { bS, oC, oH, oW };
|
||||
|
||||
pool_strides = { sH, sW };
|
||||
pool_kernel = { kH, kW };
|
||||
pool_padding = { pH, pW };
|
||||
pool_padding_r = { (oH - 1) * sH - iH + kH - pH,
|
||||
(oW - 1) * sW - iW + kW - pW };
|
||||
|
||||
algorithm = poolingMode == 0 ? pooling_max
|
||||
: extraParam0 == 0 ? pooling_avg_exclude_padding
|
||||
: pooling_avg_include_padding;
|
||||
auto type = mkldnn::memory::data_type::f32;
|
||||
auto format = isNCHW ? mkldnn::memory::format::nchw : mkldnn::memory::format::nhwc;
|
||||
auto supposed_to_be_any_format = mkldnn::memory::format::nChw8c; // doesn't work with "any"
|
||||
|
||||
if (src != nullptr && src->getBuffer() != nullptr && pool_src_md != nullptr) {
|
||||
*pool_src_md = mkldnn::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format);
|
||||
*user_src_md = mkldnn::memory::desc({ pool_src_tz }, type, format);
|
||||
user_src_md->data.format = mkldnn_blocked; // overrides "format = isNCHW ? nchw : nhwc"
|
||||
user_src_md->data.layout_desc.blocking.strides[0][0] = src->stridesOf()[isNCHW ? 0 : 0];
|
||||
user_src_md->data.layout_desc.blocking.strides[0][1] = src->stridesOf()[isNCHW ? 1 : 3];
|
||||
user_src_md->data.layout_desc.blocking.strides[0][2] = src->stridesOf()[isNCHW ? 2 : 1];
|
||||
user_src_md->data.layout_desc.blocking.strides[0][3] = src->stridesOf()[isNCHW ? 3 : 2];
|
||||
}
|
||||
|
||||
if (diff_src != nullptr && diff_src->getBuffer() != nullptr && pool_diff_src_md != nullptr) {
|
||||
*pool_diff_src_md = mkldnn::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format);
|
||||
*user_diff_src_md = mkldnn::memory::desc({ pool_src_tz }, type, format);
|
||||
user_diff_src_md->data.format = mkldnn_blocked; // overrides "format = isNCHW ? nchw : nhwc"
|
||||
user_diff_src_md->data.layout_desc.blocking.strides[0][0] = diff_src->stridesOf()[isNCHW ? 0 : 0];
|
||||
user_diff_src_md->data.layout_desc.blocking.strides[0][1] = diff_src->stridesOf()[isNCHW ? 1 : 3];
|
||||
user_diff_src_md->data.layout_desc.blocking.strides[0][2] = diff_src->stridesOf()[isNCHW ? 2 : 1];
|
||||
user_diff_src_md->data.layout_desc.blocking.strides[0][3] = diff_src->stridesOf()[isNCHW ? 3 : 2];
|
||||
}
|
||||
|
||||
if (dst != nullptr && dst->getBuffer() != nullptr && pool_dst_md != nullptr) {
|
||||
*pool_dst_md = mkldnn::memory::desc({ pool_dst_tz }, type, supposed_to_be_any_format);
|
||||
*user_dst_md = mkldnn::memory::desc({ pool_dst_tz }, type, format);
|
||||
user_dst_md->data.format = mkldnn_blocked; // overrides "format = isNCHW ? nchw : nhwc"
|
||||
user_dst_md->data.layout_desc.blocking.strides[0][0] = dst->stridesOf()[isNCHW ? 0 : 0];
|
||||
user_dst_md->data.layout_desc.blocking.strides[0][1] = dst->stridesOf()[isNCHW ? 1 : 3];
|
||||
user_dst_md->data.layout_desc.blocking.strides[0][2] = dst->stridesOf()[isNCHW ? 2 : 1];
|
||||
user_dst_md->data.layout_desc.blocking.strides[0][3] = dst->stridesOf()[isNCHW ? 3 : 2];
|
||||
}
|
||||
}
|
||||
|
||||
void ConvolutionUtils::getMKLDNNMemoryDescPool3d(
|
||||
int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, int poolingMode, int extraParam0, bool isNCDHW,
|
||||
int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW,
|
||||
const NDArray* src, const NDArray* diff_src, const NDArray* dst, mkldnn::algorithm& algorithm,
|
||||
mkldnn::memory::desc* pool_src_md, mkldnn::memory::desc* pool_diff_src_md, mkldnn::memory::desc* pool_dst_md,
|
||||
mkldnn::memory::desc* user_src_md, mkldnn::memory::desc* user_diff_src_md, mkldnn::memory::desc* user_dst_md,
|
||||
mkldnn::memory::dims& pool_strides, mkldnn::memory::dims& pool_kernel, mkldnn::memory::dims& pool_padding, mkldnn::memory::dims& pool_padding_r) {
|
||||
mkldnn::memory::dims pool_src_tz = { bS, iC, iD, iH, iW };
|
||||
mkldnn::memory::dims pool_dst_tz = { bS, oC, oD, oH, oW };
|
||||
|
||||
pool_strides = { sD, sH, sW };
|
||||
pool_kernel = { kD, kH, kW };
|
||||
pool_padding = { pD, pH, pW };
|
||||
pool_padding_r = { (oD - 1) * sD - iD + kD - pD,
|
||||
(oH - 1) * sH - iH + kH - pH,
|
||||
(oW - 1) * sW - iW + kW - pW };
|
||||
|
||||
algorithm = poolingMode == 0 ? pooling_max
|
||||
: extraParam0 == 0 ? pooling_avg_exclude_padding
|
||||
: pooling_avg_include_padding;
|
||||
auto type = mkldnn::memory::data_type::f32;
|
||||
auto format = isNCDHW ? mkldnn::memory::format::ncdhw : mkldnn::memory::format::ndhwc;
|
||||
auto supposed_to_be_any_format = mkldnn::memory::format::nCdhw8c; // doesn't work with "any"
|
||||
|
||||
if (src != nullptr && src->getBuffer() != nullptr && pool_src_md != nullptr) {
|
||||
*pool_src_md = mkldnn::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format);
|
||||
*user_src_md = mkldnn::memory::desc({ pool_src_tz }, type, format);
|
||||
user_src_md->data.format = mkldnn_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
|
||||
user_src_md->data.layout_desc.blocking.strides[0][0] = src->stridesOf()[isNCDHW ? 0 : 0];
|
||||
user_src_md->data.layout_desc.blocking.strides[0][1] = src->stridesOf()[isNCDHW ? 1 : 4];
|
||||
user_src_md->data.layout_desc.blocking.strides[0][2] = src->stridesOf()[isNCDHW ? 2 : 1];
|
||||
user_src_md->data.layout_desc.blocking.strides[0][3] = src->stridesOf()[isNCDHW ? 3 : 2];
|
||||
user_src_md->data.layout_desc.blocking.strides[0][4] = src->stridesOf()[isNCDHW ? 4 : 3];
|
||||
}
|
||||
|
||||
if (diff_src != nullptr && diff_src->getBuffer() != nullptr && pool_diff_src_md != nullptr) {
|
||||
*pool_diff_src_md = mkldnn::memory::desc({ pool_src_tz }, type, supposed_to_be_any_format);
|
||||
*user_diff_src_md = mkldnn::memory::desc({ pool_src_tz }, type, format);
|
||||
user_diff_src_md->data.format = mkldnn_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
|
||||
user_diff_src_md->data.layout_desc.blocking.strides[0][0] = diff_src->stridesOf()[isNCDHW ? 0 : 0];
|
||||
user_diff_src_md->data.layout_desc.blocking.strides[0][1] = diff_src->stridesOf()[isNCDHW ? 1 : 4];
|
||||
user_diff_src_md->data.layout_desc.blocking.strides[0][2] = diff_src->stridesOf()[isNCDHW ? 2 : 1];
|
||||
user_diff_src_md->data.layout_desc.blocking.strides[0][3] = diff_src->stridesOf()[isNCDHW ? 3 : 2];
|
||||
user_diff_src_md->data.layout_desc.blocking.strides[0][4] = diff_src->stridesOf()[isNCDHW ? 4 : 3];
|
||||
}
|
||||
|
||||
if (dst != nullptr && dst->getBuffer() != nullptr && pool_dst_md != nullptr) {
|
||||
*pool_dst_md = mkldnn::memory::desc({ pool_dst_tz }, type, supposed_to_be_any_format);
|
||||
*user_dst_md = mkldnn::memory::desc({ pool_dst_tz }, type, format);
|
||||
user_dst_md->data.format = mkldnn_blocked; // overrides "format = isNCDHW ? ncdhw : ndhwc"
|
||||
user_dst_md->data.layout_desc.blocking.strides[0][0] = dst->stridesOf()[isNCDHW ? 0 : 0];
|
||||
user_dst_md->data.layout_desc.blocking.strides[0][1] = dst->stridesOf()[isNCDHW ? 1 : 4];
|
||||
user_dst_md->data.layout_desc.blocking.strides[0][2] = dst->stridesOf()[isNCDHW ? 2 : 1];
|
||||
user_dst_md->data.layout_desc.blocking.strides[0][3] = dst->stridesOf()[isNCDHW ? 3 : 2];
|
||||
user_dst_md->data.layout_desc.blocking.strides[0][4] = dst->stridesOf()[isNCDHW ? 4 : 3];
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
static void pooling2d_(nd4j::LaunchContext & block, const NDArray& input, NDArray& output, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0) {
|
||||
static void pooling2d_(nd4j::graph::Context& block, const NDArray& input, NDArray& output, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0) {
|
||||
// input is [bS, iC, iH, iW]
|
||||
// output is [bS, iC, oH, oW]
|
||||
T* out = output.bufferAsT<T>();
|
||||
|
@ -1454,7 +1452,7 @@ static void pooling2d_(nd4j::LaunchContext & block, const NDArray& input, NDArra
|
|||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
static void pooling3d_(nd4j::LaunchContext & block, const NDArray& input, NDArray& output, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0) {
|
||||
static void pooling3d_(nd4j::graph::Context& block, const NDArray& input, NDArray& output, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0) {
|
||||
// input is [bS, iC, iD, iH, iW]
|
||||
// output is [bS, iC, oD, oH, oW]
|
||||
T* out = output.bufferAsT<T>();
|
||||
|
@ -1726,7 +1724,7 @@ static void pooling3d_(nd4j::LaunchContext & block, const NDArray& input, NDArra
|
|||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
static void pooling2dBP_(nd4j::LaunchContext & block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0) {
|
||||
static void pooling2dBP_(nd4j::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0) {
|
||||
// input [bS, iC, iH, iW]
|
||||
// gradI [bS, iC, iH, iW] -> gradI is output in this function
|
||||
// gradO [bS, iC, oH, oW]
|
||||
|
@ -2015,7 +2013,7 @@ static void pooling2dBP_(nd4j::LaunchContext & block, const NDArray& input, cons
|
|||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
static void pooling3dBP_(nd4j::LaunchContext & block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0) {
|
||||
static void pooling3dBP_(nd4j::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0) {
|
||||
// input [bS, iC, iD, iH, iW]
|
||||
// gradI [bS, iC, iD, iH, iW] -> gradI is output in this function
|
||||
// gradO [bS, iC, oD, oH, oW]
|
||||
|
@ -2351,61 +2349,61 @@ static void pooling3dBP_(nd4j::LaunchContext & block, const NDArray& input, cons
|
|||
|
||||
|
||||
|
||||
void ConvolutionUtils::conv2d(nd4j::LaunchContext & block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) {
|
||||
void ConvolutionUtils::conv2d(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) {
|
||||
BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), conv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), LIBND4J_TYPES, FLOAT_TYPES);
|
||||
}
|
||||
void ConvolutionUtils::conv2dBP(nd4j::LaunchContext & block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) {
|
||||
void ConvolutionUtils::conv2dBP(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) {
|
||||
BUILD_DOUBLE_SELECTOR(input->dataType(), gradO->dataType(), conv2dBP_, (block, input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), LIBND4J_TYPES, FLOAT_TYPES);
|
||||
}
|
||||
void ConvolutionUtils::depthwiseConv2d(nd4j::LaunchContext & block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) {
|
||||
BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), depthwiseConv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), LIBND4J_TYPES, FLOAT_TYPES);
|
||||
void ConvolutionUtils::depthwiseConv2d(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) {
|
||||
BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), depthwiseConv2d_, (input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), LIBND4J_TYPES, FLOAT_TYPES);
|
||||
}
|
||||
void ConvolutionUtils::depthwiseConv2dBP(nd4j::LaunchContext & block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) {
|
||||
BUILD_DOUBLE_SELECTOR(input->dataType(), gradO->dataType(), depthwiseConv2dBP_, (block, input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), LIBND4J_TYPES, FLOAT_TYPES);
|
||||
void ConvolutionUtils::depthwiseConv2dBP(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) {
|
||||
BUILD_DOUBLE_SELECTOR(input->dataType(), gradO->dataType(), depthwiseConv2dBP_, (input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), LIBND4J_TYPES, FLOAT_TYPES);
|
||||
}
|
||||
void ConvolutionUtils::sconv2d(nd4j::LaunchContext & block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) {
|
||||
void ConvolutionUtils::sconv2d(nd4j::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) {
|
||||
BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), sconv2d_, (block, input, weightsDepth, weightsPoint, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), LIBND4J_TYPES, FLOAT_TYPES);
|
||||
}
|
||||
void ConvolutionUtils::vol2col(nd4j::LaunchContext & block, const NDArray& volume, NDArray& columns, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) {
|
||||
void ConvolutionUtils::vol2col(nd4j::graph::Context& block, const NDArray& volume, NDArray& columns, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) {
|
||||
BUILD_SINGLE_SELECTOR(volume.dataType(), vol2col_, (volume, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW), LIBND4J_TYPES);
|
||||
}
|
||||
void ConvolutionUtils::col2vol(nd4j::LaunchContext & block, const NDArray& columns, NDArray& volume, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) {
|
||||
void ConvolutionUtils::col2vol(nd4j::graph::Context& block, const NDArray& columns, NDArray& volume, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) {
|
||||
BUILD_SINGLE_SELECTOR(volume.dataType(), col2vol_, (columns, volume, sD, sH, sW, pD, pH, pW, dD, dH, dW), LIBND4J_TYPES);
|
||||
}
|
||||
void ConvolutionUtils::upsampling2d(nd4j::LaunchContext & block, const NDArray& input, NDArray& output, const int factorH, const int factorW, const bool isNCHW) {
|
||||
void ConvolutionUtils::upsampling2d(nd4j::graph::Context& block, const NDArray& input, NDArray& output, const int factorH, const int factorW, const bool isNCHW) {
|
||||
BUILD_SINGLE_SELECTOR(input.dataType(), upsampling2d_, (input, output, factorH, factorW, isNCHW), LIBND4J_TYPES);
|
||||
}
|
||||
void ConvolutionUtils::upsampling3d(nd4j::LaunchContext & block, const NDArray& input, NDArray& output, const int factorD, const int factorH, const int factorW, const bool isNCDHW) {
|
||||
void ConvolutionUtils::upsampling3d(nd4j::graph::Context& block, const NDArray& input, NDArray& output, const int factorD, const int factorH, const int factorW, const bool isNCDHW) {
|
||||
BUILD_SINGLE_SELECTOR(input.dataType(), upsampling3d_, (input, output, factorD, factorH, factorW, isNCDHW), LIBND4J_TYPES);
|
||||
}
|
||||
void ConvolutionUtils::upsampling2dBP(nd4j::LaunchContext & block, const NDArray& gradO, NDArray& gradI, const bool isNCHW) {
|
||||
void ConvolutionUtils::upsampling2dBP(nd4j::graph::Context& block, const NDArray& gradO, NDArray& gradI, const bool isNCHW) {
|
||||
BUILD_SINGLE_SELECTOR(gradO.dataType(), upsampling2dBP_, (gradO, gradI, isNCHW), LIBND4J_TYPES);
|
||||
}
|
||||
void ConvolutionUtils::upsampling3dBP(nd4j::LaunchContext & block, const NDArray& gradO, NDArray& gradI, const bool isNCHW) {
|
||||
void ConvolutionUtils::upsampling3dBP(nd4j::graph::Context& block, const NDArray& gradO, NDArray& gradI, const bool isNCHW) {
|
||||
BUILD_SINGLE_SELECTOR(gradO.dataType(), upsampling3dBP_, (gradO, gradI, isNCHW), LIBND4J_TYPES);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void ConvolutionUtils::pooling2d(nd4j::LaunchContext & block, const NDArray& input, NDArray& output, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const PoolingType poolingMode, const int extraParam0) {
|
||||
void ConvolutionUtils::pooling2d(nd4j::graph::Context& block, const NDArray& input, NDArray& output, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const PoolingType poolingMode, const int extraParam0) {
|
||||
BUILD_SINGLE_SELECTOR(input.dataType(), pooling2d_, (block, input, output, kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0), LIBND4J_TYPES);
|
||||
}
|
||||
void ConvolutionUtils::pooling3d(nd4j::LaunchContext & block, const NDArray& input, NDArray& output, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0) {
|
||||
void ConvolutionUtils::pooling3d(nd4j::graph::Context& block, const NDArray& input, NDArray& output, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0) {
|
||||
BUILD_SINGLE_SELECTOR(input.dataType(), pooling3d_, (block, input, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0), LIBND4J_TYPES);
|
||||
}
|
||||
void ConvolutionUtils::pooling2dBP(nd4j::LaunchContext & block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0) {
|
||||
void ConvolutionUtils::pooling2dBP(nd4j::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0) {
|
||||
BUILD_SINGLE_SELECTOR(input.dataType(), pooling2dBP_, (block, input, gradO, gradI, kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0), LIBND4J_TYPES);
|
||||
}
|
||||
void ConvolutionUtils::pooling3dBP(nd4j::LaunchContext & block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0) {
|
||||
void ConvolutionUtils::pooling3dBP(nd4j::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0) {
|
||||
BUILD_SINGLE_SELECTOR(input.dataType(), pooling3dBP_, (block, input, gradO, gradI, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, extraParam0), LIBND4J_TYPES);
|
||||
}
|
||||
|
||||
|
||||
BUILD_DOUBLE_TEMPLATE(template void conv2d_, (nd4j::LaunchContext & block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW), LIBND4J_TYPES, FLOAT_TYPES);
|
||||
BUILD_DOUBLE_TEMPLATE(template void conv2dBP_, (nd4j::LaunchContext & block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW), LIBND4J_TYPES, FLOAT_TYPES);
|
||||
BUILD_DOUBLE_TEMPLATE(template void depthwiseConv2d_, (nd4j::LaunchContext & block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW), LIBND4J_TYPES, FLOAT_TYPES);
|
||||
BUILD_DOUBLE_TEMPLATE(template void depthwiseConv2dBP_, (nd4j::LaunchContext & block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW), LIBND4J_TYPES, FLOAT_TYPES);
|
||||
BUILD_DOUBLE_TEMPLATE(template void sconv2d_, (nd4j::LaunchContext & block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW), LIBND4J_TYPES, FLOAT_TYPES);
|
||||
BUILD_DOUBLE_TEMPLATE(template void conv2d_, (nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW), LIBND4J_TYPES, FLOAT_TYPES);
|
||||
BUILD_DOUBLE_TEMPLATE(template void conv2dBP_, (nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW), LIBND4J_TYPES, FLOAT_TYPES);
|
||||
BUILD_DOUBLE_TEMPLATE(template void depthwiseConv2d_, (const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW), LIBND4J_TYPES, FLOAT_TYPES);
|
||||
BUILD_DOUBLE_TEMPLATE(template void depthwiseConv2dBP_, (const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW), LIBND4J_TYPES, FLOAT_TYPES);
|
||||
BUILD_DOUBLE_TEMPLATE(template void sconv2d_, (nd4j::graph::Context& block, const NDArray* input, const NDArray* weightsDepth, const NDArray* weightsPoint, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW), LIBND4J_TYPES, FLOAT_TYPES);
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template void upsampling2d_, (const NDArray& input, NDArray& output, const int factorH, const int factorW, const bool isNCHW), LIBND4J_TYPES);
|
||||
BUILD_SINGLE_TEMPLATE(template void upsampling3d_, (const NDArray& input, NDArray& output, const int factorD, const int factorH, const int factorW, const bool isNCDHW), LIBND4J_TYPES);
|
||||
|
@ -2413,10 +2411,10 @@ BUILD_SINGLE_TEMPLATE(template void upsampling2dBP_, (const NDArray& gradO, NDAr
|
|||
BUILD_SINGLE_TEMPLATE(template void upsampling3dBP_, (const NDArray& gradO, NDArray& gradI, const bool isNCHW), LIBND4J_TYPES);
|
||||
BUILD_SINGLE_TEMPLATE(template void vol2col_, (const NDArray& volume, NDArray& columns, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW), LIBND4J_TYPES);
|
||||
BUILD_SINGLE_TEMPLATE(template void col2vol_, (const NDArray& columns, NDArray& volume, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW), LIBND4J_TYPES);
|
||||
BUILD_SINGLE_TEMPLATE(template void pooling2d_, (nd4j::LaunchContext & block, const NDArray& input, NDArray& output, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0), LIBND4J_TYPES);
|
||||
BUILD_SINGLE_TEMPLATE(template void pooling3d_, (nd4j::LaunchContext & block, const NDArray& input, NDArray& output, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0), LIBND4J_TYPES);
|
||||
BUILD_SINGLE_TEMPLATE(template void pooling2dBP_, (nd4j::LaunchContext & block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0), LIBND4J_TYPES);
|
||||
BUILD_SINGLE_TEMPLATE(template void pooling3dBP_, (nd4j::LaunchContext & block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0), LIBND4J_TYPES);
|
||||
BUILD_SINGLE_TEMPLATE(template void pooling2d_, (nd4j::graph::Context& block, const NDArray& input, NDArray& output, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0), LIBND4J_TYPES);
|
||||
BUILD_SINGLE_TEMPLATE(template void pooling3d_, (nd4j::graph::Context& block, const NDArray& input, NDArray& output, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0), LIBND4J_TYPES);
|
||||
BUILD_SINGLE_TEMPLATE(template void pooling2dBP_, (nd4j::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, const int poolingMode, const int extraParam0), LIBND4J_TYPES);
|
||||
BUILD_SINGLE_TEMPLATE(template void pooling3dBP_, (nd4j::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, const int poolingMode, const int extraParam0), LIBND4J_TYPES);
|
||||
|
||||
}
|
||||
}
|
|
@ -227,15 +227,15 @@ void lstmBlockCell(const NDArray* xt, const NDArray* cLast, const NDArray* yLast
|
|||
|
||||
//NDArray* NDArrayFactory::create_( const char order, const std::vector<Nd4jLong> &shape, nd4j::DataType dataType, nd4j::memory::Workspace* workspace) {
|
||||
std::vector<Nd4jLong> shape = {bS, 4*numUnits};
|
||||
auto m = NDArrayFactory::create_('c', shape, xt->dataType(), nullptr);
|
||||
MmulHelper::mmul(&concatOut, W, m, 1.0f, 0.0f, 'c'); //mmul: [bs, (nIn+numUnits)]* [(inSize+numUnits), 4*numUnits] = [bs, 4*numUnits] - C result array
|
||||
*m += (*b); //addiRowVector
|
||||
auto m = NDArrayFactory::create('c', shape, xt->dataType());
|
||||
MmulHelper::mmul(&concatOut, W, &m, 1.0f, 0.0f, 'c'); //mmul: [bs, (nIn+numUnits)]* [(inSize+numUnits), 4*numUnits] = [bs, 4*numUnits] - C result array
|
||||
m += (*b); //addiRowVector
|
||||
|
||||
//Note: weights are ordered [inputGate, blockInput, forgetGate, outputGate] to match TF (TF code comments state [i,f,z/ci,o] but behaviour is [i,z,f,o])
|
||||
auto zi = (*m)({0,0, 0, numUnits}); // z for input modulation gate, [bS, numUnits]
|
||||
auto zz = (*m)({0,0, numUnits, 2*numUnits}); // z for block input, [bS, numUnits]
|
||||
auto zf = (*m)({0,0, 2*numUnits, 3*numUnits}); // z for forget gate, [bS, numUnits]
|
||||
auto zo = (*m)({0,0, 3*numUnits, 4*numUnits}); // z for output gate, [bS, numUnits]
|
||||
auto zi = (m)({0,0, 0, numUnits}); // z for input modulation gate, [bS, numUnits]
|
||||
auto zz = (m)({0,0, numUnits, 2*numUnits}); // z for block input, [bS, numUnits]
|
||||
auto zf = (m)({0,0, 2*numUnits, 3*numUnits}); // z for forget gate, [bS, numUnits]
|
||||
auto zo = (m)({0,0, 3*numUnits, 4*numUnits}); // z for output gate, [bS, numUnits]
|
||||
|
||||
if(peephole) { // add peephole connections: z + ct_1*Wc
|
||||
zi += (*cLast) * (*Wci); // add peephole connections to input gate
|
||||
|
|
|
@ -57,7 +57,7 @@ namespace helpers {
|
|||
ConvolutionUtils::calcPadding2D(pY, pX, oY, oX, inY, inX, params[0], params[1], params[2], params[3], params[6], params[7]);
|
||||
|
||||
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - poolingMode; 9 - divisor;
|
||||
ConvolutionUtils::pooling2d(*block.launchContext(), *input, *values, kY, kX, sY, sX, pY, pX, dY, dX, PoolingType::MAX_POOL, 1);
|
||||
ConvolutionUtils::pooling2d(block, *input, *values, kY, kX, sY, sX, pY, pX, dY, dX, PoolingType::MAX_POOL, 1);
|
||||
|
||||
if (nullptr != indices) {
|
||||
// for max_pool_with_argmax
|
||||
|
|
|
@ -53,9 +53,16 @@ namespace nd4j {
|
|||
dtype = nd4j::DataType::BOOL;
|
||||
|
||||
if(shape::isEmpty(x) || shape::isEmpty(y)) {
|
||||
//Edge case: broadcasting with empty array gives empty array output (behaviour to match TF for import cases)
|
||||
auto empty = ConstantShapeHelper::getInstance()->emptyShapeInfo(dtype);
|
||||
shapeList->push_back(empty);
|
||||
// this is edge case, [3, 4] + [] = []
|
||||
if ((shape::isEmpty(x) && shape::rank(x) == 0) || (shape::isEmpty(y) && shape::rank(y) == 0)) {
|
||||
shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor::emptyDescriptor(dtype)));
|
||||
return shapeList;
|
||||
}
|
||||
|
||||
|
||||
Nd4jLong *newshape = nullptr;
|
||||
ShapeUtils::evalBroadcastShapeInfo(x, y, true, newshape, block.workspace());
|
||||
shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(newshape, dtype)));
|
||||
} else if (shape::isScalar(x) && shape::isScalar(y)) {
|
||||
if (shape::rank(x) >= shape::rank(y)) {
|
||||
shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(x, dtype)));
|
||||
|
|
|
@ -2873,7 +2873,7 @@ namespace simdOps {
|
|||
const static functions::ReduceType reduceType = functions::ReduceType::MAX;
|
||||
|
||||
op_def static X startingValue(const X *input) {
|
||||
return -nd4j::DataTypeUtils::max<X>();
|
||||
return -nd4j::DataTypeUtils::infOrMax<X>();
|
||||
}
|
||||
|
||||
op_def static X merge(X old, X opOutput, X *extraParams) {
|
||||
|
@ -3051,7 +3051,7 @@ namespace simdOps {
|
|||
const static functions::ReduceType reduceType = functions::ReduceType::MIN;
|
||||
|
||||
op_def static X startingValue(const X *input) {
|
||||
return nd4j::DataTypeUtils::max<X>();
|
||||
return nd4j::DataTypeUtils::infOrMax<X>();
|
||||
}
|
||||
|
||||
op_def static X merge(X old, X opOutput, X *extraParams) {
|
||||
|
@ -3831,7 +3831,7 @@ namespace simdOps {
|
|||
}
|
||||
|
||||
static _CUDA_HD inline X startingValue(const X *input) {
|
||||
return -nd4j::DataTypeUtils::max<X>();
|
||||
return -nd4j::DataTypeUtils::infOrMax<X>();
|
||||
}
|
||||
|
||||
static _CUDA_HD inline functions::indexreduce::IndexValue<X> startingIndexValue(X *input) {
|
||||
|
@ -3890,7 +3890,7 @@ namespace simdOps {
|
|||
}
|
||||
|
||||
static _CUDA_HD inline X startingValue(const X *input) {
|
||||
return -nd4j::DataTypeUtils::max<X>();
|
||||
return -nd4j::DataTypeUtils::infOrMax<X>();
|
||||
}
|
||||
|
||||
static _CUDA_HD inline functions::indexreduce::IndexValue<X> startingIndexValue(X *input) {
|
||||
|
@ -3958,7 +3958,7 @@ namespace simdOps {
|
|||
}
|
||||
|
||||
static _CUDA_HD inline X startingValue(const X *input) {
|
||||
return -nd4j::DataTypeUtils::max<X>();
|
||||
return -nd4j::DataTypeUtils::infOrMax<X>();
|
||||
}
|
||||
|
||||
static _CUDA_HD inline functions::indexreduce::IndexValue<X> startingIndexValue(X *input) {
|
||||
|
@ -3984,7 +3984,7 @@ namespace simdOps {
|
|||
}
|
||||
|
||||
static _CUDA_HD inline X startingValue(const X *input) {
|
||||
return nd4j::DataTypeUtils::max<X>();
|
||||
return nd4j::DataTypeUtils::infOrMax<X>();
|
||||
}
|
||||
|
||||
static _CUDA_HD inline functions::indexreduce::IndexValue<X> startingIndexValue(X *input) {
|
||||
|
@ -4040,7 +4040,7 @@ namespace simdOps {
|
|||
}
|
||||
|
||||
static _CUDA_HD inline X startingValue(const X *input) {
|
||||
return nd4j::DataTypeUtils::max<X>();
|
||||
return nd4j::DataTypeUtils::infOrMax<X>();
|
||||
}
|
||||
|
||||
static _CUDA_HD inline functions::indexreduce::IndexValue<X> startingIndexValue(X *input) {
|
||||
|
|
|
@ -580,6 +580,152 @@ TEST_F(BroadcastableOpsTests, broadcast_empty_1) {
|
|||
ASSERT_TRUE(z.equalsTo(zExp));
|
||||
}
|
||||
|
||||
TEST_F(BroadcastableOpsTests, broadcast_empty_2) {
|
||||
|
||||
NDArray y('c', {1,4}, {1,2,3,4});
|
||||
NDArray x = NDArrayFactory::create<double>('c', {0, 4});
|
||||
NDArray e = NDArrayFactory::create<double>('c', {0, 4});;
|
||||
|
||||
nd4j::ops::multiply op;
|
||||
auto status = op.execute({&x, &y}, {&x}, {}, {}, {});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||
ASSERT_TRUE(e.isSameShape(x));
|
||||
ASSERT_TRUE(e.equalsTo(x));
|
||||
}
|
||||
|
||||
TEST_F(BroadcastableOpsTests, broadcast_empty_3) {
|
||||
|
||||
NDArray x = NDArrayFactory::create<float>('c', {1, 0, 2});
|
||||
NDArray y('c', {}, {0.1}, nd4j::DataType::FLOAT32);
|
||||
NDArray e = NDArrayFactory::create<float>('c', {1, 0, 2});;
|
||||
|
||||
nd4j::ops::maximum op;
|
||||
auto result = op.execute({&x, &y}, {}, {});
|
||||
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
|
||||
ASSERT_TRUE(e.isSameShape(z));
|
||||
ASSERT_TRUE(e.equalsTo(*z));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(BroadcastableOpsTests, broadcast_empty_4) {
|
||||
|
||||
NDArray x = NDArrayFactory::create<float>('c', {1, 0, 1});
|
||||
NDArray y = NDArrayFactory::create<float>('c', {1, 0, 2});
|
||||
NDArray e = NDArrayFactory::create<float>('c', {1, 0, 2});;
|
||||
|
||||
nd4j::ops::maximum op;
|
||||
auto result = op.execute({&x, &y}, {}, {});
|
||||
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
|
||||
ASSERT_TRUE(e.isSameShape(z));
|
||||
ASSERT_TRUE(e.equalsTo(*z));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(BroadcastableOpsTests, broadcast_empty_5) {
|
||||
|
||||
NDArray x = NDArrayFactory::create<float>('c', {1, 0, 1});
|
||||
NDArray y = NDArrayFactory::create<float>('c', {1, 0, 2});
|
||||
NDArray e = NDArrayFactory::create<float>('c', {1, 0, 2});;
|
||||
|
||||
nd4j::ops::realdiv op;
|
||||
auto result = op.execute({&x, &y}, {}, {});
|
||||
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
|
||||
ASSERT_TRUE(e.isSameShape(z));
|
||||
ASSERT_TRUE(e.equalsTo(*z));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(BroadcastableOpsTests, broadcast_empty_6) {
|
||||
|
||||
NDArray x = NDArrayFactory::create<float>('c', {1, 0, 1});
|
||||
NDArray y = NDArrayFactory::create<float>('c', {1, 2}, {2, 2});
|
||||
NDArray e = NDArrayFactory::create<float>('c', {1, 0, 2});;
|
||||
|
||||
nd4j::ops::realdiv op;
|
||||
auto result = op.execute({&x, &y}, {}, {});
|
||||
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
|
||||
ASSERT_TRUE(e.isSameShape(z));
|
||||
ASSERT_TRUE(e.equalsTo(*z));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(BroadcastableOpsTests, broadcast_empty_7) {
|
||||
|
||||
NDArray x = NDArrayFactory::create<float>('c', {1, 0, 2, 1});
|
||||
NDArray y = NDArrayFactory::create<float>('c', {1, 2, 0});
|
||||
NDArray e = NDArrayFactory::create<float>('c', {1, 0, 2, 0});;
|
||||
|
||||
nd4j::ops::realdiv op;
|
||||
auto result = op.execute({&x, &y}, {}, {});
|
||||
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
|
||||
ASSERT_TRUE(e.isSameShape(z));
|
||||
ASSERT_TRUE(e.equalsTo(*z));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
|
||||
TEST_F(BroadcastableOpsTests, broadcast_bool_empty_1) {
|
||||
|
||||
NDArray y('c', {3,4}, {0,0,0,0, 1,2,3,4, 1,2,3,4});
|
||||
NDArray x(nd4j::DataType::DOUBLE, y.getContext(), false);
|
||||
NDArray z(nd4j::DataType::BOOL, y.getContext(), false);
|
||||
NDArray zExp(nd4j::DataType::BOOL, y.getContext(), false);
|
||||
|
||||
nd4j::ops::greater op;
|
||||
auto status = op.execute({&x, &y}, {&z}, {}, {}, {});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||
ASSERT_TRUE(z.isSameShape(zExp));
|
||||
ASSERT_TRUE(z.equalsTo(zExp));
|
||||
}
|
||||
|
||||
TEST_F(BroadcastableOpsTests, broadcast_bool_empty_2) {
|
||||
|
||||
NDArray y('c', {1,4}, {1,2,3,4});
|
||||
NDArray x = NDArrayFactory::create<double>('c', {0, 4});
|
||||
NDArray e = NDArrayFactory::create<bool>('c', {0, 4});;
|
||||
|
||||
|
||||
nd4j::ops::greater op;
|
||||
auto result = op.execute({&x, &y}, {}, {}, {});
|
||||
|
||||
auto z = result->at(0);
|
||||
|
||||
z->printShapeInfo("z");
|
||||
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
ASSERT_TRUE(e.isSameShape(z));
|
||||
ASSERT_TRUE(e.equalsTo(*z));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(BroadcastableOpsTests, broadcast_bool_1) {
|
||||
|
||||
NDArray x('c', {3, 1, 2}, nd4j::DataType::FLOAT32);
|
||||
|
|
|
@ -2021,7 +2021,8 @@ TEST_F(ConvolutionTests1, vol2col_test1) {
|
|||
// PointersManager manager(columnsExpected.getContext());
|
||||
// manager.printDevContentOnHost<float>(columnsExpected.getSpecialBuffer(), columnsExpected.lengthOf());
|
||||
|
||||
nd4j::ops::ConvolutionUtils::vol2col(*LaunchContext::defaultContext(), volume, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW);
|
||||
graph::Context context(1);
|
||||
nd4j::ops::ConvolutionUtils::vol2col(context, volume, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW);
|
||||
|
||||
ASSERT_TRUE(columns.equalsTo(columnsExpected));
|
||||
}
|
||||
|
@ -2052,7 +2053,8 @@ TEST_F(ConvolutionTests1, vol2col_test2) {
|
|||
-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
|
||||
-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.});
|
||||
|
||||
nd4j::ops::ConvolutionUtils::vol2col(*LaunchContext::defaultContext(), volume, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW);
|
||||
graph::Context context(1);
|
||||
nd4j::ops::ConvolutionUtils::vol2col(context, volume, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW);
|
||||
|
||||
ASSERT_TRUE(columns.equalsTo(columnsExpected));
|
||||
}
|
||||
|
|
|
@ -1302,8 +1302,8 @@ TEST_F(DeclarableOpsTests10, broadcast_to_test6) {
|
|||
TEST_F(DeclarableOpsTests10, broadcast_to_test7) {
|
||||
|
||||
auto input = NDArrayFactory::create<double>(10.f);
|
||||
auto shape = NDArrayFactory::create<double>(0.f);
|
||||
auto exp = NDArrayFactory::create<double>(10.f);
|
||||
auto shape = NDArrayFactory::create<Nd4jLong>(1);
|
||||
auto exp = NDArrayFactory::create<double>('c', {1}, {10.});
|
||||
|
||||
nd4j::ops::broadcast_to op;
|
||||
auto results = op.execute({&input, &shape}, {}, {}, {});
|
||||
|
@ -2261,8 +2261,8 @@ TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_1) {
|
|||
|
||||
NDArray x('c', {2,3}, {-63.80f, -63.75f, -63.70f, -63.5f, 0.0f, 0.1f}, nd4j::DataType::FLOAT32);
|
||||
NDArray exp('c', {2,3}, {-63.75f, -63.75f, -63.75f, -63.251953f, 0.0f, 0.0f}, nd4j::DataType::FLOAT32);
|
||||
NDArray min('c', {0}, {-63.65f}, nd4j::DataType::FLOAT32);
|
||||
NDArray max('c', {0}, {0.1f}, nd4j::DataType::FLOAT32);
|
||||
NDArray min('c', {}, {-63.65f}, nd4j::DataType::FLOAT32);
|
||||
NDArray max('c', {}, {0.1f}, nd4j::DataType::FLOAT32);
|
||||
|
||||
nd4j::ops::fake_quant_with_min_max_vars op;
|
||||
auto results = op.execute({&x, &min, &max}, {}, {});
|
||||
|
|
|
@ -136,7 +136,7 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test3) {
|
|||
|
||||
NDArray dLdpExp('c', {2,3,4}, {-12.49997,-13.04346, -13.63635, -14.28571,-14.99999,-15.78947, -16.66666, -17.64705,-18.75 ,-20. , -21.42857, -23.07692,
|
||||
-24.99999,-27.27272, -29.99999, -33.33332,-37.49999,-42.85713, -49.99998, -59.99998,-74.99995,-99.99992,-149.99986,-299.99911});
|
||||
NDArray dLdwExp('c', {0}, {-227.77286});
|
||||
NDArray dLdwExp('c', {}, {-227.77286});
|
||||
NDArray dLdlExp('c', {2,3,4}, {1.58903, 1.22117, 0.99621, 0.82911, 0.69315, 0.57634, 0.47223, 0.37689, 0.28768, 0.20273, 0.12058, 0.04002,
|
||||
-0.04002,-0.12058,-0.20273,-0.28768,-0.37689,-0.47223,-0.57634,-0.69315,-0.82911,-0.99621,-1.22117,-1.58903});
|
||||
|
||||
|
@ -261,7 +261,7 @@ TEST_F(DeclarableOpsTests11, log_loss_grad_test7) {
|
|||
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
||||
NDArray weights(nd4j::DataType::DOUBLE);
|
||||
|
||||
NDArray dLdwExp('c', {0}, {0.});
|
||||
NDArray dLdwExp('c', {}, {0.});
|
||||
|
||||
predictions.linspace(0.04, 0.04);
|
||||
labels.linspace(1);
|
||||
|
@ -583,7 +583,7 @@ TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test3) {
|
|||
|
||||
NDArray dLdpExp('c', {2,3,4}, {-0.96, -1.92, -2.88, -3.84, -4.8 , -5.76, -6.72, -7.68, -8.64, -9.6 ,-10.56,-11.52,
|
||||
-12.48,-13.44,-14.4 ,-15.36,-16.32,-17.28,-18.24,-19.2 ,-20.16,-21.12,-22.08,-23.04});
|
||||
NDArray dLdwExp('c', {0}, {4515.84});
|
||||
NDArray dLdwExp('c', {}, {4515.84});
|
||||
|
||||
predictions.linspace(0.04, 0.04);
|
||||
labels.linspace(1);
|
||||
|
@ -702,7 +702,7 @@ TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test7) {
|
|||
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
||||
NDArray weights(nd4j::DataType::DOUBLE);
|
||||
|
||||
NDArray dLdwExp('c', {0}, {0.});
|
||||
NDArray dLdwExp('c', {}, {0.});
|
||||
|
||||
predictions.linspace(0.04, 0.04);
|
||||
labels.linspace(1);
|
||||
|
@ -1031,7 +1031,7 @@ TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test3) {
|
|||
|
||||
NDArray dLdpExp('c', {2,3,4}, {-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,
|
||||
-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5});
|
||||
NDArray dLdwExp('c', {0}, {288.});
|
||||
NDArray dLdwExp('c', {}, {288.});
|
||||
|
||||
predictions.linspace(0.04, 0.04);
|
||||
labels.linspace(1);
|
||||
|
@ -1150,7 +1150,7 @@ TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test7) {
|
|||
NDArray predictions('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
||||
NDArray weights(nd4j::DataType::DOUBLE);
|
||||
|
||||
NDArray dLdwExp('c', {0}, {0.});
|
||||
NDArray dLdwExp('c', {}, {0.});
|
||||
|
||||
predictions.linspace(0.04, 0.04);
|
||||
labels.linspace(1);
|
||||
|
@ -1519,7 +1519,7 @@ TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test3) {
|
|||
|
||||
NDArray dLdpExp('c', {2,3,4}, {-0.18499,-0.53 ,-0.875 ,-1.22 ,-1.56501,-1.91002,-2.25504,-2.60008,-2.94514,-3.29023,-3.63534,-3.98048,
|
||||
-4.32566,-4.67087,-5.01613,-5.36143,-5.70677,-6.05217,-6.39762,-6.74313,-7.0887 ,-7.43432,-7.78001,-8.12577});
|
||||
NDArray dLdwExp('c', {0}, {-91.52109});
|
||||
NDArray dLdwExp('c', {}, {-91.52109});
|
||||
NDArray dLdlExp('c', {2,3,4}, {0.028, 0.014, -0., -0.014,-0.028, -0.042, -0.056, -0.07 ,-0.084, -0.098, -0.112, -0.126,
|
||||
-0.14 , -0.154, -0.168, -0.182,-0.196, -0.21 , -0.224, -0.238,-0.252, -0.266, -0.28 , -0.294});
|
||||
|
||||
|
@ -1642,7 +1642,7 @@ TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test7) {
|
|||
NDArray logits('c', {2,3,4}, nd4j::DataType::DOUBLE);
|
||||
NDArray weights(nd4j::DataType::DOUBLE);
|
||||
|
||||
NDArray dLdwExp('c', {0}, {0.});
|
||||
NDArray dLdwExp('c', {}, {0.});
|
||||
|
||||
logits.linspace(-0.08, 0.04);
|
||||
labels.linspace(1);
|
||||
|
@ -2001,10 +2001,10 @@ TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test3) {
|
|||
|
||||
NDArray labels('c', {4}, {0,0,1,0}, nd4j::DataType::INT32);
|
||||
NDArray logits('c', {4}, nd4j::DataType::DOUBLE);
|
||||
NDArray weights('c', {0}, nd4j::DataType::DOUBLE);
|
||||
NDArray weights('c', {}, {0}, nd4j::DataType::DOUBLE);
|
||||
|
||||
NDArray dLdpExp('c', {4}, {0.125, 0.125, -0.375, 0.125});
|
||||
NDArray dLdwExp('c', {0}, {1.38629});
|
||||
NDArray dLdwExp('c', {}, {1.38629});
|
||||
|
||||
logits = 2.;
|
||||
weights.assign(0.5);
|
||||
|
@ -2032,10 +2032,10 @@ TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test4) {
|
|||
|
||||
NDArray labels('c', {4}, {0,0,1,0}, nd4j::DataType::INT32);
|
||||
NDArray logits('c', {4}, nd4j::DataType::DOUBLE);
|
||||
NDArray weights('c', {0}, nd4j::DataType::DOUBLE);
|
||||
NDArray weights('c', {}, {0}, nd4j::DataType::DOUBLE);
|
||||
|
||||
NDArray dLdpExp('c', {4}, {0.23521, 0.2448 , -0.7452 , 0.26519});
|
||||
NDArray dLdwExp('c', {0}, {0.});
|
||||
NDArray dLdwExp('c', {}, {0.});
|
||||
|
||||
logits.linspace(-0.08, 0.04);
|
||||
weights = 0.5;
|
||||
|
@ -2466,7 +2466,7 @@ TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test2) {
|
|||
/////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test3) {
|
||||
|
||||
NDArray labels('c', {0}, {1}, nd4j::DataType::INT64);
|
||||
NDArray labels('c', {}, {1}, nd4j::DataType::INT64);
|
||||
NDArray logits('c', {2}, {-0.2, 0.3});
|
||||
|
||||
NDArray dLdpExp('c', {2}, {0.37754, -0.37754});
|
||||
|
|
|
@ -158,10 +158,10 @@ TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test4) {
|
|||
|
||||
NDArray labels('c', {1,4}, {-0.1, 0.3, 2, -1.4});
|
||||
NDArray predictions('c', {1,4}, nd4j::DataType::DOUBLE);
|
||||
NDArray weights('c', {0}, nd4j::DataType::DOUBLE);
|
||||
NDArray weights('c', {}, {0.}, nd4j::DataType::DOUBLE);
|
||||
|
||||
NDArray dLdpExp('c', {1,4}, {0.05, -0.15, -1., 0.7});
|
||||
NDArray dLdwExp('c', {0}, {1.3});
|
||||
NDArray dLdwExp('c', {}, {1.3});
|
||||
NDArray dLdlExp('c', {1,4}, {0.2, 0.1, -0. , -0.1});
|
||||
|
||||
predictions.linspace(-0.4, 0.2);
|
||||
|
@ -369,10 +369,10 @@ TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test9) {
|
|||
TEST_F(DeclarableOpsTests12, hinge_loss_14) {
|
||||
|
||||
NDArray logits('c', {3,4}, nd4j::DataType::DOUBLE);
|
||||
NDArray weights('c', {0}, {1.});
|
||||
NDArray weights('c', {}, {1.});
|
||||
NDArray labels('c', {3,4}, {0,1,1,0,1,0,1,0,1,0,1,0});
|
||||
|
||||
NDArray output('c', {0}, nd4j::DataType::DOUBLE);
|
||||
NDArray output('c', {}, {0.}, nd4j::DataType::DOUBLE);
|
||||
|
||||
logits.linspace(1.);
|
||||
weights.assign(1.);
|
||||
|
@ -594,7 +594,7 @@ TEST_F(DeclarableOpsTests12, TestMinimumBP_1) {
|
|||
TEST_F(DeclarableOpsTests12, reverse_test15) {
|
||||
|
||||
NDArray x('c', {5}, {1,2,3,4,5}, nd4j::DataType::DOUBLE);
|
||||
NDArray axis('c', {0}, {0}, nd4j::DataType::INT32);
|
||||
NDArray axis('c', {}, {0}, nd4j::DataType::INT32);
|
||||
NDArray z('c', {5}, nd4j::DataType::DOUBLE);
|
||||
NDArray exp('c', {5}, {5,4,3,2,1}, nd4j::DataType::DOUBLE);
|
||||
|
||||
|
|
|
@ -123,6 +123,7 @@ TEST_F(DeclarableOpsTests14, Test_EvalReductionShape_1) {
|
|||
ASSERT_EQ(Status::OK(), result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
z->printIndexedBuffer("Reduced shape");
|
||||
ASSERT_EQ(e, *z);
|
||||
|
||||
delete result;
|
||||
|
@ -213,3 +214,199 @@ TEST_F(DeclarableOpsTests14, Test_scalar_broadcast_2) {
|
|||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests14, test_empty_fill_1) {
|
||||
auto x = NDArrayFactory::empty<int>();
|
||||
auto y = NDArrayFactory::create<int>(1);
|
||||
|
||||
nd4j::ops::fill op;
|
||||
auto result = op.execute({&x, &y}, {}, {});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
ASSERT_EQ(y, *z);
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests14, test_lstmBlockCell_1) {
|
||||
auto a = NDArrayFactory::create<float>('c', {1, 5}, {0.7787856f, 0.80119777f, 0.72437465f, 0.23089433f, 0.72714126f});
|
||||
auto b = NDArrayFactory::create<float>('c', {1, 3});
|
||||
auto c = NDArrayFactory::create<float>('c', {1, 3});
|
||||
auto d = NDArrayFactory::create<float>('c', {8, 12}, {-0.15320599,-0.120416045,0.33126968,0.13921785,-0.32313538,-0.43956736,0.4756174,0.4335605,-0.5450856,-0.3943429,-0.28687626,0.068032146,-0.2793799,0.17298919,-0.36553562,-0.097853184,-0.2544747,-0.39872527,-0.14556861,-0.31479517,0.2559092,0.47166896,-0.31330687,0.47313118,0.5134543,-0.4678212,-0.12853557,0.26142156,0.43472284,-0.42842552,-0.1895876,0.538689,0.508651,-0.020272732,0.112327516,0.2704304,-0.046546757,0.32570732,-0.15148133,-0.19145513,0.18631572,-0.024152994,0.41603214,-0.3421499,0.0106860995,-0.2966229,-0.36713937,0.25841123,0.0843398,0.49082482,0.10800403,0.1874243,-0.26379472,-0.22531849,0.24924624,0.23119557,0.49940765,-0.051413506,0.20315129,-0.41888732,0.44097036,0.40453392,0.013338983,0.23434466,0.23942488,0.47894,-0.19898453,0.09253675,-0.032358468,-0.15213022,-0.3441009,-0.15600958,-0.08235118,0.12165731,-0.4481289,-0.4842423,-0.45797008,-0.4606034,0.08163166,-0.2981107,0.50207126,0.44195646,0.13850057,0.072246075,-0.34388685,0.030900061,0.35821778,0.47900867,0.5094063,0.23683065,0.18020362,-0.1369732,0.015235603,0.2786904,0.07954317,0.12543976});
|
||||
auto e = NDArrayFactory::create<float>('c', {3});
|
||||
auto f = NDArrayFactory::create<float>('c', {3});
|
||||
auto g = NDArrayFactory::create<float>('c', {3});
|
||||
auto h = NDArrayFactory::create<float>('c', {12});
|
||||
|
||||
auto z0 = NDArrayFactory::create<float>('c', {1, 3});
|
||||
auto z1 = NDArrayFactory::create<float>('c', {1, 3});
|
||||
auto z2 = NDArrayFactory::create<float>('c', {1, 3});
|
||||
auto z3 = NDArrayFactory::create<float>('c', {1, 3});
|
||||
auto z4 = NDArrayFactory::create<float>('c', {1, 3});
|
||||
auto z5 = NDArrayFactory::create<float>('c', {1, 3});
|
||||
auto z6 = NDArrayFactory::create<float>('c', {1, 3});
|
||||
|
||||
nd4j::ops::lstmBlockCell op;
|
||||
auto result = op.execute({&a, &b, &c, &d, &e, &f, &g, &h}, {&z0, &z1, &z2, &z3, &z4, &z5, &z6}, {1.0, -1.0}, {0}, {});
|
||||
ASSERT_EQ(Status::OK(), result);
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests14, test_empty_stack_1) {
|
||||
auto x = NDArrayFactory::create<float>('c', {0});
|
||||
auto e = NDArrayFactory::create<float>('c', {1, 0});
|
||||
|
||||
nd4j::ops::stack op;
|
||||
auto result = op.execute({&x}, {}, {0});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
ASSERT_EQ(e, *z);
|
||||
nd4j::ops::reduce_min sumOp;
|
||||
auto res2 = sumOp.execute({&e}, {1.}, {1});
|
||||
ASSERT_EQ(res2->status(), Status::OK());
|
||||
auto out = res2->at(0);
|
||||
out->printShapeInfo("ReduceSum empty shape with keep dims");
|
||||
out->printIndexedBuffer("ReduceSum scalar");
|
||||
ASSERT_EQ(out->e<float>(0), DataTypeUtils::infOrMax<float>());
|
||||
delete res2;
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests14, test_empty_stack_2) {
|
||||
auto x = NDArrayFactory::empty<float>();
|
||||
auto e = NDArrayFactory::create<float>('c', {0});
|
||||
|
||||
nd4j::ops::stack op;
|
||||
auto result = op.execute({&x}, {}, {0});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
ASSERT_EQ(e, *z);
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests14, test_empty_stack_3) {
|
||||
auto x = NDArrayFactory::empty<float>();
|
||||
auto e = NDArrayFactory::create<float>('c', {2, 0});
|
||||
|
||||
nd4j::ops::stack op;
|
||||
auto result = op.execute({&x, &x}, {}, {0});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
ASSERT_EQ(e, *z);
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests14, test_empty_stack_4) {
|
||||
auto x = NDArrayFactory::create<float>('c', {0});
|
||||
auto e = NDArrayFactory::create<float>('c', {2, 0});
|
||||
|
||||
nd4j::ops::stack op;
|
||||
auto result = op.execute({&x, &x}, {}, {0});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
ASSERT_EQ(e, *z);
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests14, test_empty_reduce_min_1) {
|
||||
|
||||
auto e = NDArrayFactory::create<float>('c', {1, 0});
|
||||
nd4j::ops::reduce_min sumOp;
|
||||
auto res2 = sumOp.execute({&e}, {1.}, {1});
|
||||
ASSERT_EQ(res2->status(), Status::OK());
|
||||
auto out = res2->at(0);
|
||||
|
||||
ASSERT_EQ(out->e<float>(0), DataTypeUtils::infOrMax<float>());
|
||||
delete res2;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests14, test_empty_reduce_max_1) {
|
||||
|
||||
auto e = NDArrayFactory::create<float>('c', {1, 0});
|
||||
nd4j::ops::reduce_max sumOp;
|
||||
auto res2 = sumOp.execute({&e}, {1.}, {1});
|
||||
ASSERT_EQ(res2->status(), Status::OK());
|
||||
auto out = res2->at(0);
|
||||
|
||||
ASSERT_EQ(out->e<float>(0), -DataTypeUtils::infOrMax<float>());
|
||||
delete res2;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests14, test_empty_reduce_sum_1) {
|
||||
|
||||
auto e = NDArrayFactory::create<float>('c', {1, 0});
|
||||
nd4j::ops::reduce_sum sumOp;
|
||||
auto res2 = sumOp.execute({&e}, {1.}, {1});
|
||||
ASSERT_EQ(res2->status(), Status::OK());
|
||||
auto out = res2->at(0);
|
||||
ASSERT_EQ(out->e<float>(0), 0.f);
|
||||
delete res2;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests14, test_empty_reduce_mean_1) {
|
||||
|
||||
auto e = NDArrayFactory::create<float>('c', {1, 0});
|
||||
nd4j::ops::reduce_mean sumOp;
|
||||
auto res2 = sumOp.execute({&e}, {1.}, {1});
|
||||
ASSERT_EQ(res2->status(), Status::OK());
|
||||
auto out = res2->at(0);
|
||||
out->printShapeInfo("ReduceMean empty shape with keep dims");
|
||||
out->printIndexedBuffer("ReduceMean scalar");
|
||||
ASSERT_EQ(out->e<float>(0), 0.f);
|
||||
delete res2;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests14, test_empty_argmax_1) {
|
||||
auto x = NDArrayFactory::create<float>('c', {1, 0});
|
||||
auto y = NDArrayFactory::create<int>(0);
|
||||
auto e = NDArrayFactory::create<Nd4jLong>('c', {0});
|
||||
|
||||
nd4j::ops::argmax op;
|
||||
//nd4j::ops::reduce_max op;
|
||||
|
||||
auto result = op.execute({&x, &y}, {}, {});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
|
||||
z->printShapeInfo("Z");
|
||||
|
||||
ASSERT_EQ(e, *z);
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests14, test_empty_argmax_2) {
|
||||
auto x = NDArrayFactory::create<float>('c', {1, 0});
|
||||
auto y = NDArrayFactory::create<int>(1);
|
||||
|
||||
nd4j::ops::argmax op;
|
||||
try {
|
||||
auto result = op.execute({&x, &y}, {&y}, {}, {}, {});
|
||||
ASSERT_TRUE(false);
|
||||
} catch (std::exception &e) {
|
||||
//
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests14, test_empty_tanh_5) {
|
||||
auto x = NDArrayFactory::create<float>('c', {32, 0});
|
||||
|
||||
nd4j::ops::tanh op;
|
||||
auto result = op.execute({&x}, {}, {});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
|
||||
ASSERT_TRUE(x.isSameShape(z));
|
||||
ASSERT_EQ(x, *z);
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
|
|
@ -905,9 +905,9 @@ TEST_F(DeclarableOpsTests4, Test_StridedSlice_Alex_2) {
|
|||
|
||||
TEST_F(DeclarableOpsTests4, Test_StridedSlice_Alex_3) {
|
||||
auto x = NDArrayFactory::create<double>('c', {1}, {10});
|
||||
auto begin = NDArrayFactory::create<double>('c', {1}, {0.});
|
||||
auto end = NDArrayFactory::create<double>('c', {1}, {0.});
|
||||
auto stride = NDArrayFactory::create<double>('c', {1}, {1});
|
||||
auto begin = NDArrayFactory::create<int>('c', {1}, {(int)0});
|
||||
auto end = NDArrayFactory::create<int>('c', {1}, {(int)0});
|
||||
auto stride = NDArrayFactory::create<int>('c', {1}, {1});
|
||||
//x.linspace(1);
|
||||
//auto exp = NDArrayFactory::create<double>('c', {1,3,4,5});
|
||||
//exp.linspace(1);
|
||||
|
|
|
@ -2421,6 +2421,26 @@ TEST_F(DeclarableOpsTests5, log_softmax_test11) {
|
|||
delete results;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests5, log_softmax_test12) {
|
||||
|
||||
auto input = NDArrayFactory::create<double>('c', {1, 4}, {0.1869, -1.4918, -0.6497, -0.8864});
|
||||
auto expOutput = NDArrayFactory::create<double>('c', {1, 4}, {-0.6738, -2.3525, -1.5104, -1.7472});
|
||||
|
||||
for (int i = 0; i < 10; ++i)
|
||||
{
|
||||
nd4j::ops::log_softmax op;
|
||||
auto results = op.execute({&input}, {}, {}, {}, false, nd4j::DataType::DOUBLE);
|
||||
auto z = results->at(0);
|
||||
|
||||
ASSERT_EQ(Status::OK(), results->status());
|
||||
ASSERT_TRUE(expOutput.isSameShape(z));
|
||||
ASSERT_TRUE(expOutput.equalsTo(z, 1e-4));
|
||||
|
||||
delete results;
|
||||
}
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests5, log_softmax_bp_test1) {
|
||||
|
||||
|
|
|
@ -109,7 +109,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_4) {
|
|||
auto e = NDArrayFactory::create<double>('c', {1}, {0.});
|
||||
auto s = NDArrayFactory::create<double>('c', {1}, {1.0});
|
||||
|
||||
//auto exp = NDArrayFactory::create<double>('c', {2}, {1.0f, 2.0f});
|
||||
auto exp = NDArrayFactory::create<double>(10);
|
||||
|
||||
//matrix.linspace(1);
|
||||
|
||||
|
@ -119,7 +119,8 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_4) {
|
|||
|
||||
auto z = result->at(0);
|
||||
z->printShapeInfo("SS OS shape");
|
||||
ASSERT_TRUE(z->isEmpty());
|
||||
z->printIndexedBuffer("SS OS out");
|
||||
ASSERT_TRUE(z->equalsTo(exp));
|
||||
//ASSERT_EQ(exp, *z);
|
||||
|
||||
delete result;
|
||||
|
|
|
@ -608,6 +608,24 @@ TEST_F(DeclarableOpsTests9, concat_test15) {
|
|||
delete result;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, concat_test16) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('c', {0,2,3});
|
||||
auto y = NDArrayFactory::create<double>('c', {0,2,3});
|
||||
auto exp = NDArrayFactory::create<double>('c', {0,2,3});
|
||||
|
||||
nd4j::ops::concat op;
|
||||
auto result = op.execute({&x, &y}, {}, {0});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, tile_bp_test3) {
|
||||
|
||||
|
|
|
@ -59,7 +59,8 @@ TEST_F(EmptyTests, Test_Create_Empty_2) {
|
|||
}
|
||||
|
||||
TEST_F(EmptyTests, Test_Concat_1) {
|
||||
auto empty = NDArrayFactory::empty_<float>();
|
||||
// auto empty = NDArrayFactory::empty_<float>();
|
||||
auto empty = new NDArray('c', {0}, nd4j::DataType::FLOAT32);//NDArrayFactory::create_<float>('c', {(Nd4jLong)0}};
|
||||
auto vector = NDArrayFactory::create_<float>('c', {1}, {1.0f});
|
||||
|
||||
ASSERT_TRUE(empty->isEmpty());
|
||||
|
@ -82,9 +83,9 @@ TEST_F(EmptyTests, Test_Concat_1) {
|
|||
|
||||
|
||||
TEST_F(EmptyTests, Test_Concat_2) {
|
||||
auto empty = NDArrayFactory::empty_<float>();
|
||||
auto scalar1 = NDArrayFactory::create_<float>(1.0f);
|
||||
auto scalar2 = NDArrayFactory::create_<float>(2.0f);
|
||||
auto empty = new NDArray('c', {0}, nd4j::DataType::FLOAT32); //NDArrayFactory::empty_<float>();
|
||||
auto scalar1 = NDArrayFactory::create_<float>('c', {1}, {1.0f});
|
||||
auto scalar2 = NDArrayFactory::create_<float>('c', {1}, {2.0f});
|
||||
auto exp = NDArrayFactory::create<float>('c', {2}, {1.f, 2.f});
|
||||
|
||||
ASSERT_TRUE(empty->isEmpty());
|
||||
|
@ -139,6 +140,23 @@ TEST_F(EmptyTests, Test_Reshape_2) {
|
|||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(EmptyTests, Test_Reshape_3) {
|
||||
auto x = NDArrayFactory::create<float>('c', {1, 0, 0, 2});
|
||||
auto y = NDArrayFactory::create<int>('c', {2}, {10, 0});
|
||||
auto e = NDArrayFactory::create<float>('c', {10, 0});
|
||||
|
||||
nd4j::ops::reshape op;
|
||||
auto result = op.execute({&x, &y}, {}, {});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
|
||||
ASSERT_TRUE(e.isSameShape(z));
|
||||
ASSERT_EQ(e, *z);
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(EmptyTests, Test_dup_1) {
|
||||
auto empty = NDArrayFactory::empty<int>();
|
||||
auto dup = empty.dup();
|
||||
|
@ -148,3 +166,47 @@ TEST_F(EmptyTests, Test_dup_1) {
|
|||
|
||||
delete dup;
|
||||
}
|
||||
|
||||
TEST_F(EmptyTests, test_shaped_empty_1) {
|
||||
auto empty = NDArrayFactory::create<float>('c', {2, 0, 3});
|
||||
std::vector<Nd4jLong> shape = {2, 0, 3};
|
||||
|
||||
ASSERT_EQ(nd4j::DataType::FLOAT32, empty.dataType());
|
||||
ASSERT_EQ(0, empty.lengthOf());
|
||||
ASSERT_TRUE(empty.isEmpty());
|
||||
ASSERT_EQ(shape, empty.getShapeAsVector());
|
||||
ASSERT_EQ(3, empty.rankOf());
|
||||
}
|
||||
|
||||
TEST_F(EmptyTests, test_shaped_empty_2) {
|
||||
auto empty = NDArrayFactory::create<float>('c', {0, 3});
|
||||
std::vector<Nd4jLong> shape = {0, 3};
|
||||
|
||||
ASSERT_EQ(nd4j::DataType::FLOAT32, empty.dataType());
|
||||
ASSERT_EQ(0, empty.lengthOf());
|
||||
ASSERT_TRUE(empty.isEmpty());
|
||||
ASSERT_EQ(shape, empty.getShapeAsVector());
|
||||
ASSERT_EQ(2, empty.rankOf());
|
||||
}
|
||||
|
||||
TEST_F(EmptyTests, test_shaped_empty_3) {
|
||||
auto empty = NDArrayFactory::create<float>('c', {0});
|
||||
std::vector<Nd4jLong> shape = {0};
|
||||
|
||||
ASSERT_EQ(nd4j::DataType::FLOAT32, empty.dataType());
|
||||
ASSERT_EQ(0, empty.lengthOf());
|
||||
ASSERT_TRUE(empty.isEmpty());
|
||||
ASSERT_EQ(shape, empty.getShapeAsVector());
|
||||
ASSERT_EQ(1, empty.rankOf());
|
||||
}
|
||||
|
||||
TEST_F(EmptyTests, test_shaped_empty_4) {
|
||||
auto shape = ConstantShapeHelper::getInstance()->vectorShapeInfo(0, nd4j::DataType::FLOAT32);
|
||||
shape::printShapeInfoLinear("shape", shape);
|
||||
NDArray array(shape, true, nd4j::LaunchContext::defaultContext());
|
||||
std::vector<Nd4jLong> shapeOf({0});
|
||||
|
||||
ASSERT_TRUE(array.isEmpty());
|
||||
ASSERT_EQ(1, array.rankOf());
|
||||
ASSERT_EQ(shapeOf, array.getShapeAsVector());
|
||||
}
|
|
@ -668,3 +668,47 @@ TEST_F(LegacyOpsTests, test_inverse_broadcast_2) {
|
|||
delete row;
|
||||
delete erow;
|
||||
}
|
||||
|
||||
TEST_F(LegacyOpsTests, test_legacy_reduce_empty_1) {
|
||||
auto x = NDArrayFactory::create<float>('c', {2, 0, 3});
|
||||
auto z = NDArrayFactory::create<float>('c', {2, 3});
|
||||
auto e = NDArrayFactory::create<float>('c', {2, 3});
|
||||
|
||||
int dim = 1;
|
||||
|
||||
NativeOpExecutioner::execReduceSame(LaunchContext::defaultContext(), reduce::SameOps::Sum, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), &dim, 1, x.shapeInfo(), nullptr);
|
||||
|
||||
ASSERT_EQ(e, z);
|
||||
}
|
||||
|
||||
TEST_F(LegacyOpsTests, test_legacy_reduce_empty_2) {
|
||||
auto x = NDArrayFactory::create<float>('c', {2, 0, 3});
|
||||
auto z = NDArrayFactory::create<float>('c', {2, 3});
|
||||
auto e = NDArrayFactory::create<float>('c', {2, 3});
|
||||
e.assign(std::numeric_limits<float>::infinity());
|
||||
|
||||
int dim = 1;
|
||||
|
||||
NativeOpExecutioner::execReduceSame(LaunchContext::defaultContext(), reduce::SameOps::Min, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), &dim, 1, x.shapeInfo(), nullptr);
|
||||
|
||||
ASSERT_EQ(e, z);
|
||||
}
|
||||
|
||||
TEST_F(LegacyOpsTests, test_legacy_reduce_empty_3) {
|
||||
auto x = NDArrayFactory::create<float>('c', {2, 0, 3});
|
||||
auto z = NDArrayFactory::create<float>('c', {2, 3});
|
||||
auto e = NDArrayFactory::create<float>('c', {2, 3});
|
||||
e.assign(-std::numeric_limits<float>::infinity());
|
||||
|
||||
int dim = 1;
|
||||
|
||||
NativeOpExecutioner::execReduceSame(LaunchContext::defaultContext(), reduce::SameOps::Max, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), &dim, 1, x.shapeInfo(), nullptr);
|
||||
|
||||
ASSERT_EQ(e, z);
|
||||
}
|
||||
|
||||
TEST_F(LegacyOpsTests, test_legacy_transform_float_1) {
|
||||
auto x = NDArrayFactory::create<float>('c', {1, 0, 4});
|
||||
|
||||
NativeOpExecutioner::execTransformFloat(LaunchContext::defaultContext(), transform::FloatOps::RSqrt, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, nullptr, nullptr);
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue