Added tests and fixed op name.

master
shugeo 2019-10-02 15:04:28 +03:00
parent 863ff76878
commit a27e61553a
2 changed files with 28 additions and 2 deletions

View File

@ -104,8 +104,8 @@ namespace nd4j {
*
* all as above op
* */
#if NOT_EXCLUDED(OP_bincast)
DECLARE_CUSTOM_OP(bincast, 1, 1, false, 0, 1);
#if NOT_EXCLUDED(OP_bitcast)
DECLARE_CUSTOM_OP(bitcast, 1, 1, false, 0, 1);
#endif
}
}

View File

@ -228,6 +228,32 @@ TEST_F(DeclarableOpsTests15, Test_AdjustContrast_4) {
ASSERT_TRUE(e.equalsTo(out));
delete result;
}
TEST_F(DeclarableOpsTests15, Test_BinCast_1) {
auto x = NDArrayFactory::create<float>('c', {2, 2, 2});
auto e = NDArrayFactory::create<double>('c', {2, 2}, {2., 512., 8192., 131072.032 });
x.linspace(1.);
nd4j::ops::bitcast op;
auto result = op.execute({&x}, {}, {nd4j::DataType::DOUBLE}, {});
ASSERT_EQ(Status::OK(), result->status());
auto out = result->at(0);
// out->printIndexedBuffer("Casted result");
ASSERT_TRUE(e.equalsTo(out));
delete result;
}
TEST_F(DeclarableOpsTests15, Test_BinCast_2) {
auto x = NDArrayFactory::create<float>('c', {2, 4});
auto e = NDArrayFactory::create<float16>('c', {2, 4, 2}, {0, 1.875, 0, 2., 0, 2.125, 0, 2.25,
0, 2.312, 0, 2.375, 0, 2.438, 0., 2.5});
x.linspace(1.);
nd4j::ops::bitcast op;
auto result = op.execute({&x}, {}, {nd4j::DataType::HALF}, {});
ASSERT_EQ(Status::OK(), result->status());
auto out = result->at(0);
out->printIndexedBuffer("Casted result");
ASSERT_TRUE(e.equalsTo(out));
delete result;
}
TEST_F(DeclarableOpsTests15, Test_depthwise_bp_1) {
auto in = NDArrayFactory::create<float>('c', {4, 8, 64, 64});