OP/CONFIGURABLE_OP shapefn fix (#125)

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-08-16 08:53:30 +03:00 committed by GitHub
parent 2f3d7330ce
commit 5c908886b0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 7 additions and 4 deletions

View File

@ -1328,7 +1328,8 @@
REGISTER_C(NAME) \
nd4j::ShapeList* nd4j::ops::NAME::calculateOutputShape(nd4j::ShapeList* inputShape, nd4j::graph::Context& block) { \
auto shapeList = SHAPELIST(); \
for (int e = 0; e < block.width(); e++) { \
auto opLimit = this->getOpDescriptor()->getNumberOfOutputs() < 1 ? block.width() : this->getOpDescriptor()->getNumberOfOutputs(); \
for (int e = 0; e < opLimit; e++) { \
auto newshape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inputShape->at(e)), shape::order(inputShape->at(e)), shape::rank(inputShape->at(e)), shape::shapeOf(inputShape->at(e))); \
shapeList->push_back(newshape); \
} \
@ -1365,7 +1366,8 @@
REGISTER_C(NAME) \
nd4j::ShapeList* nd4j::ops::NAME::calculateOutputShape(nd4j::ShapeList* inputShape, nd4j::graph::Context& block) { \
auto shapeList = SHAPELIST(); \
for (int e = 0; e < block.width(); e++) { \
auto opLimit = this->getOpDescriptor()->getNumberOfOutputs() < 1 ? block.width() : this->getOpDescriptor()->getNumberOfOutputs(); \
for (int e = 0; e < opLimit; e++) { \
Nd4jLong* newshape; \
COPY_SHAPE(inputShape->at(0), newshape); \
shapeList->push_back(CONSTANT(newshape)); \
@ -1388,7 +1390,8 @@
REGISTER_C(NAME) \
nd4j::ShapeList* nd4j::ops::NAME::calculateOutputShape(nd4j::ShapeList* inputShape, nd4j::graph::Context& block) { \
auto shapeList = SHAPELIST(); \
for (int e = 0; e < block.width(); e++) { \
auto opLimit = this->getOpDescriptor()->getNumberOfOutputs() < 1 ? block.width() : this->getOpDescriptor()->getNumberOfOutputs(); \
for (int e = 0; e < opLimit; e++) { \
auto newshape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inputShape->at(e)), shape::order(inputShape->at(e)), shape::rank(inputShape->at(e)), shape::shapeOf(inputShape->at(e))); \
shapeList->push_back(newshape); \
} \

View File

@ -27,7 +27,7 @@
namespace nd4j {
namespace ops {
OP_IMPL(toggle_bits, -1, 1, true) {
OP_IMPL(toggle_bits, -1, -1, true) {
for (int i = 0; i < block.width(); i++) {
auto x = INPUT_VARIABLE(i);