bunch of shape functions fixed

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-08-14 18:28:25 +03:00
parent 53ca9a76e8
commit e0f8d86eac
10 changed files with 11 additions and 10 deletions

View File

@ -85,7 +85,7 @@ namespace nd4j {
COPY_SHAPE(x, shapeE); COPY_SHAPE(x, shapeE);
COPY_SHAPE(y, shapeG); COPY_SHAPE(y, shapeG);
auto shapeList = SHAPELIST(shapeE, shapeG); auto shapeList = SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG));
return shapeList; return shapeList;
} }

View File

@ -86,7 +86,7 @@ namespace nd4j {
COPY_SHAPE(x, shapeE); COPY_SHAPE(x, shapeE);
COPY_SHAPE(y, shapeG); COPY_SHAPE(y, shapeG);
auto shapeList = SHAPELIST(shapeE, shapeG); auto shapeList = SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG));
return shapeList; return shapeList;
} }

View File

@ -107,7 +107,7 @@ namespace nd4j {
COPY_SHAPE(x, shapeE); COPY_SHAPE(x, shapeE);
COPY_SHAPE(y, shapeG); COPY_SHAPE(y, shapeG);
auto shapeList = SHAPELIST(shapeE, shapeG); auto shapeList = SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG));
return shapeList; return shapeList;
} }

View File

@ -114,7 +114,7 @@ namespace nd4j {
COPY_SHAPE(x, shapeE); COPY_SHAPE(x, shapeE);
COPY_SHAPE(y, shapeG); COPY_SHAPE(y, shapeG);
auto shapeList = SHAPELIST(shapeE, shapeG); auto shapeList = SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG));
return shapeList; return shapeList;
} }

View File

@ -112,7 +112,8 @@ namespace nd4j {
COPY_SHAPE(input, epsShape); COPY_SHAPE(input, epsShape);
COPY_SHAPE(bias, gradShape); COPY_SHAPE(bias, gradShape);
return SHAPELIST(epsShape, gradShape); return SHAPELIST(CONSTANT(epsShape), CONSTANT(gradShape));
} }
} }
} }

View File

@ -253,7 +253,7 @@ DECLARE_SHAPE_FN(gruCell_bp) {
Nd4jLong *dLdbcShapeInfo = nullptr; Nd4jLong *dLdbcShapeInfo = nullptr;
COPY_SHAPE(bcShapeInfo, dLdbcShapeInfo); COPY_SHAPE(bcShapeInfo, dLdbcShapeInfo);
return SHAPELIST(dLdxShapeInfo, dLdhiShapeInfo, dLdWShapeInfo, dLdWcShapeInfo, dLdbShapeInfo, dLdbcShapeInfo); return SHAPELIST(CONSTANT(dLdxShapeInfo), CONSTANT(dLdhiShapeInfo), CONSTANT(dLdWShapeInfo), CONSTANT(dLdWcShapeInfo), CONSTANT(dLdbShapeInfo), CONSTANT(dLdbcShapeInfo));
} }

View File

@ -101,7 +101,7 @@ namespace ops {
Nd4jLong *out; Nd4jLong *out;
COPY_SHAPE(in, out); COPY_SHAPE(in, out);
return SHAPELIST(out); return SHAPELIST(CONSTANT(out));
} }
} }

View File

@ -341,7 +341,7 @@ namespace nd4j {
if (DataTypeUtils::isR(xType)) { if (DataTypeUtils::isR(xType)) {
COPY_SHAPE(inShape, newShape); COPY_SHAPE(inShape, newShape);
return SHAPELIST(newShape); return SHAPELIST(CONSTANT(newShape));
} else if (DataTypeUtils::isZ(xType)) { } else if (DataTypeUtils::isZ(xType)) {
auto zShapeArr = INPUT_VARIABLE(0); auto zShapeArr = INPUT_VARIABLE(0);
auto zShapeVector = zShapeArr->asVectorT<Nd4jLong>(); auto zShapeVector = zShapeArr->asVectorT<Nd4jLong>();

View File

@ -47,7 +47,7 @@ namespace nd4j {
Nd4jLong *newShape; Nd4jLong *newShape;
COPY_SHAPE(inShape, newShape); COPY_SHAPE(inShape, newShape);
return SHAPELIST(newShape); return SHAPELIST(CONSTANT(newShape));
} }
Nd4jStatus LegacyScalarBoolOp::validateAndExecute(Context &block) { Nd4jStatus LegacyScalarBoolOp::validateAndExecute(Context &block) {

View File

@ -61,7 +61,7 @@ namespace nd4j {
Nd4jLong *newShape; Nd4jLong *newShape;
COPY_SHAPE(inShape, newShape); COPY_SHAPE(inShape, newShape);
return SHAPELIST(newShape); return SHAPELIST(CONSTANT(newShape));
} }
} }
} }