Oleh tenzor mmul (#231)
* Libnd4j: TensorMMul backprop op #8174, raw implementation Signed-off-by: Oleg <oleg.semeniv@gmail.com> * Libnd4j: TensorMMul backprop op #8174 merge master and some corrections Signed-off-by: Oleg <oleg.semeniv@gmail.com> * Libnd4j: TensorMMul backprop op #8174 algorithm update, need testing, sync with master * Libnd4j: TensorMMul backprop op #8174 fixed incorrect B axes calculation Signed-off-by: Oleg <oleg.semeniv@gmail.com> * Libnd4j: TensorMMul backprop op #8174 optimize axes identification and fix bug of indeces overlapping, added first test. need testing with different shapes Signed-off-by: Oleg <oleg.semeniv@gmail.com> * Libnd4j: TensorMMul backprop op #8174 some fixes and improvements need more testing Signed-off-by: Oleg <oleg.semeniv@gmail.com> * Libnd4j: TensorMMul backprop op #8174 fixed order of matrix multiply Signed-off-by: Oleg <oleg.semeniv@gmail.com> * Libnd4j: TensorMMul backprop op #8174 fixed issue of incorrect axes definition, add tests based on TF, need additional testing for case dLdC not equal 1 Signed-off-by: Oleg <oleg.semeniv@gmail.com> * Libnd4j: TensorMMul backprop op #8174 fixed scalar case add test Signed-off-by: Oleg <oleg.semeniv@gmail.com> * Libnd4j: TensorMMul backprop op #8174 fixed bp algorithm, axes definition, need some mode testing with different orders combination f,c; c,f f,f and add some checks for inputs Signed-off-by: Oleg <oleg.semeniv@gmail.com> * Libnd4j: TensorMMul backprop op #8174 some checks and corrections added tests, exists the problem with different input orders support A-f B-c and A-f B-f Signed-off-by: Oleg <oleg.semeniv@gmail.com> * Libnd4j: TensorMMul backprop op #8174 sync master Signed-off-by: Oleg <oleg.semeniv@gmail.com> * - correct bug in MmulHelper::tensorDot(a, b, c, axes_a, axes_b,permutForC) Signed-off-by: Yurii <iuriish@yahoo.com> * Libnd4j: TensorMMul backprop op #8174 code clean up and refactoring Signed-off-by: Oleg <oleg.semeniv@gmail.com> * - add check for linspase ordered permutations in ShapeUtils::evalShapeForTensorDot Signed-off-by: Yurii <iuriish@yahoo.com> * - provide additional code in shape::reshape stuff in order to reduce amount of allocation/copy operations during reshaping procedure Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on problem of wrong shape evaluation during permute/reshape procedures Signed-off-by: Yurii <iuriish@yahoo.com> * - still looking for bug reason in reshape/permute stuff Signed-off-by: Yurii <iuriish@yahoo.com> * - correct bug in transform cuda native ops Signed-off-by: Yurii <iuriish@yahoo.com> * - correct bug in NDArray::assign Signed-off-by: Yurii <iuriish@yahoo.com> * - remove old shape::reshape stuff Signed-off-by: Yurii <iuriish@yahoo.com> * - add possibility to disable copy of old buffer to new buffer during reshape operation in NDArray class Signed-off-by: Yurii <iuriish@yahoo.com> * - correct bug in tensorDot which had to do with wrong pointers assigments Signed-off-by: Yurii <iuriish@yahoo.com> Co-authored-by: Oleh <oleg.semeniv@gmail.com>master
parent
8c0e378ec3
commit
fe47f52896
|
@ -999,14 +999,14 @@ namespace nd4j {
|
||||||
* set new order and shape in case of suitable array length (in-place operation)
|
* set new order and shape in case of suitable array length (in-place operation)
|
||||||
* order - order to set
|
* order - order to set
|
||||||
* shape - shape to set
|
* shape - shape to set
|
||||||
*
|
* copyToNewBuff - if true then old buffer will be copied to new buffer if last one will be allocated after reshaping
|
||||||
* if there was permute applied before or there are weird strides, then new buffer is allocated for array
|
* if there was permute applied before or there are weird strides, then new buffer is allocated for array
|
||||||
*/
|
*/
|
||||||
bool reshapei(const char order, const std::initializer_list<Nd4jLong>& shape);
|
bool reshapei(const char order, const std::initializer_list<Nd4jLong>& shape, const bool copyToNewBuff = true);
|
||||||
bool reshapei(const char order, const std::vector<Nd4jLong>& shape);
|
bool reshapei(const char order, const std::vector<Nd4jLong>& shape, const bool copyToNewBuff = true);
|
||||||
|
|
||||||
bool reshapei(const std::initializer_list<Nd4jLong>& shape);
|
bool reshapei(const std::initializer_list<Nd4jLong>& shape, const bool copyToNewBuff = true);
|
||||||
bool reshapei(const std::vector<Nd4jLong>& shape);
|
bool reshapei(const std::vector<Nd4jLong>& shape, const bool copyToNewBuff = true);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* creates new array with corresponding order and shape, new array will point on _buffer of this array
|
* creates new array with corresponding order and shape, new array will point on _buffer of this array
|
||||||
|
@ -1015,8 +1015,8 @@ namespace nd4j {
|
||||||
*
|
*
|
||||||
* if permute have been applied before or there are weird strides, then new buffer is allocated for new array
|
* if permute have been applied before or there are weird strides, then new buffer is allocated for new array
|
||||||
*/
|
*/
|
||||||
NDArray reshape(const char order, const std::vector<Nd4jLong>& shape) const &;
|
NDArray reshape(const char order, const std::vector<Nd4jLong>& shape, const bool copyToNewBuff = true) const &;
|
||||||
NDArray reshape(const char order, const std::vector<Nd4jLong>& shape) &&;
|
NDArray reshape(const char order, const std::vector<Nd4jLong>& shape, const bool copyToNewBuff = true) &&;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* calculate strides and set given order
|
* calculate strides and set given order
|
||||||
|
|
|
@ -1197,8 +1197,8 @@ void NDArray::assign(const NDArray& other, bool allowParallelism) {
|
||||||
throw std::runtime_error("NDArray::assign: lengths of arrays are mismatched");
|
throw std::runtime_error("NDArray::assign: lengths of arrays are mismatched");
|
||||||
}
|
}
|
||||||
|
|
||||||
// memcpy is allowed only for same order && same ews (being equal to 1)
|
// memcpy is allowed only for same order c && same ews (being equal to 1)
|
||||||
if (ordering() == other.ordering() && dataType() == other.dataType() && ews() == 1 && other.ews() == 1)
|
if (ordering() == other.ordering() && ordering() == 'c' && dataType() == other.dataType() && ews() == 1 && other.ews() == 1)
|
||||||
copyBuffersContinuouslyFrom(other, other.lengthOf() * other.sizeOfT());
|
copyBuffersContinuouslyFrom(other, other.lengthOf() * other.sizeOfT());
|
||||||
else {
|
else {
|
||||||
NDArray::prepareSpecialUse({this}, {&other});
|
NDArray::prepareSpecialUse({this}, {&other});
|
||||||
|
@ -1569,20 +1569,25 @@ Nd4jLong NDArray::tensorsAlongDimension(const std::vector<int>& dimensions) cons
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
void NDArray::printShapeInfo(const char * msg) const {
|
void NDArray::printShapeInfo(const char * msg) const {
|
||||||
//shape::printShapeInfo(_shapeInfo);
|
|
||||||
if (msg == nullptr)
|
|
||||||
shape::printShapeInfoLinear(_shapeInfo);
|
|
||||||
else {
|
|
||||||
int rank = shape::rank(_shapeInfo);
|
int rank = shape::rank(_shapeInfo);
|
||||||
int lim = shape::shapeInfoLength(rank);
|
int lim = shape::shapeInfoLength(rank);
|
||||||
printf("%s: [", msg);
|
|
||||||
for (int i = 0; i < shape::shapeInfoLength(rank); i++) {
|
if(msg != nullptr)
|
||||||
printf("%lld", (long long) _shapeInfo[i]);
|
printf("shapeInfo %s: [", msg);
|
||||||
if (i < lim - 1)
|
else
|
||||||
printf(", ");
|
printf("shapeInfo: [");
|
||||||
}
|
|
||||||
printf("]\n");
|
printf("%i, ", rank);
|
||||||
|
for (int i = 1; i < shape::shapeInfoLength(rank) - 3; i++){
|
||||||
|
if(i == rank + 1)
|
||||||
|
printf(" ");
|
||||||
|
printf("%lld,", _shapeInfo[i]);
|
||||||
}
|
}
|
||||||
|
printf(" %lld,", shape::type(_shapeInfo));
|
||||||
|
printf("%lld,", shape::elementWiseStride(_shapeInfo));
|
||||||
|
printf("%lld]\n", (Nd4jLong)shape::order(_shapeInfo));
|
||||||
|
|
||||||
fflush(stdout);
|
fflush(stdout);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1855,19 +1860,19 @@ void NDArray::updateStrides(const char order) {
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
// set new order and shape in case of suitable array length
|
// set new order and shape in case of suitable array length
|
||||||
bool NDArray::reshapei(const char order, const std::initializer_list<Nd4jLong>& shape) {
|
bool NDArray::reshapei(const char order, const std::initializer_list<Nd4jLong>& shape, const bool copyToNewBuff) {
|
||||||
std::vector<Nd4jLong> vShape(shape);
|
std::vector<Nd4jLong> vShape(shape);
|
||||||
return reshapei(order, vShape);
|
return reshapei(order, vShape, copyToNewBuff);
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
bool NDArray::reshapei(const std::initializer_list<Nd4jLong>& shape) {
|
bool NDArray::reshapei(const std::initializer_list<Nd4jLong>& shape, const bool copyToNewBuff) {
|
||||||
return reshapei('c', shape);
|
return reshapei(ordering(), shape, copyToNewBuff);
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
bool NDArray::reshapei(const std::vector<Nd4jLong>& shape) {
|
bool NDArray::reshapei(const std::vector<Nd4jLong>& shape, const bool copyToNewBuff) {
|
||||||
return reshapei('c', shape);
|
return reshapei(ordering(), shape, copyToNewBuff);
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1918,18 +1923,18 @@ Nd4jLong NDArray::argMax(std::initializer_list<int> dimensions) {
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
// create new array with corresponding order and shape, new array will point to the same _buffer as this array
|
// create new array with corresponding order and shape, new array will point to the same _buffer as this array
|
||||||
NDArray NDArray::reshape(const char order, const std::vector<Nd4jLong>& shape) const & {
|
NDArray NDArray::reshape(const char order, const std::vector<Nd4jLong>& shape, const bool copyToNewBuff) const & {
|
||||||
|
|
||||||
NDArray newArr(getDataBuffer(), ShapeDescriptor(getShapeInfo()), getContext(), getBufferOffset());
|
NDArray newArr(getDataBuffer(), ShapeDescriptor(getShapeInfo()), getContext(), getBufferOffset());
|
||||||
newArr.reshapei(order, shape);
|
newArr.reshapei(order, shape, copyToNewBuff);
|
||||||
|
|
||||||
return newArr;
|
return newArr;
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
NDArray NDArray::reshape(const char order, const std::vector<Nd4jLong>& shape) && {
|
NDArray NDArray::reshape(const char order, const std::vector<Nd4jLong>& shape, const bool copyToNewBuff) && {
|
||||||
|
|
||||||
this->reshapei(order, shape);
|
this->reshapei(order, shape, copyToNewBuff);
|
||||||
return std::move(*this);
|
return std::move(*this);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3221,7 +3226,7 @@ BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT std::vector, NDArray::asVectorT(), LI
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
// set new order and shape in case of suitable array length
|
// set new order and shape in case of suitable array length
|
||||||
bool NDArray::reshapei(const char order, const std::vector<Nd4jLong>& cshape) {
|
bool NDArray::reshapei(const char order, const std::vector<Nd4jLong>& cshape, const bool copyToNewBuff) {
|
||||||
|
|
||||||
// check firstly whether cshape is identical to shape of array, if yes then reshape is unnecessary
|
// check firstly whether cshape is identical to shape of array, if yes then reshape is unnecessary
|
||||||
if(order == ordering() && shape::shapeEquals(rankOf(), shapeOf(), cshape.size(), cshape.data()))
|
if(order == ordering() && shape::shapeEquals(rankOf(), shapeOf(), cshape.size(), cshape.data()))
|
||||||
|
@ -3293,18 +3298,14 @@ bool NDArray::reshapei(const char order, const std::vector<Nd4jLong>& cshape) {
|
||||||
Nd4jLong *shapeInfoNew;
|
Nd4jLong *shapeInfoNew;
|
||||||
ALLOCATE(shapeInfoNew, getContext()->getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong);
|
ALLOCATE(shapeInfoNew, getContext()->getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong);
|
||||||
|
|
||||||
bool canReshape = shape::reshapeC(rankOf(), shapeInfo(), shape.size(), shape.data(), shapeInfoNew);
|
bool canReshape = shape::reshapeC(shapeInfo(), order, shape.size(), shape.data(), shapeInfoNew);
|
||||||
|
|
||||||
// we can do this only if there was no permute applied, or there are no weird strides
|
|
||||||
if (canReshape) {
|
if (canReshape) {
|
||||||
if(ordering() == 'c' && order == 'f')
|
|
||||||
throw std::invalid_argument("NDArray::reshapei(order, shape): in case of reshapeC it doesn't make sense to reshape from c order to f order !");
|
|
||||||
|
|
||||||
shape::setEws(shapeInfoNew, arrLength);
|
|
||||||
setShapeInfo(shapeInfoNew);
|
setShapeInfo(shapeInfoNew);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
NDArray temp(order, shape, dataType(), getContext());
|
NDArray temp(order, shape, dataType(), getContext());
|
||||||
|
if(copyToNewBuff)
|
||||||
this->applyTransform(transform::Assign, temp, nullptr);
|
this->applyTransform(transform::Assign, temp, nullptr);
|
||||||
*this = std::move(temp);
|
*this = std::move(temp);
|
||||||
}
|
}
|
||||||
|
@ -4846,7 +4847,7 @@ NDArray NDArray::operator()(const std::vector<Nd4jLong>& idx, const bool keepUni
|
||||||
auto shapeOf = shape::shapeOf(newShapeInfo);
|
auto shapeOf = shape::shapeOf(newShapeInfo);
|
||||||
auto stridesOf = shape::stride(newShapeInfo);
|
auto stridesOf = shape::stride(newShapeInfo);
|
||||||
|
|
||||||
Nd4jLong offset(0), subArrLen(1);
|
Nd4jLong offset = 0;
|
||||||
int n(isStrided ? 3 : 2), first, last, stride;
|
int n(isStrided ? 3 : 2), first, last, stride;
|
||||||
|
|
||||||
for (int d = rank - 1; d >= 0; --d) {
|
for (int d = rank - 1; d >= 0; --d) {
|
||||||
|
@ -4863,29 +4864,31 @@ NDArray NDArray::operator()(const std::vector<Nd4jLong>& idx, const bool keepUni
|
||||||
if(shapeOf[d] != 1)
|
if(shapeOf[d] != 1)
|
||||||
stridesOf[d] *= stride;
|
stridesOf[d] *= stride;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
subArrLen *= shapeOf[d];
|
Nd4jLong *shapeInfoNoUnities = newShapeInfo;
|
||||||
|
|
||||||
|
if(!keepUnitiesInShape) {
|
||||||
|
|
||||||
|
std::vector<int> dimsWithUnities;
|
||||||
|
|
||||||
|
for (uint d = 0; d < rank; ++d)
|
||||||
|
if(idx[n*d] != idx[n*d+1] && shapeOf[d] == 1)
|
||||||
|
dimsWithUnities.push_back(d);
|
||||||
|
|
||||||
|
if(!dimsWithUnities.empty())
|
||||||
|
shapeInfoNoUnities = ShapeBuilders::copyShapeInfoWithoutUnites(newShapeInfo, dimsWithUnities.size(), dimsWithUnities.data(), getContext()->getWorkspace());
|
||||||
}
|
}
|
||||||
|
|
||||||
// check if there is possibility to set ews = 1
|
// check if there is possibility to set ews = 1
|
||||||
shape::setEws(newShapeInfo, subArrLen);
|
shape::checkStridesSetEwsAndOrder(shapeInfoNoUnities);
|
||||||
|
|
||||||
NDArray result(_buffer, ShapeDescriptor(newShapeInfo), getContext(), offset + getBufferOffset());
|
NDArray result(_buffer, ShapeDescriptor(shapeInfoNoUnities), getContext(), offset + getBufferOffset());
|
||||||
result._isView = true;
|
result._isView = true;
|
||||||
|
|
||||||
if(!keepUnitiesInShape) {
|
|
||||||
const int coeff = isStrided ? 3 : 2;
|
|
||||||
std::vector<Nd4jLong> nonUnitDims;
|
|
||||||
|
|
||||||
for (int d = 0; d < rank; ++d)
|
|
||||||
if(!(idx[coeff*d] != idx[coeff*d+1] && newShapeInfo[d+1] == 1))
|
|
||||||
nonUnitDims.push_back(newShapeInfo[d+1]);
|
|
||||||
|
|
||||||
if(nonUnitDims.size() != rank)
|
|
||||||
result.reshapei(nonUnitDims);
|
|
||||||
}
|
|
||||||
|
|
||||||
RELEASE(newShapeInfo, getContext()->getWorkspace());
|
RELEASE(newShapeInfo, getContext()->getWorkspace());
|
||||||
|
if(newShapeInfo != shapeInfoNoUnities)
|
||||||
|
RELEASE(shapeInfoNoUnities, getContext()->getWorkspace());
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
|
@ -51,6 +51,13 @@ namespace nd4j {
|
||||||
static Nd4jLong* copyShapeInfoAndType(const Nd4jLong* inShapeInfo, const DataType dtype, const bool copyStrides, memory::Workspace* workspace = nullptr);
|
static Nd4jLong* copyShapeInfoAndType(const Nd4jLong* inShapeInfo, const DataType dtype, const bool copyStrides, memory::Workspace* workspace = nullptr);
|
||||||
static Nd4jLong* copyShapeInfoAndType(const Nd4jLong* inShapeInfo, const Nd4jLong* shapeInfoToGetTypeFrom, const bool copyStrides, memory::Workspace* workspace = nullptr);
|
static Nd4jLong* copyShapeInfoAndType(const Nd4jLong* inShapeInfo, const Nd4jLong* shapeInfoToGetTypeFrom, const bool copyStrides, memory::Workspace* workspace = nullptr);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* allocates memory for new shapeInfo and copy all information from inShapeInfo to new shapeInfo except dimensions in dimsToExclude (unit dimensions) and corresponding strides
|
||||||
|
* for example inShapeInfo is {3, 2,1,3,1,4, 12,12,4,4,1, 16384,1,99}, dimsToExclude = {2,3}, dimsSize = 2
|
||||||
|
* then outShapeInfo will contain {3, 2,3,4, 12,4,1, 16384,1,99}
|
||||||
|
*/
|
||||||
|
static Nd4jLong* copyShapeInfoWithoutUnites(const Nd4jLong* inShapeInfo, const int dimsSize, const int* dimsToExclude, memory::Workspace* workspace = nullptr);
|
||||||
|
|
||||||
static Nd4jLong* emptyShapeInfo(const nd4j::DataType dataType, memory::Workspace* workspace = nullptr);
|
static Nd4jLong* emptyShapeInfo(const nd4j::DataType dataType, memory::Workspace* workspace = nullptr);
|
||||||
|
|
||||||
static Nd4jLong* emptyShapeInfo(const nd4j::DataType dataType, const char order, const std::vector<Nd4jLong> &shape, memory::Workspace* workspace = nullptr);
|
static Nd4jLong* emptyShapeInfo(const nd4j::DataType dataType, const char order, const std::vector<Nd4jLong> &shape, memory::Workspace* workspace = nullptr);
|
||||||
|
|
|
@ -68,7 +68,7 @@ namespace nd4j {
|
||||||
const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(shapeInfo, dimsToExclude);
|
const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(shapeInfo, dimsToExclude);
|
||||||
const int subArrRank = (rank == dimsToExclude.size() || descriptor.areUnitiesinShape()) ? rank : rank - dimsToExclude.size();
|
const int subArrRank = (rank == dimsToExclude.size() || descriptor.areUnitiesinShape()) ? rank : rank - dimsToExclude.size();
|
||||||
|
|
||||||
auto sPtr = new Nd4jLong[shape::shapeInfoLength(subArrRank)];
|
auto sPtr = new Nd4jLong[shape::shapeInfoLength(subArrRank)]; // shape of sub-arrays (same for all for them)
|
||||||
auto oPtr = new Nd4jLong[numOfSubArrs];
|
auto oPtr = new Nd4jLong[numOfSubArrs];
|
||||||
|
|
||||||
if (numOfSubArrs > 0)
|
if (numOfSubArrs > 0)
|
||||||
|
|
|
@ -43,23 +43,30 @@ nd4j::NDArray* nd4j::MmulHelper::tensorDot(const nd4j::NDArray* a, const nd4j::N
|
||||||
|
|
||||||
auto outShape = ShapeUtils::evalShapeForTensorDot(a, b, axes_0, axes_1, permutAt, permutBt, shapeAt, shapeBt);
|
auto outShape = ShapeUtils::evalShapeForTensorDot(a, b, axes_0, axes_1, permutAt, permutBt, shapeAt, shapeBt);
|
||||||
|
|
||||||
NDArray aPR = a->permute(permutAt);
|
// check whether permutation is necessary
|
||||||
NDArray bPR = b->permute(permutBt);
|
const NDArray* aP = permutAt.empty() ? a : new NDArray(a->permute(permutAt));
|
||||||
|
const NDArray* bP = permutBt.empty() ? b : new NDArray(b->permute(permutBt));
|
||||||
|
|
||||||
// check whether reshape is necessary
|
// check whether reshape is necessary
|
||||||
if(!aPR.isSameShape(shapeAt))
|
const NDArray* aPR = aP->isSameShape(shapeAt) ? aP : new NDArray(aP->reshape(aP->ordering(), shapeAt));
|
||||||
aPR.reshapei( shapeAt);
|
const NDArray* bPR = bP->isSameShape(shapeAt) ? bP : new NDArray(bP->reshape(bP->ordering(), shapeBt));
|
||||||
if(!bPR.isSameShape(shapeBt))
|
|
||||||
bPR.reshapei( shapeBt);
|
|
||||||
|
|
||||||
NDArray* c = mmul(&aPR, &bPR, nullptr, 1.0, 0.0);
|
NDArray* c = mmul(aPR, bPR, nullptr, 1.0, 0.0);
|
||||||
|
|
||||||
c->reshapei(outShape);
|
c->reshapei(outShape);
|
||||||
|
|
||||||
|
if(aP != aPR)
|
||||||
|
delete aPR;
|
||||||
|
if(bP != bPR)
|
||||||
|
delete bPR;
|
||||||
|
if(a != aP)
|
||||||
|
delete aP;
|
||||||
|
if(b != bP)
|
||||||
|
delete bP;
|
||||||
|
|
||||||
return c;
|
return c;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
void nd4j::MmulHelper::tensorDot(const nd4j::NDArray* a, const nd4j::NDArray* b, nd4j::NDArray* c, const std::vector<int>& axes_a, const std::vector<int>& axes_b, const std::vector<int>& permutForC) {
|
void nd4j::MmulHelper::tensorDot(const nd4j::NDArray* a, const nd4j::NDArray* b, nd4j::NDArray* c, const std::vector<int>& axes_a, const std::vector<int>& axes_b, const std::vector<int>& permutForC) {
|
||||||
|
|
||||||
|
@ -67,32 +74,38 @@ void nd4j::MmulHelper::tensorDot(const nd4j::NDArray* a, const nd4j::NDArray* b,
|
||||||
std::vector<Nd4jLong> shapeAt, shapeBt;
|
std::vector<Nd4jLong> shapeAt, shapeBt;
|
||||||
ShapeUtils::evalShapeForTensorDot(a, b, axes_a, axes_b, permutAt, permutBt, shapeAt, shapeBt);
|
ShapeUtils::evalShapeForTensorDot(a, b, axes_a, axes_b, permutAt, permutBt, shapeAt, shapeBt);
|
||||||
|
|
||||||
NDArray *cP(c), *cPR(c);
|
|
||||||
|
|
||||||
// check whether permutation is required
|
// check whether permutation is required
|
||||||
if(!permutForC.empty())
|
NDArray* cP = permutForC.empty() ? c : new NDArray(c->permute(permutForC));
|
||||||
cP = new NDArray(c->permute(permutForC));
|
|
||||||
|
|
||||||
auto aPR = a->permute(permutAt);
|
// check whether permutation is necessary
|
||||||
auto bPR = b->permute(permutBt);
|
const NDArray* aP = permutAt.empty() ? a : new NDArray(a->permute(permutAt));
|
||||||
|
const NDArray* bP = permutBt.empty() ? b : new NDArray(b->permute(permutBt));
|
||||||
|
|
||||||
// check whether reshape is necessary
|
// check whether reshape is necessary
|
||||||
if(!aPR.isSameShape(shapeAt))
|
const NDArray* aPR = aP->isSameShape(shapeAt) ? aP : new NDArray(aP->reshape(aP->ordering(), shapeAt));
|
||||||
aPR.reshapei(shapeAt);
|
const NDArray* bPR = bP->isSameShape(shapeAt) ? bP : new NDArray(bP->reshape(bP->ordering(), shapeBt));
|
||||||
if(!bPR.isSameShape(shapeBt))
|
|
||||||
bPR.reshapei(shapeBt);
|
|
||||||
|
|
||||||
if(!cP->isSameShape({aPR.sizeAt(0), bPR.sizeAt(1)}))
|
std::vector<Nd4jLong> requiredCshape = {aPR->sizeAt(0), bPR->sizeAt(1)};
|
||||||
cPR = new NDArray(cP->reshape(cP->ordering(), {aPR.sizeAt(0), bPR.sizeAt(1)}));
|
|
||||||
|
|
||||||
mmul(&aPR, &bPR, cPR, 1.0, 0.0);
|
NDArray* cPR = cP->isSameShape(requiredCshape) ? cP : new NDArray(cP->reshape(cP->ordering(), requiredCshape, false));
|
||||||
|
|
||||||
|
mmul(aPR, bPR, cPR, 1.0, 0.0);
|
||||||
|
|
||||||
if(cPR->getBuffer() != cP->getBuffer() || cPR->getSpecialBuffer() != cP->getSpecialBuffer() ) // this means both permute and reshape have been performed on c, cP always points on c->getBuffer()
|
if(cPR->getBuffer() != cP->getBuffer() || cPR->getSpecialBuffer() != cP->getSpecialBuffer() ) // this means both permute and reshape have been performed on c, cP always points on c->getBuffer()
|
||||||
cP->assign(cPR);
|
cP->assign(cPR);
|
||||||
|
|
||||||
if(cPR != c)
|
if(aP != aPR)
|
||||||
|
delete aPR;
|
||||||
|
if(bP != bPR)
|
||||||
|
delete bPR;
|
||||||
|
if(a != aP)
|
||||||
|
delete aP;
|
||||||
|
if(b != bP)
|
||||||
|
delete bP;
|
||||||
|
|
||||||
|
if(cP != cPR)
|
||||||
delete cPR;
|
delete cPR;
|
||||||
if(cP != c)
|
if(c != cP)
|
||||||
delete cP;
|
delete cP;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -129,7 +142,7 @@ void nd4j::MmulHelper::tensorDot(const NDArray* a, const NDArray* b, NDArray* c,
|
||||||
if(!whatToDoWithC.empty()) {
|
if(!whatToDoWithC.empty()) {
|
||||||
cArrs = std::vector<NDArray*>(whatToDoWithC.size()+1, c);
|
cArrs = std::vector<NDArray*>(whatToDoWithC.size()+1, c);
|
||||||
for(int i = 0; i < cArrs.size()-1; ++i)
|
for(int i = 0; i < cArrs.size()-1; ++i)
|
||||||
cArrs[i+1] = (whatToDoWithC[i] == 'p') ? new NDArray(cArrs[i]->permute(modifC[i])) : new NDArray(cArrs[i]->reshape(c->ordering(), modifC[i])); // since we ignore first element in cArrs (that is cArrs[0]) then it is always equal to c
|
cArrs[i+1] = (whatToDoWithC[i] == 'p') ? new NDArray(cArrs[i]->permute(modifC[i])) : new NDArray(cArrs[i]->reshape(c->ordering(), modifC[i], false)); // since we ignore first element in cArrs (that is cArrs[0]) then it is always equal to c
|
||||||
}
|
}
|
||||||
|
|
||||||
mmul(aPR, bPR, cArrs[cArrs.size()-1], 1.0, 0.0);
|
mmul(aPR, bPR, cArrs[cArrs.size()-1], 1.0, 0.0);
|
||||||
|
@ -208,7 +221,7 @@ nd4j::NDArray* MmulHelper::mmul(const nd4j::NDArray* A, const nd4j::NDArray* B,
|
||||||
// vector x matrix, A{M} x B{M,N} = C{N} -> reduce to matrix x matrix A2{1,M} x B{M,N} = C2{1,N}, since there is no corresponding blas operation sgevm
|
// vector x matrix, A{M} x B{M,N} = C{N} -> reduce to matrix x matrix A2{1,M} x B{M,N} = C2{1,N}, since there is no corresponding blas operation sgevm
|
||||||
if(isAVector && bRank == 2) {
|
if(isAVector && bRank == 2) {
|
||||||
NDArray* A2 = new NDArray(A->reshape(A->ordering(), {1, A->lengthOf()})); // A{M} -> A2{1,M}
|
NDArray* A2 = new NDArray(A->reshape(A->ordering(), {1, A->lengthOf()})); // A{M} -> A2{1,M}
|
||||||
NDArray* C2 = C ? new NDArray(C->reshape(C->ordering(), {1, C->lengthOf()})) : nullptr; // C{N} -> C2{1,N}
|
NDArray* C2 = C ? new NDArray(C->reshape(C->ordering(), {1, C->lengthOf()}, false)) : nullptr; // C{N} -> C2{1,N}
|
||||||
auto result = mmulMxM(A2, B, C2, alpha, beta, outOrder); // result{1,N}
|
auto result = mmulMxM(A2, B, C2, alpha, beta, outOrder); // result{1,N}
|
||||||
delete A2;
|
delete A2;
|
||||||
delete C2;
|
delete C2;
|
||||||
|
|
|
@ -139,5 +139,15 @@ namespace nd4j {
|
||||||
return ShapeBuilders::copyShapeInfoAndType(inShapeInfo, ArrayOptions::dataType(shapeInfoToGetTypeFrom), copyStrides, workspace);
|
return ShapeBuilders::copyShapeInfoAndType(inShapeInfo, ArrayOptions::dataType(shapeInfoToGetTypeFrom), copyStrides, workspace);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
Nd4jLong* ShapeBuilders::copyShapeInfoWithoutUnites(const Nd4jLong* inShapeInfo, const int dimsSize, const int* dimsToExclude, memory::Workspace* workspace) {
|
||||||
|
|
||||||
|
Nd4jLong *outShapeInfo = nullptr;
|
||||||
|
ALLOCATE(outShapeInfo, workspace, shape::shapeInfoLength(inShapeInfo[0] - dimsSize), Nd4jLong);
|
||||||
|
|
||||||
|
shape::excludeUnitiesFromShapeInfo(inShapeInfo, dimsSize, dimsToExclude, outShapeInfo);
|
||||||
|
|
||||||
|
return outShapeInfo;
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
|
@ -75,10 +75,23 @@ std::vector<Nd4jLong> ShapeUtils::evalShapeForTensorDot(const Nd4jLong* aShapeIn
|
||||||
permutBt = axesB;
|
permutBt = axesB;
|
||||||
permutBt.insert(permutBt.end(), list_B.begin(), list_B.end());
|
permutBt.insert(permutBt.end(), list_B.begin(), list_B.end());
|
||||||
|
|
||||||
|
// if permut contains something like {0,1,2,..rank-1}, then there is no need to make permutation and we return empty vector in this case
|
||||||
|
uint i1, i2;
|
||||||
|
for(i1 = 0; i1 < aRank; ++i1)
|
||||||
|
if(permutAt[i1] != i1)
|
||||||
|
break;
|
||||||
|
if(i1 == aRank)
|
||||||
|
permutAt = {};
|
||||||
|
for(i2 = 0; i2 < bRank; ++i2)
|
||||||
|
if(permutBt[i2] != i2)
|
||||||
|
break;
|
||||||
|
if(i2 == bRank)
|
||||||
|
permutBt = {};
|
||||||
|
|
||||||
Nd4jLong n2 = 1;
|
Nd4jLong n2 = 1;
|
||||||
for (int i = 0; i < axeAsize; i++)
|
for (int i = 0; i < axeAsize; i++)
|
||||||
n2 *= aShapeInfo[axesA[i] + 1];
|
n2 *= aShapeInfo[axesA[i] + 1];
|
||||||
shapeAt = {-1, n2};
|
shapeAt = {shape::length(aShapeInfo) / n2, n2};
|
||||||
|
|
||||||
std::vector<Nd4jLong> oldShapeA;
|
std::vector<Nd4jLong> oldShapeA;
|
||||||
oldShapeA.resize(list_A.size());
|
oldShapeA.resize(list_A.size());
|
||||||
|
@ -89,7 +102,7 @@ std::vector<Nd4jLong> ShapeUtils::evalShapeForTensorDot(const Nd4jLong* aShapeIn
|
||||||
Nd4jLong n3 = 1;
|
Nd4jLong n3 = 1;
|
||||||
for (int i = 0; i < axeBsize; i++)
|
for (int i = 0; i < axeBsize; i++)
|
||||||
n3 *= bShapeInfo[axesB[i] + 1];
|
n3 *= bShapeInfo[axesB[i] + 1];
|
||||||
shapeBt = {n3, -1};
|
shapeBt = {n3, shape::length(bShapeInfo) / n3};
|
||||||
|
|
||||||
std::vector<Nd4jLong> oldShapeB;
|
std::vector<Nd4jLong> oldShapeB;
|
||||||
oldShapeB.resize(list_B.size());
|
oldShapeB.resize(list_B.size());
|
||||||
|
@ -306,10 +319,10 @@ std::vector<Nd4jLong> ShapeUtils::evalRepeatShape(int axis, const std::vector<in
|
||||||
Nd4jLong* ShapeUtils::evalPermShapeInfo(const int* dimensions, const int rank, const NDArray& arr, nd4j::memory::Workspace* workspace) {
|
Nd4jLong* ShapeUtils::evalPermShapeInfo(const int* dimensions, const int rank, const NDArray& arr, nd4j::memory::Workspace* workspace) {
|
||||||
|
|
||||||
if (!arr.nonNull())
|
if (!arr.nonNull())
|
||||||
throw std::runtime_error("ShapeUtils::evalPermShapeInfo static method: wrong arguments in pn/termute method: either array is nullptr!");
|
throw std::runtime_error("ShapeUtils::evalPermShapeInfo static method: wrong arguments: array is nullptr!");
|
||||||
|
|
||||||
if (rank != arr.rankOf())
|
if (rank != arr.rankOf())
|
||||||
throw std::runtime_error("ShapeUtils::evalPermShapeInfo static method: wrong arguments in pn/termute method: rank is not suitable!");
|
throw std::runtime_error("ShapeUtils::evalPermShapeInfo static method: wrong arguments: rank is not suitable!");
|
||||||
|
|
||||||
auto shapeInfoLength = shape::shapeInfoLength(rank);
|
auto shapeInfoLength = shape::shapeInfoLength(rank);
|
||||||
// allocate memory for new array - shapeInfo
|
// allocate memory for new array - shapeInfo
|
||||||
|
|
|
@ -131,7 +131,11 @@ namespace shape {
|
||||||
|
|
||||||
ND4J_EXPORT _CUDA_HD bool canReshape(const int oldRank, Nd4jLong* oldShape, const int newRank, Nd4jLong* newShape, bool isFOrder);
|
ND4J_EXPORT _CUDA_HD bool canReshape(const int oldRank, Nd4jLong* oldShape, const int newRank, Nd4jLong* newShape, bool isFOrder);
|
||||||
|
|
||||||
ND4J_EXPORT _CUDA_HD bool reshapeC(const int oldRank, const Nd4jLong* oldShapeInfo, const int newRank, const Nd4jLong* newShape, Nd4jLong* newShapeInfo);
|
ND4J_EXPORT _CUDA_HD bool reshapeC(const Nd4jLong* oldShapeInfo, const char newOrder, const int newRank, const Nd4jLong* newShape, Nd4jLong* newShapeInfo);
|
||||||
|
/**
|
||||||
|
* newShapeInfo contains rank, shape and order only, no strides/ews/type
|
||||||
|
*/
|
||||||
|
ND4J_EXPORT _CUDA_HD bool reshapeC(const Nd4jLong* oldShapeInfo, Nd4jLong* newShapeInfo);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get the shape info buffer
|
* Get the shape info buffer
|
||||||
|
@ -365,6 +369,13 @@ namespace shape {
|
||||||
ND4J_EXPORT _CUDA_HD bool isRowVector(const Nd4jLong *shapeInfo);
|
ND4J_EXPORT _CUDA_HD bool isRowVector(const Nd4jLong *shapeInfo);
|
||||||
|
|
||||||
ND4J_EXPORT _CUDA_HD bool isColumnVector(Nd4jLong *shapeInfo);
|
ND4J_EXPORT _CUDA_HD bool isColumnVector(Nd4jLong *shapeInfo);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* shape - input inShape is shape only, not shapeInfo
|
||||||
|
* returns number of non-unity dimensions in inShape
|
||||||
|
*/
|
||||||
|
ND4J_EXPORT _CUDA_HD int numOfNonUnitDims(const int rank, const Nd4jLong* inShape);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns whether the
|
* Returns whether the
|
||||||
* given shape is a vector or not
|
* given shape is a vector or not
|
||||||
|
@ -379,7 +390,8 @@ namespace shape {
|
||||||
* Returns the shape portion of an information
|
* Returns the shape portion of an information
|
||||||
* buffer
|
* buffer
|
||||||
*/
|
*/
|
||||||
ND4J_EXPORT _CUDA_HD Nd4jLong *shapeOf(Nd4jLong *buffer);
|
ND4J_EXPORT _CUDA_HD Nd4jLong *shapeOf(Nd4jLong *shapeInfo);
|
||||||
|
ND4J_EXPORT _CUDA_HD Nd4jLong *shapeOf(const Nd4jLong *shapeInfo);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Return a copy of a buffer.
|
* Return a copy of a buffer.
|
||||||
|
@ -994,21 +1006,16 @@ namespace shape {
|
||||||
// rank is equal to size of shape
|
// rank is equal to size of shape
|
||||||
ND4J_EXPORT void calcOffsets(const int rank, const Nd4jLong* shape, const Nd4jLong* strides, Nd4jLong* offsets, const char order = 'c');
|
ND4J_EXPORT void calcOffsets(const int rank, const Nd4jLong* shape, const Nd4jLong* strides, Nd4jLong* offsets, const char order = 'c');
|
||||||
ND4J_EXPORT void calcOffsets(const Nd4jLong* shapeInfo, Nd4jLong* offsets, const char order = 'c');
|
ND4J_EXPORT void calcOffsets(const Nd4jLong* shapeInfo, Nd4jLong* offsets, const char order = 'c');
|
||||||
ND4J_EXPORT void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffsets, const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const char order = 'c');
|
// ND4J_EXPORT void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffsets, const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const char order = 'c');
|
||||||
ND4J_EXPORT void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffsets, const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const Nd4jLong* zShapeInfo, Nd4jLong*& zOffsets, const char order = 'c');
|
// ND4J_EXPORT void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffsets, const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const Nd4jLong* zShapeInfo, Nd4jLong*& zOffsets, const char order = 'c');
|
||||||
ND4J_EXPORT _CUDA_HD void shapeOldScalar(nd4j::DataType dtype, Nd4jLong* const buffer, const char order);
|
ND4J_EXPORT _CUDA_HD void shapeOldScalar(nd4j::DataType dtype, Nd4jLong* const buffer, const char order);
|
||||||
|
|
||||||
// deduce element-wise stride
|
|
||||||
// if array is scalar or unit length vector then ews = 1
|
|
||||||
// if array is common vector then ews = stride of non-unity dimension
|
|
||||||
// if strides are normal set ews = 1, otherwise ews = 0
|
|
||||||
ND4J_EXPORT _CUDA_HD void setEws(Nd4jLong* shapeInfo, Nd4jLong len);
|
|
||||||
|
|
||||||
// deduce order and element-wise stride
|
// deduce order and element-wise stride
|
||||||
// if array is scalar or unit length vector then ews = 1 and order is preserved
|
// if array is scalar or unit length vector then ews = 1 and order is preserved
|
||||||
// if array is common vector then ews = stride of non-unity dimension and order is preserved
|
// if array is common vector then ews = stride of non-unity dimension and order is preserved
|
||||||
// if strides are normal/contiguous then ews = 1 and corresponding order is set, otherwise ews = 0 and order is preserved
|
// if strides are normal/contiguous then ews = 1 and corresponding order is set, otherwise ews = 0 and order is preserved
|
||||||
ND4J_EXPORT _CUDA_HD void setOrderAndEws(Nd4jLong* shapeInfo, Nd4jLong len = -1);
|
ND4J_EXPORT _CUDA_HD void checkStridesSetEwsAndOrder(Nd4jLong* shapeInfo, const char proposedOrder, const int numOfNonUnitDims, const Nd4jLong* shapeNoUnities, const Nd4jLong* stridesNoUnities);
|
||||||
|
ND4J_EXPORT _CUDA_HD void checkStridesSetEwsAndOrder(Nd4jLong* shapeInfo);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* processes whole set of sub-arrays
|
* processes whole set of sub-arrays
|
||||||
|
@ -1018,12 +1025,26 @@ namespace shape {
|
||||||
* numOfSubArrs - number of sub-arrays, size of subArrOffsets is equal to numOfSubArrs
|
* numOfSubArrs - number of sub-arrays, size of subArrOffsets is equal to numOfSubArrs
|
||||||
* dimsSize - size of dimsToExclude, if dimsSize = array rank or dimsSize = 0 it means sub-array is whole array, copy of wholeShapeInfo and one zero offset will be returned
|
* dimsSize - size of dimsToExclude, if dimsSize = array rank or dimsSize = 0 it means sub-array is whole array, copy of wholeShapeInfo and one zero offset will be returned
|
||||||
* dimsToExclude - MUST BE SORTED, dimensions to evaluate sub-array along, i.e. when shape is [2,3,4,5] and dimsToExclude={0,2}, then there will be 8 sub-arrays with shape [3,5]
|
* dimsToExclude - MUST BE SORTED, dimensions to evaluate sub-array along, i.e. when shape is [2,3,4,5] and dimsToExclude={0,2}, then there will be 8 sub-arrays with shape [3,5]
|
||||||
* subArrShapeInfo - output argument, contains shapeInfo common for all sub-arrays
|
* subArrShapeInfo - output argument, contains shapeInfo (same for all sub-arrays)
|
||||||
* subArrOffsets - output argument, contains successive sub-arrays offsets from original this-buffer
|
* subArrOffsets - output argument, contains successive sub-arrays offsets from original this-buffer
|
||||||
* keepUnitiesInShape - if false then eliminate unities from sub-array shapeInfo, for example {1,a,1,b} -> {a,b}
|
* keepUnitiesInShape - if false then eliminate unities from sub-array shapeInfo, for example {1,a,1,b} -> {a,b}
|
||||||
*/
|
*/
|
||||||
ND4J_EXPORT _CUDA_HD void calcSubArrShapeAndOffsets(const Nd4jLong* wholeShapeInfo, const Nd4jLong numOfSubArrs, const int dimsSize, const int* dimsToExclude, Nd4jLong* subArrShapeInfo, Nd4jLong* subArrOffsets, bool keepUnitiesInShape = false);
|
ND4J_EXPORT _CUDA_HD void calcSubArrShapeAndOffsets(const Nd4jLong* wholeShapeInfo, const Nd4jLong numOfSubArrs, const int dimsSize, const int* dimsToExclude, Nd4jLong* subArrShapeInfo, Nd4jLong* subArrOffsets, bool keepUnitiesInShape = false);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* for example inShapeInfo is {3, 2,1,4, 4,4,1, 16384,1,99}
|
||||||
|
* then output shapeNoUnities will contain {2,4, 4,1} - that is only shape and strides, no rank/type/ews/order
|
||||||
|
* stridesNoUnities will point on strides in shapeNoUnities that is on {4,1}
|
||||||
|
* returns number of non-unity dimensions in inShapeInfo
|
||||||
|
* if there is no unities in inShapeInfo, then no copy procedure will be performed and shapeNoUnities/stridesNoUnities will point on corresponding places in inShapeInfo
|
||||||
|
*/
|
||||||
|
ND4J_EXPORT _CUDA_HD int excludeUnitiesFromShapeInfo(const Nd4jLong* inShapeInfo, Nd4jLong*& shapeNoUnities, Nd4jLong*& stridesNoUnities);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* for example inShapeInfo is {3, 2,1,3,1,4, 12,12,4,4,1, 16384,1,99}, dimsToExclude = {2,3}, dimsSize = 2
|
||||||
|
* then outShapeInfo will contain {3, 2,3,4, 12,4,1, 16384,1,99}
|
||||||
|
*/
|
||||||
|
INLINEDEF _CUDA_HD void excludeUnitiesFromShapeInfo(const Nd4jLong* inShapeInfo, const int dimsSize, const int* dimsToExclude, Nd4jLong* outShapeInfo);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -2050,7 +2071,7 @@ INLINEDEF _CUDA_HD Nd4jLong indexOffset(Nd4jLong index, const Nd4jLong* lShapeIn
|
||||||
shapeInfo[i + 1 + rank] = temp[rearrange[i] + 1 + rank];
|
shapeInfo[i + 1 + rank] = temp[rearrange[i] + 1 + rank];
|
||||||
}
|
}
|
||||||
|
|
||||||
shape::setOrderAndEws(shapeInfo, len);
|
shape::checkStridesSetEwsAndOrder(shapeInfo);
|
||||||
|
|
||||||
delete[] temp;
|
delete[] temp;
|
||||||
}
|
}
|
||||||
|
@ -2227,7 +2248,7 @@ INLINEDEF _CUDA_HD Nd4jLong indexOffset(Nd4jLong index, const Nd4jLong* lShapeIn
|
||||||
INLINEDEF _CUDA_HD bool isCommonVector(const Nd4jLong *shapeInfo, int& posOfNonUnityDim) {
|
INLINEDEF _CUDA_HD bool isCommonVector(const Nd4jLong *shapeInfo, int& posOfNonUnityDim) {
|
||||||
|
|
||||||
if(rank(shapeInfo) > 0 && length(shapeInfo) == 1) {
|
if(rank(shapeInfo) > 0 && length(shapeInfo) == 1) {
|
||||||
posOfNonUnityDim = 0;
|
posOfNonUnityDim = -1;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2272,6 +2293,18 @@ INLINEDEF _CUDA_HD Nd4jLong indexOffset(Nd4jLong index, const Nd4jLong* lShapeIn
|
||||||
return isVector && !shapeFirstOne;
|
return isVector && !shapeFirstOne;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
INLINEDEF _CUDA_HD int numOfNonUnitDims(const int rank, const Nd4jLong* inShape) {
|
||||||
|
|
||||||
|
int num = 0;
|
||||||
|
|
||||||
|
for(uint i = 0; i < rank; ++i)
|
||||||
|
if(inShape[i] != 1)
|
||||||
|
++num;
|
||||||
|
|
||||||
|
return num;
|
||||||
|
}
|
||||||
|
|
||||||
INLINEDEF _CUDA_HD int oneDimEqualToLength(Nd4jLong *shape, int rank) {
|
INLINEDEF _CUDA_HD int oneDimEqualToLength(Nd4jLong *shape, int rank) {
|
||||||
for(int i = 0; i < rank; i++) {
|
for(int i = 0; i < rank; i++) {
|
||||||
if(shape[i] == shape::prodLong(shape,rank))
|
if(shape[i] == shape::prodLong(shape,rank))
|
||||||
|
@ -2310,8 +2343,14 @@ INLINEDEF _CUDA_HD Nd4jLong indexOffset(Nd4jLong index, const Nd4jLong* lShapeIn
|
||||||
* Returns the shape portion of an information
|
* Returns the shape portion of an information
|
||||||
* buffer
|
* buffer
|
||||||
*/
|
*/
|
||||||
INLINEDEF _CUDA_HD Nd4jLong *shapeOf(Nd4jLong *buffer) {
|
INLINEDEF _CUDA_HD Nd4jLong *shapeOf(Nd4jLong *shapeInfo) {
|
||||||
return buffer + 1;
|
|
||||||
|
return shapeInfo + 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
INLINEDEF _CUDA_HD Nd4jLong *shapeOf(const Nd4jLong *shapeInfo) {
|
||||||
|
|
||||||
|
return shape::shapeOf(const_cast<Nd4jLong*>(shapeInfo));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -2444,7 +2483,7 @@ INLINEDEF _CUDA_HD Nd4jLong indexOffset(Nd4jLong index, const Nd4jLong* lShapeIn
|
||||||
newShapeBuffer[2 * newRank + 3] = shape::order(shapeBuffer);
|
newShapeBuffer[2 * newRank + 3] = shape::order(shapeBuffer);
|
||||||
|
|
||||||
// correct order and ews if necessary
|
// correct order and ews if necessary
|
||||||
shape::setOrderAndEws(newShapeBuffer);
|
shape::checkStridesSetEwsAndOrder(newShapeBuffer);
|
||||||
|
|
||||||
delete[] indices;
|
delete[] indices;
|
||||||
|
|
||||||
|
@ -3918,62 +3957,51 @@ INLINEDEF _CUDA_HD bool areStridesDefault(const Nd4jLong* shapeInfo) {
|
||||||
// return true;
|
// return true;
|
||||||
// }
|
// }
|
||||||
|
|
||||||
// INLINEDEF _CUDA_H bool reshapeC(const int oldRank, const Nd4jLong* oldShapeInfo, const int newRank, const Nd4jLong* newShape, const bool isFOrder, Nd4jLong* newShapeInfo) {
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
// INLINEDEF _CUDA_H bool reshapeC(const int oldRank, const Nd4jLong* oldShapeInfo, const int newRank, const Nd4jLong* newShape, Nd4jLong* newShapeInfo) {
|
||||||
|
|
||||||
// // PLEASE NOTE !: reshaping not-permuted (ews=1) array in f order (except insertion/elimination of unities) will definitely cause allocation of new buffer for array elements
|
// // PLEASE NOTE !: reshaping not-permuted (ews=1) array in f order (except insertion/elimination of unities) will definitely cause allocation of new buffer for array elements
|
||||||
// // also this function takes into account identical shapes automatically, namely in that case oldShapeInfo is completely copied to newShapeInfo
|
// // also this function takes into account identical shapes automatically, namely in that case oldShapeInfo is completely copied to newShapeInfo
|
||||||
|
|
||||||
// const int newOrder = isFOrder ? 102 : 99;
|
|
||||||
// const int oldOrder = oldShapeInfo[2 * oldRank + 3];
|
|
||||||
|
|
||||||
// newShapeInfo[0] = newRank;
|
// newShapeInfo[0] = newRank;
|
||||||
// memcpy(newShapeInfo + 1, newShape, newRank * sizeof(Nd4jLong));
|
// memcpy(newShapeInfo + 1, newShape, newRank * sizeof(Nd4jLong));
|
||||||
|
|
||||||
// Nd4jLong* newStrides = shape::stride(newShapeInfo);
|
// Nd4jLong* newStrides = shape::stride(newShapeInfo);
|
||||||
// const Nd4jLong* oldShape = shape::shapeOf(const_cast<Nd4jLong*>(oldShapeInfo));
|
// const Nd4jLong* oldShape = shape::shapeOf(const_cast<Nd4jLong*>(oldShapeInfo));
|
||||||
// const Nd4jLong* oldStrides = shape::stride(const_cast<Nd4jLong*>(oldShapeInfo));
|
// const Nd4jLong* oldStrides = shape::stride(const_cast<Nd4jLong*>(oldShapeInfo));
|
||||||
// int oldStart(0), oldStop(1), newStart(0), newStop(1), newDim, oldDim;
|
// Nd4jLong oldStart(0), oldStop(1), newStart(0), newStop(1), newDim, oldDim;
|
||||||
|
|
||||||
|
|
||||||
// while (newStart < newRank && oldStart < oldRank) {
|
// while (newStart < newRank && oldStart < oldRank) {
|
||||||
|
|
||||||
// newDim = newShape[newStart];
|
// newDim = newShape[newStart];
|
||||||
// oldDim = oldShape[oldStart];
|
// oldDim = oldShape[oldStart];
|
||||||
|
|
||||||
// while (newDim != oldDim)
|
// while (newDim != oldDim && newDim > 0 && oldDim > 0)
|
||||||
// if (newDim < oldDim) newDim *= newShape[newStop++];
|
// if (newDim < oldDim) newDim *= newShape[newStop++];
|
||||||
// else oldDim *= oldShape[oldStop++];
|
// else oldDim *= oldShape[oldStop++];
|
||||||
|
|
||||||
// // ------ Check whether the original axes can be combined ------ //
|
// // ------ Check whether the original axes can be combined ------ //
|
||||||
// for (int i = oldStart; i < oldStop - 1; i++) {
|
// for (int step = 1, i = oldStart; i < oldStop - 1; ++i) {
|
||||||
|
// if(oldShape[i] == 1) // skip unity-dimension and its stride
|
||||||
// if(oldShape[i] == 1) { // ignore strides like {...,1,1,...}
|
|
||||||
// if(oldOrder == 102) ++oldStart;
|
|
||||||
// continue;
|
// continue;
|
||||||
// }
|
// while((i + step) < oldRank && oldShape[i + step] == 1)
|
||||||
|
// ++step; // skip following unity-dimensions and its strides if such are present
|
||||||
// if(oldOrder == 102 && oldStrides[i + 1] != oldShape[i] * oldStrides[i])
|
// if((i + step) < oldRank && oldStrides[i] != oldShape[i + step] * oldStrides[i + step])
|
||||||
// return false; // not contiguous enough
|
|
||||||
// if(oldOrder == 99 && oldStrides[i] != oldShape[i + 1] * oldStrides[i + 1])
|
|
||||||
// return false; // not contiguous enough
|
// return false; // not contiguous enough
|
||||||
// }
|
// }
|
||||||
|
|
||||||
// // ------ Calculate new strides for all axes currently worked with ------ //
|
|
||||||
// if(isFOrder) {
|
|
||||||
// newStrides[newStart] = oldStrides[oldStart];
|
|
||||||
// for (int i = newStart + 1; i < newStop; ++i)
|
|
||||||
// newStrides[i] = newStrides[i - 1] * newShape[i - 1];
|
|
||||||
// }
|
|
||||||
// else {
|
|
||||||
// newStrides[newStop - 1] = oldStrides[oldStop - 1];
|
// newStrides[newStop - 1] = oldStrides[oldStop - 1];
|
||||||
// for (int i = newStop - 1; i > newStart; --i)
|
// for (int i = newStop - 1; i > newStart; --i)
|
||||||
// newStrides[i - 1] = newStrides[i] * newShape[i];
|
// newStrides[i - 1] = newStrides[i] * newShape[i];
|
||||||
// }
|
|
||||||
|
|
||||||
// newStart = newStop++;
|
// newStart = newStop++;
|
||||||
// oldStart = oldStop++;
|
// oldStart = oldStop++;
|
||||||
// }
|
// }
|
||||||
|
|
||||||
|
// // rest of strides should be unities (if there is remainder in strides space, that is newStart < newRank)
|
||||||
|
// for (int i = newStart; i < newRank; ++i)
|
||||||
|
// newStrides[i] = 1;
|
||||||
|
|
||||||
// newShapeInfo[2 * newRank + 3] = shape::order(oldShapeInfo); // order
|
// newShapeInfo[2 * newRank + 3] = shape::order(oldShapeInfo); // order
|
||||||
// newShapeInfo[2 * newRank + 2] = shape::elementWiseStride(oldShapeInfo); // ews
|
// newShapeInfo[2 * newRank + 2] = shape::elementWiseStride(oldShapeInfo); // ews
|
||||||
// newShapeInfo[2 * newRank + 1] = shape::type(oldShapeInfo); // type
|
// newShapeInfo[2 * newRank + 1] = shape::type(oldShapeInfo); // type
|
||||||
|
@ -3982,57 +4010,98 @@ INLINEDEF _CUDA_HD bool areStridesDefault(const Nd4jLong* shapeInfo) {
|
||||||
// }
|
// }
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
INLINEDEF _CUDA_H bool reshapeC(const int oldRank, const Nd4jLong* oldShapeInfo, const int newRank, const Nd4jLong* newShape, Nd4jLong* newShapeInfo) {
|
INLINEDEF _CUDA_HD bool reshapeC(const Nd4jLong* oldShapeInfo, const char newOrder, const int newRank, const Nd4jLong* newShape, Nd4jLong* newShapeInfo) {
|
||||||
|
|
||||||
// PLEASE NOTE !: reshaping not-permuted (ews=1) array in f order (except insertion/elimination of unities) will definitely cause allocation of new buffer for array elements
|
|
||||||
// also this function takes into account identical shapes automatically, namely in that case oldShapeInfo is completely copied to newShapeInfo
|
|
||||||
|
|
||||||
|
// copy shape from newShape into newShapeInfo
|
||||||
newShapeInfo[0] = newRank;
|
newShapeInfo[0] = newRank;
|
||||||
memcpy(newShapeInfo + 1, newShape, newRank * sizeof(Nd4jLong));
|
memcpy(newShapeInfo + 1, newShape, newRank * sizeof(Nd4jLong));
|
||||||
|
|
||||||
Nd4jLong* newStrides = shape::stride(newShapeInfo);
|
// copy order
|
||||||
const Nd4jLong* oldShape = shape::shapeOf(const_cast<Nd4jLong*>(oldShapeInfo));
|
newShapeInfo[2 * newRank + 3] = newOrder;
|
||||||
const Nd4jLong* oldStrides = shape::stride(const_cast<Nd4jLong*>(oldShapeInfo));
|
|
||||||
Nd4jLong oldStart(0), oldStop(1), newStart(0), newStop(1), newDim, oldDim;
|
|
||||||
|
|
||||||
while (newStart < newRank && oldStart < oldRank) {
|
return shape::reshapeC(oldShapeInfo, newShapeInfo);
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
INLINEDEF _CUDA_HD bool reshapeC(const Nd4jLong* oldShapeInfo, Nd4jLong* newShapeInfo) {
|
||||||
|
|
||||||
|
// newShapeInfo contains rank, shape and order; but no strides, type and ews
|
||||||
|
|
||||||
|
const int newRank = shape::rank(newShapeInfo);
|
||||||
|
|
||||||
|
// if oldShapeInfo is scalar or vector with length=1
|
||||||
|
if(shape::length(oldShapeInfo) == 1) {
|
||||||
|
for (uint i = 0; i < newRank; ++i)
|
||||||
|
shape::stride(newShapeInfo)[i] = 1;
|
||||||
|
newShapeInfo[2 * newRank + 1] = shape::type(oldShapeInfo);
|
||||||
|
*shape::ews(newShapeInfo) = 1;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto oldOrder = shape::order(oldShapeInfo);
|
||||||
|
const auto newOrder = shape::order(newShapeInfo);
|
||||||
|
const auto oldEws = shape::elementWiseStride(const_cast<Nd4jLong*>(oldShapeInfo));
|
||||||
|
|
||||||
|
if(oldEws > 0 && oldOrder != newOrder)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
// *** FIRST STAGE - exclude unity dimensions from oldShapeInfo and newShapeInfo (if such are present of course), since they don't affect on strides evaluation, however they complicate code
|
||||||
|
|
||||||
|
// FIXME - indeed we don't need to allocate so large memory amount (2*MAX_RANK), sufficient amount is (2*oldNumOfNonUnities + 2*newNumOfNonUnities)
|
||||||
|
Nd4jLong tempBuffer[4*MAX_RANK];
|
||||||
|
Nd4jLong *oldShape = tempBuffer, *newShape = tempBuffer + 2*MAX_RANK, *oldStrides, *newStrides;
|
||||||
|
|
||||||
|
// exclude unities from oldShapeInfo
|
||||||
|
const int oldNumOfNonUnities = shape::excludeUnitiesFromShapeInfo(oldShapeInfo, oldShape, oldStrides);
|
||||||
|
const int newNumOfNonUnities = shape::excludeUnitiesFromShapeInfo(newShapeInfo, newShape, newStrides);
|
||||||
|
|
||||||
|
// *** SECOND STAGE - strides evaluation
|
||||||
|
|
||||||
|
int oldStart(0), oldStop(1), newStart(0), newStop(1), newDim, oldDim;
|
||||||
|
|
||||||
|
while (newStart < newNumOfNonUnities && oldStart < oldNumOfNonUnities) {
|
||||||
|
|
||||||
newDim = newShape[newStart];
|
newDim = newShape[newStart];
|
||||||
oldDim = oldShape[oldStart];
|
oldDim = oldShape[oldStart];
|
||||||
|
|
||||||
while (newDim != oldDim && newDim > 0 && oldDim > 0)
|
while (newDim != oldDim && newDim > 0 && oldDim > 0) {
|
||||||
if (newDim < oldDim) newDim *= newShape[newStop++];
|
|
||||||
else oldDim *= oldShape[oldStop++];
|
|
||||||
|
|
||||||
// ------ Check whether the original axes can be combined ------ //
|
if (newDim < oldDim)
|
||||||
for (int step = 1, i = oldStart; i < oldStop - 1; ++i) {
|
newDim *= newShape[newStop++];
|
||||||
if(oldShape[i] == 1) // skip unity-dimension and its stride
|
else
|
||||||
continue;
|
oldDim *= oldShape[oldStop++];
|
||||||
while((i + step) < oldRank && oldShape[i + step] == 1)
|
|
||||||
++step; // skip following unity-dimensions and its strides if such are present
|
|
||||||
if((i + step) < oldRank && oldStrides[i] != oldShape[i + step] * oldStrides[i + step])
|
|
||||||
return false; // not contiguous enough
|
|
||||||
}
|
}
|
||||||
|
|
||||||
newStrides[newStop - 1] = oldStrides[oldStop - 1];
|
// check c-contiguous of old axes range
|
||||||
for (int i = newStop - 1; i > newStart; --i)
|
for(uint i = oldStart; i < oldStop - 1; ++i) // do not check value of last stride, it doesn't matter
|
||||||
newStrides[i - 1] = newStrides[i] * newShape[i];
|
if(oldStrides[i] != oldShape[i + 1] * oldStrides[i + 1])
|
||||||
|
return false; // not contiguous
|
||||||
|
|
||||||
|
// fill newStrides in c manner
|
||||||
|
newStrides[newStop - 1] = oldStrides[oldStop - 1]; // copy last stride
|
||||||
|
for (int i = newStop - 2; i >= newStart; --i)
|
||||||
|
newStrides[i] = newStrides[i + 1] * newShape[i + 1];
|
||||||
|
|
||||||
newStart = newStop++;
|
newStart = newStop++;
|
||||||
oldStart = oldStop++;
|
oldStart = oldStop++;
|
||||||
}
|
}
|
||||||
|
|
||||||
// rest of strides should be unities (if there is remainder in strides space, that is newStart < newRank)
|
// fill new calculated strides into newShapeInfo, take into account possible unities in shape
|
||||||
for (int i = newStart; i < newRank; ++i)
|
for (int j = 0, i = 0; i < newRank; ++i)
|
||||||
newStrides[i] = 1;
|
shape::stride(newShapeInfo)[i] = (shape::shapeOf(newShapeInfo)[i] == 1) ? 1 : newStrides[j++];
|
||||||
|
|
||||||
|
// set ews
|
||||||
|
if(oldEws == 0)
|
||||||
|
shape::checkStridesSetEwsAndOrder(newShapeInfo, newOrder, newNumOfNonUnities, newShape, newStrides); // set ews and order
|
||||||
|
else {
|
||||||
|
newShapeInfo[2 * newRank + 3] = oldOrder; // order
|
||||||
|
*shape::ews(newShapeInfo) = oldEws; // ews
|
||||||
|
}
|
||||||
|
|
||||||
newShapeInfo[2 * newRank + 3] = shape::order(oldShapeInfo); // order
|
|
||||||
newShapeInfo[2 * newRank + 2] = shape::elementWiseStride(oldShapeInfo); // ews
|
|
||||||
newShapeInfo[2 * newRank + 1] = shape::type(oldShapeInfo); // type
|
newShapeInfo[2 * newRank + 1] = shape::type(oldShapeInfo); // type
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
INLINEDEF _CUDA_H bool canReshape(const int oldRank, Nd4jLong* oldShape, const int newRank, Nd4jLong* newShapeOf, bool isFOrder) {
|
INLINEDEF _CUDA_H bool canReshape(const int oldRank, Nd4jLong* oldShape, const int newRank, Nd4jLong* newShapeOf, bool isFOrder) {
|
||||||
|
@ -4573,129 +4642,75 @@ INLINEDEF void calcOffsets(const int rank, const Nd4jLong* shape, const Nd4jLong
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
INLINEDEF void _CUDA_HD setEws(Nd4jLong* shapeInfo, Nd4jLong len) {
|
INLINEDEF void _CUDA_HD checkStridesSetEwsAndOrder(Nd4jLong* shapeInfo) {
|
||||||
|
|
||||||
|
// FIXME - indeed we don't need to allocate so large memory amount (2*MAX_RANK), sufficient amount is (2*oldNumOfNonUnities + 2*newNumOfNonUnities)
|
||||||
|
Nd4jLong tempBuffer[2*MAX_RANK];
|
||||||
|
Nd4jLong *shape = tempBuffer, *strides;
|
||||||
|
|
||||||
const int rank = shape::rank(shapeInfo);
|
// exclude unities from shapeInfo
|
||||||
const Nd4jLong* shape = shape::shapeOf(shapeInfo);
|
const int numOfNonUnities = shape::excludeUnitiesFromShapeInfo(shapeInfo, shape, strides);
|
||||||
const Nd4jLong* strides = shape::stride(shapeInfo);
|
|
||||||
const char order = shape::order(shapeInfo);
|
|
||||||
Nd4jLong* ews = shape::ews(shapeInfo);
|
|
||||||
|
|
||||||
if(len == -1) // calculate array length if it is not given
|
shape::checkStridesSetEwsAndOrder(shapeInfo, shape::order(shapeInfo), numOfNonUnities, shape, strides);
|
||||||
len = shape::length(shapeInfo);
|
|
||||||
|
|
||||||
if(len <= 1) { // empty, scalar or unity-vector case
|
|
||||||
*ews = 1;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
int nonUnityDim(0);
|
|
||||||
if(shape::isCommonVector(shapeInfo, nonUnityDim)) {
|
|
||||||
*ews = strides[nonUnityDim];
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// check last(c)/first(f) dimension, it should be equal to 1
|
|
||||||
if((order == 'c' && shape[rank - 1] != 1 && strides[rank - 1] != 1) || (order == 'f' && shape[0] != 1 && strides[0] != 1)) {
|
|
||||||
*ews = 0;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
Nd4jLong correctStride = 1;
|
|
||||||
if(order == 'c') {
|
|
||||||
for (int i = rank - 2; i >= 0 ; i--) {
|
|
||||||
correctStride *= shape[i + 1];
|
|
||||||
if(shape[i] == 1)
|
|
||||||
continue;
|
|
||||||
if(correctStride != strides[i]) {
|
|
||||||
*ews = 0;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
for (int i = 1; i < rank; ++i) {
|
|
||||||
correctStride *= shape[i - 1];
|
|
||||||
if(shape[i] == 1)
|
|
||||||
continue;
|
|
||||||
if(correctStride != strides[i]) {
|
|
||||||
*ews = 0;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
*ews = 1;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
INLINEDEF _CUDA_HD void setOrderAndEws(Nd4jLong* shapeInfo, Nd4jLong len) {
|
INLINEDEF void _CUDA_HD checkStridesSetEwsAndOrder(Nd4jLong* shapeInfo, const char proposedOrder, const int numOfNonUnities, const Nd4jLong* shapeNoUnities, const Nd4jLong* stridesNoUnities) {
|
||||||
|
|
||||||
const int rank = shape::rank(shapeInfo);
|
const int rank = shape::rank(shapeInfo);
|
||||||
const Nd4jLong* shape = shape::shapeOf(shapeInfo);
|
|
||||||
const Nd4jLong* strides = shape::stride(shapeInfo);
|
|
||||||
const char order = shape::order(shapeInfo);
|
|
||||||
Nd4jLong* ews = shape::ews(shapeInfo);
|
|
||||||
|
|
||||||
if(len == -1) // calculate array length if it is not given
|
if(shape::length(shapeInfo) == 1) {
|
||||||
len = shape::length(shapeInfo);
|
*shape::ews(shapeInfo) = 1;
|
||||||
|
shapeInfo[rank * 2 + 3] = (int)proposedOrder;
|
||||||
if(len <= 1) { // empty, scalar or unity-vector case
|
|
||||||
*ews = 1;
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
int nonUnityDim(0);
|
if(numOfNonUnities == 1) { // case of common vector
|
||||||
if(shape::isCommonVector(shapeInfo, nonUnityDim)) { // in this case we don't change order
|
*shape::ews(shapeInfo) = *stridesNoUnities;
|
||||||
*ews = strides[nonUnityDim];
|
shapeInfo[rank * 2 + 3] = (int)proposedOrder;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// check if strides are contiguous in respect to c-order
|
bool contiguous = true;
|
||||||
// firstly check last stride, it should be equal to 1
|
|
||||||
if (strides[rank - 1] == 1 || shape[rank - 1] == 1) { // last dimension is ok, go on through the rest dimensions in reverse order
|
// *** check whether strides are in c contiguous order ***//
|
||||||
Nd4jLong correctStride = 1;
|
if(stridesNoUnities[numOfNonUnities - 1] != 1) // last stride should be always unity for c order
|
||||||
bool cContiguous = true;
|
contiguous = false;
|
||||||
for (int i = rank - 2; i >= 0 ; i--) {
|
else {
|
||||||
correctStride *= shape[i + 1];
|
for (uint i = 0; i < numOfNonUnities - 1; ++i) {
|
||||||
if(shape[i] == 1)
|
if(stridesNoUnities[i] != stridesNoUnities[i + 1] * shapeNoUnities[i + 1]) {
|
||||||
continue;
|
contiguous = false;
|
||||||
if(correctStride != strides[i]) {
|
|
||||||
cContiguous = false;
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if(cContiguous) {
|
}
|
||||||
*ews = 1;
|
if(contiguous) {
|
||||||
shapeInfo[shape::shapeInfoLength(rank) - 1] = 99;
|
*shape::ews(shapeInfo) = 1;
|
||||||
|
shapeInfo[rank * 2 + 3] = 99;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// now check if strides are contiguous in respect to f-order
|
contiguous = true;
|
||||||
// firstly check first stride, it should be equal to 1
|
|
||||||
if(strides[0] == 1 || shape[0] == 1) { // first dimension is ok, go on through the rest dimensions
|
//*** check whether strides are in f contiguous order ***//
|
||||||
Nd4jLong correctStride = 1;
|
if(stridesNoUnities[0] != 1) // first stride should be always unity for f order
|
||||||
bool fContiguous = true;
|
contiguous = false;
|
||||||
for (int i = 1; i < rank; ++i) {
|
else {
|
||||||
correctStride *= shape[i - 1];
|
for (uint i = 1; i < numOfNonUnities; ++i) {
|
||||||
if(shape[i] == 1)
|
if(stridesNoUnities[i] != stridesNoUnities[i - 1] * shapeNoUnities[i - 1]) {
|
||||||
continue;
|
contiguous = false;
|
||||||
if(correctStride != strides[i]) {
|
|
||||||
fContiguous = false;
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if(fContiguous) {
|
}
|
||||||
*ews = 1;
|
if(contiguous) {
|
||||||
shapeInfo[shape::shapeInfoLength(rank) - 1] = 102;
|
*shape::ews(shapeInfo) = 1;
|
||||||
|
shapeInfo[rank * 2 + 3] = 102;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
*ews = 0;
|
*shape::ews(shapeInfo) = 0;
|
||||||
// if both cContiguous and fContiguous are false then order is preserved
|
shapeInfo[rank * 2 + 3] = (int)proposedOrder;
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -4709,49 +4724,42 @@ INLINEDEF _CUDA_HD void calcSubArrShapeAndOffsets(const Nd4jLong* wholeShapeInfo
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
Nd4jLong *outShapeInfo = new Nd4jLong[shape::shapeInfoLength(wholeShapeInfo)];
|
const int subArrRank = keepUnitiesInShape ? rank : rank - dimsSize;
|
||||||
memcpy(outShapeInfo, wholeShapeInfo, shape::shapeInfoByteLength(wholeShapeInfo));
|
|
||||||
|
subArrShapeInfo[0] = subArrRank; // rank
|
||||||
|
subArrShapeInfo[2 * subArrRank + 1] = shape::type(wholeShapeInfo); // type
|
||||||
|
subArrShapeInfo[2 * subArrRank + 3] = shape::order(wholeShapeInfo); // order
|
||||||
|
|
||||||
Nd4jLong* shape = new Nd4jLong[dimsSize];
|
Nd4jLong* shape = new Nd4jLong[dimsSize];
|
||||||
Nd4jLong* strides = new Nd4jLong[dimsSize];
|
Nd4jLong* strides = new Nd4jLong[dimsSize];
|
||||||
|
|
||||||
const int subArrRank = keepUnitiesInShape ? rank : rank - dimsSize;
|
|
||||||
Nd4jLong* shapeNoUnities = nullptr;
|
|
||||||
if(!keepUnitiesInShape)
|
|
||||||
shapeNoUnities = new Nd4jLong[subArrRank];
|
|
||||||
|
|
||||||
Nd4jLong subArrLen = 1;
|
|
||||||
|
|
||||||
for(int k = subArrRank - 1, j = dimsSize - 1, i = rank - 1; i >= 0; --i) {
|
for(int k = subArrRank - 1, j = dimsSize - 1, i = rank - 1; i >= 0; --i) {
|
||||||
|
|
||||||
if(j >= 0 && i == dimsToExclude[j]) {
|
if(j >= 0 && i == dimsToExclude[j]) {
|
||||||
strides[j] = shape::stride(outShapeInfo)[i];
|
|
||||||
shape[j--] = shape::shapeOf(outShapeInfo)[i];
|
strides[j] = shape::stride(wholeShapeInfo)[i];
|
||||||
shape::shapeOf(outShapeInfo)[i] = 1;
|
shape[j--] = shape::shapeOf(wholeShapeInfo)[i];
|
||||||
|
|
||||||
|
if(keepUnitiesInShape) {
|
||||||
|
shape::shapeOf(subArrShapeInfo)[k] = 1;
|
||||||
|
shape::stride(subArrShapeInfo)[k--] = shape::stride(wholeShapeInfo)[i];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
subArrLen *= shape::shapeOf(outShapeInfo)[i];
|
shape::shapeOf(subArrShapeInfo)[k] = shape::shapeOf(wholeShapeInfo)[i];
|
||||||
if(!keepUnitiesInShape)
|
shape::stride(subArrShapeInfo)[k--] = shape::stride(wholeShapeInfo)[i];
|
||||||
shapeNoUnities[k--] = shape::shapeOf(outShapeInfo)[i];
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// evaluate ews
|
}
|
||||||
shape::setEws(outShapeInfo, subArrLen);
|
|
||||||
|
|
||||||
// calculation of sub-array offsets (subArrOffsets)
|
// calculation of sub-array offsets (subArrOffsets)
|
||||||
shape::calcOffsets(dimsSize, shape, strides, subArrOffsets);
|
shape::calcOffsets(dimsSize, shape, strides, subArrOffsets);
|
||||||
|
|
||||||
// remove unities from outShapeInfo if required
|
// evaluate ews
|
||||||
if(!keepUnitiesInShape) {
|
shape::checkStridesSetEwsAndOrder(subArrShapeInfo);
|
||||||
shape::reshapeC(rank, outShapeInfo, subArrRank, shapeNoUnities, subArrShapeInfo);
|
|
||||||
delete []shapeNoUnities;
|
|
||||||
}
|
|
||||||
else
|
|
||||||
memcpy(subArrShapeInfo, outShapeInfo, shape::shapeInfoLength(subArrRank) * sizeof(Nd4jLong));
|
|
||||||
|
|
||||||
delete []strides;
|
delete []strides;
|
||||||
delete []shape;
|
delete []shape;
|
||||||
delete []outShapeInfo;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -4815,197 +4823,240 @@ INLINEDEF void _CUDA_HD index2coords(Nd4jLong index, const Nd4jLong *shapeInfo,
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
INLINEDEF _CUDA_HD void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffsets, const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const Nd4jLong* zShapeInfo, Nd4jLong*& zOffsets, const char order) {
|
// INLINEDEF _CUDA_HD void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffsets, const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const Nd4jLong* zShapeInfo, Nd4jLong*& zOffsets, const char order) {
|
||||||
|
|
||||||
// we assume all array have same length
|
// // we assume all array have same length
|
||||||
const Nd4jLong len = shape::length(xShapeInfo);
|
// const Nd4jLong len = shape::length(xShapeInfo);
|
||||||
|
|
||||||
const Nd4jLong xEws = shape::elementWiseStride(xShapeInfo);
|
// const Nd4jLong xEws = shape::elementWiseStride(xShapeInfo);
|
||||||
const Nd4jLong yEws = shape::elementWiseStride(yShapeInfo);
|
// const Nd4jLong yEws = shape::elementWiseStride(yShapeInfo);
|
||||||
const Nd4jLong zEws = shape::elementWiseStride(zShapeInfo);
|
// const Nd4jLong zEws = shape::elementWiseStride(zShapeInfo);
|
||||||
|
|
||||||
const char xOrder = shape::order(xShapeInfo);
|
// const char xOrder = shape::order(xShapeInfo);
|
||||||
const char yOrder = shape::order(yShapeInfo);
|
// const char yOrder = shape::order(yShapeInfo);
|
||||||
const char zOrder = shape::order(zShapeInfo);
|
// const char zOrder = shape::order(zShapeInfo);
|
||||||
|
|
||||||
const bool shapesSame = shape::shapeEquals(xShapeInfo, yShapeInfo, zShapeInfo);
|
// const bool shapesSame = shape::shapeEquals(xShapeInfo, yShapeInfo, zShapeInfo);
|
||||||
|
|
||||||
if (xEws == 1 && yEws == 1 && zEws == 1 && xOrder == yOrder && xOrder == zOrder && (xOrder == 'c' || shapesSame)) {
|
// if (xEws == 1 && yEws == 1 && zEws == 1 && xOrder == yOrder && xOrder == zOrder && (xOrder == 'c' || shapesSame)) {
|
||||||
xOffsets = yOffsets = zOffsets = nullptr;
|
// xOffsets = yOffsets = zOffsets = nullptr;
|
||||||
}
|
// }
|
||||||
else if(xEws == 1 && yEws == 1 && xOrder == yOrder && (xOrder == 'c' || shape::shapeEquals(xShapeInfo, yShapeInfo))) {
|
// else if(xEws == 1 && yEws == 1 && xOrder == yOrder && (xOrder == 'c' || shape::shapeEquals(xShapeInfo, yShapeInfo))) {
|
||||||
xOffsets = yOffsets = nullptr;
|
// xOffsets = yOffsets = nullptr;
|
||||||
zOffsets = new Nd4jLong[len];
|
// zOffsets = new Nd4jLong[len];
|
||||||
shape::calcOffsets(zShapeInfo, zOffsets, xOrder);
|
// shape::calcOffsets(zShapeInfo, zOffsets, xOrder);
|
||||||
}
|
// }
|
||||||
else if(xEws == 1 && zEws == 1 && xOrder == zOrder && (xOrder == 'c' || shape::shapeEquals(xShapeInfo, zShapeInfo))) {
|
// else if(xEws == 1 && zEws == 1 && xOrder == zOrder && (xOrder == 'c' || shape::shapeEquals(xShapeInfo, zShapeInfo))) {
|
||||||
xOffsets = zOffsets = nullptr;
|
// xOffsets = zOffsets = nullptr;
|
||||||
yOffsets = new Nd4jLong[len];
|
// yOffsets = new Nd4jLong[len];
|
||||||
shape::calcOffsets(yShapeInfo, yOffsets, xOrder);
|
// shape::calcOffsets(yShapeInfo, yOffsets, xOrder);
|
||||||
}
|
// }
|
||||||
else if(yEws == 1 && zEws == 1 && yOrder == zOrder && (yOrder == 'c' || shape::shapeEquals(yShapeInfo, zShapeInfo))) {
|
// else if(yEws == 1 && zEws == 1 && yOrder == zOrder && (yOrder == 'c' || shape::shapeEquals(yShapeInfo, zShapeInfo))) {
|
||||||
yOffsets = zOffsets = nullptr;
|
// yOffsets = zOffsets = nullptr;
|
||||||
xOffsets = new Nd4jLong[len];
|
// xOffsets = new Nd4jLong[len];
|
||||||
shape::calcOffsets(xShapeInfo, xOffsets, yOrder);
|
// shape::calcOffsets(xShapeInfo, xOffsets, yOrder);
|
||||||
}
|
// }
|
||||||
else if(xEws == 1) {
|
// else if(xEws == 1) {
|
||||||
xOffsets = nullptr;
|
// xOffsets = nullptr;
|
||||||
PRAGMA_OMP_PARALLEL_SECTIONS
|
// PRAGMA_OMP_PARALLEL_SECTIONS
|
||||||
{
|
// {
|
||||||
PRAGMA_OMP_SECTION
|
// PRAGMA_OMP_SECTION
|
||||||
{
|
// {
|
||||||
yOffsets = new Nd4jLong[len];
|
// yOffsets = new Nd4jLong[len];
|
||||||
shape::calcOffsets(yShapeInfo, yOffsets, xOrder);
|
// shape::calcOffsets(yShapeInfo, yOffsets, xOrder);
|
||||||
}
|
// }
|
||||||
PRAGMA_OMP_SECTION
|
// PRAGMA_OMP_SECTION
|
||||||
{
|
// {
|
||||||
zOffsets = new Nd4jLong[len];
|
// zOffsets = new Nd4jLong[len];
|
||||||
shape::calcOffsets(zShapeInfo, zOffsets, xOrder);
|
// shape::calcOffsets(zShapeInfo, zOffsets, xOrder);
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
else if(yEws == 1) {
|
// else if(yEws == 1) {
|
||||||
yOffsets = nullptr;
|
// yOffsets = nullptr;
|
||||||
PRAGMA_OMP_PARALLEL_SECTIONS
|
// PRAGMA_OMP_PARALLEL_SECTIONS
|
||||||
{
|
// {
|
||||||
PRAGMA_OMP_SECTION
|
// PRAGMA_OMP_SECTION
|
||||||
{
|
// {
|
||||||
xOffsets = new Nd4jLong[len];
|
// xOffsets = new Nd4jLong[len];
|
||||||
shape::calcOffsets(xShapeInfo, xOffsets, yOrder);
|
// shape::calcOffsets(xShapeInfo, xOffsets, yOrder);
|
||||||
}
|
// }
|
||||||
PRAGMA_OMP_SECTION
|
// PRAGMA_OMP_SECTION
|
||||||
{
|
// {
|
||||||
zOffsets = new Nd4jLong[len];
|
// zOffsets = new Nd4jLong[len];
|
||||||
shape::calcOffsets(zShapeInfo, zOffsets, yOrder);
|
// shape::calcOffsets(zShapeInfo, zOffsets, yOrder);
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
else if(zEws == 1) {
|
// else if(zEws == 1) {
|
||||||
zOffsets = nullptr;
|
// zOffsets = nullptr;
|
||||||
PRAGMA_OMP_PARALLEL_SECTIONS
|
// PRAGMA_OMP_PARALLEL_SECTIONS
|
||||||
{
|
// {
|
||||||
PRAGMA_OMP_SECTION
|
// PRAGMA_OMP_SECTION
|
||||||
{
|
// {
|
||||||
xOffsets = new Nd4jLong[len];
|
// xOffsets = new Nd4jLong[len];
|
||||||
shape::calcOffsets(xShapeInfo, xOffsets, zOrder);
|
// shape::calcOffsets(xShapeInfo, xOffsets, zOrder);
|
||||||
}
|
// }
|
||||||
PRAGMA_OMP_SECTION
|
// PRAGMA_OMP_SECTION
|
||||||
{
|
// {
|
||||||
yOffsets = new Nd4jLong[len];
|
// yOffsets = new Nd4jLong[len];
|
||||||
shape::calcOffsets(yShapeInfo, yOffsets, zOrder);
|
// shape::calcOffsets(yShapeInfo, yOffsets, zOrder);
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
else if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo, zShapeInfo)) {
|
// else if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo, zShapeInfo)) {
|
||||||
xOffsets = new Nd4jLong[len];
|
// xOffsets = new Nd4jLong[len];
|
||||||
shape::calcOffsets(xShapeInfo, xOffsets);
|
// shape::calcOffsets(xShapeInfo, xOffsets);
|
||||||
yOffsets = zOffsets = xOffsets;
|
// yOffsets = zOffsets = xOffsets;
|
||||||
}
|
// }
|
||||||
else if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) {
|
// else if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) {
|
||||||
PRAGMA_OMP_PARALLEL_SECTIONS
|
// PRAGMA_OMP_PARALLEL_SECTIONS
|
||||||
{
|
// {
|
||||||
PRAGMA_OMP_SECTION
|
// PRAGMA_OMP_SECTION
|
||||||
{
|
// {
|
||||||
xOffsets = new Nd4jLong[len];
|
// xOffsets = new Nd4jLong[len];
|
||||||
shape::calcOffsets(xShapeInfo, xOffsets);
|
// shape::calcOffsets(xShapeInfo, xOffsets);
|
||||||
}
|
// }
|
||||||
PRAGMA_OMP_SECTION
|
// PRAGMA_OMP_SECTION
|
||||||
{
|
// {
|
||||||
zOffsets = new Nd4jLong[len];
|
// zOffsets = new Nd4jLong[len];
|
||||||
shape::calcOffsets(zShapeInfo, zOffsets);
|
// shape::calcOffsets(zShapeInfo, zOffsets);
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
yOffsets = xOffsets;
|
// yOffsets = xOffsets;
|
||||||
}
|
// }
|
||||||
else if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
|
// else if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
|
||||||
PRAGMA_OMP_PARALLEL_SECTIONS
|
// PRAGMA_OMP_PARALLEL_SECTIONS
|
||||||
{
|
// {
|
||||||
PRAGMA_OMP_SECTION
|
// PRAGMA_OMP_SECTION
|
||||||
{
|
// {
|
||||||
xOffsets = new Nd4jLong[len];
|
// xOffsets = new Nd4jLong[len];
|
||||||
shape::calcOffsets(xShapeInfo, xOffsets);
|
// shape::calcOffsets(xShapeInfo, xOffsets);
|
||||||
}
|
// }
|
||||||
PRAGMA_OMP_SECTION
|
// PRAGMA_OMP_SECTION
|
||||||
{
|
// {
|
||||||
yOffsets = new Nd4jLong[len];
|
// yOffsets = new Nd4jLong[len];
|
||||||
shape::calcOffsets(yShapeInfo, yOffsets);
|
// shape::calcOffsets(yShapeInfo, yOffsets);
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
zOffsets = xOffsets;
|
// zOffsets = xOffsets;
|
||||||
}
|
// }
|
||||||
else {
|
// else {
|
||||||
PRAGMA_OMP_PARALLEL_SECTIONS
|
// PRAGMA_OMP_PARALLEL_SECTIONS
|
||||||
{
|
// {
|
||||||
PRAGMA_OMP_SECTION
|
// PRAGMA_OMP_SECTION
|
||||||
{
|
// {
|
||||||
xOffsets = new Nd4jLong[len];
|
// xOffsets = new Nd4jLong[len];
|
||||||
shape::calcOffsets(xShapeInfo, xOffsets);
|
// shape::calcOffsets(xShapeInfo, xOffsets);
|
||||||
}
|
// }
|
||||||
PRAGMA_OMP_SECTION
|
// PRAGMA_OMP_SECTION
|
||||||
{
|
// {
|
||||||
yOffsets = new Nd4jLong[len];
|
// yOffsets = new Nd4jLong[len];
|
||||||
shape::calcOffsets(yShapeInfo, yOffsets);
|
// shape::calcOffsets(yShapeInfo, yOffsets);
|
||||||
}
|
// }
|
||||||
PRAGMA_OMP_SECTION
|
// PRAGMA_OMP_SECTION
|
||||||
{
|
// {
|
||||||
zOffsets = new Nd4jLong[len];
|
// zOffsets = new Nd4jLong[len];
|
||||||
shape::calcOffsets(zShapeInfo, zOffsets);
|
// shape::calcOffsets(zShapeInfo, zOffsets);
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
// INLINEDEF _CUDA_HD void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffsets, const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const char order) {
|
||||||
|
|
||||||
|
// // we assume all array have same length
|
||||||
|
// const Nd4jLong len = shape::length(xShapeInfo);
|
||||||
|
|
||||||
|
// const Nd4jLong xEws = shape::elementWiseStride(xShapeInfo);
|
||||||
|
// const Nd4jLong yEws = shape::elementWiseStride(yShapeInfo);
|
||||||
|
|
||||||
|
// const char xOrder = shape::order(xShapeInfo);
|
||||||
|
// const char yOrder = shape::order(yShapeInfo);
|
||||||
|
|
||||||
|
// const bool shapesSame = shape::shapeEquals(xShapeInfo, yShapeInfo);
|
||||||
|
|
||||||
|
// if (xEws == 1 && yEws == 1 && xOrder == yOrder && (xOrder == 'c' || shapesSame)) {
|
||||||
|
// xOffsets = yOffsets = nullptr;
|
||||||
|
// }
|
||||||
|
// else if(xEws == 1) {
|
||||||
|
// xOffsets = nullptr;
|
||||||
|
// yOffsets = new Nd4jLong[len];
|
||||||
|
// shape::calcOffsets(yShapeInfo, yOffsets, xOrder);
|
||||||
|
// }
|
||||||
|
// else if(yEws == 1) {
|
||||||
|
// yOffsets = nullptr;
|
||||||
|
// xOffsets = new Nd4jLong[len];
|
||||||
|
// shape::calcOffsets(xShapeInfo, xOffsets, yOrder);
|
||||||
|
// }
|
||||||
|
// else if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) {
|
||||||
|
// xOffsets = new Nd4jLong[len];
|
||||||
|
// shape::calcOffsets(xShapeInfo, xOffsets);
|
||||||
|
// yOffsets = xOffsets;
|
||||||
|
// }
|
||||||
|
// else {
|
||||||
|
// PRAGMA_OMP_PARALLEL_SECTIONS
|
||||||
|
// {
|
||||||
|
// PRAGMA_OMP_SECTION
|
||||||
|
// {
|
||||||
|
// xOffsets = new Nd4jLong[len];
|
||||||
|
// shape::calcOffsets(xShapeInfo, xOffsets);
|
||||||
|
// }
|
||||||
|
// PRAGMA_OMP_SECTION
|
||||||
|
// {
|
||||||
|
// yOffsets = new Nd4jLong[len];
|
||||||
|
// shape::calcOffsets(yShapeInfo, yOffsets);
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
INLINEDEF _CUDA_HD int excludeUnitiesFromShapeInfo(const Nd4jLong* inShapeInfo, Nd4jLong*& shapeNoUnities, Nd4jLong*& stridesNoUnities) {
|
||||||
|
|
||||||
|
const int rank = shape::rank(inShapeInfo);
|
||||||
|
const int numOfNonUnities = shape::numOfNonUnitDims(rank, shape::shapeOf(inShapeInfo));
|
||||||
|
|
||||||
|
if(numOfNonUnities == rank) { // no unities in shape, no copy procedure
|
||||||
|
shapeNoUnities = const_cast<Nd4jLong*>(inShapeInfo) + 1;
|
||||||
|
stridesNoUnities = const_cast<Nd4jLong*>(inShapeInfo) + 1 + rank;
|
||||||
|
return numOfNonUnities;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for(uint j = 0, i = 0; i < rank; ++i) {
|
||||||
|
if(shape::shapeOf(inShapeInfo)[i] != 1) {
|
||||||
|
shapeNoUnities[j] = shape::shapeOf(inShapeInfo)[i];
|
||||||
|
shapeNoUnities[numOfNonUnities + j++] = shape::stride(inShapeInfo)[i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
stridesNoUnities = shapeNoUnities + numOfNonUnities;
|
||||||
|
|
||||||
|
return numOfNonUnities;
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
INLINEDEF _CUDA_HD void calcOffsets(const Nd4jLong *xShapeInfo, Nd4jLong*& xOffsets, const Nd4jLong *yShapeInfo, Nd4jLong*& yOffsets, const char order) {
|
INLINEDEF _CUDA_HD void excludeUnitiesFromShapeInfo(const Nd4jLong* inShapeInfo, const int dimsSize, const int* dimsToExclude, Nd4jLong* outShapeInfo) {
|
||||||
|
|
||||||
// we assume all array have same length
|
outShapeInfo[0] = inShapeInfo[0] - dimsSize;
|
||||||
const Nd4jLong len = shape::length(xShapeInfo);
|
|
||||||
|
|
||||||
const Nd4jLong xEws = shape::elementWiseStride(xShapeInfo);
|
for(uint j = 0, k = 0, i = 0; i < inShapeInfo[0]; ++i) {
|
||||||
const Nd4jLong yEws = shape::elementWiseStride(yShapeInfo);
|
if(j < dimsSize && i == dimsToExclude[j]) {
|
||||||
|
++j;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
const char xOrder = shape::order(xShapeInfo);
|
shape::shapeOf(outShapeInfo)[k] = shape::shapeOf(inShapeInfo)[i];
|
||||||
const char yOrder = shape::order(yShapeInfo);
|
shape::stride(outShapeInfo)[k++] = shape::stride(inShapeInfo)[i];
|
||||||
|
}
|
||||||
|
|
||||||
const bool shapesSame = shape::shapeEquals(xShapeInfo, yShapeInfo);
|
outShapeInfo[2 * outShapeInfo[0] + 1] = shape::type(inShapeInfo); // type
|
||||||
|
*shape::ews(outShapeInfo) = shape::elementWiseStride(inShapeInfo); // ews
|
||||||
if (xEws == 1 && yEws == 1 && xOrder == yOrder && (xOrder == 'c' || shapesSame)) {
|
outShapeInfo[2 * outShapeInfo[0] + 3] = shape::order(inShapeInfo); // order
|
||||||
xOffsets = yOffsets = nullptr;
|
|
||||||
}
|
|
||||||
else if(xEws == 1) {
|
|
||||||
xOffsets = nullptr;
|
|
||||||
yOffsets = new Nd4jLong[len];
|
|
||||||
shape::calcOffsets(yShapeInfo, yOffsets, xOrder);
|
|
||||||
}
|
|
||||||
else if(yEws == 1) {
|
|
||||||
yOffsets = nullptr;
|
|
||||||
xOffsets = new Nd4jLong[len];
|
|
||||||
shape::calcOffsets(xShapeInfo, xOffsets, yOrder);
|
|
||||||
}
|
|
||||||
else if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) {
|
|
||||||
xOffsets = new Nd4jLong[len];
|
|
||||||
shape::calcOffsets(xShapeInfo, xOffsets);
|
|
||||||
yOffsets = xOffsets;
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
PRAGMA_OMP_PARALLEL_SECTIONS
|
|
||||||
{
|
|
||||||
PRAGMA_OMP_SECTION
|
|
||||||
{
|
|
||||||
xOffsets = new Nd4jLong[len];
|
|
||||||
shape::calcOffsets(xShapeInfo, xOffsets);
|
|
||||||
}
|
|
||||||
PRAGMA_OMP_SECTION
|
|
||||||
{
|
|
||||||
yOffsets = new Nd4jLong[len];
|
|
||||||
shape::calcOffsets(yShapeInfo, yOffsets);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif /* SHAPE_H_ */
|
#endif /* SHAPE_H_ */
|
||||||
|
|
|
@ -84,7 +84,7 @@ namespace functions {
|
||||||
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
int totalThreads = gridDim.x * blockDim.x;
|
int totalThreads = gridDim.x * blockDim.x;
|
||||||
|
|
||||||
if(xEws > 0 && zEws > 0 && xOrder == zOrder) {
|
if(xEws > 0 && zEws > 0 && xOrder == zOrder && xOrder == 'c') {
|
||||||
|
|
||||||
for (int i = tid; i < length; i += totalThreads)
|
for (int i = tid; i < length; i += totalThreads)
|
||||||
z[i * zEws] = OpType::op(x[i * xEws], params);
|
z[i * zEws] = OpType::op(x[i * xEws], params);
|
||||||
|
|
|
@ -89,7 +89,7 @@ namespace functions {
|
||||||
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
int totalThreads = gridDim.x * blockDim.x;
|
int totalThreads = gridDim.x * blockDim.x;
|
||||||
|
|
||||||
if(xEws > 0 && zEws > 0 && xOrder == zOrder) {
|
if(xEws > 0 && zEws > 0 && xOrder == zOrder && xOrder == 'c') {
|
||||||
|
|
||||||
for (int i = tid; i < length; i += totalThreads)
|
for (int i = tid; i < length; i += totalThreads)
|
||||||
z[i * zEws] = OpType::op(x[i * xEws], params);
|
z[i * zEws] = OpType::op(x[i * xEws], params);
|
||||||
|
|
|
@ -97,7 +97,7 @@ namespace functions {
|
||||||
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
int totalThreads = gridDim.x * blockDim.x;
|
int totalThreads = gridDim.x * blockDim.x;
|
||||||
|
|
||||||
if(xEws > 0 && zEws > 0 && xOrder == zOrder) {
|
if(xEws > 0 && zEws > 0 && xOrder == zOrder && xOrder == 'c') {
|
||||||
|
|
||||||
for (Nd4jLong i = tid; i < length; i += totalThreads)
|
for (Nd4jLong i = tid; i < length; i += totalThreads)
|
||||||
z[i * zEws] = OpType::op(x[i * xEws], params);
|
z[i * zEws] = OpType::op(x[i * xEws], params);
|
||||||
|
|
|
@ -87,7 +87,7 @@ namespace functions {
|
||||||
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
int totalThreads = gridDim.x * blockDim.x;
|
int totalThreads = gridDim.x * blockDim.x;
|
||||||
|
|
||||||
if(xEws > 0 && zEws > 0 && xOrder == zOrder) {
|
if(xEws > 0 && zEws > 0 && xOrder == zOrder && xOrder == 'c') {
|
||||||
|
|
||||||
for (int i = tid; i < length; i += totalThreads)
|
for (int i = tid; i < length; i += totalThreads)
|
||||||
z[i * zEws] = OpType::op(x[i * xEws], params);
|
z[i * zEws] = OpType::op(x[i * xEws], params);
|
||||||
|
|
|
@ -89,7 +89,7 @@ namespace functions {
|
||||||
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
int totalThreads = gridDim.x * blockDim.x;
|
int totalThreads = gridDim.x * blockDim.x;
|
||||||
|
|
||||||
if(xEws > 0 && zEws > 0 && xOrder == zOrder) {
|
if(xEws > 0 && zEws > 0 && xOrder == zOrder && xOrder == 'c') {
|
||||||
|
|
||||||
for (int i = tid; i < length; i += totalThreads)
|
for (int i = tid; i < length; i += totalThreads)
|
||||||
z[i * zEws] = OpType::op(x[i * xEws], params);
|
z[i * zEws] = OpType::op(x[i * xEws], params);
|
||||||
|
|
|
@ -21,17 +21,22 @@
|
||||||
#include <op_boilerplate.h>
|
#include <op_boilerplate.h>
|
||||||
#if NOT_EXCLUDED(OP_tensormmul)
|
#if NOT_EXCLUDED(OP_tensormmul)
|
||||||
|
|
||||||
|
#include <numeric>
|
||||||
#include <helpers/ShapeUtils.h>
|
#include <helpers/ShapeUtils.h>
|
||||||
#include <ops/declarable/CustomOperations.h>
|
#include <ops/declarable/CustomOperations.h>
|
||||||
#include <MmulHelper.h>
|
#include <MmulHelper.h>
|
||||||
|
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
CUSTOM_OP_IMPL(tensormmul, 2, 1, false, 0, -1) {
|
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
CUSTOM_OP_IMPL(tensormmul, 2, 1, false, 0, -1) {
|
||||||
|
|
||||||
auto a = INPUT_VARIABLE(0);
|
auto a = INPUT_VARIABLE(0);
|
||||||
auto b = INPUT_VARIABLE(1);
|
auto b = INPUT_VARIABLE(1);
|
||||||
|
|
||||||
auto c = OUTPUT_VARIABLE(0); //
|
auto c = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
REQUIRE_TRUE(a->dataType() == b->dataType(), 0, "tensormmul: A, B and C data types must be the same");
|
REQUIRE_TRUE(a->dataType() == b->dataType(), 0, "tensormmul: A, B and C data types must be the same");
|
||||||
|
|
||||||
|
@ -40,20 +45,20 @@ namespace nd4j {
|
||||||
int axe1_size = INT_ARG(axe0_size+1);
|
int axe1_size = INT_ARG(axe0_size+1);
|
||||||
std::vector<int> axes_0(axe0_size), axes_1(axe1_size);
|
std::vector<int> axes_0(axe0_size), axes_1(axe1_size);
|
||||||
for (int e = 0; e < axe0_size; e++)
|
for (int e = 0; e < axe0_size; e++)
|
||||||
axes_0[e] = (int) INT_ARG(e+1);
|
axes_0[e] = (int)INT_ARG(e + 1);
|
||||||
|
|
||||||
for (int e = 0; e < axe1_size; e++)
|
for (int e = 0; e < axe1_size; e++)
|
||||||
axes_1[e] = (int) INT_ARG(e + axe0_size + 2);
|
axes_1[e] = (int)INT_ARG(e + axe0_size + 2);
|
||||||
|
|
||||||
nd4j_verbose("axe0: %i; axe1: %i;\n", axes_0.size(), axes_1.size());
|
nd4j_verbose("axe0: %i; axe1: %i;\n", axes_0.size(), axes_1.size());
|
||||||
|
|
||||||
MmulHelper::tensorDot(a, b, c, axes_0, axes_1);
|
MmulHelper::tensorDot(a, b, c, axes_0, axes_1);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
DECLARE_SYN(tensordot, tensormmul);
|
DECLARE_SYN(tensordot, tensormmul);
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
DECLARE_SHAPE_FN(tensormmul) {
|
DECLARE_SHAPE_FN(tensormmul) {
|
||||||
|
|
||||||
auto aShapeInfo = inputShape->at(0);
|
auto aShapeInfo = inputShape->at(0);
|
||||||
auto bShapeInfo = inputShape->at(1);
|
auto bShapeInfo = inputShape->at(1);
|
||||||
|
@ -76,15 +81,114 @@ namespace nd4j {
|
||||||
auto outShape = nd4j::ShapeUtils::evalShapeForTensorDot(aShapeInfo, bShapeInfo, axes_0, axes_1, permutAt, permutBt, shapeAt, shapeBt);
|
auto outShape = nd4j::ShapeUtils::evalShapeForTensorDot(aShapeInfo, bShapeInfo, axes_0, axes_1, permutAt, permutBt, shapeAt, shapeBt);
|
||||||
|
|
||||||
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(aShapeInfo), 'c', outShape)));
|
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(aShapeInfo), 'c', outShape)));
|
||||||
}
|
}
|
||||||
|
|
||||||
DECLARE_TYPES(tensormmul) {
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
DECLARE_TYPES(tensormmul) {
|
||||||
getOpDescriptor()
|
getOpDescriptor()
|
||||||
->setAllowedInputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF})
|
->setAllowedInputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF})
|
||||||
->setAllowedInputTypes(1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF})
|
->setAllowedInputTypes(1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF})
|
||||||
->setAllowedOutputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF});
|
->setAllowedOutputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF});
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
CUSTOM_OP_IMPL(tensormmul_bp, 3, 2, false, 0, -1) {
|
||||||
|
|
||||||
|
auto A = INPUT_VARIABLE(0);
|
||||||
|
auto B = INPUT_VARIABLE(1);
|
||||||
|
|
||||||
|
auto dLdC = INPUT_VARIABLE(2);
|
||||||
|
|
||||||
|
auto dLdA = OUTPUT_VARIABLE(0);
|
||||||
|
auto dLdB = OUTPUT_VARIABLE(1);
|
||||||
|
|
||||||
|
REQUIRE_TRUE( (A->dataType() == B->dataType() && (dLdC->dataType() == A->dataType())), 0, "tensormmul_bp: A, B and dLdC data types must be the same");
|
||||||
|
|
||||||
|
int axe0Size = INT_ARG(0);
|
||||||
|
int axe1Size = INT_ARG(axe0Size + 1);
|
||||||
|
|
||||||
|
auto Arank = A->rankOf();
|
||||||
|
auto Brank = B->rankOf();
|
||||||
|
auto dLdCrank = dLdC->rankOf();
|
||||||
|
|
||||||
|
REQUIRE_TRUE((Arank >= axe0Size), 0, "tensormmul_bp: A rank must be the higher or same as input axes 0");
|
||||||
|
|
||||||
|
REQUIRE_TRUE((Brank >= axe1Size), 0, "tensormmul_bp: B rank must be the higher or same as input axes 1");
|
||||||
|
|
||||||
|
// building axes
|
||||||
|
std::vector<int> axes0(axe0Size), axes1(axe1Size);
|
||||||
|
for (uint e = 0; e < axe0Size; e++)
|
||||||
|
axes0[e] = (int)INT_ARG(e + 1);
|
||||||
|
for (uint e = 0; e < axe1Size; e++)
|
||||||
|
axes1[e] = (int)INT_ARG(e + axe0Size + 2);
|
||||||
|
|
||||||
|
std::vector<int> permutAt, permutBt;
|
||||||
|
std::vector<Nd4jLong> shapeAt, shapeBt;
|
||||||
|
|
||||||
|
ShapeUtils::evalShapeForTensorDot(A, B, axes0, axes1, permutAt, permutBt, shapeAt, shapeBt);
|
||||||
|
|
||||||
|
// special case for scalar value
|
||||||
|
if (dLdC->isScalar()) {
|
||||||
|
|
||||||
|
dLdA->assign((*dLdC) * *B);
|
||||||
|
dLdB->assign((*dLdC) * *A);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<int> axesA = ShapeUtils::evalDimsToExclude(Arank, axes0);
|
||||||
|
std::vector<int> axesB = ShapeUtils::evalDimsToExclude(Brank, axes1);
|
||||||
|
|
||||||
|
// rank always have to be divided by 2
|
||||||
|
std::vector<int> axesAdLdC, axesBdLdC;
|
||||||
|
if (dLdCrank > 1) {
|
||||||
|
axesAdLdC.resize(dLdCrank / 2);
|
||||||
|
std::iota(axesAdLdC.begin(), axesAdLdC.end(), 0);
|
||||||
|
axesBdLdC = ShapeUtils::evalDimsToExclude(dLdCrank, axesAdLdC);
|
||||||
}
|
}
|
||||||
|
else {
|
||||||
|
axesAdLdC.push_back(0);
|
||||||
|
axesBdLdC.push_back(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// calculate dLdA
|
||||||
|
MmulHelper::tensorDot(dLdC, B, dLdA, axesBdLdC, axesB, permutAt);
|
||||||
|
|
||||||
|
// calculate dLdB
|
||||||
|
MmulHelper::tensorDot(A, dLdC, dLdB, axesA, axesAdLdC, permutBt);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
DECLARE_SHAPE_FN(tensormmul_bp) {
|
||||||
|
|
||||||
|
auto aShapeInfo = inputShape->at(0);
|
||||||
|
auto bShapeInfo = inputShape->at(1);
|
||||||
|
auto dLShapeInfo = inputShape->at(2);
|
||||||
|
|
||||||
|
REQUIRE_TRUE((ArrayOptions::dataType(aShapeInfo) == ArrayOptions::dataType(bShapeInfo) &&
|
||||||
|
(ArrayOptions::dataType(dLShapeInfo) == ArrayOptions::dataType(aShapeInfo))), 0, "tensormmul_bp: A, B and dLdC data types must be the same");
|
||||||
|
|
||||||
|
Nd4jLong* dLdAShapeInfo = nullptr;
|
||||||
|
Nd4jLong* dLdBShapeInfo = nullptr;
|
||||||
|
|
||||||
|
COPY_SHAPE(aShapeInfo, dLdAShapeInfo);
|
||||||
|
COPY_SHAPE(bShapeInfo, dLdBShapeInfo);
|
||||||
|
|
||||||
|
return SHAPELIST(CONSTANT(dLdAShapeInfo), CONSTANT(dLdBShapeInfo));
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
DECLARE_TYPES(tensormmul_bp) {
|
||||||
|
getOpDescriptor()
|
||||||
|
->setAllowedInputTypes(0, { DataType::FLOAT32, DataType::DOUBLE, DataType::HALF }) // maybe better ALL_FLOATS
|
||||||
|
->setAllowedInputTypes(1, { DataType::FLOAT32, DataType::DOUBLE, DataType::HALF })
|
||||||
|
->setAllowedInputTypes(2, { DataType::FLOAT32, DataType::DOUBLE, DataType::HALF })
|
||||||
|
->setAllowedOutputTypes(0, { DataType::FLOAT32, DataType::DOUBLE, DataType::HALF })
|
||||||
|
->setAllowedOutputTypes(1, { DataType::FLOAT32, DataType::DOUBLE, DataType::HALF });
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
|
@ -79,7 +79,7 @@ CUSTOM_OP_IMPL(conv1d, 2, 1, false, 0, 5) {
|
||||||
}
|
}
|
||||||
|
|
||||||
auto inputReshaped = input ->reshape(input->ordering(), reshapeForInput);
|
auto inputReshaped = input ->reshape(input->ordering(), reshapeForInput);
|
||||||
auto outputReshaped = output ->reshape(output->ordering(), reshapeForOutput);
|
auto outputReshaped = output ->reshape(output->ordering(), reshapeForOutput, false);
|
||||||
auto weightsReshaped = weights->reshape(weights->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
|
auto weightsReshaped = weights->reshape(weights->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
|
||||||
|
|
||||||
nd4j::ops::conv2d conv2d;
|
nd4j::ops::conv2d conv2d;
|
||||||
|
@ -216,10 +216,10 @@ CUSTOM_OP_IMPL(conv1d_bp, 3, 2, false, 0, 5) {
|
||||||
}
|
}
|
||||||
|
|
||||||
auto inputReshaped = input ->reshape(input->ordering(), reshapeForInput);
|
auto inputReshaped = input ->reshape(input->ordering(), reshapeForInput);
|
||||||
auto gradIReshaped = gradI ->reshape(gradI->ordering(), reshapeForInput);
|
auto gradIReshaped = gradI ->reshape(gradI->ordering(), reshapeForInput, false);
|
||||||
auto gradOReshaped = gradO ->reshape(gradO->ordering(), reshapeForGradO);
|
auto gradOReshaped = gradO ->reshape(gradO->ordering(), reshapeForGradO);
|
||||||
auto weightsReshaped = weights->reshape(weights->ordering(),{1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
|
auto weightsReshaped = weights->reshape(weights->ordering(),{1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
|
||||||
auto gradWReshaped = gradW ->reshape(gradW->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
|
auto gradWReshaped = gradW ->reshape(gradW->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}, false);// [kW, iC, oC] -> [1, kW, iC, oC]
|
||||||
|
|
||||||
nd4j::ops::conv2d_bp conv2dBP;
|
nd4j::ops::conv2d_bp conv2dBP;
|
||||||
auto status = conv2dBP.execute({&inputReshaped, &weightsReshaped, bias, &gradOReshaped}, {&gradIReshaped, &gradWReshaped, gradB}, {}, {1,kW, 1,sW, 0,pW, 1,dW, paddingMode, !isNCW}, {});
|
auto status = conv2dBP.execute({&inputReshaped, &weightsReshaped, bias, &gradOReshaped}, {&gradIReshaped, &gradWReshaped, gradB}, {}, {1,kW, 1,sW, 0,pW, 1,dW, paddingMode, !isNCW}, {});
|
||||||
|
|
|
@ -239,7 +239,7 @@ CUSTOM_OP_IMPL(conv3dnew_bp, 3, 2, false, 0, 13) {
|
||||||
//----- calculation of gradO -----//
|
//----- calculation of gradO -----//
|
||||||
if(gradB) {
|
if(gradB) {
|
||||||
if(gradB->rankOf() == 2)
|
if(gradB->rankOf() == 2)
|
||||||
gradB = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}));
|
gradB = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}, false));
|
||||||
gradO->reduceAlongDimension(reduce::Sum, *gradB, gradOaxesForDot); // sum over bS oD oH oW
|
gradO->reduceAlongDimension(reduce::Sum, *gradB, gradOaxesForDot); // sum over bS oD oH oW
|
||||||
if(gradB != OUTPUT_VARIABLE(2))
|
if(gradB != OUTPUT_VARIABLE(2))
|
||||||
delete gradB;
|
delete gradB;
|
||||||
|
|
|
@ -233,7 +233,7 @@ CUSTOM_OP_IMPL(deconv2d_bp, 3, 2, false, 0, 9) {
|
||||||
// ----- calculation of gradB ----- //
|
// ----- calculation of gradB ----- //
|
||||||
if(gradB) {
|
if(gradB) {
|
||||||
if(gradB->rankOf() == 2)
|
if(gradB->rankOf() == 2)
|
||||||
gradB = new NDArray(gradB->reshape(gradB->ordering(), {gradB->lengthOf()}));
|
gradB = new NDArray(gradB->reshape(gradB->ordering(), {gradB->lengthOf()}, false));
|
||||||
gradO->reduceAlongDimension(reduce::Sum, *gradB, {0, 2, 3}); // sum over bS, oH, oW
|
gradO->reduceAlongDimension(reduce::Sum, *gradB, {0, 2, 3}); // sum over bS, oH, oW
|
||||||
if(gradB != OUTPUT_VARIABLE(2))
|
if(gradB != OUTPUT_VARIABLE(2))
|
||||||
delete gradB;
|
delete gradB;
|
||||||
|
|
|
@ -243,7 +243,7 @@ CUSTOM_OP_IMPL(deconv3d_bp, 3, 2, false, 0, 13) {
|
||||||
// ----- calculation of gradB ----- //
|
// ----- calculation of gradB ----- //
|
||||||
if(gradB) {
|
if(gradB) {
|
||||||
if(gradB->rankOf() == 2)
|
if(gradB->rankOf() == 2)
|
||||||
gradB = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}));
|
gradB = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}, false));
|
||||||
gradO->reduceAlongDimension(reduce::Sum, *gradB, {0, 2, 3, 4}); // sum over bS, oD, oH, oW
|
gradO->reduceAlongDimension(reduce::Sum, *gradB, {0, 2, 3, 4}); // sum over bS, oD, oH, oW
|
||||||
if(gradB != OUTPUT_VARIABLE(2))
|
if(gradB != OUTPUT_VARIABLE(2))
|
||||||
delete gradB;
|
delete gradB;
|
||||||
|
|
|
@ -61,7 +61,7 @@ namespace nd4j {
|
||||||
}
|
}
|
||||||
|
|
||||||
auto source = inRank == 4?image->reshape(image->ordering(), {image->sizeAt(0), image->sizeAt(1), image->sizeAt(2), image->sizeAt(3)}):image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)});
|
auto source = inRank == 4?image->reshape(image->ordering(), {image->sizeAt(0), image->sizeAt(1), image->sizeAt(2), image->sizeAt(3)}):image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)});
|
||||||
auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}):output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)});
|
auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}, false) : output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)}, false);
|
||||||
|
|
||||||
return helpers::resizeAreaFunctor(block.launchContext(), &source, width, height, alignCorners, &target);
|
return helpers::resizeAreaFunctor(block.launchContext(), &source, width, height, alignCorners, &target);
|
||||||
}
|
}
|
||||||
|
|
|
@ -62,7 +62,7 @@ namespace nd4j {
|
||||||
REQUIRE_TRUE(!halfPixelAlign || (halfPixelAlign && !alignCorners), 0, "resize_bicubic: `half_pixel_centers' should be false or true only when `align_corners' is false");
|
REQUIRE_TRUE(!halfPixelAlign || (halfPixelAlign && !alignCorners), 0, "resize_bicubic: `half_pixel_centers' should be false or true only when `align_corners' is false");
|
||||||
|
|
||||||
auto source = inRank == 4?image->reshape(image->ordering(), {image->sizeAt(0), image->sizeAt(1), image->sizeAt(2), image->sizeAt(3)}):image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)});
|
auto source = inRank == 4?image->reshape(image->ordering(), {image->sizeAt(0), image->sizeAt(1), image->sizeAt(2), image->sizeAt(3)}):image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)});
|
||||||
auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}):output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)});
|
auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}, false) : output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)}, false);
|
||||||
|
|
||||||
return helpers::resizeBicubicFunctorA(block.launchContext(), &source, width, height, alignCorners, halfPixelAlign, &target);
|
return helpers::resizeBicubicFunctorA(block.launchContext(), &source, width, height, alignCorners, halfPixelAlign, &target);
|
||||||
}
|
}
|
||||||
|
|
|
@ -43,7 +43,7 @@ namespace nd4j {
|
||||||
REQUIRE_TRUE(inRank == output->rankOf(), 0, "resize_bilinear: Input and output ranks should be equals, but %i and %i occured.", inRank, output->rankOf());
|
REQUIRE_TRUE(inRank == output->rankOf(), 0, "resize_bilinear: Input and output ranks should be equals, but %i and %i occured.", inRank, output->rankOf());
|
||||||
|
|
||||||
auto source = inRank == 4?image->reshape(image->ordering(), {image->sizeAt(0), image->sizeAt(1), image->sizeAt(2), image->sizeAt(3)}):image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)});
|
auto source = inRank == 4?image->reshape(image->ordering(), {image->sizeAt(0), image->sizeAt(1), image->sizeAt(2), image->sizeAt(3)}):image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)});
|
||||||
auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}):output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)});
|
auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}, false) : output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)}, false);
|
||||||
|
|
||||||
if (block.width() > 1) {
|
if (block.width() > 1) {
|
||||||
auto newImageSize = INPUT_VARIABLE(1);
|
auto newImageSize = INPUT_VARIABLE(1);
|
||||||
|
|
|
@ -63,7 +63,7 @@ namespace nd4j {
|
||||||
REQUIRE_TRUE(((alignCorners && height > 2) || (height > 0)) && ((alignCorners && width > 1) || (width > 0)), 0, "resize_nearest_neighbor: Wrong input or output size to resize (width = %d, height = %d)", width, height);
|
REQUIRE_TRUE(((alignCorners && height > 2) || (height > 0)) && ((alignCorners && width > 1) || (width > 0)), 0, "resize_nearest_neighbor: Wrong input or output size to resize (width = %d, height = %d)", width, height);
|
||||||
|
|
||||||
auto source = inRank == 4?*image:image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)});
|
auto source = inRank == 4?*image:image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)});
|
||||||
auto target = inRank == 4?*output:output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)});
|
auto target = inRank == 4 ? *output : output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)}, false);
|
||||||
|
|
||||||
return helpers::resizeNeighborFunctor(block.launchContext(), inRank==4?image:&source, width, height, alignCorners, halfPixelCenter, inRank == 4 ? output : &target);
|
return helpers::resizeNeighborFunctor(block.launchContext(), inRank==4?image:&source, width, height, alignCorners, halfPixelCenter, inRank == 4 ? output : &target);
|
||||||
}
|
}
|
||||||
|
|
|
@ -71,7 +71,7 @@ namespace nd4j {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (block.isInplace()) {
|
if (block.isInplace()) {
|
||||||
output->reshapei(input->ordering(), shape);
|
output->reshapei(input->ordering(), shape, false);
|
||||||
} else {
|
} else {
|
||||||
auto tmp = input->reshape(input->ordering(), shape);
|
auto tmp = input->reshape(input->ordering(), shape);
|
||||||
output->assign(tmp);
|
output->assign(tmp);
|
||||||
|
|
|
@ -58,6 +58,7 @@ namespace nd4j {
|
||||||
*/
|
*/
|
||||||
#if NOT_EXCLUDED(OP_tensormmul)
|
#if NOT_EXCLUDED(OP_tensormmul)
|
||||||
DECLARE_CUSTOM_OP(tensormmul, 2, 1, false, 0, -1);
|
DECLARE_CUSTOM_OP(tensormmul, 2, 1, false, 0, -1);
|
||||||
|
DECLARE_CUSTOM_OP(tensormmul_bp, 3, 2, false, 0, -1);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -432,7 +432,7 @@ namespace nd4j {
|
||||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
||||||
|
|
||||||
NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, input->dataType(), input->getContext());
|
NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, input->dataType(), input->getContext());
|
||||||
NDArray outputReshaped = output->reshape(output->ordering(), outReShape);
|
NDArray outputReshaped = output->reshape(output->ordering(), outReShape, false);
|
||||||
|
|
||||||
helpers::im2col(*output->getContext(), *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW]
|
helpers::im2col(*output->getContext(), *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW]
|
||||||
MmulHelper::tensorDot(&columns, weights, &outputReshaped, modifColumns, {{2,0,1,3},{iC,kH*kW,mC}}, modifOutput); // [iC, bS*oH*oW, kW*kH] x [iC, kH*kW, mC] = [iC, bS*oH*oW, mC]
|
MmulHelper::tensorDot(&columns, weights, &outputReshaped, modifColumns, {{2,0,1,3},{iC,kH*kW,mC}}, modifOutput); // [iC, bS*oH*oW, kW*kH] x [iC, kH*kW, mC] = [iC, bS*oH*oW, mC]
|
||||||
|
@ -505,7 +505,7 @@ namespace nd4j {
|
||||||
if(gradB) {
|
if(gradB) {
|
||||||
NDArray* gradBR = gradB;
|
NDArray* gradBR = gradB;
|
||||||
if(gradB->rankOf() == 2)
|
if(gradB->rankOf() == 2)
|
||||||
gradBR = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}));
|
gradBR = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}, false));
|
||||||
gradO->reduceAlongDimension(reduce::Sum, *gradBR, {0,indOoH,indOoH+1}); // sum over bS, oH, oW
|
gradO->reduceAlongDimension(reduce::Sum, *gradBR, {0,indOoH,indOoH+1}); // sum over bS, oH, oW
|
||||||
|
|
||||||
if(gradBR != gradB)
|
if(gradBR != gradB)
|
||||||
|
|
|
@ -30,7 +30,7 @@ namespace helpers {
|
||||||
void crossBatched(nd4j::LaunchContext * context, NDArray *a, NDArray *b, NDArray *o) {
|
void crossBatched(nd4j::LaunchContext * context, NDArray *a, NDArray *b, NDArray *o) {
|
||||||
auto _a = a->reshape(a->ordering(), {-1, 3});
|
auto _a = a->reshape(a->ordering(), {-1, 3});
|
||||||
auto _b = b->reshape(b->ordering(), {-1, 3});
|
auto _b = b->reshape(b->ordering(), {-1, 3});
|
||||||
auto _o = o->reshape(o->ordering(), {-1, 3});
|
auto _o = o->reshape(o->ordering(), {-1, 3}, false);
|
||||||
|
|
||||||
auto tadsA = _a.allTensorsAlongDimension({1});
|
auto tadsA = _a.allTensorsAlongDimension({1});
|
||||||
auto tadsB = _b.allTensorsAlongDimension({1});
|
auto tadsB = _b.allTensorsAlongDimension({1});
|
||||||
|
|
|
@ -244,14 +244,14 @@ void spaceToBatch(nd4j::LaunchContext* context, const NDArray& input, NDArray& o
|
||||||
|
|
||||||
// [bS, iH, iW, iC] is rearranged/permuted to [bS*blockSize*blockSize, (iH + padBottom + padTop)/blockSize, (iW + padLeft + padRight)/blockSize, iC]
|
// [bS, iH, iW, iC] is rearranged/permuted to [bS*blockSize*blockSize, (iH + padBottom + padTop)/blockSize, (iW + padLeft + padRight)/blockSize, iC]
|
||||||
|
|
||||||
NDArray outputRearranged0 = output.reshape(output.ordering(), {blockSize, blockSize, input.sizeAt(0), output.sizeAt(1), output.sizeAt(2), output.sizeAt(3)});
|
NDArray outputRearranged0 = output.reshape(output.ordering(), {blockSize, blockSize, input.sizeAt(0), output.sizeAt(1), output.sizeAt(2), output.sizeAt(3)}, false);
|
||||||
outputRearranged0.permutei({2, 3,0, 4,1, 5});
|
outputRearranged0.permutei({2, 3,0, 4,1, 5});
|
||||||
|
|
||||||
if(input.lengthOf() == output.lengthOf()) {
|
if(input.lengthOf() == output.lengthOf()) {
|
||||||
outputRearranged0.assign(input);
|
outputRearranged0.assign(input);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
NDArray outputRearranged1 = outputRearranged0.reshape(output.ordering(), {input.sizeAt(0), output.sizeAt(1) * blockSize, output.sizeAt(2) * blockSize, output.sizeAt(3)});
|
NDArray outputRearranged1 = outputRearranged0.reshape(output.ordering(), {input.sizeAt(0), output.sizeAt(1) * blockSize, output.sizeAt(2) * blockSize, output.sizeAt(3)}, false);
|
||||||
BUILD_SINGLE_SELECTOR(input.dataType(), spaceToBatch_, (input, outputRearranged1, padBottom, padTop, padLeft, padRight), LIBND4J_TYPES);
|
BUILD_SINGLE_SELECTOR(input.dataType(), spaceToBatch_, (input, outputRearranged1, padBottom, padTop, padLeft, padRight), LIBND4J_TYPES);
|
||||||
|
|
||||||
if(output.getBuffer() != outputRearranged1.getBuffer())
|
if(output.getBuffer() != outputRearranged1.getBuffer())
|
||||||
|
@ -352,7 +352,7 @@ void spaceToBatchND(nd4j::LaunchContext* context, const NDArray& input, const ND
|
||||||
for(int j = 1; j < rank; ++i, ++j)
|
for(int j = 1; j < rank; ++i, ++j)
|
||||||
temp[i] = output.sizeAt(j);
|
temp[i] = output.sizeAt(j);
|
||||||
|
|
||||||
NDArray outputRearranged0 = output.reshape(output.ordering(), temp);
|
NDArray outputRearranged0 = output.reshape(output.ordering(), temp, false);
|
||||||
|
|
||||||
//*** construct permuting std::vector for permutation of output array ***//
|
//*** construct permuting std::vector for permutation of output array ***//
|
||||||
|
|
||||||
|
@ -382,7 +382,7 @@ void spaceToBatchND(nd4j::LaunchContext* context, const NDArray& input, const ND
|
||||||
for(i = 1; i < rank; ++i)
|
for(i = 1; i < rank; ++i)
|
||||||
temp[i] = (i <= numOfSpatialDims) ? output.sizeAt(i) * blockShape.e<Nd4jLong>(i - 1) : output.sizeAt(i);
|
temp[i] = (i <= numOfSpatialDims) ? output.sizeAt(i) * blockShape.e<Nd4jLong>(i - 1) : output.sizeAt(i);
|
||||||
|
|
||||||
NDArray outputRearranged1 = outputRearranged0.reshape(output.ordering(), temp);
|
NDArray outputRearranged1 = outputRearranged0.reshape(output.ordering(), temp, false);
|
||||||
|
|
||||||
BUILD_SINGLE_SELECTOR(input.dataType(), spaceToBatchND_, (input, padding, outputRearranged1, numOfSpatialDims), LIBND4J_TYPES);
|
BUILD_SINGLE_SELECTOR(input.dataType(), spaceToBatchND_, (input, padding, outputRearranged1, numOfSpatialDims), LIBND4J_TYPES);
|
||||||
|
|
||||||
|
|
|
@ -59,7 +59,7 @@ void FORCEINLINE cross(nd4j::LaunchContext * context, NDArray *a, NDArray *b, ND
|
||||||
void FORCEINLINE _crossBatched(nd4j::LaunchContext * context, NDArray *a, NDArray *b, NDArray *o) {
|
void FORCEINLINE _crossBatched(nd4j::LaunchContext * context, NDArray *a, NDArray *b, NDArray *o) {
|
||||||
auto a_ = a->reshape(a->ordering(), {-1, 3});
|
auto a_ = a->reshape(a->ordering(), {-1, 3});
|
||||||
auto b_ = b->reshape(b->ordering(), {-1, 3});
|
auto b_ = b->reshape(b->ordering(), {-1, 3});
|
||||||
auto o_ = o->reshape(o->ordering(), {-1, 3});
|
auto o_ = o->reshape(o->ordering(), {-1, 3}, false);
|
||||||
|
|
||||||
auto tadsA = a_.allTensorsAlongDimension({1});
|
auto tadsA = a_.allTensorsAlongDimension({1});
|
||||||
auto tadsB = b_.allTensorsAlongDimension({1});
|
auto tadsB = b_.allTensorsAlongDimension({1});
|
||||||
|
|
|
@ -322,7 +322,7 @@ static void depthwiseConv2d_(nd4j::graph::Context& block, const NDArray* input,
|
||||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
||||||
|
|
||||||
NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, input->dataType(), input->getContext());
|
NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, input->dataType(), input->getContext());
|
||||||
NDArray outputReshaped = output->reshape(output->ordering(), outReShape);
|
NDArray outputReshaped = output->reshape(output->ordering(), outReShape, false);
|
||||||
|
|
||||||
helpers::im2col(*output->getContext(), *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW]
|
helpers::im2col(*output->getContext(), *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW]
|
||||||
MmulHelper::tensorDot(&columns, weights, &outputReshaped, modifColumns, {{2,0,1,3},{iC,kH*kW,mC}}, modifOutput); // [iC, bS*oH*oW, kW*kH] x [iC, kH*kW, mC] = [iC, bS*oH*oW, mC]
|
MmulHelper::tensorDot(&columns, weights, &outputReshaped, modifColumns, {{2,0,1,3},{iC,kH*kW,mC}}, modifOutput); // [iC, bS*oH*oW, kW*kH] x [iC, kH*kW, mC] = [iC, bS*oH*oW, mC]
|
||||||
|
@ -1228,7 +1228,7 @@ static void conv2dBP_(nd4j::graph::Context& block, const NDArray* input, const N
|
||||||
NDArray* gradBR = gradB;
|
NDArray* gradBR = gradB;
|
||||||
if(gradB->rankOf() == 2)
|
if(gradB->rankOf() == 2)
|
||||||
gradBR = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}));
|
gradBR = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}));
|
||||||
gradO->reduceAlongDimension(reduce::Sum, *gradBR, gradOaxesForDot); // sum over bS, oH, oW
|
gradO->reduceAlongDimension(reduce::Sum, *gradBR, gradOaxesForDot, false); // sum over bS, oH, oW
|
||||||
if(gradBR != gradB)
|
if(gradBR != gradB)
|
||||||
delete gradBR;
|
delete gradBR;
|
||||||
}
|
}
|
||||||
|
@ -1310,7 +1310,7 @@ static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, con
|
||||||
NDArray* gradBR = gradB;
|
NDArray* gradBR = gradB;
|
||||||
if(gradB->rankOf() == 2)
|
if(gradB->rankOf() == 2)
|
||||||
gradBR = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}));
|
gradBR = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}));
|
||||||
gradO->reduceAlongDimension(reduce::Sum, *gradBR, {0,indOoH,indOoH+1}); // sum over bS, oH, oW
|
gradO->reduceAlongDimension(reduce::Sum, *gradBR, {0,indOoH,indOoH+1}, false); // sum over bS, oH, oW
|
||||||
if(gradBR != gradB)
|
if(gradBR != gradB)
|
||||||
delete gradBR;
|
delete gradBR;
|
||||||
}
|
}
|
||||||
|
|
|
@ -313,7 +313,7 @@ void spaceToBatch(nd4j::LaunchContext* context, const NDArray& input, NDArray& o
|
||||||
|
|
||||||
// [bS, iH, iW, iC] is rearranged/permuted to [bS*blockSize*blockSize, (iH + padBottom + padTop)/blockSize, (iW + padLeft + padRight)/blockSize, iC]
|
// [bS, iH, iW, iC] is rearranged/permuted to [bS*blockSize*blockSize, (iH + padBottom + padTop)/blockSize, (iW + padLeft + padRight)/blockSize, iC]
|
||||||
|
|
||||||
NDArray outputRearranged0 = output.reshape(output.ordering(), {blockSize, blockSize, input.sizeAt(0), output.sizeAt(1), output.sizeAt(2), input.sizeAt(3)});
|
NDArray outputRearranged0 = output.reshape(output.ordering(), {blockSize, blockSize, input.sizeAt(0), output.sizeAt(1), output.sizeAt(2), input.sizeAt(3)}, false);
|
||||||
outputRearranged0.permutei({2, 3,0, 4,1, 5});
|
outputRearranged0.permutei({2, 3,0, 4,1, 5});
|
||||||
|
|
||||||
if(input.lengthOf() == output.lengthOf()) {
|
if(input.lengthOf() == output.lengthOf()) {
|
||||||
|
@ -322,7 +322,7 @@ void spaceToBatch(nd4j::LaunchContext* context, const NDArray& input, NDArray& o
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
|
||||||
NDArray outputRearranged1 = outputRearranged0.reshape(output.ordering(), {input.sizeAt(0), output.sizeAt(1) * blockSize, output.sizeAt(2) * blockSize, input.sizeAt(3)});
|
NDArray outputRearranged1 = outputRearranged0.reshape(output.ordering(), {input.sizeAt(0), output.sizeAt(1) * blockSize, output.sizeAt(2) * blockSize, input.sizeAt(3)}, false);
|
||||||
|
|
||||||
const int threadsPerBlock = MAX_NUM_THREADS / 2;
|
const int threadsPerBlock = MAX_NUM_THREADS / 2;
|
||||||
const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
||||||
|
@ -439,7 +439,7 @@ void spaceToBatchND(nd4j::LaunchContext* context, const NDArray& input, const ND
|
||||||
for(int j = 1; j < rank; ++i, ++j)
|
for(int j = 1; j < rank; ++i, ++j)
|
||||||
temp[i] = output.sizeAt(j);
|
temp[i] = output.sizeAt(j);
|
||||||
|
|
||||||
NDArray outputRearranged0 = output.reshape(output.ordering(), temp);
|
NDArray outputRearranged0 = output.reshape(output.ordering(), temp, false);
|
||||||
|
|
||||||
//*** construct permuting std::vector for permutation of output array ***//
|
//*** construct permuting std::vector for permutation of output array ***//
|
||||||
|
|
||||||
|
@ -469,7 +469,7 @@ void spaceToBatchND(nd4j::LaunchContext* context, const NDArray& input, const ND
|
||||||
for(i = 1; i < rank; ++i)
|
for(i = 1; i < rank; ++i)
|
||||||
temp[i] = (i <= numOfSpatialDims) ? output.sizeAt(i) * blockShape.e<Nd4jLong>(i - 1) : output.sizeAt(i);
|
temp[i] = (i <= numOfSpatialDims) ? output.sizeAt(i) * blockShape.e<Nd4jLong>(i - 1) : output.sizeAt(i);
|
||||||
|
|
||||||
NDArray outputRearranged1 = outputRearranged0.reshape(output.ordering(), temp);
|
NDArray outputRearranged1 = outputRearranged0.reshape(output.ordering(), temp, false);
|
||||||
|
|
||||||
const int threadsPerBlock = MAX_NUM_THREADS / 4;
|
const int threadsPerBlock = MAX_NUM_THREADS / 4;
|
||||||
const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
||||||
|
|
|
@ -471,9 +471,9 @@ PLATFORM_IMPL(lstmLayer, ENGINE_CPU) {
|
||||||
if(cI)
|
if(cI)
|
||||||
cIR = new NDArray(cI->reshape(cI->ordering(), {1,dirDim,bS,nOut}));
|
cIR = new NDArray(cI->reshape(cI->ordering(), {1,dirDim,bS,nOut}));
|
||||||
if(hL)
|
if(hL)
|
||||||
hLR = new NDArray(hL->reshape(hL->ordering(), {1,dirDim,bS,nOut}));
|
hLR = new NDArray(hL->reshape(hL->ordering(), {1,dirDim,bS,nOut}, false));
|
||||||
if(cL)
|
if(cL)
|
||||||
cLR = new NDArray(cL->reshape(cL->ordering(), {1,dirDim,bS,nOut}));
|
cLR = new NDArray(cL->reshape(cL->ordering(), {1,dirDim,bS,nOut}, false));
|
||||||
|
|
||||||
lstmLayerMKLDNN(xP, WxR, WrR, bR, hIR, cIR, params, hP, hLR, cLR);
|
lstmLayerMKLDNN(xP, WxR, WrR, bR, hIR, cIR, params, hP, hLR, cLR);
|
||||||
|
|
||||||
|
|
|
@ -321,6 +321,280 @@ TEST_F(DeclarableOpsTests1, TestTensorDot4) {
|
||||||
delete results;
|
delete results;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests1, TestTensorDot5) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
|
||||||
|
auto y = NDArrayFactory::create<double>('c', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
|
||||||
|
auto expected = NDArrayFactory::create<double>('c', {2,4,2,4}, {44,110,160, 66,132, 38, 88,154, 68,170,224,102,204, 82,136,238, 92,230,288,138,276,126,184,322, 116,290,352,174,348,170,232,406, 76,190,160,114,228,182,152,266, 100,250,224,150,300,226,200,350, 124,310,288,186,372,270,248,434, 148,370,352,222,444,314,296,518});
|
||||||
|
|
||||||
|
nd4j::ops::tensormmul op;
|
||||||
|
auto results = op.evaluate({&x, &y}, {}, {1,1,1,2});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
auto *result = results->at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(expected.isSameShape(result));
|
||||||
|
ASSERT_TRUE(expected.equalsTo(result));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests1, TestTensorDot6) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
|
||||||
|
auto y = NDArrayFactory::create<double>('f', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
|
||||||
|
auto expected = NDArrayFactory::create<double>('c', {2,4,2,4}, {22, 66,110,154, 44, 88,132,176, 34,102,170,238, 68,136,204,272, 46,138,230,322, 92,184,276,368, 58,174,290,406,116,232,348,464, 38,114,190,266, 76,152,228,304, 50,150,250,350,100,200,300,400, 62,186,310,434,124,248,372,496, 74,222,370,518,148,296,444,592});
|
||||||
|
|
||||||
|
nd4j::ops::tensormmul op;
|
||||||
|
auto results = op.evaluate({&x, &y}, {}, {1,1,1,2});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
auto *result = results->at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(expected.isSameShape(result));
|
||||||
|
ASSERT_TRUE(expected.equalsTo(result));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests1, TestTensorDot7) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<double>('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
|
||||||
|
auto y = NDArrayFactory::create<double>('c', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
|
||||||
|
auto expected = NDArrayFactory::create<double>('c', {2,4,2,4}, {76,166,112,106,196, 62,136,226, 60,174,208, 98,212,230,136,250, 76,214,336,122,260,174,168,306, 124,286,240,178,340,150,232,394, 100,226,176,142,268,106,184,310, 84,234,272,134,284,274,184,334, 100,274,400,158,332,218,216,390, 148,346,304,214,412,194,280,478});
|
||||||
|
|
||||||
|
nd4j::ops::tensormmul op;
|
||||||
|
auto results = op.evaluate({&x, &y}, {}, {1,1,1,2});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
auto *result = results->at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(expected.isSameShape(result));
|
||||||
|
ASSERT_TRUE(expected.equalsTo(result));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests1, TestTensorDot8) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<double>('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
|
||||||
|
auto y = NDArrayFactory::create<double>('f', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
|
||||||
|
auto expected = NDArrayFactory::create<double>('c', {2,4,2,4}, {30, 90,150,210, 60,120,180,240, 38,114,190,266, 76,152,228,304, 46,138,230,322, 92,184,276,368, 54,162,270,378,108,216,324,432, 42,126,210,294, 84,168,252,336, 50,150,250,350,100,200,300,400, 58,174,290,406,116,232,348,464, 66,198,330,462,132,264,396,528});
|
||||||
|
|
||||||
|
nd4j::ops::tensormmul op;
|
||||||
|
auto results = op.evaluate({&x, &y}, {}, {1,1,1,2});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
auto *result = results->at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(expected.isSameShape(result));
|
||||||
|
ASSERT_TRUE(expected.equalsTo(result));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests1, TestTensorDot9) {
|
||||||
|
|
||||||
|
// NDArray z('f',{2,2,3}, nd4j::DataType::DOUBLE);
|
||||||
|
// z.linspace(1);
|
||||||
|
// z.printShapeInfo();
|
||||||
|
// z.printIndexedBuffer();
|
||||||
|
// z.reshapei('c', {4,3});
|
||||||
|
// z.printShapeInfo();
|
||||||
|
// z.printIndexedBuffer();
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<double>('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
|
||||||
|
auto y = NDArrayFactory::create<double>('f', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
|
||||||
|
auto expected = NDArrayFactory::create<double>('c', {3,4,4,3}, {14, 14, 14, 30, 30, 30, 46, 46, 46, 62, 62, 62, 86, 86, 86,198,198,198,310,310,310,422,422,422, 62, 62, 62,142,142,142,222,222,222,302,302,302, 38, 38, 38, 86, 86, 86,134,134,134,182,182,182, 38, 38, 38, 86, 86, 86,134,134,134,182,182,182, 14, 14, 14, 30, 30, 30, 46, 46, 46, 62, 62, 62, 86, 86, 86,198,198,198,310,310,310,422,422,422, 62, 62, 62,142,142,142,222,222,222,302,302,302, 62, 62, 62,142,142,142,222,222,222,302,302,302, 38, 38, 38, 86, 86, 86,134,134,134,182,182,182, 14, 14, 14, 30, 30, 30, 46, 46, 46, 62, 62, 62, 86, 86, 86,198,198,198,310,310,310,422,422,422});
|
||||||
|
|
||||||
|
nd4j::ops::tensormmul op;
|
||||||
|
auto results = op.evaluate({&x, &y}, {}, {1,0,1,0});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
auto *result = results->at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(expected.isSameShape(result));
|
||||||
|
ASSERT_TRUE(expected.equalsTo(result));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests1, TestTensorDot10) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<double>('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
|
||||||
|
auto y = NDArrayFactory::create<double>('f', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
|
||||||
|
auto expected = NDArrayFactory::create<double>('c', {4,4}, {114,258,402,546, 138,314,490,666, 162,370,578,786, 186,426,666,906});
|
||||||
|
|
||||||
|
nd4j::ops::tensormmul op;
|
||||||
|
auto results = op.evaluate({&x, &y}, {}, {2,0,1, 2,0,2});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
auto *result = results->at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(expected.isSameShape(result));
|
||||||
|
ASSERT_TRUE(expected.equalsTo(result));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests1, TestTensorDot11) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
|
||||||
|
auto y = NDArrayFactory::create<double>('f', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
|
||||||
|
auto expected = NDArrayFactory::create<double>('c', {4,4}, {98,218,338,458, 134,302,470,638, 170,386,602,818, 206,470,734,998});
|
||||||
|
|
||||||
|
nd4j::ops::tensormmul op;
|
||||||
|
auto results = op.evaluate({&x, &y}, {}, {2,0,1, 2,0,2});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
auto *result = results->at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(expected.isSameShape(result));
|
||||||
|
ASSERT_TRUE(expected.equalsTo(result));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests1, TestTensorDot12) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
|
||||||
|
auto y = NDArrayFactory::create<double>('c', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
|
||||||
|
auto expected = NDArrayFactory::create<double>('c', {4,4}, {272,292,312,332, 368,396,424,452, 464,500,536,572, 560,604,648,692});
|
||||||
|
|
||||||
|
nd4j::ops::tensormmul op;
|
||||||
|
auto results = op.evaluate({&x, &y}, {}, {2,0,1, 2,0,2});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
auto *result = results->at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(expected.isSameShape(result));
|
||||||
|
ASSERT_TRUE(expected.equalsTo(result));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests1, TestTensorDot13) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<double>('c', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
|
||||||
|
auto y = NDArrayFactory::create<double>('c', {4,2,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
|
||||||
|
auto expected = NDArrayFactory::create<double>('c', {3,3}, {640,560,640, 576,624,576, 640,560,640});
|
||||||
|
|
||||||
|
nd4j::ops::tensormmul op;
|
||||||
|
auto results = op.evaluate({&x, &y}, {}, {2,0,2, 2,1,0});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
auto *result = results->at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(expected.isSameShape(result));
|
||||||
|
ASSERT_TRUE(expected.equalsTo(result));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests1, TestTensorDot14) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<double>('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
|
||||||
|
auto y = NDArrayFactory::create<double>('c', {4,2,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
|
||||||
|
auto expected = NDArrayFactory::create<double>('c', {3,3}, {648,600,520, 648,536,648, 520,600,648});
|
||||||
|
|
||||||
|
nd4j::ops::tensormmul op;
|
||||||
|
auto results = op.evaluate({&x, &y}, {}, {2,0,2, 2,1,0});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
auto *result = results->at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(expected.isSameShape(result));
|
||||||
|
ASSERT_TRUE(expected.equalsTo(result));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests1, TestTensorDot15) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<double>('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
|
||||||
|
auto y = NDArrayFactory::create<double>('f', {4,2,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
|
||||||
|
auto expected = NDArrayFactory::create<double>('c', {3,3}, {624,624,624, 656,656,656, 624,624,624});
|
||||||
|
|
||||||
|
nd4j::ops::tensormmul op;
|
||||||
|
auto results = op.evaluate({&x, &y}, {}, {2,0,2, 2,1,0});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
auto *result = results->at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(expected.isSameShape(result));
|
||||||
|
ASSERT_TRUE(expected.equalsTo(result));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests1, TestTensorDot16) {
|
||||||
|
|
||||||
|
NDArray x('c', {1}, std::vector<double>{2}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray y('c', {2,1,2}, {1,2,3,4}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray exp('c', {2,2}, {2,4,6,8}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
nd4j::ops::tensormmul op;
|
||||||
|
auto results = op.evaluate({&x, &y}, {}, {1,0, 1,1});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
auto *result = results->at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(result));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(result));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests1, TestTensorDot17) {
|
||||||
|
|
||||||
|
NDArray x('f', {16,16}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray y('f', {1000,16}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray z('c', {16,1000}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
nd4j::ops::tensormmul op;
|
||||||
|
auto status = op.execute({&x, &y}, {&z}, {}, {1,1, 1,1}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||||
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests1, DivergentCheck1) {
|
TEST_F(DeclarableOpsTests1, DivergentCheck1) {
|
||||||
auto op = nd4j::ops::OpRegistrator::getInstance()->getOperation("switch");
|
auto op = nd4j::ops::OpRegistrator::getInstance()->getOperation("switch");
|
||||||
|
|
|
@ -708,30 +708,6 @@ TEST_F(DeclarableOpsTests12, multiUnique_2) {
|
||||||
ASSERT_TRUE(nd4j::ops::helpers::multiUnique(arrayList));
|
ASSERT_TRUE(nd4j::ops::helpers::multiUnique(arrayList));
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////
|
|
||||||
TEST_F(DeclarableOpsTests12, tensormmul_6) {
|
|
||||||
|
|
||||||
NDArray x('c', {1}, std::vector<double>{2}, nd4j::DataType::FLOAT32);
|
|
||||||
NDArray y('c', {2,1,2}, {1,2,3,4}, nd4j::DataType::FLOAT32);
|
|
||||||
NDArray exp('c', {2,2}, {2,4,6,8}, nd4j::DataType::FLOAT32);
|
|
||||||
|
|
||||||
nd4j::ops::tensormmul op;
|
|
||||||
auto results = op.evaluate({&x, &y}, {}, {1,0, 1,1});
|
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
||||||
|
|
||||||
auto *result = results->at(0);
|
|
||||||
// exp.printShapeInfo();
|
|
||||||
// result->printShapeInfo();
|
|
||||||
// result->printIndexedBuffer();
|
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(result));
|
|
||||||
ASSERT_TRUE(exp.equalsTo(result));
|
|
||||||
|
|
||||||
delete results;
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests12, reduceMeanBp_4) {
|
TEST_F(DeclarableOpsTests12, reduceMeanBp_4) {
|
||||||
|
|
||||||
|
|
|
@ -1560,3 +1560,447 @@ TEST_F(DeclarableOpsTests15, Pow_BP_Test11) {
|
||||||
|
|
||||||
delete resultsB;
|
delete resultsB;
|
||||||
}
|
}
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP1) {
|
||||||
|
|
||||||
|
NDArray A('c', { 1, 2, 3 }, { 2.1, 2.2, 2.3, 2.4, 2.5, 2.6 }, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray B('c', { 1, 2, 4 }, { 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8 }, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray dLdC('c', { 3, 4 }, { .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.1 }, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
NDArray dLdA('c', { 1, 2, 3 }, { 3.3, 8.5, 13.36, 3.7, 9.54, 15. }, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray dLdB('c', { 1, 2, 4 }, { 3.38, 4.04, 4.7, 5.13, 3.83, 4.58, 5.33, 5.82 }, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
nd4j::ops::tensormmul_bp op_bp;
|
||||||
|
|
||||||
|
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,0,1, 2,0,1 }, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, resultsBP->status());
|
||||||
|
|
||||||
|
auto* dLdAbp = resultsBP->at(0);
|
||||||
|
auto* dLdBbp = resultsBP->at(1);
|
||||||
|
|
||||||
|
ASSERT_TRUE(dLdA.isSameShape(*dLdAbp));
|
||||||
|
ASSERT_TRUE(dLdA.equalsTo(*dLdAbp));
|
||||||
|
|
||||||
|
ASSERT_TRUE(dLdB.isSameShape(*dLdBbp));
|
||||||
|
ASSERT_TRUE(dLdB.equalsTo(*dLdBbp));
|
||||||
|
|
||||||
|
delete resultsBP;
|
||||||
|
}
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP2) {
|
||||||
|
|
||||||
|
NDArray A('c', { 1, 2, 3 }, { 2,2,2, 2,2,2 }, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray B('c', { 1, 2, 3 }, { 3,3,3,3, 3,3 }, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray dLdC('c', { 1 }, { 1 }, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
nd4j::ops::tensormmul_bp op_bp;
|
||||||
|
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, resultsBP->status());
|
||||||
|
|
||||||
|
auto* dLdAbp = resultsBP->at(0);
|
||||||
|
auto* dLdBbp = resultsBP->at(1);
|
||||||
|
|
||||||
|
ASSERT_TRUE(B.isSameShape(*dLdAbp));
|
||||||
|
ASSERT_TRUE(B.equalsTo(*dLdAbp));
|
||||||
|
|
||||||
|
ASSERT_TRUE(A.isSameShape(*dLdBbp));
|
||||||
|
ASSERT_TRUE(A.equalsTo(*dLdBbp));
|
||||||
|
|
||||||
|
delete resultsBP;
|
||||||
|
}
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP3) {
|
||||||
|
|
||||||
|
NDArray A('c', { 3, 2, 2 }, { 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2 }, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray B('c', { 4, 2, 2 }, { 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2, 4.3, 4.4, 4.5, 4.6 }, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray dLdC('c', { 3, 4 }, { .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2 }, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
NDArray dA('c', { 3, 2, 2 }, { 3.9, 4., 4.1, 4.2, 9.82, 10.08, 10.34, 10.6, 15.74, 16.16, 16.58, 17. }, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray dB('c', { 4, 2, 2 }, { 4.07, 4.22, 4.37, 4.52, 4.82, 5., 5.18, 5.36, 5.57, 5.78, 5.99, 6.2, 6.32, 6.56, 6.8, 7.04 }, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
nd4j::ops::tensormmul_bp op_bp;
|
||||||
|
|
||||||
|
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, resultsBP->status());
|
||||||
|
|
||||||
|
auto* dLdAbp = resultsBP->at(0);
|
||||||
|
auto* dLdBbp = resultsBP->at(1);
|
||||||
|
|
||||||
|
ASSERT_TRUE(dA.isSameShape(*dLdAbp));
|
||||||
|
ASSERT_TRUE(dA.equalsTo(*dLdAbp));
|
||||||
|
|
||||||
|
ASSERT_TRUE(dB.isSameShape(*dLdBbp));
|
||||||
|
ASSERT_TRUE(dB.equalsTo(*dLdBbp));
|
||||||
|
|
||||||
|
delete resultsBP;
|
||||||
|
}
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP4) {
|
||||||
|
|
||||||
|
NDArray A('c', { 3, 4, 1 }, { 0.4, 3, 5, 9, 23, 0.12, 8, 9, 0.1, 0, 124, 3 }, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray B('c', { 2, 4, 1 }, { 4, 13, .5, 19, 2.3, 1.2, 18, .9 }, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray dLdC('c', { 3, 2 }, { 1.1, 1.2, 1.3, 1.4, 1.5, 1.6 }, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
NDArray dLdA('c', { 3, 4, 1 }, { 7.16, 15.74, 22.15, 21.98, 8.42, 18.58, 25.85, 25.96, 9.68, 21.42, 29.55, 29.94 }, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray dLdB('c', { 2, 4, 1 }, { 30.49, 3.456, 201.9, 26.1, 32.84 , 3.768, 215.6, 28.2 }, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
nd4j::ops::tensormmul_bp op_bp;
|
||||||
|
|
||||||
|
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, resultsBP->status());
|
||||||
|
|
||||||
|
auto* dLdAbp = resultsBP->at(0);
|
||||||
|
auto* dLdBbp = resultsBP->at(1);
|
||||||
|
|
||||||
|
ASSERT_TRUE(dLdA.isSameShape(*dLdAbp));
|
||||||
|
ASSERT_TRUE(dLdA.equalsTo(*dLdAbp));
|
||||||
|
|
||||||
|
ASSERT_TRUE(dLdB.isSameShape(*dLdBbp));
|
||||||
|
ASSERT_TRUE(dLdB.equalsTo(*dLdBbp));
|
||||||
|
|
||||||
|
delete resultsBP;
|
||||||
|
}
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP5) {
|
||||||
|
|
||||||
|
NDArray A('c', { 3, 4, 1, 1 }, { 0.4, 3, 5, 9, 23, 0.12, 8, 9, 0.1, 0, 124, 3 }, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray B('c', { 2, 4, 1, 1 }, { 4, 13, .5, 19, 2.3, 1.2, 18, .9 }, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray dLdC('c', { 3, 1, 2, 1 }, { 1.1,1.2,1.3,1.4,1.5,1.6 }, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
NDArray dLdA('c', { 3, 4, 1, 1 }, { 7.16, 15.74, 22.15, 21.98, 8.42, 18.58, 25.85, 25.96, 9.68, 21.42, 29.55, 29.94 }, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray dLdB('c', { 2, 4, 1, 1 }, { 30.49, 3.456, 201.9, 26.1, 32.84, 3.768, 215.6, 28.2 }, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
nd4j::ops::tensormmul_bp op_bp;
|
||||||
|
|
||||||
|
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, resultsBP->status());
|
||||||
|
|
||||||
|
auto* dLdAbp = resultsBP->at(0);
|
||||||
|
auto* dLdBbp = resultsBP->at(1);
|
||||||
|
|
||||||
|
ASSERT_TRUE(dLdA.isSameShape(*dLdAbp));
|
||||||
|
ASSERT_TRUE(dLdA.equalsTo(*dLdAbp));
|
||||||
|
|
||||||
|
ASSERT_TRUE(dLdB.isSameShape(*dLdBbp));
|
||||||
|
ASSERT_TRUE(dLdB.equalsTo(*dLdBbp));
|
||||||
|
|
||||||
|
delete resultsBP;
|
||||||
|
}
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP6) {
|
||||||
|
|
||||||
|
NDArray A('c', { 2, 2, 2 }, { 2,2, 2,2, 2,2, 2,2 }, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray B('c', { 2, 2, 2 }, { 3,3, 3,3, 3,3, 3,3 }, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
auto dLdC = NDArrayFactory::create<float>(1.f);
|
||||||
|
|
||||||
|
nd4j::ops::tensormmul_bp op_bp;
|
||||||
|
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 3,0,1,2, 3,0,1,2 }, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, resultsBP->status());
|
||||||
|
|
||||||
|
auto* dLdAbp = resultsBP->at(0);
|
||||||
|
auto* dLdBbp = resultsBP->at(1);
|
||||||
|
|
||||||
|
ASSERT_TRUE(B.isSameShape(*dLdAbp));
|
||||||
|
ASSERT_TRUE(B.equalsTo(*dLdAbp));
|
||||||
|
|
||||||
|
ASSERT_TRUE(A.isSameShape(*dLdBbp));
|
||||||
|
ASSERT_TRUE(A.equalsTo(*dLdBbp));
|
||||||
|
|
||||||
|
delete resultsBP;
|
||||||
|
}
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP7) {
|
||||||
|
|
||||||
|
NDArray A('c', { 3, 4, 1 }, { 0.4, 3, 5, 9, 23, 0.12, 8, 9, 0.1, 0, 124, 3 }, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray B('c', { 2, 4, 1 }, { 4, 13, .5, 19, 2.3, 1.2, 18, .9 }, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray dLdC('c', { 3, 1, 2, 1 }, { 1.1, 1.2, 1.3, 1.4, 1.5, 1.6 }, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
NDArray dLdA('c', { 3, 4, 1 }, { 7.16, 15.74, 22.15, 21.98, 8.42, 18.58, 25.85, 25.96, 9.68, 21.42, 29.55, 29.94 }, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray dLdB('c', { 2, 4, 1 }, { 30.49, 3.456, 201.9, 26.1, 32.84, 3.768, 215.6, 28.2 }, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
nd4j::ops::tensormmul_bp op_bp;
|
||||||
|
|
||||||
|
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 1,1, 1,1 }, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, resultsBP->status());
|
||||||
|
auto* dLdAbp = resultsBP->at(0);
|
||||||
|
auto* dLdBbp = resultsBP->at(1);
|
||||||
|
|
||||||
|
ASSERT_TRUE(dLdA.isSameShape(*dLdAbp));
|
||||||
|
ASSERT_TRUE(dLdA.equalsTo(*dLdAbp));
|
||||||
|
|
||||||
|
ASSERT_TRUE(dLdB.isSameShape(*dLdBbp));
|
||||||
|
ASSERT_TRUE(dLdB.equalsTo(*dLdBbp));
|
||||||
|
|
||||||
|
delete resultsBP;
|
||||||
|
}
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP8) {
|
||||||
|
|
||||||
|
NDArray A('c', { 1, 1, 4, 3 }, { 0.4, 3, 5, 9, 23, 0.12, 8, 9, 0.1, 0, 124, 3 }, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray B('c', { 1, 1, 4, 2 }, { 4, 13, .5, 19, 2.3, 1.2, 18, .9 }, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray dLdC('c', { 3, 2 }, { 1.1,1.2,1.3,1.4,1.5,1.6 }, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
NDArray dLdA('c', { 1, 1, 4, 3 }, { 20., 23.4, 26.8, 23.35, 27.25, 31.15, 3.97, 4.67, 5.37, 20.88, 24.66, 28.44 }, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray dLdB('c', { 1, 1, 4, 2 }, { 11.84, 12.68, 39.98, 43.192, 20.65, 22.36, 165.7, 178.4 }, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
nd4j::ops::tensormmul_bp op_bp;
|
||||||
|
|
||||||
|
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 3,0,1,2, 3,0,1,2 }, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, resultsBP->status());
|
||||||
|
|
||||||
|
auto* dLdAbp = resultsBP->at(0);
|
||||||
|
auto* dLdBbp = resultsBP->at(1);
|
||||||
|
|
||||||
|
ASSERT_TRUE(dLdA.isSameShape(*dLdAbp));
|
||||||
|
ASSERT_TRUE(dLdA.equalsTo(*dLdAbp));
|
||||||
|
|
||||||
|
ASSERT_TRUE(dLdB.isSameShape(*dLdBbp));
|
||||||
|
ASSERT_TRUE(dLdB.equalsTo(*dLdBbp));
|
||||||
|
|
||||||
|
delete resultsBP;
|
||||||
|
}
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP9) {
|
||||||
|
|
||||||
|
NDArray A('c', { 3, 2, 2, 1 }, { 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2 }, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray B('c', { 4, 2, 2 ,1 }, { 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2, 4.3, 4.4, 4.5, 4.6 }, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray dLdC('c', { 3, 1, 4, 1 }, { .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2 }, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
NDArray dA('c', { 3, 2, 2, 1 }, { 3.9, 4., 4.1, 4.2, 9.82, 10.08, 10.34, 10.6, 15.74, 16.16, 16.58, 17. }, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray dB('c', { 4, 2, 2, 1 }, { 4.07, 4.22, 4.37, 4.52, 4.82, 5., 5.18, 5.36, 5.57, 5.78, 5.99, 6.2, 6.32, 6.56, 6.8, 7.04 }, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
nd4j::ops::tensormmul_bp op_bp;
|
||||||
|
|
||||||
|
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, resultsBP->status());
|
||||||
|
|
||||||
|
auto* dLdAbp = resultsBP->at(0);
|
||||||
|
auto* dLdBbp = resultsBP->at(1);
|
||||||
|
|
||||||
|
ASSERT_TRUE(dA.isSameShape(*dLdAbp));
|
||||||
|
ASSERT_TRUE(dA.equalsTo(*dLdAbp));
|
||||||
|
|
||||||
|
ASSERT_TRUE(dB.isSameShape(*dLdBbp));
|
||||||
|
ASSERT_TRUE(dB.equalsTo(*dLdBbp));
|
||||||
|
|
||||||
|
delete resultsBP;
|
||||||
|
}
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP10) {
|
||||||
|
|
||||||
|
NDArray A('c', { 1, 2, 2, 3 }, { 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2 }, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray B('c', { 1, 2, 2 ,4 }, { 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2, 4.3, 4.4, 4.5, 4.6 }, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray dLdC('c', { 1, 3, 1, 4 }, { .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2 }, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
|
||||||
|
NDArray dA('c', { 1, 2, 2, 3 }, { 3.3, 8.5, 13.7, 3.7, 9.54, 15.38, 4.1, 10.58, 17.06, 4.5, 11.62, 18.74 }, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray dB('c', { 1, 2, 2, 4 }, { 3.38, 4.04, 4.7, 5.36, 3.83, 4.58, 5.33, 6.08, 4.28, 5.12, 5.96, 6.8, 4.73, 5.66, 6.59, 7.52 }, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
nd4j::ops::tensormmul_bp op_bp;
|
||||||
|
|
||||||
|
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, resultsBP->status());
|
||||||
|
|
||||||
|
auto* dLdAbp = resultsBP->at(0);
|
||||||
|
auto* dLdBbp = resultsBP->at(1);
|
||||||
|
|
||||||
|
ASSERT_TRUE(dA.isSameShape(*dLdAbp));
|
||||||
|
ASSERT_TRUE(dA.equalsTo(*dLdAbp));
|
||||||
|
|
||||||
|
ASSERT_TRUE(dB.isSameShape(*dLdBbp));
|
||||||
|
ASSERT_TRUE(dB.equalsTo(*dLdBbp));
|
||||||
|
|
||||||
|
delete resultsBP;
|
||||||
|
}
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP11) {
|
||||||
|
|
||||||
|
NDArray A('c', { 2, 2, 3 }, { 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2 }, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray B('c', { 2, 2 ,4 }, { 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2, 4.3, 4.4, 4.5, 4.6 }, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray dLdC('c', { 3, 4 }, { .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2 }, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
|
||||||
|
NDArray dA('c', { 2, 2, 3 }, { 3.3, 8.5, 13.7, 3.7, 9.54, 15.38, 4.1, 10.58, 17.06, 4.5, 11.62, 18.74 }, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray dB('c', { 2, 2, 4 }, { 3.38, 4.04, 4.7, 5.36, 3.83, 4.58, 5.33, 6.08, 4.28, 5.12, 5.96, 6.8, 4.73, 5.66, 6.59, 7.52 }, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
nd4j::ops::tensormmul_bp op_bp;
|
||||||
|
|
||||||
|
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,0,1, 2,0,1 }, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, resultsBP->status());
|
||||||
|
|
||||||
|
auto* dLdAbp = resultsBP->at(0);
|
||||||
|
auto* dLdBbp = resultsBP->at(1);
|
||||||
|
|
||||||
|
ASSERT_TRUE(dA.isSameShape(*dLdAbp));
|
||||||
|
ASSERT_TRUE(dA.equalsTo(*dLdAbp));
|
||||||
|
|
||||||
|
ASSERT_TRUE(dB.isSameShape(*dLdBbp));
|
||||||
|
ASSERT_TRUE(dB.equalsTo(*dLdBbp));
|
||||||
|
|
||||||
|
delete resultsBP;
|
||||||
|
}
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP12) {
|
||||||
|
|
||||||
|
NDArray A('c', { 2, 2, 3 }, { 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2 }, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray B('c', { 2, 2 ,3 }, { 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2 }, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray dLdC('c', { 2, 3, 2, 3 }, { .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2,
|
||||||
|
1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2., 2.1, 2.2, 2.3, 2.4,
|
||||||
|
2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2, 3.3, 3.4, 3.5, 3.6 }, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
NDArray dA('c', { 2, 2, 3 }, { 7.66, 20.26, 32.86, 8.29, 21.97, 35.65, 45.46, 58.06, 70.66, 49.33, 63.01, 76.69 }, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray dB('c', { 2, 2, 3 }, { 25.86, 27.36, 28.86, 28.74, 30.42, 32.1, 30.36, 31.86, 33.36, 33.78, 35.46, 37.14 }, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
nd4j::ops::tensormmul_bp op_bp;
|
||||||
|
|
||||||
|
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 1,1, 1,1 }, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, resultsBP->status());
|
||||||
|
|
||||||
|
auto* dLdAbp = resultsBP->at(0);
|
||||||
|
auto* dLdBbp = resultsBP->at(1);
|
||||||
|
|
||||||
|
ASSERT_TRUE(dA.isSameShape(*dLdAbp));
|
||||||
|
ASSERT_TRUE(dA.equalsTo(*dLdAbp));
|
||||||
|
|
||||||
|
ASSERT_TRUE(dB.isSameShape(*dLdBbp));
|
||||||
|
ASSERT_TRUE(dB.equalsTo(*dLdBbp));
|
||||||
|
|
||||||
|
delete resultsBP;
|
||||||
|
}
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP13) {
|
||||||
|
|
||||||
|
NDArray A('c', { 3, 2, 2 }, { 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2 }, nd4j::DataType::DOUBLE);
|
||||||
|
NDArray B('c', { 3, 2, 2 }, { 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2 }, nd4j::DataType::DOUBLE);
|
||||||
|
NDArray dLdC('c', { 3, 2, 3, 2 }, { .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2,
|
||||||
|
1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2., 2.1, 2.2, 2.3, 2.4,
|
||||||
|
2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2, 3.3, 3.4, 3.5, 3.6 }, nd4j::DataType::DOUBLE);
|
||||||
|
|
||||||
|
NDArray dA('c', { 3, 2, 2 }, { 7.79, 20.57, 8.21, 21.71, 33.35, 46.13, 35.21, 48.71, 58.91, 71.69, 62.21, 75.71 }, nd4j::DataType::DOUBLE);
|
||||||
|
NDArray dB('c', { 3, 2, 2 }, { 26.49, 28.02, 28.41, 30.06, 29.55, 31.08, 31.71, 33.36, 32.61, 34.14, 35.01, 36.66 }, nd4j::DataType::DOUBLE);
|
||||||
|
|
||||||
|
nd4j::ops::tensormmul_bp op_bp;
|
||||||
|
|
||||||
|
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 1,1, 1,1 }, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, resultsBP->status());
|
||||||
|
|
||||||
|
auto* dLdAbp = resultsBP->at(0);
|
||||||
|
auto* dLdBbp = resultsBP->at(1);
|
||||||
|
|
||||||
|
ASSERT_TRUE(dA.isSameShape(*dLdAbp));
|
||||||
|
ASSERT_TRUE(dA.equalsTo(*dLdAbp));
|
||||||
|
|
||||||
|
ASSERT_TRUE(dB.isSameShape(*dLdBbp));
|
||||||
|
ASSERT_TRUE(dB.equalsTo(*dLdBbp));
|
||||||
|
|
||||||
|
delete resultsBP;
|
||||||
|
}
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP14) {
|
||||||
|
|
||||||
|
NDArray A('c', { 2, 2, 2, 2 }, { 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2, 3.3, 3.4, 3.5, 3.6 }, nd4j::DataType::DOUBLE);
|
||||||
|
|
||||||
|
NDArray B('c', { 2, 2, 2, 2 }, { 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2, 4.3, 4.4, 4.5, 4.6 }, nd4j::DataType::DOUBLE);
|
||||||
|
|
||||||
|
NDArray dLdC('c', { 2, 2, 2, 2, 2, 2 }, { .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2,
|
||||||
|
1.3, 1.4, 1.5, 1.6, .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2,
|
||||||
|
1.3, 1.4, 1.5, 1.6, .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2,
|
||||||
|
1.3, 1.4, 1.5, 1.6, .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2,
|
||||||
|
1.3, 1.4, 1.5, 1.6 }, nd4j::DataType::DOUBLE);
|
||||||
|
|
||||||
|
NDArray dA('c', { 2, 2, 2, 2 }, { 13.88, 37.24, 13.88, 37.24, 15.32, 41.24, 15.32, 41.24, 13.88, 37.24, 13.88, 37.24, 15.32, 41.24, 15.32, 41.24 }, nd4j::DataType::DOUBLE);
|
||||||
|
NDArray dB('c', { 2, 2, 2, 2 }, { 10.76, 12.88, 15., 17.12, 12.36, 14.8, 17.24, 19.68, 19.24, 21.36, 23.48, 25.6, 22.12, 24.56, 27., 29.44 }, nd4j::DataType::DOUBLE);
|
||||||
|
|
||||||
|
nd4j::ops::tensormmul_bp op_bp;
|
||||||
|
|
||||||
|
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 1,1, 1,1 }, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, resultsBP->status());
|
||||||
|
|
||||||
|
auto* dLdAbp = resultsBP->at(0);
|
||||||
|
auto* dLdBbp = resultsBP->at(1);
|
||||||
|
|
||||||
|
ASSERT_TRUE(dA.isSameShape(*dLdAbp));
|
||||||
|
ASSERT_TRUE(dA.equalsTo(*dLdAbp));
|
||||||
|
|
||||||
|
ASSERT_TRUE(dB.isSameShape(*dLdBbp));
|
||||||
|
ASSERT_TRUE(dB.equalsTo(*dLdBbp));
|
||||||
|
|
||||||
|
delete resultsBP;
|
||||||
|
}
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP15) {
|
||||||
|
|
||||||
|
NDArray A('c', { 2, 2, 3 }, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12. }, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray B('f', { 2, 2, 3 }, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12. }, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
NDArray dLdC('f', { 2, 2 }, { 23.0, 24.44, 2.0, 26. }, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
NDArray dA('c', { 2, 2, 3 }, { 27., 127., 227., 77., 177., 277., 76.44, 278.20001, 479.96002, 177.32, 379.08001, 580.839966 }, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray dB('f', { 2, 2, 3 }, { 194.08, 184., 336.4, 268., 241.52, 212., 383.839996, 296., 288.96002, 240., 431.27999, 324. }, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
nd4j::ops::tensormmul_bp op;
|
||||||
|
auto results = op.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2,2,1,2 });
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
auto* dLdA = results->at(0);
|
||||||
|
auto* dLdB = results->at(1);
|
||||||
|
|
||||||
|
ASSERT_TRUE(dA.isSameShape(*dLdA));
|
||||||
|
ASSERT_TRUE(dA.equalsTo(*dLdA));
|
||||||
|
|
||||||
|
ASSERT_TRUE(dB.isSameShape(*dLdB));
|
||||||
|
ASSERT_TRUE(dB.equalsTo(*dLdB));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP16) {
|
||||||
|
|
||||||
|
NDArray A('f', { 2, 2, 3 }, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12. }, nd4j::DataType::DOUBLE);
|
||||||
|
NDArray B('c', { 2, 2, 3 }, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12. }, nd4j::DataType::DOUBLE);
|
||||||
|
|
||||||
|
NDArray dLdC('c', { 2, 2 }, nd4j::DataType::DOUBLE);
|
||||||
|
|
||||||
|
const OpArgsHolder argsHolderFF({ &A, &B }, {}, { 2,1,2, 2,1,2 });
|
||||||
|
const OpArgsHolder argsHolderBP({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 });
|
||||||
|
|
||||||
|
nd4j::ops::tensormmul op;
|
||||||
|
nd4j::ops::tensormmul_bp op_bp;
|
||||||
|
|
||||||
|
const bool isGradCorrect = GradCheck::checkGrad(op, op_bp, argsHolderFF, argsHolderBP, {1,0});
|
||||||
|
ASSERT_TRUE(isGradCorrect);
|
||||||
|
}
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP17) {
|
||||||
|
|
||||||
|
NDArray A('f', { 2, 2, 3 }, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12. }, nd4j::DataType::DOUBLE);
|
||||||
|
NDArray B('f', { 2, 2, 3 }, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12. }, nd4j::DataType::DOUBLE);
|
||||||
|
|
||||||
|
NDArray dLdC('c', { 2, 2 }, nd4j::DataType::DOUBLE);
|
||||||
|
|
||||||
|
const OpArgsHolder argsHolderFF({ &A, &B }, {}, { 2,1,2, 2,1,2 });
|
||||||
|
const OpArgsHolder argsHolderBP({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 });
|
||||||
|
|
||||||
|
nd4j::ops::tensormmul op;
|
||||||
|
nd4j::ops::tensormmul_bp op_bp;
|
||||||
|
|
||||||
|
const bool isGradCorrect = GradCheck::checkGrad(op, op_bp, argsHolderFF, argsHolderBP, { 1,0 });
|
||||||
|
ASSERT_TRUE(isGradCorrect);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
|
@ -578,246 +578,6 @@ TEST_F(DeclarableOpsTests2, Test_Concat_BP_1) {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////
|
|
||||||
TEST_F(DeclarableOpsTests2, TestTensorDot5) {
|
|
||||||
|
|
||||||
auto x = NDArrayFactory::create<double>('c', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
|
|
||||||
auto y = NDArrayFactory::create<double>('c', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
|
|
||||||
auto expected = NDArrayFactory::create<double>('c', {2,4,2,4}, {44,110,160, 66,132, 38, 88,154, 68,170,224,102,204, 82,136,238, 92,230,288,138,276,126,184,322, 116,290,352,174,348,170,232,406, 76,190,160,114,228,182,152,266, 100,250,224,150,300,226,200,350, 124,310,288,186,372,270,248,434, 148,370,352,222,444,314,296,518});
|
|
||||||
|
|
||||||
nd4j::ops::tensormmul op;
|
|
||||||
auto results = op.evaluate({&x, &y}, {}, {1,1,1,2});
|
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
||||||
|
|
||||||
auto *result = results->at(0);
|
|
||||||
|
|
||||||
ASSERT_TRUE(expected.isSameShape(result));
|
|
||||||
ASSERT_TRUE(expected.equalsTo(result));
|
|
||||||
|
|
||||||
delete results;
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////
|
|
||||||
TEST_F(DeclarableOpsTests2, TestTensorDot6) {
|
|
||||||
|
|
||||||
auto x = NDArrayFactory::create<double>('c', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
|
|
||||||
auto y = NDArrayFactory::create<double>('f', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
|
|
||||||
auto expected = NDArrayFactory::create<double>('c', {2,4,2,4}, {22, 66,110,154, 44, 88,132,176, 34,102,170,238, 68,136,204,272, 46,138,230,322, 92,184,276,368, 58,174,290,406,116,232,348,464, 38,114,190,266, 76,152,228,304, 50,150,250,350,100,200,300,400, 62,186,310,434,124,248,372,496, 74,222,370,518,148,296,444,592});
|
|
||||||
|
|
||||||
nd4j::ops::tensormmul op;
|
|
||||||
auto results = op.evaluate({&x, &y}, {}, {1,1,1,2});
|
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
||||||
|
|
||||||
auto *result = results->at(0);
|
|
||||||
|
|
||||||
ASSERT_TRUE(expected.isSameShape(result));
|
|
||||||
ASSERT_TRUE(expected.equalsTo(result));
|
|
||||||
|
|
||||||
delete results;
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////
|
|
||||||
TEST_F(DeclarableOpsTests2, TestTensorDot7) {
|
|
||||||
|
|
||||||
auto x = NDArrayFactory::create<double>('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
|
|
||||||
auto y = NDArrayFactory::create<double>('c', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
|
|
||||||
auto expected = NDArrayFactory::create<double>('c', {2,4,2,4}, {76,166,112,106,196, 62,136,226, 60,174,208, 98,212,230,136,250, 76,214,336,122,260,174,168,306, 124,286,240,178,340,150,232,394, 100,226,176,142,268,106,184,310, 84,234,272,134,284,274,184,334, 100,274,400,158,332,218,216,390, 148,346,304,214,412,194,280,478});
|
|
||||||
|
|
||||||
nd4j::ops::tensormmul op;
|
|
||||||
auto results = op.evaluate({&x, &y}, {}, {1,1,1,2});
|
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
||||||
|
|
||||||
auto *result = results->at(0);
|
|
||||||
|
|
||||||
ASSERT_TRUE(expected.isSameShape(result));
|
|
||||||
ASSERT_TRUE(expected.equalsTo(result));
|
|
||||||
|
|
||||||
delete results;
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////
|
|
||||||
TEST_F(DeclarableOpsTests2, TestTensorDot8) {
|
|
||||||
|
|
||||||
auto x = NDArrayFactory::create<double>('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
|
|
||||||
auto y = NDArrayFactory::create<double>('f', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
|
|
||||||
auto expected = NDArrayFactory::create<double>('c', {2,4,2,4}, {30, 90,150,210, 60,120,180,240, 38,114,190,266, 76,152,228,304, 46,138,230,322, 92,184,276,368, 54,162,270,378,108,216,324,432, 42,126,210,294, 84,168,252,336, 50,150,250,350,100,200,300,400, 58,174,290,406,116,232,348,464, 66,198,330,462,132,264,396,528});
|
|
||||||
|
|
||||||
nd4j::ops::tensormmul op;
|
|
||||||
auto results = op.evaluate({&x, &y}, {}, {1,1,1,2});
|
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
||||||
|
|
||||||
auto *result = results->at(0);
|
|
||||||
|
|
||||||
ASSERT_TRUE(expected.isSameShape(result));
|
|
||||||
ASSERT_TRUE(expected.equalsTo(result));
|
|
||||||
|
|
||||||
delete results;
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////
|
|
||||||
TEST_F(DeclarableOpsTests2, TestTensorDot9) {
|
|
||||||
|
|
||||||
// NDArray z('f',{2,2,3}, nd4j::DataType::DOUBLE);
|
|
||||||
// z.linspace(1);
|
|
||||||
// z.printShapeInfo();
|
|
||||||
// z.printIndexedBuffer();
|
|
||||||
// z.reshapei('c', {4,3});
|
|
||||||
// z.printShapeInfo();
|
|
||||||
// z.printIndexedBuffer();
|
|
||||||
|
|
||||||
auto x = NDArrayFactory::create<double>('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
|
|
||||||
auto y = NDArrayFactory::create<double>('f', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
|
|
||||||
auto expected = NDArrayFactory::create<double>('c', {3,4,4,3}, {14, 14, 14, 30, 30, 30, 46, 46, 46, 62, 62, 62, 86, 86, 86,198,198,198,310,310,310,422,422,422, 62, 62, 62,142,142,142,222,222,222,302,302,302, 38, 38, 38, 86, 86, 86,134,134,134,182,182,182, 38, 38, 38, 86, 86, 86,134,134,134,182,182,182, 14, 14, 14, 30, 30, 30, 46, 46, 46, 62, 62, 62, 86, 86, 86,198,198,198,310,310,310,422,422,422, 62, 62, 62,142,142,142,222,222,222,302,302,302, 62, 62, 62,142,142,142,222,222,222,302,302,302, 38, 38, 38, 86, 86, 86,134,134,134,182,182,182, 14, 14, 14, 30, 30, 30, 46, 46, 46, 62, 62, 62, 86, 86, 86,198,198,198,310,310,310,422,422,422});
|
|
||||||
|
|
||||||
nd4j::ops::tensormmul op;
|
|
||||||
auto results = op.evaluate({&x, &y}, {}, {1,0,1,0});
|
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
||||||
|
|
||||||
auto *result = results->at(0);
|
|
||||||
|
|
||||||
ASSERT_TRUE(expected.isSameShape(result));
|
|
||||||
ASSERT_TRUE(expected.equalsTo(result));
|
|
||||||
|
|
||||||
delete results;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////
|
|
||||||
TEST_F(DeclarableOpsTests2, TestTensorDot10) {
|
|
||||||
|
|
||||||
auto x = NDArrayFactory::create<double>('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
|
|
||||||
auto y = NDArrayFactory::create<double>('f', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
|
|
||||||
auto expected = NDArrayFactory::create<double>('c', {4,4}, {114,258,402,546, 138,314,490,666, 162,370,578,786, 186,426,666,906});
|
|
||||||
|
|
||||||
nd4j::ops::tensormmul op;
|
|
||||||
auto results = op.evaluate({&x, &y}, {}, {2,0,1, 2,0,2});
|
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
||||||
|
|
||||||
auto *result = results->at(0);
|
|
||||||
|
|
||||||
ASSERT_TRUE(expected.isSameShape(result));
|
|
||||||
ASSERT_TRUE(expected.equalsTo(result));
|
|
||||||
|
|
||||||
delete results;
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////
|
|
||||||
TEST_F(DeclarableOpsTests2, TestTensorDot11) {
|
|
||||||
|
|
||||||
auto x = NDArrayFactory::create<double>('c', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
|
|
||||||
auto y = NDArrayFactory::create<double>('f', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
|
|
||||||
auto expected = NDArrayFactory::create<double>('c', {4,4}, {98,218,338,458, 134,302,470,638, 170,386,602,818, 206,470,734,998});
|
|
||||||
|
|
||||||
nd4j::ops::tensormmul op;
|
|
||||||
auto results = op.evaluate({&x, &y}, {}, {2,0,1, 2,0,2});
|
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
||||||
|
|
||||||
auto *result = results->at(0);
|
|
||||||
|
|
||||||
ASSERT_TRUE(expected.isSameShape(result));
|
|
||||||
ASSERT_TRUE(expected.equalsTo(result));
|
|
||||||
|
|
||||||
delete results;
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////
|
|
||||||
TEST_F(DeclarableOpsTests2, TestTensorDot12) {
|
|
||||||
|
|
||||||
auto x = NDArrayFactory::create<double>('c', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
|
|
||||||
auto y = NDArrayFactory::create<double>('c', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
|
|
||||||
auto expected = NDArrayFactory::create<double>('c', {4,4}, {272,292,312,332, 368,396,424,452, 464,500,536,572, 560,604,648,692});
|
|
||||||
|
|
||||||
nd4j::ops::tensormmul op;
|
|
||||||
auto results = op.evaluate({&x, &y}, {}, {2,0,1, 2,0,2});
|
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
||||||
|
|
||||||
auto *result = results->at(0);
|
|
||||||
|
|
||||||
ASSERT_TRUE(expected.isSameShape(result));
|
|
||||||
ASSERT_TRUE(expected.equalsTo(result));
|
|
||||||
|
|
||||||
delete results;
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////
|
|
||||||
TEST_F(DeclarableOpsTests2, TestTensorDot13) {
|
|
||||||
|
|
||||||
auto x = NDArrayFactory::create<double>('c', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
|
|
||||||
auto y = NDArrayFactory::create<double>('c', {4,2,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
|
|
||||||
auto expected = NDArrayFactory::create<double>('c', {3,3}, {640,560,640, 576,624,576, 640,560,640});
|
|
||||||
|
|
||||||
nd4j::ops::tensormmul op;
|
|
||||||
auto results = op.evaluate({&x, &y}, {}, {2,0,2, 2,1,0});
|
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
||||||
|
|
||||||
auto *result = results->at(0);
|
|
||||||
|
|
||||||
ASSERT_TRUE(expected.isSameShape(result));
|
|
||||||
ASSERT_TRUE(expected.equalsTo(result));
|
|
||||||
|
|
||||||
delete results;
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////
|
|
||||||
TEST_F(DeclarableOpsTests2, TestTensorDot14) {
|
|
||||||
|
|
||||||
auto x = NDArrayFactory::create<double>('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
|
|
||||||
auto y = NDArrayFactory::create<double>('c', {4,2,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
|
|
||||||
auto expected = NDArrayFactory::create<double>('c', {3,3}, {648,600,520, 648,536,648, 520,600,648});
|
|
||||||
|
|
||||||
nd4j::ops::tensormmul op;
|
|
||||||
auto results = op.evaluate({&x, &y}, {}, {2,0,2, 2,1,0});
|
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
||||||
|
|
||||||
auto *result = results->at(0);
|
|
||||||
|
|
||||||
ASSERT_TRUE(expected.isSameShape(result));
|
|
||||||
ASSERT_TRUE(expected.equalsTo(result));
|
|
||||||
|
|
||||||
delete results;
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////
|
|
||||||
TEST_F(DeclarableOpsTests2, TestTensorDot15) {
|
|
||||||
|
|
||||||
auto x = NDArrayFactory::create<double>('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
|
|
||||||
auto y = NDArrayFactory::create<double>('f', {4,2,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
|
|
||||||
auto expected = NDArrayFactory::create<double>('c', {3,3}, {624,624,624, 656,656,656, 624,624,624});
|
|
||||||
|
|
||||||
nd4j::ops::tensormmul op;
|
|
||||||
auto results = op.evaluate({&x, &y}, {}, {2,0,2, 2,1,0});
|
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
||||||
|
|
||||||
auto *result = results->at(0);
|
|
||||||
|
|
||||||
ASSERT_TRUE(expected.isSameShape(result));
|
|
||||||
ASSERT_TRUE(expected.equalsTo(result));
|
|
||||||
|
|
||||||
delete results;
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_1) {
|
TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_1) {
|
||||||
|
|
|
@ -2043,34 +2043,6 @@ TEST_F(DeclarableOpsTests9, cumprod_test1) {
|
||||||
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {1, 1}, {1, 1},GradCheck::MEAN);
|
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {1, 1}, {1, 1},GradCheck::MEAN);
|
||||||
|
|
||||||
ASSERT_TRUE(isGradCorrect);
|
ASSERT_TRUE(isGradCorrect);
|
||||||
|
|
||||||
//************************************//
|
|
||||||
/* exclusive = 1; reverse = 0;
|
|
||||||
|
|
||||||
result = op.execute({&inputC, &axis}, {}, {exclusive, reverse});
|
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
|
||||||
z = result->at(0);
|
|
||||||
ASSERT_TRUE(expTF.equalsTo(z));
|
|
||||||
delete result;
|
|
||||||
*/
|
|
||||||
//************************************//
|
|
||||||
/* exclusive = 0; reverse = 1;
|
|
||||||
|
|
||||||
result = op.execute({&inputC, &axis}, {}, {exclusive, reverse});
|
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
|
||||||
z = result->at(0);
|
|
||||||
ASSERT_TRUE(expFT.equalsTo(z));
|
|
||||||
delete result;
|
|
||||||
*/
|
|
||||||
//************************************//
|
|
||||||
/* exclusive = 1; reverse = 1;
|
|
||||||
|
|
||||||
result = op.execute({&inputC, &axis}, {}, {exclusive, reverse});
|
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
|
||||||
z = result->at(0);
|
|
||||||
ASSERT_TRUE(expTT.equalsTo(z));
|
|
||||||
delete result;
|
|
||||||
*/
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -2079,11 +2051,6 @@ TEST_F(DeclarableOpsTests9, cumprod_test2) {
|
||||||
auto inputC = NDArrayFactory::create<double>('c', {2, 2});
|
auto inputC = NDArrayFactory::create<double>('c', {2, 2});
|
||||||
auto axis = NDArrayFactory::create<double>(1.);
|
auto axis = NDArrayFactory::create<double>(1.);
|
||||||
|
|
||||||
// auto expFF = NDArrayFactory::create<double>('c', {3, 5}, {1., 2., 6., 24., 120., 6., 42., 336., 3024., 30240.,11., 132.,1716., 24024.,360360.});
|
|
||||||
// auto expTF = NDArrayFactory::create<double>('c', {3, 5}, {1, 1, 2, 6, 24,1, 6, 42, 336, 3024,1, 11, 132, 1716, 24024});
|
|
||||||
|
|
||||||
// auto expFT = NDArrayFactory::create<double>('c', {3, 5}, {120, 120, 60, 20, 5,30240, 5040, 720, 90, 10,360360, 32760, 2730, 210, 15}); //+++
|
|
||||||
// auto expTT = NDArrayFactory::create<double>('c', {3, 5}, {120, 60, 20, 5, 1,5040, 720, 90, 10, 1,32760, 2730, 210, 15, 1});
|
|
||||||
auto gradO = NDArrayFactory::create<double>('c', {2, 2});
|
auto gradO = NDArrayFactory::create<double>('c', {2, 2});
|
||||||
|
|
||||||
int exclusive, reverse;
|
int exclusive, reverse;
|
||||||
|
|
Loading…
Reference in New Issue