parent
53ca9a76e8
commit
e0f8d86eac
|
@ -85,7 +85,7 @@ namespace nd4j {
|
||||||
COPY_SHAPE(x, shapeE);
|
COPY_SHAPE(x, shapeE);
|
||||||
COPY_SHAPE(y, shapeG);
|
COPY_SHAPE(y, shapeG);
|
||||||
|
|
||||||
auto shapeList = SHAPELIST(shapeE, shapeG);
|
auto shapeList = SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG));
|
||||||
|
|
||||||
return shapeList;
|
return shapeList;
|
||||||
}
|
}
|
||||||
|
|
|
@ -86,7 +86,7 @@ namespace nd4j {
|
||||||
COPY_SHAPE(x, shapeE);
|
COPY_SHAPE(x, shapeE);
|
||||||
COPY_SHAPE(y, shapeG);
|
COPY_SHAPE(y, shapeG);
|
||||||
|
|
||||||
auto shapeList = SHAPELIST(shapeE, shapeG);
|
auto shapeList = SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG));
|
||||||
|
|
||||||
return shapeList;
|
return shapeList;
|
||||||
}
|
}
|
||||||
|
|
|
@ -107,7 +107,7 @@ namespace nd4j {
|
||||||
COPY_SHAPE(x, shapeE);
|
COPY_SHAPE(x, shapeE);
|
||||||
COPY_SHAPE(y, shapeG);
|
COPY_SHAPE(y, shapeG);
|
||||||
|
|
||||||
auto shapeList = SHAPELIST(shapeE, shapeG);
|
auto shapeList = SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG));
|
||||||
|
|
||||||
return shapeList;
|
return shapeList;
|
||||||
}
|
}
|
||||||
|
|
|
@ -114,7 +114,7 @@ namespace nd4j {
|
||||||
COPY_SHAPE(x, shapeE);
|
COPY_SHAPE(x, shapeE);
|
||||||
COPY_SHAPE(y, shapeG);
|
COPY_SHAPE(y, shapeG);
|
||||||
|
|
||||||
auto shapeList = SHAPELIST(shapeE, shapeG);
|
auto shapeList = SHAPELIST(CONSTANT(shapeE), CONSTANT(shapeG));
|
||||||
|
|
||||||
return shapeList;
|
return shapeList;
|
||||||
}
|
}
|
||||||
|
|
|
@ -112,7 +112,8 @@ namespace nd4j {
|
||||||
COPY_SHAPE(input, epsShape);
|
COPY_SHAPE(input, epsShape);
|
||||||
COPY_SHAPE(bias, gradShape);
|
COPY_SHAPE(bias, gradShape);
|
||||||
|
|
||||||
return SHAPELIST(epsShape, gradShape);
|
return SHAPELIST(CONSTANT(epsShape), CONSTANT(gradShape));
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -253,7 +253,7 @@ DECLARE_SHAPE_FN(gruCell_bp) {
|
||||||
Nd4jLong *dLdbcShapeInfo = nullptr;
|
Nd4jLong *dLdbcShapeInfo = nullptr;
|
||||||
COPY_SHAPE(bcShapeInfo, dLdbcShapeInfo);
|
COPY_SHAPE(bcShapeInfo, dLdbcShapeInfo);
|
||||||
|
|
||||||
return SHAPELIST(dLdxShapeInfo, dLdhiShapeInfo, dLdWShapeInfo, dLdWcShapeInfo, dLdbShapeInfo, dLdbcShapeInfo);
|
return SHAPELIST(CONSTANT(dLdxShapeInfo), CONSTANT(dLdhiShapeInfo), CONSTANT(dLdWShapeInfo), CONSTANT(dLdWcShapeInfo), CONSTANT(dLdbShapeInfo), CONSTANT(dLdbcShapeInfo));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -101,7 +101,7 @@ namespace ops {
|
||||||
Nd4jLong *out;
|
Nd4jLong *out;
|
||||||
COPY_SHAPE(in, out);
|
COPY_SHAPE(in, out);
|
||||||
|
|
||||||
return SHAPELIST(out);
|
return SHAPELIST(CONSTANT(out));
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -341,7 +341,7 @@ namespace nd4j {
|
||||||
if (DataTypeUtils::isR(xType)) {
|
if (DataTypeUtils::isR(xType)) {
|
||||||
COPY_SHAPE(inShape, newShape);
|
COPY_SHAPE(inShape, newShape);
|
||||||
|
|
||||||
return SHAPELIST(newShape);
|
return SHAPELIST(CONSTANT(newShape));
|
||||||
} else if (DataTypeUtils::isZ(xType)) {
|
} else if (DataTypeUtils::isZ(xType)) {
|
||||||
auto zShapeArr = INPUT_VARIABLE(0);
|
auto zShapeArr = INPUT_VARIABLE(0);
|
||||||
auto zShapeVector = zShapeArr->asVectorT<Nd4jLong>();
|
auto zShapeVector = zShapeArr->asVectorT<Nd4jLong>();
|
||||||
|
|
|
@ -47,7 +47,7 @@ namespace nd4j {
|
||||||
Nd4jLong *newShape;
|
Nd4jLong *newShape;
|
||||||
COPY_SHAPE(inShape, newShape);
|
COPY_SHAPE(inShape, newShape);
|
||||||
|
|
||||||
return SHAPELIST(newShape);
|
return SHAPELIST(CONSTANT(newShape));
|
||||||
}
|
}
|
||||||
|
|
||||||
Nd4jStatus LegacyScalarBoolOp::validateAndExecute(Context &block) {
|
Nd4jStatus LegacyScalarBoolOp::validateAndExecute(Context &block) {
|
||||||
|
|
|
@ -61,7 +61,7 @@ namespace nd4j {
|
||||||
Nd4jLong *newShape;
|
Nd4jLong *newShape;
|
||||||
COPY_SHAPE(inShape, newShape);
|
COPY_SHAPE(inShape, newShape);
|
||||||
|
|
||||||
return SHAPELIST(newShape);
|
return SHAPELIST(CONSTANT(newShape));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue