Shyrma transpose (#244)
* - provide contiguous strides for ouput in transpose op Signed-off-by: Yurii <iuriish@yahoo.com> * - provide contiguous strides for output in permute op Signed-off-by: Yurii <iuriish@yahoo.com> * - take into account empty shapes properly in transpose/permute op Signed-off-by: Yurii <iuriish@yahoo.com>master
parent
9e3c1b02b1
commit
011c272fde
|
@ -1976,7 +1976,7 @@ bool NDArray::permutei(const std::initializer_list<int>& dimensions) {
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
bool NDArray::permutei(const std::vector<int>& dimensions) {
|
bool NDArray::permutei(const std::vector<int>& dimensions) {
|
||||||
return permutei(dimensions.data(), dimensions.size());
|
return permutei(dimensions.data(), rankOf());
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1998,7 +1998,7 @@ bool NDArray::permutei(const std::vector<Nd4jLong>& dimensions) {
|
||||||
for (int e = 0; e < dimensions.size(); e++)
|
for (int e = 0; e < dimensions.size(); e++)
|
||||||
ivec[e] = dimensions[e];
|
ivec[e] = dimensions[e];
|
||||||
|
|
||||||
return permutei(ivec.data(), ivec.size());
|
return permutei(ivec.data(), rankOf());
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -2034,9 +2034,8 @@ NDArray NDArray::permute(const Nd4jLong* dimensions, const int rank) && {
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
NDArray NDArray::permute(const std::vector<int>& dimensions) const &{
|
NDArray NDArray::permute(const std::vector<int>& dimensions) const &{
|
||||||
auto data = dimensions.data();
|
|
||||||
auto size = dimensions.size();
|
return permute(dimensions.data(), rankOf());
|
||||||
return permute(data, size);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -2048,7 +2047,8 @@ NDArray NDArray::permute(const std::vector<int>& dimensions) && {
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
NDArray NDArray::permute(const std::vector<Nd4jLong>& dimensions) const & {
|
NDArray NDArray::permute(const std::vector<Nd4jLong>& dimensions) const & {
|
||||||
return permute(dimensions.data(), dimensions.size());
|
|
||||||
|
return permute(dimensions.data(), rankOf());
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -2111,12 +2111,12 @@ void NDArray::permute(const Nd4jLong *dimensions, const int rank, NDArray& targe
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
void NDArray::permute(const std::vector<int>& dimensions, NDArray& target) const {
|
void NDArray::permute(const std::vector<int>& dimensions, NDArray& target) const {
|
||||||
permute(dimensions.data(), dimensions.size(), target);
|
permute(dimensions.data(), rankOf(), target);
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
void NDArray::permute(const std::vector<Nd4jLong>& dimensions, NDArray& target) const {
|
void NDArray::permute(const std::vector<Nd4jLong>& dimensions, NDArray& target) const {
|
||||||
permute(dimensions.data(), dimensions.size(), target);
|
permute(dimensions.data(), rankOf(), target);
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
|
|
@ -50,11 +50,13 @@ namespace nd4j {
|
||||||
static std::vector<Nd4jLong> evalRepeatShape(int axis, const std::vector<int>& repeats, const NDArray& arr);
|
static std::vector<Nd4jLong> evalRepeatShape(int axis, const std::vector<int>& repeats, const NDArray& arr);
|
||||||
|
|
||||||
// evaluate shapeInfo of permuted array
|
// evaluate shapeInfo of permuted array
|
||||||
static Nd4jLong* evalPermShapeInfo(const int* dimensions, const int rank, const NDArray& arr, nd4j::memory::Workspace* workspace);
|
// if setContigStrides = true, then set contiguous strides in output shapeInfo in accordance with arr order
|
||||||
|
static Nd4jLong* evalPermShapeInfo(const int* dimensions, const int rank, const NDArray& arr, nd4j::memory::Workspace* workspace, const bool setContigStrides = false);
|
||||||
static Nd4jLong* evalPermShapeInfo(const Nd4jLong* dimensions, const int rank, const NDArray& arr, nd4j::memory::Workspace* workspace);
|
static Nd4jLong* evalPermShapeInfo(const Nd4jLong* dimensions, const int rank, const NDArray& arr, nd4j::memory::Workspace* workspace);
|
||||||
|
|
||||||
// evaluate shapeInfo of transposed array
|
// evaluate shapeInfo of transposed array
|
||||||
static Nd4jLong* evalTranspShapeInfo(const NDArray& arr, nd4j::memory::Workspace* workspace);
|
// if setContigStrides = true, then set contiguous strides in output shapeInfo in accordance with arr order
|
||||||
|
static Nd4jLong* evalTranspShapeInfo(const NDArray& arr, nd4j::memory::Workspace* workspace, const bool setContigStrides = false);
|
||||||
|
|
||||||
static bool copyVectorPart(std::vector<int>& target, std::vector<int>& source, int rank, int offset);
|
static bool copyVectorPart(std::vector<int>& target, std::vector<int>& source, int rank, int offset);
|
||||||
|
|
||||||
|
|
|
@ -313,10 +313,9 @@ std::vector<Nd4jLong> ShapeUtils::evalRepeatShape(int axis, const std::vector<in
|
||||||
return outShape;
|
return outShape;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
// evaluate shapeInfo of permuted array
|
// evaluate shapeInfo of permuted array
|
||||||
Nd4jLong* ShapeUtils::evalPermShapeInfo(const int* dimensions, const int rank, const NDArray& arr, nd4j::memory::Workspace* workspace) {
|
Nd4jLong* ShapeUtils::evalPermShapeInfo(const int* dimensions, const int rank, const NDArray& arr, nd4j::memory::Workspace* workspace, const bool setContigStrides) {
|
||||||
|
|
||||||
if (!arr.nonNull())
|
if (!arr.nonNull())
|
||||||
throw std::runtime_error("ShapeUtils::evalPermShapeInfo static method: wrong arguments: array is nullptr!");
|
throw std::runtime_error("ShapeUtils::evalPermShapeInfo static method: wrong arguments: array is nullptr!");
|
||||||
|
@ -325,20 +324,26 @@ std::vector<Nd4jLong> ShapeUtils::evalRepeatShape(int axis, const std::vector<in
|
||||||
throw std::runtime_error("ShapeUtils::evalPermShapeInfo static method: wrong arguments: rank is not suitable!");
|
throw std::runtime_error("ShapeUtils::evalPermShapeInfo static method: wrong arguments: rank is not suitable!");
|
||||||
|
|
||||||
auto shapeInfoLength = shape::shapeInfoLength(rank);
|
auto shapeInfoLength = shape::shapeInfoLength(rank);
|
||||||
// allocate memory for new array - shapeInfo
|
|
||||||
|
|
||||||
|
// allocate memory for new array - shapeInfo
|
||||||
Nd4jLong *shapeInfoNew = nullptr;
|
Nd4jLong *shapeInfoNew = nullptr;
|
||||||
ALLOCATE(shapeInfoNew, workspace, shapeInfoLength, Nd4jLong);
|
ALLOCATE(shapeInfoNew, workspace, shapeInfoLength, Nd4jLong);
|
||||||
|
|
||||||
// copy arr _shapeInfo into new array
|
// copy arr _shapeInfo into new array
|
||||||
memcpy(shapeInfoNew, arr.getShapeInfo(), shape::shapeInfoByteLength(rank));
|
memcpy(shapeInfoNew, arr.getShapeInfo(), shape::shapeInfoByteLength(rank));
|
||||||
|
|
||||||
// perform buffer permutation
|
// perform buffer permutation
|
||||||
shape::doPermuteShapeInfo(shapeInfoNew, dimensions);
|
shape::doPermuteShapeInfo(shapeInfoNew, dimensions, arr.lengthOf());
|
||||||
|
|
||||||
|
if(setContigStrides)
|
||||||
|
shape::updateStrides(shapeInfoNew, arr.ordering());
|
||||||
|
|
||||||
ShapeDescriptor descriptor(shapeInfoNew);
|
ShapeDescriptor descriptor(shapeInfoNew);
|
||||||
RELEASE(shapeInfoNew, workspace);
|
|
||||||
return ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
RELEASE(shapeInfoNew, workspace);
|
||||||
|
|
||||||
|
return ConstantShapeHelper::getInstance()->bufferForShapeInfo(descriptor).primaryAsT<Nd4jLong>();
|
||||||
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
// evaluate shapeInfo of permuted array
|
// evaluate shapeInfo of permuted array
|
||||||
|
@ -350,14 +355,14 @@ std::vector<Nd4jLong> ShapeUtils::evalRepeatShape(int axis, const std::vector<in
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
// evaluate shapeInfo of transposed array
|
// evaluate shapeInfo of transposed array
|
||||||
Nd4jLong* ShapeUtils::evalTranspShapeInfo(const NDArray& arr, nd4j::memory::Workspace* workspace) {
|
Nd4jLong* ShapeUtils::evalTranspShapeInfo(const NDArray& arr, nd4j::memory::Workspace* workspace, const bool setContigStrides) {
|
||||||
|
|
||||||
int rank = arr.rankOf();
|
int rank = arr.rankOf();
|
||||||
std::vector<int> dimensions(rank);
|
std::vector<int> dimensions(rank);
|
||||||
for (int i = 0; i < rank; ++i)
|
for (int i = 0; i < rank; ++i)
|
||||||
dimensions[i] = rank - 1 - i;
|
dimensions[i] = rank - 1 - i;
|
||||||
|
|
||||||
return evalPermShapeInfo(dimensions.data(), dimensions.size(), arr, workspace);
|
return evalPermShapeInfo(dimensions.data(), dimensions.size(), arr, workspace, setContigStrides);
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
|
|
@ -15,7 +15,8 @@
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
//
|
//
|
||||||
// Created by raver119 on 29/10/17.
|
// @author raver119@gmail.com
|
||||||
|
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <op_boilerplate.h>
|
#include <op_boilerplate.h>
|
||||||
|
@ -29,80 +30,52 @@ namespace nd4j {
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
// here iArgs is int vector of ordered set of dimensions to be permuted
|
// here iArgs is int vector of ordered set of dimensions to be permuted
|
||||||
CUSTOM_OP_IMPL(permute, 1, 1, true, 0, -2) {
|
CUSTOM_OP_IMPL(permute, 1, 1, true, 0, -2) {
|
||||||
|
|
||||||
auto x = INPUT_VARIABLE(0);
|
auto x = INPUT_VARIABLE(0);
|
||||||
|
auto z = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
bool replace = false;
|
if (x->isEmpty()) {
|
||||||
|
REQUIRE_TRUE(z->isEmpty(), 0, "PERMUTE OP: when input is empty, output must also be empty");
|
||||||
auto origArgs = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT<int>() : *block.getIArguments();
|
return Status::OK(); //No op
|
||||||
std::vector<int> arguments({});
|
|
||||||
if(origArgs.size() > 0){
|
|
||||||
for (int e = 0; e < origArgs.size(); e++) {
|
|
||||||
int ax = origArgs[e];
|
|
||||||
if (ax < 0)
|
|
||||||
ax += x->rankOf();
|
|
||||||
|
|
||||||
arguments.emplace_back(ax);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
replace = true;
|
if (block.width() == 1 && block.getIArguments()->size() == 0) {
|
||||||
} else {
|
z->assign(x->transpose());
|
||||||
for (int e = x->rankOf() - 1; e >= 0; e--)
|
|
||||||
arguments.emplace_back(e);
|
|
||||||
}
|
|
||||||
|
|
||||||
// 0D edge case
|
|
||||||
if (x->rankOf() == 0) {
|
|
||||||
REQUIRE_TRUE(arguments.size() == 1, 0, "Permute: only one axis is allowed for scalar");
|
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
|
||||||
if (!block.isInplace())
|
|
||||||
output->assign(x);
|
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
if(block.isInplace()) { // in-place
|
std::vector<int> permutationVector = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT<int>() : *block.getIArguments();
|
||||||
x->permutei(arguments);
|
|
||||||
STORE_RESULT(x);
|
z->assign(x->permute(permutationVector));
|
||||||
} else {
|
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
|
||||||
auto result = x->permute(arguments);
|
|
||||||
output->assign(result);
|
|
||||||
STORE_RESULT(output);
|
|
||||||
}
|
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
DECLARE_TYPES(permute) {
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
DECLARE_TYPES(permute) {
|
||||||
getOpDescriptor()
|
getOpDescriptor()
|
||||||
->setAllowedInputTypes(0, nd4j::DataType::ANY)
|
->setAllowedInputTypes(0, nd4j::DataType::ANY)
|
||||||
->setAllowedInputTypes(1, {ALL_INTS})
|
->setAllowedInputTypes(1, {ALL_INTS})
|
||||||
->setSameMode(true);
|
->setSameMode(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
DECLARE_SHAPE_FN(permute) {
|
//////////////////////////////////////////////////////////////////////////
|
||||||
auto shapeList = SHAPELIST();
|
DECLARE_SHAPE_FN(permute) {
|
||||||
auto arguments = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT<int>() : *block.getIArguments();
|
|
||||||
|
|
||||||
if (shape::rank(inputShape->at(0)) == 0) {
|
auto x = INPUT_VARIABLE(0);
|
||||||
shapeList->push_back(ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(inputShape->at(0))));
|
|
||||||
} else if (inputShape->size() == 1 && !arguments.empty()) {
|
|
||||||
shapeList->push_back(ShapeUtils::evalPermShapeInfo(arguments.data(), arguments.size(), *INPUT_VARIABLE(0), block.workspace()));
|
|
||||||
} else {
|
|
||||||
if(arguments.size() == 0){
|
|
||||||
//Reverse dimensions
|
|
||||||
int rank = shape::rank(inputShape->at(0));
|
|
||||||
for (int e = rank - 1; e >= 0; e--)
|
|
||||||
arguments.emplace_back(e);
|
|
||||||
}
|
|
||||||
|
|
||||||
shapeList->push_back(ShapeUtils::evalPermShapeInfo(arguments.data(), arguments.size(), *INPUT_VARIABLE(0), block.workspace()));
|
if (block.width() == 1 && block.getIArguments()->size() == 0)
|
||||||
}
|
return SHAPELIST(ShapeUtils::evalTranspShapeInfo(*x, block.workspace(), true));
|
||||||
|
|
||||||
return shapeList;
|
std::vector<int> permutationVector = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT<int>() : *block.getIArguments();
|
||||||
}
|
|
||||||
}
|
auto outputShapeInfo = ShapeUtils::evalPermShapeInfo(permutationVector.data(), x->rankOf(), *x, block.workspace(), true);
|
||||||
|
|
||||||
|
return SHAPELIST(outputShapeInfo);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
|
@ -24,23 +24,29 @@
|
||||||
#include <ops/declarable/CustomOperations.h>
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
//////////////////////////////////////////////////////////////////////////
|
|
||||||
// here iArgs is a vector with (optional) negative of order as first element:
|
|
||||||
// ({-order, dim1, dim2, dim3, ...})
|
|
||||||
CUSTOM_OP_IMPL(reshape, 1, 1, false, 0, -2) {
|
|
||||||
auto x = INPUT_VARIABLE(0);
|
|
||||||
|
|
||||||
if (block.width() == 1) {
|
//////////////////////////////////////////////////////////////////////////
|
||||||
auto arguments = block.getIArguments();
|
// here iArgs is a vector with (optional) negative of order as first element:
|
||||||
int argsSize = arguments->size();
|
// ({-order, dim1, dim2, dim3, ...})
|
||||||
|
CUSTOM_OP_IMPL(reshape, 1, 1, false, 0, -2) {
|
||||||
|
|
||||||
|
auto x = INPUT_VARIABLE(0);
|
||||||
|
auto z = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
//Special case: empty.reshape(<other empty shape>) -> return empty
|
//Special case: empty.reshape(<other empty shape>) -> return empty
|
||||||
if (x->isEmpty()) {
|
if (x->isEmpty()) {
|
||||||
REQUIRE_TRUE(OUTPUT_VARIABLE(0)->isEmpty(), 0, "Reshape: when input is empty, output must also be empty");
|
REQUIRE_TRUE(z->isEmpty(), 0, "Reshape: when input is empty, output must also be empty");
|
||||||
return ND4J_STATUS_OK; //No op
|
return Status::OK(); //No op
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (block.width() == 1) {
|
||||||
|
|
||||||
|
auto arguments = block.getIArguments();
|
||||||
|
int argsSize = arguments->size();
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
int e = 1;
|
int e = 1;
|
||||||
char order = (char) -(*arguments)[0];
|
char order = (char) -(*arguments)[0];
|
||||||
if (order != 'c' && order != 'f') {
|
if (order != 'c' && order != 'f') {
|
||||||
|
@ -77,28 +83,15 @@ namespace nd4j {
|
||||||
nd4j_printv("Reshape: new shape", shapeNew);
|
nd4j_printv("Reshape: new shape", shapeNew);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (block.isInplace()) {
|
|
||||||
if (x->reshapei(order, shapeNew)) {
|
|
||||||
STORE_RESULT(*x);
|
|
||||||
return ND4J_STATUS_OK;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
auto ret = OUTPUT_VARIABLE(0);
|
|
||||||
auto xr = x->reshape(order, shapeNew);
|
auto xr = x->reshape(order, shapeNew);
|
||||||
ret->assign(xr);
|
z->assign(xr);
|
||||||
STORE_RESULT(*ret);
|
STORE_RESULT(*z);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
|
||||||
} else if (block.width() == 2) {
|
|
||||||
auto s = INPUT_VARIABLE(1);
|
|
||||||
|
|
||||||
//Special case: empty.reshape(-1) -> return empty
|
} else if (block.width() == 2) {
|
||||||
if (x->isEmpty()) {
|
|
||||||
//REQUIRE_TRUE(s->lengthOf() == 1 && s->e<Nd4jLong>(0) == -1, 0, "Reshape: when input is empty, shape must be [-1]");
|
auto s = INPUT_VARIABLE(1);
|
||||||
REQUIRE_TRUE(OUTPUT_VARIABLE(0)->isEmpty(), 0, "Reshape: when input is empty, output must also be empty");
|
|
||||||
return Status::OK(); //No op
|
|
||||||
}
|
|
||||||
|
|
||||||
char order = 'c';
|
char order = 'c';
|
||||||
if (block.numI() > 0)
|
if (block.numI() > 0)
|
||||||
|
@ -129,37 +122,30 @@ namespace nd4j {
|
||||||
nd4j_printv("Reshape: new shape", shapeNew);
|
nd4j_printv("Reshape: new shape", shapeNew);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (block.isInplace()) {
|
|
||||||
if (x->reshapei(order, shapeNew)) {
|
|
||||||
STORE_RESULT(*x);
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
auto ret = OUTPUT_VARIABLE(0);
|
|
||||||
if (s->isEmpty()) {
|
if (s->isEmpty()) {
|
||||||
// just a scalar
|
// just a scalar
|
||||||
ret->assign(x);
|
z->assign(x);
|
||||||
} else {
|
} else {
|
||||||
auto xr = x->reshape(order, shapeNew);
|
auto xr = x->reshape(order, shapeNew);
|
||||||
ret->assign(xr);
|
z->assign(xr);
|
||||||
}
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return ND4J_STATUS_BAD_INPUT;
|
return ND4J_STATUS_BAD_INPUT;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
DECLARE_TYPES(reshape) {
|
DECLARE_TYPES(reshape) {
|
||||||
getOpDescriptor()
|
getOpDescriptor()
|
||||||
->setAllowedInputTypes(0, nd4j::DataType::ANY)
|
->setAllowedInputTypes(0, nd4j::DataType::ANY)
|
||||||
->setAllowedInputTypes(1, {ALL_INTS})
|
->setAllowedInputTypes(1, {ALL_INTS})
|
||||||
->setSameMode(true);
|
->setSameMode(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
DECLARE_SHAPE_FN(reshape) {
|
DECLARE_SHAPE_FN(reshape) {
|
||||||
auto inp = inputShape->at(0);
|
auto inp = inputShape->at(0);
|
||||||
|
|
||||||
// we can launch op using Int arguments
|
// we can launch op using Int arguments
|
||||||
|
@ -270,8 +256,8 @@ namespace nd4j {
|
||||||
|
|
||||||
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inp), 'c', shapeNew));
|
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inp), 'c', shapeNew));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
|
@ -34,12 +34,10 @@ namespace nd4j {
|
||||||
auto y = INPUT_VARIABLE(1);
|
auto y = INPUT_VARIABLE(1);
|
||||||
|
|
||||||
auto z = OUTPUT_VARIABLE(0);
|
auto z = OUTPUT_VARIABLE(0);
|
||||||
std::vector<Nd4jLong> shapeNew(y->shapeOf(), y->shapeOf() + y->rankOf());
|
|
||||||
char order = y->ordering();
|
|
||||||
|
|
||||||
if (x->reshapei(order, shapeNew)) {
|
if (x->reshapei(y->ordering(), y->getShapeAsVector())) {
|
||||||
*z = *x;
|
|
||||||
STORE_RESULT(*z);
|
z->assign(x);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -49,14 +47,8 @@ namespace nd4j {
|
||||||
|
|
||||||
DECLARE_SHAPE_FN(reshapeas) {
|
DECLARE_SHAPE_FN(reshapeas) {
|
||||||
|
|
||||||
auto inputShapeInfo = inputShape->at(1);
|
return SHAPELIST(ShapeBuilders::copyShapeInfo(INPUT_VARIABLE(1)->getShapeInfo(), false, block.workspace()));
|
||||||
int shapeInfoLength = inputShapeInfo[0]*2 + 4;
|
}
|
||||||
|
|
||||||
Nd4jLong* outputShapeInfo(nullptr);
|
|
||||||
COPY_SHAPE(inputShapeInfo, outputShapeInfo);
|
|
||||||
|
|
||||||
return SHAPELIST(CONSTANT(outputShapeInfo));
|
|
||||||
}
|
|
||||||
|
|
||||||
DECLARE_TYPES(reshapeas) {
|
DECLARE_TYPES(reshapeas) {
|
||||||
getOpDescriptor()
|
getOpDescriptor()
|
||||||
|
|
|
@ -15,7 +15,8 @@
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
//
|
//
|
||||||
// Created by raver119 on 29/10/17.
|
// @author raver119@gmail.com
|
||||||
|
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <op_boilerplate.h>
|
#include <op_boilerplate.h>
|
||||||
|
@ -27,111 +28,50 @@
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
CUSTOM_OP_IMPL(transpose, 1, 1, false, 0, 0) {
|
CUSTOM_OP_IMPL(transpose, 1, 1, false, 0, 0) {
|
||||||
|
|
||||||
auto x = INPUT_VARIABLE(0);
|
auto x = INPUT_VARIABLE(0);
|
||||||
if (block.width() == 1) {
|
auto z = OUTPUT_VARIABLE(0);
|
||||||
if (block.isInplace()) {
|
|
||||||
x->transposei();
|
|
||||||
STORE_RESULT(*x);
|
|
||||||
} else {
|
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
|
||||||
auto t = x->transpose();
|
|
||||||
output->assign(t);
|
|
||||||
STORE_RESULT(*output);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// this is tf-mode transpose, that's nd4j permute
|
|
||||||
bool replace = false;
|
|
||||||
std::vector<int> arguments(*block.getIArguments());
|
|
||||||
|
|
||||||
auto w = block.width();
|
//Special case: empty.reshape(<other empty shape>) -> return empty
|
||||||
auto a = arguments.size();
|
if (x->isEmpty()) {
|
||||||
|
REQUIRE_TRUE(z->isEmpty(), 0, "TRANSPOSE OP: when input is empty, output must also be empty");
|
||||||
if (w == 2 && a == 0) {
|
return Status::OK(); //No op
|
||||||
auto axis = INPUT_VARIABLE(1);
|
|
||||||
for (int e = 0; e < axis->lengthOf(); e++) {
|
|
||||||
auto ax = axis->e<int>(e);
|
|
||||||
if (ax < 0)
|
|
||||||
ax += x->rankOf();
|
|
||||||
|
|
||||||
arguments.emplace_back(ax);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
replace = true;
|
if (block.width() == 1 && block.getIArguments()->size() == 0) {
|
||||||
} else if (a == 0) {
|
z->assign(x->transpose());
|
||||||
for (int e = x->rankOf() - 1; e >= 0; e--)
|
|
||||||
arguments.emplace_back(e);
|
|
||||||
}
|
|
||||||
|
|
||||||
// 0D edge case
|
|
||||||
if (x->rankOf() == 0) {
|
|
||||||
REQUIRE_TRUE(arguments.size() == 1, 0, "Permute: only one axis is allowed for scalar");
|
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
|
||||||
if (!block.isInplace())
|
|
||||||
output->assign(x);
|
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
if(block.isInplace()) { // in-place
|
std::vector<int> permutationVector = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT<int>() : *block.getIArguments();
|
||||||
x->permutei(arguments);
|
|
||||||
STORE_RESULT(x);
|
z->assign(x->permute(permutationVector));
|
||||||
} else {
|
|
||||||
auto input = x->permute(arguments);
|
|
||||||
|
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
|
||||||
output->assign(input);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
DECLARE_TYPES(transpose) {
|
DECLARE_TYPES(transpose) {
|
||||||
getOpDescriptor()
|
getOpDescriptor()
|
||||||
->setAllowedInputTypes(nd4j::DataType::ANY)
|
->setAllowedInputTypes(nd4j::DataType::ANY)
|
||||||
->setSameMode(true);
|
->setSameMode(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
DECLARE_SHAPE_FN(transpose) {
|
||||||
|
|
||||||
|
auto x = INPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
if (block.width() == 1 && block.getIArguments()->size() == 0)
|
||||||
|
return SHAPELIST(ShapeUtils::evalTranspShapeInfo(*x, block.workspace(), true));
|
||||||
|
|
||||||
|
std::vector<int> permutationVector = block.width() > 1 ? INPUT_VARIABLE(1)->asVectorT<int>() : *block.getIArguments();
|
||||||
|
|
||||||
|
auto outputShapeInfo = ShapeUtils::evalPermShapeInfo(permutationVector.data(), x->rankOf(), *x, block.workspace(), true);
|
||||||
|
|
||||||
DECLARE_SHAPE_FN(transpose) {
|
|
||||||
if (block.width() == 1) {
|
|
||||||
auto outputShapeInfo = ShapeUtils::evalTranspShapeInfo(*INPUT_VARIABLE(0), block.workspace());
|
|
||||||
return SHAPELIST(outputShapeInfo);
|
return SHAPELIST(outputShapeInfo);
|
||||||
} else {
|
}
|
||||||
// this is basically permute mode
|
|
||||||
auto shapeList = SHAPELIST();
|
|
||||||
auto arguments = block.getIArguments();
|
|
||||||
if (shape::rank(inputShape->at(0)) == 0) {
|
|
||||||
Nd4jLong *newshape;
|
|
||||||
ALLOCATE(newshape, block.getWorkspace(), shape::shapeInfoLength(inputShape->at(0)), Nd4jLong);
|
|
||||||
newshape[0] = 0;
|
|
||||||
newshape[1] = 0;
|
|
||||||
newshape[2] = 1;
|
|
||||||
newshape[3] = 99;
|
|
||||||
ArrayOptions::copyDataType(newshape, inputShape->at(0));
|
|
||||||
shapeList->push_back(newshape);
|
|
||||||
} else if (arguments->size() > 0 || inputShape->size() > 1) {
|
|
||||||
auto axis = arguments->size() > 0 ? *arguments : (INPUT_VARIABLE(1))->template asVectorT<int>();
|
|
||||||
auto outputShapeInfo = ShapeUtils::evalPermShapeInfo(axis.data(), axis.size(), *INPUT_VARIABLE(0), block.workspace());
|
|
||||||
shapeList->push_back(outputShapeInfo);
|
|
||||||
} else if (inputShape->size() == 2) {
|
|
||||||
// dead end
|
|
||||||
auto axis = INPUT_VARIABLE(1);
|
|
||||||
auto axisV = axis->template asVectorT<Nd4jLong>();
|
|
||||||
auto newshape = ShapeUtils::evalPermShapeInfo(axisV.data(), axisV.size(), *INPUT_VARIABLE(0), block.workspace());
|
|
||||||
shapeList->push_back(newshape);
|
|
||||||
} else {
|
|
||||||
int rank = shape::rank(inputShape->at(0));
|
|
||||||
for (int e = rank - 1; e >= 0; e--)
|
|
||||||
arguments->emplace_back(e);
|
|
||||||
|
|
||||||
auto outputShapeInfo = ShapeUtils::evalPermShapeInfo(arguments->data(), arguments->size(), *INPUT_VARIABLE(0), block.workspace());
|
|
||||||
shapeList->push_back(outputShapeInfo);
|
|
||||||
}
|
|
||||||
|
|
||||||
return shapeList;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1882,36 +1882,6 @@ TEST_F(DeclarableOpsTests1, TestGemv1) {
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
|
||||||
TEST_F(DeclarableOpsTests1, Reshape1) {
|
|
||||||
const std::vector<Nd4jLong> xShape = {5,4,3};
|
|
||||||
const std::vector<Nd4jLong> yShape = {3,5,4};
|
|
||||||
|
|
||||||
auto x = NDArrayFactory::create_<float>('f', xShape);
|
|
||||||
auto y = NDArrayFactory::create_<float>('f', yShape);
|
|
||||||
|
|
||||||
auto variableSpace = new VariableSpace();
|
|
||||||
variableSpace->putVariable(-1, x);
|
|
||||||
|
|
||||||
auto block = new Context(1, variableSpace, true);
|
|
||||||
block->fillInputs({-1});
|
|
||||||
std::vector<int>* arguments = block->getIArguments();
|
|
||||||
arguments->push_back(-y->ordering());
|
|
||||||
arguments->push_back(3);
|
|
||||||
arguments->push_back(5);
|
|
||||||
arguments->push_back(4);
|
|
||||||
|
|
||||||
nd4j::ops::reshape reshape;
|
|
||||||
|
|
||||||
reshape.execute(block);
|
|
||||||
|
|
||||||
ASSERT_TRUE(x->isSameShape(y));
|
|
||||||
|
|
||||||
delete y;
|
|
||||||
delete block;
|
|
||||||
delete variableSpace;
|
|
||||||
}
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests1, Reshape2) {
|
TEST_F(DeclarableOpsTests1, Reshape2) {
|
||||||
const std::vector<Nd4jLong> xShape = {5,4,3};
|
const std::vector<Nd4jLong> xShape = {5,4,3};
|
||||||
|
@ -2022,37 +1992,8 @@ TEST_F(DeclarableOpsTests1, Reshape7){
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests1, Transpose1) {
|
TEST_F(DeclarableOpsTests1, Transpose1) {
|
||||||
|
|
||||||
auto x = NDArrayFactory::create_<float>('c', {3,5,2});
|
auto x = NDArrayFactory::create_<float>('c', {3,5,2});
|
||||||
auto exp = NDArrayFactory::create_<float>('f', {2,5,3});
|
auto exp = NDArrayFactory::create_<float>('c', {2,5,3});
|
||||||
|
|
||||||
auto variableSpace = new VariableSpace();
|
|
||||||
variableSpace->putVariable(-1, x);
|
|
||||||
|
|
||||||
auto block = new Context(1, variableSpace, true); // in-place
|
|
||||||
block->fillInputs({-1});
|
|
||||||
nd4j::ops::transpose transpose;
|
|
||||||
|
|
||||||
Nd4jStatus status = transpose.execute(block);
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, status);
|
|
||||||
// ASSERT_TRUE(x.isSameShapeStrict(exp));
|
|
||||||
|
|
||||||
for (int e = 0; e < x->rankOf() * 2 + 2; e++) {
|
|
||||||
ASSERT_EQ(x->getShapeInfo()[e], exp->getShapeInfo()[e]);
|
|
||||||
}
|
|
||||||
// ASSERT_EQ(x.getShapeInfo()[x.rankOf() * 2 + 2],-exp.getShapeInfo()[x.rankOf() * 2 + 2]);
|
|
||||||
ASSERT_EQ(x->getShapeInfo()[x->rankOf() * 2 + 3], exp->getShapeInfo()[x->rankOf() * 2 + 3]);
|
|
||||||
|
|
||||||
delete exp;
|
|
||||||
delete block;
|
|
||||||
delete variableSpace;
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
|
||||||
TEST_F(DeclarableOpsTests1, Transpose2) {
|
|
||||||
auto x = NDArrayFactory::create_<float>('c', {3,5,2});
|
|
||||||
auto exp = NDArrayFactory::create_<float>('f', {2,5,3});
|
|
||||||
|
|
||||||
auto variableSpace = new VariableSpace();
|
auto variableSpace = new VariableSpace();
|
||||||
variableSpace->putVariable(-1, x);
|
variableSpace->putVariable(-1, x);
|
||||||
|
@ -2066,12 +2007,10 @@ TEST_F(DeclarableOpsTests1, Transpose2) {
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, status);
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||||
|
|
||||||
auto result = variableSpace->getVariable(block->getNodeId())->getNDArray();
|
auto result = variableSpace->getVariable(block->getNodeId())->getNDArray();
|
||||||
// ASSERT_TRUE(result->isSameShapeStrict(exp));
|
|
||||||
for (int e = 0; e < result->rankOf() * 2 + 2; e++) {
|
ASSERT_TRUE(exp->isSameShape(result));
|
||||||
ASSERT_EQ(result->getShapeInfo()[e], exp->getShapeInfo()[e]);
|
ASSERT_TRUE(exp->dataType() == result->dataType());
|
||||||
}
|
ASSERT_TRUE(exp->ordering() == result->ordering());
|
||||||
//ASSERT_EQ(result->getShapeInfo()[x.rankOf() * 2 + 2],-exp.getShapeInfo()[x.rankOf() * 2 + 2]);
|
|
||||||
ASSERT_EQ(result->getShapeInfo()[x->rankOf() * 2 + 3], exp->getShapeInfo()[x->rankOf() * 2 + 3]);
|
|
||||||
|
|
||||||
delete exp;
|
delete exp;
|
||||||
delete block;
|
delete block;
|
||||||
|
@ -2079,44 +2018,12 @@ TEST_F(DeclarableOpsTests1, Transpose2) {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
|
||||||
// in-place
|
|
||||||
TEST_F(DeclarableOpsTests1, Permute1) {
|
|
||||||
|
|
||||||
Nd4jLong shapeX[] = {3, 5, 10, 15, 150, 15, 1, 0, 1, 99};
|
|
||||||
Nd4jLong shapeExp[] = {3, 15, 5, 10, 1, 150, 15, 0, 0, 99};
|
|
||||||
const std::vector<int> perm = {2, 0, 1};
|
|
||||||
ArrayOptions::setDataType(shapeX, nd4j::DataType::FLOAT32);
|
|
||||||
ArrayOptions::setDataType(shapeExp, nd4j::DataType::FLOAT32);
|
|
||||||
|
|
||||||
auto x = new NDArray(shapeX,true);
|
|
||||||
auto exp = new NDArray(shapeExp,true);
|
|
||||||
|
|
||||||
auto variableSpace = new VariableSpace();
|
|
||||||
variableSpace->putVariable(-1, x);
|
|
||||||
|
|
||||||
auto block = new Context(1, variableSpace, true); // in-place
|
|
||||||
block->fillInputs({-1});
|
|
||||||
std::vector<int>* arguments = block->getIArguments();
|
|
||||||
*arguments = perm; // set dimensions to be permuted
|
|
||||||
|
|
||||||
nd4j::ops::permute permute;
|
|
||||||
Nd4jStatus status = permute.execute(block);
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, status);
|
|
||||||
|
|
||||||
ASSERT_TRUE(x->isSameShapeStrict(*exp));
|
|
||||||
|
|
||||||
delete exp;
|
|
||||||
delete block;
|
|
||||||
delete variableSpace;
|
|
||||||
}
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
// not-in-place
|
// not-in-place
|
||||||
TEST_F(DeclarableOpsTests1, Permute2) {
|
TEST_F(DeclarableOpsTests1, Permute1) {
|
||||||
|
|
||||||
Nd4jLong shapeX[] = {3, 5, 10, 15, 150, 15, 1, 0, 1, 99};
|
Nd4jLong shapeX[] = {3, 5,10,15, 150,15,1, 0,1,99};
|
||||||
Nd4jLong shapeExp[] = {3, 15, 5, 10, 1, 150, 15, 0, 0, 99};
|
Nd4jLong shapeExp[] = {3, 15,5,10, 50,10,1, 0,1,99};
|
||||||
const std::vector<int> perm = {2, 0, 1};
|
const std::vector<int> perm = {2, 0, 1};
|
||||||
|
|
||||||
ArrayOptions::setDataType(shapeX, nd4j::DataType::FLOAT32);
|
ArrayOptions::setDataType(shapeX, nd4j::DataType::FLOAT32);
|
||||||
|
|
|
@ -161,23 +161,6 @@ TEST_F(EmptyTests, Test_Reshape_1) {
|
||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(EmptyTests, Test_Reshape_2) {
|
|
||||||
auto vector = NDArrayFactory::create<float>('c', {1}, {119.0f});
|
|
||||||
auto exp = NDArrayFactory::create<float>(119.0f);
|
|
||||||
auto empty = NDArrayFactory::empty_<Nd4jLong>();
|
|
||||||
|
|
||||||
nd4j::ops::reshape op;
|
|
||||||
auto result = op.evaluate({&vector, empty}, {}, {}, {}, {}, true);
|
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
|
||||||
|
|
||||||
ASSERT_EQ(exp, *result->at(0));
|
|
||||||
ASSERT_EQ(exp, vector);
|
|
||||||
|
|
||||||
delete empty;
|
|
||||||
delete result;
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(EmptyTests, Test_Reshape_3) {
|
TEST_F(EmptyTests, Test_Reshape_3) {
|
||||||
auto x = NDArrayFactory::create<float>('c', {1, 0, 0, 2});
|
auto x = NDArrayFactory::create<float>('c', {1, 0, 0, 2});
|
||||||
auto y = NDArrayFactory::create<int>('c', {2}, {10, 0});
|
auto y = NDArrayFactory::create<int>('c', {2}, {10, 0});
|
||||||
|
|
|
@ -65,7 +65,7 @@ TEST_F(PlaygroundTests, test_avx) {
|
||||||
nd4j_printf("Optimal level: %i; Binary level: %i;\n", ::optimalLevel(), ::binaryLevel());
|
nd4j_printf("Optimal level: %i; Binary level: %i;\n", ::optimalLevel(), ::binaryLevel());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
TEST_F(PlaygroundTests, test_bert_1) {
|
TEST_F(PlaygroundTests, test_bert_1) {
|
||||||
// this test will run ONLY if this model exists
|
// this test will run ONLY if this model exists
|
||||||
if (nd4j::graph::getFileSize("/home/raver119/Downloads/Bert_minimal_model/bert_minimal_model.fb") < 0)
|
if (nd4j::graph::getFileSize("/home/raver119/Downloads/Bert_minimal_model/bert_minimal_model.fb") < 0)
|
||||||
|
@ -86,15 +86,15 @@ TEST_F(PlaygroundTests, test_bert_1) {
|
||||||
graph->getVariableSpace()->putVariable(86,0, u);
|
graph->getVariableSpace()->putVariable(86,0, u);
|
||||||
graph->getVariableSpace()->putVariable(87,0, v);
|
graph->getVariableSpace()->putVariable(87,0, v);
|
||||||
|
|
||||||
/*
|
|
||||||
// validating graph now
|
|
||||||
auto status = GraphExecutioner::execute(graph);
|
|
||||||
ASSERT_EQ(Status::OK(), status);
|
|
||||||
ASSERT_TRUE(graph->getVariableSpace()->hasVariable(198));
|
|
||||||
|
|
||||||
auto array = graph->getVariableSpace()->getVariable(198)->getNDArray();
|
// validating graph now
|
||||||
ASSERT_EQ(z, *array);
|
// auto status = GraphExecutioner::execute(graph);
|
||||||
*/
|
// ASSERT_EQ(Status::OK(), status);
|
||||||
|
// ASSERT_TRUE(graph->getVariableSpace()->hasVariable(198));
|
||||||
|
|
||||||
|
// auto array = graph->getVariableSpace()->getVariable(198)->getNDArray();
|
||||||
|
// ASSERT_EQ(z, *array);
|
||||||
|
|
||||||
|
|
||||||
nd4j::Environment::getInstance()->setProfiling(true);
|
nd4j::Environment::getInstance()->setProfiling(true);
|
||||||
auto profile = GraphProfilingHelper::profile(graph, 1);
|
auto profile = GraphProfilingHelper::profile(graph, 1);
|
||||||
|
@ -104,28 +104,27 @@ TEST_F(PlaygroundTests, test_bert_1) {
|
||||||
nd4j::Environment::getInstance()->setProfiling(false);
|
nd4j::Environment::getInstance()->setProfiling(false);
|
||||||
delete profile;
|
delete profile;
|
||||||
|
|
||||||
/*
|
|
||||||
std::vector<Nd4jLong> values;
|
|
||||||
|
|
||||||
for (int e = 0; e < 1; e++) {
|
// std::vector<Nd4jLong> values;
|
||||||
auto timeStart = std::chrono::system_clock::now();
|
|
||||||
|
|
||||||
GraphExecutioner::execute(graph);
|
// for (int e = 0; e < 1; e++) {
|
||||||
|
// auto timeStart = std::chrono::system_clock::now();
|
||||||
|
|
||||||
auto timeEnd = std::chrono::system_clock::now();
|
// GraphExecutioner::execute(graph);
|
||||||
auto outerTime = std::chrono::duration_cast<std::chrono::microseconds>(timeEnd - timeStart).count();
|
|
||||||
values.emplace_back(outerTime);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::sort(values.begin(), values.end());
|
// auto timeEnd = std::chrono::system_clock::now();
|
||||||
|
// auto outerTime = std::chrono::duration_cast<std::chrono::microseconds>(timeEnd - timeStart).count();
|
||||||
|
// values.emplace_back(outerTime);
|
||||||
|
// }
|
||||||
|
|
||||||
nd4j_printf("Time: %lld us;\n", values[values.size() / 2]);
|
// std::sort(values.begin(), values.end());
|
||||||
*/
|
|
||||||
|
// nd4j_printf("Time: %lld us;\n", values[values.size() / 2]);
|
||||||
|
|
||||||
delete graph;
|
delete graph;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
|
||||||
TEST_F(PlaygroundTests, test_broadcast_1) {
|
TEST_F(PlaygroundTests, test_broadcast_1) {
|
||||||
int pool = 10;
|
int pool = 10;
|
||||||
std::vector<NDArray*> aX(pool);
|
std::vector<NDArray*> aX(pool);
|
||||||
|
|
Loading…
Reference in New Issue