Shyrma transpose (#244)
* - provide contiguous strides for ouput in transpose op Signed-off-by: Yurii <iuriish@yahoo.com> * - provide contiguous strides for output in permute op Signed-off-by: Yurii <iuriish@yahoo.com> * - take into account empty shapes properly in transpose/permute op Signed-off-by: Yurii <iuriish@yahoo.com>master
parent
9e3c1b02b1
commit
011c272fde
|
@ -1976,7 +1976,7 @@ bool NDArray::permutei(const std::initializer_list<int>& dimensions) {
|
|||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
bool NDArray::permutei(const std::vector<int>& dimensions) {
|
||||
return permutei(dimensions.data(), dimensions.size());
|
||||
return permutei(dimensions.data(), rankOf());
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
|
@ -1998,7 +1998,7 @@ bool NDArray::permutei(const std::vector<Nd4jLong>& dimensions) {
|
|||
for (int e = 0; e < dimensions.size(); e++)
|
||||
ivec[e] = dimensions[e];
|
||||
|
||||
return permutei(ivec.data(), ivec.size());
|
||||
return permutei(ivec.data(), rankOf());
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
|
@ -2034,9 +2034,8 @@ NDArray NDArray::permute(const Nd4jLong* dimensions, const int rank) && {
|
|||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
NDArray NDArray::permute(const std::vector<int>& dimensions) const &{
|
||||
auto data = dimensions.data();
|
||||
auto size = dimensions.size();
|
||||
return permute(data, size);
|
||||
|
||||
return permute(dimensions.data(), rankOf());
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
|
@ -2048,7 +2047,8 @@ NDArray NDArray::permute(const std::vector<int>& dimensions) && {
|
|||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
NDArray NDArray::permute(const std::vector<Nd4jLong>& dimensions) const & {
|
||||
return permute(dimensions.data(), dimensions.size());
|
||||
|
||||
return permute(dimensions.data(), rankOf());
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
|
@ -2111,12 +2111,12 @@ void NDArray::permute(const Nd4jLong *dimensions, const int rank, NDArray& targe
|
|||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void NDArray::permute(const std::vector<int>& dimensions, NDArray& target) const {
|
||||
permute(dimensions.data(), dimensions.size(), target);
|
||||
permute(dimensions.data(), rankOf(), target);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void NDArray::permute(const std::vector<Nd4jLong>& dimensions, NDArray& target) const {
|
||||
permute(dimensions.data(), dimensions.size(), target);
|
||||
permute(dimensions.data(), rankOf(), target);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
|
|
|
@ -50,11 +50,13 @@ namespace nd4j {
|
|||
static std::vector<Nd4jLong> evalRepeatShape(int axis, const std::vector<int>& repeats, const NDArray& arr);
|
||||
|
||||
// evaluate shapeInfo of permuted array
|
||||
static Nd4jLong* evalPermShapeInfo(const int* dimensions, const int rank, const NDArray& arr, nd4j::memory::Workspace* workspace);
|
||||
// if setContigStrides = true, then set contiguous strides in output shapeInfo in accordance with arr order
|
||||
static Nd4jLong* evalPermShapeInfo(const int* dimensions, const int rank, const NDArray& arr, nd4j::memory::Workspace* workspace, const bool setContigStrides = false);
|
||||
static Nd4jLong* evalPermShapeInfo(const Nd4jLong* dimensions, const int rank, const NDArray& arr, nd4j::memory::Workspace* workspace);
|
||||
|
||||
// evaluate shapeInfo of transposed array
|
||||
static Nd4jLong* evalTranspShapeInfo(const NDArray& arr, nd4j::memory::Workspace* workspace);
|
||||
// if setContigStrides = true, then set contiguous strides in output shapeInfo in accordance with arr order
|
||||
static Nd4jLong* evalTranspShapeInfo(const NDArray& arr, nd4j::memory::Workspace* workspace, const bool setContigStrides = false);
|
||||
|
||||
static bool copyVectorPart(std::vector<int>& target, std::vector<int>& source, int rank, int offset);
|
||||
|
||||
|
|
|
@ -313,32 +313,37 @@ std::vector<Nd4jLong> ShapeUtils::evalRepeatShape(int axis, const std::vector<in
|
|||
return outShape;
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
// evaluate shapeInfo of permuted array
|
||||
Nd4jLong* ShapeUtils::evalPermShapeInfo(const int* dimensions, const int rank, const NDArray& arr, nd4j::memory::Workspace* workspace) {
|
||||
Nd4jLong* ShapeUtils::evalPermShapeInfo(const int* dimensions, const int rank, const NDArray& arr, nd4j::memory::Workspace* workspace, const bool setContigStrides) {
|
||||
|
||||
if (!arr.nonNull())
|
||||
throw std::runtime_error("ShapeUtils::evalPermShapeInfo static method: wrong arguments: array is nullptr!");
|
||||
if (!arr.nonNull())
|
||||
throw std::runtime_error("ShapeUtils::evalPermShapeInfo static method: wrong arguments: array is nullptr!");
|
||||
|
||||
if (rank != arr.rankOf())
|
||||
throw std::runtime_error("ShapeUtils::evalPermShapeInfo static method: wrong arguments: rank is not suitable!");
|
||||
if (rank != arr.rankOf())
|
||||
throw std::runtime_error("ShapeUtils::evalPermShapeInfo static method: wrong arguments: rank is not suitable!");
|
||||
|
||||
auto shapeInfoLength = shape::shapeInfoLength(rank);
|
||||
// allocate memory for new array - shapeInfo
|
||||
auto shapeInfoLength = shape::shapeInfoLength(rank);
|
||||
|
||||
Nd4jLong *shapeInfoNew = nullptr;
|
||||
ALLOCATE(shapeInfoNew, workspace, shapeInfoLength, Nd4jLong);
|
||||
// copy arr _shapeInfo into new array
|
||||
memcpy(shapeInfoNew, arr.getShapeInfo(), shape::shapeInfoByteLength(rank));
|
||||
// perform buffer permutation
|
||||
shape::doPermuteShapeInfo(shapeInfoNew, dimensions);
|
||||
// allocate memory for new array - shapeInfo
|
||||
Nd4jLong *shapeInfoNew = nullptr;
|
||||
ALLOCATE(shapeInfoNew, workspace, shapeInfoLength, Nd4jLong);
|
||||
|
||||
ShapeDescriptor descriptor(shapeInfoNew);
|
||||
RELEASE(shapeInfoNew, workspace);
|
||||
return ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
|
||||
}
|
||||
// copy arr _shapeInfo into new array
|
||||
memcpy(shapeInfoNew, arr.getShapeInfo(), shape::shapeInfoByteLength(rank));
|
||||
|
||||
// perform buffer permutation
|
||||
shape::doPermuteShapeInfo(shapeInfoNew, dimensions, arr.lengthOf());
|
||||
|
||||
if(setContigStrides)
|
||||
shape::updateStrides(shapeInfoNew, arr.ordering());
|
||||
|
||||
ShapeDescriptor descriptor(shapeInfoNew);
|
||||
|
||||
RELEASE(shapeInfoNew, workspace);
|
||||
|
||||
return ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
// evaluate shapeInfo of permuted array
|
||||
|
@ -350,14 +355,14 @@ std::vector<Nd4jLong> ShapeUtils::evalRepeatShape(int axis, const std::vector<in
|
|||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
// evaluate shapeInfo of transposed array
|
||||
Nd4jLong* ShapeUtils::evalTranspShapeInfo(const NDArray& arr, nd4j::memory::Workspace* workspace) {
|
||||
Nd4jLong* ShapeUtils::evalTranspShapeInfo(const NDArray& arr, nd4j::memory::Workspace* workspace, const bool setContigStrides) {
|
||||
|
||||
int rank = arr.rankOf();
|
||||
std::vector<int> dimensions(rank);
|
||||
for (int i = 0; i < rank; ++i)
|
||||
dimensions[i] = rank - 1 - i;
|
||||
|
||||
return evalPermShapeInfo(dimensions.data(), dimensions.size(), arr, workspace);
|
||||
return evalPermShapeInfo(dimensions.data(), dimensions.size(), arr, workspace, setContigStrides);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
|
|
|
@ -15,7 +15,8 @@
|
|||
******************************************************************************/
|
||||
|
||||
//
|
||||
// Created by raver119 on 29/10/17.
|
||||
// @author raver119@gmail.com
|
||||
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||
//
|
||||
|
||||
#include <op_boilerplate.h>
|
||||
|
@ -29,80 +30,52 @@ namespace nd4j {
|
|||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
// here iArgs is int vector of ordered set of dimensions to be permuted
|
||||
CUSTOM_OP_IMPL(permute, 1, 1, true, 0, -2) {
|
||||
auto x = INPUT_VARIABLE(0);
|
||||
CUSTOM_OP_IMPL(permute, 1, 1, true, 0, -2) {
|
||||
|
||||
bool replace = false;
|
||||
auto x = INPUT_VARIABLE(0);
|
||||
auto z = OUTPUT_VARIABLE(0);
|
||||
|
||||
auto origArgs = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT<int>() : *block.getIArguments();
|
||||
std::vector<int> arguments({});
|
||||
if(origArgs.size() > 0){
|
||||
for (int e = 0; e < origArgs.size(); e++) {
|
||||
int ax = origArgs[e];
|
||||
if (ax < 0)
|
||||
ax += x->rankOf();
|
||||
|
||||
arguments.emplace_back(ax);
|
||||
}
|
||||
|
||||
replace = true;
|
||||
} else {
|
||||
for (int e = x->rankOf() - 1; e >= 0; e--)
|
||||
arguments.emplace_back(e);
|
||||
}
|
||||
|
||||
// 0D edge case
|
||||
if (x->rankOf() == 0) {
|
||||
REQUIRE_TRUE(arguments.size() == 1, 0, "Permute: only one axis is allowed for scalar");
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
if (!block.isInplace())
|
||||
output->assign(x);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if(block.isInplace()) { // in-place
|
||||
x->permutei(arguments);
|
||||
STORE_RESULT(x);
|
||||
} else {
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
auto result = x->permute(arguments);
|
||||
output->assign(result);
|
||||
STORE_RESULT(output);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
DECLARE_TYPES(permute) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes(0, nd4j::DataType::ANY)
|
||||
->setAllowedInputTypes(1, {ALL_INTS})
|
||||
->setSameMode(true);
|
||||
}
|
||||
|
||||
DECLARE_SHAPE_FN(permute) {
|
||||
auto shapeList = SHAPELIST();
|
||||
auto arguments = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT<int>() : *block.getIArguments();
|
||||
|
||||
if (shape::rank(inputShape->at(0)) == 0) {
|
||||
shapeList->push_back(ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(inputShape->at(0))));
|
||||
} else if (inputShape->size() == 1 && !arguments.empty()) {
|
||||
shapeList->push_back(ShapeUtils::evalPermShapeInfo(arguments.data(), arguments.size(), *INPUT_VARIABLE(0), block.workspace()));
|
||||
} else {
|
||||
if(arguments.size() == 0){
|
||||
//Reverse dimensions
|
||||
int rank = shape::rank(inputShape->at(0));
|
||||
for (int e = rank - 1; e >= 0; e--)
|
||||
arguments.emplace_back(e);
|
||||
}
|
||||
|
||||
shapeList->push_back(ShapeUtils::evalPermShapeInfo(arguments.data(), arguments.size(), *INPUT_VARIABLE(0), block.workspace()));
|
||||
}
|
||||
|
||||
return shapeList;
|
||||
}
|
||||
if (x->isEmpty()) {
|
||||
REQUIRE_TRUE(z->isEmpty(), 0, "PERMUTE OP: when input is empty, output must also be empty");
|
||||
return Status::OK(); //No op
|
||||
}
|
||||
|
||||
if (block.width() == 1 && block.getIArguments()->size() == 0) {
|
||||
z->assign(x->transpose());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::vector<int> permutationVector = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT<int>() : *block.getIArguments();
|
||||
|
||||
z->assign(x->permute(permutationVector));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
DECLARE_TYPES(permute) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes(0, nd4j::DataType::ANY)
|
||||
->setAllowedInputTypes(1, {ALL_INTS})
|
||||
->setSameMode(true);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
DECLARE_SHAPE_FN(permute) {
|
||||
|
||||
auto x = INPUT_VARIABLE(0);
|
||||
|
||||
if (block.width() == 1 && block.getIArguments()->size() == 0)
|
||||
return SHAPELIST(ShapeUtils::evalTranspShapeInfo(*x, block.workspace(), true));
|
||||
|
||||
std::vector<int> permutationVector = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT<int>() : *block.getIArguments();
|
||||
|
||||
auto outputShapeInfo = ShapeUtils::evalPermShapeInfo(permutationVector.data(), x->rankOf(), *x, block.workspace(), true);
|
||||
|
||||
return SHAPELIST(outputShapeInfo);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
|
@ -24,254 +24,240 @@
|
|||
#include <ops/declarable/CustomOperations.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
// here iArgs is a vector with (optional) negative of order as first element:
|
||||
// ({-order, dim1, dim2, dim3, ...})
|
||||
CUSTOM_OP_IMPL(reshape, 1, 1, false, 0, -2) {
|
||||
auto x = INPUT_VARIABLE(0);
|
||||
namespace ops {
|
||||
|
||||
if (block.width() == 1) {
|
||||
auto arguments = block.getIArguments();
|
||||
int argsSize = arguments->size();
|
||||
|
||||
//Special case: empty.reshape(<other empty shape>) -> return empty
|
||||
if (x->isEmpty()) {
|
||||
REQUIRE_TRUE(OUTPUT_VARIABLE(0)->isEmpty(), 0, "Reshape: when input is empty, output must also be empty");
|
||||
return ND4J_STATUS_OK; //No op
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
// here iArgs is a vector with (optional) negative of order as first element:
|
||||
// ({-order, dim1, dim2, dim3, ...})
|
||||
CUSTOM_OP_IMPL(reshape, 1, 1, false, 0, -2) {
|
||||
|
||||
auto x = INPUT_VARIABLE(0);
|
||||
auto z = OUTPUT_VARIABLE(0);
|
||||
|
||||
//Special case: empty.reshape(<other empty shape>) -> return empty
|
||||
if (x->isEmpty()) {
|
||||
REQUIRE_TRUE(z->isEmpty(), 0, "Reshape: when input is empty, output must also be empty");
|
||||
return Status::OK(); //No op
|
||||
}
|
||||
|
||||
if (block.width() == 1) {
|
||||
|
||||
auto arguments = block.getIArguments();
|
||||
int argsSize = arguments->size();
|
||||
|
||||
|
||||
|
||||
int e = 1;
|
||||
char order = (char) -(*arguments)[0];
|
||||
if (order != 'c' && order != 'f') {
|
||||
order = 'c'; //x->ordering();
|
||||
e = 0;
|
||||
}
|
||||
|
||||
REQUIRE_TRUE(argsSize - e >= 1, 0, "Reshape arguments should have at least 1 dimension");
|
||||
|
||||
std::vector<Nd4jLong> shapeNew;
|
||||
int e2 = e;
|
||||
for (; e < (int) arguments->size(); e++) {
|
||||
if (arguments->at(e) == -1){
|
||||
Nd4jLong shapeLength = 1;
|
||||
for(; e2 < e; e2++){
|
||||
shapeLength *= arguments->at(e2);
|
||||
}
|
||||
|
||||
int e = 1;
|
||||
char order = (char) -(*arguments)[0];
|
||||
if (order != 'c' && order != 'f') {
|
||||
order = 'c'; //x->ordering();
|
||||
e = 0;
|
||||
}
|
||||
|
||||
REQUIRE_TRUE(argsSize - e >= 1, 0, "Reshape arguments should have at least 1 dimension");
|
||||
|
||||
std::vector<Nd4jLong> shapeNew;
|
||||
int e2 = e;
|
||||
for (; e < (int) arguments->size(); e++) {
|
||||
if (arguments->at(e) == -1){
|
||||
Nd4jLong shapeLength = 1;
|
||||
for(; e2 < e; e2++){
|
||||
shapeLength *= arguments->at(e2);
|
||||
}
|
||||
for(e2 = e + 1; e2 < arguments->size(); e2++){
|
||||
shapeLength *= arguments->at(e2);
|
||||
}
|
||||
Nd4jLong realShape = x->lengthOf() / shapeLength;
|
||||
shapeNew.push_back(realShape);
|
||||
}
|
||||
else{
|
||||
shapeNew.push_back(arguments->at(e));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
auto len = shape::prodLong(shapeNew.data(), shapeNew.size());
|
||||
REQUIRE_TRUE(len == x->lengthOf(), 0, "Reshape: lengths before and after reshape should match, but got %i vs %i", x->lengthOf(), len);
|
||||
|
||||
if (Environment::getInstance()->isDebugAndVerbose()) {
|
||||
nd4j_printv("Reshape: new shape", shapeNew);
|
||||
}
|
||||
|
||||
if (block.isInplace()) {
|
||||
if (x->reshapei(order, shapeNew)) {
|
||||
STORE_RESULT(*x);
|
||||
return ND4J_STATUS_OK;
|
||||
}
|
||||
} else {
|
||||
auto ret = OUTPUT_VARIABLE(0);
|
||||
auto xr = x->reshape(order, shapeNew);
|
||||
ret->assign(xr);
|
||||
STORE_RESULT(*ret);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
} else if (block.width() == 2) {
|
||||
auto s = INPUT_VARIABLE(1);
|
||||
|
||||
//Special case: empty.reshape(-1) -> return empty
|
||||
if (x->isEmpty()) {
|
||||
//REQUIRE_TRUE(s->lengthOf() == 1 && s->e<Nd4jLong>(0) == -1, 0, "Reshape: when input is empty, shape must be [-1]");
|
||||
REQUIRE_TRUE(OUTPUT_VARIABLE(0)->isEmpty(), 0, "Reshape: when input is empty, output must also be empty");
|
||||
return Status::OK(); //No op
|
||||
}
|
||||
|
||||
char order = 'c';
|
||||
if (block.numI() > 0)
|
||||
order = (char) -INT_ARG(0);
|
||||
|
||||
std::vector<Nd4jLong> shapeNew(s->lengthOf());
|
||||
|
||||
for (int e = 0; e < (int) s->lengthOf(); e++) {
|
||||
auto dim = s->e<Nd4jLong >(e);
|
||||
if (dim == -1){
|
||||
Nd4jLong shapeLength = 1;
|
||||
for(int e2 = 0; e2 < e; e2++){
|
||||
shapeLength *= s->e<Nd4jLong>(e2);
|
||||
}
|
||||
for(int e2 = e + 1; e2 < (int) s->lengthOf(); e2++){
|
||||
REQUIRE_TRUE(s->e<Nd4jLong>(e2) != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
|
||||
shapeLength *= s->e<Nd4jLong>(e2);
|
||||
}
|
||||
Nd4jLong realShape = x->lengthOf() / shapeLength;
|
||||
shapeNew[e] = realShape;
|
||||
}
|
||||
else{
|
||||
shapeNew[e] = dim;
|
||||
}
|
||||
}
|
||||
|
||||
if (Environment::getInstance()->isDebugAndVerbose()) {
|
||||
nd4j_printv("Reshape: new shape", shapeNew);
|
||||
}
|
||||
|
||||
if (block.isInplace()) {
|
||||
if (x->reshapei(order, shapeNew)) {
|
||||
STORE_RESULT(*x);
|
||||
return Status::OK();
|
||||
}
|
||||
} else {
|
||||
auto ret = OUTPUT_VARIABLE(0);
|
||||
if (s->isEmpty()) {
|
||||
// just a scalar
|
||||
ret->assign(x);
|
||||
} else {
|
||||
auto xr = x->reshape(order, shapeNew);
|
||||
ret->assign(xr);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
for(e2 = e + 1; e2 < arguments->size(); e2++){
|
||||
shapeLength *= arguments->at(e2);
|
||||
}
|
||||
Nd4jLong realShape = x->lengthOf() / shapeLength;
|
||||
shapeNew.push_back(realShape);
|
||||
}
|
||||
else{
|
||||
shapeNew.push_back(arguments->at(e));
|
||||
}
|
||||
|
||||
return ND4J_STATUS_BAD_INPUT;
|
||||
}
|
||||
|
||||
auto len = shape::prodLong(shapeNew.data(), shapeNew.size());
|
||||
REQUIRE_TRUE(len == x->lengthOf(), 0, "Reshape: lengths before and after reshape should match, but got %i vs %i", x->lengthOf(), len);
|
||||
|
||||
DECLARE_TYPES(reshape) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes(0, nd4j::DataType::ANY)
|
||||
->setAllowedInputTypes(1, {ALL_INTS})
|
||||
->setSameMode(true);
|
||||
if (Environment::getInstance()->isDebugAndVerbose()) {
|
||||
nd4j_printv("Reshape: new shape", shapeNew);
|
||||
}
|
||||
|
||||
DECLARE_SHAPE_FN(reshape) {
|
||||
auto inp = inputShape->at(0);
|
||||
auto xr = x->reshape(order, shapeNew);
|
||||
z->assign(xr);
|
||||
STORE_RESULT(*z);
|
||||
|
||||
// we can launch op using Int arguments
|
||||
if (inputShape->size() == 1) {
|
||||
REQUIRE_TRUE(block.numI() > 0, 0, "Reshape: new shape should be provided as NDArray or int arguments, but nothing was defined");
|
||||
std::vector<int> *arguments = block.getIArguments();
|
||||
return Status::OK();
|
||||
|
||||
int e = 1;
|
||||
char order = (char) -(*arguments)[0];
|
||||
if (order != 'c' && order != 'f') {
|
||||
order = shape::order(inp);
|
||||
e = 0;
|
||||
} else if (block.width() == 2) {
|
||||
|
||||
auto s = INPUT_VARIABLE(1);
|
||||
|
||||
char order = 'c';
|
||||
if (block.numI() > 0)
|
||||
order = (char) -INT_ARG(0);
|
||||
|
||||
std::vector<Nd4jLong> shapeNew(s->lengthOf());
|
||||
|
||||
for (int e = 0; e < (int) s->lengthOf(); e++) {
|
||||
auto dim = s->e<Nd4jLong >(e);
|
||||
if (dim == -1){
|
||||
Nd4jLong shapeLength = 1;
|
||||
for(int e2 = 0; e2 < e; e2++){
|
||||
shapeLength *= s->e<Nd4jLong>(e2);
|
||||
}
|
||||
|
||||
std::vector<Nd4jLong> shapeNew;
|
||||
|
||||
int e2 = e;
|
||||
for (; e < (int) arguments->size(); e++) {
|
||||
if ((int) arguments->at(e) == -1){
|
||||
|
||||
Nd4jLong shapeLength = 1;
|
||||
for(; e2 < e; e2 ++){
|
||||
shapeLength *= arguments->at(e2);
|
||||
}
|
||||
for(e2 = e + 1; e2 < arguments->size(); e2++){
|
||||
REQUIRE_TRUE(arguments->at(e2) != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
|
||||
shapeLength *= arguments->at(e2);
|
||||
}
|
||||
|
||||
if(shapeLength == 0){
|
||||
//Edge case for empty:
|
||||
shapeNew.push_back(0);
|
||||
} else {
|
||||
//Standard case
|
||||
Nd4jLong realShape = shape::length(inp) / shapeLength;
|
||||
shapeNew.push_back(realShape);
|
||||
}
|
||||
}
|
||||
else{
|
||||
shapeNew.push_back(arguments->at(e));
|
||||
}
|
||||
for(int e2 = e + 1; e2 < (int) s->lengthOf(); e2++){
|
||||
REQUIRE_TRUE(s->e<Nd4jLong>(e2) != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
|
||||
shapeLength *= s->e<Nd4jLong>(e2);
|
||||
}
|
||||
|
||||
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(inp), order, shapeNew)));
|
||||
} else {
|
||||
// or, with second input "as shape"
|
||||
auto x = INPUT_VARIABLE(0);
|
||||
auto y = INPUT_VARIABLE(1);
|
||||
|
||||
// special case here
|
||||
if (y->isEmpty()) {
|
||||
REQUIRE_TRUE(x->lengthOf() == 1, 0, "Reshape: new length doesn't match existing array");
|
||||
return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(inp)));
|
||||
}
|
||||
//Special case: empty.reshape(-1) -> return empty
|
||||
if (x->isEmpty()) {
|
||||
//REQUIRE_TRUE(y->lengthOf() == 1 && y->e<Nd4jLong>(0) == -1, 0, "Reshape: when input is empty, shape must be [-1]");
|
||||
auto shapeOf = y->getBufferAsVector<Nd4jLong>();
|
||||
Nd4jLong prod = 1;
|
||||
bool hasNegs = false;
|
||||
for (auto v:shapeOf) {
|
||||
if (v < 0) {
|
||||
hasNegs = true;
|
||||
v = 0;
|
||||
}
|
||||
|
||||
prod *= v;
|
||||
}
|
||||
|
||||
REQUIRE_TRUE(prod == 0, 0, "Reshape: in case of empty arrays reshape must return empty array as well");
|
||||
|
||||
// if there are -1s - we turn them into zeros
|
||||
if (hasNegs) {
|
||||
for (int e = 0; e < shapeOf.size(); e++)
|
||||
if (shapeOf[e] < 0)
|
||||
shapeOf[e] = 0;
|
||||
}
|
||||
|
||||
auto newShape = ShapeBuilders::createShapeInfo(ArrayOptions::dataType(inp), shape::order(inp), y->lengthOf(), shapeOf.data());
|
||||
return SHAPELIST(CONSTANT(newShape));
|
||||
}
|
||||
|
||||
std::vector<Nd4jLong> shapeNew(y->lengthOf());
|
||||
|
||||
for (int e = 0; e < (int) y->lengthOf(); e++) {
|
||||
auto dim = y->e<Nd4jLong>(e);
|
||||
if (dim == -1){
|
||||
Nd4jLong shapeLength = 1;
|
||||
for(int e2 = 0; e2 < e; e2++){
|
||||
shapeLength *= y->e<Nd4jLong>(e2);
|
||||
}
|
||||
for(int e2 = e + 1; e2 < (int)y->lengthOf(); e2++){
|
||||
REQUIRE_TRUE(y->e<Nd4jLong>(e2) != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
|
||||
shapeLength *= y->e<Nd4jLong>(e2);
|
||||
}
|
||||
|
||||
if(shapeLength == 0){
|
||||
//Edge case for empty:
|
||||
shapeNew[e] = 0;
|
||||
} else {
|
||||
Nd4jLong realShape = shape::length(inp) / shapeLength;
|
||||
shapeNew[e] = realShape;
|
||||
}
|
||||
}else {
|
||||
shapeNew[e] = dim;
|
||||
}
|
||||
}
|
||||
|
||||
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inp), 'c', shapeNew));
|
||||
Nd4jLong realShape = x->lengthOf() / shapeLength;
|
||||
shapeNew[e] = realShape;
|
||||
}
|
||||
else{
|
||||
shapeNew[e] = dim;
|
||||
}
|
||||
}
|
||||
|
||||
if (Environment::getInstance()->isDebugAndVerbose()) {
|
||||
nd4j_printv("Reshape: new shape", shapeNew);
|
||||
}
|
||||
|
||||
if (s->isEmpty()) {
|
||||
// just a scalar
|
||||
z->assign(x);
|
||||
} else {
|
||||
auto xr = x->reshape(order, shapeNew);
|
||||
z->assign(xr);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
||||
}
|
||||
|
||||
return ND4J_STATUS_BAD_INPUT;
|
||||
}
|
||||
|
||||
|
||||
DECLARE_TYPES(reshape) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes(0, nd4j::DataType::ANY)
|
||||
->setAllowedInputTypes(1, {ALL_INTS})
|
||||
->setSameMode(true);
|
||||
}
|
||||
|
||||
DECLARE_SHAPE_FN(reshape) {
|
||||
auto inp = inputShape->at(0);
|
||||
|
||||
// we can launch op using Int arguments
|
||||
if (inputShape->size() == 1) {
|
||||
REQUIRE_TRUE(block.numI() > 0, 0, "Reshape: new shape should be provided as NDArray or int arguments, but nothing was defined");
|
||||
std::vector<int> *arguments = block.getIArguments();
|
||||
|
||||
int e = 1;
|
||||
char order = (char) -(*arguments)[0];
|
||||
if (order != 'c' && order != 'f') {
|
||||
order = shape::order(inp);
|
||||
e = 0;
|
||||
}
|
||||
|
||||
std::vector<Nd4jLong> shapeNew;
|
||||
|
||||
int e2 = e;
|
||||
for (; e < (int) arguments->size(); e++) {
|
||||
if ((int) arguments->at(e) == -1){
|
||||
|
||||
Nd4jLong shapeLength = 1;
|
||||
for(; e2 < e; e2 ++){
|
||||
shapeLength *= arguments->at(e2);
|
||||
}
|
||||
for(e2 = e + 1; e2 < arguments->size(); e2++){
|
||||
REQUIRE_TRUE(arguments->at(e2) != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
|
||||
shapeLength *= arguments->at(e2);
|
||||
}
|
||||
|
||||
if(shapeLength == 0){
|
||||
//Edge case for empty:
|
||||
shapeNew.push_back(0);
|
||||
} else {
|
||||
//Standard case
|
||||
Nd4jLong realShape = shape::length(inp) / shapeLength;
|
||||
shapeNew.push_back(realShape);
|
||||
}
|
||||
}
|
||||
else{
|
||||
shapeNew.push_back(arguments->at(e));
|
||||
}
|
||||
}
|
||||
|
||||
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(inp), order, shapeNew)));
|
||||
} else {
|
||||
// or, with second input "as shape"
|
||||
auto x = INPUT_VARIABLE(0);
|
||||
auto y = INPUT_VARIABLE(1);
|
||||
|
||||
// special case here
|
||||
if (y->isEmpty()) {
|
||||
REQUIRE_TRUE(x->lengthOf() == 1, 0, "Reshape: new length doesn't match existing array");
|
||||
return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(inp)));
|
||||
}
|
||||
//Special case: empty.reshape(-1) -> return empty
|
||||
if (x->isEmpty()) {
|
||||
//REQUIRE_TRUE(y->lengthOf() == 1 && y->e<Nd4jLong>(0) == -1, 0, "Reshape: when input is empty, shape must be [-1]");
|
||||
auto shapeOf = y->getBufferAsVector<Nd4jLong>();
|
||||
Nd4jLong prod = 1;
|
||||
bool hasNegs = false;
|
||||
for (auto v:shapeOf) {
|
||||
if (v < 0) {
|
||||
hasNegs = true;
|
||||
v = 0;
|
||||
}
|
||||
|
||||
prod *= v;
|
||||
}
|
||||
|
||||
REQUIRE_TRUE(prod == 0, 0, "Reshape: in case of empty arrays reshape must return empty array as well");
|
||||
|
||||
// if there are -1s - we turn them into zeros
|
||||
if (hasNegs) {
|
||||
for (int e = 0; e < shapeOf.size(); e++)
|
||||
if (shapeOf[e] < 0)
|
||||
shapeOf[e] = 0;
|
||||
}
|
||||
|
||||
auto newShape = ShapeBuilders::createShapeInfo(ArrayOptions::dataType(inp), shape::order(inp), y->lengthOf(), shapeOf.data());
|
||||
return SHAPELIST(CONSTANT(newShape));
|
||||
}
|
||||
|
||||
std::vector<Nd4jLong> shapeNew(y->lengthOf());
|
||||
|
||||
for (int e = 0; e < (int) y->lengthOf(); e++) {
|
||||
auto dim = y->e<Nd4jLong>(e);
|
||||
if (dim == -1){
|
||||
Nd4jLong shapeLength = 1;
|
||||
for(int e2 = 0; e2 < e; e2++){
|
||||
shapeLength *= y->e<Nd4jLong>(e2);
|
||||
}
|
||||
for(int e2 = e + 1; e2 < (int)y->lengthOf(); e2++){
|
||||
REQUIRE_TRUE(y->e<Nd4jLong>(e2) != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
|
||||
shapeLength *= y->e<Nd4jLong>(e2);
|
||||
}
|
||||
|
||||
if(shapeLength == 0){
|
||||
//Edge case for empty:
|
||||
shapeNew[e] = 0;
|
||||
} else {
|
||||
Nd4jLong realShape = shape::length(inp) / shapeLength;
|
||||
shapeNew[e] = realShape;
|
||||
}
|
||||
}else {
|
||||
shapeNew[e] = dim;
|
||||
}
|
||||
}
|
||||
|
||||
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inp), 'c', shapeNew));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
|
@ -29,34 +29,26 @@ namespace nd4j {
|
|||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
CUSTOM_OP_IMPL(reshapeas, 2, 1, false, 0, 0) {
|
||||
|
||||
|
||||
auto x = INPUT_VARIABLE(0);
|
||||
auto y = INPUT_VARIABLE(1);
|
||||
|
||||
auto z = OUTPUT_VARIABLE(0);
|
||||
std::vector<Nd4jLong> shapeNew(y->shapeOf(), y->shapeOf() + y->rankOf());
|
||||
char order = y->ordering();
|
||||
|
||||
if (x->reshapei(order, shapeNew)) {
|
||||
*z = *x;
|
||||
STORE_RESULT(*z);
|
||||
if (x->reshapei(y->ordering(), y->getShapeAsVector())) {
|
||||
|
||||
z->assign(x);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
return ND4J_STATUS_BAD_INPUT;
|
||||
}
|
||||
DECLARE_SYN(reshape_as, reshapeas);
|
||||
|
||||
DECLARE_SHAPE_FN(reshapeas) {
|
||||
|
||||
auto inputShapeInfo = inputShape->at(1);
|
||||
int shapeInfoLength = inputShapeInfo[0]*2 + 4;
|
||||
|
||||
Nd4jLong* outputShapeInfo(nullptr);
|
||||
COPY_SHAPE(inputShapeInfo, outputShapeInfo);
|
||||
|
||||
return SHAPELIST(CONSTANT(outputShapeInfo));
|
||||
}
|
||||
DECLARE_SHAPE_FN(reshapeas) {
|
||||
|
||||
return SHAPELIST(ShapeBuilders::copyShapeInfo(INPUT_VARIABLE(1)->getShapeInfo(), false, block.workspace()));
|
||||
}
|
||||
|
||||
DECLARE_TYPES(reshapeas) {
|
||||
getOpDescriptor()
|
||||
|
|
|
@ -15,7 +15,8 @@
|
|||
******************************************************************************/
|
||||
|
||||
//
|
||||
// Created by raver119 on 29/10/17.
|
||||
// @author raver119@gmail.com
|
||||
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||
//
|
||||
|
||||
#include <op_boilerplate.h>
|
||||
|
@ -25,113 +26,52 @@
|
|||
#include <helpers/ShapeUtils.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace ops {
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
CUSTOM_OP_IMPL(transpose, 1, 1, false, 0, 0) {
|
||||
auto x = INPUT_VARIABLE(0);
|
||||
if (block.width() == 1) {
|
||||
if (block.isInplace()) {
|
||||
x->transposei();
|
||||
STORE_RESULT(*x);
|
||||
} else {
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
auto t = x->transpose();
|
||||
output->assign(t);
|
||||
STORE_RESULT(*output);
|
||||
}
|
||||
} else {
|
||||
// this is tf-mode transpose, that's nd4j permute
|
||||
bool replace = false;
|
||||
std::vector<int> arguments(*block.getIArguments());
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
CUSTOM_OP_IMPL(transpose, 1, 1, false, 0, 0) {
|
||||
|
||||
auto w = block.width();
|
||||
auto a = arguments.size();
|
||||
auto x = INPUT_VARIABLE(0);
|
||||
auto z = OUTPUT_VARIABLE(0);
|
||||
|
||||
if (w == 2 && a == 0) {
|
||||
auto axis = INPUT_VARIABLE(1);
|
||||
for (int e = 0; e < axis->lengthOf(); e++) {
|
||||
auto ax = axis->e<int>(e);
|
||||
if (ax < 0)
|
||||
ax += x->rankOf();
|
||||
//Special case: empty.reshape(<other empty shape>) -> return empty
|
||||
if (x->isEmpty()) {
|
||||
REQUIRE_TRUE(z->isEmpty(), 0, "TRANSPOSE OP: when input is empty, output must also be empty");
|
||||
return Status::OK(); //No op
|
||||
}
|
||||
|
||||
arguments.emplace_back(ax);
|
||||
}
|
||||
|
||||
replace = true;
|
||||
} else if (a == 0) {
|
||||
for (int e = x->rankOf() - 1; e >= 0; e--)
|
||||
arguments.emplace_back(e);
|
||||
}
|
||||
|
||||
// 0D edge case
|
||||
if (x->rankOf() == 0) {
|
||||
REQUIRE_TRUE(arguments.size() == 1, 0, "Permute: only one axis is allowed for scalar");
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
if (!block.isInplace())
|
||||
output->assign(x);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if(block.isInplace()) { // in-place
|
||||
x->permutei(arguments);
|
||||
STORE_RESULT(x);
|
||||
} else {
|
||||
auto input = x->permute(arguments);
|
||||
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
output->assign(input);
|
||||
}
|
||||
}
|
||||
if (block.width() == 1 && block.getIArguments()->size() == 0) {
|
||||
z->assign(x->transpose());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
DECLARE_TYPES(transpose) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes(nd4j::DataType::ANY)
|
||||
->setSameMode(true);
|
||||
}
|
||||
std::vector<int> permutationVector = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT<int>() : *block.getIArguments();
|
||||
|
||||
DECLARE_SHAPE_FN(transpose) {
|
||||
if (block.width() == 1) {
|
||||
auto outputShapeInfo = ShapeUtils::evalTranspShapeInfo(*INPUT_VARIABLE(0), block.workspace());
|
||||
return SHAPELIST(outputShapeInfo);
|
||||
} else {
|
||||
// this is basically permute mode
|
||||
auto shapeList = SHAPELIST();
|
||||
auto arguments = block.getIArguments();
|
||||
if (shape::rank(inputShape->at(0)) == 0) {
|
||||
Nd4jLong *newshape;
|
||||
ALLOCATE(newshape, block.getWorkspace(), shape::shapeInfoLength(inputShape->at(0)), Nd4jLong);
|
||||
newshape[0] = 0;
|
||||
newshape[1] = 0;
|
||||
newshape[2] = 1;
|
||||
newshape[3] = 99;
|
||||
ArrayOptions::copyDataType(newshape, inputShape->at(0));
|
||||
shapeList->push_back(newshape);
|
||||
} else if (arguments->size() > 0 || inputShape->size() > 1) {
|
||||
auto axis = arguments->size() > 0 ? *arguments : (INPUT_VARIABLE(1))->template asVectorT<int>();
|
||||
auto outputShapeInfo = ShapeUtils::evalPermShapeInfo(axis.data(), axis.size(), *INPUT_VARIABLE(0), block.workspace());
|
||||
shapeList->push_back(outputShapeInfo);
|
||||
} else if (inputShape->size() == 2) {
|
||||
// dead end
|
||||
auto axis = INPUT_VARIABLE(1);
|
||||
auto axisV = axis->template asVectorT<Nd4jLong>();
|
||||
auto newshape = ShapeUtils::evalPermShapeInfo(axisV.data(), axisV.size(), *INPUT_VARIABLE(0), block.workspace());
|
||||
shapeList->push_back(newshape);
|
||||
} else {
|
||||
int rank = shape::rank(inputShape->at(0));
|
||||
for (int e = rank - 1; e >= 0; e--)
|
||||
arguments->emplace_back(e);
|
||||
z->assign(x->permute(permutationVector));
|
||||
|
||||
auto outputShapeInfo = ShapeUtils::evalPermShapeInfo(arguments->data(), arguments->size(), *INPUT_VARIABLE(0), block.workspace());
|
||||
shapeList->push_back(outputShapeInfo);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
DECLARE_TYPES(transpose) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes(nd4j::DataType::ANY)
|
||||
->setSameMode(true);
|
||||
}
|
||||
|
||||
DECLARE_SHAPE_FN(transpose) {
|
||||
|
||||
auto x = INPUT_VARIABLE(0);
|
||||
|
||||
if (block.width() == 1 && block.getIArguments()->size() == 0)
|
||||
return SHAPELIST(ShapeUtils::evalTranspShapeInfo(*x, block.workspace(), true));
|
||||
|
||||
std::vector<int> permutationVector = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT<int>() : *block.getIArguments();
|
||||
|
||||
auto outputShapeInfo = ShapeUtils::evalPermShapeInfo(permutationVector.data(), x->rankOf(), *x, block.workspace(), true);
|
||||
|
||||
return SHAPELIST(outputShapeInfo);
|
||||
}
|
||||
|
||||
return shapeList;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1882,36 +1882,6 @@ TEST_F(DeclarableOpsTests1, TestGemv1) {
|
|||
|
||||
#endif
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests1, Reshape1) {
|
||||
const std::vector<Nd4jLong> xShape = {5,4,3};
|
||||
const std::vector<Nd4jLong> yShape = {3,5,4};
|
||||
|
||||
auto x = NDArrayFactory::create_<float>('f', xShape);
|
||||
auto y = NDArrayFactory::create_<float>('f', yShape);
|
||||
|
||||
auto variableSpace = new VariableSpace();
|
||||
variableSpace->putVariable(-1, x);
|
||||
|
||||
auto block = new Context(1, variableSpace, true);
|
||||
block->fillInputs({-1});
|
||||
std::vector<int>* arguments = block->getIArguments();
|
||||
arguments->push_back(-y->ordering());
|
||||
arguments->push_back(3);
|
||||
arguments->push_back(5);
|
||||
arguments->push_back(4);
|
||||
|
||||
nd4j::ops::reshape reshape;
|
||||
|
||||
reshape.execute(block);
|
||||
|
||||
ASSERT_TRUE(x->isSameShape(y));
|
||||
|
||||
delete y;
|
||||
delete block;
|
||||
delete variableSpace;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests1, Reshape2) {
|
||||
const std::vector<Nd4jLong> xShape = {5,4,3};
|
||||
|
@ -2022,37 +1992,8 @@ TEST_F(DeclarableOpsTests1, Reshape7){
|
|||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests1, Transpose1) {
|
||||
|
||||
auto x = NDArrayFactory::create_<float>('c', {3,5,2});
|
||||
auto exp = NDArrayFactory::create_<float>('f', {2,5,3});
|
||||
|
||||
auto variableSpace = new VariableSpace();
|
||||
variableSpace->putVariable(-1, x);
|
||||
|
||||
auto block = new Context(1, variableSpace, true); // in-place
|
||||
block->fillInputs({-1});
|
||||
nd4j::ops::transpose transpose;
|
||||
|
||||
Nd4jStatus status = transpose.execute(block);
|
||||
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||
// ASSERT_TRUE(x.isSameShapeStrict(exp));
|
||||
|
||||
for (int e = 0; e < x->rankOf() * 2 + 2; e++) {
|
||||
ASSERT_EQ(x->getShapeInfo()[e], exp->getShapeInfo()[e]);
|
||||
}
|
||||
// ASSERT_EQ(x.getShapeInfo()[x.rankOf() * 2 + 2],-exp.getShapeInfo()[x.rankOf() * 2 + 2]);
|
||||
ASSERT_EQ(x->getShapeInfo()[x->rankOf() * 2 + 3], exp->getShapeInfo()[x->rankOf() * 2 + 3]);
|
||||
|
||||
delete exp;
|
||||
delete block;
|
||||
delete variableSpace;
|
||||
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests1, Transpose2) {
|
||||
auto x = NDArrayFactory::create_<float>('c', {3,5,2});
|
||||
auto exp = NDArrayFactory::create_<float>('f', {2,5,3});
|
||||
auto exp = NDArrayFactory::create_<float>('c', {2,5,3});
|
||||
|
||||
auto variableSpace = new VariableSpace();
|
||||
variableSpace->putVariable(-1, x);
|
||||
|
@ -2066,12 +2007,10 @@ TEST_F(DeclarableOpsTests1, Transpose2) {
|
|||
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||
|
||||
auto result = variableSpace->getVariable(block->getNodeId())->getNDArray();
|
||||
// ASSERT_TRUE(result->isSameShapeStrict(exp));
|
||||
for (int e = 0; e < result->rankOf() * 2 + 2; e++) {
|
||||
ASSERT_EQ(result->getShapeInfo()[e], exp->getShapeInfo()[e]);
|
||||
}
|
||||
//ASSERT_EQ(result->getShapeInfo()[x.rankOf() * 2 + 2],-exp.getShapeInfo()[x.rankOf() * 2 + 2]);
|
||||
ASSERT_EQ(result->getShapeInfo()[x->rankOf() * 2 + 3], exp->getShapeInfo()[x->rankOf() * 2 + 3]);
|
||||
|
||||
ASSERT_TRUE(exp->isSameShape(result));
|
||||
ASSERT_TRUE(exp->dataType() == result->dataType());
|
||||
ASSERT_TRUE(exp->ordering() == result->ordering());
|
||||
|
||||
delete exp;
|
||||
delete block;
|
||||
|
@ -2079,44 +2018,12 @@ TEST_F(DeclarableOpsTests1, Transpose2) {
|
|||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
// in-place
|
||||
TEST_F(DeclarableOpsTests1, Permute1) {
|
||||
|
||||
Nd4jLong shapeX[] = {3, 5, 10, 15, 150, 15, 1, 0, 1, 99};
|
||||
Nd4jLong shapeExp[] = {3, 15, 5, 10, 1, 150, 15, 0, 0, 99};
|
||||
const std::vector<int> perm = {2, 0, 1};
|
||||
ArrayOptions::setDataType(shapeX, nd4j::DataType::FLOAT32);
|
||||
ArrayOptions::setDataType(shapeExp, nd4j::DataType::FLOAT32);
|
||||
|
||||
auto x = new NDArray(shapeX,true);
|
||||
auto exp = new NDArray(shapeExp,true);
|
||||
|
||||
auto variableSpace = new VariableSpace();
|
||||
variableSpace->putVariable(-1, x);
|
||||
|
||||
auto block = new Context(1, variableSpace, true); // in-place
|
||||
block->fillInputs({-1});
|
||||
std::vector<int>* arguments = block->getIArguments();
|
||||
*arguments = perm; // set dimensions to be permuted
|
||||
|
||||
nd4j::ops::permute permute;
|
||||
Nd4jStatus status = permute.execute(block);
|
||||
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||
|
||||
ASSERT_TRUE(x->isSameShapeStrict(*exp));
|
||||
|
||||
delete exp;
|
||||
delete block;
|
||||
delete variableSpace;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
// not-in-place
|
||||
TEST_F(DeclarableOpsTests1, Permute2) {
|
||||
TEST_F(DeclarableOpsTests1, Permute1) {
|
||||
|
||||
Nd4jLong shapeX[] = {3, 5, 10, 15, 150, 15, 1, 0, 1, 99};
|
||||
Nd4jLong shapeExp[] = {3, 15, 5, 10, 1, 150, 15, 0, 0, 99};
|
||||
Nd4jLong shapeX[] = {3, 5,10,15, 150,15,1, 0,1,99};
|
||||
Nd4jLong shapeExp[] = {3, 15,5,10, 50,10,1, 0,1,99};
|
||||
const std::vector<int> perm = {2, 0, 1};
|
||||
|
||||
ArrayOptions::setDataType(shapeX, nd4j::DataType::FLOAT32);
|
||||
|
|
|
@ -161,23 +161,6 @@ TEST_F(EmptyTests, Test_Reshape_1) {
|
|||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(EmptyTests, Test_Reshape_2) {
|
||||
auto vector = NDArrayFactory::create<float>('c', {1}, {119.0f});
|
||||
auto exp = NDArrayFactory::create<float>(119.0f);
|
||||
auto empty = NDArrayFactory::empty_<Nd4jLong>();
|
||||
|
||||
nd4j::ops::reshape op;
|
||||
auto result = op.evaluate({&vector, empty}, {}, {}, {}, {}, true);
|
||||
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
|
||||
ASSERT_EQ(exp, *result->at(0));
|
||||
ASSERT_EQ(exp, vector);
|
||||
|
||||
delete empty;
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(EmptyTests, Test_Reshape_3) {
|
||||
auto x = NDArrayFactory::create<float>('c', {1, 0, 0, 2});
|
||||
auto y = NDArrayFactory::create<int>('c', {2}, {10, 0});
|
||||
|
|
|
@ -65,7 +65,7 @@ TEST_F(PlaygroundTests, test_avx) {
|
|||
nd4j_printf("Optimal level: %i; Binary level: %i;\n", ::optimalLevel(), ::binaryLevel());
|
||||
}
|
||||
|
||||
|
||||
/*
|
||||
TEST_F(PlaygroundTests, test_bert_1) {
|
||||
// this test will run ONLY if this model exists
|
||||
if (nd4j::graph::getFileSize("/home/raver119/Downloads/Bert_minimal_model/bert_minimal_model.fb") < 0)
|
||||
|
@ -86,15 +86,15 @@ TEST_F(PlaygroundTests, test_bert_1) {
|
|||
graph->getVariableSpace()->putVariable(86,0, u);
|
||||
graph->getVariableSpace()->putVariable(87,0, v);
|
||||
|
||||
/*
|
||||
// validating graph now
|
||||
auto status = GraphExecutioner::execute(graph);
|
||||
ASSERT_EQ(Status::OK(), status);
|
||||
ASSERT_TRUE(graph->getVariableSpace()->hasVariable(198));
|
||||
|
||||
auto array = graph->getVariableSpace()->getVariable(198)->getNDArray();
|
||||
ASSERT_EQ(z, *array);
|
||||
*/
|
||||
// validating graph now
|
||||
// auto status = GraphExecutioner::execute(graph);
|
||||
// ASSERT_EQ(Status::OK(), status);
|
||||
// ASSERT_TRUE(graph->getVariableSpace()->hasVariable(198));
|
||||
|
||||
// auto array = graph->getVariableSpace()->getVariable(198)->getNDArray();
|
||||
// ASSERT_EQ(z, *array);
|
||||
|
||||
|
||||
nd4j::Environment::getInstance()->setProfiling(true);
|
||||
auto profile = GraphProfilingHelper::profile(graph, 1);
|
||||
|
@ -104,28 +104,27 @@ TEST_F(PlaygroundTests, test_bert_1) {
|
|||
nd4j::Environment::getInstance()->setProfiling(false);
|
||||
delete profile;
|
||||
|
||||
/*
|
||||
std::vector<Nd4jLong> values;
|
||||
|
||||
for (int e = 0; e < 1; e++) {
|
||||
auto timeStart = std::chrono::system_clock::now();
|
||||
// std::vector<Nd4jLong> values;
|
||||
|
||||
GraphExecutioner::execute(graph);
|
||||
// for (int e = 0; e < 1; e++) {
|
||||
// auto timeStart = std::chrono::system_clock::now();
|
||||
|
||||
auto timeEnd = std::chrono::system_clock::now();
|
||||
auto outerTime = std::chrono::duration_cast<std::chrono::microseconds>(timeEnd - timeStart).count();
|
||||
values.emplace_back(outerTime);
|
||||
}
|
||||
// GraphExecutioner::execute(graph);
|
||||
|
||||
std::sort(values.begin(), values.end());
|
||||
// auto timeEnd = std::chrono::system_clock::now();
|
||||
// auto outerTime = std::chrono::duration_cast<std::chrono::microseconds>(timeEnd - timeStart).count();
|
||||
// values.emplace_back(outerTime);
|
||||
// }
|
||||
|
||||
nd4j_printf("Time: %lld us;\n", values[values.size() / 2]);
|
||||
*/
|
||||
// std::sort(values.begin(), values.end());
|
||||
|
||||
// nd4j_printf("Time: %lld us;\n", values[values.size() / 2]);
|
||||
|
||||
delete graph;
|
||||
}
|
||||
|
||||
/*
|
||||
|
||||
TEST_F(PlaygroundTests, test_broadcast_1) {
|
||||
int pool = 10;
|
||||
std::vector<NDArray*> aX(pool);
|
||||
|
|
Loading…
Reference in New Issue