Oleh tenzor mmul (#231)

* Libnd4j: TensorMMul backprop op #8174, raw implementation

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* Libnd4j: TensorMMul backprop op #8174 merge master and some corrections

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* Libnd4j: TensorMMul backprop op #8174 algorithm update, need testing, sync with  master

* Libnd4j: TensorMMul backprop op #8174 fixed incorrect B axes calculation

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* Libnd4j: TensorMMul backprop op #8174 optimize axes identification and fix bug of indeces overlapping, added first test. need testing with different shapes

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* Libnd4j: TensorMMul backprop op #8174 some fixes and improvements need more testing

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* Libnd4j: TensorMMul backprop op #8174 fixed order of matrix multiply

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* Libnd4j: TensorMMul backprop op #8174 fixed issue of incorrect axes definition, add tests based on TF, need additional testing for case dLdC not equal 1

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* Libnd4j: TensorMMul backprop op #8174 fixed scalar case add test

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* Libnd4j: TensorMMul backprop op #8174 fixed bp algorithm, axes definition, need some mode testing with different orders combination f,c; c,f f,f and add some checks for inputs

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* Libnd4j: TensorMMul backprop op #8174 some checks and corrections added tests, exists the problem with different input orders support A-f B-c and A-f B-f

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* Libnd4j: TensorMMul backprop op #8174 sync master

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* - correct bug in MmulHelper::tensorDot(a, b, c, axes_a, axes_b,permutForC)

Signed-off-by: Yurii <iuriish@yahoo.com>

* Libnd4j: TensorMMul backprop op #8174 code clean up and refactoring

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* - add check for linspase ordered permutations in ShapeUtils::evalShapeForTensorDot

Signed-off-by: Yurii <iuriish@yahoo.com>

* - provide additional code in shape::reshape stuff in order to reduce amount of allocation/copy operations during reshaping procedure

Signed-off-by: Yurii <iuriish@yahoo.com>

* - further work on problem of wrong shape evaluation during permute/reshape procedures

Signed-off-by: Yurii <iuriish@yahoo.com>

* - still looking for bug reason in reshape/permute stuff

Signed-off-by: Yurii <iuriish@yahoo.com>

* - correct bug in transform cuda native ops

Signed-off-by: Yurii <iuriish@yahoo.com>

* - correct bug in NDArray::assign

Signed-off-by: Yurii <iuriish@yahoo.com>

* - remove old shape::reshape stuff

Signed-off-by: Yurii <iuriish@yahoo.com>

* - add possibility to disable copy of old buffer to new buffer during reshape operation in NDArray class

Signed-off-by: Yurii <iuriish@yahoo.com>

* - correct bug in tensorDot which had to do with wrong pointers assigments

Signed-off-by: Yurii <iuriish@yahoo.com>

Co-authored-by: Oleh <oleg.semeniv@gmail.com>
master
Yurii Shyrma 2020-02-13 19:33:54 +02:00 committed by GitHub
parent 8c0e378ec3
commit fe47f52896
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
37 changed files with 1524 additions and 901 deletions

View File

@ -999,14 +999,14 @@ namespace nd4j {
* set new order and shape in case of suitable array length (in-place operation)
* order - order to set
* shape - shape to set
*
* copyToNewBuff - if true then old buffer will be copied to new buffer if last one will be allocated after reshaping
* if there was permute applied before or there are weird strides, then new buffer is allocated for array
*/
bool reshapei(const char order, const std::initializer_list<Nd4jLong>& shape);
bool reshapei(const char order, const std::vector<Nd4jLong>& shape);
bool reshapei(const char order, const std::initializer_list<Nd4jLong>& shape, const bool copyToNewBuff = true);
bool reshapei(const char order, const std::vector<Nd4jLong>& shape, const bool copyToNewBuff = true);
bool reshapei(const std::initializer_list<Nd4jLong>& shape);
bool reshapei(const std::vector<Nd4jLong>& shape);
bool reshapei(const std::initializer_list<Nd4jLong>& shape, const bool copyToNewBuff = true);
bool reshapei(const std::vector<Nd4jLong>& shape, const bool copyToNewBuff = true);
/**
* creates new array with corresponding order and shape, new array will point on _buffer of this array
@ -1015,8 +1015,8 @@ namespace nd4j {
*
* if permute have been applied before or there are weird strides, then new buffer is allocated for new array
*/
NDArray reshape(const char order, const std::vector<Nd4jLong>& shape) const &;
NDArray reshape(const char order, const std::vector<Nd4jLong>& shape) &&;
NDArray reshape(const char order, const std::vector<Nd4jLong>& shape, const bool copyToNewBuff = true) const &;
NDArray reshape(const char order, const std::vector<Nd4jLong>& shape, const bool copyToNewBuff = true) &&;
/**
* calculate strides and set given order

View File

@ -1197,8 +1197,8 @@ void NDArray::assign(const NDArray& other, bool allowParallelism) {
throw std::runtime_error("NDArray::assign: lengths of arrays are mismatched");
}
// memcpy is allowed only for same order && same ews (being equal to 1)
if (ordering() == other.ordering() && dataType() == other.dataType() && ews() == 1 && other.ews() == 1)
// memcpy is allowed only for same order c && same ews (being equal to 1)
if (ordering() == other.ordering() && ordering() == 'c' && dataType() == other.dataType() && ews() == 1 && other.ews() == 1)
copyBuffersContinuouslyFrom(other, other.lengthOf() * other.sizeOfT());
else {
NDArray::prepareSpecialUse({this}, {&other});
@ -1569,20 +1569,25 @@ Nd4jLong NDArray::tensorsAlongDimension(const std::vector<int>& dimensions) cons
//////////////////////////////////////////////////////////////////////////
void NDArray::printShapeInfo(const char * msg) const {
//shape::printShapeInfo(_shapeInfo);
if (msg == nullptr)
shape::printShapeInfoLinear(_shapeInfo);
else {
int rank = shape::rank(_shapeInfo);
int lim = shape::shapeInfoLength(rank);
printf("%s: [", msg);
for (int i = 0; i < shape::shapeInfoLength(rank); i++) {
printf("%lld", (long long) _shapeInfo[i]);
if (i < lim - 1)
printf(", ");
}
printf("]\n");
if(msg != nullptr)
printf("shapeInfo %s: [", msg);
else
printf("shapeInfo: [");
printf("%i, ", rank);
for (int i = 1; i < shape::shapeInfoLength(rank) - 3; i++){
if(i == rank + 1)
printf(" ");
printf("%lld,", _shapeInfo[i]);
}
printf(" %lld,", shape::type(_shapeInfo));
printf("%lld,", shape::elementWiseStride(_shapeInfo));
printf("%lld]\n", (Nd4jLong)shape::order(_shapeInfo));
fflush(stdout);
}
@ -1855,19 +1860,19 @@ void NDArray::updateStrides(const char order) {
//////////////////////////////////////////////////////////////////////////
// set new order and shape in case of suitable array length
bool NDArray::reshapei(const char order, const std::initializer_list<Nd4jLong>& shape) {
bool NDArray::reshapei(const char order, const std::initializer_list<Nd4jLong>& shape, const bool copyToNewBuff) {
std::vector<Nd4jLong> vShape(shape);
return reshapei(order, vShape);
return reshapei(order, vShape, copyToNewBuff);
}
//////////////////////////////////////////////////////////////////////////
bool NDArray::reshapei(const std::initializer_list<Nd4jLong>& shape) {
return reshapei('c', shape);
bool NDArray::reshapei(const std::initializer_list<Nd4jLong>& shape, const bool copyToNewBuff) {
return reshapei(ordering(), shape, copyToNewBuff);
}
//////////////////////////////////////////////////////////////////////////
bool NDArray::reshapei(const std::vector<Nd4jLong>& shape) {
return reshapei('c', shape);
bool NDArray::reshapei(const std::vector<Nd4jLong>& shape, const bool copyToNewBuff) {
return reshapei(ordering(), shape, copyToNewBuff);
}
//////////////////////////////////////////////////////////////////////////
@ -1918,18 +1923,18 @@ Nd4jLong NDArray::argMax(std::initializer_list<int> dimensions) {
//////////////////////////////////////////////////////////////////////////
// create new array with corresponding order and shape, new array will point to the same _buffer as this array
NDArray NDArray::reshape(const char order, const std::vector<Nd4jLong>& shape) const & {
NDArray NDArray::reshape(const char order, const std::vector<Nd4jLong>& shape, const bool copyToNewBuff) const & {
NDArray newArr(getDataBuffer(), ShapeDescriptor(getShapeInfo()), getContext(), getBufferOffset());
newArr.reshapei(order, shape);
newArr.reshapei(order, shape, copyToNewBuff);
return newArr;
}
//////////////////////////////////////////////////////////////////////////
NDArray NDArray::reshape(const char order, const std::vector<Nd4jLong>& shape) && {
NDArray NDArray::reshape(const char order, const std::vector<Nd4jLong>& shape, const bool copyToNewBuff) && {
this->reshapei(order, shape);
this->reshapei(order, shape, copyToNewBuff);
return std::move(*this);
}
@ -3221,7 +3226,7 @@ BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT std::vector, NDArray::asVectorT(), LI
//////////////////////////////////////////////////////////////////////////
// set new order and shape in case of suitable array length
bool NDArray::reshapei(const char order, const std::vector<Nd4jLong>& cshape) {
bool NDArray::reshapei(const char order, const std::vector<Nd4jLong>& cshape, const bool copyToNewBuff) {
// check firstly whether cshape is identical to shape of array, if yes then reshape is unnecessary
if(order == ordering() && shape::shapeEquals(rankOf(), shapeOf(), cshape.size(), cshape.data()))
@ -3293,18 +3298,14 @@ bool NDArray::reshapei(const char order, const std::vector<Nd4jLong>& cshape) {
Nd4jLong *shapeInfoNew;
ALLOCATE(shapeInfoNew, getContext()->getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong);
bool canReshape = shape::reshapeC(rankOf(), shapeInfo(), shape.size(), shape.data(), shapeInfoNew);
bool canReshape = shape::reshapeC(shapeInfo(), order, shape.size(), shape.data(), shapeInfoNew);
// we can do this only if there was no permute applied, or there are no weird strides
if (canReshape) {
if(ordering() == 'c' && order == 'f')
throw std::invalid_argument("NDArray::reshapei(order, shape): in case of reshapeC it doesn't make sense to reshape from c order to f order !");
shape::setEws(shapeInfoNew, arrLength);
setShapeInfo(shapeInfoNew);
}
else {
NDArray temp(order, shape, dataType(), getContext());
if(copyToNewBuff)
this->applyTransform(transform::Assign, temp, nullptr);
*this = std::move(temp);
}
@ -4846,7 +4847,7 @@ NDArray NDArray::operator()(const std::vector<Nd4jLong>& idx, const bool keepUni
auto shapeOf = shape::shapeOf(newShapeInfo);
auto stridesOf = shape::stride(newShapeInfo);
Nd4jLong offset(0), subArrLen(1);
Nd4jLong offset = 0;
int n(isStrided ? 3 : 2), first, last, stride;
for (int d = rank - 1; d >= 0; --d) {
@ -4863,29 +4864,31 @@ NDArray NDArray::operator()(const std::vector<Nd4jLong>& idx, const bool keepUni
if(shapeOf[d] != 1)
stridesOf[d] *= stride;
}
}
subArrLen *= shapeOf[d];
Nd4jLong *shapeInfoNoUnities = newShapeInfo;
if(!keepUnitiesInShape) {
std::vector<int> dimsWithUnities;
for (uint d = 0; d < rank; ++d)
if(idx[n*d] != idx[n*d+1] && shapeOf[d] == 1)
dimsWithUnities.push_back(d);
if(!dimsWithUnities.empty())
shapeInfoNoUnities = ShapeBuilders::copyShapeInfoWithoutUnites(newShapeInfo, dimsWithUnities.size(), dimsWithUnities.data(), getContext()->getWorkspace());
}
// check if there is possibility to set ews = 1
shape::setEws(newShapeInfo, subArrLen);
shape::checkStridesSetEwsAndOrder(shapeInfoNoUnities);
NDArray result(_buffer, ShapeDescriptor(newShapeInfo), getContext(), offset + getBufferOffset());
NDArray result(_buffer, ShapeDescriptor(shapeInfoNoUnities), getContext(), offset + getBufferOffset());
result._isView = true;
if(!keepUnitiesInShape) {
const int coeff = isStrided ? 3 : 2;
std::vector<Nd4jLong> nonUnitDims;
for (int d = 0; d < rank; ++d)
if(!(idx[coeff*d] != idx[coeff*d+1] && newShapeInfo[d+1] == 1))
nonUnitDims.push_back(newShapeInfo[d+1]);
if(nonUnitDims.size() != rank)
result.reshapei(nonUnitDims);
}
RELEASE(newShapeInfo, getContext()->getWorkspace());
if(newShapeInfo != shapeInfoNoUnities)
RELEASE(shapeInfoNoUnities, getContext()->getWorkspace());
return result;
}

View File

@ -51,6 +51,13 @@ namespace nd4j {
static Nd4jLong* copyShapeInfoAndType(const Nd4jLong* inShapeInfo, const DataType dtype, const bool copyStrides, memory::Workspace* workspace = nullptr);
static Nd4jLong* copyShapeInfoAndType(const Nd4jLong* inShapeInfo, const Nd4jLong* shapeInfoToGetTypeFrom, const bool copyStrides, memory::Workspace* workspace = nullptr);
/**
* allocates memory for new shapeInfo and copy all information from inShapeInfo to new shapeInfo except dimensions in dimsToExclude (unit dimensions) and corresponding strides
* for example inShapeInfo is {3, 2,1,3,1,4, 12,12,4,4,1, 16384,1,99}, dimsToExclude = {2,3}, dimsSize = 2
* then outShapeInfo will contain {3, 2,3,4, 12,4,1, 16384,1,99}
*/
static Nd4jLong* copyShapeInfoWithoutUnites(const Nd4jLong* inShapeInfo, const int dimsSize, const int* dimsToExclude, memory::Workspace* workspace = nullptr);
static Nd4jLong* emptyShapeInfo(const nd4j::DataType dataType, memory::Workspace* workspace = nullptr);
static Nd4jLong* emptyShapeInfo(const nd4j::DataType dataType, const char order, const std::vector<Nd4jLong> &shape, memory::Workspace* workspace = nullptr);

View File

@ -68,7 +68,7 @@ namespace nd4j {
const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(shapeInfo, dimsToExclude);
const int subArrRank = (rank == dimsToExclude.size() || descriptor.areUnitiesinShape()) ? rank : rank - dimsToExclude.size();
auto sPtr = new Nd4jLong[shape::shapeInfoLength(subArrRank)];
auto sPtr = new Nd4jLong[shape::shapeInfoLength(subArrRank)]; // shape of sub-arrays (same for all for them)
auto oPtr = new Nd4jLong[numOfSubArrs];
if (numOfSubArrs > 0)

View File

@ -43,23 +43,30 @@ nd4j::NDArray* nd4j::MmulHelper::tensorDot(const nd4j::NDArray* a, const nd4j::N
auto outShape = ShapeUtils::evalShapeForTensorDot(a, b, axes_0, axes_1, permutAt, permutBt, shapeAt, shapeBt);
NDArray aPR = a->permute(permutAt);
NDArray bPR = b->permute(permutBt);
// check whether permutation is necessary
const NDArray* aP = permutAt.empty() ? a : new NDArray(a->permute(permutAt));
const NDArray* bP = permutBt.empty() ? b : new NDArray(b->permute(permutBt));
// check whether reshape is necessary
if(!aPR.isSameShape(shapeAt))
aPR.reshapei( shapeAt);
if(!bPR.isSameShape(shapeBt))
bPR.reshapei( shapeBt);
const NDArray* aPR = aP->isSameShape(shapeAt) ? aP : new NDArray(aP->reshape(aP->ordering(), shapeAt));
const NDArray* bPR = bP->isSameShape(shapeAt) ? bP : new NDArray(bP->reshape(bP->ordering(), shapeBt));
NDArray* c = mmul(&aPR, &bPR, nullptr, 1.0, 0.0);
NDArray* c = mmul(aPR, bPR, nullptr, 1.0, 0.0);
c->reshapei(outShape);
if(aP != aPR)
delete aPR;
if(bP != bPR)
delete bPR;
if(a != aP)
delete aP;
if(b != bP)
delete bP;
return c;
}
//////////////////////////////////////////////////////////////////////////
void nd4j::MmulHelper::tensorDot(const nd4j::NDArray* a, const nd4j::NDArray* b, nd4j::NDArray* c, const std::vector<int>& axes_a, const std::vector<int>& axes_b, const std::vector<int>& permutForC) {
@ -67,32 +74,38 @@ void nd4j::MmulHelper::tensorDot(const nd4j::NDArray* a, const nd4j::NDArray* b,
std::vector<Nd4jLong> shapeAt, shapeBt;
ShapeUtils::evalShapeForTensorDot(a, b, axes_a, axes_b, permutAt, permutBt, shapeAt, shapeBt);
NDArray *cP(c), *cPR(c);
// check whether permutation is required
if(!permutForC.empty())
cP = new NDArray(c->permute(permutForC));
NDArray* cP = permutForC.empty() ? c : new NDArray(c->permute(permutForC));
auto aPR = a->permute(permutAt);
auto bPR = b->permute(permutBt);
// check whether permutation is necessary
const NDArray* aP = permutAt.empty() ? a : new NDArray(a->permute(permutAt));
const NDArray* bP = permutBt.empty() ? b : new NDArray(b->permute(permutBt));
// check whether reshape is necessary
if(!aPR.isSameShape(shapeAt))
aPR.reshapei(shapeAt);
if(!bPR.isSameShape(shapeBt))
bPR.reshapei(shapeBt);
const NDArray* aPR = aP->isSameShape(shapeAt) ? aP : new NDArray(aP->reshape(aP->ordering(), shapeAt));
const NDArray* bPR = bP->isSameShape(shapeAt) ? bP : new NDArray(bP->reshape(bP->ordering(), shapeBt));
if(!cP->isSameShape({aPR.sizeAt(0), bPR.sizeAt(1)}))
cPR = new NDArray(cP->reshape(cP->ordering(), {aPR.sizeAt(0), bPR.sizeAt(1)}));
std::vector<Nd4jLong> requiredCshape = {aPR->sizeAt(0), bPR->sizeAt(1)};
mmul(&aPR, &bPR, cPR, 1.0, 0.0);
NDArray* cPR = cP->isSameShape(requiredCshape) ? cP : new NDArray(cP->reshape(cP->ordering(), requiredCshape, false));
mmul(aPR, bPR, cPR, 1.0, 0.0);
if(cPR->getBuffer() != cP->getBuffer() || cPR->getSpecialBuffer() != cP->getSpecialBuffer() ) // this means both permute and reshape have been performed on c, cP always points on c->getBuffer()
cP->assign(cPR);
if(cPR != c)
if(aP != aPR)
delete aPR;
if(bP != bPR)
delete bPR;
if(a != aP)
delete aP;
if(b != bP)
delete bP;
if(cP != cPR)
delete cPR;
if(cP != c)
if(c != cP)
delete cP;
}
@ -129,7 +142,7 @@ void nd4j::MmulHelper::tensorDot(const NDArray* a, const NDArray* b, NDArray* c,
if(!whatToDoWithC.empty()) {
cArrs = std::vector<NDArray*>(whatToDoWithC.size()+1, c);
for(int i = 0; i < cArrs.size()-1; ++i)
cArrs[i+1] = (whatToDoWithC[i] == 'p') ? new NDArray(cArrs[i]->permute(modifC[i])) : new NDArray(cArrs[i]->reshape(c->ordering(), modifC[i])); // since we ignore first element in cArrs (that is cArrs[0]) then it is always equal to c
cArrs[i+1] = (whatToDoWithC[i] == 'p') ? new NDArray(cArrs[i]->permute(modifC[i])) : new NDArray(cArrs[i]->reshape(c->ordering(), modifC[i], false)); // since we ignore first element in cArrs (that is cArrs[0]) then it is always equal to c
}
mmul(aPR, bPR, cArrs[cArrs.size()-1], 1.0, 0.0);
@ -208,7 +221,7 @@ nd4j::NDArray* MmulHelper::mmul(const nd4j::NDArray* A, const nd4j::NDArray* B,
// vector x matrix, A{M} x B{M,N} = C{N} -> reduce to matrix x matrix A2{1,M} x B{M,N} = C2{1,N}, since there is no corresponding blas operation sgevm
if(isAVector && bRank == 2) {
NDArray* A2 = new NDArray(A->reshape(A->ordering(), {1, A->lengthOf()})); // A{M} -> A2{1,M}
NDArray* C2 = C ? new NDArray(C->reshape(C->ordering(), {1, C->lengthOf()})) : nullptr; // C{N} -> C2{1,N}
NDArray* C2 = C ? new NDArray(C->reshape(C->ordering(), {1, C->lengthOf()}, false)) : nullptr; // C{N} -> C2{1,N}
auto result = mmulMxM(A2, B, C2, alpha, beta, outOrder); // result{1,N}
delete A2;
delete C2;

View File

@ -139,5 +139,15 @@ namespace nd4j {
return ShapeBuilders::copyShapeInfoAndType(inShapeInfo, ArrayOptions::dataType(shapeInfoToGetTypeFrom), copyStrides, workspace);
}
////////////////////////////////////////////////////////////////////////////////
Nd4jLong* ShapeBuilders::copyShapeInfoWithoutUnites(const Nd4jLong* inShapeInfo, const int dimsSize, const int* dimsToExclude, memory::Workspace* workspace) {
Nd4jLong *outShapeInfo = nullptr;
ALLOCATE(outShapeInfo, workspace, shape::shapeInfoLength(inShapeInfo[0] - dimsSize), Nd4jLong);
shape::excludeUnitiesFromShapeInfo(inShapeInfo, dimsSize, dimsToExclude, outShapeInfo);
return outShapeInfo;
}
}

View File

@ -75,10 +75,23 @@ std::vector<Nd4jLong> ShapeUtils::evalShapeForTensorDot(const Nd4jLong* aShapeIn
permutBt = axesB;
permutBt.insert(permutBt.end(), list_B.begin(), list_B.end());
// if permut contains something like {0,1,2,..rank-1}, then there is no need to make permutation and we return empty vector in this case
uint i1, i2;
for(i1 = 0; i1 < aRank; ++i1)
if(permutAt[i1] != i1)
break;
if(i1 == aRank)
permutAt = {};
for(i2 = 0; i2 < bRank; ++i2)
if(permutBt[i2] != i2)
break;
if(i2 == bRank)
permutBt = {};
Nd4jLong n2 = 1;
for (int i = 0; i < axeAsize; i++)
n2 *= aShapeInfo[axesA[i] + 1];
shapeAt = {-1, n2};
shapeAt = {shape::length(aShapeInfo) / n2, n2};
std::vector<Nd4jLong> oldShapeA;
oldShapeA.resize(list_A.size());
@ -89,7 +102,7 @@ std::vector<Nd4jLong> ShapeUtils::evalShapeForTensorDot(const Nd4jLong* aShapeIn
Nd4jLong n3 = 1;
for (int i = 0; i < axeBsize; i++)
n3 *= bShapeInfo[axesB[i] + 1];
shapeBt = {n3, -1};
shapeBt = {n3, shape::length(bShapeInfo) / n3};
std::vector<Nd4jLong> oldShapeB;
oldShapeB.resize(list_B.size());
@ -306,10 +319,10 @@ std::vector<Nd4jLong> ShapeUtils::evalRepeatShape(int axis, const std::vector<in
Nd4jLong* ShapeUtils::evalPermShapeInfo(const int* dimensions, const int rank, const NDArray& arr, nd4j::memory::Workspace* workspace) {
if (!arr.nonNull())
throw std::runtime_error("ShapeUtils::evalPermShapeInfo static method: wrong arguments in pn/termute method: either array is nullptr!");
throw std::runtime_error("ShapeUtils::evalPermShapeInfo static method: wrong arguments: array is nullptr!");
if (rank != arr.rankOf())
throw std::runtime_error("ShapeUtils::evalPermShapeInfo static method: wrong arguments in pn/termute method: rank is not suitable!");
throw std::runtime_error("ShapeUtils::evalPermShapeInfo static method: wrong arguments: rank is not suitable!");
auto shapeInfoLength = shape::shapeInfoLength(rank);
// allocate memory for new array - shapeInfo

View File

@ -131,7 +131,11 @@ namespace shape {
ND4J_EXPORT _CUDA_HD bool canReshape(const int oldRank, Nd4jLong* oldShape, const int newRank, Nd4jLong* newShape, bool isFOrder);
ND4J_EXPORT _CUDA_HD bool reshapeC(const int oldRank, const Nd4jLong* oldShapeInfo, const int newRank, const Nd4jLong* newShape, Nd4jLong* newShapeInfo);
ND4J_EXPORT _CUDA_HD bool reshapeC(const Nd4jLong* oldShapeInfo, const char newOrder, const int newRank, const Nd4jLong* newShape, Nd4jLong* newShapeInfo);
/**
* newShapeInfo contains rank, shape and order only, no strides/ews/type
*/
ND4J_EXPORT _CUDA_HD bool reshapeC(const Nd4jLong* oldShapeInfo, Nd4jLong* newShapeInfo);
/**
* Get the shape info buffer
@ -365,6 +369,13 @@ namespace shape {
ND4J_EXPORT _CUDA_HD bool isRowVector(const Nd4jLong *shapeInfo);
ND4J_EXPORT _CUDA_HD bool isColumnVector(Nd4jLong *shapeInfo);
/**
* shape - input inShape is shape only, not shapeInfo
* returns number of non-unity dimensions in inShape
*/
ND4J_EXPORT _CUDA_HD int numOfNonUnitDims(const int rank, const Nd4jLong* inShape);
/**
* Returns whether the
* given shape is a vector or not
@ -379,7 +390,8 @@ namespace shape {
* Returns the shape portion of an information
* buffer
*/
ND4J_EXPORT _CUDA_HD Nd4jLong *shapeOf(Nd4jLong *buffer);
ND4J_EXPORT _CUDA_HD Nd4jLong *shapeOf(Nd4jLong *shapeInfo);
ND4J_EXPORT _CUDA_HD Nd4jLong *shapeOf(const Nd4jLong *shapeInfo);
/**
* Return a copy of a buffer.
@ -994,21 +1006,16 @@ namespace shape {
// rank is equal to size of shape
ND4J_EXPORT void calcOffsets(const int rank, const Nd4jLong* shape, const Nd4jLong* strides, Nd4jLong* offsets, const char order = 'c');
ND4J_EXPORT void calcOffsets(const Nd4jLong* shapeInfo, Nd4jLong* offsets, const char order = 'c');
ND4J_EXPORT void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffsets, const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const char order = 'c');
ND4J_EXPORT void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffsets, const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const Nd4jLong* zShapeInfo, Nd4jLong*& zOffsets, const char order = 'c');
// ND4J_EXPORT void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffsets, const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const char order = 'c');
// ND4J_EXPORT void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffsets, const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const Nd4jLong* zShapeInfo, Nd4jLong*& zOffsets, const char order = 'c');
ND4J_EXPORT _CUDA_HD void shapeOldScalar(nd4j::DataType dtype, Nd4jLong* const buffer, const char order);
// deduce element-wise stride
// if array is scalar or unit length vector then ews = 1
// if array is common vector then ews = stride of non-unity dimension
// if strides are normal set ews = 1, otherwise ews = 0
ND4J_EXPORT _CUDA_HD void setEws(Nd4jLong* shapeInfo, Nd4jLong len);
// deduce order and element-wise stride
// if array is scalar or unit length vector then ews = 1 and order is preserved
// if array is common vector then ews = stride of non-unity dimension and order is preserved
// if strides are normal/contiguous then ews = 1 and corresponding order is set, otherwise ews = 0 and order is preserved
ND4J_EXPORT _CUDA_HD void setOrderAndEws(Nd4jLong* shapeInfo, Nd4jLong len = -1);
ND4J_EXPORT _CUDA_HD void checkStridesSetEwsAndOrder(Nd4jLong* shapeInfo, const char proposedOrder, const int numOfNonUnitDims, const Nd4jLong* shapeNoUnities, const Nd4jLong* stridesNoUnities);
ND4J_EXPORT _CUDA_HD void checkStridesSetEwsAndOrder(Nd4jLong* shapeInfo);
/**
* processes whole set of sub-arrays
@ -1018,12 +1025,26 @@ namespace shape {
* numOfSubArrs - number of sub-arrays, size of subArrOffsets is equal to numOfSubArrs
* dimsSize - size of dimsToExclude, if dimsSize = array rank or dimsSize = 0 it means sub-array is whole array, copy of wholeShapeInfo and one zero offset will be returned
* dimsToExclude - MUST BE SORTED, dimensions to evaluate sub-array along, i.e. when shape is [2,3,4,5] and dimsToExclude={0,2}, then there will be 8 sub-arrays with shape [3,5]
* subArrShapeInfo - output argument, contains shapeInfo common for all sub-arrays
* subArrShapeInfo - output argument, contains shapeInfo (same for all sub-arrays)
* subArrOffsets - output argument, contains successive sub-arrays offsets from original this-buffer
* keepUnitiesInShape - if false then eliminate unities from sub-array shapeInfo, for example {1,a,1,b} -> {a,b}
*/
ND4J_EXPORT _CUDA_HD void calcSubArrShapeAndOffsets(const Nd4jLong* wholeShapeInfo, const Nd4jLong numOfSubArrs, const int dimsSize, const int* dimsToExclude, Nd4jLong* subArrShapeInfo, Nd4jLong* subArrOffsets, bool keepUnitiesInShape = false);
/**
* for example inShapeInfo is {3, 2,1,4, 4,4,1, 16384,1,99}
* then output shapeNoUnities will contain {2,4, 4,1} - that is only shape and strides, no rank/type/ews/order
* stridesNoUnities will point on strides in shapeNoUnities that is on {4,1}
* returns number of non-unity dimensions in inShapeInfo
* if there is no unities in inShapeInfo, then no copy procedure will be performed and shapeNoUnities/stridesNoUnities will point on corresponding places in inShapeInfo
*/
ND4J_EXPORT _CUDA_HD int excludeUnitiesFromShapeInfo(const Nd4jLong* inShapeInfo, Nd4jLong*& shapeNoUnities, Nd4jLong*& stridesNoUnities);
/**
* for example inShapeInfo is {3, 2,1,3,1,4, 12,12,4,4,1, 16384,1,99}, dimsToExclude = {2,3}, dimsSize = 2
* then outShapeInfo will contain {3, 2,3,4, 12,4,1, 16384,1,99}
*/
INLINEDEF _CUDA_HD void excludeUnitiesFromShapeInfo(const Nd4jLong* inShapeInfo, const int dimsSize, const int* dimsToExclude, Nd4jLong* outShapeInfo);
@ -2050,7 +2071,7 @@ INLINEDEF _CUDA_HD Nd4jLong indexOffset(Nd4jLong index, const Nd4jLong* lShapeIn
shapeInfo[i + 1 + rank] = temp[rearrange[i] + 1 + rank];
}
shape::setOrderAndEws(shapeInfo, len);
shape::checkStridesSetEwsAndOrder(shapeInfo);
delete[] temp;
}
@ -2227,7 +2248,7 @@ INLINEDEF _CUDA_HD Nd4jLong indexOffset(Nd4jLong index, const Nd4jLong* lShapeIn
INLINEDEF _CUDA_HD bool isCommonVector(const Nd4jLong *shapeInfo, int& posOfNonUnityDim) {
if(rank(shapeInfo) > 0 && length(shapeInfo) == 1) {
posOfNonUnityDim = 0;
posOfNonUnityDim = -1;
return true;
}
@ -2272,6 +2293,18 @@ INLINEDEF _CUDA_HD Nd4jLong indexOffset(Nd4jLong index, const Nd4jLong* lShapeIn
return isVector && !shapeFirstOne;
}
//////////////////////////////////////////////////////////////////////
INLINEDEF _CUDA_HD int numOfNonUnitDims(const int rank, const Nd4jLong* inShape) {
int num = 0;
for(uint i = 0; i < rank; ++i)
if(inShape[i] != 1)
++num;
return num;
}
INLINEDEF _CUDA_HD int oneDimEqualToLength(Nd4jLong *shape, int rank) {
for(int i = 0; i < rank; i++) {
if(shape[i] == shape::prodLong(shape,rank))
@ -2310,8 +2343,14 @@ INLINEDEF _CUDA_HD Nd4jLong indexOffset(Nd4jLong index, const Nd4jLong* lShapeIn
* Returns the shape portion of an information
* buffer
*/
INLINEDEF _CUDA_HD Nd4jLong *shapeOf(Nd4jLong *buffer) {
return buffer + 1;
INLINEDEF _CUDA_HD Nd4jLong *shapeOf(Nd4jLong *shapeInfo) {
return shapeInfo + 1;
}
INLINEDEF _CUDA_HD Nd4jLong *shapeOf(const Nd4jLong *shapeInfo) {
return shape::shapeOf(const_cast<Nd4jLong*>(shapeInfo));
}
/**
@ -2444,7 +2483,7 @@ INLINEDEF _CUDA_HD Nd4jLong indexOffset(Nd4jLong index, const Nd4jLong* lShapeIn
newShapeBuffer[2 * newRank + 3] = shape::order(shapeBuffer);
// correct order and ews if necessary
shape::setOrderAndEws(newShapeBuffer);
shape::checkStridesSetEwsAndOrder(newShapeBuffer);
delete[] indices;
@ -3918,62 +3957,51 @@ INLINEDEF _CUDA_HD bool areStridesDefault(const Nd4jLong* shapeInfo) {
// return true;
// }
// INLINEDEF _CUDA_H bool reshapeC(const int oldRank, const Nd4jLong* oldShapeInfo, const int newRank, const Nd4jLong* newShape, const bool isFOrder, Nd4jLong* newShapeInfo) {
//////////////////////////////////////////////////////////////////////
// INLINEDEF _CUDA_H bool reshapeC(const int oldRank, const Nd4jLong* oldShapeInfo, const int newRank, const Nd4jLong* newShape, Nd4jLong* newShapeInfo) {
// // PLEASE NOTE !: reshaping not-permuted (ews=1) array in f order (except insertion/elimination of unities) will definitely cause allocation of new buffer for array elements
// // also this function takes into account identical shapes automatically, namely in that case oldShapeInfo is completely copied to newShapeInfo
// const int newOrder = isFOrder ? 102 : 99;
// const int oldOrder = oldShapeInfo[2 * oldRank + 3];
// newShapeInfo[0] = newRank;
// memcpy(newShapeInfo + 1, newShape, newRank * sizeof(Nd4jLong));
// Nd4jLong* newStrides = shape::stride(newShapeInfo);
// const Nd4jLong* oldShape = shape::shapeOf(const_cast<Nd4jLong*>(oldShapeInfo));
// const Nd4jLong* oldStrides = shape::stride(const_cast<Nd4jLong*>(oldShapeInfo));
// int oldStart(0), oldStop(1), newStart(0), newStop(1), newDim, oldDim;
// Nd4jLong oldStart(0), oldStop(1), newStart(0), newStop(1), newDim, oldDim;
// while (newStart < newRank && oldStart < oldRank) {
// newDim = newShape[newStart];
// oldDim = oldShape[oldStart];
// while (newDim != oldDim)
// while (newDim != oldDim && newDim > 0 && oldDim > 0)
// if (newDim < oldDim) newDim *= newShape[newStop++];
// else oldDim *= oldShape[oldStop++];
// // ------ Check whether the original axes can be combined ------ //
// for (int i = oldStart; i < oldStop - 1; i++) {
// if(oldShape[i] == 1) { // ignore strides like {...,1,1,...}
// if(oldOrder == 102) ++oldStart;
// for (int step = 1, i = oldStart; i < oldStop - 1; ++i) {
// if(oldShape[i] == 1) // skip unity-dimension and its stride
// continue;
// }
// if(oldOrder == 102 && oldStrides[i + 1] != oldShape[i] * oldStrides[i])
// return false; // not contiguous enough
// if(oldOrder == 99 && oldStrides[i] != oldShape[i + 1] * oldStrides[i + 1])
// while((i + step) < oldRank && oldShape[i + step] == 1)
// ++step; // skip following unity-dimensions and its strides if such are present
// if((i + step) < oldRank && oldStrides[i] != oldShape[i + step] * oldStrides[i + step])
// return false; // not contiguous enough
// }
// // ------ Calculate new strides for all axes currently worked with ------ //
// if(isFOrder) {
// newStrides[newStart] = oldStrides[oldStart];
// for (int i = newStart + 1; i < newStop; ++i)
// newStrides[i] = newStrides[i - 1] * newShape[i - 1];
// }
// else {
// newStrides[newStop - 1] = oldStrides[oldStop - 1];
// for (int i = newStop - 1; i > newStart; --i)
// newStrides[i - 1] = newStrides[i] * newShape[i];
// }
// newStart = newStop++;
// oldStart = oldStop++;
// }
// // rest of strides should be unities (if there is remainder in strides space, that is newStart < newRank)
// for (int i = newStart; i < newRank; ++i)
// newStrides[i] = 1;
// newShapeInfo[2 * newRank + 3] = shape::order(oldShapeInfo); // order
// newShapeInfo[2 * newRank + 2] = shape::elementWiseStride(oldShapeInfo); // ews
// newShapeInfo[2 * newRank + 1] = shape::type(oldShapeInfo); // type
@ -3982,57 +4010,98 @@ INLINEDEF _CUDA_HD bool areStridesDefault(const Nd4jLong* shapeInfo) {
// }
//////////////////////////////////////////////////////////////////////
INLINEDEF _CUDA_H bool reshapeC(const int oldRank, const Nd4jLong* oldShapeInfo, const int newRank, const Nd4jLong* newShape, Nd4jLong* newShapeInfo) {
// PLEASE NOTE !: reshaping not-permuted (ews=1) array in f order (except insertion/elimination of unities) will definitely cause allocation of new buffer for array elements
// also this function takes into account identical shapes automatically, namely in that case oldShapeInfo is completely copied to newShapeInfo
INLINEDEF _CUDA_HD bool reshapeC(const Nd4jLong* oldShapeInfo, const char newOrder, const int newRank, const Nd4jLong* newShape, Nd4jLong* newShapeInfo) {
// copy shape from newShape into newShapeInfo
newShapeInfo[0] = newRank;
memcpy(newShapeInfo + 1, newShape, newRank * sizeof(Nd4jLong));
Nd4jLong* newStrides = shape::stride(newShapeInfo);
const Nd4jLong* oldShape = shape::shapeOf(const_cast<Nd4jLong*>(oldShapeInfo));
const Nd4jLong* oldStrides = shape::stride(const_cast<Nd4jLong*>(oldShapeInfo));
Nd4jLong oldStart(0), oldStop(1), newStart(0), newStop(1), newDim, oldDim;
// copy order
newShapeInfo[2 * newRank + 3] = newOrder;
while (newStart < newRank && oldStart < oldRank) {
return shape::reshapeC(oldShapeInfo, newShapeInfo);
}
//////////////////////////////////////////////////////////////////////
INLINEDEF _CUDA_HD bool reshapeC(const Nd4jLong* oldShapeInfo, Nd4jLong* newShapeInfo) {
// newShapeInfo contains rank, shape and order; but no strides, type and ews
const int newRank = shape::rank(newShapeInfo);
// if oldShapeInfo is scalar or vector with length=1
if(shape::length(oldShapeInfo) == 1) {
for (uint i = 0; i < newRank; ++i)
shape::stride(newShapeInfo)[i] = 1;
newShapeInfo[2 * newRank + 1] = shape::type(oldShapeInfo);
*shape::ews(newShapeInfo) = 1;
return true;
}
const auto oldOrder = shape::order(oldShapeInfo);
const auto newOrder = shape::order(newShapeInfo);
const auto oldEws = shape::elementWiseStride(const_cast<Nd4jLong*>(oldShapeInfo));
if(oldEws > 0 && oldOrder != newOrder)
return false;
// *** FIRST STAGE - exclude unity dimensions from oldShapeInfo and newShapeInfo (if such are present of course), since they don't affect on strides evaluation, however they complicate code
// FIXME - indeed we don't need to allocate so large memory amount (2*MAX_RANK), sufficient amount is (2*oldNumOfNonUnities + 2*newNumOfNonUnities)
Nd4jLong tempBuffer[4*MAX_RANK];
Nd4jLong *oldShape = tempBuffer, *newShape = tempBuffer + 2*MAX_RANK, *oldStrides, *newStrides;
// exclude unities from oldShapeInfo
const int oldNumOfNonUnities = shape::excludeUnitiesFromShapeInfo(oldShapeInfo, oldShape, oldStrides);
const int newNumOfNonUnities = shape::excludeUnitiesFromShapeInfo(newShapeInfo, newShape, newStrides);
// *** SECOND STAGE - strides evaluation
int oldStart(0), oldStop(1), newStart(0), newStop(1), newDim, oldDim;
while (newStart < newNumOfNonUnities && oldStart < oldNumOfNonUnities) {
newDim = newShape[newStart];
oldDim = oldShape[oldStart];
while (newDim != oldDim && newDim > 0 && oldDim > 0)
if (newDim < oldDim) newDim *= newShape[newStop++];
else oldDim *= oldShape[oldStop++];
while (newDim != oldDim && newDim > 0 && oldDim > 0) {
// ------ Check whether the original axes can be combined ------ //
for (int step = 1, i = oldStart; i < oldStop - 1; ++i) {
if(oldShape[i] == 1) // skip unity-dimension and its stride
continue;
while((i + step) < oldRank && oldShape[i + step] == 1)
++step; // skip following unity-dimensions and its strides if such are present
if((i + step) < oldRank && oldStrides[i] != oldShape[i + step] * oldStrides[i + step])
return false; // not contiguous enough
if (newDim < oldDim)
newDim *= newShape[newStop++];
else
oldDim *= oldShape[oldStop++];
}
newStrides[newStop - 1] = oldStrides[oldStop - 1];
for (int i = newStop - 1; i > newStart; --i)
newStrides[i - 1] = newStrides[i] * newShape[i];
// check c-contiguous of old axes range
for(uint i = oldStart; i < oldStop - 1; ++i) // do not check value of last stride, it doesn't matter
if(oldStrides[i] != oldShape[i + 1] * oldStrides[i + 1])
return false; // not contiguous
// fill newStrides in c manner
newStrides[newStop - 1] = oldStrides[oldStop - 1]; // copy last stride
for (int i = newStop - 2; i >= newStart; --i)
newStrides[i] = newStrides[i + 1] * newShape[i + 1];
newStart = newStop++;
oldStart = oldStop++;
}
// rest of strides should be unities (if there is remainder in strides space, that is newStart < newRank)
for (int i = newStart; i < newRank; ++i)
newStrides[i] = 1;
// fill new calculated strides into newShapeInfo, take into account possible unities in shape
for (int j = 0, i = 0; i < newRank; ++i)
shape::stride(newShapeInfo)[i] = (shape::shapeOf(newShapeInfo)[i] == 1) ? 1 : newStrides[j++];
// set ews
if(oldEws == 0)
shape::checkStridesSetEwsAndOrder(newShapeInfo, newOrder, newNumOfNonUnities, newShape, newStrides); // set ews and order
else {
newShapeInfo[2 * newRank + 3] = oldOrder; // order
*shape::ews(newShapeInfo) = oldEws; // ews
}
newShapeInfo[2 * newRank + 3] = shape::order(oldShapeInfo); // order
newShapeInfo[2 * newRank + 2] = shape::elementWiseStride(oldShapeInfo); // ews
newShapeInfo[2 * newRank + 1] = shape::type(oldShapeInfo); // type
return true;
}
}
INLINEDEF _CUDA_H bool canReshape(const int oldRank, Nd4jLong* oldShape, const int newRank, Nd4jLong* newShapeOf, bool isFOrder) {
@ -4573,129 +4642,75 @@ INLINEDEF void calcOffsets(const int rank, const Nd4jLong* shape, const Nd4jLong
}
//////////////////////////////////////////////////////////////////////
INLINEDEF void _CUDA_HD setEws(Nd4jLong* shapeInfo, Nd4jLong len) {
INLINEDEF void _CUDA_HD checkStridesSetEwsAndOrder(Nd4jLong* shapeInfo) {
// FIXME - indeed we don't need to allocate so large memory amount (2*MAX_RANK), sufficient amount is (2*oldNumOfNonUnities + 2*newNumOfNonUnities)
Nd4jLong tempBuffer[2*MAX_RANK];
Nd4jLong *shape = tempBuffer, *strides;
const int rank = shape::rank(shapeInfo);
const Nd4jLong* shape = shape::shapeOf(shapeInfo);
const Nd4jLong* strides = shape::stride(shapeInfo);
const char order = shape::order(shapeInfo);
Nd4jLong* ews = shape::ews(shapeInfo);
// exclude unities from shapeInfo
const int numOfNonUnities = shape::excludeUnitiesFromShapeInfo(shapeInfo, shape, strides);
if(len == -1) // calculate array length if it is not given
len = shape::length(shapeInfo);
if(len <= 1) { // empty, scalar or unity-vector case
*ews = 1;
return;
}
int nonUnityDim(0);
if(shape::isCommonVector(shapeInfo, nonUnityDim)) {
*ews = strides[nonUnityDim];
return;
}
// check last(c)/first(f) dimension, it should be equal to 1
if((order == 'c' && shape[rank - 1] != 1 && strides[rank - 1] != 1) || (order == 'f' && shape[0] != 1 && strides[0] != 1)) {
*ews = 0;
return;
}
Nd4jLong correctStride = 1;
if(order == 'c') {
for (int i = rank - 2; i >= 0 ; i--) {
correctStride *= shape[i + 1];
if(shape[i] == 1)
continue;
if(correctStride != strides[i]) {
*ews = 0;
return;
}
}
}
else {
for (int i = 1; i < rank; ++i) {
correctStride *= shape[i - 1];
if(shape[i] == 1)
continue;
if(correctStride != strides[i]) {
*ews = 0;
return;
}
}
}
*ews = 1;
shape::checkStridesSetEwsAndOrder(shapeInfo, shape::order(shapeInfo), numOfNonUnities, shape, strides);
}
//////////////////////////////////////////////////////////////////////
INLINEDEF _CUDA_HD void setOrderAndEws(Nd4jLong* shapeInfo, Nd4jLong len) {
INLINEDEF void _CUDA_HD checkStridesSetEwsAndOrder(Nd4jLong* shapeInfo, const char proposedOrder, const int numOfNonUnities, const Nd4jLong* shapeNoUnities, const Nd4jLong* stridesNoUnities) {
const int rank = shape::rank(shapeInfo);
const Nd4jLong* shape = shape::shapeOf(shapeInfo);
const Nd4jLong* strides = shape::stride(shapeInfo);
const char order = shape::order(shapeInfo);
Nd4jLong* ews = shape::ews(shapeInfo);
if(len == -1) // calculate array length if it is not given
len = shape::length(shapeInfo);
if(len <= 1) { // empty, scalar or unity-vector case
*ews = 1;
if(shape::length(shapeInfo) == 1) {
*shape::ews(shapeInfo) = 1;
shapeInfo[rank * 2 + 3] = (int)proposedOrder;
return;
}
int nonUnityDim(0);
if(shape::isCommonVector(shapeInfo, nonUnityDim)) { // in this case we don't change order
*ews = strides[nonUnityDim];
if(numOfNonUnities == 1) { // case of common vector
*shape::ews(shapeInfo) = *stridesNoUnities;
shapeInfo[rank * 2 + 3] = (int)proposedOrder;
return;
}
// check if strides are contiguous in respect to c-order
// firstly check last stride, it should be equal to 1
if (strides[rank - 1] == 1 || shape[rank - 1] == 1) { // last dimension is ok, go on through the rest dimensions in reverse order
Nd4jLong correctStride = 1;
bool cContiguous = true;
for (int i = rank - 2; i >= 0 ; i--) {
correctStride *= shape[i + 1];
if(shape[i] == 1)
continue;
if(correctStride != strides[i]) {
cContiguous = false;
bool contiguous = true;
// *** check whether strides are in c contiguous order ***//
if(stridesNoUnities[numOfNonUnities - 1] != 1) // last stride should be always unity for c order
contiguous = false;
else {
for (uint i = 0; i < numOfNonUnities - 1; ++i) {
if(stridesNoUnities[i] != stridesNoUnities[i + 1] * shapeNoUnities[i + 1]) {
contiguous = false;
break;
}
}
if(cContiguous) {
*ews = 1;
shapeInfo[shape::shapeInfoLength(rank) - 1] = 99;
}
if(contiguous) {
*shape::ews(shapeInfo) = 1;
shapeInfo[rank * 2 + 3] = 99;
return;
}
}
// now check if strides are contiguous in respect to f-order
// firstly check first stride, it should be equal to 1
if(strides[0] == 1 || shape[0] == 1) { // first dimension is ok, go on through the rest dimensions
Nd4jLong correctStride = 1;
bool fContiguous = true;
for (int i = 1; i < rank; ++i) {
correctStride *= shape[i - 1];
if(shape[i] == 1)
continue;
if(correctStride != strides[i]) {
fContiguous = false;
contiguous = true;
//*** check whether strides are in f contiguous order ***//
if(stridesNoUnities[0] != 1) // first stride should be always unity for f order
contiguous = false;
else {
for (uint i = 1; i < numOfNonUnities; ++i) {
if(stridesNoUnities[i] != stridesNoUnities[i - 1] * shapeNoUnities[i - 1]) {
contiguous = false;
break;
}
}
if(fContiguous) {
*ews = 1;
shapeInfo[shape::shapeInfoLength(rank) - 1] = 102;
}
if(contiguous) {
*shape::ews(shapeInfo) = 1;
shapeInfo[rank * 2 + 3] = 102;
return;
}
}
*ews = 0;
// if both cContiguous and fContiguous are false then order is preserved
*shape::ews(shapeInfo) = 0;
shapeInfo[rank * 2 + 3] = (int)proposedOrder;
}
//////////////////////////////////////////////////////////////////////
@ -4709,49 +4724,42 @@ INLINEDEF _CUDA_HD void calcSubArrShapeAndOffsets(const Nd4jLong* wholeShapeInfo
return;
}
Nd4jLong *outShapeInfo = new Nd4jLong[shape::shapeInfoLength(wholeShapeInfo)];
memcpy(outShapeInfo, wholeShapeInfo, shape::shapeInfoByteLength(wholeShapeInfo));
const int subArrRank = keepUnitiesInShape ? rank : rank - dimsSize;
subArrShapeInfo[0] = subArrRank; // rank
subArrShapeInfo[2 * subArrRank + 1] = shape::type(wholeShapeInfo); // type
subArrShapeInfo[2 * subArrRank + 3] = shape::order(wholeShapeInfo); // order
Nd4jLong* shape = new Nd4jLong[dimsSize];
Nd4jLong* strides = new Nd4jLong[dimsSize];
const int subArrRank = keepUnitiesInShape ? rank : rank - dimsSize;
Nd4jLong* shapeNoUnities = nullptr;
if(!keepUnitiesInShape)
shapeNoUnities = new Nd4jLong[subArrRank];
Nd4jLong subArrLen = 1;
for(int k = subArrRank - 1, j = dimsSize - 1, i = rank - 1; i >= 0; --i) {
if(j >= 0 && i == dimsToExclude[j]) {
strides[j] = shape::stride(outShapeInfo)[i];
shape[j--] = shape::shapeOf(outShapeInfo)[i];
shape::shapeOf(outShapeInfo)[i] = 1;
strides[j] = shape::stride(wholeShapeInfo)[i];
shape[j--] = shape::shapeOf(wholeShapeInfo)[i];
if(keepUnitiesInShape) {
shape::shapeOf(subArrShapeInfo)[k] = 1;
shape::stride(subArrShapeInfo)[k--] = shape::stride(wholeShapeInfo)[i];
}
}
else {
subArrLen *= shape::shapeOf(outShapeInfo)[i];
if(!keepUnitiesInShape)
shapeNoUnities[k--] = shape::shapeOf(outShapeInfo)[i];
}
shape::shapeOf(subArrShapeInfo)[k] = shape::shapeOf(wholeShapeInfo)[i];
shape::stride(subArrShapeInfo)[k--] = shape::stride(wholeShapeInfo)[i];
}
// evaluate ews
shape::setEws(outShapeInfo, subArrLen);
}
// calculation of sub-array offsets (subArrOffsets)
shape::calcOffsets(dimsSize, shape, strides, subArrOffsets);
// remove unities from outShapeInfo if required
if(!keepUnitiesInShape) {
shape::reshapeC(rank, outShapeInfo, subArrRank, shapeNoUnities, subArrShapeInfo);
delete []shapeNoUnities;
}
else
memcpy(subArrShapeInfo, outShapeInfo, shape::shapeInfoLength(subArrRank) * sizeof(Nd4jLong));
// evaluate ews
shape::checkStridesSetEwsAndOrder(subArrShapeInfo);
delete []strides;
delete []shape;
delete []outShapeInfo;
}
//////////////////////////////////////////////////////////////////////
@ -4815,197 +4823,240 @@ INLINEDEF void _CUDA_HD index2coords(Nd4jLong index, const Nd4jLong *shapeInfo,
}
//////////////////////////////////////////////////////////////////////
INLINEDEF _CUDA_HD void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffsets, const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const Nd4jLong* zShapeInfo, Nd4jLong*& zOffsets, const char order) {
// INLINEDEF _CUDA_HD void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffsets, const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const Nd4jLong* zShapeInfo, Nd4jLong*& zOffsets, const char order) {
// we assume all array have same length
const Nd4jLong len = shape::length(xShapeInfo);
// // we assume all array have same length
// const Nd4jLong len = shape::length(xShapeInfo);
const Nd4jLong xEws = shape::elementWiseStride(xShapeInfo);
const Nd4jLong yEws = shape::elementWiseStride(yShapeInfo);
const Nd4jLong zEws = shape::elementWiseStride(zShapeInfo);
// const Nd4jLong xEws = shape::elementWiseStride(xShapeInfo);
// const Nd4jLong yEws = shape::elementWiseStride(yShapeInfo);
// const Nd4jLong zEws = shape::elementWiseStride(zShapeInfo);
const char xOrder = shape::order(xShapeInfo);
const char yOrder = shape::order(yShapeInfo);
const char zOrder = shape::order(zShapeInfo);
// const char xOrder = shape::order(xShapeInfo);
// const char yOrder = shape::order(yShapeInfo);
// const char zOrder = shape::order(zShapeInfo);
const bool shapesSame = shape::shapeEquals(xShapeInfo, yShapeInfo, zShapeInfo);
// const bool shapesSame = shape::shapeEquals(xShapeInfo, yShapeInfo, zShapeInfo);
if (xEws == 1 && yEws == 1 && zEws == 1 && xOrder == yOrder && xOrder == zOrder && (xOrder == 'c' || shapesSame)) {
xOffsets = yOffsets = zOffsets = nullptr;
}
else if(xEws == 1 && yEws == 1 && xOrder == yOrder && (xOrder == 'c' || shape::shapeEquals(xShapeInfo, yShapeInfo))) {
xOffsets = yOffsets = nullptr;
zOffsets = new Nd4jLong[len];
shape::calcOffsets(zShapeInfo, zOffsets, xOrder);
}
else if(xEws == 1 && zEws == 1 && xOrder == zOrder && (xOrder == 'c' || shape::shapeEquals(xShapeInfo, zShapeInfo))) {
xOffsets = zOffsets = nullptr;
yOffsets = new Nd4jLong[len];
shape::calcOffsets(yShapeInfo, yOffsets, xOrder);
}
else if(yEws == 1 && zEws == 1 && yOrder == zOrder && (yOrder == 'c' || shape::shapeEquals(yShapeInfo, zShapeInfo))) {
yOffsets = zOffsets = nullptr;
xOffsets = new Nd4jLong[len];
shape::calcOffsets(xShapeInfo, xOffsets, yOrder);
}
else if(xEws == 1) {
xOffsets = nullptr;
PRAGMA_OMP_PARALLEL_SECTIONS
{
PRAGMA_OMP_SECTION
{
yOffsets = new Nd4jLong[len];
shape::calcOffsets(yShapeInfo, yOffsets, xOrder);
}
PRAGMA_OMP_SECTION
{
zOffsets = new Nd4jLong[len];
shape::calcOffsets(zShapeInfo, zOffsets, xOrder);
}
}
}
else if(yEws == 1) {
yOffsets = nullptr;
PRAGMA_OMP_PARALLEL_SECTIONS
{
PRAGMA_OMP_SECTION
{
xOffsets = new Nd4jLong[len];
shape::calcOffsets(xShapeInfo, xOffsets, yOrder);
}
PRAGMA_OMP_SECTION
{
zOffsets = new Nd4jLong[len];
shape::calcOffsets(zShapeInfo, zOffsets, yOrder);
}
}
}
else if(zEws == 1) {
zOffsets = nullptr;
PRAGMA_OMP_PARALLEL_SECTIONS
{
PRAGMA_OMP_SECTION
{
xOffsets = new Nd4jLong[len];
shape::calcOffsets(xShapeInfo, xOffsets, zOrder);
}
PRAGMA_OMP_SECTION
{
yOffsets = new Nd4jLong[len];
shape::calcOffsets(yShapeInfo, yOffsets, zOrder);
}
}
}
else if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo, zShapeInfo)) {
xOffsets = new Nd4jLong[len];
shape::calcOffsets(xShapeInfo, xOffsets);
yOffsets = zOffsets = xOffsets;
}
else if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) {
PRAGMA_OMP_PARALLEL_SECTIONS
{
PRAGMA_OMP_SECTION
{
xOffsets = new Nd4jLong[len];
shape::calcOffsets(xShapeInfo, xOffsets);
}
PRAGMA_OMP_SECTION
{
zOffsets = new Nd4jLong[len];
shape::calcOffsets(zShapeInfo, zOffsets);
}
}
yOffsets = xOffsets;
}
else if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
PRAGMA_OMP_PARALLEL_SECTIONS
{
PRAGMA_OMP_SECTION
{
xOffsets = new Nd4jLong[len];
shape::calcOffsets(xShapeInfo, xOffsets);
}
PRAGMA_OMP_SECTION
{
yOffsets = new Nd4jLong[len];
shape::calcOffsets(yShapeInfo, yOffsets);
}
}
zOffsets = xOffsets;
}
else {
PRAGMA_OMP_PARALLEL_SECTIONS
{
PRAGMA_OMP_SECTION
{
xOffsets = new Nd4jLong[len];
shape::calcOffsets(xShapeInfo, xOffsets);
}
PRAGMA_OMP_SECTION
{
yOffsets = new Nd4jLong[len];
shape::calcOffsets(yShapeInfo, yOffsets);
}
PRAGMA_OMP_SECTION
{
zOffsets = new Nd4jLong[len];
shape::calcOffsets(zShapeInfo, zOffsets);
// if (xEws == 1 && yEws == 1 && zEws == 1 && xOrder == yOrder && xOrder == zOrder && (xOrder == 'c' || shapesSame)) {
// xOffsets = yOffsets = zOffsets = nullptr;
// }
// else if(xEws == 1 && yEws == 1 && xOrder == yOrder && (xOrder == 'c' || shape::shapeEquals(xShapeInfo, yShapeInfo))) {
// xOffsets = yOffsets = nullptr;
// zOffsets = new Nd4jLong[len];
// shape::calcOffsets(zShapeInfo, zOffsets, xOrder);
// }
// else if(xEws == 1 && zEws == 1 && xOrder == zOrder && (xOrder == 'c' || shape::shapeEquals(xShapeInfo, zShapeInfo))) {
// xOffsets = zOffsets = nullptr;
// yOffsets = new Nd4jLong[len];
// shape::calcOffsets(yShapeInfo, yOffsets, xOrder);
// }
// else if(yEws == 1 && zEws == 1 && yOrder == zOrder && (yOrder == 'c' || shape::shapeEquals(yShapeInfo, zShapeInfo))) {
// yOffsets = zOffsets = nullptr;
// xOffsets = new Nd4jLong[len];
// shape::calcOffsets(xShapeInfo, xOffsets, yOrder);
// }
// else if(xEws == 1) {
// xOffsets = nullptr;
// PRAGMA_OMP_PARALLEL_SECTIONS
// {
// PRAGMA_OMP_SECTION
// {
// yOffsets = new Nd4jLong[len];
// shape::calcOffsets(yShapeInfo, yOffsets, xOrder);
// }
// PRAGMA_OMP_SECTION
// {
// zOffsets = new Nd4jLong[len];
// shape::calcOffsets(zShapeInfo, zOffsets, xOrder);
// }
// }
// }
// else if(yEws == 1) {
// yOffsets = nullptr;
// PRAGMA_OMP_PARALLEL_SECTIONS
// {
// PRAGMA_OMP_SECTION
// {
// xOffsets = new Nd4jLong[len];
// shape::calcOffsets(xShapeInfo, xOffsets, yOrder);
// }
// PRAGMA_OMP_SECTION
// {
// zOffsets = new Nd4jLong[len];
// shape::calcOffsets(zShapeInfo, zOffsets, yOrder);
// }
// }
// }
// else if(zEws == 1) {
// zOffsets = nullptr;
// PRAGMA_OMP_PARALLEL_SECTIONS
// {
// PRAGMA_OMP_SECTION
// {
// xOffsets = new Nd4jLong[len];
// shape::calcOffsets(xShapeInfo, xOffsets, zOrder);
// }
// PRAGMA_OMP_SECTION
// {
// yOffsets = new Nd4jLong[len];
// shape::calcOffsets(yShapeInfo, yOffsets, zOrder);
// }
// }
// }
// else if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo, zShapeInfo)) {
// xOffsets = new Nd4jLong[len];
// shape::calcOffsets(xShapeInfo, xOffsets);
// yOffsets = zOffsets = xOffsets;
// }
// else if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) {
// PRAGMA_OMP_PARALLEL_SECTIONS
// {
// PRAGMA_OMP_SECTION
// {
// xOffsets = new Nd4jLong[len];
// shape::calcOffsets(xShapeInfo, xOffsets);
// }
// PRAGMA_OMP_SECTION
// {
// zOffsets = new Nd4jLong[len];
// shape::calcOffsets(zShapeInfo, zOffsets);
// }
// }
// yOffsets = xOffsets;
// }
// else if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
// PRAGMA_OMP_PARALLEL_SECTIONS
// {
// PRAGMA_OMP_SECTION
// {
// xOffsets = new Nd4jLong[len];
// shape::calcOffsets(xShapeInfo, xOffsets);
// }
// PRAGMA_OMP_SECTION
// {
// yOffsets = new Nd4jLong[len];
// shape::calcOffsets(yShapeInfo, yOffsets);
// }
// }
// zOffsets = xOffsets;
// }
// else {
// PRAGMA_OMP_PARALLEL_SECTIONS
// {
// PRAGMA_OMP_SECTION
// {
// xOffsets = new Nd4jLong[len];
// shape::calcOffsets(xShapeInfo, xOffsets);
// }
// PRAGMA_OMP_SECTION
// {
// yOffsets = new Nd4jLong[len];
// shape::calcOffsets(yShapeInfo, yOffsets);
// }
// PRAGMA_OMP_SECTION
// {
// zOffsets = new Nd4jLong[len];
// shape::calcOffsets(zShapeInfo, zOffsets);
// }
// }
// }
// }
//////////////////////////////////////////////////////////////////////
// INLINEDEF _CUDA_HD void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffsets, const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const char order) {
// // we assume all array have same length
// const Nd4jLong len = shape::length(xShapeInfo);
// const Nd4jLong xEws = shape::elementWiseStride(xShapeInfo);
// const Nd4jLong yEws = shape::elementWiseStride(yShapeInfo);
// const char xOrder = shape::order(xShapeInfo);
// const char yOrder = shape::order(yShapeInfo);
// const bool shapesSame = shape::shapeEquals(xShapeInfo, yShapeInfo);
// if (xEws == 1 && yEws == 1 && xOrder == yOrder && (xOrder == 'c' || shapesSame)) {
// xOffsets = yOffsets = nullptr;
// }
// else if(xEws == 1) {
// xOffsets = nullptr;
// yOffsets = new Nd4jLong[len];
// shape::calcOffsets(yShapeInfo, yOffsets, xOrder);
// }
// else if(yEws == 1) {
// yOffsets = nullptr;
// xOffsets = new Nd4jLong[len];
// shape::calcOffsets(xShapeInfo, xOffsets, yOrder);
// }
// else if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) {
// xOffsets = new Nd4jLong[len];
// shape::calcOffsets(xShapeInfo, xOffsets);
// yOffsets = xOffsets;
// }
// else {
// PRAGMA_OMP_PARALLEL_SECTIONS
// {
// PRAGMA_OMP_SECTION
// {
// xOffsets = new Nd4jLong[len];
// shape::calcOffsets(xShapeInfo, xOffsets);
// }
// PRAGMA_OMP_SECTION
// {
// yOffsets = new Nd4jLong[len];
// shape::calcOffsets(yShapeInfo, yOffsets);
// }
// }
// }
// }
//////////////////////////////////////////////////////////////////////
INLINEDEF _CUDA_HD int excludeUnitiesFromShapeInfo(const Nd4jLong* inShapeInfo, Nd4jLong*& shapeNoUnities, Nd4jLong*& stridesNoUnities) {
const int rank = shape::rank(inShapeInfo);
const int numOfNonUnities = shape::numOfNonUnitDims(rank, shape::shapeOf(inShapeInfo));
if(numOfNonUnities == rank) { // no unities in shape, no copy procedure
shapeNoUnities = const_cast<Nd4jLong*>(inShapeInfo) + 1;
stridesNoUnities = const_cast<Nd4jLong*>(inShapeInfo) + 1 + rank;
return numOfNonUnities;
}
for(uint j = 0, i = 0; i < rank; ++i) {
if(shape::shapeOf(inShapeInfo)[i] != 1) {
shapeNoUnities[j] = shape::shapeOf(inShapeInfo)[i];
shapeNoUnities[numOfNonUnities + j++] = shape::stride(inShapeInfo)[i];
}
}
stridesNoUnities = shapeNoUnities + numOfNonUnities;
return numOfNonUnities;
}
//////////////////////////////////////////////////////////////////////
INLINEDEF _CUDA_HD void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffsets, const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const char order) {
INLINEDEF _CUDA_HD void excludeUnitiesFromShapeInfo(const Nd4jLong* inShapeInfo, const int dimsSize, const int* dimsToExclude, Nd4jLong* outShapeInfo) {
// we assume all array have same length
const Nd4jLong len = shape::length(xShapeInfo);
outShapeInfo[0] = inShapeInfo[0] - dimsSize;
const Nd4jLong xEws = shape::elementWiseStride(xShapeInfo);
const Nd4jLong yEws = shape::elementWiseStride(yShapeInfo);
for(uint j = 0, k = 0, i = 0; i < inShapeInfo[0]; ++i) {
if(j < dimsSize && i == dimsToExclude[j]) {
++j;
continue;
}
const char xOrder = shape::order(xShapeInfo);
const char yOrder = shape::order(yShapeInfo);
shape::shapeOf(outShapeInfo)[k] = shape::shapeOf(inShapeInfo)[i];
shape::stride(outShapeInfo)[k++] = shape::stride(inShapeInfo)[i];
}
const bool shapesSame = shape::shapeEquals(xShapeInfo, yShapeInfo);
if (xEws == 1 && yEws == 1 && xOrder == yOrder && (xOrder == 'c' || shapesSame)) {
xOffsets = yOffsets = nullptr;
}
else if(xEws == 1) {
xOffsets = nullptr;
yOffsets = new Nd4jLong[len];
shape::calcOffsets(yShapeInfo, yOffsets, xOrder);
}
else if(yEws == 1) {
yOffsets = nullptr;
xOffsets = new Nd4jLong[len];
shape::calcOffsets(xShapeInfo, xOffsets, yOrder);
}
else if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) {
xOffsets = new Nd4jLong[len];
shape::calcOffsets(xShapeInfo, xOffsets);
yOffsets = xOffsets;
}
else {
PRAGMA_OMP_PARALLEL_SECTIONS
{
PRAGMA_OMP_SECTION
{
xOffsets = new Nd4jLong[len];
shape::calcOffsets(xShapeInfo, xOffsets);
}
PRAGMA_OMP_SECTION
{
yOffsets = new Nd4jLong[len];
shape::calcOffsets(yShapeInfo, yOffsets);
}
}
}
outShapeInfo[2 * outShapeInfo[0] + 1] = shape::type(inShapeInfo); // type
*shape::ews(outShapeInfo) = shape::elementWiseStride(inShapeInfo); // ews
outShapeInfo[2 * outShapeInfo[0] + 3] = shape::order(inShapeInfo); // order
}
}
#endif /* SHAPE_H_ */

View File

@ -84,7 +84,7 @@ namespace functions {
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
int totalThreads = gridDim.x * blockDim.x;
if(xEws > 0 && zEws > 0 && xOrder == zOrder) {
if(xEws > 0 && zEws > 0 && xOrder == zOrder && xOrder == 'c') {
for (int i = tid; i < length; i += totalThreads)
z[i * zEws] = OpType::op(x[i * xEws], params);

View File

@ -89,7 +89,7 @@ namespace functions {
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
int totalThreads = gridDim.x * blockDim.x;
if(xEws > 0 && zEws > 0 && xOrder == zOrder) {
if(xEws > 0 && zEws > 0 && xOrder == zOrder && xOrder == 'c') {
for (int i = tid; i < length; i += totalThreads)
z[i * zEws] = OpType::op(x[i * xEws], params);

View File

@ -97,7 +97,7 @@ namespace functions {
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
int totalThreads = gridDim.x * blockDim.x;
if(xEws > 0 && zEws > 0 && xOrder == zOrder) {
if(xEws > 0 && zEws > 0 && xOrder == zOrder && xOrder == 'c') {
for (Nd4jLong i = tid; i < length; i += totalThreads)
z[i * zEws] = OpType::op(x[i * xEws], params);

View File

@ -87,7 +87,7 @@ namespace functions {
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
int totalThreads = gridDim.x * blockDim.x;
if(xEws > 0 && zEws > 0 && xOrder == zOrder) {
if(xEws > 0 && zEws > 0 && xOrder == zOrder && xOrder == 'c') {
for (int i = tid; i < length; i += totalThreads)
z[i * zEws] = OpType::op(x[i * xEws], params);

View File

@ -89,7 +89,7 @@ namespace functions {
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
int totalThreads = gridDim.x * blockDim.x;
if(xEws > 0 && zEws > 0 && xOrder == zOrder) {
if(xEws > 0 && zEws > 0 && xOrder == zOrder && xOrder == 'c') {
for (int i = tid; i < length; i += totalThreads)
z[i * zEws] = OpType::op(x[i * xEws], params);

View File

@ -21,17 +21,22 @@
#include <op_boilerplate.h>
#if NOT_EXCLUDED(OP_tensormmul)
#include <numeric>
#include <helpers/ShapeUtils.h>
#include <ops/declarable/CustomOperations.h>
#include <MmulHelper.h>
namespace nd4j {
namespace ops {
CUSTOM_OP_IMPL(tensormmul, 2, 1, false, 0, -1) {
namespace ops {
////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(tensormmul, 2, 1, false, 0, -1) {
auto a = INPUT_VARIABLE(0);
auto b = INPUT_VARIABLE(1);
auto c = OUTPUT_VARIABLE(0); //
auto c = OUTPUT_VARIABLE(0);
REQUIRE_TRUE(a->dataType() == b->dataType(), 0, "tensormmul: A, B and C data types must be the same");
@ -40,20 +45,20 @@ namespace nd4j {
int axe1_size = INT_ARG(axe0_size+1);
std::vector<int> axes_0(axe0_size), axes_1(axe1_size);
for (int e = 0; e < axe0_size; e++)
axes_0[e] = (int) INT_ARG(e+1);
axes_0[e] = (int)INT_ARG(e + 1);
for (int e = 0; e < axe1_size; e++)
axes_1[e] = (int) INT_ARG(e + axe0_size + 2);
axes_1[e] = (int)INT_ARG(e + axe0_size + 2);
nd4j_verbose("axe0: %i; axe1: %i;\n", axes_0.size(), axes_1.size());
MmulHelper::tensorDot(a, b, c, axes_0, axes_1);
return Status::OK();
}
DECLARE_SYN(tensordot, tensormmul);
}
DECLARE_SYN(tensordot, tensormmul);
DECLARE_SHAPE_FN(tensormmul) {
////////////////////////////////////////////////////////////////////////
DECLARE_SHAPE_FN(tensormmul) {
auto aShapeInfo = inputShape->at(0);
auto bShapeInfo = inputShape->at(1);
@ -76,15 +81,114 @@ namespace nd4j {
auto outShape = nd4j::ShapeUtils::evalShapeForTensorDot(aShapeInfo, bShapeInfo, axes_0, axes_1, permutAt, permutBt, shapeAt, shapeBt);
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(aShapeInfo), 'c', outShape)));
}
}
DECLARE_TYPES(tensormmul) {
////////////////////////////////////////////////////////////////////////
DECLARE_TYPES(tensormmul) {
getOpDescriptor()
->setAllowedInputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF})
->setAllowedInputTypes(1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF})
->setAllowedOutputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF});
}
////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(tensormmul_bp, 3, 2, false, 0, -1) {
auto A = INPUT_VARIABLE(0);
auto B = INPUT_VARIABLE(1);
auto dLdC = INPUT_VARIABLE(2);
auto dLdA = OUTPUT_VARIABLE(0);
auto dLdB = OUTPUT_VARIABLE(1);
REQUIRE_TRUE( (A->dataType() == B->dataType() && (dLdC->dataType() == A->dataType())), 0, "tensormmul_bp: A, B and dLdC data types must be the same");
int axe0Size = INT_ARG(0);
int axe1Size = INT_ARG(axe0Size + 1);
auto Arank = A->rankOf();
auto Brank = B->rankOf();
auto dLdCrank = dLdC->rankOf();
REQUIRE_TRUE((Arank >= axe0Size), 0, "tensormmul_bp: A rank must be the higher or same as input axes 0");
REQUIRE_TRUE((Brank >= axe1Size), 0, "tensormmul_bp: B rank must be the higher or same as input axes 1");
// building axes
std::vector<int> axes0(axe0Size), axes1(axe1Size);
for (uint e = 0; e < axe0Size; e++)
axes0[e] = (int)INT_ARG(e + 1);
for (uint e = 0; e < axe1Size; e++)
axes1[e] = (int)INT_ARG(e + axe0Size + 2);
std::vector<int> permutAt, permutBt;
std::vector<Nd4jLong> shapeAt, shapeBt;
ShapeUtils::evalShapeForTensorDot(A, B, axes0, axes1, permutAt, permutBt, shapeAt, shapeBt);
// special case for scalar value
if (dLdC->isScalar()) {
dLdA->assign((*dLdC) * *B);
dLdB->assign((*dLdC) * *A);
return Status::OK();
}
std::vector<int> axesA = ShapeUtils::evalDimsToExclude(Arank, axes0);
std::vector<int> axesB = ShapeUtils::evalDimsToExclude(Brank, axes1);
// rank always have to be divided by 2
std::vector<int> axesAdLdC, axesBdLdC;
if (dLdCrank > 1) {
axesAdLdC.resize(dLdCrank / 2);
std::iota(axesAdLdC.begin(), axesAdLdC.end(), 0);
axesBdLdC = ShapeUtils::evalDimsToExclude(dLdCrank, axesAdLdC);
}
else {
axesAdLdC.push_back(0);
axesBdLdC.push_back(0);
}
// calculate dLdA
MmulHelper::tensorDot(dLdC, B, dLdA, axesBdLdC, axesB, permutAt);
// calculate dLdB
MmulHelper::tensorDot(A, dLdC, dLdB, axesA, axesAdLdC, permutBt);
return Status::OK();
}
////////////////////////////////////////////////////////////////////////
DECLARE_SHAPE_FN(tensormmul_bp) {
auto aShapeInfo = inputShape->at(0);
auto bShapeInfo = inputShape->at(1);
auto dLShapeInfo = inputShape->at(2);
REQUIRE_TRUE((ArrayOptions::dataType(aShapeInfo) == ArrayOptions::dataType(bShapeInfo) &&
(ArrayOptions::dataType(dLShapeInfo) == ArrayOptions::dataType(aShapeInfo))), 0, "tensormmul_bp: A, B and dLdC data types must be the same");
Nd4jLong* dLdAShapeInfo = nullptr;
Nd4jLong* dLdBShapeInfo = nullptr;
COPY_SHAPE(aShapeInfo, dLdAShapeInfo);
COPY_SHAPE(bShapeInfo, dLdBShapeInfo);
return SHAPELIST(CONSTANT(dLdAShapeInfo), CONSTANT(dLdBShapeInfo));
}
////////////////////////////////////////////////////////////////////////
DECLARE_TYPES(tensormmul_bp) {
getOpDescriptor()
->setAllowedInputTypes(0, { DataType::FLOAT32, DataType::DOUBLE, DataType::HALF }) // maybe better ALL_FLOATS
->setAllowedInputTypes(1, { DataType::FLOAT32, DataType::DOUBLE, DataType::HALF })
->setAllowedInputTypes(2, { DataType::FLOAT32, DataType::DOUBLE, DataType::HALF })
->setAllowedOutputTypes(0, { DataType::FLOAT32, DataType::DOUBLE, DataType::HALF })
->setAllowedOutputTypes(1, { DataType::FLOAT32, DataType::DOUBLE, DataType::HALF });
}
}
}
#endif

View File

@ -79,7 +79,7 @@ CUSTOM_OP_IMPL(conv1d, 2, 1, false, 0, 5) {
}
auto inputReshaped = input ->reshape(input->ordering(), reshapeForInput);
auto outputReshaped = output ->reshape(output->ordering(), reshapeForOutput);
auto outputReshaped = output ->reshape(output->ordering(), reshapeForOutput, false);
auto weightsReshaped = weights->reshape(weights->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
nd4j::ops::conv2d conv2d;
@ -216,10 +216,10 @@ CUSTOM_OP_IMPL(conv1d_bp, 3, 2, false, 0, 5) {
}
auto inputReshaped = input ->reshape(input->ordering(), reshapeForInput);
auto gradIReshaped = gradI ->reshape(gradI->ordering(), reshapeForInput);
auto gradIReshaped = gradI ->reshape(gradI->ordering(), reshapeForInput, false);
auto gradOReshaped = gradO ->reshape(gradO->ordering(), reshapeForGradO);
auto weightsReshaped = weights->reshape(weights->ordering(),{1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
auto gradWReshaped = gradW ->reshape(gradW->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
auto gradWReshaped = gradW ->reshape(gradW->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}, false);// [kW, iC, oC] -> [1, kW, iC, oC]
nd4j::ops::conv2d_bp conv2dBP;
auto status = conv2dBP.execute({&inputReshaped, &weightsReshaped, bias, &gradOReshaped}, {&gradIReshaped, &gradWReshaped, gradB}, {}, {1,kW, 1,sW, 0,pW, 1,dW, paddingMode, !isNCW}, {});

View File

@ -239,7 +239,7 @@ CUSTOM_OP_IMPL(conv3dnew_bp, 3, 2, false, 0, 13) {
//----- calculation of gradO -----//
if(gradB) {
if(gradB->rankOf() == 2)
gradB = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}));
gradB = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}, false));
gradO->reduceAlongDimension(reduce::Sum, *gradB, gradOaxesForDot); // sum over bS oD oH oW
if(gradB != OUTPUT_VARIABLE(2))
delete gradB;

View File

@ -233,7 +233,7 @@ CUSTOM_OP_IMPL(deconv2d_bp, 3, 2, false, 0, 9) {
// ----- calculation of gradB ----- //
if(gradB) {
if(gradB->rankOf() == 2)
gradB = new NDArray(gradB->reshape(gradB->ordering(), {gradB->lengthOf()}));
gradB = new NDArray(gradB->reshape(gradB->ordering(), {gradB->lengthOf()}, false));
gradO->reduceAlongDimension(reduce::Sum, *gradB, {0, 2, 3}); // sum over bS, oH, oW
if(gradB != OUTPUT_VARIABLE(2))
delete gradB;

View File

@ -243,7 +243,7 @@ CUSTOM_OP_IMPL(deconv3d_bp, 3, 2, false, 0, 13) {
// ----- calculation of gradB ----- //
if(gradB) {
if(gradB->rankOf() == 2)
gradB = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}));
gradB = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}, false));
gradO->reduceAlongDimension(reduce::Sum, *gradB, {0, 2, 3, 4}); // sum over bS, oD, oH, oW
if(gradB != OUTPUT_VARIABLE(2))
delete gradB;

View File

@ -61,7 +61,7 @@ namespace nd4j {
}
auto source = inRank == 4?image->reshape(image->ordering(), {image->sizeAt(0), image->sizeAt(1), image->sizeAt(2), image->sizeAt(3)}):image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)});
auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}):output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)});
auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}, false) : output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)}, false);
return helpers::resizeAreaFunctor(block.launchContext(), &source, width, height, alignCorners, &target);
}

View File

@ -62,7 +62,7 @@ namespace nd4j {
REQUIRE_TRUE(!halfPixelAlign || (halfPixelAlign && !alignCorners), 0, "resize_bicubic: `half_pixel_centers' should be false or true only when `align_corners' is false");
auto source = inRank == 4?image->reshape(image->ordering(), {image->sizeAt(0), image->sizeAt(1), image->sizeAt(2), image->sizeAt(3)}):image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)});
auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}):output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)});
auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}, false) : output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)}, false);
return helpers::resizeBicubicFunctorA(block.launchContext(), &source, width, height, alignCorners, halfPixelAlign, &target);
}

View File

@ -43,7 +43,7 @@ namespace nd4j {
REQUIRE_TRUE(inRank == output->rankOf(), 0, "resize_bilinear: Input and output ranks should be equals, but %i and %i occured.", inRank, output->rankOf());
auto source = inRank == 4?image->reshape(image->ordering(), {image->sizeAt(0), image->sizeAt(1), image->sizeAt(2), image->sizeAt(3)}):image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)});
auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}):output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)});
auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}, false) : output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)}, false);
if (block.width() > 1) {
auto newImageSize = INPUT_VARIABLE(1);

View File

@ -63,7 +63,7 @@ namespace nd4j {
REQUIRE_TRUE(((alignCorners && height > 2) || (height > 0)) && ((alignCorners && width > 1) || (width > 0)), 0, "resize_nearest_neighbor: Wrong input or output size to resize (width = %d, height = %d)", width, height);
auto source = inRank == 4?*image:image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)});
auto target = inRank == 4?*output:output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)});
auto target = inRank == 4 ? *output : output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)}, false);
return helpers::resizeNeighborFunctor(block.launchContext(), inRank==4?image:&source, width, height, alignCorners, halfPixelCenter, inRank == 4 ? output : &target);
}

View File

@ -71,7 +71,7 @@ namespace nd4j {
}
if (block.isInplace()) {
output->reshapei(input->ordering(), shape);
output->reshapei(input->ordering(), shape, false);
} else {
auto tmp = input->reshape(input->ordering(), shape);
output->assign(tmp);

View File

@ -58,6 +58,7 @@ namespace nd4j {
*/
#if NOT_EXCLUDED(OP_tensormmul)
DECLARE_CUSTOM_OP(tensormmul, 2, 1, false, 0, -1);
DECLARE_CUSTOM_OP(tensormmul_bp, 3, 2, false, 0, -1);
#endif
/**

View File

@ -432,7 +432,7 @@ namespace nd4j {
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, input->dataType(), input->getContext());
NDArray outputReshaped = output->reshape(output->ordering(), outReShape);
NDArray outputReshaped = output->reshape(output->ordering(), outReShape, false);
helpers::im2col(*output->getContext(), *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW]
MmulHelper::tensorDot(&columns, weights, &outputReshaped, modifColumns, {{2,0,1,3},{iC,kH*kW,mC}}, modifOutput); // [iC, bS*oH*oW, kW*kH] x [iC, kH*kW, mC] = [iC, bS*oH*oW, mC]
@ -505,7 +505,7 @@ namespace nd4j {
if(gradB) {
NDArray* gradBR = gradB;
if(gradB->rankOf() == 2)
gradBR = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}));
gradBR = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}, false));
gradO->reduceAlongDimension(reduce::Sum, *gradBR, {0,indOoH,indOoH+1}); // sum over bS, oH, oW
if(gradBR != gradB)

View File

@ -30,7 +30,7 @@ namespace helpers {
void crossBatched(nd4j::LaunchContext * context, NDArray *a, NDArray *b, NDArray *o) {
auto _a = a->reshape(a->ordering(), {-1, 3});
auto _b = b->reshape(b->ordering(), {-1, 3});
auto _o = o->reshape(o->ordering(), {-1, 3});
auto _o = o->reshape(o->ordering(), {-1, 3}, false);
auto tadsA = _a.allTensorsAlongDimension({1});
auto tadsB = _b.allTensorsAlongDimension({1});

View File

@ -244,14 +244,14 @@ void spaceToBatch(nd4j::LaunchContext* context, const NDArray& input, NDArray& o
// [bS, iH, iW, iC] is rearranged/permuted to [bS*blockSize*blockSize, (iH + padBottom + padTop)/blockSize, (iW + padLeft + padRight)/blockSize, iC]
NDArray outputRearranged0 = output.reshape(output.ordering(), {blockSize, blockSize, input.sizeAt(0), output.sizeAt(1), output.sizeAt(2), output.sizeAt(3)});
NDArray outputRearranged0 = output.reshape(output.ordering(), {blockSize, blockSize, input.sizeAt(0), output.sizeAt(1), output.sizeAt(2), output.sizeAt(3)}, false);
outputRearranged0.permutei({2, 3,0, 4,1, 5});
if(input.lengthOf() == output.lengthOf()) {
outputRearranged0.assign(input);
}
else {
NDArray outputRearranged1 = outputRearranged0.reshape(output.ordering(), {input.sizeAt(0), output.sizeAt(1) * blockSize, output.sizeAt(2) * blockSize, output.sizeAt(3)});
NDArray outputRearranged1 = outputRearranged0.reshape(output.ordering(), {input.sizeAt(0), output.sizeAt(1) * blockSize, output.sizeAt(2) * blockSize, output.sizeAt(3)}, false);
BUILD_SINGLE_SELECTOR(input.dataType(), spaceToBatch_, (input, outputRearranged1, padBottom, padTop, padLeft, padRight), LIBND4J_TYPES);
if(output.getBuffer() != outputRearranged1.getBuffer())
@ -352,7 +352,7 @@ void spaceToBatchND(nd4j::LaunchContext* context, const NDArray& input, const ND
for(int j = 1; j < rank; ++i, ++j)
temp[i] = output.sizeAt(j);
NDArray outputRearranged0 = output.reshape(output.ordering(), temp);
NDArray outputRearranged0 = output.reshape(output.ordering(), temp, false);
//*** construct permuting std::vector for permutation of output array ***//
@ -382,7 +382,7 @@ void spaceToBatchND(nd4j::LaunchContext* context, const NDArray& input, const ND
for(i = 1; i < rank; ++i)
temp[i] = (i <= numOfSpatialDims) ? output.sizeAt(i) * blockShape.e<Nd4jLong>(i - 1) : output.sizeAt(i);
NDArray outputRearranged1 = outputRearranged0.reshape(output.ordering(), temp);
NDArray outputRearranged1 = outputRearranged0.reshape(output.ordering(), temp, false);
BUILD_SINGLE_SELECTOR(input.dataType(), spaceToBatchND_, (input, padding, outputRearranged1, numOfSpatialDims), LIBND4J_TYPES);

View File

@ -59,7 +59,7 @@ void FORCEINLINE cross(nd4j::LaunchContext * context, NDArray *a, NDArray *b, ND
void FORCEINLINE _crossBatched(nd4j::LaunchContext * context, NDArray *a, NDArray *b, NDArray *o) {
auto a_ = a->reshape(a->ordering(), {-1, 3});
auto b_ = b->reshape(b->ordering(), {-1, 3});
auto o_ = o->reshape(o->ordering(), {-1, 3});
auto o_ = o->reshape(o->ordering(), {-1, 3}, false);
auto tadsA = a_.allTensorsAlongDimension({1});
auto tadsB = b_.allTensorsAlongDimension({1});

View File

@ -322,7 +322,7 @@ static void depthwiseConv2d_(nd4j::graph::Context& block, const NDArray* input,
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, input->dataType(), input->getContext());
NDArray outputReshaped = output->reshape(output->ordering(), outReShape);
NDArray outputReshaped = output->reshape(output->ordering(), outReShape, false);
helpers::im2col(*output->getContext(), *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW]
MmulHelper::tensorDot(&columns, weights, &outputReshaped, modifColumns, {{2,0,1,3},{iC,kH*kW,mC}}, modifOutput); // [iC, bS*oH*oW, kW*kH] x [iC, kH*kW, mC] = [iC, bS*oH*oW, mC]
@ -1228,7 +1228,7 @@ static void conv2dBP_(nd4j::graph::Context& block, const NDArray* input, const N
NDArray* gradBR = gradB;
if(gradB->rankOf() == 2)
gradBR = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}));
gradO->reduceAlongDimension(reduce::Sum, *gradBR, gradOaxesForDot); // sum over bS, oH, oW
gradO->reduceAlongDimension(reduce::Sum, *gradBR, gradOaxesForDot, false); // sum over bS, oH, oW
if(gradBR != gradB)
delete gradBR;
}
@ -1310,7 +1310,7 @@ static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, con
NDArray* gradBR = gradB;
if(gradB->rankOf() == 2)
gradBR = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}));
gradO->reduceAlongDimension(reduce::Sum, *gradBR, {0,indOoH,indOoH+1}); // sum over bS, oH, oW
gradO->reduceAlongDimension(reduce::Sum, *gradBR, {0,indOoH,indOoH+1}, false); // sum over bS, oH, oW
if(gradBR != gradB)
delete gradBR;
}

View File

@ -313,7 +313,7 @@ void spaceToBatch(nd4j::LaunchContext* context, const NDArray& input, NDArray& o
// [bS, iH, iW, iC] is rearranged/permuted to [bS*blockSize*blockSize, (iH + padBottom + padTop)/blockSize, (iW + padLeft + padRight)/blockSize, iC]
NDArray outputRearranged0 = output.reshape(output.ordering(), {blockSize, blockSize, input.sizeAt(0), output.sizeAt(1), output.sizeAt(2), input.sizeAt(3)});
NDArray outputRearranged0 = output.reshape(output.ordering(), {blockSize, blockSize, input.sizeAt(0), output.sizeAt(1), output.sizeAt(2), input.sizeAt(3)}, false);
outputRearranged0.permutei({2, 3,0, 4,1, 5});
if(input.lengthOf() == output.lengthOf()) {
@ -322,7 +322,7 @@ void spaceToBatch(nd4j::LaunchContext* context, const NDArray& input, NDArray& o
}
else {
NDArray outputRearranged1 = outputRearranged0.reshape(output.ordering(), {input.sizeAt(0), output.sizeAt(1) * blockSize, output.sizeAt(2) * blockSize, input.sizeAt(3)});
NDArray outputRearranged1 = outputRearranged0.reshape(output.ordering(), {input.sizeAt(0), output.sizeAt(1) * blockSize, output.sizeAt(2) * blockSize, input.sizeAt(3)}, false);
const int threadsPerBlock = MAX_NUM_THREADS / 2;
const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
@ -439,7 +439,7 @@ void spaceToBatchND(nd4j::LaunchContext* context, const NDArray& input, const ND
for(int j = 1; j < rank; ++i, ++j)
temp[i] = output.sizeAt(j);
NDArray outputRearranged0 = output.reshape(output.ordering(), temp);
NDArray outputRearranged0 = output.reshape(output.ordering(), temp, false);
//*** construct permuting std::vector for permutation of output array ***//
@ -469,7 +469,7 @@ void spaceToBatchND(nd4j::LaunchContext* context, const NDArray& input, const ND
for(i = 1; i < rank; ++i)
temp[i] = (i <= numOfSpatialDims) ? output.sizeAt(i) * blockShape.e<Nd4jLong>(i - 1) : output.sizeAt(i);
NDArray outputRearranged1 = outputRearranged0.reshape(output.ordering(), temp);
NDArray outputRearranged1 = outputRearranged0.reshape(output.ordering(), temp, false);
const int threadsPerBlock = MAX_NUM_THREADS / 4;
const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;

View File

@ -471,9 +471,9 @@ PLATFORM_IMPL(lstmLayer, ENGINE_CPU) {
if(cI)
cIR = new NDArray(cI->reshape(cI->ordering(), {1,dirDim,bS,nOut}));
if(hL)
hLR = new NDArray(hL->reshape(hL->ordering(), {1,dirDim,bS,nOut}));
hLR = new NDArray(hL->reshape(hL->ordering(), {1,dirDim,bS,nOut}, false));
if(cL)
cLR = new NDArray(cL->reshape(cL->ordering(), {1,dirDim,bS,nOut}));
cLR = new NDArray(cL->reshape(cL->ordering(), {1,dirDim,bS,nOut}, false));
lstmLayerMKLDNN(xP, WxR, WrR, bR, hIR, cIR, params, hP, hLR, cLR);

View File

@ -321,6 +321,280 @@ TEST_F(DeclarableOpsTests1, TestTensorDot4) {
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests1, TestTensorDot5) {
auto x = NDArrayFactory::create<double>('c', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
auto y = NDArrayFactory::create<double>('c', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
auto expected = NDArrayFactory::create<double>('c', {2,4,2,4}, {44,110,160, 66,132, 38, 88,154, 68,170,224,102,204, 82,136,238, 92,230,288,138,276,126,184,322, 116,290,352,174,348,170,232,406, 76,190,160,114,228,182,152,266, 100,250,224,150,300,226,200,350, 124,310,288,186,372,270,248,434, 148,370,352,222,444,314,296,518});
nd4j::ops::tensormmul op;
auto results = op.evaluate({&x, &y}, {}, {1,1,1,2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *result = results->at(0);
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests1, TestTensorDot6) {
auto x = NDArrayFactory::create<double>('c', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
auto y = NDArrayFactory::create<double>('f', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
auto expected = NDArrayFactory::create<double>('c', {2,4,2,4}, {22, 66,110,154, 44, 88,132,176, 34,102,170,238, 68,136,204,272, 46,138,230,322, 92,184,276,368, 58,174,290,406,116,232,348,464, 38,114,190,266, 76,152,228,304, 50,150,250,350,100,200,300,400, 62,186,310,434,124,248,372,496, 74,222,370,518,148,296,444,592});
nd4j::ops::tensormmul op;
auto results = op.evaluate({&x, &y}, {}, {1,1,1,2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *result = results->at(0);
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests1, TestTensorDot7) {
auto x = NDArrayFactory::create<double>('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
auto y = NDArrayFactory::create<double>('c', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
auto expected = NDArrayFactory::create<double>('c', {2,4,2,4}, {76,166,112,106,196, 62,136,226, 60,174,208, 98,212,230,136,250, 76,214,336,122,260,174,168,306, 124,286,240,178,340,150,232,394, 100,226,176,142,268,106,184,310, 84,234,272,134,284,274,184,334, 100,274,400,158,332,218,216,390, 148,346,304,214,412,194,280,478});
nd4j::ops::tensormmul op;
auto results = op.evaluate({&x, &y}, {}, {1,1,1,2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *result = results->at(0);
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests1, TestTensorDot8) {
auto x = NDArrayFactory::create<double>('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
auto y = NDArrayFactory::create<double>('f', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
auto expected = NDArrayFactory::create<double>('c', {2,4,2,4}, {30, 90,150,210, 60,120,180,240, 38,114,190,266, 76,152,228,304, 46,138,230,322, 92,184,276,368, 54,162,270,378,108,216,324,432, 42,126,210,294, 84,168,252,336, 50,150,250,350,100,200,300,400, 58,174,290,406,116,232,348,464, 66,198,330,462,132,264,396,528});
nd4j::ops::tensormmul op;
auto results = op.evaluate({&x, &y}, {}, {1,1,1,2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *result = results->at(0);
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests1, TestTensorDot9) {
// NDArray z('f',{2,2,3}, nd4j::DataType::DOUBLE);
// z.linspace(1);
// z.printShapeInfo();
// z.printIndexedBuffer();
// z.reshapei('c', {4,3});
// z.printShapeInfo();
// z.printIndexedBuffer();
auto x = NDArrayFactory::create<double>('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
auto y = NDArrayFactory::create<double>('f', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
auto expected = NDArrayFactory::create<double>('c', {3,4,4,3}, {14, 14, 14, 30, 30, 30, 46, 46, 46, 62, 62, 62, 86, 86, 86,198,198,198,310,310,310,422,422,422, 62, 62, 62,142,142,142,222,222,222,302,302,302, 38, 38, 38, 86, 86, 86,134,134,134,182,182,182, 38, 38, 38, 86, 86, 86,134,134,134,182,182,182, 14, 14, 14, 30, 30, 30, 46, 46, 46, 62, 62, 62, 86, 86, 86,198,198,198,310,310,310,422,422,422, 62, 62, 62,142,142,142,222,222,222,302,302,302, 62, 62, 62,142,142,142,222,222,222,302,302,302, 38, 38, 38, 86, 86, 86,134,134,134,182,182,182, 14, 14, 14, 30, 30, 30, 46, 46, 46, 62, 62, 62, 86, 86, 86,198,198,198,310,310,310,422,422,422});
nd4j::ops::tensormmul op;
auto results = op.evaluate({&x, &y}, {}, {1,0,1,0});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *result = results->at(0);
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests1, TestTensorDot10) {
auto x = NDArrayFactory::create<double>('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
auto y = NDArrayFactory::create<double>('f', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
auto expected = NDArrayFactory::create<double>('c', {4,4}, {114,258,402,546, 138,314,490,666, 162,370,578,786, 186,426,666,906});
nd4j::ops::tensormmul op;
auto results = op.evaluate({&x, &y}, {}, {2,0,1, 2,0,2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *result = results->at(0);
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests1, TestTensorDot11) {
auto x = NDArrayFactory::create<double>('c', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
auto y = NDArrayFactory::create<double>('f', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
auto expected = NDArrayFactory::create<double>('c', {4,4}, {98,218,338,458, 134,302,470,638, 170,386,602,818, 206,470,734,998});
nd4j::ops::tensormmul op;
auto results = op.evaluate({&x, &y}, {}, {2,0,1, 2,0,2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *result = results->at(0);
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests1, TestTensorDot12) {
auto x = NDArrayFactory::create<double>('c', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
auto y = NDArrayFactory::create<double>('c', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
auto expected = NDArrayFactory::create<double>('c', {4,4}, {272,292,312,332, 368,396,424,452, 464,500,536,572, 560,604,648,692});
nd4j::ops::tensormmul op;
auto results = op.evaluate({&x, &y}, {}, {2,0,1, 2,0,2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *result = results->at(0);
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests1, TestTensorDot13) {
auto x = NDArrayFactory::create<double>('c', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
auto y = NDArrayFactory::create<double>('c', {4,2,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
auto expected = NDArrayFactory::create<double>('c', {3,3}, {640,560,640, 576,624,576, 640,560,640});
nd4j::ops::tensormmul op;
auto results = op.evaluate({&x, &y}, {}, {2,0,2, 2,1,0});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *result = results->at(0);
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests1, TestTensorDot14) {
auto x = NDArrayFactory::create<double>('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
auto y = NDArrayFactory::create<double>('c', {4,2,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
auto expected = NDArrayFactory::create<double>('c', {3,3}, {648,600,520, 648,536,648, 520,600,648});
nd4j::ops::tensormmul op;
auto results = op.evaluate({&x, &y}, {}, {2,0,2, 2,1,0});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *result = results->at(0);
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests1, TestTensorDot15) {
auto x = NDArrayFactory::create<double>('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
auto y = NDArrayFactory::create<double>('f', {4,2,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
auto expected = NDArrayFactory::create<double>('c', {3,3}, {624,624,624, 656,656,656, 624,624,624});
nd4j::ops::tensormmul op;
auto results = op.evaluate({&x, &y}, {}, {2,0,2, 2,1,0});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *result = results->at(0);
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests1, TestTensorDot16) {
NDArray x('c', {1}, std::vector<double>{2}, nd4j::DataType::FLOAT32);
NDArray y('c', {2,1,2}, {1,2,3,4}, nd4j::DataType::FLOAT32);
NDArray exp('c', {2,2}, {2,4,6,8}, nd4j::DataType::FLOAT32);
nd4j::ops::tensormmul op;
auto results = op.evaluate({&x, &y}, {}, {1,0, 1,1});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *result = results->at(0);
ASSERT_TRUE(exp.isSameShape(result));
ASSERT_TRUE(exp.equalsTo(result));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests1, TestTensorDot17) {
NDArray x('f', {16,16}, nd4j::DataType::FLOAT32);
NDArray y('f', {1000,16}, nd4j::DataType::FLOAT32);
NDArray z('c', {16,1000}, nd4j::DataType::FLOAT32);
nd4j::ops::tensormmul op;
auto status = op.execute({&x, &y}, {&z}, {}, {1,1, 1,1}, {});
ASSERT_EQ(ND4J_STATUS_OK, status);
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests1, DivergentCheck1) {
auto op = nd4j::ops::OpRegistrator::getInstance()->getOperation("switch");

View File

@ -708,30 +708,6 @@ TEST_F(DeclarableOpsTests12, multiUnique_2) {
ASSERT_TRUE(nd4j::ops::helpers::multiUnique(arrayList));
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests12, tensormmul_6) {
NDArray x('c', {1}, std::vector<double>{2}, nd4j::DataType::FLOAT32);
NDArray y('c', {2,1,2}, {1,2,3,4}, nd4j::DataType::FLOAT32);
NDArray exp('c', {2,2}, {2,4,6,8}, nd4j::DataType::FLOAT32);
nd4j::ops::tensormmul op;
auto results = op.evaluate({&x, &y}, {}, {1,0, 1,1});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *result = results->at(0);
// exp.printShapeInfo();
// result->printShapeInfo();
// result->printIndexedBuffer();
ASSERT_TRUE(exp.isSameShape(result));
ASSERT_TRUE(exp.equalsTo(result));
delete results;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests12, reduceMeanBp_4) {

View File

@ -1560,3 +1560,447 @@ TEST_F(DeclarableOpsTests15, Pow_BP_Test11) {
delete resultsB;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP1) {
NDArray A('c', { 1, 2, 3 }, { 2.1, 2.2, 2.3, 2.4, 2.5, 2.6 }, nd4j::DataType::FLOAT32);
NDArray B('c', { 1, 2, 4 }, { 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8 }, nd4j::DataType::FLOAT32);
NDArray dLdC('c', { 3, 4 }, { .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.1 }, nd4j::DataType::FLOAT32);
NDArray dLdA('c', { 1, 2, 3 }, { 3.3, 8.5, 13.36, 3.7, 9.54, 15. }, nd4j::DataType::FLOAT32);
NDArray dLdB('c', { 1, 2, 4 }, { 3.38, 4.04, 4.7, 5.13, 3.83, 4.58, 5.33, 5.82 }, nd4j::DataType::FLOAT32);
nd4j::ops::tensormmul_bp op_bp;
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,0,1, 2,0,1 }, {});
ASSERT_EQ(ND4J_STATUS_OK, resultsBP->status());
auto* dLdAbp = resultsBP->at(0);
auto* dLdBbp = resultsBP->at(1);
ASSERT_TRUE(dLdA.isSameShape(*dLdAbp));
ASSERT_TRUE(dLdA.equalsTo(*dLdAbp));
ASSERT_TRUE(dLdB.isSameShape(*dLdBbp));
ASSERT_TRUE(dLdB.equalsTo(*dLdBbp));
delete resultsBP;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP2) {
NDArray A('c', { 1, 2, 3 }, { 2,2,2, 2,2,2 }, nd4j::DataType::FLOAT32);
NDArray B('c', { 1, 2, 3 }, { 3,3,3,3, 3,3 }, nd4j::DataType::FLOAT32);
NDArray dLdC('c', { 1 }, { 1 }, nd4j::DataType::FLOAT32);
nd4j::ops::tensormmul_bp op_bp;
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }, {});
ASSERT_EQ(ND4J_STATUS_OK, resultsBP->status());
auto* dLdAbp = resultsBP->at(0);
auto* dLdBbp = resultsBP->at(1);
ASSERT_TRUE(B.isSameShape(*dLdAbp));
ASSERT_TRUE(B.equalsTo(*dLdAbp));
ASSERT_TRUE(A.isSameShape(*dLdBbp));
ASSERT_TRUE(A.equalsTo(*dLdBbp));
delete resultsBP;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP3) {
NDArray A('c', { 3, 2, 2 }, { 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2 }, nd4j::DataType::FLOAT32);
NDArray B('c', { 4, 2, 2 }, { 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2, 4.3, 4.4, 4.5, 4.6 }, nd4j::DataType::FLOAT32);
NDArray dLdC('c', { 3, 4 }, { .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2 }, nd4j::DataType::FLOAT32);
NDArray dA('c', { 3, 2, 2 }, { 3.9, 4., 4.1, 4.2, 9.82, 10.08, 10.34, 10.6, 15.74, 16.16, 16.58, 17. }, nd4j::DataType::FLOAT32);
NDArray dB('c', { 4, 2, 2 }, { 4.07, 4.22, 4.37, 4.52, 4.82, 5., 5.18, 5.36, 5.57, 5.78, 5.99, 6.2, 6.32, 6.56, 6.8, 7.04 }, nd4j::DataType::FLOAT32);
nd4j::ops::tensormmul_bp op_bp;
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }, {});
ASSERT_EQ(ND4J_STATUS_OK, resultsBP->status());
auto* dLdAbp = resultsBP->at(0);
auto* dLdBbp = resultsBP->at(1);
ASSERT_TRUE(dA.isSameShape(*dLdAbp));
ASSERT_TRUE(dA.equalsTo(*dLdAbp));
ASSERT_TRUE(dB.isSameShape(*dLdBbp));
ASSERT_TRUE(dB.equalsTo(*dLdBbp));
delete resultsBP;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP4) {
NDArray A('c', { 3, 4, 1 }, { 0.4, 3, 5, 9, 23, 0.12, 8, 9, 0.1, 0, 124, 3 }, nd4j::DataType::FLOAT32);
NDArray B('c', { 2, 4, 1 }, { 4, 13, .5, 19, 2.3, 1.2, 18, .9 }, nd4j::DataType::FLOAT32);
NDArray dLdC('c', { 3, 2 }, { 1.1, 1.2, 1.3, 1.4, 1.5, 1.6 }, nd4j::DataType::FLOAT32);
NDArray dLdA('c', { 3, 4, 1 }, { 7.16, 15.74, 22.15, 21.98, 8.42, 18.58, 25.85, 25.96, 9.68, 21.42, 29.55, 29.94 }, nd4j::DataType::FLOAT32);
NDArray dLdB('c', { 2, 4, 1 }, { 30.49, 3.456, 201.9, 26.1, 32.84 , 3.768, 215.6, 28.2 }, nd4j::DataType::FLOAT32);
nd4j::ops::tensormmul_bp op_bp;
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }, {});
ASSERT_EQ(ND4J_STATUS_OK, resultsBP->status());
auto* dLdAbp = resultsBP->at(0);
auto* dLdBbp = resultsBP->at(1);
ASSERT_TRUE(dLdA.isSameShape(*dLdAbp));
ASSERT_TRUE(dLdA.equalsTo(*dLdAbp));
ASSERT_TRUE(dLdB.isSameShape(*dLdBbp));
ASSERT_TRUE(dLdB.equalsTo(*dLdBbp));
delete resultsBP;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP5) {
NDArray A('c', { 3, 4, 1, 1 }, { 0.4, 3, 5, 9, 23, 0.12, 8, 9, 0.1, 0, 124, 3 }, nd4j::DataType::FLOAT32);
NDArray B('c', { 2, 4, 1, 1 }, { 4, 13, .5, 19, 2.3, 1.2, 18, .9 }, nd4j::DataType::FLOAT32);
NDArray dLdC('c', { 3, 1, 2, 1 }, { 1.1,1.2,1.3,1.4,1.5,1.6 }, nd4j::DataType::FLOAT32);
NDArray dLdA('c', { 3, 4, 1, 1 }, { 7.16, 15.74, 22.15, 21.98, 8.42, 18.58, 25.85, 25.96, 9.68, 21.42, 29.55, 29.94 }, nd4j::DataType::FLOAT32);
NDArray dLdB('c', { 2, 4, 1, 1 }, { 30.49, 3.456, 201.9, 26.1, 32.84, 3.768, 215.6, 28.2 }, nd4j::DataType::FLOAT32);
nd4j::ops::tensormmul_bp op_bp;
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }, {});
ASSERT_EQ(ND4J_STATUS_OK, resultsBP->status());
auto* dLdAbp = resultsBP->at(0);
auto* dLdBbp = resultsBP->at(1);
ASSERT_TRUE(dLdA.isSameShape(*dLdAbp));
ASSERT_TRUE(dLdA.equalsTo(*dLdAbp));
ASSERT_TRUE(dLdB.isSameShape(*dLdBbp));
ASSERT_TRUE(dLdB.equalsTo(*dLdBbp));
delete resultsBP;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP6) {
NDArray A('c', { 2, 2, 2 }, { 2,2, 2,2, 2,2, 2,2 }, nd4j::DataType::FLOAT32);
NDArray B('c', { 2, 2, 2 }, { 3,3, 3,3, 3,3, 3,3 }, nd4j::DataType::FLOAT32);
auto dLdC = NDArrayFactory::create<float>(1.f);
nd4j::ops::tensormmul_bp op_bp;
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 3,0,1,2, 3,0,1,2 }, {});
ASSERT_EQ(ND4J_STATUS_OK, resultsBP->status());
auto* dLdAbp = resultsBP->at(0);
auto* dLdBbp = resultsBP->at(1);
ASSERT_TRUE(B.isSameShape(*dLdAbp));
ASSERT_TRUE(B.equalsTo(*dLdAbp));
ASSERT_TRUE(A.isSameShape(*dLdBbp));
ASSERT_TRUE(A.equalsTo(*dLdBbp));
delete resultsBP;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP7) {
NDArray A('c', { 3, 4, 1 }, { 0.4, 3, 5, 9, 23, 0.12, 8, 9, 0.1, 0, 124, 3 }, nd4j::DataType::FLOAT32);
NDArray B('c', { 2, 4, 1 }, { 4, 13, .5, 19, 2.3, 1.2, 18, .9 }, nd4j::DataType::FLOAT32);
NDArray dLdC('c', { 3, 1, 2, 1 }, { 1.1, 1.2, 1.3, 1.4, 1.5, 1.6 }, nd4j::DataType::FLOAT32);
NDArray dLdA('c', { 3, 4, 1 }, { 7.16, 15.74, 22.15, 21.98, 8.42, 18.58, 25.85, 25.96, 9.68, 21.42, 29.55, 29.94 }, nd4j::DataType::FLOAT32);
NDArray dLdB('c', { 2, 4, 1 }, { 30.49, 3.456, 201.9, 26.1, 32.84, 3.768, 215.6, 28.2 }, nd4j::DataType::FLOAT32);
nd4j::ops::tensormmul_bp op_bp;
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 1,1, 1,1 }, {});
ASSERT_EQ(ND4J_STATUS_OK, resultsBP->status());
auto* dLdAbp = resultsBP->at(0);
auto* dLdBbp = resultsBP->at(1);
ASSERT_TRUE(dLdA.isSameShape(*dLdAbp));
ASSERT_TRUE(dLdA.equalsTo(*dLdAbp));
ASSERT_TRUE(dLdB.isSameShape(*dLdBbp));
ASSERT_TRUE(dLdB.equalsTo(*dLdBbp));
delete resultsBP;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP8) {
NDArray A('c', { 1, 1, 4, 3 }, { 0.4, 3, 5, 9, 23, 0.12, 8, 9, 0.1, 0, 124, 3 }, nd4j::DataType::FLOAT32);
NDArray B('c', { 1, 1, 4, 2 }, { 4, 13, .5, 19, 2.3, 1.2, 18, .9 }, nd4j::DataType::FLOAT32);
NDArray dLdC('c', { 3, 2 }, { 1.1,1.2,1.3,1.4,1.5,1.6 }, nd4j::DataType::FLOAT32);
NDArray dLdA('c', { 1, 1, 4, 3 }, { 20., 23.4, 26.8, 23.35, 27.25, 31.15, 3.97, 4.67, 5.37, 20.88, 24.66, 28.44 }, nd4j::DataType::FLOAT32);
NDArray dLdB('c', { 1, 1, 4, 2 }, { 11.84, 12.68, 39.98, 43.192, 20.65, 22.36, 165.7, 178.4 }, nd4j::DataType::FLOAT32);
nd4j::ops::tensormmul_bp op_bp;
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 3,0,1,2, 3,0,1,2 }, {});
ASSERT_EQ(ND4J_STATUS_OK, resultsBP->status());
auto* dLdAbp = resultsBP->at(0);
auto* dLdBbp = resultsBP->at(1);
ASSERT_TRUE(dLdA.isSameShape(*dLdAbp));
ASSERT_TRUE(dLdA.equalsTo(*dLdAbp));
ASSERT_TRUE(dLdB.isSameShape(*dLdBbp));
ASSERT_TRUE(dLdB.equalsTo(*dLdBbp));
delete resultsBP;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP9) {
NDArray A('c', { 3, 2, 2, 1 }, { 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2 }, nd4j::DataType::FLOAT32);
NDArray B('c', { 4, 2, 2 ,1 }, { 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2, 4.3, 4.4, 4.5, 4.6 }, nd4j::DataType::FLOAT32);
NDArray dLdC('c', { 3, 1, 4, 1 }, { .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2 }, nd4j::DataType::FLOAT32);
NDArray dA('c', { 3, 2, 2, 1 }, { 3.9, 4., 4.1, 4.2, 9.82, 10.08, 10.34, 10.6, 15.74, 16.16, 16.58, 17. }, nd4j::DataType::FLOAT32);
NDArray dB('c', { 4, 2, 2, 1 }, { 4.07, 4.22, 4.37, 4.52, 4.82, 5., 5.18, 5.36, 5.57, 5.78, 5.99, 6.2, 6.32, 6.56, 6.8, 7.04 }, nd4j::DataType::FLOAT32);
nd4j::ops::tensormmul_bp op_bp;
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }, {});
ASSERT_EQ(ND4J_STATUS_OK, resultsBP->status());
auto* dLdAbp = resultsBP->at(0);
auto* dLdBbp = resultsBP->at(1);
ASSERT_TRUE(dA.isSameShape(*dLdAbp));
ASSERT_TRUE(dA.equalsTo(*dLdAbp));
ASSERT_TRUE(dB.isSameShape(*dLdBbp));
ASSERT_TRUE(dB.equalsTo(*dLdBbp));
delete resultsBP;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP10) {
NDArray A('c', { 1, 2, 2, 3 }, { 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2 }, nd4j::DataType::FLOAT32);
NDArray B('c', { 1, 2, 2 ,4 }, { 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2, 4.3, 4.4, 4.5, 4.6 }, nd4j::DataType::FLOAT32);
NDArray dLdC('c', { 1, 3, 1, 4 }, { .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2 }, nd4j::DataType::FLOAT32);
NDArray dA('c', { 1, 2, 2, 3 }, { 3.3, 8.5, 13.7, 3.7, 9.54, 15.38, 4.1, 10.58, 17.06, 4.5, 11.62, 18.74 }, nd4j::DataType::FLOAT32);
NDArray dB('c', { 1, 2, 2, 4 }, { 3.38, 4.04, 4.7, 5.36, 3.83, 4.58, 5.33, 6.08, 4.28, 5.12, 5.96, 6.8, 4.73, 5.66, 6.59, 7.52 }, nd4j::DataType::FLOAT32);
nd4j::ops::tensormmul_bp op_bp;
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }, {});
ASSERT_EQ(ND4J_STATUS_OK, resultsBP->status());
auto* dLdAbp = resultsBP->at(0);
auto* dLdBbp = resultsBP->at(1);
ASSERT_TRUE(dA.isSameShape(*dLdAbp));
ASSERT_TRUE(dA.equalsTo(*dLdAbp));
ASSERT_TRUE(dB.isSameShape(*dLdBbp));
ASSERT_TRUE(dB.equalsTo(*dLdBbp));
delete resultsBP;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP11) {
NDArray A('c', { 2, 2, 3 }, { 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2 }, nd4j::DataType::FLOAT32);
NDArray B('c', { 2, 2 ,4 }, { 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2, 4.3, 4.4, 4.5, 4.6 }, nd4j::DataType::FLOAT32);
NDArray dLdC('c', { 3, 4 }, { .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2 }, nd4j::DataType::FLOAT32);
NDArray dA('c', { 2, 2, 3 }, { 3.3, 8.5, 13.7, 3.7, 9.54, 15.38, 4.1, 10.58, 17.06, 4.5, 11.62, 18.74 }, nd4j::DataType::FLOAT32);
NDArray dB('c', { 2, 2, 4 }, { 3.38, 4.04, 4.7, 5.36, 3.83, 4.58, 5.33, 6.08, 4.28, 5.12, 5.96, 6.8, 4.73, 5.66, 6.59, 7.52 }, nd4j::DataType::FLOAT32);
nd4j::ops::tensormmul_bp op_bp;
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,0,1, 2,0,1 }, {});
ASSERT_EQ(ND4J_STATUS_OK, resultsBP->status());
auto* dLdAbp = resultsBP->at(0);
auto* dLdBbp = resultsBP->at(1);
ASSERT_TRUE(dA.isSameShape(*dLdAbp));
ASSERT_TRUE(dA.equalsTo(*dLdAbp));
ASSERT_TRUE(dB.isSameShape(*dLdBbp));
ASSERT_TRUE(dB.equalsTo(*dLdBbp));
delete resultsBP;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP12) {
NDArray A('c', { 2, 2, 3 }, { 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2 }, nd4j::DataType::FLOAT32);
NDArray B('c', { 2, 2 ,3 }, { 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2 }, nd4j::DataType::FLOAT32);
NDArray dLdC('c', { 2, 3, 2, 3 }, { .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2,
1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2., 2.1, 2.2, 2.3, 2.4,
2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2, 3.3, 3.4, 3.5, 3.6 }, nd4j::DataType::FLOAT32);
NDArray dA('c', { 2, 2, 3 }, { 7.66, 20.26, 32.86, 8.29, 21.97, 35.65, 45.46, 58.06, 70.66, 49.33, 63.01, 76.69 }, nd4j::DataType::FLOAT32);
NDArray dB('c', { 2, 2, 3 }, { 25.86, 27.36, 28.86, 28.74, 30.42, 32.1, 30.36, 31.86, 33.36, 33.78, 35.46, 37.14 }, nd4j::DataType::FLOAT32);
nd4j::ops::tensormmul_bp op_bp;
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 1,1, 1,1 }, {});
ASSERT_EQ(ND4J_STATUS_OK, resultsBP->status());
auto* dLdAbp = resultsBP->at(0);
auto* dLdBbp = resultsBP->at(1);
ASSERT_TRUE(dA.isSameShape(*dLdAbp));
ASSERT_TRUE(dA.equalsTo(*dLdAbp));
ASSERT_TRUE(dB.isSameShape(*dLdBbp));
ASSERT_TRUE(dB.equalsTo(*dLdBbp));
delete resultsBP;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP13) {
NDArray A('c', { 3, 2, 2 }, { 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2 }, nd4j::DataType::DOUBLE);
NDArray B('c', { 3, 2, 2 }, { 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2 }, nd4j::DataType::DOUBLE);
NDArray dLdC('c', { 3, 2, 3, 2 }, { .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2,
1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2., 2.1, 2.2, 2.3, 2.4,
2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2, 3.3, 3.4, 3.5, 3.6 }, nd4j::DataType::DOUBLE);
NDArray dA('c', { 3, 2, 2 }, { 7.79, 20.57, 8.21, 21.71, 33.35, 46.13, 35.21, 48.71, 58.91, 71.69, 62.21, 75.71 }, nd4j::DataType::DOUBLE);
NDArray dB('c', { 3, 2, 2 }, { 26.49, 28.02, 28.41, 30.06, 29.55, 31.08, 31.71, 33.36, 32.61, 34.14, 35.01, 36.66 }, nd4j::DataType::DOUBLE);
nd4j::ops::tensormmul_bp op_bp;
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 1,1, 1,1 }, {});
ASSERT_EQ(ND4J_STATUS_OK, resultsBP->status());
auto* dLdAbp = resultsBP->at(0);
auto* dLdBbp = resultsBP->at(1);
ASSERT_TRUE(dA.isSameShape(*dLdAbp));
ASSERT_TRUE(dA.equalsTo(*dLdAbp));
ASSERT_TRUE(dB.isSameShape(*dLdBbp));
ASSERT_TRUE(dB.equalsTo(*dLdBbp));
delete resultsBP;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP14) {
NDArray A('c', { 2, 2, 2, 2 }, { 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2, 3.3, 3.4, 3.5, 3.6 }, nd4j::DataType::DOUBLE);
NDArray B('c', { 2, 2, 2, 2 }, { 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2, 4.3, 4.4, 4.5, 4.6 }, nd4j::DataType::DOUBLE);
NDArray dLdC('c', { 2, 2, 2, 2, 2, 2 }, { .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2,
1.3, 1.4, 1.5, 1.6, .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2,
1.3, 1.4, 1.5, 1.6, .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2,
1.3, 1.4, 1.5, 1.6, .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2,
1.3, 1.4, 1.5, 1.6 }, nd4j::DataType::DOUBLE);
NDArray dA('c', { 2, 2, 2, 2 }, { 13.88, 37.24, 13.88, 37.24, 15.32, 41.24, 15.32, 41.24, 13.88, 37.24, 13.88, 37.24, 15.32, 41.24, 15.32, 41.24 }, nd4j::DataType::DOUBLE);
NDArray dB('c', { 2, 2, 2, 2 }, { 10.76, 12.88, 15., 17.12, 12.36, 14.8, 17.24, 19.68, 19.24, 21.36, 23.48, 25.6, 22.12, 24.56, 27., 29.44 }, nd4j::DataType::DOUBLE);
nd4j::ops::tensormmul_bp op_bp;
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 1,1, 1,1 }, {});
ASSERT_EQ(ND4J_STATUS_OK, resultsBP->status());
auto* dLdAbp = resultsBP->at(0);
auto* dLdBbp = resultsBP->at(1);
ASSERT_TRUE(dA.isSameShape(*dLdAbp));
ASSERT_TRUE(dA.equalsTo(*dLdAbp));
ASSERT_TRUE(dB.isSameShape(*dLdBbp));
ASSERT_TRUE(dB.equalsTo(*dLdBbp));
delete resultsBP;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP15) {
NDArray A('c', { 2, 2, 3 }, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12. }, nd4j::DataType::FLOAT32);
NDArray B('f', { 2, 2, 3 }, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12. }, nd4j::DataType::FLOAT32);
NDArray dLdC('f', { 2, 2 }, { 23.0, 24.44, 2.0, 26. }, nd4j::DataType::FLOAT32);
NDArray dA('c', { 2, 2, 3 }, { 27., 127., 227., 77., 177., 277., 76.44, 278.20001, 479.96002, 177.32, 379.08001, 580.839966 }, nd4j::DataType::FLOAT32);
NDArray dB('f', { 2, 2, 3 }, { 194.08, 184., 336.4, 268., 241.52, 212., 383.839996, 296., 288.96002, 240., 431.27999, 324. }, nd4j::DataType::FLOAT32);
nd4j::ops::tensormmul_bp op;
auto results = op.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2,2,1,2 });
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto* dLdA = results->at(0);
auto* dLdB = results->at(1);
ASSERT_TRUE(dA.isSameShape(*dLdA));
ASSERT_TRUE(dA.equalsTo(*dLdA));
ASSERT_TRUE(dB.isSameShape(*dLdB));
ASSERT_TRUE(dB.equalsTo(*dLdB));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP16) {
NDArray A('f', { 2, 2, 3 }, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12. }, nd4j::DataType::DOUBLE);
NDArray B('c', { 2, 2, 3 }, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12. }, nd4j::DataType::DOUBLE);
NDArray dLdC('c', { 2, 2 }, nd4j::DataType::DOUBLE);
const OpArgsHolder argsHolderFF({ &A, &B }, {}, { 2,1,2, 2,1,2 });
const OpArgsHolder argsHolderBP({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 });
nd4j::ops::tensormmul op;
nd4j::ops::tensormmul_bp op_bp;
const bool isGradCorrect = GradCheck::checkGrad(op, op_bp, argsHolderFF, argsHolderBP, {1,0});
ASSERT_TRUE(isGradCorrect);
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP17) {
NDArray A('f', { 2, 2, 3 }, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12. }, nd4j::DataType::DOUBLE);
NDArray B('f', { 2, 2, 3 }, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12. }, nd4j::DataType::DOUBLE);
NDArray dLdC('c', { 2, 2 }, nd4j::DataType::DOUBLE);
const OpArgsHolder argsHolderFF({ &A, &B }, {}, { 2,1,2, 2,1,2 });
const OpArgsHolder argsHolderBP({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 });
nd4j::ops::tensormmul op;
nd4j::ops::tensormmul_bp op_bp;
const bool isGradCorrect = GradCheck::checkGrad(op, op_bp, argsHolderFF, argsHolderBP, { 1,0 });
ASSERT_TRUE(isGradCorrect);
}

View File

@ -578,246 +578,6 @@ TEST_F(DeclarableOpsTests2, Test_Concat_BP_1) {
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests2, TestTensorDot5) {
auto x = NDArrayFactory::create<double>('c', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
auto y = NDArrayFactory::create<double>('c', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
auto expected = NDArrayFactory::create<double>('c', {2,4,2,4}, {44,110,160, 66,132, 38, 88,154, 68,170,224,102,204, 82,136,238, 92,230,288,138,276,126,184,322, 116,290,352,174,348,170,232,406, 76,190,160,114,228,182,152,266, 100,250,224,150,300,226,200,350, 124,310,288,186,372,270,248,434, 148,370,352,222,444,314,296,518});
nd4j::ops::tensormmul op;
auto results = op.evaluate({&x, &y}, {}, {1,1,1,2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *result = results->at(0);
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests2, TestTensorDot6) {
auto x = NDArrayFactory::create<double>('c', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
auto y = NDArrayFactory::create<double>('f', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
auto expected = NDArrayFactory::create<double>('c', {2,4,2,4}, {22, 66,110,154, 44, 88,132,176, 34,102,170,238, 68,136,204,272, 46,138,230,322, 92,184,276,368, 58,174,290,406,116,232,348,464, 38,114,190,266, 76,152,228,304, 50,150,250,350,100,200,300,400, 62,186,310,434,124,248,372,496, 74,222,370,518,148,296,444,592});
nd4j::ops::tensormmul op;
auto results = op.evaluate({&x, &y}, {}, {1,1,1,2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *result = results->at(0);
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests2, TestTensorDot7) {
auto x = NDArrayFactory::create<double>('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
auto y = NDArrayFactory::create<double>('c', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
auto expected = NDArrayFactory::create<double>('c', {2,4,2,4}, {76,166,112,106,196, 62,136,226, 60,174,208, 98,212,230,136,250, 76,214,336,122,260,174,168,306, 124,286,240,178,340,150,232,394, 100,226,176,142,268,106,184,310, 84,234,272,134,284,274,184,334, 100,274,400,158,332,218,216,390, 148,346,304,214,412,194,280,478});
nd4j::ops::tensormmul op;
auto results = op.evaluate({&x, &y}, {}, {1,1,1,2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *result = results->at(0);
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests2, TestTensorDot8) {
auto x = NDArrayFactory::create<double>('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
auto y = NDArrayFactory::create<double>('f', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
auto expected = NDArrayFactory::create<double>('c', {2,4,2,4}, {30, 90,150,210, 60,120,180,240, 38,114,190,266, 76,152,228,304, 46,138,230,322, 92,184,276,368, 54,162,270,378,108,216,324,432, 42,126,210,294, 84,168,252,336, 50,150,250,350,100,200,300,400, 58,174,290,406,116,232,348,464, 66,198,330,462,132,264,396,528});
nd4j::ops::tensormmul op;
auto results = op.evaluate({&x, &y}, {}, {1,1,1,2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *result = results->at(0);
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests2, TestTensorDot9) {
// NDArray z('f',{2,2,3}, nd4j::DataType::DOUBLE);
// z.linspace(1);
// z.printShapeInfo();
// z.printIndexedBuffer();
// z.reshapei('c', {4,3});
// z.printShapeInfo();
// z.printIndexedBuffer();
auto x = NDArrayFactory::create<double>('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
auto y = NDArrayFactory::create<double>('f', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
auto expected = NDArrayFactory::create<double>('c', {3,4,4,3}, {14, 14, 14, 30, 30, 30, 46, 46, 46, 62, 62, 62, 86, 86, 86,198,198,198,310,310,310,422,422,422, 62, 62, 62,142,142,142,222,222,222,302,302,302, 38, 38, 38, 86, 86, 86,134,134,134,182,182,182, 38, 38, 38, 86, 86, 86,134,134,134,182,182,182, 14, 14, 14, 30, 30, 30, 46, 46, 46, 62, 62, 62, 86, 86, 86,198,198,198,310,310,310,422,422,422, 62, 62, 62,142,142,142,222,222,222,302,302,302, 62, 62, 62,142,142,142,222,222,222,302,302,302, 38, 38, 38, 86, 86, 86,134,134,134,182,182,182, 14, 14, 14, 30, 30, 30, 46, 46, 46, 62, 62, 62, 86, 86, 86,198,198,198,310,310,310,422,422,422});
nd4j::ops::tensormmul op;
auto results = op.evaluate({&x, &y}, {}, {1,0,1,0});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *result = results->at(0);
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests2, TestTensorDot10) {
auto x = NDArrayFactory::create<double>('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
auto y = NDArrayFactory::create<double>('f', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
auto expected = NDArrayFactory::create<double>('c', {4,4}, {114,258,402,546, 138,314,490,666, 162,370,578,786, 186,426,666,906});
nd4j::ops::tensormmul op;
auto results = op.evaluate({&x, &y}, {}, {2,0,1, 2,0,2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *result = results->at(0);
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests2, TestTensorDot11) {
auto x = NDArrayFactory::create<double>('c', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
auto y = NDArrayFactory::create<double>('f', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
auto expected = NDArrayFactory::create<double>('c', {4,4}, {98,218,338,458, 134,302,470,638, 170,386,602,818, 206,470,734,998});
nd4j::ops::tensormmul op;
auto results = op.evaluate({&x, &y}, {}, {2,0,1, 2,0,2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *result = results->at(0);
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests2, TestTensorDot12) {
auto x = NDArrayFactory::create<double>('c', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
auto y = NDArrayFactory::create<double>('c', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
auto expected = NDArrayFactory::create<double>('c', {4,4}, {272,292,312,332, 368,396,424,452, 464,500,536,572, 560,604,648,692});
nd4j::ops::tensormmul op;
auto results = op.evaluate({&x, &y}, {}, {2,0,1, 2,0,2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *result = results->at(0);
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests2, TestTensorDot13) {
auto x = NDArrayFactory::create<double>('c', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
auto y = NDArrayFactory::create<double>('c', {4,2,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
auto expected = NDArrayFactory::create<double>('c', {3,3}, {640,560,640, 576,624,576, 640,560,640});
nd4j::ops::tensormmul op;
auto results = op.evaluate({&x, &y}, {}, {2,0,2, 2,1,0});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *result = results->at(0);
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests2, TestTensorDot14) {
auto x = NDArrayFactory::create<double>('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
auto y = NDArrayFactory::create<double>('c', {4,2,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
auto expected = NDArrayFactory::create<double>('c', {3,3}, {648,600,520, 648,536,648, 520,600,648});
nd4j::ops::tensormmul op;
auto results = op.evaluate({&x, &y}, {}, {2,0,2, 2,1,0});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *result = results->at(0);
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests2, TestTensorDot15) {
auto x = NDArrayFactory::create<double>('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
auto y = NDArrayFactory::create<double>('f', {4,2,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
auto expected = NDArrayFactory::create<double>('c', {3,3}, {624,624,624, 656,656,656, 624,624,624});
nd4j::ops::tensormmul op;
auto results = op.evaluate({&x, &y}, {}, {2,0,2, 2,1,0});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *result = results->at(0);
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_1) {

View File

@ -2043,34 +2043,6 @@ TEST_F(DeclarableOpsTests9, cumprod_test1) {
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {1, 1}, {1, 1},GradCheck::MEAN);
ASSERT_TRUE(isGradCorrect);
//************************************//
/* exclusive = 1; reverse = 0;
result = op.execute({&inputC, &axis}, {}, {exclusive, reverse});
ASSERT_EQ(Status::OK(), result->status());
z = result->at(0);
ASSERT_TRUE(expTF.equalsTo(z));
delete result;
*/
//************************************//
/* exclusive = 0; reverse = 1;
result = op.execute({&inputC, &axis}, {}, {exclusive, reverse});
ASSERT_EQ(Status::OK(), result->status());
z = result->at(0);
ASSERT_TRUE(expFT.equalsTo(z));
delete result;
*/
//************************************//
/* exclusive = 1; reverse = 1;
result = op.execute({&inputC, &axis}, {}, {exclusive, reverse});
ASSERT_EQ(Status::OK(), result->status());
z = result->at(0);
ASSERT_TRUE(expTT.equalsTo(z));
delete result;
*/
}
////////////////////////////////////////////////////////////////////////////////
@ -2079,11 +2051,6 @@ TEST_F(DeclarableOpsTests9, cumprod_test2) {
auto inputC = NDArrayFactory::create<double>('c', {2, 2});
auto axis = NDArrayFactory::create<double>(1.);
// auto expFF = NDArrayFactory::create<double>('c', {3, 5}, {1., 2., 6., 24., 120., 6., 42., 336., 3024., 30240.,11., 132.,1716., 24024.,360360.});
// auto expTF = NDArrayFactory::create<double>('c', {3, 5}, {1, 1, 2, 6, 24,1, 6, 42, 336, 3024,1, 11, 132, 1716, 24024});
// auto expFT = NDArrayFactory::create<double>('c', {3, 5}, {120, 120, 60, 20, 5,30240, 5040, 720, 90, 10,360360, 32760, 2730, 210, 15}); //+++
// auto expTT = NDArrayFactory::create<double>('c', {3, 5}, {120, 60, 20, 5, 1,5040, 720, 90, 10, 1,32760, 2730, 210, 15, 1});
auto gradO = NDArrayFactory::create<double>('c', {2, 2});
int exclusive, reverse;