diff --git a/libnd4j/include/op_boilerplate.h b/libnd4j/include/op_boilerplate.h index ac30b28d8..d9c8dee62 100644 --- a/libnd4j/include/op_boilerplate.h +++ b/libnd4j/include/op_boilerplate.h @@ -1328,7 +1328,8 @@ REGISTER_C(NAME) \ nd4j::ShapeList* nd4j::ops::NAME::calculateOutputShape(nd4j::ShapeList* inputShape, nd4j::graph::Context& block) { \ auto shapeList = SHAPELIST(); \ - for (int e = 0; e < block.width(); e++) { \ + auto opLimit = this->getOpDescriptor()->getNumberOfOutputs() < 1 ? block.width() : this->getOpDescriptor()->getNumberOfOutputs(); \ + for (int e = 0; e < opLimit; e++) { \ auto newshape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inputShape->at(e)), shape::order(inputShape->at(e)), shape::rank(inputShape->at(e)), shape::shapeOf(inputShape->at(e))); \ shapeList->push_back(newshape); \ } \ @@ -1365,7 +1366,8 @@ REGISTER_C(NAME) \ nd4j::ShapeList* nd4j::ops::NAME::calculateOutputShape(nd4j::ShapeList* inputShape, nd4j::graph::Context& block) { \ auto shapeList = SHAPELIST(); \ - for (int e = 0; e < block.width(); e++) { \ + auto opLimit = this->getOpDescriptor()->getNumberOfOutputs() < 1 ? block.width() : this->getOpDescriptor()->getNumberOfOutputs(); \ + for (int e = 0; e < opLimit; e++) { \ Nd4jLong* newshape; \ COPY_SHAPE(inputShape->at(0), newshape); \ shapeList->push_back(CONSTANT(newshape)); \ @@ -1388,7 +1390,8 @@ REGISTER_C(NAME) \ nd4j::ShapeList* nd4j::ops::NAME::calculateOutputShape(nd4j::ShapeList* inputShape, nd4j::graph::Context& block) { \ auto shapeList = SHAPELIST(); \ - for (int e = 0; e < block.width(); e++) { \ + auto opLimit = this->getOpDescriptor()->getNumberOfOutputs() < 1 ? block.width() : this->getOpDescriptor()->getNumberOfOutputs(); \ + for (int e = 0; e < opLimit; e++) { \ auto newshape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inputShape->at(e)), shape::order(inputShape->at(e)), shape::rank(inputShape->at(e)), shape::shapeOf(inputShape->at(e))); \ shapeList->push_back(newshape); \ } \ diff --git a/libnd4j/include/ops/declarable/generic/bitwise/toggle_bits.cpp b/libnd4j/include/ops/declarable/generic/bitwise/toggle_bits.cpp index 1813946d0..4aaae3c0d 100644 --- a/libnd4j/include/ops/declarable/generic/bitwise/toggle_bits.cpp +++ b/libnd4j/include/ops/declarable/generic/bitwise/toggle_bits.cpp @@ -27,7 +27,7 @@ namespace nd4j { namespace ops { - OP_IMPL(toggle_bits, -1, 1, true) { + OP_IMPL(toggle_bits, -1, -1, true) { for (int i = 0; i < block.width(); i++) { auto x = INPUT_VARIABLE(i);