From 010744ef9c441ef98cbd2805f10be63cca8cdbca Mon Sep 17 00:00:00 2001 From: Alexander Stoyakin Date: Mon, 30 Dec 2019 14:06:12 +0200 Subject: [PATCH] Lu wrapper and tests fixes (#144) * Tests fixed * Lu added * Test fixed * Default timeout * Tests timeouts fixed. * TF import fix * Timeouts added * Timeout fixed. * Test corrected * rgb and yiq conversion ops added * Converter ops added * Header * Yuv converters * API added * Empty test for matmul * Explanation * skip gemm/gemv on empty inputs Signed-off-by: raver119 * Test added * Correct test * one more empty pass-through for mmul Signed-off-by: raver119 * Cleanup * Test added * Test fixed * Added missing mapping * Added missing mapping Co-authored-by: raver119 --- .../datasets/MnistFetcherTest.java | 8 +- .../RecordReaderDataSetiteratorTest.java | 2 +- .../RecordReaderMultiDataSetIteratorTest.java | 2 +- .../fetchers/SvhnDataFetcherTest.java | 5 + .../iterator/MultipleEpochsIteratorTest.java | 2 +- .../iterator/TestEmnistDataSetIterator.java | 5 + .../parameterserver/BaseDL4JTest.java | 4 + .../org/deeplearning4j/zoo/BaseDL4JTest.java | 4 + libnd4j/include/helpers/cpu/MmulHelper.cpp | 9 + .../include/helpers/cuda_off/MmulHelper.cu | 9 + libnd4j/include/helpers/impl/MmulHelper.cpp | 3 + libnd4j/tests_cpu/layers_tests/EmptyTests.cpp | 33 ++- libnd4j/windows.md | 4 + .../nd4j/autodiff/samediff/ops/SDImage.java | 44 ++++ .../converters/ImportClassMapping.java | 8 +- .../org/nd4j/linalg/api/ops/custom/Lu.java | 69 +++++ .../linalg/api/ops/custom/RgbToGrayscale.java | 44 ++++ .../nd4j/linalg/api/ops/custom/RgbToYiq.java | 56 ++++ .../nd4j/linalg/api/ops/custom/RgbToYuv.java | 56 ++++ .../nd4j/linalg/api/ops/custom/YiqToRgb.java | 55 ++++ .../nd4j/linalg/api/ops/custom/YuvToRgb.java | 56 ++++ .../test/java/org/nd4j/linalg/Nd4jTestsC.java | 36 +++ .../nd4j/linalg/custom/CustomOpsTests.java | 248 +++++++++++++++--- 23 files changed, 716 insertions(+), 46 deletions(-) create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Lu.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/RgbToGrayscale.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/RgbToYiq.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/RgbToYuv.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/YiqToRgb.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/YuvToRgb.java diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/MnistFetcherTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/MnistFetcherTest.java index 7a057afc6..362e099e4 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/MnistFetcherTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/MnistFetcherTest.java @@ -20,11 +20,9 @@ import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.base.MnistFetcher; import org.deeplearning4j.common.resources.DL4JResources; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; -import org.junit.AfterClass; -import org.junit.BeforeClass; -import org.junit.ClassRule; -import org.junit.Test; +import org.junit.*; import org.junit.rules.TemporaryFolder; +import org.junit.rules.Timeout; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition; import org.nd4j.linalg.dataset.DataSet; @@ -47,6 +45,8 @@ public class MnistFetcherTest extends BaseDL4JTest { @ClassRule public static TemporaryFolder testDir = new TemporaryFolder(); + @Rule + public Timeout timeout = Timeout.seconds(300); @BeforeClass public static void setup() throws Exception { diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderDataSetiteratorTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderDataSetiteratorTest.java index 0772072b5..7dfd46a8d 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderDataSetiteratorTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderDataSetiteratorTest.java @@ -72,7 +72,7 @@ import static org.nd4j.linalg.indexing.NDArrayIndex.point; public class RecordReaderDataSetiteratorTest extends BaseDL4JTest { @Rule - protected Timeout timeout = Timeout.seconds(300); + public Timeout timeout = Timeout.seconds(300); @Override public DataType getDataType(){ diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIteratorTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIteratorTest.java index cb534cc23..897876112 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIteratorTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIteratorTest.java @@ -71,7 +71,7 @@ public class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest { public TemporaryFolder temporaryFolder = new TemporaryFolder(); @Rule - protected Timeout timeout = Timeout.seconds(300); + public Timeout timeout = Timeout.seconds(300); @Test public void testsBasic() throws Exception { diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/fetchers/SvhnDataFetcherTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/fetchers/SvhnDataFetcherTest.java index 818bb752f..1815dff73 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/fetchers/SvhnDataFetcherTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/fetchers/SvhnDataFetcherTest.java @@ -17,7 +17,9 @@ package org.deeplearning4j.datasets.fetchers; import org.deeplearning4j.BaseDL4JTest; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.Timeout; import java.io.File; @@ -28,6 +30,9 @@ import static org.junit.Assert.assertTrue; */ public class SvhnDataFetcherTest extends BaseDL4JTest { + @Rule + public Timeout timeout = Timeout.seconds(600); + @Test public void testSvhnDataFetcher() throws Exception { SvhnDataFetcher fetch = new SvhnDataFetcher(); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/MultipleEpochsIteratorTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/MultipleEpochsIteratorTest.java index 7bad73f06..f37642c24 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/MultipleEpochsIteratorTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/MultipleEpochsIteratorTest.java @@ -40,7 +40,7 @@ import static org.junit.Assert.*; public class MultipleEpochsIteratorTest extends BaseDL4JTest { @Rule - protected Timeout timeout = Timeout.seconds(300); + public Timeout timeout = Timeout.seconds(300); @Test public void testNextAndReset() throws Exception { diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/TestEmnistDataSetIterator.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/TestEmnistDataSetIterator.java index 0c1049845..b0d5e0d25 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/TestEmnistDataSetIterator.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/datasets/iterator/TestEmnistDataSetIterator.java @@ -19,7 +19,9 @@ package org.deeplearning4j.datasets.iterator; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.datasets.iterator.impl.EmnistDataSetIterator; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.Timeout; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; @@ -33,6 +35,9 @@ import static org.junit.Assert.*; @Slf4j public class TestEmnistDataSetIterator extends BaseDL4JTest { + @Rule + public Timeout timeout = Timeout.seconds(600); + @Override public DataType getDataType(){ return DataType.FLOAT; diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper-parameter-server/src/test/java/org/deeplearning4j/parallelism/parameterserver/BaseDL4JTest.java b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper-parameter-server/src/test/java/org/deeplearning4j/parallelism/parameterserver/BaseDL4JTest.java index 8e087cc2f..6684c6384 100644 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper-parameter-server/src/test/java/org/deeplearning4j/parallelism/parameterserver/BaseDL4JTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper-parameter-server/src/test/java/org/deeplearning4j/parallelism/parameterserver/BaseDL4JTest.java @@ -22,6 +22,7 @@ import org.junit.After; import org.junit.Before; import org.junit.Rule; import org.junit.rules.TestName; +import org.junit.rules.Timeout; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.ops.executioner.OpExecutioner; @@ -36,6 +37,9 @@ import java.util.Properties; @Slf4j public class BaseDL4JTest { + @Rule + public Timeout timeout = Timeout.seconds(600); + @Rule public TestName name = new TestName(); diff --git a/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/BaseDL4JTest.java b/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/BaseDL4JTest.java index 8c2b9bb07..5d5cbd8a8 100644 --- a/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/BaseDL4JTest.java +++ b/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/BaseDL4JTest.java @@ -22,6 +22,7 @@ import org.junit.After; import org.junit.Before; import org.junit.Rule; import org.junit.rules.TestName; +import org.junit.rules.Timeout; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; @@ -37,6 +38,9 @@ import java.util.Properties; @Slf4j public class BaseDL4JTest { + @Rule + public Timeout timeout = Timeout.seconds(600); + @Rule public TestName name = new TestName(); diff --git a/libnd4j/include/helpers/cpu/MmulHelper.cpp b/libnd4j/include/helpers/cpu/MmulHelper.cpp index 189143f03..f0e8846e3 100644 --- a/libnd4j/include/helpers/cpu/MmulHelper.cpp +++ b/libnd4j/include/helpers/cpu/MmulHelper.cpp @@ -202,6 +202,9 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, con if(C == nullptr) C = new NDArray(outOrder, {M,N}, DataTypeUtils::pickPairwiseResultType(A->dataType(), B->dataType()), A->getContext()); + if (C->isEmpty()) + return C; + const auto aType = A->dataType(); const auto bType = B->dataType(); const auto cType = C->dataType(); @@ -307,6 +310,9 @@ NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, nd4j::NDArray* if(Y == nullptr) Y = new NDArray(outOrder, {M}, DataTypeUtils::pickPairwiseResultType(A->dataType(), X->dataType()), A->getContext()); + if (Y->isEmpty()) + return Y; + const int incx = X->stridesOf()[xLenDim]; const int incy = Y->stridesOf()[yLenDim]; @@ -511,6 +517,9 @@ NDArray* MmulHelper::mmulNxN(const NDArray* A, const NDArray* B, NDArray* C, con C = new NDArray(outOrder, cExpectedShape, B->dataType()); } + if (C->isEmpty()) + return C; + const int cRank = C->rankOf(); const int aMaxis(aRank-2), aKaxis(aRank-1), bKaxis(bRank-2), bNaxis(bRank-1), cMaxis(cRank-2), cNaxis(cRank-1); diff --git a/libnd4j/include/helpers/cuda_off/MmulHelper.cu b/libnd4j/include/helpers/cuda_off/MmulHelper.cu index 40bb9453f..bf366dc29 100644 --- a/libnd4j/include/helpers/cuda_off/MmulHelper.cu +++ b/libnd4j/include/helpers/cuda_off/MmulHelper.cu @@ -235,6 +235,9 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, dou if(C == nullptr) C = new NDArray(outOrder, {M,N}, DataTypeUtils::pickPairwiseResultType(A->dataType(), B->dataType()), A->getContext()); + if (C->isEmpty()) + return C; + const int major = Environment::getInstance()->capabilities()[AffinityManager::currentDeviceId()].first(); const auto aType = A->dataType(); @@ -376,6 +379,9 @@ NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, nd4j::NDArray* if(Y == nullptr) Y = new NDArray(outOrder, {M}, DataTypeUtils::pickPairwiseResultType(A->dataType(), X->dataType()), A->getContext()); + if (Y->isEmpty()) + return Y; + const int incx = X->strideAt(xLenDim); const int incy = Y->strideAt(yLenDim); @@ -634,6 +640,9 @@ NDArray* MmulHelper::mmulNxN(const NDArray* A, const NDArray* B, NDArray* C, con else C = new NDArray(outOrder, cExpectedShape, DataTypeUtils::pickPairwiseResultType(A->dataType(), B->dataType()), A->getContext()); + if (C->isEmpty()) + return C; + const int cRank = C->rankOf(); const int aMaxis(aRank-2), aKaxis(aRank-1), bKaxis(bRank-2), bNaxis(bRank-1), cMaxis(cRank-2), cNaxis(cRank-1); diff --git a/libnd4j/include/helpers/impl/MmulHelper.cpp b/libnd4j/include/helpers/impl/MmulHelper.cpp index ab97ad137..716062a53 100644 --- a/libnd4j/include/helpers/impl/MmulHelper.cpp +++ b/libnd4j/include/helpers/impl/MmulHelper.cpp @@ -236,6 +236,9 @@ nd4j::NDArray* MmulHelper::mmul(const nd4j::NDArray* A, const nd4j::NDArray* B, throw std::invalid_argument(""); } + if (z->isEmpty()) + return; + NDArray* xT(const_cast(x)), *yT(const_cast(y)), *zT(z); if((transX && xRank > 1) || (transY && yRank > 1)) { diff --git a/libnd4j/tests_cpu/layers_tests/EmptyTests.cpp b/libnd4j/tests_cpu/layers_tests/EmptyTests.cpp index f17c5aa5a..12069c67e 100644 --- a/libnd4j/tests_cpu/layers_tests/EmptyTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/EmptyTests.cpp @@ -323,4 +323,35 @@ TEST_F(EmptyTests, test_empty_reshape_1) { delete result0; delete result1; -} \ No newline at end of file +} + + +TEST_F(EmptyTests, test_empty_matmul_1) { + auto x = NDArrayFactory::create('c', {0, 1}); + auto y = NDArrayFactory::create('c', {1, 0}); + auto e = NDArrayFactory::create('c', {0, 0}); + + nd4j::ops::matmul op; + auto result = op.execute({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); + ASSERT_EQ(e, *z); + + delete result; +} + +TEST_F(EmptyTests, test_empty_matmul_2) { + auto x = NDArrayFactory::create('c', {1, 0, 4}); + auto y = NDArrayFactory::create('c', {1, 4, 0}); + auto e = NDArrayFactory::create('c', {1, 0, 0}); + + nd4j::ops::matmul op; + auto result = op.execute({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); + ASSERT_EQ(e, *z); + + delete result; +} diff --git a/libnd4j/windows.md b/libnd4j/windows.md index 884b2c3ee..57b40cb83 100644 --- a/libnd4j/windows.md +++ b/libnd4j/windows.md @@ -274,3 +274,7 @@ To build libnd4j with MKL: Then build libnd4j as before. You may have to be careful about having multiple BLAS implementations on your path. Ideally, have only MKL on the path while building libnd4j. Note: you may be able to get some additional performance on hyperthreaded processors by setting the system/environment variable MKL_DYNAMIC to have the value 'false'. + + +float16_nhcw +float16_nhwc \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDImage.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDImage.java index 4cc020e3a..7b662b960 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDImage.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDImage.java @@ -138,4 +138,48 @@ public class SDImage extends SDOps { SDVariable out = new HsvToRgb(sd, input).outputVariable(); return updateVariableNameAndReference(out, name); } + + /** + * Converting array from RGB to YIQ format + * @param name name + * @param input 3D image + * @return 3D image + */ + public SDVariable rgbToYiq(String name, @NonNull SDVariable input) { + SDVariable out = new RgbToYiq(sd, input).outputVariable(); + return updateVariableNameAndReference(out, name); + } + + /** + * Converting image from YIQ to RGB format + * @param name name + * @param input 3D image + * @return 3D image + */ + public SDVariable yiqToRgb(String name, @NonNull SDVariable input) { + SDVariable out = new YiqToRgb(sd, input).outputVariable(); + return updateVariableNameAndReference(out, name); + } + + /** + * Converting array from RGB to YUV format + * @param name name + * @param input 3D image + * @return 3D image + */ + public SDVariable rgbToYuv(String name, @NonNull SDVariable input) { + SDVariable out = new RgbToYuv(sd, input).outputVariable(); + return updateVariableNameAndReference(out, name); + } + + /** + * Converting image from YUV to RGB format + * @param name name + * @param input 3D image + * @return 3D image + */ + public SDVariable yuvToRgb(String name, @NonNull SDVariable input) { + SDVariable out = new YuvToRgb(sd, input).outputVariable(); + return updateVariableNameAndReference(out, name); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java index 700d48a3c..0fbc3960d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java @@ -593,6 +593,11 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.custom.AdjustContrastV2.class, org.nd4j.linalg.api.ops.custom.HsvToRgb.class, org.nd4j.linalg.api.ops.custom.RgbToHsv.class, + org.nd4j.linalg.api.ops.custom.RgbToYiq.class, + org.nd4j.linalg.api.ops.custom.RgbToGrayscale.class, + org.nd4j.linalg.api.ops.custom.YiqToRgb.class, + org.nd4j.linalg.api.ops.custom.RgbToYuv.class, + org.nd4j.linalg.api.ops.custom.YuvToRgb.class, org.nd4j.linalg.api.ops.custom.BitCast.class, org.nd4j.linalg.api.ops.custom.CompareAndBitpack.class, org.nd4j.linalg.api.ops.custom.DivideNoNan.class, @@ -609,7 +614,8 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.custom.ToggleBits.class, org.nd4j.linalg.api.ops.custom.Igamma.class, org.nd4j.linalg.api.ops.custom.Igammac.class, - org.nd4j.linalg.api.ops.custom.Digamma.class + org.nd4j.linalg.api.ops.custom.Digamma.class, + org.nd4j.linalg.api.ops.custom.Lu.class ); static { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Lu.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Lu.java new file mode 100644 index 000000000..af1bf0155 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Lu.java @@ -0,0 +1,69 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.custom; + +import lombok.NoArgsConstructor; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.imports.graphmapper.tf.TFGraphMapper; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.tensorflow.framework.AttrValue; +import org.tensorflow.framework.GraphDef; +import org.tensorflow.framework.NodeDef; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; + +@NoArgsConstructor +public class Lu extends DynamicCustomOp { + private DataType indexDataType; + + public Lu(INDArray input) { + addInputArgument(input); + } + + public Lu(SameDiff sameDiff, SDVariable input) { + super(sameDiff, new SDVariable[]{input}); + } + + @Override + public String opName() { + return "lu"; + } + + @Override + public String tensorflowName() { + return "Lu"; + } + + @Override + public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { + if (attributesForNode.containsKey("output_idx_type")){ + indexDataType = TFGraphMapper.convertType(attributesForNode.get("output_idx_type").getType()); + } + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + int n = args().length; + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes); + return Arrays.asList(inputDataTypes.get(0), indexDataType); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/RgbToGrayscale.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/RgbToGrayscale.java new file mode 100644 index 000000000..6b71ba17f --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/RgbToGrayscale.java @@ -0,0 +1,44 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.custom; + +import lombok.NoArgsConstructor; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +@NoArgsConstructor +public class RgbToGrayscale extends DynamicCustomOp { + + public RgbToGrayscale(INDArray image) { + addInputArgument(image); + } + + public RgbToGrayscale(SameDiff sameDiff, SDVariable image) { + super(sameDiff, new SDVariable[]{image}); + } + + @Override + public String opName() { + return "rgb_to_grs"; + } + + @Override + public String tensorflowName() { + return "RgbToGrs"; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/RgbToYiq.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/RgbToYiq.java new file mode 100644 index 000000000..1d6a48a4f --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/RgbToYiq.java @@ -0,0 +1,56 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.custom; + +import lombok.NoArgsConstructor; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.Collections; +import java.util.List; + +@NoArgsConstructor +public class RgbToYiq extends DynamicCustomOp { + + public RgbToYiq(INDArray input) { + addInputArgument(input); + } + + public RgbToYiq(SameDiff sameDiff, SDVariable input) { + super(sameDiff, new SDVariable[]{input}); + } + + @Override + public String opName() { + return "rgb_to_yiq"; + } + + @Override + public String tensorflowName() { + return "RgbToYiq"; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + int n = args().length; + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes); + return Collections.singletonList(inputDataTypes.get(0)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/RgbToYuv.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/RgbToYuv.java new file mode 100644 index 000000000..c65c6e777 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/RgbToYuv.java @@ -0,0 +1,56 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.api.ops.custom; + +import lombok.NoArgsConstructor; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.Collections; +import java.util.List; + +@NoArgsConstructor +public class RgbToYuv extends DynamicCustomOp { + public RgbToYuv(INDArray input) { + addInputArgument(input); + } + + public RgbToYuv(SameDiff sameDiff, SDVariable input) { + super(sameDiff, new SDVariable[]{input}); + } + + @Override + public String opName() { + return "rgb_to_yuv"; + } + + @Override + public String tensorflowName() { + return "RgbToYuv"; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + int n = args().length; + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes); + return Collections.singletonList(inputDataTypes.get(0)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/YiqToRgb.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/YiqToRgb.java new file mode 100644 index 000000000..8126a1803 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/YiqToRgb.java @@ -0,0 +1,55 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.custom; + +import lombok.NoArgsConstructor; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.Collections; +import java.util.List; + +@NoArgsConstructor +public class YiqToRgb extends DynamicCustomOp { + public YiqToRgb(INDArray input) { + addInputArgument(input); + } + + public YiqToRgb(SameDiff sameDiff, SDVariable input) { + super(sameDiff, new SDVariable[]{input}); + } + + @Override + public String opName() { + return "yiq_to_rgb"; + } + + @Override + public String tensorflowName() { + return "YiqToRgb"; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + int n = args().length; + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes); + return Collections.singletonList(inputDataTypes.get(0)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/YuvToRgb.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/YuvToRgb.java new file mode 100644 index 000000000..4643ec3fe --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/YuvToRgb.java @@ -0,0 +1,56 @@ +/* ****************************************************************************** + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.api.ops.custom; + +import lombok.NoArgsConstructor; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.Collections; +import java.util.List; + +@NoArgsConstructor +public class YuvToRgb extends DynamicCustomOp { + public YuvToRgb(INDArray input) { + addInputArgument(input); + } + + public YuvToRgb(SameDiff sameDiff, SDVariable input) { + super(sameDiff, new SDVariable[]{input}); + } + + @Override + public String opName() { + return "yuv_to_rgb"; + } + + @Override + public String tensorflowName() { + return "YuvToRgb"; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + int n = args().length; + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes); + return Collections.singletonList(inputDataTypes.get(0)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java index 66c68e3c4..42144bb97 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java @@ -6081,6 +6081,42 @@ public class Nd4jTestsC extends BaseNd4jTest { assertEquals(mE, mC); } + /* + Analog of this TF code: + a = tf.constant([], shape=[0,1]) + b = tf.constant([], shape=[1, 0]) + c = tf.matmul(a, b) + */ + @Test + public void testMatmul_Empty() { + val mA = Nd4j.create(0,1); + val mB = Nd4j.create(1,0); + val mC = Nd4j.create(0,0); + + val op = DynamicCustomOp.builder("matmul") + .addInputs(mA, mB) + .addOutputs(mC) + .build(); + + Nd4j.getExecutioner().exec(op); + assertEquals(Nd4j.create(0,0), mC); + } + + @Test + public void testMatmul_Empty1() { + val mA = Nd4j.create(1,0, 4); + val mB = Nd4j.create(1,4, 0); + val mC = Nd4j.create(1,0, 0); + + val op = DynamicCustomOp.builder("mmul") + .addInputs(mA, mB) + .addOutputs(mC) + .addIntegerArguments(0,0) + .build(); + + Nd4j.getExecutioner().exec(op); + assertEquals(Nd4j.create(1,0,0), mC); + } @Test public void testScalarSqueeze() { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java index d6f367988..dea4b6a00 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java @@ -19,6 +19,7 @@ package org.nd4j.linalg.custom; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import lombok.val; +import org.apache.commons.lang3.ArrayUtils; import org.junit.Ignore; import org.junit.Test; import org.nd4j.linalg.BaseNd4jTest; @@ -1161,6 +1162,23 @@ public class CustomOpsTests extends BaseNd4jTest { assertArrayEquals(expectedY.shape(), y.shape()); } + @Test + public void testFusedBatchNormHalf() { + INDArray x = Nd4j.create(DataType.HALF, 1,2,3,4); + //INDArray scale = Nd4j.createFromArray(new float[]{0.7717f, 0.9281f, 0.9846f, 0.4838f}); + //INDArray offset = Nd4j.createFromArray(new float[]{0.9441f, 0.5957f, 0.8669f, 0.3502f}); + INDArray scale = Nd4j.create(DataType.HALF, 4); + INDArray offset = Nd4j.create(DataType.HALF, 4); + + INDArray y = Nd4j.createUninitialized(DataType.HALF, x.shape()); + INDArray batchMean = Nd4j.create(4); + INDArray batchVar = Nd4j.create(4); + + FusedBatchNorm op = new FusedBatchNorm(x, scale, offset, 0, 1, + y, batchMean, batchVar); + Nd4j.exec(op); + } + @Test public void testMatrixBandPart() { INDArray x = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 2*3*3).reshape(2,3,3); @@ -1367,76 +1385,232 @@ public class CustomOpsTests extends BaseNd4jTest { @Test @Ignore public void testRgbToHsv() { - INDArray expected = Nd4j.createFromArray(new float[]{6.75000000e+01f, 2.54545455e-01f, 8.62745098e-01f, 1.80000000e+02f, - 3.27777778e-01f, 7.05882353e-01f, 1.35066079e+02f, 9.26530612e-01f, - 9.60784314e-01f, 7.45341615e-01f, 6.85106383e-01f, 9.21568627e-01f, - 2.78688525e+02f, 7.85407725e-01f, 9.13725490e-01f, 2.10989011e+01f, - 4.76439791e-01f, 7.49019608e-01f, 2.89038462e+02f, 8.48979592e-01f, - 9.60784314e-01f, 1.56416185e+02f, 6.92000000e-01f, 9.80392157e-01f, - 3.52881356e+02f, 5.31531532e-01f, 4.35294118e-01f, 1.07142857e+01f, - 2.90155440e-01f, 7.56862745e-01f, 3.43384615e+02f, 3.86904762e-01f, - 6.58823529e-01f, 1.78321678e+02f, 7.48691099e-01f, 7.49019608e-01f, - 2.30645161e+02f, 7.78242678e-01f, 9.37254902e-01f, 3.19159664e+02f, - 7.62820513e-01f, 6.11764706e-01f, 2.10126582e+01f, 9.71311475e-01f, - 9.56862745e-01f, 2.90896552e+02f, 5.96707819e-01f, 9.52941176e-01f, - 1.74822335e+02f, 9.42583732e-01f, 8.19607843e-01f, 2.06600985e+02f, - 9.90243902e-01f, 8.03921569e-01f, 1.06883721e+02f, 8.70445344e-01f, - 9.68627451e-01f, 1.95272727e+02f, 6.11111111e-01f, 7.05882353e-01f}).reshape(5,4,3); - INDArray input = Nd4j.createFromArray(new float[]{213.f, 220.f, 164.f, 121.f, 180.f, 180.f, 18.f, 245.f, 75.f, 235.f, 76.f, 74.f, 168.f, - 50.f, 233.f, 191.f, 132.f, 100.f, 207.f, 37.f, 245.f, 77.f, 250.f, 182.f, 111.f, 52.f, - 59.f, 193.f, 147.f, 137.f, 168.f, 103.f, 121.f, 48.f, 191.f, 187.f, 53.f, 82.f, 239.f, - 156.f, 37.f, 118.f, 244.f, 90.f, 7.f, 221.f, 98.f, 243.f, 12.f, 209.f, 192.f, 2.f, - 115.f, 205.f, 79.f, 247.f, 32.f, 70.f, 152.f, 180.f}).reshape(5,4,3); + INDArray expected = Nd4j.createFromArray(new float[]{ + 0.545678377f, 0.644941628f, 0.461456001f, 0.588904262f, 0.725874603f, + 0.517642438f, 0.0869259685f, 0.54742825f, 0.413571358f, 0.890151322f, + 0.928968489f, 0.684074104f, 0.52110225f, 0.753103435f, 0.913557053f, + 0.46850124f, 0.761800349f, 0.237176552f, 0.90049392f, 0.965541422f, + 0.486593395f, 0.263826847f, 0.290193319f, 0.148351923f, 0.674094439f, + 0.0361763388f, 0.3721793f, 0.823592246f, 0.524110138f, 0.2204483f, + 0.632020354f, 0.637001634f, 0.216262609f, 0.279114306f, 0.25007084f, + 0.30433768f, 0.0448598303f, 0.586083114f, 0.978048146f, 0.91390729f, + 0.385092884f, 0.218390301f, 0.762684941f, 0.505838513f, 0.366362303f, + 0.931746006f, 0.00208298792f, 0.875348926f, 0.428009957f, 0.270003974f, + 0.313204288f, 0.775881767f, 0.367065936f, 0.164243385f, 0.644775152f, + 0.575452209f, 0.911922634f, 0.0581932105f, 0.437950462f, 0.946475744f + }).reshape(5,4,3); + INDArray input = Nd4j.createFromArray(new float[]{ + 0.262831867f, 0.723622441f, 0.740797927f, 0.717254877f, 0.430244058f, + 0.418478161f, 0.906427443f, 0.199753001f, 0.725874603f, 0.890151322f, + 0.928968489f, 0.684074104f, 0.312434604f, 0.991390795f, 0.163174023f, + 0.268038541f, 0.361258626f, 0.685067773f, 0.682347894f, 0.84635365f, + 0.761800349f, 0.753103435f, 0.913557053f, 0.965541422f, 0.112067183f, + 0.540247589f, 0.280050347f, 0.106776128f, 0.679180562f, 0.870388806f, + 0.604331017f, 0.630475283f, 0.674094439f, 0.279114306f, 0.632020354f, + 0.823592246f, 0.490824632f, 0.75257351f, 0.129888852f, 0.849081645f, + 0.883509099f, 0.765611768f, 0.997870266f, 0.446510047f, 0.385092884f, + 0.931746006f, 0.978048146f, 0.91390729f, 0.685308874f, 0.0834472676f, + 0.396037966f, 0.756701186f, 0.597481251f, 0.784472764f, 0.514242649f, + 0.392005324f, 0.911922634f, 0.270003974f, 0.644775152f, 0.946475744f + }).reshape(5,4,3); RgbToHsv op = new RgbToHsv(input); INDArray[] ret = Nd4j.exec(op); assertEquals(ret[0], expected); } // Exact copy of libnd4j test - @Ignore @Test public void testHsvToRgb() { - INDArray input = Nd4j.createFromArray(new float[]{130.f, 61.f, 239.f, 117.f, 16.f, 168.f, 181.f, 223.f, 0.f, 49.f, 195.f, 195.f, 131.f, - 153.f, 78.f, 86.f, 21.f, 209.f, 101.f, 14.f, 107.f, 191.f, 98.f, 210.f}).reshape(8,3); + INDArray input = Nd4j.createFromArray(new float[]{0.705504596f, 0.793608069f, 0.65870738f, 0.848827183f, 0.920532584f, + 0.887555957f, 0.72317636f, 0.563831031f, 0.773604929f, 0.269532293f, + 0.332347751f, 0.111181192f}).reshape(4,3); - INDArray expected = Nd4j.createFromArray(new float[]{263.25842697f, 0.74476987f, 0.9372549f, 279.86842105f, - 0.9047619f, 0.65882353f, 71.30044843f, 1.f, 0.8745098f, 180.f, 0.74871795f, 0.76470588f, - 77.6f, 0.49019608f, 0.6f, 260.74468085f, - 0.89952153f, 0.81960784f, 296.12903226f, 0.86915888f, - 0.41960784f, 289.82142857f, 0.53333333f, 0.82352941f}).reshape(8,3); + INDArray expected = Nd4j.createFromArray(new float[]{0.257768334f, 0.135951888f, 0.65870738f, 0.887555957f, 0.0705317783f, + 0.811602857f, 0.485313689f, 0.337422464f, 0.773604929f, 0.0883753772f, + 0.111181192f, 0.074230373f}).reshape(4,3); HsvToRgb op = new HsvToRgb(input); INDArray[] ret = Nd4j.exec(op); assertEquals(ret[0], expected); - } - @Ignore @Test public void testHsvToRgb_1() { /* Emulation of simple TF test: image = tf.random_uniform(shape = [1,1,3]) tf.image.hsv_to_rgb(image)*/ - INDArray image = Nd4j.createFromArray(new float[]{0.7788f, 0.8012f, 0.7244f}). + INDArray image = Nd4j.createFromArray(new float[]{0.778785586f,0.801197767f,0.724374652f}). reshape(1,1,3); HsvToRgb op = new HsvToRgb(image); INDArray[] ret = Nd4j.exec(op); - INDArray expected = Nd4j.createFromArray(new float[]{0.53442812f,0.144007295f,0.724374652f}).reshape(1,1,3); + System.out.println(ret[0].toStringFull()); + INDArray expected = Nd4j.createFromArray(new float[]{ 0.53442812f, 0.144007325f, 0.724374652f}).reshape(1,1,3); assertEquals(expected, ret[0]); } - @Ignore @Test public void testRgbToHsv_1() { /* Emulation of simple TF test: image = tf.random_uniform(shape = [1,2,3]) tf.image.rgb_to_hsv(image)*/ - INDArray image = Nd4j.createFromArray(new float[]{0.7788f,0.8012f,0.7244f, - 0.2309f,0.7271f,0.1804f}).reshape(1,2,3); + INDArray image = Nd4j.createFromArray(new float[]{0.778785586f,0.801197767f,0.724374652f, + 0.230894327f, 0.727141261f, 0.180390716f }).reshape(2,3); RgbToHsv op = new RgbToHsv(image); INDArray[] ret = Nd4j.exec(op); - INDArray expected = Nd4j.createFromArray(new float[]{0.215289578f, 0.095885336f, 0.801197767f, - 0.317938268f, 0.751917899f, 0.727141261f}).reshape(1,2,3); + INDArray expected = Nd4j.createFromArray(new float[]{0.215289578f,0.095885336f,0.801197767f, + 0.317938268f,0.751917899f,0.727141261f}).reshape(2,3); assertEquals(expected, ret[0]); } + + @Test + public void testLu() { + INDArray input = Nd4j.createFromArray(new float[]{1.f, 2.f, 3.f, 0.f, 2.f, 3.f, 0.f, 0.f, 7.f}) + .reshape(3,3); + Lu op = new Lu(input); + INDArray[] ret = Nd4j.exec(op); + + INDArray expected = Nd4j.createFromArray(new float[]{1.f, 2.f, 3.f, 0.f, 2.f, 3.f, 0.f, 0.f, 7f}).reshape(3,3); + assertEquals(expected, ret[0]); + } + + @Test + public void testRgbToYiq() { + INDArray image = Nd4j.createFromArray(new float[]{ + 0.48055f , 0.80757356f, 0.2564435f , 0.94277316f, 0.17006584f, + 0.33366168f, 0.41727918f, 0.54528666f, 0.48942474f, 0.3305715f , + 0.98633456f, 0.00158441f, 0.97605824f, 0.02462568f, 0.14837205f, + 0.00112842f, 0.99260217f, 0.9585542f , 0.41196227f, 0.3095014f , + 0.6620493f , 0.30888894f, 0.3122602f , 0.7993488f , 0.86656475f, + 0.5997049f , 0.9776477f , 0.72481847f, 0.7835693f , 0.14649455f, + 0.3573504f , 0.33301765f, 0.7853056f , 0.25830218f, 0.59289205f, + 0.41357264f, 0.5934154f , 0.72647524f, 0.6623308f , 0.96197623f, + 0.0720306f , 0.23853847f, 0.1427159f , 0.19581454f, 0.06766324f, + 0.10614152f, 0.26093867f, 0.9584985f , 0.01258832f, 0.8160156f , + 0.56506383f, 0.08418505f, 0.86440504f, 0.6807802f , 0.20662387f, + 0.4153733f , 0.76146203f, 0.50057423f, 0.08274968f, 0.9521758f + }).reshape(5,4,3); + + INDArray expected = Nd4j.createFromArray(new float[]{ + 0.64696468f, -0.01777124f, -0.24070648f, 0.41975525f, 0.40788622f, + 0.21433232f, 0.50064416f, -0.05832884f, -0.04447775f, 0.67799989f, + -0.07432612f, -0.44518381f, 0.32321111f, 0.52719408f, 0.2397369f , + 0.69227005f, -0.57987869f, -0.22032876f, 0.38032767f, -0.05223263f, + 0.13137188f, 0.3667803f , -0.15853189f, 0.15085728f, 0.72258149f, + 0.03757231f, 0.17403452f, 0.69337627f, 0.16971045f, -0.21071186f, + 0.39185397f, -0.13084008f, 0.145886f , 0.47240727f, -0.1417591f , + -0.12659159f, 0.67937788f, -0.05867803f, -0.04813048f, 0.35710624f, + 0.47681283f, 0.24003804f, 0.1653288f , 0.00953913f, -0.05111816f, + 0.29417614f, -0.31640032f, 0.18433114f, 0.54718234f, -0.39812097f, + -0.24805083f, 0.61018603f, -0.40592682f, -0.22219216f, 0.39241133f, + -0.23560742f, 0.06353694f, 0.3067938f , -0.0304029f , 0.35893188f + }).reshape(5,4,3); + + RgbToYiq op = new RgbToYiq(image); + INDArray[] ret = Nd4j.exec(op); + assertEquals(expected, ret[0]); + } + + @Test + public void testYiqToRgb() { + INDArray image = Nd4j.createFromArray(new float[]{ + 0.775258899f, -0.288912386f, -0.132725924f, 0.0664454922f, -0.212469354f, + 0.455438733f, 0.418221354f, 0.349350512f, 0.145902053f, 0.947576523f, + -0.471601307f, 0.263960421f, 0.700227439f, 0.32434237f, -0.278446227f, + 0.130805135f, -0.438441873f, 0.187127829f, 0.0276055578f, -0.179727226f, + 0.305075705f, 0.716282248f, 0.278215706f, -0.44586885f, 0.76971364f, + 0.131288841f, -0.141177326f, 0.900081575f, -0.0788725987f, 0.14756602f, + 0.387832165f, 0.229834676f, 0.47921446f, 0.632930398f, 0.0443540029f, + -0.268817365f, 0.0977194682f, -0.141669706f, -0.140715122f, 0.946808815f, + -0.52525419f, -0.106209636f, 0.659476519f, 0.391066104f, 0.426448852f, + 0.496989518f, -0.283434421f, -0.177366048f, 0.715208411f, -0.496444523f, + 0.189553142f, 0.616444945f, 0.345852494f, 0.447739422f, 0.224696323f, + 0.451372236f, 0.298027098f, 0.446561724f, -0.187599331f, -0.448159873f + }).reshape(5,4,3); + + INDArray expected = Nd4j.createFromArray(new float[]{ + 0.416663059f, 0.939747555f, 0.868814286f, 0.146075352f, -0.170521997f, + 1.07776645f, 0.842775284f, 0.228765106f, 0.280231822f, 0.660605291f, + 0.905021825f, 1.91936605f, 0.837427991f, 0.792213732f, -0.133271854f, + -0.17216571f, 0.128957025f, 0.934955336f, 0.0451873479f, -0.120952621f, + 0.746436225f, 0.705446224f, 0.929172217f, -0.351493549f, 0.807577594f, + 0.825371955f, 0.383812296f, 0.916293093f, 0.82603058f, 1.23885956f, + 0.905059196f, 0.015164554f, 0.950156781f, 0.508443732f, 0.794845279f, + 0.12571529f, -0.125074273f, 0.227326869f, 0.0147000261f, 0.378735409f, + 1.15842402f, 1.34712305f, 1.2980804f, 0.277102016f, 0.953435072f, + 0.115916842f, 0.688879376f, 0.508405162f, 0.35829352f, 0.727568094f, + 1.58768577f, 1.22504294f, 0.232589777f, 0.996727258f, 0.841224629f, + -0.0909671176f, 0.233051388f, -0.0110094378f, 0.787642119f, -0.109582274f + }).reshape(5,4,3); + + YiqToRgb op = new YiqToRgb(image); + INDArray[] ret = Nd4j.exec(op); + assertEquals(expected, ret[0]); + } + + @Test + public void testRgbToGrayscale() { + INDArray image = Nd4j.createFromArray(new float[]{ + 1.7750e+01f, -7.1062e+01f, -1.0019e+02f,-2.3406e+01f, 5.2094e+01f, + 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f,3.3562e+01f, + -5.8844e+01f, 2.2750e+01f, -1.0477e+01f, 7.7344e+00f, 9.5469e+00f, + 2.1391e+01f, -8.5312e+01f, 7.5830e-01f,2.3125e+01f, 1.8145e+00f, + 1.4602e+01f,-4.5859e+00f, 3.9344e+01f, 1.1617e+01f,-8.6562e+01f, + 1.0038e+02f, 6.7938e+01f,5.9961e+00f, 6.7812e+01f, 2.9734e+01f, + 2.9609e+01f, -6.1438e+01f, 1.7750e+01f,6.8562e+01f, -7.4414e+00f, + 3.9656e+01f,1.1641e+01f, -2.7516e+01f, 6.7562e+01f,7.8438e+01f, + 5.4883e+00f, 2.9438e+01f,-3.1344e+01f, 6.5125e+01f, + 1.2695e+01f,4.0531e+01f, -6.1211e+00f, 6.2219e+01f,4.6812e+01f, + 5.2250e+01f, -1.1414e+01f,1.5404e-02f, 2.9938e+01f, 5.6719e+00f, + -2.0125e+01f, 2.1531e+01f, 6.2500e+01f,7.2188e+01f, 9.3750e+00f, + -4.8125e+01f + }).reshape(5,4,3); + + INDArray expected = Nd4j.createFromArray(new float[]{ + -47.82958221f, 34.46305847f, 21.36137581f, -21.91625023f,2.49686432f, + -43.59792709f, 9.64180183f, 23.04854202f,40.7946167f, 44.98754883f, + -25.19047546f, 20.64586449f,-4.97033119f, 30.0226841f, 30.30688286f, + 15.61459541f,43.36166f, 18.22480774f, 13.74833488f, 21.59387016f + }).reshape(5,4,1); + + RgbToGrayscale op = new RgbToGrayscale(image); + INDArray[] ret = Nd4j.exec(op); + assertEquals(expected, ret[0]); + } + + @Test + public void testRgbToYuv() { + INDArray image = Nd4j.createFromArray(new float[]{ + 10f,50f,200f + }); + + INDArray expected = Nd4j.createFromArray(new float[]{ + 55.14f , 71.2872001f, -39.6005542f + }); + + RgbToYuv op = new RgbToYuv(image); + INDArray[] ret = Nd4j.exec(op); + assertEquals(expected, ret[0]); + } + + @Test + public void testYuvToRgb() { + INDArray image = Nd4j.createFromArray(new float[]{ + 55.14f , 71.2872001f, -39.6005542f + }); + + INDArray expected = Nd4j.createFromArray(new float[]{ + 10f, 50f, 200f + }); + YuvToRgb op = new YuvToRgb(image); + INDArray[] ret = Nd4j.exec(op); + assertEquals(expected, ret[0]); + } + + @Test + public void testRgbToYiqEmpty() { + INDArray image = Nd4j.create(0,4,3); + RgbToYiq op = new RgbToYiq(image); + INDArray[] ret = Nd4j.exec(op); + assertArrayEquals(image.shape(), ret[0].shape()); + } }