Added missing Java ops wrappers (#122)
* Timeouts added * Added some ops * Ops added * Fixed tests * Minor fix * Some fixes * Digamma added * Small fixes * Timeouts added * Added some ops * Ops added * Fixed tests * Minor fix * Some fixes * Digamma added * Small fixes * Fused batch norm fixes- Signed-off-by: AlexDBlack <blacka101@gmail.com> * Tests switched off. * Added test for resize_bicubic. * Eliminated wasted in test of bicubic resize. * Switched off multithreading explicit. * HsvToRgb and RgbToHsv added * Eliminated waste comments and conform proper float constants. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed multithreading with resize_bicubic helper for cpu platform. Signed-off-by: shugeo <sgazeos@gmail.com> * ResizeBicubic was fixed. * Some fixes * Fix op name * Validation fixed. * Clarifications for tests * Wrappers and small fixes for new ops.master
parent
e0a9cb6c08
commit
f5068f3980
|
@ -22,6 +22,7 @@ import org.junit.After;
|
|||
import org.junit.Before;
|
||||
import org.junit.Rule;
|
||||
import org.junit.rules.TestName;
|
||||
import org.junit.rules.Timeout;
|
||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||
|
@ -42,6 +43,8 @@ public class BaseDL4JTest {
|
|||
|
||||
@Rule
|
||||
public TestName name = new TestName();
|
||||
@Rule
|
||||
public Timeout timeout = Timeout.seconds(30);
|
||||
|
||||
protected long startTime;
|
||||
protected int threadCountBefore;
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.deeplearning4j.datasets.datavec;
|
||||
|
||||
import org.junit.rules.Timeout;
|
||||
import org.nd4j.shade.guava.io.Files;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.io.FileUtils;
|
||||
|
@ -70,6 +71,9 @@ import static org.nd4j.linalg.indexing.NDArrayIndex.point;
|
|||
@Slf4j
|
||||
public class RecordReaderDataSetiteratorTest extends BaseDL4JTest {
|
||||
|
||||
@Rule
|
||||
protected Timeout timeout = Timeout.seconds(300);
|
||||
|
||||
@Override
|
||||
public DataType getDataType(){
|
||||
return DataType.FLOAT;
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
package org.deeplearning4j.datasets.datavec;
|
||||
|
||||
|
||||
import org.junit.rules.Timeout;
|
||||
import org.nd4j.shade.guava.io.Files;
|
||||
import org.apache.commons.compress.utils.IOUtils;
|
||||
import org.apache.commons.io.FileUtils;
|
||||
|
@ -69,6 +70,9 @@ public class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest {
|
|||
@Rule
|
||||
public TemporaryFolder temporaryFolder = new TemporaryFolder();
|
||||
|
||||
@Rule
|
||||
protected Timeout timeout = Timeout.seconds(300);
|
||||
|
||||
@Test
|
||||
public void testsBasic() throws Exception {
|
||||
//Load details from CSV files; single input/output -> compare to RecordReaderDataSetIterator
|
||||
|
|
|
@ -22,7 +22,9 @@ import org.datavec.api.split.FileSplit;
|
|||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
|
||||
import org.deeplearning4j.nn.util.TestDataSetConsumer;
|
||||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
import org.junit.rules.Timeout;
|
||||
import org.nd4j.linalg.dataset.DataSet;
|
||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
@ -37,6 +39,9 @@ import static org.junit.Assert.*;
|
|||
|
||||
public class MultipleEpochsIteratorTest extends BaseDL4JTest {
|
||||
|
||||
@Rule
|
||||
protected Timeout timeout = Timeout.seconds(300);
|
||||
|
||||
@Test
|
||||
public void testNextAndReset() throws Exception {
|
||||
int epochs = 3;
|
||||
|
|
|
@ -161,7 +161,7 @@ public class EvalTest extends BaseDL4JTest {
|
|||
assertEquals(evalExpected.getConfusionMatrix(), evalActual.getConfusionMatrix());
|
||||
}
|
||||
|
||||
@Test
|
||||
@Test(timeout = 300000)
|
||||
public void testEvaluationWithMetaData() throws Exception {
|
||||
|
||||
RecordReader csv = new CSVRecordReader();
|
||||
|
|
|
@ -317,7 +317,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
|
|||
assertEquals(paramsMLN, paramsGraph);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Test(timeout = 300000)
|
||||
public void testIrisFitMultiDataSetIterator() throws Exception {
|
||||
|
||||
RecordReader rr = new CSVRecordReader(0, ',');
|
||||
|
|
|
@ -103,7 +103,7 @@ public class BarnesHutTsneTest extends BaseDL4JTest {
|
|||
assertArrayEquals(exp.data().asDouble(), b.getData().data().asDouble(), eps);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Test(timeout = 300000)
|
||||
public void testTsne() throws Exception {
|
||||
DataTypeUtil.setDTypeForContext(DataType.DOUBLE);
|
||||
Nd4j.getRandom().setSeed(123);
|
||||
|
|
|
@ -28,7 +28,9 @@ import org.deeplearning4j.nn.weights.WeightInitDistribution;
|
|||
import org.deeplearning4j.nn.weights.WeightInitRelu;
|
||||
import org.deeplearning4j.nn.weights.WeightInitXavier;
|
||||
import org.deeplearning4j.util.ModelSerializer;
|
||||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
import org.junit.rules.Timeout;
|
||||
import org.nd4j.linalg.activations.impl.ActivationLReLU;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
@ -42,6 +44,7 @@ import org.nd4j.linalg.lossfunctions.impl.LossNegativeLogLikelihood;
|
|||
import org.nd4j.resources.Resources;
|
||||
|
||||
import java.io.File;
|
||||
import java.sql.Time;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
|
@ -55,6 +58,9 @@ import static org.junit.Assert.*;
|
|||
*/
|
||||
public class RegressionTest050 extends BaseDL4JTest {
|
||||
|
||||
@Rule
|
||||
public Timeout timeout = Timeout.seconds(300);
|
||||
|
||||
@Override
|
||||
public DataType getDataType(){
|
||||
return DataType.FLOAT;
|
||||
|
|
|
@ -29,6 +29,7 @@ import org.deeplearning4j.nn.weights.WeightInit;
|
|||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
import org.junit.rules.TemporaryFolder;
|
||||
import org.junit.rules.Timeout;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.dataset.DataSet;
|
||||
import org.nd4j.linalg.dataset.api.preprocessor.Normalizer;
|
||||
|
@ -54,6 +55,9 @@ public class ModelGuesserTest extends BaseDL4JTest {
|
|||
@Rule
|
||||
public TemporaryFolder testDir = new TemporaryFolder();
|
||||
|
||||
@Rule
|
||||
public Timeout timeout = Timeout.seconds(300);
|
||||
|
||||
|
||||
@Test
|
||||
public void testModelGuessFile() throws Exception {
|
||||
|
|
|
@ -33,6 +33,7 @@ import org.junit.Ignore;
|
|||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
import org.junit.rules.TemporaryFolder;
|
||||
import org.junit.rules.Timeout;
|
||||
import org.nd4j.linalg.activations.impl.ActivationHardSigmoid;
|
||||
import org.nd4j.linalg.activations.impl.ActivationTanH;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
@ -60,6 +61,9 @@ public class FullModelComparisons extends BaseDL4JTest {
|
|||
@Rule
|
||||
public TemporaryFolder testDir = new TemporaryFolder();
|
||||
|
||||
@Rule
|
||||
public Timeout timeout = Timeout.seconds(300);
|
||||
|
||||
@Test
|
||||
public void lstmTest() throws IOException, UnsupportedKerasConfigurationException,
|
||||
InvalidKerasConfigurationException, InterruptedException {
|
||||
|
|
|
@ -30,7 +30,7 @@ import java.io.IOException;
|
|||
*/
|
||||
public class TimeSeriesGeneratorImportTest extends BaseDL4JTest {
|
||||
|
||||
@Test
|
||||
@Test(timeout=300000)
|
||||
public void importTimeSeriesTest() throws IOException, InvalidKerasConfigurationException {
|
||||
String path = "modelimport/keras/preprocessing/timeseries_generator.json";
|
||||
|
||||
|
|
|
@ -35,7 +35,7 @@ public class TokenizerImportTest extends BaseDL4JTest {
|
|||
ClassLoader classLoader = getClass().getClassLoader();
|
||||
|
||||
|
||||
@Test
|
||||
@Test(timeout=300000)
|
||||
public void importTest() throws IOException, InvalidKerasConfigurationException {
|
||||
|
||||
String path = "modelimport/keras/preprocessing/tokenizer.json";
|
||||
|
@ -51,7 +51,7 @@ public class TokenizerImportTest extends BaseDL4JTest {
|
|||
|
||||
}
|
||||
|
||||
@Test
|
||||
@Test(timeout=300000)
|
||||
public void importNumWordsNullTest() throws IOException, InvalidKerasConfigurationException {
|
||||
|
||||
String path = "modelimport/keras/preprocessing/tokenizer_num_words_null.json";
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.deeplearning4j.models;
|
||||
|
||||
import org.junit.rules.Timeout;
|
||||
import org.nd4j.shade.guava.io.Files;
|
||||
import org.nd4j.shade.guava.primitives.Doubles;
|
||||
import lombok.val;
|
||||
|
@ -75,6 +76,9 @@ public class WordVectorSerializerTest extends BaseDL4JTest {
|
|||
@Rule
|
||||
public TemporaryFolder testDir = new TemporaryFolder();
|
||||
|
||||
@Rule
|
||||
public Timeout timeout = Timeout.seconds(300);
|
||||
|
||||
private File textFile, binaryFile, textFile2;
|
||||
private File fastTextRaw, fastTextZip, fastTextGzip;
|
||||
String pathToWriteto;
|
||||
|
|
|
@ -62,7 +62,7 @@ public class VectorsConfigurationTest extends BaseDL4JTest {
|
|||
assertEquals(configuration, configuration2);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Test(timeout = 300000)
|
||||
public void testFromW2V() throws Exception {
|
||||
VectorsConfiguration configuration = new VectorsConfiguration();
|
||||
configuration.setHugeModelExpected(true);
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
|
||||
package org.deeplearning4j.models.word2vec;
|
||||
|
||||
import org.junit.Rule;
|
||||
import org.junit.rules.Timeout;
|
||||
import org.nd4j.shade.guava.primitives.Doubles;
|
||||
import org.nd4j.shade.guava.primitives.Ints;
|
||||
import lombok.val;
|
||||
|
@ -68,6 +70,9 @@ public class Word2VecTests extends BaseDL4JTest {
|
|||
private String pathToWriteto;
|
||||
private WordVectors googleModel;
|
||||
|
||||
@Rule
|
||||
public Timeout timeout = Timeout.seconds(300);
|
||||
|
||||
@Before
|
||||
public void before() throws Exception {
|
||||
File googleModelTextFile = new ClassPathResource("word2vecserialization/google_news_30.txt").getFile();
|
||||
|
|
|
@ -52,7 +52,7 @@ public class InMemoryLookupTableTest extends BaseDL4JTest {
|
|||
|
||||
}
|
||||
|
||||
@Test
|
||||
@Test(timeout = 300000)
|
||||
public void testConsumeOnEqualVocabs() throws Exception {
|
||||
TokenizerFactory t = new DefaultTokenizerFactory();
|
||||
t.setTokenPreProcessor(new CommonPreprocessor());
|
||||
|
@ -99,7 +99,7 @@ public class InMemoryLookupTableTest extends BaseDL4JTest {
|
|||
}
|
||||
|
||||
|
||||
@Test
|
||||
@Test(timeout = 300000)
|
||||
public void testConsumeOnNonEqualVocabs() throws Exception {
|
||||
TokenizerFactory t = new DefaultTokenizerFactory();
|
||||
t.setTokenPreProcessor(new CommonPreprocessor());
|
||||
|
|
|
@ -12,6 +12,7 @@ import org.junit.Ignore;
|
|||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
import org.junit.rules.TemporaryFolder;
|
||||
import org.junit.rules.Timeout;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
import org.nd4j.resources.Resources;
|
||||
|
||||
|
@ -27,6 +28,9 @@ import static org.junit.Assert.assertEquals;
|
|||
@Slf4j
|
||||
public class FastTextTest extends BaseDL4JTest {
|
||||
|
||||
@Rule
|
||||
public Timeout timeout = Timeout.seconds(300);
|
||||
|
||||
private File inputFile = Resources.asFile("models/fasttext/data/labeled_data.txt");
|
||||
private File supModelFile = Resources.asFile("models/fasttext/supervised.model.bin");
|
||||
private File cbowModelFile = Resources.asFile("models/fasttext/cbow.model.bin");
|
||||
|
|
|
@ -55,7 +55,7 @@ public class ParallelTransformerIteratorTest extends BaseDL4JTest {
|
|||
|
||||
}
|
||||
|
||||
@Test
|
||||
@Test(timeout = 300000)
|
||||
public void hasNext() throws Exception {
|
||||
SentenceIterator iterator = new BasicLineIterator(Resources.asFile("big/raw_sentences.txt"));
|
||||
|
||||
|
@ -77,7 +77,7 @@ public class ParallelTransformerIteratorTest extends BaseDL4JTest {
|
|||
assertEquals(97162, cnt);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Test(timeout = 300000)
|
||||
public void testSpeedComparison1() throws Exception {
|
||||
SentenceIterator iterator = new MutipleEpochsSentenceIterator(
|
||||
new BasicLineIterator(Resources.asFile("big/raw_sentences.txt")), 25);
|
||||
|
|
|
@ -82,7 +82,7 @@ public class Word2VecTestsSmall extends BaseDL4JTest {
|
|||
assertEquals(neighbours, nearestWords.size());
|
||||
}
|
||||
|
||||
@Test
|
||||
@Test(timeout = 300000)
|
||||
public void testUnkSerialization_1() throws Exception {
|
||||
val inputFile = Resources.asFile("big/raw_sentences.txt");
|
||||
|
||||
|
@ -142,7 +142,7 @@ public class Word2VecTestsSmall extends BaseDL4JTest {
|
|||
}
|
||||
|
||||
|
||||
@Test
|
||||
@Test(timeout = 300000)
|
||||
public void testW2VEmbeddingLayerInit() throws Exception {
|
||||
Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT);
|
||||
|
||||
|
|
|
@ -50,7 +50,7 @@ public class Word2VecDataSetIteratorTest extends BaseDL4JTest {
|
|||
/**
|
||||
* Basically all we want from this test - being able to finish without exceptions.
|
||||
*/
|
||||
@Test
|
||||
@Test(timeout = 300000)
|
||||
public void testIterator1() throws Exception {
|
||||
File inputFile = Resources.asFile("big/raw_sentences.txt");
|
||||
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
|
||||
|
|
|
@ -22,6 +22,7 @@ import lombok.val;
|
|||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.junit.Rule;
|
||||
import org.junit.rules.TemporaryFolder;
|
||||
import org.junit.rules.Timeout;
|
||||
import org.nd4j.linalg.io.ClassPathResource;
|
||||
import org.deeplearning4j.models.sequencevectors.interfaces.SequenceIterator;
|
||||
import org.deeplearning4j.models.sequencevectors.iterators.AbstractSequenceIterator;
|
||||
|
@ -43,6 +44,7 @@ import org.slf4j.Logger;
|
|||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.io.File;
|
||||
import java.sql.Time;
|
||||
import java.util.*;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
|
||||
|
@ -53,6 +55,9 @@ import static org.junit.Assert.*;
|
|||
*/
|
||||
public class VocabConstructorTest extends BaseDL4JTest {
|
||||
|
||||
@Rule
|
||||
public Timeout timeout = Timeout.seconds(300);
|
||||
|
||||
protected static final Logger log = LoggerFactory.getLogger(VocabConstructorTest.class);
|
||||
|
||||
TokenizerFactory t = new DefaultTokenizerFactory();
|
||||
|
|
|
@ -29,7 +29,7 @@ import static org.junit.Assert.assertEquals;
|
|||
* @author raver119@gmail.com
|
||||
*/
|
||||
public class AsyncLabelAwareIteratorTest extends BaseDL4JTest {
|
||||
@Test
|
||||
@Test(timeout = 300000)
|
||||
public void nextDocument() throws Exception {
|
||||
SentenceIterator sentence = new BasicLineIterator(Resources.asFile("big/raw_sentences.txt"));
|
||||
BasicLabelAwareIterator backed = new BasicLabelAwareIterator.Builder(sentence).build();
|
||||
|
|
|
@ -17,6 +17,8 @@
|
|||
package org.deeplearning4j.text.documentiterator;
|
||||
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.junit.Rule;
|
||||
import org.junit.rules.Timeout;
|
||||
import org.nd4j.linalg.io.ClassPathResource;
|
||||
import org.deeplearning4j.text.sentenceiterator.BasicLineIterator;
|
||||
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
|
||||
|
@ -33,6 +35,9 @@ import static org.junit.Assert.assertEquals;
|
|||
*/
|
||||
public class BasicLabelAwareIteratorTest extends BaseDL4JTest {
|
||||
|
||||
@Rule
|
||||
public Timeout timeout = Timeout.seconds(300);
|
||||
|
||||
@Before
|
||||
public void setUp() throws Exception {
|
||||
|
||||
|
|
|
@ -30,7 +30,7 @@ import static org.junit.Assert.assertEquals;
|
|||
*/
|
||||
public class AggregatingSentenceIteratorTest extends BaseDL4JTest {
|
||||
|
||||
@Test
|
||||
@Test(timeout = 300000)
|
||||
public void testHasNext() throws Exception {
|
||||
File file = Resources.asFile("/big/raw_sentences.txt");
|
||||
BasicLineIterator iterator = new BasicLineIterator(file);
|
||||
|
|
|
@ -17,6 +17,8 @@
|
|||
package org.deeplearning4j.text.sentenceiterator;
|
||||
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.junit.Rule;
|
||||
import org.junit.rules.Timeout;
|
||||
import org.nd4j.linalg.io.ClassPathResource;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
|
@ -32,6 +34,9 @@ import static org.junit.Assert.assertEquals;
|
|||
*/
|
||||
public class BasicLineIteratorTest extends BaseDL4JTest {
|
||||
|
||||
@Rule
|
||||
public Timeout timeout = Timeout.seconds(300);
|
||||
|
||||
@Before
|
||||
public void setUp() throws Exception {
|
||||
|
||||
|
|
|
@ -27,7 +27,7 @@ import static org.junit.Assert.assertEquals;
|
|||
* @author raver119@gmail.com
|
||||
*/
|
||||
public class MutipleEpochsSentenceIteratorTest extends BaseDL4JTest {
|
||||
@Test
|
||||
@Test(timeout = 300000)
|
||||
public void hasNext() throws Exception {
|
||||
SentenceIterator iterator = new MutipleEpochsSentenceIterator(
|
||||
new BasicLineIterator(Resources.asFile("big/raw_sentences.txt")), 100);
|
||||
|
|
|
@ -17,6 +17,8 @@
|
|||
package org.deeplearning4j.text.sentenceiterator;
|
||||
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.junit.Rule;
|
||||
import org.junit.rules.Timeout;
|
||||
import org.nd4j.linalg.io.ClassPathResource;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.resources.Resources;
|
||||
|
@ -33,6 +35,9 @@ import static org.junit.Assert.assertTrue;
|
|||
*/
|
||||
public class PrefetchingSentenceIteratorTest extends BaseDL4JTest {
|
||||
|
||||
@Rule
|
||||
public Timeout timeout = Timeout.seconds(300);
|
||||
|
||||
protected static final Logger log = LoggerFactory.getLogger(PrefetchingSentenceIteratorTest.class);
|
||||
|
||||
@Test
|
||||
|
|
|
@ -205,7 +205,7 @@ public class BertWordPieceTokenizerTests extends BaseDL4JTest {
|
|||
}
|
||||
|
||||
|
||||
@Test
|
||||
@Test(timeout = 300000)
|
||||
public void testBertWordPieceTokenizer10() throws Exception {
|
||||
File f = Resources.asFile("deeplearning4j-nlp/bert/uncased_L-12_H-768_A-12/vocab.txt");
|
||||
BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(f, true, true, StandardCharsets.UTF_8);
|
||||
|
|
|
@ -27,7 +27,8 @@ import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
|||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.layers.*;
|
||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||
import org.junit.Ignore;
|
||||
import org.junit.*;
|
||||
import org.junit.rules.Timeout;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.io.ClassPathResource;
|
||||
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
||||
|
@ -38,9 +39,6 @@ import org.deeplearning4j.parallelism.inference.InferenceObservable;
|
|||
import org.deeplearning4j.parallelism.inference.observers.BasicInferenceObserver;
|
||||
import org.deeplearning4j.parallelism.inference.observers.BatchedInferenceObservable;
|
||||
import org.deeplearning4j.util.ModelSerializer;
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.DataSet;
|
||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||
|
@ -66,6 +64,9 @@ public class ParallelInferenceTest extends BaseDL4JTest {
|
|||
private static MultiLayerNetwork model;
|
||||
private static DataSetIterator iterator;
|
||||
|
||||
@Rule
|
||||
public Timeout timeout = Timeout.seconds(300);
|
||||
|
||||
@Before
|
||||
public void setUp() throws Exception {
|
||||
if (model == null) {
|
||||
|
|
|
@ -100,7 +100,7 @@ public class ManualTests {
|
|||
}
|
||||
|
||||
|
||||
@Test
|
||||
@Test(timeout = 300000)
|
||||
public void testTsne() throws Exception {
|
||||
DataTypeUtil.setDTypeForContext(DataType.DOUBLE);
|
||||
Nd4j.getRandom().setSeed(123);
|
||||
|
@ -208,7 +208,7 @@ public class ManualTests {
|
|||
|
||||
}
|
||||
|
||||
@Test
|
||||
@Test(timeout = 300000)
|
||||
public void testWord2VecPlot() throws Exception {
|
||||
File inputFile = Resources.asFile("big/raw_sentences.txt");
|
||||
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
|
||||
|
|
|
@ -32,7 +32,7 @@ namespace ops {
|
|||
->setAllowedOutputTypes({ALL_FLOATS});
|
||||
}
|
||||
|
||||
CUSTOM_OP_IMPL(fused_batch_norm, 3, 1, false, 0, 2) {
|
||||
CUSTOM_OP_IMPL(fused_batch_norm, 3, 3, false, 0, 2) {
|
||||
auto x = INPUT_VARIABLE(0); // [bS,iH,iW,iD] (NHWC) or [bS,iD,iH,iW] (NCHW)
|
||||
auto scale = INPUT_VARIABLE(1); // [iD]
|
||||
auto offset = INPUT_VARIABLE(2); // [iD]
|
||||
|
@ -70,7 +70,7 @@ CUSTOM_OP_IMPL(fused_batch_norm, 3, 1, false, 0, 2) {
|
|||
REQUIRE_TRUE(variance->rankOf() == 1 && variance->sizeAt(0) == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input variance array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(variance).c_str());
|
||||
}
|
||||
else {
|
||||
REQUIRE_TRUE(block.width() == 3, 0, "CUSTOM_OP fused_batch_norm: when isTraining=true then number of input arrays must be equal to 3, but got %i instead !", block.width());
|
||||
//REQUIRE_TRUE(block.width() == 3, 0, "CUSTOM_OP fused_batch_norm: when isTraining=true then number of input arrays must be equal to 3, but got %i instead !", block.width());
|
||||
std::vector<Nd4jLong> shape = {iD};
|
||||
mean = NDArrayFactory::create_(scale->ordering(), shape, scale->dataType(), block.launchContext());
|
||||
variance = NDArrayFactory::create_(scale->ordering(), shape, scale->dataType(), block.launchContext());
|
||||
|
|
|
@ -682,27 +682,29 @@ namespace helpers {
|
|||
pY2[pt_index], pY3[pt_index]);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
template <typename T, typename F>
|
||||
static void
|
||||
bicubicInterpolateWithCaching(NDArray const* image, ImageResizerState const& resizerState, bool const halfPixelCenters, NDArray* output) {
|
||||
std::vector<WeightsAndIndices> xWais(resizerState.outWidth);
|
||||
|
||||
computeXWeightsAndIndices(resizerState, halfPixelCenters, &xWais);
|
||||
|
||||
const auto numChannels = resizerState.channels;
|
||||
const Nd4jLong inRowWidth = resizerState.inWidth * numChannels;
|
||||
const Nd4jLong inBatchWidth = resizerState.inHeight * inRowWidth;
|
||||
const auto batchNum = resizerState.batchSize;
|
||||
const auto outHeight = resizerState.outHeight;
|
||||
const auto outWidth = resizerState.outWidth;
|
||||
|
||||
const T* inputPtr = image->getDataBuffer()->primaryAsT<T>();
|
||||
float* pOutputY = output->dataBuffer()->primaryAsT<float>(); // output is float anyway
|
||||
std::vector<float> cachedValue(numChannels == 3 ? 0 : 4 * numChannels, 0);
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
const T* inputPtr = image->getDataBuffer()->primaryAsT<T>();
|
||||
F* pOutputY = output->dataBuffer()->primaryAsT<F>(); // output is float anyway
|
||||
std::vector<float> cachedValue(numChannels == 3 ? 0 : 4 * numChannels, 0);
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
for (auto b = start; b < stop; ++b) {
|
||||
auto pInput = inputPtr + b * inBatchWidth;
|
||||
|
||||
for (auto y = 0; y < resizerState.outHeight; ++y) {
|
||||
auto pOutput = &pOutputY[(b * resizerState.outHeight + y) * resizerState.outWidth * numChannels];
|
||||
for (auto y = 0; y < outHeight; ++y) {
|
||||
auto pOutput = &pOutputY[(b * outHeight + y) * outWidth * numChannels];
|
||||
|
||||
WeightsAndIndices yWai;
|
||||
if (halfPixelCenters) {
|
||||
|
@ -713,16 +715,16 @@ namespace helpers {
|
|||
resizerState.heightScale, y, resizerState.inHeight, &yWai);
|
||||
}
|
||||
// Make pointers represent offsets of data in inputBPtr.
|
||||
const T *y_ptr_0 = pInput + yWai._index0 * inRowWidth;
|
||||
const T *y_ptr_1 = pInput + yWai._index1 * inRowWidth;
|
||||
const T *y_ptr_2 = pInput + yWai._index2 * inRowWidth;
|
||||
const T *y_ptr_3 = pInput + yWai._index3 * inRowWidth;
|
||||
const T* y_ptr_0 = pInput + yWai._index0 * inRowWidth;
|
||||
const T* y_ptr_1 = pInput + yWai._index1 * inRowWidth;
|
||||
const T* y_ptr_2 = pInput + yWai._index2 * inRowWidth;
|
||||
const T* y_ptr_3 = pInput + yWai._index3 * inRowWidth;
|
||||
|
||||
if (numChannels == 3) {
|
||||
// Manually unroll case of 3 channels.
|
||||
float cached_value_0[4] = {0};
|
||||
float cached_value_1[4] = {0};
|
||||
float cached_value_2[4] = {0};
|
||||
F cached_value_0[4] = {0};
|
||||
F cached_value_1[4] = {0};
|
||||
F cached_value_2[4] = {0};
|
||||
for (auto x = 0; x < resizerState.outWidth; ++x) {
|
||||
const WeightsAndIndices &xWai = xWais[x];
|
||||
// Shift values in cached_value_* to fill first '_advance' values.
|
||||
|
@ -854,7 +856,7 @@ namespace helpers {
|
|||
}
|
||||
for (auto c = 0; c < numChannels; ++c) {
|
||||
pOutput[x * numChannels + c] =
|
||||
compute(&cachedValue[4 * c], xWai._weight0, xWai._weight1,
|
||||
(F)compute(&cachedValue[4 * c], xWai._weight0, xWai._weight1,
|
||||
xWai._weight2, xWai._weight3);
|
||||
}
|
||||
}
|
||||
|
@ -862,7 +864,7 @@ namespace helpers {
|
|||
}
|
||||
}
|
||||
};
|
||||
samediff::Threads::parallel_tad(func, 0, resizerState.batchSize);
|
||||
samediff::Threads::parallel_tad(func, 0, batchNum);
|
||||
}
|
||||
|
||||
// simplified bicubic resize without antialiasing
|
||||
|
@ -873,7 +875,7 @@ namespace helpers {
|
|||
ImageResizerState st(alignCorners, halfPixelAlign); // align_corners, half_pixel_align
|
||||
int res = st.validateAndCreateOutput(image, width, height);
|
||||
if (res == Status::OK())
|
||||
bicubicInterpolateWithCaching<T>(image, st, halfPixelAlign, output);
|
||||
bicubicInterpolateWithCaching<T, float>(image, st, halfPixelAlign, output);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
|
|
@ -975,6 +975,118 @@ TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test6) {
|
|||
delete results;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test7) {
|
||||
|
||||
NDArray input = NDArrayFactory::create<double>('c', {2, 5, 5, 1}, {
|
||||
0.2303, 0.7950, 0.8171, 0.0451, 0.3690, 0.6846, 0.2727, 0.2770, 0.2381, 0.9511,
|
||||
0.4116, 0.3997, 0.4075, 0.6275, 0.8018, 0.0678, 0.6221, 0.2982, 0.1524, 0.2613,
|
||||
0.7425, 0.6036, 0.7926, 0.5838, 0.1361, 0.4154, 0.3634, 0.3741, 0.2088, 0.2989,
|
||||
0.3982, 0.5618, 0.7266, 0.1089, 0.2922, 0.3306, 0.2869, 0.6638, 0.3091, 0.9312,
|
||||
0.0240, 0.2893, 0.5632, 0.9625, 0.4189, 0.3854, 0.2743, 0.6754, 0.8820, 0.8699});
|
||||
|
||||
NDArray expected = NDArrayFactory::create<float>('c', {2, 9, 9, 1}, {
|
||||
0.2303f, 0.54569f, 0.840649f, 0.92725444f, 0.65660673f,
|
||||
0.16641647f, 0.06117659f, 0.33279106f, 0.4023279f, 0.5139505f,
|
||||
0.49821317f, 0.4906872f, 0.537642f, 0.4070102f, 0.13030615f,
|
||||
0.258801f, 0.65352744f, 0.773368f, 0.69225276f, 0.44177493f,
|
||||
0.21910316f, 0.22368976f, 0.24221404f, 0.21399781f, 0.5114972f,
|
||||
0.9169859f, 1.0511527f, 0.5608501f, 0.41315168f, 0.2913824f,
|
||||
0.2966933f, 0.38585684f, 0.48849702f, 0.71013063f, 0.9086001f,
|
||||
0.9794303f, 0.29625386f, 0.39427578f, 0.45971435f, 0.39693952f,
|
||||
0.40860707f, 0.51061106f, 0.6181093f, 0.67309624f, 0.69564015f,
|
||||
0.06012487f, 0.3863805f, 0.58993465f, 0.40679216f, 0.22607432f,
|
||||
0.20093678f, 0.25901243f, 0.3615362f, 0.39371052f, 0.24176767f,
|
||||
0.4868709f, 0.650651f, 0.5493148f, 0.3825456f, 0.27788478f,
|
||||
0.18927254f, 0.16692996f, 0.15432167f, 0.677519f, 0.6236242f,
|
||||
0.61700624f, 0.7214321f, 0.7307374f, 0.6251454f, 0.3924176f,
|
||||
0.17802659f, 0.10231908f, 0.81192374f, 0.66878575f, 0.6118803f,
|
||||
0.7797006f, 0.8396968f, 0.72889954f, 0.44547448f, 0.16794783f,
|
||||
0.07125802f, 0.4154f, 0.38504714f, 0.3623221f, 0.3862173f,
|
||||
0.3397379f, 0.23285517f, 0.21876639f, 0.2892362f, 0.30817088f,
|
||||
0.41268015f, 0.45587808f, 0.51991886f, 0.60977113f, 0.49489656f,
|
||||
0.21313031f, 0.11297428f, 0.2167207f, 0.23940037f, 0.39337245f,
|
||||
0.46112412f, 0.583034f, 0.76207364f, 0.6326203f, 0.22189438f,
|
||||
0.12071565f, 0.3275853f, 0.3794855f, 0.38497013f, 0.35049653f,
|
||||
0.41895086f, 0.671095f, 0.62119365f, 0.22362521f, 0.30189657f,
|
||||
0.72530353f, 0.85048175f, 0.2524255f, 0.2182264f, 0.2964637f,
|
||||
0.5361996f, 0.6255393f, 0.46424767f, 0.5741281f, 0.8408146f,
|
||||
0.92403257f, 0.04648584f, 0.14959256f, 0.32215607f, 0.46194845f,
|
||||
0.6642166f, 0.83560026f, 0.7663391f, 0.5284251f, 0.4573109f,
|
||||
0.10357999f, 0.17442937f, 0.32116935f, 0.45530772f, 0.7163773f,
|
||||
0.9856574f, 0.8976148f, 0.5538923f, 0.45173654f, 0.34958175f,
|
||||
0.2680429f, 0.30470955f, 0.51233786f, 0.75128907f, 0.86736864f,
|
||||
0.8982046f, 0.83254474f, 0.8168574f, 0.4225865f, 0.2956836f,
|
||||
0.29948136f, 0.5276342f, 0.76461166f, 0.8442875f, 0.907862f,
|
||||
0.9139262f, 0.92068815f
|
||||
});
|
||||
auto size = NDArrayFactory::create<int>({9, 9});
|
||||
nd4j::ops::resize_bicubic op;
|
||||
auto results = op.execute({&input, &size}, {}, {});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
NDArray* result = results->at(0);
|
||||
|
||||
// result->printBuffer("Resized to 9x9");
|
||||
// expected.printBuffer("Expect for 9x9");
|
||||
ASSERT_TRUE(expected.isSameShape(result));
|
||||
ASSERT_TRUE(expected.equalsTo(result));
|
||||
delete results;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test8) {
|
||||
|
||||
NDArray input = NDArrayFactory::create<double>('c', {2, 5, 5, 1}, {
|
||||
0.23028551377579154, 0.7949972231516509, 0.8171307820461517, 0.04507309923418412, 0.3689673597428338,
|
||||
0.6845757584903018, 0.27268547668219667, 0.2770196372806053, 0.2381478370531429, 0.9511201914609859,
|
||||
0.41160882670429033, 0.3997152563642703, 0.4074505147711718, 0.6274595060113246, 0.8017922711300232,
|
||||
0.06782045852179475, 0.6220772280691722, 0.2982335327629251, 0.1523603480424196, 0.2612986044295986,
|
||||
0.7424762244324299, 0.6036156464824591, 0.7926371071102005, 0.5838270656432538, 0.13607200219168547,
|
||||
0.4154002170215956, 0.36340617544852116, 0.37405031188276827, 0.20880251686544882, 0.298919946410666,
|
||||
0.39820758164277126, 0.5617728968896589, 0.72660225993937, 0.10888245916813699, 0.29215797784445496,
|
||||
0.3305531351746034, 0.28693451964931715, 0.6637635348315494, 0.30913418229827583, 0.9312186188801752,
|
||||
0.0239594182399363, 0.2892942758780874, 0.5631691110629038, 0.9625499752246309, 0.4189439089689968,
|
||||
0.3854304088214935, 0.27426304203925045, 0.6754051704648238, 0.8820362490795286, 0.8699337744328859});
|
||||
|
||||
|
||||
auto testData = NDArrayFactory::create<float>('c', {2,9,9,1}, {
|
||||
0.230286f, 0.510566354f, 0.794997215f, 0.931386113f, 0.817130804f, 0.402811885f, 0.045073099f, 0.134639814f, 0.368967354f,
|
||||
0.483021289f, 0.501266003f, 0.521932304f, 0.572325349f, 0.534847379f, 0.267853439f, 0.105112493f, 0.349290252f, 0.674043298f,
|
||||
0.684575737f, 0.478224277f, 0.272685468f, 0.239882097f, 0.27701965f, 0.191148892f, 0.23814784f, 0.590989769f, 0.951120198f,
|
||||
0.622912169f, 0.441326082f, 0.266387194f, 0.232538164f, 0.301838756f, 0.356378645f, 0.495445013f, 0.756725252f, 0.981704295f,
|
||||
0.411608815f, 0.40493685f, 0.399715245f, 0.381842017f, 0.407450527f, 0.501836538f, 0.627459526f, 0.735251725f, 0.801792264f,
|
||||
0.150875032f, 0.357000858f, 0.524536073f, 0.450354964f, 0.318719596f, 0.319606483f, 0.385957927f, 0.46392554f, 0.529285908f,
|
||||
0.06782046f, 0.375309169f, 0.622077227f, 0.525792599f, 0.298233539f, 0.184723631f, 0.15236035f, 0.193153858f, 0.261298597f,
|
||||
|
||||
0.372918189f, 0.512539625f, 0.63369292f, 0.628733814f, 0.535196245f, 0.436597466f, 0.323553175f, 0.215942055f, 0.148014024f,
|
||||
0.742476225f, 0.655325174f, 0.603615642f, 0.704684138f, 0.79263711f, 0.747929871f, 0.583827078f, 0.340373576f, 0.136071995f,
|
||||
0.415400207f, 0.388405323f, 0.363406181f, 0.379345775f, 0.374050319f, 0.28397581f, 0.208802521f, 0.238369256f, 0.298919946f,
|
||||
0.413146496f, 0.444389015f, 0.488355637f, 0.568351328f, 0.556217432f, 0.345546633f, 0.140068889f, 0.148834035f, 0.23562704f,
|
||||
0.398207575f, 0.464537472f, 0.561772883f, 0.717433035f, 0.726602256f, 0.416013002f, 0.108882457f, 0.142608985f, 0.292157978f,
|
||||
0.391511708f, 0.389470309f, 0.442729384f, 0.651181757f, 0.737665415f, 0.41685915f, 0.138383076f, 0.342548877f, 0.659080088f,
|
||||
|
||||
0.330553144f, 0.273416102f, 0.286934525f, 0.50450629f, 0.663763523f, 0.463456154f, 0.309134185f, 0.586929917f, 0.931218624f,
|
||||
0.137025774f, 0.169145152f, 0.263757467f, 0.436182201f, 0.597053051f, 0.657990932f, 0.662163854f, 0.68354249f, 0.692712903f,
|
||||
0.023959421f, 0.130951077f, 0.289294273f, 0.413664877f, 0.563169122f, 0.839498401f, 0.962549984f, 0.728188932f, 0.418943912f,
|
||||
0.175951749f, 0.198239252f, 0.281999886f, 0.420836329f, 0.609856486f, 0.863734365f, 0.983550847f, 0.825015843f, 0.596413136f,
|
||||
0.385430396f, 0.292239636f, 0.274263054f, 0.445040524f, 0.675405145f, 0.817462444f, 0.882036269f, 0.895356655f, 0.869933784f
|
||||
});
|
||||
|
||||
auto size = NDArrayFactory::create<int>({9, 9});
|
||||
nd4j::ops::resize_bicubic op;
|
||||
auto results = op.execute({&input, &size}, {}, {}, {true, false});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
NDArray* result = results->at(0);
|
||||
|
||||
// result->printBuffer("Resized to 9x9");
|
||||
// expected.printBuffer("Expect for 9x9");
|
||||
ASSERT_TRUE(testData.isSameShape(result));
|
||||
ASSERT_TRUE(testData.equalsTo(result));
|
||||
delete results;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests11, summaryStatsData_test1) {
|
||||
|
||||
|
|
|
@ -232,10 +232,7 @@ import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentProdBp;
|
|||
import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentSqrtNBp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentSumBp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.*;
|
||||
import org.nd4j.linalg.api.ops.random.custom.DistributionUniform;
|
||||
import org.nd4j.linalg.api.ops.random.custom.RandomBernoulli;
|
||||
import org.nd4j.linalg.api.ops.random.custom.RandomExponential;
|
||||
import org.nd4j.linalg.api.ops.random.custom.RandomNormal;
|
||||
import org.nd4j.linalg.api.ops.random.custom.*;
|
||||
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
|
||||
import org.nd4j.linalg.api.ops.random.impl.BinomialDistribution;
|
||||
import org.nd4j.linalg.api.ops.random.impl.DropOutInverted;
|
||||
|
@ -384,6 +381,18 @@ public class DifferentialFunctionFactory {
|
|||
return new TruncatedNormalDistribution(sameDiff(), mean, stdev, shape).outputVariable();
|
||||
}
|
||||
|
||||
public SDVariable randomGamma(SDVariable shape, SDVariable alpha, SDVariable beta, int... seeds) {
|
||||
return new RandomGamma(sameDiff(), shape, alpha, beta, seeds).outputVariable();
|
||||
}
|
||||
|
||||
public SDVariable randomPoisson(SDVariable shape, SDVariable rate, int... seeds) {
|
||||
return new RandomPoisson(sameDiff(), shape, rate, seeds).outputVariable();
|
||||
}
|
||||
|
||||
public SDVariable randomShuffle(SDVariable values, int... seeds) {
|
||||
return new RandomShuffle(sameDiff(), values, seeds).outputVariable();
|
||||
}
|
||||
|
||||
/**
|
||||
* Exponential distribution: P(x) = lambda * exp(-lambda * x)
|
||||
*
|
||||
|
|
|
@ -3,10 +3,7 @@ package org.nd4j.autodiff.samediff.ops;
|
|||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.linalg.api.ops.custom.AdjustContrast;
|
||||
import org.nd4j.linalg.api.ops.custom.AdjustHue;
|
||||
import org.nd4j.linalg.api.ops.custom.AdjustSaturation;
|
||||
import org.nd4j.linalg.api.ops.custom.RandomCrop;
|
||||
import org.nd4j.linalg.api.ops.custom.*;
|
||||
import org.nd4j.linalg.api.ops.impl.image.CropAndResize;
|
||||
import org.nd4j.linalg.api.ops.impl.image.ExtractImagePatches;
|
||||
import org.nd4j.linalg.api.ops.impl.image.NonMaxSuppression;
|
||||
|
@ -119,4 +116,26 @@ public class SDImage extends SDOps {
|
|||
SDVariable out = new RandomCrop(sd, input, shape).outputVariable();
|
||||
return updateVariableNameAndReference(out, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Converting array from HSV to RGB format
|
||||
* @param name name
|
||||
* @param input 3D image
|
||||
* @return 3D image
|
||||
*/
|
||||
public SDVariable rgbToHsv(String name, @NonNull SDVariable input) {
|
||||
SDVariable out = new RgbToHsv(sd, input).outputVariable();
|
||||
return updateVariableNameAndReference(out, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Converting image from HSV to RGB format
|
||||
* @param name name
|
||||
* @param input 3D image
|
||||
* @return 3D image
|
||||
*/
|
||||
public SDVariable hsvToRgb(String name, @NonNull SDVariable input) {
|
||||
SDVariable out = new HsvToRgb(sd, input).outputVariable();
|
||||
return updateVariableNameAndReference(out, name);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,6 +19,7 @@ package org.nd4j.autodiff.samediff.ops;
|
|||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
import static org.nd4j.autodiff.samediff.ops.SDValidation.validateInteger;
|
||||
|
||||
|
@ -295,4 +296,43 @@ public class SDRandom extends SDOps {
|
|||
return updateVariableNameAndReference(ret, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a new random SDVariable with Gamma distribution
|
||||
*
|
||||
* @param name Name of the output variable
|
||||
* @param alpha distribution parameter
|
||||
* @param beta distribution parameter
|
||||
* @param shape Shape of the new variable
|
||||
* @return new SDVariable
|
||||
*/
|
||||
public SDVariable gamma(String name, SDVariable shape, SDVariable alpha, SDVariable beta) {
|
||||
SDVariable ret = f().randomGamma(alpha, beta, shape);
|
||||
return updateVariableNameAndReference(ret, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a new random SDVariable with Poission distribution
|
||||
*
|
||||
* @param name Name of the output variable
|
||||
* @param lambda rate distribution parameter
|
||||
* @param shape Shape of the new variable
|
||||
* @return new SDVariable
|
||||
*/
|
||||
public SDVariable poisson(String name, SDVariable lambda, SDVariable shape, int... seeds) {
|
||||
SDVariable ret = f().randomPoisson(shape, lambda, seeds);
|
||||
return updateVariableNameAndReference(ret, name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a new random SDVariable by random shuffle
|
||||
*
|
||||
* @param name Name of the output variable
|
||||
* @param value array to shuffle
|
||||
* @return new SDVariable
|
||||
*/
|
||||
public SDVariable shuffle(String name, SDVariable value, int... seeds) {
|
||||
SDVariable ret = f().randomShuffle(value, seeds);
|
||||
return updateVariableNameAndReference(ret, name);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -572,6 +572,9 @@ public class ImportClassMapping {
|
|||
org.nd4j.linalg.api.ops.random.custom.RandomBernoulli.class,
|
||||
org.nd4j.linalg.api.ops.random.custom.RandomExponential.class,
|
||||
org.nd4j.linalg.api.ops.random.custom.RandomNormal.class,
|
||||
org.nd4j.linalg.api.ops.random.custom.RandomGamma.class,
|
||||
org.nd4j.linalg.api.ops.random.custom.RandomPoisson.class,
|
||||
org.nd4j.linalg.api.ops.random.custom.RandomShuffle.class,
|
||||
org.nd4j.linalg.api.ops.random.impl.AlphaDropOut.class,
|
||||
org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution.class,
|
||||
org.nd4j.linalg.api.ops.random.impl.BinomialDistribution.class,
|
||||
|
@ -588,6 +591,8 @@ public class ImportClassMapping {
|
|||
org.nd4j.linalg.api.ops.random.impl.UniformDistribution.class,
|
||||
org.nd4j.linalg.api.ops.custom.AdjustContrast.class,
|
||||
org.nd4j.linalg.api.ops.custom.AdjustContrastV2.class,
|
||||
org.nd4j.linalg.api.ops.custom.HsvToRgb.class,
|
||||
org.nd4j.linalg.api.ops.custom.RgbToHsv.class,
|
||||
org.nd4j.linalg.api.ops.custom.BitCast.class,
|
||||
org.nd4j.linalg.api.ops.custom.CompareAndBitpack.class,
|
||||
org.nd4j.linalg.api.ops.custom.DivideNoNan.class,
|
||||
|
@ -601,7 +606,10 @@ public class ImportClassMapping {
|
|||
org.nd4j.linalg.api.ops.custom.Polygamma.class,
|
||||
org.nd4j.linalg.api.ops.custom.RandomCrop.class,
|
||||
org.nd4j.linalg.api.ops.custom.Roll.class,
|
||||
org.nd4j.linalg.api.ops.custom.ToggleBits.class
|
||||
org.nd4j.linalg.api.ops.custom.ToggleBits.class,
|
||||
org.nd4j.linalg.api.ops.custom.Igamma.class,
|
||||
org.nd4j.linalg.api.ops.custom.Igammac.class,
|
||||
org.nd4j.linalg.api.ops.custom.Digamma.class
|
||||
);
|
||||
|
||||
static {
|
||||
|
|
|
@ -0,0 +1,56 @@
|
|||
/* ******************************************************************************
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.nd4j.linalg.api.ops.custom;
|
||||
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
@NoArgsConstructor
|
||||
public class Digamma extends DynamicCustomOp {
|
||||
public Digamma(@NonNull INDArray x) {
|
||||
addInputArgument(x);
|
||||
}
|
||||
|
||||
public Digamma(@NonNull SameDiff sameDiff, @NonNull SDVariable x) {
|
||||
super("", sameDiff, new SDVariable[]{x});
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "digamma";
|
||||
}
|
||||
|
||||
@Override
|
||||
public String tensorflowName() {
|
||||
return "Digamma";
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||
int n = args().length;
|
||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
|
||||
return Collections.singletonList(inputDataTypes.get(0));
|
||||
}
|
||||
}
|
|
@ -19,15 +19,23 @@ import lombok.NonNull;
|
|||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.tensorflow.framework.AttrValue;
|
||||
import org.tensorflow.framework.GraphDef;
|
||||
import org.tensorflow.framework.NodeDef;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public class FusedBatchNorm extends DynamicCustomOp {
|
||||
|
||||
private DataType outputDataType;
|
||||
|
||||
public FusedBatchNorm() {}
|
||||
|
||||
public FusedBatchNorm(@NonNull INDArray x, @NonNull INDArray scale, @NonNull INDArray offset,
|
||||
|
@ -38,6 +46,7 @@ public class FusedBatchNorm extends DynamicCustomOp {
|
|||
if (yOut != null && batchMeanOut != null && batchMeanVar != null) {
|
||||
addOutputArgument(yOut, batchMeanOut, batchMeanVar);
|
||||
}
|
||||
this.outputDataType = x.dataType();
|
||||
}
|
||||
|
||||
public FusedBatchNorm(@NonNull SameDiff sameDiff, @NonNull SDVariable x, @NonNull SDVariable scale, @NonNull SDVariable offset,
|
||||
|
@ -51,14 +60,25 @@ public class FusedBatchNorm extends DynamicCustomOp {
|
|||
}
|
||||
|
||||
@Override
|
||||
public String tensorflowName() {
|
||||
return "FusedBatchNormV2";
|
||||
public String[] tensorflowNames() {
|
||||
return new String[]{"FusedBatchNormV2","FusedBatchNormV3"};
|
||||
}
|
||||
|
||||
@Override
|
||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||
boolean isNchw = attributesForNode.containsKey("data_format") && attributesForNode.get("data_format").getS().toStringUtf8().equalsIgnoreCase("NCHW");
|
||||
boolean training = !attributesForNode.containsKey("is_training") ? true : attributesForNode.get("is_training").getB();
|
||||
addIArgument(isNchw ? 1 : 0);
|
||||
addIArgument(training ? 1 : 0);
|
||||
if(attributesForNode.containsKey("T")){
|
||||
outputDataType = TFGraphMapper.convertType(attributesForNode.get("T").getType());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||
int n = args().length;
|
||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
|
||||
return Collections.singletonList(inputDataTypes.get(0));
|
||||
return Arrays.asList(outputDataType, DataType.FLOAT, DataType.FLOAT); //Activations may be half, bfloat16, float32; mean/var is always float
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
|
||||
/* ******************************************************************************
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.nd4j.linalg.api.ops.custom;
|
||||
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
@NoArgsConstructor
|
||||
public class HsvToRgb extends DynamicCustomOp {
|
||||
|
||||
public HsvToRgb(INDArray input) {
|
||||
addInputArgument(input);
|
||||
}
|
||||
|
||||
public HsvToRgb(SameDiff sameDiff, SDVariable input) {
|
||||
super(sameDiff, new SDVariable[]{input});
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "hsv_to_rgb";
|
||||
}
|
||||
|
||||
@Override
|
||||
public String tensorflowName() {
|
||||
return "HSVToRGB";
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||
int n = args().length;
|
||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
|
||||
return Collections.singletonList(inputDataTypes.get(0));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,65 @@
|
|||
/* ******************************************************************************
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.nd4j.linalg.api.ops.custom;
|
||||
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
@NoArgsConstructor
|
||||
public class Igamma extends DynamicCustomOp {
|
||||
public Igamma(@NonNull INDArray n, @NonNull INDArray x) {
|
||||
Preconditions.checkArgument(n.shape() != x.shape(),
|
||||
"Igamma: n and x must have the same shapes");
|
||||
addInputArgument(n,x);
|
||||
}
|
||||
|
||||
public Igamma(@NonNull INDArray n, @NonNull INDArray x, INDArray output) {
|
||||
this(n,x);
|
||||
if (output != null) {
|
||||
addOutputArgument(output);
|
||||
}
|
||||
}
|
||||
|
||||
public Igamma(@NonNull SameDiff sameDiff, @NonNull SDVariable n, @NonNull SDVariable x) {
|
||||
super("", sameDiff, new SDVariable[]{n ,x});
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "igamma";
|
||||
}
|
||||
|
||||
@Override
|
||||
public String tensorflowName() {
|
||||
return "Igamma";
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||
int n = args().length;
|
||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
|
||||
return Collections.singletonList(inputDataTypes.get(0));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,66 @@
|
|||
/* ******************************************************************************
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.nd4j.linalg.api.ops.custom;
|
||||
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.DynamicPartitionBp;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
@NoArgsConstructor
|
||||
public class Igammac extends DynamicCustomOp {
|
||||
public Igammac(@NonNull INDArray n, @NonNull INDArray x) {
|
||||
Preconditions.checkArgument(n.shape() != x.shape(),
|
||||
"Igamma: n and x must have the same shapes");
|
||||
addInputArgument(n,x);
|
||||
}
|
||||
|
||||
public Igammac(@NonNull INDArray n, @NonNull INDArray x, INDArray output) {
|
||||
this(n,x);
|
||||
if (output != null) {
|
||||
addOutputArgument(output);
|
||||
}
|
||||
}
|
||||
|
||||
public Igammac(@NonNull SameDiff sameDiff, @NonNull SDVariable n, @NonNull SDVariable x) {
|
||||
super("", sameDiff, new SDVariable[]{n ,x});
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "igammac";
|
||||
}
|
||||
|
||||
@Override
|
||||
public String tensorflowName() {
|
||||
return "Igammac";
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||
int n = args().length;
|
||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
|
||||
return Collections.singletonList(inputDataTypes.get(0));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,57 @@
|
|||
|
||||
/* ******************************************************************************
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.nd4j.linalg.api.ops.custom;
|
||||
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
@NoArgsConstructor
|
||||
public class RgbToHsv extends DynamicCustomOp {
|
||||
|
||||
public RgbToHsv(INDArray input) {
|
||||
addInputArgument(input);
|
||||
}
|
||||
|
||||
public RgbToHsv(SameDiff sameDiff, SDVariable input) {
|
||||
super(sameDiff, new SDVariable[]{input});
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "rgb_to_hsv";
|
||||
}
|
||||
|
||||
@Override
|
||||
public String tensorflowName() {
|
||||
return "RGBToHSV";
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||
int n = args().length;
|
||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
|
||||
return Collections.singletonList(inputDataTypes.get(0));
|
||||
}
|
||||
}
|
|
@ -61,7 +61,7 @@ public class UnsortedSegmentMax extends DynamicCustomOp {
|
|||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 2, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes);
|
||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes);
|
||||
return Collections.singletonList(inputDataTypes.get(0));
|
||||
}
|
||||
|
||||
|
|
|
@ -61,7 +61,7 @@ public class UnsortedSegmentMean extends DynamicCustomOp {
|
|||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 2, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes);
|
||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes);
|
||||
return Collections.singletonList(inputDataTypes.get(0));
|
||||
}
|
||||
|
||||
|
|
|
@ -61,7 +61,7 @@ public class UnsortedSegmentMin extends DynamicCustomOp {
|
|||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 2, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes);
|
||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes);
|
||||
return Collections.singletonList(inputDataTypes.get(0));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -61,7 +61,7 @@ public class UnsortedSegmentProd extends DynamicCustomOp {
|
|||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 2, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes);
|
||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes);
|
||||
return Collections.singletonList(inputDataTypes.get(0));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -60,7 +60,7 @@ public class UnsortedSegmentSqrtN extends DynamicCustomOp {
|
|||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 2, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes);
|
||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes);
|
||||
List<DataType> out = new ArrayList<>();
|
||||
for( int i=0; i<numSegments; i++ ){
|
||||
out.add(inputDataTypes.get(0));
|
||||
|
|
|
@ -62,7 +62,7 @@ public class UnsortedSegmentSum extends DynamicCustomOp {
|
|||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 2, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes);
|
||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes);
|
||||
//TODO Allow customizing output type
|
||||
return Collections.singletonList(Nd4j.defaultFloatingPointType());
|
||||
}
|
||||
|
|
|
@ -0,0 +1,84 @@
|
|||
|
||||
/* ******************************************************************************
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.nd4j.linalg.api.ops.random.custom;
|
||||
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.imports.descriptors.properties.adapters.DataTypeAdapter;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.tensorflow.framework.AttrValue;
|
||||
import org.tensorflow.framework.GraphDef;
|
||||
import org.tensorflow.framework.NodeDef;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
@NoArgsConstructor
|
||||
public class RandomGamma extends DynamicCustomOp {
|
||||
|
||||
public RandomGamma(@NonNull INDArray shape, @NonNull INDArray alpha, INDArray beta,
|
||||
int... seeds) {
|
||||
if (beta != null) {
|
||||
addInputArgument(shape,alpha,beta);
|
||||
}
|
||||
addInputArgument(shape,alpha);
|
||||
addIArgument(seeds);
|
||||
}
|
||||
|
||||
public RandomGamma(@NonNull INDArray shape, @NonNull INDArray alpha, INDArray beta) {
|
||||
|
||||
this(shape,alpha,beta,0,0);
|
||||
}
|
||||
|
||||
public RandomGamma(@NonNull SameDiff sameDiff, @NonNull SDVariable shape,
|
||||
@NonNull SDVariable alpha, SDVariable beta, int... seeds) {
|
||||
super(null, sameDiff, beta != null ? new SDVariable[]{shape, alpha, beta} :
|
||||
new SDVariable[]{shape, alpha});
|
||||
addIArgument(seeds);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "random_gamma";
|
||||
}
|
||||
|
||||
@Override
|
||||
public String tensorflowName() {
|
||||
return "RandomGamma";
|
||||
}
|
||||
|
||||
private DataType outputDataType = DataType.FLOAT;
|
||||
|
||||
@Override
|
||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||
if(attributesForNode.containsKey("alpha")) {
|
||||
outputDataType = DataTypeAdapter.dtypeConv(attributesForNode.get("alpha").getType());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||
Preconditions.checkState(inputDataTypes != null, "Expected exactly input datatypes for %s, got null", getClass());
|
||||
return Collections.singletonList(outputDataType);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,79 @@
|
|||
|
||||
/* ******************************************************************************
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.nd4j.linalg.api.ops.random.custom;
|
||||
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.imports.descriptors.properties.adapters.DataTypeAdapter;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.api.rng.Random;
|
||||
import org.tensorflow.framework.AttrValue;
|
||||
import org.tensorflow.framework.GraphDef;
|
||||
import org.tensorflow.framework.NodeDef;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
@NoArgsConstructor
|
||||
public class RandomPoisson extends DynamicCustomOp {
|
||||
|
||||
private DataType outputDataType = DataType.FLOAT;
|
||||
|
||||
public RandomPoisson(@NonNull INDArray shape, @NonNull INDArray rate, int... seeds) {
|
||||
addInputArgument(shape, rate);
|
||||
addIArgument(seeds);
|
||||
}
|
||||
|
||||
public RandomPoisson(@NonNull INDArray shape, @NonNull INDArray rate) {
|
||||
this(shape, rate, 0,0);
|
||||
}
|
||||
|
||||
public RandomPoisson(@NonNull SameDiff sameDiff, @NonNull SDVariable shape, @NonNull SDVariable rate, int... seeds) {
|
||||
super(null, sameDiff, new SDVariable[]{shape, rate});
|
||||
addIArgument(seeds);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "random_poisson";
|
||||
}
|
||||
|
||||
@Override
|
||||
public String[] tensorflowNames() {
|
||||
return new String[]{"RandomPoisson","RandomPoissonV2"};
|
||||
}
|
||||
|
||||
@Override
|
||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||
if(attributesForNode.containsKey("dtype")) {
|
||||
outputDataType = DataTypeAdapter.dtypeConv(attributesForNode.get("dtype").getType());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||
Preconditions.checkState(inputDataTypes.size() == 2, "Expected exactly 2 input datatypes for %s, got %s",
|
||||
getClass(), inputDataTypes.size());
|
||||
return Collections.singletonList(outputDataType);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,63 @@
|
|||
|
||||
/* ******************************************************************************
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.nd4j.linalg.api.ops.random.custom;
|
||||
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
@NoArgsConstructor
|
||||
public class RandomShuffle extends DynamicCustomOp {
|
||||
|
||||
public RandomShuffle(@NonNull INDArray value, int... seeds) {
|
||||
addInputArgument(value);
|
||||
addIArgument(seeds);
|
||||
}
|
||||
|
||||
public RandomShuffle(@NonNull INDArray value) {
|
||||
this(value, 0, 0);
|
||||
}
|
||||
|
||||
public RandomShuffle(@NonNull SameDiff sameDiff, @NonNull SDVariable value, int...seeds) {
|
||||
super(null, sameDiff, new SDVariable[]{value});
|
||||
addIArgument(seeds);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "random_shuffle";
|
||||
}
|
||||
|
||||
@Override
|
||||
public String tensorflowName() {
|
||||
return "RandomShuffle";
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 1, "Expected exactly 1 input datatype for %s, got %s", getClass(), inputDataTypes);
|
||||
return Collections.singletonList(inputDataTypes.get(0));
|
||||
}
|
||||
}
|
|
@ -63,8 +63,8 @@ import static org.junit.Assume.assumeFalse;
|
|||
TransformOpValidation.class,
|
||||
|
||||
//TF import tests
|
||||
TFGraphTestAllSameDiff.class,
|
||||
TFGraphTestAllLibnd4j.class
|
||||
TFGraphTestAllSameDiff.class
|
||||
//TFGraphTestAllLibnd4j.class
|
||||
})
|
||||
//IMPORTANT: This ignore is added to avoid maven surefire running both the suite AND the individual tests in "mvn test"
|
||||
// With it ignored here, the individual tests will run outside (i.e., separately/independently) of the suite in both "mvn test" and IntelliJ
|
||||
|
|
|
@ -20,6 +20,7 @@ import lombok.extern.slf4j.Slf4j;
|
|||
import lombok.val;
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
import org.junit.Ignore;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.OpValidationSuite;
|
||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||
|
@ -1465,6 +1466,7 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
|
||||
@Ignore("12/16/2019 https://github.com/eclipse/deeplearning4j/issues/8540")
|
||||
@Test
|
||||
public void testPad(){
|
||||
INDArray in = Nd4j.valueArrayOf(new long[]{5}, 1.0);
|
||||
|
|
|
@ -26,6 +26,7 @@ import org.junit.After;
|
|||
import org.junit.Before;
|
||||
import org.junit.BeforeClass;
|
||||
import org.nd4j.autodiff.execution.NativeGraphExecutioner;
|
||||
|
||||
import org.nd4j.autodiff.execution.conf.ExecutionMode;
|
||||
import org.nd4j.autodiff.execution.conf.ExecutorConfiguration;
|
||||
import org.nd4j.autodiff.execution.conf.OutputMode;
|
||||
|
@ -228,9 +229,9 @@ public class TFGraphTestAllHelper {
|
|||
String s1 = s.format(tfPred, false);
|
||||
String s2 = s.format(nd4jPred, false);
|
||||
System.out.print("TF: ");
|
||||
System.out.println(s1);
|
||||
System.out.println(tfPred.toStringFull());
|
||||
System.out.print("SD: ");
|
||||
System.out.println(s2);
|
||||
System.out.println(nd4jPred.toStringFull());
|
||||
}
|
||||
}
|
||||
assertTrue("Predictions do not match on " + modelName + ", node " + outputNode, eq);
|
||||
|
|
|
@ -111,17 +111,11 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
|
|||
// 2019/11/15 - missing dtype argument in nd4j, tests are useless https://github.com/eclipse/deeplearning4j/issues/8398
|
||||
"zeros_like/rank2_float32_dtype_int.*",
|
||||
|
||||
// 11.26.2019 failing - https://github.com/eclipse/deeplearning4j/issues/8450
|
||||
"betainc.*",
|
||||
|
||||
// 11.26.2019 failing - https://github.com/eclipse/deeplearning4j/issues/8453
|
||||
"roll/.*",
|
||||
|
||||
// 11.26.2019 failing https://github.com/eclipse/deeplearning4j/issues/8455
|
||||
"matrix_band_part/.*",
|
||||
|
||||
// 05.12.2019 failing https://github.com/eclipse/deeplearning4j/issues/8507
|
||||
"resize_bicubic/int32.*"
|
||||
"matrix_band_part/.*"
|
||||
};
|
||||
|
||||
/* As per TFGraphTestList.printArraysDebugging - this field defines a set of regexes for test cases that should have
|
||||
|
|
|
@ -23,6 +23,7 @@ import org.junit.After;
|
|||
import org.junit.Before;
|
||||
import org.junit.Rule;
|
||||
import org.junit.rules.TestName;
|
||||
import org.junit.rules.Timeout;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
import org.nd4j.config.ND4JSystemProperties;
|
||||
|
@ -51,6 +52,9 @@ public abstract class BaseNd4jTest {
|
|||
@Rule
|
||||
public TestName testName = new TestName();
|
||||
|
||||
@Rule
|
||||
public Timeout timeout = Timeout.seconds(30);
|
||||
|
||||
protected long startTime;
|
||||
protected int threadCountBefore;
|
||||
|
||||
|
|
|
@ -34,8 +34,6 @@ import org.nd4j.linalg.api.ops.impl.controlflow.Where;
|
|||
import org.nd4j.linalg.api.ops.impl.image.CropAndResize;
|
||||
import org.nd4j.linalg.api.ops.impl.image.NonMaxSuppression;
|
||||
import org.nd4j.linalg.api.ops.impl.image.ResizeBilinear;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.MaxPoolWithArgmax;
|
||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig;
|
||||
import org.nd4j.linalg.api.ops.impl.reduce.MmulBp;
|
||||
import org.nd4j.linalg.api.ops.impl.shape.Create;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax;
|
||||
|
@ -1353,4 +1351,80 @@ public class CustomOpsTests extends BaseNd4jTest {
|
|||
|
||||
assertEquals(exp, result);
|
||||
}
|
||||
|
||||
// Exact copy of libnd4j test
|
||||
@Test
|
||||
public void testRgbToHsv() {
|
||||
INDArray expected = Nd4j.createFromArray(new float[]{6.75000000e+01f, 2.54545455e-01f, 8.62745098e-01f, 1.80000000e+02f,
|
||||
3.27777778e-01f, 7.05882353e-01f, 1.35066079e+02f, 9.26530612e-01f,
|
||||
9.60784314e-01f, 7.45341615e-01f, 6.85106383e-01f, 9.21568627e-01f,
|
||||
2.78688525e+02f, 7.85407725e-01f, 9.13725490e-01f, 2.10989011e+01f,
|
||||
4.76439791e-01f, 7.49019608e-01f, 2.89038462e+02f, 8.48979592e-01f,
|
||||
9.60784314e-01f, 1.56416185e+02f, 6.92000000e-01f, 9.80392157e-01f,
|
||||
3.52881356e+02f, 5.31531532e-01f, 4.35294118e-01f, 1.07142857e+01f,
|
||||
2.90155440e-01f, 7.56862745e-01f, 3.43384615e+02f, 3.86904762e-01f,
|
||||
6.58823529e-01f, 1.78321678e+02f, 7.48691099e-01f, 7.49019608e-01f,
|
||||
2.30645161e+02f, 7.78242678e-01f, 9.37254902e-01f, 3.19159664e+02f,
|
||||
7.62820513e-01f, 6.11764706e-01f, 2.10126582e+01f, 9.71311475e-01f,
|
||||
9.56862745e-01f, 2.90896552e+02f, 5.96707819e-01f, 9.52941176e-01f,
|
||||
1.74822335e+02f, 9.42583732e-01f, 8.19607843e-01f, 2.06600985e+02f,
|
||||
9.90243902e-01f, 8.03921569e-01f, 1.06883721e+02f, 8.70445344e-01f,
|
||||
9.68627451e-01f, 1.95272727e+02f, 6.11111111e-01f, 7.05882353e-01f}).reshape(5,4,3);
|
||||
INDArray input = Nd4j.createFromArray(new float[]{213.f, 220.f, 164.f, 121.f, 180.f, 180.f, 18.f, 245.f, 75.f, 235.f, 76.f, 74.f, 168.f,
|
||||
50.f, 233.f, 191.f, 132.f, 100.f, 207.f, 37.f, 245.f, 77.f, 250.f, 182.f, 111.f, 52.f,
|
||||
59.f, 193.f, 147.f, 137.f, 168.f, 103.f, 121.f, 48.f, 191.f, 187.f, 53.f, 82.f, 239.f,
|
||||
156.f, 37.f, 118.f, 244.f, 90.f, 7.f, 221.f, 98.f, 243.f, 12.f, 209.f, 192.f, 2.f,
|
||||
115.f, 205.f, 79.f, 247.f, 32.f, 70.f, 152.f, 180.f}).reshape(5,4,3);
|
||||
RgbToHsv op = new RgbToHsv(input);
|
||||
INDArray[] ret = Nd4j.exec(op);
|
||||
assertEquals(ret[0], expected);
|
||||
}
|
||||
|
||||
// Exact copy of libnd4j test
|
||||
@Test
|
||||
public void testHsvToRgb() {
|
||||
INDArray input = Nd4j.createFromArray(new float[]{263.25842697f, 0.74476987f, 0.9372549f, 279.86842105f,
|
||||
0.9047619f, 0.65882353f, 71.30044843f, 1.f,
|
||||
0.8745098f, 180.f, 0.74871795f, 0.76470588f,
|
||||
77.6f, 0.49019608f, 0.6f, 260.74468085f,
|
||||
0.89952153f, 0.81960784f, 296.12903226f, 0.86915888f,
|
||||
0.41960784f, 289.82142857f, 0.53333333f, 0.82352941f}).reshape(8,3);
|
||||
|
||||
INDArray expected = Nd4j.createFromArray(new float[]{130.f, 61.f, 239.f, 117.f, 16.f, 168.f, 181.f, 223.f, 0.f, 49.f, 195.f, 195.f, 131.f,
|
||||
153.f, 78.f, 86.f, 21.f, 209.f, 101.f, 14.f, 107.f, 191.f, 98.f, 210.f}).reshape(8,3);
|
||||
|
||||
HsvToRgb op = new HsvToRgb(input);
|
||||
INDArray[] ret = Nd4j.exec(op);
|
||||
assertEquals(ret[0], expected);
|
||||
|
||||
}
|
||||
|
||||
@Ignore
|
||||
@Test
|
||||
public void testHsvToRgb_1() {
|
||||
/* Emulation of simple TF test:
|
||||
image = tf.random_uniform(shape = [1,1,3])
|
||||
tf.image.hsv_to_rgb(image)*/
|
||||
INDArray image = Nd4j.createFromArray(new float[]{0.7788f, 0.8012f, 0.7244f}).
|
||||
reshape(1,1,3);
|
||||
HsvToRgb op = new HsvToRgb(image);
|
||||
INDArray[] ret = Nd4j.exec(op);
|
||||
INDArray expected = Nd4j.createFromArray(new float[]{0.53442812f,0.144007295f,0.724374652f}).reshape(1,1,3);
|
||||
assertEquals(expected, ret[0]);
|
||||
}
|
||||
|
||||
@Ignore
|
||||
@Test
|
||||
public void testRgbToHsv_1() {
|
||||
/* Emulation of simple TF test:
|
||||
image = tf.random_uniform(shape = [1,2,3])
|
||||
tf.image.rgb_to_hsv(image)*/
|
||||
INDArray image = Nd4j.createFromArray(new float[]{0.7788f,0.8012f,0.7244f,
|
||||
0.2309f,0.7271f,0.1804f}).reshape(1,2,3);
|
||||
RgbToHsv op = new RgbToHsv(image);
|
||||
INDArray[] ret = Nd4j.exec(op);
|
||||
INDArray expected = Nd4j.createFromArray(new float[]{0.215289578f, 0.095885336f, 0.801197767f,
|
||||
0.317938268f, 0.751917899f, 0.727141261f}).reshape(1,2,3);
|
||||
assertEquals(expected, ret[0]);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -31,8 +31,7 @@ import org.nd4j.linalg.api.buffer.DataType;
|
|||
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
|
||||
import org.nd4j.linalg.api.ops.random.custom.DistributionUniform;
|
||||
import org.nd4j.linalg.api.ops.random.custom.RandomBernoulli;
|
||||
import org.nd4j.linalg.api.ops.random.custom.*;
|
||||
import org.nd4j.linalg.api.ops.random.impl.*;
|
||||
import org.nd4j.linalg.api.rng.DefaultRandom;
|
||||
import org.nd4j.linalg.api.rng.Random;
|
||||
|
@ -1473,6 +1472,44 @@ public class RandomTests extends BaseNd4jTest {
|
|||
assertEquals(out1, out2);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testGamma(){
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
INDArray shape = Nd4j.createFromArray(new int[] {1,3});
|
||||
INDArray alpha = Nd4j.rand(1,3);
|
||||
val randomGamma = new RandomGamma(shape, alpha, null);
|
||||
INDArray[] res = Nd4j.exec(randomGamma);
|
||||
|
||||
val randomGamma1 = new RandomGamma(shape, alpha, null);
|
||||
INDArray[] res1 = Nd4j.exec(randomGamma1);
|
||||
assertEquals(res[0], res1[0]);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testPoisson(){
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
INDArray shape = Nd4j.createFromArray(new int[] {1,3});
|
||||
INDArray alpha = Nd4j.rand(1,3);
|
||||
val randomPoisson = new RandomPoisson(shape, alpha);
|
||||
INDArray[] res = Nd4j.exec(randomPoisson);
|
||||
|
||||
val randomPoisson1 = new RandomPoisson(shape, alpha);
|
||||
INDArray[] res1 = Nd4j.exec(randomPoisson1);
|
||||
assertEquals(res[0], res1[0]);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testShuffle(){
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
INDArray alpha = Nd4j.rand(1,3);
|
||||
val randomShuffle = new RandomShuffle(alpha);
|
||||
INDArray[] res = Nd4j.exec(randomShuffle);
|
||||
|
||||
val randomShuffle1 = new RandomShuffle(alpha);
|
||||
INDArray[] res1 = Nd4j.exec(randomShuffle1);
|
||||
assertEquals(res[0], res1[0]);
|
||||
}
|
||||
|
||||
@Override
|
||||
public char ordering() {
|
||||
return 'c';
|
||||
|
|
Loading…
Reference in New Issue