diff --git a/libnd4j/include/ops/declarable/generic/transforms/mirrorPad.cpp b/libnd4j/include/ops/declarable/generic/transforms/mirrorPad.cpp index fac8451a5..2de8ee5a2 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/mirrorPad.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/mirrorPad.cpp @@ -56,7 +56,7 @@ CUSTOM_OP_IMPL(mirror_pad, 2, 1, false, 0, 1) { DECLARE_TYPES(mirror_pad) { getOpDescriptor()->setAllowedInputTypes(0, {ALL_FLOATS}); - getOpDescriptor()->setAllowedInputTypes(1, {DataType::INT32}); // to conform with TF + getOpDescriptor()->setAllowedInputTypes(1, {DataType::INT32, DataType::INT64}); // to conform with TF getOpDescriptor()->setAllowedOutputTypes(0, {ALL_FLOATS}); } diff --git a/libnd4j/include/ops/declarable/generic/transforms/pad.cpp b/libnd4j/include/ops/declarable/generic/transforms/pad.cpp index 9d410a6c3..c6c8c8ff8 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/pad.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/pad.cpp @@ -78,7 +78,7 @@ CUSTOM_OP_IMPL(pad, 2, 1, false, 0, 1) { DECLARE_TYPES(pad) { getOpDescriptor() ->setAllowedInputTypes(0, nd4j::DataType::ANY) - ->setAllowedInputTypes(1, {DataType::INT32}) // INT32 with TF + ->setAllowedInputTypes(1, {DataType::INT32, DataType::INT64}) // INT32 with TF // ->setAllowedInputTypes(1, {DataType::INT32, DataType::INT64}) // INT32 with TF, but used also INT64 due long shapes ->setSameMode(true); } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp index 5ca22c95e..0d205c2db 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp @@ -597,7 +597,7 @@ TEST_F(DeclarableOpsTests12, reverse_test15) { TEST_F(DeclarableOpsTests12, mirrorPad_test17) { NDArray x('c', {2,3}, {1,2,3,4,5,6}, nd4j::DataType::DOUBLE); - NDArray padding('c', {2,2}, {1,1,2,2}, nd4j::DataType::INT32); + NDArray padding('c', {2,2}, {1,1,2,2}, nd4j::DataType::INT64); NDArray z('c', {4,7}, nd4j::DataType::DOUBLE); NDArray exp1('c', {4,7}, {6, 5, 4, 5, 6, 5, 4,3, 2, 1, 2, 3, 2, 1,6, 5, 4, 5, 6, 5, 4,3, 2, 1, 2, 3, 2, 1}, nd4j::DataType::DOUBLE); NDArray exp2('c', {4,7}, {2, 1, 1, 2, 3, 3, 2,2, 1, 1, 2, 3, 3, 2,5, 4, 4, 5, 6, 6, 5,5, 4, 4, 5, 6, 6, 5}, nd4j::DataType::DOUBLE); @@ -621,7 +621,7 @@ TEST_F(DeclarableOpsTests12, mirrorPad_test17) { TEST_F(DeclarableOpsTests12, mirrorPad_test18) { NDArray x('c', {3}, {1,2,3}, nd4j::DataType::DOUBLE); - NDArray padding('c', {2}, {1,1}, nd4j::DataType::INT32); + NDArray padding('c', {1, 2}, {1,1}, nd4j::DataType::INT32); NDArray z('c', {5}, nd4j::DataType::DOUBLE); NDArray exp('c', {5}, {2,1,2,3,2}, nd4j::DataType::DOUBLE); @@ -1434,11 +1434,11 @@ TEST_F(DeclarableOpsTests12, pad_tests2) { TEST_F(DeclarableOpsTests12, pad_tests3) { float inBuff[] = {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}; - int padBuff[] = {1,1,2,2}; + Nd4jLong padBuff[] = {1,1,2,2}; float expBuff[] = {2.f, 1.f, 1.f, 2.f, 3.f, 3.f, 2.f, 2.f,1.f,1.f,2.f,3.f,3.f,2.f, 5.f,4.f,4.f,5.f,6.f,6.f,5.f, 5.f,4.f,4.f,5.f,6.f,6.f,5.f}; auto input = NDArrayFactory::create(inBuff, 'c', {2,3}); - auto paddings = NDArrayFactory::create(padBuff, 'c', {2,2}); + auto paddings = NDArrayFactory::create(padBuff, 'c', {2,2}); auto expected = NDArrayFactory::create(expBuff, 'c', {4,7}); nd4j::ops::pad op;