diff --git a/libnd4j/include/op_boilerplate.h b/libnd4j/include/op_boilerplate.h index 8487f0264..5fef0c892 100644 --- a/libnd4j/include/op_boilerplate.h +++ b/libnd4j/include/op_boilerplate.h @@ -1229,7 +1229,7 @@ /// graph definitions #define REQUIRE_OK(A) if (nd4j::ops::resultHelper( (A), #A, __FILE__, __LINE__ ) != 0) return ND4J_STATUS_VALIDATION; -#define REQUIRE_TRUE(...) if (nd4j::ops::conditionHelper(__FILE__, __LINE__, __VA_ARGS__) != 0) throw std::invalid_argument("Op validation failed"); +#define REQUIRE_TRUE(COND, ...) if (!(COND)) { if (nd4j::ops::conditionHelper(__FILE__, __LINE__, COND, __VA_ARGS__) != 0) throw std::invalid_argument("Op validation failed");}; #define DECLARE_ENTRY(NAME, ...) template struct ND4J_EXPORT __registratorFloat>; \ template struct ND4J_EXPORT __registratorHalf>; \ diff --git a/libnd4j/include/ops/declarable/generic/shape/expand_dims.cpp b/libnd4j/include/ops/declarable/generic/shape/expand_dims.cpp index efa723c20..50aa0cb9c 100644 --- a/libnd4j/include/ops/declarable/generic/shape/expand_dims.cpp +++ b/libnd4j/include/ops/declarable/generic/shape/expand_dims.cpp @@ -41,9 +41,10 @@ namespace nd4j { REQUIRE_TRUE(axis >= 0 && axis <= input->rankOf()+1, 0, "ExpandDims: axis should be in range of 0...%i in this case, but got %i instead", input->rankOf() + 1, axis); - std::vector shape; + std::vector shape(input->rankOf()); + for(int e = 0; e < input->rankOf(); e++) - shape.emplace_back(input->sizeAt(e)); + shape[input->sizeAt(e)]; shape.insert(shape.begin() + axis, 1); diff --git a/libnd4j/include/ops/declarable/generic/shape/reshape.cpp b/libnd4j/include/ops/declarable/generic/shape/reshape.cpp index 4a06455eb..dba15bf22 100644 --- a/libnd4j/include/ops/declarable/generic/shape/reshape.cpp +++ b/libnd4j/include/ops/declarable/generic/shape/reshape.cpp @@ -122,12 +122,17 @@ CUSTOM_OP_IMPL(reshape, 1, 1, false, 0, -2) { nd4j_printv("Reshape: new shape", shapeNew); } - if (s->isEmpty()) { + if (s->isScalar()) { // just a scalar z->assign(x); } else { - auto xr = x->reshape(order, shapeNew); - z->assign(xr); + // in some cases we might go away with simple memcpy call instead of assign call + if (x->ordering() == 'c' && z->ordering() == x->ordering() && shape::reshapeC(x->shapeInfo(), z->shapeInfo())) { + z->dataBuffer()->copyBufferFrom(*x->dataBuffer().get(), z->lengthOf() * DataTypeUtils::sizeOfElement(z->dataType()), 0, x->bufferOffset()); + } else { + auto xr = x->reshape(order, shapeNew); + z->assign(xr); + } } return Status::OK();