reshape tweak (#275)

* - expand dims tweak
- reshape memcpy

Signed-off-by: raver119 <raver119@gmail.com>

* validation fix

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2020-02-26 14:05:32 +03:00 committed by GitHub
parent b686368b82
commit 5c806d2fb5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 12 additions and 6 deletions

View File

@ -1229,7 +1229,7 @@
/// graph definitions /// graph definitions
#define REQUIRE_OK(A) if (nd4j::ops::resultHelper( (A), #A, __FILE__, __LINE__ ) != 0) return ND4J_STATUS_VALIDATION; #define REQUIRE_OK(A) if (nd4j::ops::resultHelper( (A), #A, __FILE__, __LINE__ ) != 0) return ND4J_STATUS_VALIDATION;
#define REQUIRE_TRUE(...) if (nd4j::ops::conditionHelper(__FILE__, __LINE__, __VA_ARGS__) != 0) throw std::invalid_argument("Op validation failed"); #define REQUIRE_TRUE(COND, ...) if (!(COND)) { if (nd4j::ops::conditionHelper(__FILE__, __LINE__, COND, __VA_ARGS__) != 0) throw std::invalid_argument("Op validation failed");};
#define DECLARE_ENTRY(NAME, ...) template struct ND4J_EXPORT __registratorFloat<NAME<float>>; \ #define DECLARE_ENTRY(NAME, ...) template struct ND4J_EXPORT __registratorFloat<NAME<float>>; \
template struct ND4J_EXPORT __registratorHalf<NAME<float16>>; \ template struct ND4J_EXPORT __registratorHalf<NAME<float16>>; \

View File

@ -41,9 +41,10 @@ namespace nd4j {
REQUIRE_TRUE(axis >= 0 && axis <= input->rankOf()+1, 0, "ExpandDims: axis should be in range of 0...%i in this case, but got %i instead", input->rankOf() + 1, axis); REQUIRE_TRUE(axis >= 0 && axis <= input->rankOf()+1, 0, "ExpandDims: axis should be in range of 0...%i in this case, but got %i instead", input->rankOf() + 1, axis);
std::vector<Nd4jLong> shape; std::vector<Nd4jLong> shape(input->rankOf());
for(int e = 0; e < input->rankOf(); e++) for(int e = 0; e < input->rankOf(); e++)
shape.emplace_back(input->sizeAt(e)); shape[input->sizeAt(e)];
shape.insert(shape.begin() + axis, 1); shape.insert(shape.begin() + axis, 1);

View File

@ -122,13 +122,18 @@ CUSTOM_OP_IMPL(reshape, 1, 1, false, 0, -2) {
nd4j_printv("Reshape: new shape", shapeNew); nd4j_printv("Reshape: new shape", shapeNew);
} }
if (s->isEmpty()) { if (s->isScalar()) {
// just a scalar // just a scalar
z->assign(x); z->assign(x);
} else {
// in some cases we might go away with simple memcpy call instead of assign call
if (x->ordering() == 'c' && z->ordering() == x->ordering() && shape::reshapeC(x->shapeInfo(), z->shapeInfo())) {
z->dataBuffer()->copyBufferFrom(*x->dataBuffer().get(), z->lengthOf() * DataTypeUtils::sizeOfElement(z->dataType()), 0, x->bufferOffset());
} else { } else {
auto xr = x->reshape(order, shapeNew); auto xr = x->reshape(order, shapeNew);
z->assign(xr); z->assign(xr);
} }
}
return Status::OK(); return Status::OK();