reshape tweak (#275)
* - expand dims tweak - reshape memcpy Signed-off-by: raver119 <raver119@gmail.com> * validation fix Signed-off-by: raver119 <raver119@gmail.com>master
parent
b686368b82
commit
5c806d2fb5
|
@ -1229,7 +1229,7 @@
|
||||||
|
|
||||||
/// graph definitions
|
/// graph definitions
|
||||||
#define REQUIRE_OK(A) if (nd4j::ops::resultHelper( (A), #A, __FILE__, __LINE__ ) != 0) return ND4J_STATUS_VALIDATION;
|
#define REQUIRE_OK(A) if (nd4j::ops::resultHelper( (A), #A, __FILE__, __LINE__ ) != 0) return ND4J_STATUS_VALIDATION;
|
||||||
#define REQUIRE_TRUE(...) if (nd4j::ops::conditionHelper(__FILE__, __LINE__, __VA_ARGS__) != 0) throw std::invalid_argument("Op validation failed");
|
#define REQUIRE_TRUE(COND, ...) if (!(COND)) { if (nd4j::ops::conditionHelper(__FILE__, __LINE__, COND, __VA_ARGS__) != 0) throw std::invalid_argument("Op validation failed");};
|
||||||
|
|
||||||
#define DECLARE_ENTRY(NAME, ...) template struct ND4J_EXPORT __registratorFloat<NAME<float>>; \
|
#define DECLARE_ENTRY(NAME, ...) template struct ND4J_EXPORT __registratorFloat<NAME<float>>; \
|
||||||
template struct ND4J_EXPORT __registratorHalf<NAME<float16>>; \
|
template struct ND4J_EXPORT __registratorHalf<NAME<float16>>; \
|
||||||
|
|
|
@ -41,9 +41,10 @@ namespace nd4j {
|
||||||
|
|
||||||
REQUIRE_TRUE(axis >= 0 && axis <= input->rankOf()+1, 0, "ExpandDims: axis should be in range of 0...%i in this case, but got %i instead", input->rankOf() + 1, axis);
|
REQUIRE_TRUE(axis >= 0 && axis <= input->rankOf()+1, 0, "ExpandDims: axis should be in range of 0...%i in this case, but got %i instead", input->rankOf() + 1, axis);
|
||||||
|
|
||||||
std::vector<Nd4jLong> shape;
|
std::vector<Nd4jLong> shape(input->rankOf());
|
||||||
|
|
||||||
for(int e = 0; e < input->rankOf(); e++)
|
for(int e = 0; e < input->rankOf(); e++)
|
||||||
shape.emplace_back(input->sizeAt(e));
|
shape[input->sizeAt(e)];
|
||||||
|
|
||||||
shape.insert(shape.begin() + axis, 1);
|
shape.insert(shape.begin() + axis, 1);
|
||||||
|
|
||||||
|
|
|
@ -122,12 +122,17 @@ CUSTOM_OP_IMPL(reshape, 1, 1, false, 0, -2) {
|
||||||
nd4j_printv("Reshape: new shape", shapeNew);
|
nd4j_printv("Reshape: new shape", shapeNew);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (s->isEmpty()) {
|
if (s->isScalar()) {
|
||||||
// just a scalar
|
// just a scalar
|
||||||
z->assign(x);
|
z->assign(x);
|
||||||
} else {
|
} else {
|
||||||
auto xr = x->reshape(order, shapeNew);
|
// in some cases we might go away with simple memcpy call instead of assign call
|
||||||
z->assign(xr);
|
if (x->ordering() == 'c' && z->ordering() == x->ordering() && shape::reshapeC(x->shapeInfo(), z->shapeInfo())) {
|
||||||
|
z->dataBuffer()->copyBufferFrom(*x->dataBuffer().get(), z->lengthOf() * DataTypeUtils::sizeOfElement(z->dataType()), 0, x->bufferOffset());
|
||||||
|
} else {
|
||||||
|
auto xr = x->reshape(order, shapeNew);
|
||||||
|
z->assign(xr);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
|
Loading…
Reference in New Issue