Added missing Java ops wrappers (#122)

* Timeouts added

* Added some ops

* Ops added

* Fixed tests

* Minor fix

* Some fixes

* Digamma added

* Small fixes

* Timeouts added

* Added some ops

* Ops added

* Fixed tests

* Minor fix

* Some fixes

* Digamma added

* Small fixes

* Fused batch norm fixes-

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Tests switched off.

* Added test for resize_bicubic.

* Eliminated wasted in test of bicubic resize.

* Switched off multithreading explicit.

* HsvToRgb and RgbToHsv added

* Eliminated waste comments and conform proper float constants.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed multithreading with resize_bicubic helper for cpu platform.

Signed-off-by: shugeo <sgazeos@gmail.com>

* ResizeBicubic was fixed.

* Some fixes

* Fix op name

* Validation fixed.

* Clarifications for tests

* Wrappers and small fixes for new ops.
master
Alexander Stoyakin 2019-12-19 11:15:48 +02:00 committed by Alex Black
parent e0a9cb6c08
commit f5068f3980
59 changed files with 990 additions and 77 deletions

View File

@ -22,6 +22,7 @@ import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
import org.junit.rules.TestName;
import org.junit.rules.Timeout;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
@ -42,6 +43,8 @@ public class BaseDL4JTest {
@Rule
public TestName name = new TestName();
@Rule
public Timeout timeout = Timeout.seconds(30);
protected long startTime;
protected int threadCountBefore;

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.datasets.datavec;
import org.junit.rules.Timeout;
import org.nd4j.shade.guava.io.Files;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils;
@ -70,6 +71,9 @@ import static org.nd4j.linalg.indexing.NDArrayIndex.point;
@Slf4j
public class RecordReaderDataSetiteratorTest extends BaseDL4JTest {
@Rule
protected Timeout timeout = Timeout.seconds(300);
@Override
public DataType getDataType(){
return DataType.FLOAT;

View File

@ -17,6 +17,7 @@
package org.deeplearning4j.datasets.datavec;
import org.junit.rules.Timeout;
import org.nd4j.shade.guava.io.Files;
import org.apache.commons.compress.utils.IOUtils;
import org.apache.commons.io.FileUtils;
@ -69,6 +70,9 @@ public class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest {
@Rule
public TemporaryFolder temporaryFolder = new TemporaryFolder();
@Rule
protected Timeout timeout = Timeout.seconds(300);
@Test
public void testsBasic() throws Exception {
//Load details from CSV files; single input/output -> compare to RecordReaderDataSetIterator

View File

@ -22,7 +22,9 @@ import org.datavec.api.split.FileSplit;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.nn.util.TestDataSetConsumer;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.Timeout;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
@ -37,6 +39,9 @@ import static org.junit.Assert.*;
public class MultipleEpochsIteratorTest extends BaseDL4JTest {
@Rule
protected Timeout timeout = Timeout.seconds(300);
@Test
public void testNextAndReset() throws Exception {
int epochs = 3;

View File

@ -161,7 +161,7 @@ public class EvalTest extends BaseDL4JTest {
assertEquals(evalExpected.getConfusionMatrix(), evalActual.getConfusionMatrix());
}
@Test
@Test(timeout = 300000)
public void testEvaluationWithMetaData() throws Exception {
RecordReader csv = new CSVRecordReader();

View File

@ -317,7 +317,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
assertEquals(paramsMLN, paramsGraph);
}
@Test
@Test(timeout = 300000)
public void testIrisFitMultiDataSetIterator() throws Exception {
RecordReader rr = new CSVRecordReader(0, ',');

View File

@ -103,7 +103,7 @@ public class BarnesHutTsneTest extends BaseDL4JTest {
assertArrayEquals(exp.data().asDouble(), b.getData().data().asDouble(), eps);
}
@Test
@Test(timeout = 300000)
public void testTsne() throws Exception {
DataTypeUtil.setDTypeForContext(DataType.DOUBLE);
Nd4j.getRandom().setSeed(123);

View File

@ -28,7 +28,9 @@ import org.deeplearning4j.nn.weights.WeightInitDistribution;
import org.deeplearning4j.nn.weights.WeightInitRelu;
import org.deeplearning4j.nn.weights.WeightInitXavier;
import org.deeplearning4j.util.ModelSerializer;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.Timeout;
import org.nd4j.linalg.activations.impl.ActivationLReLU;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.factory.Nd4j;
@ -42,6 +44,7 @@ import org.nd4j.linalg.lossfunctions.impl.LossNegativeLogLikelihood;
import org.nd4j.resources.Resources;
import java.io.File;
import java.sql.Time;
import static org.junit.Assert.*;
@ -55,6 +58,9 @@ import static org.junit.Assert.*;
*/
public class RegressionTest050 extends BaseDL4JTest {
@Rule
public Timeout timeout = Timeout.seconds(300);
@Override
public DataType getDataType(){
return DataType.FLOAT;

View File

@ -29,6 +29,7 @@ import org.deeplearning4j.nn.weights.WeightInit;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.rules.Timeout;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.preprocessor.Normalizer;
@ -54,6 +55,9 @@ public class ModelGuesserTest extends BaseDL4JTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Rule
public Timeout timeout = Timeout.seconds(300);
@Test
public void testModelGuessFile() throws Exception {

View File

@ -33,6 +33,7 @@ import org.junit.Ignore;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.rules.Timeout;
import org.nd4j.linalg.activations.impl.ActivationHardSigmoid;
import org.nd4j.linalg.activations.impl.ActivationTanH;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -60,6 +61,9 @@ public class FullModelComparisons extends BaseDL4JTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Rule
public Timeout timeout = Timeout.seconds(300);
@Test
public void lstmTest() throws IOException, UnsupportedKerasConfigurationException,
InvalidKerasConfigurationException, InterruptedException {

View File

@ -30,7 +30,7 @@ import java.io.IOException;
*/
public class TimeSeriesGeneratorImportTest extends BaseDL4JTest {
@Test
@Test(timeout=300000)
public void importTimeSeriesTest() throws IOException, InvalidKerasConfigurationException {
String path = "modelimport/keras/preprocessing/timeseries_generator.json";

View File

@ -35,7 +35,7 @@ public class TokenizerImportTest extends BaseDL4JTest {
ClassLoader classLoader = getClass().getClassLoader();
@Test
@Test(timeout=300000)
public void importTest() throws IOException, InvalidKerasConfigurationException {
String path = "modelimport/keras/preprocessing/tokenizer.json";
@ -51,7 +51,7 @@ public class TokenizerImportTest extends BaseDL4JTest {
}
@Test
@Test(timeout=300000)
public void importNumWordsNullTest() throws IOException, InvalidKerasConfigurationException {
String path = "modelimport/keras/preprocessing/tokenizer_num_words_null.json";

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.models;
import org.junit.rules.Timeout;
import org.nd4j.shade.guava.io.Files;
import org.nd4j.shade.guava.primitives.Doubles;
import lombok.val;
@ -75,6 +76,9 @@ public class WordVectorSerializerTest extends BaseDL4JTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Rule
public Timeout timeout = Timeout.seconds(300);
private File textFile, binaryFile, textFile2;
private File fastTextRaw, fastTextZip, fastTextGzip;
String pathToWriteto;

View File

@ -62,7 +62,7 @@ public class VectorsConfigurationTest extends BaseDL4JTest {
assertEquals(configuration, configuration2);
}
@Test
@Test(timeout = 300000)
public void testFromW2V() throws Exception {
VectorsConfiguration configuration = new VectorsConfiguration();
configuration.setHugeModelExpected(true);

View File

@ -16,6 +16,8 @@
package org.deeplearning4j.models.word2vec;
import org.junit.Rule;
import org.junit.rules.Timeout;
import org.nd4j.shade.guava.primitives.Doubles;
import org.nd4j.shade.guava.primitives.Ints;
import lombok.val;
@ -68,6 +70,9 @@ public class Word2VecTests extends BaseDL4JTest {
private String pathToWriteto;
private WordVectors googleModel;
@Rule
public Timeout timeout = Timeout.seconds(300);
@Before
public void before() throws Exception {
File googleModelTextFile = new ClassPathResource("word2vecserialization/google_news_30.txt").getFile();

View File

@ -52,7 +52,7 @@ public class InMemoryLookupTableTest extends BaseDL4JTest {
}
@Test
@Test(timeout = 300000)
public void testConsumeOnEqualVocabs() throws Exception {
TokenizerFactory t = new DefaultTokenizerFactory();
t.setTokenPreProcessor(new CommonPreprocessor());
@ -99,7 +99,7 @@ public class InMemoryLookupTableTest extends BaseDL4JTest {
}
@Test
@Test(timeout = 300000)
public void testConsumeOnNonEqualVocabs() throws Exception {
TokenizerFactory t = new DefaultTokenizerFactory();
t.setTokenPreProcessor(new CommonPreprocessor());

View File

@ -12,6 +12,7 @@ import org.junit.Ignore;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.rules.Timeout;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.resources.Resources;
@ -27,6 +28,9 @@ import static org.junit.Assert.assertEquals;
@Slf4j
public class FastTextTest extends BaseDL4JTest {
@Rule
public Timeout timeout = Timeout.seconds(300);
private File inputFile = Resources.asFile("models/fasttext/data/labeled_data.txt");
private File supModelFile = Resources.asFile("models/fasttext/supervised.model.bin");
private File cbowModelFile = Resources.asFile("models/fasttext/cbow.model.bin");

View File

@ -55,7 +55,7 @@ public class ParallelTransformerIteratorTest extends BaseDL4JTest {
}
@Test
@Test(timeout = 300000)
public void hasNext() throws Exception {
SentenceIterator iterator = new BasicLineIterator(Resources.asFile("big/raw_sentences.txt"));
@ -77,7 +77,7 @@ public class ParallelTransformerIteratorTest extends BaseDL4JTest {
assertEquals(97162, cnt);
}
@Test
@Test(timeout = 300000)
public void testSpeedComparison1() throws Exception {
SentenceIterator iterator = new MutipleEpochsSentenceIterator(
new BasicLineIterator(Resources.asFile("big/raw_sentences.txt")), 25);

View File

@ -82,7 +82,7 @@ public class Word2VecTestsSmall extends BaseDL4JTest {
assertEquals(neighbours, nearestWords.size());
}
@Test
@Test(timeout = 300000)
public void testUnkSerialization_1() throws Exception {
val inputFile = Resources.asFile("big/raw_sentences.txt");
@ -142,7 +142,7 @@ public class Word2VecTestsSmall extends BaseDL4JTest {
}
@Test
@Test(timeout = 300000)
public void testW2VEmbeddingLayerInit() throws Exception {
Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT);

View File

@ -50,7 +50,7 @@ public class Word2VecDataSetIteratorTest extends BaseDL4JTest {
/**
* Basically all we want from this test - being able to finish without exceptions.
*/
@Test
@Test(timeout = 300000)
public void testIterator1() throws Exception {
File inputFile = Resources.asFile("big/raw_sentences.txt");
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());

View File

@ -22,6 +22,7 @@ import lombok.val;
import org.deeplearning4j.BaseDL4JTest;
import org.junit.Rule;
import org.junit.rules.TemporaryFolder;
import org.junit.rules.Timeout;
import org.nd4j.linalg.io.ClassPathResource;
import org.deeplearning4j.models.sequencevectors.interfaces.SequenceIterator;
import org.deeplearning4j.models.sequencevectors.iterators.AbstractSequenceIterator;
@ -43,6 +44,7 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
import java.sql.Time;
import java.util.*;
import java.util.concurrent.atomic.AtomicBoolean;
@ -53,6 +55,9 @@ import static org.junit.Assert.*;
*/
public class VocabConstructorTest extends BaseDL4JTest {
@Rule
public Timeout timeout = Timeout.seconds(300);
protected static final Logger log = LoggerFactory.getLogger(VocabConstructorTest.class);
TokenizerFactory t = new DefaultTokenizerFactory();

View File

@ -29,7 +29,7 @@ import static org.junit.Assert.assertEquals;
* @author raver119@gmail.com
*/
public class AsyncLabelAwareIteratorTest extends BaseDL4JTest {
@Test
@Test(timeout = 300000)
public void nextDocument() throws Exception {
SentenceIterator sentence = new BasicLineIterator(Resources.asFile("big/raw_sentences.txt"));
BasicLabelAwareIterator backed = new BasicLabelAwareIterator.Builder(sentence).build();

View File

@ -17,6 +17,8 @@
package org.deeplearning4j.text.documentiterator;
import org.deeplearning4j.BaseDL4JTest;
import org.junit.Rule;
import org.junit.rules.Timeout;
import org.nd4j.linalg.io.ClassPathResource;
import org.deeplearning4j.text.sentenceiterator.BasicLineIterator;
import org.deeplearning4j.text.sentenceiterator.SentenceIterator;
@ -33,6 +35,9 @@ import static org.junit.Assert.assertEquals;
*/
public class BasicLabelAwareIteratorTest extends BaseDL4JTest {
@Rule
public Timeout timeout = Timeout.seconds(300);
@Before
public void setUp() throws Exception {

View File

@ -30,7 +30,7 @@ import static org.junit.Assert.assertEquals;
*/
public class AggregatingSentenceIteratorTest extends BaseDL4JTest {
@Test
@Test(timeout = 300000)
public void testHasNext() throws Exception {
File file = Resources.asFile("/big/raw_sentences.txt");
BasicLineIterator iterator = new BasicLineIterator(file);

View File

@ -17,6 +17,8 @@
package org.deeplearning4j.text.sentenceiterator;
import org.deeplearning4j.BaseDL4JTest;
import org.junit.Rule;
import org.junit.rules.Timeout;
import org.nd4j.linalg.io.ClassPathResource;
import org.junit.Before;
import org.junit.Test;
@ -32,6 +34,9 @@ import static org.junit.Assert.assertEquals;
*/
public class BasicLineIteratorTest extends BaseDL4JTest {
@Rule
public Timeout timeout = Timeout.seconds(300);
@Before
public void setUp() throws Exception {

View File

@ -27,7 +27,7 @@ import static org.junit.Assert.assertEquals;
* @author raver119@gmail.com
*/
public class MutipleEpochsSentenceIteratorTest extends BaseDL4JTest {
@Test
@Test(timeout = 300000)
public void hasNext() throws Exception {
SentenceIterator iterator = new MutipleEpochsSentenceIterator(
new BasicLineIterator(Resources.asFile("big/raw_sentences.txt")), 100);

View File

@ -17,6 +17,8 @@
package org.deeplearning4j.text.sentenceiterator;
import org.deeplearning4j.BaseDL4JTest;
import org.junit.Rule;
import org.junit.rules.Timeout;
import org.nd4j.linalg.io.ClassPathResource;
import org.junit.Test;
import org.nd4j.resources.Resources;
@ -33,6 +35,9 @@ import static org.junit.Assert.assertTrue;
*/
public class PrefetchingSentenceIteratorTest extends BaseDL4JTest {
@Rule
public Timeout timeout = Timeout.seconds(300);
protected static final Logger log = LoggerFactory.getLogger(PrefetchingSentenceIteratorTest.class);
@Test

View File

@ -205,7 +205,7 @@ public class BertWordPieceTokenizerTests extends BaseDL4JTest {
}
@Test
@Test(timeout = 300000)
public void testBertWordPieceTokenizer10() throws Exception {
File f = Resources.asFile("deeplearning4j-nlp/bert/uncased_L-12_H-768_A-12/vocab.txt");
BertWordPieceTokenizerFactory t = new BertWordPieceTokenizerFactory(f, true, true, StandardCharsets.UTF_8);

View File

@ -27,7 +27,8 @@ import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.junit.Ignore;
import org.junit.*;
import org.junit.rules.Timeout;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.io.ClassPathResource;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
@ -38,9 +39,6 @@ import org.deeplearning4j.parallelism.inference.InferenceObservable;
import org.deeplearning4j.parallelism.inference.observers.BasicInferenceObserver;
import org.deeplearning4j.parallelism.inference.observers.BatchedInferenceObservable;
import org.deeplearning4j.util.ModelSerializer;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
@ -66,6 +64,9 @@ public class ParallelInferenceTest extends BaseDL4JTest {
private static MultiLayerNetwork model;
private static DataSetIterator iterator;
@Rule
public Timeout timeout = Timeout.seconds(300);
@Before
public void setUp() throws Exception {
if (model == null) {

View File

@ -100,7 +100,7 @@ public class ManualTests {
}
@Test
@Test(timeout = 300000)
public void testTsne() throws Exception {
DataTypeUtil.setDTypeForContext(DataType.DOUBLE);
Nd4j.getRandom().setSeed(123);
@ -208,7 +208,7 @@ public class ManualTests {
}
@Test
@Test(timeout = 300000)
public void testWord2VecPlot() throws Exception {
File inputFile = Resources.asFile("big/raw_sentences.txt");
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());

View File

@ -32,7 +32,7 @@ namespace ops {
->setAllowedOutputTypes({ALL_FLOATS});
}
CUSTOM_OP_IMPL(fused_batch_norm, 3, 1, false, 0, 2) {
CUSTOM_OP_IMPL(fused_batch_norm, 3, 3, false, 0, 2) {
auto x = INPUT_VARIABLE(0); // [bS,iH,iW,iD] (NHWC) or [bS,iD,iH,iW] (NCHW)
auto scale = INPUT_VARIABLE(1); // [iD]
auto offset = INPUT_VARIABLE(2); // [iD]
@ -70,7 +70,7 @@ CUSTOM_OP_IMPL(fused_batch_norm, 3, 1, false, 0, 2) {
REQUIRE_TRUE(variance->rankOf() == 1 && variance->sizeAt(0) == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input variance array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(variance).c_str());
}
else {
REQUIRE_TRUE(block.width() == 3, 0, "CUSTOM_OP fused_batch_norm: when isTraining=true then number of input arrays must be equal to 3, but got %i instead !", block.width());
//REQUIRE_TRUE(block.width() == 3, 0, "CUSTOM_OP fused_batch_norm: when isTraining=true then number of input arrays must be equal to 3, but got %i instead !", block.width());
std::vector<Nd4jLong> shape = {iD};
mean = NDArrayFactory::create_(scale->ordering(), shape, scale->dataType(), block.launchContext());
variance = NDArrayFactory::create_(scale->ordering(), shape, scale->dataType(), block.launchContext());

View File

@ -682,27 +682,29 @@ namespace helpers {
pY2[pt_index], pY3[pt_index]);
}
template <typename T>
template <typename T, typename F>
static void
bicubicInterpolateWithCaching(NDArray const* image, ImageResizerState const& resizerState, bool const halfPixelCenters, NDArray* output) {
std::vector<WeightsAndIndices> xWais(resizerState.outWidth);
computeXWeightsAndIndices(resizerState, halfPixelCenters, &xWais);
const auto numChannels = resizerState.channels;
const Nd4jLong inRowWidth = resizerState.inWidth * numChannels;
const Nd4jLong inBatchWidth = resizerState.inHeight * inRowWidth;
const auto batchNum = resizerState.batchSize;
const auto outHeight = resizerState.outHeight;
const auto outWidth = resizerState.outWidth;
const T* inputPtr = image->getDataBuffer()->primaryAsT<T>();
float* pOutputY = output->dataBuffer()->primaryAsT<float>(); // output is float anyway
std::vector<float> cachedValue(numChannels == 3 ? 0 : 4 * numChannels, 0);
auto func = PRAGMA_THREADS_FOR {
const T* inputPtr = image->getDataBuffer()->primaryAsT<T>();
F* pOutputY = output->dataBuffer()->primaryAsT<F>(); // output is float anyway
std::vector<float> cachedValue(numChannels == 3 ? 0 : 4 * numChannels, 0);
auto func = PRAGMA_THREADS_FOR {
for (auto b = start; b < stop; ++b) {
auto pInput = inputPtr + b * inBatchWidth;
for (auto y = 0; y < resizerState.outHeight; ++y) {
auto pOutput = &pOutputY[(b * resizerState.outHeight + y) * resizerState.outWidth * numChannels];
for (auto y = 0; y < outHeight; ++y) {
auto pOutput = &pOutputY[(b * outHeight + y) * outWidth * numChannels];
WeightsAndIndices yWai;
if (halfPixelCenters) {
@ -713,16 +715,16 @@ namespace helpers {
resizerState.heightScale, y, resizerState.inHeight, &yWai);
}
// Make pointers represent offsets of data in inputBPtr.
const T *y_ptr_0 = pInput + yWai._index0 * inRowWidth;
const T *y_ptr_1 = pInput + yWai._index1 * inRowWidth;
const T *y_ptr_2 = pInput + yWai._index2 * inRowWidth;
const T *y_ptr_3 = pInput + yWai._index3 * inRowWidth;
const T* y_ptr_0 = pInput + yWai._index0 * inRowWidth;
const T* y_ptr_1 = pInput + yWai._index1 * inRowWidth;
const T* y_ptr_2 = pInput + yWai._index2 * inRowWidth;
const T* y_ptr_3 = pInput + yWai._index3 * inRowWidth;
if (numChannels == 3) {
// Manually unroll case of 3 channels.
float cached_value_0[4] = {0};
float cached_value_1[4] = {0};
float cached_value_2[4] = {0};
F cached_value_0[4] = {0};
F cached_value_1[4] = {0};
F cached_value_2[4] = {0};
for (auto x = 0; x < resizerState.outWidth; ++x) {
const WeightsAndIndices &xWai = xWais[x];
// Shift values in cached_value_* to fill first '_advance' values.
@ -854,7 +856,7 @@ namespace helpers {
}
for (auto c = 0; c < numChannels; ++c) {
pOutput[x * numChannels + c] =
compute(&cachedValue[4 * c], xWai._weight0, xWai._weight1,
(F)compute(&cachedValue[4 * c], xWai._weight0, xWai._weight1,
xWai._weight2, xWai._weight3);
}
}
@ -862,7 +864,7 @@ namespace helpers {
}
}
};
samediff::Threads::parallel_tad(func, 0, resizerState.batchSize);
samediff::Threads::parallel_tad(func, 0, batchNum);
}
// simplified bicubic resize without antialiasing
@ -873,7 +875,7 @@ namespace helpers {
ImageResizerState st(alignCorners, halfPixelAlign); // align_corners, half_pixel_align
int res = st.validateAndCreateOutput(image, width, height);
if (res == Status::OK())
bicubicInterpolateWithCaching<T>(image, st, halfPixelAlign, output);
bicubicInterpolateWithCaching<T, float>(image, st, halfPixelAlign, output);
return res;
}

View File

@ -975,6 +975,118 @@ TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test6) {
delete results;
}
TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test7) {
NDArray input = NDArrayFactory::create<double>('c', {2, 5, 5, 1}, {
0.2303, 0.7950, 0.8171, 0.0451, 0.3690, 0.6846, 0.2727, 0.2770, 0.2381, 0.9511,
0.4116, 0.3997, 0.4075, 0.6275, 0.8018, 0.0678, 0.6221, 0.2982, 0.1524, 0.2613,
0.7425, 0.6036, 0.7926, 0.5838, 0.1361, 0.4154, 0.3634, 0.3741, 0.2088, 0.2989,
0.3982, 0.5618, 0.7266, 0.1089, 0.2922, 0.3306, 0.2869, 0.6638, 0.3091, 0.9312,
0.0240, 0.2893, 0.5632, 0.9625, 0.4189, 0.3854, 0.2743, 0.6754, 0.8820, 0.8699});
NDArray expected = NDArrayFactory::create<float>('c', {2, 9, 9, 1}, {
0.2303f, 0.54569f, 0.840649f, 0.92725444f, 0.65660673f,
0.16641647f, 0.06117659f, 0.33279106f, 0.4023279f, 0.5139505f,
0.49821317f, 0.4906872f, 0.537642f, 0.4070102f, 0.13030615f,
0.258801f, 0.65352744f, 0.773368f, 0.69225276f, 0.44177493f,
0.21910316f, 0.22368976f, 0.24221404f, 0.21399781f, 0.5114972f,
0.9169859f, 1.0511527f, 0.5608501f, 0.41315168f, 0.2913824f,
0.2966933f, 0.38585684f, 0.48849702f, 0.71013063f, 0.9086001f,
0.9794303f, 0.29625386f, 0.39427578f, 0.45971435f, 0.39693952f,
0.40860707f, 0.51061106f, 0.6181093f, 0.67309624f, 0.69564015f,
0.06012487f, 0.3863805f, 0.58993465f, 0.40679216f, 0.22607432f,
0.20093678f, 0.25901243f, 0.3615362f, 0.39371052f, 0.24176767f,
0.4868709f, 0.650651f, 0.5493148f, 0.3825456f, 0.27788478f,
0.18927254f, 0.16692996f, 0.15432167f, 0.677519f, 0.6236242f,
0.61700624f, 0.7214321f, 0.7307374f, 0.6251454f, 0.3924176f,
0.17802659f, 0.10231908f, 0.81192374f, 0.66878575f, 0.6118803f,
0.7797006f, 0.8396968f, 0.72889954f, 0.44547448f, 0.16794783f,
0.07125802f, 0.4154f, 0.38504714f, 0.3623221f, 0.3862173f,
0.3397379f, 0.23285517f, 0.21876639f, 0.2892362f, 0.30817088f,
0.41268015f, 0.45587808f, 0.51991886f, 0.60977113f, 0.49489656f,
0.21313031f, 0.11297428f, 0.2167207f, 0.23940037f, 0.39337245f,
0.46112412f, 0.583034f, 0.76207364f, 0.6326203f, 0.22189438f,
0.12071565f, 0.3275853f, 0.3794855f, 0.38497013f, 0.35049653f,
0.41895086f, 0.671095f, 0.62119365f, 0.22362521f, 0.30189657f,
0.72530353f, 0.85048175f, 0.2524255f, 0.2182264f, 0.2964637f,
0.5361996f, 0.6255393f, 0.46424767f, 0.5741281f, 0.8408146f,
0.92403257f, 0.04648584f, 0.14959256f, 0.32215607f, 0.46194845f,
0.6642166f, 0.83560026f, 0.7663391f, 0.5284251f, 0.4573109f,
0.10357999f, 0.17442937f, 0.32116935f, 0.45530772f, 0.7163773f,
0.9856574f, 0.8976148f, 0.5538923f, 0.45173654f, 0.34958175f,
0.2680429f, 0.30470955f, 0.51233786f, 0.75128907f, 0.86736864f,
0.8982046f, 0.83254474f, 0.8168574f, 0.4225865f, 0.2956836f,
0.29948136f, 0.5276342f, 0.76461166f, 0.8442875f, 0.907862f,
0.9139262f, 0.92068815f
});
auto size = NDArrayFactory::create<int>({9, 9});
nd4j::ops::resize_bicubic op;
auto results = op.execute({&input, &size}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
NDArray* result = results->at(0);
// result->printBuffer("Resized to 9x9");
// expected.printBuffer("Expect for 9x9");
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test8) {
NDArray input = NDArrayFactory::create<double>('c', {2, 5, 5, 1}, {
0.23028551377579154, 0.7949972231516509, 0.8171307820461517, 0.04507309923418412, 0.3689673597428338,
0.6845757584903018, 0.27268547668219667, 0.2770196372806053, 0.2381478370531429, 0.9511201914609859,
0.41160882670429033, 0.3997152563642703, 0.4074505147711718, 0.6274595060113246, 0.8017922711300232,
0.06782045852179475, 0.6220772280691722, 0.2982335327629251, 0.1523603480424196, 0.2612986044295986,
0.7424762244324299, 0.6036156464824591, 0.7926371071102005, 0.5838270656432538, 0.13607200219168547,
0.4154002170215956, 0.36340617544852116, 0.37405031188276827, 0.20880251686544882, 0.298919946410666,
0.39820758164277126, 0.5617728968896589, 0.72660225993937, 0.10888245916813699, 0.29215797784445496,
0.3305531351746034, 0.28693451964931715, 0.6637635348315494, 0.30913418229827583, 0.9312186188801752,
0.0239594182399363, 0.2892942758780874, 0.5631691110629038, 0.9625499752246309, 0.4189439089689968,
0.3854304088214935, 0.27426304203925045, 0.6754051704648238, 0.8820362490795286, 0.8699337744328859});
auto testData = NDArrayFactory::create<float>('c', {2,9,9,1}, {
0.230286f, 0.510566354f, 0.794997215f, 0.931386113f, 0.817130804f, 0.402811885f, 0.045073099f, 0.134639814f, 0.368967354f,
0.483021289f, 0.501266003f, 0.521932304f, 0.572325349f, 0.534847379f, 0.267853439f, 0.105112493f, 0.349290252f, 0.674043298f,
0.684575737f, 0.478224277f, 0.272685468f, 0.239882097f, 0.27701965f, 0.191148892f, 0.23814784f, 0.590989769f, 0.951120198f,
0.622912169f, 0.441326082f, 0.266387194f, 0.232538164f, 0.301838756f, 0.356378645f, 0.495445013f, 0.756725252f, 0.981704295f,
0.411608815f, 0.40493685f, 0.399715245f, 0.381842017f, 0.407450527f, 0.501836538f, 0.627459526f, 0.735251725f, 0.801792264f,
0.150875032f, 0.357000858f, 0.524536073f, 0.450354964f, 0.318719596f, 0.319606483f, 0.385957927f, 0.46392554f, 0.529285908f,
0.06782046f, 0.375309169f, 0.622077227f, 0.525792599f, 0.298233539f, 0.184723631f, 0.15236035f, 0.193153858f, 0.261298597f,
0.372918189f, 0.512539625f, 0.63369292f, 0.628733814f, 0.535196245f, 0.436597466f, 0.323553175f, 0.215942055f, 0.148014024f,
0.742476225f, 0.655325174f, 0.603615642f, 0.704684138f, 0.79263711f, 0.747929871f, 0.583827078f, 0.340373576f, 0.136071995f,
0.415400207f, 0.388405323f, 0.363406181f, 0.379345775f, 0.374050319f, 0.28397581f, 0.208802521f, 0.238369256f, 0.298919946f,
0.413146496f, 0.444389015f, 0.488355637f, 0.568351328f, 0.556217432f, 0.345546633f, 0.140068889f, 0.148834035f, 0.23562704f,
0.398207575f, 0.464537472f, 0.561772883f, 0.717433035f, 0.726602256f, 0.416013002f, 0.108882457f, 0.142608985f, 0.292157978f,
0.391511708f, 0.389470309f, 0.442729384f, 0.651181757f, 0.737665415f, 0.41685915f, 0.138383076f, 0.342548877f, 0.659080088f,
0.330553144f, 0.273416102f, 0.286934525f, 0.50450629f, 0.663763523f, 0.463456154f, 0.309134185f, 0.586929917f, 0.931218624f,
0.137025774f, 0.169145152f, 0.263757467f, 0.436182201f, 0.597053051f, 0.657990932f, 0.662163854f, 0.68354249f, 0.692712903f,
0.023959421f, 0.130951077f, 0.289294273f, 0.413664877f, 0.563169122f, 0.839498401f, 0.962549984f, 0.728188932f, 0.418943912f,
0.175951749f, 0.198239252f, 0.281999886f, 0.420836329f, 0.609856486f, 0.863734365f, 0.983550847f, 0.825015843f, 0.596413136f,
0.385430396f, 0.292239636f, 0.274263054f, 0.445040524f, 0.675405145f, 0.817462444f, 0.882036269f, 0.895356655f, 0.869933784f
});
auto size = NDArrayFactory::create<int>({9, 9});
nd4j::ops::resize_bicubic op;
auto results = op.execute({&input, &size}, {}, {}, {true, false});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
NDArray* result = results->at(0);
// result->printBuffer("Resized to 9x9");
// expected.printBuffer("Expect for 9x9");
ASSERT_TRUE(testData.isSameShape(result));
ASSERT_TRUE(testData.equalsTo(result));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests11, summaryStatsData_test1) {

View File

@ -232,10 +232,7 @@ import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentProdBp;
import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentSqrtNBp;
import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentSumBp;
import org.nd4j.linalg.api.ops.impl.transforms.strict.*;
import org.nd4j.linalg.api.ops.random.custom.DistributionUniform;
import org.nd4j.linalg.api.ops.random.custom.RandomBernoulli;
import org.nd4j.linalg.api.ops.random.custom.RandomExponential;
import org.nd4j.linalg.api.ops.random.custom.RandomNormal;
import org.nd4j.linalg.api.ops.random.custom.*;
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
import org.nd4j.linalg.api.ops.random.impl.BinomialDistribution;
import org.nd4j.linalg.api.ops.random.impl.DropOutInverted;
@ -384,6 +381,18 @@ public class DifferentialFunctionFactory {
return new TruncatedNormalDistribution(sameDiff(), mean, stdev, shape).outputVariable();
}
public SDVariable randomGamma(SDVariable shape, SDVariable alpha, SDVariable beta, int... seeds) {
return new RandomGamma(sameDiff(), shape, alpha, beta, seeds).outputVariable();
}
public SDVariable randomPoisson(SDVariable shape, SDVariable rate, int... seeds) {
return new RandomPoisson(sameDiff(), shape, rate, seeds).outputVariable();
}
public SDVariable randomShuffle(SDVariable values, int... seeds) {
return new RandomShuffle(sameDiff(), values, seeds).outputVariable();
}
/**
* Exponential distribution: P(x) = lambda * exp(-lambda * x)
*

View File

@ -3,10 +3,7 @@ package org.nd4j.autodiff.samediff.ops;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ops.custom.AdjustContrast;
import org.nd4j.linalg.api.ops.custom.AdjustHue;
import org.nd4j.linalg.api.ops.custom.AdjustSaturation;
import org.nd4j.linalg.api.ops.custom.RandomCrop;
import org.nd4j.linalg.api.ops.custom.*;
import org.nd4j.linalg.api.ops.impl.image.CropAndResize;
import org.nd4j.linalg.api.ops.impl.image.ExtractImagePatches;
import org.nd4j.linalg.api.ops.impl.image.NonMaxSuppression;
@ -119,4 +116,26 @@ public class SDImage extends SDOps {
SDVariable out = new RandomCrop(sd, input, shape).outputVariable();
return updateVariableNameAndReference(out, name);
}
/**
* Converting array from HSV to RGB format
* @param name name
* @param input 3D image
* @return 3D image
*/
public SDVariable rgbToHsv(String name, @NonNull SDVariable input) {
SDVariable out = new RgbToHsv(sd, input).outputVariable();
return updateVariableNameAndReference(out, name);
}
/**
* Converting image from HSV to RGB format
* @param name name
* @param input 3D image
* @return 3D image
*/
public SDVariable hsvToRgb(String name, @NonNull SDVariable input) {
SDVariable out = new HsvToRgb(sd, input).outputVariable();
return updateVariableNameAndReference(out, name);
}
}

View File

@ -19,6 +19,7 @@ package org.nd4j.autodiff.samediff.ops;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import static org.nd4j.autodiff.samediff.ops.SDValidation.validateInteger;
@ -295,4 +296,43 @@ public class SDRandom extends SDOps {
return updateVariableNameAndReference(ret, name);
}
/**
* Generate a new random SDVariable with Gamma distribution
*
* @param name Name of the output variable
* @param alpha distribution parameter
* @param beta distribution parameter
* @param shape Shape of the new variable
* @return new SDVariable
*/
public SDVariable gamma(String name, SDVariable shape, SDVariable alpha, SDVariable beta) {
SDVariable ret = f().randomGamma(alpha, beta, shape);
return updateVariableNameAndReference(ret, name);
}
/**
* Generate a new random SDVariable with Poission distribution
*
* @param name Name of the output variable
* @param lambda rate distribution parameter
* @param shape Shape of the new variable
* @return new SDVariable
*/
public SDVariable poisson(String name, SDVariable lambda, SDVariable shape, int... seeds) {
SDVariable ret = f().randomPoisson(shape, lambda, seeds);
return updateVariableNameAndReference(ret, name);
}
/**
* Generate a new random SDVariable by random shuffle
*
* @param name Name of the output variable
* @param value array to shuffle
* @return new SDVariable
*/
public SDVariable shuffle(String name, SDVariable value, int... seeds) {
SDVariable ret = f().randomShuffle(value, seeds);
return updateVariableNameAndReference(ret, name);
}
}

View File

@ -572,6 +572,9 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.random.custom.RandomBernoulli.class,
org.nd4j.linalg.api.ops.random.custom.RandomExponential.class,
org.nd4j.linalg.api.ops.random.custom.RandomNormal.class,
org.nd4j.linalg.api.ops.random.custom.RandomGamma.class,
org.nd4j.linalg.api.ops.random.custom.RandomPoisson.class,
org.nd4j.linalg.api.ops.random.custom.RandomShuffle.class,
org.nd4j.linalg.api.ops.random.impl.AlphaDropOut.class,
org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution.class,
org.nd4j.linalg.api.ops.random.impl.BinomialDistribution.class,
@ -588,6 +591,8 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.random.impl.UniformDistribution.class,
org.nd4j.linalg.api.ops.custom.AdjustContrast.class,
org.nd4j.linalg.api.ops.custom.AdjustContrastV2.class,
org.nd4j.linalg.api.ops.custom.HsvToRgb.class,
org.nd4j.linalg.api.ops.custom.RgbToHsv.class,
org.nd4j.linalg.api.ops.custom.BitCast.class,
org.nd4j.linalg.api.ops.custom.CompareAndBitpack.class,
org.nd4j.linalg.api.ops.custom.DivideNoNan.class,
@ -601,7 +606,10 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.custom.Polygamma.class,
org.nd4j.linalg.api.ops.custom.RandomCrop.class,
org.nd4j.linalg.api.ops.custom.Roll.class,
org.nd4j.linalg.api.ops.custom.ToggleBits.class
org.nd4j.linalg.api.ops.custom.ToggleBits.class,
org.nd4j.linalg.api.ops.custom.Igamma.class,
org.nd4j.linalg.api.ops.custom.Igammac.class,
org.nd4j.linalg.api.ops.custom.Digamma.class
);
static {

View File

@ -0,0 +1,56 @@
/* ******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.custom;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import java.util.Collections;
import java.util.List;
@NoArgsConstructor
public class Digamma extends DynamicCustomOp {
public Digamma(@NonNull INDArray x) {
addInputArgument(x);
}
public Digamma(@NonNull SameDiff sameDiff, @NonNull SDVariable x) {
super("", sameDiff, new SDVariable[]{x});
}
@Override
public String opName() {
return "digamma";
}
@Override
public String tensorflowName() {
return "Digamma";
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
int n = args().length;
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
return Collections.singletonList(inputDataTypes.get(0));
}
}

View File

@ -19,15 +19,23 @@ import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
public class FusedBatchNorm extends DynamicCustomOp {
private DataType outputDataType;
public FusedBatchNorm() {}
public FusedBatchNorm(@NonNull INDArray x, @NonNull INDArray scale, @NonNull INDArray offset,
@ -38,6 +46,7 @@ public class FusedBatchNorm extends DynamicCustomOp {
if (yOut != null && batchMeanOut != null && batchMeanVar != null) {
addOutputArgument(yOut, batchMeanOut, batchMeanVar);
}
this.outputDataType = x.dataType();
}
public FusedBatchNorm(@NonNull SameDiff sameDiff, @NonNull SDVariable x, @NonNull SDVariable scale, @NonNull SDVariable offset,
@ -51,14 +60,25 @@ public class FusedBatchNorm extends DynamicCustomOp {
}
@Override
public String tensorflowName() {
return "FusedBatchNormV2";
public String[] tensorflowNames() {
return new String[]{"FusedBatchNormV2","FusedBatchNormV3"};
}
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
boolean isNchw = attributesForNode.containsKey("data_format") && attributesForNode.get("data_format").getS().toStringUtf8().equalsIgnoreCase("NCHW");
boolean training = !attributesForNode.containsKey("is_training") ? true : attributesForNode.get("is_training").getB();
addIArgument(isNchw ? 1 : 0);
addIArgument(training ? 1 : 0);
if(attributesForNode.containsKey("T")){
outputDataType = TFGraphMapper.convertType(attributesForNode.get("T").getType());
}
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
int n = args().length;
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
return Collections.singletonList(inputDataTypes.get(0));
return Arrays.asList(outputDataType, DataType.FLOAT, DataType.FLOAT); //Activations may be half, bfloat16, float32; mean/var is always float
}
}

View File

@ -0,0 +1,57 @@
/* ******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.custom;
import lombok.NoArgsConstructor;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import java.util.Collections;
import java.util.List;
@NoArgsConstructor
public class HsvToRgb extends DynamicCustomOp {
public HsvToRgb(INDArray input) {
addInputArgument(input);
}
public HsvToRgb(SameDiff sameDiff, SDVariable input) {
super(sameDiff, new SDVariable[]{input});
}
@Override
public String opName() {
return "hsv_to_rgb";
}
@Override
public String tensorflowName() {
return "HSVToRGB";
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
int n = args().length;
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
return Collections.singletonList(inputDataTypes.get(0));
}
}

View File

@ -0,0 +1,65 @@
/* ******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.custom;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import java.util.Collections;
import java.util.List;
@NoArgsConstructor
public class Igamma extends DynamicCustomOp {
public Igamma(@NonNull INDArray n, @NonNull INDArray x) {
Preconditions.checkArgument(n.shape() != x.shape(),
"Igamma: n and x must have the same shapes");
addInputArgument(n,x);
}
public Igamma(@NonNull INDArray n, @NonNull INDArray x, INDArray output) {
this(n,x);
if (output != null) {
addOutputArgument(output);
}
}
public Igamma(@NonNull SameDiff sameDiff, @NonNull SDVariable n, @NonNull SDVariable x) {
super("", sameDiff, new SDVariable[]{n ,x});
}
@Override
public String opName() {
return "igamma";
}
@Override
public String tensorflowName() {
return "Igamma";
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
int n = args().length;
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
return Collections.singletonList(inputDataTypes.get(0));
}
}

View File

@ -0,0 +1,66 @@
/* ******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.custom;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.DynamicPartitionBp;
import java.util.Collections;
import java.util.List;
@NoArgsConstructor
public class Igammac extends DynamicCustomOp {
public Igammac(@NonNull INDArray n, @NonNull INDArray x) {
Preconditions.checkArgument(n.shape() != x.shape(),
"Igamma: n and x must have the same shapes");
addInputArgument(n,x);
}
public Igammac(@NonNull INDArray n, @NonNull INDArray x, INDArray output) {
this(n,x);
if (output != null) {
addOutputArgument(output);
}
}
public Igammac(@NonNull SameDiff sameDiff, @NonNull SDVariable n, @NonNull SDVariable x) {
super("", sameDiff, new SDVariable[]{n ,x});
}
@Override
public String opName() {
return "igammac";
}
@Override
public String tensorflowName() {
return "Igammac";
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
int n = args().length;
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
return Collections.singletonList(inputDataTypes.get(0));
}
}

View File

@ -0,0 +1,57 @@
/* ******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.custom;
import lombok.NoArgsConstructor;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import java.util.Collections;
import java.util.List;
@NoArgsConstructor
public class RgbToHsv extends DynamicCustomOp {
public RgbToHsv(INDArray input) {
addInputArgument(input);
}
public RgbToHsv(SameDiff sameDiff, SDVariable input) {
super(sameDiff, new SDVariable[]{input});
}
@Override
public String opName() {
return "rgb_to_hsv";
}
@Override
public String tensorflowName() {
return "RGBToHSV";
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
int n = args().length;
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
return Collections.singletonList(inputDataTypes.get(0));
}
}

View File

@ -61,7 +61,7 @@ public class UnsortedSegmentMax extends DynamicCustomOp {
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 2, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes);
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes);
return Collections.singletonList(inputDataTypes.get(0));
}

View File

@ -61,7 +61,7 @@ public class UnsortedSegmentMean extends DynamicCustomOp {
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 2, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes);
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes);
return Collections.singletonList(inputDataTypes.get(0));
}

View File

@ -61,7 +61,7 @@ public class UnsortedSegmentMin extends DynamicCustomOp {
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 2, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes);
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes);
return Collections.singletonList(inputDataTypes.get(0));
}
}

View File

@ -61,7 +61,7 @@ public class UnsortedSegmentProd extends DynamicCustomOp {
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 2, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes);
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes);
return Collections.singletonList(inputDataTypes.get(0));
}
}

View File

@ -60,7 +60,7 @@ public class UnsortedSegmentSqrtN extends DynamicCustomOp {
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 2, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes);
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes);
List<DataType> out = new ArrayList<>();
for( int i=0; i<numSegments; i++ ){
out.add(inputDataTypes.get(0));

View File

@ -62,7 +62,7 @@ public class UnsortedSegmentSum extends DynamicCustomOp {
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 2, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes);
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes);
//TODO Allow customizing output type
return Collections.singletonList(Nd4j.defaultFloatingPointType());
}

View File

@ -0,0 +1,84 @@
/* ******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.random.custom;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.descriptors.properties.adapters.DataTypeAdapter;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;
import java.util.Collections;
import java.util.List;
import java.util.Map;
@NoArgsConstructor
public class RandomGamma extends DynamicCustomOp {
public RandomGamma(@NonNull INDArray shape, @NonNull INDArray alpha, INDArray beta,
int... seeds) {
if (beta != null) {
addInputArgument(shape,alpha,beta);
}
addInputArgument(shape,alpha);
addIArgument(seeds);
}
public RandomGamma(@NonNull INDArray shape, @NonNull INDArray alpha, INDArray beta) {
this(shape,alpha,beta,0,0);
}
public RandomGamma(@NonNull SameDiff sameDiff, @NonNull SDVariable shape,
@NonNull SDVariable alpha, SDVariable beta, int... seeds) {
super(null, sameDiff, beta != null ? new SDVariable[]{shape, alpha, beta} :
new SDVariable[]{shape, alpha});
addIArgument(seeds);
}
@Override
public String opName() {
return "random_gamma";
}
@Override
public String tensorflowName() {
return "RandomGamma";
}
private DataType outputDataType = DataType.FLOAT;
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
if(attributesForNode.containsKey("alpha")) {
outputDataType = DataTypeAdapter.dtypeConv(attributesForNode.get("alpha").getType());
}
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
Preconditions.checkState(inputDataTypes != null, "Expected exactly input datatypes for %s, got null", getClass());
return Collections.singletonList(outputDataType);
}
}

View File

@ -0,0 +1,79 @@
/* ******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.random.custom;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.descriptors.properties.adapters.DataTypeAdapter;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.rng.Random;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;
import java.util.Collections;
import java.util.List;
import java.util.Map;
@NoArgsConstructor
public class RandomPoisson extends DynamicCustomOp {
private DataType outputDataType = DataType.FLOAT;
public RandomPoisson(@NonNull INDArray shape, @NonNull INDArray rate, int... seeds) {
addInputArgument(shape, rate);
addIArgument(seeds);
}
public RandomPoisson(@NonNull INDArray shape, @NonNull INDArray rate) {
this(shape, rate, 0,0);
}
public RandomPoisson(@NonNull SameDiff sameDiff, @NonNull SDVariable shape, @NonNull SDVariable rate, int... seeds) {
super(null, sameDiff, new SDVariable[]{shape, rate});
addIArgument(seeds);
}
@Override
public String opName() {
return "random_poisson";
}
@Override
public String[] tensorflowNames() {
return new String[]{"RandomPoisson","RandomPoissonV2"};
}
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
if(attributesForNode.containsKey("dtype")) {
outputDataType = DataTypeAdapter.dtypeConv(attributesForNode.get("dtype").getType());
}
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
Preconditions.checkState(inputDataTypes.size() == 2, "Expected exactly 2 input datatypes for %s, got %s",
getClass(), inputDataTypes.size());
return Collections.singletonList(outputDataType);
}
}

View File

@ -0,0 +1,63 @@
/* ******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.random.custom;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import java.util.Collections;
import java.util.List;
@NoArgsConstructor
public class RandomShuffle extends DynamicCustomOp {
public RandomShuffle(@NonNull INDArray value, int... seeds) {
addInputArgument(value);
addIArgument(seeds);
}
public RandomShuffle(@NonNull INDArray value) {
this(value, 0, 0);
}
public RandomShuffle(@NonNull SameDiff sameDiff, @NonNull SDVariable value, int...seeds) {
super(null, sameDiff, new SDVariable[]{value});
addIArgument(seeds);
}
@Override
public String opName() {
return "random_shuffle";
}
@Override
public String tensorflowName() {
return "RandomShuffle";
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 1, "Expected exactly 1 input datatype for %s, got %s", getClass(), inputDataTypes);
return Collections.singletonList(inputDataTypes.get(0));
}
}

View File

@ -63,8 +63,8 @@ import static org.junit.Assume.assumeFalse;
TransformOpValidation.class,
//TF import tests
TFGraphTestAllSameDiff.class,
TFGraphTestAllLibnd4j.class
TFGraphTestAllSameDiff.class
//TFGraphTestAllLibnd4j.class
})
//IMPORTANT: This ignore is added to avoid maven surefire running both the suite AND the individual tests in "mvn test"
// With it ignored here, the individual tests will run outside (i.e., separately/independently) of the suite in both "mvn test" and IntelliJ

View File

@ -20,6 +20,7 @@ import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.junit.After;
import org.junit.Before;
import org.junit.Ignore;
import org.junit.Test;
import org.nd4j.OpValidationSuite;
import org.nd4j.autodiff.functions.DifferentialFunction;
@ -1465,6 +1466,7 @@ public class TransformOpValidation extends BaseOpValidation {
}
@Ignore("12/16/2019 https://github.com/eclipse/deeplearning4j/issues/8540")
@Test
public void testPad(){
INDArray in = Nd4j.valueArrayOf(new long[]{5}, 1.0);

View File

@ -26,6 +26,7 @@ import org.junit.After;
import org.junit.Before;
import org.junit.BeforeClass;
import org.nd4j.autodiff.execution.NativeGraphExecutioner;
import org.nd4j.autodiff.execution.conf.ExecutionMode;
import org.nd4j.autodiff.execution.conf.ExecutorConfiguration;
import org.nd4j.autodiff.execution.conf.OutputMode;
@ -228,9 +229,9 @@ public class TFGraphTestAllHelper {
String s1 = s.format(tfPred, false);
String s2 = s.format(nd4jPred, false);
System.out.print("TF: ");
System.out.println(s1);
System.out.println(tfPred.toStringFull());
System.out.print("SD: ");
System.out.println(s2);
System.out.println(nd4jPred.toStringFull());
}
}
assertTrue("Predictions do not match on " + modelName + ", node " + outputNode, eq);

View File

@ -111,17 +111,11 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
// 2019/11/15 - missing dtype argument in nd4j, tests are useless https://github.com/eclipse/deeplearning4j/issues/8398
"zeros_like/rank2_float32_dtype_int.*",
// 11.26.2019 failing - https://github.com/eclipse/deeplearning4j/issues/8450
"betainc.*",
// 11.26.2019 failing - https://github.com/eclipse/deeplearning4j/issues/8453
"roll/.*",
// 11.26.2019 failing https://github.com/eclipse/deeplearning4j/issues/8455
"matrix_band_part/.*",
// 05.12.2019 failing https://github.com/eclipse/deeplearning4j/issues/8507
"resize_bicubic/int32.*"
"matrix_band_part/.*"
};
/* As per TFGraphTestList.printArraysDebugging - this field defines a set of regexes for test cases that should have

View File

@ -23,6 +23,7 @@ import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
import org.junit.rules.TestName;
import org.junit.rules.Timeout;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.config.ND4JSystemProperties;
@ -51,6 +52,9 @@ public abstract class BaseNd4jTest {
@Rule
public TestName testName = new TestName();
@Rule
public Timeout timeout = Timeout.seconds(30);
protected long startTime;
protected int threadCountBefore;

View File

@ -34,8 +34,6 @@ import org.nd4j.linalg.api.ops.impl.controlflow.Where;
import org.nd4j.linalg.api.ops.impl.image.CropAndResize;
import org.nd4j.linalg.api.ops.impl.image.NonMaxSuppression;
import org.nd4j.linalg.api.ops.impl.image.ResizeBilinear;
import org.nd4j.linalg.api.ops.impl.layers.convolution.MaxPoolWithArgmax;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig;
import org.nd4j.linalg.api.ops.impl.reduce.MmulBp;
import org.nd4j.linalg.api.ops.impl.shape.Create;
import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax;
@ -1353,4 +1351,80 @@ public class CustomOpsTests extends BaseNd4jTest {
assertEquals(exp, result);
}
// Exact copy of libnd4j test
@Test
public void testRgbToHsv() {
INDArray expected = Nd4j.createFromArray(new float[]{6.75000000e+01f, 2.54545455e-01f, 8.62745098e-01f, 1.80000000e+02f,
3.27777778e-01f, 7.05882353e-01f, 1.35066079e+02f, 9.26530612e-01f,
9.60784314e-01f, 7.45341615e-01f, 6.85106383e-01f, 9.21568627e-01f,
2.78688525e+02f, 7.85407725e-01f, 9.13725490e-01f, 2.10989011e+01f,
4.76439791e-01f, 7.49019608e-01f, 2.89038462e+02f, 8.48979592e-01f,
9.60784314e-01f, 1.56416185e+02f, 6.92000000e-01f, 9.80392157e-01f,
3.52881356e+02f, 5.31531532e-01f, 4.35294118e-01f, 1.07142857e+01f,
2.90155440e-01f, 7.56862745e-01f, 3.43384615e+02f, 3.86904762e-01f,
6.58823529e-01f, 1.78321678e+02f, 7.48691099e-01f, 7.49019608e-01f,
2.30645161e+02f, 7.78242678e-01f, 9.37254902e-01f, 3.19159664e+02f,
7.62820513e-01f, 6.11764706e-01f, 2.10126582e+01f, 9.71311475e-01f,
9.56862745e-01f, 2.90896552e+02f, 5.96707819e-01f, 9.52941176e-01f,
1.74822335e+02f, 9.42583732e-01f, 8.19607843e-01f, 2.06600985e+02f,
9.90243902e-01f, 8.03921569e-01f, 1.06883721e+02f, 8.70445344e-01f,
9.68627451e-01f, 1.95272727e+02f, 6.11111111e-01f, 7.05882353e-01f}).reshape(5,4,3);
INDArray input = Nd4j.createFromArray(new float[]{213.f, 220.f, 164.f, 121.f, 180.f, 180.f, 18.f, 245.f, 75.f, 235.f, 76.f, 74.f, 168.f,
50.f, 233.f, 191.f, 132.f, 100.f, 207.f, 37.f, 245.f, 77.f, 250.f, 182.f, 111.f, 52.f,
59.f, 193.f, 147.f, 137.f, 168.f, 103.f, 121.f, 48.f, 191.f, 187.f, 53.f, 82.f, 239.f,
156.f, 37.f, 118.f, 244.f, 90.f, 7.f, 221.f, 98.f, 243.f, 12.f, 209.f, 192.f, 2.f,
115.f, 205.f, 79.f, 247.f, 32.f, 70.f, 152.f, 180.f}).reshape(5,4,3);
RgbToHsv op = new RgbToHsv(input);
INDArray[] ret = Nd4j.exec(op);
assertEquals(ret[0], expected);
}
// Exact copy of libnd4j test
@Test
public void testHsvToRgb() {
INDArray input = Nd4j.createFromArray(new float[]{263.25842697f, 0.74476987f, 0.9372549f, 279.86842105f,
0.9047619f, 0.65882353f, 71.30044843f, 1.f,
0.8745098f, 180.f, 0.74871795f, 0.76470588f,
77.6f, 0.49019608f, 0.6f, 260.74468085f,
0.89952153f, 0.81960784f, 296.12903226f, 0.86915888f,
0.41960784f, 289.82142857f, 0.53333333f, 0.82352941f}).reshape(8,3);
INDArray expected = Nd4j.createFromArray(new float[]{130.f, 61.f, 239.f, 117.f, 16.f, 168.f, 181.f, 223.f, 0.f, 49.f, 195.f, 195.f, 131.f,
153.f, 78.f, 86.f, 21.f, 209.f, 101.f, 14.f, 107.f, 191.f, 98.f, 210.f}).reshape(8,3);
HsvToRgb op = new HsvToRgb(input);
INDArray[] ret = Nd4j.exec(op);
assertEquals(ret[0], expected);
}
@Ignore
@Test
public void testHsvToRgb_1() {
/* Emulation of simple TF test:
image = tf.random_uniform(shape = [1,1,3])
tf.image.hsv_to_rgb(image)*/
INDArray image = Nd4j.createFromArray(new float[]{0.7788f, 0.8012f, 0.7244f}).
reshape(1,1,3);
HsvToRgb op = new HsvToRgb(image);
INDArray[] ret = Nd4j.exec(op);
INDArray expected = Nd4j.createFromArray(new float[]{0.53442812f,0.144007295f,0.724374652f}).reshape(1,1,3);
assertEquals(expected, ret[0]);
}
@Ignore
@Test
public void testRgbToHsv_1() {
/* Emulation of simple TF test:
image = tf.random_uniform(shape = [1,2,3])
tf.image.rgb_to_hsv(image)*/
INDArray image = Nd4j.createFromArray(new float[]{0.7788f,0.8012f,0.7244f,
0.2309f,0.7271f,0.1804f}).reshape(1,2,3);
RgbToHsv op = new RgbToHsv(image);
INDArray[] ret = Nd4j.exec(op);
INDArray expected = Nd4j.createFromArray(new float[]{0.215289578f, 0.095885336f, 0.801197767f,
0.317938268f, 0.751917899f, 0.727141261f}).reshape(1,2,3);
assertEquals(expected, ret[0]);
}
}

View File

@ -31,8 +31,7 @@ import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
import org.nd4j.linalg.api.ops.random.custom.DistributionUniform;
import org.nd4j.linalg.api.ops.random.custom.RandomBernoulli;
import org.nd4j.linalg.api.ops.random.custom.*;
import org.nd4j.linalg.api.ops.random.impl.*;
import org.nd4j.linalg.api.rng.DefaultRandom;
import org.nd4j.linalg.api.rng.Random;
@ -1473,6 +1472,44 @@ public class RandomTests extends BaseNd4jTest {
assertEquals(out1, out2);
}
@Test
public void testGamma(){
Nd4j.getRandom().setSeed(12345);
INDArray shape = Nd4j.createFromArray(new int[] {1,3});
INDArray alpha = Nd4j.rand(1,3);
val randomGamma = new RandomGamma(shape, alpha, null);
INDArray[] res = Nd4j.exec(randomGamma);
val randomGamma1 = new RandomGamma(shape, alpha, null);
INDArray[] res1 = Nd4j.exec(randomGamma1);
assertEquals(res[0], res1[0]);
}
@Test
public void testPoisson(){
Nd4j.getRandom().setSeed(12345);
INDArray shape = Nd4j.createFromArray(new int[] {1,3});
INDArray alpha = Nd4j.rand(1,3);
val randomPoisson = new RandomPoisson(shape, alpha);
INDArray[] res = Nd4j.exec(randomPoisson);
val randomPoisson1 = new RandomPoisson(shape, alpha);
INDArray[] res1 = Nd4j.exec(randomPoisson1);
assertEquals(res[0], res1[0]);
}
@Test
public void testShuffle(){
Nd4j.getRandom().setSeed(12345);
INDArray alpha = Nd4j.rand(1,3);
val randomShuffle = new RandomShuffle(alpha);
INDArray[] res = Nd4j.exec(randomShuffle);
val randomShuffle1 = new RandomShuffle(alpha);
INDArray[] res1 = Nd4j.exec(randomShuffle1);
assertEquals(res[0], res1[0]);
}
@Override
public char ordering() {
return 'c';