diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/BaseDL4JTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/BaseDL4JTest.java index 786c7ea93..0da356677 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/BaseDL4JTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/BaseDL4JTest.java @@ -22,6 +22,7 @@ import org.junit.After; import org.junit.Before; import org.junit.Rule; import org.junit.rules.TestName; +import org.junit.rules.Timeout; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; @@ -42,6 +43,8 @@ public class BaseDL4JTest { @Rule public TestName name = new TestName(); + @Rule + public Timeout timeout = Timeout.seconds(30); protected long startTime; protected int threadCountBefore; diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderDataSetiteratorTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderDataSetiteratorTest.java index 3bd1bd37f..0772072b5 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderDataSetiteratorTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderDataSetiteratorTest.java @@ -16,6 +16,7 @@ package org.deeplearning4j.datasets.datavec; +import org.junit.rules.Timeout; import org.nd4j.shade.guava.io.Files; import lombok.extern.slf4j.Slf4j; import org.apache.commons.io.FileUtils; @@ -70,6 +71,9 @@ import static org.nd4j.linalg.indexing.NDArrayIndex.point; @Slf4j public class RecordReaderDataSetiteratorTest extends BaseDL4JTest { + @Rule + protected Timeout timeout = Timeout.seconds(300); + @Override public DataType getDataType(){ return DataType.FLOAT; diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIteratorTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIteratorTest.java index 1e82a4783..cb534cc23 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIteratorTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIteratorTest.java @@ -17,6 +17,7 @@ package org.deeplearning4j.datasets.datavec; +import org.junit.rules.Timeout; import org.nd4j.shade.guava.io.Files; import org.apache.commons.compress.utils.IOUtils; import org.apache.commons.io.FileUtils; @@ -69,6 +70,9 @@ public class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest { @Rule public TemporaryFolder temporaryFolder = new TemporaryFolder(); + @Rule + protected Timeout timeout = Timeout.seconds(300); + @Test public void testsBasic() throws Exception { //Load details from CSV files; single input/output -> compare to RecordReaderDataSetIterator diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/MultipleEpochsIteratorTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/MultipleEpochsIteratorTest.java index b2b32c715..7bad73f06 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/MultipleEpochsIteratorTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/MultipleEpochsIteratorTest.java @@ -22,7 +22,9 @@ import org.datavec.api.split.FileSplit; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; import org.deeplearning4j.nn.util.TestDataSetConsumer; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.Timeout; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.factory.Nd4j; @@ -37,6 +39,9 @@ import static org.junit.Assert.*; public class MultipleEpochsIteratorTest extends BaseDL4JTest { + @Rule + protected Timeout timeout = Timeout.seconds(300); + @Test public void testNextAndReset() throws Exception { int epochs = 3; diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/eval/EvalTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/eval/EvalTest.java index 0a9cfea4c..90d9a37c9 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/eval/EvalTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/eval/EvalTest.java @@ -161,7 +161,7 @@ public class EvalTest extends BaseDL4JTest { assertEquals(evalExpected.getConfusionMatrix(), evalActual.getConfusionMatrix()); } - @Test + @Test(timeout = 300000) public void testEvaluationWithMetaData() throws Exception { RecordReader csv = new CSVRecordReader(); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/TestComputationGraphNetwork.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/TestComputationGraphNetwork.java index 3e330d248..0aecb4c94 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/TestComputationGraphNetwork.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/TestComputationGraphNetwork.java @@ -317,7 +317,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { assertEquals(paramsMLN, paramsGraph); } - @Test + @Test(timeout = 300000) public void testIrisFitMultiDataSetIterator() throws Exception { RecordReader rr = new CSVRecordReader(0, ','); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/plot/BarnesHutTsneTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/plot/BarnesHutTsneTest.java index f65783bce..1c64c1fd1 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/plot/BarnesHutTsneTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/plot/BarnesHutTsneTest.java @@ -103,7 +103,7 @@ public class BarnesHutTsneTest extends BaseDL4JTest { assertArrayEquals(exp.data().asDouble(), b.getData().data().asDouble(), eps); } - @Test + @Test(timeout = 300000) public void testTsne() throws Exception { DataTypeUtil.setDTypeForContext(DataType.DOUBLE); Nd4j.getRandom().setSeed(123); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest050.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest050.java index 4dcbce538..867e96f09 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest050.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest050.java @@ -28,7 +28,9 @@ import org.deeplearning4j.nn.weights.WeightInitDistribution; import org.deeplearning4j.nn.weights.WeightInitRelu; import org.deeplearning4j.nn.weights.WeightInitXavier; import org.deeplearning4j.util.ModelSerializer; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.Timeout; import org.nd4j.linalg.activations.impl.ActivationLReLU; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.factory.Nd4j; @@ -42,6 +44,7 @@ import org.nd4j.linalg.lossfunctions.impl.LossNegativeLogLikelihood; import org.nd4j.resources.Resources; import java.io.File; +import java.sql.Time; import static org.junit.Assert.*; @@ -55,6 +58,9 @@ import static org.junit.Assert.*; */ public class RegressionTest050 extends BaseDL4JTest { + @Rule + public Timeout timeout = Timeout.seconds(300); + @Override public DataType getDataType(){ return DataType.FLOAT; diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/util/ModelGuesserTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/util/ModelGuesserTest.java index fda013533..5973bae71 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/util/ModelGuesserTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/util/ModelGuesserTest.java @@ -29,6 +29,7 @@ import org.deeplearning4j.nn.weights.WeightInit; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; +import org.junit.rules.Timeout; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.preprocessor.Normalizer; @@ -54,6 +55,9 @@ public class ModelGuesserTest extends BaseDL4JTest { @Rule public TemporaryFolder testDir = new TemporaryFolder(); + @Rule + public Timeout timeout = Timeout.seconds(300); + @Test public void testModelGuessFile() throws Exception { diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/FullModelComparisons.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/FullModelComparisons.java index 6043d7d48..2fb99dd2e 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/FullModelComparisons.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/FullModelComparisons.java @@ -33,6 +33,7 @@ import org.junit.Ignore; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; +import org.junit.rules.Timeout; import org.nd4j.linalg.activations.impl.ActivationHardSigmoid; import org.nd4j.linalg.activations.impl.ActivationTanH; import org.nd4j.linalg.api.ndarray.INDArray; @@ -60,6 +61,9 @@ public class FullModelComparisons extends BaseDL4JTest { @Rule public TemporaryFolder testDir = new TemporaryFolder(); + @Rule + public Timeout timeout = Timeout.seconds(300); + @Test public void lstmTest() throws IOException, UnsupportedKerasConfigurationException, InvalidKerasConfigurationException, InterruptedException { diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/preprocessing/sequence/TimeSeriesGeneratorImportTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/preprocessing/sequence/TimeSeriesGeneratorImportTest.java index 577e089f9..632bcc692 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/preprocessing/sequence/TimeSeriesGeneratorImportTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/preprocessing/sequence/TimeSeriesGeneratorImportTest.java @@ -30,7 +30,7 @@ import java.io.IOException; */ public class TimeSeriesGeneratorImportTest extends BaseDL4JTest { - @Test + @Test(timeout=300000) public void importTimeSeriesTest() throws IOException, InvalidKerasConfigurationException { String path = "modelimport/keras/preprocessing/timeseries_generator.json"; diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/preprocessing/text/TokenizerImportTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/preprocessing/text/TokenizerImportTest.java index 45114685b..f79ef60a5 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/preprocessing/text/TokenizerImportTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/preprocessing/text/TokenizerImportTest.java @@ -35,7 +35,7 @@ public class TokenizerImportTest extends BaseDL4JTest { ClassLoader classLoader = getClass().getClassLoader(); - @Test + @Test(timeout=300000) public void importTest() throws IOException, InvalidKerasConfigurationException { String path = "modelimport/keras/preprocessing/tokenizer.json"; @@ -51,7 +51,7 @@ public class TokenizerImportTest extends BaseDL4JTest { } - @Test + @Test(timeout=300000) public void importNumWordsNullTest() throws IOException, InvalidKerasConfigurationException { String path = "modelimport/keras/preprocessing/tokenizer_num_words_null.json"; diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/WordVectorSerializerTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/WordVectorSerializerTest.java index 69eae7307..4bb420ef6 100755 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/WordVectorSerializerTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/WordVectorSerializerTest.java @@ -16,6 +16,7 @@ package org.deeplearning4j.models; +import org.junit.rules.Timeout; import org.nd4j.shade.guava.io.Files; import org.nd4j.shade.guava.primitives.Doubles; import lombok.val; @@ -75,6 +76,9 @@ public class WordVectorSerializerTest extends BaseDL4JTest { @Rule public TemporaryFolder testDir = new TemporaryFolder(); + @Rule + public Timeout timeout = Timeout.seconds(300); + private File textFile, binaryFile, textFile2; private File fastTextRaw, fastTextZip, fastTextGzip; String pathToWriteto; diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/embeddings/loader/VectorsConfigurationTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/embeddings/loader/VectorsConfigurationTest.java index 1546ead8f..8f13ae3fd 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/embeddings/loader/VectorsConfigurationTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/embeddings/loader/VectorsConfigurationTest.java @@ -62,7 +62,7 @@ public class VectorsConfigurationTest extends BaseDL4JTest { assertEquals(configuration, configuration2); } - @Test + @Test(timeout = 300000) public void testFromW2V() throws Exception { VectorsConfiguration configuration = new VectorsConfiguration(); configuration.setHugeModelExpected(true); diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTests.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTests.java index 736998484..851178666 100755 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTests.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTests.java @@ -16,6 +16,8 @@ package org.deeplearning4j.models.word2vec; +import org.junit.Rule; +import org.junit.rules.Timeout; import org.nd4j.shade.guava.primitives.Doubles; import org.nd4j.shade.guava.primitives.Ints; import lombok.val; @@ -68,6 +70,9 @@ public class Word2VecTests extends BaseDL4JTest { private String pathToWriteto; private WordVectors googleModel; + @Rule + public Timeout timeout = Timeout.seconds(300); + @Before public void before() throws Exception { File googleModelTextFile = new ClassPathResource("word2vecserialization/google_news_30.txt").getFile(); diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/embeddings/inmemory/InMemoryLookupTableTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/embeddings/inmemory/InMemoryLookupTableTest.java index d7a0b7934..97a8821c6 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/embeddings/inmemory/InMemoryLookupTableTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/embeddings/inmemory/InMemoryLookupTableTest.java @@ -52,7 +52,7 @@ public class InMemoryLookupTableTest extends BaseDL4JTest { } - @Test + @Test(timeout = 300000) public void testConsumeOnEqualVocabs() throws Exception { TokenizerFactory t = new DefaultTokenizerFactory(); t.setTokenPreProcessor(new CommonPreprocessor()); @@ -99,7 +99,7 @@ public class InMemoryLookupTableTest extends BaseDL4JTest { } - @Test + @Test(timeout = 300000) public void testConsumeOnNonEqualVocabs() throws Exception { TokenizerFactory t = new DefaultTokenizerFactory(); t.setTokenPreProcessor(new CommonPreprocessor()); diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/fasttext/FastTextTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/fasttext/FastTextTest.java index 731a9cd60..8d53e8c1a 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/fasttext/FastTextTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/fasttext/FastTextTest.java @@ -12,6 +12,7 @@ import org.junit.Ignore; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; +import org.junit.rules.Timeout; import org.nd4j.linalg.primitives.Pair; import org.nd4j.resources.Resources; @@ -27,6 +28,9 @@ import static org.junit.Assert.assertEquals; @Slf4j public class FastTextTest extends BaseDL4JTest { + @Rule + public Timeout timeout = Timeout.seconds(300); + private File inputFile = Resources.asFile("models/fasttext/data/labeled_data.txt"); private File supModelFile = Resources.asFile("models/fasttext/supervised.model.bin"); private File cbowModelFile = Resources.asFile("models/fasttext/cbow.model.bin"); diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/transformers/impl/iterables/ParallelTransformerIteratorTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/transformers/impl/iterables/ParallelTransformerIteratorTest.java index c8674e630..c031d99ea 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/transformers/impl/iterables/ParallelTransformerIteratorTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/transformers/impl/iterables/ParallelTransformerIteratorTest.java @@ -55,7 +55,7 @@ public class ParallelTransformerIteratorTest extends BaseDL4JTest { } - @Test + @Test(timeout = 300000) public void hasNext() throws Exception { SentenceIterator iterator = new BasicLineIterator(Resources.asFile("big/raw_sentences.txt")); @@ -77,7 +77,7 @@ public class ParallelTransformerIteratorTest extends BaseDL4JTest { assertEquals(97162, cnt); } - @Test + @Test(timeout = 300000) public void testSpeedComparison1() throws Exception { SentenceIterator iterator = new MutipleEpochsSentenceIterator( new BasicLineIterator(Resources.asFile("big/raw_sentences.txt")), 25); diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTestsSmall.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTestsSmall.java index fc10a592a..ae2bf83c7 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTestsSmall.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTestsSmall.java @@ -82,7 +82,7 @@ public class Word2VecTestsSmall extends BaseDL4JTest { assertEquals(neighbours, nearestWords.size()); } - @Test + @Test(timeout = 300000) public void testUnkSerialization_1() throws Exception { val inputFile = Resources.asFile("big/raw_sentences.txt"); @@ -142,7 +142,7 @@ public class Word2VecTestsSmall extends BaseDL4JTest { } - @Test + @Test(timeout = 300000) public void testW2VEmbeddingLayerInit() throws Exception { Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT); diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/iterator/Word2VecDataSetIteratorTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/iterator/Word2VecDataSetIteratorTest.java index 4a77d6807..b84ecf95c 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/iterator/Word2VecDataSetIteratorTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/iterator/Word2VecDataSetIteratorTest.java @@ -50,7 +50,7 @@ public class Word2VecDataSetIteratorTest extends BaseDL4JTest { /** * Basically all we want from this test - being able to finish without exceptions. */ - @Test + @Test(timeout = 300000) public void testIterator1() throws Exception { File inputFile = Resources.asFile("big/raw_sentences.txt"); SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath()); diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/wordstore/VocabConstructorTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/wordstore/VocabConstructorTest.java index 84b8d3c38..430c21492 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/wordstore/VocabConstructorTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/word2vec/wordstore/VocabConstructorTest.java @@ -22,6 +22,7 @@ import lombok.val; import org.deeplearning4j.BaseDL4JTest; import org.junit.Rule; import org.junit.rules.TemporaryFolder; +import org.junit.rules.Timeout; import org.nd4j.linalg.io.ClassPathResource; import org.deeplearning4j.models.sequencevectors.interfaces.SequenceIterator; import org.deeplearning4j.models.sequencevectors.iterators.AbstractSequenceIterator; @@ -43,6 +44,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.File; +import java.sql.Time; import java.util.*; import java.util.concurrent.atomic.AtomicBoolean; @@ -53,6 +55,9 @@ import static org.junit.Assert.*; */ public class VocabConstructorTest extends BaseDL4JTest { + @Rule + public Timeout timeout = Timeout.seconds(300); + protected static final Logger log = LoggerFactory.getLogger(VocabConstructorTest.class); TokenizerFactory t = new DefaultTokenizerFactory(); diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/documentiterator/AsyncLabelAwareIteratorTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/documentiterator/AsyncLabelAwareIteratorTest.java index 7c8ea3ba7..dbf6c7aa4 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/documentiterator/AsyncLabelAwareIteratorTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/documentiterator/AsyncLabelAwareIteratorTest.java @@ -29,7 +29,7 @@ import static org.junit.Assert.assertEquals; * @author raver119@gmail.com */ public class AsyncLabelAwareIteratorTest extends BaseDL4JTest { - @Test + @Test(timeout = 300000) public void nextDocument() throws Exception { SentenceIterator sentence = new BasicLineIterator(Resources.asFile("big/raw_sentences.txt")); BasicLabelAwareIterator backed = new BasicLabelAwareIterator.Builder(sentence).build(); diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/documentiterator/BasicLabelAwareIteratorTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/documentiterator/BasicLabelAwareIteratorTest.java index 1696226d3..984a098cc 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/documentiterator/BasicLabelAwareIteratorTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/documentiterator/BasicLabelAwareIteratorTest.java @@ -17,6 +17,8 @@ package org.deeplearning4j.text.documentiterator; import org.deeplearning4j.BaseDL4JTest; +import org.junit.Rule; +import org.junit.rules.Timeout; import org.nd4j.linalg.io.ClassPathResource; import org.deeplearning4j.text.sentenceiterator.BasicLineIterator; import org.deeplearning4j.text.sentenceiterator.SentenceIterator; @@ -33,6 +35,9 @@ import static org.junit.Assert.assertEquals; */ public class BasicLabelAwareIteratorTest extends BaseDL4JTest { + @Rule + public Timeout timeout = Timeout.seconds(300); + @Before public void setUp() throws Exception { diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/sentenceiterator/AggregatingSentenceIteratorTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/sentenceiterator/AggregatingSentenceIteratorTest.java index cd6fe3449..775e18d3f 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/sentenceiterator/AggregatingSentenceIteratorTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/sentenceiterator/AggregatingSentenceIteratorTest.java @@ -30,7 +30,7 @@ import static org.junit.Assert.assertEquals; */ public class AggregatingSentenceIteratorTest extends BaseDL4JTest { - @Test + @Test(timeout = 300000) public void testHasNext() throws Exception { File file = Resources.asFile("/big/raw_sentences.txt"); BasicLineIterator iterator = new BasicLineIterator(file); diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/sentenceiterator/BasicLineIteratorTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/sentenceiterator/BasicLineIteratorTest.java index e6aaa338e..0f01937ed 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/sentenceiterator/BasicLineIteratorTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/sentenceiterator/BasicLineIteratorTest.java @@ -17,6 +17,8 @@ package org.deeplearning4j.text.sentenceiterator; import org.deeplearning4j.BaseDL4JTest; +import org.junit.Rule; +import org.junit.rules.Timeout; import org.nd4j.linalg.io.ClassPathResource; import org.junit.Before; import org.junit.Test; @@ -32,6 +34,9 @@ import static org.junit.Assert.assertEquals; */ public class BasicLineIteratorTest extends BaseDL4JTest { + @Rule + public Timeout timeout = Timeout.seconds(300); + @Before public void setUp() throws Exception { diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/sentenceiterator/MutipleEpochsSentenceIteratorTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/sentenceiterator/MutipleEpochsSentenceIteratorTest.java index b187caf8d..1a3c215aa 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/sentenceiterator/MutipleEpochsSentenceIteratorTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/sentenceiterator/MutipleEpochsSentenceIteratorTest.java @@ -27,7 +27,7 @@ import static org.junit.Assert.assertEquals; * @author raver119@gmail.com */ public class MutipleEpochsSentenceIteratorTest extends BaseDL4JTest { - @Test + @Test(timeout = 300000) public void hasNext() throws Exception { SentenceIterator iterator = new MutipleEpochsSentenceIterator( new BasicLineIterator(Resources.asFile("big/raw_sentences.txt")), 100); diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/sentenceiterator/PrefetchingSentenceIteratorTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/sentenceiterator/PrefetchingSentenceIteratorTest.java index 414f1454b..f0f6f1c54 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/sentenceiterator/PrefetchingSentenceIteratorTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/sentenceiterator/PrefetchingSentenceIteratorTest.java @@ -17,6 +17,8 @@ package org.deeplearning4j.text.sentenceiterator; import org.deeplearning4j.BaseDL4JTest; +import org.junit.Rule; +import org.junit.rules.Timeout; import org.nd4j.linalg.io.ClassPathResource; import org.junit.Test; import org.nd4j.resources.Resources; @@ -33,6 +35,9 @@ import static org.junit.Assert.assertTrue; */ public class PrefetchingSentenceIteratorTest extends BaseDL4JTest { + @Rule + public Timeout timeout = Timeout.seconds(300); + protected static final Logger log = LoggerFactory.getLogger(PrefetchingSentenceIteratorTest.class); @Test diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/tokenization/tokenizer/BertWordPieceTokenizerTests.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/tokenization/tokenizer/BertWordPieceTokenizerTests.java index 4b78e51a2..80570ae54 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/tokenization/tokenizer/BertWordPieceTokenizerTests.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/text/tokenization/tokenizer/BertWordPieceTokenizerTests.java @@ -205,7 +205,7 @@ public class BertWordPieceTokenizerTests extends BaseDL4JTest { } - @Test + @Test(timeout = 300000) public void testBertWordPieceTokenizer10() throws Exception { File f = Resources.asFile("deeplearning4j-nlp/bert/uncased_L-12_H-768_A-12/vocab.txt"); BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(f, true, true, StandardCharsets.UTF_8); diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/ParallelInferenceTest.java b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/ParallelInferenceTest.java index 4cec0eed4..e4d48937e 100644 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/ParallelInferenceTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/ParallelInferenceTest.java @@ -27,7 +27,8 @@ import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.graph.ComputationGraph; -import org.junit.Ignore; +import org.junit.*; +import org.junit.rules.Timeout; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.io.ClassPathResource; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; @@ -38,9 +39,6 @@ import org.deeplearning4j.parallelism.inference.InferenceObservable; import org.deeplearning4j.parallelism.inference.observers.BasicInferenceObserver; import org.deeplearning4j.parallelism.inference.observers.BatchedInferenceObservable; import org.deeplearning4j.util.ModelSerializer; -import org.junit.After; -import org.junit.Before; -import org.junit.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; @@ -66,6 +64,9 @@ public class ParallelInferenceTest extends BaseDL4JTest { private static MultiLayerNetwork model; private static DataSetIterator iterator; + @Rule + public Timeout timeout = Timeout.seconds(300); + @Before public void setUp() throws Exception { if (model == null) { diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui/src/test/java/org/deeplearning4j/ui/ManualTests.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui/src/test/java/org/deeplearning4j/ui/ManualTests.java index 26a1ec950..7621faa3b 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui/src/test/java/org/deeplearning4j/ui/ManualTests.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui/src/test/java/org/deeplearning4j/ui/ManualTests.java @@ -100,7 +100,7 @@ public class ManualTests { } - @Test + @Test(timeout = 300000) public void testTsne() throws Exception { DataTypeUtil.setDTypeForContext(DataType.DOUBLE); Nd4j.getRandom().setSeed(123); @@ -208,7 +208,7 @@ public class ManualTests { } - @Test + @Test(timeout = 300000) public void testWord2VecPlot() throws Exception { File inputFile = Resources.asFile("big/raw_sentences.txt"); SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath()); diff --git a/libnd4j/include/ops/declarable/generic/nn/fusedBatchNorm.cpp b/libnd4j/include/ops/declarable/generic/nn/fusedBatchNorm.cpp index f2215503c..84facc0cc 100644 --- a/libnd4j/include/ops/declarable/generic/nn/fusedBatchNorm.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/fusedBatchNorm.cpp @@ -32,7 +32,7 @@ namespace ops { ->setAllowedOutputTypes({ALL_FLOATS}); } -CUSTOM_OP_IMPL(fused_batch_norm, 3, 1, false, 0, 2) { +CUSTOM_OP_IMPL(fused_batch_norm, 3, 3, false, 0, 2) { auto x = INPUT_VARIABLE(0); // [bS,iH,iW,iD] (NHWC) or [bS,iD,iH,iW] (NCHW) auto scale = INPUT_VARIABLE(1); // [iD] auto offset = INPUT_VARIABLE(2); // [iD] @@ -70,7 +70,7 @@ CUSTOM_OP_IMPL(fused_batch_norm, 3, 1, false, 0, 2) { REQUIRE_TRUE(variance->rankOf() == 1 && variance->sizeAt(0) == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input variance array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(variance).c_str()); } else { - REQUIRE_TRUE(block.width() == 3, 0, "CUSTOM_OP fused_batch_norm: when isTraining=true then number of input arrays must be equal to 3, but got %i instead !", block.width()); + //REQUIRE_TRUE(block.width() == 3, 0, "CUSTOM_OP fused_batch_norm: when isTraining=true then number of input arrays must be equal to 3, but got %i instead !", block.width()); std::vector shape = {iD}; mean = NDArrayFactory::create_(scale->ordering(), shape, scale->dataType(), block.launchContext()); variance = NDArrayFactory::create_(scale->ordering(), shape, scale->dataType(), block.launchContext()); diff --git a/libnd4j/include/ops/declarable/helpers/cpu/image_resize.cpp b/libnd4j/include/ops/declarable/helpers/cpu/image_resize.cpp index b17167b9a..ced05ceaa 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/image_resize.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/image_resize.cpp @@ -682,27 +682,29 @@ namespace helpers { pY2[pt_index], pY3[pt_index]); } - template + template static void bicubicInterpolateWithCaching(NDArray const* image, ImageResizerState const& resizerState, bool const halfPixelCenters, NDArray* output) { std::vector xWais(resizerState.outWidth); - computeXWeightsAndIndices(resizerState, halfPixelCenters, &xWais); const auto numChannels = resizerState.channels; const Nd4jLong inRowWidth = resizerState.inWidth * numChannels; const Nd4jLong inBatchWidth = resizerState.inHeight * inRowWidth; + const auto batchNum = resizerState.batchSize; + const auto outHeight = resizerState.outHeight; + const auto outWidth = resizerState.outWidth; - const T* inputPtr = image->getDataBuffer()->primaryAsT(); - float* pOutputY = output->dataBuffer()->primaryAsT(); // output is float anyway - std::vector cachedValue(numChannels == 3 ? 0 : 4 * numChannels, 0); + auto func = PRAGMA_THREADS_FOR { + const T* inputPtr = image->getDataBuffer()->primaryAsT(); + F* pOutputY = output->dataBuffer()->primaryAsT(); // output is float anyway + std::vector cachedValue(numChannels == 3 ? 0 : 4 * numChannels, 0); - auto func = PRAGMA_THREADS_FOR { for (auto b = start; b < stop; ++b) { auto pInput = inputPtr + b * inBatchWidth; - for (auto y = 0; y < resizerState.outHeight; ++y) { - auto pOutput = &pOutputY[(b * resizerState.outHeight + y) * resizerState.outWidth * numChannels]; + for (auto y = 0; y < outHeight; ++y) { + auto pOutput = &pOutputY[(b * outHeight + y) * outWidth * numChannels]; WeightsAndIndices yWai; if (halfPixelCenters) { @@ -713,16 +715,16 @@ namespace helpers { resizerState.heightScale, y, resizerState.inHeight, &yWai); } // Make pointers represent offsets of data in inputBPtr. - const T *y_ptr_0 = pInput + yWai._index0 * inRowWidth; - const T *y_ptr_1 = pInput + yWai._index1 * inRowWidth; - const T *y_ptr_2 = pInput + yWai._index2 * inRowWidth; - const T *y_ptr_3 = pInput + yWai._index3 * inRowWidth; + const T* y_ptr_0 = pInput + yWai._index0 * inRowWidth; + const T* y_ptr_1 = pInput + yWai._index1 * inRowWidth; + const T* y_ptr_2 = pInput + yWai._index2 * inRowWidth; + const T* y_ptr_3 = pInput + yWai._index3 * inRowWidth; if (numChannels == 3) { // Manually unroll case of 3 channels. - float cached_value_0[4] = {0}; - float cached_value_1[4] = {0}; - float cached_value_2[4] = {0}; + F cached_value_0[4] = {0}; + F cached_value_1[4] = {0}; + F cached_value_2[4] = {0}; for (auto x = 0; x < resizerState.outWidth; ++x) { const WeightsAndIndices &xWai = xWais[x]; // Shift values in cached_value_* to fill first '_advance' values. @@ -854,7 +856,7 @@ namespace helpers { } for (auto c = 0; c < numChannels; ++c) { pOutput[x * numChannels + c] = - compute(&cachedValue[4 * c], xWai._weight0, xWai._weight1, + (F)compute(&cachedValue[4 * c], xWai._weight0, xWai._weight1, xWai._weight2, xWai._weight3); } } @@ -862,7 +864,7 @@ namespace helpers { } } }; - samediff::Threads::parallel_tad(func, 0, resizerState.batchSize); + samediff::Threads::parallel_tad(func, 0, batchNum); } // simplified bicubic resize without antialiasing @@ -873,7 +875,7 @@ namespace helpers { ImageResizerState st(alignCorners, halfPixelAlign); // align_corners, half_pixel_align int res = st.validateAndCreateOutput(image, width, height); if (res == Status::OK()) - bicubicInterpolateWithCaching(image, st, halfPixelAlign, output); + bicubicInterpolateWithCaching(image, st, halfPixelAlign, output); return res; } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp index 647f37271..7b08bfbe4 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp @@ -975,6 +975,118 @@ TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test6) { delete results; } +TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test7) { + + NDArray input = NDArrayFactory::create('c', {2, 5, 5, 1}, { + 0.2303, 0.7950, 0.8171, 0.0451, 0.3690, 0.6846, 0.2727, 0.2770, 0.2381, 0.9511, + 0.4116, 0.3997, 0.4075, 0.6275, 0.8018, 0.0678, 0.6221, 0.2982, 0.1524, 0.2613, + 0.7425, 0.6036, 0.7926, 0.5838, 0.1361, 0.4154, 0.3634, 0.3741, 0.2088, 0.2989, + 0.3982, 0.5618, 0.7266, 0.1089, 0.2922, 0.3306, 0.2869, 0.6638, 0.3091, 0.9312, + 0.0240, 0.2893, 0.5632, 0.9625, 0.4189, 0.3854, 0.2743, 0.6754, 0.8820, 0.8699}); + + NDArray expected = NDArrayFactory::create('c', {2, 9, 9, 1}, { + 0.2303f, 0.54569f, 0.840649f, 0.92725444f, 0.65660673f, + 0.16641647f, 0.06117659f, 0.33279106f, 0.4023279f, 0.5139505f, + 0.49821317f, 0.4906872f, 0.537642f, 0.4070102f, 0.13030615f, + 0.258801f, 0.65352744f, 0.773368f, 0.69225276f, 0.44177493f, + 0.21910316f, 0.22368976f, 0.24221404f, 0.21399781f, 0.5114972f, + 0.9169859f, 1.0511527f, 0.5608501f, 0.41315168f, 0.2913824f, + 0.2966933f, 0.38585684f, 0.48849702f, 0.71013063f, 0.9086001f, + 0.9794303f, 0.29625386f, 0.39427578f, 0.45971435f, 0.39693952f, + 0.40860707f, 0.51061106f, 0.6181093f, 0.67309624f, 0.69564015f, + 0.06012487f, 0.3863805f, 0.58993465f, 0.40679216f, 0.22607432f, + 0.20093678f, 0.25901243f, 0.3615362f, 0.39371052f, 0.24176767f, + 0.4868709f, 0.650651f, 0.5493148f, 0.3825456f, 0.27788478f, + 0.18927254f, 0.16692996f, 0.15432167f, 0.677519f, 0.6236242f, + 0.61700624f, 0.7214321f, 0.7307374f, 0.6251454f, 0.3924176f, + 0.17802659f, 0.10231908f, 0.81192374f, 0.66878575f, 0.6118803f, + 0.7797006f, 0.8396968f, 0.72889954f, 0.44547448f, 0.16794783f, + 0.07125802f, 0.4154f, 0.38504714f, 0.3623221f, 0.3862173f, + 0.3397379f, 0.23285517f, 0.21876639f, 0.2892362f, 0.30817088f, + 0.41268015f, 0.45587808f, 0.51991886f, 0.60977113f, 0.49489656f, + 0.21313031f, 0.11297428f, 0.2167207f, 0.23940037f, 0.39337245f, + 0.46112412f, 0.583034f, 0.76207364f, 0.6326203f, 0.22189438f, + 0.12071565f, 0.3275853f, 0.3794855f, 0.38497013f, 0.35049653f, + 0.41895086f, 0.671095f, 0.62119365f, 0.22362521f, 0.30189657f, + 0.72530353f, 0.85048175f, 0.2524255f, 0.2182264f, 0.2964637f, + 0.5361996f, 0.6255393f, 0.46424767f, 0.5741281f, 0.8408146f, + 0.92403257f, 0.04648584f, 0.14959256f, 0.32215607f, 0.46194845f, + 0.6642166f, 0.83560026f, 0.7663391f, 0.5284251f, 0.4573109f, + 0.10357999f, 0.17442937f, 0.32116935f, 0.45530772f, 0.7163773f, + 0.9856574f, 0.8976148f, 0.5538923f, 0.45173654f, 0.34958175f, + 0.2680429f, 0.30470955f, 0.51233786f, 0.75128907f, 0.86736864f, + 0.8982046f, 0.83254474f, 0.8168574f, 0.4225865f, 0.2956836f, + 0.29948136f, 0.5276342f, 0.76461166f, 0.8442875f, 0.907862f, + 0.9139262f, 0.92068815f + }); + auto size = NDArrayFactory::create({9, 9}); + nd4j::ops::resize_bicubic op; + auto results = op.execute({&input, &size}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + NDArray* result = results->at(0); + +// result->printBuffer("Resized to 9x9"); +// expected.printBuffer("Expect for 9x9"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + delete results; +} + +TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test8) { + + NDArray input = NDArrayFactory::create('c', {2, 5, 5, 1}, { + 0.23028551377579154, 0.7949972231516509, 0.8171307820461517, 0.04507309923418412, 0.3689673597428338, + 0.6845757584903018, 0.27268547668219667, 0.2770196372806053, 0.2381478370531429, 0.9511201914609859, + 0.41160882670429033, 0.3997152563642703, 0.4074505147711718, 0.6274595060113246, 0.8017922711300232, + 0.06782045852179475, 0.6220772280691722, 0.2982335327629251, 0.1523603480424196, 0.2612986044295986, + 0.7424762244324299, 0.6036156464824591, 0.7926371071102005, 0.5838270656432538, 0.13607200219168547, + 0.4154002170215956, 0.36340617544852116, 0.37405031188276827, 0.20880251686544882, 0.298919946410666, + 0.39820758164277126, 0.5617728968896589, 0.72660225993937, 0.10888245916813699, 0.29215797784445496, + 0.3305531351746034, 0.28693451964931715, 0.6637635348315494, 0.30913418229827583, 0.9312186188801752, + 0.0239594182399363, 0.2892942758780874, 0.5631691110629038, 0.9625499752246309, 0.4189439089689968, + 0.3854304088214935, 0.27426304203925045, 0.6754051704648238, 0.8820362490795286, 0.8699337744328859}); + + + auto testData = NDArrayFactory::create('c', {2,9,9,1}, { + 0.230286f, 0.510566354f, 0.794997215f, 0.931386113f, 0.817130804f, 0.402811885f, 0.045073099f, 0.134639814f, 0.368967354f, + 0.483021289f, 0.501266003f, 0.521932304f, 0.572325349f, 0.534847379f, 0.267853439f, 0.105112493f, 0.349290252f, 0.674043298f, + 0.684575737f, 0.478224277f, 0.272685468f, 0.239882097f, 0.27701965f, 0.191148892f, 0.23814784f, 0.590989769f, 0.951120198f, + 0.622912169f, 0.441326082f, 0.266387194f, 0.232538164f, 0.301838756f, 0.356378645f, 0.495445013f, 0.756725252f, 0.981704295f, + 0.411608815f, 0.40493685f, 0.399715245f, 0.381842017f, 0.407450527f, 0.501836538f, 0.627459526f, 0.735251725f, 0.801792264f, + 0.150875032f, 0.357000858f, 0.524536073f, 0.450354964f, 0.318719596f, 0.319606483f, 0.385957927f, 0.46392554f, 0.529285908f, + 0.06782046f, 0.375309169f, 0.622077227f, 0.525792599f, 0.298233539f, 0.184723631f, 0.15236035f, 0.193153858f, 0.261298597f, + + 0.372918189f, 0.512539625f, 0.63369292f, 0.628733814f, 0.535196245f, 0.436597466f, 0.323553175f, 0.215942055f, 0.148014024f, + 0.742476225f, 0.655325174f, 0.603615642f, 0.704684138f, 0.79263711f, 0.747929871f, 0.583827078f, 0.340373576f, 0.136071995f, + 0.415400207f, 0.388405323f, 0.363406181f, 0.379345775f, 0.374050319f, 0.28397581f, 0.208802521f, 0.238369256f, 0.298919946f, + 0.413146496f, 0.444389015f, 0.488355637f, 0.568351328f, 0.556217432f, 0.345546633f, 0.140068889f, 0.148834035f, 0.23562704f, + 0.398207575f, 0.464537472f, 0.561772883f, 0.717433035f, 0.726602256f, 0.416013002f, 0.108882457f, 0.142608985f, 0.292157978f, + 0.391511708f, 0.389470309f, 0.442729384f, 0.651181757f, 0.737665415f, 0.41685915f, 0.138383076f, 0.342548877f, 0.659080088f, + + 0.330553144f, 0.273416102f, 0.286934525f, 0.50450629f, 0.663763523f, 0.463456154f, 0.309134185f, 0.586929917f, 0.931218624f, + 0.137025774f, 0.169145152f, 0.263757467f, 0.436182201f, 0.597053051f, 0.657990932f, 0.662163854f, 0.68354249f, 0.692712903f, + 0.023959421f, 0.130951077f, 0.289294273f, 0.413664877f, 0.563169122f, 0.839498401f, 0.962549984f, 0.728188932f, 0.418943912f, + 0.175951749f, 0.198239252f, 0.281999886f, 0.420836329f, 0.609856486f, 0.863734365f, 0.983550847f, 0.825015843f, 0.596413136f, + 0.385430396f, 0.292239636f, 0.274263054f, 0.445040524f, 0.675405145f, 0.817462444f, 0.882036269f, 0.895356655f, 0.869933784f + }); + + auto size = NDArrayFactory::create({9, 9}); + nd4j::ops::resize_bicubic op; + auto results = op.execute({&input, &size}, {}, {}, {true, false}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + NDArray* result = results->at(0); + +// result->printBuffer("Resized to 9x9"); +// expected.printBuffer("Expect for 9x9"); + ASSERT_TRUE(testData.isSameShape(result)); + ASSERT_TRUE(testData.equalsTo(result)); + delete results; +} + /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, summaryStatsData_test1) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java index b0fc00bac..0a725786e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java @@ -232,10 +232,7 @@ import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentProdBp; import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentSqrtNBp; import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentSumBp; import org.nd4j.linalg.api.ops.impl.transforms.strict.*; -import org.nd4j.linalg.api.ops.random.custom.DistributionUniform; -import org.nd4j.linalg.api.ops.random.custom.RandomBernoulli; -import org.nd4j.linalg.api.ops.random.custom.RandomExponential; -import org.nd4j.linalg.api.ops.random.custom.RandomNormal; +import org.nd4j.linalg.api.ops.random.custom.*; import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution; import org.nd4j.linalg.api.ops.random.impl.BinomialDistribution; import org.nd4j.linalg.api.ops.random.impl.DropOutInverted; @@ -384,6 +381,18 @@ public class DifferentialFunctionFactory { return new TruncatedNormalDistribution(sameDiff(), mean, stdev, shape).outputVariable(); } + public SDVariable randomGamma(SDVariable shape, SDVariable alpha, SDVariable beta, int... seeds) { + return new RandomGamma(sameDiff(), shape, alpha, beta, seeds).outputVariable(); + } + + public SDVariable randomPoisson(SDVariable shape, SDVariable rate, int... seeds) { + return new RandomPoisson(sameDiff(), shape, rate, seeds).outputVariable(); + } + + public SDVariable randomShuffle(SDVariable values, int... seeds) { + return new RandomShuffle(sameDiff(), values, seeds).outputVariable(); + } + /** * Exponential distribution: P(x) = lambda * exp(-lambda * x) * diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDImage.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDImage.java index bf71a665e..4cc020e3a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDImage.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDImage.java @@ -3,10 +3,7 @@ package org.nd4j.autodiff.samediff.ops; import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.linalg.api.ops.custom.AdjustContrast; -import org.nd4j.linalg.api.ops.custom.AdjustHue; -import org.nd4j.linalg.api.ops.custom.AdjustSaturation; -import org.nd4j.linalg.api.ops.custom.RandomCrop; +import org.nd4j.linalg.api.ops.custom.*; import org.nd4j.linalg.api.ops.impl.image.CropAndResize; import org.nd4j.linalg.api.ops.impl.image.ExtractImagePatches; import org.nd4j.linalg.api.ops.impl.image.NonMaxSuppression; @@ -119,4 +116,26 @@ public class SDImage extends SDOps { SDVariable out = new RandomCrop(sd, input, shape).outputVariable(); return updateVariableNameAndReference(out, name); } + + /** + * Converting array from HSV to RGB format + * @param name name + * @param input 3D image + * @return 3D image + */ + public SDVariable rgbToHsv(String name, @NonNull SDVariable input) { + SDVariable out = new RgbToHsv(sd, input).outputVariable(); + return updateVariableNameAndReference(out, name); + } + + /** + * Converting image from HSV to RGB format + * @param name name + * @param input 3D image + * @return 3D image + */ + public SDVariable hsvToRgb(String name, @NonNull SDVariable input) { + SDVariable out = new HsvToRgb(sd, input).outputVariable(); + return updateVariableNameAndReference(out, name); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDRandom.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDRandom.java index 66e52c151..879c08cba 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDRandom.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDRandom.java @@ -19,6 +19,7 @@ package org.nd4j.autodiff.samediff.ops; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import static org.nd4j.autodiff.samediff.ops.SDValidation.validateInteger; @@ -295,4 +296,43 @@ public class SDRandom extends SDOps { return updateVariableNameAndReference(ret, name); } + /** + * Generate a new random SDVariable with Gamma distribution + * + * @param name Name of the output variable + * @param alpha distribution parameter + * @param beta distribution parameter + * @param shape Shape of the new variable + * @return new SDVariable + */ + public SDVariable gamma(String name, SDVariable shape, SDVariable alpha, SDVariable beta) { + SDVariable ret = f().randomGamma(alpha, beta, shape); + return updateVariableNameAndReference(ret, name); + } + + /** + * Generate a new random SDVariable with Poission distribution + * + * @param name Name of the output variable + * @param lambda rate distribution parameter + * @param shape Shape of the new variable + * @return new SDVariable + */ + public SDVariable poisson(String name, SDVariable lambda, SDVariable shape, int... seeds) { + SDVariable ret = f().randomPoisson(shape, lambda, seeds); + return updateVariableNameAndReference(ret, name); + } + + /** + * Generate a new random SDVariable by random shuffle + * + * @param name Name of the output variable + * @param value array to shuffle + * @return new SDVariable + */ + public SDVariable shuffle(String name, SDVariable value, int... seeds) { + SDVariable ret = f().randomShuffle(value, seeds); + return updateVariableNameAndReference(ret, name); + } + } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java index cb63dab61..700d48a3c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java @@ -572,6 +572,9 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.random.custom.RandomBernoulli.class, org.nd4j.linalg.api.ops.random.custom.RandomExponential.class, org.nd4j.linalg.api.ops.random.custom.RandomNormal.class, + org.nd4j.linalg.api.ops.random.custom.RandomGamma.class, + org.nd4j.linalg.api.ops.random.custom.RandomPoisson.class, + org.nd4j.linalg.api.ops.random.custom.RandomShuffle.class, org.nd4j.linalg.api.ops.random.impl.AlphaDropOut.class, org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution.class, org.nd4j.linalg.api.ops.random.impl.BinomialDistribution.class, @@ -588,6 +591,8 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.random.impl.UniformDistribution.class, org.nd4j.linalg.api.ops.custom.AdjustContrast.class, org.nd4j.linalg.api.ops.custom.AdjustContrastV2.class, + org.nd4j.linalg.api.ops.custom.HsvToRgb.class, + org.nd4j.linalg.api.ops.custom.RgbToHsv.class, org.nd4j.linalg.api.ops.custom.BitCast.class, org.nd4j.linalg.api.ops.custom.CompareAndBitpack.class, org.nd4j.linalg.api.ops.custom.DivideNoNan.class, @@ -601,7 +606,10 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.custom.Polygamma.class, org.nd4j.linalg.api.ops.custom.RandomCrop.class, org.nd4j.linalg.api.ops.custom.Roll.class, - org.nd4j.linalg.api.ops.custom.ToggleBits.class + org.nd4j.linalg.api.ops.custom.ToggleBits.class, + org.nd4j.linalg.api.ops.custom.Igamma.class, + org.nd4j.linalg.api.ops.custom.Igammac.class, + org.nd4j.linalg.api.ops.custom.Digamma.class ); static { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Digamma.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Digamma.java new file mode 100644 index 000000000..206d0027f --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Digamma.java @@ -0,0 +1,56 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.custom; + +import lombok.NoArgsConstructor; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.Collections; +import java.util.List; + +@NoArgsConstructor +public class Digamma extends DynamicCustomOp { + public Digamma(@NonNull INDArray x) { + addInputArgument(x); + } + + public Digamma(@NonNull SameDiff sameDiff, @NonNull SDVariable x) { + super("", sameDiff, new SDVariable[]{x}); + } + + @Override + public String opName() { + return "digamma"; + } + + @Override + public String tensorflowName() { + return "Digamma"; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + int n = args().length; + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes); + return Collections.singletonList(inputDataTypes.get(0)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/FusedBatchNorm.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/FusedBatchNorm.java index 691e5d43f..3bbab11be 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/FusedBatchNorm.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/FusedBatchNorm.java @@ -19,15 +19,23 @@ import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; +import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.tensorflow.framework.AttrValue; +import org.tensorflow.framework.GraphDef; +import org.tensorflow.framework.NodeDef; +import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.Map; public class FusedBatchNorm extends DynamicCustomOp { + private DataType outputDataType; + public FusedBatchNorm() {} public FusedBatchNorm(@NonNull INDArray x, @NonNull INDArray scale, @NonNull INDArray offset, @@ -38,6 +46,7 @@ public class FusedBatchNorm extends DynamicCustomOp { if (yOut != null && batchMeanOut != null && batchMeanVar != null) { addOutputArgument(yOut, batchMeanOut, batchMeanVar); } + this.outputDataType = x.dataType(); } public FusedBatchNorm(@NonNull SameDiff sameDiff, @NonNull SDVariable x, @NonNull SDVariable scale, @NonNull SDVariable offset, @@ -51,14 +60,25 @@ public class FusedBatchNorm extends DynamicCustomOp { } @Override - public String tensorflowName() { - return "FusedBatchNormV2"; + public String[] tensorflowNames() { + return new String[]{"FusedBatchNormV2","FusedBatchNormV3"}; + } + + @Override + public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { + boolean isNchw = attributesForNode.containsKey("data_format") && attributesForNode.get("data_format").getS().toStringUtf8().equalsIgnoreCase("NCHW"); + boolean training = !attributesForNode.containsKey("is_training") ? true : attributesForNode.get("is_training").getB(); + addIArgument(isNchw ? 1 : 0); + addIArgument(training ? 1 : 0); + if(attributesForNode.containsKey("T")){ + outputDataType = TFGraphMapper.convertType(attributesForNode.get("T").getType()); + } } @Override public List calculateOutputDataTypes(List inputDataTypes){ int n = args().length; Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes); - return Collections.singletonList(inputDataTypes.get(0)); + return Arrays.asList(outputDataType, DataType.FLOAT, DataType.FLOAT); //Activations may be half, bfloat16, float32; mean/var is always float } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/HsvToRgb.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/HsvToRgb.java new file mode 100644 index 000000000..6b1361376 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/HsvToRgb.java @@ -0,0 +1,57 @@ + +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.custom; + +import lombok.NoArgsConstructor; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.Collections; +import java.util.List; + +@NoArgsConstructor +public class HsvToRgb extends DynamicCustomOp { + + public HsvToRgb(INDArray input) { + addInputArgument(input); + } + + public HsvToRgb(SameDiff sameDiff, SDVariable input) { + super(sameDiff, new SDVariable[]{input}); + } + + @Override + public String opName() { + return "hsv_to_rgb"; + } + + @Override + public String tensorflowName() { + return "HSVToRGB"; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + int n = args().length; + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes); + return Collections.singletonList(inputDataTypes.get(0)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Igamma.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Igamma.java new file mode 100644 index 000000000..e8efddb12 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Igamma.java @@ -0,0 +1,65 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.custom; + +import lombok.NoArgsConstructor; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.Collections; +import java.util.List; + +@NoArgsConstructor +public class Igamma extends DynamicCustomOp { + public Igamma(@NonNull INDArray n, @NonNull INDArray x) { + Preconditions.checkArgument(n.shape() != x.shape(), + "Igamma: n and x must have the same shapes"); + addInputArgument(n,x); + } + + public Igamma(@NonNull INDArray n, @NonNull INDArray x, INDArray output) { + this(n,x); + if (output != null) { + addOutputArgument(output); + } + } + + public Igamma(@NonNull SameDiff sameDiff, @NonNull SDVariable n, @NonNull SDVariable x) { + super("", sameDiff, new SDVariable[]{n ,x}); + } + + @Override + public String opName() { + return "igamma"; + } + + @Override + public String tensorflowName() { + return "Igamma"; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + int n = args().length; + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes); + return Collections.singletonList(inputDataTypes.get(0)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Igammac.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Igammac.java new file mode 100644 index 000000000..915a57764 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Igammac.java @@ -0,0 +1,66 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.custom; + +import lombok.NoArgsConstructor; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.DynamicPartitionBp; + +import java.util.Collections; +import java.util.List; + +@NoArgsConstructor +public class Igammac extends DynamicCustomOp { + public Igammac(@NonNull INDArray n, @NonNull INDArray x) { + Preconditions.checkArgument(n.shape() != x.shape(), + "Igamma: n and x must have the same shapes"); + addInputArgument(n,x); + } + + public Igammac(@NonNull INDArray n, @NonNull INDArray x, INDArray output) { + this(n,x); + if (output != null) { + addOutputArgument(output); + } + } + + public Igammac(@NonNull SameDiff sameDiff, @NonNull SDVariable n, @NonNull SDVariable x) { + super("", sameDiff, new SDVariable[]{n ,x}); + } + + @Override + public String opName() { + return "igammac"; + } + + @Override + public String tensorflowName() { + return "Igammac"; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + int n = args().length; + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes); + return Collections.singletonList(inputDataTypes.get(0)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/RgbToHsv.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/RgbToHsv.java new file mode 100644 index 000000000..d96981460 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/RgbToHsv.java @@ -0,0 +1,57 @@ + +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.custom; + +import lombok.NoArgsConstructor; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.Collections; +import java.util.List; + +@NoArgsConstructor +public class RgbToHsv extends DynamicCustomOp { + + public RgbToHsv(INDArray input) { + addInputArgument(input); + } + + public RgbToHsv(SameDiff sameDiff, SDVariable input) { + super(sameDiff, new SDVariable[]{input}); + } + + @Override + public String opName() { + return "rgb_to_hsv"; + } + + @Override + public String tensorflowName() { + return "RGBToHSV"; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + int n = args().length; + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes); + return Collections.singletonList(inputDataTypes.get(0)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMax.java index 6d6798701..fb4dc1c80 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMax.java @@ -61,7 +61,7 @@ public class UnsortedSegmentMax extends DynamicCustomOp { @Override public List calculateOutputDataTypes(List inputDataTypes){ - Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 2, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes); + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes); return Collections.singletonList(inputDataTypes.get(0)); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMean.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMean.java index f51b94218..78774d3da 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMean.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMean.java @@ -61,7 +61,7 @@ public class UnsortedSegmentMean extends DynamicCustomOp { @Override public List calculateOutputDataTypes(List inputDataTypes){ - Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 2, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes); + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes); return Collections.singletonList(inputDataTypes.get(0)); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMin.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMin.java index 1b885676e..cc97c3ddb 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMin.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMin.java @@ -61,7 +61,7 @@ public class UnsortedSegmentMin extends DynamicCustomOp { @Override public List calculateOutputDataTypes(List inputDataTypes){ - Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 2, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes); + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes); return Collections.singletonList(inputDataTypes.get(0)); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentProd.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentProd.java index b2e254fb7..4f18b4cec 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentProd.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentProd.java @@ -61,7 +61,7 @@ public class UnsortedSegmentProd extends DynamicCustomOp { @Override public List calculateOutputDataTypes(List inputDataTypes){ - Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 2, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes); + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes); return Collections.singletonList(inputDataTypes.get(0)); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentSqrtN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentSqrtN.java index ef34e9f81..e995ec427 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentSqrtN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentSqrtN.java @@ -60,7 +60,7 @@ public class UnsortedSegmentSqrtN extends DynamicCustomOp { @Override public List calculateOutputDataTypes(List inputDataTypes){ - Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 2, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes); + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes); List out = new ArrayList<>(); for( int i=0; i calculateOutputDataTypes(List inputDataTypes){ - Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 2, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes); + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes); //TODO Allow customizing output type return Collections.singletonList(Nd4j.defaultFloatingPointType()); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/custom/RandomGamma.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/custom/RandomGamma.java new file mode 100644 index 000000000..d7ad376f0 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/custom/RandomGamma.java @@ -0,0 +1,84 @@ + +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.random.custom; + +import lombok.NoArgsConstructor; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.imports.descriptors.properties.adapters.DataTypeAdapter; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.tensorflow.framework.AttrValue; +import org.tensorflow.framework.GraphDef; +import org.tensorflow.framework.NodeDef; + +import java.util.Collections; +import java.util.List; +import java.util.Map; + +@NoArgsConstructor +public class RandomGamma extends DynamicCustomOp { + + public RandomGamma(@NonNull INDArray shape, @NonNull INDArray alpha, INDArray beta, + int... seeds) { + if (beta != null) { + addInputArgument(shape,alpha,beta); + } + addInputArgument(shape,alpha); + addIArgument(seeds); + } + + public RandomGamma(@NonNull INDArray shape, @NonNull INDArray alpha, INDArray beta) { + + this(shape,alpha,beta,0,0); + } + + public RandomGamma(@NonNull SameDiff sameDiff, @NonNull SDVariable shape, + @NonNull SDVariable alpha, SDVariable beta, int... seeds) { + super(null, sameDiff, beta != null ? new SDVariable[]{shape, alpha, beta} : + new SDVariable[]{shape, alpha}); + addIArgument(seeds); + } + + @Override + public String opName() { + return "random_gamma"; + } + + @Override + public String tensorflowName() { + return "RandomGamma"; + } + + private DataType outputDataType = DataType.FLOAT; + + @Override + public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { + if(attributesForNode.containsKey("alpha")) { + outputDataType = DataTypeAdapter.dtypeConv(attributesForNode.get("alpha").getType()); + } + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + Preconditions.checkState(inputDataTypes != null, "Expected exactly input datatypes for %s, got null", getClass()); + return Collections.singletonList(outputDataType); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/custom/RandomPoisson.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/custom/RandomPoisson.java new file mode 100644 index 000000000..b407ca47d --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/custom/RandomPoisson.java @@ -0,0 +1,79 @@ + +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.random.custom; + +import lombok.NoArgsConstructor; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.imports.descriptors.properties.adapters.DataTypeAdapter; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.rng.Random; +import org.tensorflow.framework.AttrValue; +import org.tensorflow.framework.GraphDef; +import org.tensorflow.framework.NodeDef; + +import java.util.Collections; +import java.util.List; +import java.util.Map; + +@NoArgsConstructor +public class RandomPoisson extends DynamicCustomOp { + + private DataType outputDataType = DataType.FLOAT; + + public RandomPoisson(@NonNull INDArray shape, @NonNull INDArray rate, int... seeds) { + addInputArgument(shape, rate); + addIArgument(seeds); + } + + public RandomPoisson(@NonNull INDArray shape, @NonNull INDArray rate) { + this(shape, rate, 0,0); + } + + public RandomPoisson(@NonNull SameDiff sameDiff, @NonNull SDVariable shape, @NonNull SDVariable rate, int... seeds) { + super(null, sameDiff, new SDVariable[]{shape, rate}); + addIArgument(seeds); + } + + @Override + public String opName() { + return "random_poisson"; + } + + @Override + public String[] tensorflowNames() { + return new String[]{"RandomPoisson","RandomPoissonV2"}; + } + + @Override + public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { + if(attributesForNode.containsKey("dtype")) { + outputDataType = DataTypeAdapter.dtypeConv(attributesForNode.get("dtype").getType()); + } + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + Preconditions.checkState(inputDataTypes.size() == 2, "Expected exactly 2 input datatypes for %s, got %s", + getClass(), inputDataTypes.size()); + return Collections.singletonList(outputDataType); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/custom/RandomShuffle.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/custom/RandomShuffle.java new file mode 100644 index 000000000..b08970f11 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/random/custom/RandomShuffle.java @@ -0,0 +1,63 @@ + +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.random.custom; + +import lombok.NoArgsConstructor; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.Collections; +import java.util.List; + +@NoArgsConstructor +public class RandomShuffle extends DynamicCustomOp { + + public RandomShuffle(@NonNull INDArray value, int... seeds) { + addInputArgument(value); + addIArgument(seeds); + } + + public RandomShuffle(@NonNull INDArray value) { + this(value, 0, 0); + } + + public RandomShuffle(@NonNull SameDiff sameDiff, @NonNull SDVariable value, int...seeds) { + super(null, sameDiff, new SDVariable[]{value}); + addIArgument(seeds); + } + + @Override + public String opName() { + return "random_shuffle"; + } + + @Override + public String tensorflowName() { + return "RandomShuffle"; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 1, "Expected exactly 1 input datatype for %s, got %s", getClass(), inputDataTypes); + return Collections.singletonList(inputDataTypes.get(0)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/OpValidationSuite.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/OpValidationSuite.java index b4bb6e1a4..1acc013c2 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/OpValidationSuite.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/OpValidationSuite.java @@ -63,8 +63,8 @@ import static org.junit.Assume.assumeFalse; TransformOpValidation.class, //TF import tests - TFGraphTestAllSameDiff.class, - TFGraphTestAllLibnd4j.class + TFGraphTestAllSameDiff.class + //TFGraphTestAllLibnd4j.class }) //IMPORTANT: This ignore is added to avoid maven surefire running both the suite AND the individual tests in "mvn test" // With it ignored here, the individual tests will run outside (i.e., separately/independently) of the suite in both "mvn test" and IntelliJ diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java index 6a42d21e1..415e76fd5 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java @@ -20,6 +20,7 @@ import lombok.extern.slf4j.Slf4j; import lombok.val; import org.junit.After; import org.junit.Before; +import org.junit.Ignore; import org.junit.Test; import org.nd4j.OpValidationSuite; import org.nd4j.autodiff.functions.DifferentialFunction; @@ -1465,6 +1466,7 @@ public class TransformOpValidation extends BaseOpValidation { } + @Ignore("12/16/2019 https://github.com/eclipse/deeplearning4j/issues/8540") @Test public void testPad(){ INDArray in = Nd4j.valueArrayOf(new long[]{5}, 1.0); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java index eae14b230..afcd8c3e1 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java @@ -26,6 +26,7 @@ import org.junit.After; import org.junit.Before; import org.junit.BeforeClass; import org.nd4j.autodiff.execution.NativeGraphExecutioner; + import org.nd4j.autodiff.execution.conf.ExecutionMode; import org.nd4j.autodiff.execution.conf.ExecutorConfiguration; import org.nd4j.autodiff.execution.conf.OutputMode; @@ -228,9 +229,9 @@ public class TFGraphTestAllHelper { String s1 = s.format(tfPred, false); String s2 = s.format(nd4jPred, false); System.out.print("TF: "); - System.out.println(s1); + System.out.println(tfPred.toStringFull()); System.out.print("SD: "); - System.out.println(s2); + System.out.println(nd4jPred.toStringFull()); } } assertTrue("Predictions do not match on " + modelName + ", node " + outputNode, eq); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java index 9e3db5b1a..d3e48b4d3 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java @@ -111,17 +111,11 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a // 2019/11/15 - missing dtype argument in nd4j, tests are useless https://github.com/eclipse/deeplearning4j/issues/8398 "zeros_like/rank2_float32_dtype_int.*", - // 11.26.2019 failing - https://github.com/eclipse/deeplearning4j/issues/8450 - "betainc.*", - // 11.26.2019 failing - https://github.com/eclipse/deeplearning4j/issues/8453 "roll/.*", // 11.26.2019 failing https://github.com/eclipse/deeplearning4j/issues/8455 - "matrix_band_part/.*", - - // 05.12.2019 failing https://github.com/eclipse/deeplearning4j/issues/8507 - "resize_bicubic/int32.*" + "matrix_band_part/.*" }; /* As per TFGraphTestList.printArraysDebugging - this field defines a set of regexes for test cases that should have diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/BaseNd4jTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/BaseNd4jTest.java index c3c94e1ed..3f75a3293 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/BaseNd4jTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/BaseNd4jTest.java @@ -23,6 +23,7 @@ import org.junit.After; import org.junit.Before; import org.junit.Rule; import org.junit.rules.TestName; +import org.junit.rules.Timeout; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.config.ND4JSystemProperties; @@ -51,6 +52,9 @@ public abstract class BaseNd4jTest { @Rule public TestName testName = new TestName(); + @Rule + public Timeout timeout = Timeout.seconds(30); + protected long startTime; protected int threadCountBefore; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java index c3d4fe699..b03ab1156 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java @@ -34,8 +34,6 @@ import org.nd4j.linalg.api.ops.impl.controlflow.Where; import org.nd4j.linalg.api.ops.impl.image.CropAndResize; import org.nd4j.linalg.api.ops.impl.image.NonMaxSuppression; import org.nd4j.linalg.api.ops.impl.image.ResizeBilinear; -import org.nd4j.linalg.api.ops.impl.layers.convolution.MaxPoolWithArgmax; -import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig; import org.nd4j.linalg.api.ops.impl.reduce.MmulBp; import org.nd4j.linalg.api.ops.impl.shape.Create; import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax; @@ -1353,4 +1351,80 @@ public class CustomOpsTests extends BaseNd4jTest { assertEquals(exp, result); } + + // Exact copy of libnd4j test + @Test + public void testRgbToHsv() { + INDArray expected = Nd4j.createFromArray(new float[]{6.75000000e+01f, 2.54545455e-01f, 8.62745098e-01f, 1.80000000e+02f, + 3.27777778e-01f, 7.05882353e-01f, 1.35066079e+02f, 9.26530612e-01f, + 9.60784314e-01f, 7.45341615e-01f, 6.85106383e-01f, 9.21568627e-01f, + 2.78688525e+02f, 7.85407725e-01f, 9.13725490e-01f, 2.10989011e+01f, + 4.76439791e-01f, 7.49019608e-01f, 2.89038462e+02f, 8.48979592e-01f, + 9.60784314e-01f, 1.56416185e+02f, 6.92000000e-01f, 9.80392157e-01f, + 3.52881356e+02f, 5.31531532e-01f, 4.35294118e-01f, 1.07142857e+01f, + 2.90155440e-01f, 7.56862745e-01f, 3.43384615e+02f, 3.86904762e-01f, + 6.58823529e-01f, 1.78321678e+02f, 7.48691099e-01f, 7.49019608e-01f, + 2.30645161e+02f, 7.78242678e-01f, 9.37254902e-01f, 3.19159664e+02f, + 7.62820513e-01f, 6.11764706e-01f, 2.10126582e+01f, 9.71311475e-01f, + 9.56862745e-01f, 2.90896552e+02f, 5.96707819e-01f, 9.52941176e-01f, + 1.74822335e+02f, 9.42583732e-01f, 8.19607843e-01f, 2.06600985e+02f, + 9.90243902e-01f, 8.03921569e-01f, 1.06883721e+02f, 8.70445344e-01f, + 9.68627451e-01f, 1.95272727e+02f, 6.11111111e-01f, 7.05882353e-01f}).reshape(5,4,3); + INDArray input = Nd4j.createFromArray(new float[]{213.f, 220.f, 164.f, 121.f, 180.f, 180.f, 18.f, 245.f, 75.f, 235.f, 76.f, 74.f, 168.f, + 50.f, 233.f, 191.f, 132.f, 100.f, 207.f, 37.f, 245.f, 77.f, 250.f, 182.f, 111.f, 52.f, + 59.f, 193.f, 147.f, 137.f, 168.f, 103.f, 121.f, 48.f, 191.f, 187.f, 53.f, 82.f, 239.f, + 156.f, 37.f, 118.f, 244.f, 90.f, 7.f, 221.f, 98.f, 243.f, 12.f, 209.f, 192.f, 2.f, + 115.f, 205.f, 79.f, 247.f, 32.f, 70.f, 152.f, 180.f}).reshape(5,4,3); + RgbToHsv op = new RgbToHsv(input); + INDArray[] ret = Nd4j.exec(op); + assertEquals(ret[0], expected); + } + + // Exact copy of libnd4j test + @Test + public void testHsvToRgb() { + INDArray input = Nd4j.createFromArray(new float[]{263.25842697f, 0.74476987f, 0.9372549f, 279.86842105f, + 0.9047619f, 0.65882353f, 71.30044843f, 1.f, + 0.8745098f, 180.f, 0.74871795f, 0.76470588f, + 77.6f, 0.49019608f, 0.6f, 260.74468085f, + 0.89952153f, 0.81960784f, 296.12903226f, 0.86915888f, + 0.41960784f, 289.82142857f, 0.53333333f, 0.82352941f}).reshape(8,3); + + INDArray expected = Nd4j.createFromArray(new float[]{130.f, 61.f, 239.f, 117.f, 16.f, 168.f, 181.f, 223.f, 0.f, 49.f, 195.f, 195.f, 131.f, + 153.f, 78.f, 86.f, 21.f, 209.f, 101.f, 14.f, 107.f, 191.f, 98.f, 210.f}).reshape(8,3); + + HsvToRgb op = new HsvToRgb(input); + INDArray[] ret = Nd4j.exec(op); + assertEquals(ret[0], expected); + + } + + @Ignore + @Test + public void testHsvToRgb_1() { + /* Emulation of simple TF test: + image = tf.random_uniform(shape = [1,1,3]) + tf.image.hsv_to_rgb(image)*/ + INDArray image = Nd4j.createFromArray(new float[]{0.7788f, 0.8012f, 0.7244f}). + reshape(1,1,3); + HsvToRgb op = new HsvToRgb(image); + INDArray[] ret = Nd4j.exec(op); + INDArray expected = Nd4j.createFromArray(new float[]{0.53442812f,0.144007295f,0.724374652f}).reshape(1,1,3); + assertEquals(expected, ret[0]); + } + + @Ignore + @Test + public void testRgbToHsv_1() { + /* Emulation of simple TF test: + image = tf.random_uniform(shape = [1,2,3]) + tf.image.rgb_to_hsv(image)*/ + INDArray image = Nd4j.createFromArray(new float[]{0.7788f,0.8012f,0.7244f, + 0.2309f,0.7271f,0.1804f}).reshape(1,2,3); + RgbToHsv op = new RgbToHsv(image); + INDArray[] ret = Nd4j.exec(op); + INDArray expected = Nd4j.createFromArray(new float[]{0.215289578f, 0.095885336f, 0.801197767f, + 0.317938268f, 0.751917899f, 0.727141261f}).reshape(1,2,3); + assertEquals(expected, ret[0]); + } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RandomTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RandomTests.java index ed8f4d441..c663f3db8 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RandomTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RandomTests.java @@ -31,8 +31,7 @@ import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.util.DataTypeUtil; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition; -import org.nd4j.linalg.api.ops.random.custom.DistributionUniform; -import org.nd4j.linalg.api.ops.random.custom.RandomBernoulli; +import org.nd4j.linalg.api.ops.random.custom.*; import org.nd4j.linalg.api.ops.random.impl.*; import org.nd4j.linalg.api.rng.DefaultRandom; import org.nd4j.linalg.api.rng.Random; @@ -1473,6 +1472,44 @@ public class RandomTests extends BaseNd4jTest { assertEquals(out1, out2); } + @Test + public void testGamma(){ + Nd4j.getRandom().setSeed(12345); + INDArray shape = Nd4j.createFromArray(new int[] {1,3}); + INDArray alpha = Nd4j.rand(1,3); + val randomGamma = new RandomGamma(shape, alpha, null); + INDArray[] res = Nd4j.exec(randomGamma); + + val randomGamma1 = new RandomGamma(shape, alpha, null); + INDArray[] res1 = Nd4j.exec(randomGamma1); + assertEquals(res[0], res1[0]); + } + + @Test + public void testPoisson(){ + Nd4j.getRandom().setSeed(12345); + INDArray shape = Nd4j.createFromArray(new int[] {1,3}); + INDArray alpha = Nd4j.rand(1,3); + val randomPoisson = new RandomPoisson(shape, alpha); + INDArray[] res = Nd4j.exec(randomPoisson); + + val randomPoisson1 = new RandomPoisson(shape, alpha); + INDArray[] res1 = Nd4j.exec(randomPoisson1); + assertEquals(res[0], res1[0]); + } + + @Test + public void testShuffle(){ + Nd4j.getRandom().setSeed(12345); + INDArray alpha = Nd4j.rand(1,3); + val randomShuffle = new RandomShuffle(alpha); + INDArray[] res = Nd4j.exec(randomShuffle); + + val randomShuffle1 = new RandomShuffle(alpha); + INDArray[] res1 = Nd4j.exec(randomShuffle1); + assertEquals(res[0], res1[0]); + } + @Override public char ordering() { return 'c';