Lu wrapper and tests fixes (#144)

* Tests fixed

* Lu added

* Test fixed

* Default timeout

* Tests timeouts fixed.

* TF import fix

* Timeouts added

* Timeout fixed.

* Test corrected

* rgb and yiq conversion ops added

* Converter ops added

* Header

* Yuv converters

* API added

* Empty test for matmul

* Explanation

* skip gemm/gemv on empty inputs

Signed-off-by: raver119 <raver119@gmail.com>

* Test added

* Correct test

* one more empty pass-through for mmul

Signed-off-by: raver119 <raver119@gmail.com>

* Cleanup

* Test added

* Test fixed

* Added missing mapping

* Added missing mapping

Co-authored-by: raver119 <raver119@gmail.com>
master
Alexander Stoyakin 2019-12-30 14:06:12 +02:00 committed by raver119
parent 9b329d2601
commit 010744ef9c
23 changed files with 716 additions and 46 deletions

View File

@ -20,11 +20,9 @@ import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.base.MnistFetcher;
import org.deeplearning4j.common.resources.DL4JResources;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.ClassRule;
import org.junit.Test;
import org.junit.*;
import org.junit.rules.TemporaryFolder;
import org.junit.rules.Timeout;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
import org.nd4j.linalg.dataset.DataSet;
@ -47,6 +45,8 @@ public class MnistFetcherTest extends BaseDL4JTest {
@ClassRule
public static TemporaryFolder testDir = new TemporaryFolder();
@Rule
public Timeout timeout = Timeout.seconds(300);
@BeforeClass
public static void setup() throws Exception {

View File

@ -72,7 +72,7 @@ import static org.nd4j.linalg.indexing.NDArrayIndex.point;
public class RecordReaderDataSetiteratorTest extends BaseDL4JTest {
@Rule
protected Timeout timeout = Timeout.seconds(300);
public Timeout timeout = Timeout.seconds(300);
@Override
public DataType getDataType(){

View File

@ -71,7 +71,7 @@ public class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest {
public TemporaryFolder temporaryFolder = new TemporaryFolder();
@Rule
protected Timeout timeout = Timeout.seconds(300);
public Timeout timeout = Timeout.seconds(300);
@Test
public void testsBasic() throws Exception {

View File

@ -17,7 +17,9 @@
package org.deeplearning4j.datasets.fetchers;
import org.deeplearning4j.BaseDL4JTest;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.Timeout;
import java.io.File;
@ -28,6 +30,9 @@ import static org.junit.Assert.assertTrue;
*/
public class SvhnDataFetcherTest extends BaseDL4JTest {
@Rule
public Timeout timeout = Timeout.seconds(600);
@Test
public void testSvhnDataFetcher() throws Exception {
SvhnDataFetcher fetch = new SvhnDataFetcher();

View File

@ -40,7 +40,7 @@ import static org.junit.Assert.*;
public class MultipleEpochsIteratorTest extends BaseDL4JTest {
@Rule
protected Timeout timeout = Timeout.seconds(300);
public Timeout timeout = Timeout.seconds(300);
@Test
public void testNextAndReset() throws Exception {

View File

@ -19,7 +19,9 @@ package org.deeplearning4j.datasets.iterator;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.impl.EmnistDataSetIterator;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.Timeout;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
@ -33,6 +35,9 @@ import static org.junit.Assert.*;
@Slf4j
public class TestEmnistDataSetIterator extends BaseDL4JTest {
@Rule
public Timeout timeout = Timeout.seconds(600);
@Override
public DataType getDataType(){
return DataType.FLOAT;

View File

@ -22,6 +22,7 @@ import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
import org.junit.rules.TestName;
import org.junit.rules.Timeout;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
@ -36,6 +37,9 @@ import java.util.Properties;
@Slf4j
public class BaseDL4JTest {
@Rule
public Timeout timeout = Timeout.seconds(600);
@Rule
public TestName name = new TestName();

View File

@ -22,6 +22,7 @@ import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
import org.junit.rules.TestName;
import org.junit.rules.Timeout;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
@ -37,6 +38,9 @@ import java.util.Properties;
@Slf4j
public class BaseDL4JTest {
@Rule
public Timeout timeout = Timeout.seconds(600);
@Rule
public TestName name = new TestName();

View File

@ -202,6 +202,9 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, con
if(C == nullptr)
C = new NDArray(outOrder, {M,N}, DataTypeUtils::pickPairwiseResultType(A->dataType(), B->dataType()), A->getContext());
if (C->isEmpty())
return C;
const auto aType = A->dataType();
const auto bType = B->dataType();
const auto cType = C->dataType();
@ -307,6 +310,9 @@ NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, nd4j::NDArray*
if(Y == nullptr)
Y = new NDArray(outOrder, {M}, DataTypeUtils::pickPairwiseResultType(A->dataType(), X->dataType()), A->getContext());
if (Y->isEmpty())
return Y;
const int incx = X->stridesOf()[xLenDim];
const int incy = Y->stridesOf()[yLenDim];
@ -511,6 +517,9 @@ NDArray* MmulHelper::mmulNxN(const NDArray* A, const NDArray* B, NDArray* C, con
C = new NDArray(outOrder, cExpectedShape, B->dataType());
}
if (C->isEmpty())
return C;
const int cRank = C->rankOf();
const int aMaxis(aRank-2), aKaxis(aRank-1), bKaxis(bRank-2), bNaxis(bRank-1), cMaxis(cRank-2), cNaxis(cRank-1);

View File

@ -235,6 +235,9 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, dou
if(C == nullptr)
C = new NDArray(outOrder, {M,N}, DataTypeUtils::pickPairwiseResultType(A->dataType(), B->dataType()), A->getContext());
if (C->isEmpty())
return C;
const int major = Environment::getInstance()->capabilities()[AffinityManager::currentDeviceId()].first();
const auto aType = A->dataType();
@ -376,6 +379,9 @@ NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, nd4j::NDArray*
if(Y == nullptr)
Y = new NDArray(outOrder, {M}, DataTypeUtils::pickPairwiseResultType(A->dataType(), X->dataType()), A->getContext());
if (Y->isEmpty())
return Y;
const int incx = X->strideAt(xLenDim);
const int incy = Y->strideAt(yLenDim);
@ -634,6 +640,9 @@ NDArray* MmulHelper::mmulNxN(const NDArray* A, const NDArray* B, NDArray* C, con
else
C = new NDArray(outOrder, cExpectedShape, DataTypeUtils::pickPairwiseResultType(A->dataType(), B->dataType()), A->getContext());
if (C->isEmpty())
return C;
const int cRank = C->rankOf();
const int aMaxis(aRank-2), aKaxis(aRank-1), bKaxis(bRank-2), bNaxis(bRank-1), cMaxis(cRank-2), cNaxis(cRank-1);

View File

@ -236,6 +236,9 @@ nd4j::NDArray* MmulHelper::mmul(const nd4j::NDArray* A, const nd4j::NDArray* B,
throw std::invalid_argument("");
}
if (z->isEmpty())
return;
NDArray* xT(const_cast<NDArray*>(x)), *yT(const_cast<NDArray*>(y)), *zT(z);
if((transX && xRank > 1) || (transY && yRank > 1)) {

View File

@ -323,4 +323,35 @@ TEST_F(EmptyTests, test_empty_reshape_1) {
delete result0;
delete result1;
}
}
TEST_F(EmptyTests, test_empty_matmul_1) {
auto x = NDArrayFactory::create<float>('c', {0, 1});
auto y = NDArrayFactory::create<float>('c', {1, 0});
auto e = NDArrayFactory::create<float>('c', {0, 0});
nd4j::ops::matmul op;
auto result = op.execute({&x, &y}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_EQ(e, *z);
delete result;
}
TEST_F(EmptyTests, test_empty_matmul_2) {
auto x = NDArrayFactory::create<float>('c', {1, 0, 4});
auto y = NDArrayFactory::create<float>('c', {1, 4, 0});
auto e = NDArrayFactory::create<float>('c', {1, 0, 0});
nd4j::ops::matmul op;
auto result = op.execute({&x, &y}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_EQ(e, *z);
delete result;
}

View File

@ -274,3 +274,7 @@ To build libnd4j with MKL:
Then build libnd4j as before. You may have to be careful about having multiple BLAS implementations on your path. Ideally, have only MKL on the path while building libnd4j.
Note: you may be able to get some additional performance on hyperthreaded processors by setting the system/environment variable MKL_DYNAMIC to have the value 'false'.
float16_nhcw
float16_nhwc

View File

@ -138,4 +138,48 @@ public class SDImage extends SDOps {
SDVariable out = new HsvToRgb(sd, input).outputVariable();
return updateVariableNameAndReference(out, name);
}
/**
* Converting array from RGB to YIQ format
* @param name name
* @param input 3D image
* @return 3D image
*/
public SDVariable rgbToYiq(String name, @NonNull SDVariable input) {
SDVariable out = new RgbToYiq(sd, input).outputVariable();
return updateVariableNameAndReference(out, name);
}
/**
* Converting image from YIQ to RGB format
* @param name name
* @param input 3D image
* @return 3D image
*/
public SDVariable yiqToRgb(String name, @NonNull SDVariable input) {
SDVariable out = new YiqToRgb(sd, input).outputVariable();
return updateVariableNameAndReference(out, name);
}
/**
* Converting array from RGB to YUV format
* @param name name
* @param input 3D image
* @return 3D image
*/
public SDVariable rgbToYuv(String name, @NonNull SDVariable input) {
SDVariable out = new RgbToYuv(sd, input).outputVariable();
return updateVariableNameAndReference(out, name);
}
/**
* Converting image from YUV to RGB format
* @param name name
* @param input 3D image
* @return 3D image
*/
public SDVariable yuvToRgb(String name, @NonNull SDVariable input) {
SDVariable out = new YuvToRgb(sd, input).outputVariable();
return updateVariableNameAndReference(out, name);
}
}

View File

@ -593,6 +593,11 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.custom.AdjustContrastV2.class,
org.nd4j.linalg.api.ops.custom.HsvToRgb.class,
org.nd4j.linalg.api.ops.custom.RgbToHsv.class,
org.nd4j.linalg.api.ops.custom.RgbToYiq.class,
org.nd4j.linalg.api.ops.custom.RgbToGrayscale.class,
org.nd4j.linalg.api.ops.custom.YiqToRgb.class,
org.nd4j.linalg.api.ops.custom.RgbToYuv.class,
org.nd4j.linalg.api.ops.custom.YuvToRgb.class,
org.nd4j.linalg.api.ops.custom.BitCast.class,
org.nd4j.linalg.api.ops.custom.CompareAndBitpack.class,
org.nd4j.linalg.api.ops.custom.DivideNoNan.class,
@ -609,7 +614,8 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.custom.ToggleBits.class,
org.nd4j.linalg.api.ops.custom.Igamma.class,
org.nd4j.linalg.api.ops.custom.Igammac.class,
org.nd4j.linalg.api.ops.custom.Digamma.class
org.nd4j.linalg.api.ops.custom.Digamma.class,
org.nd4j.linalg.api.ops.custom.Lu.class
);
static {

View File

@ -0,0 +1,69 @@
/* ******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.custom;
import lombok.NoArgsConstructor;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
@NoArgsConstructor
public class Lu extends DynamicCustomOp {
private DataType indexDataType;
public Lu(INDArray input) {
addInputArgument(input);
}
public Lu(SameDiff sameDiff, SDVariable input) {
super(sameDiff, new SDVariable[]{input});
}
@Override
public String opName() {
return "lu";
}
@Override
public String tensorflowName() {
return "Lu";
}
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
if (attributesForNode.containsKey("output_idx_type")){
indexDataType = TFGraphMapper.convertType(attributesForNode.get("output_idx_type").getType());
}
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
int n = args().length;
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
return Arrays.asList(inputDataTypes.get(0), indexDataType);
}
}

View File

@ -0,0 +1,44 @@
/* ******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.custom;
import lombok.NoArgsConstructor;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
@NoArgsConstructor
public class RgbToGrayscale extends DynamicCustomOp {
public RgbToGrayscale(INDArray image) {
addInputArgument(image);
}
public RgbToGrayscale(SameDiff sameDiff, SDVariable image) {
super(sameDiff, new SDVariable[]{image});
}
@Override
public String opName() {
return "rgb_to_grs";
}
@Override
public String tensorflowName() {
return "RgbToGrs";
}
}

View File

@ -0,0 +1,56 @@
/* ******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.custom;
import lombok.NoArgsConstructor;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import java.util.Collections;
import java.util.List;
@NoArgsConstructor
public class RgbToYiq extends DynamicCustomOp {
public RgbToYiq(INDArray input) {
addInputArgument(input);
}
public RgbToYiq(SameDiff sameDiff, SDVariable input) {
super(sameDiff, new SDVariable[]{input});
}
@Override
public String opName() {
return "rgb_to_yiq";
}
@Override
public String tensorflowName() {
return "RgbToYiq";
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
int n = args().length;
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
return Collections.singletonList(inputDataTypes.get(0));
}
}

View File

@ -0,0 +1,56 @@
/* ******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.custom;
import lombok.NoArgsConstructor;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import java.util.Collections;
import java.util.List;
@NoArgsConstructor
public class RgbToYuv extends DynamicCustomOp {
public RgbToYuv(INDArray input) {
addInputArgument(input);
}
public RgbToYuv(SameDiff sameDiff, SDVariable input) {
super(sameDiff, new SDVariable[]{input});
}
@Override
public String opName() {
return "rgb_to_yuv";
}
@Override
public String tensorflowName() {
return "RgbToYuv";
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
int n = args().length;
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
return Collections.singletonList(inputDataTypes.get(0));
}
}

View File

@ -0,0 +1,55 @@
/* ******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.custom;
import lombok.NoArgsConstructor;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import java.util.Collections;
import java.util.List;
@NoArgsConstructor
public class YiqToRgb extends DynamicCustomOp {
public YiqToRgb(INDArray input) {
addInputArgument(input);
}
public YiqToRgb(SameDiff sameDiff, SDVariable input) {
super(sameDiff, new SDVariable[]{input});
}
@Override
public String opName() {
return "yiq_to_rgb";
}
@Override
public String tensorflowName() {
return "YiqToRgb";
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
int n = args().length;
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
return Collections.singletonList(inputDataTypes.get(0));
}
}

View File

@ -0,0 +1,56 @@
/* ******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.custom;
import lombok.NoArgsConstructor;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import java.util.Collections;
import java.util.List;
@NoArgsConstructor
public class YuvToRgb extends DynamicCustomOp {
public YuvToRgb(INDArray input) {
addInputArgument(input);
}
public YuvToRgb(SameDiff sameDiff, SDVariable input) {
super(sameDiff, new SDVariable[]{input});
}
@Override
public String opName() {
return "yuv_to_rgb";
}
@Override
public String tensorflowName() {
return "YuvToRgb";
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
int n = args().length;
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
return Collections.singletonList(inputDataTypes.get(0));
}
}

View File

@ -6081,6 +6081,42 @@ public class Nd4jTestsC extends BaseNd4jTest {
assertEquals(mE, mC);
}
/*
Analog of this TF code:
a = tf.constant([], shape=[0,1])
b = tf.constant([], shape=[1, 0])
c = tf.matmul(a, b)
*/
@Test
public void testMatmul_Empty() {
val mA = Nd4j.create(0,1);
val mB = Nd4j.create(1,0);
val mC = Nd4j.create(0,0);
val op = DynamicCustomOp.builder("matmul")
.addInputs(mA, mB)
.addOutputs(mC)
.build();
Nd4j.getExecutioner().exec(op);
assertEquals(Nd4j.create(0,0), mC);
}
@Test
public void testMatmul_Empty1() {
val mA = Nd4j.create(1,0, 4);
val mB = Nd4j.create(1,4, 0);
val mC = Nd4j.create(1,0, 0);
val op = DynamicCustomOp.builder("mmul")
.addInputs(mA, mB)
.addOutputs(mC)
.addIntegerArguments(0,0)
.build();
Nd4j.getExecutioner().exec(op);
assertEquals(Nd4j.create(1,0,0), mC);
}
@Test
public void testScalarSqueeze() {

View File

@ -19,6 +19,7 @@ package org.nd4j.linalg.custom;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.apache.commons.lang3.ArrayUtils;
import org.junit.Ignore;
import org.junit.Test;
import org.nd4j.linalg.BaseNd4jTest;
@ -1161,6 +1162,23 @@ public class CustomOpsTests extends BaseNd4jTest {
assertArrayEquals(expectedY.shape(), y.shape());
}
@Test
public void testFusedBatchNormHalf() {
INDArray x = Nd4j.create(DataType.HALF, 1,2,3,4);
//INDArray scale = Nd4j.createFromArray(new float[]{0.7717f, 0.9281f, 0.9846f, 0.4838f});
//INDArray offset = Nd4j.createFromArray(new float[]{0.9441f, 0.5957f, 0.8669f, 0.3502f});
INDArray scale = Nd4j.create(DataType.HALF, 4);
INDArray offset = Nd4j.create(DataType.HALF, 4);
INDArray y = Nd4j.createUninitialized(DataType.HALF, x.shape());
INDArray batchMean = Nd4j.create(4);
INDArray batchVar = Nd4j.create(4);
FusedBatchNorm op = new FusedBatchNorm(x, scale, offset, 0, 1,
y, batchMean, batchVar);
Nd4j.exec(op);
}
@Test
public void testMatrixBandPart() {
INDArray x = Nd4j.linspace(DataType.DOUBLE, 1.0, 1.0, 2*3*3).reshape(2,3,3);
@ -1367,76 +1385,232 @@ public class CustomOpsTests extends BaseNd4jTest {
@Test
@Ignore
public void testRgbToHsv() {
INDArray expected = Nd4j.createFromArray(new float[]{6.75000000e+01f, 2.54545455e-01f, 8.62745098e-01f, 1.80000000e+02f,
3.27777778e-01f, 7.05882353e-01f, 1.35066079e+02f, 9.26530612e-01f,
9.60784314e-01f, 7.45341615e-01f, 6.85106383e-01f, 9.21568627e-01f,
2.78688525e+02f, 7.85407725e-01f, 9.13725490e-01f, 2.10989011e+01f,
4.76439791e-01f, 7.49019608e-01f, 2.89038462e+02f, 8.48979592e-01f,
9.60784314e-01f, 1.56416185e+02f, 6.92000000e-01f, 9.80392157e-01f,
3.52881356e+02f, 5.31531532e-01f, 4.35294118e-01f, 1.07142857e+01f,
2.90155440e-01f, 7.56862745e-01f, 3.43384615e+02f, 3.86904762e-01f,
6.58823529e-01f, 1.78321678e+02f, 7.48691099e-01f, 7.49019608e-01f,
2.30645161e+02f, 7.78242678e-01f, 9.37254902e-01f, 3.19159664e+02f,
7.62820513e-01f, 6.11764706e-01f, 2.10126582e+01f, 9.71311475e-01f,
9.56862745e-01f, 2.90896552e+02f, 5.96707819e-01f, 9.52941176e-01f,
1.74822335e+02f, 9.42583732e-01f, 8.19607843e-01f, 2.06600985e+02f,
9.90243902e-01f, 8.03921569e-01f, 1.06883721e+02f, 8.70445344e-01f,
9.68627451e-01f, 1.95272727e+02f, 6.11111111e-01f, 7.05882353e-01f}).reshape(5,4,3);
INDArray input = Nd4j.createFromArray(new float[]{213.f, 220.f, 164.f, 121.f, 180.f, 180.f, 18.f, 245.f, 75.f, 235.f, 76.f, 74.f, 168.f,
50.f, 233.f, 191.f, 132.f, 100.f, 207.f, 37.f, 245.f, 77.f, 250.f, 182.f, 111.f, 52.f,
59.f, 193.f, 147.f, 137.f, 168.f, 103.f, 121.f, 48.f, 191.f, 187.f, 53.f, 82.f, 239.f,
156.f, 37.f, 118.f, 244.f, 90.f, 7.f, 221.f, 98.f, 243.f, 12.f, 209.f, 192.f, 2.f,
115.f, 205.f, 79.f, 247.f, 32.f, 70.f, 152.f, 180.f}).reshape(5,4,3);
INDArray expected = Nd4j.createFromArray(new float[]{
0.545678377f, 0.644941628f, 0.461456001f, 0.588904262f, 0.725874603f,
0.517642438f, 0.0869259685f, 0.54742825f, 0.413571358f, 0.890151322f,
0.928968489f, 0.684074104f, 0.52110225f, 0.753103435f, 0.913557053f,
0.46850124f, 0.761800349f, 0.237176552f, 0.90049392f, 0.965541422f,
0.486593395f, 0.263826847f, 0.290193319f, 0.148351923f, 0.674094439f,
0.0361763388f, 0.3721793f, 0.823592246f, 0.524110138f, 0.2204483f,
0.632020354f, 0.637001634f, 0.216262609f, 0.279114306f, 0.25007084f,
0.30433768f, 0.0448598303f, 0.586083114f, 0.978048146f, 0.91390729f,
0.385092884f, 0.218390301f, 0.762684941f, 0.505838513f, 0.366362303f,
0.931746006f, 0.00208298792f, 0.875348926f, 0.428009957f, 0.270003974f,
0.313204288f, 0.775881767f, 0.367065936f, 0.164243385f, 0.644775152f,
0.575452209f, 0.911922634f, 0.0581932105f, 0.437950462f, 0.946475744f
}).reshape(5,4,3);
INDArray input = Nd4j.createFromArray(new float[]{
0.262831867f, 0.723622441f, 0.740797927f, 0.717254877f, 0.430244058f,
0.418478161f, 0.906427443f, 0.199753001f, 0.725874603f, 0.890151322f,
0.928968489f, 0.684074104f, 0.312434604f, 0.991390795f, 0.163174023f,
0.268038541f, 0.361258626f, 0.685067773f, 0.682347894f, 0.84635365f,
0.761800349f, 0.753103435f, 0.913557053f, 0.965541422f, 0.112067183f,
0.540247589f, 0.280050347f, 0.106776128f, 0.679180562f, 0.870388806f,
0.604331017f, 0.630475283f, 0.674094439f, 0.279114306f, 0.632020354f,
0.823592246f, 0.490824632f, 0.75257351f, 0.129888852f, 0.849081645f,
0.883509099f, 0.765611768f, 0.997870266f, 0.446510047f, 0.385092884f,
0.931746006f, 0.978048146f, 0.91390729f, 0.685308874f, 0.0834472676f,
0.396037966f, 0.756701186f, 0.597481251f, 0.784472764f, 0.514242649f,
0.392005324f, 0.911922634f, 0.270003974f, 0.644775152f, 0.946475744f
}).reshape(5,4,3);
RgbToHsv op = new RgbToHsv(input);
INDArray[] ret = Nd4j.exec(op);
assertEquals(ret[0], expected);
}
// Exact copy of libnd4j test
@Ignore
@Test
public void testHsvToRgb() {
INDArray input = Nd4j.createFromArray(new float[]{130.f, 61.f, 239.f, 117.f, 16.f, 168.f, 181.f, 223.f, 0.f, 49.f, 195.f, 195.f, 131.f,
153.f, 78.f, 86.f, 21.f, 209.f, 101.f, 14.f, 107.f, 191.f, 98.f, 210.f}).reshape(8,3);
INDArray input = Nd4j.createFromArray(new float[]{0.705504596f, 0.793608069f, 0.65870738f, 0.848827183f, 0.920532584f,
0.887555957f, 0.72317636f, 0.563831031f, 0.773604929f, 0.269532293f,
0.332347751f, 0.111181192f}).reshape(4,3);
INDArray expected = Nd4j.createFromArray(new float[]{263.25842697f, 0.74476987f, 0.9372549f, 279.86842105f,
0.9047619f, 0.65882353f, 71.30044843f, 1.f, 0.8745098f, 180.f, 0.74871795f, 0.76470588f,
77.6f, 0.49019608f, 0.6f, 260.74468085f,
0.89952153f, 0.81960784f, 296.12903226f, 0.86915888f,
0.41960784f, 289.82142857f, 0.53333333f, 0.82352941f}).reshape(8,3);
INDArray expected = Nd4j.createFromArray(new float[]{0.257768334f, 0.135951888f, 0.65870738f, 0.887555957f, 0.0705317783f,
0.811602857f, 0.485313689f, 0.337422464f, 0.773604929f, 0.0883753772f,
0.111181192f, 0.074230373f}).reshape(4,3);
HsvToRgb op = new HsvToRgb(input);
INDArray[] ret = Nd4j.exec(op);
assertEquals(ret[0], expected);
}
@Ignore
@Test
public void testHsvToRgb_1() {
/* Emulation of simple TF test:
image = tf.random_uniform(shape = [1,1,3])
tf.image.hsv_to_rgb(image)*/
INDArray image = Nd4j.createFromArray(new float[]{0.7788f, 0.8012f, 0.7244f}).
INDArray image = Nd4j.createFromArray(new float[]{0.778785586f,0.801197767f,0.724374652f}).
reshape(1,1,3);
HsvToRgb op = new HsvToRgb(image);
INDArray[] ret = Nd4j.exec(op);
INDArray expected = Nd4j.createFromArray(new float[]{0.53442812f,0.144007295f,0.724374652f}).reshape(1,1,3);
System.out.println(ret[0].toStringFull());
INDArray expected = Nd4j.createFromArray(new float[]{ 0.53442812f, 0.144007325f, 0.724374652f}).reshape(1,1,3);
assertEquals(expected, ret[0]);
}
@Ignore
@Test
public void testRgbToHsv_1() {
/* Emulation of simple TF test:
image = tf.random_uniform(shape = [1,2,3])
tf.image.rgb_to_hsv(image)*/
INDArray image = Nd4j.createFromArray(new float[]{0.7788f,0.8012f,0.7244f,
0.2309f,0.7271f,0.1804f}).reshape(1,2,3);
INDArray image = Nd4j.createFromArray(new float[]{0.778785586f,0.801197767f,0.724374652f,
0.230894327f, 0.727141261f, 0.180390716f }).reshape(2,3);
RgbToHsv op = new RgbToHsv(image);
INDArray[] ret = Nd4j.exec(op);
INDArray expected = Nd4j.createFromArray(new float[]{0.215289578f, 0.095885336f, 0.801197767f,
0.317938268f, 0.751917899f, 0.727141261f}).reshape(1,2,3);
INDArray expected = Nd4j.createFromArray(new float[]{0.215289578f,0.095885336f,0.801197767f,
0.317938268f,0.751917899f,0.727141261f}).reshape(2,3);
assertEquals(expected, ret[0]);
}
@Test
public void testLu() {
INDArray input = Nd4j.createFromArray(new float[]{1.f, 2.f, 3.f, 0.f, 2.f, 3.f, 0.f, 0.f, 7.f})
.reshape(3,3);
Lu op = new Lu(input);
INDArray[] ret = Nd4j.exec(op);
INDArray expected = Nd4j.createFromArray(new float[]{1.f, 2.f, 3.f, 0.f, 2.f, 3.f, 0.f, 0.f, 7f}).reshape(3,3);
assertEquals(expected, ret[0]);
}
@Test
public void testRgbToYiq() {
INDArray image = Nd4j.createFromArray(new float[]{
0.48055f , 0.80757356f, 0.2564435f , 0.94277316f, 0.17006584f,
0.33366168f, 0.41727918f, 0.54528666f, 0.48942474f, 0.3305715f ,
0.98633456f, 0.00158441f, 0.97605824f, 0.02462568f, 0.14837205f,
0.00112842f, 0.99260217f, 0.9585542f , 0.41196227f, 0.3095014f ,
0.6620493f , 0.30888894f, 0.3122602f , 0.7993488f , 0.86656475f,
0.5997049f , 0.9776477f , 0.72481847f, 0.7835693f , 0.14649455f,
0.3573504f , 0.33301765f, 0.7853056f , 0.25830218f, 0.59289205f,
0.41357264f, 0.5934154f , 0.72647524f, 0.6623308f , 0.96197623f,
0.0720306f , 0.23853847f, 0.1427159f , 0.19581454f, 0.06766324f,
0.10614152f, 0.26093867f, 0.9584985f , 0.01258832f, 0.8160156f ,
0.56506383f, 0.08418505f, 0.86440504f, 0.6807802f , 0.20662387f,
0.4153733f , 0.76146203f, 0.50057423f, 0.08274968f, 0.9521758f
}).reshape(5,4,3);
INDArray expected = Nd4j.createFromArray(new float[]{
0.64696468f, -0.01777124f, -0.24070648f, 0.41975525f, 0.40788622f,
0.21433232f, 0.50064416f, -0.05832884f, -0.04447775f, 0.67799989f,
-0.07432612f, -0.44518381f, 0.32321111f, 0.52719408f, 0.2397369f ,
0.69227005f, -0.57987869f, -0.22032876f, 0.38032767f, -0.05223263f,
0.13137188f, 0.3667803f , -0.15853189f, 0.15085728f, 0.72258149f,
0.03757231f, 0.17403452f, 0.69337627f, 0.16971045f, -0.21071186f,
0.39185397f, -0.13084008f, 0.145886f , 0.47240727f, -0.1417591f ,
-0.12659159f, 0.67937788f, -0.05867803f, -0.04813048f, 0.35710624f,
0.47681283f, 0.24003804f, 0.1653288f , 0.00953913f, -0.05111816f,
0.29417614f, -0.31640032f, 0.18433114f, 0.54718234f, -0.39812097f,
-0.24805083f, 0.61018603f, -0.40592682f, -0.22219216f, 0.39241133f,
-0.23560742f, 0.06353694f, 0.3067938f , -0.0304029f , 0.35893188f
}).reshape(5,4,3);
RgbToYiq op = new RgbToYiq(image);
INDArray[] ret = Nd4j.exec(op);
assertEquals(expected, ret[0]);
}
@Test
public void testYiqToRgb() {
INDArray image = Nd4j.createFromArray(new float[]{
0.775258899f, -0.288912386f, -0.132725924f, 0.0664454922f, -0.212469354f,
0.455438733f, 0.418221354f, 0.349350512f, 0.145902053f, 0.947576523f,
-0.471601307f, 0.263960421f, 0.700227439f, 0.32434237f, -0.278446227f,
0.130805135f, -0.438441873f, 0.187127829f, 0.0276055578f, -0.179727226f,
0.305075705f, 0.716282248f, 0.278215706f, -0.44586885f, 0.76971364f,
0.131288841f, -0.141177326f, 0.900081575f, -0.0788725987f, 0.14756602f,
0.387832165f, 0.229834676f, 0.47921446f, 0.632930398f, 0.0443540029f,
-0.268817365f, 0.0977194682f, -0.141669706f, -0.140715122f, 0.946808815f,
-0.52525419f, -0.106209636f, 0.659476519f, 0.391066104f, 0.426448852f,
0.496989518f, -0.283434421f, -0.177366048f, 0.715208411f, -0.496444523f,
0.189553142f, 0.616444945f, 0.345852494f, 0.447739422f, 0.224696323f,
0.451372236f, 0.298027098f, 0.446561724f, -0.187599331f, -0.448159873f
}).reshape(5,4,3);
INDArray expected = Nd4j.createFromArray(new float[]{
0.416663059f, 0.939747555f, 0.868814286f, 0.146075352f, -0.170521997f,
1.07776645f, 0.842775284f, 0.228765106f, 0.280231822f, 0.660605291f,
0.905021825f, 1.91936605f, 0.837427991f, 0.792213732f, -0.133271854f,
-0.17216571f, 0.128957025f, 0.934955336f, 0.0451873479f, -0.120952621f,
0.746436225f, 0.705446224f, 0.929172217f, -0.351493549f, 0.807577594f,
0.825371955f, 0.383812296f, 0.916293093f, 0.82603058f, 1.23885956f,
0.905059196f, 0.015164554f, 0.950156781f, 0.508443732f, 0.794845279f,
0.12571529f, -0.125074273f, 0.227326869f, 0.0147000261f, 0.378735409f,
1.15842402f, 1.34712305f, 1.2980804f, 0.277102016f, 0.953435072f,
0.115916842f, 0.688879376f, 0.508405162f, 0.35829352f, 0.727568094f,
1.58768577f, 1.22504294f, 0.232589777f, 0.996727258f, 0.841224629f,
-0.0909671176f, 0.233051388f, -0.0110094378f, 0.787642119f, -0.109582274f
}).reshape(5,4,3);
YiqToRgb op = new YiqToRgb(image);
INDArray[] ret = Nd4j.exec(op);
assertEquals(expected, ret[0]);
}
@Test
public void testRgbToGrayscale() {
INDArray image = Nd4j.createFromArray(new float[]{
1.7750e+01f, -7.1062e+01f, -1.0019e+02f,-2.3406e+01f, 5.2094e+01f,
9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f,3.3562e+01f,
-5.8844e+01f, 2.2750e+01f, -1.0477e+01f, 7.7344e+00f, 9.5469e+00f,
2.1391e+01f, -8.5312e+01f, 7.5830e-01f,2.3125e+01f, 1.8145e+00f,
1.4602e+01f,-4.5859e+00f, 3.9344e+01f, 1.1617e+01f,-8.6562e+01f,
1.0038e+02f, 6.7938e+01f,5.9961e+00f, 6.7812e+01f, 2.9734e+01f,
2.9609e+01f, -6.1438e+01f, 1.7750e+01f,6.8562e+01f, -7.4414e+00f,
3.9656e+01f,1.1641e+01f, -2.7516e+01f, 6.7562e+01f,7.8438e+01f,
5.4883e+00f, 2.9438e+01f,-3.1344e+01f, 6.5125e+01f,
1.2695e+01f,4.0531e+01f, -6.1211e+00f, 6.2219e+01f,4.6812e+01f,
5.2250e+01f, -1.1414e+01f,1.5404e-02f, 2.9938e+01f, 5.6719e+00f,
-2.0125e+01f, 2.1531e+01f, 6.2500e+01f,7.2188e+01f, 9.3750e+00f,
-4.8125e+01f
}).reshape(5,4,3);
INDArray expected = Nd4j.createFromArray(new float[]{
-47.82958221f, 34.46305847f, 21.36137581f, -21.91625023f,2.49686432f,
-43.59792709f, 9.64180183f, 23.04854202f,40.7946167f, 44.98754883f,
-25.19047546f, 20.64586449f,-4.97033119f, 30.0226841f, 30.30688286f,
15.61459541f,43.36166f, 18.22480774f, 13.74833488f, 21.59387016f
}).reshape(5,4,1);
RgbToGrayscale op = new RgbToGrayscale(image);
INDArray[] ret = Nd4j.exec(op);
assertEquals(expected, ret[0]);
}
@Test
public void testRgbToYuv() {
INDArray image = Nd4j.createFromArray(new float[]{
10f,50f,200f
});
INDArray expected = Nd4j.createFromArray(new float[]{
55.14f , 71.2872001f, -39.6005542f
});
RgbToYuv op = new RgbToYuv(image);
INDArray[] ret = Nd4j.exec(op);
assertEquals(expected, ret[0]);
}
@Test
public void testYuvToRgb() {
INDArray image = Nd4j.createFromArray(new float[]{
55.14f , 71.2872001f, -39.6005542f
});
INDArray expected = Nd4j.createFromArray(new float[]{
10f, 50f, 200f
});
YuvToRgb op = new YuvToRgb(image);
INDArray[] ret = Nd4j.exec(op);
assertEquals(expected, ret[0]);
}
@Test
public void testRgbToYiqEmpty() {
INDArray image = Nd4j.create(0,4,3);
RgbToYiq op = new RgbToYiq(image);
INDArray[] ret = Nd4j.exec(op);
assertArrayEquals(image.shape(), ret[0].shape());
}
}