Lu wrapper and tests fixes (#144)

* Tests fixed

* Lu added

* Test fixed

* Default timeout

* Tests timeouts fixed.

* TF import fix

* Timeouts added

* Timeout fixed.

* Test corrected

* rgb and yiq conversion ops added

* Converter ops added

* Header

* Yuv converters

* API added

* Empty test for matmul

* Explanation

* skip gemm/gemv on empty inputs

Signed-off-by: raver119 <raver119@gmail.com>

* Test added

* Correct test

* one more empty pass-through for mmul

Signed-off-by: raver119 <raver119@gmail.com>

* Cleanup

* Test added

* Test fixed

* Added missing mapping

* Added missing mapping

Co-authored-by: raver119 <raver119@gmail.com>
master
Alexander Stoyakin 2019-12-30 14:06:12 +02:00 committed by raver119
parent 9b329d2601
commit 010744ef9c
23 changed files with 716 additions and 46 deletions

View File

@ -20,11 +20,9 @@ import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.base.MnistFetcher; import org.deeplearning4j.base.MnistFetcher;
import org.deeplearning4j.common.resources.DL4JResources; import org.deeplearning4j.common.resources.DL4JResources;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.junit.AfterClass; import org.junit.*;
import org.junit.BeforeClass;
import org.junit.ClassRule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder; import org.junit.rules.TemporaryFolder;
import org.junit.rules.Timeout;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition; import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.DataSet;
@ -47,6 +45,8 @@ public class MnistFetcherTest extends BaseDL4JTest {
@ClassRule @ClassRule
public static TemporaryFolder testDir = new TemporaryFolder(); public static TemporaryFolder testDir = new TemporaryFolder();
@Rule
public Timeout timeout = Timeout.seconds(300);
@BeforeClass @BeforeClass
public static void setup() throws Exception { public static void setup() throws Exception {

View File

@ -72,7 +72,7 @@ import static org.nd4j.linalg.indexing.NDArrayIndex.point;
public class RecordReaderDataSetiteratorTest extends BaseDL4JTest { public class RecordReaderDataSetiteratorTest extends BaseDL4JTest {
@Rule @Rule
protected Timeout timeout = Timeout.seconds(300); public Timeout timeout = Timeout.seconds(300);
@Override @Override
public DataType getDataType(){ public DataType getDataType(){

View File

@ -71,7 +71,7 @@ public class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest {
public TemporaryFolder temporaryFolder = new TemporaryFolder(); public TemporaryFolder temporaryFolder = new TemporaryFolder();
@Rule @Rule
protected Timeout timeout = Timeout.seconds(300); public Timeout timeout = Timeout.seconds(300);
@Test @Test
public void testsBasic() throws Exception { public void testsBasic() throws Exception {

View File

@ -17,7 +17,9 @@
package org.deeplearning4j.datasets.fetchers; package org.deeplearning4j.datasets.fetchers;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.rules.Timeout;
import java.io.File; import java.io.File;
@ -28,6 +30,9 @@ import static org.junit.Assert.assertTrue;
*/ */
public class SvhnDataFetcherTest extends BaseDL4JTest { public class SvhnDataFetcherTest extends BaseDL4JTest {
@Rule
public Timeout timeout = Timeout.seconds(600);
@Test @Test
public void testSvhnDataFetcher() throws Exception { public void testSvhnDataFetcher() throws Exception {
SvhnDataFetcher fetch = new SvhnDataFetcher(); SvhnDataFetcher fetch = new SvhnDataFetcher();

View File

@ -40,7 +40,7 @@ import static org.junit.Assert.*;
public class MultipleEpochsIteratorTest extends BaseDL4JTest { public class MultipleEpochsIteratorTest extends BaseDL4JTest {
@Rule @Rule
protected Timeout timeout = Timeout.seconds(300); public Timeout timeout = Timeout.seconds(300);
@Test @Test
public void testNextAndReset() throws Exception { public void testNextAndReset() throws Exception {

View File

@ -19,7 +19,9 @@ package org.deeplearning4j.datasets.iterator;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.impl.EmnistDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.EmnistDataSetIterator;
import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.rules.Timeout;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.DataSet;
@ -33,6 +35,9 @@ import static org.junit.Assert.*;
@Slf4j @Slf4j
public class TestEmnistDataSetIterator extends BaseDL4JTest { public class TestEmnistDataSetIterator extends BaseDL4JTest {
@Rule
public Timeout timeout = Timeout.seconds(600);
@Override @Override
public DataType getDataType(){ public DataType getDataType(){
return DataType.FLOAT; return DataType.FLOAT;

View File

@ -22,6 +22,7 @@ import org.junit.After;
import org.junit.Before; import org.junit.Before;
import org.junit.Rule; import org.junit.Rule;
import org.junit.rules.TestName; import org.junit.rules.TestName;
import org.junit.rules.Timeout;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner; import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
@ -36,6 +37,9 @@ import java.util.Properties;
@Slf4j @Slf4j
public class BaseDL4JTest { public class BaseDL4JTest {
@Rule
public Timeout timeout = Timeout.seconds(600);
@Rule @Rule
public TestName name = new TestName(); public TestName name = new TestName();

View File

@ -22,6 +22,7 @@ import org.junit.After;
import org.junit.Before; import org.junit.Before;
import org.junit.Rule; import org.junit.Rule;
import org.junit.rules.TestName; import org.junit.rules.TestName;
import org.junit.rules.Timeout;
import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.memory.MemoryWorkspace;
@ -37,6 +38,9 @@ import java.util.Properties;
@Slf4j @Slf4j
public class BaseDL4JTest { public class BaseDL4JTest {
@Rule
public Timeout timeout = Timeout.seconds(600);
@Rule @Rule
public TestName name = new TestName(); public TestName name = new TestName();

View File

@ -202,6 +202,9 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, con
if(C == nullptr) if(C == nullptr)
C = new NDArray(outOrder, {M,N}, DataTypeUtils::pickPairwiseResultType(A->dataType(), B->dataType()), A->getContext()); C = new NDArray(outOrder, {M,N}, DataTypeUtils::pickPairwiseResultType(A->dataType(), B->dataType()), A->getContext());
if (C->isEmpty())
return C;
const auto aType = A->dataType(); const auto aType = A->dataType();
const auto bType = B->dataType(); const auto bType = B->dataType();
const auto cType = C->dataType(); const auto cType = C->dataType();
@ -307,6 +310,9 @@ NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, nd4j::NDArray*
if(Y == nullptr) if(Y == nullptr)
Y = new NDArray(outOrder, {M}, DataTypeUtils::pickPairwiseResultType(A->dataType(), X->dataType()), A->getContext()); Y = new NDArray(outOrder, {M}, DataTypeUtils::pickPairwiseResultType(A->dataType(), X->dataType()), A->getContext());
if (Y->isEmpty())
return Y;
const int incx = X->stridesOf()[xLenDim]; const int incx = X->stridesOf()[xLenDim];
const int incy = Y->stridesOf()[yLenDim]; const int incy = Y->stridesOf()[yLenDim];
@ -511,6 +517,9 @@ NDArray* MmulHelper::mmulNxN(const NDArray* A, const NDArray* B, NDArray* C, con
C = new NDArray(outOrder, cExpectedShape, B->dataType()); C = new NDArray(outOrder, cExpectedShape, B->dataType());
} }
if (C->isEmpty())
return C;
const int cRank = C->rankOf(); const int cRank = C->rankOf();
const int aMaxis(aRank-2), aKaxis(aRank-1), bKaxis(bRank-2), bNaxis(bRank-1), cMaxis(cRank-2), cNaxis(cRank-1); const int aMaxis(aRank-2), aKaxis(aRank-1), bKaxis(bRank-2), bNaxis(bRank-1), cMaxis(cRank-2), cNaxis(cRank-1);

View File

@ -235,6 +235,9 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, dou
if(C == nullptr) if(C == nullptr)
C = new NDArray(outOrder, {M,N}, DataTypeUtils::pickPairwiseResultType(A->dataType(), B->dataType()), A->getContext()); C = new NDArray(outOrder, {M,N}, DataTypeUtils::pickPairwiseResultType(A->dataType(), B->dataType()), A->getContext());
if (C->isEmpty())
return C;
const int major = Environment::getInstance()->capabilities()[AffinityManager::currentDeviceId()].first(); const int major = Environment::getInstance()->capabilities()[AffinityManager::currentDeviceId()].first();
const auto aType = A->dataType(); const auto aType = A->dataType();
@ -376,6 +379,9 @@ NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, nd4j::NDArray*
if(Y == nullptr) if(Y == nullptr)
Y = new NDArray(outOrder, {M}, DataTypeUtils::pickPairwiseResultType(A->dataType(), X->dataType()), A->getContext()); Y = new NDArray(outOrder, {M}, DataTypeUtils::pickPairwiseResultType(A->dataType(), X->dataType()), A->getContext());
if (Y->isEmpty())
return Y;
const int incx = X->strideAt(xLenDim); const int incx = X->strideAt(xLenDim);
const int incy = Y->strideAt(yLenDim); const int incy = Y->strideAt(yLenDim);
@ -634,6 +640,9 @@ NDArray* MmulHelper::mmulNxN(const NDArray* A, const NDArray* B, NDArray* C, con
else else
C = new NDArray(outOrder, cExpectedShape, DataTypeUtils::pickPairwiseResultType(A->dataType(), B->dataType()), A->getContext()); C = new NDArray(outOrder, cExpectedShape, DataTypeUtils::pickPairwiseResultType(A->dataType(), B->dataType()), A->getContext());
if (C->isEmpty())
return C;
const int cRank = C->rankOf(); const int cRank = C->rankOf();
const int aMaxis(aRank-2), aKaxis(aRank-1), bKaxis(bRank-2), bNaxis(bRank-1), cMaxis(cRank-2), cNaxis(cRank-1); const int aMaxis(aRank-2), aKaxis(aRank-1), bKaxis(bRank-2), bNaxis(bRank-1), cMaxis(cRank-2), cNaxis(cRank-1);

View File

@ -236,6 +236,9 @@ nd4j::NDArray* MmulHelper::mmul(const nd4j::NDArray* A, const nd4j::NDArray* B,
throw std::invalid_argument(""); throw std::invalid_argument("");
} }
if (z->isEmpty())
return;
NDArray* xT(const_cast<NDArray*>(x)), *yT(const_cast<NDArray*>(y)), *zT(z); NDArray* xT(const_cast<NDArray*>(x)), *yT(const_cast<NDArray*>(y)), *zT(z);
if((transX && xRank > 1) || (transY && yRank > 1)) { if((transX && xRank > 1) || (transY && yRank > 1)) {

View File

@ -323,4 +323,35 @@ TEST_F(EmptyTests, test_empty_reshape_1) {
delete result0; delete result0;
delete result1; delete result1;
} }
TEST_F(EmptyTests, test_empty_matmul_1) {
auto x = NDArrayFactory::create<float>('c', {0, 1});
auto y = NDArrayFactory::create<float>('c', {1, 0});
auto e = NDArrayFactory::create<float>('c', {0, 0});
nd4j::ops::matmul op;
auto result = op.execute({&x, &y}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_EQ(e, *z);
delete result;
}
TEST_F(EmptyTests, test_empty_matmul_2) {
auto x = NDArrayFactory::create<float>('c', {1, 0, 4});
auto y = NDArrayFactory::create<float>('c', {1, 4, 0});
auto e = NDArrayFactory::create<float>('c', {1, 0, 0});
nd4j::ops::matmul op;
auto result = op.execute({&x, &y}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_EQ(e, *z);
delete result;
}

View File

@ -274,3 +274,7 @@ To build libnd4j with MKL:
Then build libnd4j as before. You may have to be careful about having multiple BLAS implementations on your path. Ideally, have only MKL on the path while building libnd4j. Then build libnd4j as before. You may have to be careful about having multiple BLAS implementations on your path. Ideally, have only MKL on the path while building libnd4j.
Note: you may be able to get some additional performance on hyperthreaded processors by setting the system/environment variable MKL_DYNAMIC to have the value 'false'. Note: you may be able to get some additional performance on hyperthreaded processors by setting the system/environment variable MKL_DYNAMIC to have the value 'false'.
float16_nhcw
float16_nhwc

View File

@ -138,4 +138,48 @@ public class SDImage extends SDOps {
SDVariable out = new HsvToRgb(sd, input).outputVariable(); SDVariable out = new HsvToRgb(sd, input).outputVariable();
return updateVariableNameAndReference(out, name); return updateVariableNameAndReference(out, name);
} }
/**
* Converting array from RGB to YIQ format
* @param name name
* @param input 3D image
* @return 3D image
*/
public SDVariable rgbToYiq(String name, @NonNull SDVariable input) {
SDVariable out = new RgbToYiq(sd, input).outputVariable();
return updateVariableNameAndReference(out, name);
}
/**
* Converting image from YIQ to RGB format
* @param name name
* @param input 3D image
* @return 3D image
*/
public SDVariable yiqToRgb(String name, @NonNull SDVariable input) {
SDVariable out = new YiqToRgb(sd, input).outputVariable();
return updateVariableNameAndReference(out, name);
}
/**
* Converting array from RGB to YUV format
* @param name name
* @param input 3D image
* @return 3D image
*/
public SDVariable rgbToYuv(String name, @NonNull SDVariable input) {
SDVariable out = new RgbToYuv(sd, input).outputVariable();
return updateVariableNameAndReference(out, name);
}
/**
* Converting image from YUV to RGB format
* @param name name
* @param input 3D image
* @return 3D image
*/
public SDVariable yuvToRgb(String name, @NonNull SDVariable input) {
SDVariable out = new YuvToRgb(sd, input).outputVariable();
return updateVariableNameAndReference(out, name);
}
} }

View File

@ -593,6 +593,11 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.custom.AdjustContrastV2.class, org.nd4j.linalg.api.ops.custom.AdjustContrastV2.class,
org.nd4j.linalg.api.ops.custom.HsvToRgb.class, org.nd4j.linalg.api.ops.custom.HsvToRgb.class,
org.nd4j.linalg.api.ops.custom.RgbToHsv.class, org.nd4j.linalg.api.ops.custom.RgbToHsv.class,
org.nd4j.linalg.api.ops.custom.RgbToYiq.class,
org.nd4j.linalg.api.ops.custom.RgbToGrayscale.class,
org.nd4j.linalg.api.ops.custom.YiqToRgb.class,
org.nd4j.linalg.api.ops.custom.RgbToYuv.class,
org.nd4j.linalg.api.ops.custom.YuvToRgb.class,
org.nd4j.linalg.api.ops.custom.BitCast.class, org.nd4j.linalg.api.ops.custom.BitCast.class,
org.nd4j.linalg.api.ops.custom.CompareAndBitpack.class, org.nd4j.linalg.api.ops.custom.CompareAndBitpack.class,
org.nd4j.linalg.api.ops.custom.DivideNoNan.class, org.nd4j.linalg.api.ops.custom.DivideNoNan.class,
@ -609,7 +614,8 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.custom.ToggleBits.class, org.nd4j.linalg.api.ops.custom.ToggleBits.class,
org.nd4j.linalg.api.ops.custom.Igamma.class, org.nd4j.linalg.api.ops.custom.Igamma.class,
org.nd4j.linalg.api.ops.custom.Igammac.class, org.nd4j.linalg.api.ops.custom.Igammac.class,
org.nd4j.linalg.api.ops.custom.Digamma.class org.nd4j.linalg.api.ops.custom.Digamma.class,
org.nd4j.linalg.api.ops.custom.Lu.class
); );
static { static {

View File

@ -0,0 +1,69 @@
/* ******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.custom;
import lombok.NoArgsConstructor;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
@NoArgsConstructor
public class Lu extends DynamicCustomOp {
private DataType indexDataType;
public Lu(INDArray input) {
addInputArgument(input);
}
public Lu(SameDiff sameDiff, SDVariable input) {
super(sameDiff, new SDVariable[]{input});
}
@Override
public String opName() {
return "lu";
}
@Override
public String tensorflowName() {
return "Lu";
}
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
if (attributesForNode.containsKey("output_idx_type")){
indexDataType = TFGraphMapper.convertType(attributesForNode.get("output_idx_type").getType());
}
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
int n = args().length;
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
return Arrays.asList(inputDataTypes.get(0), indexDataType);
}
}

View File

@ -0,0 +1,44 @@
/* ******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.custom;
import lombok.NoArgsConstructor;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
@NoArgsConstructor
public class RgbToGrayscale extends DynamicCustomOp {
public RgbToGrayscale(INDArray image) {
addInputArgument(image);
}
public RgbToGrayscale(SameDiff sameDiff, SDVariable image) {
super(sameDiff, new SDVariable[]{image});
}
@Override
public String opName() {
return "rgb_to_grs";
}
@Override
public String tensorflowName() {
return "RgbToGrs";
}
}

View File

@ -0,0 +1,56 @@
/* ******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.custom;
import lombok.NoArgsConstructor;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import java.util.Collections;
import java.util.List;
@NoArgsConstructor
public class RgbToYiq extends DynamicCustomOp {
public RgbToYiq(INDArray input) {
addInputArgument(input);
}
public RgbToYiq(SameDiff sameDiff, SDVariable input) {
super(sameDiff, new SDVariable[]{input});
}
@Override
public String opName() {
return "rgb_to_yiq";
}
@Override
public String tensorflowName() {
return "RgbToYiq";
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
int n = args().length;
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
return Collections.singletonList(inputDataTypes.get(0));
}
}

View File

@ -0,0 +1,56 @@
/* ******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.custom;
import lombok.NoArgsConstructor;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import java.util.Collections;
import java.util.List;
@NoArgsConstructor
public class RgbToYuv extends DynamicCustomOp {
public RgbToYuv(INDArray input) {
addInputArgument(input);
}
public RgbToYuv(SameDiff sameDiff, SDVariable input) {
super(sameDiff, new SDVariable[]{input});
}
@Override
public String opName() {
return "rgb_to_yuv";
}
@Override
public String tensorflowName() {
return "RgbToYuv";
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
int n = args().length;
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
return Collections.singletonList(inputDataTypes.get(0));
}
}

View File

@ -0,0 +1,55 @@
/* ******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.custom;
import lombok.NoArgsConstructor;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import java.util.Collections;
import java.util.List;
@NoArgsConstructor
public class YiqToRgb extends DynamicCustomOp {
public YiqToRgb(INDArray input) {
addInputArgument(input);
}
public YiqToRgb(SameDiff sameDiff, SDVariable input) {
super(sameDiff, new SDVariable[]{input});
}
@Override
public String opName() {
return "yiq_to_rgb";
}
@Override
public String tensorflowName() {
return "YiqToRgb";
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
int n = args().length;
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
return Collections.singletonList(inputDataTypes.get(0));
}
}

View File

@ -0,0 +1,56 @@
/* ******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.custom;
import lombok.NoArgsConstructor;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import java.util.Collections;
import java.util.List;
@NoArgsConstructor
public class YuvToRgb extends DynamicCustomOp {
public YuvToRgb(INDArray input) {
addInputArgument(input);
}
public YuvToRgb(SameDiff sameDiff, SDVariable input) {
super(sameDiff, new SDVariable[]{input});
}
@Override
public String opName() {
return "yuv_to_rgb";
}
@Override
public String tensorflowName() {
return "YuvToRgb";
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
int n = args().length;
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
return Collections.singletonList(inputDataTypes.get(0));
}
}

View File

@ -6081,6 +6081,42 @@ public class Nd4jTestsC extends BaseNd4jTest {
assertEquals(mE, mC); assertEquals(mE, mC);
} }
/*
Analog of this TF code:
a = tf.constant([], shape=[0,1])
b = tf.constant([], shape=[1, 0])
c = tf.matmul(a, b)
*/
@Test
public void testMatmul_Empty() {
val mA = Nd4j.create(0,1);
val mB = Nd4j.create(1,0);
val mC = Nd4j.create(0,0);
val op = DynamicCustomOp.builder("matmul")
.addInputs(mA, mB)
.addOutputs(mC)
.build();
Nd4j.getExecutioner().exec(op);
assertEquals(Nd4j.create(0,0), mC);
}
@Test
public void testMatmul_Empty1() {
val mA = Nd4j.create(1,0, 4);
val mB = Nd4j.create(1,4, 0);
val mC = Nd4j.create(1,0, 0);
val op = DynamicCustomOp.builder("mmul")
.addInputs(mA, mB)
.addOutputs(mC)
.addIntegerArguments(0,0)
.build();
Nd4j.getExecutioner().exec(op);
assertEquals(Nd4j.create(1,0,0), mC);
}
@Test @Test
public void testScalarSqueeze() { public void testScalarSqueeze() {

View File

@ -19,6 +19,7 @@ package org.nd4j.linalg.custom;
import lombok.NonNull; import lombok.NonNull;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import org.apache.commons.lang3.ArrayUtils;
import org.junit.Ignore; import org.junit.Ignore;
import org.junit.Test; import org.junit.Test;
import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.BaseNd4jTest;
@ -1161,6 +1162,23 @@ public class CustomOpsTests extends BaseNd4jTest {
assertArrayEquals(expectedY.shape(), y.shape()); assertArrayEquals(expectedY.shape(), y.shape());
} }
@Test
public void testFusedBatchNormHalf() {
INDArray x = Nd4j.create(DataType.HALF, 1,2,3,4);
//INDArray scale = Nd4j.createFromArray(new float[]{0.7717f, 0.9281f, 0.9846f, 0.4838f});
//INDArray offset = Nd4j.createFromArray(new float[]{0.9441f, 0.5957f, 0.8669f, 0.3502f});
INDArray scale = Nd4j.create(DataType.HALF, 4);
INDArray offset = Nd4j.create(DataType.HALF, 4);
INDArray y = Nd4j.createUninitialized(DataType.HALF, x.shape());
INDArray batchMean = Nd4j.create(4);
INDArray batchVar = Nd4j.create(4);
FusedBatchNorm op = new FusedBatchNorm(x, scale, offset, 0, 1,
y, batchMean, batchVar);
Nd4j.exec(op);
}
@Test @Test
public void testMatrixBandPart() { public void testMatrixBandPart() {
INDArray x = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 2*3*3).reshape(2,3,3); INDArray x = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 2*3*3).reshape(2,3,3);
@ -1367,76 +1385,232 @@ public class CustomOpsTests extends BaseNd4jTest {
@Test @Test
@Ignore @Ignore
public void testRgbToHsv() { public void testRgbToHsv() {
INDArray expected = Nd4j.createFromArray(new float[]{6.75000000e+01f, 2.54545455e-01f, 8.62745098e-01f, 1.80000000e+02f, INDArray expected = Nd4j.createFromArray(new float[]{
3.27777778e-01f, 7.05882353e-01f, 1.35066079e+02f, 9.26530612e-01f, 0.545678377f, 0.644941628f, 0.461456001f, 0.588904262f, 0.725874603f,
9.60784314e-01f, 7.45341615e-01f, 6.85106383e-01f, 9.21568627e-01f, 0.517642438f, 0.0869259685f, 0.54742825f, 0.413571358f, 0.890151322f,
2.78688525e+02f, 7.85407725e-01f, 9.13725490e-01f, 2.10989011e+01f, 0.928968489f, 0.684074104f, 0.52110225f, 0.753103435f, 0.913557053f,
4.76439791e-01f, 7.49019608e-01f, 2.89038462e+02f, 8.48979592e-01f, 0.46850124f, 0.761800349f, 0.237176552f, 0.90049392f, 0.965541422f,
9.60784314e-01f, 1.56416185e+02f, 6.92000000e-01f, 9.80392157e-01f, 0.486593395f, 0.263826847f, 0.290193319f, 0.148351923f, 0.674094439f,
3.52881356e+02f, 5.31531532e-01f, 4.35294118e-01f, 1.07142857e+01f, 0.0361763388f, 0.3721793f, 0.823592246f, 0.524110138f, 0.2204483f,
2.90155440e-01f, 7.56862745e-01f, 3.43384615e+02f, 3.86904762e-01f, 0.632020354f, 0.637001634f, 0.216262609f, 0.279114306f, 0.25007084f,
6.58823529e-01f, 1.78321678e+02f, 7.48691099e-01f, 7.49019608e-01f, 0.30433768f, 0.0448598303f, 0.586083114f, 0.978048146f, 0.91390729f,
2.30645161e+02f, 7.78242678e-01f, 9.37254902e-01f, 3.19159664e+02f, 0.385092884f, 0.218390301f, 0.762684941f, 0.505838513f, 0.366362303f,
7.62820513e-01f, 6.11764706e-01f, 2.10126582e+01f, 9.71311475e-01f, 0.931746006f, 0.00208298792f, 0.875348926f, 0.428009957f, 0.270003974f,
9.56862745e-01f, 2.90896552e+02f, 5.96707819e-01f, 9.52941176e-01f, 0.313204288f, 0.775881767f, 0.367065936f, 0.164243385f, 0.644775152f,
1.74822335e+02f, 9.42583732e-01f, 8.19607843e-01f, 2.06600985e+02f, 0.575452209f, 0.911922634f, 0.0581932105f, 0.437950462f, 0.946475744f
9.90243902e-01f, 8.03921569e-01f, 1.06883721e+02f, 8.70445344e-01f, }).reshape(5,4,3);
9.68627451e-01f, 1.95272727e+02f, 6.11111111e-01f, 7.05882353e-01f}).reshape(5,4,3); INDArray input = Nd4j.createFromArray(new float[]{
INDArray input = Nd4j.createFromArray(new float[]{213.f, 220.f, 164.f, 121.f, 180.f, 180.f, 18.f, 245.f, 75.f, 235.f, 76.f, 74.f, 168.f, 0.262831867f, 0.723622441f, 0.740797927f, 0.717254877f, 0.430244058f,
50.f, 233.f, 191.f, 132.f, 100.f, 207.f, 37.f, 245.f, 77.f, 250.f, 182.f, 111.f, 52.f, 0.418478161f, 0.906427443f, 0.199753001f, 0.725874603f, 0.890151322f,
59.f, 193.f, 147.f, 137.f, 168.f, 103.f, 121.f, 48.f, 191.f, 187.f, 53.f, 82.f, 239.f, 0.928968489f, 0.684074104f, 0.312434604f, 0.991390795f, 0.163174023f,
156.f, 37.f, 118.f, 244.f, 90.f, 7.f, 221.f, 98.f, 243.f, 12.f, 209.f, 192.f, 2.f, 0.268038541f, 0.361258626f, 0.685067773f, 0.682347894f, 0.84635365f,
115.f, 205.f, 79.f, 247.f, 32.f, 70.f, 152.f, 180.f}).reshape(5,4,3); 0.761800349f, 0.753103435f, 0.913557053f, 0.965541422f, 0.112067183f,
0.540247589f, 0.280050347f, 0.106776128f, 0.679180562f, 0.870388806f,
0.604331017f, 0.630475283f, 0.674094439f, 0.279114306f, 0.632020354f,
0.823592246f, 0.490824632f, 0.75257351f, 0.129888852f, 0.849081645f,
0.883509099f, 0.765611768f, 0.997870266f, 0.446510047f, 0.385092884f,
0.931746006f, 0.978048146f, 0.91390729f, 0.685308874f, 0.0834472676f,
0.396037966f, 0.756701186f, 0.597481251f, 0.784472764f, 0.514242649f,
0.392005324f, 0.911922634f, 0.270003974f, 0.644775152f, 0.946475744f
}).reshape(5,4,3);
RgbToHsv op = new RgbToHsv(input); RgbToHsv op = new RgbToHsv(input);
INDArray[] ret = Nd4j.exec(op); INDArray[] ret = Nd4j.exec(op);
assertEquals(ret[0], expected); assertEquals(ret[0], expected);
} }
// Exact copy of libnd4j test // Exact copy of libnd4j test
@Ignore
@Test @Test
public void testHsvToRgb() { public void testHsvToRgb() {
INDArray input = Nd4j.createFromArray(new float[]{130.f, 61.f, 239.f, 117.f, 16.f, 168.f, 181.f, 223.f, 0.f, 49.f, 195.f, 195.f, 131.f, INDArray input = Nd4j.createFromArray(new float[]{0.705504596f, 0.793608069f, 0.65870738f, 0.848827183f, 0.920532584f,
153.f, 78.f, 86.f, 21.f, 209.f, 101.f, 14.f, 107.f, 191.f, 98.f, 210.f}).reshape(8,3); 0.887555957f, 0.72317636f, 0.563831031f, 0.773604929f, 0.269532293f,
0.332347751f, 0.111181192f}).reshape(4,3);
INDArray expected = Nd4j.createFromArray(new float[]{263.25842697f, 0.74476987f, 0.9372549f, 279.86842105f, INDArray expected = Nd4j.createFromArray(new float[]{0.257768334f, 0.135951888f, 0.65870738f, 0.887555957f, 0.0705317783f,
0.9047619f, 0.65882353f, 71.30044843f, 1.f, 0.8745098f, 180.f, 0.74871795f, 0.76470588f, 0.811602857f, 0.485313689f, 0.337422464f, 0.773604929f, 0.0883753772f,
77.6f, 0.49019608f, 0.6f, 260.74468085f, 0.111181192f, 0.074230373f}).reshape(4,3);
0.89952153f, 0.81960784f, 296.12903226f, 0.86915888f,
0.41960784f, 289.82142857f, 0.53333333f, 0.82352941f}).reshape(8,3);
HsvToRgb op = new HsvToRgb(input); HsvToRgb op = new HsvToRgb(input);
INDArray[] ret = Nd4j.exec(op); INDArray[] ret = Nd4j.exec(op);
assertEquals(ret[0], expected); assertEquals(ret[0], expected);
} }
@Ignore
@Test @Test
public void testHsvToRgb_1() { public void testHsvToRgb_1() {
/* Emulation of simple TF test: /* Emulation of simple TF test:
image = tf.random_uniform(shape = [1,1,3]) image = tf.random_uniform(shape = [1,1,3])
tf.image.hsv_to_rgb(image)*/ tf.image.hsv_to_rgb(image)*/
INDArray image = Nd4j.createFromArray(new float[]{0.7788f, 0.8012f, 0.7244f}). INDArray image = Nd4j.createFromArray(new float[]{0.778785586f,0.801197767f,0.724374652f}).
reshape(1,1,3); reshape(1,1,3);
HsvToRgb op = new HsvToRgb(image); HsvToRgb op = new HsvToRgb(image);
INDArray[] ret = Nd4j.exec(op); INDArray[] ret = Nd4j.exec(op);
INDArray expected = Nd4j.createFromArray(new float[]{0.53442812f,0.144007295f,0.724374652f}).reshape(1,1,3); System.out.println(ret[0].toStringFull());
INDArray expected = Nd4j.createFromArray(new float[]{ 0.53442812f, 0.144007325f, 0.724374652f}).reshape(1,1,3);
assertEquals(expected, ret[0]); assertEquals(expected, ret[0]);
} }
@Ignore
@Test @Test
public void testRgbToHsv_1() { public void testRgbToHsv_1() {
/* Emulation of simple TF test: /* Emulation of simple TF test:
image = tf.random_uniform(shape = [1,2,3]) image = tf.random_uniform(shape = [1,2,3])
tf.image.rgb_to_hsv(image)*/ tf.image.rgb_to_hsv(image)*/
INDArray image = Nd4j.createFromArray(new float[]{0.7788f,0.8012f,0.7244f, INDArray image = Nd4j.createFromArray(new float[]{0.778785586f,0.801197767f,0.724374652f,
0.2309f,0.7271f,0.1804f}).reshape(1,2,3); 0.230894327f, 0.727141261f, 0.180390716f }).reshape(2,3);
RgbToHsv op = new RgbToHsv(image); RgbToHsv op = new RgbToHsv(image);
INDArray[] ret = Nd4j.exec(op); INDArray[] ret = Nd4j.exec(op);
INDArray expected = Nd4j.createFromArray(new float[]{0.215289578f, 0.095885336f, 0.801197767f, INDArray expected = Nd4j.createFromArray(new float[]{0.215289578f,0.095885336f,0.801197767f,
0.317938268f, 0.751917899f, 0.727141261f}).reshape(1,2,3); 0.317938268f,0.751917899f,0.727141261f}).reshape(2,3);
assertEquals(expected, ret[0]); assertEquals(expected, ret[0]);
} }
@Test
public void testLu() {
INDArray input = Nd4j.createFromArray(new float[]{1.f, 2.f, 3.f, 0.f, 2.f, 3.f, 0.f, 0.f, 7.f})
.reshape(3,3);
Lu op = new Lu(input);
INDArray[] ret = Nd4j.exec(op);
INDArray expected = Nd4j.createFromArray(new float[]{1.f, 2.f, 3.f, 0.f, 2.f, 3.f, 0.f, 0.f, 7f}).reshape(3,3);
assertEquals(expected, ret[0]);
}
@Test
public void testRgbToYiq() {
INDArray image = Nd4j.createFromArray(new float[]{
0.48055f , 0.80757356f, 0.2564435f , 0.94277316f, 0.17006584f,
0.33366168f, 0.41727918f, 0.54528666f, 0.48942474f, 0.3305715f ,
0.98633456f, 0.00158441f, 0.97605824f, 0.02462568f, 0.14837205f,
0.00112842f, 0.99260217f, 0.9585542f , 0.41196227f, 0.3095014f ,
0.6620493f , 0.30888894f, 0.3122602f , 0.7993488f , 0.86656475f,
0.5997049f , 0.9776477f , 0.72481847f, 0.7835693f , 0.14649455f,
0.3573504f , 0.33301765f, 0.7853056f , 0.25830218f, 0.59289205f,
0.41357264f, 0.5934154f , 0.72647524f, 0.6623308f , 0.96197623f,
0.0720306f , 0.23853847f, 0.1427159f , 0.19581454f, 0.06766324f,
0.10614152f, 0.26093867f, 0.9584985f , 0.01258832f, 0.8160156f ,
0.56506383f, 0.08418505f, 0.86440504f, 0.6807802f , 0.20662387f,
0.4153733f , 0.76146203f, 0.50057423f, 0.08274968f, 0.9521758f
}).reshape(5,4,3);
INDArray expected = Nd4j.createFromArray(new float[]{
0.64696468f, -0.01777124f, -0.24070648f, 0.41975525f, 0.40788622f,
0.21433232f, 0.50064416f, -0.05832884f, -0.04447775f, 0.67799989f,
-0.07432612f, -0.44518381f, 0.32321111f, 0.52719408f, 0.2397369f ,
0.69227005f, -0.57987869f, -0.22032876f, 0.38032767f, -0.05223263f,
0.13137188f, 0.3667803f , -0.15853189f, 0.15085728f, 0.72258149f,
0.03757231f, 0.17403452f, 0.69337627f, 0.16971045f, -0.21071186f,
0.39185397f, -0.13084008f, 0.145886f , 0.47240727f, -0.1417591f ,
-0.12659159f, 0.67937788f, -0.05867803f, -0.04813048f, 0.35710624f,
0.47681283f, 0.24003804f, 0.1653288f , 0.00953913f, -0.05111816f,
0.29417614f, -0.31640032f, 0.18433114f, 0.54718234f, -0.39812097f,
-0.24805083f, 0.61018603f, -0.40592682f, -0.22219216f, 0.39241133f,
-0.23560742f, 0.06353694f, 0.3067938f , -0.0304029f , 0.35893188f
}).reshape(5,4,3);
RgbToYiq op = new RgbToYiq(image);
INDArray[] ret = Nd4j.exec(op);
assertEquals(expected, ret[0]);
}
@Test
public void testYiqToRgb() {
INDArray image = Nd4j.createFromArray(new float[]{
0.775258899f, -0.288912386f, -0.132725924f, 0.0664454922f, -0.212469354f,
0.455438733f, 0.418221354f, 0.349350512f, 0.145902053f, 0.947576523f,
-0.471601307f, 0.263960421f, 0.700227439f, 0.32434237f, -0.278446227f,
0.130805135f, -0.438441873f, 0.187127829f, 0.0276055578f, -0.179727226f,
0.305075705f, 0.716282248f, 0.278215706f, -0.44586885f, 0.76971364f,
0.131288841f, -0.141177326f, 0.900081575f, -0.0788725987f, 0.14756602f,
0.387832165f, 0.229834676f, 0.47921446f, 0.632930398f, 0.0443540029f,
-0.268817365f, 0.0977194682f, -0.141669706f, -0.140715122f, 0.946808815f,
-0.52525419f, -0.106209636f, 0.659476519f, 0.391066104f, 0.426448852f,
0.496989518f, -0.283434421f, -0.177366048f, 0.715208411f, -0.496444523f,
0.189553142f, 0.616444945f, 0.345852494f, 0.447739422f, 0.224696323f,
0.451372236f, 0.298027098f, 0.446561724f, -0.187599331f, -0.448159873f
}).reshape(5,4,3);
INDArray expected = Nd4j.createFromArray(new float[]{
0.416663059f, 0.939747555f, 0.868814286f, 0.146075352f, -0.170521997f,
1.07776645f, 0.842775284f, 0.228765106f, 0.280231822f, 0.660605291f,
0.905021825f, 1.91936605f, 0.837427991f, 0.792213732f, -0.133271854f,
-0.17216571f, 0.128957025f, 0.934955336f, 0.0451873479f, -0.120952621f,
0.746436225f, 0.705446224f, 0.929172217f, -0.351493549f, 0.807577594f,
0.825371955f, 0.383812296f, 0.916293093f, 0.82603058f, 1.23885956f,
0.905059196f, 0.015164554f, 0.950156781f, 0.508443732f, 0.794845279f,
0.12571529f, -0.125074273f, 0.227326869f, 0.0147000261f, 0.378735409f,
1.15842402f, 1.34712305f, 1.2980804f, 0.277102016f, 0.953435072f,
0.115916842f, 0.688879376f, 0.508405162f, 0.35829352f, 0.727568094f,
1.58768577f, 1.22504294f, 0.232589777f, 0.996727258f, 0.841224629f,
-0.0909671176f, 0.233051388f, -0.0110094378f, 0.787642119f, -0.109582274f
}).reshape(5,4,3);
YiqToRgb op = new YiqToRgb(image);
INDArray[] ret = Nd4j.exec(op);
assertEquals(expected, ret[0]);
}
@Test
public void testRgbToGrayscale() {
INDArray image = Nd4j.createFromArray(new float[]{
1.7750e+01f, -7.1062e+01f, -1.0019e+02f,-2.3406e+01f, 5.2094e+01f,
9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f,3.3562e+01f,
-5.8844e+01f, 2.2750e+01f, -1.0477e+01f, 7.7344e+00f, 9.5469e+00f,
2.1391e+01f, -8.5312e+01f, 7.5830e-01f,2.3125e+01f, 1.8145e+00f,
1.4602e+01f,-4.5859e+00f, 3.9344e+01f, 1.1617e+01f,-8.6562e+01f,
1.0038e+02f, 6.7938e+01f,5.9961e+00f, 6.7812e+01f, 2.9734e+01f,
2.9609e+01f, -6.1438e+01f, 1.7750e+01f,6.8562e+01f, -7.4414e+00f,
3.9656e+01f,1.1641e+01f, -2.7516e+01f, 6.7562e+01f,7.8438e+01f,
5.4883e+00f, 2.9438e+01f,-3.1344e+01f, 6.5125e+01f,
1.2695e+01f,4.0531e+01f, -6.1211e+00f, 6.2219e+01f,4.6812e+01f,
5.2250e+01f, -1.1414e+01f,1.5404e-02f, 2.9938e+01f, 5.6719e+00f,
-2.0125e+01f, 2.1531e+01f, 6.2500e+01f,7.2188e+01f, 9.3750e+00f,
-4.8125e+01f
}).reshape(5,4,3);
INDArray expected = Nd4j.createFromArray(new float[]{
-47.82958221f, 34.46305847f, 21.36137581f, -21.91625023f,2.49686432f,
-43.59792709f, 9.64180183f, 23.04854202f,40.7946167f, 44.98754883f,
-25.19047546f, 20.64586449f,-4.97033119f, 30.0226841f, 30.30688286f,
15.61459541f,43.36166f, 18.22480774f, 13.74833488f, 21.59387016f
}).reshape(5,4,1);
RgbToGrayscale op = new RgbToGrayscale(image);
INDArray[] ret = Nd4j.exec(op);
assertEquals(expected, ret[0]);
}
@Test
public void testRgbToYuv() {
INDArray image = Nd4j.createFromArray(new float[]{
10f,50f,200f
});
INDArray expected = Nd4j.createFromArray(new float[]{
55.14f , 71.2872001f, -39.6005542f
});
RgbToYuv op = new RgbToYuv(image);
INDArray[] ret = Nd4j.exec(op);
assertEquals(expected, ret[0]);
}
@Test
public void testYuvToRgb() {
INDArray image = Nd4j.createFromArray(new float[]{
55.14f , 71.2872001f, -39.6005542f
});
INDArray expected = Nd4j.createFromArray(new float[]{
10f, 50f, 200f
});
YuvToRgb op = new YuvToRgb(image);
INDArray[] ret = Nd4j.exec(op);
assertEquals(expected, ret[0]);
}
@Test
public void testRgbToYiqEmpty() {
INDArray image = Nd4j.create(0,4,3);
RgbToYiq op = new RgbToYiq(image);
INDArray[] ret = Nd4j.exec(op);
assertArrayEquals(image.shape(), ret[0].shape());
}
} }