4 additional tests

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-07-25 13:50:36 +03:00 committed by AlexDBlack
parent 70dbe70594
commit 7c5c84bea8
3 changed files with 53 additions and 3 deletions

View File

@ -31,7 +31,6 @@ namespace nd4j {
REQUIRE_TRUE(output->isScalar(), 0, "Rank output should be scalar");
// output->assign(static_cast<Nd4jLong>(input->rankOf()));
output->assign(input->rankOf());
return Status::OK();

View File

@ -52,11 +52,10 @@ namespace nd4j {
auto inputShapeInfo = inputShape->at(1);
int shapeInfoLength = inputShapeInfo[0]*2 + 4;
// FIXME: remove memcpy
Nd4jLong* outputShapeInfo(nullptr);
COPY_SHAPE(inputShapeInfo, outputShapeInfo);
return SHAPELIST(outputShapeInfo);
return SHAPELIST(CONSTANT(outputShapeInfo));
}
DECLARE_TYPES(reshapeas) {

View File

@ -305,6 +305,58 @@ TEST_F(DeclarableOpsTests15, test_hashCode_2) {
delete resultB0;
}
TEST_F(DeclarableOpsTests15, test_reshape_to_scalar_1) {
auto array = NDArrayFactory::create<float>(119.f);
auto e = NDArrayFactory::create<float>('c', {1, 1}, {119.f});
nd4j::ops::reshape op;
auto result = op.execute({&array}, {}, {1, 1});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_EQ(e, *z);
delete result;
}
TEST_F(DeclarableOpsTests15, test_reshape_to_scalar_2) {
auto array = NDArrayFactory::create<float>(119.f);
auto e = NDArrayFactory::create<float>('c', {1, 1}, {119.f});
auto z = NDArrayFactory::create<float>('c', {1, 1});
nd4j::ops::reshape op;
auto result = op.execute({&array}, {&z}, {}, {1, 1}, {});
ASSERT_EQ(Status::OK(), result);
ASSERT_EQ(e, z);
}
TEST_F(DeclarableOpsTests15, test_rank_1) {
auto array = NDArrayFactory::create<float>('c', {4, 64});
auto e = NDArrayFactory::create<int>('c', {}, {2});
auto z = NDArrayFactory::create<int>('c', {});
nd4j::ops::rank op;
auto result = op.execute({&array}, {&z}, {}, {}, {});
ASSERT_EQ(Status::OK(), result);
ASSERT_EQ(e, z);
}
TEST_F(DeclarableOpsTests15, test_rank_2) {
auto array = NDArrayFactory::create<float>('c', {4, 64});
auto e = NDArrayFactory::create<int>('c', {}, {2});
nd4j::ops::rank op;
auto result = op.execute({&array}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_EQ(e, *z);
delete result;
}
TEST_F(DeclarableOpsTests15, test_lstmBlock_1) {
auto x0 = NDArrayFactory::create<Nd4jLong>(5);
auto x1 = NDArrayFactory::create<float>('c', {5, 1, 4}, {0.7787856f, 0.80119777f, 0.72437465f, 0.23089433f, 0.72714126f, 0.18039072f, 0.50563407f, 0.89252293f, 0.5461209f, 0.92336726f, 0.085571885f, 0.7937801f, 0.65908563f, 0.55552566f, 0.15962744f, 0.30874777f, 0.15476847f, 0.46954823f, 0.9938899f, 0.6112741f});