From e0f8d86eac8900f01595dc6b867fc29c5f6fb8aa Mon Sep 17 00:00:00 2001 From: raver119 Date: Wed, 14 Aug 2019 18:28:25 +0300 Subject: [PATCH] bunch of shape functions fixed Signed-off-by: raver119 --- libnd4j/include/ops/declarable/generic/broadcastable/mod.cpp | 2 +- .../ops/declarable/generic/broadcastable/reverse_mod.cpp | 2 +- .../ops/declarable/generic/broadcastable/reverse_subtract.cpp | 2 +- .../include/ops/declarable/generic/broadcastable/subtract.cpp | 2 +- libnd4j/include/ops/declarable/generic/parity_ops/bias_add.cpp | 3 ++- libnd4j/include/ops/declarable/generic/recurrent/gruCell.cpp | 2 +- libnd4j/include/ops/declarable/generic/transforms/reverse.cpp | 2 +- libnd4j/include/ops/declarable/impl/LegacyRandomOp.cpp | 2 +- libnd4j/include/ops/declarable/impl/LegacyScalarBoolOp.cpp | 2 +- libnd4j/include/ops/declarable/impl/LegacyTransformOp.cpp | 2 +- 10 files changed, 11 insertions(+), 10 deletions(-) diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/mod.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/mod.cpp index 781dea86a..5ae075c99 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/mod.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/mod.cpp @@ -85,7 +85,7 @@ namespace nd4j { COPY_SHAPE(x, shapeE); COPY_SHAPE(y, shapeG); - auto shapeList = SHAPELIST(shapeE, shapeG); + auto shapeList = SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG)); return shapeList; } diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/reverse_mod.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/reverse_mod.cpp index d10c32435..9dea93699 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/reverse_mod.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/reverse_mod.cpp @@ -86,7 +86,7 @@ namespace nd4j { COPY_SHAPE(x, shapeE); COPY_SHAPE(y, shapeG); - auto shapeList = SHAPELIST(shapeE, shapeG); + auto shapeList = SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG)); return shapeList; } diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/reverse_subtract.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/reverse_subtract.cpp index 887225f6a..af282fe7c 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/reverse_subtract.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/reverse_subtract.cpp @@ -107,7 +107,7 @@ namespace nd4j { COPY_SHAPE(x, shapeE); COPY_SHAPE(y, shapeG); - auto shapeList = SHAPELIST(shapeE, shapeG); + auto shapeList = SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG)); return shapeList; } diff --git a/libnd4j/include/ops/declarable/generic/broadcastable/subtract.cpp b/libnd4j/include/ops/declarable/generic/broadcastable/subtract.cpp index f27b4fc61..76f2d6830 100644 --- a/libnd4j/include/ops/declarable/generic/broadcastable/subtract.cpp +++ b/libnd4j/include/ops/declarable/generic/broadcastable/subtract.cpp @@ -114,7 +114,7 @@ namespace nd4j { COPY_SHAPE(x, shapeE); COPY_SHAPE(y, shapeG); - auto shapeList = SHAPELIST(shapeE, shapeG); + auto shapeList = SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG)); return shapeList; } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/bias_add.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/bias_add.cpp index 6806be664..3309c6104 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/bias_add.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/bias_add.cpp @@ -112,7 +112,8 @@ namespace nd4j { COPY_SHAPE(input, epsShape); COPY_SHAPE(bias, gradShape); - return SHAPELIST(epsShape, gradShape); + return SHAPELIST(CONSTANT(epsShape), CONSTANT(gradShape)); + } } } diff --git a/libnd4j/include/ops/declarable/generic/recurrent/gruCell.cpp b/libnd4j/include/ops/declarable/generic/recurrent/gruCell.cpp index 5ddd1654e..ddd18dc84 100644 --- a/libnd4j/include/ops/declarable/generic/recurrent/gruCell.cpp +++ b/libnd4j/include/ops/declarable/generic/recurrent/gruCell.cpp @@ -253,7 +253,7 @@ DECLARE_SHAPE_FN(gruCell_bp) { Nd4jLong *dLdbcShapeInfo = nullptr; COPY_SHAPE(bcShapeInfo, dLdbcShapeInfo); - return SHAPELIST(dLdxShapeInfo, dLdhiShapeInfo, dLdWShapeInfo, dLdWcShapeInfo, dLdbShapeInfo, dLdbcShapeInfo); + return SHAPELIST(CONSTANT(dLdxShapeInfo), CONSTANT(dLdhiShapeInfo), CONSTANT(dLdWShapeInfo), CONSTANT(dLdWcShapeInfo), CONSTANT(dLdbShapeInfo), CONSTANT(dLdbcShapeInfo)); } diff --git a/libnd4j/include/ops/declarable/generic/transforms/reverse.cpp b/libnd4j/include/ops/declarable/generic/transforms/reverse.cpp index 1d2b25678..8047da41a 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/reverse.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/reverse.cpp @@ -101,7 +101,7 @@ namespace ops { Nd4jLong *out; COPY_SHAPE(in, out); - return SHAPELIST(out); + return SHAPELIST(CONSTANT(out)); } } diff --git a/libnd4j/include/ops/declarable/impl/LegacyRandomOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyRandomOp.cpp index 1ce00f44a..731c5a5f9 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyRandomOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyRandomOp.cpp @@ -341,7 +341,7 @@ namespace nd4j { if (DataTypeUtils::isR(xType)) { COPY_SHAPE(inShape, newShape); - return SHAPELIST(newShape); + return SHAPELIST(CONSTANT(newShape)); } else if (DataTypeUtils::isZ(xType)) { auto zShapeArr = INPUT_VARIABLE(0); auto zShapeVector = zShapeArr->asVectorT(); diff --git a/libnd4j/include/ops/declarable/impl/LegacyScalarBoolOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyScalarBoolOp.cpp index d20bf3d04..3e35e2c11 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyScalarBoolOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyScalarBoolOp.cpp @@ -47,7 +47,7 @@ namespace nd4j { Nd4jLong *newShape; COPY_SHAPE(inShape, newShape); - return SHAPELIST(newShape); + return SHAPELIST(CONSTANT(newShape)); } Nd4jStatus LegacyScalarBoolOp::validateAndExecute(Context &block) { diff --git a/libnd4j/include/ops/declarable/impl/LegacyTransformOp.cpp b/libnd4j/include/ops/declarable/impl/LegacyTransformOp.cpp index d870f15d0..de8248d25 100644 --- a/libnd4j/include/ops/declarable/impl/LegacyTransformOp.cpp +++ b/libnd4j/include/ops/declarable/impl/LegacyTransformOp.cpp @@ -61,7 +61,7 @@ namespace nd4j { Nd4jLong *newShape; COPY_SHAPE(inShape, newShape); - return SHAPELIST(newShape); + return SHAPELIST(CONSTANT(newShape)); } } }