From d5002b14c7774792399e7df83c9f83babefc3f08 Mon Sep 17 00:00:00 2001 From: Alexander Stoyakin Date: Wed, 16 Oct 2019 12:59:08 +0300 Subject: [PATCH 1/7] New ops wrappers --- .../linalg/api/ops/custom/AdjustContrast.java | 19 +++ .../api/ops/custom/AdjustContrastV2.java | 19 +++ .../api/ops/custom/BaseAdjustContrast.java | 19 +++ .../nd4j/linalg/api/ops/custom/BitCast.java | 21 ++++ .../api/ops/custom/CompareAndBitpack.java | 20 +++ .../linalg/api/ops/custom/DivideNoNan.java | 21 ++++ .../api/ops/custom/DrawBoundingBoxes.java | 21 ++++ .../FakeQuantWithMinMaxVarsPerChannel.java | 25 ++++ .../nd4j/linalg/custom/CustomOpsTests.java | 117 +++++++++++++++++- 9 files changed, 280 insertions(+), 2 deletions(-) create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrast.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrastV2.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BaseAdjustContrast.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BitCast.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/CompareAndBitpack.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/DivideNoNan.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/DrawBoundingBoxes.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/FakeQuantWithMinMaxVarsPerChannel.java diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrast.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrast.java new file mode 100644 index 000000000..aad384a26 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrast.java @@ -0,0 +1,19 @@ +package org.nd4j.linalg.api.ops.custom; + +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +public class AdjustContrast extends BaseAdjustContrast { + + public AdjustContrast() {super();} + + public AdjustContrast(INDArray in, double factor, INDArray out) { + super(in, factor, out); + } + + @Override + public String opName() { + return "adjust_contrast"; + } +} \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrastV2.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrastV2.java new file mode 100644 index 000000000..4be4ae098 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrastV2.java @@ -0,0 +1,19 @@ +package org.nd4j.linalg.api.ops.custom; + +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +public class AdjustContrastV2 extends BaseAdjustContrast { + + public AdjustContrastV2() {super();} + + public AdjustContrastV2(INDArray in, double factor, INDArray out) { + super(in, factor, out); + } + + @Override + public String opName() { + return "adjust_contrast_v2"; + } +} \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BaseAdjustContrast.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BaseAdjustContrast.java new file mode 100644 index 000000000..7057118c5 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BaseAdjustContrast.java @@ -0,0 +1,19 @@ +package org.nd4j.linalg.api.ops.custom; + +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +public abstract class BaseAdjustContrast extends DynamicCustomOp { + public BaseAdjustContrast() { + } + + public BaseAdjustContrast(INDArray in, double factor, INDArray out) { + Preconditions.checkArgument(in.rank() >= 3, + String.format("AdjustContrast: op expects rank of input array to be >= 3, but got %d instead", in.rank())); + inputArguments.add(in); + outputArguments.add(out); + + addTArgument(factor); + } +} \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BitCast.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BitCast.java new file mode 100644 index 000000000..7a1f125c6 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BitCast.java @@ -0,0 +1,21 @@ +package org.nd4j.linalg.api.ops.custom; + +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.factory.Nd4j; + +public class BitCast extends DynamicCustomOp { + public BitCast() {} + + public BitCast(INDArray in, int dataType, INDArray out) { + inputArguments.add(in); + outputArguments.add(out); + iArguments.add(Long.valueOf(dataType)); + } + + @Override + public String opName() { + return "bitcast"; + } +} \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/CompareAndBitpack.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/CompareAndBitpack.java new file mode 100644 index 000000000..4f0aad2ee --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/CompareAndBitpack.java @@ -0,0 +1,20 @@ +package org.nd4j.linalg.api.ops.custom; + +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.factory.Nd4j; + +public class CompareAndBitpack extends DynamicCustomOp { + public CompareAndBitpack() {} + + public CompareAndBitpack(INDArray in, double threshold, INDArray out) { + inputArguments.add(in); + inputArguments.add(Nd4j.scalar(threshold)); + outputArguments.add(out); + } + + @Override + public String opName() { + return "compare_and_bitpack"; + } +} \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/DivideNoNan.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/DivideNoNan.java new file mode 100644 index 000000000..b2eafb791 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/DivideNoNan.java @@ -0,0 +1,21 @@ +package org.nd4j.linalg.api.ops.custom; + +import org.apache.commons.math3.analysis.function.Divide; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +public class DivideNoNan extends DynamicCustomOp { + public DivideNoNan() { + } + + public DivideNoNan(INDArray in1, INDArray in2, INDArray out) { + inputArguments.add(in1); + inputArguments.add(in2); + outputArguments.add(out); + } + + @Override + public String opName() { + return "divide_no_nan"; + } +} \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/DrawBoundingBoxes.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/DrawBoundingBoxes.java new file mode 100644 index 000000000..c6cf04b62 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/DrawBoundingBoxes.java @@ -0,0 +1,21 @@ +package org.nd4j.linalg.api.ops.custom; + +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +public class DrawBoundingBoxes extends DynamicCustomOp { + public DrawBoundingBoxes() {} + + public DrawBoundingBoxes(INDArray images, INDArray boxes, INDArray colors, + INDArray output) { + inputArguments.add(images); + inputArguments.add(boxes); + inputArguments.add(colors); + outputArguments.add(output); + } + + @Override + public String opName() { + return "draw_bounding_boxes"; + } +} \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/FakeQuantWithMinMaxVarsPerChannel.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/FakeQuantWithMinMaxVarsPerChannel.java new file mode 100644 index 000000000..3bdcf6dd3 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/FakeQuantWithMinMaxVarsPerChannel.java @@ -0,0 +1,25 @@ +package org.nd4j.linalg.api.ops.custom; + +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +public class FakeQuantWithMinMaxVarsPerChannel extends DynamicCustomOp { + public FakeQuantWithMinMaxVarsPerChannel() {} + + public FakeQuantWithMinMaxVarsPerChannel(INDArray x, INDArray min, INDArray max, + INDArray output) { + Preconditions.checkArgument(min.isVector() && max.isVector() && + min.length() == max.length(), + "FakeQuantWithMinMaxVarsPerChannel: min and max should be 1D tensors with the same length"); + inputArguments.add(x); + inputArguments.add(min); + inputArguments.add(max); + outputArguments.add(output); + } + + @Override + public String opName() { + return "fake_quant_with_min_max_vars_per_channel"; + } +} \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java index ded23f810..d93c934f0 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java @@ -26,8 +26,7 @@ import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.CustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp; -import org.nd4j.linalg.api.ops.custom.Flatten; -import org.nd4j.linalg.api.ops.custom.ScatterUpdate; +import org.nd4j.linalg.api.ops.custom.*; import org.nd4j.linalg.api.ops.executioner.OpExecutioner; import org.nd4j.linalg.api.ops.executioner.OpStatus; import org.nd4j.linalg.api.ops.impl.reduce.Mmul; @@ -807,4 +806,118 @@ public class CustomOpsTests extends BaseNd4jTest { Nd4j.getExecutioner().commit(); } + + @Test + public void testAdjustContrast() { + INDArray in = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 4*4*3).reshape(4,4,3); + INDArray out = Nd4j.zeros(4,4,3); + + INDArray expected = Nd4j.createFromArray(new double[]{-21.5, -20.5, -19.5, -15.5, -14.5, -13.5, -9.5, -8.5, -7.5, -3.5, -2.5, -1.5, + 2.5, 3.5, 4.5, 8.5, 9.5, 10.5, 14.5, 15.5, 16.5, 20.5, 21.5, 22.5, + 26.5, 27.5, 28.5, 32.5, 33.5, 34.5, 38.5, 39.5, 40.5, 44.5, 45.5, 46.5, + 50.5, 51.5, 52.5, 56.5, 57.5, 58.5, 62.5, 63.5, 64.5, 68.5, 69.5, 70.5 + }).reshape(4,4,3); + Nd4j.exec(new AdjustContrast(in, 2.0, out)); + + assertArrayEquals(out.shape(), in.shape()); + assertEquals(expected, out); + } + + @Test + public void testAdjustContrastV2() { + INDArray in = Nd4j.linspace(DataType.DOUBLE,1.0,1.0, 4*4*3).reshape(4,4,3); + INDArray out = Nd4j.createUninitialized(4,4,3); + + INDArray expected = Nd4j.createFromArray(new double[]{-21.5, -20.5, -19.5, -15.5, -14.5, -13.5, -9.5, -8.5, -7.5, -3.5, -2.5, -1.5, + 2.5, 3.5, 4.5, 8.5, 9.5, 10.5, 14.5, 15.5, 16.5, 20.5, 21.5, 22.5, + 26.5, 27.5, 28.5, 32.5, 33.5, 34.5, 38.5, 39.5, 40.5, 44.5, 45.5, 46.5, + 50.5, 51.5, 52.5, 56.5, 57.5, 58.5, 62.5, 63.5, 64.5, 68.5, 69.5, 70.5 + }).reshape(4,4,3); + + Nd4j.exec(new AdjustContrastV2(in, 2.0, out)); + + assertArrayEquals(out.shape(), in.shape()); + assertEquals(expected, out); + } + + @Test + public void testBitCast() { + INDArray in = Nd4j.linspace(DataType.FLOAT, 1.0f, 1.0f, 8).reshape(2,2,2); + INDArray out = Nd4j.createUninitialized(2,2); + + Nd4j.exec(new BitCast(in, DataType.DOUBLE.toInt(), out)); + + INDArray expected = Nd4j.createFromArray(new double[]{2., 512., 8192., 131072.032 }).reshape(2,2); + assertArrayEquals(new long[]{2,2}, out.shape()); + assertEquals(expected, out); + } + + @Test + public void testCompareAndBitpack() { + INDArray in = Nd4j.createFromArray(new double[]{-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, -4.f, -3.f, + -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}).reshape( 2,3,4); + INDArray out = Nd4j.createUninitialized(DataType.UBYTE, 2,3,4); + INDArray expected = Nd4j.createFromArray(new byte[]{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1}). + reshape(2,3,4); + + Nd4j.exec(new CompareAndBitpack(in ,2.0, out)); + assertArrayEquals(new long[]{2,3,4}, out.shape()); + } + + @Test + public void testDivideNoNan() { + INDArray in1 = Nd4j.rand(DataType.DOUBLE, 2,3,4); + INDArray in2 = Nd4j.rand(DataType.DOUBLE, 2,3,4); + INDArray out = Nd4j.createUninitialized(DataType.DOUBLE, 2,3,4); + + Nd4j.exec(new DivideNoNan(in1, in2, out)); + assertArrayEquals(new long[]{2,3,4}, out.shape()); + } + + @Test + public void testDrawBoundingBoxes() { + INDArray images = Nd4j.linspace(DataType.FLOAT, 1.0f, 1.0f, 2*4*5*3).reshape(2,4,5,3); + INDArray boxes = Nd4j.createFromArray(new float[]{ 0.0f , 0.0f , 1.0f , 1.0f, + 0.1f, 0.2f, 0.9f, 0.8f, + 0.3f, 0.3f, 0.7f, 0.7f, + 0.4f, 0.4f, 0.6f, 0.6f}).reshape(2,2,4); + INDArray colors = Nd4j.createFromArray(new float[]{ + 201.0f, 202.0f, 203.0f, 127.0f, 128.0f, 129.0f}). + reshape(2,3); + INDArray output = Nd4j.create(DataType.FLOAT, images.shape()); + INDArray expected = Nd4j.createFromArray(new float[]{127.f, 128.f, 129.f, 127.f, 128.f, 129.f, 127.f, 128.f, 129.f, + 127.f, 128.f, 129.f, 201.f, 202.f, 203.f, + 127.f, 128.f, 129.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 127.f, 128.f, 129.f, 201.f, 202.f, 203.f, + 127.f, 128.f, 129.f, 127.f, 128.f, 129.f, 127.f, 128.f, 129.f, 127.f, 128.f, 129.f, 201.f, 202.f, 203.f, + 201.f, 202.f, 203.f, 201.f ,202.f ,203.f, 201.f, 202.f, 203.f, 201.f, 202.f, 203.f, 201.f, 202.f, 203.f, + + 61.f, 62.f, 63.f, 201.f, 202.f, 203.f, 201.f, 202.f, 203.f, 70.f, 71.f, 72.f, 73.f, 74.f, 75.f, + 76.f, 77.f, 78.f, 127.f, 128.f, 129.f, 127.f, 128.f, 129.f, 85.f, 86.f, 87.f, 88.f, 89.f, 90.f, + 91.f, 92.f, 93.f, 201.f, 202.f, 203.f, 201.f, 202.f, 203.f, 100.f, 101.f, 102.f, 103.f, 104.f, 105.f, + 106.f, 107.f, 108.f, 109.f, 110.f, 111.f, 112.f, 113.f, 114.f, 115.f, 116.f, 117.f, 118.f, 119.f, 120.f}). + reshape(2,4,5,3); + + Nd4j.exec(new DrawBoundingBoxes(images, boxes, colors, output)); + + assertArrayEquals(images.shape(), output.shape()); + assertEquals(expected, output); + } + + @Test + public void FakeQuantWithMinMaxVarsPerChannel() { + + INDArray x = Nd4j.createFromArray(new float[]{-63.80f, -63.75f, -63.4f, -63.5f, 0.0f, 0.1f}). + reshape(1,2,3,1); + + INDArray min = Nd4j.createFromArray(new float[]{-63.65f}); + INDArray max = Nd4j.createFromArray(new float[]{0.1f}); + + INDArray output = Nd4j.createUninitialized(DataType.FLOAT, 1,2,3,1); + INDArray expected = Nd4j.createFromArray(new float[]{-63.75f, -63.75f, -63.5f, -63.5f, 0.f, 0.f}). + reshape(1,2,3,1); + + Nd4j.exec(new FakeQuantWithMinMaxVarsPerChannel(x,min,max,output)); + + assertEquals(expected, output); + } } From c4307384f35b9e21b31fc6334c86fa8e10ed70d2 Mon Sep 17 00:00:00 2001 From: Alexander Stoyakin Date: Wed, 16 Oct 2019 12:59:25 +0300 Subject: [PATCH 2/7] Fixed shape for muli --- .../models/word2vec/Word2VecTests.java | 32 +++++++++++++++++++ .../reader/impl/BasicModelUtils.java | 3 +- 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTests.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTests.java index 01b38a644..736998484 100755 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTests.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTests.java @@ -50,6 +50,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.File; +import java.io.IOException; import java.util.*; import static org.junit.Assert.*; @@ -816,6 +817,37 @@ public class Word2VecTests extends BaseDL4JTest { assertEquals(vec1.getWordVectorMatrix("money"), vec2.getWordVectorMatrix("money")); } + @Test + public void testWordsNearestSum() throws IOException { + log.info("Load & Vectorize Sentences...."); + SentenceIterator iter = new BasicLineIterator(inputFile); + TokenizerFactory t = new DefaultTokenizerFactory(); + t.setTokenPreProcessor(new CommonPreprocessor()); + + log.info("Building model...."); + Word2Vec vec = new Word2Vec.Builder() + .minWordFrequency(5) + .iterations(1) + .layerSize(100) + .seed(42) + .windowSize(5) + .iterate(iter) + .tokenizerFactory(t) + .build(); + + log.info("Fitting Word2Vec model...."); + vec.fit(); + log.info("Writing word vectors to text file...."); + log.info("Closest Words:"); + Collection lst = vec.wordsNearestSum("day", 10); + log.info("10 Words closest to 'day': {}", lst); + assertTrue(lst.contains("week")); + assertTrue(lst.contains("night")); + assertTrue(lst.contains("year")); + assertTrue(lst.contains("years")); + assertTrue(lst.contains("time")); + } + private static void printWords(String target, Collection list, Word2Vec vec) { System.out.println("Words close to [" + target + "]:"); for (String word : list) { diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/reader/impl/BasicModelUtils.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/reader/impl/BasicModelUtils.java index 84fc17b7e..4912a3c47 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/reader/impl/BasicModelUtils.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/embeddings/reader/impl/BasicModelUtils.java @@ -351,7 +351,8 @@ public class BasicModelUtils implements ModelUtils if (lookupTable instanceof InMemoryLookupTable) { InMemoryLookupTable l = (InMemoryLookupTable) lookupTable; INDArray syn0 = l.getSyn0(); - INDArray weights = syn0.norm2(0).rdivi(1).muli(words); + INDArray temp = syn0.norm2(0).rdivi(1).reshape(words.shape()); + INDArray weights = temp.muli(words); INDArray distances = syn0.mulRowVector(weights).sum(1); INDArray[] sorted = Nd4j.sortWithIndices(distances, 0, false); INDArray sort = sorted[0]; From 96a9a1a733f78884fe4209076180a176aae588ba Mon Sep 17 00:00:00 2001 From: Alexander Stoyakin Date: Wed, 16 Oct 2019 18:07:52 +0300 Subject: [PATCH 3/7] Fixed output from operation. --- .../ops/declarable/generic/parity_ops/adjust_contrast.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/adjust_contrast.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/adjust_contrast.cpp index 4011d5e32..538214b14 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/adjust_contrast.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/adjust_contrast.cpp @@ -47,7 +47,7 @@ CONFIGURABLE_OP_IMPL(adjust_contrast, 1, 1, true, 1, 0) { NDArray factorT(output->dataType(), block.launchContext()); // = NDArrayFactory::create(factor, block.launchContext()); factorT.p(0, factor); // this is contrast calculation - *output = (*input - mean) * factorT + mean; + output->assign((*input - mean) * factorT + mean); return Status::OK(); } From 99d77e138412309eb8411b8d91ec7613950ecd7c Mon Sep 17 00:00:00 2001 From: Alexander Stoyakin Date: Wed, 16 Oct 2019 19:16:47 +0300 Subject: [PATCH 4/7] Ops exported for sameDiff --- .../DifferentialFunctionFactory.java | 28 +++++++++++++++++++ .../linalg/api/ops/custom/AdjustContrast.java | 6 ++++ .../api/ops/custom/AdjustContrastV2.java | 6 ++++ .../api/ops/custom/BaseAdjustContrast.java | 10 +++++++ .../nd4j/linalg/api/ops/custom/BitCast.java | 6 ++++ .../api/ops/custom/CompareAndBitpack.java | 6 ++++ .../linalg/api/ops/custom/DivideNoNan.java | 6 ++++ .../api/ops/custom/DrawBoundingBoxes.java | 6 ++++ .../FakeQuantWithMinMaxVarsPerChannel.java | 6 ++++ 9 files changed, 80 insertions(+) diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java index 1a40fbd11..4b042dded 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java @@ -33,6 +33,7 @@ import org.nd4j.linalg.api.blas.params.MMulTranspose; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.NoOp; +import org.nd4j.linalg.api.ops.custom.*; import org.nd4j.linalg.api.ops.impl.broadcast.BiasAdd; import org.nd4j.linalg.api.ops.impl.broadcast.BiasAddGrad; import org.nd4j.linalg.api.ops.impl.controlflow.compat.Enter; @@ -2649,6 +2650,33 @@ public class DifferentialFunctionFactory { return new NextIteration(sameDiff, x).outputVariable(); } + public SDVariable adjustContrast(SDVariable in, SDVariable factor) { + return new AdjustContrast(sameDiff, in, factor).outputVariable(); + } + + public SDVariable adjustContrastV2(SDVariable in, SDVariable factor) { + return new AdjustContrastV2(sameDiff, in, factor).outputVariable(); + } + + public SDVariable bitCast(SDVariable in, SDVariable dataType) { + return new BitCast(sameDiff, in, dataType).outputVariable(); + } + + public SDVariable compareAndBitpack(SDVariable threshold) { + return new CompareAndBitpack(sameDiff, threshold).outputVariable(); + } + + public SDVariable divideNoNan(SDVariable in1, SDVariable in2) { + return new DivideNoNan(sameDiff, in1, in2).outputVariable(); + } + + public SDVariable drawBoundingBoxes(SDVariable boxes, SDVariable colors) { + return new DrawBoundingBoxes(sameDiff, boxes, colors).outputVariable(); + } + + public SDVariable fakeQuantWithMinMaxVarsPerChannel(SDVariable x, SDVariable min, SDVariable max) { + return new FakeQuantWithMinMaxVarsPerChannel(sameDiff,x,min,max).outputVariable(); + } public String toString() { return "DifferentialFunctionFactory{methodNames=" + methodNames + "}"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrast.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrast.java index aad384a26..181b1657d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrast.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrast.java @@ -1,5 +1,7 @@ package org.nd4j.linalg.api.ops.custom; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; @@ -12,6 +14,10 @@ public class AdjustContrast extends BaseAdjustContrast { super(in, factor, out); } + public AdjustContrast(SameDiff sameDiff, SDVariable in, SDVariable factor) { + super(sameDiff,new SDVariable[]{in,factor}); + } + @Override public String opName() { return "adjust_contrast"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrastV2.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrastV2.java index 4be4ae098..74359da7f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrastV2.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrastV2.java @@ -1,5 +1,7 @@ package org.nd4j.linalg.api.ops.custom; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; @@ -12,6 +14,10 @@ public class AdjustContrastV2 extends BaseAdjustContrast { super(in, factor, out); } + public AdjustContrastV2(SameDiff sameDiff, SDVariable in, SDVariable factor) { + super( sameDiff,new SDVariable[]{in,factor}); + } + @Override public String opName() { return "adjust_contrast_v2"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BaseAdjustContrast.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BaseAdjustContrast.java index 7057118c5..fe14fe69c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BaseAdjustContrast.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BaseAdjustContrast.java @@ -1,5 +1,7 @@ package org.nd4j.linalg.api.ops.custom; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; @@ -16,4 +18,12 @@ public abstract class BaseAdjustContrast extends DynamicCustomOp { addTArgument(factor); } + + public BaseAdjustContrast(SameDiff sameDiff, SDVariable[] vars) { + super("", sameDiff, vars); + } + + public BaseAdjustContrast(SameDiff sameDiff, SDVariable in, SDVariable factor, SDVariable out) { + super(null, sameDiff, new SDVariable[]{in, factor, out}); + } } \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BitCast.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BitCast.java index 7a1f125c6..fbfad0305 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BitCast.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BitCast.java @@ -1,5 +1,7 @@ package org.nd4j.linalg.api.ops.custom; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; @@ -14,6 +16,10 @@ public class BitCast extends DynamicCustomOp { iArguments.add(Long.valueOf(dataType)); } + public BitCast(SameDiff sameDiff, SDVariable in, SDVariable dataType) { + super("", sameDiff, new SDVariable[]{in, dataType}); + } + @Override public String opName() { return "bitcast"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/CompareAndBitpack.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/CompareAndBitpack.java index 4f0aad2ee..eb0762f0f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/CompareAndBitpack.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/CompareAndBitpack.java @@ -1,5 +1,7 @@ package org.nd4j.linalg.api.ops.custom; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.factory.Nd4j; @@ -13,6 +15,10 @@ public class CompareAndBitpack extends DynamicCustomOp { outputArguments.add(out); } + public CompareAndBitpack(SameDiff sameDiff, SDVariable threshold) { + super("", sameDiff, new SDVariable[]{threshold}); + } + @Override public String opName() { return "compare_and_bitpack"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/DivideNoNan.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/DivideNoNan.java index b2eafb791..ce67b14f9 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/DivideNoNan.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/DivideNoNan.java @@ -1,6 +1,8 @@ package org.nd4j.linalg.api.ops.custom; import org.apache.commons.math3.analysis.function.Divide; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; @@ -14,6 +16,10 @@ public class DivideNoNan extends DynamicCustomOp { outputArguments.add(out); } + public DivideNoNan(SameDiff sameDiff, SDVariable in1, SDVariable in2) { + super("", sameDiff, new SDVariable[]{in1, in2}); + } + @Override public String opName() { return "divide_no_nan"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/DrawBoundingBoxes.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/DrawBoundingBoxes.java index c6cf04b62..2ac6e6458 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/DrawBoundingBoxes.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/DrawBoundingBoxes.java @@ -1,5 +1,7 @@ package org.nd4j.linalg.api.ops.custom; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; @@ -14,6 +16,10 @@ public class DrawBoundingBoxes extends DynamicCustomOp { outputArguments.add(output); } + public DrawBoundingBoxes(SameDiff sameDiff, SDVariable boxes, SDVariable colors) { + super("", sameDiff, new SDVariable[]{boxes, colors}); + } + @Override public String opName() { return "draw_bounding_boxes"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/FakeQuantWithMinMaxVarsPerChannel.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/FakeQuantWithMinMaxVarsPerChannel.java index 3bdcf6dd3..2043732d4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/FakeQuantWithMinMaxVarsPerChannel.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/FakeQuantWithMinMaxVarsPerChannel.java @@ -1,5 +1,7 @@ package org.nd4j.linalg.api.ops.custom; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; @@ -18,6 +20,10 @@ public class FakeQuantWithMinMaxVarsPerChannel extends DynamicCustomOp { outputArguments.add(output); } + public FakeQuantWithMinMaxVarsPerChannel(SameDiff sameDiff, SDVariable x, SDVariable min, SDVariable max) { + super("", sameDiff, new SDVariable[]{x, min, max}); + } + @Override public String opName() { return "fake_quant_with_min_max_vars_per_channel"; From ec722b20ee7ac52995b4589479a64521e0857cb2 Mon Sep 17 00:00:00 2001 From: Alexander Stoyakin Date: Wed, 16 Oct 2019 19:29:19 +0300 Subject: [PATCH 5/7] TF names added --- .../java/org/nd4j/linalg/api/ops/custom/AdjustContrast.java | 5 +++++ .../org/nd4j/linalg/api/ops/custom/AdjustContrastV2.java | 5 +++++ .../main/java/org/nd4j/linalg/api/ops/custom/BitCast.java | 5 +++++ .../org/nd4j/linalg/api/ops/custom/CompareAndBitpack.java | 5 +++++ .../java/org/nd4j/linalg/api/ops/custom/DivideNoNan.java | 5 +++++ .../org/nd4j/linalg/api/ops/custom/DrawBoundingBoxes.java | 5 +++++ .../api/ops/custom/FakeQuantWithMinMaxVarsPerChannel.java | 5 +++++ 7 files changed, 35 insertions(+) diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrast.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrast.java index 181b1657d..df94016ef 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrast.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrast.java @@ -22,4 +22,9 @@ public class AdjustContrast extends BaseAdjustContrast { public String opName() { return "adjust_contrast"; } + + @Override + public String tensorflowName() { + return "adjust_contrast"; + } } \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrastV2.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrastV2.java index 74359da7f..8c2e3b8d3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrastV2.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrastV2.java @@ -22,4 +22,9 @@ public class AdjustContrastV2 extends BaseAdjustContrast { public String opName() { return "adjust_contrast_v2"; } + + @Override + public String tensorflowName() { + return "adjust_contrast"; + } } \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BitCast.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BitCast.java index fbfad0305..b565125cf 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BitCast.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BitCast.java @@ -24,4 +24,9 @@ public class BitCast extends DynamicCustomOp { public String opName() { return "bitcast"; } + + @Override + public String tensorflowName() { + return "bitcast"; + } } \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/CompareAndBitpack.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/CompareAndBitpack.java index eb0762f0f..d69c73da4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/CompareAndBitpack.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/CompareAndBitpack.java @@ -23,4 +23,9 @@ public class CompareAndBitpack extends DynamicCustomOp { public String opName() { return "compare_and_bitpack"; } + + @Override + public String tensorflowName() { + return "CompareAndBitpack"; + } } \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/DivideNoNan.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/DivideNoNan.java index ce67b14f9..a465c27dd 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/DivideNoNan.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/DivideNoNan.java @@ -24,4 +24,9 @@ public class DivideNoNan extends DynamicCustomOp { public String opName() { return "divide_no_nan"; } + + @Override + public String tensorflowName() { + return "divide_no_nan"; + } } \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/DrawBoundingBoxes.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/DrawBoundingBoxes.java index 2ac6e6458..d319e0090 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/DrawBoundingBoxes.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/DrawBoundingBoxes.java @@ -24,4 +24,9 @@ public class DrawBoundingBoxes extends DynamicCustomOp { public String opName() { return "draw_bounding_boxes"; } + + @Override + public String tensorflowName() { + return "draw_bounding_boxes"; + } } \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/FakeQuantWithMinMaxVarsPerChannel.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/FakeQuantWithMinMaxVarsPerChannel.java index 2043732d4..b4c33c31a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/FakeQuantWithMinMaxVarsPerChannel.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/FakeQuantWithMinMaxVarsPerChannel.java @@ -28,4 +28,9 @@ public class FakeQuantWithMinMaxVarsPerChannel extends DynamicCustomOp { public String opName() { return "fake_quant_with_min_max_vars_per_channel"; } + + @Override + public String tensorflowName() { + return "fake_quant_with_min_max_vars_per_channel"; + } } \ No newline at end of file From 502bedf5d5de5116cf3878370969293ed8df957e Mon Sep 17 00:00:00 2001 From: Alexander Stoyakin Date: Wed, 16 Oct 2019 19:39:04 +0300 Subject: [PATCH 6/7] Register ops for TF import. --- .../nd4j/imports/converters/ImportClassMapping.java | 10 ++++++++-- .../nd4j/linalg/api/ops/custom/BaseAdjustContrast.java | 4 ---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java index fcf3fe630..2fd2e6332 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java @@ -581,8 +581,14 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.random.impl.ProbablisticMerge.class, org.nd4j.linalg.api.ops.random.impl.Range.class, org.nd4j.linalg.api.ops.random.impl.TruncatedNormalDistribution.class, - org.nd4j.linalg.api.ops.random.impl.UniformDistribution.class - + org.nd4j.linalg.api.ops.random.impl.UniformDistribution.class, + org.nd4j.linalg.api.ops.custom.AdjustContrast.class, + org.nd4j.linalg.api.ops.custom.AdjustContrastV2.class, + org.nd4j.linalg.api.ops.custom.BitCast.class, + org.nd4j.linalg.api.ops.custom.CompareAndBitpack.class, + org.nd4j.linalg.api.ops.custom.DivideNoNan.class, + org.nd4j.linalg.api.ops.custom.DrawBoundingBoxes.class, + org.nd4j.linalg.api.ops.custom.FakeQuantWithMinMaxVarsPerChannel.class ); static { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BaseAdjustContrast.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BaseAdjustContrast.java index fe14fe69c..cadef80e6 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BaseAdjustContrast.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BaseAdjustContrast.java @@ -22,8 +22,4 @@ public abstract class BaseAdjustContrast extends DynamicCustomOp { public BaseAdjustContrast(SameDiff sameDiff, SDVariable[] vars) { super("", sameDiff, vars); } - - public BaseAdjustContrast(SameDiff sameDiff, SDVariable in, SDVariable factor, SDVariable out) { - super(null, sameDiff, new SDVariable[]{in, factor, out}); - } } \ No newline at end of file From 9e5799847a207d9494fe5e52af1a6db7021ce7b2 Mon Sep 17 00:00:00 2001 From: Alexander Stoyakin Date: Wed, 16 Oct 2019 19:50:18 +0300 Subject: [PATCH 7/7] TF names fixed. --- .../java/org/nd4j/linalg/api/ops/custom/AdjustContrast.java | 2 +- .../java/org/nd4j/linalg/api/ops/custom/AdjustContrastV2.java | 2 +- .../src/main/java/org/nd4j/linalg/api/ops/custom/BitCast.java | 2 +- .../main/java/org/nd4j/linalg/api/ops/custom/DivideNoNan.java | 2 +- .../java/org/nd4j/linalg/api/ops/custom/DrawBoundingBoxes.java | 2 +- .../api/ops/custom/FakeQuantWithMinMaxVarsPerChannel.java | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrast.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrast.java index df94016ef..80f29e577 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrast.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrast.java @@ -25,6 +25,6 @@ public class AdjustContrast extends BaseAdjustContrast { @Override public String tensorflowName() { - return "adjust_contrast"; + return "AdjustContrast"; } } \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrastV2.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrastV2.java index 8c2e3b8d3..80f8c106d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrastV2.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrastV2.java @@ -25,6 +25,6 @@ public class AdjustContrastV2 extends BaseAdjustContrast { @Override public String tensorflowName() { - return "adjust_contrast"; + return "AdjustContrast"; } } \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BitCast.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BitCast.java index b565125cf..ee0adfb94 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BitCast.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BitCast.java @@ -27,6 +27,6 @@ public class BitCast extends DynamicCustomOp { @Override public String tensorflowName() { - return "bitcast"; + return "Bitcast"; } } \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/DivideNoNan.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/DivideNoNan.java index a465c27dd..400830ec3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/DivideNoNan.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/DivideNoNan.java @@ -27,6 +27,6 @@ public class DivideNoNan extends DynamicCustomOp { @Override public String tensorflowName() { - return "divide_no_nan"; + return "DivNoNan"; } } \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/DrawBoundingBoxes.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/DrawBoundingBoxes.java index d319e0090..4c672a66c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/DrawBoundingBoxes.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/DrawBoundingBoxes.java @@ -27,6 +27,6 @@ public class DrawBoundingBoxes extends DynamicCustomOp { @Override public String tensorflowName() { - return "draw_bounding_boxes"; + return "DrawBoundingBoxes"; } } \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/FakeQuantWithMinMaxVarsPerChannel.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/FakeQuantWithMinMaxVarsPerChannel.java index b4c33c31a..303ac8458 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/FakeQuantWithMinMaxVarsPerChannel.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/FakeQuantWithMinMaxVarsPerChannel.java @@ -31,6 +31,6 @@ public class FakeQuantWithMinMaxVarsPerChannel extends DynamicCustomOp { @Override public String tensorflowName() { - return "fake_quant_with_min_max_vars_per_channel"; + return "FakeQuantWithMinMaxVarsPerChannel"; } } \ No newline at end of file