Shugeo pad fix3 (#132)
* Expanding allowed paddings type to 64bit ints also. * Extended to int64 paddins data types for mirror_pad op. Signed-off-by: shugeo <sgazeos@gmail.com>master
parent
de3c0afdce
commit
e303c06042
|
@ -56,7 +56,7 @@ CUSTOM_OP_IMPL(mirror_pad, 2, 1, false, 0, 1) {
|
||||||
|
|
||||||
DECLARE_TYPES(mirror_pad) {
|
DECLARE_TYPES(mirror_pad) {
|
||||||
getOpDescriptor()->setAllowedInputTypes(0, {ALL_FLOATS});
|
getOpDescriptor()->setAllowedInputTypes(0, {ALL_FLOATS});
|
||||||
getOpDescriptor()->setAllowedInputTypes(1, {DataType::INT32}); // to conform with TF
|
getOpDescriptor()->setAllowedInputTypes(1, {DataType::INT32, DataType::INT64}); // to conform with TF
|
||||||
getOpDescriptor()->setAllowedOutputTypes(0, {ALL_FLOATS});
|
getOpDescriptor()->setAllowedOutputTypes(0, {ALL_FLOATS});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -78,7 +78,7 @@ CUSTOM_OP_IMPL(pad, 2, 1, false, 0, 1) {
|
||||||
DECLARE_TYPES(pad) {
|
DECLARE_TYPES(pad) {
|
||||||
getOpDescriptor()
|
getOpDescriptor()
|
||||||
->setAllowedInputTypes(0, nd4j::DataType::ANY)
|
->setAllowedInputTypes(0, nd4j::DataType::ANY)
|
||||||
->setAllowedInputTypes(1, {DataType::INT32}) // INT32 with TF
|
->setAllowedInputTypes(1, {DataType::INT32, DataType::INT64}) // INT32 with TF
|
||||||
// ->setAllowedInputTypes(1, {DataType::INT32, DataType::INT64}) // INT32 with TF, but used also INT64 due long shapes
|
// ->setAllowedInputTypes(1, {DataType::INT32, DataType::INT64}) // INT32 with TF, but used also INT64 due long shapes
|
||||||
->setSameMode(true);
|
->setSameMode(true);
|
||||||
}
|
}
|
||||||
|
|
|
@ -597,7 +597,7 @@ TEST_F(DeclarableOpsTests12, reverse_test15) {
|
||||||
TEST_F(DeclarableOpsTests12, mirrorPad_test17) {
|
TEST_F(DeclarableOpsTests12, mirrorPad_test17) {
|
||||||
|
|
||||||
NDArray x('c', {2,3}, {1,2,3,4,5,6}, nd4j::DataType::DOUBLE);
|
NDArray x('c', {2,3}, {1,2,3,4,5,6}, nd4j::DataType::DOUBLE);
|
||||||
NDArray padding('c', {2,2}, {1,1,2,2}, nd4j::DataType::INT32);
|
NDArray padding('c', {2,2}, {1,1,2,2}, nd4j::DataType::INT64);
|
||||||
NDArray z('c', {4,7}, nd4j::DataType::DOUBLE);
|
NDArray z('c', {4,7}, nd4j::DataType::DOUBLE);
|
||||||
NDArray exp1('c', {4,7}, {6, 5, 4, 5, 6, 5, 4,3, 2, 1, 2, 3, 2, 1,6, 5, 4, 5, 6, 5, 4,3, 2, 1, 2, 3, 2, 1}, nd4j::DataType::DOUBLE);
|
NDArray exp1('c', {4,7}, {6, 5, 4, 5, 6, 5, 4,3, 2, 1, 2, 3, 2, 1,6, 5, 4, 5, 6, 5, 4,3, 2, 1, 2, 3, 2, 1}, nd4j::DataType::DOUBLE);
|
||||||
NDArray exp2('c', {4,7}, {2, 1, 1, 2, 3, 3, 2,2, 1, 1, 2, 3, 3, 2,5, 4, 4, 5, 6, 6, 5,5, 4, 4, 5, 6, 6, 5}, nd4j::DataType::DOUBLE);
|
NDArray exp2('c', {4,7}, {2, 1, 1, 2, 3, 3, 2,2, 1, 1, 2, 3, 3, 2,5, 4, 4, 5, 6, 6, 5,5, 4, 4, 5, 6, 6, 5}, nd4j::DataType::DOUBLE);
|
||||||
|
@ -621,7 +621,7 @@ TEST_F(DeclarableOpsTests12, mirrorPad_test17) {
|
||||||
TEST_F(DeclarableOpsTests12, mirrorPad_test18) {
|
TEST_F(DeclarableOpsTests12, mirrorPad_test18) {
|
||||||
|
|
||||||
NDArray x('c', {3}, {1,2,3}, nd4j::DataType::DOUBLE);
|
NDArray x('c', {3}, {1,2,3}, nd4j::DataType::DOUBLE);
|
||||||
NDArray padding('c', {2}, {1,1}, nd4j::DataType::INT32);
|
NDArray padding('c', {1, 2}, {1,1}, nd4j::DataType::INT32);
|
||||||
NDArray z('c', {5}, nd4j::DataType::DOUBLE);
|
NDArray z('c', {5}, nd4j::DataType::DOUBLE);
|
||||||
NDArray exp('c', {5}, {2,1,2,3,2}, nd4j::DataType::DOUBLE);
|
NDArray exp('c', {5}, {2,1,2,3,2}, nd4j::DataType::DOUBLE);
|
||||||
|
|
||||||
|
@ -1434,11 +1434,11 @@ TEST_F(DeclarableOpsTests12, pad_tests2) {
|
||||||
TEST_F(DeclarableOpsTests12, pad_tests3) {
|
TEST_F(DeclarableOpsTests12, pad_tests3) {
|
||||||
|
|
||||||
float inBuff[] = {1.f, 2.f, 3.f, 4.f, 5.f, 6.f};
|
float inBuff[] = {1.f, 2.f, 3.f, 4.f, 5.f, 6.f};
|
||||||
int padBuff[] = {1,1,2,2};
|
Nd4jLong padBuff[] = {1,1,2,2};
|
||||||
float expBuff[] = {2.f, 1.f, 1.f, 2.f, 3.f, 3.f, 2.f, 2.f,1.f,1.f,2.f,3.f,3.f,2.f, 5.f,4.f,4.f,5.f,6.f,6.f,5.f, 5.f,4.f,4.f,5.f,6.f,6.f,5.f};
|
float expBuff[] = {2.f, 1.f, 1.f, 2.f, 3.f, 3.f, 2.f, 2.f,1.f,1.f,2.f,3.f,3.f,2.f, 5.f,4.f,4.f,5.f,6.f,6.f,5.f, 5.f,4.f,4.f,5.f,6.f,6.f,5.f};
|
||||||
|
|
||||||
auto input = NDArrayFactory::create<float>(inBuff, 'c', {2,3});
|
auto input = NDArrayFactory::create<float>(inBuff, 'c', {2,3});
|
||||||
auto paddings = NDArrayFactory::create<int>(padBuff, 'c', {2,2});
|
auto paddings = NDArrayFactory::create<Nd4jLong>(padBuff, 'c', {2,2});
|
||||||
auto expected = NDArrayFactory::create<float>(expBuff, 'c', {4,7});
|
auto expected = NDArrayFactory::create<float>(expBuff, 'c', {4,7});
|
||||||
|
|
||||||
nd4j::ops::pad op;
|
nd4j::ops::pad op;
|
||||||
|
|
Loading…
Reference in New Issue