diff --git a/libnd4j/include/ops/declarable/headers/datatypes.h b/libnd4j/include/ops/declarable/headers/datatypes.h index 7c96ae4c7..b82ab4ad6 100644 --- a/libnd4j/include/ops/declarable/headers/datatypes.h +++ b/libnd4j/include/ops/declarable/headers/datatypes.h @@ -104,8 +104,8 @@ namespace nd4j { * * all as above op * */ - #if NOT_EXCLUDED(OP_bincast) - DECLARE_CUSTOM_OP(bincast, 1, 1, false, 0, 1); + #if NOT_EXCLUDED(OP_bitcast) + DECLARE_CUSTOM_OP(bitcast, 1, 1, false, 0, 1); #endif } } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp index 0bd05cec3..685953f2d 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp @@ -228,6 +228,32 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_4) { ASSERT_TRUE(e.equalsTo(out)); delete result; } +TEST_F(DeclarableOpsTests15, Test_BinCast_1) { + auto x = NDArrayFactory::create('c', {2, 2, 2}); + auto e = NDArrayFactory::create('c', {2, 2}, {2., 512., 8192., 131072.032 }); + x.linspace(1.); + nd4j::ops::bitcast op; + auto result = op.execute({&x}, {}, {nd4j::DataType::DOUBLE}, {}); + ASSERT_EQ(Status::OK(), result->status()); + auto out = result->at(0); +// out->printIndexedBuffer("Casted result"); + ASSERT_TRUE(e.equalsTo(out)); + delete result; +} + +TEST_F(DeclarableOpsTests15, Test_BinCast_2) { + auto x = NDArrayFactory::create('c', {2, 4}); + auto e = NDArrayFactory::create('c', {2, 4, 2}, {0, 1.875, 0, 2., 0, 2.125, 0, 2.25, + 0, 2.312, 0, 2.375, 0, 2.438, 0., 2.5}); + x.linspace(1.); + nd4j::ops::bitcast op; + auto result = op.execute({&x}, {}, {nd4j::DataType::HALF}, {}); + ASSERT_EQ(Status::OK(), result->status()); + auto out = result->at(0); + out->printIndexedBuffer("Casted result"); + ASSERT_TRUE(e.equalsTo(out)); + delete result; +} TEST_F(DeclarableOpsTests15, Test_depthwise_bp_1) { auto in = NDArrayFactory::create('c', {4, 8, 64, 64});