Lu wrapper and tests fixes (#144)
* Tests fixed * Lu added * Test fixed * Default timeout * Tests timeouts fixed. * TF import fix * Timeouts added * Timeout fixed. * Test corrected * rgb and yiq conversion ops added * Converter ops added * Header * Yuv converters * API added * Empty test for matmul * Explanation * skip gemm/gemv on empty inputs Signed-off-by: raver119 <raver119@gmail.com> * Test added * Correct test * one more empty pass-through for mmul Signed-off-by: raver119 <raver119@gmail.com> * Cleanup * Test added * Test fixed * Added missing mapping * Added missing mapping Co-authored-by: raver119 <raver119@gmail.com>master
parent
9b329d2601
commit
010744ef9c
|
@ -20,11 +20,9 @@ import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.base.MnistFetcher;
|
import org.deeplearning4j.base.MnistFetcher;
|
||||||
import org.deeplearning4j.common.resources.DL4JResources;
|
import org.deeplearning4j.common.resources.DL4JResources;
|
||||||
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
||||||
import org.junit.AfterClass;
|
import org.junit.*;
|
||||||
import org.junit.BeforeClass;
|
|
||||||
import org.junit.ClassRule;
|
|
||||||
import org.junit.Test;
|
|
||||||
import org.junit.rules.TemporaryFolder;
|
import org.junit.rules.TemporaryFolder;
|
||||||
|
import org.junit.rules.Timeout;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
|
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
|
||||||
import org.nd4j.linalg.dataset.DataSet;
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
|
@ -47,6 +45,8 @@ public class MnistFetcherTest extends BaseDL4JTest {
|
||||||
|
|
||||||
@ClassRule
|
@ClassRule
|
||||||
public static TemporaryFolder testDir = new TemporaryFolder();
|
public static TemporaryFolder testDir = new TemporaryFolder();
|
||||||
|
@Rule
|
||||||
|
public Timeout timeout = Timeout.seconds(300);
|
||||||
|
|
||||||
@BeforeClass
|
@BeforeClass
|
||||||
public static void setup() throws Exception {
|
public static void setup() throws Exception {
|
||||||
|
|
|
@ -72,7 +72,7 @@ import static org.nd4j.linalg.indexing.NDArrayIndex.point;
|
||||||
public class RecordReaderDataSetiteratorTest extends BaseDL4JTest {
|
public class RecordReaderDataSetiteratorTest extends BaseDL4JTest {
|
||||||
|
|
||||||
@Rule
|
@Rule
|
||||||
protected Timeout timeout = Timeout.seconds(300);
|
public Timeout timeout = Timeout.seconds(300);
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public DataType getDataType(){
|
public DataType getDataType(){
|
||||||
|
|
|
@ -71,7 +71,7 @@ public class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest {
|
||||||
public TemporaryFolder temporaryFolder = new TemporaryFolder();
|
public TemporaryFolder temporaryFolder = new TemporaryFolder();
|
||||||
|
|
||||||
@Rule
|
@Rule
|
||||||
protected Timeout timeout = Timeout.seconds(300);
|
public Timeout timeout = Timeout.seconds(300);
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testsBasic() throws Exception {
|
public void testsBasic() throws Exception {
|
||||||
|
|
|
@ -17,7 +17,9 @@
|
||||||
package org.deeplearning4j.datasets.fetchers;
|
package org.deeplearning4j.datasets.fetchers;
|
||||||
|
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
|
import org.junit.Rule;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
import org.junit.rules.Timeout;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
|
|
||||||
|
@ -28,6 +30,9 @@ import static org.junit.Assert.assertTrue;
|
||||||
*/
|
*/
|
||||||
public class SvhnDataFetcherTest extends BaseDL4JTest {
|
public class SvhnDataFetcherTest extends BaseDL4JTest {
|
||||||
|
|
||||||
|
@Rule
|
||||||
|
public Timeout timeout = Timeout.seconds(600);
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testSvhnDataFetcher() throws Exception {
|
public void testSvhnDataFetcher() throws Exception {
|
||||||
SvhnDataFetcher fetch = new SvhnDataFetcher();
|
SvhnDataFetcher fetch = new SvhnDataFetcher();
|
||||||
|
|
|
@ -40,7 +40,7 @@ import static org.junit.Assert.*;
|
||||||
public class MultipleEpochsIteratorTest extends BaseDL4JTest {
|
public class MultipleEpochsIteratorTest extends BaseDL4JTest {
|
||||||
|
|
||||||
@Rule
|
@Rule
|
||||||
protected Timeout timeout = Timeout.seconds(300);
|
public Timeout timeout = Timeout.seconds(300);
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testNextAndReset() throws Exception {
|
public void testNextAndReset() throws Exception {
|
||||||
|
|
|
@ -19,7 +19,9 @@ package org.deeplearning4j.datasets.iterator;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.datasets.iterator.impl.EmnistDataSetIterator;
|
import org.deeplearning4j.datasets.iterator.impl.EmnistDataSetIterator;
|
||||||
|
import org.junit.Rule;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
import org.junit.rules.Timeout;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.DataSet;
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
|
@ -33,6 +35,9 @@ import static org.junit.Assert.*;
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class TestEmnistDataSetIterator extends BaseDL4JTest {
|
public class TestEmnistDataSetIterator extends BaseDL4JTest {
|
||||||
|
|
||||||
|
@Rule
|
||||||
|
public Timeout timeout = Timeout.seconds(600);
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public DataType getDataType(){
|
public DataType getDataType(){
|
||||||
return DataType.FLOAT;
|
return DataType.FLOAT;
|
||||||
|
|
|
@ -22,6 +22,7 @@ import org.junit.After;
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
import org.junit.Rule;
|
import org.junit.Rule;
|
||||||
import org.junit.rules.TestName;
|
import org.junit.rules.TestName;
|
||||||
|
import org.junit.rules.Timeout;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||||
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
|
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
|
||||||
|
@ -36,6 +37,9 @@ import java.util.Properties;
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class BaseDL4JTest {
|
public class BaseDL4JTest {
|
||||||
|
|
||||||
|
@Rule
|
||||||
|
public Timeout timeout = Timeout.seconds(600);
|
||||||
|
|
||||||
@Rule
|
@Rule
|
||||||
public TestName name = new TestName();
|
public TestName name = new TestName();
|
||||||
|
|
||||||
|
|
|
@ -22,6 +22,7 @@ import org.junit.After;
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
import org.junit.Rule;
|
import org.junit.Rule;
|
||||||
import org.junit.rules.TestName;
|
import org.junit.rules.TestName;
|
||||||
|
import org.junit.rules.Timeout;
|
||||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
import org.nd4j.linalg.api.buffer.DataBuffer;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||||
|
@ -37,6 +38,9 @@ import java.util.Properties;
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class BaseDL4JTest {
|
public class BaseDL4JTest {
|
||||||
|
|
||||||
|
@Rule
|
||||||
|
public Timeout timeout = Timeout.seconds(600);
|
||||||
|
|
||||||
@Rule
|
@Rule
|
||||||
public TestName name = new TestName();
|
public TestName name = new TestName();
|
||||||
|
|
||||||
|
|
|
@ -202,6 +202,9 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, con
|
||||||
if(C == nullptr)
|
if(C == nullptr)
|
||||||
C = new NDArray(outOrder, {M,N}, DataTypeUtils::pickPairwiseResultType(A->dataType(), B->dataType()), A->getContext());
|
C = new NDArray(outOrder, {M,N}, DataTypeUtils::pickPairwiseResultType(A->dataType(), B->dataType()), A->getContext());
|
||||||
|
|
||||||
|
if (C->isEmpty())
|
||||||
|
return C;
|
||||||
|
|
||||||
const auto aType = A->dataType();
|
const auto aType = A->dataType();
|
||||||
const auto bType = B->dataType();
|
const auto bType = B->dataType();
|
||||||
const auto cType = C->dataType();
|
const auto cType = C->dataType();
|
||||||
|
@ -307,6 +310,9 @@ NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, nd4j::NDArray*
|
||||||
if(Y == nullptr)
|
if(Y == nullptr)
|
||||||
Y = new NDArray(outOrder, {M}, DataTypeUtils::pickPairwiseResultType(A->dataType(), X->dataType()), A->getContext());
|
Y = new NDArray(outOrder, {M}, DataTypeUtils::pickPairwiseResultType(A->dataType(), X->dataType()), A->getContext());
|
||||||
|
|
||||||
|
if (Y->isEmpty())
|
||||||
|
return Y;
|
||||||
|
|
||||||
const int incx = X->stridesOf()[xLenDim];
|
const int incx = X->stridesOf()[xLenDim];
|
||||||
const int incy = Y->stridesOf()[yLenDim];
|
const int incy = Y->stridesOf()[yLenDim];
|
||||||
|
|
||||||
|
@ -511,6 +517,9 @@ NDArray* MmulHelper::mmulNxN(const NDArray* A, const NDArray* B, NDArray* C, con
|
||||||
C = new NDArray(outOrder, cExpectedShape, B->dataType());
|
C = new NDArray(outOrder, cExpectedShape, B->dataType());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (C->isEmpty())
|
||||||
|
return C;
|
||||||
|
|
||||||
const int cRank = C->rankOf();
|
const int cRank = C->rankOf();
|
||||||
|
|
||||||
const int aMaxis(aRank-2), aKaxis(aRank-1), bKaxis(bRank-2), bNaxis(bRank-1), cMaxis(cRank-2), cNaxis(cRank-1);
|
const int aMaxis(aRank-2), aKaxis(aRank-1), bKaxis(bRank-2), bNaxis(bRank-1), cMaxis(cRank-2), cNaxis(cRank-1);
|
||||||
|
|
|
@ -235,6 +235,9 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, dou
|
||||||
if(C == nullptr)
|
if(C == nullptr)
|
||||||
C = new NDArray(outOrder, {M,N}, DataTypeUtils::pickPairwiseResultType(A->dataType(), B->dataType()), A->getContext());
|
C = new NDArray(outOrder, {M,N}, DataTypeUtils::pickPairwiseResultType(A->dataType(), B->dataType()), A->getContext());
|
||||||
|
|
||||||
|
if (C->isEmpty())
|
||||||
|
return C;
|
||||||
|
|
||||||
const int major = Environment::getInstance()->capabilities()[AffinityManager::currentDeviceId()].first();
|
const int major = Environment::getInstance()->capabilities()[AffinityManager::currentDeviceId()].first();
|
||||||
|
|
||||||
const auto aType = A->dataType();
|
const auto aType = A->dataType();
|
||||||
|
@ -376,6 +379,9 @@ NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, nd4j::NDArray*
|
||||||
if(Y == nullptr)
|
if(Y == nullptr)
|
||||||
Y = new NDArray(outOrder, {M}, DataTypeUtils::pickPairwiseResultType(A->dataType(), X->dataType()), A->getContext());
|
Y = new NDArray(outOrder, {M}, DataTypeUtils::pickPairwiseResultType(A->dataType(), X->dataType()), A->getContext());
|
||||||
|
|
||||||
|
if (Y->isEmpty())
|
||||||
|
return Y;
|
||||||
|
|
||||||
const int incx = X->strideAt(xLenDim);
|
const int incx = X->strideAt(xLenDim);
|
||||||
const int incy = Y->strideAt(yLenDim);
|
const int incy = Y->strideAt(yLenDim);
|
||||||
|
|
||||||
|
@ -634,6 +640,9 @@ NDArray* MmulHelper::mmulNxN(const NDArray* A, const NDArray* B, NDArray* C, con
|
||||||
else
|
else
|
||||||
C = new NDArray(outOrder, cExpectedShape, DataTypeUtils::pickPairwiseResultType(A->dataType(), B->dataType()), A->getContext());
|
C = new NDArray(outOrder, cExpectedShape, DataTypeUtils::pickPairwiseResultType(A->dataType(), B->dataType()), A->getContext());
|
||||||
|
|
||||||
|
if (C->isEmpty())
|
||||||
|
return C;
|
||||||
|
|
||||||
const int cRank = C->rankOf();
|
const int cRank = C->rankOf();
|
||||||
|
|
||||||
const int aMaxis(aRank-2), aKaxis(aRank-1), bKaxis(bRank-2), bNaxis(bRank-1), cMaxis(cRank-2), cNaxis(cRank-1);
|
const int aMaxis(aRank-2), aKaxis(aRank-1), bKaxis(bRank-2), bNaxis(bRank-1), cMaxis(cRank-2), cNaxis(cRank-1);
|
||||||
|
|
|
@ -236,6 +236,9 @@ nd4j::NDArray* MmulHelper::mmul(const nd4j::NDArray* A, const nd4j::NDArray* B,
|
||||||
throw std::invalid_argument("");
|
throw std::invalid_argument("");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (z->isEmpty())
|
||||||
|
return;
|
||||||
|
|
||||||
NDArray* xT(const_cast<NDArray*>(x)), *yT(const_cast<NDArray*>(y)), *zT(z);
|
NDArray* xT(const_cast<NDArray*>(x)), *yT(const_cast<NDArray*>(y)), *zT(z);
|
||||||
|
|
||||||
if((transX && xRank > 1) || (transY && yRank > 1)) {
|
if((transX && xRank > 1) || (transY && yRank > 1)) {
|
||||||
|
|
|
@ -324,3 +324,34 @@ TEST_F(EmptyTests, test_empty_reshape_1) {
|
||||||
delete result0;
|
delete result0;
|
||||||
delete result1;
|
delete result1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(EmptyTests, test_empty_matmul_1) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {0, 1});
|
||||||
|
auto y = NDArrayFactory::create<float>('c', {1, 0});
|
||||||
|
auto e = NDArrayFactory::create<float>('c', {0, 0});
|
||||||
|
|
||||||
|
nd4j::ops::matmul op;
|
||||||
|
auto result = op.execute({&x, &y}, {}, {});
|
||||||
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
ASSERT_EQ(e, *z);
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(EmptyTests, test_empty_matmul_2) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {1, 0, 4});
|
||||||
|
auto y = NDArrayFactory::create<float>('c', {1, 4, 0});
|
||||||
|
auto e = NDArrayFactory::create<float>('c', {1, 0, 0});
|
||||||
|
|
||||||
|
nd4j::ops::matmul op;
|
||||||
|
auto result = op.execute({&x, &y}, {}, {});
|
||||||
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
ASSERT_EQ(e, *z);
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
|
@ -274,3 +274,7 @@ To build libnd4j with MKL:
|
||||||
Then build libnd4j as before. You may have to be careful about having multiple BLAS implementations on your path. Ideally, have only MKL on the path while building libnd4j.
|
Then build libnd4j as before. You may have to be careful about having multiple BLAS implementations on your path. Ideally, have only MKL on the path while building libnd4j.
|
||||||
|
|
||||||
Note: you may be able to get some additional performance on hyperthreaded processors by setting the system/environment variable MKL_DYNAMIC to have the value 'false'.
|
Note: you may be able to get some additional performance on hyperthreaded processors by setting the system/environment variable MKL_DYNAMIC to have the value 'false'.
|
||||||
|
|
||||||
|
|
||||||
|
float16_nhcw
|
||||||
|
float16_nhwc
|
|
@ -138,4 +138,48 @@ public class SDImage extends SDOps {
|
||||||
SDVariable out = new HsvToRgb(sd, input).outputVariable();
|
SDVariable out = new HsvToRgb(sd, input).outputVariable();
|
||||||
return updateVariableNameAndReference(out, name);
|
return updateVariableNameAndReference(out, name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Converting array from RGB to YIQ format
|
||||||
|
* @param name name
|
||||||
|
* @param input 3D image
|
||||||
|
* @return 3D image
|
||||||
|
*/
|
||||||
|
public SDVariable rgbToYiq(String name, @NonNull SDVariable input) {
|
||||||
|
SDVariable out = new RgbToYiq(sd, input).outputVariable();
|
||||||
|
return updateVariableNameAndReference(out, name);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Converting image from YIQ to RGB format
|
||||||
|
* @param name name
|
||||||
|
* @param input 3D image
|
||||||
|
* @return 3D image
|
||||||
|
*/
|
||||||
|
public SDVariable yiqToRgb(String name, @NonNull SDVariable input) {
|
||||||
|
SDVariable out = new YiqToRgb(sd, input).outputVariable();
|
||||||
|
return updateVariableNameAndReference(out, name);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Converting array from RGB to YUV format
|
||||||
|
* @param name name
|
||||||
|
* @param input 3D image
|
||||||
|
* @return 3D image
|
||||||
|
*/
|
||||||
|
public SDVariable rgbToYuv(String name, @NonNull SDVariable input) {
|
||||||
|
SDVariable out = new RgbToYuv(sd, input).outputVariable();
|
||||||
|
return updateVariableNameAndReference(out, name);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Converting image from YUV to RGB format
|
||||||
|
* @param name name
|
||||||
|
* @param input 3D image
|
||||||
|
* @return 3D image
|
||||||
|
*/
|
||||||
|
public SDVariable yuvToRgb(String name, @NonNull SDVariable input) {
|
||||||
|
SDVariable out = new YuvToRgb(sd, input).outputVariable();
|
||||||
|
return updateVariableNameAndReference(out, name);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -593,6 +593,11 @@ public class ImportClassMapping {
|
||||||
org.nd4j.linalg.api.ops.custom.AdjustContrastV2.class,
|
org.nd4j.linalg.api.ops.custom.AdjustContrastV2.class,
|
||||||
org.nd4j.linalg.api.ops.custom.HsvToRgb.class,
|
org.nd4j.linalg.api.ops.custom.HsvToRgb.class,
|
||||||
org.nd4j.linalg.api.ops.custom.RgbToHsv.class,
|
org.nd4j.linalg.api.ops.custom.RgbToHsv.class,
|
||||||
|
org.nd4j.linalg.api.ops.custom.RgbToYiq.class,
|
||||||
|
org.nd4j.linalg.api.ops.custom.RgbToGrayscale.class,
|
||||||
|
org.nd4j.linalg.api.ops.custom.YiqToRgb.class,
|
||||||
|
org.nd4j.linalg.api.ops.custom.RgbToYuv.class,
|
||||||
|
org.nd4j.linalg.api.ops.custom.YuvToRgb.class,
|
||||||
org.nd4j.linalg.api.ops.custom.BitCast.class,
|
org.nd4j.linalg.api.ops.custom.BitCast.class,
|
||||||
org.nd4j.linalg.api.ops.custom.CompareAndBitpack.class,
|
org.nd4j.linalg.api.ops.custom.CompareAndBitpack.class,
|
||||||
org.nd4j.linalg.api.ops.custom.DivideNoNan.class,
|
org.nd4j.linalg.api.ops.custom.DivideNoNan.class,
|
||||||
|
@ -609,7 +614,8 @@ public class ImportClassMapping {
|
||||||
org.nd4j.linalg.api.ops.custom.ToggleBits.class,
|
org.nd4j.linalg.api.ops.custom.ToggleBits.class,
|
||||||
org.nd4j.linalg.api.ops.custom.Igamma.class,
|
org.nd4j.linalg.api.ops.custom.Igamma.class,
|
||||||
org.nd4j.linalg.api.ops.custom.Igammac.class,
|
org.nd4j.linalg.api.ops.custom.Igammac.class,
|
||||||
org.nd4j.linalg.api.ops.custom.Digamma.class
|
org.nd4j.linalg.api.ops.custom.Digamma.class,
|
||||||
|
org.nd4j.linalg.api.ops.custom.Lu.class
|
||||||
);
|
);
|
||||||
|
|
||||||
static {
|
static {
|
||||||
|
|
|
@ -0,0 +1,69 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
package org.nd4j.linalg.api.ops.custom;
|
||||||
|
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
import org.tensorflow.framework.AttrValue;
|
||||||
|
import org.tensorflow.framework.GraphDef;
|
||||||
|
import org.tensorflow.framework.NodeDef;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
@NoArgsConstructor
|
||||||
|
public class Lu extends DynamicCustomOp {
|
||||||
|
private DataType indexDataType;
|
||||||
|
|
||||||
|
public Lu(INDArray input) {
|
||||||
|
addInputArgument(input);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Lu(SameDiff sameDiff, SDVariable input) {
|
||||||
|
super(sameDiff, new SDVariable[]{input});
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String opName() {
|
||||||
|
return "lu";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String tensorflowName() {
|
||||||
|
return "Lu";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||||
|
if (attributesForNode.containsKey("output_idx_type")){
|
||||||
|
indexDataType = TFGraphMapper.convertType(attributesForNode.get("output_idx_type").getType());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||||
|
int n = args().length;
|
||||||
|
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
|
||||||
|
return Arrays.asList(inputDataTypes.get(0), indexDataType);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,44 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
package org.nd4j.linalg.api.ops.custom;
|
||||||
|
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
|
||||||
|
@NoArgsConstructor
|
||||||
|
public class RgbToGrayscale extends DynamicCustomOp {
|
||||||
|
|
||||||
|
public RgbToGrayscale(INDArray image) {
|
||||||
|
addInputArgument(image);
|
||||||
|
}
|
||||||
|
|
||||||
|
public RgbToGrayscale(SameDiff sameDiff, SDVariable image) {
|
||||||
|
super(sameDiff, new SDVariable[]{image});
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String opName() {
|
||||||
|
return "rgb_to_grs";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String tensorflowName() {
|
||||||
|
return "RgbToGrs";
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,56 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
package org.nd4j.linalg.api.ops.custom;
|
||||||
|
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
@NoArgsConstructor
|
||||||
|
public class RgbToYiq extends DynamicCustomOp {
|
||||||
|
|
||||||
|
public RgbToYiq(INDArray input) {
|
||||||
|
addInputArgument(input);
|
||||||
|
}
|
||||||
|
|
||||||
|
public RgbToYiq(SameDiff sameDiff, SDVariable input) {
|
||||||
|
super(sameDiff, new SDVariable[]{input});
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String opName() {
|
||||||
|
return "rgb_to_yiq";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String tensorflowName() {
|
||||||
|
return "RgbToYiq";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||||
|
int n = args().length;
|
||||||
|
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
|
||||||
|
return Collections.singletonList(inputDataTypes.get(0));
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,56 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.nd4j.linalg.api.ops.custom;
|
||||||
|
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
@NoArgsConstructor
|
||||||
|
public class RgbToYuv extends DynamicCustomOp {
|
||||||
|
public RgbToYuv(INDArray input) {
|
||||||
|
addInputArgument(input);
|
||||||
|
}
|
||||||
|
|
||||||
|
public RgbToYuv(SameDiff sameDiff, SDVariable input) {
|
||||||
|
super(sameDiff, new SDVariable[]{input});
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String opName() {
|
||||||
|
return "rgb_to_yuv";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String tensorflowName() {
|
||||||
|
return "RgbToYuv";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||||
|
int n = args().length;
|
||||||
|
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
|
||||||
|
return Collections.singletonList(inputDataTypes.get(0));
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,55 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
package org.nd4j.linalg.api.ops.custom;
|
||||||
|
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
@NoArgsConstructor
|
||||||
|
public class YiqToRgb extends DynamicCustomOp {
|
||||||
|
public YiqToRgb(INDArray input) {
|
||||||
|
addInputArgument(input);
|
||||||
|
}
|
||||||
|
|
||||||
|
public YiqToRgb(SameDiff sameDiff, SDVariable input) {
|
||||||
|
super(sameDiff, new SDVariable[]{input});
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String opName() {
|
||||||
|
return "yiq_to_rgb";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String tensorflowName() {
|
||||||
|
return "YiqToRgb";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||||
|
int n = args().length;
|
||||||
|
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
|
||||||
|
return Collections.singletonList(inputDataTypes.get(0));
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,56 @@
|
||||||
|
/* ******************************************************************************
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.nd4j.linalg.api.ops.custom;
|
||||||
|
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
@NoArgsConstructor
|
||||||
|
public class YuvToRgb extends DynamicCustomOp {
|
||||||
|
public YuvToRgb(INDArray input) {
|
||||||
|
addInputArgument(input);
|
||||||
|
}
|
||||||
|
|
||||||
|
public YuvToRgb(SameDiff sameDiff, SDVariable input) {
|
||||||
|
super(sameDiff, new SDVariable[]{input});
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String opName() {
|
||||||
|
return "yuv_to_rgb";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String tensorflowName() {
|
||||||
|
return "YuvToRgb";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||||
|
int n = args().length;
|
||||||
|
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
|
||||||
|
return Collections.singletonList(inputDataTypes.get(0));
|
||||||
|
}
|
||||||
|
}
|
|
@ -6081,6 +6081,42 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
assertEquals(mE, mC);
|
assertEquals(mE, mC);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
Analog of this TF code:
|
||||||
|
a = tf.constant([], shape=[0,1])
|
||||||
|
b = tf.constant([], shape=[1, 0])
|
||||||
|
c = tf.matmul(a, b)
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
public void testMatmul_Empty() {
|
||||||
|
val mA = Nd4j.create(0,1);
|
||||||
|
val mB = Nd4j.create(1,0);
|
||||||
|
val mC = Nd4j.create(0,0);
|
||||||
|
|
||||||
|
val op = DynamicCustomOp.builder("matmul")
|
||||||
|
.addInputs(mA, mB)
|
||||||
|
.addOutputs(mC)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
Nd4j.getExecutioner().exec(op);
|
||||||
|
assertEquals(Nd4j.create(0,0), mC);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testMatmul_Empty1() {
|
||||||
|
val mA = Nd4j.create(1,0, 4);
|
||||||
|
val mB = Nd4j.create(1,4, 0);
|
||||||
|
val mC = Nd4j.create(1,0, 0);
|
||||||
|
|
||||||
|
val op = DynamicCustomOp.builder("mmul")
|
||||||
|
.addInputs(mA, mB)
|
||||||
|
.addOutputs(mC)
|
||||||
|
.addIntegerArguments(0,0)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
Nd4j.getExecutioner().exec(op);
|
||||||
|
assertEquals(Nd4j.create(1,0,0), mC);
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testScalarSqueeze() {
|
public void testScalarSqueeze() {
|
||||||
|
|
|
@ -19,6 +19,7 @@ package org.nd4j.linalg.custom;
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
|
import org.apache.commons.lang3.ArrayUtils;
|
||||||
import org.junit.Ignore;
|
import org.junit.Ignore;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.linalg.BaseNd4jTest;
|
import org.nd4j.linalg.BaseNd4jTest;
|
||||||
|
@ -1161,6 +1162,23 @@ public class CustomOpsTests extends BaseNd4jTest {
|
||||||
assertArrayEquals(expectedY.shape(), y.shape());
|
assertArrayEquals(expectedY.shape(), y.shape());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testFusedBatchNormHalf() {
|
||||||
|
INDArray x = Nd4j.create(DataType.HALF, 1,2,3,4);
|
||||||
|
//INDArray scale = Nd4j.createFromArray(new float[]{0.7717f, 0.9281f, 0.9846f, 0.4838f});
|
||||||
|
//INDArray offset = Nd4j.createFromArray(new float[]{0.9441f, 0.5957f, 0.8669f, 0.3502f});
|
||||||
|
INDArray scale = Nd4j.create(DataType.HALF, 4);
|
||||||
|
INDArray offset = Nd4j.create(DataType.HALF, 4);
|
||||||
|
|
||||||
|
INDArray y = Nd4j.createUninitialized(DataType.HALF, x.shape());
|
||||||
|
INDArray batchMean = Nd4j.create(4);
|
||||||
|
INDArray batchVar = Nd4j.create(4);
|
||||||
|
|
||||||
|
FusedBatchNorm op = new FusedBatchNorm(x, scale, offset, 0, 1,
|
||||||
|
y, batchMean, batchVar);
|
||||||
|
Nd4j.exec(op);
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testMatrixBandPart() {
|
public void testMatrixBandPart() {
|
||||||
INDArray x = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 2*3*3).reshape(2,3,3);
|
INDArray x = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 2*3*3).reshape(2,3,3);
|
||||||
|
@ -1367,76 +1385,232 @@ public class CustomOpsTests extends BaseNd4jTest {
|
||||||
@Test
|
@Test
|
||||||
@Ignore
|
@Ignore
|
||||||
public void testRgbToHsv() {
|
public void testRgbToHsv() {
|
||||||
INDArray expected = Nd4j.createFromArray(new float[]{6.75000000e+01f, 2.54545455e-01f, 8.62745098e-01f, 1.80000000e+02f,
|
INDArray expected = Nd4j.createFromArray(new float[]{
|
||||||
3.27777778e-01f, 7.05882353e-01f, 1.35066079e+02f, 9.26530612e-01f,
|
0.545678377f, 0.644941628f, 0.461456001f, 0.588904262f, 0.725874603f,
|
||||||
9.60784314e-01f, 7.45341615e-01f, 6.85106383e-01f, 9.21568627e-01f,
|
0.517642438f, 0.0869259685f, 0.54742825f, 0.413571358f, 0.890151322f,
|
||||||
2.78688525e+02f, 7.85407725e-01f, 9.13725490e-01f, 2.10989011e+01f,
|
0.928968489f, 0.684074104f, 0.52110225f, 0.753103435f, 0.913557053f,
|
||||||
4.76439791e-01f, 7.49019608e-01f, 2.89038462e+02f, 8.48979592e-01f,
|
0.46850124f, 0.761800349f, 0.237176552f, 0.90049392f, 0.965541422f,
|
||||||
9.60784314e-01f, 1.56416185e+02f, 6.92000000e-01f, 9.80392157e-01f,
|
0.486593395f, 0.263826847f, 0.290193319f, 0.148351923f, 0.674094439f,
|
||||||
3.52881356e+02f, 5.31531532e-01f, 4.35294118e-01f, 1.07142857e+01f,
|
0.0361763388f, 0.3721793f, 0.823592246f, 0.524110138f, 0.2204483f,
|
||||||
2.90155440e-01f, 7.56862745e-01f, 3.43384615e+02f, 3.86904762e-01f,
|
0.632020354f, 0.637001634f, 0.216262609f, 0.279114306f, 0.25007084f,
|
||||||
6.58823529e-01f, 1.78321678e+02f, 7.48691099e-01f, 7.49019608e-01f,
|
0.30433768f, 0.0448598303f, 0.586083114f, 0.978048146f, 0.91390729f,
|
||||||
2.30645161e+02f, 7.78242678e-01f, 9.37254902e-01f, 3.19159664e+02f,
|
0.385092884f, 0.218390301f, 0.762684941f, 0.505838513f, 0.366362303f,
|
||||||
7.62820513e-01f, 6.11764706e-01f, 2.10126582e+01f, 9.71311475e-01f,
|
0.931746006f, 0.00208298792f, 0.875348926f, 0.428009957f, 0.270003974f,
|
||||||
9.56862745e-01f, 2.90896552e+02f, 5.96707819e-01f, 9.52941176e-01f,
|
0.313204288f, 0.775881767f, 0.367065936f, 0.164243385f, 0.644775152f,
|
||||||
1.74822335e+02f, 9.42583732e-01f, 8.19607843e-01f, 2.06600985e+02f,
|
0.575452209f, 0.911922634f, 0.0581932105f, 0.437950462f, 0.946475744f
|
||||||
9.90243902e-01f, 8.03921569e-01f, 1.06883721e+02f, 8.70445344e-01f,
|
}).reshape(5,4,3);
|
||||||
9.68627451e-01f, 1.95272727e+02f, 6.11111111e-01f, 7.05882353e-01f}).reshape(5,4,3);
|
INDArray input = Nd4j.createFromArray(new float[]{
|
||||||
INDArray input = Nd4j.createFromArray(new float[]{213.f, 220.f, 164.f, 121.f, 180.f, 180.f, 18.f, 245.f, 75.f, 235.f, 76.f, 74.f, 168.f,
|
0.262831867f, 0.723622441f, 0.740797927f, 0.717254877f, 0.430244058f,
|
||||||
50.f, 233.f, 191.f, 132.f, 100.f, 207.f, 37.f, 245.f, 77.f, 250.f, 182.f, 111.f, 52.f,
|
0.418478161f, 0.906427443f, 0.199753001f, 0.725874603f, 0.890151322f,
|
||||||
59.f, 193.f, 147.f, 137.f, 168.f, 103.f, 121.f, 48.f, 191.f, 187.f, 53.f, 82.f, 239.f,
|
0.928968489f, 0.684074104f, 0.312434604f, 0.991390795f, 0.163174023f,
|
||||||
156.f, 37.f, 118.f, 244.f, 90.f, 7.f, 221.f, 98.f, 243.f, 12.f, 209.f, 192.f, 2.f,
|
0.268038541f, 0.361258626f, 0.685067773f, 0.682347894f, 0.84635365f,
|
||||||
115.f, 205.f, 79.f, 247.f, 32.f, 70.f, 152.f, 180.f}).reshape(5,4,3);
|
0.761800349f, 0.753103435f, 0.913557053f, 0.965541422f, 0.112067183f,
|
||||||
|
0.540247589f, 0.280050347f, 0.106776128f, 0.679180562f, 0.870388806f,
|
||||||
|
0.604331017f, 0.630475283f, 0.674094439f, 0.279114306f, 0.632020354f,
|
||||||
|
0.823592246f, 0.490824632f, 0.75257351f, 0.129888852f, 0.849081645f,
|
||||||
|
0.883509099f, 0.765611768f, 0.997870266f, 0.446510047f, 0.385092884f,
|
||||||
|
0.931746006f, 0.978048146f, 0.91390729f, 0.685308874f, 0.0834472676f,
|
||||||
|
0.396037966f, 0.756701186f, 0.597481251f, 0.784472764f, 0.514242649f,
|
||||||
|
0.392005324f, 0.911922634f, 0.270003974f, 0.644775152f, 0.946475744f
|
||||||
|
}).reshape(5,4,3);
|
||||||
RgbToHsv op = new RgbToHsv(input);
|
RgbToHsv op = new RgbToHsv(input);
|
||||||
INDArray[] ret = Nd4j.exec(op);
|
INDArray[] ret = Nd4j.exec(op);
|
||||||
assertEquals(ret[0], expected);
|
assertEquals(ret[0], expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Exact copy of libnd4j test
|
// Exact copy of libnd4j test
|
||||||
@Ignore
|
|
||||||
@Test
|
@Test
|
||||||
public void testHsvToRgb() {
|
public void testHsvToRgb() {
|
||||||
INDArray input = Nd4j.createFromArray(new float[]{130.f, 61.f, 239.f, 117.f, 16.f, 168.f, 181.f, 223.f, 0.f, 49.f, 195.f, 195.f, 131.f,
|
INDArray input = Nd4j.createFromArray(new float[]{0.705504596f, 0.793608069f, 0.65870738f, 0.848827183f, 0.920532584f,
|
||||||
153.f, 78.f, 86.f, 21.f, 209.f, 101.f, 14.f, 107.f, 191.f, 98.f, 210.f}).reshape(8,3);
|
0.887555957f, 0.72317636f, 0.563831031f, 0.773604929f, 0.269532293f,
|
||||||
|
0.332347751f, 0.111181192f}).reshape(4,3);
|
||||||
|
|
||||||
INDArray expected = Nd4j.createFromArray(new float[]{263.25842697f, 0.74476987f, 0.9372549f, 279.86842105f,
|
INDArray expected = Nd4j.createFromArray(new float[]{0.257768334f, 0.135951888f, 0.65870738f, 0.887555957f, 0.0705317783f,
|
||||||
0.9047619f, 0.65882353f, 71.30044843f, 1.f, 0.8745098f, 180.f, 0.74871795f, 0.76470588f,
|
0.811602857f, 0.485313689f, 0.337422464f, 0.773604929f, 0.0883753772f,
|
||||||
77.6f, 0.49019608f, 0.6f, 260.74468085f,
|
0.111181192f, 0.074230373f}).reshape(4,3);
|
||||||
0.89952153f, 0.81960784f, 296.12903226f, 0.86915888f,
|
|
||||||
0.41960784f, 289.82142857f, 0.53333333f, 0.82352941f}).reshape(8,3);
|
|
||||||
|
|
||||||
HsvToRgb op = new HsvToRgb(input);
|
HsvToRgb op = new HsvToRgb(input);
|
||||||
INDArray[] ret = Nd4j.exec(op);
|
INDArray[] ret = Nd4j.exec(op);
|
||||||
assertEquals(ret[0], expected);
|
assertEquals(ret[0], expected);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Ignore
|
|
||||||
@Test
|
@Test
|
||||||
public void testHsvToRgb_1() {
|
public void testHsvToRgb_1() {
|
||||||
/* Emulation of simple TF test:
|
/* Emulation of simple TF test:
|
||||||
image = tf.random_uniform(shape = [1,1,3])
|
image = tf.random_uniform(shape = [1,1,3])
|
||||||
tf.image.hsv_to_rgb(image)*/
|
tf.image.hsv_to_rgb(image)*/
|
||||||
INDArray image = Nd4j.createFromArray(new float[]{0.7788f, 0.8012f, 0.7244f}).
|
INDArray image = Nd4j.createFromArray(new float[]{0.778785586f,0.801197767f,0.724374652f}).
|
||||||
reshape(1,1,3);
|
reshape(1,1,3);
|
||||||
HsvToRgb op = new HsvToRgb(image);
|
HsvToRgb op = new HsvToRgb(image);
|
||||||
INDArray[] ret = Nd4j.exec(op);
|
INDArray[] ret = Nd4j.exec(op);
|
||||||
INDArray expected = Nd4j.createFromArray(new float[]{0.53442812f,0.144007295f,0.724374652f}).reshape(1,1,3);
|
System.out.println(ret[0].toStringFull());
|
||||||
|
INDArray expected = Nd4j.createFromArray(new float[]{ 0.53442812f, 0.144007325f, 0.724374652f}).reshape(1,1,3);
|
||||||
assertEquals(expected, ret[0]);
|
assertEquals(expected, ret[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Ignore
|
|
||||||
@Test
|
@Test
|
||||||
public void testRgbToHsv_1() {
|
public void testRgbToHsv_1() {
|
||||||
/* Emulation of simple TF test:
|
/* Emulation of simple TF test:
|
||||||
image = tf.random_uniform(shape = [1,2,3])
|
image = tf.random_uniform(shape = [1,2,3])
|
||||||
tf.image.rgb_to_hsv(image)*/
|
tf.image.rgb_to_hsv(image)*/
|
||||||
INDArray image = Nd4j.createFromArray(new float[]{0.7788f,0.8012f,0.7244f,
|
INDArray image = Nd4j.createFromArray(new float[]{0.778785586f,0.801197767f,0.724374652f,
|
||||||
0.2309f,0.7271f,0.1804f}).reshape(1,2,3);
|
0.230894327f, 0.727141261f, 0.180390716f }).reshape(2,3);
|
||||||
RgbToHsv op = new RgbToHsv(image);
|
RgbToHsv op = new RgbToHsv(image);
|
||||||
INDArray[] ret = Nd4j.exec(op);
|
INDArray[] ret = Nd4j.exec(op);
|
||||||
INDArray expected = Nd4j.createFromArray(new float[]{0.215289578f, 0.095885336f, 0.801197767f,
|
INDArray expected = Nd4j.createFromArray(new float[]{0.215289578f,0.095885336f,0.801197767f,
|
||||||
0.317938268f, 0.751917899f, 0.727141261f}).reshape(1,2,3);
|
0.317938268f,0.751917899f,0.727141261f}).reshape(2,3);
|
||||||
assertEquals(expected, ret[0]);
|
assertEquals(expected, ret[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testLu() {
|
||||||
|
INDArray input = Nd4j.createFromArray(new float[]{1.f, 2.f, 3.f, 0.f, 2.f, 3.f, 0.f, 0.f, 7.f})
|
||||||
|
.reshape(3,3);
|
||||||
|
Lu op = new Lu(input);
|
||||||
|
INDArray[] ret = Nd4j.exec(op);
|
||||||
|
|
||||||
|
INDArray expected = Nd4j.createFromArray(new float[]{1.f, 2.f, 3.f, 0.f, 2.f, 3.f, 0.f, 0.f, 7f}).reshape(3,3);
|
||||||
|
assertEquals(expected, ret[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testRgbToYiq() {
|
||||||
|
INDArray image = Nd4j.createFromArray(new float[]{
|
||||||
|
0.48055f , 0.80757356f, 0.2564435f , 0.94277316f, 0.17006584f,
|
||||||
|
0.33366168f, 0.41727918f, 0.54528666f, 0.48942474f, 0.3305715f ,
|
||||||
|
0.98633456f, 0.00158441f, 0.97605824f, 0.02462568f, 0.14837205f,
|
||||||
|
0.00112842f, 0.99260217f, 0.9585542f , 0.41196227f, 0.3095014f ,
|
||||||
|
0.6620493f , 0.30888894f, 0.3122602f , 0.7993488f , 0.86656475f,
|
||||||
|
0.5997049f , 0.9776477f , 0.72481847f, 0.7835693f , 0.14649455f,
|
||||||
|
0.3573504f , 0.33301765f, 0.7853056f , 0.25830218f, 0.59289205f,
|
||||||
|
0.41357264f, 0.5934154f , 0.72647524f, 0.6623308f , 0.96197623f,
|
||||||
|
0.0720306f , 0.23853847f, 0.1427159f , 0.19581454f, 0.06766324f,
|
||||||
|
0.10614152f, 0.26093867f, 0.9584985f , 0.01258832f, 0.8160156f ,
|
||||||
|
0.56506383f, 0.08418505f, 0.86440504f, 0.6807802f , 0.20662387f,
|
||||||
|
0.4153733f , 0.76146203f, 0.50057423f, 0.08274968f, 0.9521758f
|
||||||
|
}).reshape(5,4,3);
|
||||||
|
|
||||||
|
INDArray expected = Nd4j.createFromArray(new float[]{
|
||||||
|
0.64696468f, -0.01777124f, -0.24070648f, 0.41975525f, 0.40788622f,
|
||||||
|
0.21433232f, 0.50064416f, -0.05832884f, -0.04447775f, 0.67799989f,
|
||||||
|
-0.07432612f, -0.44518381f, 0.32321111f, 0.52719408f, 0.2397369f ,
|
||||||
|
0.69227005f, -0.57987869f, -0.22032876f, 0.38032767f, -0.05223263f,
|
||||||
|
0.13137188f, 0.3667803f , -0.15853189f, 0.15085728f, 0.72258149f,
|
||||||
|
0.03757231f, 0.17403452f, 0.69337627f, 0.16971045f, -0.21071186f,
|
||||||
|
0.39185397f, -0.13084008f, 0.145886f , 0.47240727f, -0.1417591f ,
|
||||||
|
-0.12659159f, 0.67937788f, -0.05867803f, -0.04813048f, 0.35710624f,
|
||||||
|
0.47681283f, 0.24003804f, 0.1653288f , 0.00953913f, -0.05111816f,
|
||||||
|
0.29417614f, -0.31640032f, 0.18433114f, 0.54718234f, -0.39812097f,
|
||||||
|
-0.24805083f, 0.61018603f, -0.40592682f, -0.22219216f, 0.39241133f,
|
||||||
|
-0.23560742f, 0.06353694f, 0.3067938f , -0.0304029f , 0.35893188f
|
||||||
|
}).reshape(5,4,3);
|
||||||
|
|
||||||
|
RgbToYiq op = new RgbToYiq(image);
|
||||||
|
INDArray[] ret = Nd4j.exec(op);
|
||||||
|
assertEquals(expected, ret[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testYiqToRgb() {
|
||||||
|
INDArray image = Nd4j.createFromArray(new float[]{
|
||||||
|
0.775258899f, -0.288912386f, -0.132725924f, 0.0664454922f, -0.212469354f,
|
||||||
|
0.455438733f, 0.418221354f, 0.349350512f, 0.145902053f, 0.947576523f,
|
||||||
|
-0.471601307f, 0.263960421f, 0.700227439f, 0.32434237f, -0.278446227f,
|
||||||
|
0.130805135f, -0.438441873f, 0.187127829f, 0.0276055578f, -0.179727226f,
|
||||||
|
0.305075705f, 0.716282248f, 0.278215706f, -0.44586885f, 0.76971364f,
|
||||||
|
0.131288841f, -0.141177326f, 0.900081575f, -0.0788725987f, 0.14756602f,
|
||||||
|
0.387832165f, 0.229834676f, 0.47921446f, 0.632930398f, 0.0443540029f,
|
||||||
|
-0.268817365f, 0.0977194682f, -0.141669706f, -0.140715122f, 0.946808815f,
|
||||||
|
-0.52525419f, -0.106209636f, 0.659476519f, 0.391066104f, 0.426448852f,
|
||||||
|
0.496989518f, -0.283434421f, -0.177366048f, 0.715208411f, -0.496444523f,
|
||||||
|
0.189553142f, 0.616444945f, 0.345852494f, 0.447739422f, 0.224696323f,
|
||||||
|
0.451372236f, 0.298027098f, 0.446561724f, -0.187599331f, -0.448159873f
|
||||||
|
}).reshape(5,4,3);
|
||||||
|
|
||||||
|
INDArray expected = Nd4j.createFromArray(new float[]{
|
||||||
|
0.416663059f, 0.939747555f, 0.868814286f, 0.146075352f, -0.170521997f,
|
||||||
|
1.07776645f, 0.842775284f, 0.228765106f, 0.280231822f, 0.660605291f,
|
||||||
|
0.905021825f, 1.91936605f, 0.837427991f, 0.792213732f, -0.133271854f,
|
||||||
|
-0.17216571f, 0.128957025f, 0.934955336f, 0.0451873479f, -0.120952621f,
|
||||||
|
0.746436225f, 0.705446224f, 0.929172217f, -0.351493549f, 0.807577594f,
|
||||||
|
0.825371955f, 0.383812296f, 0.916293093f, 0.82603058f, 1.23885956f,
|
||||||
|
0.905059196f, 0.015164554f, 0.950156781f, 0.508443732f, 0.794845279f,
|
||||||
|
0.12571529f, -0.125074273f, 0.227326869f, 0.0147000261f, 0.378735409f,
|
||||||
|
1.15842402f, 1.34712305f, 1.2980804f, 0.277102016f, 0.953435072f,
|
||||||
|
0.115916842f, 0.688879376f, 0.508405162f, 0.35829352f, 0.727568094f,
|
||||||
|
1.58768577f, 1.22504294f, 0.232589777f, 0.996727258f, 0.841224629f,
|
||||||
|
-0.0909671176f, 0.233051388f, -0.0110094378f, 0.787642119f, -0.109582274f
|
||||||
|
}).reshape(5,4,3);
|
||||||
|
|
||||||
|
YiqToRgb op = new YiqToRgb(image);
|
||||||
|
INDArray[] ret = Nd4j.exec(op);
|
||||||
|
assertEquals(expected, ret[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testRgbToGrayscale() {
|
||||||
|
INDArray image = Nd4j.createFromArray(new float[]{
|
||||||
|
1.7750e+01f, -7.1062e+01f, -1.0019e+02f,-2.3406e+01f, 5.2094e+01f,
|
||||||
|
9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f,3.3562e+01f,
|
||||||
|
-5.8844e+01f, 2.2750e+01f, -1.0477e+01f, 7.7344e+00f, 9.5469e+00f,
|
||||||
|
2.1391e+01f, -8.5312e+01f, 7.5830e-01f,2.3125e+01f, 1.8145e+00f,
|
||||||
|
1.4602e+01f,-4.5859e+00f, 3.9344e+01f, 1.1617e+01f,-8.6562e+01f,
|
||||||
|
1.0038e+02f, 6.7938e+01f,5.9961e+00f, 6.7812e+01f, 2.9734e+01f,
|
||||||
|
2.9609e+01f, -6.1438e+01f, 1.7750e+01f,6.8562e+01f, -7.4414e+00f,
|
||||||
|
3.9656e+01f,1.1641e+01f, -2.7516e+01f, 6.7562e+01f,7.8438e+01f,
|
||||||
|
5.4883e+00f, 2.9438e+01f,-3.1344e+01f, 6.5125e+01f,
|
||||||
|
1.2695e+01f,4.0531e+01f, -6.1211e+00f, 6.2219e+01f,4.6812e+01f,
|
||||||
|
5.2250e+01f, -1.1414e+01f,1.5404e-02f, 2.9938e+01f, 5.6719e+00f,
|
||||||
|
-2.0125e+01f, 2.1531e+01f, 6.2500e+01f,7.2188e+01f, 9.3750e+00f,
|
||||||
|
-4.8125e+01f
|
||||||
|
}).reshape(5,4,3);
|
||||||
|
|
||||||
|
INDArray expected = Nd4j.createFromArray(new float[]{
|
||||||
|
-47.82958221f, 34.46305847f, 21.36137581f, -21.91625023f,2.49686432f,
|
||||||
|
-43.59792709f, 9.64180183f, 23.04854202f,40.7946167f, 44.98754883f,
|
||||||
|
-25.19047546f, 20.64586449f,-4.97033119f, 30.0226841f, 30.30688286f,
|
||||||
|
15.61459541f,43.36166f, 18.22480774f, 13.74833488f, 21.59387016f
|
||||||
|
}).reshape(5,4,1);
|
||||||
|
|
||||||
|
RgbToGrayscale op = new RgbToGrayscale(image);
|
||||||
|
INDArray[] ret = Nd4j.exec(op);
|
||||||
|
assertEquals(expected, ret[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testRgbToYuv() {
|
||||||
|
INDArray image = Nd4j.createFromArray(new float[]{
|
||||||
|
10f,50f,200f
|
||||||
|
});
|
||||||
|
|
||||||
|
INDArray expected = Nd4j.createFromArray(new float[]{
|
||||||
|
55.14f , 71.2872001f, -39.6005542f
|
||||||
|
});
|
||||||
|
|
||||||
|
RgbToYuv op = new RgbToYuv(image);
|
||||||
|
INDArray[] ret = Nd4j.exec(op);
|
||||||
|
assertEquals(expected, ret[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testYuvToRgb() {
|
||||||
|
INDArray image = Nd4j.createFromArray(new float[]{
|
||||||
|
55.14f , 71.2872001f, -39.6005542f
|
||||||
|
});
|
||||||
|
|
||||||
|
INDArray expected = Nd4j.createFromArray(new float[]{
|
||||||
|
10f, 50f, 200f
|
||||||
|
});
|
||||||
|
YuvToRgb op = new YuvToRgb(image);
|
||||||
|
INDArray[] ret = Nd4j.exec(op);
|
||||||
|
assertEquals(expected, ret[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testRgbToYiqEmpty() {
|
||||||
|
INDArray image = Nd4j.create(0,4,3);
|
||||||
|
RgbToYiq op = new RgbToYiq(image);
|
||||||
|
INDArray[] ret = Nd4j.exec(op);
|
||||||
|
assertArrayEquals(image.shape(), ret[0].shape());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue