Shugeo pad fix3 (#132)

* Expanding allowed paddings type to 64bit ints also.

* Extended to int64 paddins data types for mirror_pad op.

Signed-off-by: shugeo <sgazeos@gmail.com>
master
shugeo 2019-12-19 12:14:02 +02:00 committed by raver119
parent de3c0afdce
commit e303c06042
3 changed files with 6 additions and 6 deletions

View File

@ -56,7 +56,7 @@ CUSTOM_OP_IMPL(mirror_pad, 2, 1, false, 0, 1) {
DECLARE_TYPES(mirror_pad) {
getOpDescriptor()->setAllowedInputTypes(0, {ALL_FLOATS});
getOpDescriptor()->setAllowedInputTypes(1, {DataType::INT32}); // to conform with TF
getOpDescriptor()->setAllowedInputTypes(1, {DataType::INT32, DataType::INT64}); // to conform with TF
getOpDescriptor()->setAllowedOutputTypes(0, {ALL_FLOATS});
}

View File

@ -78,7 +78,7 @@ CUSTOM_OP_IMPL(pad, 2, 1, false, 0, 1) {
DECLARE_TYPES(pad) {
getOpDescriptor()
->setAllowedInputTypes(0, nd4j::DataType::ANY)
->setAllowedInputTypes(1, {DataType::INT32}) // INT32 with TF
->setAllowedInputTypes(1, {DataType::INT32, DataType::INT64}) // INT32 with TF
// ->setAllowedInputTypes(1, {DataType::INT32, DataType::INT64}) // INT32 with TF, but used also INT64 due long shapes
->setSameMode(true);
}

View File

@ -597,7 +597,7 @@ TEST_F(DeclarableOpsTests12, reverse_test15) {
TEST_F(DeclarableOpsTests12, mirrorPad_test17) {
NDArray x('c', {2,3}, {1,2,3,4,5,6}, nd4j::DataType::DOUBLE);
NDArray padding('c', {2,2}, {1,1,2,2}, nd4j::DataType::INT32);
NDArray padding('c', {2,2}, {1,1,2,2}, nd4j::DataType::INT64);
NDArray z('c', {4,7}, nd4j::DataType::DOUBLE);
NDArray exp1('c', {4,7}, {6, 5, 4, 5, 6, 5, 4,3, 2, 1, 2, 3, 2, 1,6, 5, 4, 5, 6, 5, 4,3, 2, 1, 2, 3, 2, 1}, nd4j::DataType::DOUBLE);
NDArray exp2('c', {4,7}, {2, 1, 1, 2, 3, 3, 2,2, 1, 1, 2, 3, 3, 2,5, 4, 4, 5, 6, 6, 5,5, 4, 4, 5, 6, 6, 5}, nd4j::DataType::DOUBLE);
@ -621,7 +621,7 @@ TEST_F(DeclarableOpsTests12, mirrorPad_test17) {
TEST_F(DeclarableOpsTests12, mirrorPad_test18) {
NDArray x('c', {3}, {1,2,3}, nd4j::DataType::DOUBLE);
NDArray padding('c', {2}, {1,1}, nd4j::DataType::INT32);
NDArray padding('c', {1, 2}, {1,1}, nd4j::DataType::INT32);
NDArray z('c', {5}, nd4j::DataType::DOUBLE);
NDArray exp('c', {5}, {2,1,2,3,2}, nd4j::DataType::DOUBLE);
@ -1434,11 +1434,11 @@ TEST_F(DeclarableOpsTests12, pad_tests2) {
TEST_F(DeclarableOpsTests12, pad_tests3) {
float inBuff[] = {1.f, 2.f, 3.f, 4.f, 5.f, 6.f};
int padBuff[] = {1,1,2,2};
Nd4jLong padBuff[] = {1,1,2,2};
float expBuff[] = {2.f, 1.f, 1.f, 2.f, 3.f, 3.f, 2.f, 2.f,1.f,1.f,2.f,3.f,3.f,2.f, 5.f,4.f,4.f,5.f,6.f,6.f,5.f, 5.f,4.f,4.f,5.f,6.f,6.f,5.f};
auto input = NDArrayFactory::create<float>(inBuff, 'c', {2,3});
auto paddings = NDArrayFactory::create<int>(padBuff, 'c', {2,2});
auto paddings = NDArrayFactory::create<Nd4jLong>(padBuff, 'c', {2,2});
auto expected = NDArrayFactory::create<float>(expBuff, 'c', {4,7});
nd4j::ops::pad op;