parent
70dbe70594
commit
7c5c84bea8
|
@ -31,7 +31,6 @@ namespace nd4j {
|
|||
|
||||
REQUIRE_TRUE(output->isScalar(), 0, "Rank output should be scalar");
|
||||
|
||||
// output->assign(static_cast<Nd4jLong>(input->rankOf()));
|
||||
output->assign(input->rankOf());
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
@ -52,11 +52,10 @@ namespace nd4j {
|
|||
auto inputShapeInfo = inputShape->at(1);
|
||||
int shapeInfoLength = inputShapeInfo[0]*2 + 4;
|
||||
|
||||
// FIXME: remove memcpy
|
||||
Nd4jLong* outputShapeInfo(nullptr);
|
||||
COPY_SHAPE(inputShapeInfo, outputShapeInfo);
|
||||
|
||||
return SHAPELIST(outputShapeInfo);
|
||||
return SHAPELIST(CONSTANT(outputShapeInfo));
|
||||
}
|
||||
|
||||
DECLARE_TYPES(reshapeas) {
|
||||
|
|
|
@ -305,6 +305,58 @@ TEST_F(DeclarableOpsTests15, test_hashCode_2) {
|
|||
delete resultB0;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests15, test_reshape_to_scalar_1) {
|
||||
auto array = NDArrayFactory::create<float>(119.f);
|
||||
auto e = NDArrayFactory::create<float>('c', {1, 1}, {119.f});
|
||||
|
||||
nd4j::ops::reshape op;
|
||||
auto result = op.execute({&array}, {}, {1, 1});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
|
||||
ASSERT_EQ(e, *z);
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests15, test_reshape_to_scalar_2) {
|
||||
auto array = NDArrayFactory::create<float>(119.f);
|
||||
auto e = NDArrayFactory::create<float>('c', {1, 1}, {119.f});
|
||||
auto z = NDArrayFactory::create<float>('c', {1, 1});
|
||||
|
||||
nd4j::ops::reshape op;
|
||||
auto result = op.execute({&array}, {&z}, {}, {1, 1}, {});
|
||||
ASSERT_EQ(Status::OK(), result);
|
||||
ASSERT_EQ(e, z);
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests15, test_rank_1) {
|
||||
auto array = NDArrayFactory::create<float>('c', {4, 64});
|
||||
auto e = NDArrayFactory::create<int>('c', {}, {2});
|
||||
auto z = NDArrayFactory::create<int>('c', {});
|
||||
|
||||
nd4j::ops::rank op;
|
||||
auto result = op.execute({&array}, {&z}, {}, {}, {});
|
||||
ASSERT_EQ(Status::OK(), result);
|
||||
ASSERT_EQ(e, z);
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests15, test_rank_2) {
|
||||
auto array = NDArrayFactory::create<float>('c', {4, 64});
|
||||
auto e = NDArrayFactory::create<int>('c', {}, {2});
|
||||
|
||||
nd4j::ops::rank op;
|
||||
auto result = op.execute({&array}, {}, {});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
|
||||
ASSERT_EQ(e, *z);
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests15, test_lstmBlock_1) {
|
||||
auto x0 = NDArrayFactory::create<Nd4jLong>(5);
|
||||
auto x1 = NDArrayFactory::create<float>('c', {5, 1, 4}, {0.7787856f, 0.80119777f, 0.72437465f, 0.23089433f, 0.72714126f, 0.18039072f, 0.50563407f, 0.89252293f, 0.5461209f, 0.92336726f, 0.085571885f, 0.7937801f, 0.65908563f, 0.55552566f, 0.15962744f, 0.30874777f, 0.15476847f, 0.46954823f, 0.9938899f, 0.6112741f});
|
||||
|
|
Loading…
Reference in New Issue