Refactored pad and mirror_pad ops to conform with TF. (#100)
parent
d8339246d9
commit
190575196c
|
@ -56,7 +56,7 @@ CUSTOM_OP_IMPL(mirror_pad, 2, 1, false, 0, 1) {
|
||||||
|
|
||||||
DECLARE_TYPES(mirror_pad) {
|
DECLARE_TYPES(mirror_pad) {
|
||||||
getOpDescriptor()->setAllowedInputTypes(0, {ALL_FLOATS});
|
getOpDescriptor()->setAllowedInputTypes(0, {ALL_FLOATS});
|
||||||
getOpDescriptor()->setAllowedInputTypes(1, {ALL_INTS});
|
getOpDescriptor()->setAllowedInputTypes(1, {DataType::INT32}); // to conform with TF
|
||||||
getOpDescriptor()->setAllowedOutputTypes(0, {ALL_FLOATS});
|
getOpDescriptor()->setAllowedOutputTypes(0, {ALL_FLOATS});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -78,7 +78,8 @@ CUSTOM_OP_IMPL(pad, 2, 1, false, 0, 1) {
|
||||||
DECLARE_TYPES(pad) {
|
DECLARE_TYPES(pad) {
|
||||||
getOpDescriptor()
|
getOpDescriptor()
|
||||||
->setAllowedInputTypes(0, nd4j::DataType::ANY)
|
->setAllowedInputTypes(0, nd4j::DataType::ANY)
|
||||||
->setAllowedInputTypes(1, {DataType::INT32, DataType::INT64}) // INT32 with TF, but used also INT64 due long shapes
|
->setAllowedInputTypes(1, {DataType::INT32}) // INT32 with TF
|
||||||
|
// ->setAllowedInputTypes(1, {DataType::INT32, DataType::INT64}) // INT32 with TF, but used also INT64 due long shapes
|
||||||
->setSameMode(true);
|
->setSameMode(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -4549,7 +4549,7 @@ TEST_F(DeclarableOpsTests7, mirrorPad_test13) {
|
||||||
TEST_F(DeclarableOpsTests7, mirrorPad_test14) {
|
TEST_F(DeclarableOpsTests7, mirrorPad_test14) {
|
||||||
|
|
||||||
auto input = NDArrayFactory::create<double>('c', {2, 3}, {1., 2., 3., 4., 5., 6.});
|
auto input = NDArrayFactory::create<double>('c', {2, 3}, {1., 2., 3., 4., 5., 6.});
|
||||||
auto paddings = NDArrayFactory::create<Nd4jLong>('c', {2, 2}, {1, 0, 0, 1});
|
auto paddings = NDArrayFactory::create<int>('c', {2, 2}, {1LL, 0LL, 0LL, 1LL});
|
||||||
|
|
||||||
auto exp = NDArrayFactory::create<double>('c', {3, 4}, {4, 5, 6, 5, 1, 2, 3, 2, 4, 5, 6, 5});
|
auto exp = NDArrayFactory::create<double>('c', {3, 4}, {4, 5, 6, 5, 1, 2, 3, 2, 4, 5, 6, 5});
|
||||||
|
|
||||||
|
@ -4567,7 +4567,7 @@ TEST_F(DeclarableOpsTests7, mirrorPad_test14) {
|
||||||
TEST_F(DeclarableOpsTests7, mirrorPad_test15) {
|
TEST_F(DeclarableOpsTests7, mirrorPad_test15) {
|
||||||
|
|
||||||
auto input = NDArrayFactory::create<double>('c', {2, 3}, {1., 2., 3., 4., 5., 6.});
|
auto input = NDArrayFactory::create<double>('c', {2, 3}, {1., 2., 3., 4., 5., 6.});
|
||||||
auto paddings = NDArrayFactory::create<Nd4jLong>('c', {2, 2}, {1, 1, 0, 0});
|
auto paddings = NDArrayFactory::create<int>('c', {2, 2}, {1, 1, 0, 0});
|
||||||
|
|
||||||
auto exp = NDArrayFactory::create<double>('c', {4, 3}, {1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6});
|
auto exp = NDArrayFactory::create<double>('c', {4, 3}, {1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6});
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue