Refactored pad and mirror_pad ops to conform with TF. (#100)

master
shugeo 2019-12-03 14:06:38 +02:00 committed by raver119
parent d8339246d9
commit 190575196c
3 changed files with 5 additions and 4 deletions

View File

@ -56,7 +56,7 @@ CUSTOM_OP_IMPL(mirror_pad, 2, 1, false, 0, 1) {
DECLARE_TYPES(mirror_pad) {
getOpDescriptor()->setAllowedInputTypes(0, {ALL_FLOATS});
getOpDescriptor()->setAllowedInputTypes(1, {ALL_INTS});
getOpDescriptor()->setAllowedInputTypes(1, {DataType::INT32}); // to conform with TF
getOpDescriptor()->setAllowedOutputTypes(0, {ALL_FLOATS});
}

View File

@ -78,7 +78,8 @@ CUSTOM_OP_IMPL(pad, 2, 1, false, 0, 1) {
DECLARE_TYPES(pad) {
getOpDescriptor()
->setAllowedInputTypes(0, nd4j::DataType::ANY)
->setAllowedInputTypes(1, {DataType::INT32, DataType::INT64}) // INT32 with TF, but used also INT64 due long shapes
->setAllowedInputTypes(1, {DataType::INT32}) // INT32 with TF
// ->setAllowedInputTypes(1, {DataType::INT32, DataType::INT64}) // INT32 with TF, but used also INT64 due long shapes
->setSameMode(true);
}

View File

@ -4549,7 +4549,7 @@ TEST_F(DeclarableOpsTests7, mirrorPad_test13) {
TEST_F(DeclarableOpsTests7, mirrorPad_test14) {
auto input = NDArrayFactory::create<double>('c', {2, 3}, {1., 2., 3., 4., 5., 6.});
auto paddings = NDArrayFactory::create<Nd4jLong>('c', {2, 2}, {1, 0, 0, 1});
auto paddings = NDArrayFactory::create<int>('c', {2, 2}, {1LL, 0LL, 0LL, 1LL});
auto exp = NDArrayFactory::create<double>('c', {3, 4}, {4, 5, 6, 5, 1, 2, 3, 2, 4, 5, 6, 5});
@ -4567,7 +4567,7 @@ TEST_F(DeclarableOpsTests7, mirrorPad_test14) {
TEST_F(DeclarableOpsTests7, mirrorPad_test15) {
auto input = NDArrayFactory::create<double>('c', {2, 3}, {1., 2., 3., 4., 5., 6.});
auto paddings = NDArrayFactory::create<Nd4jLong>('c', {2, 2}, {1, 1, 0, 0});
auto paddings = NDArrayFactory::create<int>('c', {2, 2}, {1, 1, 0, 0});
auto exp = NDArrayFactory::create<double>('c', {4, 3}, {1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6});