diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/rank.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/rank.cpp index 7a15967d5..9ef0ed12b 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/rank.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/rank.cpp @@ -31,7 +31,6 @@ namespace nd4j { REQUIRE_TRUE(output->isScalar(), 0, "Rank output should be scalar"); -// output->assign(static_cast(input->rankOf())); output->assign(input->rankOf()); return Status::OK(); diff --git a/libnd4j/include/ops/declarable/generic/shape/reshape_as.cpp b/libnd4j/include/ops/declarable/generic/shape/reshape_as.cpp index a5368ebed..92dc2a146 100644 --- a/libnd4j/include/ops/declarable/generic/shape/reshape_as.cpp +++ b/libnd4j/include/ops/declarable/generic/shape/reshape_as.cpp @@ -52,11 +52,10 @@ namespace nd4j { auto inputShapeInfo = inputShape->at(1); int shapeInfoLength = inputShapeInfo[0]*2 + 4; - // FIXME: remove memcpy Nd4jLong* outputShapeInfo(nullptr); COPY_SHAPE(inputShapeInfo, outputShapeInfo); - return SHAPELIST(outputShapeInfo); + return SHAPELIST(CONSTANT(outputShapeInfo)); } DECLARE_TYPES(reshapeas) { diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp index cee0bf415..6983cb39c 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp @@ -305,6 +305,58 @@ TEST_F(DeclarableOpsTests15, test_hashCode_2) { delete resultB0; } +TEST_F(DeclarableOpsTests15, test_reshape_to_scalar_1) { + auto array = NDArrayFactory::create(119.f); + auto e = NDArrayFactory::create('c', {1, 1}, {119.f}); + + nd4j::ops::reshape op; + auto result = op.execute({&array}, {}, {1, 1}); + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); + + ASSERT_EQ(e, *z); + + delete result; +} + +TEST_F(DeclarableOpsTests15, test_reshape_to_scalar_2) { + auto array = NDArrayFactory::create(119.f); + auto e = NDArrayFactory::create('c', {1, 1}, {119.f}); + auto z = NDArrayFactory::create('c', {1, 1}); + + nd4j::ops::reshape op; + auto result = op.execute({&array}, {&z}, {}, {1, 1}, {}); + ASSERT_EQ(Status::OK(), result); + ASSERT_EQ(e, z); +} + +TEST_F(DeclarableOpsTests15, test_rank_1) { + auto array = NDArrayFactory::create('c', {4, 64}); + auto e = NDArrayFactory::create('c', {}, {2}); + auto z = NDArrayFactory::create('c', {}); + + nd4j::ops::rank op; + auto result = op.execute({&array}, {&z}, {}, {}, {}); + ASSERT_EQ(Status::OK(), result); + ASSERT_EQ(e, z); +} + +TEST_F(DeclarableOpsTests15, test_rank_2) { + auto array = NDArrayFactory::create('c', {4, 64}); + auto e = NDArrayFactory::create('c', {}, {2}); + + nd4j::ops::rank op; + auto result = op.execute({&array}, {}, {}); + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); + + ASSERT_EQ(e, *z); + + delete result; +} + TEST_F(DeclarableOpsTests15, test_lstmBlock_1) { auto x0 = NDArrayFactory::create(5); auto x1 = NDArrayFactory::create('c', {5, 1, 4}, {0.7787856f, 0.80119777f, 0.72437465f, 0.23089433f, 0.72714126f, 0.18039072f, 0.50563407f, 0.89252293f, 0.5461209f, 0.92336726f, 0.085571885f, 0.7937801f, 0.65908563f, 0.55552566f, 0.15962744f, 0.30874777f, 0.15476847f, 0.46954823f, 0.9938899f, 0.6112741f});