profiling of stack and unstack ops (#261)

* - profiling of stack and unstack ops

Signed-off-by: Yurii <iuriish@yahoo.com>

* - fix bug in cpu concat op

Signed-off-by: Yurii <iuriish@yahoo.com>

* - correction of cuda stack and unstack

Signed-off-by: Yurii <iuriish@yahoo.com>

* - change shape.h method which operates with unity dimensions strides

Signed-off-by: Yurii <iuriish@yahoo.com>

* - rearrange stack tests

Signed-off-by: Yurii <iuriish@yahoo.com>

* - correct evaluation of smallest stride for moving through contiguous axis

Signed-off-by: Yurii <iuriish@yahoo.com>

* - forgot to update signature of function strideOverContigAxis in cuda concat and split ops

Signed-off-by: Yurii <iuriish@yahoo.com>

* - remove ShapeUtils::shapeAsString method applied before input arrays validations

Signed-off-by: Yurii <iuriish@yahoo.com>

* -  further removing of ShapeUtils::shapeAsString

Signed-off-by: Yurii <iuriish@yahoo.com>

* - take sub-array shapeIndo/offset calculation out of NDArray class
- add possibility of contiguous memory copy in execTransformAny op if opNum == assign

Signed-off-by: Yurii <iuriish@yahoo.com>

* - correct test_empty_scatter_2 in EmptyTests.cpp

Signed-off-by: Yurii <iuriish@yahoo.com>

* - profiling of slice op

Signed-off-by: Yurii <iuriish@yahoo.com>

* - get rid of contiguous memcpy for some cases in concat and split ops

Signed-off-by: Yurii <iuriish@yahoo.com>

* - forgot to declare oid nd4j::SpecialMethods<T>::splitCpuGeneric

Signed-off-by: Yurii <iuriish@yahoo.com>

* - correct typo in calculation of threads in cuda split op

Signed-off-by: Yurii <iuriish@yahoo.com>

* - forgot to correct another set of threads variables in split cuda ops

Signed-off-by: Yurii <iuriish@yahoo.com>

* - further conflicts resolving

Signed-off-by: Yurii <iuriish@yahoo.com>

Co-authored-by: raver119 <raver119@gmail.com>
master
Yurii Shyrma 2020-03-03 06:32:37 +02:00 committed by GitHub
parent 0f581e74e3
commit 78934c17ad
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
61 changed files with 1947 additions and 1523 deletions

View File

@ -903,14 +903,6 @@ namespace sd {
*/
void transposei();
/**
* return array pointing on certain range of this array
* index - the number of array to be returned among set of possible arrays
* dimensions - array of dimensions to point on
*/
NDArray tensorAlongDimension(Nd4jLong index, const std::initializer_list<int>& dimensions) const;
NDArray tensorAlongDimension(Nd4jLong index, const std::vector<int>& dimensions) const;
/**
* returns the number of arrays pointing on specified dimension(s)
* dimensions - array of dimensions to point on

View File

@ -1197,14 +1197,9 @@ void NDArray::assign(const NDArray& other, bool allowParallelism) {
throw std::runtime_error("NDArray::assign: lengths of arrays are mismatched");
}
// memcpy is allowed only for same order c && same ews (being equal to 1)
if (ordering() == other.ordering() && ordering() == 'c' && dataType() == other.dataType() && ews() == 1 && other.ews() == 1)
copyBuffersContinuouslyFrom(other, other.lengthOf() * other.sizeOfT());
else {
NDArray::prepareSpecialUse({this}, {&other});
NativeOpExecutioner::execTransformAny(getContext(), transform::Assign, other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, nullptr, nullptr, allowParallelism);
NDArray::registerSpecialUse({this}, {&other});
}
NDArray::prepareSpecialUse({this}, {&other});
NativeOpExecutioner::execTransformAny(getContext(), transform::Assign, other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, nullptr, nullptr, allowParallelism);
NDArray::registerSpecialUse({this}, {&other});
}
}
@ -4810,30 +4805,6 @@ ResultSet NDArray::allTensorsAlongDimension(const std::vector<int> &dimensions)
return result;
}
//////////////////////////////////////////////////////////////////////////
NDArray NDArray::tensorAlongDimension(Nd4jLong index, const std::vector<int>& dimensions) const {
std::vector<int> copy(dimensions);
shape::checkDimensions(rankOf(), copy);
Nd4jLong tadLength = shape::tadLength(this->getShapeInfo(), copy.data(), copy.size());
Nd4jLong numTads = this->lengthOf() / tadLength;
if (index >= numTads)
throw std::runtime_error("Can't get index higher than total number of TADs");
auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(this->getShapeInfo(), copy);
NDArray array(_buffer, ShapeDescriptor(packX.primaryShapeInfo()), getContext(), packX.primaryOffsets()[index] + getBufferOffset());
array._isView = true;
return array;
}
//////////////////////////////////////////////////////////////////////////
NDArray NDArray::tensorAlongDimension(Nd4jLong index, const std::initializer_list<int>& dimensions) const {
return tensorAlongDimension(index, std::vector<int>(dimensions));
}
////////////////////////////////////////////////////////////////////////
// operator returns sub-array with buffer pointing at this->_buffer + certain offset
NDArray NDArray::operator()(const std::vector<Nd4jLong>& idx, const bool keepUnitiesInShape, const bool isStrided) const {
@ -4841,63 +4812,73 @@ NDArray NDArray::operator()(const std::vector<Nd4jLong>& idx, const bool keepUni
if(isEmpty())
throw std::invalid_argument("NDArray::operator(sub-arrays): array is empty !");
const int rank = rankOf();
Nd4jLong *newShapeInfo = ShapeBuilders::copyShapeInfo(getShapeInfo(), true, getContext()->getWorkspace());
// Nd4jLong *outShapeInfo = nullptr;
// ALLOCATE(outShapeInfo, workspace, shape::shapeInfoLength(inShapeInfo), Nd4jLong);
auto shapeOf = shape::shapeOf(newShapeInfo);
auto stridesOf = shape::stride(newShapeInfo);
int numOfUntiesInSubArrShape = 0;
Nd4jLong offset = 0;
int n(isStrided ? 3 : 2), first, last, stride;
for (int d = rank - 1; d >= 0; --d) {
if (idx[n * d] != idx[n * d + 1]) {
first = idx[n * d] >= 0 ? idx[n * d] : idx[n * d] + sizeAt(d) + 1;
last = idx[n * d + 1] >= 0 ? idx[n * d + 1] : idx[n * d + 1] + sizeAt(d) + 1;
stride = isStrided ? idx[n * d + 2] : 1;
shapeOf[d] = (last - first + stride - 1) / stride; // ceil (last - first) / stride;
offset += first * stridesOf[d];
if(shapeOf[d] != 1)
stridesOf[d] *= stride;
}
}
Nd4jLong *newShapeInfo2 = newShapeInfo;
Nd4jLong* subArrShapeInfo = nullptr;
if(!keepUnitiesInShape) {
std::vector<int> dimsWithUnities;
int n(isStrided ? 3 : 2), first, last;
for (int d = 0; d < rank; ++d)
if(idx[n*d] != idx[n*d+1] && shapeOf[d] == 1)
dimsWithUnities.push_back(d);
// calculate the number of unities in shape
for (uint d = 0; d < rankOf(); ++d) {
if(!dimsWithUnities.empty())
newShapeInfo2 = ShapeBuilders::copyShapeInfoWithoutUnites(newShapeInfo, dimsWithUnities.size(), dimsWithUnities.data(), getContext()->getWorkspace());
if (idx[n * d] != idx[n * d + 1]) {
first = idx[n * d] >= 0 ? idx[n * d] : idx[n * d] + sizeAt(d) + 1;
last = idx[n * d + 1] >= 0 ? idx[n * d + 1] : idx[n * d + 1] + sizeAt(d) + 1;
if(last - first == 1)
++numOfUntiesInSubArrShape;
}
}
}
// check if there is possibility to set ews = 1
shape::checkStridesEwsAndOrder(newShapeInfo2);
ALLOCATE(subArrShapeInfo, getContext()->getWorkspace(), shape::shapeInfoLength(rankOf() - numOfUntiesInSubArrShape), Nd4jLong);
NDArray result(_buffer, ShapeDescriptor(newShapeInfo2), getContext(), offset + getBufferOffset());
Nd4jLong offset;
shape::calcSubArrShapeInfoAndOffset(idx.data(), getShapeInfo(), subArrShapeInfo, offset, keepUnitiesInShape, isStrided, numOfUntiesInSubArrShape);
NDArray result(_buffer, ShapeDescriptor(subArrShapeInfo), getContext(), offset + getBufferOffset());
result._isView = true;
RELEASE(newShapeInfo, getContext()->getWorkspace());
if(newShapeInfo != newShapeInfo2)
RELEASE(newShapeInfo2, getContext()->getWorkspace());
RELEASE(subArrShapeInfo, getContext()->getWorkspace());
return result;
}
////////////////////////////////////////////////////////////////////////
NDArray NDArray::operator()(const Nd4jLong subArrIdx, const std::vector<int>& dimsToExclude, bool keepUnitiesInShape) const {
std::vector<Nd4jLong> idxRanges(2 * rankOf());
ShapeUtils::evalIdxRangesForSubArr(subArrIdx, _shapeInfo, dimsToExclude, idxRanges.data());
const auto rank = rankOf();
const auto subArrRank = static_cast<int>(dimsToExclude.size());
if(subArrRank > rank)
throw std::invalid_argument("NDArray::operator(const Nd4jLong subArrIdx, const std::vector<int>& dimsToExclude, bool keepUnitiesInShape): static method: dimsToExclude is empty or has size > rank of array !");
memset(idxRanges.data(), 0, 2 * rank * sizeof(Nd4jLong));
// subArrRank == 0 means whole array, idxRanges should contain zeros only
if(subArrRank != 0) {
std::vector<Nd4jLong> shapeOfSubArr(subArrRank), indexes(subArrRank);
for(int i = 0; i < subArrRank; ++i)
shapeOfSubArr[i] = sizeAt(dimsToExclude[i]);
shape::index2coords(subArrIdx, subArrRank, shapeOfSubArr.data(), indexes.data());
for(int i = 0; i < subArrRank; ++i) {
int currIdx = 2 * dimsToExclude[i];
idxRanges[currIdx] = indexes[i];
idxRanges[currIdx + 1] = indexes[i] + 1;
}
}
return (*this)(idxRanges, keepUnitiesInShape);
}
@ -4916,7 +4897,7 @@ void NDArray::getSubArrShapeAndOffsets(const std::vector<int>& dimsToExclude, Nd
ALLOCATE(subArrShapeInfo, getContext()->getWorkspace(), shape::shapeInfoLength(subArrRank), Nd4jLong);
ALLOCATE(subArrOffsets, getContext()->getWorkspace(), numOfSubArrs, Nd4jLong);
shape::calcSubArrShapeAndOffsets(_shapeInfo, numOfSubArrs, dimsToExclude.size(), dimsToExclude.data(), subArrShapeInfo, subArrOffsets, keepUnitiesInShape);
shape::calcSubArrsShapeInfoAndOffsets(_shapeInfo, numOfSubArrs, dimsToExclude.size(), dimsToExclude.data(), subArrShapeInfo, subArrOffsets, keepUnitiesInShape);
}
//////////////////////////////////////////////////////////////////////////

View File

@ -138,16 +138,6 @@ namespace sd {
*/
static Nd4jLong getNumOfSubArrs(const Nd4jLong* shapeInfo, const std::vector<int>& dimsToExclude);
/**
* evaluate indexes ranges that define sub-array of array having shape=shapeInfo
* subArrIdx - index of current sub-array
* shapeInfo - shapeInfo of array for which to evaluate sub-arrays
* dimsToExclude - MUST BE SORTED, dimensions to evaluate sub-arrays along, i.e. when shape is [2,3,4,5] and dimsToExclude={0,2}, then there will be 8 sub-arrays with shape [3,5],
* if dimsToExclude is empty then idxRanges containing all zeros (means whole array) will be returned.
* idxRanges - where to put result, the length of idxRanges must be equal to 2*shapeInfo[0]
*/
static void evalIdxRangesForSubArr(const Nd4jLong subArrIdx, const Nd4jLong* shapeInfo, const std::vector<int>& dimsToExclude, Nd4jLong* idxRanges);
/**
* return shape without unities, for example if shape is [1,2,1,3] then [2,3] will be returned
* if unities are not present in given shapeInfo then exactly identical shape will be returned, for example [2,3] -> [2,3]
@ -202,6 +192,11 @@ namespace sd {
static bool isSubArrayCase(const NDArray& arr1, const NDArray& arr2, std::vector<int>& sameDims);
*/
/*
* comparing of shapes, not strides
*/
static bool areShapesEqual(const Nd4jLong* shapeInfo, const std::vector<Nd4jLong>& shapeOnly);
};

View File

@ -73,7 +73,7 @@ namespace sd {
auto oPtr = new Nd4jLong[numOfSubArrs];
if (numOfSubArrs > 0)
shape::calcSubArrShapeAndOffsets(shapeInfo, numOfSubArrs, dimsToExclude.size(), dimsToExclude.data(), sPtr, oPtr, descriptor.areUnitiesinShape());
shape::calcSubArrsShapeInfoAndOffsets(shapeInfo, numOfSubArrs, dimsToExclude.size(), dimsToExclude.data(), sPtr, oPtr, descriptor.areUnitiesinShape());
ConstantDataBuffer shapesBuffer(sPtr, nullptr, shape::shapeInfoLength(subArrRank)*sizeof(Nd4jLong), DataType::INT64);

View File

@ -77,7 +77,7 @@ namespace sd {
auto oPtr = new Nd4jLong[numOfSubArrs];
if (numOfSubArrs > 0)
shape::calcSubArrShapeAndOffsets(shapeInfo, numOfSubArrs, dimsToExclude.size(), dimsToExclude.data(), sPtr, oPtr, descriptor.areUnitiesinShape());
shape::calcSubArrsShapeInfoAndOffsets(shapeInfo, numOfSubArrs, dimsToExclude.size(), dimsToExclude.data(), sPtr, oPtr, descriptor.areUnitiesinShape());
Nd4jPointer soPtr;
auto res = cudaMalloc(reinterpret_cast<void**>(&soPtr), numOfSubArrs * sizeof(Nd4jLong));

View File

@ -940,16 +940,16 @@ NDArray* MmulHelper::mmulNxNold2(const NDArray* A, const NDArray* B, NDArray* C,
std::vector<void*> aSubArrs(bS), bSubArrs(bS), cSubArrs(bS);
if(aRank > 2)
shape::calcSubArrShapeAndOffsets(pA->getShapeInfo(), bS, dimsToExclude.size(), dimsToExclude.data(), subArrShapeInfo.data(), subArrOffsets.data());
shape::calcSubArrsShapeInfoAndOffsets(pA->getShapeInfo(), bS, dimsToExclude.size(), dimsToExclude.data(), subArrShapeInfo.data(), subArrOffsets.data());
for (int i = 0; i < bS; ++i)
aSubArrs[i] = aRank == 2 ? pA->getSpecialBuffer() : pA->getSpecialBuffer() + subArrOffsets[i] * pA->sizeOfT();
if(bRank > 2)
shape::calcSubArrShapeAndOffsets(pB->getShapeInfo(), bS, dimsToExclude.size(), dimsToExclude.data(), subArrShapeInfo.data(), subArrOffsets.data());
shape::calcSubArrsShapeInfoAndOffsets(pB->getShapeInfo(), bS, dimsToExclude.size(), dimsToExclude.data(), subArrShapeInfo.data(), subArrOffsets.data());
for (int i = 0; i < bS; ++i)
bSubArrs[i] = bRank == 2 ? pB->getSpecialBuffer() : pB->getSpecialBuffer() + subArrOffsets[i] * pB->sizeOfT();
shape::calcSubArrShapeAndOffsets(pC->getShapeInfo(), bS, dimsToExclude.size(), dimsToExclude.data(), subArrShapeInfo.data(), subArrOffsets.data());
shape::calcSubArrsShapeInfoAndOffsets(pC->getShapeInfo(), bS, dimsToExclude.size(), dimsToExclude.data(), subArrShapeInfo.data(), subArrOffsets.data());
for (int i = 0; i < bS; ++i)
cSubArrs[i] = pC->getSpecialBuffer() + subArrOffsets[i] * pC->sizeOfT();

View File

@ -974,35 +974,6 @@ Nd4jLong ShapeUtils::getNumOfSubArrs(const Nd4jLong* shapeInfo, const std::vecto
return numOfSubArrs;
}
////////////////////////////////////////////////////////////////////////////////
void ShapeUtils::evalIdxRangesForSubArr(const Nd4jLong subArrIdx, const Nd4jLong* shapeInfo, const std::vector<int>& dimsToExclude, Nd4jLong* idxRanges) {
const auto rank = shape::rank(shapeInfo);
const auto subArrRank = static_cast<int>(dimsToExclude.size());
if(subArrRank > rank)
throw std::invalid_argument("ShapeUtils::evalIdxRangesForSubArr static method: dimsToExclude is empty or has size > rank of array !");
if(subArrRank == 0) { // means whole array
memset(idxRanges, 0, 2 * rank * sizeof(Nd4jLong));
return;
}
std::vector<Nd4jLong> shapeOfSubArr(subArrRank), indexes(subArrRank);
for(int i = 0; i < subArrRank; ++i)
shapeOfSubArr[i] = shapeInfo[dimsToExclude[i] + 1];
shape::index2coords(subArrIdx, subArrRank, shapeOfSubArr.data(), indexes.data());
memset(idxRanges, 0, 2 * rank * sizeof(Nd4jLong));
for(int i = 0; i < subArrRank; ++i) {
int currIdx = 2 * dimsToExclude[i];
idxRanges[currIdx] = indexes[i];
idxRanges[currIdx + 1] = indexes[i] + 1;
}
}
////////////////////////////////////////////////////////////////////////////////
std::vector<Nd4jLong> ShapeUtils::evalDimsWithoutUnities(const Nd4jLong* shapeInfo) {
@ -1080,6 +1051,19 @@ void ShapeUtils::copyCertainStridesFromShapeInfo(const Nd4jLong* inShapeInfo, co
}
}
}
bool ShapeUtils::areShapesEqual(const Nd4jLong* shapeInfo, const std::vector<Nd4jLong>& shapeOnly) {
if(shape::rank(shapeInfo) != shapeOnly.size())
return false;
for(uint i = 0; i < shape::rank(shapeInfo); ++i)
if(shape::shapeOf(shapeInfo)[i] != shapeOnly[i])
return false;
return true;
}
////////////////////////////////////////////////////////////////////////////////
/*
bool ShapeUtils::isSubArrayCase(const NDArray& arr1, const NDArray& arr2, std::vector<int>& sameDims) {

View File

@ -117,7 +117,7 @@ namespace shape {
ND4J_EXPORT _CUDA_HD bool haveSameShapeAndStrides(const Nd4jLong *shapeInfo1, const Nd4jLong *shapeInfo2);
ND4J_EXPORT _CUDA_HD bool haveSameShapeAndStrides(const Nd4jLong *shapeInfo1, const Nd4jLong *shapeInfo2, const Nd4jLong *shapeInfo3);
ND4J_EXPORT _CUDA_HD int sizeAt(const Nd4jLong *shape, const int dim);
ND4J_EXPORT _CUDA_HD int sizeAt(const Nd4jLong *shapeInfo, const int dim);
template <typename T>
ND4J_EXPORT _CUDA_HD void fill(T* buffer, T value, Nd4jLong length);
@ -469,9 +469,6 @@ namespace shape {
ND4J_EXPORT _CUDA_HD int rank(const int *shapeInfo);
ND4J_EXPORT _CUDA_HD int rank(const unsigned int *shapeInfo);
// returns pointer on elementWiseStride
ND4J_EXPORT _CUDA_HD Nd4jLong* ews(Nd4jLong* shapeInfo);
/**
* returns pointer on elementWiseStride
*/
@ -1029,7 +1026,23 @@ namespace shape {
* subArrOffsets - output argument, contains successive sub-arrays offsets from original this-buffer
* keepUnitiesInShape - if false then eliminate unities from sub-array shapeInfo, for example {1,a,1,b} -> {a,b}
*/
ND4J_EXPORT _CUDA_HD void calcSubArrShapeAndOffsets(const Nd4jLong* wholeShapeInfo, const Nd4jLong numOfSubArrs, const int dimsSize, const int* dimsToExclude, Nd4jLong* subArrShapeInfo, Nd4jLong* subArrOffsets, bool keepUnitiesInShape = false);
ND4J_EXPORT _CUDA_HD void calcSubArrsShapeInfoAndOffsets(const Nd4jLong* wholeShapeInfo, const Nd4jLong numOfSubArrs, const int dimsSize, const int* dimsToExclude, Nd4jLong* subArrShapeInfo, Nd4jLong* subArrOffsets, bool keepUnitiesInShape = false);
/**
* processes only one sub-array, evaluates shapeInfo of sub-array and its buffer offset from original array
* arguments:
* idx - input argument, intervals of indexes which define the sub-array to point on,
* when isStrided = false then idx has form {dim0Start,dim0End, dim1Start,dim1End, ....} and length (2 * maxRank)
* when isStrided = true then idx has form {dim0Start,dim0End,dim0Stride, dim1Start,dim1End,dim1Stride, ....} and length (3 * maxRank)
* when (dimStart == dimEnd) then whole range will be used for current dimension
* maxShapeInfo - input argument, shapeInfo of original array
* minShapeInfo - output argument, shapeInfo of sub-array to be deduced
* minOffset - output argument, offset of sub-array buffer offsets from original buffer
* keepUnitiesInShape - input argument, if false then eliminate unities from sub-array shapeInfo, for example {1,a,1,b} -> {a,b}
* isStrided - input argument, if true then idx has length (3 * this->rankOf()) and contains additional stride numbers which correspond to stride between dimStart and dimEnd,
* numOfUntiesInMinShape - input argument, number of occurrences in idx when (dimEnd - dimStart) = 1
*/
ND4J_EXPORT void calcSubArrShapeInfoAndOffset(const Nd4jLong* idx, const Nd4jLong* maxShapeInfo, Nd4jLong* minShapeInfo, Nd4jLong& minOffset, const bool keepUnitiesInShape = false, const bool isStrided = false, const int numOfUntiesInMinShape = 0);
/**
* for example inShapeInfo is {3, 2,1,4, 4,4,1, 16384,1,99}
@ -1046,6 +1059,12 @@ namespace shape {
*/
INLINEDEF _CUDA_HD void excludeUnitiesFromShapeInfo(const Nd4jLong* inShapeInfo, const int dimsSize, const int* dimsToExclude, Nd4jLong* outShapeInfo);
/**
* get stride over contiguous axis (contiguous axis must have stride = 1)
* for example when inShapeInfo is {4, 2,5,4,3, 60,1,5,20, 16384,0,99} then output is 5 (that is smallest stride in inShapeInfo except those equal to 1)
*/
INLINEDEF _CUDA_HD Nd4jLong strideOverContigAxis(const int axis, const Nd4jLong* inShapeInfo);
@ -2961,13 +2980,13 @@ INLINEDEF _CUDA_HD bool haveSameShapeAndStrides(const Nd4jLong *shapeInfo1, cons
return shape::haveSameShapeAndStrides(shapeInfo1, shapeInfo2) && shape::haveSameShapeAndStrides(shapeInfo1, shapeInfo3);
}
INLINEDEF _CUDA_HD int sizeAt(const Nd4jLong *shape, const int dim) {
if (0 == rank(shape))
INLINEDEF _CUDA_HD int sizeAt(const Nd4jLong *shapeInfo, const int dim) {
if (0 == rank(shapeInfo))
return 1;
if (dim >= 0)
return shape[1+dim];
return shapeInfo[1+dim];
else
return shape[1+(rank(shape) + dim)];
return shapeInfo[1+(rank(shapeInfo) + dim)];
}
/**
@ -4683,21 +4702,6 @@ INLINEDEF void _CUDA_HD checkStridesEwsAndOrder(Nd4jLong* shapeInfo, const char
if(contiguous) {
// for example we have shapeInfo = {3, 5,1,1, 4,4,1, ...} then we should change it to shapeInfo = {3, 5,1,1, 4,4,4, ...ews=4}
if(numOfNonUnities < rank) { // unities are present in shape
int indNonUnit = rank - 1;
while(shape::shapeOf(shapeInfo)[indNonUnit--] == 1)
for(int j = indNonUnit + 2; j < rank; ++j)
shape::stride(shapeInfo)[j] = stridesNoUnities[numOfNonUnities - 1];
for(int j = indNonUnit; j >= 0; --j)
if(shape::shapeOf(shapeInfo)[j] == 1)
shape::stride(shapeInfo)[j] = shape::shapeOf(shapeInfo)[j + 1] * shape::stride(shapeInfo)[j + 1];
}
*shape::ews(shapeInfo) = stridesNoUnities[numOfNonUnities - 1];
shapeInfo[rank * 2 + 3] = 99;
return;
@ -4715,21 +4719,6 @@ INLINEDEF void _CUDA_HD checkStridesEwsAndOrder(Nd4jLong* shapeInfo, const char
if(contiguous) {
// for example we have shapeInfo = {3, 1,1,5, 1,4,4, ...} then we should change it to shapeInfo = {3, 1,1,5, 4,4,4, ...ews=4}
if(numOfNonUnities < rank) { // unities are present in shape
int indNonUnit = 0;
while(shape::shapeOf(shapeInfo)[indNonUnit++] == 1)
for(int j = 0; j < indNonUnit - 1; ++j)
shape::stride(shapeInfo)[j] = stridesNoUnities[0];
for(int j = indNonUnit; j < rank; ++j)
if(shape::shapeOf(shapeInfo)[j] == 1)
shape::stride(shapeInfo)[j] = shape::shapeOf(shapeInfo)[j - 1] * shape::stride(shapeInfo)[j - 1];
}
*shape::ews(shapeInfo) = stridesNoUnities[0];
shapeInfo[rank * 2 + 3] = 102;
return;
@ -4740,7 +4729,7 @@ INLINEDEF void _CUDA_HD checkStridesEwsAndOrder(Nd4jLong* shapeInfo, const char
}
//////////////////////////////////////////////////////////////////////
INLINEDEF _CUDA_HD void calcSubArrShapeAndOffsets(const Nd4jLong* wholeShapeInfo, const Nd4jLong numOfSubArrs, const int dimsSize, const int* dimsToExclude, Nd4jLong* subArrShapeInfo, Nd4jLong* subArrOffsets, bool keepUnitiesInShape) {
INLINEDEF _CUDA_HD void calcSubArrsShapeInfoAndOffsets(const Nd4jLong* wholeShapeInfo, const Nd4jLong numOfSubArrs, const int dimsSize, const int* dimsToExclude, Nd4jLong* subArrShapeInfo, Nd4jLong* subArrOffsets, bool keepUnitiesInShape) {
const int rank = shape::rank(wholeShapeInfo);
@ -4788,6 +4777,54 @@ INLINEDEF _CUDA_HD void calcSubArrShapeAndOffsets(const Nd4jLong* wholeShapeInfo
delete []shape;
}
//////////////////////////////////////////////////////////////////////
INLINEDEF void calcSubArrShapeInfoAndOffset(const Nd4jLong* idx, const Nd4jLong* maxShapeInfo, Nd4jLong* minShapeInfo, Nd4jLong& minOffset, const bool keepUnitiesInShape, const bool isStrided, const int numOfUntiesInMinShape) {
const uint maxRank = shape::rank(maxShapeInfo);
minOffset = 0;
uint first, last, stride, n(isStrided ? 3 : 2);
minShapeInfo[0] = keepUnitiesInShape ? maxRank : maxRank - numOfUntiesInMinShape;
for (uint step = 0, j = 0, i = 0; i < maxRank; ++i, step += n) {
if (idx[step] == idx[step + 1]) { // means whole dimension
shape::shapeOf(minShapeInfo)[j] = shape::shapeOf(maxShapeInfo)[i];
shape::stride(minShapeInfo)[j++] = shape::stride(maxShapeInfo)[i];
}
else {
first = idx[step] >= 0 ? idx[step] : idx[step] + shape::sizeAt(maxShapeInfo, i) + 1;
last = idx[step + 1] >= 0 ? idx[step + 1] : idx[step + 1] + shape::sizeAt(maxShapeInfo, i) + 1;
if(last < first)
throw("shape::calcSubArrShapeInfoAndOffset: negative range in input indexes is found!");
if(isStrided) {
stride = idx[step + 2];
last /*resulting sub-array axis*/ = (last - first + stride - 1) / stride; // ceil (last - first) / stride;
}
else {
stride = 1;
last /*resulting sub-array axis*/ = last - first;
}
minOffset += first * shape::stride(maxShapeInfo)[i];
if(!keepUnitiesInShape && last == 1)
continue;
shape::shapeOf(minShapeInfo)[j] = last;
shape::stride(minShapeInfo)[j++] = last == 1 ? shape::stride(maxShapeInfo)[i] : shape::stride(maxShapeInfo)[i] * stride;
}
}
minShapeInfo[2 * shape::rank(minShapeInfo) + 3] = shape::order(maxShapeInfo); // order
minShapeInfo[2 * shape::rank(minShapeInfo) + 1] = shape::type(maxShapeInfo); // type
shape::checkStridesEwsAndOrder(minShapeInfo);
}
//////////////////////////////////////////////////////////////////////
INLINEDEF void _CUDA_HD index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, Nd4jLong *coords) {
@ -5083,6 +5120,27 @@ INLINEDEF _CUDA_HD void excludeUnitiesFromShapeInfo(const Nd4jLong* inShapeInfo,
}
//////////////////////////////////////////////////////////////////////
INLINEDEF _CUDA_HD Nd4jLong strideOverContigAxis(const int axis, const Nd4jLong* inShapeInfo) {
Nd4jLong result = 9223372036854775807LL;
for(uint i = 0; i < shape::rank(inShapeInfo); ++i) {
const auto currentStride = shape::stride(inShapeInfo)[i];
if(i == axis || shape::shapeOf(inShapeInfo)[i] == 1)
continue;
if(result > currentStride)
result = currentStride;
}
return result == 9223372036854775807LL ? 1 : result;
}
}
#endif /* SHAPE_H_ */

View File

@ -1233,11 +1233,18 @@ void NativeOpExecutioner::execTransformAny(sd::LaunchContext *lc,
if (shape::isEmpty(hXShapeInfo))
return;
auto func = PRAGMA_THREADS_DO {
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformAny, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES, LIBND4J_TYPES);
};
if (opNum == sd::transform::Assign && shape::order(hXShapeInfo) == shape::order(hZShapeInfo) && shape::order(hXShapeInfo) == 'c' && xType == zType && shape::elementWiseStride(hXShapeInfo) == 1 && shape::elementWiseStride(hZShapeInfo) == 1) {
samediff::Threads::parallel_do(func, sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(shape::length(hZShapeInfo) / 1024, sd::Environment::getInstance()->maxMasterThreads())));
memcpy(hZ, hX, shape::length(hXShapeInfo) * sd::DataTypeUtils::sizeOfElement(xType));
}
else {
auto func = PRAGMA_THREADS_DO {
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformAny, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES, LIBND4J_TYPES);
};
samediff::Threads::parallel_do(func, sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(shape::length(hZShapeInfo) / 1024, sd::Environment::getInstance()->maxMasterThreads())));
}
}
////////////////////////////////////////////////////////////////////////

View File

@ -926,9 +926,14 @@ void NativeOpExecutioner::execTransformAny(sd::LaunchContext *lc,
if (shape::isEmpty(hXShapeInfo))
return;
dim3 launchDims(512, 512, 2048);
if (opNum == sd::transform::Assign && shape::order(hXShapeInfo) == shape::order(hZShapeInfo) && shape::order(hXShapeInfo) == 'c' && xType == zType && shape::elementWiseStride(hXShapeInfo) == 1 && shape::elementWiseStride(hZShapeInfo) == 1) {
cudaMemcpyAsync(dZ, dX, shape::length(hXShapeInfo) * sd::DataTypeUtils::sizeOfElement(xType), cudaMemcpyDeviceToDevice, *stream);
}
else {
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformAny, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), LIBND4J_TYPES, LIBND4J_TYPES);
dim3 launchDims(512, 512, 2048);
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformAny, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), LIBND4J_TYPES, LIBND4J_TYPES);
}
// TODO: remove after the release
auto res = cudaStreamSynchronize(*stream);

View File

@ -63,8 +63,8 @@ CUSTOM_OP_IMPL(conv1d, 2, 1, false, 0, 5) {
int iC = input->sizeAt(indIOioC); // input channels
int oC = weights->sizeAt(indWoC); // output channels
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kW, iC, oC});
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, "CUSTOM CONV1D OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str());
std::vector<Nd4jLong> expectedWeightsShape = {kW, iC, oC};
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV1D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
if (bias)
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV1D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
@ -123,8 +123,8 @@ DECLARE_SHAPE_FN(conv1d) {
int iC = inputShapeInfo[indIOioC+1]; // input channels
int oC = weightsShapeInfo[indWoC+1]; // output channels
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kW, iC, oC});
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weightsShapeInfo), 0, "CUSTOM CONV1D OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
std::vector<Nd4jLong> expectedWeightsShape = {kW, iC, oC};
REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "CUSTOM CONV1D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
if (biasShapeInfo)
REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM CONV1D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo));
@ -198,10 +198,10 @@ CUSTOM_OP_IMPL(conv1d_bp, 3, 2, false, 0, 5) {
int trueoH, trueoW; // true output height, width
ConvolutionUtils::calcOutSizePool2D(trueoH,trueoW, 1,kW, 1,sW, 0,pW, 1,dW, 1,iW, paddingMode);
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoW, 0,indIOioC,indIiW}));
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kW, iC, oC});
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "CUSTOM CONV1D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, "CUSTOM CONV1D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str());
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoW, 0,indIOioC,indIiW});
std::vector<Nd4jLong> expectedWeightsShape = {kW, iC, oC};
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM CONV1D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV1D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
if(bias)
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV1D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
@ -267,10 +267,10 @@ DECLARE_SHAPE_FN(conv1d_bp) {
int trueoH, trueoW; // true output height, width
ConvolutionUtils::calcOutSizePool2D(trueoH,trueoW, 1,kW, 1,sW, 0,pW, 1,dW, 1,iW, paddingMode);
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoW, 0,indIOioC,indIiW}));
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kW, iC, oC});
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradOShapeInfo), 0, "CUSTOM CONV1D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str());
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weightsShapeInfo), 0, "CUSTOM CONV1D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoW, 0,indIOioC,indIiW});
std::vector<Nd4jLong> expectedWeightsShape = {kW, iC, oC};
REQUIRE_TRUE(ShapeUtils::areShapesEqual(gradOShapeInfo, expectedGradOShape), 0, "CUSTOM CONV1D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str());
REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "CUSTOM CONV1D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
if(biasShapeInfo)
REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM CONV1D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo));

View File

@ -58,8 +58,8 @@ CUSTOM_OP_IMPL(conv2d, 2, 1, false, 0, 9) {
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kH, kW, iC, oC});
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, "CUSTOM CONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str());
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, oC};
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
if (bias)
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
@ -109,8 +109,8 @@ DECLARE_SHAPE_FN(conv2d) {
const int iC = inputShapeInfo[indIOioC+1]; // input channels
const int oC = weightsShapeInfo[indWoC+1]; // output channels
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kH, kW, iC, oC});
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weightsShapeInfo), 0, "CUSTOM CONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, oC};
REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "CUSTOM CONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
if (biasShapeInfo)
REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM CONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo));
@ -187,10 +187,10 @@ CUSTOM_OP_IMPL(conv2d_bp, 3, 2, false, 0, 9) {
int trueoH, trueoW; // true output height, width
ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}));
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kH, kW, iC, oC});
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "CUSTOM CONV2D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, "CUSTOM CONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str());
std::vector<Nd4jLong>expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1});
std::vector<Nd4jLong>expectedWeightsShape = {kH, kW, iC, oC};
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM CONV2D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
if(bias)
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV2D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
@ -242,10 +242,10 @@ DECLARE_SHAPE_FN(conv2d_bp) {
int trueoH, trueoW; // true output height, width
ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}));
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kH, kW, iC, oC});
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradOShapeInfo), 0, "CUSTOM CONV2D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str());
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weightsShapeInfo), 0, "CUSTOM CONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1});
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, oC};
REQUIRE_TRUE(ShapeUtils::areShapesEqual(gradOShapeInfo, expectedGradOShape), 0, "CUSTOM CONV2D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str());
REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "CUSTOM CONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
if(biasShapeInfo)
REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM CONV2D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo));
@ -300,10 +300,10 @@ CUSTOM_OP_IMPL(conv2d_input_bp, 3, 1, false, 0, 9) {
int trueoH, trueoW; // true output height, width
ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}));
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kH, kW, iC, oC});
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "CUSTOM CONV2D_INPUT_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, "CUSTOM CONV2D_INPUT_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str());
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1});
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, oC};
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM CONV2D_INPUT_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV2D_INPUT_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
ConvolutionUtils::conv2dBP(block, &input, weights, nullptr, gradO, gradI, nullptr, nullptr, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW);
@ -360,10 +360,10 @@ DECLARE_SHAPE_FN(conv2d_input_bp) {
int trueoH, trueoW; // true output height, width
ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}));
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kH, kW, iC, oC});
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradOShapeInfo), 0, "CUSTOM CONV2D_INPUT_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str());
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weightsShapeInfo), 0, "CUSTOM CONV2D_INPUT_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1});
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, oC};
REQUIRE_TRUE(ShapeUtils::areShapesEqual(gradOShapeInfo, expectedGradOShape), 0, "CUSTOM CONV2D_INPUT_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str());
REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "CUSTOM CONV2D_INPUT_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
Nd4jLong* gradIshapeInfo(nullptr);
ALLOCATE(gradIshapeInfo, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong);

View File

@ -59,8 +59,8 @@ CUSTOM_OP_IMPL(conv3dnew, 2, 1, false, 0, 13) {
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
REQUIRE_TRUE(paddingMode < 2, 0, "CUSTOM CONV3D OP: causal padding mode (paddingMode = 2) is not allowed for this operation !");
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kD, kH, kW, iC, oC});
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, "CUSTOM CONV3D OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str());
std::vector<Nd4jLong> expectedWeightsShape = {kD, kH, kW, iC, oC};
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV3D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
if (bias)
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV3D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
@ -139,8 +139,8 @@ DECLARE_SHAPE_FN(conv3dnew) {
int iC = inputShapeInfo[indIOioC+1]; // input channels
int oC = weightsShapeInfo[indWoC+1]; // output channels
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kD, kH, kW, iC, oC});
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weightsShapeInfo), 0, "CUSTOM CONV3D OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
std::vector<Nd4jLong> expectedWeightsShape = {kD, kH, kW, iC, oC};
REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "CUSTOM CONV3D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
if (biasShapeInfo)
REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM CONV3D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo));
@ -209,10 +209,10 @@ CUSTOM_OP_IMPL(conv3dnew_bp, 3, 2, false, 0, 13) {
ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, paddingMode);
REQUIRE_TRUE(paddingMode < 2, 0, "CUSTOM CONV3D_BP OP: causal padding mode (paddingMode = 2) is not allowed for this operation !");
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}));
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kD, kH, kW, iC, oC});
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "CUSTOM CONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, "CUSTOM CONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str());
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
std::vector<Nd4jLong> expectedWeightsShape = {kD, kH, kW, iC, oC};
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM CONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
if(bias)
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV3D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
@ -313,10 +313,10 @@ DECLARE_SHAPE_FN(conv3dnew_bp) {
int trueoD, trueoH, trueoW; // true output depth/height/width
ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, paddingMode);
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW, 0,indIOioC,indIiD,indIiD+1,indIiD+2}));
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kD, kH, kW, iC, oC});
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradOShapeInfo), 0, "CUSTOM CONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str());
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weightsShapeInfo), 0, "CUSTOM CONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW, 0,indIOioC,indIiD,indIiD+1,indIiD+2});
std::vector<Nd4jLong> expectedWeightsShape = {kD, kH, kW, iC, oC};
REQUIRE_TRUE(ShapeUtils::areShapesEqual(gradOShapeInfo, expectedGradOShape), 0, "CUSTOM CONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str());
REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "CUSTOM CONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
if(biasShapeInfo)
REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM CONV3D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo));

View File

@ -31,13 +31,13 @@ CUSTOM_OP_IMPL(pointwise_conv2d, 2, 1, false, 0, 0) {
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
auto weights = INPUT_VARIABLE(1); // [1, 1, iC, oC] always
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto output = OUTPUT_VARIABLE(0); // [bS, iH, iW, oC] (NHWC) or [bS, oC, iH, iW] (NCHW)
REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM POINTWISECONV2D OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf());
REQUIRE_TRUE(weights->rankOf() == 4, 0, "CUSTOM POINTWISECONV2D OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf());
if(bias)
REQUIRE_TRUE(bias->rankOf() <= 2, 0, "CUSTOM POINTWISECONV2D OP: rank of biases array must be equal <= 2, but got %i instead !", bias->rankOf());
REQUIRE_TRUE(bias->rankOf() <= 2, 0, "CUSTOM POINTWISECONV2D OP: rank of biases array must be equal <= 2, but got %i instead !", bias->rankOf());
int kH = 1; // filter(kernel) height
int kW = 1; // filter(kernel) width
@ -46,18 +46,18 @@ CUSTOM_OP_IMPL(pointwise_conv2d, 2, 1, false, 0, 0) {
int pH = 0; // paddings height
int pW = 0; // paddings width
int dH = 1; // dilations height
int dW = 1; // dilations width
int dW = 1; // dilations width
int isNCHW = block.getIArguments()->size() > 0 ? !INT_ARG(0) : 1; // INT_ARG(0): 0-NCHW, 1-NHWC
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
std::string expectedWeightsShape = ShapeUtils::shapeAsString({1, 1, iC, oC});
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, "CUSTOM POINTWISECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str());
if (bias)
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM POINTWISECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
std::vector<Nd4jLong> expectedWeightsShape = {1, 1, iC, oC};
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM POINTWISECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
if (bias)
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM POINTWISECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
ConvolutionUtils::conv2d(block, input, weights, bias, output, kH,kW, sH,sW, pH,pW, dH,dW, 1/*isSameMode*/, isNCHW);
return Status::OK();
@ -71,7 +71,7 @@ CUSTOM_OP_IMPL(pointwise_conv2d, 2, 1, false, 0, 0) {
DECLARE_SHAPE_FN(pointwise_conv2d) {
Nd4jLong* inputShapeInfo = inputShape->at(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
Nd4jLong* weightsShapeInfo = inputShape->at(1); // [1, 1, iC, oC] always
Nd4jLong* biasShapeInfo = block.width() > 2 ? inputShape->at(2) : nullptr; // [oC]
@ -89,18 +89,18 @@ DECLARE_SHAPE_FN(pointwise_conv2d) {
indIOioC = 1;
const int bS = inputShapeInfo[1]; // batch size
const int iC = inputShapeInfo[indIOioC+1]; // input channels
const int iC = inputShapeInfo[indIOioC+1]; // input channels
const int oC = weightsShapeInfo[indWoC+1]; // output channels
std::string expectedWeightsShape = ShapeUtils::shapeAsString({1, 1, iC, oC});
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weightsShapeInfo), 0, "POINTWISECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
if (biasShapeInfo)
REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "POINTWISECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo));
std::vector<Nd4jLong> expectedWeightsShape = {1, 1, iC, oC};
REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "POINTWISECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
if (biasShapeInfo)
REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "POINTWISECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo));
auto outputShapeInfo = ShapeBuilders::copyShapeInfoAndType(inputShapeInfo, weightsShapeInfo, true, block.getWorkspace());
// do not forget to put oC instead of iC in outputShapeInfo
outputShapeInfo[indIOioC + 1] = oC;
outputShapeInfo[indIOioC + 1] = oC;
shape::updateStrides(outputShapeInfo, shape::order(inputShapeInfo));

View File

@ -73,11 +73,11 @@ CUSTOM_OP_IMPL(sconv2d, 2, 1, false, 0, 9) {
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH);
mC = weightsDepth->sizeAt(indWmC); // channels multiplier
std::string expectedWeightsDShape = ShapeUtils::shapeAsString({kH, kW, iC, mC});
REQUIRE_TRUE(expectedWeightsDShape == ShapeUtils::shapeAsString(weightsDepth), 0, " SCONV2D OP: wrong shape of weightsDepth array, expected is %s, but got %s instead !", expectedWeightsDShape.c_str(), ShapeUtils::shapeAsString(weightsDepth).c_str());
std::vector<Nd4jLong> expectedWeightsDShape = {kH, kW, iC, mC};
REQUIRE_TRUE(weightsDepth->isSameShape(expectedWeightsDShape), 0, " SCONV2D OP: wrong shape of weightsDepth array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsDShape).c_str(), ShapeUtils::shapeAsString(weightsDepth).c_str());
if(weightsPoint) {
std::string expectedWeightsPShape = ShapeUtils::shapeAsString({1, 1, iC*mC, oC});
REQUIRE_TRUE(expectedWeightsPShape == ShapeUtils::shapeAsString(weightsPoint), 0, " SCONV2D OP: wrong shape of weightsPoint array, expected is %s, but got %s instead !", expectedWeightsPShape.c_str(), ShapeUtils::shapeAsString(weightsPoint).c_str());
std::vector<Nd4jLong> expectedWeightsPShape = {1, 1, iC*mC, oC};
REQUIRE_TRUE(weightsPoint->isSameShape(expectedWeightsPShape), 0, " SCONV2D OP: wrong shape of weightsPoint array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsPShape).c_str(), ShapeUtils::shapeAsString(weightsPoint).c_str());
}
if (bias)
REQUIRE_TRUE(oC == bias->lengthOf(), 0, " SCONV2D OP: length of bias array must be equal to outChannels, but got %i instead", bias->lengthOf());
@ -151,11 +151,11 @@ DECLARE_SHAPE_FN(sconv2d) {
const int mC = weightsDShapeInfo[indWmC+1]; // channel multiplier
const int oC = weightsPShapeInfo ? weightsPShapeInfo[indWmC+1] : iC*mC; // output channels (oC or iC*mC)
std::string expectedWeightsDShape = ShapeUtils::shapeAsString({kH, kW, iC, mC});
REQUIRE_TRUE(expectedWeightsDShape == ShapeUtils::shapeAsString(weightsDShapeInfo), 0, "SCONV2D OP: wrong shape of depth weights array, expected is %s, but got %s instead !", expectedWeightsDShape.c_str(), ShapeUtils::shapeAsString(weightsDShapeInfo).c_str());
std::vector<Nd4jLong> expectedWeightsDShape = {kH, kW, iC, mC};
REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsDShapeInfo, expectedWeightsDShape), 0, "SCONV2D OP: wrong shape of depth weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsDShape).c_str(), ShapeUtils::shapeAsString(weightsDShapeInfo).c_str());
if(weightsPShapeInfo) {
std::string expectedWeightsPShape = ShapeUtils::shapeAsString({1, 1, iC*mC, oC});
REQUIRE_TRUE(expectedWeightsPShape == ShapeUtils::shapeAsString(weightsPShapeInfo), 0, "SCONV2D OP: wrong shape of point array, expected is %s, but got %s instead !", expectedWeightsPShape.c_str(), ShapeUtils::shapeAsString(weightsPShapeInfo).c_str());
std::vector<Nd4jLong> expectedWeightsPShape = {1, 1, iC*mC, oC};
REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsPShapeInfo, expectedWeightsPShape), 0, "SCONV2D OP: wrong shape of point array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsPShape).c_str(), ShapeUtils::shapeAsString(weightsPShapeInfo).c_str());
}
if (biasShapeInfo)
REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "SCONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo));
@ -250,13 +250,13 @@ CUSTOM_OP_IMPL(sconv2d_bp, 3, 2, false, 0, 9) {
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH);
mC = weightsDepth->sizeAt(indWmC); // channels multiplier
std::string expectedWeightsDShape = ShapeUtils::shapeAsString({kH, kW, iC, mC});
REQUIRE_TRUE(expectedWeightsDShape == ShapeUtils::shapeAsString(weightsDepth), 0, " SCONV2D_BP OP: wrong shape of weightsDepth array, expected is %s, but got %s instead !", expectedWeightsDShape.c_str(), ShapeUtils::shapeAsString(weightsDepth).c_str());
REQUIRE_TRUE(expectedWeightsDShape == ShapeUtils::shapeAsString(gradWD), 0, " SCONV2D_BP OP: wrong shape of gradWD array, expected is %s, but got %s instead !", expectedWeightsDShape.c_str(), ShapeUtils::shapeAsString(gradWD).c_str());
std::vector<Nd4jLong> expectedWeightsDShape = {kH, kW, iC, mC};
REQUIRE_TRUE(weightsDepth->isSameShape(expectedWeightsDShape), 0, " SCONV2D_BP OP: wrong shape of weightsDepth array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsDShape).c_str(), ShapeUtils::shapeAsString(weightsDepth).c_str());
REQUIRE_TRUE(gradWD->isSameShape(expectedWeightsDShape), 0, " SCONV2D_BP OP: wrong shape of gradWD array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsDShape).c_str(), ShapeUtils::shapeAsString(gradWD).c_str());
if(weightsPoint) {
std::string expectedWeightsPShape = ShapeUtils::shapeAsString({1, 1, iC*mC, oC});
REQUIRE_TRUE(expectedWeightsPShape == ShapeUtils::shapeAsString(weightsPoint), 0, " SCONV2D_BP OP: wrong shape of weightsPoint array, expected is %s, but got %s instead !", expectedWeightsPShape.c_str(), ShapeUtils::shapeAsString(weightsPoint).c_str());
REQUIRE_TRUE(expectedWeightsPShape == ShapeUtils::shapeAsString(gradWP), 0, " SCONV2D_BP OP: wrong shape of gradWP array, expected is %s, but got %s instead !", expectedWeightsPShape.c_str(), ShapeUtils::shapeAsString(gradWP).c_str());
std::vector<Nd4jLong> expectedWeightsPShape = {1, 1, iC*mC, oC};
REQUIRE_TRUE(weightsPoint->isSameShape(expectedWeightsPShape), 0, " SCONV2D_BP OP: wrong shape of weightsPoint array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsPShape).c_str(), ShapeUtils::shapeAsString(weightsPoint).c_str());
REQUIRE_TRUE(gradWP->isSameShape(expectedWeightsPShape), 0, " SCONV2D_BP OP: wrong shape of gradWP array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsPShape).c_str(), ShapeUtils::shapeAsString(gradWP).c_str());
}
if (bias) {
REQUIRE_TRUE(oC == bias->lengthOf(), 0, " SCONV2D_BP OP: length of bias array must be equal to outChannels, but got %i instead", bias->lengthOf());
@ -354,13 +354,13 @@ DECLARE_SHAPE_FN(sconv2d_bp) {
int trueoH, trueoW; // true output height, width
ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
std::string expectedGradOShapeInfo = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indIiH,indIiH+1}));
REQUIRE_TRUE(expectedGradOShapeInfo == ShapeUtils::shapeAsString(gradOShapeInfo), 0, "SCONV2D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", expectedGradOShapeInfo.c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str());
std::string expectedWeightsDShape = ShapeUtils::shapeAsString({kH, kW, iC, mC});
REQUIRE_TRUE(expectedWeightsDShape == ShapeUtils::shapeAsString(weightsDShapeInfo), 0, "SCONV2D_BP OP: wrong shape of depth weights array, expected is %s, but got %s instead !", expectedWeightsDShape.c_str(), ShapeUtils::shapeAsString(weightsDShapeInfo).c_str());
std::vector<Nd4jLong> expectedGradOShapeInfo = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indIiH,indIiH+1});
REQUIRE_TRUE(ShapeUtils::areShapesEqual(gradOShapeInfo, expectedGradOShapeInfo), 0, "SCONV2D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShapeInfo).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str());
std::vector<Nd4jLong> expectedWeightsDShape = {kH, kW, iC, mC};
REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsDShapeInfo, expectedWeightsDShape), 0, "SCONV2D_BP OP: wrong shape of depth weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsDShape).c_str(), ShapeUtils::shapeAsString(weightsDShapeInfo).c_str());
if(weightsPShapeInfo) {
std::string expectedWeightsPShape = ShapeUtils::shapeAsString({1, 1, iC*mC, oC});
REQUIRE_TRUE(expectedWeightsPShape == ShapeUtils::shapeAsString(weightsPShapeInfo), 0, "SCONV2D_BP OP: wrong shape of point array, expected is %s, but got %s instead !", expectedWeightsPShape.c_str(), ShapeUtils::shapeAsString(weightsPShapeInfo).c_str());
std::vector<Nd4jLong> expectedWeightsPShape = {1, 1, iC*mC, oC};
REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsPShapeInfo, expectedWeightsPShape), 0, "SCONV2D_BP OP: wrong shape of point array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsPShape).c_str(), ShapeUtils::shapeAsString(weightsPShapeInfo).c_str());
}
if (biasShapeInfo)
REQUIRE_TRUE((biasShapeInfo[0] == 1 || biasShapeInfo[0] == 2) && oC == shape::length(biasShapeInfo), 0, "SCONV2D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo));

View File

@ -168,11 +168,10 @@ CUSTOM_OP_IMPL(avgpool2d_bp, 2, 1, false, 0, 10) {
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW, 0,indIOioC,indIiH,indIiH+1}));
std::string expectedGradIShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iH,iW, 0,indIOioC,indIiH,indIiH+1}));
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "AVGPOOL2D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, "AVGPOOL2D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW, 0,indIOioC,indIiH,indIiH+1});
std::vector<Nd4jLong> expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iH,iW, 0,indIOioC,indIiH,indIiH+1});
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "AVGPOOL2D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
REQUIRE_TRUE(gradI->isSameShape(expectedGradIShape), 0, "AVGPOOL2D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradIShape).c_str(), ShapeUtils::shapeAsString(gradI).c_str());
if(!isNCHW) {
input = new NDArray(input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]

View File

@ -57,8 +57,8 @@ CUSTOM_OP_IMPL(avgpool3dnew, 1, 1, false, 0, 14) {
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
std::string expectedOutputShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}));
REQUIRE_TRUE(expectedOutputShape == ShapeUtils::shapeAsString(output), 0, "AVGPOOL3DNEW OP: wrong shape of output array, expected is %s, but got %s instead !", expectedOutputShape.c_str(), ShapeUtils::shapeAsString(output).c_str());
std::vector<Nd4jLong> expectedOutputShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
REQUIRE_TRUE(output->isSameShape(expectedOutputShape), 0, "AVGPOOL3DNEW OP: wrong shape of output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedOutputShape).c_str(), ShapeUtils::shapeAsString(output).c_str());
if(!isNCDHW) {
input = new NDArray(input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
@ -174,10 +174,10 @@ CUSTOM_OP_IMPL(avgpool3dnew_bp, 2, 1, false, 0, 14) {
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}));
std::string expectedGradIShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iD,iH,iW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}));
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "AVGPOOL3DNEW_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, "AVGPOOL3DNEW_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
std::vector<Nd4jLong> expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iD,iH,iW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "AVGPOOL3DNEW_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
REQUIRE_TRUE(gradI->isSameShape(expectedGradIShape), 0, "AVGPOOL3DNEW_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradIShape).c_str(), ShapeUtils::shapeAsString(gradI).c_str());
if(!isNCDHW) {
input = new NDArray(input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]

View File

@ -170,10 +170,10 @@ CUSTOM_OP_IMPL(maxpool2d_bp, 2, 1, false, 0, 10) {
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW, 0,indIOioC,indIiH,indIiH+1}));
std::string expectedGradIShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iH,iW, 0,indIOioC,indIiH,indIiH+1}));
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "MAXPOOL2D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, "MAXPOOL2D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW, 0,indIOioC,indIiH,indIiH+1});
std::vector<Nd4jLong> expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iH,iW, 0,indIOioC,indIiH,indIiH+1});
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "MAXPOOL2D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
REQUIRE_TRUE(gradI->isSameShape(expectedGradIShape), 0, "MAXPOOL2D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradIShape).c_str(), ShapeUtils::shapeAsString(gradI).c_str());
if(!isNCHW) {
input = new NDArray(input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]

View File

@ -57,8 +57,8 @@ CUSTOM_OP_IMPL(maxpool3dnew, 1, 1, false, 0, 14) {
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
std::string expectedOutputShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}));
REQUIRE_TRUE(expectedOutputShape == ShapeUtils::shapeAsString(output), 0, "MAXPOOL3D op: wrong shape of output array, expected is %s, but got %s instead !", expectedOutputShape.c_str(), ShapeUtils::shapeAsString(output).c_str());
std::vector<Nd4jLong> expectedOutputShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
REQUIRE_TRUE(output->isSameShape(expectedOutputShape), 0, "MAXPOOL3D op: wrong shape of output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedOutputShape).c_str(), ShapeUtils::shapeAsString(output).c_str());
// REQUIRE_TRUE(iD >= kD && iH >= kH && iW >= kW, 0, "MAXPOOL3D OP: the input depth/height/width must be greater or equal to kernel(filter) depth/height/width, but got [%i, %i, %i] and [%i, %i, %i] correspondingly !", iD,iH,iW, kD,kH,kW);
// REQUIRE_TRUE(kD/2 >= pD && kH/2 >= pH && kW/2 >= pW, 0, "MAXPOOL3D OP: pad depth/height/width must not be greater than half of kernel depth/height/width, but got [%i, %i, %i] and [%i, %i, %i] correspondingly !", pD,pH,pW, kD,kH,kW);
@ -176,10 +176,10 @@ CUSTOM_OP_IMPL(maxpool3dnew_bp, 2, 1, false, 0, 14) {
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}));
std::string expectedGradIShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iD,iH,iW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}));
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "MAXPOOL3DNEW_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, "MAXPOOL3DNEW_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
std::vector<Nd4jLong> expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iD,iH,iW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "MAXPOOL3DNEW_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
REQUIRE_TRUE(gradI->isSameShape(expectedGradIShape), 0, "MAXPOOL3DNEW_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradIShape).c_str(), ShapeUtils::shapeAsString(gradI).c_str());
if(!isNCDHW) {
input = new NDArray(input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]

View File

@ -169,10 +169,10 @@ CUSTOM_OP_IMPL(pnormpool2d_bp, 2, 1, false, 1, 10) {
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW, 0,indIOioC,indIiH,indIiH+1}));
std::string expectedGradIShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iH,iW, 0,indIOioC,indIiH,indIiH+1}));
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "PNORMPOOL2D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, "PNORMPOOL2D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW, 0,indIOioC,indIiH,indIiH+1});
std::vector<Nd4jLong> expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iH,iW, 0,indIOioC,indIiH,indIiH+1});
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "PNORMPOOL2D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
REQUIRE_TRUE(gradI->isSameShape(expectedGradIShape), 0, "PNORMPOOL2D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradIShape).c_str(), ShapeUtils::shapeAsString(gradI).c_str());
if(!isNCHW) {
input = new NDArray(input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]

View File

@ -41,7 +41,7 @@ CUSTOM_OP_IMPL(parallel_stack, -1, 1, false, 0, 0) {
inArrs[i] = INPUT_VARIABLE(i);
const int dim = 0;
helpers::stack(block.launchContext(), inArrs, output, dim);
helpers::stack(block.launchContext(), inArrs, *output, dim);
return Status::OK();
}

View File

@ -70,7 +70,7 @@ namespace sd {
empty = true;
//Don't break to perform input validation on other dims
}
indices[2*e] = start;
indices[2*e+1] = start + size;
}
@ -80,8 +80,29 @@ namespace sd {
return Status::OK();
}
auto sub = (*input)(indices, true);
output->assign(sub);
Nd4jLong* subArrShapeInfo = nullptr;
ALLOCATE(subArrShapeInfo, block.getWorkspace(), shape::shapeInfoLength(input->rankOf()), Nd4jLong);
Nd4jLong offset;
shape::calcSubArrShapeInfoAndOffset(indices.data(), input->getShapeInfo(), subArrShapeInfo, offset, true);
auto subArrShapeInfoPack = ConstantShapeHelper::getInstance()->bufferForShapeInfo(subArrShapeInfo);
NDArray::prepareSpecialUse({output}, {input});
NativeOpExecutioner::execTransformAny(block.launchContext(), sd::transform::Assign,
input->bufferWithOffset(offset), reinterpret_cast<Nd4jLong *>(subArrShapeInfoPack.primary()),
input->specialBufferWithOffset(offset), reinterpret_cast<Nd4jLong *>(subArrShapeInfoPack.special()),
output->buffer(), output->shapeInfo(), output->specialBuffer(), output->specialShapeInfo(),
nullptr, nullptr, nullptr, true);
NDArray::registerSpecialUse({output}, {input});
RELEASE(subArrShapeInfo, block.getWorkspace());
// auto sub = (*input)(indices, true);
// output->assign(sub);
STORE_RESULT(output);
@ -116,7 +137,7 @@ namespace sd {
REQUIRE_TRUE(begin.size() == x_rank, 0, "Begin array should have length of [%i] but got [%i] instead", x_rank, begin.size());
REQUIRE_TRUE(sz.size() == x_rank, 0, "Size array should have length of [%i] but got [%i] instead", x_rank, sz.size());
std::vector<Nd4jLong> shape;
auto empty = false;
for (int e = 0; e < x_rank; e++) {
@ -186,12 +207,12 @@ namespace sd {
size = input->sizeAt(e) - start;
}
REQUIRE_TRUE(size > 0, 0, "Slice: interval for dimension %i is less then 1", e);
indices[2*e] = start;
indices[2*e + 1] = start + size;
}
auto sub = (*output)(indices, true);
sub.assign(epsNext);
sub.assign(epsNext);
return Status::OK();
}

View File

@ -50,7 +50,7 @@ CUSTOM_OP_IMPL(stack, -1, 1, false, 0, 0) {
for(int i = 0; i < block.width(); ++i)
inArrs[i] = INPUT_VARIABLE(i);
helpers::stack(block.launchContext(), inArrs, output, dim);
helpers::stack(block.launchContext(), inArrs, *output, dim);
return Status::OK();
}

View File

@ -16,125 +16,114 @@
//
// @author raver119@gmail.com
// @author Yurii Shyrma (iuriish@yahoo.com)
//
#include <system/op_boilerplate.h>
#if NOT_EXCLUDED(OP_unstack)
#include <ops/declarable/CustomOperations.h>
#include <helpers/ConstantTadHelper.h>
#include<ops/declarable/helpers/stack.h>
namespace sd {
namespace ops {
CUSTOM_OP_IMPL(unstack, 1, -1, false, 0, 1) {
auto input = INPUT_VARIABLE(0);
namespace ops {
auto dim = INT_ARG(0);
if (dim < 0)
dim += input->rankOf();
CUSTOM_OP_IMPL(unstack, 1, -1, false, 0, 1) {
auto input = INPUT_VARIABLE(0);
auto dim = INT_ARG(0);
if (dim < 0)
dim += input->rankOf();
REQUIRE_TRUE(dim < input->rankOf(), 0, "Unstack dimension should be lower then rank of input %i, but got dimension=%i !", input->rankOf(), dim);
REQUIRE_TRUE(dim >= 0, 0, "Unstack dimension should be non-negative value, but got %i !", dim);
REQUIRE_TRUE(dim < input->rankOf(), 0, "Unstack dimension should be lower then rank of input %i, but got dimension=%i !", input->rankOf(), dim);
REQUIRE_TRUE(dim >= 0, 0, "Unstack dimension should be non-negative value, but got %i !", dim);
if(input->isEmpty())
return Status::OK();
if(input->isEmpty())
return Status::OK();
std::vector<int> dims;
for (int e = 0; e < input->rankOf(); e++)
if (e != dim)
dims.emplace_back(e);
if (dims.size() == 0 && input->rankOf() == 1) { // split vector into lenthOf scalars
for (Nd4jLong e = 0; e < input->lengthOf(); e++) {
auto outE = OUTPUT_VARIABLE(e);
outE->assign(input->e(e));
}
}
std::vector<NDArray*> outArrs(input->sizeAt(dim));
for(uint i = 0; i < outArrs.size(); ++i)
outArrs[i] = OUTPUT_VARIABLE(i);
auto tads = input->allTensorsAlongDimension(dims);
//nd4j_printf("Tad size: %d\n",tads.size());
for (int e = 0; e < tads.size(); e++) {
//nd4j_printf("Calling assign at index %d\n",e);
auto outE = OUTPUT_VARIABLE(e);
auto tadAtE = tads.at(e);
helpers::unstack(block.launchContext(), *input, outArrs, dim);
outE->assign(tadAtE);
return Status::OK();
}
this->storeResult(block, e, *outE);
}
DECLARE_SYN(unpack, unstack);
return Status::OK();
}
DECLARE_SHAPE_FN(unstack) {
auto inShapeInfo = inputShape->at(0);
DECLARE_SYN(unpack, unstack);
auto dim = INT_ARG(0);
if (dim < 0)
dim += shape::rank(inShapeInfo);
DECLARE_SHAPE_FN(unstack) {
auto inShape = inputShape->at(0);
REQUIRE_TRUE(dim < inShapeInfo[0], 0, "UNSTACK op: dimension should be lower then rank of input %i, but got dimension=%i !", inShapeInfo[0], dim);
REQUIRE_TRUE(dim >= 0, 0, "UNSTACK op: dimension should be non-negative value, but got %i !", dim);
auto dim = INT_ARG(0);
if (dim < 0)
dim += shape::rank(inShape);
if(ArrayOptions::arrayType(inShapeInfo) == ArrayType::EMPTY) {
REQUIRE_TRUE(dim < inShape[0], 0, "UNSTACK op: dimension should be lower then rank of input %i, but got dimension=%i !", inShape[0], dim);
REQUIRE_TRUE(dim >= 0, 0, "UNSTACK op: dimension should be non-negative value, but got %i !", dim);
if(shape::shapeOf(inShapeInfo)[dim] == 0)
return SHAPELIST();
if(ArrayOptions::arrayType(inShape) == ArrayType::EMPTY) {
if(shape::shapeOf(inShape)[dim] == 0)
return SHAPELIST();
const Nd4jLong numTads = shape::shapeOf(inShape)[dim];
std::vector<Nd4jLong> outShape;
for(uint i = 0; i < shape::rank(inShape); ++i)
if(i != dim)
outShape.push_back(shape::shapeOf(inShape)[i]);
const Nd4jLong numTads = shape::shapeOf(inShapeInfo)[dim];
std::vector<Nd4jLong> outShape;
for(uint i = 0; i < shape::rank(inShapeInfo); ++i)
if(i != dim)
outShape.push_back(shape::shapeOf(inShapeInfo)[i]);
auto result = SHAPELIST();
for(uint i = 0; i < numTads; ++i)
result->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShape), shape::order(inShape), outShape));
return result;
}
auto result = SHAPELIST();
for(uint i = 0; i < numTads; ++i)
result->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShapeInfo), shape::order(inShapeInfo), outShape));
std::vector<int> dims;
for (int e = 0; e < shape::rank(inShape); e++)
if (e != dim)
dims.emplace_back(e);
if (dims.size() == 0 && shape::rank(inShape) == 1) { // split vector into lenthOf scalars
//
auto result = SHAPELIST();
for (Nd4jLong e = 0; e < shape::length(inShape); e++)
result->push_back(ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(inShape)));
return result;
}
auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(inShape, dims);
auto numTads = tadPack.numberOfTads();
std::vector<Nd4jLong> shape(shape::rank(tadPack.primaryShapeInfo()));
for (int e = 0; e < shape::rank(tadPack.primaryShapeInfo()); e++)
shape[e] = shape::shapeOf(tadPack.primaryShapeInfo())[e];
// remove leading and trailing 1
if (inShape[0] == 2 && shape.size() == 2) {
if (shape[0] == 1) {
shape.erase(shape.begin());
} else if (shape[1] == 1) {
shape.erase(shape.end());
}
}
auto result = SHAPELIST();
for (int e = 0; e < numTads; e++) {
auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShape), shape::order(inShape), shape);
result->push_back(newShape);
}
return result;
}
DECLARE_TYPES(unstack) {
getOpDescriptor()
->setAllowedInputTypes({ALL_FLOATS, ALL_INTS})
->setSameMode(true);
}
return result;
}
std::vector<int> dims = ShapeUtils::evalDimsToExclude(inShapeInfo[0], {dim});
if (dims.size() == 0 && shape::rank(inShapeInfo) == 1) { // split vector into lenthOf scalars
auto result = SHAPELIST();
for (Nd4jLong e = 0; e < shape::length(inShapeInfo); e++)
result->push_back(ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(inShapeInfo)));
return result;
}
std::vector<Nd4jLong> subArrShape(shape::rank(inShapeInfo) - 1);
for(uint j = 0, i = 0; i < shape::rank(inShapeInfo); i++)
if(i != dim)
subArrShape[j++] = shape::shapeOf(inShapeInfo)[i];
// remove leading and trailing 1
if (inShapeInfo[0] == 2 && subArrShape.size() == 2) {
if (subArrShape[0] == 1)
subArrShape.erase(subArrShape.begin());
else if (subArrShape[1] == 1)
subArrShape.erase(subArrShape.end());
}
auto result = SHAPELIST();
for (int e = 0; e < shape::shapeOf(inShapeInfo)[dim]; e++) {
auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShapeInfo), shape::order(inShapeInfo), subArrShape);
result->push_back(newShape);
}
return result;
}
DECLARE_TYPES(unstack) {
getOpDescriptor()
->setAllowedInputTypes({ALL_FLOATS, ALL_INTS})
->setSameMode(true);
}
}
}
#endif

View File

@ -54,15 +54,15 @@ CUSTOM_OP_IMPL(dynamic_bidirectional_rnn, 7, 4, false, 0, 0) {
maxTimeStep = INPUT_VARIABLE(9);
break;
}
auto hFW = OUTPUT_VARIABLE(0); // cell outputs for forward RNN [time x bS x numUnitsFW] or [bS x time x numUnitsFW], shape depends on timeMajor parameter
auto hBW = OUTPUT_VARIABLE(1); // cell outputs for backward RNN [time x bS x numUnitsBW] or [bS x time x numUnitsBW], shape depends on timeMajor parameter
auto hFWFinal = OUTPUT_VARIABLE(2); // final cell out for forward RNN [bS x numUnitsFW]
auto hBWFinal = OUTPUT_VARIABLE(3); // final cell out for backward RNN [bS x numUnitsBF]
REQUIRE_TRUE(x->rankOf() == 3, 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: input array must have rank = 3, but got %i instead !", x->rankOf());
REQUIRE_TRUE(WxFW->rankOf() == 2, 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: input-to-hidden weights array (for forward RNN) must have rank = 2, but got %i instead !", WxFW->rankOf());
REQUIRE_TRUE(WxBW->rankOf() == 2, 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: input-to-hidden weights array (for backward RNN) must have rank = 2, but got %i instead !", WxBW->rankOf());
REQUIRE_TRUE(WxFW->rankOf() == 2, 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: input-to-hidden weights array (for forward RNN) must have rank = 2, but got %i instead !", WxFW->rankOf());
REQUIRE_TRUE(WxBW->rankOf() == 2, 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: input-to-hidden weights array (for backward RNN) must have rank = 2, but got %i instead !", WxBW->rankOf());
const int inRank = x->rankOf();
const int time = timeMajor ? x->sizeAt(0) : x->sizeAt(1);
@ -70,16 +70,26 @@ CUSTOM_OP_IMPL(dynamic_bidirectional_rnn, 7, 4, false, 0, 0) {
const int numUnitsFW = WxFW->sizeAt(1);
const int numUnitsBW = WxBW->sizeAt(1);
REQUIRE_TRUE(ShapeUtils::shapeAsString(WhFW) == ShapeUtils::shapeAsString({numUnitsFW, numUnitsFW}), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of hidden-to-hidden weights array (for forward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString({numUnitsFW, numUnitsFW}).c_str(), ShapeUtils::shapeAsString(WhFW).c_str());
REQUIRE_TRUE(ShapeUtils::shapeAsString(WhBW) == ShapeUtils::shapeAsString({numUnitsBW, numUnitsBW}), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of hidden-to-hidden weights array (for backward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString({numUnitsBW, numUnitsBW}).c_str(), ShapeUtils::shapeAsString(WhBW).c_str());
REQUIRE_TRUE(ShapeUtils::shapeAsString(bFW) == ShapeUtils::shapeAsString({2*numUnitsFW}), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of biases array (for forward RNN), expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2*numUnitsFW}).c_str(), ShapeUtils::shapeAsString(bFW).c_str());
REQUIRE_TRUE(ShapeUtils::shapeAsString(bBW) == ShapeUtils::shapeAsString({2*numUnitsBW}), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of biases array (for backward RNN), expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2*numUnitsBW}).c_str(), ShapeUtils::shapeAsString(bBW).c_str());
if(h0FW)
REQUIRE_TRUE(ShapeUtils::shapeAsString(h0FW) == ShapeUtils::shapeAsString({bS, numUnitsFW}), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of initial cell output array (for forward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString({bS, numUnitsFW}).c_str(), ShapeUtils::shapeAsString(h0FW).c_str());
if(h0BW)
REQUIRE_TRUE(ShapeUtils::shapeAsString(h0BW) == ShapeUtils::shapeAsString({bS, numUnitsBW}), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of initial cell output array (for backward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString({bS, numUnitsBW}).c_str(), ShapeUtils::shapeAsString(h0BW).c_str());
if(maxTimeStep)
REQUIRE_TRUE(ShapeUtils::shapeAsString(maxTimeStep) == ShapeUtils::shapeAsString({bS}), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of maxTimeStep array, expected is [%i], but got %s instead !", bS, ShapeUtils::shapeAsString(maxTimeStep).c_str());
std::vector<Nd4jLong> expectedWhFWshape = {numUnitsFW, numUnitsFW};
std::vector<Nd4jLong> expectedWhBWshape = {numUnitsBW, numUnitsBW};
std::vector<Nd4jLong> expectedbFWshape = {2*numUnitsFW};
std::vector<Nd4jLong> expectedbBWshape = {2*numUnitsBW};
REQUIRE_TRUE(WhFW->isSameShape(expectedWhFWshape), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of hidden-to-hidden weights array (for forward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedWhFWshape).c_str(), ShapeUtils::shapeAsString(WhFW).c_str());
REQUIRE_TRUE(WhBW->isSameShape(expectedWhBWshape) , 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of hidden-to-hidden weights array (for backward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedWhBWshape).c_str(), ShapeUtils::shapeAsString(WhBW).c_str());
REQUIRE_TRUE(bFW->isSameShape(expectedbFWshape), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of biases array (for forward RNN), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedbFWshape).c_str(), ShapeUtils::shapeAsString(bFW).c_str());
REQUIRE_TRUE(bBW->isSameShape(expectedbBWshape), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of biases array (for backward RNN), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedbBWshape).c_str(), ShapeUtils::shapeAsString(bBW).c_str());
if(h0FW) {
std::vector<Nd4jLong> expectedh0FWshape = {bS, numUnitsFW};
REQUIRE_TRUE(h0FW->isSameShape(expectedh0FWshape), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of initial cell output array (for forward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedh0FWshape).c_str(), ShapeUtils::shapeAsString(h0FW).c_str());
}
if(h0BW) {
std::vector<Nd4jLong> expectedh0BWshape = {bS, numUnitsBW};
REQUIRE_TRUE(h0BW->isSameShape(expectedh0BWshape), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of initial cell output array (for backward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedh0BWshape).c_str(), ShapeUtils::shapeAsString(h0BW).c_str());
}
if(maxTimeStep) {
std::vector<Nd4jLong> expectedmaxTimeStepshape = {bS};
REQUIRE_TRUE(maxTimeStep->isSameShape(expectedmaxTimeStepshape), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of maxTimeStep array, expected is [%i], but got %s instead !", bS, ShapeUtils::shapeAsString(maxTimeStep).c_str());
}
// forward steps
@ -101,19 +111,19 @@ CUSTOM_OP_IMPL(dynamic_bidirectional_rnn, 7, 4, false, 0, 0) {
REQUIRE_TRUE (resultsIn->status() == ND4J_STATUS_OK, 0, "dynamic_bidirectional_rnn: there is a problem with reverse on the sequence.");
auto revInput = resultsIn->at(0);
// backward steps
// backward steps
auto resultsBW = dynamicRnn.evaluate({revInput, WxBW, WhBW, bBW, h0BW, maxTimeStep}, {timeMajor});
auto hBWtemp = resultsBW->at(0); // [time x bS x numUnitsBW] or [ bS x time xnumUnitsBW]
hBWFinal->assign(resultsBW->at(1));
// reverse hBWtemp
// reverse hBWtemp
auto resultsOut = timeMajor ? reverse.evaluate({hBWtemp, seqLen}, {0, 1}) : reverse.evaluate({hBWtemp, seqLen}, {1, 0});
hBW->assign(resultsOut->at(0));
delete resultsOut;
delete resultsBW;
delete resultsIn;
delete resultsFW;
delete resultsIn;
delete resultsFW;
if(seqLen != maxTimeStep)
delete seqLen;
@ -128,7 +138,7 @@ CUSTOM_OP_IMPL(dynamic_bidirectional_rnn, 7, 4, false, 0, 0) {
}
DECLARE_SHAPE_FN(dynamic_bidirectional_rnn) {
DECLARE_SHAPE_FN(dynamic_bidirectional_rnn) {
auto x = INPUT_VARIABLE(0); // input [time x bS x inSize] or [bS x time x inSize], shape depends on timeMajor parameter
auto WxFW = INPUT_VARIABLE(1); // input-to-hidden weights for forward RNN, [inSize x numUnitsFW]
@ -143,7 +153,7 @@ DECLARE_SHAPE_FN(dynamic_bidirectional_rnn) {
NDArray* maxTimeStep = nullptr; // vector [bS] containing integer values within [0,time), each element of this vector set max time step per each input in batch, this means there are no calculations for time >= maxTimeStep
const int timeMajor = block.getIArguments()->size() > 0 ? INT_ARG(0) : 0; // if true then [time, bS, ...], else [bS, time, ...]
switch(block.width()) {
case 8:
maxTimeStep = INPUT_VARIABLE(7);
@ -160,8 +170,8 @@ DECLARE_SHAPE_FN(dynamic_bidirectional_rnn) {
}
REQUIRE_TRUE(x->rankOf() == 3, 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: input array must have rank = 3, but got %i instead !", x->rankOf());
REQUIRE_TRUE(WxFW->rankOf() == 2, 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: input-to-hidden weights array (for forward RNN) must have rank = 2, but got %i instead !", WxFW->rankOf());
REQUIRE_TRUE(WxBW->rankOf() == 2, 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: input-to-hidden weights array (for backward RNN) must have rank = 2, but got %i instead !", WxBW->rankOf());
REQUIRE_TRUE(WxFW->rankOf() == 2, 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: input-to-hidden weights array (for forward RNN) must have rank = 2, but got %i instead !", WxFW->rankOf());
REQUIRE_TRUE(WxBW->rankOf() == 2, 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: input-to-hidden weights array (for backward RNN) must have rank = 2, but got %i instead !", WxBW->rankOf());
const int inRank = x->rankOf();
const int time = timeMajor ? x->sizeAt(0) : x->sizeAt(1);
@ -169,16 +179,28 @@ DECLARE_SHAPE_FN(dynamic_bidirectional_rnn) {
const int numUnitsFW = WxFW->sizeAt(1);
const int numUnitsBW = WxBW->sizeAt(1);
REQUIRE_TRUE(ShapeUtils::shapeAsString(WhFW) == ShapeUtils::shapeAsString({numUnitsFW, numUnitsFW}), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of hidden-to-hidden weights array (for forward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString({numUnitsFW, numUnitsFW}).c_str(), ShapeUtils::shapeAsString(WhFW).c_str());
REQUIRE_TRUE(ShapeUtils::shapeAsString(WhBW) == ShapeUtils::shapeAsString({numUnitsBW, numUnitsBW}), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of hidden-to-hidden weights array (for backward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString({numUnitsBW, numUnitsBW}).c_str(), ShapeUtils::shapeAsString(WhBW).c_str());
REQUIRE_TRUE(ShapeUtils::shapeAsString(bFW) == ShapeUtils::shapeAsString({2*numUnitsFW}), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of biases array (for forward RNN), expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2*numUnitsFW}).c_str(), ShapeUtils::shapeAsString(bFW).c_str());
REQUIRE_TRUE(ShapeUtils::shapeAsString(bBW) == ShapeUtils::shapeAsString({2*numUnitsBW}), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of biases array (for backward RNN), expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2*numUnitsBW}).c_str(), ShapeUtils::shapeAsString(bBW).c_str());
if(h0FW)
REQUIRE_TRUE(ShapeUtils::shapeAsString(h0FW) == ShapeUtils::shapeAsString({bS, numUnitsFW}), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of initial cell output array (for forward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString({bS, numUnitsFW}).c_str(), ShapeUtils::shapeAsString(h0FW).c_str());
if(h0BW)
REQUIRE_TRUE(ShapeUtils::shapeAsString(h0BW) == ShapeUtils::shapeAsString({bS, numUnitsBW}), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of initial cell output array (for backward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString({bS, numUnitsBW}).c_str(), ShapeUtils::shapeAsString(h0BW).c_str());
if(maxTimeStep)
REQUIRE_TRUE(ShapeUtils::shapeAsString(maxTimeStep) == ShapeUtils::shapeAsString({bS}), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of maxTimeStep array, expected is [%i], but got %s instead !", bS, ShapeUtils::shapeAsString(maxTimeStep).c_str());
std::vector<Nd4jLong> expectedWhFWshape = {numUnitsFW, numUnitsFW};
std::vector<Nd4jLong> expectedWhBWshape = {numUnitsBW, numUnitsBW};
std::vector<Nd4jLong> expectedbFWshape = {2*numUnitsFW};
std::vector<Nd4jLong> expectedbBWshape = {2*numUnitsBW};
REQUIRE_TRUE(WhFW->isSameShape(expectedWhFWshape), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of hidden-to-hidden weights array (for forward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedWhFWshape).c_str(), ShapeUtils::shapeAsString(WhFW).c_str());
REQUIRE_TRUE(WhBW->isSameShape(expectedWhBWshape), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of hidden-to-hidden weights array (for backward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedWhBWshape).c_str(), ShapeUtils::shapeAsString(WhBW).c_str());
REQUIRE_TRUE(bFW->isSameShape(expectedbFWshape), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of biases array (for forward RNN), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedbFWshape).c_str(), ShapeUtils::shapeAsString(bFW).c_str());
REQUIRE_TRUE(bBW->isSameShape(expectedbBWshape), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of biases array (for backward RNN), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedbBWshape).c_str(), ShapeUtils::shapeAsString(bBW).c_str());
if(h0FW) {
std::vector<Nd4jLong> expectedh0FWshape = {bS, numUnitsFW};
REQUIRE_TRUE(h0FW->isSameShape(expectedh0FWshape), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of initial cell output array (for forward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedh0FWshape).c_str(), ShapeUtils::shapeAsString(h0FW).c_str());
}
if(h0BW) {
std::vector<Nd4jLong> expectedh0BWshape = {bS, numUnitsBW};
REQUIRE_TRUE(h0BW->isSameShape(expectedh0BWshape), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of initial cell output array (for backward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedh0BWshape).c_str(), ShapeUtils::shapeAsString(h0BW).c_str());
}
if(maxTimeStep) {
std::vector<Nd4jLong> expectedmaxTimeStepshape = {bS};
REQUIRE_TRUE(maxTimeStep->isSameShape(expectedmaxTimeStepshape), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of maxTimeStep array, expected is [%i], but got %s instead !", bS, ShapeUtils::shapeAsString(maxTimeStep).c_str());
}
// evaluate output shapeInfos
Nd4jLong *hFWShapeInfo(nullptr), *hBWShapeInfo(nullptr), *hFWFinalPrevShapeInfo(nullptr), *hBWFinalPrevShapeInfo(nullptr);
@ -187,13 +209,13 @@ DECLARE_SHAPE_FN(dynamic_bidirectional_rnn) {
ALLOCATE(hFWFinalPrevShapeInfo, block.getWorkspace(), shape::shapeInfoLength(inRank-1), Nd4jLong);
ALLOCATE(hBWFinalPrevShapeInfo, block.getWorkspace(), shape::shapeInfoLength(inRank-1), Nd4jLong);
hFWShapeInfo[0] = hBWShapeInfo[0] = inRank;
hFWShapeInfo[0] = hBWShapeInfo[0] = inRank;
hFWShapeInfo[1] = hBWShapeInfo[1] = timeMajor ? time : bS;
hFWShapeInfo[2] = hBWShapeInfo[2] = timeMajor ? bS : time;
hFWShapeInfo[3] = numUnitsFW;
hBWShapeInfo[3] = numUnitsBW;
hFWFinalPrevShapeInfo[0] = hBWFinalPrevShapeInfo[0] = inRank-1;
hFWFinalPrevShapeInfo[1] = hBWFinalPrevShapeInfo[1] = bS;
hFWFinalPrevShapeInfo[1] = hBWFinalPrevShapeInfo[1] = bS;
hFWFinalPrevShapeInfo[2] = numUnitsFW;
hBWFinalPrevShapeInfo[2] = numUnitsBW;
@ -201,9 +223,9 @@ DECLARE_SHAPE_FN(dynamic_bidirectional_rnn) {
ShapeUtils::updateStridesAndType(hBWShapeInfo, x->getShapeInfo(), x->ordering());
ShapeUtils::updateStridesAndType(hFWFinalPrevShapeInfo, x->getShapeInfo(), x->ordering());
ShapeUtils::updateStridesAndType(hBWFinalPrevShapeInfo, x->getShapeInfo(), x->ordering());
return SHAPELIST(CONSTANT(hFWShapeInfo), CONSTANT(hBWShapeInfo), CONSTANT(hFWFinalPrevShapeInfo), CONSTANT(hBWFinalPrevShapeInfo));
}
}

View File

@ -60,12 +60,18 @@ CUSTOM_OP_IMPL(dynamic_rnn, 4, 2, false, 0, 0) {
const int bS = timeMajor ? x->sizeAt(1) : x->sizeAt(0);
const int numUnits = Wx->sizeAt(1);
REQUIRE_TRUE(ShapeUtils::shapeAsString(Wh) == ShapeUtils::shapeAsString({numUnits, numUnits}), 0, "DYNAMIC_RNN custom operation: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({numUnits, numUnits}).c_str(), ShapeUtils::shapeAsString(Wh).c_str());
REQUIRE_TRUE(ShapeUtils::shapeAsString(b) == ShapeUtils::shapeAsString({2*numUnits}), 0, "DYNAMIC_RNN custom operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2*numUnits}).c_str(), ShapeUtils::shapeAsString(b).c_str());
if(h0)
REQUIRE_TRUE(ShapeUtils::shapeAsString(h0) == ShapeUtils::shapeAsString({bS, numUnits}), 0, "DYNAMIC_RNN custom operation: wrong shape of initial cell output array, expected is %s but got %s instead !", ShapeUtils::shapeAsString({bS, numUnits}).c_str(), ShapeUtils::shapeAsString(h0).c_str());
if(maxTimeStep)
REQUIRE_TRUE(ShapeUtils::shapeAsString(maxTimeStep) == ShapeUtils::shapeAsString({bS}), 0, "DYNAMIC_RNN custom operation: wrong shape of maxTimeStep array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS}).c_str(), ShapeUtils::shapeAsString(maxTimeStep).c_str());
std::vector<Nd4jLong> expectedWhShape = {numUnits, numUnits};
std::vector<Nd4jLong> expectedBShape = {2*numUnits};
REQUIRE_TRUE(Wh->isSameShape(expectedWhShape), 0, "DYNAMIC_RNN custom operation: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWhShape).c_str(), ShapeUtils::shapeAsString(Wh).c_str());
REQUIRE_TRUE(b->isSameShape(expectedBShape), 0, "DYNAMIC_RNN custom operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedBShape).c_str(), ShapeUtils::shapeAsString(b).c_str());
if(h0) {
std::vector<Nd4jLong> expectedh0Shape = {bS, numUnits};
REQUIRE_TRUE(h0->isSameShape(expectedh0Shape), 0, "DYNAMIC_RNN custom operation: wrong shape of initial cell output array, expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedh0Shape).c_str(), ShapeUtils::shapeAsString(h0).c_str());
}
if(maxTimeStep) {
std::vector<Nd4jLong> expectedmaxTimeStepShape = {bS};
REQUIRE_TRUE(maxTimeStep->isSameShape(expectedmaxTimeStepShape), 0, "DYNAMIC_RNN custom operation: wrong shape of maxTimeStep array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedmaxTimeStepShape).c_str(), ShapeUtils::shapeAsString(maxTimeStep).c_str());
}
if(timeMajor == false) {
x = new NDArray(x->permute({1, 0, 2})); // [bS x time x inSize] -> [time x bS x inSize]
@ -127,12 +133,19 @@ DECLARE_SHAPE_FN(dynamic_rnn) {
const int bS = timeMajor ? xShapeInfo[2] : xShapeInfo[1];
const int numUnits = WxShapeInfo[2];
REQUIRE_TRUE(ShapeUtils::shapeAsString(WhShapeInfo) == ShapeUtils::shapeAsString({numUnits, numUnits}), 0, "DYNAMIC_RNN custom operation: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({numUnits, numUnits}).c_str(), ShapeUtils::shapeAsString(WhShapeInfo).c_str());
REQUIRE_TRUE(ShapeUtils::shapeAsString(bShapeInfo) == ShapeUtils::shapeAsString({2*numUnits}), 0, "DYNAMIC_RNN custom operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2*numUnits}).c_str(), ShapeUtils::shapeAsString(bShapeInfo).c_str());
if(h0ShapeInfo)
REQUIRE_TRUE(ShapeUtils::shapeAsString(h0ShapeInfo) == ShapeUtils::shapeAsString({bS, numUnits}), 0, "DYNAMIC_RNN custom operation: wrong shape of initial cell output array, expected is %s but got %s instead !", ShapeUtils::shapeAsString({bS, numUnits}).c_str(), ShapeUtils::shapeAsString(h0ShapeInfo).c_str());
if(maxTimeStepShapeInfo)
REQUIRE_TRUE(ShapeUtils::shapeAsString(maxTimeStepShapeInfo) == ShapeUtils::shapeAsString({bS}), 0, "DYNAMIC_RNN custom operation: wrong shape of maxTimeStep array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS}).c_str(), ShapeUtils::shapeAsString(maxTimeStepShapeInfo).c_str());
std::vector<Nd4jLong> expectedWhShape = {numUnits, numUnits};
std::vector<Nd4jLong> expectedBShape = {2*numUnits};
REQUIRE_TRUE(ShapeUtils::areShapesEqual(WhShapeInfo, expectedWhShape), 0, "DYNAMIC_RNN custom operation: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWhShape).c_str(), ShapeUtils::shapeAsString(WhShapeInfo).c_str());
REQUIRE_TRUE(ShapeUtils::areShapesEqual(bShapeInfo, expectedBShape), 0, "DYNAMIC_RNN custom operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedBShape).c_str(), ShapeUtils::shapeAsString(bShapeInfo).c_str());
if(h0ShapeInfo) {
std::vector<Nd4jLong> expectedh0Shape = {bS, numUnits};
REQUIRE_TRUE(ShapeUtils::areShapesEqual(h0ShapeInfo, expectedh0Shape), 0, "DYNAMIC_RNN custom operation: wrong shape of initial cell output array, expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedh0Shape).c_str(), ShapeUtils::shapeAsString(h0ShapeInfo).c_str());
}
if(maxTimeStepShapeInfo) {
std::vector<Nd4jLong> expectedmaxTimeStepShape = {bS};
REQUIRE_TRUE(ShapeUtils::areShapesEqual(maxTimeStepShapeInfo, expectedmaxTimeStepShape), 0, "DYNAMIC_RNN custom operation: wrong shape of maxTimeStep array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedmaxTimeStepShape).c_str(), ShapeUtils::shapeAsString(maxTimeStepShapeInfo).c_str());
}
// evaluate output shapeInfos
Nd4jLong *hShapeInfo(nullptr), *hPrevShapeInfo(nullptr);

View File

@ -32,35 +32,31 @@ namespace ops {
CUSTOM_OP_IMPL(gru, 5, 1, false, 0, 0) {
auto x = INPUT_VARIABLE(0); // input [time x bS x iS]
auto h0 = INPUT_VARIABLE(1); // initial cell output (at time step = 0) [bS x nU]
auto Wx = INPUT_VARIABLE(2); // input-to-hidden weights, [iS x 3*nU]
auto Wh = INPUT_VARIABLE(3); // hidden-to-hidden weights, [nU x 3*nU]
auto b = INPUT_VARIABLE(4); // biases, [3*nU]
auto h = OUTPUT_VARIABLE(0); // cell outputs [time x bS x nU], that is per each time step
const int rank = x->rankOf(); // = 3
const int rank = x->rankOf(); // = 3
const int time = x->sizeAt(0);
const int bS = x->sizeAt(1);
const int iS = x->sizeAt(2);
const int nU = h0->sizeAt(1);
const std::string h0Shape = ShapeUtils::shapeAsString(h0);
const std::string h0CorrectShape = ShapeUtils::shapeAsString({bS, nU});
const std::string wxShape = ShapeUtils::shapeAsString(Wx);
const std::string wxCorrectShape = ShapeUtils::shapeAsString({iS, 3*nU});
const std::string whShape = ShapeUtils::shapeAsString(Wh);
const std::string whCorrectShape = ShapeUtils::shapeAsString({nU, 3*nU});
const std::string bShape = ShapeUtils::shapeAsString(b);
const std::string bCorrectShape = ShapeUtils::shapeAsString({3*nU});
REQUIRE_TRUE(h0Shape == h0CorrectShape, 0, "GRU operation: wrong shape of previous cell output array, expected is %s, but got %s instead !", h0CorrectShape.c_str(), h0Shape.c_str());
REQUIRE_TRUE(wxShape == wxCorrectShape, 0, "GRU operation: wrong shape of input-to-hidden weights array, expected is %s, but got %s instead !", wxCorrectShape.c_str(), wxShape.c_str());
REQUIRE_TRUE(whShape == whCorrectShape, 0, "GRU operation: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !", whCorrectShape.c_str(), whShape.c_str());
REQUIRE_TRUE(bShape == bCorrectShape, 0, "GRU operation: wrong shape of biases array, expected is %s, but got %s instead !", bCorrectShape.c_str(), bShape.c_str());
const std::vector<Nd4jLong> h0CorrectShape = {bS, nU};
const std::vector<Nd4jLong> wxCorrectShape = {iS, 3*nU};
const std::vector<Nd4jLong> whCorrectShape = {nU, 3*nU};
const std::vector<Nd4jLong> bCorrectShape = {3*nU};
REQUIRE_TRUE(h0->isSameShape(h0CorrectShape), 0, "GRU operation: wrong shape of previous cell output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(h0CorrectShape).c_str(), ShapeUtils::shapeAsString(h0).c_str());
REQUIRE_TRUE(Wx->isSameShape(wxCorrectShape), 0, "GRU operation: wrong shape of input-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wxCorrectShape).c_str(), ShapeUtils::shapeAsString(Wx).c_str());
REQUIRE_TRUE(Wh->isSameShape(whCorrectShape), 0, "GRU operation: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(whCorrectShape).c_str(), ShapeUtils::shapeAsString(Wh).c_str());
REQUIRE_TRUE(b->isSameShape(bCorrectShape), 0, "GRU operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bCorrectShape).c_str(), ShapeUtils::shapeAsString(b).c_str());
helpers::gruTimeLoop(block.launchContext(), x, h0, Wx, Wh, b, h);
return Status::OK();
}
@ -72,7 +68,7 @@ CUSTOM_OP_IMPL(gru, 5, 1, false, 0, 0) {
}
DECLARE_SHAPE_FN(gru) {
DECLARE_SHAPE_FN(gru) {
const auto xShapeInfo = inputShape->at(0); // input [time x bS x inSize]
const auto h0ShapeInfo = inputShape->at(1); // initial cell output [bS x numUnits], that is at time step t=0
const auto WxShapeInfo = inputShape->at(2); // input-to-hidden weights, [inSize x 3*numUnits]
@ -85,34 +81,30 @@ DECLARE_SHAPE_FN(gru) {
const auto inSize = xShapeInfo[3];
const auto numUnits = h0ShapeInfo[2];
const std::string h0Shape = ShapeUtils::shapeAsString(h0ShapeInfo);
const std::string h0CorrectShape = ShapeUtils::shapeAsString({bS, numUnits});
const std::string wxShape = ShapeUtils::shapeAsString(WxShapeInfo);
const std::string wxCorrectShape = ShapeUtils::shapeAsString({inSize, 3*numUnits});
const std::string whShape = ShapeUtils::shapeAsString(WhShapeInfo);
const std::string whCorrectShape = ShapeUtils::shapeAsString({numUnits, 3*numUnits});
const std::string bShape = ShapeUtils::shapeAsString(bShapeInfo);
const std::string bCorrectShape = ShapeUtils::shapeAsString({3*numUnits});
REQUIRE_TRUE(h0Shape == h0CorrectShape, 0, "GRU operation: wrong shape of previous cell output array, expected is %s, but got %s instead !", h0CorrectShape.c_str(), h0Shape.c_str());
REQUIRE_TRUE(wxShape == wxCorrectShape, 0, "GRU operation: wrong shape of input-to-hidden weights array, expected is %s, but got %s instead !", wxCorrectShape.c_str(), wxShape.c_str());
REQUIRE_TRUE(whShape == whCorrectShape, 0, "GRU operation: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !", whCorrectShape.c_str(), whShape.c_str());
REQUIRE_TRUE(bShape == bCorrectShape, 0, "GRU operation: wrong shape of biases array, expected is %s, but got %s instead !", bCorrectShape.c_str(), bShape.c_str());
const std::vector<Nd4jLong> h0CorrectShape = {bS, numUnits};
const std::vector<Nd4jLong> wxCorrectShape = {inSize, 3*numUnits};
const std::vector<Nd4jLong> whCorrectShape = {numUnits, 3*numUnits};
const std::vector<Nd4jLong> bCorrectShape = {3*numUnits};
REQUIRE_TRUE(ShapeUtils::areShapesEqual(h0ShapeInfo, h0CorrectShape), 0, "GRU operation: wrong shape of previous cell output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(h0CorrectShape).c_str(), ShapeUtils::shapeAsString(h0ShapeInfo).c_str());
REQUIRE_TRUE(ShapeUtils::areShapesEqual(WxShapeInfo, wxCorrectShape), 0, "GRU operation: wrong shape of input-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wxCorrectShape).c_str(), ShapeUtils::shapeAsString(WxShapeInfo).c_str());
REQUIRE_TRUE(ShapeUtils::areShapesEqual(WhShapeInfo, whCorrectShape), 0, "GRU operation: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(whCorrectShape).c_str(), ShapeUtils::shapeAsString(WhShapeInfo).c_str());
REQUIRE_TRUE(ShapeUtils::areShapesEqual(bShapeInfo, bCorrectShape), 0, "GRU operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bCorrectShape).c_str(), ShapeUtils::shapeAsString(bShapeInfo).c_str());
// evaluate output shapeInfo
Nd4jLong *hShapeInfo(nullptr);
ALLOCATE(hShapeInfo, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong);
hShapeInfo[0] = rank;
hShapeInfo[1] = time;
hShapeInfo[2] = bS;
hShapeInfo[3] = numUnits;
ShapeUtils::updateStridesAndType(hShapeInfo, xShapeInfo, shape::order(h0ShapeInfo));
return SHAPELIST(hShapeInfo);
}
}
}

View File

@ -145,30 +145,21 @@ CUSTOM_OP_IMPL(gruCell_bp, 10, 6, false, 0, 0) {
REQUIRE_TRUE(x->rankOf() == 2, 0, "GRU_CELL_BP: rank of input array x must be 2, but got %i instead", x->rankOf());
const std::string hiShape = ShapeUtils::shapeAsString(hi);
const std::string hiCorrectShape = ShapeUtils::shapeAsString({bS, nU});
const std::string wShape = ShapeUtils::shapeAsString(W);
const std::string wCorrectShape = ShapeUtils::shapeAsString({iS+nU, 2*nU});
const std::string wcShape = ShapeUtils::shapeAsString(Wc);
const std::string wcCorrectShape = ShapeUtils::shapeAsString({iS+nU, nU});
const std::string bShape = ShapeUtils::shapeAsString(b);
const std::string bCorrectShape = ShapeUtils::shapeAsString({2*nU});
const std::string bcShape = ShapeUtils::shapeAsString(bc);
const std::string bcCorrectShape = ShapeUtils::shapeAsString({nU});
const std::string dLdrShape = ShapeUtils::shapeAsString(dLdr);
const std::string dLduShape = ShapeUtils::shapeAsString(dLdu);
const std::string dLdcShape = ShapeUtils::shapeAsString(dLdc);
const std::string dLdhShape = ShapeUtils::shapeAsString(dLdh);
const std::vector<Nd4jLong> hiCorrectShape = {bS, nU};
const std::vector<Nd4jLong> wCorrectShape = {iS+nU, 2*nU};
const std::vector<Nd4jLong> wcCorrectShape = {iS+nU, nU};
const std::vector<Nd4jLong> bCorrectShape = {2*nU};
const std::vector<Nd4jLong> bcCorrectShape = {nU};
REQUIRE_TRUE(hiShape == hiCorrectShape, 0, "GRU_CELL_BP op: wrong shape of previous cell output array, expected is %s, but got %s instead !", hiCorrectShape.c_str(), hiShape.c_str());
REQUIRE_TRUE(wShape == wCorrectShape, 0, "GRU_CELL_BP op: wrong shape of weights array, expected is %s, but got %s instead !", wCorrectShape.c_str(), wShape.c_str());
REQUIRE_TRUE(wcShape == wcCorrectShape, 0, "GRU_CELL_BP op: wrong shape of c weights array, expected is %s, but got %s instead !", wcCorrectShape.c_str(), wcShape.c_str());
REQUIRE_TRUE(bShape == bCorrectShape, 0, "GRU_CELL_BP op: wrong shape of biases array, expected is %s, but got %s instead !", bCorrectShape.c_str(), bShape.c_str());
REQUIRE_TRUE(bcShape == bcCorrectShape, 0, "GRU_CELL_BP op: wrong shape of c biases array, expected is %s, but got %s instead !", bcCorrectShape.c_str(), bcShape.c_str());
REQUIRE_TRUE(dLdrShape == hiCorrectShape, 0, "GRU_CELL_BP op: wrong shape of dLdr array (gradient wrt reset gate), expected is %s, but got %s instead !", hiCorrectShape.c_str(), dLdrShape.c_str());
REQUIRE_TRUE(dLduShape == hiCorrectShape, 0, "GRU_CELL_BP op: wrong shape of dLdu array (gradient wrt update gate), expected is %s, but got %s instead !", hiCorrectShape.c_str(), dLduShape.c_str());
REQUIRE_TRUE(dLdcShape == hiCorrectShape, 0, "GRU_CELL_BP op: wrong shape of dLdc array (gradient wrt cell state), expected is %s, but got %s instead !", hiCorrectShape.c_str(), dLdcShape.c_str());
REQUIRE_TRUE(dLdhShape == hiCorrectShape, 0, "GRU_CELL_BP op: wrong shape of dLdh array (gradient wrt current cell output), expected is %s, but got %s instead !", hiCorrectShape.c_str(), dLdhShape.c_str());
REQUIRE_TRUE(hi->isSameShape(hiCorrectShape), 0, "GRU_CELL_BP op: wrong shape of previous cell output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(hiCorrectShape).c_str(), ShapeUtils::shapeAsString(hi).c_str());
REQUIRE_TRUE(W->isSameShape(wCorrectShape), 0, "GRU_CELL_BP op: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wCorrectShape).c_str(), ShapeUtils::shapeAsString(W).c_str());
REQUIRE_TRUE(Wc->isSameShape(wcCorrectShape), 0, "GRU_CELL_BP op: wrong shape of c weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wcCorrectShape).c_str(), ShapeUtils::shapeAsString(Wc).c_str());
REQUIRE_TRUE(b->isSameShape(bCorrectShape), 0, "GRU_CELL_BP op: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bCorrectShape).c_str(), ShapeUtils::shapeAsString(b).c_str());
REQUIRE_TRUE(bc->isSameShape(bcCorrectShape), 0, "GRU_CELL_BP op: wrong shape of c biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bcCorrectShape).c_str(), ShapeUtils::shapeAsString(bc).c_str());
REQUIRE_TRUE(dLdr->isSameShape(hiCorrectShape), 0, "GRU_CELL_BP op: wrong shape of dLdr array (gradient wrt reset gate), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(hiCorrectShape).c_str(), ShapeUtils::shapeAsString(dLdr).c_str());
REQUIRE_TRUE(dLdu->isSameShape(hiCorrectShape), 0, "GRU_CELL_BP op: wrong shape of dLdu array (gradient wrt update gate), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(hiCorrectShape).c_str(), ShapeUtils::shapeAsString(dLdu).c_str());
REQUIRE_TRUE(dLdc->isSameShape(hiCorrectShape), 0, "GRU_CELL_BP op: wrong shape of dLdc array (gradient wrt cell state), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(hiCorrectShape).c_str(), ShapeUtils::shapeAsString(dLdc).c_str());
REQUIRE_TRUE(dLdh->isSameShape(hiCorrectShape), 0, "GRU_CELL_BP op: wrong shape of dLdh array (gradient wrt current cell output), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(hiCorrectShape).c_str(), ShapeUtils::shapeAsString(dLdh).c_str());
helpers::gruCellBP(block.launchContext(), x, hi, W, Wc, b, bc, dLdr, dLdu, dLdc, dLdh, dLdx, dLdhi, dLdW, dLdWc, dLdb, dLdbc);
@ -210,30 +201,21 @@ DECLARE_SHAPE_FN(gruCell_bp) {
REQUIRE_TRUE(xShapeInfo[0] == 2, 0, "GRU_CELL_BP: rank of input array x must be 2, but got %i instead", xShapeInfo[0]);
const std::string hiShape = ShapeUtils::shapeAsString(hiShapeInfo);
const std::string hiCorrectShape = ShapeUtils::shapeAsString({bS, nU});
const std::string wShape = ShapeUtils::shapeAsString(wShapeInfo);
const std::string wCorrectShape = ShapeUtils::shapeAsString({iS+nU, 2*nU});
const std::string wcShape = ShapeUtils::shapeAsString(wcShapeInfo);
const std::string wcCorrectShape = ShapeUtils::shapeAsString({iS+nU, nU});
const std::string bShape = ShapeUtils::shapeAsString(bShapeInfo);
const std::string bCorrectShape = ShapeUtils::shapeAsString({2*nU});
const std::string bcShape = ShapeUtils::shapeAsString(bcShapeInfo);
const std::string bcCorrectShape = ShapeUtils::shapeAsString({nU});
const std::string dLdrShape = ShapeUtils::shapeAsString(dLdrShapeInfo);
const std::string dLduShape = ShapeUtils::shapeAsString(dLduShapeInfo);
const std::string dLdcShape = ShapeUtils::shapeAsString(dLdcShapeInfo);
const std::string dLdhShape = ShapeUtils::shapeAsString(dLdhShapeInfo);
const std::vector<Nd4jLong> hiCorrectShape = {bS, nU};
const std::vector<Nd4jLong> wCorrectShape = {iS+nU, 2*nU};
const std::vector<Nd4jLong> wcCorrectShape = {iS+nU, nU};
const std::vector<Nd4jLong> bCorrectShape = {2*nU};
const std::vector<Nd4jLong> bcCorrectShape = {nU};
REQUIRE_TRUE(hiShape == hiCorrectShape, 0, "GRU_CELL_BP op: wrong shape of previous cell output array, expected is %s, but got %s instead !", hiCorrectShape.c_str(), hiShape.c_str());
REQUIRE_TRUE(wShape == wCorrectShape, 0, "GRU_CELL_BP op: wrong shape of weights array, expected is %s, but got %s instead !", wCorrectShape.c_str(), wShape.c_str());
REQUIRE_TRUE(wcShape == wcCorrectShape, 0, "GRU_CELL_BP op: wrong shape of c weights array, expected is %s, but got %s instead !", wcCorrectShape.c_str(), wcShape.c_str());
REQUIRE_TRUE(bShape == bCorrectShape, 0, "GRU_CELL_BP op: wrong shape of biases array, expected is %s, but got %s instead !", bCorrectShape.c_str(), bShape.c_str());
REQUIRE_TRUE(bcShape == bcCorrectShape, 0, "GRU_CELL_BP op: wrong shape of c biases array, expected is %s, but got %s instead !", bcCorrectShape.c_str(), bcShape.c_str());
REQUIRE_TRUE(dLdrShape == hiCorrectShape, 0, "GRU_CELL_BP op: wrong shape of dLdr array (gradient wrt reset gate), expected is %s, but got %s instead !", hiCorrectShape.c_str(), dLdrShape.c_str());
REQUIRE_TRUE(dLduShape == hiCorrectShape, 0, "GRU_CELL_BP op: wrong shape of dLdu array (gradient wrt update gate), expected is %s, but got %s instead !", hiCorrectShape.c_str(), dLduShape.c_str());
REQUIRE_TRUE(dLdcShape == hiCorrectShape, 0, "GRU_CELL_BP op: wrong shape of dLdc array (gradient wrt cell state), expected is %s, but got %s instead !", hiCorrectShape.c_str(), dLdcShape.c_str());
REQUIRE_TRUE(dLdhShape == hiCorrectShape, 0, "GRU_CELL_BP op: wrong shape of dLdh array (gradient wrt current cell output), expected is %s, but got %s instead !", hiCorrectShape.c_str(), dLdhShape.c_str());
REQUIRE_TRUE(ShapeUtils::areShapesEqual(hiShapeInfo, hiCorrectShape), 0, "GRU_CELL_BP op: wrong shape of previous cell output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(hiCorrectShape).c_str(), ShapeUtils::shapeAsString(hiShapeInfo).c_str());
REQUIRE_TRUE(ShapeUtils::areShapesEqual(wShapeInfo, wCorrectShape), 0, "GRU_CELL_BP op: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wCorrectShape).c_str(), ShapeUtils::shapeAsString(wShapeInfo).c_str());
REQUIRE_TRUE(ShapeUtils::areShapesEqual(wcShapeInfo, wcCorrectShape), 0, "GRU_CELL_BP op: wrong shape of c weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wcCorrectShape).c_str(), ShapeUtils::shapeAsString(wcShapeInfo).c_str());
REQUIRE_TRUE(ShapeUtils::areShapesEqual(bShapeInfo, bCorrectShape), 0, "GRU_CELL_BP op: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bCorrectShape).c_str(), ShapeUtils::shapeAsString(bShapeInfo).c_str());
REQUIRE_TRUE(ShapeUtils::areShapesEqual(bcShapeInfo, bcCorrectShape), 0, "GRU_CELL_BP op: wrong shape of c biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bcCorrectShape).c_str(), ShapeUtils::shapeAsString(bcShapeInfo).c_str());
REQUIRE_TRUE(ShapeUtils::areShapesEqual(dLdrShapeInfo, hiCorrectShape), 0, "GRU_CELL_BP op: wrong shape of dLdr array (gradient wrt reset gate), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(hiCorrectShape).c_str(), ShapeUtils::shapeAsString(dLdrShapeInfo).c_str());
REQUIRE_TRUE(ShapeUtils::areShapesEqual(dLduShapeInfo, hiCorrectShape), 0, "GRU_CELL_BP op: wrong shape of dLdu array (gradient wrt update gate), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(hiCorrectShape).c_str(), ShapeUtils::shapeAsString(dLduShapeInfo).c_str());
REQUIRE_TRUE(ShapeUtils::areShapesEqual(dLdcShapeInfo, hiCorrectShape), 0, "GRU_CELL_BP op: wrong shape of dLdc array (gradient wrt cell state), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(hiCorrectShape).c_str(), ShapeUtils::shapeAsString(dLdcShapeInfo).c_str());
REQUIRE_TRUE(ShapeUtils::areShapesEqual(dLdhShapeInfo, hiCorrectShape), 0, "GRU_CELL_BP op: wrong shape of dLdh array (gradient wrt current cell output), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(hiCorrectShape).c_str(), ShapeUtils::shapeAsString(dLdhShapeInfo).c_str());
Nd4jLong *dLdxShapeInfo = nullptr;
COPY_SHAPE(xShapeInfo, dLdxShapeInfo);

View File

@ -39,10 +39,10 @@ CUSTOM_OP_IMPL(lstm, 8, 2, false, 3, 2) {
auto Wc = INPUT_VARIABLE(5); // diagonal weights for peephole connections [3*numUnits]
auto Wp = INPUT_VARIABLE(6); // projection weights [numUnits x numProj]
auto b = INPUT_VARIABLE(7); // biases, [4*numUnits]
auto h = OUTPUT_VARIABLE(0); // cell outputs [time x bS x numProj], that is per each time step
auto c = OUTPUT_VARIABLE(1); // cell states [time x bS x numUnits] that is per each time step
const int peephole = INT_ARG(0); // if 1, provide peephole connections
const int projection = INT_ARG(1); // if 1, then projection is performed, if false then numProj==numUnits is mandatory!!!!
@ -59,28 +59,21 @@ CUSTOM_OP_IMPL(lstm, 8, 2, false, 3, 2) {
const int numUnits = c0->sizeAt(1);
// input shapes validation
const std::string h0Shape = ShapeUtils::shapeAsString(h0);
const std::string correctH0Shape = ShapeUtils::shapeAsString({bS, numProj});
const std::string c0Shape = ShapeUtils::shapeAsString(c0);
const std::string correctC0Shape = ShapeUtils::shapeAsString({bS, numUnits});
const std::string WxShape = ShapeUtils::shapeAsString(Wx);
const std::string correctWxShape = ShapeUtils::shapeAsString({inSize, 4*numUnits});
const std::string WhShape = ShapeUtils::shapeAsString(Wh);
const std::string correctWhShape = ShapeUtils::shapeAsString({numProj, 4*numUnits});
const std::string WcShape = ShapeUtils::shapeAsString(Wc);
const std::string correctWcShape = ShapeUtils::shapeAsString({3*numUnits});
const std::string WpShape = ShapeUtils::shapeAsString(Wp);
const std::string correctWpShape = ShapeUtils::shapeAsString({numUnits, numProj});
const std::string bShape = ShapeUtils::shapeAsString(b);
const std::string correctBShape = ShapeUtils::shapeAsString({4*numUnits});
const std::vector<Nd4jLong> correctH0Shape = {bS, numProj};
const std::vector<Nd4jLong> correctC0Shape = {bS, numUnits};
const std::vector<Nd4jLong> correctWxShape = {inSize, 4*numUnits};
const std::vector<Nd4jLong> correctWhShape = {numProj, 4*numUnits};
const std::vector<Nd4jLong> correctWcShape = {3*numUnits};
const std::vector<Nd4jLong> correctWpShape = {numUnits, numProj};
const std::vector<Nd4jLong> correctBShape = {4*numUnits};
REQUIRE_TRUE(correctH0Shape == h0Shape, 0, "LSTM operation: wrong shape of initial cell output, expected is %s, but got %s instead !", correctH0Shape.c_str(), h0Shape.c_str());
REQUIRE_TRUE(correctC0Shape == c0Shape, 0, "LSTM operation: wrong shape of initial cell state, expected is %s, but got %s instead !", correctC0Shape.c_str(), c0Shape.c_str());
REQUIRE_TRUE(correctWxShape == WxShape, 0, "LSTM operation: wrong shape of input-to-hidden weights, expected is %s, but got %s instead !", correctWxShape.c_str(), WxShape.c_str());
REQUIRE_TRUE(correctWhShape == WhShape, 0, "LSTM operation: wrong shape of hidden-to-hidden weights, expected is %s, but got %s instead !", correctWhShape.c_str(), WhShape.c_str());
REQUIRE_TRUE(correctWcShape == WcShape, 0, "LSTM operation: wrong shape of diagonal weights for peephole connections, expected is %s, but got %s instead !", correctWcShape.c_str(), WcShape.c_str());
REQUIRE_TRUE(correctWpShape == WpShape, 0, "LSTM operation: wrong shape of projection weights, expected is %s, but got %s instead !", correctWpShape.c_str(), WpShape.c_str());
REQUIRE_TRUE(correctBShape == bShape, 0, "LSTM operation: wrong shape of biases, expected is %s, but got %s instead !", correctBShape.c_str(), bShape.c_str());
REQUIRE_TRUE(h0->isSameShape(correctH0Shape), 0, "LSTM operation: wrong shape of initial cell output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctH0Shape).c_str(), ShapeUtils::shapeAsString(h0).c_str());
REQUIRE_TRUE(c0->isSameShape(correctC0Shape), 0, "LSTM operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctC0Shape).c_str(), ShapeUtils::shapeAsString(c0).c_str());
REQUIRE_TRUE(Wx->isSameShape(correctWxShape), 0, "LSTM operation: wrong shape of input-to-hidden weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctWxShape).c_str(), ShapeUtils::shapeAsString(Wx).c_str());
REQUIRE_TRUE(Wh->isSameShape(correctWhShape), 0, "LSTM operation: wrong shape of hidden-to-hidden weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctWhShape).c_str(), ShapeUtils::shapeAsString(Wh).c_str());
REQUIRE_TRUE(Wc->isSameShape(correctWcShape), 0, "LSTM operation: wrong shape of diagonal weights for peephole connections, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctWcShape).c_str(), ShapeUtils::shapeAsString(Wc).c_str());
REQUIRE_TRUE(Wp->isSameShape(correctWpShape), 0, "LSTM operation: wrong shape of projection weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctWpShape).c_str(), ShapeUtils::shapeAsString(Wp).c_str());
REQUIRE_TRUE(b->isSameShape(correctBShape), 0, "LSTM operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctBShape).c_str(), ShapeUtils::shapeAsString(b).c_str());
REQUIRE_TRUE(!(!projection && numUnits != numProj), 0, "LSTM operation: projection option is switched of, and in this case output dimensionality for the projection matrices (numProj) must be equal to number of units in lstmCell !");
helpers::lstmTimeLoop(block.launchContext(), x, h0, c0, Wx, Wh, Wc, Wp, b, h, c, {(double)peephole, (double)projection, clippingCellValue, clippingProjValue, forgetBias});
@ -95,7 +88,7 @@ CUSTOM_OP_IMPL(lstm, 8, 2, false, 3, 2) {
}
DECLARE_SHAPE_FN(lstm) {
DECLARE_SHAPE_FN(lstm) {
auto xShapeInfo = inputShape->at(0); // input [time x bS x inSize]
auto h0ShapeInfo = inputShape->at(1); // initial cell output (at time step = 0) [bS x numProj], in case of projection=false -> numProj == numUnits !!!
@ -113,37 +106,30 @@ DECLARE_SHAPE_FN(lstm) {
const int inSize = xShapeInfo[3];
const int numProj = h0ShapeInfo[2];
const int numUnits = c0ShapeInfo[2];
// input shapes validation
const std::string h0Shape = ShapeUtils::shapeAsString(h0ShapeInfo);
const std::string correctH0Shape = ShapeUtils::shapeAsString({bS, numProj});
const std::string c0Shape = ShapeUtils::shapeAsString(c0ShapeInfo);
const std::string correctC0Shape = ShapeUtils::shapeAsString({bS, numUnits});
const std::string WxShape = ShapeUtils::shapeAsString(WxShapeInfo);
const std::string correctWxShape = ShapeUtils::shapeAsString({inSize, 4*numUnits});
const std::string WhShape = ShapeUtils::shapeAsString(WhShapeInfo);
const std::string correctWhShape = ShapeUtils::shapeAsString({numProj, 4*numUnits});
const std::string WcShape = ShapeUtils::shapeAsString(WcShapeInfo);
const std::string correctWcShape = ShapeUtils::shapeAsString({3*numUnits});
const std::string WpShape = ShapeUtils::shapeAsString(WpShapeInfo);
const std::string correctWpShape = ShapeUtils::shapeAsString({numUnits, numProj});
const std::string bShape = ShapeUtils::shapeAsString(bShapeInfo);
const std::string correctBShape = ShapeUtils::shapeAsString({4*numUnits});
const std::vector<Nd4jLong> correctH0Shape = {bS, numProj};
const std::vector<Nd4jLong> correctC0Shape = {bS, numUnits};
const std::vector<Nd4jLong> correctWxShape = {inSize, 4*numUnits};
const std::vector<Nd4jLong> correctWhShape = {numProj, 4*numUnits};
const std::vector<Nd4jLong> correctWcShape = {3*numUnits};
const std::vector<Nd4jLong> correctWpShape = {numUnits, numProj};
const std::vector<Nd4jLong> correctBShape = {4*numUnits};
REQUIRE_TRUE(ShapeUtils::areShapesEqual(h0ShapeInfo, correctH0Shape), 0, "LSTM operation: wrong shape of initial cell output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctH0Shape).c_str(), ShapeUtils::shapeAsString(h0ShapeInfo).c_str());
REQUIRE_TRUE(ShapeUtils::areShapesEqual(c0ShapeInfo, correctC0Shape), 0, "LSTM operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctC0Shape).c_str(), ShapeUtils::shapeAsString(c0ShapeInfo).c_str());
REQUIRE_TRUE(ShapeUtils::areShapesEqual(WxShapeInfo, correctWxShape), 0, "LSTM operation: wrong shape of input-to-hidden weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctWxShape).c_str(), ShapeUtils::shapeAsString(WxShapeInfo).c_str());
REQUIRE_TRUE(ShapeUtils::areShapesEqual(WhShapeInfo, correctWhShape), 0, "LSTM operation: wrong shape of hidden-to-hidden weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctWhShape).c_str(), ShapeUtils::shapeAsString(WhShapeInfo).c_str());
REQUIRE_TRUE(ShapeUtils::areShapesEqual(WcShapeInfo, correctWcShape), 0, "LSTM operation: wrong shape of diagonal weights for peephole connections, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctWcShape).c_str(), ShapeUtils::shapeAsString(WcShapeInfo).c_str());
REQUIRE_TRUE(ShapeUtils::areShapesEqual(WpShapeInfo, correctWpShape), 0, "LSTM operation: wrong shape of projection weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctWpShape).c_str(), ShapeUtils::shapeAsString(WpShapeInfo).c_str());
REQUIRE_TRUE(ShapeUtils::areShapesEqual(bShapeInfo, correctBShape), 0, "LSTM operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctBShape).c_str(), ShapeUtils::shapeAsString(bShapeInfo).c_str());
REQUIRE_TRUE(correctH0Shape == h0Shape, 0, "LSTM operation: wrong shape of initial cell output, expected is %s, but got %s instead !", correctH0Shape.c_str(), h0Shape.c_str());
REQUIRE_TRUE(correctC0Shape == c0Shape, 0, "LSTM operation: wrong shape of initial cell state, expected is %s, but got %s instead !", correctC0Shape.c_str(), c0Shape.c_str());
REQUIRE_TRUE(correctWxShape == WxShape, 0, "LSTM operation: wrong shape of input-to-hidden weights, expected is %s, but got %s instead !", correctWxShape.c_str(), WxShape.c_str());
REQUIRE_TRUE(correctWhShape == WhShape, 0, "LSTM operation: wrong shape of hidden-to-hidden weights, expected is %s, but got %s instead !", correctWhShape.c_str(), WhShape.c_str());
REQUIRE_TRUE(correctWcShape == WcShape, 0, "LSTM operation: wrong shape of diagonal weights for peephole connections, expected is %s, but got %s instead !", correctWcShape.c_str(), WcShape.c_str());
REQUIRE_TRUE(correctWpShape == WpShape, 0, "LSTM operation: wrong shape of projection weights, expected is %s, but got %s instead !", correctWpShape.c_str(), WpShape.c_str());
REQUIRE_TRUE(correctBShape == bShape, 0, "LSTM operation: wrong shape of biases, expected is %s, but got %s instead !", correctBShape.c_str(), bShape.c_str());
// evaluate output shapeInfos
Nd4jLong *hShapeInfo(nullptr), *cShapeInfo(nullptr);
ALLOCATE(hShapeInfo, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong); // [time x bS x numProj]
ALLOCATE(cShapeInfo, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong); // [time x bS x numUnits]
hShapeInfo[0] = cShapeInfo[0] = rank;
hShapeInfo[1] = cShapeInfo[1] = time;
hShapeInfo[2] = cShapeInfo[2] = bS;
@ -152,9 +138,9 @@ DECLARE_SHAPE_FN(lstm) {
ShapeUtils::updateStridesAndType(hShapeInfo, xShapeInfo, shape::order(h0ShapeInfo));
ShapeUtils::updateStridesAndType(cShapeInfo, xShapeInfo, shape::order(c0ShapeInfo));
return SHAPELIST(CONSTANT(hShapeInfo), CONSTANT(cShapeInfo));
}
}

View File

@ -39,10 +39,10 @@ CUSTOM_OP_IMPL(lstmCell, 8, 2, false, 3, 2) {
auto Wc = INPUT_VARIABLE(5); // diagonal weights for peephole connections [3*numUnits]
auto Wp = INPUT_VARIABLE(6); // projection weights [numUnits x numProj]
auto b = INPUT_VARIABLE(7); // biases, [4*numUnits]
auto ht = OUTPUT_VARIABLE(0); // current cell output [bS x numProj], that is at current time step t
auto ct = OUTPUT_VARIABLE(1); // current cell state [bS x numUnits], that is at current time step t
const int peephole = INT_ARG(0); // if 1, provide peephole connections
const int projection = INT_ARG(1); // if 1, then projection is performed, if false then numProj==numUnits is mandatory!!!!
@ -51,40 +51,33 @@ CUSTOM_OP_IMPL(lstmCell, 8, 2, false, 3, 2) {
const double clippingProjValue = T_ARG(1); // clipping value for projected ht, if it is not equal to zero, then projected cell output is clipped
const double forgetBias = T_ARG(2);
const int rank = xt->rankOf();
const int rank = xt->rankOf();
const int bS = xt->sizeAt(0);
const int inSize = xt->sizeAt(1);
const int numProj = ht_1->sizeAt(1);
const int numUnits = ct_1->sizeAt(1);
// input shapes validation
const std::string ht_1Shape = ShapeUtils::shapeAsString(ht_1);
const std::string correctHt_1Shape = ShapeUtils::shapeAsString({bS, numProj});
const std::string ct_1Shape = ShapeUtils::shapeAsString(ct_1);
const std::string correctCt_1Shape = ShapeUtils::shapeAsString({bS, numUnits});
const std::string WxShape = ShapeUtils::shapeAsString(Wx);
const std::string correctWxShape = ShapeUtils::shapeAsString({inSize, 4*numUnits});
const std::string WhShape = ShapeUtils::shapeAsString(Wh);
const std::string correctWhShape = ShapeUtils::shapeAsString({numProj, 4*numUnits});
const std::string WcShape = ShapeUtils::shapeAsString(Wc);
const std::string correctWcShape = ShapeUtils::shapeAsString({3*numUnits});
const std::string WpShape = ShapeUtils::shapeAsString(Wp);
const std::string correctWpShape = ShapeUtils::shapeAsString({numUnits, numProj});
const std::string bShape = ShapeUtils::shapeAsString(b);
const std::string correctBShape = ShapeUtils::shapeAsString({4*numUnits});
const int numUnits = ct_1->sizeAt(1);
REQUIRE_TRUE(correctHt_1Shape == ht_1Shape, 0, "LSTMCELL operation: wrong shape of initial cell output, expected is %s, but got %s instead !", correctHt_1Shape.c_str(), ht_1Shape.c_str());
REQUIRE_TRUE(correctCt_1Shape == ct_1Shape, 0, "LSTMCELL operation: wrong shape of initial cell state, expected is %s, but got %s instead !", correctCt_1Shape.c_str(), ct_1Shape.c_str());
REQUIRE_TRUE(correctWxShape == WxShape, 0, "LSTMCELL operation: wrong shape of input-to-hidden weights, expected is %s, but got %s instead !", correctWxShape.c_str(), WxShape.c_str());
REQUIRE_TRUE(correctWhShape == WhShape, 0, "LSTMCELL operation: wrong shape of hidden-to-hidden weights, expected is %s, but got %s instead !", correctWhShape.c_str(), WhShape.c_str());
REQUIRE_TRUE(correctWcShape == WcShape, 0, "LSTMCELL operation: wrong shape of diagonal weights for peephole connections, expected is %s, but got %s instead !", correctWcShape.c_str(), WcShape.c_str());
REQUIRE_TRUE(correctWpShape == WpShape, 0, "LSTMCELL operation: wrong shape of projection weights, expected is %s, but got %s instead !", correctWpShape.c_str(), WpShape.c_str());
REQUIRE_TRUE(correctBShape == bShape, 0, "LSTMCELL operation: wrong shape of biases, expected is %s, but got %s instead !", correctBShape.c_str(), bShape.c_str());
// input shapes validation
const std::vector<Nd4jLong> correctHt_1Shape = {bS, numProj};
const std::vector<Nd4jLong> correctCt_1Shape = {bS, numUnits};
const std::vector<Nd4jLong> correctWxShape = {inSize, 4*numUnits};
const std::vector<Nd4jLong> correctWhShape = {numProj, 4*numUnits};
const std::vector<Nd4jLong> correctWcShape = {3*numUnits};
const std::vector<Nd4jLong> correctWpShape = {numUnits, numProj};
const std::vector<Nd4jLong> correctBShape = {4*numUnits};
REQUIRE_TRUE(ht_1->isSameShape(correctHt_1Shape), 0, "LSTMCELL operation: wrong shape of initial cell output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctHt_1Shape).c_str(), ShapeUtils::shapeAsString(ht_1).c_str());
REQUIRE_TRUE(ct_1->isSameShape(correctCt_1Shape), 0, "LSTMCELL operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctCt_1Shape).c_str(), ShapeUtils::shapeAsString(ct_1).c_str());
REQUIRE_TRUE(Wx->isSameShape(correctWxShape), 0, "LSTMCELL operation: wrong shape of input-to-hidden weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctWxShape).c_str(), ShapeUtils::shapeAsString(Wx).c_str());
REQUIRE_TRUE(Wh->isSameShape(correctWhShape), 0, "LSTMCELL operation: wrong shape of hidden-to-hidden weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctWhShape).c_str(), ShapeUtils::shapeAsString(Wh).c_str());
REQUIRE_TRUE(Wc->isSameShape(correctWcShape), 0, "LSTMCELL operation: wrong shape of diagonal weights for peephole connections, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctWcShape).c_str(), ShapeUtils::shapeAsString(Wc).c_str());
REQUIRE_TRUE(Wp->isSameShape(correctWpShape), 0, "LSTMCELL operation: wrong shape of projection weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctWpShape).c_str(), ShapeUtils::shapeAsString(Wp).c_str());
REQUIRE_TRUE(b->isSameShape(correctBShape), 0, "LSTMCELL operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctBShape).c_str(), ShapeUtils::shapeAsString(b).c_str());
REQUIRE_TRUE(!(!projection && numUnits != numProj), 0, "LSTMCELL operation: projection option is switched of, and in this case output dimensionality for the projection matrices (numProj) must be equal to number of units in lstmCell !");
// calculations
// calculations
helpers::lstmCell(block.launchContext(), xt,ht_1,ct_1, Wx,Wh,Wc,Wp, b, ht,ct, {(double)peephole, (double)projection, clippingCellValue, clippingProjValue, forgetBias});
return Status::OK();
}
@ -95,53 +88,46 @@ CUSTOM_OP_IMPL(lstmCell, 8, 2, false, 3, 2) {
}
DECLARE_SHAPE_FN(lstmCell) {
DECLARE_SHAPE_FN(lstmCell) {
auto xtShapeInfo = inputShape->at(0); // input [bS x inSize]
auto ht_1ShapeInfo = inputShape->at(1); // previous cell output [bS x numProj], that is at previous time step t-1, in case of projection=false -> numProj=numUnits!!!
auto ct_1ShapeInfo = inputShape->at(2); // previous cell state [bS x numUnits], that is at previous time step t-1
auto ht_1ShapeInfo = inputShape->at(1); // previous cell output [bS x numProj], that is at previous time step t-1, in case of projection=false -> numProj=numUnits!!!
auto ct_1ShapeInfo = inputShape->at(2); // previous cell state [bS x numUnits], that is at previous time step t-1
auto WxShapeInfo = inputShape->at(3); // input-to-hidden weights, [inSize x 4*numUnits]
auto WhShapeInfo = inputShape->at(4); // hidden-to-hidden weights, [numProj x 4*numUnits]
auto WcShapeInfo = inputShape->at(5); // diagonal weights for peephole connections [3*numUnits]
auto WpShapeInfo = inputShape->at(6); // projection weights [numUnits x numProj]
auto bShapeInfo = inputShape->at(7); // biases, [4*numUnits]
auto WxShapeInfo = inputShape->at(3); // input-to-hidden weights, [inSize x 4*numUnits]
auto WhShapeInfo = inputShape->at(4); // hidden-to-hidden weights, [numProj x 4*numUnits]
auto WcShapeInfo = inputShape->at(5); // diagonal weights for peephole connections [3*numUnits]
auto WpShapeInfo = inputShape->at(6); // projection weights [numUnits x numProj]
auto bShapeInfo = inputShape->at(7); // biases, [4*numUnits]
const int rank = shape::rank(xtShapeInfo);
const auto bS = xtShapeInfo[1];
const auto inSize = xtShapeInfo[2];
const auto numProj = ht_1ShapeInfo[2];
const auto numUnits = ct_1ShapeInfo[2];
// input shapes validation
const std::string ht_1Shape = ShapeUtils::shapeAsString(ht_1ShapeInfo);
const std::string correctHt_1Shape = ShapeUtils::shapeAsString({bS, numProj});
const std::string ct_1Shape = ShapeUtils::shapeAsString(ct_1ShapeInfo);
const std::string correctCt_1Shape = ShapeUtils::shapeAsString({bS, numUnits});
const std::string WxShape = ShapeUtils::shapeAsString(WxShapeInfo);
const std::string correctWxShape = ShapeUtils::shapeAsString({inSize, 4*numUnits});
const std::string WhShape = ShapeUtils::shapeAsString(WhShapeInfo);
const std::string correctWhShape = ShapeUtils::shapeAsString({numProj, 4*numUnits});
const std::string WcShape = ShapeUtils::shapeAsString(WcShapeInfo );
const std::string correctWcShape = ShapeUtils::shapeAsString({3*numUnits});
const std::string WpShape = ShapeUtils::shapeAsString(WpShapeInfo);
const std::string correctWpShape = ShapeUtils::shapeAsString({numUnits, numProj});
const std::string bShape = ShapeUtils::shapeAsString(bShapeInfo);
const std::string correctBShape = ShapeUtils::shapeAsString({4*numUnits});
REQUIRE_TRUE(correctHt_1Shape == ht_1Shape, 0, "LSTMCELL operation: wrong shape of initial cell output, expected is %s, but got %s instead !", correctHt_1Shape.c_str(), ht_1Shape.c_str());
REQUIRE_TRUE(correctCt_1Shape == ct_1Shape, 0, "LSTMCELL operation: wrong shape of initial cell state, expected is %s, but got %s instead !", correctCt_1Shape.c_str(), ct_1Shape.c_str());
REQUIRE_TRUE(correctWxShape == WxShape, 0, "LSTMCELL operation: wrong shape of input-to-hidden weights, expected is %s, but got %s instead !", correctWxShape.c_str(), WxShape.c_str());
REQUIRE_TRUE(correctWhShape == WhShape, 0, "LSTMCELL operation: wrong shape of hidden-to-hidden weights, expected is %s, but got %s instead !", correctWhShape.c_str(), WhShape.c_str());
REQUIRE_TRUE(correctWcShape == WcShape, 0, "LSTMCELL operation: wrong shape of diagonal weights for peephole connections, expected is %s, but got %s instead !", correctWcShape.c_str(), WcShape.c_str());
REQUIRE_TRUE(correctWpShape == WpShape, 0, "LSTMCELL operation: wrong shape of projection weights, expected is %s, but got %s instead !", correctWpShape.c_str(), WpShape.c_str());
REQUIRE_TRUE(correctBShape == bShape, 0, "LSTMCELL operation: wrong shape of biases, expected is %s, but got %s instead !", correctBShape.c_str(), bShape.c_str());
// input shapes validation
const std::vector<Nd4jLong> correctHt_1Shape = {bS, numProj};
const std::vector<Nd4jLong> correctCt_1Shape = {bS, numUnits};
const std::vector<Nd4jLong> correctWxShape = {inSize, 4*numUnits};
const std::vector<Nd4jLong> correctWhShape = {numProj, 4*numUnits};
const std::vector<Nd4jLong> correctWcShape = {3*numUnits};
const std::vector<Nd4jLong> correctWpShape = {numUnits, numProj};
const std::vector<Nd4jLong> correctBShape = {4*numUnits};
REQUIRE_TRUE(ShapeUtils::areShapesEqual(ht_1ShapeInfo, correctHt_1Shape), 0, "LSTMCELL operation: wrong shape of initial cell output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctHt_1Shape).c_str(), ShapeUtils::shapeAsString(ht_1ShapeInfo).c_str());
REQUIRE_TRUE(ShapeUtils::areShapesEqual(ct_1ShapeInfo, correctCt_1Shape), 0, "LSTMCELL operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctCt_1Shape).c_str(), ShapeUtils::shapeAsString(ct_1ShapeInfo).c_str());
REQUIRE_TRUE(ShapeUtils::areShapesEqual(WxShapeInfo, correctWxShape), 0, "LSTMCELL operation: wrong shape of input-to-hidden weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctWxShape).c_str(), ShapeUtils::shapeAsString(WxShapeInfo).c_str());
REQUIRE_TRUE(ShapeUtils::areShapesEqual(WhShapeInfo, correctWhShape), 0, "LSTMCELL operation: wrong shape of hidden-to-hidden weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctWhShape).c_str(), ShapeUtils::shapeAsString(WhShapeInfo).c_str());
REQUIRE_TRUE(ShapeUtils::areShapesEqual(WcShapeInfo, correctWcShape), 0, "LSTMCELL operation: wrong shape of diagonal weights for peephole connections, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctWcShape).c_str(), ShapeUtils::shapeAsString(WcShapeInfo).c_str());
REQUIRE_TRUE(ShapeUtils::areShapesEqual(WpShapeInfo, correctWpShape), 0, "LSTMCELL operation: wrong shape of projection weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctWpShape).c_str(), ShapeUtils::shapeAsString(WpShapeInfo).c_str());
REQUIRE_TRUE(ShapeUtils::areShapesEqual(bShapeInfo, correctBShape), 0, "LSTMCELL operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctBShape).c_str(), ShapeUtils::shapeAsString(bShapeInfo).c_str());
// evaluate output shapeInfos
Nd4jLong *hShapeInfo(nullptr), *cShapeInfo(nullptr);
ALLOCATE(hShapeInfo, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong); // [bS x numProj]
ALLOCATE(cShapeInfo, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong); // [bS x numUnits]
hShapeInfo[0] = cShapeInfo[0] = rank;
hShapeInfo[1] = cShapeInfo[1] = bS;
hShapeInfo[2] = numProj;
@ -154,7 +140,7 @@ DECLARE_SHAPE_FN(lstmCell) {
RELEASE(hShapeInfo, block.workspace());
RELEASE(cShapeInfo, block.workspace());
return result;
}
}
}
}

View File

@ -54,20 +54,15 @@ CUSTOM_OP_IMPL(sru, 5, 2, false, 0, 0) {
if(mask)
REQUIRE_TRUE(mask->rankOf() == rank-1, 0, "SRU operation: wrong rank of mask array, expected is %i, but got %i instead !", rank-1, mask->rankOf());
const std::string wShape = ShapeUtils::shapeAsString(w);
const std::string wCorrectShape = ShapeUtils::shapeAsString({3*inSize, inSize});
const std::string bShape = ShapeUtils::shapeAsString(b);
const std::string bCorrectShape = ShapeUtils::shapeAsString({2*inSize});
const std::string c0Shape = ShapeUtils::shapeAsString(c0);
const std::string c0CorrectShape = ShapeUtils::shapeAsString({bS, inSize});
const std::vector<Nd4jLong> wCorrectShape = {3*inSize, inSize};
const std::vector<Nd4jLong> bCorrectShape = {2*inSize};
const std::vector<Nd4jLong> c0CorrectShape = {bS, inSize};
REQUIRE_TRUE(wShape == wCorrectShape, 0, "SRU operation: wrong shape of weights array, expected is %s, but got %s instead !", wCorrectShape.c_str(), wShape.c_str());
REQUIRE_TRUE(bShape == bCorrectShape, 0, "SRU operation: wrong shape of biases array, expected is %s, but got %s instead !", bCorrectShape.c_str(), bShape.c_str());
REQUIRE_TRUE(c0Shape == c0CorrectShape, 0, "SRU operation: wrong shape of initial state array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), c0Shape.c_str());
if(mask) {
const std::string maskShape = ShapeUtils::shapeAsString(mask);
REQUIRE_TRUE(maskShape == c0CorrectShape, 0, "SRU operation: wrong shape of mask array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), maskShape.c_str());
}
REQUIRE_TRUE(w->isSameShape(wCorrectShape), 0, "SRU operation: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wCorrectShape).c_str(), ShapeUtils::shapeAsString(w).c_str());
REQUIRE_TRUE(b->isSameShape(bCorrectShape), 0, "SRU operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bCorrectShape).c_str(), ShapeUtils::shapeAsString(b).c_str());
REQUIRE_TRUE(c0->isSameShape(c0CorrectShape), 0, "SRU operation: wrong shape of initial state array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(c0CorrectShape).c_str(), ShapeUtils::shapeAsString(c0).c_str());
if(mask)
REQUIRE_TRUE(mask->isSameShape(c0CorrectShape), 0, "SRU operation: wrong shape of mask array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(c0CorrectShape).c_str(), ShapeUtils::shapeAsString(mask).c_str());
// xm = x * mask
auto xm = x;
@ -111,20 +106,15 @@ DECLARE_SHAPE_FN(sru) {
if(maskShapeInfo)
REQUIRE_TRUE(maskShapeInfo[0] == rank-1, 0, "SRU operation: wrong rank of mask array, expected is %i, but got %i instead !", rank-1, maskShapeInfo[0]);
const std::string wShape = ShapeUtils::shapeAsString(wShapeInfo);
const std::string wCorrectShape = ShapeUtils::shapeAsString({3*inSize, inSize});
const std::string bShape = ShapeUtils::shapeAsString(bShapeInfo);
const std::string bCorrectShape = ShapeUtils::shapeAsString({2*inSize});
const std::string c0Shape = ShapeUtils::shapeAsString(c0ShapeInfo);
const std::string c0CorrectShape = ShapeUtils::shapeAsString({bS, inSize});
const std::vector<Nd4jLong> wCorrectShape = {3*inSize, inSize};
const std::vector<Nd4jLong> bCorrectShape = {2*inSize};
const std::vector<Nd4jLong> c0CorrectShape = {bS, inSize};
REQUIRE_TRUE(wShape == wCorrectShape, 0, "SRU operation: wrong shape of weights array, expected is %s, but got %s instead !", wCorrectShape.c_str(), wShape.c_str());
REQUIRE_TRUE(bShape == bCorrectShape, 0, "SRU operation: wrong shape of biases array, expected is %s, but got %s instead !", bCorrectShape.c_str(), bShape.c_str());
REQUIRE_TRUE(c0Shape == c0CorrectShape, 0, "SRU operation: wrong shape of initial state array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), c0Shape.c_str());
if(maskShapeInfo) {
const std::string maskShape = ShapeUtils::shapeAsString(maskShapeInfo);
REQUIRE_TRUE(maskShape == c0CorrectShape, 0, "SRU operation: wrong shape of mask array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), maskShape.c_str());
}
REQUIRE_TRUE(ShapeUtils::areShapesEqual(wShapeInfo, wCorrectShape), 0, "SRU operation: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wCorrectShape).c_str(), ShapeUtils::shapeAsString(wShapeInfo).c_str());
REQUIRE_TRUE(ShapeUtils::areShapesEqual(bShapeInfo, bCorrectShape), 0, "SRU operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bCorrectShape).c_str(), ShapeUtils::shapeAsString(bShapeInfo).c_str());
REQUIRE_TRUE(ShapeUtils::areShapesEqual(c0ShapeInfo, c0CorrectShape), 0, "SRU operation: wrong shape of initial state array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(c0CorrectShape).c_str(), ShapeUtils::shapeAsString(c0ShapeInfo).c_str());
if(maskShapeInfo)
REQUIRE_TRUE(ShapeUtils::areShapesEqual(maskShapeInfo, c0CorrectShape), 0, "SRU operation: wrong shape of mask array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(c0CorrectShape).c_str(), ShapeUtils::shapeAsString(maskShapeInfo).c_str());
Nd4jLong* newShapeInfo1 = nullptr;
ALLOCATE(newShapeInfo1, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong); // [bS x inSize x time]
@ -350,20 +340,15 @@ CUSTOM_OP_IMPL(sru_bi, 5, 2, true, 0, 0) {
if(mask)
REQUIRE_TRUE(mask->rankOf() == rank-1, 0, "SRU_BI operation: wrong rank of mask array, expected is %i, but got %i instead !", rank-1, mask->rankOf());
const std::string wShape = ShapeUtils::shapeAsString(w);
const std::string wCorrectShape = ShapeUtils::shapeAsString({2*inSize, 6*inSize});
const std::string bShape = ShapeUtils::shapeAsString(b);
const std::string bCorrectShape = ShapeUtils::shapeAsString({4*inSize});
const std::string c0Shape = ShapeUtils::shapeAsString(c0);
const std::string c0CorrectShape = ShapeUtils::shapeAsString({bS, 2*inSize});
const std::vector<Nd4jLong> wCorrectShape = {2*inSize, 6*inSize};
const std::vector<Nd4jLong> bCorrectShape = {4*inSize};
const std::vector<Nd4jLong> c0CorrectShape = {bS, 2*inSize};
REQUIRE_TRUE(wShape == wCorrectShape, 0, "SRU_BI operation: wrong shape of weights array, expected is %s, but got %s instead !", wCorrectShape.c_str(), wShape.c_str());
REQUIRE_TRUE(bShape == bCorrectShape, 0, "SRU_BI operation: wrong shape of biases array, expected is %s, but got %s instead !", bCorrectShape.c_str(), bShape.c_str());
REQUIRE_TRUE(c0Shape == c0CorrectShape, 0, "SRU_BI operation: wrong shape of initial state array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), c0Shape.c_str());
if(mask) {
const std::string maskShape = ShapeUtils::shapeAsString(mask);
REQUIRE_TRUE(maskShape == c0CorrectShape, 0, "SRU_BI operation: wrong shape of mask array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), maskShape.c_str());
}
REQUIRE_TRUE(w->isSameShape(wCorrectShape), 0, "SRU_BI operation: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wCorrectShape).c_str(), ShapeUtils::shapeAsString(w).c_str());
REQUIRE_TRUE(b->isSameShape(bCorrectShape), 0, "SRU_BI operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bCorrectShape).c_str(), ShapeUtils::shapeAsString(b).c_str());
REQUIRE_TRUE(c0->isSameShape(c0CorrectShape), 0, "SRU_BI operation: wrong shape of initial state array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(c0CorrectShape).c_str(), ShapeUtils::shapeAsString(c0).c_str());
if(mask)
REQUIRE_TRUE(mask->isSameShape(c0CorrectShape), 0, "SRU_BI operation: wrong shape of mask array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(c0CorrectShape).c_str(), ShapeUtils::shapeAsString(mask).c_str());
helpers::sruBI(block.launchContext(), x, w, b, c0, mask, ht, ct);
@ -397,20 +382,16 @@ DECLARE_SHAPE_FN(sru_bi) {
if(maskShapeInfo)
REQUIRE_TRUE(maskShapeInfo[0] == rank-1, 0, "SRU_BI operation: wrong rank of mask array, expected is %i, but got %i instead !", rank-1, maskShapeInfo[0]);
const std::string wShape = ShapeUtils::shapeAsString(wShapeInfo);
const std::string wCorrectShape = ShapeUtils::shapeAsString({2*inSize, 6*inSize});
const std::string bShape = ShapeUtils::shapeAsString(bShapeInfo);
const std::string bCorrectShape = ShapeUtils::shapeAsString({4*inSize});
const std::string c0Shape = ShapeUtils::shapeAsString(c0ShapeInfo);
const std::string c0CorrectShape = ShapeUtils::shapeAsString({bS, 2*inSize});
const std::vector<Nd4jLong> wCorrectShape = {2*inSize, 6*inSize};
const std::vector<Nd4jLong> bCorrectShape = {4*inSize};
const std::vector<Nd4jLong> c0CorrectShape = {bS, 2*inSize};
REQUIRE_TRUE(ShapeUtils::areShapesEqual(wShapeInfo, wCorrectShape), 0, "SRU_BI operation: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wCorrectShape).c_str(), ShapeUtils::shapeAsString(wShapeInfo).c_str());
REQUIRE_TRUE(ShapeUtils::areShapesEqual(bShapeInfo, bCorrectShape), 0, "SRU_BI operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bCorrectShape).c_str(), ShapeUtils::shapeAsString(bShapeInfo).c_str());
REQUIRE_TRUE(ShapeUtils::areShapesEqual(c0ShapeInfo, c0CorrectShape), 0, "SRU_BI operation: wrong shape of initial state array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(c0CorrectShape).c_str(), ShapeUtils::shapeAsString(c0ShapeInfo).c_str());
if(maskShapeInfo)
REQUIRE_TRUE(ShapeUtils::areShapesEqual(maskShapeInfo, c0CorrectShape), 0, "SRU_BI operation: wrong shape of mask array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(c0CorrectShape).c_str(), ShapeUtils::shapeAsString(maskShapeInfo).c_str());
REQUIRE_TRUE(wShape == wCorrectShape, 0, "SRU_BI operation: wrong shape of weights array, expected is %s, but got %s instead !", wCorrectShape.c_str(), wShape.c_str());
REQUIRE_TRUE(bShape == bCorrectShape, 0, "SRU_BI operation: wrong shape of biases array, expected is %s, but got %s instead !", bCorrectShape.c_str(), bShape.c_str());
REQUIRE_TRUE(c0Shape == c0CorrectShape, 0, "SRU_BI operation: wrong shape of initial state array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), c0Shape.c_str());
if(maskShapeInfo) {
const std::string maskShape = ShapeUtils::shapeAsString(maskShapeInfo);
REQUIRE_TRUE(maskShape == c0CorrectShape, 0, "SRU_BI operation: wrong shape of mask array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), maskShape.c_str());
}
char order = shape::order(xShapeInfo);
@ -453,23 +434,17 @@ CUSTOM_OP_IMPL(sru_bi_bp, 8, 4, true, 0, 0) {
if(mask)
REQUIRE_TRUE(mask->rankOf() == rank-1, 0, "SRU_BI_BP operation: wrong rank of mask array, expected is %i, but got %i instead !", rank-1, mask->rankOf());
const std::string wShape = ShapeUtils::shapeAsString(w);
const std::string wCorrectShape = ShapeUtils::shapeAsString({2*inSize, 6*inSize});
const std::string bShape = ShapeUtils::shapeAsString(b);
const std::string bCorrectShape = ShapeUtils::shapeAsString({4*inSize});
const std::string c0Shape = ShapeUtils::shapeAsString(c0);
const std::string c0CorrectShape = ShapeUtils::shapeAsString({bS, 2*inSize});
const std::string ctShape = ShapeUtils::shapeAsString(ct);
const std::string ctCorrectShape = ShapeUtils::shapeAsString({time, bS, 2*inSize});
const std::vector<Nd4jLong> wCorrectShape = {2*inSize, 6*inSize};
const std::vector<Nd4jLong> bCorrectShape = {4*inSize};
const std::vector<Nd4jLong> c0CorrectShape = {bS, 2*inSize};
const std::vector<Nd4jLong> ctCorrectShape = {time, bS, 2*inSize};
REQUIRE_TRUE(wShape == wCorrectShape, 0, "SRU_BI operation: wrong shape of weights array, expected is %s, but got %s instead !", wCorrectShape.c_str(), wShape.c_str());
REQUIRE_TRUE(bShape == bCorrectShape, 0, "SRU_BI operation: wrong shape of biases array, expected is %s, but got %s instead !", bCorrectShape.c_str(), bShape.c_str());
REQUIRE_TRUE(c0Shape == c0CorrectShape, 0, "SRU_BI operation: wrong shape of initial state array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), c0Shape.c_str());
REQUIRE_TRUE(ctShape == ctCorrectShape, 0, "SRU_BI operation: wrong shape of state array, expected is %s, but got %s instead !", ctCorrectShape.c_str(), ctShape.c_str());
if(mask) {
const std::string maskShape = ShapeUtils::shapeAsString(mask);
REQUIRE_TRUE(maskShape == c0CorrectShape, 0, "SRU_BI operation: wrong shape of mask array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), maskShape.c_str());
}
REQUIRE_TRUE(w->isSameShape(wCorrectShape), 0, "SRU_BI operation: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wCorrectShape).c_str(), ShapeUtils::shapeAsString(w).c_str());
REQUIRE_TRUE(b->isSameShape(bCorrectShape), 0, "SRU_BI operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bCorrectShape).c_str(), ShapeUtils::shapeAsString(b).c_str());
REQUIRE_TRUE(c0->isSameShape(c0CorrectShape), 0, "SRU_BI operation: wrong shape of initial state array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(c0CorrectShape).c_str(), ShapeUtils::shapeAsString(c0).c_str());
REQUIRE_TRUE(ct->isSameShape(ctCorrectShape), 0, "SRU_BI operation: wrong shape of state array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(ctCorrectShape).c_str(), ShapeUtils::shapeAsString(ct).c_str());
if(mask)
REQUIRE_TRUE(mask->isSameShape(c0CorrectShape), 0, "SRU_BI operation: wrong shape of mask array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(c0CorrectShape).c_str(), ShapeUtils::shapeAsString(mask).c_str());
auto gradI = OUTPUT_VARIABLE(0); // [time x bS x 2*inSize]
auto gradW = OUTPUT_VARIABLE(1); // [time x 2*inSize x 6*inSize]
@ -507,29 +482,21 @@ DECLARE_SHAPE_FN(sru_bi_bp) {
if(maskShapeInfo)
REQUIRE_TRUE(maskShapeInfo[0] == rank-1, 0, "SRU_BI_BP operation: wrong rank of mask array, expected is %i, but got %i instead !", rank-1, maskShapeInfo[0]);
const std::string wShape = ShapeUtils::shapeAsString(wShapeInfo);
const std::string wCorrectShape = ShapeUtils::shapeAsString({2*inSize, 6*inSize});
const std::string bShape = ShapeUtils::shapeAsString(bShapeInfo);
const std::string bCorrectShape = ShapeUtils::shapeAsString({4*inSize});
const std::string c0Shape = ShapeUtils::shapeAsString(c0ShapeInfo);
const std::string c0CorrectShape = ShapeUtils::shapeAsString({bS, 2*inSize});
const std::string ctShape = ShapeUtils::shapeAsString(ctShapeInfo);
const std::string ctCorrectShape = ShapeUtils::shapeAsString({time, bS, 2*inSize});
const std::string inGradC0Shape = ShapeUtils::shapeAsString(inGradC0ShapeInfo);
const std::string inGradC0CorrectShape = ShapeUtils::shapeAsString({bS, 2*inSize});
const std::string inGradHtShape = ShapeUtils::shapeAsString(inGradHtShapeInfo);
const std::string inGradHtCorrectShape = ShapeUtils::shapeAsString({time, bS, 2*inSize});
const std::vector<Nd4jLong> wCorrectShape = {2*inSize, 6*inSize};
const std::vector<Nd4jLong> bCorrectShape = {4*inSize};
const std::vector<Nd4jLong> c0CorrectShape = {bS, 2*inSize};
const std::vector<Nd4jLong> ctCorrectShape = {time, bS, 2*inSize};
const std::vector<Nd4jLong> inGradC0CorrectShape = {bS, 2*inSize};
const std::vector<Nd4jLong> inGradHtCorrectShape = {time, bS, 2*inSize};
REQUIRE_TRUE(wShape == wCorrectShape, 0, "SRU_BI operation: wrong shape of weights array, expected is %s, but got %s instead !", wCorrectShape.c_str(), wShape.c_str());
REQUIRE_TRUE(bShape == bCorrectShape, 0, "SRU_BI operation: wrong shape of biases array, expected is %s, but got %s instead !", bCorrectShape.c_str(), bShape.c_str());
REQUIRE_TRUE(c0Shape == c0CorrectShape, 0, "SRU_BI operation: wrong shape of initial state array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), c0Shape.c_str());
REQUIRE_TRUE(ctShape == ctCorrectShape, 0, "SRU_BI operation: wrong shape of state array, expected is %s, but got %s instead !", ctCorrectShape.c_str(), ctShape.c_str());
REQUIRE_TRUE(inGradC0Shape == inGradC0CorrectShape, 0, "SRU_BI operation: wrong shape of gradient c0 array, expected is %s, but got %s instead !", inGradC0CorrectShape.c_str(), inGradC0Shape.c_str());
REQUIRE_TRUE(inGradHtShape == inGradHtCorrectShape, 0, "SRU_BI operation: wrong shape of gradient ht array, expected is %s, but got %s instead !", inGradHtCorrectShape.c_str(), inGradHtShape.c_str());
if(maskShapeInfo) {
const std::string maskShape = ShapeUtils::shapeAsString(maskShapeInfo);
REQUIRE_TRUE(maskShape == c0CorrectShape, 0, "SRU_BI operation: wrong shape of mask array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), maskShape.c_str());
}
REQUIRE_TRUE(ShapeUtils::areShapesEqual(wShapeInfo, wCorrectShape), 0, "SRU_BI operation: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wCorrectShape).c_str(), ShapeUtils::shapeAsString(wShapeInfo).c_str());
REQUIRE_TRUE(ShapeUtils::areShapesEqual(bShapeInfo, bCorrectShape), 0, "SRU_BI operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bCorrectShape).c_str(), ShapeUtils::shapeAsString(bShapeInfo).c_str());
REQUIRE_TRUE(ShapeUtils::areShapesEqual(c0ShapeInfo, c0CorrectShape), 0, "SRU_BI operation: wrong shape of initial state array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(c0CorrectShape).c_str(), ShapeUtils::shapeAsString(c0ShapeInfo).c_str());
REQUIRE_TRUE(ShapeUtils::areShapesEqual(ctShapeInfo, ctCorrectShape), 0, "SRU_BI operation: wrong shape of state array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(ctCorrectShape).c_str(), ShapeUtils::shapeAsString(ctShapeInfo).c_str());
REQUIRE_TRUE(ShapeUtils::areShapesEqual(inGradC0ShapeInfo, inGradC0CorrectShape), 0, "SRU_BI operation: wrong shape of gradient c0 array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(inGradC0CorrectShape).c_str(), ShapeUtils::shapeAsString(inGradC0ShapeInfo).c_str());
REQUIRE_TRUE(ShapeUtils::areShapesEqual(inGradHtShapeInfo, inGradHtCorrectShape), 0, "SRU_BI operation: wrong shape of gradient ht array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(inGradHtCorrectShape).c_str(), ShapeUtils::shapeAsString(inGradHtShapeInfo).c_str());
if(maskShapeInfo)
REQUIRE_TRUE(ShapeUtils::areShapesEqual(maskShapeInfo, c0CorrectShape), 0, "SRU_BI operation: wrong shape of mask array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(c0CorrectShape).c_str(), ShapeUtils::shapeAsString(maskShapeInfo).c_str());
const char order = shape::order(xShapeInfo);

View File

@ -31,7 +31,7 @@ namespace ops {
//////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(sruCell, 4, 2, false, 0, 0) {
auto xt = INPUT_VARIABLE(0); // input [bS x inSize], bS - batch size, inSize - number of features
auto xt = INPUT_VARIABLE(0); // input [bS x inSize], bS - batch size, inSize - number of features
auto ct_1 = INPUT_VARIABLE(1); // previous cell state ct [bS x inSize], that is at previous time step t-1
auto w = INPUT_VARIABLE(2); // weights [inSize x 3*inSize]
auto b = INPUT_VARIABLE(3); // biases [2*inSize]
@ -40,25 +40,22 @@ CUSTOM_OP_IMPL(sruCell, 4, 2, false, 0, 0) {
auto ct = OUTPUT_VARIABLE(1); // current cell state [bS x inSize], that is at current time step t
const int rank = xt->rankOf();
const int bS = xt->sizeAt(0);
const int bS = xt->sizeAt(0);
const int inSize = xt->sizeAt(1); // inSize - number of features
// input shapes validation
const std::string ct_1Shape = ShapeUtils::shapeAsString(ct_1);
const std::string correctCt_1Shape = ShapeUtils::shapeAsString({bS, inSize});
const std::string WShape = ShapeUtils::shapeAsString(w);
const std::string correctWShape = ShapeUtils::shapeAsString({inSize, 3*inSize});
const std::string bShape = ShapeUtils::shapeAsString(b);
const std::string correctBShape = ShapeUtils::shapeAsString({2*inSize});
const std::vector<Nd4jLong> correctCt_1Shape = {bS, inSize};
const std::vector<Nd4jLong> correctWShape = {inSize, 3*inSize};
const std::vector<Nd4jLong> correctBShape = {2*inSize};
REQUIRE_TRUE(correctCt_1Shape == ct_1Shape, 0, "SRUCELL operation: wrong shape of previous cell state, expected is %s, but got %s instead !", correctCt_1Shape.c_str(), ct_1Shape.c_str());
REQUIRE_TRUE(correctWShape == WShape, 0, "SRUCELL operation: wrong shape of weights, expected is %s, but got %s instead !", correctWShape.c_str(), WShape.c_str());
REQUIRE_TRUE(correctBShape == bShape, 0, "SRUCELL operation: wrong shape of biases, expected is %s, but got %s instead !", correctBShape.c_str(), bShape.c_str());
REQUIRE_TRUE(ct_1->isSameShape(correctCt_1Shape), 0, "SRUCELL operation: wrong shape of previous cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctCt_1Shape).c_str(), ShapeUtils::shapeAsString(ct_1).c_str());
REQUIRE_TRUE(w->isSameShape(correctWShape), 0, "SRUCELL operation: wrong shape of weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctWShape).c_str(), ShapeUtils::shapeAsString(w).c_str());
REQUIRE_TRUE(b->isSameShape(correctBShape), 0, "SRUCELL operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctBShape).c_str(), ShapeUtils::shapeAsString(b).c_str());
// fixme: shitty initializer lists
helpers::sruCell(block.launchContext(), xt, ct_1, w, b, ht, ct);
return Status::OK();
}
@ -71,40 +68,37 @@ CUSTOM_OP_IMPL(sruCell, 4, 2, false, 0, 0) {
DECLARE_SHAPE_FN(sruCell) {
auto xtShapeInfo = inputShape->at(0); // input [bS x inSize], bS - batch size, inSize - number of features
auto ct_1ShapeInfo = inputShape->at(1); // previous cell state ct [bS x inSize], that is at previous time step t-1
auto ct_1ShapeInfo = inputShape->at(1); // previous cell state ct [bS x inSize], that is at previous time step t-1
auto wShapeInfo = inputShape->at(2); // weights [inSize x 3*inSize]
auto bShapeInfo = inputShape->at(3); // biases [2*inSize]
const int rank = xtShapeInfo[0];
const int bS = xtShapeInfo[1];
const int bS = xtShapeInfo[1];
const int inSize = xtShapeInfo[2]; // inSize - number of features
// input shapes validation
const std::string ct_1Shape = ShapeUtils::shapeAsString(ct_1ShapeInfo);
const std::string correctCt_1Shape = ShapeUtils::shapeAsString({bS, inSize});
const std::string WShape = ShapeUtils::shapeAsString(wShapeInfo);
const std::string correctWShape = ShapeUtils::shapeAsString({inSize, 3*inSize});
const std::string bShape = ShapeUtils::shapeAsString(bShapeInfo);
const std::string correctBShape = ShapeUtils::shapeAsString({2*inSize});
const std::vector<Nd4jLong> correctCt_1Shape = {bS, inSize};
const std::vector<Nd4jLong> correctWShape = {inSize, 3*inSize};
const std::vector<Nd4jLong> correctBShape = {2*inSize};
REQUIRE_TRUE(ShapeUtils::areShapesEqual(ct_1ShapeInfo, correctCt_1Shape) , 0, "SRUCELL operation: wrong shape of previous cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctCt_1Shape).c_str(), ShapeUtils::shapeAsString(ct_1ShapeInfo).c_str());
REQUIRE_TRUE(ShapeUtils::areShapesEqual(wShapeInfo ,correctWShape), 0, "SRUCELL operation: wrong shape of weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctWShape).c_str(), ShapeUtils::shapeAsString(wShapeInfo).c_str());
REQUIRE_TRUE(ShapeUtils::areShapesEqual(bShapeInfo ,correctBShape), 0, "SRUCELL operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctBShape).c_str(), ShapeUtils::shapeAsString(bShapeInfo).c_str());
REQUIRE_TRUE(correctCt_1Shape == ct_1Shape, 0, "SRUCELL operation: wrong shape of previous cell state, expected is %s, but got %s instead !", correctCt_1Shape.c_str(), ct_1Shape.c_str());
REQUIRE_TRUE(correctWShape == WShape, 0, "SRUCELL operation: wrong shape of weights, expected is %s, but got %s instead !", correctWShape.c_str(), WShape.c_str());
REQUIRE_TRUE(correctBShape == bShape, 0, "SRUCELL operation: wrong shape of biases, expected is %s, but got %s instead !", correctBShape.c_str(), bShape.c_str());
// evaluate output shapeInfos
Nd4jLong *hShapeInfo(nullptr), *cShapeInfo(nullptr);
ALLOCATE(hShapeInfo, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong); // [bS x numProj]
ALLOCATE(cShapeInfo, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong); // [bS x numUnits]
hShapeInfo[0] = cShapeInfo[0] = rank;
hShapeInfo[1] = cShapeInfo[1] = bS;
hShapeInfo[2] = cShapeInfo[2] = inSize;
ShapeUtils::updateStridesAndType(hShapeInfo, ct_1ShapeInfo, shape::order(ct_1ShapeInfo));
ShapeUtils::updateStridesAndType(cShapeInfo, ct_1ShapeInfo, shape::order(ct_1ShapeInfo));
return SHAPELIST(ConstantShapeHelper::getInstance()->createFromExisting(hShapeInfo, block.workspace()), ConstantShapeHelper::getInstance()->createFromExisting(cShapeInfo, block.workspace()));
}
}

View File

@ -55,14 +55,14 @@ CUSTOM_OP_IMPL(static_bidirectional_rnn, 7, 3, false, 0, 0) {
maxTimeStep = INPUT_VARIABLE(9);
break;
}
auto h = OUTPUT_VARIABLE(0); // cell outputs [time x bS x (numUnitsFW + numUnitsBW)], that is per each time step
auto hFWFinal = OUTPUT_VARIABLE(1); // final cell out for forward RNN [bS x numUnitsFW]
auto hBWFinal = OUTPUT_VARIABLE(2); // final cell out for backward RNN [bS x numUnitsBF]
REQUIRE_TRUE(x->rankOf() == 3, 0, "STATIC_BIDIRECTIONAL_RNN custom operation: input array must have rank = 3, but got %i instead !", x->rankOf());
REQUIRE_TRUE(WxFW->rankOf() == 2, 0, "STATIC_BIDIRECTIONAL_RNN custom operation: input-to-hidden weights array (for forward RNN) must have rank = 2, but got %i instead !", WxFW->rankOf());
REQUIRE_TRUE(WxBW->rankOf() == 2, 0, "STATIC_BIDIRECTIONAL_RNN custom operation: input-to-hidden weights array (for backward RNN) must have rank = 2, but got %i instead !", WxBW->rankOf());
REQUIRE_TRUE(WxFW->rankOf() == 2, 0, "STATIC_BIDIRECTIONAL_RNN custom operation: input-to-hidden weights array (for forward RNN) must have rank = 2, but got %i instead !", WxFW->rankOf());
REQUIRE_TRUE(WxBW->rankOf() == 2, 0, "STATIC_BIDIRECTIONAL_RNN custom operation: input-to-hidden weights array (for backward RNN) must have rank = 2, but got %i instead !", WxBW->rankOf());
const Nd4jLong inRank = x->rankOf();
const Nd4jLong time = x->sizeAt(0);
@ -70,39 +70,48 @@ CUSTOM_OP_IMPL(static_bidirectional_rnn, 7, 3, false, 0, 0) {
const Nd4jLong numUnitsFW = WxFW->sizeAt(1);
const Nd4jLong numUnitsBW = WxBW->sizeAt(1);
REQUIRE_TRUE(ShapeUtils::shapeAsString(WhFW) == ShapeUtils::shapeAsString({numUnitsFW, numUnitsFW}), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of hidden-to-hidden weights array (for forward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString({numUnitsFW, numUnitsFW}).c_str(), ShapeUtils::shapeAsString(WhFW).c_str());
REQUIRE_TRUE(ShapeUtils::shapeAsString(WhBW) == ShapeUtils::shapeAsString({numUnitsBW, numUnitsBW}), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of hidden-to-hidden weights array (for backward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString({numUnitsBW, numUnitsBW}).c_str(), ShapeUtils::shapeAsString(WhBW).c_str());
REQUIRE_TRUE(ShapeUtils::shapeAsString(bFW) == ShapeUtils::shapeAsString({2*numUnitsFW}), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of biases array (for forward RNN), expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2*numUnitsFW}).c_str(), ShapeUtils::shapeAsString(bFW).c_str());
REQUIRE_TRUE(ShapeUtils::shapeAsString(bBW) == ShapeUtils::shapeAsString({2*numUnitsBW}), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of biases array (for backward RNN), expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2*numUnitsBW}).c_str(), ShapeUtils::shapeAsString(bBW).c_str());
if(h0FW)
REQUIRE_TRUE(ShapeUtils::shapeAsString(h0FW) == ShapeUtils::shapeAsString({bS, numUnitsFW}), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of initial cell output array (for forward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString({bS, numUnitsFW}).c_str(), ShapeUtils::shapeAsString(h0FW).c_str());
if(h0BW)
REQUIRE_TRUE(ShapeUtils::shapeAsString(h0BW) == ShapeUtils::shapeAsString({bS, numUnitsBW}), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of initial cell output array (for backward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString({bS, numUnitsBW}).c_str(), ShapeUtils::shapeAsString(h0BW).c_str());
if(maxTimeStep)
REQUIRE_TRUE(ShapeUtils::shapeAsString(maxTimeStep) == ShapeUtils::shapeAsString({bS}), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of maxTimeStep array, expected is [%i], but got %s instead !", bS, ShapeUtils::shapeAsString(maxTimeStep).c_str());
const std::vector<Nd4jLong> expectedWhFWshape = {numUnitsFW, numUnitsFW};
const std::vector<Nd4jLong> expectedWhBWshape = {numUnitsBW, numUnitsBW};
const std::vector<Nd4jLong> expectedbFWshape = {2 * numUnitsFW};
const std::vector<Nd4jLong> expectedbBWshape = {2 * numUnitsBW};
// forward steps
REQUIRE_TRUE(WhFW->isSameShape(expectedWhFWshape), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of hidden-to-hidden weights array (for forward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedWhFWshape).c_str(), ShapeUtils::shapeAsString(WhFW).c_str());
REQUIRE_TRUE(WhBW->isSameShape(expectedWhBWshape), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of hidden-to-hidden weights array (for backward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedWhBWshape).c_str(), ShapeUtils::shapeAsString(WhBW).c_str());
REQUIRE_TRUE(bFW->isSameShape(expectedbFWshape), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of biases array (for forward RNN), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedbFWshape).c_str(), ShapeUtils::shapeAsString(bFW).c_str());
REQUIRE_TRUE(bBW->isSameShape(expectedbBWshape), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of biases array (for backward RNN), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedbBWshape).c_str(), ShapeUtils::shapeAsString(bBW).c_str());
if(h0FW) {
const std::vector<Nd4jLong> expectedh0FWshape = {bS, numUnitsFW};
REQUIRE_TRUE(h0FW->isSameShape(expectedh0FWshape), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of initial cell output array (for forward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedh0FWshape).c_str(), ShapeUtils::shapeAsString(h0FW).c_str());
}
if(h0BW) {
const std::vector<Nd4jLong> expectedh0BWshape = {bS, numUnitsBW};
REQUIRE_TRUE(h0BW->isSameShape(expectedh0BWshape), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of initial cell output array (for backward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedh0BWshape).c_str(), ShapeUtils::shapeAsString(h0BW).c_str());
}
if(maxTimeStep)
REQUIRE_TRUE(maxTimeStep->isSameShape({bS}), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of maxTimeStep array, expected is [%i], but got %s instead !", bS, ShapeUtils::shapeAsString(maxTimeStep).c_str());
// forward steps
auto hFW = new NDArray(x->ordering(), {time, bS, numUnitsFW}, x->dataType(), block.launchContext());
helpers::rnnTimeLoop(block.launchContext(), x, WxFW, WhFW, bFW, h0FW, maxTimeStep, hFW, hFWFinal);
auto seqLen = maxTimeStep;
auto seqLen = maxTimeStep;
if(seqLen == nullptr) {
// seqLen = new NDArray(x->ordering(), {x->sizeAt(1)}, x->dataType(), block.launchContext()); // [bS]
seqLen = new NDArray(x->ordering(), {x->sizeAt(1)}, sd::DataType::INT64, block.launchContext()); // [bS]
*seqLen = x->sizeAt(0); // set each element of seqLen to be equal to time
}
// reverse x
}
// reverse x
auto revOut = new NDArray(x, false, block.launchContext());
helpers::reverseSequence(block.launchContext(), x, seqLen, revOut, 0, 1);
// backward steps
// backward steps
auto hBW = new NDArray(x->ordering(), {time, bS, numUnitsBW}, x->dataType(), block.launchContext());
helpers::rnnTimeLoop(block.launchContext(), revOut, WxBW, WhBW, bBW, h0BW, maxTimeStep, hBW, hBWFinal);
// reverse hBW
auto hBWcopy = new NDArray(*hBW);
// reverse hBW
auto hBWcopy = new NDArray(*hBW);
helpers::reverseSequence(block.launchContext(), hBWcopy, seqLen, hBW, 0, 1);
// concatenate hFW and hBW along last third dimension
@ -117,7 +126,7 @@ CUSTOM_OP_IMPL(static_bidirectional_rnn, 7, 3, false, 0, 0) {
if(seqLen != maxTimeStep)
delete seqLen;
return Status::OK();
}
@ -128,7 +137,7 @@ CUSTOM_OP_IMPL(static_bidirectional_rnn, 7, 3, false, 0, 0) {
}
DECLARE_SHAPE_FN(static_bidirectional_rnn) {
DECLARE_SHAPE_FN(static_bidirectional_rnn) {
auto xShapeInfo = inputShape->at(0); // input [time x bS x inSize]
auto WxFWShapeInfo = inputShape->at(1); // input-to-hidden weights for forward RNN, [inSize x numUnitsFW]
@ -167,16 +176,25 @@ DECLARE_SHAPE_FN(static_bidirectional_rnn) {
const int numUnitsFW = WxFWShapeInfo[2];
const int numUnitsBW = WxBWShapeInfo[2];
REQUIRE_TRUE(ShapeUtils::shapeAsString(WhFWShapeInfo) == ShapeUtils::shapeAsString({numUnitsFW, numUnitsFW}), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of hidden-to-hidden weights array (for forward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString({numUnitsFW, numUnitsFW}).c_str(), ShapeUtils::shapeAsString(WhFWShapeInfo).c_str());
REQUIRE_TRUE(ShapeUtils::shapeAsString(WhBWShapeInfo) == ShapeUtils::shapeAsString({numUnitsBW, numUnitsBW}), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of hidden-to-hidden weights array (for backward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString({numUnitsBW, numUnitsBW}).c_str(), ShapeUtils::shapeAsString(WhBWShapeInfo).c_str());
REQUIRE_TRUE(ShapeUtils::shapeAsString(bFWShapeInfo) == ShapeUtils::shapeAsString({2*numUnitsFW}), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of biases array (for forward RNN), expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2*numUnitsFW}).c_str(), ShapeUtils::shapeAsString(bFWShapeInfo).c_str());
REQUIRE_TRUE(ShapeUtils::shapeAsString(bBWShapeInfo) == ShapeUtils::shapeAsString({2*numUnitsBW}), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of biases array (for backward RNN), expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2*numUnitsBW}).c_str(), ShapeUtils::shapeAsString(bBWShapeInfo).c_str());
if(h0FWShapeInfo)
REQUIRE_TRUE(ShapeUtils::shapeAsString(h0FWShapeInfo) == ShapeUtils::shapeAsString({bS, numUnitsFW}), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of initial cell output array (for forward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString({bS, numUnitsFW}).c_str(), ShapeUtils::shapeAsString(h0FWShapeInfo).c_str());
if(h0BWShapeInfo)
REQUIRE_TRUE(ShapeUtils::shapeAsString(h0BWShapeInfo) == ShapeUtils::shapeAsString({bS, numUnitsBW}), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of initial cell output array (for backward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString({bS, numUnitsBW}).c_str(), ShapeUtils::shapeAsString(h0BWShapeInfo).c_str());
const std::vector<Nd4jLong> expectedWhFWshape = {numUnitsFW, numUnitsFW};
const std::vector<Nd4jLong> expectedWhBWshape = {numUnitsBW, numUnitsBW};
const std::vector<Nd4jLong> expectedbFWshape = {2 * numUnitsFW};
const std::vector<Nd4jLong> expectedbBWshape = {2 * numUnitsBW};
REQUIRE_TRUE(ShapeUtils::areShapesEqual(WhFWShapeInfo, expectedWhFWshape), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of hidden-to-hidden weights array (for forward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedWhFWshape).c_str(), ShapeUtils::shapeAsString(WhFWShapeInfo).c_str());
REQUIRE_TRUE(ShapeUtils::areShapesEqual(WhBWShapeInfo, expectedWhBWshape), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of hidden-to-hidden weights array (for backward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedWhBWshape).c_str(), ShapeUtils::shapeAsString(WhBWShapeInfo).c_str());
REQUIRE_TRUE(ShapeUtils::areShapesEqual(bFWShapeInfo, expectedbFWshape), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of biases array (for forward RNN), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedbFWshape).c_str(), ShapeUtils::shapeAsString(bFWShapeInfo).c_str());
REQUIRE_TRUE(ShapeUtils::areShapesEqual(bBWShapeInfo, expectedbBWshape), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of biases array (for backward RNN), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedbBWshape).c_str(), ShapeUtils::shapeAsString(bBWShapeInfo).c_str());
if(h0FWShapeInfo) {
const std::vector<Nd4jLong> expectedh0FWshape = {bS, numUnitsFW};
REQUIRE_TRUE(ShapeUtils::areShapesEqual(h0FWShapeInfo, expectedh0FWshape), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of initial cell output array (for forward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedh0FWshape).c_str(), ShapeUtils::shapeAsString(h0FWShapeInfo).c_str());
}
if(h0BWShapeInfo) {
const std::vector<Nd4jLong> expectedh0BWshape = {bS, numUnitsBW};
REQUIRE_TRUE(ShapeUtils::areShapesEqual(h0BWShapeInfo, expectedh0BWshape), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of initial cell output array (for backward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedh0BWshape).c_str(), ShapeUtils::shapeAsString(h0BWShapeInfo).c_str());
}
if(maxTimeStepShapeInfo)
REQUIRE_TRUE(ShapeUtils::shapeAsString(maxTimeStepShapeInfo) == ShapeUtils::shapeAsString({bS}), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of maxTimeStep array, expected is [%i], but got %s instead !", bS, ShapeUtils::shapeAsString(maxTimeStepShapeInfo).c_str());
REQUIRE_TRUE(ShapeUtils::areShapesEqual(maxTimeStepShapeInfo, {bS}), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of maxTimeStep array, expected is [%i], but got %s instead !", bS, ShapeUtils::shapeAsString(maxTimeStepShapeInfo).c_str());
// evaluate output shapeInfos
Nd4jLong *hShapeInfo(nullptr), *hFWFinalPrevShapeInfo(nullptr), *hBWFinalPrevShapeInfo(nullptr);
@ -195,9 +213,9 @@ DECLARE_SHAPE_FN(static_bidirectional_rnn) {
ShapeUtils::updateStridesAndType(hShapeInfo, xShapeInfo, shape::order(xShapeInfo));
ShapeUtils::updateStridesAndType(hFWFinalPrevShapeInfo, xShapeInfo, shape::order(xShapeInfo));
ShapeUtils::updateStridesAndType(hBWFinalPrevShapeInfo, xShapeInfo, shape::order(xShapeInfo));
return SHAPELIST(CONSTANT(hShapeInfo), CONSTANT(hFWFinalPrevShapeInfo), CONSTANT(hBWFinalPrevShapeInfo));
}
}

View File

@ -58,12 +58,17 @@ CUSTOM_OP_IMPL(static_rnn, 4, 2, false, 0, 0) {
const int inSize = x->sizeAt(2);
const int numUnits = Wx->sizeAt(1);
REQUIRE_TRUE(ShapeUtils::shapeAsString(Wh) == ShapeUtils::shapeAsString({numUnits, numUnits}), 0, "STATIC_RNN custom operation: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({numUnits, numUnits}).c_str(), ShapeUtils::shapeAsString(Wh).c_str());
REQUIRE_TRUE(ShapeUtils::shapeAsString(b) == ShapeUtils::shapeAsString({2*numUnits}), 0, "STATIC_RNN custom operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2*numUnits}).c_str(), ShapeUtils::shapeAsString(b).c_str());
if(h0)
REQUIRE_TRUE(ShapeUtils::shapeAsString(h0) == ShapeUtils::shapeAsString({bS, numUnits}), 0, "STATIC_RNN custom operation: wrong shape of initial cell output array, expected is %s but got %s instead !", ShapeUtils::shapeAsString({bS, numUnits}).c_str(), ShapeUtils::shapeAsString(h0).c_str());
const std::vector<Nd4jLong> expectedWhShape = {numUnits, numUnits};
const std::vector<Nd4jLong> expectedbShape = {2 * numUnits};
REQUIRE_TRUE(Wh->isSameShape(expectedWhShape), 0, "STATIC_RNN custom operation: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWhShape).c_str(), ShapeUtils::shapeAsString(Wh).c_str());
REQUIRE_TRUE(b->isSameShape(expectedbShape), 0, "STATIC_RNN custom operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedbShape).c_str(), ShapeUtils::shapeAsString(b).c_str());
if(h0) {
const std::vector<Nd4jLong> expectedh0Shape = {bS, numUnits};
REQUIRE_TRUE(h0->isSameShape(expectedh0Shape), 0, "STATIC_RNN custom operation: wrong shape of initial cell output array, expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedh0Shape).c_str(), ShapeUtils::shapeAsString(h0).c_str());
}
if(maxTimeStep)
REQUIRE_TRUE(ShapeUtils::shapeAsString(maxTimeStep) == ShapeUtils::shapeAsString({bS}), 0, "STATIC_RNN custom operation: wrong shape of maxTimeStep array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS}).c_str(), ShapeUtils::shapeAsString(maxTimeStep).c_str());
REQUIRE_TRUE(maxTimeStep->isSameShape({bS}), 0, "STATIC_RNN custom operation: wrong shape of maxTimeStep array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS}).c_str(), ShapeUtils::shapeAsString(maxTimeStep).c_str());
helpers::rnnTimeLoop(block.launchContext(), x, Wx, Wh, b, h0, maxTimeStep, h, hFinal);
@ -107,12 +112,17 @@ DECLARE_SHAPE_FN(static_rnn) {
const int bS = xShapeInfo[2];
const int numUnits = WxShapeInfo[2];
REQUIRE_TRUE(ShapeUtils::shapeAsString(WhShapeInfo) == ShapeUtils::shapeAsString({numUnits, numUnits}), 0, "STATIC_RNN custom operation: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({numUnits, numUnits}).c_str(), ShapeUtils::shapeAsString(WhShapeInfo).c_str());
REQUIRE_TRUE(ShapeUtils::shapeAsString(bShapeInfo) == ShapeUtils::shapeAsString({2*numUnits}), 0, "STATIC_RNN custom operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2*numUnits}).c_str(), ShapeUtils::shapeAsString(bShapeInfo).c_str());
if(h0ShapeInfo)
REQUIRE_TRUE(ShapeUtils::shapeAsString(h0ShapeInfo) == ShapeUtils::shapeAsString({bS, numUnits}), 0, "STATIC_RNN custom operation: wrong shape of initial cell output array, expected is %s but got %s instead !", ShapeUtils::shapeAsString({bS, numUnits}).c_str(), ShapeUtils::shapeAsString(h0ShapeInfo).c_str());
const std::vector<Nd4jLong> expectedWhShape = {numUnits, numUnits};
const std::vector<Nd4jLong> expectedbShape = {2 * numUnits};
REQUIRE_TRUE(ShapeUtils::areShapesEqual(WhShapeInfo, expectedWhShape), 0, "STATIC_RNN custom operation: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWhShape).c_str(), ShapeUtils::shapeAsString(WhShapeInfo).c_str());
REQUIRE_TRUE(ShapeUtils::areShapesEqual(bShapeInfo, expectedbShape), 0, "STATIC_RNN custom operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedbShape).c_str(), ShapeUtils::shapeAsString(bShapeInfo).c_str());
if(h0ShapeInfo){
const std::vector<Nd4jLong> expectedh0Shape = {bS, numUnits};
REQUIRE_TRUE(ShapeUtils::areShapesEqual(h0ShapeInfo, expectedh0Shape), 0, "STATIC_RNN custom operation: wrong shape of initial cell output array, expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedh0Shape).c_str(), ShapeUtils::shapeAsString(h0ShapeInfo).c_str());
}
if(maxTimeStepShapeInfo)
REQUIRE_TRUE(ShapeUtils::shapeAsString(maxTimeStepShapeInfo) == ShapeUtils::shapeAsString({bS}), 0, "STATIC_RNN custom operation: wrong shape of maxTimeStep array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS}).c_str(), ShapeUtils::shapeAsString(maxTimeStepShapeInfo).c_str());
REQUIRE_TRUE(ShapeUtils::areShapesEqual(maxTimeStepShapeInfo, {bS}), 0, "STATIC_RNN custom operation: wrong shape of maxTimeStep array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS}).c_str(), ShapeUtils::shapeAsString(maxTimeStepShapeInfo).c_str());
// evaluate output shapeInfos
Nd4jLong *hShapeInfo(nullptr), *hPrevShapeInfo(nullptr);

View File

@ -38,7 +38,7 @@ CUSTOM_OP_IMPL(concat, -1, 1, false, 0, 0) {
// first of all take into account possible presence of empty arrays
// also if scalar is present -> copy its value to vector with length=1
std::vector<NDArray*> nonEmptyArrs;
std::vector<const NDArray*> nonEmptyArrs;
std::vector<int> arrsToDelete;
int index = 0;
bool allOfSameType = true;

View File

@ -36,16 +36,16 @@ CUSTOM_OP_IMPL(pad, 2, 1, false, 0, 1) {
auto paddings = INPUT_VARIABLE(1);
auto output = OUTPUT_VARIABLE(0);
const int rank = input->rankOf();
const int rank = input->rankOf();
// input validation
std::string expectedPaddingsShape = ShapeUtils::shapeAsString({rank, 2});
std::string currentPaddingsShape = ShapeUtils::shapeAsString(paddings);
REQUIRE_TRUE(expectedPaddingsShape == currentPaddingsShape, 0, "PAD op: wrong shape of paddings array, expected is %s, but got %s instead !", expectedPaddingsShape.c_str(), currentPaddingsShape.c_str());
std::vector<Nd4jLong> expectedPaddingsShape = {rank, 2};
std::vector<Nd4jLong> currentPaddingsShape = paddings->getShapeAsVector();
REQUIRE_TRUE(expectedPaddingsShape == currentPaddingsShape, 0, "PAD op: wrong shape of paddings array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedPaddingsShape).c_str(), ShapeUtils::shapeAsString(currentPaddingsShape).c_str());
NDArray padValue(input->dataType(), block.launchContext());
// in case of REFLECT and SYMMETRIC modes paddings must obey additional shape requirements
// in case of REFLECT and SYMMETRIC modes paddings must obey additional shape requirements
if (INT_ARG(0) == 0) { // CONSTANT mode
if(block.width() > 2) {
REQUIRE_TRUE(input->dataType() == INPUT_VARIABLE(2)->dataType(), 0, "PAD op: data types of input and padValue arrays should be the same but got %i and %i correspondingly !", input->dataType(), INPUT_VARIABLE(2)->dataType());
@ -68,10 +68,10 @@ CUSTOM_OP_IMPL(pad, 2, 1, false, 0, 1) {
// std::vector<int> dimensions(input->rankOf());
// std::iota(dimensions.begin(), dimensions.end(), 0); // fill with 0, 1, ... rank-1
// helpers::recursiveLoopForPad(INT_ARG(0), *input, *paddings, *output, dimensions, 0, 0, 0, padValue);
helpers::pad(block.launchContext(), INT_ARG(0), *input, *paddings, *output, padValue);
return Status::OK();
}
@ -85,27 +85,26 @@ DECLARE_TYPES(pad) {
DECLARE_SHAPE_FN(pad) {
// check shape of paddings
// check shape of paddings
auto inputShapeInfo = inputShape->at(0);
auto paddings = INPUT_VARIABLE(1);
const int rank = inputShapeInfo[0];
const int rank = inputShapeInfo[0];
// paddings validation
std::string expectedPaddingsShape = ShapeUtils::shapeAsString({rank, 2});
std::string currentPaddingsShape = ShapeUtils::shapeAsString(paddings);
REQUIRE_TRUE(expectedPaddingsShape == currentPaddingsShape, 0, "PAD op: wrong shape of paddings array, expected is %s, but got %s instead !", expectedPaddingsShape.c_str(), currentPaddingsShape.c_str());
const std::vector<Nd4jLong> expectedPaddingsShape = {rank, 2};
const std::vector<Nd4jLong> currentPaddingsShape = paddings->getShapeAsVector();
REQUIRE_TRUE(expectedPaddingsShape == currentPaddingsShape, 0, "PAD op: wrong shape of paddings array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedPaddingsShape).c_str(), ShapeUtils::shapeAsString(currentPaddingsShape).c_str());
Nd4jLong* outShapeInfo = nullptr;
ALLOCATE(outShapeInfo, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong);
outShapeInfo[0] = rank;
for(int i=1; i <= rank; ++i)
outShapeInfo[i] = inputShapeInfo[i] + paddings->e<Nd4jLong>(i-1,0) + paddings->e<Nd4jLong>(i-1,1);
ShapeUtils::updateStridesAndType(outShapeInfo, inputShapeInfo, shape::order(inputShapeInfo));
ShapeDescriptor descriptor(outShapeInfo);
RELEASE(outShapeInfo, block.getWorkspace());
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(descriptor));
}

View File

@ -27,15 +27,15 @@ namespace sd {
namespace helpers {
//////////////////////////////////////////////////////////////////////////
template<typename T>
static void concat_(const std::vector<NDArray*>& inArrs, NDArray& output, const int axis) {
static void concat_(const std::vector<const NDArray*>& inArrs, NDArray& output, const int axis) {
sd::SpecialMethods<T>::concatCpuGeneric(inArrs, output, axis);
}
void concat(sd::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output, const int axis) {
void concat(sd::LaunchContext * context, const std::vector<const NDArray*>& inArrs, NDArray& output, const int axis) {
BUILD_SINGLE_SELECTOR(output.dataType(), concat_,(inArrs, output, axis), LIBND4J_TYPES);
}
BUILD_SINGLE_TEMPLATE(template void concat_, (const std::vector<NDArray*>& inArrs, NDArray& output, const int axis), LIBND4J_TYPES);
BUILD_SINGLE_TEMPLATE(template void concat_, (const std::vector<const NDArray*>& inArrs, NDArray& output, const int axis), LIBND4J_TYPES);
}
}
}

View File

@ -15,13 +15,14 @@
******************************************************************************/
//
// Created by Yurii Shyrma on 02.01.2018
// @author Yurii Shyrma (iuriish@yahoo.com)
//
#include <ops/declarable/helpers/stack.h>
#include <helpers/ShapeUtils.h>
#include <array/ResultSet.h>
#include <execution/Threads.h>
#include <helpers/ConstantTadHelper.h>
namespace sd {
@ -31,37 +32,90 @@ namespace helpers {
///////////////////////////////////////////////////////////////////
template <typename T>
static void stack_(const std::vector<const NDArray*>& inArrs, NDArray* outArr, const int dim) {
static void stack_(const std::vector<const NDArray*>& inArrs, NDArray& output, const int dim) {
const int numOfSubArrs = inArrs.size();
if(inArrs[0]->rankOf() == 0) {
int inSize = inArrs.size();
auto func = PRAGMA_THREADS_FOR {
for (auto i = start; i < stop; i++)
outArr->p<T>(i, inArrs[i]->t<T>(0));
output.p<T>(i, inArrs[i]->t<T>(0));
};
samediff::Threads::parallel_for(func, 0, inSize);
samediff::Threads::parallel_for(func, 0, numOfSubArrs);
}
else {
std::vector<int> dimsToExclude = ShapeUtils::evalDimsToExclude(outArr->rankOf(), {dim});
auto list = outArr->allTensorsAlongDimension(dimsToExclude); // list.size() == block.width()
int listSize = list.size();
auto zTadPack = ConstantTadHelper::getInstance()->tadForDimensions(output.getShapeInfo(), ShapeUtils::evalDimsToExclude(output.rankOf(), {dim}));
Nd4jLong* zTadShapeInfo = zTadPack.primaryShapeInfo();
auto func = PRAGMA_THREADS_FOR {
for (auto i = start; i < stop; i++) {
void* zBuff = output.bufferWithOffset(zTadPack.primaryOffsets()[i]);
NativeOpExecutioner::execTransformAny(inArrs[0]->getContext(), transform::Assign,
inArrs[i]->getBuffer(), inArrs[i]->getShapeInfo(), nullptr/*input specialBuffer*/, nullptr/*input specialShapeInfo*/,
zBuff, zTadShapeInfo, nullptr/*output specialBuffer*/, nullptr/*output specialShapeInfo*/,
nullptr, nullptr, nullptr, false/*allowParallelism*/);
}
};
samediff::Threads::parallel_tad(func, 0, numOfSubArrs);
}
}
////////////////////////////////////////////////////////////////////////
void stack(sd::LaunchContext * context, const std::vector<const NDArray*>& inArrs, NDArray& output, const int dim) {
BUILD_SINGLE_SELECTOR(output.dataType(), stack_, (inArrs, output, dim), LIBND4J_TYPES);
}
BUILD_SINGLE_TEMPLATE(template void stack_ , (const std::vector<const NDArray*>& inArrs, NDArray& output, const int dim), LIBND4J_TYPES);
///////////////////////////////////////////////////////////////////
template <typename T>
static void unstack_(const NDArray& input, const std::vector<NDArray*>& outArrs, const int dim) {
const int numOfSubArrs = outArrs.size();
if(outArrs[0]->rankOf() == 0) {
auto func = PRAGMA_THREADS_FOR {
for (auto i = start; i < stop; i++)
list.at(i)->assign(inArrs[i]);
outArrs[i]->p<T>(0, input.t<T>(i));
};
samediff::Threads::parallel_tad(func, 0, listSize);
samediff::Threads::parallel_for(func, 0, numOfSubArrs);
}
else {
auto xTadPack = ConstantTadHelper::getInstance()->tadForDimensions(input.getShapeInfo(), ShapeUtils::evalDimsToExclude(input.rankOf(), {dim}));
Nd4jLong* xTadShapeInfo = xTadPack.primaryShapeInfo();
auto func = PRAGMA_THREADS_FOR {
for (auto i = start; i < stop; i++) {
void* xBuff = input.bufferWithOffset(xTadPack.primaryOffsets()[i]);
NativeOpExecutioner::execTransformAny(input.getContext(), transform::Assign,
xBuff, xTadShapeInfo, nullptr/*input specialBuffer*/, nullptr/*input specialShapeInfo*/,
outArrs[i]->getBuffer(), outArrs[i]->getShapeInfo(), nullptr/*output specialBuffer*/, nullptr/*output specialShapeInfo*/,
nullptr, nullptr, nullptr, false/*allowParallelism*/);
}
};
samediff::Threads::parallel_tad(func, 0, numOfSubArrs);
}
}
void stack(sd::LaunchContext * context, const std::vector<const NDArray*>& inArrs, NDArray* outArr, const int dim) {
BUILD_SINGLE_SELECTOR(outArr->dataType(), stack_, (inArrs, outArr, dim), LIBND4J_TYPES);
}
BUILD_SINGLE_TEMPLATE(template void stack_ , (const std::vector<const NDArray*>& inArrs, NDArray* outArr, const int dim), LIBND4J_TYPES);
////////////////////////////////////////////////////////////////////////
void unstack(sd::LaunchContext* context, const NDArray& input, const std::vector<NDArray*>& outArrs, const int dim) {
BUILD_SINGLE_SELECTOR(input.dataType(), unstack_, (input, outArrs, dim), LIBND4J_TYPES);
}
BUILD_SINGLE_TEMPLATE(template void unstack_, (const NDArray& input, const std::vector<NDArray*>& outArrs, const int dim), LIBND4J_TYPES);
}
}

View File

@ -83,14 +83,12 @@ __host__ static void concatCudaLauncher(const int blocksPerGrid, const int threa
BUILD_SINGLE_TEMPLATE(template void concatCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, void* pVx, void* pxShapeInfo, void* vz, Nd4jLong* zShapeInfo, const int axis), LIBND4J_TYPES);
//////////////////////////////////////////////////////////////////////////
void concat(sd::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output, const int axis) {
void concat(sd::LaunchContext * context, const std::vector<const NDArray*>& inArrs, NDArray& output, const int axis) {
const int numOfInArrs = inArrs.size();
const auto sizeofT = output.sizeOfT();
for(int i = 0; i < numOfInArrs; ++i)
inArrs[i]->syncToDevice();
output.syncToDevice();
NDArray::prepareSpecialUse({&output}, inArrs);
bool luckCase1 = ((axis == 0 && output.ordering() == 'c') || (axis == output.rankOf() - 1 && output.ordering() == 'f')) && output.ews() == 1;
@ -122,43 +120,48 @@ void concat(sd::LaunchContext * context, const std::vector<NDArray*>& inArrs, ND
return;
}
const bool isZcontin = output.strideAt(axis) == 1;
bool areInputsContin = true;
bool allSameOrder = true;
// const bool isZcontin = output.strideAt(axis) == 1;
// bool areInputsContin = true;
// bool allSameOrder = true;
// std::vector<Nd4jLong> strideOfContigStride(numOfInArrs);
if(isZcontin) {
for (uint i = 0; i < inArrs.size(); ++i) {
areInputsContin &= inArrs[i]->strideAt(axis) == 1;
allSameOrder &= output.ordering() == inArrs[i]->ordering();
if(!areInputsContin || !allSameOrder)
break;
}
}
// if(isZcontin) {
const bool luckCase2 = isZcontin && areInputsContin && allSameOrder;
// for (uint i = 0; i < inArrs.size(); ++i) {
if(luckCase2) { // for example {2,1,3} + {2,5,3} + {2,10,3} = {2,16,3}, here axis 1 shoud have stride = 1 for all inputs arrays and output array
// areInputsContin &= inArrs[i]->strideAt(axis) == 1;
// allSameOrder &= output.ordering() == inArrs[i]->ordering();
// if(!areInputsContin || !allSameOrder)
// break;
const uint zDim = output.sizeAt(axis);
// strideOfContigStride[i] = shape::strideOverContigAxis(axis, inArrs[i]->getShapeInfo());
// }
// }
for (uint i = 0; i < output.lengthOf() / zDim; ++i) {
// const bool luckCase2 = isZcontin && areInputsContin && allSameOrder;
const auto iShift = i * sizeofT;
void* z = static_cast<int8_t*>(output.getSpecialBuffer()) + zDim * iShift;
// if(luckCase2) { // for example {2,1,3} + {2,5,3} + {2,10,3} = {2,16,3}, here axis 1 shoud have stride = 1 for all inputs arrays and output array
for (uint j = 0; j < numOfInArrs; ++j) {
const auto xDim = inArrs[j]->sizeAt(axis);
void* x = static_cast<int8_t*>(inArrs[j]->getSpecialBuffer()) + xDim * iShift;
const auto memSizeToCopy = xDim * sizeofT;
cudaMemcpyAsync(z, x, memSizeToCopy, cudaMemcpyDeviceToDevice, *context->getCudaStream());
z = static_cast<int8_t*>(z) + memSizeToCopy;
}
}
// const auto zStep = shape::strideOverContigAxis(axis, output.getShapeInfo());
if(cudaStreamSynchronize(*context->getCudaStream()) != 0)
throw std::runtime_error("concat cuda: luckCase2 failed!");
}
else { // general (slower) case
// for (uint i = 0; i < output.lengthOf() / output.sizeAt(axis); ++i) {
// const auto iShift = i * sizeofT;
// void* z = static_cast<int8_t*>(output.getSpecialBuffer()) + zStep * iShift;
// for (uint j = 0; j < numOfInArrs; ++j) {
// const auto xDim = inArrs[j]->sizeAt(axis);
// void* x = static_cast<int8_t*>(inArrs[j]->getSpecialBuffer()) + strideOfContigStride[j] * iShift;
// const auto memSizeToCopy = xDim * sizeofT;
// cudaMemcpyAsync(z, x, memSizeToCopy, cudaMemcpyDeviceToDevice, *context->getCudaStream());
// z = static_cast<int8_t*>(z) + memSizeToCopy;
// }
// }
// if(cudaStreamSynchronize(*context->getCudaStream()) != 0)
// throw std::runtime_error("concat cuda: luckCase2 failed!");
// }
// else { // general (slower) case
const int threadsPerBlock = 256;
const int blocksPerGrid = 512;
@ -181,11 +184,9 @@ void concat(sd::LaunchContext * context, const std::vector<NDArray*>& inArrs, ND
BUILD_SINGLE_SELECTOR(inArrs[0]->dataType(), concatCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), dInBuffers, dInShapeInfo, output.specialBuffer(), output.specialShapeInfo(), axis), LIBND4J_TYPES);
manager.synchronize();
}
// }
for(int i = 0; i < numOfInArrs; ++i)
inArrs[i]->tickReadDevice();
output.tickWriteDevice();
NDArray::registerSpecialUse({&output}, inArrs);
}
}

View File

@ -48,11 +48,11 @@ __global__ static void splitCuda(const void* vx, const Nd4jLong* xShapeInfo, voi
xLen = shape::length(xShapeInfo);
xRank = shape::rank(xShapeInfo);
zDim = shape::shapeOf(zTadShapeInfo)[axis]; // same for all input arrays
totalThreads = gridDim.z * blockDim.z;
totalThreads = gridDim.x * blockDim.x;
}
__syncthreads();
const auto tid = blockIdx.z * blockDim.z + threadIdx.z;
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
Nd4jLong coords[MAX_RANK];
@ -121,43 +121,48 @@ void split(sd::LaunchContext* context, const NDArray& input, std::vector<NDArray
return;
}
const bool isXcontin = input.strideAt(axis) == 1;
bool areOutputsContin = true;
bool allSameOrder = true;
// const bool isXcontin = input.strideAt(axis) == 1;
// bool areOutputsContin = true;
// bool allSameOrder = true;
// std::vector<Nd4jLong> strideOfContigStride(outArrs.size());
if(isXcontin) {
for (uint i = 0; i < outArrs.size(); ++i) {
areOutputsContin &= outArrs[i]->strideAt(axis) == 1;
allSameOrder &= input.ordering() == outArrs[i]->ordering();
if(!areOutputsContin || !allSameOrder)
break;
}
}
// if(isXcontin) {
const bool luckCase2 = isXcontin && areOutputsContin && allSameOrder;
// for (uint i = 0; i < outArrs.size(); ++i) {
if(luckCase2) { // for example {2,1,3} + {2,5,3} + {2,10,3} = {2,16,3}, here axis 1 shoud have stride = 1 for all inputs arrays and input array
// areOutputsContin &= outArrs[i]->strideAt(axis) == 1;
// allSameOrder &= input.ordering() == outArrs[i]->ordering();
// if(!areOutputsContin || !allSameOrder)
// break;
const auto xDim = input.sizeAt(axis);
const auto zDim = outArrs[0]->sizeAt(axis); // same for all outArrs
// strideOfContigStride[i] = shape::strideOverContigAxis(axis, outArrs[i]->getShapeInfo());
// }
// }
for (uint i = 0; i < input.lengthOf() / xDim; ++i) {
// const bool luckCase2 = isXcontin && areOutputsContin && allSameOrder;
const auto iShift = i * sizeofT;
void* x = static_cast<int8_t*>(input.getSpecialBuffer()) + xDim * iShift;
// if(luckCase2) { // for example {2,1,3} + {2,5,3} + {2,10,3} = {2,16,3}, here axis 1 shoud have stride = 1 for all inputs arrays and input array
for (uint j = 0; j < numOfSubArrs; ++j) {
void* z = static_cast<int8_t*>(outArrs[j]->getSpecialBuffer()) + zDim * iShift;
const auto memSizeToCopy = zDim * sizeofT;
cudaMemcpyAsync(z, x, memSizeToCopy, cudaMemcpyDeviceToDevice, *context->getCudaStream());
x = static_cast<int8_t*>(x) + memSizeToCopy;
}
}
// const auto xStep = shape::strideOverContigAxis(axis, input.getShapeInfo());
// const auto zDim = outArrs[0]->sizeAt(axis); // same for all outArrs
if(cudaStreamSynchronize(*context->getCudaStream()) != 0)
throw std::runtime_error("split cuda: luckCase2 failed!");
}
else { // general (slower) case
// for (uint i = 0; i < input.lengthOf() / input.sizeAt(axis); ++i) {
// const auto iShift = i * sizeofT;
// void* x = static_cast<int8_t*>(input.getSpecialBuffer()) + xStep * iShift;
// for (uint j = 0; j < numOfSubArrs; ++j) {
// void* z = static_cast<int8_t*>(outArrs[j]->getSpecialBuffer()) + strideOfContigStride[j] * iShift;
// const auto memSizeToCopy = zDim * sizeofT;
// cudaMemcpyAsync(z, x, memSizeToCopy, cudaMemcpyDeviceToDevice, *context->getCudaStream());
// x = static_cast<int8_t*>(x) + memSizeToCopy;
// }
// }
// if(cudaStreamSynchronize(*context->getCudaStream()) != 0)
// throw std::runtime_error("split cuda: luckCase2 failed!");
// }
// else { // general (slower) case
const int threadsPerBlock = MAX_NUM_THREADS / 2;
const int blocksPerGrid = (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
@ -175,7 +180,7 @@ void split(sd::LaunchContext* context, const NDArray& input, std::vector<NDArray
BUILD_SINGLE_SELECTOR(input.dataType(), splitCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), dOutBuffers, outArrs[0]->specialShapeInfo(), axis), LIBND4J_TYPES);
manager.synchronize();
}
// }
for(int i = 0; i < numOfSubArrs; ++i)
outArrs[i]->tickWriteDevice();

View File

@ -31,79 +31,333 @@ namespace ops {
namespace helpers {
template <typename T>
static __global__ void stackKernel(void** inputList, void** inputShapeList, int inputListLength, Nd4jLong arrLen, void* vz, const Nd4jLong* zShapeInfo, Nd4jLong* tadShape, Nd4jLong *tadOffsets) {
///////////////////////////////////////////////////////////////////
template <typename T>
static __global__ void stackScalarsCuda(void* pVx, void* vz, const Nd4jLong* zShapeInfo) {
T* z = reinterpret_cast<T*>(vz);
T* z = reinterpret_cast<T*>(vz);
if(tadShape == nullptr) { // scalar case
__shared__ Nd4jLong zLen, totalThreads;
for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < inputListLength; i += gridDim.x * blockDim.x)
z[shape::getIndexOffset(i, zShapeInfo)] = reinterpret_cast<T*>(inputList[i])[0];
}
else {
if (threadIdx.x == 0) {
zLen = shape::length(zShapeInfo);
totalThreads = gridDim.x * blockDim.x;
}
__syncthreads();
for (int t = blockIdx.x; t < inputListLength; t += gridDim.x) {
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
auto tZ = z + tadOffsets[t];
auto tX = reinterpret_cast<T*>(inputList[t]);
auto xShapeInfo = reinterpret_cast<Nd4jLong*>(inputShapeList[t]);
for (Nd4jLong i = tid; i < zLen; i += totalThreads) {
for (int e = threadIdx.x; e < arrLen; e += blockDim.x)
tZ[shape::getIndexOffset(e, tadShape)] = tX[shape::getIndexOffset(e, xShapeInfo)];
}
}
}
const T *x = reinterpret_cast<const T*>(reinterpret_cast<void**>(pVx)[i]);
z[shape::getIndexOffset(i, zShapeInfo)] = *x;
}
}
///////////////////////////////////////////////////////////////////
template <typename T>
static void stack_(sd::LaunchContext * context, const std::vector<const NDArray*>& inArrs, NDArray* outArr, const int dim) {
const bool scalarCase = inArrs[0]->isScalar();
///////////////////////////////////////////////////////////////////
template<typename T>
__host__ static void stackScalarsCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream,
void* pVx, void* vz, const Nd4jLong* zShapeInfo) {
const int threadsPerBlock = MAX_NUM_THREADS / 2;
const int blocksPerGrid = scalarCase ? (outArr->lengthOf() + threadsPerBlock - 1) / threadsPerBlock : inArrs.size();
stackScalarsCuda<T><<<blocksPerGrid, threadsPerBlock, 256, *stream>>>(pVx, vz, zShapeInfo);
}
NDArray::prepareSpecialUse({outArr}, {});
///////////////////////////////////////////////////////////////////
template <typename T>
static void stack_(sd::LaunchContext* context, const std::vector<const NDArray*>& inArrs, NDArray& output, const int dim) {
// FIXME: !!!
for (auto v:inArrs)
NDArray::prepareSpecialUse({}, {v});
const int numOfSubArrs = inArrs.size();
std::vector<void const*> inputList(inArrs.size());
std::vector<Nd4jLong const*> inputShapeList(inArrs.size());
NDArray::prepareSpecialUse({&output}, inArrs);
for (size_t i = 0; i < inputList.size(); ++i) {
inputList[i] = inArrs[i]->getSpecialBuffer();
inputShapeList[i] = inArrs[i]->getSpecialShapeInfo();
}
if(inArrs[0]->rankOf() == 0) {
PointersManager manager(context, "helpers::stack");
auto dInBuffers = (void **) manager.replicatePointer(inputList.data(), inputList.size() * sizeof(Nd4jLong*));
auto dInShapeInfo = (void **) manager.replicatePointer(inputShapeList.data(), inputShapeList.size() * sizeof(Nd4jLong*));
std::vector<void*> hInBuffers(numOfSubArrs);
for(int i = 0; i < numOfSubArrs; ++i)
hInBuffers[i] = inArrs[i]->getSpecialBuffer();
PointersManager manager(context, "helpers::stack cuda");
void* dInBuffers = manager.replicatePointer(hInBuffers.data(), hInBuffers.size() * sizeof(void*));
const int threadsPerBlock = MAX_NUM_THREADS / 2;
const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
stackScalarsCudaLauncher<T>(blocksPerGrid, threadsPerBlock, context->getCudaStream(), dInBuffers, output.specialBuffer(), output.specialShapeInfo());
if(scalarCase) {
stackKernel<T><<<blocksPerGrid, threadsPerBlock, 1024, *context->getCudaStream()>>>((void**)dInBuffers, (void**)dInShapeInfo, inputList.size(), inArrs[0]->lengthOf(), outArr->specialBuffer(), outArr->getSpecialShapeInfo(), nullptr, nullptr);
}
else {
std::vector<int> axis = ShapeUtils::evalDimsToExclude(outArr->rankOf(), {dim});
auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(outArr->getShapeInfo(), axis);
stackKernel<T><<<blocksPerGrid, threadsPerBlock, 1024, *context->getCudaStream()>>>((void**)dInBuffers, (void**)dInShapeInfo, inputList.size(), inArrs[0]->lengthOf(), outArr->specialBuffer(), nullptr, packZ.specialShapeInfo(), packZ.specialOffsets());
}
manager.synchronize();
}
else {
NDArray::registerSpecialUse({outArr}, {});
auto zTadPack = ConstantTadHelper::getInstance()->tadForDimensions(output.getShapeInfo(), ShapeUtils::evalDimsToExclude(output.rankOf(), {dim}));
Nd4jLong* zTadShapeInfo = zTadPack.primaryShapeInfo();
// FIXME: !!!
for (auto v:inArrs)
NDArray::registerSpecialUse({}, {v});
}
for (uint i = 0; i < numOfSubArrs; ++i) {
void stack(sd::LaunchContext * context, const std::vector<const NDArray*>& inArrs, NDArray* outArr, const int dim) {
BUILD_SINGLE_SELECTOR(outArr->dataType(), stack_, (context, inArrs, outArr, dim), LIBND4J_TYPES);
}
void* zBuff = output.specialBufferWithOffset(zTadPack.primaryOffsets()[i]);
BUILD_SINGLE_TEMPLATE(template void stack_ , (sd::LaunchContext * context, const std::vector<const NDArray*>& inArrs, NDArray* outArr, const int dim), LIBND4J_TYPES);
NativeOpExecutioner::execTransformAny(context, transform::Assign,
nullptr, inArrs[i]->getShapeInfo(), inArrs[i]->getSpecialBuffer(), inArrs[i]->getSpecialShapeInfo(),
nullptr, zTadShapeInfo, zBuff, zTadPack.specialShapeInfo(),
nullptr, nullptr, nullptr, false/*allowParallelism*/);
}
}
NDArray::registerSpecialUse({&output}, inArrs);
}
////////////////////////////////////////////////////////////////////////
void stack(sd::LaunchContext* context, const std::vector<const NDArray*>& inArrs, NDArray& output, const int dim) {
BUILD_SINGLE_SELECTOR(output.dataType(), stack_, (context, inArrs, output, dim), LIBND4J_TYPES);
}
BUILD_SINGLE_TEMPLATE(template void stack_ , (sd::LaunchContext* context, const std::vector<const NDArray*>& inArrs, NDArray& output, const int dim), LIBND4J_TYPES);
///////////////////////////////////////////////////////////////////
template <typename T>
static __global__ void unstackScalarsCuda(const void* vx, const Nd4jLong* xShapeInfo, void* pVz) {
const T* x = reinterpret_cast<const T*>(vx);
__shared__ Nd4jLong xLen, totalThreads;
if (threadIdx.x == 0) {
xLen = shape::length(xShapeInfo);
totalThreads = gridDim.x * blockDim.x;
}
__syncthreads();
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
for (Nd4jLong i = tid; i < xLen; i += totalThreads) {
T* z = reinterpret_cast<T*>(reinterpret_cast<void**>(pVz)[i]);
*z = x[shape::getIndexOffset(i, xShapeInfo)];
}
}
///////////////////////////////////////////////////////////////////
template<typename T>
__host__ static void unstackScalarsCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream,
const void* vx, const Nd4jLong* xShapeInfo, void* pVz) {
unstackScalarsCuda<T><<<blocksPerGrid, threadsPerBlock, 256, *stream>>>(vx, xShapeInfo, pVz);
}
///////////////////////////////////////////////////////////////////
template <typename T>
static void unstack_(sd::LaunchContext* context, const NDArray& input, const std::vector<NDArray*>& outArrs, const int dim) {
const int numOfSubArrs = outArrs.size();
// NDArray::prepareSpecialUse(outArrs, {&input});
input.syncToDevice();
for (const auto a : outArrs)
a->getDataBuffer()->allocateSpecial();
if(outArrs[0]->rankOf() == 0) {
std::vector<void*> hOutBuffers(numOfSubArrs);
for(int i = 0; i < numOfSubArrs; ++i)
hOutBuffers[i] = outArrs[i]->getSpecialBuffer();
PointersManager manager(context, "helpers::unstack cuda");
void* dOutBuffers = manager.replicatePointer(hOutBuffers.data(), hOutBuffers.size() * sizeof(void*));
const int threadsPerBlock = MAX_NUM_THREADS / 2;
const int blocksPerGrid = (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
unstackScalarsCudaLauncher<T>(blocksPerGrid, threadsPerBlock, context->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), dOutBuffers);
manager.synchronize();
}
else {
auto xTadPack = ConstantTadHelper::getInstance()->tadForDimensions(input.getShapeInfo(), ShapeUtils::evalDimsToExclude(input.rankOf(), {dim}));
Nd4jLong* xTadShapeInfo = xTadPack.primaryShapeInfo();
for (uint i = 0; i < numOfSubArrs; ++i) {
void* xBuff = input.specialBufferWithOffset(xTadPack.primaryOffsets()[i]);
NativeOpExecutioner::execTransformAny(input.getContext(), transform::Assign,
nullptr, xTadShapeInfo, xBuff, xTadPack.specialShapeInfo(),
nullptr, outArrs[i]->getShapeInfo(), outArrs[i]->specialBuffer(), outArrs[i]->specialShapeInfo(),
nullptr, nullptr, nullptr, false/*allowParallelism*/);
}
}
// NDArray::registerSpecialUse(outArrs, {&input});
input.tickReadDevice();
for (const auto p : outArrs)
p->tickWriteDevice();
}
////////////////////////////////////////////////////////////////////////
void unstack(sd::LaunchContext* context, const NDArray& input, const std::vector<NDArray*>& outArrs, const int dim) {
BUILD_SINGLE_SELECTOR(input.dataType(), unstack_, (context, input, outArrs, dim), LIBND4J_TYPES);
}
BUILD_SINGLE_TEMPLATE(template void unstack_, (sd::LaunchContext* context, const NDArray& input, const std::vector<NDArray*>& outArrs, const int dim), LIBND4J_TYPES);
///////////////////////////////////////////////////////////////////
// template <typename T>
// static __global__ void unstackCuda(const void* vx, const Nd4jLong* xShapeInfo, void* pVz, const Nd4jLong* zTadShapeInfo, const int axis) {
// const T* x = reinterpret_cast<const T*>(vx);
// __shared__ Nd4jLong xLen, totalThreads;
// __shared__ int xRank;
// if (threadIdx.x == 0) {
// xLen = shape::length(xShapeInfo);
// xRank = shape::rank(xShapeInfo);
// totalThreads = gridDim.x * blockDim.x;
// }
// __syncthreads();
// const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
// Nd4jLong coords[MAX_RANK];
// for (uint64_t i = tid; i < xLen; i += totalThreads) {
// shape::index2coords(i, xShapeInfo, coords);
// const auto xOffset = shape::getOffset(xShapeInfo, coords);
// T *z = reinterpret_cast<T*>(reinterpret_cast<void **>(pVz)[coords[axis]]);
// for (uint j = axis; j < xRank - 1; ++j) // shift coords staring from axis position
// coords[j] = coords[j + 1];
// const auto zOffset = shape::getOffset(zTadShapeInfo, coords);
// z[zOffset] = x[xOffset];
// }
// }
// ///////////////////////////////////////////////////////////////////
// template<typename T>
// __host__ static void unstackCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream,
// const void* vx, const Nd4jLong* xShapeInfo, void* pVz, const Nd4jLong* zTadShapeInfo, const int axis) {
// unstackCuda<T><<<blocksPerGrid, threadsPerBlock, 256, *stream>>>(vx, xShapeInfo, pVz, zTadShapeInfo, axis);
// }
// BUILD_SINGLE_TEMPLATE(template void unstackCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, void* pVz, const Nd4jLong* zTadShapeInfo, const int axis), LIBND4J_TYPES);
// ///////////////////////////////////////////////////////////////////
// void unstack(sd::LaunchContext* context, const NDArray& input, const std::vector<const NDArray*>& outArrs, const int axis) {
// const int threadsPerBlock = MAX_NUM_THREADS / 2;
// const int blocksPerGrid = (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
// const int numOfSubArrs = outArrs.size();
// std::vector<void*> hOutBuffers(numOfSubArrs);
// for(int i = 0; i < numOfSubArrs; ++i)
// hOutBuffers[i] = outArrs[i]->getSpecialBuffer();
// PointersManager manager(context, "helpers::unstack");
// void* dOutBuffers = manager.replicatePointer(hOutBuffers.data(), hOutBuffers.size() * sizeof(void*));
// for(uint i = 0; i < numOfSubArrs; ++i)
// outArrs[i]->syncToDevice();
// input.syncToDevice();
// BUILD_SINGLE_SELECTOR(input.dataType(), unstackCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), dOutBuffers, outArrs[0]->getSpecialShapeInfo(), axis), LIBND4J_TYPES);
// manager.synchronize();
// for(uint i = 0; i < numOfSubArrs; ++i)
// outArrs[i]->tickReadDevice();
// input.tickWriteDevice();
// }
// ///////////////////////////////////////////////////////////////////
// template <typename T>
// static __global__ void stackCuda(void* pVx, const Nd4jLong* xTadShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int axis) {
// T* z = reinterpret_cast<T*>(vz);
// __shared__ Nd4jLong zLen, totalThreads;
// __shared__ int zRank;
// if (threadIdx.x == 0) {
// zLen = shape::length(zShapeInfo);
// zRank = shape::rank(zShapeInfo);
// totalThreads = gridDim.x * blockDim.x;
// }
// __syncthreads();
// const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
// Nd4jLong coords[MAX_RANK];
// for (uint64_t i = tid; i < zLen; i += totalThreads) {
// shape::index2coords(i, zShapeInfo, coords);
// const auto zOffset = shape::getOffset(zShapeInfo, coords);
// const T *x = reinterpret_cast<const T*>(reinterpret_cast<void**>(pVx)[coords[axis]]);
// for (uint j = axis; j < zRank - 1; ++j) // shift coords staring from axis position
// coords[j] = coords[j + 1];
// const auto xOffset = shape::getOffset(xTadShapeInfo, coords);
// z[zOffset] = x[xOffset];
// }
// }
// ///////////////////////////////////////////////////////////////////
// template<typename T>
// __host__ static void stackCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream,
// void* pVx, const Nd4jLong* xTadShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int axis) {
// stackCuda<T><<<blocksPerGrid, threadsPerBlock, 256, *stream>>>(pVx, xTadShapeInfo, vz, zShapeInfo, axis);
// }
// BUILD_SINGLE_TEMPLATE(template void stackCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, void* pVx, const Nd4jLong* xTadShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int axis), LIBND4J_TYPES);
// ///////////////////////////////////////////////////////////////////
// void stack(sd::LaunchContext* context, const std::vector<const NDArray*>& inArrs, NDArray& output, const int axis) {
// const int threadsPerBlock = MAX_NUM_THREADS / 2;
// const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
// const int numOfSubArrs = inArrs.size();
// std::vector<void*> hInBuffers(numOfSubArrs);
// for(int i = 0; i < numOfSubArrs; ++i)
// hInBuffers[i] = inArrs[i]->getSpecialBuffer();
// PointersManager manager(context, "helpers::stack");
// void* dInBuffers = manager.replicatePointer(hInBuffers.data(), hInBuffers.size() * sizeof(void*));
// for(uint i = 0; i < numOfSubArrs; ++i)
// inArrs[i]->syncToDevice();
// output.syncToDevice();
// BUILD_SINGLE_SELECTOR(output.dataType(), stackCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), dInBuffers, inArrs[0]->getSpecialShapeInfo(), output.getSpecialBuffer(), output.getSpecialShapeInfo(), axis), LIBND4J_TYPES);
// manager.synchronize();
// for(uint i = 0; i < numOfSubArrs; ++i)
// inArrs[i]->tickReadDevice();
// output.tickWriteDevice();
// }
}
}

View File

@ -124,19 +124,19 @@ static void svdQR(sd::LaunchContext* context, const NDArray* A, NDArray* S, NDAr
if(m < n)
throw std::runtime_error("svdQR: due to cuda api input constrains given shape of A array are not valid !");
if(ShapeUtils::shapeAsString({minDim}) != ShapeUtils::shapeAsString(S))
if(std::vector<Nd4jLong>({minDim}) != S->getShapeAsVector())
throw std::runtime_error("svdQR: wrong shape of S array !");
if(calcUV) {
if(fullUV && ShapeUtils::shapeAsString({m,m}) != ShapeUtils::shapeAsString(U))
if(fullUV && std::vector<Nd4jLong>({m,m}) != U->getShapeAsVector())
throw std::runtime_error("svdQR: wrong shape of U array !");
else if(!fullUV && ShapeUtils::shapeAsString({m,minDim}) != ShapeUtils::shapeAsString(U))
else if(!fullUV && std::vector<Nd4jLong>({m,minDim}) != U->getShapeAsVector())
throw std::runtime_error("svdQR: wrong shape of U array !");
if(fullUV && ShapeUtils::shapeAsString({n,n}) != ShapeUtils::shapeAsString(VT))
if(fullUV && std::vector<Nd4jLong>({n,n}) != VT->getShapeAsVector())
throw std::runtime_error("svdQR: wrong shape of VT array !");
else if(!fullUV && ShapeUtils::shapeAsString({minDim,n}) != ShapeUtils::shapeAsString(VT))
else if(!fullUV && std::vector<Nd4jLong>({minDim,n}) != VT->getShapeAsVector())
throw std::runtime_error("svdQR: wrong shape of VT array !");
}
@ -280,19 +280,19 @@ static void svdJcb(sd::LaunchContext* context, const NDArray* A, NDArray* S, NDA
int n = A->sizeAt(1);
const int minDim = m < n ? m : n;
if(ShapeUtils::shapeAsString({minDim}) != ShapeUtils::shapeAsString(S))
if(std::vector<Nd4jLong>({minDim}) != S->getShapeAsVector())
throw std::runtime_error("svdJcb: wrong shape of S array !");
if(calcUV) {
if(fullUV && ShapeUtils::shapeAsString({m,m}) != ShapeUtils::shapeAsString(U))
if(fullUV && std::vector<Nd4jLong>({m,m}) != U->getShapeAsVector())
throw std::runtime_error("svdJcb: wrong shape of U array !");
else if(!fullUV && ShapeUtils::shapeAsString({m,minDim}) != ShapeUtils::shapeAsString(U))
else if(!fullUV && std::vector<Nd4jLong>({m,minDim}) != U->getShapeAsVector())
throw std::runtime_error("svdJcb: wrong shape of U array !");
if(fullUV && ShapeUtils::shapeAsString({n,n}) != ShapeUtils::shapeAsString(V))
if(fullUV && std::vector<Nd4jLong>({n,n}) != V->getShapeAsVector())
throw std::runtime_error("svdJcb: wrong shape of V array !");
else if(!fullUV && ShapeUtils::shapeAsString({n,minDim}) != ShapeUtils::shapeAsString(V))
else if(!fullUV && std::vector<Nd4jLong>({n,minDim}) != V->getShapeAsVector())
throw std::runtime_error("svdJcb: wrong shape of V array !");
}

View File

@ -28,7 +28,8 @@ namespace sd {
namespace ops {
namespace helpers {
void stack(sd::LaunchContext * context, const std::vector<const NDArray*>& inArrs, NDArray* outArr, const int dim);
void stack (sd::LaunchContext* context, const std::vector<const NDArray*>& inArrs, NDArray& outArr, const int dim);
void unstack(sd::LaunchContext* context, const NDArray& input, const std::vector<NDArray*>& outArrs, const int dim);
}

View File

@ -70,7 +70,11 @@ namespace helpers {
void mirrorPad(sd::LaunchContext * context, const NDArray& input, const NDArray& paddings, NDArray& output, const int mode);
void concat(sd::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output, const int axis);
void clipByValue(sd::LaunchContext * context, NDArray& input, double leftBound, double rightBound, NDArray& output);
void mirrorPad(sd::LaunchContext * context, const NDArray& input, const NDArray& paddings, NDArray& output, const int mode);
void concat(sd::LaunchContext * context, const std::vector<const NDArray*>& inArrs, NDArray& output, const int axis);
void tileBP(sd::LaunchContext * context, const NDArray& gradO /*input*/, NDArray& gradI /*output*/, const std::vector<Nd4jLong> reps);

View File

@ -43,11 +43,8 @@ namespace sd {
NDArray::prepareSpecialUse({z}, {x, y});
if (!x->isSameShape(y)) {
std::string sx = ShapeUtils::shapeAsString(x);
std::string sy = ShapeUtils::shapeAsString(y);
REQUIRE_TRUE(x->isSameShape(y) || y->isScalar(), 0, "Node_%i: For Pairwise transforms shapes of both operands should be equal but got %s vs %s", block.getNodeId(), sx.c_str(), sy.c_str());
}
if (!x->isSameShape(y))
REQUIRE_TRUE(x->isSameShape(y) || y->isScalar(), 0, "Node_%i: For Pairwise transforms shapes of both operands should be equal but got %s vs %s", block.getNodeId(), ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str());
int opNum = block.opNum() < 0 ? this->_opNum : block.opNum();

View File

@ -43,11 +43,8 @@ namespace sd {
NDArray::prepareSpecialUse({z}, {x, y});
if (!x->isSameShape(y)) {
std::string sx = ShapeUtils::shapeAsString(x);
std::string sy = ShapeUtils::shapeAsString(y);
REQUIRE_TRUE(x->isSameShape(y) || y->isScalar(), 0, "Node_%i: For Pairwise transforms shapes of both operands should be equal but got %s vs %s", block.getNodeId(), sx.c_str(), sy.c_str());
}
if (!x->isSameShape(y))
REQUIRE_TRUE(x->isSameShape(y) || y->isScalar(), 0, "Node_%i: For Pairwise transforms shapes of both operands should be equal but got %s vs %s", block.getNodeId(), ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str());
int opNum = block.opNum() < 0 ? this->_opNum : block.opNum();

View File

@ -108,7 +108,7 @@ namespace sd {
// }
template <typename T>
void SpecialMethods<T>::concatCpuGeneric(const std::vector<NDArray*>& inArrs, NDArray& output, const int axis) {
void SpecialMethods<T>::concatCpuGeneric(const std::vector<const NDArray*>& inArrs, NDArray& output, const int axis) {
const int numOfInArrs = inArrs.size();
const auto sizeofT = output.sizeOfT();
@ -136,38 +136,44 @@ void SpecialMethods<T>::concatCpuGeneric(const std::vector<NDArray*>& inArrs, ND
return;
}
const bool isZcontin = output.strideAt(axis) == 1 && output.ordering() == 'c';
bool areInputsContin = true;
bool allSameOrder = true;
// const bool isZcontin = output.strideAt(axis) == 1;
// bool areInputsContin = true;
// bool allSameOrder = true;
// std::vector<Nd4jLong> strideOfContigStride(numOfInArrs);
if(isZcontin) {
for (uint i = 0; i < numOfInArrs; ++i) {
areInputsContin &= inArrs[i]->strideAt(axis) == 1;
allSameOrder &= inArrs[i]->ordering() == output.ordering();
if(!areInputsContin || !allSameOrder)
break;
}
}
// if(isZcontin) {
const bool luckCase2 = isZcontin && areInputsContin && allSameOrder;
// for (uint i = 0; i < numOfInArrs; ++i) {
if(luckCase2) { // for example {2,1,3} + {2,5,3} + {2,10,3} = {2,16,3}, here axis 1 shoud have stride = 1 for all inputs arrays and output array
// areInputsContin &= inArrs[i]->strideAt(axis) == 1;
// allSameOrder &= inArrs[i]->ordering() == output.ordering();
// if(!areInputsContin || !allSameOrder)
// break;
const uint zDim = output.sizeAt(axis);
// strideOfContigStride[i] = shape::strideOverContigAxis(axis, inArrs[i]->getShapeInfo());
// }
// }
for (uint i = 0; i < output.lengthOf() / zDim; ++i) {
T* z = zBuff + zDim * i;
// const bool luckCase2 = isZcontin && areInputsContin && allSameOrder;
for (uint j = 0; j < inArrs.size(); ++j) {
const auto xDim = inArrs[j]->sizeAt(axis);
const T* x = inArrs[j]->bufferAsT<T>() + xDim * i;
memcpy(z, x, xDim * sizeofT);
z += xDim;
}
}
// if(luckCase2) { // for example {2,1,3} + {2,5,3} + {2,10,3} = {2,16,3}, here axis 1 shoud have stride = 1 for all inputs arrays and output array
return;
}
// const auto zStep = shape::strideOverContigAxis(axis, output.getShapeInfo());
// for (uint i = 0; i < output.lengthOf() / output.sizeAt(axis); ++i) {
// T* z = zBuff + zStep * i;
// for (uint j = 0; j < inArrs.size(); ++j) {
// const auto xDim = inArrs[j]->sizeAt(axis);
// const T* x = inArrs[j]->bufferAsT<T>() + strideOfContigStride[j] * i;
// memcpy(z, x, xDim * sizeofT);
// z += xDim;
// }
// }
// return;
// }
// general case
auto func = PRAGMA_THREADS_FOR {
@ -204,7 +210,7 @@ void SpecialMethods<T>::concatCpuGeneric(const std::vector<NDArray*>& inArrs, ND
template <typename T>
void SpecialMethods<T>::concatCpuGeneric(int dimension, int numArrays, Nd4jPointer *data, Nd4jPointer *inputShapeInfo, void *vresult, Nd4jLong *resultShapeInfo) {
auto result = reinterpret_cast<T *>(vresult);
std::vector<NDArray*> inputs(numArrays);
std::vector<const NDArray*> inputs(numArrays);
NDArray output(static_cast<void*>(result), static_cast<Nd4jLong*>(resultShapeInfo));
@ -217,6 +223,104 @@ void SpecialMethods<T>::concatCpuGeneric(int dimension, int numArrays, Nd4jPoint
delete inputs[i];
}
template <typename T>
void SpecialMethods<T>::splitCpuGeneric(const NDArray& input, const std::vector<NDArray*>& outArrs, const int axis) {
int numSplits = outArrs.size();
const auto sizeofT = input.sizeOfT();
T* xBuff = input.bufferAsT<T>();
bool luckCase1 = ((axis == 0 && input.ordering() == 'c') || (axis == input.rankOf() - 1 && input.ordering() == 'f')) && input.ews() == 1;
if (luckCase1) {
for (uint i = 0; i < numSplits; ++i) {
luckCase1 &= outArrs[i]->ordering() == input.ordering() && outArrs[i]->ews() == 1;
if (!luckCase1)
break;
}
}
if (luckCase1) {
T* x = const_cast<T*>(xBuff);
for (uint i = 0; i < numSplits; ++i) {
const auto memAmountToCopy = outArrs[i]->lengthOf();
memcpy(outArrs[i]->bufferAsT<T>(), x, memAmountToCopy * sizeofT);
x += memAmountToCopy;
}
return;
}
// const bool isXcontin = input.strideAt(axis) == 1;
// bool areOutsContin = true;
// bool allSameOrder = true;
// std::vector<Nd4jLong> strideOfContigStride(numSplits);
// if (isXcontin) {
// for (uint i = 0; i < numSplits; ++i) {
// areOutsContin &= outArrs[i]->strideAt(axis) == 1;
// allSameOrder &= outArrs[i]->ordering() == input.ordering();
// if (!areOutsContin || !allSameOrder)
// break;
// strideOfContigStride[i] = shape::strideOverContigAxis(axis, outArrs[i]->getShapeInfo());
// }
// }
// const bool luckCase2 = isXcontin && areOutsContin && allSameOrder;
// if (luckCase2) {
// const auto xStep = shape::strideOverContigAxis(axis, input.getShapeInfo());
// for (uint i = 0; i < input.lengthOf() / input.sizeAt(axis); ++i) {
// T* x = xBuff + xStep * i;
// for (uint j = 0; j < numSplits; ++j) {
// const auto zDim = outArrs[j]->sizeAt(axis);
// T* z = outArrs[j]->bufferAsT<T>() + strideOfContigStride[j] * i;
// memcpy(z, x, zDim * sizeofT);
// x += zDim;
// }
// }
// return;
// }
uint zDim = outArrs[0]->sizeAt(axis);
// general case
auto func = PRAGMA_THREADS_FOR{
Nd4jLong coords[MAX_RANK];
for (auto i = start; i < stop; i += increment) {
shape::index2coords(i, input.getShapeInfo(), coords);
const auto xOffset = shape::getOffset(input.getShapeInfo(), coords);
uint outArrIdx = 0;
while (coords[axis] >= zDim) {
coords[axis] -= zDim;
++outArrIdx;
}
T* z = outArrs[outArrIdx]->bufferAsT<T>();
const auto zOffset = shape::getOffset(outArrs[outArrIdx]->getShapeInfo(), coords);
z[zOffset] = xBuff[xOffset];
}
};
samediff::Threads::parallel_for(func, 0, input.lengthOf());
}
/**
* This kernel accumulates X arrays, and stores result into Z
*

View File

@ -50,8 +50,9 @@ namespace sd {
template <typename T>
class ND4J_EXPORT SpecialMethods {
public:
static void concatCpuGeneric(const std::vector<NDArray*>& inArrs, NDArray& output, const int axis);
static void concatCpuGeneric(const std::vector<const NDArray*>& inArrs, NDArray& output, const int axis);
static void concatCpuGeneric(int dimension, int numArrays, Nd4jPointer *data, Nd4jPointer *inputShapeInfo, void *result, Nd4jLong *resultShapeInfo);
static void splitCpuGeneric(const NDArray& input, const std::vector<NDArray*>& outArrs, const int axis);
static void accumulateGeneric(void **x, void *z, Nd4jLong *zShapeInfo, int n, const Nd4jLong length);
static void averageGeneric(void **x, void *z, Nd4jLong *zShapeInfo, int n, const Nd4jLong length, bool propagate);

View File

@ -2853,289 +2853,6 @@ TEST_F(DeclarableOpsTests1, LRN1) {
lrn.getOpName();
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests1, Stack_1) {
float buff1[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12};
float buff2[] = {13,14,16,16,17,18,19,20,21,22,23,24};
float expBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12,13,14,16,16,17,18,19,20,21,22,23,24};
Nd4jLong shape1[] = {2, 3, 4, 4, 1, 0, 1, 99};
Nd4jLong shape2[] = {2, 3, 4, 4, 1, 0, 1, 99};
Nd4jLong expShape[] = {3, 2, 3, 4, 12, 4, 1, 0, 1, 99};
ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32);
ArrayOptions::setDataType(shape2, sd::DataType::FLOAT32);
ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32);
NDArray input1(buff1, shape1);
NDArray input2(buff2, shape2);
NDArray expected(expBuff, expShape);
sd::ops::stack op;
auto results = op.evaluate({&input1, &input2}, {}, {0});
auto output = results->at(0);
ASSERT_TRUE(expected.isSameShapeStrict(*output));
ASSERT_TRUE(expected.equalsTo(output));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests1, Stack_2) {
float buff1[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12};
float buff2[] = {13,14,16,16,17,18,19,20,21,22,23,24};
float expBuff[] = {1,2,3,4, 13, 14, 16, 16, 5,6,7,8, 17, 18, 19, 20, 9, 10, 11, 12, 21, 22, 23, 24};
Nd4jLong shape1[] = {2, 3, 4, 4, 1, 0, 1, 99};
Nd4jLong shape2[] = {2, 3, 4, 4, 1, 0, 1, 99};
Nd4jLong expShape[] = {3, 3, 2, 4, 8, 4, 1, 0, 1, 99};
ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32);
ArrayOptions::setDataType(shape2, sd::DataType::FLOAT32);
ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32);
NDArray input1(buff1, shape1);
NDArray input2(buff2, shape2);
NDArray expected(expBuff, expShape);
sd::ops::stack op;
auto results = op.evaluate({&input1, &input2}, {}, {1});
auto output = results->at(0);
ASSERT_TRUE(expected.isSameShapeStrict(*output));
ASSERT_TRUE(expected.equalsTo(output));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests1, Stack_3) {
float buff1[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12};
float buff2[] = {13,14,16,16,17,18,19,20,21,22,23,24};
float expBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12,13,14,16,16,17,18,19,20,21,22,23,24};
Nd4jLong shape1[] = {2, 1, 12, 12, 1, 0, 1, 99};
Nd4jLong shape2[] = {2, 1, 12, 12, 1, 0, 1, 99};
Nd4jLong expShape[] = {3, 2, 1, 12, 12, 12, 1, 0, 1, 99};
ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32);
ArrayOptions::setDataType(shape2, sd::DataType::FLOAT32);
ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32);
NDArray input1(buff1, shape1);
NDArray input2(buff2, shape2);
NDArray expected(expBuff, expShape);
sd::ops::stack op;
auto results = op.evaluate({&input1, &input2}, {}, {0});
auto output = results->at(0);
ASSERT_TRUE(expected.isSameShapeStrict(*output));
ASSERT_TRUE(expected.equalsTo(output));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests1, Stack_4) {
float buff1[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12};
float buff2[] = {13,14,16,16,17,18,19,20,21,22,23,24};
float expBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12,13,14,16,16,17,18,19,20,21,22,23,24};
Nd4jLong shape1[] = {2, 1, 12, 12, 1, 0, 1, 99};
Nd4jLong shape2[] = {2, 1, 12, 12, 1, 0, 1, 99};
Nd4jLong expShape[] = {3, 1, 2, 12, 24, 12, 1, 0, 1, 99};
ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32);
ArrayOptions::setDataType(shape2, sd::DataType::FLOAT32);
ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32);
NDArray input1(buff1, shape1);
NDArray input2(buff2, shape2);
NDArray expected(expBuff, expShape);
sd::ops::stack op;
auto results = op.evaluate({&input1, &input2}, {}, {1});
auto output = results->at(0);
ASSERT_TRUE(expected.isSameShapeStrict(*output));
ASSERT_TRUE(expected.equalsTo(output));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests1, Stack_5) {
float buff1[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12};
float buff2[] = {13,14,16,16,17,18,19,20,21,22,23,24};
float expBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12,13,14,16,16,17,18,19,20,21,22,23,24};
Nd4jLong shape1[] = {2, 12, 1, 1,1, 0, 1, 99};
Nd4jLong shape2[] = {2, 12, 1, 1,1, 0, 1, 99};
Nd4jLong expShape[] = {3, 2, 12, 1, 12, 1, 1, 0, 1, 99};
ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32);
ArrayOptions::setDataType(shape2, sd::DataType::FLOAT32);
ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32);
NDArray input1(buff1, shape1);
NDArray input2(buff2, shape2);
NDArray expected(expBuff, expShape);
sd::ops::stack op;
auto results = op.evaluate({&input1, &input2}, {}, {0});
auto output = results->at(0);
ASSERT_TRUE(expected.isSameShapeStrict(*output));
ASSERT_TRUE(expected.equalsTo(output));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests1, Stack_6) {
float buff1[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12};
float buff2[] = {13,14,16,16,17,18,19,20,21,22,23,24};
float expBuff[] = {1 ,13 ,2 ,14 ,3 ,16 ,4 ,16 ,5 ,17 ,6 ,18 ,7 ,19 ,8 ,20 ,9 ,21 ,10 ,22 ,11 ,23 ,12 ,24};
Nd4jLong shape1[] = {2, 12, 1, 1, 12, 0, 1, 99};
Nd4jLong shape2[] = {2, 12, 1, 1, 12, 0, 1, 99};
Nd4jLong expShape[] = {3, 12, 2, 1, 2, 1, 1, 0, 1, 99};
ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32);
ArrayOptions::setDataType(shape2, sd::DataType::FLOAT32);
ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32);
NDArray input1(buff1, shape1);
NDArray input2(buff2, shape2);
NDArray expected(expBuff, expShape);
sd::ops::stack op;
auto results = op.evaluate({&input1, &input2}, {}, {1});
auto output = results->at(0);
ASSERT_TRUE(expected.isSameShapeStrict(*output));
ASSERT_TRUE(expected.equalsTo(output));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests1, Stack_7) {
float buff1[] = {1};
float expBuff[] = {1, 1, 1};
Nd4jLong shape1[] = {2, 1, 1, 1, 1, 0, 1, 99};
Nd4jLong expShape[] = {3, 3, 1, 1, 1, 1, 1, 0, 1, 99};
ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32);
ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32);
NDArray input1(buff1, shape1);
NDArray expected(expBuff, expShape);
sd::ops::stack op;
auto results = op.evaluate({&input1, &input1, &input1}, {}, {0});
auto output = results->at(0);
ASSERT_TRUE(expected.isSameShapeStrict(*output));
ASSERT_TRUE(expected.equalsTo(output));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests1, Stack_8) {
float buff1[] = {1};
float expBuff[] = {1, 1, 1};
Nd4jLong shape1[] = {1, 1, 1, 0, 1, 99};
Nd4jLong expShape[] = {2, 3, 1, 1, 1, 0, 1, 99};
ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32);
ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32);
NDArray input1(buff1, shape1);
NDArray expected(expBuff, expShape);
sd::ops::stack op;
auto results = op.evaluate({&input1, &input1, &input1}, {}, {0});
auto output = results->at(0);
ASSERT_TRUE(expected.isSameShapeStrict(*output));
ASSERT_TRUE(expected.equalsTo(output));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests1, Stack_9) {
float buff1[] = {1};
float expBuff[] = {1, 1, 1};
Nd4jLong shape1[] = {2, 1, 1, 1, 1, 0, 1, 99};
Nd4jLong expShape[] = {3, 1, 3, 1, 3, 1, 1, 0, 1, 99};
ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32);
ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32);
NDArray input1(buff1, shape1);
NDArray expected(expBuff, expShape);
sd::ops::stack op;
auto results = op.evaluate({&input1, &input1, &input1}, {}, {1});
auto output = results->at(0);
ASSERT_TRUE(expected.isSameShapeStrict(*output));
ASSERT_TRUE(expected.equalsTo(output));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests1, Stack_10) {
float buff1[] = {1};
float expBuff[] = {1, 1, 1};
Nd4jLong shape1[] = {1, 1, 1, 0, 1, 99};
Nd4jLong expShape[] = {2, 1, 3, 3, 1, 0, 1, 99};
ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32);
ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32);
NDArray input1(buff1, shape1);
NDArray expected(expBuff, expShape);
sd::ops::stack op;
auto results = op.evaluate({&input1, &input1, &input1}, {}, {1});
auto output = results->at(0);
//expected.printShapeInfo("exp");
//output->printShapeInfo("out");
ASSERT_TRUE(expected.isSameShapeStrict(*output));
ASSERT_TRUE(expected.equalsTo(output));
delete results;
}
TEST_F(DeclarableOpsTests1, Stack_11) {
float buff1[] = {1};
float expBuff[] = {1, 1, 1};
Nd4jLong shape1[] = {1, 1, 1, 0, 1, 99};
Nd4jLong expShape[] = {2, 3, 1, 1, 1, 0, 1, 99};
ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32);
ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32);
NDArray input1(buff1, shape1);
NDArray expected(expBuff, expShape);
sd::ops::stack op;
auto results = op.evaluate({&input1, &input1, &input1}, {}, {});
auto output = results->at(0);
ASSERT_TRUE(expected.isSameShapeStrict(*output));
ASSERT_TRUE(expected.equalsTo(output));
delete results;
}
TEST_F(DeclarableOpsTests1, Test_Range_Integer_1) {
auto exp = NDArrayFactory::create<int>('c', {4});
exp.linspace(1);
@ -3330,74 +3047,6 @@ TEST_F(DeclarableOpsTests1, softmax_test8) {
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests1, Test_Stack_Edge_1) {
float inBuff[] = {1.0f, 2.0f, 3.0f};
float expBuff[] = {1.0f, 2.0f, 3.0f};
auto input = NDArrayFactory::create<float>(inBuff, 'c', {1, 3});
auto exp = NDArrayFactory::create<float>(expBuff, 'c', {1, 1, 3});
sd::ops::stack op;
auto result = op.evaluate({&input}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests1, Test_Stack_Edge_2) {
float inBuff[] = {1.0f, 2.0f, 3.0f};
float expBuff[] = {1.0f, 2.0f, 3.0f};
auto input = NDArrayFactory::create<float>(inBuff, 'c', {1, 1, 3});
auto exp = NDArrayFactory::create<float>(expBuff, 'c', {1, 1, 1, 3});
sd::ops::stack op;
auto result = op.evaluate({&input}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests1, Test_Stack_Edge_3) {
float inBuff[] = {1.0f, 2.0f, 3.0f};
float expBuff[] = {1.0f, 2.0f, 3.0f};
auto input = NDArrayFactory::create<float>(inBuff, 'c', {1, 3});
auto exp = NDArrayFactory::create<float>(expBuff, 'c', {1, 1, 3});
sd::ops::stack op;
auto result = op.evaluate({&input}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
//z->printShapeInfo();
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests1, Reverse_1 ) {

View File

@ -250,68 +250,6 @@ TEST_F(DeclarableOpsTests14, test_lstmBlockCell_1) {
ASSERT_EQ(Status::OK(), result);
}
TEST_F(DeclarableOpsTests14, test_empty_stack_1) {
auto x = NDArrayFactory::create<float>('c', {0});
auto e = NDArrayFactory::create<float>('c', {1, 0});
sd::ops::stack op;
auto result = op.evaluate({&x}, {}, {0});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_EQ(e, *z);
sd::ops::reduce_min sumOp;
auto res2 = sumOp.evaluate({&e}, {1.}, {1});
ASSERT_EQ(res2->status(), Status::OK());
auto out = res2->at(0);
ASSERT_EQ(out->e<float>(0), DataTypeUtils::infOrMax<float>());
delete res2;
delete result;
}
TEST_F(DeclarableOpsTests14, test_empty_stack_2) {
auto x = NDArrayFactory::empty<float>();
auto e = NDArrayFactory::create<float>('c', {0});
sd::ops::stack op;
auto result = op.evaluate({&x}, {}, {0});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_EQ(e, *z);
delete result;
}
TEST_F(DeclarableOpsTests14, test_empty_stack_3) {
auto x = NDArrayFactory::empty<float>();
auto e = NDArrayFactory::create<float>('c', {2, 0});
sd::ops::stack op;
auto result = op.evaluate({&x, &x}, {}, {0});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_EQ(e, *z);
delete result;
}
TEST_F(DeclarableOpsTests14, test_empty_stack_4) {
auto x = NDArrayFactory::create<float>('c', {0});
auto e = NDArrayFactory::create<float>('c', {2, 0});
sd::ops::stack op;
auto result = op.evaluate({&x, &x}, {}, {0});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_EQ(e, *z);
delete result;
}
TEST_F(DeclarableOpsTests14, test_empty_reduce_min_1) {
auto e = NDArrayFactory::create<float>('c', {1, 0});
@ -1655,30 +1593,489 @@ TEST_F(DeclarableOpsTests14, Test_broadcast_5D_4) {
ASSERT_EQ(e, z);
}
// @Test
// public void testMmulRank4_simple(){
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests14, Stack_1) {
// INDArray arr1 = Nd4j.ones(DataType.FLOAT, 32, 12, 128, 64);
// INDArray arr2 = Nd4j.ones(DataType.FLOAT, 32, 12, 128, 64);
float buff1[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12};
float buff2[] = {13,14,16,16,17,18,19,20,21,22,23,24};
float expBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12,13,14,16,16,17,18,19,20,21,22,23,24};
Nd4jLong shape1[] = {2, 3, 4, 4, 1, 0, 1, 99};
Nd4jLong shape2[] = {2, 3, 4, 4, 1, 0, 1, 99};
Nd4jLong expShape[] = {3, 2, 3, 4, 12, 4, 1, 0, 1, 99};
ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32);
ArrayOptions::setDataType(shape2, sd::DataType::FLOAT32);
ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32);
// DynamicCustomOp op = DynamicCustomOp.builder("matmul")
// .addInputs(arr1, arr2)
// .addIntegerArguments(0, 1) //Transpose arr2 only
// .build();
NDArray input1(buff1, shape1);
NDArray input2(buff2, shape2);
NDArray expected(expBuff, expShape);
// List<LongShapeDescriptor> shapes = op.calculateOutputShape();
// assertEquals(1, shapes.size());
// long[] shape = new long[]{32,12,128,128};
// assertArrayEquals(shape, shapes.get(0).getShape());
sd::ops::stack op;
auto results = op.evaluate({&input1, &input2}, {}, {0});
auto output = results->at(0);
// INDArray out = Nd4j.create(DataType.FLOAT, shape);
ASSERT_TRUE(expected.isSameShapeStrict(*output));
ASSERT_TRUE(expected.equalsTo(output));
// op.setOutputArgument(0, out);
// Nd4j.exec(op);
// // System.out.println(out);
delete results;
// INDArray exp = Nd4j.valueArrayOf(shape, 64.0, DataType.FLOAT); //Each entry in output is sum of 64 (1.0 x 1.0) multiplications
// assertEquals(exp, out);
// }
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests14, Stack_2) {
float buff1[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12};
float buff2[] = {13,14,16,16,17,18,19,20,21,22,23,24};
float expBuff[] = {1,2,3,4, 13, 14, 16, 16, 5,6,7,8, 17, 18, 19, 20, 9, 10, 11, 12, 21, 22, 23, 24};
Nd4jLong shape1[] = {2, 3, 4, 4, 1, 0, 1, 99};
Nd4jLong shape2[] = {2, 3, 4, 4, 1, 0, 1, 99};
Nd4jLong expShape[] = {3, 3, 2, 4, 8, 4, 1, 0, 1, 99};
ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32);
ArrayOptions::setDataType(shape2, sd::DataType::FLOAT32);
ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32);
NDArray input1(buff1, shape1);
NDArray input2(buff2, shape2);
NDArray expected(expBuff, expShape);
sd::ops::stack op;
auto results = op.evaluate({&input1, &input2}, {}, {1});
auto output = results->at(0);
ASSERT_TRUE(expected.isSameShapeStrict(*output));
ASSERT_TRUE(expected.equalsTo(output));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests14, Stack_3) {
float buff1[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12};
float buff2[] = {13,14,16,16,17,18,19,20,21,22,23,24};
float expBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12,13,14,16,16,17,18,19,20,21,22,23,24};
Nd4jLong shape1[] = {2, 1, 12, 12, 1, 0, 1, 99};
Nd4jLong shape2[] = {2, 1, 12, 12, 1, 0, 1, 99};
Nd4jLong expShape[] = {3, 2, 1, 12, 12, 12, 1, 0, 1, 99};
ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32);
ArrayOptions::setDataType(shape2, sd::DataType::FLOAT32);
ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32);
NDArray input1(buff1, shape1);
NDArray input2(buff2, shape2);
NDArray expected(expBuff, expShape);
sd::ops::stack op;
auto results = op.evaluate({&input1, &input2}, {}, {0});
auto output = results->at(0);
ASSERT_TRUE(expected.isSameShapeStrict(*output));
ASSERT_TRUE(expected.equalsTo(output));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests14, Stack_4) {
float buff1[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12};
float buff2[] = {13,14,16,16,17,18,19,20,21,22,23,24};
float expBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12,13,14,16,16,17,18,19,20,21,22,23,24};
Nd4jLong shape1[] = {2, 1, 12, 12, 1, 0, 1, 99};
Nd4jLong shape2[] = {2, 1, 12, 12, 1, 0, 1, 99};
Nd4jLong expShape[] = {3, 1, 2, 12, 24, 12, 1, 0, 1, 99};
ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32);
ArrayOptions::setDataType(shape2, sd::DataType::FLOAT32);
ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32);
NDArray input1(buff1, shape1);
NDArray input2(buff2, shape2);
NDArray expected(expBuff, expShape);
sd::ops::stack op;
auto results = op.evaluate({&input1, &input2}, {}, {1});
auto output = results->at(0);
ASSERT_TRUE(expected.isSameShapeStrict(*output));
ASSERT_TRUE(expected.equalsTo(output));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests14, Stack_5) {
float buff1[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12};
float buff2[] = {13,14,16,16,17,18,19,20,21,22,23,24};
float expBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12,13,14,16,16,17,18,19,20,21,22,23,24};
Nd4jLong shape1[] = {2, 12, 1, 1,1, 0, 1, 99};
Nd4jLong shape2[] = {2, 12, 1, 1,1, 0, 1, 99};
Nd4jLong expShape[] = {3, 2, 12, 1, 12, 1, 1, 0, 1, 99};
ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32);
ArrayOptions::setDataType(shape2, sd::DataType::FLOAT32);
ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32);
NDArray input1(buff1, shape1);
NDArray input2(buff2, shape2);
NDArray expected(expBuff, expShape);
sd::ops::stack op;
auto results = op.evaluate({&input1, &input2}, {}, {0});
auto output = results->at(0);
ASSERT_TRUE(expected.isSameShapeStrict(*output));
ASSERT_TRUE(expected.equalsTo(output));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests14, Stack_6) {
float buff1[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12};
float buff2[] = {13,14,16,16,17,18,19,20,21,22,23,24};
float expBuff[] = {1 ,13 ,2 ,14 ,3 ,16 ,4 ,16 ,5 ,17 ,6 ,18 ,7 ,19 ,8 ,20 ,9 ,21 ,10 ,22 ,11 ,23 ,12 ,24};
Nd4jLong shape1[] = {2, 12, 1, 1, 12, 0, 1, 99};
Nd4jLong shape2[] = {2, 12, 1, 1, 12, 0, 1, 99};
Nd4jLong expShape[] = {3, 12, 2, 1, 2, 1, 1, 0, 1, 99};
ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32);
ArrayOptions::setDataType(shape2, sd::DataType::FLOAT32);
ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32);
NDArray input1(buff1, shape1);
NDArray input2(buff2, shape2);
NDArray expected(expBuff, expShape);
sd::ops::stack op;
auto results = op.evaluate({&input1, &input2}, {}, {1});
auto output = results->at(0);
ASSERT_TRUE(expected.isSameShapeStrict(*output));
ASSERT_TRUE(expected.equalsTo(output));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests14, Stack_7) {
float buff1[] = {1};
float expBuff[] = {1, 1, 1};
Nd4jLong shape1[] = {2, 1, 1, 1, 1, 0, 1, 99};
Nd4jLong expShape[] = {3, 3, 1, 1, 1, 1, 1, 0, 1, 99};
ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32);
ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32);
NDArray input1(buff1, shape1);
NDArray expected(expBuff, expShape);
sd::ops::stack op;
auto results = op.evaluate({&input1, &input1, &input1}, {}, {0});
auto output = results->at(0);
ASSERT_TRUE(expected.isSameShapeStrict(*output));
ASSERT_TRUE(expected.equalsTo(output));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests14, Stack_8) {
float buff1[] = {1};
float expBuff[] = {1, 1, 1};
Nd4jLong shape1[] = {1, 1, 1, 0, 1, 99};
Nd4jLong expShape[] = {2, 3, 1, 1, 1, 0, 1, 99};
ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32);
ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32);
NDArray input1(buff1, shape1);
NDArray expected(expBuff, expShape);
sd::ops::stack op;
auto results = op.evaluate({&input1, &input1, &input1}, {}, {0});
auto output = results->at(0);
ASSERT_TRUE(expected.isSameShapeStrict(*output));
ASSERT_TRUE(expected.equalsTo(output));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests14, Stack_9) {
float buff1[] = {1};
float expBuff[] = {1, 1, 1};
Nd4jLong shape1[] = {2, 1, 1, 1, 1, 0, 1, 99};
Nd4jLong expShape[] = {3, 1, 3, 1, 3, 1, 1, 0, 1, 99};
ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32);
ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32);
NDArray input1(buff1, shape1);
NDArray expected(expBuff, expShape);
sd::ops::stack op;
auto results = op.evaluate({&input1, &input1, &input1}, {}, {1});
auto output = results->at(0);
ASSERT_TRUE(expected.isSameShapeStrict(*output));
ASSERT_TRUE(expected.equalsTo(output));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests14, Stack_10) {
float buff1[] = {1};
float expBuff[] = {1, 1, 1};
Nd4jLong shape1[] = {1, 1, 1, 0, 1, 99};
Nd4jLong expShape[] = {2, 1, 3, 3, 1, 0, 1, 99};
ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32);
ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32);
NDArray input1(buff1, shape1);
NDArray expected(expBuff, expShape);
sd::ops::stack op;
auto results = op.evaluate({&input1, &input1, &input1}, {}, {1});
auto output = results->at(0);
//expected.printShapeInfo("exp");
//output->printShapeInfo("out");
ASSERT_TRUE(expected.isSameShapeStrict(*output));
ASSERT_TRUE(expected.equalsTo(output));
delete results;
}
TEST_F(DeclarableOpsTests14, Stack_11) {
float buff1[] = {1};
float expBuff[] = {1, 1, 1};
Nd4jLong shape1[] = {1, 1, 1, 0, 1, 99};
Nd4jLong expShape[] = {2, 3, 1, 1, 1, 0, 1, 99};
ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32);
ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32);
NDArray input1(buff1, shape1);
NDArray expected(expBuff, expShape);
sd::ops::stack op;
auto results = op.evaluate({&input1, &input1, &input1}, {}, {});
auto output = results->at(0);
ASSERT_TRUE(expected.isSameShapeStrict(*output));
ASSERT_TRUE(expected.equalsTo(output));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests14, Stack_12) {
float inBuff[] = {1.0f, 2.0f, 3.0f};
float expBuff[] = {1.0f, 2.0f, 3.0f};
auto input = NDArrayFactory::create<float>(inBuff, 'c', {1, 3});
auto exp = NDArrayFactory::create<float>(expBuff, 'c', {1, 1, 3});
sd::ops::stack op;
auto result = op.evaluate({&input}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests14, Stack_13) {
float inBuff[] = {1.0f, 2.0f, 3.0f};
float expBuff[] = {1.0f, 2.0f, 3.0f};
auto input = NDArrayFactory::create<float>(inBuff, 'c', {1, 1, 3});
auto exp = NDArrayFactory::create<float>(expBuff, 'c', {1, 1, 1, 3});
sd::ops::stack op;
auto result = op.evaluate({&input}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests14, Stack_14) {
float inBuff[] = {1.0f, 2.0f, 3.0f};
float expBuff[] = {1.0f, 2.0f, 3.0f};
auto input = NDArrayFactory::create<float>(inBuff, 'c', {1, 3});
auto exp = NDArrayFactory::create<float>(expBuff, 'c', {1, 1, 3});
sd::ops::stack op;
auto result = op.evaluate({&input}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
//z->printShapeInfo();
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(DeclarableOpsTests14, Stack_15) {
auto t = NDArrayFactory::create<double>('c', {2, 3, 5});
auto u = NDArrayFactory::create<double>('c', {2, 3, 5});
auto v = NDArrayFactory::create<double>('c', {2, 3, 5});
auto exp = NDArrayFactory::create<double>('c', {3, 2, 3, 5});
sd::ops::stack op;
auto result = op.evaluate({&t, &u, &v}, {}, {-4});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
ASSERT_TRUE(exp.isSameShape(z));
delete result;
}
TEST_F(DeclarableOpsTests14, Stack_16) {
auto t = NDArrayFactory::create<float>(1.0f);
auto u = NDArrayFactory::create<float>(2.0f);
auto v = NDArrayFactory::create<float>(3.0f);
auto exp = NDArrayFactory::create<float>('c', {3}, {1, 2, 3});
sd::ops::stack op;
auto result = op.evaluate({&t, &u, &v}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(DeclarableOpsTests14, Stack_17) {
auto t = NDArrayFactory::create<float>('c', {1, 1}, {1.0f});
auto u = NDArrayFactory::create<float>('c', {1, 1}, {2.0f});
auto v = NDArrayFactory::create<float>('c', {1, 1}, {3.0f});
auto w = NDArrayFactory::create<float>('c', {1, 1}, {4.0f});
auto exp = NDArrayFactory::create<float>('c', {4, 1, 1}, {1, 2, 3, 4});
sd::ops::stack op;
auto result = op.evaluate({&t, &u, &v, &w}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
// z->printShapeInfo("z shape");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(DeclarableOpsTests14, Stack_18) {
auto x = NDArrayFactory::create<float>('c', {0});
auto e = NDArrayFactory::create<float>('c', {1, 0});
sd::ops::stack op;
auto result = op.evaluate({&x}, {}, {0});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_EQ(e, *z);
sd::ops::reduce_min sumOp;
auto res2 = sumOp.evaluate({&e}, {1.}, {1});
ASSERT_EQ(res2->status(), Status::OK());
auto out = res2->at(0);
ASSERT_EQ(out->e<float>(0), DataTypeUtils::infOrMax<float>());
delete res2;
delete result;
}
TEST_F(DeclarableOpsTests14, Stack_19) {
auto x = NDArrayFactory::empty<float>();
auto e = NDArrayFactory::create<float>('c', {0});
sd::ops::stack op;
auto result = op.evaluate({&x}, {}, {0});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_EQ(e, *z);
delete result;
}
TEST_F(DeclarableOpsTests14, Stack_20) {
auto x = NDArrayFactory::empty<float>();
auto e = NDArrayFactory::create<float>('c', {2, 0});
sd::ops::stack op;
auto result = op.evaluate({&x, &x}, {}, {0});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_EQ(e, *z);
delete result;
}
TEST_F(DeclarableOpsTests14, Stack_21) {
NDArray x1('c', {3,2}, sd::DataType::FLOAT32);
NDArray x2('c', {3,2}, sd::DataType::FLOAT32);
x1.linspace(0);
x2.linspace(6);
sd::ops::stack opStack;
auto resultStack = opStack.evaluate({&x1, &x2}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, resultStack->status());
sd::ops::concat opConcat;
auto resultConcat = opConcat.evaluate({&x1, &x2}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, resultConcat->status());
auto outStack = resultStack->at(0);
auto outConcat = resultConcat->at(0);
outConcat->reshapei({2,3,2});
ASSERT_TRUE(outStack->isSameShape(outConcat));
ASSERT_TRUE(outStack->equalsTo(outConcat));
delete resultStack;
delete resultConcat;
}

View File

@ -927,24 +927,6 @@ TEST_F(DeclarableOpsTests4, Test_Split_3) {
delete result;
}
TEST_F(DeclarableOpsTests4, Test_Stack_4) {
auto t = NDArrayFactory::create<double>('c', {2, 3, 5});
auto u = NDArrayFactory::create<double>('c', {2, 3, 5});
auto v = NDArrayFactory::create<double>('c', {2, 3, 5});
auto exp = NDArrayFactory::create<double>('c', {3, 2, 3, 5});
sd::ops::stack op;
auto result = op.evaluate({&t, &u, &v}, {}, {-4});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
ASSERT_TRUE(exp.isSameShape(z));
delete result;
}
TEST_F(DeclarableOpsTests4, Test_Squeeze_args_1) {
auto x = NDArrayFactory::create<double>('c', {2, 1, 1, 1, 2}, {1, 2, 3, 4});
auto exp = NDArrayFactory::create<double>('c', {2, 1, 2}, {1, 2, 3, 4});
@ -995,22 +977,6 @@ TEST_F(DeclarableOpsTests4, Test_Squeeze_args_3) {
delete result;
}
TEST_F(DeclarableOpsTests4, Test_1D_1) {
auto x = NDArrayFactory::create<double>('c', {2, 3});
sd::ops::unstack op;
auto result = op.evaluate({&x}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(3, result->size());
for (int e = 0; e < 3; e++)
ASSERT_EQ(1, result->at(e)->rankOf());
delete result;
}
TEST_F(DeclarableOpsTests4, Test_SpaceToDepth_1) {
auto x = NDArrayFactory::create<double>('c', {1, 2, 2, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
auto exp = NDArrayFactory::create<double>('c', {1, 1, 1, 12}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});

View File

@ -637,7 +637,7 @@ TEST_F(DeclarableOpsTests9, concat_test18) {
for (int e = 0; e < 2000; e++) {
auto exp = NDArrayFactory::create<int>('c', {300});
exp.assign(e);
auto row = z.tensorAlongDimension(e, {1});
auto row = z(e, {0});
ASSERT_EQ(exp, row);
}
}
@ -778,6 +778,33 @@ TEST_F(DeclarableOpsTests9, concat_test25) {
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, concat_test26) {
NDArray x0('f', {1, 2, 3}, sd::DataType::INT32);
NDArray x1('f', {1, 2, 3}, sd::DataType::INT32);
NDArray x2('f', {1, 2, 3}, sd::DataType::INT32);
NDArray exp('f', {3, 2, 3}, {0, 6, 12, 3, 9, 15, 1, 7, 13, 4, 10, 16, 2, 8, 14, 5, 11, 17}, sd::DataType::INT32);
x0.linspace(0);
x1.linspace(6);
x2.linspace(12);
sd::ops::concat op;
auto result = op.evaluate({&x0, &x1, &x2}, {}, {0}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto output = result->at(0);
output->printLinearBuffer();
ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output));
delete result;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, tile_bp_test1) {

View File

@ -206,24 +206,17 @@ TEST_F(EmptyTests, test_empty_scatter_1) {
}
TEST_F(EmptyTests, test_empty_scatter_2) {
auto x = NDArrayFactory::create<float>('c', {5});
auto z = NDArrayFactory::create<float>('c', {5});
NDArray x ('c', {5}, sd::DataType::FLOAT32);
NDArray z ('c', {5}, sd::DataType::FLOAT32);
auto indices = NDArrayFactory::create<int>('c', {0});
auto updates = NDArrayFactory::create<float>('c', {0});
x.linspace(1.0f);
Context ctx(1);
ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo());
ctx.setInputArray(1, indices.buffer(), indices.shapeInfo(), indices.specialBuffer(), indices.specialShapeInfo());
ctx.setInputArray(2, updates.buffer(), updates.shapeInfo(), updates.specialBuffer(), updates.specialShapeInfo());
ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo());
bool args[] = {true};
ctx.setBArguments(args, 1);
sd::ops::scatter_upd op;
auto result = op.execute(&ctx);
ASSERT_EQ(Status::OK(), result);
auto status = op.execute({&x, &indices, &updates}, {&z}, {}, {}, {true});
ASSERT_EQ(Status::OK(), status);
ASSERT_EQ(x, z);
}

View File

@ -1242,7 +1242,7 @@ TEST_F(JavaInteropTests, test_ismax_view) {
v.assign(1.0);
auto e = v.like();
auto t = e.tensorAlongDimension(0, {0, 1});
auto t = e(0, {2});
t.assign(1.0);
auto z = v.ulike();

View File

@ -674,10 +674,10 @@ TEST_F(LegacyOpsTests, test_inverse_broadcast_2) {
auto e = NDArrayFactory::create<bool>('c', {3, 4});
e.assign(false);
auto row = y.tensorAlongDimension(1, {1});
auto row = y(1, {0});
row.assign(2.0f);
auto erow = e.tensorAlongDimension(1, {1});
auto erow = e(1, {0});
erow.assign(true);
auto tadPackY = sd::ConstantTadHelper::getInstance()->tadForDimensions(y.shapeInfo(), 1);

View File

@ -197,7 +197,7 @@ TEST_F(NDArrayTest, EqualityTest1) {
TEST_F(NDArrayTest, TestTad1) {
auto array = NDArrayFactory::create_<float>('c', {3, 3});
auto row2 = array->tensorAlongDimension(1, {1});
auto row2 = (*array)(1, {0});
ASSERT_TRUE(row2.isView());
ASSERT_EQ(3, row2.lengthOf());
@ -221,7 +221,7 @@ TEST_F(NDArrayTest, TestTad2) {
TEST_F(NDArrayTest, TestTad3) {
auto array = NDArrayFactory::create_<float>('c', {4, 3});
auto row2 = array->tensorAlongDimension(1, {1});
auto row2 = (*array)(1, {0});
ASSERT_TRUE(row2.isView());
ASSERT_EQ(3, row2.lengthOf());
@ -1529,7 +1529,7 @@ TEST_F(NDArrayTest, TestStdDev1) {
////////////////////////////////////////////////////////////////////////////////
TEST_F(NDArrayTest, TestStdDev2) {
auto array = NDArrayFactory::create<double>('c', {5, 6});
auto tad = array.tensorAlongDimension(0, {0});
auto tad = array(0, {1});
ASSERT_EQ(5, tad.lengthOf());

View File

@ -229,7 +229,7 @@ TEST_F(NlpTests, basic_sg_ns_test_1) {
auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {1, 1}, {false}, {}, true);
ASSERT_EQ(Status::OK(), result->status());
auto row0 = syn0({1,2, 0,0}, true);
auto row0 = syn0({1,2, 0,0}, true);
ASSERT_EQ(exp0, row0);
ASSERT_FALSE(syn1Neg2.equalsTo(syn1Neg, 1e-6));
@ -418,10 +418,6 @@ TEST_F(NlpTests, test_sg_ns_batch_1) {
auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {4, 5}, {false, true}, {}, true);
ASSERT_EQ(Status::OK(), result->status());
auto row0 = syn0({0,0, 0,0}, true);
auto row1 = syn0({5,0, 0,0}, true);
auto row2 = syn0({2,0, 0,0}, true);
delete result;
}

View File

@ -328,6 +328,25 @@ TEST_F(ParityOpsTests, TestUnstack12) {
delete result;
}
TEST_F(ParityOpsTests, TestUnstack13) {
auto x = NDArrayFactory::create<double>('c', {2, 3});
sd::ops::unstack op;
auto result = op.evaluate({&x}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
ASSERT_EQ(3, result->size());
for (int e = 0; e < 3; e++)
ASSERT_EQ(1, result->at(e)->rankOf());
delete result;
}
TEST_F(ParityOpsTests, ExpandDimsTest1) {
auto input = NDArrayFactory::create<float>('c', {5, 5});
input.linspace(1);

View File

@ -216,47 +216,6 @@ TEST_F(ScalarTests, Test_Permute_1) {
delete result;
}
TEST_F(ScalarTests, Test_Stack_1) {
auto t = NDArrayFactory::create<float>(1.0f);
auto u = NDArrayFactory::create<float>(2.0f);
auto v = NDArrayFactory::create<float>(3.0f);
auto exp = NDArrayFactory::create<float>('c', {3}, {1, 2, 3});
sd::ops::stack op;
auto result = op.evaluate({&t, &u, &v}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(ScalarTests, Test_Stack_2) {
auto t = NDArrayFactory::create<float>('c', {1, 1}, {1.0f});
auto u = NDArrayFactory::create<float>('c', {1, 1}, {2.0f});
auto v = NDArrayFactory::create<float>('c', {1, 1}, {3.0f});
auto w = NDArrayFactory::create<float>('c', {1, 1}, {4.0f});
auto exp = NDArrayFactory::create<float>('c', {4, 1, 1}, {1, 2, 3, 4});
sd::ops::stack op;
auto result = op.evaluate({&t, &u, &v, &w}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
// z->printShapeInfo("z shape");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(ScalarTests, Test_Concat_Scalar_1) {
auto t = NDArrayFactory::create<float>('c', {1, 1}, {1.0f});
auto u = NDArrayFactory::create<float>('c', {1, 1}, {2.0f});
@ -268,7 +227,7 @@ TEST_F(ScalarTests, Test_Concat_Scalar_1) {
auto result = op.evaluate({&t, &u, &v, &w}, {}, {0});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
auto z = result->at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));

View File

@ -196,7 +196,7 @@ public:
int dimensionLength = 2;
int dimension[2] = {2,3};
Nd4jLong tadAssertionC[10] = {3,4,4,1,4,1,16,16384,1,99};
Nd4jLong tadCAssertionF[10] = {3,4,4,1,1,4,16,16384,1,102};
Nd4jLong tadCAssertionF[10] = {3,4,4,1,1,4,1,16384,1,102};
};

View File

@ -130,7 +130,7 @@ TEST_F(TadTests, TadEdgeCase_1) {
auto exp = NDArrayFactory::create<float>('c', {5, 4});
array.linspace(1);
auto tad = array.tensorAlongDimension(0, {0, 1});
auto tad = array(0, {2});
ASSERT_TRUE(exp.isSameShape(tad));
}
@ -140,7 +140,7 @@ TEST_F(TadTests, TestEdgeCase_2) {
auto array = NDArrayFactory::create<float>('f', {2, 3, 1}, {1, 4, 2, 5, 3, 6});
for (int e = 0 ; e < array.lengthOf(); e++) {
auto tad = array.tensorAlongDimension(e, {2});
auto tad = array(e, {0,1});
ASSERT_NEAR(tad.e<float>(0), array.e<float>(e), 1e-5);
}
}
@ -148,7 +148,7 @@ TEST_F(TadTests, TestEdgeCase_2) {
TEST_F(TadTests, TadEdgeCase_2) {
auto array = NDArrayFactory::create<float>('c', {2, 3, 4});
auto tad = array.tensorAlongDimension(0, {1});
auto tad = array(0, {0,2});
ASSERT_EQ(3, tad.lengthOf());
}