profiling of stack and unstack ops (#261)
* - profiling of stack and unstack ops Signed-off-by: Yurii <iuriish@yahoo.com> * - fix bug in cpu concat op Signed-off-by: Yurii <iuriish@yahoo.com> * - correction of cuda stack and unstack Signed-off-by: Yurii <iuriish@yahoo.com> * - change shape.h method which operates with unity dimensions strides Signed-off-by: Yurii <iuriish@yahoo.com> * - rearrange stack tests Signed-off-by: Yurii <iuriish@yahoo.com> * - correct evaluation of smallest stride for moving through contiguous axis Signed-off-by: Yurii <iuriish@yahoo.com> * - forgot to update signature of function strideOverContigAxis in cuda concat and split ops Signed-off-by: Yurii <iuriish@yahoo.com> * - remove ShapeUtils::shapeAsString method applied before input arrays validations Signed-off-by: Yurii <iuriish@yahoo.com> * - further removing of ShapeUtils::shapeAsString Signed-off-by: Yurii <iuriish@yahoo.com> * - take sub-array shapeIndo/offset calculation out of NDArray class - add possibility of contiguous memory copy in execTransformAny op if opNum == assign Signed-off-by: Yurii <iuriish@yahoo.com> * - correct test_empty_scatter_2 in EmptyTests.cpp Signed-off-by: Yurii <iuriish@yahoo.com> * - profiling of slice op Signed-off-by: Yurii <iuriish@yahoo.com> * - get rid of contiguous memcpy for some cases in concat and split ops Signed-off-by: Yurii <iuriish@yahoo.com> * - forgot to declare oid nd4j::SpecialMethods<T>::splitCpuGeneric Signed-off-by: Yurii <iuriish@yahoo.com> * - correct typo in calculation of threads in cuda split op Signed-off-by: Yurii <iuriish@yahoo.com> * - forgot to correct another set of threads variables in split cuda ops Signed-off-by: Yurii <iuriish@yahoo.com> * - further conflicts resolving Signed-off-by: Yurii <iuriish@yahoo.com> Co-authored-by: raver119 <raver119@gmail.com>master
parent
0f581e74e3
commit
78934c17ad
|
@ -903,14 +903,6 @@ namespace sd {
|
|||
*/
|
||||
void transposei();
|
||||
|
||||
/**
|
||||
* return array pointing on certain range of this array
|
||||
* index - the number of array to be returned among set of possible arrays
|
||||
* dimensions - array of dimensions to point on
|
||||
*/
|
||||
NDArray tensorAlongDimension(Nd4jLong index, const std::initializer_list<int>& dimensions) const;
|
||||
NDArray tensorAlongDimension(Nd4jLong index, const std::vector<int>& dimensions) const;
|
||||
|
||||
/**
|
||||
* returns the number of arrays pointing on specified dimension(s)
|
||||
* dimensions - array of dimensions to point on
|
||||
|
|
|
@ -1197,14 +1197,9 @@ void NDArray::assign(const NDArray& other, bool allowParallelism) {
|
|||
throw std::runtime_error("NDArray::assign: lengths of arrays are mismatched");
|
||||
}
|
||||
|
||||
// memcpy is allowed only for same order c && same ews (being equal to 1)
|
||||
if (ordering() == other.ordering() && ordering() == 'c' && dataType() == other.dataType() && ews() == 1 && other.ews() == 1)
|
||||
copyBuffersContinuouslyFrom(other, other.lengthOf() * other.sizeOfT());
|
||||
else {
|
||||
NDArray::prepareSpecialUse({this}, {&other});
|
||||
NativeOpExecutioner::execTransformAny(getContext(), transform::Assign, other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, nullptr, nullptr, allowParallelism);
|
||||
NDArray::registerSpecialUse({this}, {&other});
|
||||
}
|
||||
NDArray::prepareSpecialUse({this}, {&other});
|
||||
NativeOpExecutioner::execTransformAny(getContext(), transform::Assign, other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, nullptr, nullptr, allowParallelism);
|
||||
NDArray::registerSpecialUse({this}, {&other});
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -4810,30 +4805,6 @@ ResultSet NDArray::allTensorsAlongDimension(const std::vector<int> &dimensions)
|
|||
return result;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
NDArray NDArray::tensorAlongDimension(Nd4jLong index, const std::vector<int>& dimensions) const {
|
||||
std::vector<int> copy(dimensions);
|
||||
shape::checkDimensions(rankOf(), copy);
|
||||
|
||||
Nd4jLong tadLength = shape::tadLength(this->getShapeInfo(), copy.data(), copy.size());
|
||||
Nd4jLong numTads = this->lengthOf() / tadLength;
|
||||
|
||||
if (index >= numTads)
|
||||
throw std::runtime_error("Can't get index higher than total number of TADs");
|
||||
|
||||
auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(this->getShapeInfo(), copy);
|
||||
|
||||
NDArray array(_buffer, ShapeDescriptor(packX.primaryShapeInfo()), getContext(), packX.primaryOffsets()[index] + getBufferOffset());
|
||||
array._isView = true;
|
||||
|
||||
return array;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
NDArray NDArray::tensorAlongDimension(Nd4jLong index, const std::initializer_list<int>& dimensions) const {
|
||||
return tensorAlongDimension(index, std::vector<int>(dimensions));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
// operator returns sub-array with buffer pointing at this->_buffer + certain offset
|
||||
NDArray NDArray::operator()(const std::vector<Nd4jLong>& idx, const bool keepUnitiesInShape, const bool isStrided) const {
|
||||
|
@ -4841,63 +4812,73 @@ NDArray NDArray::operator()(const std::vector<Nd4jLong>& idx, const bool keepUni
|
|||
if(isEmpty())
|
||||
throw std::invalid_argument("NDArray::operator(sub-arrays): array is empty !");
|
||||
|
||||
const int rank = rankOf();
|
||||
Nd4jLong *newShapeInfo = ShapeBuilders::copyShapeInfo(getShapeInfo(), true, getContext()->getWorkspace());
|
||||
// Nd4jLong *outShapeInfo = nullptr;
|
||||
// ALLOCATE(outShapeInfo, workspace, shape::shapeInfoLength(inShapeInfo), Nd4jLong);
|
||||
|
||||
auto shapeOf = shape::shapeOf(newShapeInfo);
|
||||
auto stridesOf = shape::stride(newShapeInfo);
|
||||
int numOfUntiesInSubArrShape = 0;
|
||||
|
||||
Nd4jLong offset = 0;
|
||||
int n(isStrided ? 3 : 2), first, last, stride;
|
||||
|
||||
for (int d = rank - 1; d >= 0; --d) {
|
||||
|
||||
if (idx[n * d] != idx[n * d + 1]) {
|
||||
|
||||
first = idx[n * d] >= 0 ? idx[n * d] : idx[n * d] + sizeAt(d) + 1;
|
||||
last = idx[n * d + 1] >= 0 ? idx[n * d + 1] : idx[n * d + 1] + sizeAt(d) + 1;
|
||||
stride = isStrided ? idx[n * d + 2] : 1;
|
||||
|
||||
shapeOf[d] = (last - first + stride - 1) / stride; // ceil (last - first) / stride;
|
||||
offset += first * stridesOf[d];
|
||||
|
||||
if(shapeOf[d] != 1)
|
||||
stridesOf[d] *= stride;
|
||||
}
|
||||
}
|
||||
|
||||
Nd4jLong *newShapeInfo2 = newShapeInfo;
|
||||
Nd4jLong* subArrShapeInfo = nullptr;
|
||||
|
||||
if(!keepUnitiesInShape) {
|
||||
|
||||
std::vector<int> dimsWithUnities;
|
||||
int n(isStrided ? 3 : 2), first, last;
|
||||
|
||||
for (int d = 0; d < rank; ++d)
|
||||
if(idx[n*d] != idx[n*d+1] && shapeOf[d] == 1)
|
||||
dimsWithUnities.push_back(d);
|
||||
// calculate the number of unities in shape
|
||||
for (uint d = 0; d < rankOf(); ++d) {
|
||||
|
||||
if(!dimsWithUnities.empty())
|
||||
newShapeInfo2 = ShapeBuilders::copyShapeInfoWithoutUnites(newShapeInfo, dimsWithUnities.size(), dimsWithUnities.data(), getContext()->getWorkspace());
|
||||
if (idx[n * d] != idx[n * d + 1]) {
|
||||
|
||||
first = idx[n * d] >= 0 ? idx[n * d] : idx[n * d] + sizeAt(d) + 1;
|
||||
last = idx[n * d + 1] >= 0 ? idx[n * d + 1] : idx[n * d + 1] + sizeAt(d) + 1;
|
||||
if(last - first == 1)
|
||||
++numOfUntiesInSubArrShape;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// check if there is possibility to set ews = 1
|
||||
shape::checkStridesEwsAndOrder(newShapeInfo2);
|
||||
ALLOCATE(subArrShapeInfo, getContext()->getWorkspace(), shape::shapeInfoLength(rankOf() - numOfUntiesInSubArrShape), Nd4jLong);
|
||||
|
||||
NDArray result(_buffer, ShapeDescriptor(newShapeInfo2), getContext(), offset + getBufferOffset());
|
||||
Nd4jLong offset;
|
||||
|
||||
shape::calcSubArrShapeInfoAndOffset(idx.data(), getShapeInfo(), subArrShapeInfo, offset, keepUnitiesInShape, isStrided, numOfUntiesInSubArrShape);
|
||||
|
||||
NDArray result(_buffer, ShapeDescriptor(subArrShapeInfo), getContext(), offset + getBufferOffset());
|
||||
result._isView = true;
|
||||
|
||||
RELEASE(newShapeInfo, getContext()->getWorkspace());
|
||||
if(newShapeInfo != newShapeInfo2)
|
||||
RELEASE(newShapeInfo2, getContext()->getWorkspace());
|
||||
RELEASE(subArrShapeInfo, getContext()->getWorkspace());
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
NDArray NDArray::operator()(const Nd4jLong subArrIdx, const std::vector<int>& dimsToExclude, bool keepUnitiesInShape) const {
|
||||
|
||||
std::vector<Nd4jLong> idxRanges(2 * rankOf());
|
||||
|
||||
ShapeUtils::evalIdxRangesForSubArr(subArrIdx, _shapeInfo, dimsToExclude, idxRanges.data());
|
||||
const auto rank = rankOf();
|
||||
const auto subArrRank = static_cast<int>(dimsToExclude.size());
|
||||
|
||||
if(subArrRank > rank)
|
||||
throw std::invalid_argument("NDArray::operator(const Nd4jLong subArrIdx, const std::vector<int>& dimsToExclude, bool keepUnitiesInShape): static method: dimsToExclude is empty or has size > rank of array !");
|
||||
|
||||
memset(idxRanges.data(), 0, 2 * rank * sizeof(Nd4jLong));
|
||||
|
||||
// subArrRank == 0 means whole array, idxRanges should contain zeros only
|
||||
|
||||
if(subArrRank != 0) {
|
||||
|
||||
std::vector<Nd4jLong> shapeOfSubArr(subArrRank), indexes(subArrRank);
|
||||
for(int i = 0; i < subArrRank; ++i)
|
||||
shapeOfSubArr[i] = sizeAt(dimsToExclude[i]);
|
||||
|
||||
shape::index2coords(subArrIdx, subArrRank, shapeOfSubArr.data(), indexes.data());
|
||||
|
||||
for(int i = 0; i < subArrRank; ++i) {
|
||||
int currIdx = 2 * dimsToExclude[i];
|
||||
idxRanges[currIdx] = indexes[i];
|
||||
idxRanges[currIdx + 1] = indexes[i] + 1;
|
||||
}
|
||||
}
|
||||
|
||||
return (*this)(idxRanges, keepUnitiesInShape);
|
||||
}
|
||||
|
@ -4916,7 +4897,7 @@ void NDArray::getSubArrShapeAndOffsets(const std::vector<int>& dimsToExclude, Nd
|
|||
ALLOCATE(subArrShapeInfo, getContext()->getWorkspace(), shape::shapeInfoLength(subArrRank), Nd4jLong);
|
||||
ALLOCATE(subArrOffsets, getContext()->getWorkspace(), numOfSubArrs, Nd4jLong);
|
||||
|
||||
shape::calcSubArrShapeAndOffsets(_shapeInfo, numOfSubArrs, dimsToExclude.size(), dimsToExclude.data(), subArrShapeInfo, subArrOffsets, keepUnitiesInShape);
|
||||
shape::calcSubArrsShapeInfoAndOffsets(_shapeInfo, numOfSubArrs, dimsToExclude.size(), dimsToExclude.data(), subArrShapeInfo, subArrOffsets, keepUnitiesInShape);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
|
|
|
@ -138,16 +138,6 @@ namespace sd {
|
|||
*/
|
||||
static Nd4jLong getNumOfSubArrs(const Nd4jLong* shapeInfo, const std::vector<int>& dimsToExclude);
|
||||
|
||||
/**
|
||||
* evaluate indexes ranges that define sub-array of array having shape=shapeInfo
|
||||
* subArrIdx - index of current sub-array
|
||||
* shapeInfo - shapeInfo of array for which to evaluate sub-arrays
|
||||
* dimsToExclude - MUST BE SORTED, dimensions to evaluate sub-arrays along, i.e. when shape is [2,3,4,5] and dimsToExclude={0,2}, then there will be 8 sub-arrays with shape [3,5],
|
||||
* if dimsToExclude is empty then idxRanges containing all zeros (means whole array) will be returned.
|
||||
* idxRanges - where to put result, the length of idxRanges must be equal to 2*shapeInfo[0]
|
||||
*/
|
||||
static void evalIdxRangesForSubArr(const Nd4jLong subArrIdx, const Nd4jLong* shapeInfo, const std::vector<int>& dimsToExclude, Nd4jLong* idxRanges);
|
||||
|
||||
/**
|
||||
* return shape without unities, for example if shape is [1,2,1,3] then [2,3] will be returned
|
||||
* if unities are not present in given shapeInfo then exactly identical shape will be returned, for example [2,3] -> [2,3]
|
||||
|
@ -202,6 +192,11 @@ namespace sd {
|
|||
|
||||
static bool isSubArrayCase(const NDArray& arr1, const NDArray& arr2, std::vector<int>& sameDims);
|
||||
*/
|
||||
|
||||
/*
|
||||
* comparing of shapes, not strides
|
||||
*/
|
||||
static bool areShapesEqual(const Nd4jLong* shapeInfo, const std::vector<Nd4jLong>& shapeOnly);
|
||||
};
|
||||
|
||||
|
||||
|
|
|
@ -73,7 +73,7 @@ namespace sd {
|
|||
auto oPtr = new Nd4jLong[numOfSubArrs];
|
||||
|
||||
if (numOfSubArrs > 0)
|
||||
shape::calcSubArrShapeAndOffsets(shapeInfo, numOfSubArrs, dimsToExclude.size(), dimsToExclude.data(), sPtr, oPtr, descriptor.areUnitiesinShape());
|
||||
shape::calcSubArrsShapeInfoAndOffsets(shapeInfo, numOfSubArrs, dimsToExclude.size(), dimsToExclude.data(), sPtr, oPtr, descriptor.areUnitiesinShape());
|
||||
|
||||
|
||||
ConstantDataBuffer shapesBuffer(sPtr, nullptr, shape::shapeInfoLength(subArrRank)*sizeof(Nd4jLong), DataType::INT64);
|
||||
|
|
|
@ -77,7 +77,7 @@ namespace sd {
|
|||
auto oPtr = new Nd4jLong[numOfSubArrs];
|
||||
|
||||
if (numOfSubArrs > 0)
|
||||
shape::calcSubArrShapeAndOffsets(shapeInfo, numOfSubArrs, dimsToExclude.size(), dimsToExclude.data(), sPtr, oPtr, descriptor.areUnitiesinShape());
|
||||
shape::calcSubArrsShapeInfoAndOffsets(shapeInfo, numOfSubArrs, dimsToExclude.size(), dimsToExclude.data(), sPtr, oPtr, descriptor.areUnitiesinShape());
|
||||
|
||||
Nd4jPointer soPtr;
|
||||
auto res = cudaMalloc(reinterpret_cast<void**>(&soPtr), numOfSubArrs * sizeof(Nd4jLong));
|
||||
|
|
|
@ -940,16 +940,16 @@ NDArray* MmulHelper::mmulNxNold2(const NDArray* A, const NDArray* B, NDArray* C,
|
|||
std::vector<void*> aSubArrs(bS), bSubArrs(bS), cSubArrs(bS);
|
||||
|
||||
if(aRank > 2)
|
||||
shape::calcSubArrShapeAndOffsets(pA->getShapeInfo(), bS, dimsToExclude.size(), dimsToExclude.data(), subArrShapeInfo.data(), subArrOffsets.data());
|
||||
shape::calcSubArrsShapeInfoAndOffsets(pA->getShapeInfo(), bS, dimsToExclude.size(), dimsToExclude.data(), subArrShapeInfo.data(), subArrOffsets.data());
|
||||
for (int i = 0; i < bS; ++i)
|
||||
aSubArrs[i] = aRank == 2 ? pA->getSpecialBuffer() : pA->getSpecialBuffer() + subArrOffsets[i] * pA->sizeOfT();
|
||||
|
||||
if(bRank > 2)
|
||||
shape::calcSubArrShapeAndOffsets(pB->getShapeInfo(), bS, dimsToExclude.size(), dimsToExclude.data(), subArrShapeInfo.data(), subArrOffsets.data());
|
||||
shape::calcSubArrsShapeInfoAndOffsets(pB->getShapeInfo(), bS, dimsToExclude.size(), dimsToExclude.data(), subArrShapeInfo.data(), subArrOffsets.data());
|
||||
for (int i = 0; i < bS; ++i)
|
||||
bSubArrs[i] = bRank == 2 ? pB->getSpecialBuffer() : pB->getSpecialBuffer() + subArrOffsets[i] * pB->sizeOfT();
|
||||
|
||||
shape::calcSubArrShapeAndOffsets(pC->getShapeInfo(), bS, dimsToExclude.size(), dimsToExclude.data(), subArrShapeInfo.data(), subArrOffsets.data());
|
||||
shape::calcSubArrsShapeInfoAndOffsets(pC->getShapeInfo(), bS, dimsToExclude.size(), dimsToExclude.data(), subArrShapeInfo.data(), subArrOffsets.data());
|
||||
for (int i = 0; i < bS; ++i)
|
||||
cSubArrs[i] = pC->getSpecialBuffer() + subArrOffsets[i] * pC->sizeOfT();
|
||||
|
||||
|
|
|
@ -974,35 +974,6 @@ Nd4jLong ShapeUtils::getNumOfSubArrs(const Nd4jLong* shapeInfo, const std::vecto
|
|||
return numOfSubArrs;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
void ShapeUtils::evalIdxRangesForSubArr(const Nd4jLong subArrIdx, const Nd4jLong* shapeInfo, const std::vector<int>& dimsToExclude, Nd4jLong* idxRanges) {
|
||||
|
||||
const auto rank = shape::rank(shapeInfo);
|
||||
const auto subArrRank = static_cast<int>(dimsToExclude.size());
|
||||
|
||||
if(subArrRank > rank)
|
||||
throw std::invalid_argument("ShapeUtils::evalIdxRangesForSubArr static method: dimsToExclude is empty or has size > rank of array !");
|
||||
|
||||
if(subArrRank == 0) { // means whole array
|
||||
memset(idxRanges, 0, 2 * rank * sizeof(Nd4jLong));
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<Nd4jLong> shapeOfSubArr(subArrRank), indexes(subArrRank);
|
||||
for(int i = 0; i < subArrRank; ++i)
|
||||
shapeOfSubArr[i] = shapeInfo[dimsToExclude[i] + 1];
|
||||
|
||||
shape::index2coords(subArrIdx, subArrRank, shapeOfSubArr.data(), indexes.data());
|
||||
|
||||
memset(idxRanges, 0, 2 * rank * sizeof(Nd4jLong));
|
||||
|
||||
for(int i = 0; i < subArrRank; ++i) {
|
||||
int currIdx = 2 * dimsToExclude[i];
|
||||
idxRanges[currIdx] = indexes[i];
|
||||
idxRanges[currIdx + 1] = indexes[i] + 1;
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
std::vector<Nd4jLong> ShapeUtils::evalDimsWithoutUnities(const Nd4jLong* shapeInfo) {
|
||||
|
||||
|
@ -1080,6 +1051,19 @@ void ShapeUtils::copyCertainStridesFromShapeInfo(const Nd4jLong* inShapeInfo, co
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool ShapeUtils::areShapesEqual(const Nd4jLong* shapeInfo, const std::vector<Nd4jLong>& shapeOnly) {
|
||||
|
||||
if(shape::rank(shapeInfo) != shapeOnly.size())
|
||||
return false;
|
||||
|
||||
for(uint i = 0; i < shape::rank(shapeInfo); ++i)
|
||||
if(shape::shapeOf(shapeInfo)[i] != shapeOnly[i])
|
||||
return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/*
|
||||
bool ShapeUtils::isSubArrayCase(const NDArray& arr1, const NDArray& arr2, std::vector<int>& sameDims) {
|
||||
|
|
|
@ -117,7 +117,7 @@ namespace shape {
|
|||
ND4J_EXPORT _CUDA_HD bool haveSameShapeAndStrides(const Nd4jLong *shapeInfo1, const Nd4jLong *shapeInfo2);
|
||||
ND4J_EXPORT _CUDA_HD bool haveSameShapeAndStrides(const Nd4jLong *shapeInfo1, const Nd4jLong *shapeInfo2, const Nd4jLong *shapeInfo3);
|
||||
|
||||
ND4J_EXPORT _CUDA_HD int sizeAt(const Nd4jLong *shape, const int dim);
|
||||
ND4J_EXPORT _CUDA_HD int sizeAt(const Nd4jLong *shapeInfo, const int dim);
|
||||
|
||||
template <typename T>
|
||||
ND4J_EXPORT _CUDA_HD void fill(T* buffer, T value, Nd4jLong length);
|
||||
|
@ -469,9 +469,6 @@ namespace shape {
|
|||
ND4J_EXPORT _CUDA_HD int rank(const int *shapeInfo);
|
||||
ND4J_EXPORT _CUDA_HD int rank(const unsigned int *shapeInfo);
|
||||
|
||||
// returns pointer on elementWiseStride
|
||||
ND4J_EXPORT _CUDA_HD Nd4jLong* ews(Nd4jLong* shapeInfo);
|
||||
|
||||
/**
|
||||
* returns pointer on elementWiseStride
|
||||
*/
|
||||
|
@ -1029,7 +1026,23 @@ namespace shape {
|
|||
* subArrOffsets - output argument, contains successive sub-arrays offsets from original this-buffer
|
||||
* keepUnitiesInShape - if false then eliminate unities from sub-array shapeInfo, for example {1,a,1,b} -> {a,b}
|
||||
*/
|
||||
ND4J_EXPORT _CUDA_HD void calcSubArrShapeAndOffsets(const Nd4jLong* wholeShapeInfo, const Nd4jLong numOfSubArrs, const int dimsSize, const int* dimsToExclude, Nd4jLong* subArrShapeInfo, Nd4jLong* subArrOffsets, bool keepUnitiesInShape = false);
|
||||
ND4J_EXPORT _CUDA_HD void calcSubArrsShapeInfoAndOffsets(const Nd4jLong* wholeShapeInfo, const Nd4jLong numOfSubArrs, const int dimsSize, const int* dimsToExclude, Nd4jLong* subArrShapeInfo, Nd4jLong* subArrOffsets, bool keepUnitiesInShape = false);
|
||||
|
||||
/**
|
||||
* processes only one sub-array, evaluates shapeInfo of sub-array and its buffer offset from original array
|
||||
* arguments:
|
||||
* idx - input argument, intervals of indexes which define the sub-array to point on,
|
||||
* when isStrided = false then idx has form {dim0Start,dim0End, dim1Start,dim1End, ....} and length (2 * maxRank)
|
||||
* when isStrided = true then idx has form {dim0Start,dim0End,dim0Stride, dim1Start,dim1End,dim1Stride, ....} and length (3 * maxRank)
|
||||
* when (dimStart == dimEnd) then whole range will be used for current dimension
|
||||
* maxShapeInfo - input argument, shapeInfo of original array
|
||||
* minShapeInfo - output argument, shapeInfo of sub-array to be deduced
|
||||
* minOffset - output argument, offset of sub-array buffer offsets from original buffer
|
||||
* keepUnitiesInShape - input argument, if false then eliminate unities from sub-array shapeInfo, for example {1,a,1,b} -> {a,b}
|
||||
* isStrided - input argument, if true then idx has length (3 * this->rankOf()) and contains additional stride numbers which correspond to stride between dimStart and dimEnd,
|
||||
* numOfUntiesInMinShape - input argument, number of occurrences in idx when (dimEnd - dimStart) = 1
|
||||
*/
|
||||
ND4J_EXPORT void calcSubArrShapeInfoAndOffset(const Nd4jLong* idx, const Nd4jLong* maxShapeInfo, Nd4jLong* minShapeInfo, Nd4jLong& minOffset, const bool keepUnitiesInShape = false, const bool isStrided = false, const int numOfUntiesInMinShape = 0);
|
||||
|
||||
/**
|
||||
* for example inShapeInfo is {3, 2,1,4, 4,4,1, 16384,1,99}
|
||||
|
@ -1046,6 +1059,12 @@ namespace shape {
|
|||
*/
|
||||
INLINEDEF _CUDA_HD void excludeUnitiesFromShapeInfo(const Nd4jLong* inShapeInfo, const int dimsSize, const int* dimsToExclude, Nd4jLong* outShapeInfo);
|
||||
|
||||
/**
|
||||
* get stride over contiguous axis (contiguous axis must have stride = 1)
|
||||
* for example when inShapeInfo is {4, 2,5,4,3, 60,1,5,20, 16384,0,99} then output is 5 (that is smallest stride in inShapeInfo except those equal to 1)
|
||||
*/
|
||||
INLINEDEF _CUDA_HD Nd4jLong strideOverContigAxis(const int axis, const Nd4jLong* inShapeInfo);
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -2961,13 +2980,13 @@ INLINEDEF _CUDA_HD bool haveSameShapeAndStrides(const Nd4jLong *shapeInfo1, cons
|
|||
|
||||
return shape::haveSameShapeAndStrides(shapeInfo1, shapeInfo2) && shape::haveSameShapeAndStrides(shapeInfo1, shapeInfo3);
|
||||
}
|
||||
INLINEDEF _CUDA_HD int sizeAt(const Nd4jLong *shape, const int dim) {
|
||||
if (0 == rank(shape))
|
||||
INLINEDEF _CUDA_HD int sizeAt(const Nd4jLong *shapeInfo, const int dim) {
|
||||
if (0 == rank(shapeInfo))
|
||||
return 1;
|
||||
if (dim >= 0)
|
||||
return shape[1+dim];
|
||||
return shapeInfo[1+dim];
|
||||
else
|
||||
return shape[1+(rank(shape) + dim)];
|
||||
return shapeInfo[1+(rank(shapeInfo) + dim)];
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -4683,21 +4702,6 @@ INLINEDEF void _CUDA_HD checkStridesEwsAndOrder(Nd4jLong* shapeInfo, const char
|
|||
|
||||
if(contiguous) {
|
||||
|
||||
// for example we have shapeInfo = {3, 5,1,1, 4,4,1, ...} then we should change it to shapeInfo = {3, 5,1,1, 4,4,4, ...ews=4}
|
||||
if(numOfNonUnities < rank) { // unities are present in shape
|
||||
|
||||
int indNonUnit = rank - 1;
|
||||
|
||||
while(shape::shapeOf(shapeInfo)[indNonUnit--] == 1)
|
||||
|
||||
for(int j = indNonUnit + 2; j < rank; ++j)
|
||||
shape::stride(shapeInfo)[j] = stridesNoUnities[numOfNonUnities - 1];
|
||||
|
||||
for(int j = indNonUnit; j >= 0; --j)
|
||||
if(shape::shapeOf(shapeInfo)[j] == 1)
|
||||
shape::stride(shapeInfo)[j] = shape::shapeOf(shapeInfo)[j + 1] * shape::stride(shapeInfo)[j + 1];
|
||||
}
|
||||
|
||||
*shape::ews(shapeInfo) = stridesNoUnities[numOfNonUnities - 1];
|
||||
shapeInfo[rank * 2 + 3] = 99;
|
||||
return;
|
||||
|
@ -4715,21 +4719,6 @@ INLINEDEF void _CUDA_HD checkStridesEwsAndOrder(Nd4jLong* shapeInfo, const char
|
|||
|
||||
if(contiguous) {
|
||||
|
||||
// for example we have shapeInfo = {3, 1,1,5, 1,4,4, ...} then we should change it to shapeInfo = {3, 1,1,5, 4,4,4, ...ews=4}
|
||||
if(numOfNonUnities < rank) { // unities are present in shape
|
||||
|
||||
int indNonUnit = 0;
|
||||
|
||||
while(shape::shapeOf(shapeInfo)[indNonUnit++] == 1)
|
||||
|
||||
for(int j = 0; j < indNonUnit - 1; ++j)
|
||||
shape::stride(shapeInfo)[j] = stridesNoUnities[0];
|
||||
|
||||
for(int j = indNonUnit; j < rank; ++j)
|
||||
if(shape::shapeOf(shapeInfo)[j] == 1)
|
||||
shape::stride(shapeInfo)[j] = shape::shapeOf(shapeInfo)[j - 1] * shape::stride(shapeInfo)[j - 1];
|
||||
}
|
||||
|
||||
*shape::ews(shapeInfo) = stridesNoUnities[0];
|
||||
shapeInfo[rank * 2 + 3] = 102;
|
||||
return;
|
||||
|
@ -4740,7 +4729,7 @@ INLINEDEF void _CUDA_HD checkStridesEwsAndOrder(Nd4jLong* shapeInfo, const char
|
|||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
INLINEDEF _CUDA_HD void calcSubArrShapeAndOffsets(const Nd4jLong* wholeShapeInfo, const Nd4jLong numOfSubArrs, const int dimsSize, const int* dimsToExclude, Nd4jLong* subArrShapeInfo, Nd4jLong* subArrOffsets, bool keepUnitiesInShape) {
|
||||
INLINEDEF _CUDA_HD void calcSubArrsShapeInfoAndOffsets(const Nd4jLong* wholeShapeInfo, const Nd4jLong numOfSubArrs, const int dimsSize, const int* dimsToExclude, Nd4jLong* subArrShapeInfo, Nd4jLong* subArrOffsets, bool keepUnitiesInShape) {
|
||||
|
||||
const int rank = shape::rank(wholeShapeInfo);
|
||||
|
||||
|
@ -4788,6 +4777,54 @@ INLINEDEF _CUDA_HD void calcSubArrShapeAndOffsets(const Nd4jLong* wholeShapeInfo
|
|||
delete []shape;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
INLINEDEF void calcSubArrShapeInfoAndOffset(const Nd4jLong* idx, const Nd4jLong* maxShapeInfo, Nd4jLong* minShapeInfo, Nd4jLong& minOffset, const bool keepUnitiesInShape, const bool isStrided, const int numOfUntiesInMinShape) {
|
||||
|
||||
const uint maxRank = shape::rank(maxShapeInfo);
|
||||
minOffset = 0;
|
||||
uint first, last, stride, n(isStrided ? 3 : 2);
|
||||
|
||||
minShapeInfo[0] = keepUnitiesInShape ? maxRank : maxRank - numOfUntiesInMinShape;
|
||||
|
||||
for (uint step = 0, j = 0, i = 0; i < maxRank; ++i, step += n) {
|
||||
|
||||
if (idx[step] == idx[step + 1]) { // means whole dimension
|
||||
shape::shapeOf(minShapeInfo)[j] = shape::shapeOf(maxShapeInfo)[i];
|
||||
shape::stride(minShapeInfo)[j++] = shape::stride(maxShapeInfo)[i];
|
||||
}
|
||||
else {
|
||||
|
||||
first = idx[step] >= 0 ? idx[step] : idx[step] + shape::sizeAt(maxShapeInfo, i) + 1;
|
||||
last = idx[step + 1] >= 0 ? idx[step + 1] : idx[step + 1] + shape::sizeAt(maxShapeInfo, i) + 1;
|
||||
|
||||
if(last < first)
|
||||
throw("shape::calcSubArrShapeInfoAndOffset: negative range in input indexes is found!");
|
||||
|
||||
if(isStrided) {
|
||||
stride = idx[step + 2];
|
||||
last /*resulting sub-array axis*/ = (last - first + stride - 1) / stride; // ceil (last - first) / stride;
|
||||
}
|
||||
else {
|
||||
stride = 1;
|
||||
last /*resulting sub-array axis*/ = last - first;
|
||||
}
|
||||
|
||||
minOffset += first * shape::stride(maxShapeInfo)[i];
|
||||
|
||||
if(!keepUnitiesInShape && last == 1)
|
||||
continue;
|
||||
|
||||
shape::shapeOf(minShapeInfo)[j] = last;
|
||||
shape::stride(minShapeInfo)[j++] = last == 1 ? shape::stride(maxShapeInfo)[i] : shape::stride(maxShapeInfo)[i] * stride;
|
||||
}
|
||||
}
|
||||
|
||||
minShapeInfo[2 * shape::rank(minShapeInfo) + 3] = shape::order(maxShapeInfo); // order
|
||||
minShapeInfo[2 * shape::rank(minShapeInfo) + 1] = shape::type(maxShapeInfo); // type
|
||||
|
||||
shape::checkStridesEwsAndOrder(minShapeInfo);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
INLINEDEF void _CUDA_HD index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, Nd4jLong *coords) {
|
||||
|
||||
|
@ -5083,6 +5120,27 @@ INLINEDEF _CUDA_HD void excludeUnitiesFromShapeInfo(const Nd4jLong* inShapeInfo,
|
|||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
INLINEDEF _CUDA_HD Nd4jLong strideOverContigAxis(const int axis, const Nd4jLong* inShapeInfo) {
|
||||
|
||||
Nd4jLong result = 9223372036854775807LL;
|
||||
|
||||
for(uint i = 0; i < shape::rank(inShapeInfo); ++i) {
|
||||
|
||||
const auto currentStride = shape::stride(inShapeInfo)[i];
|
||||
|
||||
if(i == axis || shape::shapeOf(inShapeInfo)[i] == 1)
|
||||
continue;
|
||||
|
||||
if(result > currentStride)
|
||||
result = currentStride;
|
||||
}
|
||||
|
||||
return result == 9223372036854775807LL ? 1 : result;
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
#endif /* SHAPE_H_ */
|
||||
|
|
|
@ -1233,11 +1233,18 @@ void NativeOpExecutioner::execTransformAny(sd::LaunchContext *lc,
|
|||
if (shape::isEmpty(hXShapeInfo))
|
||||
return;
|
||||
|
||||
auto func = PRAGMA_THREADS_DO {
|
||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformAny, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||
};
|
||||
if (opNum == sd::transform::Assign && shape::order(hXShapeInfo) == shape::order(hZShapeInfo) && shape::order(hXShapeInfo) == 'c' && xType == zType && shape::elementWiseStride(hXShapeInfo) == 1 && shape::elementWiseStride(hZShapeInfo) == 1) {
|
||||
|
||||
samediff::Threads::parallel_do(func, sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(shape::length(hZShapeInfo) / 1024, sd::Environment::getInstance()->maxMasterThreads())));
|
||||
memcpy(hZ, hX, shape::length(hXShapeInfo) * sd::DataTypeUtils::sizeOfElement(xType));
|
||||
}
|
||||
else {
|
||||
auto func = PRAGMA_THREADS_DO {
|
||||
|
||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformAny, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||
};
|
||||
|
||||
samediff::Threads::parallel_do(func, sd::math::nd4j_max<int>(1, sd::math::nd4j_min<int>(shape::length(hZShapeInfo) / 1024, sd::Environment::getInstance()->maxMasterThreads())));
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
|
|
|
@ -926,9 +926,14 @@ void NativeOpExecutioner::execTransformAny(sd::LaunchContext *lc,
|
|||
if (shape::isEmpty(hXShapeInfo))
|
||||
return;
|
||||
|
||||
dim3 launchDims(512, 512, 2048);
|
||||
if (opNum == sd::transform::Assign && shape::order(hXShapeInfo) == shape::order(hZShapeInfo) && shape::order(hXShapeInfo) == 'c' && xType == zType && shape::elementWiseStride(hXShapeInfo) == 1 && shape::elementWiseStride(hZShapeInfo) == 1) {
|
||||
cudaMemcpyAsync(dZ, dX, shape::length(hXShapeInfo) * sd::DataTypeUtils::sizeOfElement(xType), cudaMemcpyDeviceToDevice, *stream);
|
||||
}
|
||||
else {
|
||||
|
||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformAny, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||
dim3 launchDims(512, 512, 2048);
|
||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformAny, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||
}
|
||||
|
||||
// TODO: remove after the release
|
||||
auto res = cudaStreamSynchronize(*stream);
|
||||
|
|
|
@ -63,8 +63,8 @@ CUSTOM_OP_IMPL(conv1d, 2, 1, false, 0, 5) {
|
|||
int iC = input->sizeAt(indIOioC); // input channels
|
||||
int oC = weights->sizeAt(indWoC); // output channels
|
||||
|
||||
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kW, iC, oC});
|
||||
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, "CUSTOM CONV1D OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||
std::vector<Nd4jLong> expectedWeightsShape = {kW, iC, oC};
|
||||
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV1D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||
if (bias)
|
||||
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV1D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||
|
||||
|
@ -123,8 +123,8 @@ DECLARE_SHAPE_FN(conv1d) {
|
|||
int iC = inputShapeInfo[indIOioC+1]; // input channels
|
||||
int oC = weightsShapeInfo[indWoC+1]; // output channels
|
||||
|
||||
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kW, iC, oC});
|
||||
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weightsShapeInfo), 0, "CUSTOM CONV1D OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
|
||||
std::vector<Nd4jLong> expectedWeightsShape = {kW, iC, oC};
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "CUSTOM CONV1D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
|
||||
if (biasShapeInfo)
|
||||
REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM CONV1D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo));
|
||||
|
||||
|
@ -198,10 +198,10 @@ CUSTOM_OP_IMPL(conv1d_bp, 3, 2, false, 0, 5) {
|
|||
int trueoH, trueoW; // true output height, width
|
||||
ConvolutionUtils::calcOutSizePool2D(trueoH,trueoW, 1,kW, 1,sW, 0,pW, 1,dW, 1,iW, paddingMode);
|
||||
|
||||
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoW, 0,indIOioC,indIiW}));
|
||||
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kW, iC, oC});
|
||||
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "CUSTOM CONV1D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, "CUSTOM CONV1D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoW, 0,indIOioC,indIiW});
|
||||
std::vector<Nd4jLong> expectedWeightsShape = {kW, iC, oC};
|
||||
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM CONV1D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV1D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||
if(bias)
|
||||
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV1D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||
|
||||
|
@ -267,10 +267,10 @@ DECLARE_SHAPE_FN(conv1d_bp) {
|
|||
int trueoH, trueoW; // true output height, width
|
||||
ConvolutionUtils::calcOutSizePool2D(trueoH,trueoW, 1,kW, 1,sW, 0,pW, 1,dW, 1,iW, paddingMode);
|
||||
|
||||
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoW, 0,indIOioC,indIiW}));
|
||||
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kW, iC, oC});
|
||||
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradOShapeInfo), 0, "CUSTOM CONV1D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str());
|
||||
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weightsShapeInfo), 0, "CUSTOM CONV1D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
|
||||
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoW, 0,indIOioC,indIiW});
|
||||
std::vector<Nd4jLong> expectedWeightsShape = {kW, iC, oC};
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(gradOShapeInfo, expectedGradOShape), 0, "CUSTOM CONV1D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str());
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "CUSTOM CONV1D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
|
||||
if(biasShapeInfo)
|
||||
REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM CONV1D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo));
|
||||
|
||||
|
|
|
@ -58,8 +58,8 @@ CUSTOM_OP_IMPL(conv2d, 2, 1, false, 0, 9) {
|
|||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
|
||||
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kH, kW, iC, oC});
|
||||
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, "CUSTOM CONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, oC};
|
||||
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||
if (bias)
|
||||
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||
|
||||
|
@ -109,8 +109,8 @@ DECLARE_SHAPE_FN(conv2d) {
|
|||
const int iC = inputShapeInfo[indIOioC+1]; // input channels
|
||||
const int oC = weightsShapeInfo[indWoC+1]; // output channels
|
||||
|
||||
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kH, kW, iC, oC});
|
||||
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weightsShapeInfo), 0, "CUSTOM CONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
|
||||
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, oC};
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "CUSTOM CONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
|
||||
if (biasShapeInfo)
|
||||
REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM CONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo));
|
||||
|
||||
|
@ -187,10 +187,10 @@ CUSTOM_OP_IMPL(conv2d_bp, 3, 2, false, 0, 9) {
|
|||
int trueoH, trueoW; // true output height, width
|
||||
ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
|
||||
|
||||
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}));
|
||||
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kH, kW, iC, oC});
|
||||
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "CUSTOM CONV2D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, "CUSTOM CONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||
std::vector<Nd4jLong>expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1});
|
||||
std::vector<Nd4jLong>expectedWeightsShape = {kH, kW, iC, oC};
|
||||
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM CONV2D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||
if(bias)
|
||||
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV2D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||
|
||||
|
@ -242,10 +242,10 @@ DECLARE_SHAPE_FN(conv2d_bp) {
|
|||
int trueoH, trueoW; // true output height, width
|
||||
ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
|
||||
|
||||
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}));
|
||||
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kH, kW, iC, oC});
|
||||
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradOShapeInfo), 0, "CUSTOM CONV2D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str());
|
||||
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weightsShapeInfo), 0, "CUSTOM CONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
|
||||
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1});
|
||||
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, oC};
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(gradOShapeInfo, expectedGradOShape), 0, "CUSTOM CONV2D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str());
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "CUSTOM CONV2D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
|
||||
if(biasShapeInfo)
|
||||
REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM CONV2D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo));
|
||||
|
||||
|
@ -300,10 +300,10 @@ CUSTOM_OP_IMPL(conv2d_input_bp, 3, 1, false, 0, 9) {
|
|||
int trueoH, trueoW; // true output height, width
|
||||
ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
|
||||
|
||||
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}));
|
||||
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kH, kW, iC, oC});
|
||||
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "CUSTOM CONV2D_INPUT_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, "CUSTOM CONV2D_INPUT_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1});
|
||||
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, oC};
|
||||
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM CONV2D_INPUT_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV2D_INPUT_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||
|
||||
ConvolutionUtils::conv2dBP(block, &input, weights, nullptr, gradO, gradI, nullptr, nullptr, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW);
|
||||
|
||||
|
@ -360,10 +360,10 @@ DECLARE_SHAPE_FN(conv2d_input_bp) {
|
|||
int trueoH, trueoW; // true output height, width
|
||||
ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
|
||||
|
||||
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}));
|
||||
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kH, kW, iC, oC});
|
||||
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradOShapeInfo), 0, "CUSTOM CONV2D_INPUT_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str());
|
||||
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weightsShapeInfo), 0, "CUSTOM CONV2D_INPUT_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
|
||||
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1});
|
||||
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, oC};
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(gradOShapeInfo, expectedGradOShape), 0, "CUSTOM CONV2D_INPUT_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str());
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "CUSTOM CONV2D_INPUT_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
|
||||
|
||||
Nd4jLong* gradIshapeInfo(nullptr);
|
||||
ALLOCATE(gradIshapeInfo, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong);
|
||||
|
|
|
@ -59,8 +59,8 @@ CUSTOM_OP_IMPL(conv3dnew, 2, 1, false, 0, 13) {
|
|||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
|
||||
REQUIRE_TRUE(paddingMode < 2, 0, "CUSTOM CONV3D OP: causal padding mode (paddingMode = 2) is not allowed for this operation !");
|
||||
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kD, kH, kW, iC, oC});
|
||||
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, "CUSTOM CONV3D OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||
std::vector<Nd4jLong> expectedWeightsShape = {kD, kH, kW, iC, oC};
|
||||
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV3D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||
if (bias)
|
||||
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV3D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||
|
||||
|
@ -139,8 +139,8 @@ DECLARE_SHAPE_FN(conv3dnew) {
|
|||
int iC = inputShapeInfo[indIOioC+1]; // input channels
|
||||
int oC = weightsShapeInfo[indWoC+1]; // output channels
|
||||
|
||||
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kD, kH, kW, iC, oC});
|
||||
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weightsShapeInfo), 0, "CUSTOM CONV3D OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
|
||||
std::vector<Nd4jLong> expectedWeightsShape = {kD, kH, kW, iC, oC};
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "CUSTOM CONV3D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
|
||||
if (biasShapeInfo)
|
||||
REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM CONV3D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo));
|
||||
|
||||
|
@ -209,10 +209,10 @@ CUSTOM_OP_IMPL(conv3dnew_bp, 3, 2, false, 0, 13) {
|
|||
ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, paddingMode);
|
||||
|
||||
REQUIRE_TRUE(paddingMode < 2, 0, "CUSTOM CONV3D_BP OP: causal padding mode (paddingMode = 2) is not allowed for this operation !");
|
||||
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}));
|
||||
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kD, kH, kW, iC, oC});
|
||||
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "CUSTOM CONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, "CUSTOM CONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
|
||||
std::vector<Nd4jLong> expectedWeightsShape = {kD, kH, kW, iC, oC};
|
||||
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM CONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||
if(bias)
|
||||
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV3D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||
|
||||
|
@ -313,10 +313,10 @@ DECLARE_SHAPE_FN(conv3dnew_bp) {
|
|||
int trueoD, trueoH, trueoW; // true output depth/height/width
|
||||
ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, paddingMode);
|
||||
|
||||
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW, 0,indIOioC,indIiD,indIiD+1,indIiD+2}));
|
||||
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kD, kH, kW, iC, oC});
|
||||
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradOShapeInfo), 0, "CUSTOM CONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str());
|
||||
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weightsShapeInfo), 0, "CUSTOM CONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
|
||||
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoD,trueoH,trueoW, 0,indIOioC,indIiD,indIiD+1,indIiD+2});
|
||||
std::vector<Nd4jLong> expectedWeightsShape = {kD, kH, kW, iC, oC};
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(gradOShapeInfo, expectedGradOShape), 0, "CUSTOM CONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str());
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "CUSTOM CONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
|
||||
if(biasShapeInfo)
|
||||
REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "CUSTOM CONV3D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo));
|
||||
|
||||
|
|
|
@ -31,13 +31,13 @@ CUSTOM_OP_IMPL(pointwise_conv2d, 2, 1, false, 0, 0) {
|
|||
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||
auto weights = INPUT_VARIABLE(1); // [1, 1, iC, oC] always
|
||||
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||
|
||||
|
||||
auto output = OUTPUT_VARIABLE(0); // [bS, iH, iW, oC] (NHWC) or [bS, oC, iH, iW] (NCHW)
|
||||
|
||||
REQUIRE_TRUE(input->rankOf() == 4, 0, "CUSTOM POINTWISECONV2D OP: rank of input array must be equal to 4, but got %i instead !", input->rankOf());
|
||||
REQUIRE_TRUE(weights->rankOf() == 4, 0, "CUSTOM POINTWISECONV2D OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf());
|
||||
if(bias)
|
||||
REQUIRE_TRUE(bias->rankOf() <= 2, 0, "CUSTOM POINTWISECONV2D OP: rank of biases array must be equal <= 2, but got %i instead !", bias->rankOf());
|
||||
REQUIRE_TRUE(bias->rankOf() <= 2, 0, "CUSTOM POINTWISECONV2D OP: rank of biases array must be equal <= 2, but got %i instead !", bias->rankOf());
|
||||
|
||||
int kH = 1; // filter(kernel) height
|
||||
int kW = 1; // filter(kernel) width
|
||||
|
@ -46,18 +46,18 @@ CUSTOM_OP_IMPL(pointwise_conv2d, 2, 1, false, 0, 0) {
|
|||
int pH = 0; // paddings height
|
||||
int pW = 0; // paddings width
|
||||
int dH = 1; // dilations height
|
||||
int dW = 1; // dilations width
|
||||
int dW = 1; // dilations width
|
||||
int isNCHW = block.getIArguments()->size() > 0 ? !INT_ARG(0) : 1; // INT_ARG(0): 0-NCHW, 1-NHWC
|
||||
|
||||
|
||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
|
||||
std::string expectedWeightsShape = ShapeUtils::shapeAsString({1, 1, iC, oC});
|
||||
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, "CUSTOM POINTWISECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||
if (bias)
|
||||
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM POINTWISECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||
|
||||
std::vector<Nd4jLong> expectedWeightsShape = {1, 1, iC, oC};
|
||||
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM POINTWISECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||
if (bias)
|
||||
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM POINTWISECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||
|
||||
ConvolutionUtils::conv2d(block, input, weights, bias, output, kH,kW, sH,sW, pH,pW, dH,dW, 1/*isSameMode*/, isNCHW);
|
||||
|
||||
return Status::OK();
|
||||
|
@ -71,7 +71,7 @@ CUSTOM_OP_IMPL(pointwise_conv2d, 2, 1, false, 0, 0) {
|
|||
|
||||
|
||||
DECLARE_SHAPE_FN(pointwise_conv2d) {
|
||||
|
||||
|
||||
Nd4jLong* inputShapeInfo = inputShape->at(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||
Nd4jLong* weightsShapeInfo = inputShape->at(1); // [1, 1, iC, oC] always
|
||||
Nd4jLong* biasShapeInfo = block.width() > 2 ? inputShape->at(2) : nullptr; // [oC]
|
||||
|
@ -89,18 +89,18 @@ DECLARE_SHAPE_FN(pointwise_conv2d) {
|
|||
indIOioC = 1;
|
||||
|
||||
const int bS = inputShapeInfo[1]; // batch size
|
||||
const int iC = inputShapeInfo[indIOioC+1]; // input channels
|
||||
const int iC = inputShapeInfo[indIOioC+1]; // input channels
|
||||
const int oC = weightsShapeInfo[indWoC+1]; // output channels
|
||||
|
||||
std::string expectedWeightsShape = ShapeUtils::shapeAsString({1, 1, iC, oC});
|
||||
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weightsShapeInfo), 0, "POINTWISECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
|
||||
if (biasShapeInfo)
|
||||
REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "POINTWISECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo));
|
||||
std::vector<Nd4jLong> expectedWeightsShape = {1, 1, iC, oC};
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsShapeInfo, expectedWeightsShape), 0, "POINTWISECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weightsShapeInfo).c_str());
|
||||
if (biasShapeInfo)
|
||||
REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "POINTWISECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo));
|
||||
|
||||
auto outputShapeInfo = ShapeBuilders::copyShapeInfoAndType(inputShapeInfo, weightsShapeInfo, true, block.getWorkspace());
|
||||
|
||||
// do not forget to put oC instead of iC in outputShapeInfo
|
||||
outputShapeInfo[indIOioC + 1] = oC;
|
||||
outputShapeInfo[indIOioC + 1] = oC;
|
||||
|
||||
shape::updateStrides(outputShapeInfo, shape::order(inputShapeInfo));
|
||||
|
||||
|
|
|
@ -73,11 +73,11 @@ CUSTOM_OP_IMPL(sconv2d, 2, 1, false, 0, 9) {
|
|||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH);
|
||||
mC = weightsDepth->sizeAt(indWmC); // channels multiplier
|
||||
|
||||
std::string expectedWeightsDShape = ShapeUtils::shapeAsString({kH, kW, iC, mC});
|
||||
REQUIRE_TRUE(expectedWeightsDShape == ShapeUtils::shapeAsString(weightsDepth), 0, " SCONV2D OP: wrong shape of weightsDepth array, expected is %s, but got %s instead !", expectedWeightsDShape.c_str(), ShapeUtils::shapeAsString(weightsDepth).c_str());
|
||||
std::vector<Nd4jLong> expectedWeightsDShape = {kH, kW, iC, mC};
|
||||
REQUIRE_TRUE(weightsDepth->isSameShape(expectedWeightsDShape), 0, " SCONV2D OP: wrong shape of weightsDepth array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsDShape).c_str(), ShapeUtils::shapeAsString(weightsDepth).c_str());
|
||||
if(weightsPoint) {
|
||||
std::string expectedWeightsPShape = ShapeUtils::shapeAsString({1, 1, iC*mC, oC});
|
||||
REQUIRE_TRUE(expectedWeightsPShape == ShapeUtils::shapeAsString(weightsPoint), 0, " SCONV2D OP: wrong shape of weightsPoint array, expected is %s, but got %s instead !", expectedWeightsPShape.c_str(), ShapeUtils::shapeAsString(weightsPoint).c_str());
|
||||
std::vector<Nd4jLong> expectedWeightsPShape = {1, 1, iC*mC, oC};
|
||||
REQUIRE_TRUE(weightsPoint->isSameShape(expectedWeightsPShape), 0, " SCONV2D OP: wrong shape of weightsPoint array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsPShape).c_str(), ShapeUtils::shapeAsString(weightsPoint).c_str());
|
||||
}
|
||||
if (bias)
|
||||
REQUIRE_TRUE(oC == bias->lengthOf(), 0, " SCONV2D OP: length of bias array must be equal to outChannels, but got %i instead", bias->lengthOf());
|
||||
|
@ -151,11 +151,11 @@ DECLARE_SHAPE_FN(sconv2d) {
|
|||
const int mC = weightsDShapeInfo[indWmC+1]; // channel multiplier
|
||||
const int oC = weightsPShapeInfo ? weightsPShapeInfo[indWmC+1] : iC*mC; // output channels (oC or iC*mC)
|
||||
|
||||
std::string expectedWeightsDShape = ShapeUtils::shapeAsString({kH, kW, iC, mC});
|
||||
REQUIRE_TRUE(expectedWeightsDShape == ShapeUtils::shapeAsString(weightsDShapeInfo), 0, "SCONV2D OP: wrong shape of depth weights array, expected is %s, but got %s instead !", expectedWeightsDShape.c_str(), ShapeUtils::shapeAsString(weightsDShapeInfo).c_str());
|
||||
std::vector<Nd4jLong> expectedWeightsDShape = {kH, kW, iC, mC};
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsDShapeInfo, expectedWeightsDShape), 0, "SCONV2D OP: wrong shape of depth weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsDShape).c_str(), ShapeUtils::shapeAsString(weightsDShapeInfo).c_str());
|
||||
if(weightsPShapeInfo) {
|
||||
std::string expectedWeightsPShape = ShapeUtils::shapeAsString({1, 1, iC*mC, oC});
|
||||
REQUIRE_TRUE(expectedWeightsPShape == ShapeUtils::shapeAsString(weightsPShapeInfo), 0, "SCONV2D OP: wrong shape of point array, expected is %s, but got %s instead !", expectedWeightsPShape.c_str(), ShapeUtils::shapeAsString(weightsPShapeInfo).c_str());
|
||||
std::vector<Nd4jLong> expectedWeightsPShape = {1, 1, iC*mC, oC};
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsPShapeInfo, expectedWeightsPShape), 0, "SCONV2D OP: wrong shape of point array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsPShape).c_str(), ShapeUtils::shapeAsString(weightsPShapeInfo).c_str());
|
||||
}
|
||||
if (biasShapeInfo)
|
||||
REQUIRE_TRUE(biasShapeInfo[0] <= 2 && oC == shape::length(biasShapeInfo), 0, "SCONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo));
|
||||
|
@ -250,13 +250,13 @@ CUSTOM_OP_IMPL(sconv2d_bp, 3, 2, false, 0, 9) {
|
|||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWmC, indWkH, indOoH);
|
||||
mC = weightsDepth->sizeAt(indWmC); // channels multiplier
|
||||
|
||||
std::string expectedWeightsDShape = ShapeUtils::shapeAsString({kH, kW, iC, mC});
|
||||
REQUIRE_TRUE(expectedWeightsDShape == ShapeUtils::shapeAsString(weightsDepth), 0, " SCONV2D_BP OP: wrong shape of weightsDepth array, expected is %s, but got %s instead !", expectedWeightsDShape.c_str(), ShapeUtils::shapeAsString(weightsDepth).c_str());
|
||||
REQUIRE_TRUE(expectedWeightsDShape == ShapeUtils::shapeAsString(gradWD), 0, " SCONV2D_BP OP: wrong shape of gradWD array, expected is %s, but got %s instead !", expectedWeightsDShape.c_str(), ShapeUtils::shapeAsString(gradWD).c_str());
|
||||
std::vector<Nd4jLong> expectedWeightsDShape = {kH, kW, iC, mC};
|
||||
REQUIRE_TRUE(weightsDepth->isSameShape(expectedWeightsDShape), 0, " SCONV2D_BP OP: wrong shape of weightsDepth array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsDShape).c_str(), ShapeUtils::shapeAsString(weightsDepth).c_str());
|
||||
REQUIRE_TRUE(gradWD->isSameShape(expectedWeightsDShape), 0, " SCONV2D_BP OP: wrong shape of gradWD array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsDShape).c_str(), ShapeUtils::shapeAsString(gradWD).c_str());
|
||||
if(weightsPoint) {
|
||||
std::string expectedWeightsPShape = ShapeUtils::shapeAsString({1, 1, iC*mC, oC});
|
||||
REQUIRE_TRUE(expectedWeightsPShape == ShapeUtils::shapeAsString(weightsPoint), 0, " SCONV2D_BP OP: wrong shape of weightsPoint array, expected is %s, but got %s instead !", expectedWeightsPShape.c_str(), ShapeUtils::shapeAsString(weightsPoint).c_str());
|
||||
REQUIRE_TRUE(expectedWeightsPShape == ShapeUtils::shapeAsString(gradWP), 0, " SCONV2D_BP OP: wrong shape of gradWP array, expected is %s, but got %s instead !", expectedWeightsPShape.c_str(), ShapeUtils::shapeAsString(gradWP).c_str());
|
||||
std::vector<Nd4jLong> expectedWeightsPShape = {1, 1, iC*mC, oC};
|
||||
REQUIRE_TRUE(weightsPoint->isSameShape(expectedWeightsPShape), 0, " SCONV2D_BP OP: wrong shape of weightsPoint array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsPShape).c_str(), ShapeUtils::shapeAsString(weightsPoint).c_str());
|
||||
REQUIRE_TRUE(gradWP->isSameShape(expectedWeightsPShape), 0, " SCONV2D_BP OP: wrong shape of gradWP array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsPShape).c_str(), ShapeUtils::shapeAsString(gradWP).c_str());
|
||||
}
|
||||
if (bias) {
|
||||
REQUIRE_TRUE(oC == bias->lengthOf(), 0, " SCONV2D_BP OP: length of bias array must be equal to outChannels, but got %i instead", bias->lengthOf());
|
||||
|
@ -354,13 +354,13 @@ DECLARE_SHAPE_FN(sconv2d_bp) {
|
|||
int trueoH, trueoW; // true output height, width
|
||||
ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
|
||||
|
||||
std::string expectedGradOShapeInfo = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indIiH,indIiH+1}));
|
||||
REQUIRE_TRUE(expectedGradOShapeInfo == ShapeUtils::shapeAsString(gradOShapeInfo), 0, "SCONV2D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", expectedGradOShapeInfo.c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str());
|
||||
std::string expectedWeightsDShape = ShapeUtils::shapeAsString({kH, kW, iC, mC});
|
||||
REQUIRE_TRUE(expectedWeightsDShape == ShapeUtils::shapeAsString(weightsDShapeInfo), 0, "SCONV2D_BP OP: wrong shape of depth weights array, expected is %s, but got %s instead !", expectedWeightsDShape.c_str(), ShapeUtils::shapeAsString(weightsDShapeInfo).c_str());
|
||||
std::vector<Nd4jLong> expectedGradOShapeInfo = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indIiH,indIiH+1});
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(gradOShapeInfo, expectedGradOShapeInfo), 0, "SCONV2D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShapeInfo).c_str(), ShapeUtils::shapeAsString(gradOShapeInfo).c_str());
|
||||
std::vector<Nd4jLong> expectedWeightsDShape = {kH, kW, iC, mC};
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsDShapeInfo, expectedWeightsDShape), 0, "SCONV2D_BP OP: wrong shape of depth weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsDShape).c_str(), ShapeUtils::shapeAsString(weightsDShapeInfo).c_str());
|
||||
if(weightsPShapeInfo) {
|
||||
std::string expectedWeightsPShape = ShapeUtils::shapeAsString({1, 1, iC*mC, oC});
|
||||
REQUIRE_TRUE(expectedWeightsPShape == ShapeUtils::shapeAsString(weightsPShapeInfo), 0, "SCONV2D_BP OP: wrong shape of point array, expected is %s, but got %s instead !", expectedWeightsPShape.c_str(), ShapeUtils::shapeAsString(weightsPShapeInfo).c_str());
|
||||
std::vector<Nd4jLong> expectedWeightsPShape = {1, 1, iC*mC, oC};
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(weightsPShapeInfo, expectedWeightsPShape), 0, "SCONV2D_BP OP: wrong shape of point array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsPShape).c_str(), ShapeUtils::shapeAsString(weightsPShapeInfo).c_str());
|
||||
}
|
||||
if (biasShapeInfo)
|
||||
REQUIRE_TRUE((biasShapeInfo[0] == 1 || biasShapeInfo[0] == 2) && oC == shape::length(biasShapeInfo), 0, "SCONV2D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, biasShapeInfo[0], shape::length(biasShapeInfo));
|
||||
|
|
|
@ -168,11 +168,10 @@ CUSTOM_OP_IMPL(avgpool2d_bp, 2, 1, false, 0, 10) {
|
|||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
|
||||
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW, 0,indIOioC,indIiH,indIiH+1}));
|
||||
std::string expectedGradIShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iH,iW, 0,indIOioC,indIiH,indIiH+1}));
|
||||
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "AVGPOOL2D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, "AVGPOOL2D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
|
||||
|
||||
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW, 0,indIOioC,indIiH,indIiH+1});
|
||||
std::vector<Nd4jLong> expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iH,iW, 0,indIOioC,indIiH,indIiH+1});
|
||||
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "AVGPOOL2D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||
REQUIRE_TRUE(gradI->isSameShape(expectedGradIShape), 0, "AVGPOOL2D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradIShape).c_str(), ShapeUtils::shapeAsString(gradI).c_str());
|
||||
|
||||
if(!isNCHW) {
|
||||
input = new NDArray(input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
||||
|
|
|
@ -57,8 +57,8 @@ CUSTOM_OP_IMPL(avgpool3dnew, 1, 1, false, 0, 14) {
|
|||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
|
||||
std::string expectedOutputShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}));
|
||||
REQUIRE_TRUE(expectedOutputShape == ShapeUtils::shapeAsString(output), 0, "AVGPOOL3DNEW OP: wrong shape of output array, expected is %s, but got %s instead !", expectedOutputShape.c_str(), ShapeUtils::shapeAsString(output).c_str());
|
||||
std::vector<Nd4jLong> expectedOutputShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
|
||||
REQUIRE_TRUE(output->isSameShape(expectedOutputShape), 0, "AVGPOOL3DNEW OP: wrong shape of output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedOutputShape).c_str(), ShapeUtils::shapeAsString(output).c_str());
|
||||
|
||||
if(!isNCDHW) {
|
||||
input = new NDArray(input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
||||
|
@ -174,10 +174,10 @@ CUSTOM_OP_IMPL(avgpool3dnew_bp, 2, 1, false, 0, 14) {
|
|||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
|
||||
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}));
|
||||
std::string expectedGradIShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iD,iH,iW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}));
|
||||
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "AVGPOOL3DNEW_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, "AVGPOOL3DNEW_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
|
||||
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
|
||||
std::vector<Nd4jLong> expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iD,iH,iW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
|
||||
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "AVGPOOL3DNEW_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||
REQUIRE_TRUE(gradI->isSameShape(expectedGradIShape), 0, "AVGPOOL3DNEW_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradIShape).c_str(), ShapeUtils::shapeAsString(gradI).c_str());
|
||||
|
||||
if(!isNCDHW) {
|
||||
input = new NDArray(input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
||||
|
|
|
@ -170,10 +170,10 @@ CUSTOM_OP_IMPL(maxpool2d_bp, 2, 1, false, 0, 10) {
|
|||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
|
||||
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW, 0,indIOioC,indIiH,indIiH+1}));
|
||||
std::string expectedGradIShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iH,iW, 0,indIOioC,indIiH,indIiH+1}));
|
||||
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "MAXPOOL2D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, "MAXPOOL2D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
|
||||
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW, 0,indIOioC,indIiH,indIiH+1});
|
||||
std::vector<Nd4jLong> expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iH,iW, 0,indIOioC,indIiH,indIiH+1});
|
||||
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "MAXPOOL2D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||
REQUIRE_TRUE(gradI->isSameShape(expectedGradIShape), 0, "MAXPOOL2D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradIShape).c_str(), ShapeUtils::shapeAsString(gradI).c_str());
|
||||
|
||||
if(!isNCHW) {
|
||||
input = new NDArray(input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
||||
|
|
|
@ -57,8 +57,8 @@ CUSTOM_OP_IMPL(maxpool3dnew, 1, 1, false, 0, 14) {
|
|||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
|
||||
std::string expectedOutputShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}));
|
||||
REQUIRE_TRUE(expectedOutputShape == ShapeUtils::shapeAsString(output), 0, "MAXPOOL3D op: wrong shape of output array, expected is %s, but got %s instead !", expectedOutputShape.c_str(), ShapeUtils::shapeAsString(output).c_str());
|
||||
std::vector<Nd4jLong> expectedOutputShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
|
||||
REQUIRE_TRUE(output->isSameShape(expectedOutputShape), 0, "MAXPOOL3D op: wrong shape of output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedOutputShape).c_str(), ShapeUtils::shapeAsString(output).c_str());
|
||||
// REQUIRE_TRUE(iD >= kD && iH >= kH && iW >= kW, 0, "MAXPOOL3D OP: the input depth/height/width must be greater or equal to kernel(filter) depth/height/width, but got [%i, %i, %i] and [%i, %i, %i] correspondingly !", iD,iH,iW, kD,kH,kW);
|
||||
// REQUIRE_TRUE(kD/2 >= pD && kH/2 >= pH && kW/2 >= pW, 0, "MAXPOOL3D OP: pad depth/height/width must not be greater than half of kernel depth/height/width, but got [%i, %i, %i] and [%i, %i, %i] correspondingly !", pD,pH,pW, kD,kH,kW);
|
||||
|
||||
|
@ -176,10 +176,10 @@ CUSTOM_OP_IMPL(maxpool3dnew_bp, 2, 1, false, 0, 14) {
|
|||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||
|
||||
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}));
|
||||
std::string expectedGradIShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iD,iH,iW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}));
|
||||
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "MAXPOOL3DNEW_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, "MAXPOOL3DNEW_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
|
||||
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
|
||||
std::vector<Nd4jLong> expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iD,iH,iW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
|
||||
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "MAXPOOL3DNEW_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||
REQUIRE_TRUE(gradI->isSameShape(expectedGradIShape), 0, "MAXPOOL3DNEW_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradIShape).c_str(), ShapeUtils::shapeAsString(gradI).c_str());
|
||||
|
||||
if(!isNCDHW) {
|
||||
input = new NDArray(input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
||||
|
|
|
@ -169,10 +169,10 @@ CUSTOM_OP_IMPL(pnormpool2d_bp, 2, 1, false, 1, 10) {
|
|||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
|
||||
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW, 0,indIOioC,indIiH,indIiH+1}));
|
||||
std::string expectedGradIShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iH,iW, 0,indIOioC,indIiH,indIiH+1}));
|
||||
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "PNORMPOOL2D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, "PNORMPOOL2D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
|
||||
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW, 0,indIOioC,indIiH,indIiH+1});
|
||||
std::vector<Nd4jLong> expectedGradIShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iH,iW, 0,indIOioC,indIiH,indIiH+1});
|
||||
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "PNORMPOOL2D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||
REQUIRE_TRUE(gradI->isSameShape(expectedGradIShape), 0, "PNORMPOOL2D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradIShape).c_str(), ShapeUtils::shapeAsString(gradI).c_str());
|
||||
|
||||
if(!isNCHW) {
|
||||
input = new NDArray(input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
||||
|
|
|
@ -41,7 +41,7 @@ CUSTOM_OP_IMPL(parallel_stack, -1, 1, false, 0, 0) {
|
|||
inArrs[i] = INPUT_VARIABLE(i);
|
||||
|
||||
const int dim = 0;
|
||||
helpers::stack(block.launchContext(), inArrs, output, dim);
|
||||
helpers::stack(block.launchContext(), inArrs, *output, dim);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -70,7 +70,7 @@ namespace sd {
|
|||
empty = true;
|
||||
//Don't break to perform input validation on other dims
|
||||
}
|
||||
|
||||
|
||||
indices[2*e] = start;
|
||||
indices[2*e+1] = start + size;
|
||||
}
|
||||
|
@ -80,8 +80,29 @@ namespace sd {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
auto sub = (*input)(indices, true);
|
||||
output->assign(sub);
|
||||
Nd4jLong* subArrShapeInfo = nullptr;
|
||||
ALLOCATE(subArrShapeInfo, block.getWorkspace(), shape::shapeInfoLength(input->rankOf()), Nd4jLong);
|
||||
|
||||
Nd4jLong offset;
|
||||
|
||||
shape::calcSubArrShapeInfoAndOffset(indices.data(), input->getShapeInfo(), subArrShapeInfo, offset, true);
|
||||
|
||||
auto subArrShapeInfoPack = ConstantShapeHelper::getInstance()->bufferForShapeInfo(subArrShapeInfo);
|
||||
|
||||
NDArray::prepareSpecialUse({output}, {input});
|
||||
|
||||
NativeOpExecutioner::execTransformAny(block.launchContext(), sd::transform::Assign,
|
||||
input->bufferWithOffset(offset), reinterpret_cast<Nd4jLong *>(subArrShapeInfoPack.primary()),
|
||||
input->specialBufferWithOffset(offset), reinterpret_cast<Nd4jLong *>(subArrShapeInfoPack.special()),
|
||||
output->buffer(), output->shapeInfo(), output->specialBuffer(), output->specialShapeInfo(),
|
||||
nullptr, nullptr, nullptr, true);
|
||||
|
||||
NDArray::registerSpecialUse({output}, {input});
|
||||
|
||||
RELEASE(subArrShapeInfo, block.getWorkspace());
|
||||
|
||||
// auto sub = (*input)(indices, true);
|
||||
// output->assign(sub);
|
||||
|
||||
STORE_RESULT(output);
|
||||
|
||||
|
@ -116,7 +137,7 @@ namespace sd {
|
|||
|
||||
REQUIRE_TRUE(begin.size() == x_rank, 0, "Begin array should have length of [%i] but got [%i] instead", x_rank, begin.size());
|
||||
REQUIRE_TRUE(sz.size() == x_rank, 0, "Size array should have length of [%i] but got [%i] instead", x_rank, sz.size());
|
||||
|
||||
|
||||
std::vector<Nd4jLong> shape;
|
||||
auto empty = false;
|
||||
for (int e = 0; e < x_rank; e++) {
|
||||
|
@ -186,12 +207,12 @@ namespace sd {
|
|||
size = input->sizeAt(e) - start;
|
||||
}
|
||||
REQUIRE_TRUE(size > 0, 0, "Slice: interval for dimension %i is less then 1", e);
|
||||
|
||||
|
||||
indices[2*e] = start;
|
||||
indices[2*e + 1] = start + size;
|
||||
}
|
||||
auto sub = (*output)(indices, true);
|
||||
sub.assign(epsNext);
|
||||
sub.assign(epsNext);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -50,7 +50,7 @@ CUSTOM_OP_IMPL(stack, -1, 1, false, 0, 0) {
|
|||
for(int i = 0; i < block.width(); ++i)
|
||||
inArrs[i] = INPUT_VARIABLE(i);
|
||||
|
||||
helpers::stack(block.launchContext(), inArrs, output, dim);
|
||||
helpers::stack(block.launchContext(), inArrs, *output, dim);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -16,125 +16,114 @@
|
|||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||
//
|
||||
|
||||
#include <system/op_boilerplate.h>
|
||||
#if NOT_EXCLUDED(OP_unstack)
|
||||
|
||||
#include <ops/declarable/CustomOperations.h>
|
||||
#include <helpers/ConstantTadHelper.h>
|
||||
#include<ops/declarable/helpers/stack.h>
|
||||
|
||||
namespace sd {
|
||||
namespace ops {
|
||||
CUSTOM_OP_IMPL(unstack, 1, -1, false, 0, 1) {
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
namespace ops {
|
||||
|
||||
auto dim = INT_ARG(0);
|
||||
if (dim < 0)
|
||||
dim += input->rankOf();
|
||||
CUSTOM_OP_IMPL(unstack, 1, -1, false, 0, 1) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
|
||||
auto dim = INT_ARG(0);
|
||||
if (dim < 0)
|
||||
dim += input->rankOf();
|
||||
|
||||
|
||||
REQUIRE_TRUE(dim < input->rankOf(), 0, "Unstack dimension should be lower then rank of input %i, but got dimension=%i !", input->rankOf(), dim);
|
||||
REQUIRE_TRUE(dim >= 0, 0, "Unstack dimension should be non-negative value, but got %i !", dim);
|
||||
REQUIRE_TRUE(dim < input->rankOf(), 0, "Unstack dimension should be lower then rank of input %i, but got dimension=%i !", input->rankOf(), dim);
|
||||
REQUIRE_TRUE(dim >= 0, 0, "Unstack dimension should be non-negative value, but got %i !", dim);
|
||||
|
||||
if(input->isEmpty())
|
||||
return Status::OK();
|
||||
if(input->isEmpty())
|
||||
return Status::OK();
|
||||
|
||||
std::vector<int> dims;
|
||||
for (int e = 0; e < input->rankOf(); e++)
|
||||
if (e != dim)
|
||||
dims.emplace_back(e);
|
||||
if (dims.size() == 0 && input->rankOf() == 1) { // split vector into lenthOf scalars
|
||||
for (Nd4jLong e = 0; e < input->lengthOf(); e++) {
|
||||
auto outE = OUTPUT_VARIABLE(e);
|
||||
outE->assign(input->e(e));
|
||||
}
|
||||
}
|
||||
std::vector<NDArray*> outArrs(input->sizeAt(dim));
|
||||
for(uint i = 0; i < outArrs.size(); ++i)
|
||||
outArrs[i] = OUTPUT_VARIABLE(i);
|
||||
|
||||
auto tads = input->allTensorsAlongDimension(dims);
|
||||
//nd4j_printf("Tad size: %d\n",tads.size());
|
||||
for (int e = 0; e < tads.size(); e++) {
|
||||
//nd4j_printf("Calling assign at index %d\n",e);
|
||||
auto outE = OUTPUT_VARIABLE(e);
|
||||
auto tadAtE = tads.at(e);
|
||||
helpers::unstack(block.launchContext(), *input, outArrs, dim);
|
||||
|
||||
outE->assign(tadAtE);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
this->storeResult(block, e, *outE);
|
||||
}
|
||||
DECLARE_SYN(unpack, unstack);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
DECLARE_SHAPE_FN(unstack) {
|
||||
auto inShapeInfo = inputShape->at(0);
|
||||
|
||||
DECLARE_SYN(unpack, unstack);
|
||||
auto dim = INT_ARG(0);
|
||||
if (dim < 0)
|
||||
dim += shape::rank(inShapeInfo);
|
||||
|
||||
DECLARE_SHAPE_FN(unstack) {
|
||||
auto inShape = inputShape->at(0);
|
||||
REQUIRE_TRUE(dim < inShapeInfo[0], 0, "UNSTACK op: dimension should be lower then rank of input %i, but got dimension=%i !", inShapeInfo[0], dim);
|
||||
REQUIRE_TRUE(dim >= 0, 0, "UNSTACK op: dimension should be non-negative value, but got %i !", dim);
|
||||
|
||||
auto dim = INT_ARG(0);
|
||||
if (dim < 0)
|
||||
dim += shape::rank(inShape);
|
||||
if(ArrayOptions::arrayType(inShapeInfo) == ArrayType::EMPTY) {
|
||||
|
||||
REQUIRE_TRUE(dim < inShape[0], 0, "UNSTACK op: dimension should be lower then rank of input %i, but got dimension=%i !", inShape[0], dim);
|
||||
REQUIRE_TRUE(dim >= 0, 0, "UNSTACK op: dimension should be non-negative value, but got %i !", dim);
|
||||
if(shape::shapeOf(inShapeInfo)[dim] == 0)
|
||||
return SHAPELIST();
|
||||
|
||||
if(ArrayOptions::arrayType(inShape) == ArrayType::EMPTY) {
|
||||
if(shape::shapeOf(inShape)[dim] == 0)
|
||||
return SHAPELIST();
|
||||
const Nd4jLong numTads = shape::shapeOf(inShape)[dim];
|
||||
std::vector<Nd4jLong> outShape;
|
||||
for(uint i = 0; i < shape::rank(inShape); ++i)
|
||||
if(i != dim)
|
||||
outShape.push_back(shape::shapeOf(inShape)[i]);
|
||||
const Nd4jLong numTads = shape::shapeOf(inShapeInfo)[dim];
|
||||
std::vector<Nd4jLong> outShape;
|
||||
for(uint i = 0; i < shape::rank(inShapeInfo); ++i)
|
||||
if(i != dim)
|
||||
outShape.push_back(shape::shapeOf(inShapeInfo)[i]);
|
||||
|
||||
auto result = SHAPELIST();
|
||||
for(uint i = 0; i < numTads; ++i)
|
||||
result->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShape), shape::order(inShape), outShape));
|
||||
return result;
|
||||
}
|
||||
auto result = SHAPELIST();
|
||||
for(uint i = 0; i < numTads; ++i)
|
||||
result->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShapeInfo), shape::order(inShapeInfo), outShape));
|
||||
|
||||
std::vector<int> dims;
|
||||
for (int e = 0; e < shape::rank(inShape); e++)
|
||||
if (e != dim)
|
||||
dims.emplace_back(e);
|
||||
if (dims.size() == 0 && shape::rank(inShape) == 1) { // split vector into lenthOf scalars
|
||||
//
|
||||
auto result = SHAPELIST();
|
||||
for (Nd4jLong e = 0; e < shape::length(inShape); e++)
|
||||
result->push_back(ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(inShape)));
|
||||
return result;
|
||||
}
|
||||
|
||||
auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(inShape, dims);
|
||||
auto numTads = tadPack.numberOfTads();
|
||||
|
||||
std::vector<Nd4jLong> shape(shape::rank(tadPack.primaryShapeInfo()));
|
||||
for (int e = 0; e < shape::rank(tadPack.primaryShapeInfo()); e++)
|
||||
shape[e] = shape::shapeOf(tadPack.primaryShapeInfo())[e];
|
||||
|
||||
// remove leading and trailing 1
|
||||
if (inShape[0] == 2 && shape.size() == 2) {
|
||||
if (shape[0] == 1) {
|
||||
shape.erase(shape.begin());
|
||||
} else if (shape[1] == 1) {
|
||||
shape.erase(shape.end());
|
||||
}
|
||||
}
|
||||
|
||||
auto result = SHAPELIST();
|
||||
for (int e = 0; e < numTads; e++) {
|
||||
auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShape), shape::order(inShape), shape);
|
||||
result->push_back(newShape);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
DECLARE_TYPES(unstack) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes({ALL_FLOATS, ALL_INTS})
|
||||
->setSameMode(true);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<int> dims = ShapeUtils::evalDimsToExclude(inShapeInfo[0], {dim});
|
||||
|
||||
if (dims.size() == 0 && shape::rank(inShapeInfo) == 1) { // split vector into lenthOf scalars
|
||||
|
||||
auto result = SHAPELIST();
|
||||
for (Nd4jLong e = 0; e < shape::length(inShapeInfo); e++)
|
||||
result->push_back(ConstantShapeHelper::getInstance()->scalarShapeInfo(ArrayOptions::dataType(inShapeInfo)));
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<Nd4jLong> subArrShape(shape::rank(inShapeInfo) - 1);
|
||||
|
||||
for(uint j = 0, i = 0; i < shape::rank(inShapeInfo); i++)
|
||||
if(i != dim)
|
||||
subArrShape[j++] = shape::shapeOf(inShapeInfo)[i];
|
||||
|
||||
// remove leading and trailing 1
|
||||
if (inShapeInfo[0] == 2 && subArrShape.size() == 2) {
|
||||
|
||||
if (subArrShape[0] == 1)
|
||||
subArrShape.erase(subArrShape.begin());
|
||||
else if (subArrShape[1] == 1)
|
||||
subArrShape.erase(subArrShape.end());
|
||||
}
|
||||
|
||||
auto result = SHAPELIST();
|
||||
for (int e = 0; e < shape::shapeOf(inShapeInfo)[dim]; e++) {
|
||||
auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShapeInfo), shape::order(inShapeInfo), subArrShape);
|
||||
result->push_back(newShape);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
DECLARE_TYPES(unstack) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes({ALL_FLOATS, ALL_INTS})
|
||||
->setSameMode(true);
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
|
@ -54,15 +54,15 @@ CUSTOM_OP_IMPL(dynamic_bidirectional_rnn, 7, 4, false, 0, 0) {
|
|||
maxTimeStep = INPUT_VARIABLE(9);
|
||||
break;
|
||||
}
|
||||
|
||||
|
||||
auto hFW = OUTPUT_VARIABLE(0); // cell outputs for forward RNN [time x bS x numUnitsFW] or [bS x time x numUnitsFW], shape depends on timeMajor parameter
|
||||
auto hBW = OUTPUT_VARIABLE(1); // cell outputs for backward RNN [time x bS x numUnitsBW] or [bS x time x numUnitsBW], shape depends on timeMajor parameter
|
||||
auto hFWFinal = OUTPUT_VARIABLE(2); // final cell out for forward RNN [bS x numUnitsFW]
|
||||
auto hBWFinal = OUTPUT_VARIABLE(3); // final cell out for backward RNN [bS x numUnitsBF]
|
||||
|
||||
REQUIRE_TRUE(x->rankOf() == 3, 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: input array must have rank = 3, but got %i instead !", x->rankOf());
|
||||
REQUIRE_TRUE(WxFW->rankOf() == 2, 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: input-to-hidden weights array (for forward RNN) must have rank = 2, but got %i instead !", WxFW->rankOf());
|
||||
REQUIRE_TRUE(WxBW->rankOf() == 2, 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: input-to-hidden weights array (for backward RNN) must have rank = 2, but got %i instead !", WxBW->rankOf());
|
||||
REQUIRE_TRUE(WxFW->rankOf() == 2, 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: input-to-hidden weights array (for forward RNN) must have rank = 2, but got %i instead !", WxFW->rankOf());
|
||||
REQUIRE_TRUE(WxBW->rankOf() == 2, 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: input-to-hidden weights array (for backward RNN) must have rank = 2, but got %i instead !", WxBW->rankOf());
|
||||
|
||||
const int inRank = x->rankOf();
|
||||
const int time = timeMajor ? x->sizeAt(0) : x->sizeAt(1);
|
||||
|
@ -70,16 +70,26 @@ CUSTOM_OP_IMPL(dynamic_bidirectional_rnn, 7, 4, false, 0, 0) {
|
|||
const int numUnitsFW = WxFW->sizeAt(1);
|
||||
const int numUnitsBW = WxBW->sizeAt(1);
|
||||
|
||||
REQUIRE_TRUE(ShapeUtils::shapeAsString(WhFW) == ShapeUtils::shapeAsString({numUnitsFW, numUnitsFW}), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of hidden-to-hidden weights array (for forward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString({numUnitsFW, numUnitsFW}).c_str(), ShapeUtils::shapeAsString(WhFW).c_str());
|
||||
REQUIRE_TRUE(ShapeUtils::shapeAsString(WhBW) == ShapeUtils::shapeAsString({numUnitsBW, numUnitsBW}), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of hidden-to-hidden weights array (for backward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString({numUnitsBW, numUnitsBW}).c_str(), ShapeUtils::shapeAsString(WhBW).c_str());
|
||||
REQUIRE_TRUE(ShapeUtils::shapeAsString(bFW) == ShapeUtils::shapeAsString({2*numUnitsFW}), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of biases array (for forward RNN), expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2*numUnitsFW}).c_str(), ShapeUtils::shapeAsString(bFW).c_str());
|
||||
REQUIRE_TRUE(ShapeUtils::shapeAsString(bBW) == ShapeUtils::shapeAsString({2*numUnitsBW}), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of biases array (for backward RNN), expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2*numUnitsBW}).c_str(), ShapeUtils::shapeAsString(bBW).c_str());
|
||||
if(h0FW)
|
||||
REQUIRE_TRUE(ShapeUtils::shapeAsString(h0FW) == ShapeUtils::shapeAsString({bS, numUnitsFW}), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of initial cell output array (for forward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString({bS, numUnitsFW}).c_str(), ShapeUtils::shapeAsString(h0FW).c_str());
|
||||
if(h0BW)
|
||||
REQUIRE_TRUE(ShapeUtils::shapeAsString(h0BW) == ShapeUtils::shapeAsString({bS, numUnitsBW}), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of initial cell output array (for backward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString({bS, numUnitsBW}).c_str(), ShapeUtils::shapeAsString(h0BW).c_str());
|
||||
if(maxTimeStep)
|
||||
REQUIRE_TRUE(ShapeUtils::shapeAsString(maxTimeStep) == ShapeUtils::shapeAsString({bS}), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of maxTimeStep array, expected is [%i], but got %s instead !", bS, ShapeUtils::shapeAsString(maxTimeStep).c_str());
|
||||
std::vector<Nd4jLong> expectedWhFWshape = {numUnitsFW, numUnitsFW};
|
||||
std::vector<Nd4jLong> expectedWhBWshape = {numUnitsBW, numUnitsBW};
|
||||
std::vector<Nd4jLong> expectedbFWshape = {2*numUnitsFW};
|
||||
std::vector<Nd4jLong> expectedbBWshape = {2*numUnitsBW};
|
||||
REQUIRE_TRUE(WhFW->isSameShape(expectedWhFWshape), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of hidden-to-hidden weights array (for forward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedWhFWshape).c_str(), ShapeUtils::shapeAsString(WhFW).c_str());
|
||||
REQUIRE_TRUE(WhBW->isSameShape(expectedWhBWshape) , 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of hidden-to-hidden weights array (for backward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedWhBWshape).c_str(), ShapeUtils::shapeAsString(WhBW).c_str());
|
||||
REQUIRE_TRUE(bFW->isSameShape(expectedbFWshape), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of biases array (for forward RNN), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedbFWshape).c_str(), ShapeUtils::shapeAsString(bFW).c_str());
|
||||
REQUIRE_TRUE(bBW->isSameShape(expectedbBWshape), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of biases array (for backward RNN), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedbBWshape).c_str(), ShapeUtils::shapeAsString(bBW).c_str());
|
||||
if(h0FW) {
|
||||
std::vector<Nd4jLong> expectedh0FWshape = {bS, numUnitsFW};
|
||||
REQUIRE_TRUE(h0FW->isSameShape(expectedh0FWshape), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of initial cell output array (for forward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedh0FWshape).c_str(), ShapeUtils::shapeAsString(h0FW).c_str());
|
||||
}
|
||||
if(h0BW) {
|
||||
std::vector<Nd4jLong> expectedh0BWshape = {bS, numUnitsBW};
|
||||
REQUIRE_TRUE(h0BW->isSameShape(expectedh0BWshape), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of initial cell output array (for backward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedh0BWshape).c_str(), ShapeUtils::shapeAsString(h0BW).c_str());
|
||||
}
|
||||
if(maxTimeStep) {
|
||||
std::vector<Nd4jLong> expectedmaxTimeStepshape = {bS};
|
||||
REQUIRE_TRUE(maxTimeStep->isSameShape(expectedmaxTimeStepshape), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of maxTimeStep array, expected is [%i], but got %s instead !", bS, ShapeUtils::shapeAsString(maxTimeStep).c_str());
|
||||
}
|
||||
|
||||
|
||||
// forward steps
|
||||
|
@ -101,19 +111,19 @@ CUSTOM_OP_IMPL(dynamic_bidirectional_rnn, 7, 4, false, 0, 0) {
|
|||
REQUIRE_TRUE (resultsIn->status() == ND4J_STATUS_OK, 0, "dynamic_bidirectional_rnn: there is a problem with reverse on the sequence.");
|
||||
auto revInput = resultsIn->at(0);
|
||||
|
||||
// backward steps
|
||||
// backward steps
|
||||
auto resultsBW = dynamicRnn.evaluate({revInput, WxBW, WhBW, bBW, h0BW, maxTimeStep}, {timeMajor});
|
||||
auto hBWtemp = resultsBW->at(0); // [time x bS x numUnitsBW] or [ bS x time xnumUnitsBW]
|
||||
hBWFinal->assign(resultsBW->at(1));
|
||||
|
||||
// reverse hBWtemp
|
||||
// reverse hBWtemp
|
||||
auto resultsOut = timeMajor ? reverse.evaluate({hBWtemp, seqLen}, {0, 1}) : reverse.evaluate({hBWtemp, seqLen}, {1, 0});
|
||||
hBW->assign(resultsOut->at(0));
|
||||
|
||||
|
||||
delete resultsOut;
|
||||
delete resultsBW;
|
||||
delete resultsIn;
|
||||
delete resultsFW;
|
||||
delete resultsIn;
|
||||
delete resultsFW;
|
||||
|
||||
if(seqLen != maxTimeStep)
|
||||
delete seqLen;
|
||||
|
@ -128,7 +138,7 @@ CUSTOM_OP_IMPL(dynamic_bidirectional_rnn, 7, 4, false, 0, 0) {
|
|||
}
|
||||
|
||||
|
||||
DECLARE_SHAPE_FN(dynamic_bidirectional_rnn) {
|
||||
DECLARE_SHAPE_FN(dynamic_bidirectional_rnn) {
|
||||
|
||||
auto x = INPUT_VARIABLE(0); // input [time x bS x inSize] or [bS x time x inSize], shape depends on timeMajor parameter
|
||||
auto WxFW = INPUT_VARIABLE(1); // input-to-hidden weights for forward RNN, [inSize x numUnitsFW]
|
||||
|
@ -143,7 +153,7 @@ DECLARE_SHAPE_FN(dynamic_bidirectional_rnn) {
|
|||
NDArray* maxTimeStep = nullptr; // vector [bS] containing integer values within [0,time), each element of this vector set max time step per each input in batch, this means there are no calculations for time >= maxTimeStep
|
||||
|
||||
const int timeMajor = block.getIArguments()->size() > 0 ? INT_ARG(0) : 0; // if true then [time, bS, ...], else [bS, time, ...]
|
||||
|
||||
|
||||
switch(block.width()) {
|
||||
case 8:
|
||||
maxTimeStep = INPUT_VARIABLE(7);
|
||||
|
@ -160,8 +170,8 @@ DECLARE_SHAPE_FN(dynamic_bidirectional_rnn) {
|
|||
}
|
||||
|
||||
REQUIRE_TRUE(x->rankOf() == 3, 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: input array must have rank = 3, but got %i instead !", x->rankOf());
|
||||
REQUIRE_TRUE(WxFW->rankOf() == 2, 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: input-to-hidden weights array (for forward RNN) must have rank = 2, but got %i instead !", WxFW->rankOf());
|
||||
REQUIRE_TRUE(WxBW->rankOf() == 2, 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: input-to-hidden weights array (for backward RNN) must have rank = 2, but got %i instead !", WxBW->rankOf());
|
||||
REQUIRE_TRUE(WxFW->rankOf() == 2, 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: input-to-hidden weights array (for forward RNN) must have rank = 2, but got %i instead !", WxFW->rankOf());
|
||||
REQUIRE_TRUE(WxBW->rankOf() == 2, 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: input-to-hidden weights array (for backward RNN) must have rank = 2, but got %i instead !", WxBW->rankOf());
|
||||
|
||||
const int inRank = x->rankOf();
|
||||
const int time = timeMajor ? x->sizeAt(0) : x->sizeAt(1);
|
||||
|
@ -169,16 +179,28 @@ DECLARE_SHAPE_FN(dynamic_bidirectional_rnn) {
|
|||
const int numUnitsFW = WxFW->sizeAt(1);
|
||||
const int numUnitsBW = WxBW->sizeAt(1);
|
||||
|
||||
REQUIRE_TRUE(ShapeUtils::shapeAsString(WhFW) == ShapeUtils::shapeAsString({numUnitsFW, numUnitsFW}), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of hidden-to-hidden weights array (for forward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString({numUnitsFW, numUnitsFW}).c_str(), ShapeUtils::shapeAsString(WhFW).c_str());
|
||||
REQUIRE_TRUE(ShapeUtils::shapeAsString(WhBW) == ShapeUtils::shapeAsString({numUnitsBW, numUnitsBW}), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of hidden-to-hidden weights array (for backward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString({numUnitsBW, numUnitsBW}).c_str(), ShapeUtils::shapeAsString(WhBW).c_str());
|
||||
REQUIRE_TRUE(ShapeUtils::shapeAsString(bFW) == ShapeUtils::shapeAsString({2*numUnitsFW}), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of biases array (for forward RNN), expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2*numUnitsFW}).c_str(), ShapeUtils::shapeAsString(bFW).c_str());
|
||||
REQUIRE_TRUE(ShapeUtils::shapeAsString(bBW) == ShapeUtils::shapeAsString({2*numUnitsBW}), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of biases array (for backward RNN), expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2*numUnitsBW}).c_str(), ShapeUtils::shapeAsString(bBW).c_str());
|
||||
if(h0FW)
|
||||
REQUIRE_TRUE(ShapeUtils::shapeAsString(h0FW) == ShapeUtils::shapeAsString({bS, numUnitsFW}), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of initial cell output array (for forward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString({bS, numUnitsFW}).c_str(), ShapeUtils::shapeAsString(h0FW).c_str());
|
||||
if(h0BW)
|
||||
REQUIRE_TRUE(ShapeUtils::shapeAsString(h0BW) == ShapeUtils::shapeAsString({bS, numUnitsBW}), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of initial cell output array (for backward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString({bS, numUnitsBW}).c_str(), ShapeUtils::shapeAsString(h0BW).c_str());
|
||||
if(maxTimeStep)
|
||||
REQUIRE_TRUE(ShapeUtils::shapeAsString(maxTimeStep) == ShapeUtils::shapeAsString({bS}), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of maxTimeStep array, expected is [%i], but got %s instead !", bS, ShapeUtils::shapeAsString(maxTimeStep).c_str());
|
||||
|
||||
std::vector<Nd4jLong> expectedWhFWshape = {numUnitsFW, numUnitsFW};
|
||||
std::vector<Nd4jLong> expectedWhBWshape = {numUnitsBW, numUnitsBW};
|
||||
std::vector<Nd4jLong> expectedbFWshape = {2*numUnitsFW};
|
||||
std::vector<Nd4jLong> expectedbBWshape = {2*numUnitsBW};
|
||||
|
||||
REQUIRE_TRUE(WhFW->isSameShape(expectedWhFWshape), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of hidden-to-hidden weights array (for forward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedWhFWshape).c_str(), ShapeUtils::shapeAsString(WhFW).c_str());
|
||||
REQUIRE_TRUE(WhBW->isSameShape(expectedWhBWshape), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of hidden-to-hidden weights array (for backward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedWhBWshape).c_str(), ShapeUtils::shapeAsString(WhBW).c_str());
|
||||
REQUIRE_TRUE(bFW->isSameShape(expectedbFWshape), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of biases array (for forward RNN), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedbFWshape).c_str(), ShapeUtils::shapeAsString(bFW).c_str());
|
||||
REQUIRE_TRUE(bBW->isSameShape(expectedbBWshape), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of biases array (for backward RNN), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedbBWshape).c_str(), ShapeUtils::shapeAsString(bBW).c_str());
|
||||
if(h0FW) {
|
||||
std::vector<Nd4jLong> expectedh0FWshape = {bS, numUnitsFW};
|
||||
REQUIRE_TRUE(h0FW->isSameShape(expectedh0FWshape), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of initial cell output array (for forward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedh0FWshape).c_str(), ShapeUtils::shapeAsString(h0FW).c_str());
|
||||
}
|
||||
if(h0BW) {
|
||||
std::vector<Nd4jLong> expectedh0BWshape = {bS, numUnitsBW};
|
||||
REQUIRE_TRUE(h0BW->isSameShape(expectedh0BWshape), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of initial cell output array (for backward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedh0BWshape).c_str(), ShapeUtils::shapeAsString(h0BW).c_str());
|
||||
}
|
||||
if(maxTimeStep) {
|
||||
std::vector<Nd4jLong> expectedmaxTimeStepshape = {bS};
|
||||
REQUIRE_TRUE(maxTimeStep->isSameShape(expectedmaxTimeStepshape), 0, "DYNAMIC_BIDIRECTIONAL_RNN custom operation: wrong shape of maxTimeStep array, expected is [%i], but got %s instead !", bS, ShapeUtils::shapeAsString(maxTimeStep).c_str());
|
||||
}
|
||||
|
||||
// evaluate output shapeInfos
|
||||
Nd4jLong *hFWShapeInfo(nullptr), *hBWShapeInfo(nullptr), *hFWFinalPrevShapeInfo(nullptr), *hBWFinalPrevShapeInfo(nullptr);
|
||||
|
@ -187,13 +209,13 @@ DECLARE_SHAPE_FN(dynamic_bidirectional_rnn) {
|
|||
ALLOCATE(hFWFinalPrevShapeInfo, block.getWorkspace(), shape::shapeInfoLength(inRank-1), Nd4jLong);
|
||||
ALLOCATE(hBWFinalPrevShapeInfo, block.getWorkspace(), shape::shapeInfoLength(inRank-1), Nd4jLong);
|
||||
|
||||
hFWShapeInfo[0] = hBWShapeInfo[0] = inRank;
|
||||
hFWShapeInfo[0] = hBWShapeInfo[0] = inRank;
|
||||
hFWShapeInfo[1] = hBWShapeInfo[1] = timeMajor ? time : bS;
|
||||
hFWShapeInfo[2] = hBWShapeInfo[2] = timeMajor ? bS : time;
|
||||
hFWShapeInfo[3] = numUnitsFW;
|
||||
hBWShapeInfo[3] = numUnitsBW;
|
||||
hFWFinalPrevShapeInfo[0] = hBWFinalPrevShapeInfo[0] = inRank-1;
|
||||
hFWFinalPrevShapeInfo[1] = hBWFinalPrevShapeInfo[1] = bS;
|
||||
hFWFinalPrevShapeInfo[1] = hBWFinalPrevShapeInfo[1] = bS;
|
||||
hFWFinalPrevShapeInfo[2] = numUnitsFW;
|
||||
hBWFinalPrevShapeInfo[2] = numUnitsBW;
|
||||
|
||||
|
@ -201,9 +223,9 @@ DECLARE_SHAPE_FN(dynamic_bidirectional_rnn) {
|
|||
ShapeUtils::updateStridesAndType(hBWShapeInfo, x->getShapeInfo(), x->ordering());
|
||||
ShapeUtils::updateStridesAndType(hFWFinalPrevShapeInfo, x->getShapeInfo(), x->ordering());
|
||||
ShapeUtils::updateStridesAndType(hBWFinalPrevShapeInfo, x->getShapeInfo(), x->ordering());
|
||||
|
||||
|
||||
return SHAPELIST(CONSTANT(hFWShapeInfo), CONSTANT(hBWShapeInfo), CONSTANT(hFWFinalPrevShapeInfo), CONSTANT(hBWFinalPrevShapeInfo));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -60,12 +60,18 @@ CUSTOM_OP_IMPL(dynamic_rnn, 4, 2, false, 0, 0) {
|
|||
const int bS = timeMajor ? x->sizeAt(1) : x->sizeAt(0);
|
||||
const int numUnits = Wx->sizeAt(1);
|
||||
|
||||
REQUIRE_TRUE(ShapeUtils::shapeAsString(Wh) == ShapeUtils::shapeAsString({numUnits, numUnits}), 0, "DYNAMIC_RNN custom operation: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({numUnits, numUnits}).c_str(), ShapeUtils::shapeAsString(Wh).c_str());
|
||||
REQUIRE_TRUE(ShapeUtils::shapeAsString(b) == ShapeUtils::shapeAsString({2*numUnits}), 0, "DYNAMIC_RNN custom operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2*numUnits}).c_str(), ShapeUtils::shapeAsString(b).c_str());
|
||||
if(h0)
|
||||
REQUIRE_TRUE(ShapeUtils::shapeAsString(h0) == ShapeUtils::shapeAsString({bS, numUnits}), 0, "DYNAMIC_RNN custom operation: wrong shape of initial cell output array, expected is %s but got %s instead !", ShapeUtils::shapeAsString({bS, numUnits}).c_str(), ShapeUtils::shapeAsString(h0).c_str());
|
||||
if(maxTimeStep)
|
||||
REQUIRE_TRUE(ShapeUtils::shapeAsString(maxTimeStep) == ShapeUtils::shapeAsString({bS}), 0, "DYNAMIC_RNN custom operation: wrong shape of maxTimeStep array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS}).c_str(), ShapeUtils::shapeAsString(maxTimeStep).c_str());
|
||||
std::vector<Nd4jLong> expectedWhShape = {numUnits, numUnits};
|
||||
std::vector<Nd4jLong> expectedBShape = {2*numUnits};
|
||||
REQUIRE_TRUE(Wh->isSameShape(expectedWhShape), 0, "DYNAMIC_RNN custom operation: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWhShape).c_str(), ShapeUtils::shapeAsString(Wh).c_str());
|
||||
REQUIRE_TRUE(b->isSameShape(expectedBShape), 0, "DYNAMIC_RNN custom operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedBShape).c_str(), ShapeUtils::shapeAsString(b).c_str());
|
||||
if(h0) {
|
||||
std::vector<Nd4jLong> expectedh0Shape = {bS, numUnits};
|
||||
REQUIRE_TRUE(h0->isSameShape(expectedh0Shape), 0, "DYNAMIC_RNN custom operation: wrong shape of initial cell output array, expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedh0Shape).c_str(), ShapeUtils::shapeAsString(h0).c_str());
|
||||
}
|
||||
if(maxTimeStep) {
|
||||
std::vector<Nd4jLong> expectedmaxTimeStepShape = {bS};
|
||||
REQUIRE_TRUE(maxTimeStep->isSameShape(expectedmaxTimeStepShape), 0, "DYNAMIC_RNN custom operation: wrong shape of maxTimeStep array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedmaxTimeStepShape).c_str(), ShapeUtils::shapeAsString(maxTimeStep).c_str());
|
||||
}
|
||||
|
||||
if(timeMajor == false) {
|
||||
x = new NDArray(x->permute({1, 0, 2})); // [bS x time x inSize] -> [time x bS x inSize]
|
||||
|
@ -127,12 +133,19 @@ DECLARE_SHAPE_FN(dynamic_rnn) {
|
|||
const int bS = timeMajor ? xShapeInfo[2] : xShapeInfo[1];
|
||||
const int numUnits = WxShapeInfo[2];
|
||||
|
||||
REQUIRE_TRUE(ShapeUtils::shapeAsString(WhShapeInfo) == ShapeUtils::shapeAsString({numUnits, numUnits}), 0, "DYNAMIC_RNN custom operation: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({numUnits, numUnits}).c_str(), ShapeUtils::shapeAsString(WhShapeInfo).c_str());
|
||||
REQUIRE_TRUE(ShapeUtils::shapeAsString(bShapeInfo) == ShapeUtils::shapeAsString({2*numUnits}), 0, "DYNAMIC_RNN custom operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2*numUnits}).c_str(), ShapeUtils::shapeAsString(bShapeInfo).c_str());
|
||||
if(h0ShapeInfo)
|
||||
REQUIRE_TRUE(ShapeUtils::shapeAsString(h0ShapeInfo) == ShapeUtils::shapeAsString({bS, numUnits}), 0, "DYNAMIC_RNN custom operation: wrong shape of initial cell output array, expected is %s but got %s instead !", ShapeUtils::shapeAsString({bS, numUnits}).c_str(), ShapeUtils::shapeAsString(h0ShapeInfo).c_str());
|
||||
if(maxTimeStepShapeInfo)
|
||||
REQUIRE_TRUE(ShapeUtils::shapeAsString(maxTimeStepShapeInfo) == ShapeUtils::shapeAsString({bS}), 0, "DYNAMIC_RNN custom operation: wrong shape of maxTimeStep array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS}).c_str(), ShapeUtils::shapeAsString(maxTimeStepShapeInfo).c_str());
|
||||
|
||||
std::vector<Nd4jLong> expectedWhShape = {numUnits, numUnits};
|
||||
std::vector<Nd4jLong> expectedBShape = {2*numUnits};
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(WhShapeInfo, expectedWhShape), 0, "DYNAMIC_RNN custom operation: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWhShape).c_str(), ShapeUtils::shapeAsString(WhShapeInfo).c_str());
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(bShapeInfo, expectedBShape), 0, "DYNAMIC_RNN custom operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedBShape).c_str(), ShapeUtils::shapeAsString(bShapeInfo).c_str());
|
||||
if(h0ShapeInfo) {
|
||||
std::vector<Nd4jLong> expectedh0Shape = {bS, numUnits};
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(h0ShapeInfo, expectedh0Shape), 0, "DYNAMIC_RNN custom operation: wrong shape of initial cell output array, expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedh0Shape).c_str(), ShapeUtils::shapeAsString(h0ShapeInfo).c_str());
|
||||
}
|
||||
if(maxTimeStepShapeInfo) {
|
||||
std::vector<Nd4jLong> expectedmaxTimeStepShape = {bS};
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(maxTimeStepShapeInfo, expectedmaxTimeStepShape), 0, "DYNAMIC_RNN custom operation: wrong shape of maxTimeStep array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedmaxTimeStepShape).c_str(), ShapeUtils::shapeAsString(maxTimeStepShapeInfo).c_str());
|
||||
}
|
||||
|
||||
// evaluate output shapeInfos
|
||||
Nd4jLong *hShapeInfo(nullptr), *hPrevShapeInfo(nullptr);
|
||||
|
|
|
@ -32,35 +32,31 @@ namespace ops {
|
|||
CUSTOM_OP_IMPL(gru, 5, 1, false, 0, 0) {
|
||||
auto x = INPUT_VARIABLE(0); // input [time x bS x iS]
|
||||
auto h0 = INPUT_VARIABLE(1); // initial cell output (at time step = 0) [bS x nU]
|
||||
|
||||
|
||||
auto Wx = INPUT_VARIABLE(2); // input-to-hidden weights, [iS x 3*nU]
|
||||
auto Wh = INPUT_VARIABLE(3); // hidden-to-hidden weights, [nU x 3*nU]
|
||||
auto b = INPUT_VARIABLE(4); // biases, [3*nU]
|
||||
|
||||
|
||||
auto h = OUTPUT_VARIABLE(0); // cell outputs [time x bS x nU], that is per each time step
|
||||
|
||||
const int rank = x->rankOf(); // = 3
|
||||
const int rank = x->rankOf(); // = 3
|
||||
const int time = x->sizeAt(0);
|
||||
const int bS = x->sizeAt(1);
|
||||
const int iS = x->sizeAt(2);
|
||||
const int nU = h0->sizeAt(1);
|
||||
|
||||
const std::string h0Shape = ShapeUtils::shapeAsString(h0);
|
||||
const std::string h0CorrectShape = ShapeUtils::shapeAsString({bS, nU});
|
||||
const std::string wxShape = ShapeUtils::shapeAsString(Wx);
|
||||
const std::string wxCorrectShape = ShapeUtils::shapeAsString({iS, 3*nU});
|
||||
const std::string whShape = ShapeUtils::shapeAsString(Wh);
|
||||
const std::string whCorrectShape = ShapeUtils::shapeAsString({nU, 3*nU});
|
||||
const std::string bShape = ShapeUtils::shapeAsString(b);
|
||||
const std::string bCorrectShape = ShapeUtils::shapeAsString({3*nU});
|
||||
|
||||
REQUIRE_TRUE(h0Shape == h0CorrectShape, 0, "GRU operation: wrong shape of previous cell output array, expected is %s, but got %s instead !", h0CorrectShape.c_str(), h0Shape.c_str());
|
||||
REQUIRE_TRUE(wxShape == wxCorrectShape, 0, "GRU operation: wrong shape of input-to-hidden weights array, expected is %s, but got %s instead !", wxCorrectShape.c_str(), wxShape.c_str());
|
||||
REQUIRE_TRUE(whShape == whCorrectShape, 0, "GRU operation: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !", whCorrectShape.c_str(), whShape.c_str());
|
||||
REQUIRE_TRUE(bShape == bCorrectShape, 0, "GRU operation: wrong shape of biases array, expected is %s, but got %s instead !", bCorrectShape.c_str(), bShape.c_str());
|
||||
const std::vector<Nd4jLong> h0CorrectShape = {bS, nU};
|
||||
const std::vector<Nd4jLong> wxCorrectShape = {iS, 3*nU};
|
||||
const std::vector<Nd4jLong> whCorrectShape = {nU, 3*nU};
|
||||
const std::vector<Nd4jLong> bCorrectShape = {3*nU};
|
||||
|
||||
REQUIRE_TRUE(h0->isSameShape(h0CorrectShape), 0, "GRU operation: wrong shape of previous cell output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(h0CorrectShape).c_str(), ShapeUtils::shapeAsString(h0).c_str());
|
||||
REQUIRE_TRUE(Wx->isSameShape(wxCorrectShape), 0, "GRU operation: wrong shape of input-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wxCorrectShape).c_str(), ShapeUtils::shapeAsString(Wx).c_str());
|
||||
REQUIRE_TRUE(Wh->isSameShape(whCorrectShape), 0, "GRU operation: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(whCorrectShape).c_str(), ShapeUtils::shapeAsString(Wh).c_str());
|
||||
REQUIRE_TRUE(b->isSameShape(bCorrectShape), 0, "GRU operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bCorrectShape).c_str(), ShapeUtils::shapeAsString(b).c_str());
|
||||
|
||||
helpers::gruTimeLoop(block.launchContext(), x, h0, Wx, Wh, b, h);
|
||||
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -72,7 +68,7 @@ CUSTOM_OP_IMPL(gru, 5, 1, false, 0, 0) {
|
|||
}
|
||||
|
||||
|
||||
DECLARE_SHAPE_FN(gru) {
|
||||
DECLARE_SHAPE_FN(gru) {
|
||||
const auto xShapeInfo = inputShape->at(0); // input [time x bS x inSize]
|
||||
const auto h0ShapeInfo = inputShape->at(1); // initial cell output [bS x numUnits], that is at time step t=0
|
||||
const auto WxShapeInfo = inputShape->at(2); // input-to-hidden weights, [inSize x 3*numUnits]
|
||||
|
@ -85,34 +81,30 @@ DECLARE_SHAPE_FN(gru) {
|
|||
const auto inSize = xShapeInfo[3];
|
||||
const auto numUnits = h0ShapeInfo[2];
|
||||
|
||||
const std::string h0Shape = ShapeUtils::shapeAsString(h0ShapeInfo);
|
||||
const std::string h0CorrectShape = ShapeUtils::shapeAsString({bS, numUnits});
|
||||
const std::string wxShape = ShapeUtils::shapeAsString(WxShapeInfo);
|
||||
const std::string wxCorrectShape = ShapeUtils::shapeAsString({inSize, 3*numUnits});
|
||||
const std::string whShape = ShapeUtils::shapeAsString(WhShapeInfo);
|
||||
const std::string whCorrectShape = ShapeUtils::shapeAsString({numUnits, 3*numUnits});
|
||||
const std::string bShape = ShapeUtils::shapeAsString(bShapeInfo);
|
||||
const std::string bCorrectShape = ShapeUtils::shapeAsString({3*numUnits});
|
||||
|
||||
REQUIRE_TRUE(h0Shape == h0CorrectShape, 0, "GRU operation: wrong shape of previous cell output array, expected is %s, but got %s instead !", h0CorrectShape.c_str(), h0Shape.c_str());
|
||||
REQUIRE_TRUE(wxShape == wxCorrectShape, 0, "GRU operation: wrong shape of input-to-hidden weights array, expected is %s, but got %s instead !", wxCorrectShape.c_str(), wxShape.c_str());
|
||||
REQUIRE_TRUE(whShape == whCorrectShape, 0, "GRU operation: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !", whCorrectShape.c_str(), whShape.c_str());
|
||||
REQUIRE_TRUE(bShape == bCorrectShape, 0, "GRU operation: wrong shape of biases array, expected is %s, but got %s instead !", bCorrectShape.c_str(), bShape.c_str());
|
||||
const std::vector<Nd4jLong> h0CorrectShape = {bS, numUnits};
|
||||
const std::vector<Nd4jLong> wxCorrectShape = {inSize, 3*numUnits};
|
||||
const std::vector<Nd4jLong> whCorrectShape = {numUnits, 3*numUnits};
|
||||
const std::vector<Nd4jLong> bCorrectShape = {3*numUnits};
|
||||
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(h0ShapeInfo, h0CorrectShape), 0, "GRU operation: wrong shape of previous cell output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(h0CorrectShape).c_str(), ShapeUtils::shapeAsString(h0ShapeInfo).c_str());
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(WxShapeInfo, wxCorrectShape), 0, "GRU operation: wrong shape of input-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wxCorrectShape).c_str(), ShapeUtils::shapeAsString(WxShapeInfo).c_str());
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(WhShapeInfo, whCorrectShape), 0, "GRU operation: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(whCorrectShape).c_str(), ShapeUtils::shapeAsString(WhShapeInfo).c_str());
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(bShapeInfo, bCorrectShape), 0, "GRU operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bCorrectShape).c_str(), ShapeUtils::shapeAsString(bShapeInfo).c_str());
|
||||
|
||||
|
||||
// evaluate output shapeInfo
|
||||
Nd4jLong *hShapeInfo(nullptr);
|
||||
ALLOCATE(hShapeInfo, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong);
|
||||
|
||||
|
||||
hShapeInfo[0] = rank;
|
||||
hShapeInfo[1] = time;
|
||||
hShapeInfo[2] = bS;
|
||||
hShapeInfo[3] = numUnits;
|
||||
|
||||
ShapeUtils::updateStridesAndType(hShapeInfo, xShapeInfo, shape::order(h0ShapeInfo));
|
||||
|
||||
|
||||
return SHAPELIST(hShapeInfo);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
|
|
@ -145,30 +145,21 @@ CUSTOM_OP_IMPL(gruCell_bp, 10, 6, false, 0, 0) {
|
|||
|
||||
REQUIRE_TRUE(x->rankOf() == 2, 0, "GRU_CELL_BP: rank of input array x must be 2, but got %i instead", x->rankOf());
|
||||
|
||||
const std::string hiShape = ShapeUtils::shapeAsString(hi);
|
||||
const std::string hiCorrectShape = ShapeUtils::shapeAsString({bS, nU});
|
||||
const std::string wShape = ShapeUtils::shapeAsString(W);
|
||||
const std::string wCorrectShape = ShapeUtils::shapeAsString({iS+nU, 2*nU});
|
||||
const std::string wcShape = ShapeUtils::shapeAsString(Wc);
|
||||
const std::string wcCorrectShape = ShapeUtils::shapeAsString({iS+nU, nU});
|
||||
const std::string bShape = ShapeUtils::shapeAsString(b);
|
||||
const std::string bCorrectShape = ShapeUtils::shapeAsString({2*nU});
|
||||
const std::string bcShape = ShapeUtils::shapeAsString(bc);
|
||||
const std::string bcCorrectShape = ShapeUtils::shapeAsString({nU});
|
||||
const std::string dLdrShape = ShapeUtils::shapeAsString(dLdr);
|
||||
const std::string dLduShape = ShapeUtils::shapeAsString(dLdu);
|
||||
const std::string dLdcShape = ShapeUtils::shapeAsString(dLdc);
|
||||
const std::string dLdhShape = ShapeUtils::shapeAsString(dLdh);
|
||||
const std::vector<Nd4jLong> hiCorrectShape = {bS, nU};
|
||||
const std::vector<Nd4jLong> wCorrectShape = {iS+nU, 2*nU};
|
||||
const std::vector<Nd4jLong> wcCorrectShape = {iS+nU, nU};
|
||||
const std::vector<Nd4jLong> bCorrectShape = {2*nU};
|
||||
const std::vector<Nd4jLong> bcCorrectShape = {nU};
|
||||
|
||||
REQUIRE_TRUE(hiShape == hiCorrectShape, 0, "GRU_CELL_BP op: wrong shape of previous cell output array, expected is %s, but got %s instead !", hiCorrectShape.c_str(), hiShape.c_str());
|
||||
REQUIRE_TRUE(wShape == wCorrectShape, 0, "GRU_CELL_BP op: wrong shape of weights array, expected is %s, but got %s instead !", wCorrectShape.c_str(), wShape.c_str());
|
||||
REQUIRE_TRUE(wcShape == wcCorrectShape, 0, "GRU_CELL_BP op: wrong shape of c weights array, expected is %s, but got %s instead !", wcCorrectShape.c_str(), wcShape.c_str());
|
||||
REQUIRE_TRUE(bShape == bCorrectShape, 0, "GRU_CELL_BP op: wrong shape of biases array, expected is %s, but got %s instead !", bCorrectShape.c_str(), bShape.c_str());
|
||||
REQUIRE_TRUE(bcShape == bcCorrectShape, 0, "GRU_CELL_BP op: wrong shape of c biases array, expected is %s, but got %s instead !", bcCorrectShape.c_str(), bcShape.c_str());
|
||||
REQUIRE_TRUE(dLdrShape == hiCorrectShape, 0, "GRU_CELL_BP op: wrong shape of dLdr array (gradient wrt reset gate), expected is %s, but got %s instead !", hiCorrectShape.c_str(), dLdrShape.c_str());
|
||||
REQUIRE_TRUE(dLduShape == hiCorrectShape, 0, "GRU_CELL_BP op: wrong shape of dLdu array (gradient wrt update gate), expected is %s, but got %s instead !", hiCorrectShape.c_str(), dLduShape.c_str());
|
||||
REQUIRE_TRUE(dLdcShape == hiCorrectShape, 0, "GRU_CELL_BP op: wrong shape of dLdc array (gradient wrt cell state), expected is %s, but got %s instead !", hiCorrectShape.c_str(), dLdcShape.c_str());
|
||||
REQUIRE_TRUE(dLdhShape == hiCorrectShape, 0, "GRU_CELL_BP op: wrong shape of dLdh array (gradient wrt current cell output), expected is %s, but got %s instead !", hiCorrectShape.c_str(), dLdhShape.c_str());
|
||||
REQUIRE_TRUE(hi->isSameShape(hiCorrectShape), 0, "GRU_CELL_BP op: wrong shape of previous cell output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(hiCorrectShape).c_str(), ShapeUtils::shapeAsString(hi).c_str());
|
||||
REQUIRE_TRUE(W->isSameShape(wCorrectShape), 0, "GRU_CELL_BP op: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wCorrectShape).c_str(), ShapeUtils::shapeAsString(W).c_str());
|
||||
REQUIRE_TRUE(Wc->isSameShape(wcCorrectShape), 0, "GRU_CELL_BP op: wrong shape of c weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wcCorrectShape).c_str(), ShapeUtils::shapeAsString(Wc).c_str());
|
||||
REQUIRE_TRUE(b->isSameShape(bCorrectShape), 0, "GRU_CELL_BP op: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bCorrectShape).c_str(), ShapeUtils::shapeAsString(b).c_str());
|
||||
REQUIRE_TRUE(bc->isSameShape(bcCorrectShape), 0, "GRU_CELL_BP op: wrong shape of c biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bcCorrectShape).c_str(), ShapeUtils::shapeAsString(bc).c_str());
|
||||
REQUIRE_TRUE(dLdr->isSameShape(hiCorrectShape), 0, "GRU_CELL_BP op: wrong shape of dLdr array (gradient wrt reset gate), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(hiCorrectShape).c_str(), ShapeUtils::shapeAsString(dLdr).c_str());
|
||||
REQUIRE_TRUE(dLdu->isSameShape(hiCorrectShape), 0, "GRU_CELL_BP op: wrong shape of dLdu array (gradient wrt update gate), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(hiCorrectShape).c_str(), ShapeUtils::shapeAsString(dLdu).c_str());
|
||||
REQUIRE_TRUE(dLdc->isSameShape(hiCorrectShape), 0, "GRU_CELL_BP op: wrong shape of dLdc array (gradient wrt cell state), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(hiCorrectShape).c_str(), ShapeUtils::shapeAsString(dLdc).c_str());
|
||||
REQUIRE_TRUE(dLdh->isSameShape(hiCorrectShape), 0, "GRU_CELL_BP op: wrong shape of dLdh array (gradient wrt current cell output), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(hiCorrectShape).c_str(), ShapeUtils::shapeAsString(dLdh).c_str());
|
||||
|
||||
helpers::gruCellBP(block.launchContext(), x, hi, W, Wc, b, bc, dLdr, dLdu, dLdc, dLdh, dLdx, dLdhi, dLdW, dLdWc, dLdb, dLdbc);
|
||||
|
||||
|
@ -210,30 +201,21 @@ DECLARE_SHAPE_FN(gruCell_bp) {
|
|||
|
||||
REQUIRE_TRUE(xShapeInfo[0] == 2, 0, "GRU_CELL_BP: rank of input array x must be 2, but got %i instead", xShapeInfo[0]);
|
||||
|
||||
const std::string hiShape = ShapeUtils::shapeAsString(hiShapeInfo);
|
||||
const std::string hiCorrectShape = ShapeUtils::shapeAsString({bS, nU});
|
||||
const std::string wShape = ShapeUtils::shapeAsString(wShapeInfo);
|
||||
const std::string wCorrectShape = ShapeUtils::shapeAsString({iS+nU, 2*nU});
|
||||
const std::string wcShape = ShapeUtils::shapeAsString(wcShapeInfo);
|
||||
const std::string wcCorrectShape = ShapeUtils::shapeAsString({iS+nU, nU});
|
||||
const std::string bShape = ShapeUtils::shapeAsString(bShapeInfo);
|
||||
const std::string bCorrectShape = ShapeUtils::shapeAsString({2*nU});
|
||||
const std::string bcShape = ShapeUtils::shapeAsString(bcShapeInfo);
|
||||
const std::string bcCorrectShape = ShapeUtils::shapeAsString({nU});
|
||||
const std::string dLdrShape = ShapeUtils::shapeAsString(dLdrShapeInfo);
|
||||
const std::string dLduShape = ShapeUtils::shapeAsString(dLduShapeInfo);
|
||||
const std::string dLdcShape = ShapeUtils::shapeAsString(dLdcShapeInfo);
|
||||
const std::string dLdhShape = ShapeUtils::shapeAsString(dLdhShapeInfo);
|
||||
const std::vector<Nd4jLong> hiCorrectShape = {bS, nU};
|
||||
const std::vector<Nd4jLong> wCorrectShape = {iS+nU, 2*nU};
|
||||
const std::vector<Nd4jLong> wcCorrectShape = {iS+nU, nU};
|
||||
const std::vector<Nd4jLong> bCorrectShape = {2*nU};
|
||||
const std::vector<Nd4jLong> bcCorrectShape = {nU};
|
||||
|
||||
REQUIRE_TRUE(hiShape == hiCorrectShape, 0, "GRU_CELL_BP op: wrong shape of previous cell output array, expected is %s, but got %s instead !", hiCorrectShape.c_str(), hiShape.c_str());
|
||||
REQUIRE_TRUE(wShape == wCorrectShape, 0, "GRU_CELL_BP op: wrong shape of weights array, expected is %s, but got %s instead !", wCorrectShape.c_str(), wShape.c_str());
|
||||
REQUIRE_TRUE(wcShape == wcCorrectShape, 0, "GRU_CELL_BP op: wrong shape of c weights array, expected is %s, but got %s instead !", wcCorrectShape.c_str(), wcShape.c_str());
|
||||
REQUIRE_TRUE(bShape == bCorrectShape, 0, "GRU_CELL_BP op: wrong shape of biases array, expected is %s, but got %s instead !", bCorrectShape.c_str(), bShape.c_str());
|
||||
REQUIRE_TRUE(bcShape == bcCorrectShape, 0, "GRU_CELL_BP op: wrong shape of c biases array, expected is %s, but got %s instead !", bcCorrectShape.c_str(), bcShape.c_str());
|
||||
REQUIRE_TRUE(dLdrShape == hiCorrectShape, 0, "GRU_CELL_BP op: wrong shape of dLdr array (gradient wrt reset gate), expected is %s, but got %s instead !", hiCorrectShape.c_str(), dLdrShape.c_str());
|
||||
REQUIRE_TRUE(dLduShape == hiCorrectShape, 0, "GRU_CELL_BP op: wrong shape of dLdu array (gradient wrt update gate), expected is %s, but got %s instead !", hiCorrectShape.c_str(), dLduShape.c_str());
|
||||
REQUIRE_TRUE(dLdcShape == hiCorrectShape, 0, "GRU_CELL_BP op: wrong shape of dLdc array (gradient wrt cell state), expected is %s, but got %s instead !", hiCorrectShape.c_str(), dLdcShape.c_str());
|
||||
REQUIRE_TRUE(dLdhShape == hiCorrectShape, 0, "GRU_CELL_BP op: wrong shape of dLdh array (gradient wrt current cell output), expected is %s, but got %s instead !", hiCorrectShape.c_str(), dLdhShape.c_str());
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(hiShapeInfo, hiCorrectShape), 0, "GRU_CELL_BP op: wrong shape of previous cell output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(hiCorrectShape).c_str(), ShapeUtils::shapeAsString(hiShapeInfo).c_str());
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(wShapeInfo, wCorrectShape), 0, "GRU_CELL_BP op: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wCorrectShape).c_str(), ShapeUtils::shapeAsString(wShapeInfo).c_str());
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(wcShapeInfo, wcCorrectShape), 0, "GRU_CELL_BP op: wrong shape of c weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wcCorrectShape).c_str(), ShapeUtils::shapeAsString(wcShapeInfo).c_str());
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(bShapeInfo, bCorrectShape), 0, "GRU_CELL_BP op: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bCorrectShape).c_str(), ShapeUtils::shapeAsString(bShapeInfo).c_str());
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(bcShapeInfo, bcCorrectShape), 0, "GRU_CELL_BP op: wrong shape of c biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bcCorrectShape).c_str(), ShapeUtils::shapeAsString(bcShapeInfo).c_str());
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(dLdrShapeInfo, hiCorrectShape), 0, "GRU_CELL_BP op: wrong shape of dLdr array (gradient wrt reset gate), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(hiCorrectShape).c_str(), ShapeUtils::shapeAsString(dLdrShapeInfo).c_str());
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(dLduShapeInfo, hiCorrectShape), 0, "GRU_CELL_BP op: wrong shape of dLdu array (gradient wrt update gate), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(hiCorrectShape).c_str(), ShapeUtils::shapeAsString(dLduShapeInfo).c_str());
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(dLdcShapeInfo, hiCorrectShape), 0, "GRU_CELL_BP op: wrong shape of dLdc array (gradient wrt cell state), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(hiCorrectShape).c_str(), ShapeUtils::shapeAsString(dLdcShapeInfo).c_str());
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(dLdhShapeInfo, hiCorrectShape), 0, "GRU_CELL_BP op: wrong shape of dLdh array (gradient wrt current cell output), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(hiCorrectShape).c_str(), ShapeUtils::shapeAsString(dLdhShapeInfo).c_str());
|
||||
|
||||
Nd4jLong *dLdxShapeInfo = nullptr;
|
||||
COPY_SHAPE(xShapeInfo, dLdxShapeInfo);
|
||||
|
|
|
@ -39,10 +39,10 @@ CUSTOM_OP_IMPL(lstm, 8, 2, false, 3, 2) {
|
|||
auto Wc = INPUT_VARIABLE(5); // diagonal weights for peephole connections [3*numUnits]
|
||||
auto Wp = INPUT_VARIABLE(6); // projection weights [numUnits x numProj]
|
||||
auto b = INPUT_VARIABLE(7); // biases, [4*numUnits]
|
||||
|
||||
|
||||
auto h = OUTPUT_VARIABLE(0); // cell outputs [time x bS x numProj], that is per each time step
|
||||
auto c = OUTPUT_VARIABLE(1); // cell states [time x bS x numUnits] that is per each time step
|
||||
|
||||
|
||||
const int peephole = INT_ARG(0); // if 1, provide peephole connections
|
||||
const int projection = INT_ARG(1); // if 1, then projection is performed, if false then numProj==numUnits is mandatory!!!!
|
||||
|
||||
|
@ -59,28 +59,21 @@ CUSTOM_OP_IMPL(lstm, 8, 2, false, 3, 2) {
|
|||
const int numUnits = c0->sizeAt(1);
|
||||
|
||||
// input shapes validation
|
||||
const std::string h0Shape = ShapeUtils::shapeAsString(h0);
|
||||
const std::string correctH0Shape = ShapeUtils::shapeAsString({bS, numProj});
|
||||
const std::string c0Shape = ShapeUtils::shapeAsString(c0);
|
||||
const std::string correctC0Shape = ShapeUtils::shapeAsString({bS, numUnits});
|
||||
const std::string WxShape = ShapeUtils::shapeAsString(Wx);
|
||||
const std::string correctWxShape = ShapeUtils::shapeAsString({inSize, 4*numUnits});
|
||||
const std::string WhShape = ShapeUtils::shapeAsString(Wh);
|
||||
const std::string correctWhShape = ShapeUtils::shapeAsString({numProj, 4*numUnits});
|
||||
const std::string WcShape = ShapeUtils::shapeAsString(Wc);
|
||||
const std::string correctWcShape = ShapeUtils::shapeAsString({3*numUnits});
|
||||
const std::string WpShape = ShapeUtils::shapeAsString(Wp);
|
||||
const std::string correctWpShape = ShapeUtils::shapeAsString({numUnits, numProj});
|
||||
const std::string bShape = ShapeUtils::shapeAsString(b);
|
||||
const std::string correctBShape = ShapeUtils::shapeAsString({4*numUnits});
|
||||
const std::vector<Nd4jLong> correctH0Shape = {bS, numProj};
|
||||
const std::vector<Nd4jLong> correctC0Shape = {bS, numUnits};
|
||||
const std::vector<Nd4jLong> correctWxShape = {inSize, 4*numUnits};
|
||||
const std::vector<Nd4jLong> correctWhShape = {numProj, 4*numUnits};
|
||||
const std::vector<Nd4jLong> correctWcShape = {3*numUnits};
|
||||
const std::vector<Nd4jLong> correctWpShape = {numUnits, numProj};
|
||||
const std::vector<Nd4jLong> correctBShape = {4*numUnits};
|
||||
|
||||
REQUIRE_TRUE(correctH0Shape == h0Shape, 0, "LSTM operation: wrong shape of initial cell output, expected is %s, but got %s instead !", correctH0Shape.c_str(), h0Shape.c_str());
|
||||
REQUIRE_TRUE(correctC0Shape == c0Shape, 0, "LSTM operation: wrong shape of initial cell state, expected is %s, but got %s instead !", correctC0Shape.c_str(), c0Shape.c_str());
|
||||
REQUIRE_TRUE(correctWxShape == WxShape, 0, "LSTM operation: wrong shape of input-to-hidden weights, expected is %s, but got %s instead !", correctWxShape.c_str(), WxShape.c_str());
|
||||
REQUIRE_TRUE(correctWhShape == WhShape, 0, "LSTM operation: wrong shape of hidden-to-hidden weights, expected is %s, but got %s instead !", correctWhShape.c_str(), WhShape.c_str());
|
||||
REQUIRE_TRUE(correctWcShape == WcShape, 0, "LSTM operation: wrong shape of diagonal weights for peephole connections, expected is %s, but got %s instead !", correctWcShape.c_str(), WcShape.c_str());
|
||||
REQUIRE_TRUE(correctWpShape == WpShape, 0, "LSTM operation: wrong shape of projection weights, expected is %s, but got %s instead !", correctWpShape.c_str(), WpShape.c_str());
|
||||
REQUIRE_TRUE(correctBShape == bShape, 0, "LSTM operation: wrong shape of biases, expected is %s, but got %s instead !", correctBShape.c_str(), bShape.c_str());
|
||||
REQUIRE_TRUE(h0->isSameShape(correctH0Shape), 0, "LSTM operation: wrong shape of initial cell output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctH0Shape).c_str(), ShapeUtils::shapeAsString(h0).c_str());
|
||||
REQUIRE_TRUE(c0->isSameShape(correctC0Shape), 0, "LSTM operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctC0Shape).c_str(), ShapeUtils::shapeAsString(c0).c_str());
|
||||
REQUIRE_TRUE(Wx->isSameShape(correctWxShape), 0, "LSTM operation: wrong shape of input-to-hidden weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctWxShape).c_str(), ShapeUtils::shapeAsString(Wx).c_str());
|
||||
REQUIRE_TRUE(Wh->isSameShape(correctWhShape), 0, "LSTM operation: wrong shape of hidden-to-hidden weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctWhShape).c_str(), ShapeUtils::shapeAsString(Wh).c_str());
|
||||
REQUIRE_TRUE(Wc->isSameShape(correctWcShape), 0, "LSTM operation: wrong shape of diagonal weights for peephole connections, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctWcShape).c_str(), ShapeUtils::shapeAsString(Wc).c_str());
|
||||
REQUIRE_TRUE(Wp->isSameShape(correctWpShape), 0, "LSTM operation: wrong shape of projection weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctWpShape).c_str(), ShapeUtils::shapeAsString(Wp).c_str());
|
||||
REQUIRE_TRUE(b->isSameShape(correctBShape), 0, "LSTM operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctBShape).c_str(), ShapeUtils::shapeAsString(b).c_str());
|
||||
REQUIRE_TRUE(!(!projection && numUnits != numProj), 0, "LSTM operation: projection option is switched of, and in this case output dimensionality for the projection matrices (numProj) must be equal to number of units in lstmCell !");
|
||||
|
||||
helpers::lstmTimeLoop(block.launchContext(), x, h0, c0, Wx, Wh, Wc, Wp, b, h, c, {(double)peephole, (double)projection, clippingCellValue, clippingProjValue, forgetBias});
|
||||
|
@ -95,7 +88,7 @@ CUSTOM_OP_IMPL(lstm, 8, 2, false, 3, 2) {
|
|||
}
|
||||
|
||||
|
||||
DECLARE_SHAPE_FN(lstm) {
|
||||
DECLARE_SHAPE_FN(lstm) {
|
||||
|
||||
auto xShapeInfo = inputShape->at(0); // input [time x bS x inSize]
|
||||
auto h0ShapeInfo = inputShape->at(1); // initial cell output (at time step = 0) [bS x numProj], in case of projection=false -> numProj == numUnits !!!
|
||||
|
@ -113,37 +106,30 @@ DECLARE_SHAPE_FN(lstm) {
|
|||
const int inSize = xShapeInfo[3];
|
||||
const int numProj = h0ShapeInfo[2];
|
||||
const int numUnits = c0ShapeInfo[2];
|
||||
|
||||
|
||||
// input shapes validation
|
||||
const std::string h0Shape = ShapeUtils::shapeAsString(h0ShapeInfo);
|
||||
const std::string correctH0Shape = ShapeUtils::shapeAsString({bS, numProj});
|
||||
const std::string c0Shape = ShapeUtils::shapeAsString(c0ShapeInfo);
|
||||
const std::string correctC0Shape = ShapeUtils::shapeAsString({bS, numUnits});
|
||||
const std::string WxShape = ShapeUtils::shapeAsString(WxShapeInfo);
|
||||
const std::string correctWxShape = ShapeUtils::shapeAsString({inSize, 4*numUnits});
|
||||
const std::string WhShape = ShapeUtils::shapeAsString(WhShapeInfo);
|
||||
const std::string correctWhShape = ShapeUtils::shapeAsString({numProj, 4*numUnits});
|
||||
const std::string WcShape = ShapeUtils::shapeAsString(WcShapeInfo);
|
||||
const std::string correctWcShape = ShapeUtils::shapeAsString({3*numUnits});
|
||||
const std::string WpShape = ShapeUtils::shapeAsString(WpShapeInfo);
|
||||
const std::string correctWpShape = ShapeUtils::shapeAsString({numUnits, numProj});
|
||||
const std::string bShape = ShapeUtils::shapeAsString(bShapeInfo);
|
||||
const std::string correctBShape = ShapeUtils::shapeAsString({4*numUnits});
|
||||
const std::vector<Nd4jLong> correctH0Shape = {bS, numProj};
|
||||
const std::vector<Nd4jLong> correctC0Shape = {bS, numUnits};
|
||||
const std::vector<Nd4jLong> correctWxShape = {inSize, 4*numUnits};
|
||||
const std::vector<Nd4jLong> correctWhShape = {numProj, 4*numUnits};
|
||||
const std::vector<Nd4jLong> correctWcShape = {3*numUnits};
|
||||
const std::vector<Nd4jLong> correctWpShape = {numUnits, numProj};
|
||||
const std::vector<Nd4jLong> correctBShape = {4*numUnits};
|
||||
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(h0ShapeInfo, correctH0Shape), 0, "LSTM operation: wrong shape of initial cell output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctH0Shape).c_str(), ShapeUtils::shapeAsString(h0ShapeInfo).c_str());
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(c0ShapeInfo, correctC0Shape), 0, "LSTM operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctC0Shape).c_str(), ShapeUtils::shapeAsString(c0ShapeInfo).c_str());
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(WxShapeInfo, correctWxShape), 0, "LSTM operation: wrong shape of input-to-hidden weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctWxShape).c_str(), ShapeUtils::shapeAsString(WxShapeInfo).c_str());
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(WhShapeInfo, correctWhShape), 0, "LSTM operation: wrong shape of hidden-to-hidden weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctWhShape).c_str(), ShapeUtils::shapeAsString(WhShapeInfo).c_str());
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(WcShapeInfo, correctWcShape), 0, "LSTM operation: wrong shape of diagonal weights for peephole connections, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctWcShape).c_str(), ShapeUtils::shapeAsString(WcShapeInfo).c_str());
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(WpShapeInfo, correctWpShape), 0, "LSTM operation: wrong shape of projection weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctWpShape).c_str(), ShapeUtils::shapeAsString(WpShapeInfo).c_str());
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(bShapeInfo, correctBShape), 0, "LSTM operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctBShape).c_str(), ShapeUtils::shapeAsString(bShapeInfo).c_str());
|
||||
|
||||
REQUIRE_TRUE(correctH0Shape == h0Shape, 0, "LSTM operation: wrong shape of initial cell output, expected is %s, but got %s instead !", correctH0Shape.c_str(), h0Shape.c_str());
|
||||
REQUIRE_TRUE(correctC0Shape == c0Shape, 0, "LSTM operation: wrong shape of initial cell state, expected is %s, but got %s instead !", correctC0Shape.c_str(), c0Shape.c_str());
|
||||
REQUIRE_TRUE(correctWxShape == WxShape, 0, "LSTM operation: wrong shape of input-to-hidden weights, expected is %s, but got %s instead !", correctWxShape.c_str(), WxShape.c_str());
|
||||
REQUIRE_TRUE(correctWhShape == WhShape, 0, "LSTM operation: wrong shape of hidden-to-hidden weights, expected is %s, but got %s instead !", correctWhShape.c_str(), WhShape.c_str());
|
||||
REQUIRE_TRUE(correctWcShape == WcShape, 0, "LSTM operation: wrong shape of diagonal weights for peephole connections, expected is %s, but got %s instead !", correctWcShape.c_str(), WcShape.c_str());
|
||||
REQUIRE_TRUE(correctWpShape == WpShape, 0, "LSTM operation: wrong shape of projection weights, expected is %s, but got %s instead !", correctWpShape.c_str(), WpShape.c_str());
|
||||
REQUIRE_TRUE(correctBShape == bShape, 0, "LSTM operation: wrong shape of biases, expected is %s, but got %s instead !", correctBShape.c_str(), bShape.c_str());
|
||||
|
||||
|
||||
// evaluate output shapeInfos
|
||||
Nd4jLong *hShapeInfo(nullptr), *cShapeInfo(nullptr);
|
||||
ALLOCATE(hShapeInfo, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong); // [time x bS x numProj]
|
||||
ALLOCATE(cShapeInfo, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong); // [time x bS x numUnits]
|
||||
|
||||
|
||||
hShapeInfo[0] = cShapeInfo[0] = rank;
|
||||
hShapeInfo[1] = cShapeInfo[1] = time;
|
||||
hShapeInfo[2] = cShapeInfo[2] = bS;
|
||||
|
@ -152,9 +138,9 @@ DECLARE_SHAPE_FN(lstm) {
|
|||
|
||||
ShapeUtils::updateStridesAndType(hShapeInfo, xShapeInfo, shape::order(h0ShapeInfo));
|
||||
ShapeUtils::updateStridesAndType(cShapeInfo, xShapeInfo, shape::order(c0ShapeInfo));
|
||||
|
||||
|
||||
return SHAPELIST(CONSTANT(hShapeInfo), CONSTANT(cShapeInfo));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -39,10 +39,10 @@ CUSTOM_OP_IMPL(lstmCell, 8, 2, false, 3, 2) {
|
|||
auto Wc = INPUT_VARIABLE(5); // diagonal weights for peephole connections [3*numUnits]
|
||||
auto Wp = INPUT_VARIABLE(6); // projection weights [numUnits x numProj]
|
||||
auto b = INPUT_VARIABLE(7); // biases, [4*numUnits]
|
||||
|
||||
|
||||
auto ht = OUTPUT_VARIABLE(0); // current cell output [bS x numProj], that is at current time step t
|
||||
auto ct = OUTPUT_VARIABLE(1); // current cell state [bS x numUnits], that is at current time step t
|
||||
|
||||
|
||||
const int peephole = INT_ARG(0); // if 1, provide peephole connections
|
||||
const int projection = INT_ARG(1); // if 1, then projection is performed, if false then numProj==numUnits is mandatory!!!!
|
||||
|
||||
|
@ -51,40 +51,33 @@ CUSTOM_OP_IMPL(lstmCell, 8, 2, false, 3, 2) {
|
|||
const double clippingProjValue = T_ARG(1); // clipping value for projected ht, if it is not equal to zero, then projected cell output is clipped
|
||||
const double forgetBias = T_ARG(2);
|
||||
|
||||
const int rank = xt->rankOf();
|
||||
const int rank = xt->rankOf();
|
||||
const int bS = xt->sizeAt(0);
|
||||
const int inSize = xt->sizeAt(1);
|
||||
const int numProj = ht_1->sizeAt(1);
|
||||
const int numUnits = ct_1->sizeAt(1);
|
||||
|
||||
// input shapes validation
|
||||
const std::string ht_1Shape = ShapeUtils::shapeAsString(ht_1);
|
||||
const std::string correctHt_1Shape = ShapeUtils::shapeAsString({bS, numProj});
|
||||
const std::string ct_1Shape = ShapeUtils::shapeAsString(ct_1);
|
||||
const std::string correctCt_1Shape = ShapeUtils::shapeAsString({bS, numUnits});
|
||||
const std::string WxShape = ShapeUtils::shapeAsString(Wx);
|
||||
const std::string correctWxShape = ShapeUtils::shapeAsString({inSize, 4*numUnits});
|
||||
const std::string WhShape = ShapeUtils::shapeAsString(Wh);
|
||||
const std::string correctWhShape = ShapeUtils::shapeAsString({numProj, 4*numUnits});
|
||||
const std::string WcShape = ShapeUtils::shapeAsString(Wc);
|
||||
const std::string correctWcShape = ShapeUtils::shapeAsString({3*numUnits});
|
||||
const std::string WpShape = ShapeUtils::shapeAsString(Wp);
|
||||
const std::string correctWpShape = ShapeUtils::shapeAsString({numUnits, numProj});
|
||||
const std::string bShape = ShapeUtils::shapeAsString(b);
|
||||
const std::string correctBShape = ShapeUtils::shapeAsString({4*numUnits});
|
||||
const int numUnits = ct_1->sizeAt(1);
|
||||
|
||||
REQUIRE_TRUE(correctHt_1Shape == ht_1Shape, 0, "LSTMCELL operation: wrong shape of initial cell output, expected is %s, but got %s instead !", correctHt_1Shape.c_str(), ht_1Shape.c_str());
|
||||
REQUIRE_TRUE(correctCt_1Shape == ct_1Shape, 0, "LSTMCELL operation: wrong shape of initial cell state, expected is %s, but got %s instead !", correctCt_1Shape.c_str(), ct_1Shape.c_str());
|
||||
REQUIRE_TRUE(correctWxShape == WxShape, 0, "LSTMCELL operation: wrong shape of input-to-hidden weights, expected is %s, but got %s instead !", correctWxShape.c_str(), WxShape.c_str());
|
||||
REQUIRE_TRUE(correctWhShape == WhShape, 0, "LSTMCELL operation: wrong shape of hidden-to-hidden weights, expected is %s, but got %s instead !", correctWhShape.c_str(), WhShape.c_str());
|
||||
REQUIRE_TRUE(correctWcShape == WcShape, 0, "LSTMCELL operation: wrong shape of diagonal weights for peephole connections, expected is %s, but got %s instead !", correctWcShape.c_str(), WcShape.c_str());
|
||||
REQUIRE_TRUE(correctWpShape == WpShape, 0, "LSTMCELL operation: wrong shape of projection weights, expected is %s, but got %s instead !", correctWpShape.c_str(), WpShape.c_str());
|
||||
REQUIRE_TRUE(correctBShape == bShape, 0, "LSTMCELL operation: wrong shape of biases, expected is %s, but got %s instead !", correctBShape.c_str(), bShape.c_str());
|
||||
// input shapes validation
|
||||
const std::vector<Nd4jLong> correctHt_1Shape = {bS, numProj};
|
||||
const std::vector<Nd4jLong> correctCt_1Shape = {bS, numUnits};
|
||||
const std::vector<Nd4jLong> correctWxShape = {inSize, 4*numUnits};
|
||||
const std::vector<Nd4jLong> correctWhShape = {numProj, 4*numUnits};
|
||||
const std::vector<Nd4jLong> correctWcShape = {3*numUnits};
|
||||
const std::vector<Nd4jLong> correctWpShape = {numUnits, numProj};
|
||||
const std::vector<Nd4jLong> correctBShape = {4*numUnits};
|
||||
|
||||
REQUIRE_TRUE(ht_1->isSameShape(correctHt_1Shape), 0, "LSTMCELL operation: wrong shape of initial cell output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctHt_1Shape).c_str(), ShapeUtils::shapeAsString(ht_1).c_str());
|
||||
REQUIRE_TRUE(ct_1->isSameShape(correctCt_1Shape), 0, "LSTMCELL operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctCt_1Shape).c_str(), ShapeUtils::shapeAsString(ct_1).c_str());
|
||||
REQUIRE_TRUE(Wx->isSameShape(correctWxShape), 0, "LSTMCELL operation: wrong shape of input-to-hidden weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctWxShape).c_str(), ShapeUtils::shapeAsString(Wx).c_str());
|
||||
REQUIRE_TRUE(Wh->isSameShape(correctWhShape), 0, "LSTMCELL operation: wrong shape of hidden-to-hidden weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctWhShape).c_str(), ShapeUtils::shapeAsString(Wh).c_str());
|
||||
REQUIRE_TRUE(Wc->isSameShape(correctWcShape), 0, "LSTMCELL operation: wrong shape of diagonal weights for peephole connections, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctWcShape).c_str(), ShapeUtils::shapeAsString(Wc).c_str());
|
||||
REQUIRE_TRUE(Wp->isSameShape(correctWpShape), 0, "LSTMCELL operation: wrong shape of projection weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctWpShape).c_str(), ShapeUtils::shapeAsString(Wp).c_str());
|
||||
REQUIRE_TRUE(b->isSameShape(correctBShape), 0, "LSTMCELL operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctBShape).c_str(), ShapeUtils::shapeAsString(b).c_str());
|
||||
REQUIRE_TRUE(!(!projection && numUnits != numProj), 0, "LSTMCELL operation: projection option is switched of, and in this case output dimensionality for the projection matrices (numProj) must be equal to number of units in lstmCell !");
|
||||
|
||||
// calculations
|
||||
|
||||
// calculations
|
||||
helpers::lstmCell(block.launchContext(), xt,ht_1,ct_1, Wx,Wh,Wc,Wp, b, ht,ct, {(double)peephole, (double)projection, clippingCellValue, clippingProjValue, forgetBias});
|
||||
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -95,53 +88,46 @@ CUSTOM_OP_IMPL(lstmCell, 8, 2, false, 3, 2) {
|
|||
}
|
||||
|
||||
|
||||
DECLARE_SHAPE_FN(lstmCell) {
|
||||
DECLARE_SHAPE_FN(lstmCell) {
|
||||
|
||||
auto xtShapeInfo = inputShape->at(0); // input [bS x inSize]
|
||||
auto ht_1ShapeInfo = inputShape->at(1); // previous cell output [bS x numProj], that is at previous time step t-1, in case of projection=false -> numProj=numUnits!!!
|
||||
auto ct_1ShapeInfo = inputShape->at(2); // previous cell state [bS x numUnits], that is at previous time step t-1
|
||||
auto ht_1ShapeInfo = inputShape->at(1); // previous cell output [bS x numProj], that is at previous time step t-1, in case of projection=false -> numProj=numUnits!!!
|
||||
auto ct_1ShapeInfo = inputShape->at(2); // previous cell state [bS x numUnits], that is at previous time step t-1
|
||||
|
||||
auto WxShapeInfo = inputShape->at(3); // input-to-hidden weights, [inSize x 4*numUnits]
|
||||
auto WhShapeInfo = inputShape->at(4); // hidden-to-hidden weights, [numProj x 4*numUnits]
|
||||
auto WcShapeInfo = inputShape->at(5); // diagonal weights for peephole connections [3*numUnits]
|
||||
auto WpShapeInfo = inputShape->at(6); // projection weights [numUnits x numProj]
|
||||
auto bShapeInfo = inputShape->at(7); // biases, [4*numUnits]
|
||||
|
||||
auto WxShapeInfo = inputShape->at(3); // input-to-hidden weights, [inSize x 4*numUnits]
|
||||
auto WhShapeInfo = inputShape->at(4); // hidden-to-hidden weights, [numProj x 4*numUnits]
|
||||
auto WcShapeInfo = inputShape->at(5); // diagonal weights for peephole connections [3*numUnits]
|
||||
auto WpShapeInfo = inputShape->at(6); // projection weights [numUnits x numProj]
|
||||
auto bShapeInfo = inputShape->at(7); // biases, [4*numUnits]
|
||||
|
||||
const int rank = shape::rank(xtShapeInfo);
|
||||
const auto bS = xtShapeInfo[1];
|
||||
const auto inSize = xtShapeInfo[2];
|
||||
const auto numProj = ht_1ShapeInfo[2];
|
||||
const auto numUnits = ct_1ShapeInfo[2];
|
||||
|
||||
// input shapes validation
|
||||
const std::string ht_1Shape = ShapeUtils::shapeAsString(ht_1ShapeInfo);
|
||||
const std::string correctHt_1Shape = ShapeUtils::shapeAsString({bS, numProj});
|
||||
const std::string ct_1Shape = ShapeUtils::shapeAsString(ct_1ShapeInfo);
|
||||
const std::string correctCt_1Shape = ShapeUtils::shapeAsString({bS, numUnits});
|
||||
const std::string WxShape = ShapeUtils::shapeAsString(WxShapeInfo);
|
||||
const std::string correctWxShape = ShapeUtils::shapeAsString({inSize, 4*numUnits});
|
||||
const std::string WhShape = ShapeUtils::shapeAsString(WhShapeInfo);
|
||||
const std::string correctWhShape = ShapeUtils::shapeAsString({numProj, 4*numUnits});
|
||||
const std::string WcShape = ShapeUtils::shapeAsString(WcShapeInfo );
|
||||
const std::string correctWcShape = ShapeUtils::shapeAsString({3*numUnits});
|
||||
const std::string WpShape = ShapeUtils::shapeAsString(WpShapeInfo);
|
||||
const std::string correctWpShape = ShapeUtils::shapeAsString({numUnits, numProj});
|
||||
const std::string bShape = ShapeUtils::shapeAsString(bShapeInfo);
|
||||
const std::string correctBShape = ShapeUtils::shapeAsString({4*numUnits});
|
||||
|
||||
REQUIRE_TRUE(correctHt_1Shape == ht_1Shape, 0, "LSTMCELL operation: wrong shape of initial cell output, expected is %s, but got %s instead !", correctHt_1Shape.c_str(), ht_1Shape.c_str());
|
||||
REQUIRE_TRUE(correctCt_1Shape == ct_1Shape, 0, "LSTMCELL operation: wrong shape of initial cell state, expected is %s, but got %s instead !", correctCt_1Shape.c_str(), ct_1Shape.c_str());
|
||||
REQUIRE_TRUE(correctWxShape == WxShape, 0, "LSTMCELL operation: wrong shape of input-to-hidden weights, expected is %s, but got %s instead !", correctWxShape.c_str(), WxShape.c_str());
|
||||
REQUIRE_TRUE(correctWhShape == WhShape, 0, "LSTMCELL operation: wrong shape of hidden-to-hidden weights, expected is %s, but got %s instead !", correctWhShape.c_str(), WhShape.c_str());
|
||||
REQUIRE_TRUE(correctWcShape == WcShape, 0, "LSTMCELL operation: wrong shape of diagonal weights for peephole connections, expected is %s, but got %s instead !", correctWcShape.c_str(), WcShape.c_str());
|
||||
REQUIRE_TRUE(correctWpShape == WpShape, 0, "LSTMCELL operation: wrong shape of projection weights, expected is %s, but got %s instead !", correctWpShape.c_str(), WpShape.c_str());
|
||||
REQUIRE_TRUE(correctBShape == bShape, 0, "LSTMCELL operation: wrong shape of biases, expected is %s, but got %s instead !", correctBShape.c_str(), bShape.c_str());
|
||||
|
||||
// input shapes validation
|
||||
const std::vector<Nd4jLong> correctHt_1Shape = {bS, numProj};
|
||||
const std::vector<Nd4jLong> correctCt_1Shape = {bS, numUnits};
|
||||
const std::vector<Nd4jLong> correctWxShape = {inSize, 4*numUnits};
|
||||
const std::vector<Nd4jLong> correctWhShape = {numProj, 4*numUnits};
|
||||
const std::vector<Nd4jLong> correctWcShape = {3*numUnits};
|
||||
const std::vector<Nd4jLong> correctWpShape = {numUnits, numProj};
|
||||
const std::vector<Nd4jLong> correctBShape = {4*numUnits};
|
||||
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(ht_1ShapeInfo, correctHt_1Shape), 0, "LSTMCELL operation: wrong shape of initial cell output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctHt_1Shape).c_str(), ShapeUtils::shapeAsString(ht_1ShapeInfo).c_str());
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(ct_1ShapeInfo, correctCt_1Shape), 0, "LSTMCELL operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctCt_1Shape).c_str(), ShapeUtils::shapeAsString(ct_1ShapeInfo).c_str());
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(WxShapeInfo, correctWxShape), 0, "LSTMCELL operation: wrong shape of input-to-hidden weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctWxShape).c_str(), ShapeUtils::shapeAsString(WxShapeInfo).c_str());
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(WhShapeInfo, correctWhShape), 0, "LSTMCELL operation: wrong shape of hidden-to-hidden weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctWhShape).c_str(), ShapeUtils::shapeAsString(WhShapeInfo).c_str());
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(WcShapeInfo, correctWcShape), 0, "LSTMCELL operation: wrong shape of diagonal weights for peephole connections, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctWcShape).c_str(), ShapeUtils::shapeAsString(WcShapeInfo).c_str());
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(WpShapeInfo, correctWpShape), 0, "LSTMCELL operation: wrong shape of projection weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctWpShape).c_str(), ShapeUtils::shapeAsString(WpShapeInfo).c_str());
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(bShapeInfo, correctBShape), 0, "LSTMCELL operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctBShape).c_str(), ShapeUtils::shapeAsString(bShapeInfo).c_str());
|
||||
|
||||
// evaluate output shapeInfos
|
||||
Nd4jLong *hShapeInfo(nullptr), *cShapeInfo(nullptr);
|
||||
ALLOCATE(hShapeInfo, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong); // [bS x numProj]
|
||||
ALLOCATE(cShapeInfo, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong); // [bS x numUnits]
|
||||
|
||||
|
||||
hShapeInfo[0] = cShapeInfo[0] = rank;
|
||||
hShapeInfo[1] = cShapeInfo[1] = bS;
|
||||
hShapeInfo[2] = numProj;
|
||||
|
@ -154,7 +140,7 @@ DECLARE_SHAPE_FN(lstmCell) {
|
|||
RELEASE(hShapeInfo, block.workspace());
|
||||
RELEASE(cShapeInfo, block.workspace());
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
|
|
@ -54,20 +54,15 @@ CUSTOM_OP_IMPL(sru, 5, 2, false, 0, 0) {
|
|||
if(mask)
|
||||
REQUIRE_TRUE(mask->rankOf() == rank-1, 0, "SRU operation: wrong rank of mask array, expected is %i, but got %i instead !", rank-1, mask->rankOf());
|
||||
|
||||
const std::string wShape = ShapeUtils::shapeAsString(w);
|
||||
const std::string wCorrectShape = ShapeUtils::shapeAsString({3*inSize, inSize});
|
||||
const std::string bShape = ShapeUtils::shapeAsString(b);
|
||||
const std::string bCorrectShape = ShapeUtils::shapeAsString({2*inSize});
|
||||
const std::string c0Shape = ShapeUtils::shapeAsString(c0);
|
||||
const std::string c0CorrectShape = ShapeUtils::shapeAsString({bS, inSize});
|
||||
const std::vector<Nd4jLong> wCorrectShape = {3*inSize, inSize};
|
||||
const std::vector<Nd4jLong> bCorrectShape = {2*inSize};
|
||||
const std::vector<Nd4jLong> c0CorrectShape = {bS, inSize};
|
||||
|
||||
REQUIRE_TRUE(wShape == wCorrectShape, 0, "SRU operation: wrong shape of weights array, expected is %s, but got %s instead !", wCorrectShape.c_str(), wShape.c_str());
|
||||
REQUIRE_TRUE(bShape == bCorrectShape, 0, "SRU operation: wrong shape of biases array, expected is %s, but got %s instead !", bCorrectShape.c_str(), bShape.c_str());
|
||||
REQUIRE_TRUE(c0Shape == c0CorrectShape, 0, "SRU operation: wrong shape of initial state array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), c0Shape.c_str());
|
||||
if(mask) {
|
||||
const std::string maskShape = ShapeUtils::shapeAsString(mask);
|
||||
REQUIRE_TRUE(maskShape == c0CorrectShape, 0, "SRU operation: wrong shape of mask array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), maskShape.c_str());
|
||||
}
|
||||
REQUIRE_TRUE(w->isSameShape(wCorrectShape), 0, "SRU operation: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wCorrectShape).c_str(), ShapeUtils::shapeAsString(w).c_str());
|
||||
REQUIRE_TRUE(b->isSameShape(bCorrectShape), 0, "SRU operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bCorrectShape).c_str(), ShapeUtils::shapeAsString(b).c_str());
|
||||
REQUIRE_TRUE(c0->isSameShape(c0CorrectShape), 0, "SRU operation: wrong shape of initial state array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(c0CorrectShape).c_str(), ShapeUtils::shapeAsString(c0).c_str());
|
||||
if(mask)
|
||||
REQUIRE_TRUE(mask->isSameShape(c0CorrectShape), 0, "SRU operation: wrong shape of mask array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(c0CorrectShape).c_str(), ShapeUtils::shapeAsString(mask).c_str());
|
||||
|
||||
// xm = x * mask
|
||||
auto xm = x;
|
||||
|
@ -111,20 +106,15 @@ DECLARE_SHAPE_FN(sru) {
|
|||
if(maskShapeInfo)
|
||||
REQUIRE_TRUE(maskShapeInfo[0] == rank-1, 0, "SRU operation: wrong rank of mask array, expected is %i, but got %i instead !", rank-1, maskShapeInfo[0]);
|
||||
|
||||
const std::string wShape = ShapeUtils::shapeAsString(wShapeInfo);
|
||||
const std::string wCorrectShape = ShapeUtils::shapeAsString({3*inSize, inSize});
|
||||
const std::string bShape = ShapeUtils::shapeAsString(bShapeInfo);
|
||||
const std::string bCorrectShape = ShapeUtils::shapeAsString({2*inSize});
|
||||
const std::string c0Shape = ShapeUtils::shapeAsString(c0ShapeInfo);
|
||||
const std::string c0CorrectShape = ShapeUtils::shapeAsString({bS, inSize});
|
||||
const std::vector<Nd4jLong> wCorrectShape = {3*inSize, inSize};
|
||||
const std::vector<Nd4jLong> bCorrectShape = {2*inSize};
|
||||
const std::vector<Nd4jLong> c0CorrectShape = {bS, inSize};
|
||||
|
||||
REQUIRE_TRUE(wShape == wCorrectShape, 0, "SRU operation: wrong shape of weights array, expected is %s, but got %s instead !", wCorrectShape.c_str(), wShape.c_str());
|
||||
REQUIRE_TRUE(bShape == bCorrectShape, 0, "SRU operation: wrong shape of biases array, expected is %s, but got %s instead !", bCorrectShape.c_str(), bShape.c_str());
|
||||
REQUIRE_TRUE(c0Shape == c0CorrectShape, 0, "SRU operation: wrong shape of initial state array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), c0Shape.c_str());
|
||||
if(maskShapeInfo) {
|
||||
const std::string maskShape = ShapeUtils::shapeAsString(maskShapeInfo);
|
||||
REQUIRE_TRUE(maskShape == c0CorrectShape, 0, "SRU operation: wrong shape of mask array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), maskShape.c_str());
|
||||
}
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(wShapeInfo, wCorrectShape), 0, "SRU operation: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wCorrectShape).c_str(), ShapeUtils::shapeAsString(wShapeInfo).c_str());
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(bShapeInfo, bCorrectShape), 0, "SRU operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bCorrectShape).c_str(), ShapeUtils::shapeAsString(bShapeInfo).c_str());
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(c0ShapeInfo, c0CorrectShape), 0, "SRU operation: wrong shape of initial state array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(c0CorrectShape).c_str(), ShapeUtils::shapeAsString(c0ShapeInfo).c_str());
|
||||
if(maskShapeInfo)
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(maskShapeInfo, c0CorrectShape), 0, "SRU operation: wrong shape of mask array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(c0CorrectShape).c_str(), ShapeUtils::shapeAsString(maskShapeInfo).c_str());
|
||||
|
||||
Nd4jLong* newShapeInfo1 = nullptr;
|
||||
ALLOCATE(newShapeInfo1, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong); // [bS x inSize x time]
|
||||
|
@ -350,20 +340,15 @@ CUSTOM_OP_IMPL(sru_bi, 5, 2, true, 0, 0) {
|
|||
if(mask)
|
||||
REQUIRE_TRUE(mask->rankOf() == rank-1, 0, "SRU_BI operation: wrong rank of mask array, expected is %i, but got %i instead !", rank-1, mask->rankOf());
|
||||
|
||||
const std::string wShape = ShapeUtils::shapeAsString(w);
|
||||
const std::string wCorrectShape = ShapeUtils::shapeAsString({2*inSize, 6*inSize});
|
||||
const std::string bShape = ShapeUtils::shapeAsString(b);
|
||||
const std::string bCorrectShape = ShapeUtils::shapeAsString({4*inSize});
|
||||
const std::string c0Shape = ShapeUtils::shapeAsString(c0);
|
||||
const std::string c0CorrectShape = ShapeUtils::shapeAsString({bS, 2*inSize});
|
||||
const std::vector<Nd4jLong> wCorrectShape = {2*inSize, 6*inSize};
|
||||
const std::vector<Nd4jLong> bCorrectShape = {4*inSize};
|
||||
const std::vector<Nd4jLong> c0CorrectShape = {bS, 2*inSize};
|
||||
|
||||
REQUIRE_TRUE(wShape == wCorrectShape, 0, "SRU_BI operation: wrong shape of weights array, expected is %s, but got %s instead !", wCorrectShape.c_str(), wShape.c_str());
|
||||
REQUIRE_TRUE(bShape == bCorrectShape, 0, "SRU_BI operation: wrong shape of biases array, expected is %s, but got %s instead !", bCorrectShape.c_str(), bShape.c_str());
|
||||
REQUIRE_TRUE(c0Shape == c0CorrectShape, 0, "SRU_BI operation: wrong shape of initial state array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), c0Shape.c_str());
|
||||
if(mask) {
|
||||
const std::string maskShape = ShapeUtils::shapeAsString(mask);
|
||||
REQUIRE_TRUE(maskShape == c0CorrectShape, 0, "SRU_BI operation: wrong shape of mask array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), maskShape.c_str());
|
||||
}
|
||||
REQUIRE_TRUE(w->isSameShape(wCorrectShape), 0, "SRU_BI operation: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wCorrectShape).c_str(), ShapeUtils::shapeAsString(w).c_str());
|
||||
REQUIRE_TRUE(b->isSameShape(bCorrectShape), 0, "SRU_BI operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bCorrectShape).c_str(), ShapeUtils::shapeAsString(b).c_str());
|
||||
REQUIRE_TRUE(c0->isSameShape(c0CorrectShape), 0, "SRU_BI operation: wrong shape of initial state array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(c0CorrectShape).c_str(), ShapeUtils::shapeAsString(c0).c_str());
|
||||
if(mask)
|
||||
REQUIRE_TRUE(mask->isSameShape(c0CorrectShape), 0, "SRU_BI operation: wrong shape of mask array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(c0CorrectShape).c_str(), ShapeUtils::shapeAsString(mask).c_str());
|
||||
|
||||
helpers::sruBI(block.launchContext(), x, w, b, c0, mask, ht, ct);
|
||||
|
||||
|
@ -397,20 +382,16 @@ DECLARE_SHAPE_FN(sru_bi) {
|
|||
if(maskShapeInfo)
|
||||
REQUIRE_TRUE(maskShapeInfo[0] == rank-1, 0, "SRU_BI operation: wrong rank of mask array, expected is %i, but got %i instead !", rank-1, maskShapeInfo[0]);
|
||||
|
||||
const std::string wShape = ShapeUtils::shapeAsString(wShapeInfo);
|
||||
const std::string wCorrectShape = ShapeUtils::shapeAsString({2*inSize, 6*inSize});
|
||||
const std::string bShape = ShapeUtils::shapeAsString(bShapeInfo);
|
||||
const std::string bCorrectShape = ShapeUtils::shapeAsString({4*inSize});
|
||||
const std::string c0Shape = ShapeUtils::shapeAsString(c0ShapeInfo);
|
||||
const std::string c0CorrectShape = ShapeUtils::shapeAsString({bS, 2*inSize});
|
||||
const std::vector<Nd4jLong> wCorrectShape = {2*inSize, 6*inSize};
|
||||
const std::vector<Nd4jLong> bCorrectShape = {4*inSize};
|
||||
const std::vector<Nd4jLong> c0CorrectShape = {bS, 2*inSize};
|
||||
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(wShapeInfo, wCorrectShape), 0, "SRU_BI operation: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wCorrectShape).c_str(), ShapeUtils::shapeAsString(wShapeInfo).c_str());
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(bShapeInfo, bCorrectShape), 0, "SRU_BI operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bCorrectShape).c_str(), ShapeUtils::shapeAsString(bShapeInfo).c_str());
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(c0ShapeInfo, c0CorrectShape), 0, "SRU_BI operation: wrong shape of initial state array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(c0CorrectShape).c_str(), ShapeUtils::shapeAsString(c0ShapeInfo).c_str());
|
||||
if(maskShapeInfo)
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(maskShapeInfo, c0CorrectShape), 0, "SRU_BI operation: wrong shape of mask array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(c0CorrectShape).c_str(), ShapeUtils::shapeAsString(maskShapeInfo).c_str());
|
||||
|
||||
REQUIRE_TRUE(wShape == wCorrectShape, 0, "SRU_BI operation: wrong shape of weights array, expected is %s, but got %s instead !", wCorrectShape.c_str(), wShape.c_str());
|
||||
REQUIRE_TRUE(bShape == bCorrectShape, 0, "SRU_BI operation: wrong shape of biases array, expected is %s, but got %s instead !", bCorrectShape.c_str(), bShape.c_str());
|
||||
REQUIRE_TRUE(c0Shape == c0CorrectShape, 0, "SRU_BI operation: wrong shape of initial state array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), c0Shape.c_str());
|
||||
if(maskShapeInfo) {
|
||||
const std::string maskShape = ShapeUtils::shapeAsString(maskShapeInfo);
|
||||
REQUIRE_TRUE(maskShape == c0CorrectShape, 0, "SRU_BI operation: wrong shape of mask array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), maskShape.c_str());
|
||||
}
|
||||
|
||||
char order = shape::order(xShapeInfo);
|
||||
|
||||
|
@ -453,23 +434,17 @@ CUSTOM_OP_IMPL(sru_bi_bp, 8, 4, true, 0, 0) {
|
|||
if(mask)
|
||||
REQUIRE_TRUE(mask->rankOf() == rank-1, 0, "SRU_BI_BP operation: wrong rank of mask array, expected is %i, but got %i instead !", rank-1, mask->rankOf());
|
||||
|
||||
const std::string wShape = ShapeUtils::shapeAsString(w);
|
||||
const std::string wCorrectShape = ShapeUtils::shapeAsString({2*inSize, 6*inSize});
|
||||
const std::string bShape = ShapeUtils::shapeAsString(b);
|
||||
const std::string bCorrectShape = ShapeUtils::shapeAsString({4*inSize});
|
||||
const std::string c0Shape = ShapeUtils::shapeAsString(c0);
|
||||
const std::string c0CorrectShape = ShapeUtils::shapeAsString({bS, 2*inSize});
|
||||
const std::string ctShape = ShapeUtils::shapeAsString(ct);
|
||||
const std::string ctCorrectShape = ShapeUtils::shapeAsString({time, bS, 2*inSize});
|
||||
const std::vector<Nd4jLong> wCorrectShape = {2*inSize, 6*inSize};
|
||||
const std::vector<Nd4jLong> bCorrectShape = {4*inSize};
|
||||
const std::vector<Nd4jLong> c0CorrectShape = {bS, 2*inSize};
|
||||
const std::vector<Nd4jLong> ctCorrectShape = {time, bS, 2*inSize};
|
||||
|
||||
REQUIRE_TRUE(wShape == wCorrectShape, 0, "SRU_BI operation: wrong shape of weights array, expected is %s, but got %s instead !", wCorrectShape.c_str(), wShape.c_str());
|
||||
REQUIRE_TRUE(bShape == bCorrectShape, 0, "SRU_BI operation: wrong shape of biases array, expected is %s, but got %s instead !", bCorrectShape.c_str(), bShape.c_str());
|
||||
REQUIRE_TRUE(c0Shape == c0CorrectShape, 0, "SRU_BI operation: wrong shape of initial state array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), c0Shape.c_str());
|
||||
REQUIRE_TRUE(ctShape == ctCorrectShape, 0, "SRU_BI operation: wrong shape of state array, expected is %s, but got %s instead !", ctCorrectShape.c_str(), ctShape.c_str());
|
||||
if(mask) {
|
||||
const std::string maskShape = ShapeUtils::shapeAsString(mask);
|
||||
REQUIRE_TRUE(maskShape == c0CorrectShape, 0, "SRU_BI operation: wrong shape of mask array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), maskShape.c_str());
|
||||
}
|
||||
REQUIRE_TRUE(w->isSameShape(wCorrectShape), 0, "SRU_BI operation: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wCorrectShape).c_str(), ShapeUtils::shapeAsString(w).c_str());
|
||||
REQUIRE_TRUE(b->isSameShape(bCorrectShape), 0, "SRU_BI operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bCorrectShape).c_str(), ShapeUtils::shapeAsString(b).c_str());
|
||||
REQUIRE_TRUE(c0->isSameShape(c0CorrectShape), 0, "SRU_BI operation: wrong shape of initial state array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(c0CorrectShape).c_str(), ShapeUtils::shapeAsString(c0).c_str());
|
||||
REQUIRE_TRUE(ct->isSameShape(ctCorrectShape), 0, "SRU_BI operation: wrong shape of state array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(ctCorrectShape).c_str(), ShapeUtils::shapeAsString(ct).c_str());
|
||||
if(mask)
|
||||
REQUIRE_TRUE(mask->isSameShape(c0CorrectShape), 0, "SRU_BI operation: wrong shape of mask array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(c0CorrectShape).c_str(), ShapeUtils::shapeAsString(mask).c_str());
|
||||
|
||||
auto gradI = OUTPUT_VARIABLE(0); // [time x bS x 2*inSize]
|
||||
auto gradW = OUTPUT_VARIABLE(1); // [time x 2*inSize x 6*inSize]
|
||||
|
@ -507,29 +482,21 @@ DECLARE_SHAPE_FN(sru_bi_bp) {
|
|||
if(maskShapeInfo)
|
||||
REQUIRE_TRUE(maskShapeInfo[0] == rank-1, 0, "SRU_BI_BP operation: wrong rank of mask array, expected is %i, but got %i instead !", rank-1, maskShapeInfo[0]);
|
||||
|
||||
const std::string wShape = ShapeUtils::shapeAsString(wShapeInfo);
|
||||
const std::string wCorrectShape = ShapeUtils::shapeAsString({2*inSize, 6*inSize});
|
||||
const std::string bShape = ShapeUtils::shapeAsString(bShapeInfo);
|
||||
const std::string bCorrectShape = ShapeUtils::shapeAsString({4*inSize});
|
||||
const std::string c0Shape = ShapeUtils::shapeAsString(c0ShapeInfo);
|
||||
const std::string c0CorrectShape = ShapeUtils::shapeAsString({bS, 2*inSize});
|
||||
const std::string ctShape = ShapeUtils::shapeAsString(ctShapeInfo);
|
||||
const std::string ctCorrectShape = ShapeUtils::shapeAsString({time, bS, 2*inSize});
|
||||
const std::string inGradC0Shape = ShapeUtils::shapeAsString(inGradC0ShapeInfo);
|
||||
const std::string inGradC0CorrectShape = ShapeUtils::shapeAsString({bS, 2*inSize});
|
||||
const std::string inGradHtShape = ShapeUtils::shapeAsString(inGradHtShapeInfo);
|
||||
const std::string inGradHtCorrectShape = ShapeUtils::shapeAsString({time, bS, 2*inSize});
|
||||
const std::vector<Nd4jLong> wCorrectShape = {2*inSize, 6*inSize};
|
||||
const std::vector<Nd4jLong> bCorrectShape = {4*inSize};
|
||||
const std::vector<Nd4jLong> c0CorrectShape = {bS, 2*inSize};
|
||||
const std::vector<Nd4jLong> ctCorrectShape = {time, bS, 2*inSize};
|
||||
const std::vector<Nd4jLong> inGradC0CorrectShape = {bS, 2*inSize};
|
||||
const std::vector<Nd4jLong> inGradHtCorrectShape = {time, bS, 2*inSize};
|
||||
|
||||
REQUIRE_TRUE(wShape == wCorrectShape, 0, "SRU_BI operation: wrong shape of weights array, expected is %s, but got %s instead !", wCorrectShape.c_str(), wShape.c_str());
|
||||
REQUIRE_TRUE(bShape == bCorrectShape, 0, "SRU_BI operation: wrong shape of biases array, expected is %s, but got %s instead !", bCorrectShape.c_str(), bShape.c_str());
|
||||
REQUIRE_TRUE(c0Shape == c0CorrectShape, 0, "SRU_BI operation: wrong shape of initial state array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), c0Shape.c_str());
|
||||
REQUIRE_TRUE(ctShape == ctCorrectShape, 0, "SRU_BI operation: wrong shape of state array, expected is %s, but got %s instead !", ctCorrectShape.c_str(), ctShape.c_str());
|
||||
REQUIRE_TRUE(inGradC0Shape == inGradC0CorrectShape, 0, "SRU_BI operation: wrong shape of gradient c0 array, expected is %s, but got %s instead !", inGradC0CorrectShape.c_str(), inGradC0Shape.c_str());
|
||||
REQUIRE_TRUE(inGradHtShape == inGradHtCorrectShape, 0, "SRU_BI operation: wrong shape of gradient ht array, expected is %s, but got %s instead !", inGradHtCorrectShape.c_str(), inGradHtShape.c_str());
|
||||
if(maskShapeInfo) {
|
||||
const std::string maskShape = ShapeUtils::shapeAsString(maskShapeInfo);
|
||||
REQUIRE_TRUE(maskShape == c0CorrectShape, 0, "SRU_BI operation: wrong shape of mask array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), maskShape.c_str());
|
||||
}
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(wShapeInfo, wCorrectShape), 0, "SRU_BI operation: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wCorrectShape).c_str(), ShapeUtils::shapeAsString(wShapeInfo).c_str());
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(bShapeInfo, bCorrectShape), 0, "SRU_BI operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bCorrectShape).c_str(), ShapeUtils::shapeAsString(bShapeInfo).c_str());
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(c0ShapeInfo, c0CorrectShape), 0, "SRU_BI operation: wrong shape of initial state array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(c0CorrectShape).c_str(), ShapeUtils::shapeAsString(c0ShapeInfo).c_str());
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(ctShapeInfo, ctCorrectShape), 0, "SRU_BI operation: wrong shape of state array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(ctCorrectShape).c_str(), ShapeUtils::shapeAsString(ctShapeInfo).c_str());
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(inGradC0ShapeInfo, inGradC0CorrectShape), 0, "SRU_BI operation: wrong shape of gradient c0 array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(inGradC0CorrectShape).c_str(), ShapeUtils::shapeAsString(inGradC0ShapeInfo).c_str());
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(inGradHtShapeInfo, inGradHtCorrectShape), 0, "SRU_BI operation: wrong shape of gradient ht array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(inGradHtCorrectShape).c_str(), ShapeUtils::shapeAsString(inGradHtShapeInfo).c_str());
|
||||
if(maskShapeInfo)
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(maskShapeInfo, c0CorrectShape), 0, "SRU_BI operation: wrong shape of mask array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(c0CorrectShape).c_str(), ShapeUtils::shapeAsString(maskShapeInfo).c_str());
|
||||
|
||||
const char order = shape::order(xShapeInfo);
|
||||
|
||||
|
|
|
@ -31,7 +31,7 @@ namespace ops {
|
|||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
CUSTOM_OP_IMPL(sruCell, 4, 2, false, 0, 0) {
|
||||
auto xt = INPUT_VARIABLE(0); // input [bS x inSize], bS - batch size, inSize - number of features
|
||||
auto xt = INPUT_VARIABLE(0); // input [bS x inSize], bS - batch size, inSize - number of features
|
||||
auto ct_1 = INPUT_VARIABLE(1); // previous cell state ct [bS x inSize], that is at previous time step t-1
|
||||
auto w = INPUT_VARIABLE(2); // weights [inSize x 3*inSize]
|
||||
auto b = INPUT_VARIABLE(3); // biases [2*inSize]
|
||||
|
@ -40,25 +40,22 @@ CUSTOM_OP_IMPL(sruCell, 4, 2, false, 0, 0) {
|
|||
auto ct = OUTPUT_VARIABLE(1); // current cell state [bS x inSize], that is at current time step t
|
||||
|
||||
const int rank = xt->rankOf();
|
||||
const int bS = xt->sizeAt(0);
|
||||
const int bS = xt->sizeAt(0);
|
||||
const int inSize = xt->sizeAt(1); // inSize - number of features
|
||||
|
||||
// input shapes validation
|
||||
const std::string ct_1Shape = ShapeUtils::shapeAsString(ct_1);
|
||||
const std::string correctCt_1Shape = ShapeUtils::shapeAsString({bS, inSize});
|
||||
const std::string WShape = ShapeUtils::shapeAsString(w);
|
||||
const std::string correctWShape = ShapeUtils::shapeAsString({inSize, 3*inSize});
|
||||
const std::string bShape = ShapeUtils::shapeAsString(b);
|
||||
const std::string correctBShape = ShapeUtils::shapeAsString({2*inSize});
|
||||
const std::vector<Nd4jLong> correctCt_1Shape = {bS, inSize};
|
||||
const std::vector<Nd4jLong> correctWShape = {inSize, 3*inSize};
|
||||
const std::vector<Nd4jLong> correctBShape = {2*inSize};
|
||||
|
||||
REQUIRE_TRUE(correctCt_1Shape == ct_1Shape, 0, "SRUCELL operation: wrong shape of previous cell state, expected is %s, but got %s instead !", correctCt_1Shape.c_str(), ct_1Shape.c_str());
|
||||
REQUIRE_TRUE(correctWShape == WShape, 0, "SRUCELL operation: wrong shape of weights, expected is %s, but got %s instead !", correctWShape.c_str(), WShape.c_str());
|
||||
REQUIRE_TRUE(correctBShape == bShape, 0, "SRUCELL operation: wrong shape of biases, expected is %s, but got %s instead !", correctBShape.c_str(), bShape.c_str());
|
||||
REQUIRE_TRUE(ct_1->isSameShape(correctCt_1Shape), 0, "SRUCELL operation: wrong shape of previous cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctCt_1Shape).c_str(), ShapeUtils::shapeAsString(ct_1).c_str());
|
||||
REQUIRE_TRUE(w->isSameShape(correctWShape), 0, "SRUCELL operation: wrong shape of weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctWShape).c_str(), ShapeUtils::shapeAsString(w).c_str());
|
||||
REQUIRE_TRUE(b->isSameShape(correctBShape), 0, "SRUCELL operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctBShape).c_str(), ShapeUtils::shapeAsString(b).c_str());
|
||||
|
||||
|
||||
// fixme: shitty initializer lists
|
||||
helpers::sruCell(block.launchContext(), xt, ct_1, w, b, ht, ct);
|
||||
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -71,40 +68,37 @@ CUSTOM_OP_IMPL(sruCell, 4, 2, false, 0, 0) {
|
|||
DECLARE_SHAPE_FN(sruCell) {
|
||||
|
||||
auto xtShapeInfo = inputShape->at(0); // input [bS x inSize], bS - batch size, inSize - number of features
|
||||
auto ct_1ShapeInfo = inputShape->at(1); // previous cell state ct [bS x inSize], that is at previous time step t-1
|
||||
auto ct_1ShapeInfo = inputShape->at(1); // previous cell state ct [bS x inSize], that is at previous time step t-1
|
||||
auto wShapeInfo = inputShape->at(2); // weights [inSize x 3*inSize]
|
||||
auto bShapeInfo = inputShape->at(3); // biases [2*inSize]
|
||||
|
||||
const int rank = xtShapeInfo[0];
|
||||
const int bS = xtShapeInfo[1];
|
||||
const int bS = xtShapeInfo[1];
|
||||
const int inSize = xtShapeInfo[2]; // inSize - number of features
|
||||
|
||||
// input shapes validation
|
||||
const std::string ct_1Shape = ShapeUtils::shapeAsString(ct_1ShapeInfo);
|
||||
const std::string correctCt_1Shape = ShapeUtils::shapeAsString({bS, inSize});
|
||||
const std::string WShape = ShapeUtils::shapeAsString(wShapeInfo);
|
||||
const std::string correctWShape = ShapeUtils::shapeAsString({inSize, 3*inSize});
|
||||
const std::string bShape = ShapeUtils::shapeAsString(bShapeInfo);
|
||||
const std::string correctBShape = ShapeUtils::shapeAsString({2*inSize});
|
||||
const std::vector<Nd4jLong> correctCt_1Shape = {bS, inSize};
|
||||
const std::vector<Nd4jLong> correctWShape = {inSize, 3*inSize};
|
||||
const std::vector<Nd4jLong> correctBShape = {2*inSize};
|
||||
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(ct_1ShapeInfo, correctCt_1Shape) , 0, "SRUCELL operation: wrong shape of previous cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctCt_1Shape).c_str(), ShapeUtils::shapeAsString(ct_1ShapeInfo).c_str());
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(wShapeInfo ,correctWShape), 0, "SRUCELL operation: wrong shape of weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctWShape).c_str(), ShapeUtils::shapeAsString(wShapeInfo).c_str());
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(bShapeInfo ,correctBShape), 0, "SRUCELL operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(correctBShape).c_str(), ShapeUtils::shapeAsString(bShapeInfo).c_str());
|
||||
|
||||
REQUIRE_TRUE(correctCt_1Shape == ct_1Shape, 0, "SRUCELL operation: wrong shape of previous cell state, expected is %s, but got %s instead !", correctCt_1Shape.c_str(), ct_1Shape.c_str());
|
||||
REQUIRE_TRUE(correctWShape == WShape, 0, "SRUCELL operation: wrong shape of weights, expected is %s, but got %s instead !", correctWShape.c_str(), WShape.c_str());
|
||||
REQUIRE_TRUE(correctBShape == bShape, 0, "SRUCELL operation: wrong shape of biases, expected is %s, but got %s instead !", correctBShape.c_str(), bShape.c_str());
|
||||
|
||||
// evaluate output shapeInfos
|
||||
Nd4jLong *hShapeInfo(nullptr), *cShapeInfo(nullptr);
|
||||
ALLOCATE(hShapeInfo, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong); // [bS x numProj]
|
||||
ALLOCATE(cShapeInfo, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong); // [bS x numUnits]
|
||||
|
||||
|
||||
hShapeInfo[0] = cShapeInfo[0] = rank;
|
||||
hShapeInfo[1] = cShapeInfo[1] = bS;
|
||||
hShapeInfo[2] = cShapeInfo[2] = inSize;
|
||||
|
||||
|
||||
ShapeUtils::updateStridesAndType(hShapeInfo, ct_1ShapeInfo, shape::order(ct_1ShapeInfo));
|
||||
ShapeUtils::updateStridesAndType(cShapeInfo, ct_1ShapeInfo, shape::order(ct_1ShapeInfo));
|
||||
|
||||
|
||||
return SHAPELIST(ConstantShapeHelper::getInstance()->createFromExisting(hShapeInfo, block.workspace()), ConstantShapeHelper::getInstance()->createFromExisting(cShapeInfo, block.workspace()));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -55,14 +55,14 @@ CUSTOM_OP_IMPL(static_bidirectional_rnn, 7, 3, false, 0, 0) {
|
|||
maxTimeStep = INPUT_VARIABLE(9);
|
||||
break;
|
||||
}
|
||||
|
||||
|
||||
auto h = OUTPUT_VARIABLE(0); // cell outputs [time x bS x (numUnitsFW + numUnitsBW)], that is per each time step
|
||||
auto hFWFinal = OUTPUT_VARIABLE(1); // final cell out for forward RNN [bS x numUnitsFW]
|
||||
auto hBWFinal = OUTPUT_VARIABLE(2); // final cell out for backward RNN [bS x numUnitsBF]
|
||||
|
||||
REQUIRE_TRUE(x->rankOf() == 3, 0, "STATIC_BIDIRECTIONAL_RNN custom operation: input array must have rank = 3, but got %i instead !", x->rankOf());
|
||||
REQUIRE_TRUE(WxFW->rankOf() == 2, 0, "STATIC_BIDIRECTIONAL_RNN custom operation: input-to-hidden weights array (for forward RNN) must have rank = 2, but got %i instead !", WxFW->rankOf());
|
||||
REQUIRE_TRUE(WxBW->rankOf() == 2, 0, "STATIC_BIDIRECTIONAL_RNN custom operation: input-to-hidden weights array (for backward RNN) must have rank = 2, but got %i instead !", WxBW->rankOf());
|
||||
REQUIRE_TRUE(WxFW->rankOf() == 2, 0, "STATIC_BIDIRECTIONAL_RNN custom operation: input-to-hidden weights array (for forward RNN) must have rank = 2, but got %i instead !", WxFW->rankOf());
|
||||
REQUIRE_TRUE(WxBW->rankOf() == 2, 0, "STATIC_BIDIRECTIONAL_RNN custom operation: input-to-hidden weights array (for backward RNN) must have rank = 2, but got %i instead !", WxBW->rankOf());
|
||||
|
||||
const Nd4jLong inRank = x->rankOf();
|
||||
const Nd4jLong time = x->sizeAt(0);
|
||||
|
@ -70,39 +70,48 @@ CUSTOM_OP_IMPL(static_bidirectional_rnn, 7, 3, false, 0, 0) {
|
|||
const Nd4jLong numUnitsFW = WxFW->sizeAt(1);
|
||||
const Nd4jLong numUnitsBW = WxBW->sizeAt(1);
|
||||
|
||||
REQUIRE_TRUE(ShapeUtils::shapeAsString(WhFW) == ShapeUtils::shapeAsString({numUnitsFW, numUnitsFW}), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of hidden-to-hidden weights array (for forward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString({numUnitsFW, numUnitsFW}).c_str(), ShapeUtils::shapeAsString(WhFW).c_str());
|
||||
REQUIRE_TRUE(ShapeUtils::shapeAsString(WhBW) == ShapeUtils::shapeAsString({numUnitsBW, numUnitsBW}), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of hidden-to-hidden weights array (for backward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString({numUnitsBW, numUnitsBW}).c_str(), ShapeUtils::shapeAsString(WhBW).c_str());
|
||||
REQUIRE_TRUE(ShapeUtils::shapeAsString(bFW) == ShapeUtils::shapeAsString({2*numUnitsFW}), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of biases array (for forward RNN), expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2*numUnitsFW}).c_str(), ShapeUtils::shapeAsString(bFW).c_str());
|
||||
REQUIRE_TRUE(ShapeUtils::shapeAsString(bBW) == ShapeUtils::shapeAsString({2*numUnitsBW}), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of biases array (for backward RNN), expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2*numUnitsBW}).c_str(), ShapeUtils::shapeAsString(bBW).c_str());
|
||||
if(h0FW)
|
||||
REQUIRE_TRUE(ShapeUtils::shapeAsString(h0FW) == ShapeUtils::shapeAsString({bS, numUnitsFW}), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of initial cell output array (for forward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString({bS, numUnitsFW}).c_str(), ShapeUtils::shapeAsString(h0FW).c_str());
|
||||
if(h0BW)
|
||||
REQUIRE_TRUE(ShapeUtils::shapeAsString(h0BW) == ShapeUtils::shapeAsString({bS, numUnitsBW}), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of initial cell output array (for backward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString({bS, numUnitsBW}).c_str(), ShapeUtils::shapeAsString(h0BW).c_str());
|
||||
if(maxTimeStep)
|
||||
REQUIRE_TRUE(ShapeUtils::shapeAsString(maxTimeStep) == ShapeUtils::shapeAsString({bS}), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of maxTimeStep array, expected is [%i], but got %s instead !", bS, ShapeUtils::shapeAsString(maxTimeStep).c_str());
|
||||
const std::vector<Nd4jLong> expectedWhFWshape = {numUnitsFW, numUnitsFW};
|
||||
const std::vector<Nd4jLong> expectedWhBWshape = {numUnitsBW, numUnitsBW};
|
||||
const std::vector<Nd4jLong> expectedbFWshape = {2 * numUnitsFW};
|
||||
const std::vector<Nd4jLong> expectedbBWshape = {2 * numUnitsBW};
|
||||
|
||||
// forward steps
|
||||
REQUIRE_TRUE(WhFW->isSameShape(expectedWhFWshape), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of hidden-to-hidden weights array (for forward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedWhFWshape).c_str(), ShapeUtils::shapeAsString(WhFW).c_str());
|
||||
REQUIRE_TRUE(WhBW->isSameShape(expectedWhBWshape), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of hidden-to-hidden weights array (for backward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedWhBWshape).c_str(), ShapeUtils::shapeAsString(WhBW).c_str());
|
||||
REQUIRE_TRUE(bFW->isSameShape(expectedbFWshape), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of biases array (for forward RNN), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedbFWshape).c_str(), ShapeUtils::shapeAsString(bFW).c_str());
|
||||
REQUIRE_TRUE(bBW->isSameShape(expectedbBWshape), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of biases array (for backward RNN), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedbBWshape).c_str(), ShapeUtils::shapeAsString(bBW).c_str());
|
||||
if(h0FW) {
|
||||
const std::vector<Nd4jLong> expectedh0FWshape = {bS, numUnitsFW};
|
||||
REQUIRE_TRUE(h0FW->isSameShape(expectedh0FWshape), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of initial cell output array (for forward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedh0FWshape).c_str(), ShapeUtils::shapeAsString(h0FW).c_str());
|
||||
}
|
||||
if(h0BW) {
|
||||
const std::vector<Nd4jLong> expectedh0BWshape = {bS, numUnitsBW};
|
||||
REQUIRE_TRUE(h0BW->isSameShape(expectedh0BWshape), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of initial cell output array (for backward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedh0BWshape).c_str(), ShapeUtils::shapeAsString(h0BW).c_str());
|
||||
}
|
||||
if(maxTimeStep)
|
||||
REQUIRE_TRUE(maxTimeStep->isSameShape({bS}), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of maxTimeStep array, expected is [%i], but got %s instead !", bS, ShapeUtils::shapeAsString(maxTimeStep).c_str());
|
||||
|
||||
// forward steps
|
||||
auto hFW = new NDArray(x->ordering(), {time, bS, numUnitsFW}, x->dataType(), block.launchContext());
|
||||
helpers::rnnTimeLoop(block.launchContext(), x, WxFW, WhFW, bFW, h0FW, maxTimeStep, hFW, hFWFinal);
|
||||
|
||||
auto seqLen = maxTimeStep;
|
||||
auto seqLen = maxTimeStep;
|
||||
if(seqLen == nullptr) {
|
||||
// seqLen = new NDArray(x->ordering(), {x->sizeAt(1)}, x->dataType(), block.launchContext()); // [bS]
|
||||
seqLen = new NDArray(x->ordering(), {x->sizeAt(1)}, sd::DataType::INT64, block.launchContext()); // [bS]
|
||||
*seqLen = x->sizeAt(0); // set each element of seqLen to be equal to time
|
||||
}
|
||||
|
||||
// reverse x
|
||||
}
|
||||
|
||||
// reverse x
|
||||
auto revOut = new NDArray(x, false, block.launchContext());
|
||||
helpers::reverseSequence(block.launchContext(), x, seqLen, revOut, 0, 1);
|
||||
|
||||
// backward steps
|
||||
// backward steps
|
||||
auto hBW = new NDArray(x->ordering(), {time, bS, numUnitsBW}, x->dataType(), block.launchContext());
|
||||
|
||||
|
||||
helpers::rnnTimeLoop(block.launchContext(), revOut, WxBW, WhBW, bBW, h0BW, maxTimeStep, hBW, hBWFinal);
|
||||
|
||||
// reverse hBW
|
||||
auto hBWcopy = new NDArray(*hBW);
|
||||
// reverse hBW
|
||||
auto hBWcopy = new NDArray(*hBW);
|
||||
helpers::reverseSequence(block.launchContext(), hBWcopy, seqLen, hBW, 0, 1);
|
||||
|
||||
// concatenate hFW and hBW along last third dimension
|
||||
|
@ -117,7 +126,7 @@ CUSTOM_OP_IMPL(static_bidirectional_rnn, 7, 3, false, 0, 0) {
|
|||
if(seqLen != maxTimeStep)
|
||||
delete seqLen;
|
||||
|
||||
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -128,7 +137,7 @@ CUSTOM_OP_IMPL(static_bidirectional_rnn, 7, 3, false, 0, 0) {
|
|||
}
|
||||
|
||||
|
||||
DECLARE_SHAPE_FN(static_bidirectional_rnn) {
|
||||
DECLARE_SHAPE_FN(static_bidirectional_rnn) {
|
||||
|
||||
auto xShapeInfo = inputShape->at(0); // input [time x bS x inSize]
|
||||
auto WxFWShapeInfo = inputShape->at(1); // input-to-hidden weights for forward RNN, [inSize x numUnitsFW]
|
||||
|
@ -167,16 +176,25 @@ DECLARE_SHAPE_FN(static_bidirectional_rnn) {
|
|||
const int numUnitsFW = WxFWShapeInfo[2];
|
||||
const int numUnitsBW = WxBWShapeInfo[2];
|
||||
|
||||
REQUIRE_TRUE(ShapeUtils::shapeAsString(WhFWShapeInfo) == ShapeUtils::shapeAsString({numUnitsFW, numUnitsFW}), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of hidden-to-hidden weights array (for forward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString({numUnitsFW, numUnitsFW}).c_str(), ShapeUtils::shapeAsString(WhFWShapeInfo).c_str());
|
||||
REQUIRE_TRUE(ShapeUtils::shapeAsString(WhBWShapeInfo) == ShapeUtils::shapeAsString({numUnitsBW, numUnitsBW}), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of hidden-to-hidden weights array (for backward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString({numUnitsBW, numUnitsBW}).c_str(), ShapeUtils::shapeAsString(WhBWShapeInfo).c_str());
|
||||
REQUIRE_TRUE(ShapeUtils::shapeAsString(bFWShapeInfo) == ShapeUtils::shapeAsString({2*numUnitsFW}), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of biases array (for forward RNN), expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2*numUnitsFW}).c_str(), ShapeUtils::shapeAsString(bFWShapeInfo).c_str());
|
||||
REQUIRE_TRUE(ShapeUtils::shapeAsString(bBWShapeInfo) == ShapeUtils::shapeAsString({2*numUnitsBW}), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of biases array (for backward RNN), expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2*numUnitsBW}).c_str(), ShapeUtils::shapeAsString(bBWShapeInfo).c_str());
|
||||
if(h0FWShapeInfo)
|
||||
REQUIRE_TRUE(ShapeUtils::shapeAsString(h0FWShapeInfo) == ShapeUtils::shapeAsString({bS, numUnitsFW}), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of initial cell output array (for forward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString({bS, numUnitsFW}).c_str(), ShapeUtils::shapeAsString(h0FWShapeInfo).c_str());
|
||||
if(h0BWShapeInfo)
|
||||
REQUIRE_TRUE(ShapeUtils::shapeAsString(h0BWShapeInfo) == ShapeUtils::shapeAsString({bS, numUnitsBW}), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of initial cell output array (for backward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString({bS, numUnitsBW}).c_str(), ShapeUtils::shapeAsString(h0BWShapeInfo).c_str());
|
||||
const std::vector<Nd4jLong> expectedWhFWshape = {numUnitsFW, numUnitsFW};
|
||||
const std::vector<Nd4jLong> expectedWhBWshape = {numUnitsBW, numUnitsBW};
|
||||
const std::vector<Nd4jLong> expectedbFWshape = {2 * numUnitsFW};
|
||||
const std::vector<Nd4jLong> expectedbBWshape = {2 * numUnitsBW};
|
||||
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(WhFWShapeInfo, expectedWhFWshape), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of hidden-to-hidden weights array (for forward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedWhFWshape).c_str(), ShapeUtils::shapeAsString(WhFWShapeInfo).c_str());
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(WhBWShapeInfo, expectedWhBWshape), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of hidden-to-hidden weights array (for backward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedWhBWshape).c_str(), ShapeUtils::shapeAsString(WhBWShapeInfo).c_str());
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(bFWShapeInfo, expectedbFWshape), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of biases array (for forward RNN), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedbFWshape).c_str(), ShapeUtils::shapeAsString(bFWShapeInfo).c_str());
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(bBWShapeInfo, expectedbBWshape), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of biases array (for backward RNN), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedbBWshape).c_str(), ShapeUtils::shapeAsString(bBWShapeInfo).c_str());
|
||||
if(h0FWShapeInfo) {
|
||||
const std::vector<Nd4jLong> expectedh0FWshape = {bS, numUnitsFW};
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(h0FWShapeInfo, expectedh0FWshape), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of initial cell output array (for forward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedh0FWshape).c_str(), ShapeUtils::shapeAsString(h0FWShapeInfo).c_str());
|
||||
}
|
||||
if(h0BWShapeInfo) {
|
||||
const std::vector<Nd4jLong> expectedh0BWshape = {bS, numUnitsBW};
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(h0BWShapeInfo, expectedh0BWshape), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of initial cell output array (for backward RNN), expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedh0BWshape).c_str(), ShapeUtils::shapeAsString(h0BWShapeInfo).c_str());
|
||||
}
|
||||
if(maxTimeStepShapeInfo)
|
||||
REQUIRE_TRUE(ShapeUtils::shapeAsString(maxTimeStepShapeInfo) == ShapeUtils::shapeAsString({bS}), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of maxTimeStep array, expected is [%i], but got %s instead !", bS, ShapeUtils::shapeAsString(maxTimeStepShapeInfo).c_str());
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(maxTimeStepShapeInfo, {bS}), 0, "STATIC_BIDIRECTIONAL_RNN custom operation: wrong shape of maxTimeStep array, expected is [%i], but got %s instead !", bS, ShapeUtils::shapeAsString(maxTimeStepShapeInfo).c_str());
|
||||
|
||||
// evaluate output shapeInfos
|
||||
Nd4jLong *hShapeInfo(nullptr), *hFWFinalPrevShapeInfo(nullptr), *hBWFinalPrevShapeInfo(nullptr);
|
||||
|
@ -195,9 +213,9 @@ DECLARE_SHAPE_FN(static_bidirectional_rnn) {
|
|||
ShapeUtils::updateStridesAndType(hShapeInfo, xShapeInfo, shape::order(xShapeInfo));
|
||||
ShapeUtils::updateStridesAndType(hFWFinalPrevShapeInfo, xShapeInfo, shape::order(xShapeInfo));
|
||||
ShapeUtils::updateStridesAndType(hBWFinalPrevShapeInfo, xShapeInfo, shape::order(xShapeInfo));
|
||||
|
||||
|
||||
return SHAPELIST(CONSTANT(hShapeInfo), CONSTANT(hFWFinalPrevShapeInfo), CONSTANT(hBWFinalPrevShapeInfo));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -58,12 +58,17 @@ CUSTOM_OP_IMPL(static_rnn, 4, 2, false, 0, 0) {
|
|||
const int inSize = x->sizeAt(2);
|
||||
const int numUnits = Wx->sizeAt(1);
|
||||
|
||||
REQUIRE_TRUE(ShapeUtils::shapeAsString(Wh) == ShapeUtils::shapeAsString({numUnits, numUnits}), 0, "STATIC_RNN custom operation: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({numUnits, numUnits}).c_str(), ShapeUtils::shapeAsString(Wh).c_str());
|
||||
REQUIRE_TRUE(ShapeUtils::shapeAsString(b) == ShapeUtils::shapeAsString({2*numUnits}), 0, "STATIC_RNN custom operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2*numUnits}).c_str(), ShapeUtils::shapeAsString(b).c_str());
|
||||
if(h0)
|
||||
REQUIRE_TRUE(ShapeUtils::shapeAsString(h0) == ShapeUtils::shapeAsString({bS, numUnits}), 0, "STATIC_RNN custom operation: wrong shape of initial cell output array, expected is %s but got %s instead !", ShapeUtils::shapeAsString({bS, numUnits}).c_str(), ShapeUtils::shapeAsString(h0).c_str());
|
||||
const std::vector<Nd4jLong> expectedWhShape = {numUnits, numUnits};
|
||||
const std::vector<Nd4jLong> expectedbShape = {2 * numUnits};
|
||||
|
||||
REQUIRE_TRUE(Wh->isSameShape(expectedWhShape), 0, "STATIC_RNN custom operation: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWhShape).c_str(), ShapeUtils::shapeAsString(Wh).c_str());
|
||||
REQUIRE_TRUE(b->isSameShape(expectedbShape), 0, "STATIC_RNN custom operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedbShape).c_str(), ShapeUtils::shapeAsString(b).c_str());
|
||||
if(h0) {
|
||||
const std::vector<Nd4jLong> expectedh0Shape = {bS, numUnits};
|
||||
REQUIRE_TRUE(h0->isSameShape(expectedh0Shape), 0, "STATIC_RNN custom operation: wrong shape of initial cell output array, expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedh0Shape).c_str(), ShapeUtils::shapeAsString(h0).c_str());
|
||||
}
|
||||
if(maxTimeStep)
|
||||
REQUIRE_TRUE(ShapeUtils::shapeAsString(maxTimeStep) == ShapeUtils::shapeAsString({bS}), 0, "STATIC_RNN custom operation: wrong shape of maxTimeStep array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS}).c_str(), ShapeUtils::shapeAsString(maxTimeStep).c_str());
|
||||
REQUIRE_TRUE(maxTimeStep->isSameShape({bS}), 0, "STATIC_RNN custom operation: wrong shape of maxTimeStep array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS}).c_str(), ShapeUtils::shapeAsString(maxTimeStep).c_str());
|
||||
|
||||
|
||||
helpers::rnnTimeLoop(block.launchContext(), x, Wx, Wh, b, h0, maxTimeStep, h, hFinal);
|
||||
|
@ -107,12 +112,17 @@ DECLARE_SHAPE_FN(static_rnn) {
|
|||
const int bS = xShapeInfo[2];
|
||||
const int numUnits = WxShapeInfo[2];
|
||||
|
||||
REQUIRE_TRUE(ShapeUtils::shapeAsString(WhShapeInfo) == ShapeUtils::shapeAsString({numUnits, numUnits}), 0, "STATIC_RNN custom operation: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({numUnits, numUnits}).c_str(), ShapeUtils::shapeAsString(WhShapeInfo).c_str());
|
||||
REQUIRE_TRUE(ShapeUtils::shapeAsString(bShapeInfo) == ShapeUtils::shapeAsString({2*numUnits}), 0, "STATIC_RNN custom operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2*numUnits}).c_str(), ShapeUtils::shapeAsString(bShapeInfo).c_str());
|
||||
if(h0ShapeInfo)
|
||||
REQUIRE_TRUE(ShapeUtils::shapeAsString(h0ShapeInfo) == ShapeUtils::shapeAsString({bS, numUnits}), 0, "STATIC_RNN custom operation: wrong shape of initial cell output array, expected is %s but got %s instead !", ShapeUtils::shapeAsString({bS, numUnits}).c_str(), ShapeUtils::shapeAsString(h0ShapeInfo).c_str());
|
||||
const std::vector<Nd4jLong> expectedWhShape = {numUnits, numUnits};
|
||||
const std::vector<Nd4jLong> expectedbShape = {2 * numUnits};
|
||||
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(WhShapeInfo, expectedWhShape), 0, "STATIC_RNN custom operation: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWhShape).c_str(), ShapeUtils::shapeAsString(WhShapeInfo).c_str());
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(bShapeInfo, expectedbShape), 0, "STATIC_RNN custom operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedbShape).c_str(), ShapeUtils::shapeAsString(bShapeInfo).c_str());
|
||||
if(h0ShapeInfo){
|
||||
const std::vector<Nd4jLong> expectedh0Shape = {bS, numUnits};
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(h0ShapeInfo, expectedh0Shape), 0, "STATIC_RNN custom operation: wrong shape of initial cell output array, expected is %s but got %s instead !", ShapeUtils::shapeAsString(expectedh0Shape).c_str(), ShapeUtils::shapeAsString(h0ShapeInfo).c_str());
|
||||
}
|
||||
if(maxTimeStepShapeInfo)
|
||||
REQUIRE_TRUE(ShapeUtils::shapeAsString(maxTimeStepShapeInfo) == ShapeUtils::shapeAsString({bS}), 0, "STATIC_RNN custom operation: wrong shape of maxTimeStep array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS}).c_str(), ShapeUtils::shapeAsString(maxTimeStepShapeInfo).c_str());
|
||||
REQUIRE_TRUE(ShapeUtils::areShapesEqual(maxTimeStepShapeInfo, {bS}), 0, "STATIC_RNN custom operation: wrong shape of maxTimeStep array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS}).c_str(), ShapeUtils::shapeAsString(maxTimeStepShapeInfo).c_str());
|
||||
|
||||
// evaluate output shapeInfos
|
||||
Nd4jLong *hShapeInfo(nullptr), *hPrevShapeInfo(nullptr);
|
||||
|
|
|
@ -38,7 +38,7 @@ CUSTOM_OP_IMPL(concat, -1, 1, false, 0, 0) {
|
|||
|
||||
// first of all take into account possible presence of empty arrays
|
||||
// also if scalar is present -> copy its value to vector with length=1
|
||||
std::vector<NDArray*> nonEmptyArrs;
|
||||
std::vector<const NDArray*> nonEmptyArrs;
|
||||
std::vector<int> arrsToDelete;
|
||||
int index = 0;
|
||||
bool allOfSameType = true;
|
||||
|
|
|
@ -36,16 +36,16 @@ CUSTOM_OP_IMPL(pad, 2, 1, false, 0, 1) {
|
|||
auto paddings = INPUT_VARIABLE(1);
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
|
||||
const int rank = input->rankOf();
|
||||
const int rank = input->rankOf();
|
||||
|
||||
// input validation
|
||||
std::string expectedPaddingsShape = ShapeUtils::shapeAsString({rank, 2});
|
||||
std::string currentPaddingsShape = ShapeUtils::shapeAsString(paddings);
|
||||
REQUIRE_TRUE(expectedPaddingsShape == currentPaddingsShape, 0, "PAD op: wrong shape of paddings array, expected is %s, but got %s instead !", expectedPaddingsShape.c_str(), currentPaddingsShape.c_str());
|
||||
std::vector<Nd4jLong> expectedPaddingsShape = {rank, 2};
|
||||
std::vector<Nd4jLong> currentPaddingsShape = paddings->getShapeAsVector();
|
||||
REQUIRE_TRUE(expectedPaddingsShape == currentPaddingsShape, 0, "PAD op: wrong shape of paddings array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedPaddingsShape).c_str(), ShapeUtils::shapeAsString(currentPaddingsShape).c_str());
|
||||
|
||||
NDArray padValue(input->dataType(), block.launchContext());
|
||||
|
||||
// in case of REFLECT and SYMMETRIC modes paddings must obey additional shape requirements
|
||||
// in case of REFLECT and SYMMETRIC modes paddings must obey additional shape requirements
|
||||
if (INT_ARG(0) == 0) { // CONSTANT mode
|
||||
if(block.width() > 2) {
|
||||
REQUIRE_TRUE(input->dataType() == INPUT_VARIABLE(2)->dataType(), 0, "PAD op: data types of input and padValue arrays should be the same but got %i and %i correspondingly !", input->dataType(), INPUT_VARIABLE(2)->dataType());
|
||||
|
@ -68,10 +68,10 @@ CUSTOM_OP_IMPL(pad, 2, 1, false, 0, 1) {
|
|||
|
||||
// std::vector<int> dimensions(input->rankOf());
|
||||
// std::iota(dimensions.begin(), dimensions.end(), 0); // fill with 0, 1, ... rank-1
|
||||
|
||||
|
||||
// helpers::recursiveLoopForPad(INT_ARG(0), *input, *paddings, *output, dimensions, 0, 0, 0, padValue);
|
||||
helpers::pad(block.launchContext(), INT_ARG(0), *input, *paddings, *output, padValue);
|
||||
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -85,27 +85,26 @@ DECLARE_TYPES(pad) {
|
|||
|
||||
DECLARE_SHAPE_FN(pad) {
|
||||
|
||||
// check shape of paddings
|
||||
// check shape of paddings
|
||||
auto inputShapeInfo = inputShape->at(0);
|
||||
auto paddings = INPUT_VARIABLE(1);
|
||||
const int rank = inputShapeInfo[0];
|
||||
const int rank = inputShapeInfo[0];
|
||||
|
||||
// paddings validation
|
||||
std::string expectedPaddingsShape = ShapeUtils::shapeAsString({rank, 2});
|
||||
std::string currentPaddingsShape = ShapeUtils::shapeAsString(paddings);
|
||||
REQUIRE_TRUE(expectedPaddingsShape == currentPaddingsShape, 0, "PAD op: wrong shape of paddings array, expected is %s, but got %s instead !", expectedPaddingsShape.c_str(), currentPaddingsShape.c_str());
|
||||
|
||||
const std::vector<Nd4jLong> expectedPaddingsShape = {rank, 2};
|
||||
const std::vector<Nd4jLong> currentPaddingsShape = paddings->getShapeAsVector();
|
||||
REQUIRE_TRUE(expectedPaddingsShape == currentPaddingsShape, 0, "PAD op: wrong shape of paddings array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedPaddingsShape).c_str(), ShapeUtils::shapeAsString(currentPaddingsShape).c_str());
|
||||
|
||||
Nd4jLong* outShapeInfo = nullptr;
|
||||
ALLOCATE(outShapeInfo, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong);
|
||||
outShapeInfo[0] = rank;
|
||||
for(int i=1; i <= rank; ++i)
|
||||
outShapeInfo[i] = inputShapeInfo[i] + paddings->e<Nd4jLong>(i-1,0) + paddings->e<Nd4jLong>(i-1,1);
|
||||
|
||||
|
||||
ShapeUtils::updateStridesAndType(outShapeInfo, inputShapeInfo, shape::order(inputShapeInfo));
|
||||
ShapeDescriptor descriptor(outShapeInfo);
|
||||
RELEASE(outShapeInfo, block.getWorkspace());
|
||||
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(descriptor));
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -27,15 +27,15 @@ namespace sd {
|
|||
namespace helpers {
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template<typename T>
|
||||
static void concat_(const std::vector<NDArray*>& inArrs, NDArray& output, const int axis) {
|
||||
static void concat_(const std::vector<const NDArray*>& inArrs, NDArray& output, const int axis) {
|
||||
sd::SpecialMethods<T>::concatCpuGeneric(inArrs, output, axis);
|
||||
}
|
||||
|
||||
void concat(sd::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output, const int axis) {
|
||||
void concat(sd::LaunchContext * context, const std::vector<const NDArray*>& inArrs, NDArray& output, const int axis) {
|
||||
BUILD_SINGLE_SELECTOR(output.dataType(), concat_,(inArrs, output, axis), LIBND4J_TYPES);
|
||||
}
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template void concat_, (const std::vector<NDArray*>& inArrs, NDArray& output, const int axis), LIBND4J_TYPES);
|
||||
BUILD_SINGLE_TEMPLATE(template void concat_, (const std::vector<const NDArray*>& inArrs, NDArray& output, const int axis), LIBND4J_TYPES);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -15,13 +15,14 @@
|
|||
******************************************************************************/
|
||||
|
||||
//
|
||||
// Created by Yurii Shyrma on 02.01.2018
|
||||
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||
//
|
||||
|
||||
#include <ops/declarable/helpers/stack.h>
|
||||
#include <helpers/ShapeUtils.h>
|
||||
#include <array/ResultSet.h>
|
||||
#include <execution/Threads.h>
|
||||
#include <helpers/ConstantTadHelper.h>
|
||||
|
||||
|
||||
namespace sd {
|
||||
|
@ -31,37 +32,90 @@ namespace helpers {
|
|||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
static void stack_(const std::vector<const NDArray*>& inArrs, NDArray* outArr, const int dim) {
|
||||
static void stack_(const std::vector<const NDArray*>& inArrs, NDArray& output, const int dim) {
|
||||
|
||||
const int numOfSubArrs = inArrs.size();
|
||||
|
||||
if(inArrs[0]->rankOf() == 0) {
|
||||
int inSize = inArrs.size();
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
for (auto i = start; i < stop; i++)
|
||||
outArr->p<T>(i, inArrs[i]->t<T>(0));
|
||||
output.p<T>(i, inArrs[i]->t<T>(0));
|
||||
};
|
||||
|
||||
samediff::Threads::parallel_for(func, 0, inSize);
|
||||
samediff::Threads::parallel_for(func, 0, numOfSubArrs);
|
||||
}
|
||||
else {
|
||||
|
||||
std::vector<int> dimsToExclude = ShapeUtils::evalDimsToExclude(outArr->rankOf(), {dim});
|
||||
auto list = outArr->allTensorsAlongDimension(dimsToExclude); // list.size() == block.width()
|
||||
int listSize = list.size();
|
||||
auto zTadPack = ConstantTadHelper::getInstance()->tadForDimensions(output.getShapeInfo(), ShapeUtils::evalDimsToExclude(output.rankOf(), {dim}));
|
||||
Nd4jLong* zTadShapeInfo = zTadPack.primaryShapeInfo();
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
|
||||
for (auto i = start; i < stop; i++) {
|
||||
|
||||
void* zBuff = output.bufferWithOffset(zTadPack.primaryOffsets()[i]);
|
||||
|
||||
NativeOpExecutioner::execTransformAny(inArrs[0]->getContext(), transform::Assign,
|
||||
inArrs[i]->getBuffer(), inArrs[i]->getShapeInfo(), nullptr/*input specialBuffer*/, nullptr/*input specialShapeInfo*/,
|
||||
zBuff, zTadShapeInfo, nullptr/*output specialBuffer*/, nullptr/*output specialShapeInfo*/,
|
||||
nullptr, nullptr, nullptr, false/*allowParallelism*/);
|
||||
}
|
||||
};
|
||||
|
||||
samediff::Threads::parallel_tad(func, 0, numOfSubArrs);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
void stack(sd::LaunchContext * context, const std::vector<const NDArray*>& inArrs, NDArray& output, const int dim) {
|
||||
BUILD_SINGLE_SELECTOR(output.dataType(), stack_, (inArrs, output, dim), LIBND4J_TYPES);
|
||||
}
|
||||
BUILD_SINGLE_TEMPLATE(template void stack_ , (const std::vector<const NDArray*>& inArrs, NDArray& output, const int dim), LIBND4J_TYPES);
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
static void unstack_(const NDArray& input, const std::vector<NDArray*>& outArrs, const int dim) {
|
||||
|
||||
const int numOfSubArrs = outArrs.size();
|
||||
|
||||
if(outArrs[0]->rankOf() == 0) {
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
for (auto i = start; i < stop; i++)
|
||||
list.at(i)->assign(inArrs[i]);
|
||||
outArrs[i]->p<T>(0, input.t<T>(i));
|
||||
};
|
||||
samediff::Threads::parallel_tad(func, 0, listSize);
|
||||
|
||||
samediff::Threads::parallel_for(func, 0, numOfSubArrs);
|
||||
}
|
||||
else {
|
||||
|
||||
auto xTadPack = ConstantTadHelper::getInstance()->tadForDimensions(input.getShapeInfo(), ShapeUtils::evalDimsToExclude(input.rankOf(), {dim}));
|
||||
Nd4jLong* xTadShapeInfo = xTadPack.primaryShapeInfo();
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
for (auto i = start; i < stop; i++) {
|
||||
|
||||
void* xBuff = input.bufferWithOffset(xTadPack.primaryOffsets()[i]);
|
||||
|
||||
NativeOpExecutioner::execTransformAny(input.getContext(), transform::Assign,
|
||||
xBuff, xTadShapeInfo, nullptr/*input specialBuffer*/, nullptr/*input specialShapeInfo*/,
|
||||
outArrs[i]->getBuffer(), outArrs[i]->getShapeInfo(), nullptr/*output specialBuffer*/, nullptr/*output specialShapeInfo*/,
|
||||
nullptr, nullptr, nullptr, false/*allowParallelism*/);
|
||||
}
|
||||
};
|
||||
|
||||
samediff::Threads::parallel_tad(func, 0, numOfSubArrs);
|
||||
}
|
||||
}
|
||||
|
||||
void stack(sd::LaunchContext * context, const std::vector<const NDArray*>& inArrs, NDArray* outArr, const int dim) {
|
||||
BUILD_SINGLE_SELECTOR(outArr->dataType(), stack_, (inArrs, outArr, dim), LIBND4J_TYPES);
|
||||
}
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template void stack_ , (const std::vector<const NDArray*>& inArrs, NDArray* outArr, const int dim), LIBND4J_TYPES);
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
void unstack(sd::LaunchContext* context, const NDArray& input, const std::vector<NDArray*>& outArrs, const int dim) {
|
||||
BUILD_SINGLE_SELECTOR(input.dataType(), unstack_, (input, outArrs, dim), LIBND4J_TYPES);
|
||||
}
|
||||
BUILD_SINGLE_TEMPLATE(template void unstack_, (const NDArray& input, const std::vector<NDArray*>& outArrs, const int dim), LIBND4J_TYPES);
|
||||
|
||||
}
|
||||
}
|
||||
|
|
|
@ -83,14 +83,12 @@ __host__ static void concatCudaLauncher(const int blocksPerGrid, const int threa
|
|||
BUILD_SINGLE_TEMPLATE(template void concatCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, void* pVx, void* pxShapeInfo, void* vz, Nd4jLong* zShapeInfo, const int axis), LIBND4J_TYPES);
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void concat(sd::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output, const int axis) {
|
||||
void concat(sd::LaunchContext * context, const std::vector<const NDArray*>& inArrs, NDArray& output, const int axis) {
|
||||
|
||||
const int numOfInArrs = inArrs.size();
|
||||
const auto sizeofT = output.sizeOfT();
|
||||
|
||||
for(int i = 0; i < numOfInArrs; ++i)
|
||||
inArrs[i]->syncToDevice();
|
||||
output.syncToDevice();
|
||||
NDArray::prepareSpecialUse({&output}, inArrs);
|
||||
|
||||
bool luckCase1 = ((axis == 0 && output.ordering() == 'c') || (axis == output.rankOf() - 1 && output.ordering() == 'f')) && output.ews() == 1;
|
||||
|
||||
|
@ -122,43 +120,48 @@ void concat(sd::LaunchContext * context, const std::vector<NDArray*>& inArrs, ND
|
|||
return;
|
||||
}
|
||||
|
||||
const bool isZcontin = output.strideAt(axis) == 1;
|
||||
bool areInputsContin = true;
|
||||
bool allSameOrder = true;
|
||||
// const bool isZcontin = output.strideAt(axis) == 1;
|
||||
// bool areInputsContin = true;
|
||||
// bool allSameOrder = true;
|
||||
// std::vector<Nd4jLong> strideOfContigStride(numOfInArrs);
|
||||
|
||||
if(isZcontin) {
|
||||
for (uint i = 0; i < inArrs.size(); ++i) {
|
||||
areInputsContin &= inArrs[i]->strideAt(axis) == 1;
|
||||
allSameOrder &= output.ordering() == inArrs[i]->ordering();
|
||||
if(!areInputsContin || !allSameOrder)
|
||||
break;
|
||||
}
|
||||
}
|
||||
// if(isZcontin) {
|
||||
|
||||
const bool luckCase2 = isZcontin && areInputsContin && allSameOrder;
|
||||
// for (uint i = 0; i < inArrs.size(); ++i) {
|
||||
|
||||
if(luckCase2) { // for example {2,1,3} + {2,5,3} + {2,10,3} = {2,16,3}, here axis 1 shoud have stride = 1 for all inputs arrays and output array
|
||||
// areInputsContin &= inArrs[i]->strideAt(axis) == 1;
|
||||
// allSameOrder &= output.ordering() == inArrs[i]->ordering();
|
||||
// if(!areInputsContin || !allSameOrder)
|
||||
// break;
|
||||
|
||||
const uint zDim = output.sizeAt(axis);
|
||||
// strideOfContigStride[i] = shape::strideOverContigAxis(axis, inArrs[i]->getShapeInfo());
|
||||
// }
|
||||
// }
|
||||
|
||||
for (uint i = 0; i < output.lengthOf() / zDim; ++i) {
|
||||
// const bool luckCase2 = isZcontin && areInputsContin && allSameOrder;
|
||||
|
||||
const auto iShift = i * sizeofT;
|
||||
void* z = static_cast<int8_t*>(output.getSpecialBuffer()) + zDim * iShift;
|
||||
// if(luckCase2) { // for example {2,1,3} + {2,5,3} + {2,10,3} = {2,16,3}, here axis 1 shoud have stride = 1 for all inputs arrays and output array
|
||||
|
||||
for (uint j = 0; j < numOfInArrs; ++j) {
|
||||
const auto xDim = inArrs[j]->sizeAt(axis);
|
||||
void* x = static_cast<int8_t*>(inArrs[j]->getSpecialBuffer()) + xDim * iShift;
|
||||
const auto memSizeToCopy = xDim * sizeofT;
|
||||
cudaMemcpyAsync(z, x, memSizeToCopy, cudaMemcpyDeviceToDevice, *context->getCudaStream());
|
||||
z = static_cast<int8_t*>(z) + memSizeToCopy;
|
||||
}
|
||||
}
|
||||
// const auto zStep = shape::strideOverContigAxis(axis, output.getShapeInfo());
|
||||
|
||||
if(cudaStreamSynchronize(*context->getCudaStream()) != 0)
|
||||
throw std::runtime_error("concat cuda: luckCase2 failed!");
|
||||
}
|
||||
else { // general (slower) case
|
||||
// for (uint i = 0; i < output.lengthOf() / output.sizeAt(axis); ++i) {
|
||||
|
||||
// const auto iShift = i * sizeofT;
|
||||
// void* z = static_cast<int8_t*>(output.getSpecialBuffer()) + zStep * iShift;
|
||||
|
||||
// for (uint j = 0; j < numOfInArrs; ++j) {
|
||||
// const auto xDim = inArrs[j]->sizeAt(axis);
|
||||
// void* x = static_cast<int8_t*>(inArrs[j]->getSpecialBuffer()) + strideOfContigStride[j] * iShift;
|
||||
// const auto memSizeToCopy = xDim * sizeofT;
|
||||
// cudaMemcpyAsync(z, x, memSizeToCopy, cudaMemcpyDeviceToDevice, *context->getCudaStream());
|
||||
// z = static_cast<int8_t*>(z) + memSizeToCopy;
|
||||
// }
|
||||
// }
|
||||
|
||||
// if(cudaStreamSynchronize(*context->getCudaStream()) != 0)
|
||||
// throw std::runtime_error("concat cuda: luckCase2 failed!");
|
||||
// }
|
||||
// else { // general (slower) case
|
||||
|
||||
const int threadsPerBlock = 256;
|
||||
const int blocksPerGrid = 512;
|
||||
|
@ -181,11 +184,9 @@ void concat(sd::LaunchContext * context, const std::vector<NDArray*>& inArrs, ND
|
|||
BUILD_SINGLE_SELECTOR(inArrs[0]->dataType(), concatCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), dInBuffers, dInShapeInfo, output.specialBuffer(), output.specialShapeInfo(), axis), LIBND4J_TYPES);
|
||||
|
||||
manager.synchronize();
|
||||
}
|
||||
// }
|
||||
|
||||
for(int i = 0; i < numOfInArrs; ++i)
|
||||
inArrs[i]->tickReadDevice();
|
||||
output.tickWriteDevice();
|
||||
NDArray::registerSpecialUse({&output}, inArrs);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -48,11 +48,11 @@ __global__ static void splitCuda(const void* vx, const Nd4jLong* xShapeInfo, voi
|
|||
xLen = shape::length(xShapeInfo);
|
||||
xRank = shape::rank(xShapeInfo);
|
||||
zDim = shape::shapeOf(zTadShapeInfo)[axis]; // same for all input arrays
|
||||
totalThreads = gridDim.z * blockDim.z;
|
||||
totalThreads = gridDim.x * blockDim.x;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
const auto tid = blockIdx.z * blockDim.z + threadIdx.z;
|
||||
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
Nd4jLong coords[MAX_RANK];
|
||||
|
||||
|
@ -121,43 +121,48 @@ void split(sd::LaunchContext* context, const NDArray& input, std::vector<NDArray
|
|||
return;
|
||||
}
|
||||
|
||||
const bool isXcontin = input.strideAt(axis) == 1;
|
||||
bool areOutputsContin = true;
|
||||
bool allSameOrder = true;
|
||||
// const bool isXcontin = input.strideAt(axis) == 1;
|
||||
// bool areOutputsContin = true;
|
||||
// bool allSameOrder = true;
|
||||
// std::vector<Nd4jLong> strideOfContigStride(outArrs.size());
|
||||
|
||||
if(isXcontin) {
|
||||
for (uint i = 0; i < outArrs.size(); ++i) {
|
||||
areOutputsContin &= outArrs[i]->strideAt(axis) == 1;
|
||||
allSameOrder &= input.ordering() == outArrs[i]->ordering();
|
||||
if(!areOutputsContin || !allSameOrder)
|
||||
break;
|
||||
}
|
||||
}
|
||||
// if(isXcontin) {
|
||||
|
||||
const bool luckCase2 = isXcontin && areOutputsContin && allSameOrder;
|
||||
// for (uint i = 0; i < outArrs.size(); ++i) {
|
||||
|
||||
if(luckCase2) { // for example {2,1,3} + {2,5,3} + {2,10,3} = {2,16,3}, here axis 1 shoud have stride = 1 for all inputs arrays and input array
|
||||
// areOutputsContin &= outArrs[i]->strideAt(axis) == 1;
|
||||
// allSameOrder &= input.ordering() == outArrs[i]->ordering();
|
||||
// if(!areOutputsContin || !allSameOrder)
|
||||
// break;
|
||||
|
||||
const auto xDim = input.sizeAt(axis);
|
||||
const auto zDim = outArrs[0]->sizeAt(axis); // same for all outArrs
|
||||
// strideOfContigStride[i] = shape::strideOverContigAxis(axis, outArrs[i]->getShapeInfo());
|
||||
// }
|
||||
// }
|
||||
|
||||
for (uint i = 0; i < input.lengthOf() / xDim; ++i) {
|
||||
// const bool luckCase2 = isXcontin && areOutputsContin && allSameOrder;
|
||||
|
||||
const auto iShift = i * sizeofT;
|
||||
void* x = static_cast<int8_t*>(input.getSpecialBuffer()) + xDim * iShift;
|
||||
// if(luckCase2) { // for example {2,1,3} + {2,5,3} + {2,10,3} = {2,16,3}, here axis 1 shoud have stride = 1 for all inputs arrays and input array
|
||||
|
||||
for (uint j = 0; j < numOfSubArrs; ++j) {
|
||||
void* z = static_cast<int8_t*>(outArrs[j]->getSpecialBuffer()) + zDim * iShift;
|
||||
const auto memSizeToCopy = zDim * sizeofT;
|
||||
cudaMemcpyAsync(z, x, memSizeToCopy, cudaMemcpyDeviceToDevice, *context->getCudaStream());
|
||||
x = static_cast<int8_t*>(x) + memSizeToCopy;
|
||||
}
|
||||
}
|
||||
// const auto xStep = shape::strideOverContigAxis(axis, input.getShapeInfo());
|
||||
// const auto zDim = outArrs[0]->sizeAt(axis); // same for all outArrs
|
||||
|
||||
if(cudaStreamSynchronize(*context->getCudaStream()) != 0)
|
||||
throw std::runtime_error("split cuda: luckCase2 failed!");
|
||||
}
|
||||
else { // general (slower) case
|
||||
// for (uint i = 0; i < input.lengthOf() / input.sizeAt(axis); ++i) {
|
||||
|
||||
// const auto iShift = i * sizeofT;
|
||||
// void* x = static_cast<int8_t*>(input.getSpecialBuffer()) + xStep * iShift;
|
||||
|
||||
// for (uint j = 0; j < numOfSubArrs; ++j) {
|
||||
// void* z = static_cast<int8_t*>(outArrs[j]->getSpecialBuffer()) + strideOfContigStride[j] * iShift;
|
||||
// const auto memSizeToCopy = zDim * sizeofT;
|
||||
// cudaMemcpyAsync(z, x, memSizeToCopy, cudaMemcpyDeviceToDevice, *context->getCudaStream());
|
||||
// x = static_cast<int8_t*>(x) + memSizeToCopy;
|
||||
// }
|
||||
// }
|
||||
|
||||
// if(cudaStreamSynchronize(*context->getCudaStream()) != 0)
|
||||
// throw std::runtime_error("split cuda: luckCase2 failed!");
|
||||
// }
|
||||
// else { // general (slower) case
|
||||
|
||||
const int threadsPerBlock = MAX_NUM_THREADS / 2;
|
||||
const int blocksPerGrid = (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
||||
|
@ -175,7 +180,7 @@ void split(sd::LaunchContext* context, const NDArray& input, std::vector<NDArray
|
|||
BUILD_SINGLE_SELECTOR(input.dataType(), splitCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), dOutBuffers, outArrs[0]->specialShapeInfo(), axis), LIBND4J_TYPES);
|
||||
|
||||
manager.synchronize();
|
||||
}
|
||||
// }
|
||||
|
||||
for(int i = 0; i < numOfSubArrs; ++i)
|
||||
outArrs[i]->tickWriteDevice();
|
||||
|
|
|
@ -31,79 +31,333 @@ namespace ops {
|
|||
namespace helpers {
|
||||
|
||||
|
||||
template <typename T>
|
||||
static __global__ void stackKernel(void** inputList, void** inputShapeList, int inputListLength, Nd4jLong arrLen, void* vz, const Nd4jLong* zShapeInfo, Nd4jLong* tadShape, Nd4jLong *tadOffsets) {
|
||||
///////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
static __global__ void stackScalarsCuda(void* pVx, void* vz, const Nd4jLong* zShapeInfo) {
|
||||
|
||||
T* z = reinterpret_cast<T*>(vz);
|
||||
T* z = reinterpret_cast<T*>(vz);
|
||||
|
||||
if(tadShape == nullptr) { // scalar case
|
||||
__shared__ Nd4jLong zLen, totalThreads;
|
||||
|
||||
for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < inputListLength; i += gridDim.x * blockDim.x)
|
||||
z[shape::getIndexOffset(i, zShapeInfo)] = reinterpret_cast<T*>(inputList[i])[0];
|
||||
}
|
||||
else {
|
||||
if (threadIdx.x == 0) {
|
||||
zLen = shape::length(zShapeInfo);
|
||||
totalThreads = gridDim.x * blockDim.x;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int t = blockIdx.x; t < inputListLength; t += gridDim.x) {
|
||||
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
auto tZ = z + tadOffsets[t];
|
||||
auto tX = reinterpret_cast<T*>(inputList[t]);
|
||||
auto xShapeInfo = reinterpret_cast<Nd4jLong*>(inputShapeList[t]);
|
||||
for (Nd4jLong i = tid; i < zLen; i += totalThreads) {
|
||||
|
||||
for (int e = threadIdx.x; e < arrLen; e += blockDim.x)
|
||||
tZ[shape::getIndexOffset(e, tadShape)] = tX[shape::getIndexOffset(e, xShapeInfo)];
|
||||
}
|
||||
}
|
||||
}
|
||||
const T *x = reinterpret_cast<const T*>(reinterpret_cast<void**>(pVx)[i]);
|
||||
z[shape::getIndexOffset(i, zShapeInfo)] = *x;
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
static void stack_(sd::LaunchContext * context, const std::vector<const NDArray*>& inArrs, NDArray* outArr, const int dim) {
|
||||
|
||||
const bool scalarCase = inArrs[0]->isScalar();
|
||||
///////////////////////////////////////////////////////////////////
|
||||
template<typename T>
|
||||
__host__ static void stackScalarsCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream,
|
||||
void* pVx, void* vz, const Nd4jLong* zShapeInfo) {
|
||||
|
||||
const int threadsPerBlock = MAX_NUM_THREADS / 2;
|
||||
const int blocksPerGrid = scalarCase ? (outArr->lengthOf() + threadsPerBlock - 1) / threadsPerBlock : inArrs.size();
|
||||
stackScalarsCuda<T><<<blocksPerGrid, threadsPerBlock, 256, *stream>>>(pVx, vz, zShapeInfo);
|
||||
}
|
||||
|
||||
NDArray::prepareSpecialUse({outArr}, {});
|
||||
///////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
static void stack_(sd::LaunchContext* context, const std::vector<const NDArray*>& inArrs, NDArray& output, const int dim) {
|
||||
|
||||
// FIXME: !!!
|
||||
for (auto v:inArrs)
|
||||
NDArray::prepareSpecialUse({}, {v});
|
||||
const int numOfSubArrs = inArrs.size();
|
||||
|
||||
std::vector<void const*> inputList(inArrs.size());
|
||||
std::vector<Nd4jLong const*> inputShapeList(inArrs.size());
|
||||
NDArray::prepareSpecialUse({&output}, inArrs);
|
||||
|
||||
for (size_t i = 0; i < inputList.size(); ++i) {
|
||||
inputList[i] = inArrs[i]->getSpecialBuffer();
|
||||
inputShapeList[i] = inArrs[i]->getSpecialShapeInfo();
|
||||
}
|
||||
if(inArrs[0]->rankOf() == 0) {
|
||||
|
||||
PointersManager manager(context, "helpers::stack");
|
||||
auto dInBuffers = (void **) manager.replicatePointer(inputList.data(), inputList.size() * sizeof(Nd4jLong*));
|
||||
auto dInShapeInfo = (void **) manager.replicatePointer(inputShapeList.data(), inputShapeList.size() * sizeof(Nd4jLong*));
|
||||
std::vector<void*> hInBuffers(numOfSubArrs);
|
||||
|
||||
for(int i = 0; i < numOfSubArrs; ++i)
|
||||
hInBuffers[i] = inArrs[i]->getSpecialBuffer();
|
||||
|
||||
PointersManager manager(context, "helpers::stack cuda");
|
||||
|
||||
void* dInBuffers = manager.replicatePointer(hInBuffers.data(), hInBuffers.size() * sizeof(void*));
|
||||
|
||||
const int threadsPerBlock = MAX_NUM_THREADS / 2;
|
||||
const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
||||
|
||||
stackScalarsCudaLauncher<T>(blocksPerGrid, threadsPerBlock, context->getCudaStream(), dInBuffers, output.specialBuffer(), output.specialShapeInfo());
|
||||
|
||||
if(scalarCase) {
|
||||
stackKernel<T><<<blocksPerGrid, threadsPerBlock, 1024, *context->getCudaStream()>>>((void**)dInBuffers, (void**)dInShapeInfo, inputList.size(), inArrs[0]->lengthOf(), outArr->specialBuffer(), outArr->getSpecialShapeInfo(), nullptr, nullptr);
|
||||
}
|
||||
else {
|
||||
std::vector<int> axis = ShapeUtils::evalDimsToExclude(outArr->rankOf(), {dim});
|
||||
auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(outArr->getShapeInfo(), axis);
|
||||
stackKernel<T><<<blocksPerGrid, threadsPerBlock, 1024, *context->getCudaStream()>>>((void**)dInBuffers, (void**)dInShapeInfo, inputList.size(), inArrs[0]->lengthOf(), outArr->specialBuffer(), nullptr, packZ.specialShapeInfo(), packZ.specialOffsets());
|
||||
}
|
||||
manager.synchronize();
|
||||
}
|
||||
else {
|
||||
|
||||
NDArray::registerSpecialUse({outArr}, {});
|
||||
auto zTadPack = ConstantTadHelper::getInstance()->tadForDimensions(output.getShapeInfo(), ShapeUtils::evalDimsToExclude(output.rankOf(), {dim}));
|
||||
Nd4jLong* zTadShapeInfo = zTadPack.primaryShapeInfo();
|
||||
|
||||
// FIXME: !!!
|
||||
for (auto v:inArrs)
|
||||
NDArray::registerSpecialUse({}, {v});
|
||||
}
|
||||
for (uint i = 0; i < numOfSubArrs; ++i) {
|
||||
|
||||
void stack(sd::LaunchContext * context, const std::vector<const NDArray*>& inArrs, NDArray* outArr, const int dim) {
|
||||
BUILD_SINGLE_SELECTOR(outArr->dataType(), stack_, (context, inArrs, outArr, dim), LIBND4J_TYPES);
|
||||
}
|
||||
void* zBuff = output.specialBufferWithOffset(zTadPack.primaryOffsets()[i]);
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template void stack_ , (sd::LaunchContext * context, const std::vector<const NDArray*>& inArrs, NDArray* outArr, const int dim), LIBND4J_TYPES);
|
||||
NativeOpExecutioner::execTransformAny(context, transform::Assign,
|
||||
nullptr, inArrs[i]->getShapeInfo(), inArrs[i]->getSpecialBuffer(), inArrs[i]->getSpecialShapeInfo(),
|
||||
nullptr, zTadShapeInfo, zBuff, zTadPack.specialShapeInfo(),
|
||||
nullptr, nullptr, nullptr, false/*allowParallelism*/);
|
||||
}
|
||||
}
|
||||
|
||||
NDArray::registerSpecialUse({&output}, inArrs);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
void stack(sd::LaunchContext* context, const std::vector<const NDArray*>& inArrs, NDArray& output, const int dim) {
|
||||
BUILD_SINGLE_SELECTOR(output.dataType(), stack_, (context, inArrs, output, dim), LIBND4J_TYPES);
|
||||
}
|
||||
BUILD_SINGLE_TEMPLATE(template void stack_ , (sd::LaunchContext* context, const std::vector<const NDArray*>& inArrs, NDArray& output, const int dim), LIBND4J_TYPES);
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
static __global__ void unstackScalarsCuda(const void* vx, const Nd4jLong* xShapeInfo, void* pVz) {
|
||||
|
||||
const T* x = reinterpret_cast<const T*>(vx);
|
||||
|
||||
__shared__ Nd4jLong xLen, totalThreads;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
xLen = shape::length(xShapeInfo);
|
||||
totalThreads = gridDim.x * blockDim.x;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
for (Nd4jLong i = tid; i < xLen; i += totalThreads) {
|
||||
|
||||
T* z = reinterpret_cast<T*>(reinterpret_cast<void**>(pVz)[i]);
|
||||
*z = x[shape::getIndexOffset(i, xShapeInfo)];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
template<typename T>
|
||||
__host__ static void unstackScalarsCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream,
|
||||
const void* vx, const Nd4jLong* xShapeInfo, void* pVz) {
|
||||
|
||||
unstackScalarsCuda<T><<<blocksPerGrid, threadsPerBlock, 256, *stream>>>(vx, xShapeInfo, pVz);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
static void unstack_(sd::LaunchContext* context, const NDArray& input, const std::vector<NDArray*>& outArrs, const int dim) {
|
||||
|
||||
const int numOfSubArrs = outArrs.size();
|
||||
|
||||
// NDArray::prepareSpecialUse(outArrs, {&input});
|
||||
input.syncToDevice();
|
||||
for (const auto a : outArrs)
|
||||
a->getDataBuffer()->allocateSpecial();
|
||||
|
||||
|
||||
if(outArrs[0]->rankOf() == 0) {
|
||||
|
||||
std::vector<void*> hOutBuffers(numOfSubArrs);
|
||||
|
||||
for(int i = 0; i < numOfSubArrs; ++i)
|
||||
hOutBuffers[i] = outArrs[i]->getSpecialBuffer();
|
||||
|
||||
PointersManager manager(context, "helpers::unstack cuda");
|
||||
|
||||
void* dOutBuffers = manager.replicatePointer(hOutBuffers.data(), hOutBuffers.size() * sizeof(void*));
|
||||
|
||||
const int threadsPerBlock = MAX_NUM_THREADS / 2;
|
||||
const int blocksPerGrid = (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
||||
|
||||
unstackScalarsCudaLauncher<T>(blocksPerGrid, threadsPerBlock, context->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), dOutBuffers);
|
||||
|
||||
manager.synchronize();
|
||||
}
|
||||
else {
|
||||
|
||||
auto xTadPack = ConstantTadHelper::getInstance()->tadForDimensions(input.getShapeInfo(), ShapeUtils::evalDimsToExclude(input.rankOf(), {dim}));
|
||||
Nd4jLong* xTadShapeInfo = xTadPack.primaryShapeInfo();
|
||||
|
||||
for (uint i = 0; i < numOfSubArrs; ++i) {
|
||||
|
||||
void* xBuff = input.specialBufferWithOffset(xTadPack.primaryOffsets()[i]);
|
||||
|
||||
NativeOpExecutioner::execTransformAny(input.getContext(), transform::Assign,
|
||||
nullptr, xTadShapeInfo, xBuff, xTadPack.specialShapeInfo(),
|
||||
nullptr, outArrs[i]->getShapeInfo(), outArrs[i]->specialBuffer(), outArrs[i]->specialShapeInfo(),
|
||||
nullptr, nullptr, nullptr, false/*allowParallelism*/);
|
||||
}
|
||||
}
|
||||
|
||||
// NDArray::registerSpecialUse(outArrs, {&input});
|
||||
input.tickReadDevice();
|
||||
for (const auto p : outArrs)
|
||||
p->tickWriteDevice();
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
void unstack(sd::LaunchContext* context, const NDArray& input, const std::vector<NDArray*>& outArrs, const int dim) {
|
||||
BUILD_SINGLE_SELECTOR(input.dataType(), unstack_, (context, input, outArrs, dim), LIBND4J_TYPES);
|
||||
}
|
||||
BUILD_SINGLE_TEMPLATE(template void unstack_, (sd::LaunchContext* context, const NDArray& input, const std::vector<NDArray*>& outArrs, const int dim), LIBND4J_TYPES);
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
// template <typename T>
|
||||
// static __global__ void unstackCuda(const void* vx, const Nd4jLong* xShapeInfo, void* pVz, const Nd4jLong* zTadShapeInfo, const int axis) {
|
||||
|
||||
// const T* x = reinterpret_cast<const T*>(vx);
|
||||
// __shared__ Nd4jLong xLen, totalThreads;
|
||||
// __shared__ int xRank;
|
||||
|
||||
// if (threadIdx.x == 0) {
|
||||
// xLen = shape::length(xShapeInfo);
|
||||
// xRank = shape::rank(xShapeInfo);
|
||||
// totalThreads = gridDim.x * blockDim.x;
|
||||
// }
|
||||
// __syncthreads();
|
||||
|
||||
// const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
// Nd4jLong coords[MAX_RANK];
|
||||
|
||||
// for (uint64_t i = tid; i < xLen; i += totalThreads) {
|
||||
|
||||
// shape::index2coords(i, xShapeInfo, coords);
|
||||
|
||||
// const auto xOffset = shape::getOffset(xShapeInfo, coords);
|
||||
|
||||
// T *z = reinterpret_cast<T*>(reinterpret_cast<void **>(pVz)[coords[axis]]);
|
||||
|
||||
// for (uint j = axis; j < xRank - 1; ++j) // shift coords staring from axis position
|
||||
// coords[j] = coords[j + 1];
|
||||
|
||||
// const auto zOffset = shape::getOffset(zTadShapeInfo, coords);
|
||||
|
||||
// z[zOffset] = x[xOffset];
|
||||
// }
|
||||
// }
|
||||
|
||||
// ///////////////////////////////////////////////////////////////////
|
||||
// template<typename T>
|
||||
// __host__ static void unstackCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream,
|
||||
// const void* vx, const Nd4jLong* xShapeInfo, void* pVz, const Nd4jLong* zTadShapeInfo, const int axis) {
|
||||
|
||||
// unstackCuda<T><<<blocksPerGrid, threadsPerBlock, 256, *stream>>>(vx, xShapeInfo, pVz, zTadShapeInfo, axis);
|
||||
// }
|
||||
// BUILD_SINGLE_TEMPLATE(template void unstackCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, const void* vx, const Nd4jLong* xShapeInfo, void* pVz, const Nd4jLong* zTadShapeInfo, const int axis), LIBND4J_TYPES);
|
||||
|
||||
|
||||
// ///////////////////////////////////////////////////////////////////
|
||||
// void unstack(sd::LaunchContext* context, const NDArray& input, const std::vector<const NDArray*>& outArrs, const int axis) {
|
||||
|
||||
// const int threadsPerBlock = MAX_NUM_THREADS / 2;
|
||||
// const int blocksPerGrid = (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
||||
|
||||
// const int numOfSubArrs = outArrs.size();
|
||||
|
||||
// std::vector<void*> hOutBuffers(numOfSubArrs);
|
||||
|
||||
// for(int i = 0; i < numOfSubArrs; ++i)
|
||||
// hOutBuffers[i] = outArrs[i]->getSpecialBuffer();
|
||||
|
||||
// PointersManager manager(context, "helpers::unstack");
|
||||
|
||||
// void* dOutBuffers = manager.replicatePointer(hOutBuffers.data(), hOutBuffers.size() * sizeof(void*));
|
||||
|
||||
// for(uint i = 0; i < numOfSubArrs; ++i)
|
||||
// outArrs[i]->syncToDevice();
|
||||
// input.syncToDevice();
|
||||
|
||||
// BUILD_SINGLE_SELECTOR(input.dataType(), unstackCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), dOutBuffers, outArrs[0]->getSpecialShapeInfo(), axis), LIBND4J_TYPES);
|
||||
|
||||
// manager.synchronize();
|
||||
|
||||
// for(uint i = 0; i < numOfSubArrs; ++i)
|
||||
// outArrs[i]->tickReadDevice();
|
||||
// input.tickWriteDevice();
|
||||
// }
|
||||
|
||||
|
||||
// ///////////////////////////////////////////////////////////////////
|
||||
// template <typename T>
|
||||
// static __global__ void stackCuda(void* pVx, const Nd4jLong* xTadShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int axis) {
|
||||
|
||||
// T* z = reinterpret_cast<T*>(vz);
|
||||
|
||||
// __shared__ Nd4jLong zLen, totalThreads;
|
||||
// __shared__ int zRank;
|
||||
|
||||
// if (threadIdx.x == 0) {
|
||||
// zLen = shape::length(zShapeInfo);
|
||||
// zRank = shape::rank(zShapeInfo);
|
||||
// totalThreads = gridDim.x * blockDim.x;
|
||||
// }
|
||||
// __syncthreads();
|
||||
|
||||
// const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
// Nd4jLong coords[MAX_RANK];
|
||||
|
||||
// for (uint64_t i = tid; i < zLen; i += totalThreads) {
|
||||
|
||||
// shape::index2coords(i, zShapeInfo, coords);
|
||||
|
||||
// const auto zOffset = shape::getOffset(zShapeInfo, coords);
|
||||
|
||||
// const T *x = reinterpret_cast<const T*>(reinterpret_cast<void**>(pVx)[coords[axis]]);
|
||||
|
||||
// for (uint j = axis; j < zRank - 1; ++j) // shift coords staring from axis position
|
||||
// coords[j] = coords[j + 1];
|
||||
|
||||
// const auto xOffset = shape::getOffset(xTadShapeInfo, coords);
|
||||
|
||||
// z[zOffset] = x[xOffset];
|
||||
// }
|
||||
// }
|
||||
|
||||
// ///////////////////////////////////////////////////////////////////
|
||||
// template<typename T>
|
||||
// __host__ static void stackCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream,
|
||||
// void* pVx, const Nd4jLong* xTadShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int axis) {
|
||||
|
||||
// stackCuda<T><<<blocksPerGrid, threadsPerBlock, 256, *stream>>>(pVx, xTadShapeInfo, vz, zShapeInfo, axis);
|
||||
// }
|
||||
// BUILD_SINGLE_TEMPLATE(template void stackCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, void* pVx, const Nd4jLong* xTadShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int axis), LIBND4J_TYPES);
|
||||
|
||||
|
||||
// ///////////////////////////////////////////////////////////////////
|
||||
// void stack(sd::LaunchContext* context, const std::vector<const NDArray*>& inArrs, NDArray& output, const int axis) {
|
||||
|
||||
// const int threadsPerBlock = MAX_NUM_THREADS / 2;
|
||||
// const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
||||
|
||||
// const int numOfSubArrs = inArrs.size();
|
||||
|
||||
// std::vector<void*> hInBuffers(numOfSubArrs);
|
||||
|
||||
// for(int i = 0; i < numOfSubArrs; ++i)
|
||||
// hInBuffers[i] = inArrs[i]->getSpecialBuffer();
|
||||
|
||||
// PointersManager manager(context, "helpers::stack");
|
||||
|
||||
// void* dInBuffers = manager.replicatePointer(hInBuffers.data(), hInBuffers.size() * sizeof(void*));
|
||||
|
||||
// for(uint i = 0; i < numOfSubArrs; ++i)
|
||||
// inArrs[i]->syncToDevice();
|
||||
// output.syncToDevice();
|
||||
|
||||
// BUILD_SINGLE_SELECTOR(output.dataType(), stackCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), dInBuffers, inArrs[0]->getSpecialShapeInfo(), output.getSpecialBuffer(), output.getSpecialShapeInfo(), axis), LIBND4J_TYPES);
|
||||
|
||||
// manager.synchronize();
|
||||
|
||||
// for(uint i = 0; i < numOfSubArrs; ++i)
|
||||
// inArrs[i]->tickReadDevice();
|
||||
// output.tickWriteDevice();
|
||||
// }
|
||||
|
||||
}
|
||||
}
|
||||
|
|
|
@ -124,19 +124,19 @@ static void svdQR(sd::LaunchContext* context, const NDArray* A, NDArray* S, NDAr
|
|||
if(m < n)
|
||||
throw std::runtime_error("svdQR: due to cuda api input constrains given shape of A array are not valid !");
|
||||
|
||||
if(ShapeUtils::shapeAsString({minDim}) != ShapeUtils::shapeAsString(S))
|
||||
if(std::vector<Nd4jLong>({minDim}) != S->getShapeAsVector())
|
||||
throw std::runtime_error("svdQR: wrong shape of S array !");
|
||||
|
||||
if(calcUV) {
|
||||
|
||||
if(fullUV && ShapeUtils::shapeAsString({m,m}) != ShapeUtils::shapeAsString(U))
|
||||
if(fullUV && std::vector<Nd4jLong>({m,m}) != U->getShapeAsVector())
|
||||
throw std::runtime_error("svdQR: wrong shape of U array !");
|
||||
else if(!fullUV && ShapeUtils::shapeAsString({m,minDim}) != ShapeUtils::shapeAsString(U))
|
||||
else if(!fullUV && std::vector<Nd4jLong>({m,minDim}) != U->getShapeAsVector())
|
||||
throw std::runtime_error("svdQR: wrong shape of U array !");
|
||||
|
||||
if(fullUV && ShapeUtils::shapeAsString({n,n}) != ShapeUtils::shapeAsString(VT))
|
||||
if(fullUV && std::vector<Nd4jLong>({n,n}) != VT->getShapeAsVector())
|
||||
throw std::runtime_error("svdQR: wrong shape of VT array !");
|
||||
else if(!fullUV && ShapeUtils::shapeAsString({minDim,n}) != ShapeUtils::shapeAsString(VT))
|
||||
else if(!fullUV && std::vector<Nd4jLong>({minDim,n}) != VT->getShapeAsVector())
|
||||
throw std::runtime_error("svdQR: wrong shape of VT array !");
|
||||
}
|
||||
|
||||
|
@ -280,19 +280,19 @@ static void svdJcb(sd::LaunchContext* context, const NDArray* A, NDArray* S, NDA
|
|||
int n = A->sizeAt(1);
|
||||
const int minDim = m < n ? m : n;
|
||||
|
||||
if(ShapeUtils::shapeAsString({minDim}) != ShapeUtils::shapeAsString(S))
|
||||
if(std::vector<Nd4jLong>({minDim}) != S->getShapeAsVector())
|
||||
throw std::runtime_error("svdJcb: wrong shape of S array !");
|
||||
|
||||
if(calcUV) {
|
||||
|
||||
if(fullUV && ShapeUtils::shapeAsString({m,m}) != ShapeUtils::shapeAsString(U))
|
||||
if(fullUV && std::vector<Nd4jLong>({m,m}) != U->getShapeAsVector())
|
||||
throw std::runtime_error("svdJcb: wrong shape of U array !");
|
||||
else if(!fullUV && ShapeUtils::shapeAsString({m,minDim}) != ShapeUtils::shapeAsString(U))
|
||||
else if(!fullUV && std::vector<Nd4jLong>({m,minDim}) != U->getShapeAsVector())
|
||||
throw std::runtime_error("svdJcb: wrong shape of U array !");
|
||||
|
||||
if(fullUV && ShapeUtils::shapeAsString({n,n}) != ShapeUtils::shapeAsString(V))
|
||||
if(fullUV && std::vector<Nd4jLong>({n,n}) != V->getShapeAsVector())
|
||||
throw std::runtime_error("svdJcb: wrong shape of V array !");
|
||||
else if(!fullUV && ShapeUtils::shapeAsString({n,minDim}) != ShapeUtils::shapeAsString(V))
|
||||
else if(!fullUV && std::vector<Nd4jLong>({n,minDim}) != V->getShapeAsVector())
|
||||
throw std::runtime_error("svdJcb: wrong shape of V array !");
|
||||
}
|
||||
|
||||
|
|
|
@ -28,7 +28,8 @@ namespace sd {
|
|||
namespace ops {
|
||||
namespace helpers {
|
||||
|
||||
void stack(sd::LaunchContext * context, const std::vector<const NDArray*>& inArrs, NDArray* outArr, const int dim);
|
||||
void stack (sd::LaunchContext* context, const std::vector<const NDArray*>& inArrs, NDArray& outArr, const int dim);
|
||||
void unstack(sd::LaunchContext* context, const NDArray& input, const std::vector<NDArray*>& outArrs, const int dim);
|
||||
|
||||
|
||||
}
|
||||
|
|
|
@ -70,7 +70,11 @@ namespace helpers {
|
|||
|
||||
void mirrorPad(sd::LaunchContext * context, const NDArray& input, const NDArray& paddings, NDArray& output, const int mode);
|
||||
|
||||
void concat(sd::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output, const int axis);
|
||||
void clipByValue(sd::LaunchContext * context, NDArray& input, double leftBound, double rightBound, NDArray& output);
|
||||
|
||||
void mirrorPad(sd::LaunchContext * context, const NDArray& input, const NDArray& paddings, NDArray& output, const int mode);
|
||||
|
||||
void concat(sd::LaunchContext * context, const std::vector<const NDArray*>& inArrs, NDArray& output, const int axis);
|
||||
|
||||
void tileBP(sd::LaunchContext * context, const NDArray& gradO /*input*/, NDArray& gradI /*output*/, const std::vector<Nd4jLong> reps);
|
||||
|
||||
|
|
|
@ -43,11 +43,8 @@ namespace sd {
|
|||
|
||||
NDArray::prepareSpecialUse({z}, {x, y});
|
||||
|
||||
if (!x->isSameShape(y)) {
|
||||
std::string sx = ShapeUtils::shapeAsString(x);
|
||||
std::string sy = ShapeUtils::shapeAsString(y);
|
||||
REQUIRE_TRUE(x->isSameShape(y) || y->isScalar(), 0, "Node_%i: For Pairwise transforms shapes of both operands should be equal but got %s vs %s", block.getNodeId(), sx.c_str(), sy.c_str());
|
||||
}
|
||||
if (!x->isSameShape(y))
|
||||
REQUIRE_TRUE(x->isSameShape(y) || y->isScalar(), 0, "Node_%i: For Pairwise transforms shapes of both operands should be equal but got %s vs %s", block.getNodeId(), ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str());
|
||||
|
||||
int opNum = block.opNum() < 0 ? this->_opNum : block.opNum();
|
||||
|
||||
|
|
|
@ -43,11 +43,8 @@ namespace sd {
|
|||
|
||||
NDArray::prepareSpecialUse({z}, {x, y});
|
||||
|
||||
if (!x->isSameShape(y)) {
|
||||
std::string sx = ShapeUtils::shapeAsString(x);
|
||||
std::string sy = ShapeUtils::shapeAsString(y);
|
||||
REQUIRE_TRUE(x->isSameShape(y) || y->isScalar(), 0, "Node_%i: For Pairwise transforms shapes of both operands should be equal but got %s vs %s", block.getNodeId(), sx.c_str(), sy.c_str());
|
||||
}
|
||||
if (!x->isSameShape(y))
|
||||
REQUIRE_TRUE(x->isSameShape(y) || y->isScalar(), 0, "Node_%i: For Pairwise transforms shapes of both operands should be equal but got %s vs %s", block.getNodeId(), ShapeUtils::shapeAsString(x).c_str(), ShapeUtils::shapeAsString(y).c_str());
|
||||
|
||||
int opNum = block.opNum() < 0 ? this->_opNum : block.opNum();
|
||||
|
||||
|
|
|
@ -108,7 +108,7 @@ namespace sd {
|
|||
// }
|
||||
|
||||
template <typename T>
|
||||
void SpecialMethods<T>::concatCpuGeneric(const std::vector<NDArray*>& inArrs, NDArray& output, const int axis) {
|
||||
void SpecialMethods<T>::concatCpuGeneric(const std::vector<const NDArray*>& inArrs, NDArray& output, const int axis) {
|
||||
|
||||
const int numOfInArrs = inArrs.size();
|
||||
const auto sizeofT = output.sizeOfT();
|
||||
|
@ -136,38 +136,44 @@ void SpecialMethods<T>::concatCpuGeneric(const std::vector<NDArray*>& inArrs, ND
|
|||
return;
|
||||
}
|
||||
|
||||
const bool isZcontin = output.strideAt(axis) == 1 && output.ordering() == 'c';
|
||||
bool areInputsContin = true;
|
||||
bool allSameOrder = true;
|
||||
// const bool isZcontin = output.strideAt(axis) == 1;
|
||||
// bool areInputsContin = true;
|
||||
// bool allSameOrder = true;
|
||||
// std::vector<Nd4jLong> strideOfContigStride(numOfInArrs);
|
||||
|
||||
if(isZcontin) {
|
||||
for (uint i = 0; i < numOfInArrs; ++i) {
|
||||
areInputsContin &= inArrs[i]->strideAt(axis) == 1;
|
||||
allSameOrder &= inArrs[i]->ordering() == output.ordering();
|
||||
if(!areInputsContin || !allSameOrder)
|
||||
break;
|
||||
}
|
||||
}
|
||||
// if(isZcontin) {
|
||||
|
||||
const bool luckCase2 = isZcontin && areInputsContin && allSameOrder;
|
||||
// for (uint i = 0; i < numOfInArrs; ++i) {
|
||||
|
||||
if(luckCase2) { // for example {2,1,3} + {2,5,3} + {2,10,3} = {2,16,3}, here axis 1 shoud have stride = 1 for all inputs arrays and output array
|
||||
// areInputsContin &= inArrs[i]->strideAt(axis) == 1;
|
||||
// allSameOrder &= inArrs[i]->ordering() == output.ordering();
|
||||
// if(!areInputsContin || !allSameOrder)
|
||||
// break;
|
||||
|
||||
const uint zDim = output.sizeAt(axis);
|
||||
// strideOfContigStride[i] = shape::strideOverContigAxis(axis, inArrs[i]->getShapeInfo());
|
||||
// }
|
||||
// }
|
||||
|
||||
for (uint i = 0; i < output.lengthOf() / zDim; ++i) {
|
||||
T* z = zBuff + zDim * i;
|
||||
// const bool luckCase2 = isZcontin && areInputsContin && allSameOrder;
|
||||
|
||||
for (uint j = 0; j < inArrs.size(); ++j) {
|
||||
const auto xDim = inArrs[j]->sizeAt(axis);
|
||||
const T* x = inArrs[j]->bufferAsT<T>() + xDim * i;
|
||||
memcpy(z, x, xDim * sizeofT);
|
||||
z += xDim;
|
||||
}
|
||||
}
|
||||
// if(luckCase2) { // for example {2,1,3} + {2,5,3} + {2,10,3} = {2,16,3}, here axis 1 shoud have stride = 1 for all inputs arrays and output array
|
||||
|
||||
return;
|
||||
}
|
||||
// const auto zStep = shape::strideOverContigAxis(axis, output.getShapeInfo());
|
||||
|
||||
// for (uint i = 0; i < output.lengthOf() / output.sizeAt(axis); ++i) {
|
||||
|
||||
// T* z = zBuff + zStep * i;
|
||||
|
||||
// for (uint j = 0; j < inArrs.size(); ++j) {
|
||||
// const auto xDim = inArrs[j]->sizeAt(axis);
|
||||
// const T* x = inArrs[j]->bufferAsT<T>() + strideOfContigStride[j] * i;
|
||||
// memcpy(z, x, xDim * sizeofT);
|
||||
// z += xDim;
|
||||
// }
|
||||
// }
|
||||
|
||||
// return;
|
||||
// }
|
||||
|
||||
// general case
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
|
@ -204,7 +210,7 @@ void SpecialMethods<T>::concatCpuGeneric(const std::vector<NDArray*>& inArrs, ND
|
|||
template <typename T>
|
||||
void SpecialMethods<T>::concatCpuGeneric(int dimension, int numArrays, Nd4jPointer *data, Nd4jPointer *inputShapeInfo, void *vresult, Nd4jLong *resultShapeInfo) {
|
||||
auto result = reinterpret_cast<T *>(vresult);
|
||||
std::vector<NDArray*> inputs(numArrays);
|
||||
std::vector<const NDArray*> inputs(numArrays);
|
||||
|
||||
NDArray output(static_cast<void*>(result), static_cast<Nd4jLong*>(resultShapeInfo));
|
||||
|
||||
|
@ -217,6 +223,104 @@ void SpecialMethods<T>::concatCpuGeneric(int dimension, int numArrays, Nd4jPoint
|
|||
delete inputs[i];
|
||||
}
|
||||
|
||||
|
||||
template <typename T>
|
||||
void SpecialMethods<T>::splitCpuGeneric(const NDArray& input, const std::vector<NDArray*>& outArrs, const int axis) {
|
||||
|
||||
int numSplits = outArrs.size();
|
||||
|
||||
const auto sizeofT = input.sizeOfT();
|
||||
|
||||
T* xBuff = input.bufferAsT<T>();
|
||||
|
||||
bool luckCase1 = ((axis == 0 && input.ordering() == 'c') || (axis == input.rankOf() - 1 && input.ordering() == 'f')) && input.ews() == 1;
|
||||
|
||||
if (luckCase1) {
|
||||
for (uint i = 0; i < numSplits; ++i) {
|
||||
luckCase1 &= outArrs[i]->ordering() == input.ordering() && outArrs[i]->ews() == 1;
|
||||
if (!luckCase1)
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (luckCase1) {
|
||||
|
||||
T* x = const_cast<T*>(xBuff);
|
||||
for (uint i = 0; i < numSplits; ++i) {
|
||||
const auto memAmountToCopy = outArrs[i]->lengthOf();
|
||||
memcpy(outArrs[i]->bufferAsT<T>(), x, memAmountToCopy * sizeofT);
|
||||
x += memAmountToCopy;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// const bool isXcontin = input.strideAt(axis) == 1;
|
||||
// bool areOutsContin = true;
|
||||
// bool allSameOrder = true;
|
||||
// std::vector<Nd4jLong> strideOfContigStride(numSplits);
|
||||
|
||||
// if (isXcontin) {
|
||||
|
||||
// for (uint i = 0; i < numSplits; ++i) {
|
||||
|
||||
// areOutsContin &= outArrs[i]->strideAt(axis) == 1;
|
||||
// allSameOrder &= outArrs[i]->ordering() == input.ordering();
|
||||
// if (!areOutsContin || !allSameOrder)
|
||||
// break;
|
||||
|
||||
// strideOfContigStride[i] = shape::strideOverContigAxis(axis, outArrs[i]->getShapeInfo());
|
||||
// }
|
||||
// }
|
||||
|
||||
// const bool luckCase2 = isXcontin && areOutsContin && allSameOrder;
|
||||
|
||||
// if (luckCase2) {
|
||||
|
||||
// const auto xStep = shape::strideOverContigAxis(axis, input.getShapeInfo());
|
||||
|
||||
// for (uint i = 0; i < input.lengthOf() / input.sizeAt(axis); ++i) {
|
||||
|
||||
// T* x = xBuff + xStep * i;
|
||||
|
||||
// for (uint j = 0; j < numSplits; ++j) {
|
||||
// const auto zDim = outArrs[j]->sizeAt(axis);
|
||||
// T* z = outArrs[j]->bufferAsT<T>() + strideOfContigStride[j] * i;
|
||||
// memcpy(z, x, zDim * sizeofT);
|
||||
// x += zDim;
|
||||
// }
|
||||
// }
|
||||
|
||||
// return;
|
||||
// }
|
||||
|
||||
uint zDim = outArrs[0]->sizeAt(axis);
|
||||
// general case
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR{
|
||||
|
||||
Nd4jLong coords[MAX_RANK];
|
||||
for (auto i = start; i < stop; i += increment) {
|
||||
|
||||
shape::index2coords(i, input.getShapeInfo(), coords);
|
||||
const auto xOffset = shape::getOffset(input.getShapeInfo(), coords);
|
||||
|
||||
uint outArrIdx = 0;
|
||||
|
||||
while (coords[axis] >= zDim) {
|
||||
coords[axis] -= zDim;
|
||||
++outArrIdx;
|
||||
}
|
||||
|
||||
T* z = outArrs[outArrIdx]->bufferAsT<T>();
|
||||
const auto zOffset = shape::getOffset(outArrs[outArrIdx]->getShapeInfo(), coords);
|
||||
z[zOffset] = xBuff[xOffset];
|
||||
}
|
||||
};
|
||||
|
||||
samediff::Threads::parallel_for(func, 0, input.lengthOf());
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* This kernel accumulates X arrays, and stores result into Z
|
||||
*
|
||||
|
|
|
@ -50,8 +50,9 @@ namespace sd {
|
|||
template <typename T>
|
||||
class ND4J_EXPORT SpecialMethods {
|
||||
public:
|
||||
static void concatCpuGeneric(const std::vector<NDArray*>& inArrs, NDArray& output, const int axis);
|
||||
static void concatCpuGeneric(const std::vector<const NDArray*>& inArrs, NDArray& output, const int axis);
|
||||
static void concatCpuGeneric(int dimension, int numArrays, Nd4jPointer *data, Nd4jPointer *inputShapeInfo, void *result, Nd4jLong *resultShapeInfo);
|
||||
static void splitCpuGeneric(const NDArray& input, const std::vector<NDArray*>& outArrs, const int axis);
|
||||
static void accumulateGeneric(void **x, void *z, Nd4jLong *zShapeInfo, int n, const Nd4jLong length);
|
||||
static void averageGeneric(void **x, void *z, Nd4jLong *zShapeInfo, int n, const Nd4jLong length, bool propagate);
|
||||
|
||||
|
|
|
@ -2853,289 +2853,6 @@ TEST_F(DeclarableOpsTests1, LRN1) {
|
|||
lrn.getOpName();
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests1, Stack_1) {
|
||||
|
||||
float buff1[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12};
|
||||
float buff2[] = {13,14,16,16,17,18,19,20,21,22,23,24};
|
||||
float expBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12,13,14,16,16,17,18,19,20,21,22,23,24};
|
||||
Nd4jLong shape1[] = {2, 3, 4, 4, 1, 0, 1, 99};
|
||||
Nd4jLong shape2[] = {2, 3, 4, 4, 1, 0, 1, 99};
|
||||
Nd4jLong expShape[] = {3, 2, 3, 4, 12, 4, 1, 0, 1, 99};
|
||||
ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32);
|
||||
ArrayOptions::setDataType(shape2, sd::DataType::FLOAT32);
|
||||
ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32);
|
||||
|
||||
NDArray input1(buff1, shape1);
|
||||
NDArray input2(buff2, shape2);
|
||||
NDArray expected(expBuff, expShape);
|
||||
|
||||
sd::ops::stack op;
|
||||
auto results = op.evaluate({&input1, &input2}, {}, {0});
|
||||
auto output = results->at(0);
|
||||
|
||||
ASSERT_TRUE(expected.isSameShapeStrict(*output));
|
||||
ASSERT_TRUE(expected.equalsTo(output));
|
||||
|
||||
delete results;
|
||||
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests1, Stack_2) {
|
||||
|
||||
float buff1[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12};
|
||||
float buff2[] = {13,14,16,16,17,18,19,20,21,22,23,24};
|
||||
float expBuff[] = {1,2,3,4, 13, 14, 16, 16, 5,6,7,8, 17, 18, 19, 20, 9, 10, 11, 12, 21, 22, 23, 24};
|
||||
Nd4jLong shape1[] = {2, 3, 4, 4, 1, 0, 1, 99};
|
||||
Nd4jLong shape2[] = {2, 3, 4, 4, 1, 0, 1, 99};
|
||||
Nd4jLong expShape[] = {3, 3, 2, 4, 8, 4, 1, 0, 1, 99};
|
||||
ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32);
|
||||
ArrayOptions::setDataType(shape2, sd::DataType::FLOAT32);
|
||||
ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32);
|
||||
|
||||
NDArray input1(buff1, shape1);
|
||||
NDArray input2(buff2, shape2);
|
||||
NDArray expected(expBuff, expShape);
|
||||
|
||||
sd::ops::stack op;
|
||||
auto results = op.evaluate({&input1, &input2}, {}, {1});
|
||||
auto output = results->at(0);
|
||||
|
||||
ASSERT_TRUE(expected.isSameShapeStrict(*output));
|
||||
ASSERT_TRUE(expected.equalsTo(output));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests1, Stack_3) {
|
||||
|
||||
float buff1[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12};
|
||||
float buff2[] = {13,14,16,16,17,18,19,20,21,22,23,24};
|
||||
float expBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12,13,14,16,16,17,18,19,20,21,22,23,24};
|
||||
Nd4jLong shape1[] = {2, 1, 12, 12, 1, 0, 1, 99};
|
||||
Nd4jLong shape2[] = {2, 1, 12, 12, 1, 0, 1, 99};
|
||||
Nd4jLong expShape[] = {3, 2, 1, 12, 12, 12, 1, 0, 1, 99};
|
||||
ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32);
|
||||
ArrayOptions::setDataType(shape2, sd::DataType::FLOAT32);
|
||||
ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32);
|
||||
|
||||
NDArray input1(buff1, shape1);
|
||||
NDArray input2(buff2, shape2);
|
||||
NDArray expected(expBuff, expShape);
|
||||
|
||||
sd::ops::stack op;
|
||||
auto results = op.evaluate({&input1, &input2}, {}, {0});
|
||||
auto output = results->at(0);
|
||||
|
||||
ASSERT_TRUE(expected.isSameShapeStrict(*output));
|
||||
ASSERT_TRUE(expected.equalsTo(output));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests1, Stack_4) {
|
||||
|
||||
float buff1[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12};
|
||||
float buff2[] = {13,14,16,16,17,18,19,20,21,22,23,24};
|
||||
float expBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12,13,14,16,16,17,18,19,20,21,22,23,24};
|
||||
Nd4jLong shape1[] = {2, 1, 12, 12, 1, 0, 1, 99};
|
||||
Nd4jLong shape2[] = {2, 1, 12, 12, 1, 0, 1, 99};
|
||||
Nd4jLong expShape[] = {3, 1, 2, 12, 24, 12, 1, 0, 1, 99};
|
||||
ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32);
|
||||
ArrayOptions::setDataType(shape2, sd::DataType::FLOAT32);
|
||||
ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32);
|
||||
|
||||
NDArray input1(buff1, shape1);
|
||||
NDArray input2(buff2, shape2);
|
||||
NDArray expected(expBuff, expShape);
|
||||
|
||||
sd::ops::stack op;
|
||||
auto results = op.evaluate({&input1, &input2}, {}, {1});
|
||||
auto output = results->at(0);
|
||||
|
||||
ASSERT_TRUE(expected.isSameShapeStrict(*output));
|
||||
ASSERT_TRUE(expected.equalsTo(output));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests1, Stack_5) {
|
||||
|
||||
float buff1[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12};
|
||||
float buff2[] = {13,14,16,16,17,18,19,20,21,22,23,24};
|
||||
float expBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12,13,14,16,16,17,18,19,20,21,22,23,24};
|
||||
Nd4jLong shape1[] = {2, 12, 1, 1,1, 0, 1, 99};
|
||||
Nd4jLong shape2[] = {2, 12, 1, 1,1, 0, 1, 99};
|
||||
Nd4jLong expShape[] = {3, 2, 12, 1, 12, 1, 1, 0, 1, 99};
|
||||
ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32);
|
||||
ArrayOptions::setDataType(shape2, sd::DataType::FLOAT32);
|
||||
ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32);
|
||||
|
||||
NDArray input1(buff1, shape1);
|
||||
NDArray input2(buff2, shape2);
|
||||
NDArray expected(expBuff, expShape);
|
||||
|
||||
sd::ops::stack op;
|
||||
auto results = op.evaluate({&input1, &input2}, {}, {0});
|
||||
auto output = results->at(0);
|
||||
|
||||
ASSERT_TRUE(expected.isSameShapeStrict(*output));
|
||||
ASSERT_TRUE(expected.equalsTo(output));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests1, Stack_6) {
|
||||
|
||||
float buff1[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12};
|
||||
float buff2[] = {13,14,16,16,17,18,19,20,21,22,23,24};
|
||||
float expBuff[] = {1 ,13 ,2 ,14 ,3 ,16 ,4 ,16 ,5 ,17 ,6 ,18 ,7 ,19 ,8 ,20 ,9 ,21 ,10 ,22 ,11 ,23 ,12 ,24};
|
||||
Nd4jLong shape1[] = {2, 12, 1, 1, 12, 0, 1, 99};
|
||||
Nd4jLong shape2[] = {2, 12, 1, 1, 12, 0, 1, 99};
|
||||
Nd4jLong expShape[] = {3, 12, 2, 1, 2, 1, 1, 0, 1, 99};
|
||||
ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32);
|
||||
ArrayOptions::setDataType(shape2, sd::DataType::FLOAT32);
|
||||
ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32);
|
||||
|
||||
NDArray input1(buff1, shape1);
|
||||
NDArray input2(buff2, shape2);
|
||||
NDArray expected(expBuff, expShape);
|
||||
|
||||
sd::ops::stack op;
|
||||
auto results = op.evaluate({&input1, &input2}, {}, {1});
|
||||
auto output = results->at(0);
|
||||
|
||||
ASSERT_TRUE(expected.isSameShapeStrict(*output));
|
||||
ASSERT_TRUE(expected.equalsTo(output));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests1, Stack_7) {
|
||||
|
||||
float buff1[] = {1};
|
||||
float expBuff[] = {1, 1, 1};
|
||||
Nd4jLong shape1[] = {2, 1, 1, 1, 1, 0, 1, 99};
|
||||
Nd4jLong expShape[] = {3, 3, 1, 1, 1, 1, 1, 0, 1, 99};
|
||||
ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32);
|
||||
ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32);
|
||||
|
||||
NDArray input1(buff1, shape1);
|
||||
NDArray expected(expBuff, expShape);
|
||||
|
||||
sd::ops::stack op;
|
||||
auto results = op.evaluate({&input1, &input1, &input1}, {}, {0});
|
||||
auto output = results->at(0);
|
||||
|
||||
ASSERT_TRUE(expected.isSameShapeStrict(*output));
|
||||
ASSERT_TRUE(expected.equalsTo(output));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests1, Stack_8) {
|
||||
|
||||
float buff1[] = {1};
|
||||
float expBuff[] = {1, 1, 1};
|
||||
Nd4jLong shape1[] = {1, 1, 1, 0, 1, 99};
|
||||
Nd4jLong expShape[] = {2, 3, 1, 1, 1, 0, 1, 99};
|
||||
ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32);
|
||||
ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32);
|
||||
|
||||
NDArray input1(buff1, shape1);
|
||||
NDArray expected(expBuff, expShape);
|
||||
|
||||
sd::ops::stack op;
|
||||
auto results = op.evaluate({&input1, &input1, &input1}, {}, {0});
|
||||
auto output = results->at(0);
|
||||
|
||||
ASSERT_TRUE(expected.isSameShapeStrict(*output));
|
||||
ASSERT_TRUE(expected.equalsTo(output));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests1, Stack_9) {
|
||||
|
||||
float buff1[] = {1};
|
||||
float expBuff[] = {1, 1, 1};
|
||||
Nd4jLong shape1[] = {2, 1, 1, 1, 1, 0, 1, 99};
|
||||
Nd4jLong expShape[] = {3, 1, 3, 1, 3, 1, 1, 0, 1, 99};
|
||||
ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32);
|
||||
ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32);
|
||||
|
||||
NDArray input1(buff1, shape1);
|
||||
NDArray expected(expBuff, expShape);
|
||||
|
||||
sd::ops::stack op;
|
||||
auto results = op.evaluate({&input1, &input1, &input1}, {}, {1});
|
||||
auto output = results->at(0);
|
||||
|
||||
ASSERT_TRUE(expected.isSameShapeStrict(*output));
|
||||
ASSERT_TRUE(expected.equalsTo(output));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests1, Stack_10) {
|
||||
|
||||
float buff1[] = {1};
|
||||
float expBuff[] = {1, 1, 1};
|
||||
Nd4jLong shape1[] = {1, 1, 1, 0, 1, 99};
|
||||
Nd4jLong expShape[] = {2, 1, 3, 3, 1, 0, 1, 99};
|
||||
ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32);
|
||||
ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32);
|
||||
|
||||
NDArray input1(buff1, shape1);
|
||||
NDArray expected(expBuff, expShape);
|
||||
|
||||
sd::ops::stack op;
|
||||
auto results = op.evaluate({&input1, &input1, &input1}, {}, {1});
|
||||
auto output = results->at(0);
|
||||
|
||||
//expected.printShapeInfo("exp");
|
||||
//output->printShapeInfo("out");
|
||||
|
||||
ASSERT_TRUE(expected.isSameShapeStrict(*output));
|
||||
ASSERT_TRUE(expected.equalsTo(output));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests1, Stack_11) {
|
||||
|
||||
float buff1[] = {1};
|
||||
float expBuff[] = {1, 1, 1};
|
||||
Nd4jLong shape1[] = {1, 1, 1, 0, 1, 99};
|
||||
Nd4jLong expShape[] = {2, 3, 1, 1, 1, 0, 1, 99};
|
||||
ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32);
|
||||
ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32);
|
||||
|
||||
NDArray input1(buff1, shape1);
|
||||
NDArray expected(expBuff, expShape);
|
||||
|
||||
sd::ops::stack op;
|
||||
auto results = op.evaluate({&input1, &input1, &input1}, {}, {});
|
||||
auto output = results->at(0);
|
||||
|
||||
ASSERT_TRUE(expected.isSameShapeStrict(*output));
|
||||
ASSERT_TRUE(expected.equalsTo(output));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
|
||||
TEST_F(DeclarableOpsTests1, Test_Range_Integer_1) {
|
||||
auto exp = NDArrayFactory::create<int>('c', {4});
|
||||
exp.linspace(1);
|
||||
|
@ -3330,74 +3047,6 @@ TEST_F(DeclarableOpsTests1, softmax_test8) {
|
|||
delete results;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests1, Test_Stack_Edge_1) {
|
||||
float inBuff[] = {1.0f, 2.0f, 3.0f};
|
||||
float expBuff[] = {1.0f, 2.0f, 3.0f};
|
||||
|
||||
auto input = NDArrayFactory::create<float>(inBuff, 'c', {1, 3});
|
||||
|
||||
auto exp = NDArrayFactory::create<float>(expBuff, 'c', {1, 1, 3});
|
||||
|
||||
sd::ops::stack op;
|
||||
|
||||
auto result = op.evaluate({&input}, {}, {0});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests1, Test_Stack_Edge_2) {
|
||||
float inBuff[] = {1.0f, 2.0f, 3.0f};
|
||||
float expBuff[] = {1.0f, 2.0f, 3.0f};
|
||||
|
||||
auto input = NDArrayFactory::create<float>(inBuff, 'c', {1, 1, 3});
|
||||
|
||||
auto exp = NDArrayFactory::create<float>(expBuff, 'c', {1, 1, 1, 3});
|
||||
|
||||
sd::ops::stack op;
|
||||
|
||||
auto result = op.evaluate({&input}, {}, {0});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests1, Test_Stack_Edge_3) {
|
||||
float inBuff[] = {1.0f, 2.0f, 3.0f};
|
||||
float expBuff[] = {1.0f, 2.0f, 3.0f};
|
||||
|
||||
auto input = NDArrayFactory::create<float>(inBuff, 'c', {1, 3});
|
||||
|
||||
auto exp = NDArrayFactory::create<float>(expBuff, 'c', {1, 1, 3});
|
||||
|
||||
sd::ops::stack op;
|
||||
|
||||
auto result = op.evaluate({&input}, {}, {1});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
|
||||
//z->printShapeInfo();
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests1, Reverse_1 ) {
|
||||
|
||||
|
|
|
@ -250,68 +250,6 @@ TEST_F(DeclarableOpsTests14, test_lstmBlockCell_1) {
|
|||
ASSERT_EQ(Status::OK(), result);
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests14, test_empty_stack_1) {
|
||||
auto x = NDArrayFactory::create<float>('c', {0});
|
||||
auto e = NDArrayFactory::create<float>('c', {1, 0});
|
||||
|
||||
sd::ops::stack op;
|
||||
auto result = op.evaluate({&x}, {}, {0});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
ASSERT_EQ(e, *z);
|
||||
sd::ops::reduce_min sumOp;
|
||||
auto res2 = sumOp.evaluate({&e}, {1.}, {1});
|
||||
ASSERT_EQ(res2->status(), Status::OK());
|
||||
auto out = res2->at(0);
|
||||
|
||||
ASSERT_EQ(out->e<float>(0), DataTypeUtils::infOrMax<float>());
|
||||
delete res2;
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests14, test_empty_stack_2) {
|
||||
auto x = NDArrayFactory::empty<float>();
|
||||
auto e = NDArrayFactory::create<float>('c', {0});
|
||||
|
||||
sd::ops::stack op;
|
||||
auto result = op.evaluate({&x}, {}, {0});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
ASSERT_EQ(e, *z);
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests14, test_empty_stack_3) {
|
||||
auto x = NDArrayFactory::empty<float>();
|
||||
auto e = NDArrayFactory::create<float>('c', {2, 0});
|
||||
|
||||
sd::ops::stack op;
|
||||
auto result = op.evaluate({&x, &x}, {}, {0});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
ASSERT_EQ(e, *z);
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests14, test_empty_stack_4) {
|
||||
auto x = NDArrayFactory::create<float>('c', {0});
|
||||
auto e = NDArrayFactory::create<float>('c', {2, 0});
|
||||
|
||||
sd::ops::stack op;
|
||||
auto result = op.evaluate({&x, &x}, {}, {0});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
ASSERT_EQ(e, *z);
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests14, test_empty_reduce_min_1) {
|
||||
|
||||
auto e = NDArrayFactory::create<float>('c', {1, 0});
|
||||
|
@ -1655,30 +1593,489 @@ TEST_F(DeclarableOpsTests14, Test_broadcast_5D_4) {
|
|||
ASSERT_EQ(e, z);
|
||||
}
|
||||
|
||||
// @Test
|
||||
// public void testMmulRank4_simple(){
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests14, Stack_1) {
|
||||
|
||||
// INDArray arr1 = Nd4j.ones(DataType.FLOAT, 32, 12, 128, 64);
|
||||
// INDArray arr2 = Nd4j.ones(DataType.FLOAT, 32, 12, 128, 64);
|
||||
float buff1[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12};
|
||||
float buff2[] = {13,14,16,16,17,18,19,20,21,22,23,24};
|
||||
float expBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12,13,14,16,16,17,18,19,20,21,22,23,24};
|
||||
Nd4jLong shape1[] = {2, 3, 4, 4, 1, 0, 1, 99};
|
||||
Nd4jLong shape2[] = {2, 3, 4, 4, 1, 0, 1, 99};
|
||||
Nd4jLong expShape[] = {3, 2, 3, 4, 12, 4, 1, 0, 1, 99};
|
||||
ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32);
|
||||
ArrayOptions::setDataType(shape2, sd::DataType::FLOAT32);
|
||||
ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32);
|
||||
|
||||
// DynamicCustomOp op = DynamicCustomOp.builder("matmul")
|
||||
// .addInputs(arr1, arr2)
|
||||
// .addIntegerArguments(0, 1) //Transpose arr2 only
|
||||
// .build();
|
||||
NDArray input1(buff1, shape1);
|
||||
NDArray input2(buff2, shape2);
|
||||
NDArray expected(expBuff, expShape);
|
||||
|
||||
// List<LongShapeDescriptor> shapes = op.calculateOutputShape();
|
||||
// assertEquals(1, shapes.size());
|
||||
// long[] shape = new long[]{32,12,128,128};
|
||||
// assertArrayEquals(shape, shapes.get(0).getShape());
|
||||
sd::ops::stack op;
|
||||
auto results = op.evaluate({&input1, &input2}, {}, {0});
|
||||
auto output = results->at(0);
|
||||
|
||||
// INDArray out = Nd4j.create(DataType.FLOAT, shape);
|
||||
ASSERT_TRUE(expected.isSameShapeStrict(*output));
|
||||
ASSERT_TRUE(expected.equalsTo(output));
|
||||
|
||||
// op.setOutputArgument(0, out);
|
||||
// Nd4j.exec(op);
|
||||
// // System.out.println(out);
|
||||
delete results;
|
||||
|
||||
// INDArray exp = Nd4j.valueArrayOf(shape, 64.0, DataType.FLOAT); //Each entry in output is sum of 64 (1.0 x 1.0) multiplications
|
||||
// assertEquals(exp, out);
|
||||
// }
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests14, Stack_2) {
|
||||
|
||||
float buff1[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12};
|
||||
float buff2[] = {13,14,16,16,17,18,19,20,21,22,23,24};
|
||||
float expBuff[] = {1,2,3,4, 13, 14, 16, 16, 5,6,7,8, 17, 18, 19, 20, 9, 10, 11, 12, 21, 22, 23, 24};
|
||||
Nd4jLong shape1[] = {2, 3, 4, 4, 1, 0, 1, 99};
|
||||
Nd4jLong shape2[] = {2, 3, 4, 4, 1, 0, 1, 99};
|
||||
Nd4jLong expShape[] = {3, 3, 2, 4, 8, 4, 1, 0, 1, 99};
|
||||
ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32);
|
||||
ArrayOptions::setDataType(shape2, sd::DataType::FLOAT32);
|
||||
ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32);
|
||||
|
||||
NDArray input1(buff1, shape1);
|
||||
NDArray input2(buff2, shape2);
|
||||
NDArray expected(expBuff, expShape);
|
||||
|
||||
sd::ops::stack op;
|
||||
auto results = op.evaluate({&input1, &input2}, {}, {1});
|
||||
auto output = results->at(0);
|
||||
|
||||
ASSERT_TRUE(expected.isSameShapeStrict(*output));
|
||||
ASSERT_TRUE(expected.equalsTo(output));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests14, Stack_3) {
|
||||
|
||||
float buff1[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12};
|
||||
float buff2[] = {13,14,16,16,17,18,19,20,21,22,23,24};
|
||||
float expBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12,13,14,16,16,17,18,19,20,21,22,23,24};
|
||||
Nd4jLong shape1[] = {2, 1, 12, 12, 1, 0, 1, 99};
|
||||
Nd4jLong shape2[] = {2, 1, 12, 12, 1, 0, 1, 99};
|
||||
Nd4jLong expShape[] = {3, 2, 1, 12, 12, 12, 1, 0, 1, 99};
|
||||
ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32);
|
||||
ArrayOptions::setDataType(shape2, sd::DataType::FLOAT32);
|
||||
ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32);
|
||||
|
||||
NDArray input1(buff1, shape1);
|
||||
NDArray input2(buff2, shape2);
|
||||
NDArray expected(expBuff, expShape);
|
||||
|
||||
sd::ops::stack op;
|
||||
auto results = op.evaluate({&input1, &input2}, {}, {0});
|
||||
auto output = results->at(0);
|
||||
|
||||
ASSERT_TRUE(expected.isSameShapeStrict(*output));
|
||||
ASSERT_TRUE(expected.equalsTo(output));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests14, Stack_4) {
|
||||
|
||||
float buff1[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12};
|
||||
float buff2[] = {13,14,16,16,17,18,19,20,21,22,23,24};
|
||||
float expBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12,13,14,16,16,17,18,19,20,21,22,23,24};
|
||||
Nd4jLong shape1[] = {2, 1, 12, 12, 1, 0, 1, 99};
|
||||
Nd4jLong shape2[] = {2, 1, 12, 12, 1, 0, 1, 99};
|
||||
Nd4jLong expShape[] = {3, 1, 2, 12, 24, 12, 1, 0, 1, 99};
|
||||
ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32);
|
||||
ArrayOptions::setDataType(shape2, sd::DataType::FLOAT32);
|
||||
ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32);
|
||||
|
||||
NDArray input1(buff1, shape1);
|
||||
NDArray input2(buff2, shape2);
|
||||
NDArray expected(expBuff, expShape);
|
||||
|
||||
sd::ops::stack op;
|
||||
auto results = op.evaluate({&input1, &input2}, {}, {1});
|
||||
auto output = results->at(0);
|
||||
|
||||
ASSERT_TRUE(expected.isSameShapeStrict(*output));
|
||||
ASSERT_TRUE(expected.equalsTo(output));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests14, Stack_5) {
|
||||
|
||||
float buff1[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12};
|
||||
float buff2[] = {13,14,16,16,17,18,19,20,21,22,23,24};
|
||||
float expBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12,13,14,16,16,17,18,19,20,21,22,23,24};
|
||||
Nd4jLong shape1[] = {2, 12, 1, 1,1, 0, 1, 99};
|
||||
Nd4jLong shape2[] = {2, 12, 1, 1,1, 0, 1, 99};
|
||||
Nd4jLong expShape[] = {3, 2, 12, 1, 12, 1, 1, 0, 1, 99};
|
||||
ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32);
|
||||
ArrayOptions::setDataType(shape2, sd::DataType::FLOAT32);
|
||||
ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32);
|
||||
|
||||
NDArray input1(buff1, shape1);
|
||||
NDArray input2(buff2, shape2);
|
||||
NDArray expected(expBuff, expShape);
|
||||
|
||||
sd::ops::stack op;
|
||||
auto results = op.evaluate({&input1, &input2}, {}, {0});
|
||||
auto output = results->at(0);
|
||||
|
||||
ASSERT_TRUE(expected.isSameShapeStrict(*output));
|
||||
ASSERT_TRUE(expected.equalsTo(output));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests14, Stack_6) {
|
||||
|
||||
float buff1[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12};
|
||||
float buff2[] = {13,14,16,16,17,18,19,20,21,22,23,24};
|
||||
float expBuff[] = {1 ,13 ,2 ,14 ,3 ,16 ,4 ,16 ,5 ,17 ,6 ,18 ,7 ,19 ,8 ,20 ,9 ,21 ,10 ,22 ,11 ,23 ,12 ,24};
|
||||
Nd4jLong shape1[] = {2, 12, 1, 1, 12, 0, 1, 99};
|
||||
Nd4jLong shape2[] = {2, 12, 1, 1, 12, 0, 1, 99};
|
||||
Nd4jLong expShape[] = {3, 12, 2, 1, 2, 1, 1, 0, 1, 99};
|
||||
ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32);
|
||||
ArrayOptions::setDataType(shape2, sd::DataType::FLOAT32);
|
||||
ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32);
|
||||
|
||||
NDArray input1(buff1, shape1);
|
||||
NDArray input2(buff2, shape2);
|
||||
NDArray expected(expBuff, expShape);
|
||||
|
||||
sd::ops::stack op;
|
||||
auto results = op.evaluate({&input1, &input2}, {}, {1});
|
||||
auto output = results->at(0);
|
||||
|
||||
ASSERT_TRUE(expected.isSameShapeStrict(*output));
|
||||
ASSERT_TRUE(expected.equalsTo(output));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests14, Stack_7) {
|
||||
|
||||
float buff1[] = {1};
|
||||
float expBuff[] = {1, 1, 1};
|
||||
Nd4jLong shape1[] = {2, 1, 1, 1, 1, 0, 1, 99};
|
||||
Nd4jLong expShape[] = {3, 3, 1, 1, 1, 1, 1, 0, 1, 99};
|
||||
ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32);
|
||||
ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32);
|
||||
|
||||
NDArray input1(buff1, shape1);
|
||||
NDArray expected(expBuff, expShape);
|
||||
|
||||
sd::ops::stack op;
|
||||
auto results = op.evaluate({&input1, &input1, &input1}, {}, {0});
|
||||
auto output = results->at(0);
|
||||
|
||||
ASSERT_TRUE(expected.isSameShapeStrict(*output));
|
||||
ASSERT_TRUE(expected.equalsTo(output));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests14, Stack_8) {
|
||||
|
||||
float buff1[] = {1};
|
||||
float expBuff[] = {1, 1, 1};
|
||||
Nd4jLong shape1[] = {1, 1, 1, 0, 1, 99};
|
||||
Nd4jLong expShape[] = {2, 3, 1, 1, 1, 0, 1, 99};
|
||||
ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32);
|
||||
ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32);
|
||||
|
||||
NDArray input1(buff1, shape1);
|
||||
NDArray expected(expBuff, expShape);
|
||||
|
||||
sd::ops::stack op;
|
||||
auto results = op.evaluate({&input1, &input1, &input1}, {}, {0});
|
||||
auto output = results->at(0);
|
||||
|
||||
ASSERT_TRUE(expected.isSameShapeStrict(*output));
|
||||
ASSERT_TRUE(expected.equalsTo(output));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests14, Stack_9) {
|
||||
|
||||
float buff1[] = {1};
|
||||
float expBuff[] = {1, 1, 1};
|
||||
Nd4jLong shape1[] = {2, 1, 1, 1, 1, 0, 1, 99};
|
||||
Nd4jLong expShape[] = {3, 1, 3, 1, 3, 1, 1, 0, 1, 99};
|
||||
ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32);
|
||||
ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32);
|
||||
|
||||
NDArray input1(buff1, shape1);
|
||||
NDArray expected(expBuff, expShape);
|
||||
|
||||
sd::ops::stack op;
|
||||
auto results = op.evaluate({&input1, &input1, &input1}, {}, {1});
|
||||
auto output = results->at(0);
|
||||
|
||||
ASSERT_TRUE(expected.isSameShapeStrict(*output));
|
||||
ASSERT_TRUE(expected.equalsTo(output));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests14, Stack_10) {
|
||||
|
||||
float buff1[] = {1};
|
||||
float expBuff[] = {1, 1, 1};
|
||||
Nd4jLong shape1[] = {1, 1, 1, 0, 1, 99};
|
||||
Nd4jLong expShape[] = {2, 1, 3, 3, 1, 0, 1, 99};
|
||||
ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32);
|
||||
ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32);
|
||||
|
||||
NDArray input1(buff1, shape1);
|
||||
NDArray expected(expBuff, expShape);
|
||||
|
||||
sd::ops::stack op;
|
||||
auto results = op.evaluate({&input1, &input1, &input1}, {}, {1});
|
||||
auto output = results->at(0);
|
||||
|
||||
//expected.printShapeInfo("exp");
|
||||
//output->printShapeInfo("out");
|
||||
|
||||
ASSERT_TRUE(expected.isSameShapeStrict(*output));
|
||||
ASSERT_TRUE(expected.equalsTo(output));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests14, Stack_11) {
|
||||
|
||||
float buff1[] = {1};
|
||||
float expBuff[] = {1, 1, 1};
|
||||
Nd4jLong shape1[] = {1, 1, 1, 0, 1, 99};
|
||||
Nd4jLong expShape[] = {2, 3, 1, 1, 1, 0, 1, 99};
|
||||
ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32);
|
||||
ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32);
|
||||
|
||||
NDArray input1(buff1, shape1);
|
||||
NDArray expected(expBuff, expShape);
|
||||
|
||||
sd::ops::stack op;
|
||||
auto results = op.evaluate({&input1, &input1, &input1}, {}, {});
|
||||
auto output = results->at(0);
|
||||
|
||||
ASSERT_TRUE(expected.isSameShapeStrict(*output));
|
||||
ASSERT_TRUE(expected.equalsTo(output));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests14, Stack_12) {
|
||||
float inBuff[] = {1.0f, 2.0f, 3.0f};
|
||||
float expBuff[] = {1.0f, 2.0f, 3.0f};
|
||||
|
||||
auto input = NDArrayFactory::create<float>(inBuff, 'c', {1, 3});
|
||||
|
||||
auto exp = NDArrayFactory::create<float>(expBuff, 'c', {1, 1, 3});
|
||||
|
||||
sd::ops::stack op;
|
||||
|
||||
auto result = op.evaluate({&input}, {}, {0});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests14, Stack_13) {
|
||||
float inBuff[] = {1.0f, 2.0f, 3.0f};
|
||||
float expBuff[] = {1.0f, 2.0f, 3.0f};
|
||||
|
||||
auto input = NDArrayFactory::create<float>(inBuff, 'c', {1, 1, 3});
|
||||
|
||||
auto exp = NDArrayFactory::create<float>(expBuff, 'c', {1, 1, 1, 3});
|
||||
|
||||
sd::ops::stack op;
|
||||
|
||||
auto result = op.evaluate({&input}, {}, {0});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests14, Stack_14) {
|
||||
float inBuff[] = {1.0f, 2.0f, 3.0f};
|
||||
float expBuff[] = {1.0f, 2.0f, 3.0f};
|
||||
|
||||
auto input = NDArrayFactory::create<float>(inBuff, 'c', {1, 3});
|
||||
|
||||
auto exp = NDArrayFactory::create<float>(expBuff, 'c', {1, 1, 3});
|
||||
|
||||
sd::ops::stack op;
|
||||
|
||||
auto result = op.evaluate({&input}, {}, {1});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
|
||||
//z->printShapeInfo();
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests14, Stack_15) {
|
||||
auto t = NDArrayFactory::create<double>('c', {2, 3, 5});
|
||||
auto u = NDArrayFactory::create<double>('c', {2, 3, 5});
|
||||
auto v = NDArrayFactory::create<double>('c', {2, 3, 5});
|
||||
auto exp = NDArrayFactory::create<double>('c', {3, 2, 3, 5});
|
||||
|
||||
sd::ops::stack op;
|
||||
auto result = op.evaluate({&t, &u, &v}, {}, {-4});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
|
||||
|
||||
auto z = result->at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
|
||||
TEST_F(DeclarableOpsTests14, Stack_16) {
|
||||
auto t = NDArrayFactory::create<float>(1.0f);
|
||||
auto u = NDArrayFactory::create<float>(2.0f);
|
||||
auto v = NDArrayFactory::create<float>(3.0f);
|
||||
auto exp = NDArrayFactory::create<float>('c', {3}, {1, 2, 3});
|
||||
|
||||
sd::ops::stack op;
|
||||
auto result = op.evaluate({&t, &u, &v}, {}, {0});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests14, Stack_17) {
|
||||
auto t = NDArrayFactory::create<float>('c', {1, 1}, {1.0f});
|
||||
auto u = NDArrayFactory::create<float>('c', {1, 1}, {2.0f});
|
||||
auto v = NDArrayFactory::create<float>('c', {1, 1}, {3.0f});
|
||||
auto w = NDArrayFactory::create<float>('c', {1, 1}, {4.0f});
|
||||
auto exp = NDArrayFactory::create<float>('c', {4, 1, 1}, {1, 2, 3, 4});
|
||||
|
||||
sd::ops::stack op;
|
||||
auto result = op.evaluate({&t, &u, &v, &w}, {}, {0});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
|
||||
// z->printShapeInfo("z shape");
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests14, Stack_18) {
|
||||
auto x = NDArrayFactory::create<float>('c', {0});
|
||||
auto e = NDArrayFactory::create<float>('c', {1, 0});
|
||||
|
||||
sd::ops::stack op;
|
||||
auto result = op.evaluate({&x}, {}, {0});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
ASSERT_EQ(e, *z);
|
||||
sd::ops::reduce_min sumOp;
|
||||
auto res2 = sumOp.evaluate({&e}, {1.}, {1});
|
||||
ASSERT_EQ(res2->status(), Status::OK());
|
||||
auto out = res2->at(0);
|
||||
|
||||
ASSERT_EQ(out->e<float>(0), DataTypeUtils::infOrMax<float>());
|
||||
delete res2;
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests14, Stack_19) {
|
||||
auto x = NDArrayFactory::empty<float>();
|
||||
auto e = NDArrayFactory::create<float>('c', {0});
|
||||
|
||||
sd::ops::stack op;
|
||||
auto result = op.evaluate({&x}, {}, {0});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
ASSERT_EQ(e, *z);
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests14, Stack_20) {
|
||||
auto x = NDArrayFactory::empty<float>();
|
||||
auto e = NDArrayFactory::create<float>('c', {2, 0});
|
||||
|
||||
sd::ops::stack op;
|
||||
auto result = op.evaluate({&x, &x}, {}, {0});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
ASSERT_EQ(e, *z);
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests14, Stack_21) {
|
||||
|
||||
NDArray x1('c', {3,2}, sd::DataType::FLOAT32);
|
||||
NDArray x2('c', {3,2}, sd::DataType::FLOAT32);
|
||||
x1.linspace(0);
|
||||
x2.linspace(6);
|
||||
|
||||
sd::ops::stack opStack;
|
||||
auto resultStack = opStack.evaluate({&x1, &x2}, {}, {0});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, resultStack->status());
|
||||
|
||||
|
||||
sd::ops::concat opConcat;
|
||||
auto resultConcat = opConcat.evaluate({&x1, &x2}, {}, {0});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, resultConcat->status());
|
||||
|
||||
auto outStack = resultStack->at(0);
|
||||
auto outConcat = resultConcat->at(0);
|
||||
|
||||
outConcat->reshapei({2,3,2});
|
||||
|
||||
ASSERT_TRUE(outStack->isSameShape(outConcat));
|
||||
ASSERT_TRUE(outStack->equalsTo(outConcat));
|
||||
|
||||
delete resultStack;
|
||||
delete resultConcat;
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -927,24 +927,6 @@ TEST_F(DeclarableOpsTests4, Test_Split_3) {
|
|||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests4, Test_Stack_4) {
|
||||
auto t = NDArrayFactory::create<double>('c', {2, 3, 5});
|
||||
auto u = NDArrayFactory::create<double>('c', {2, 3, 5});
|
||||
auto v = NDArrayFactory::create<double>('c', {2, 3, 5});
|
||||
auto exp = NDArrayFactory::create<double>('c', {3, 2, 3, 5});
|
||||
|
||||
sd::ops::stack op;
|
||||
auto result = op.evaluate({&t, &u, &v}, {}, {-4});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
|
||||
|
||||
auto z = result->at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests4, Test_Squeeze_args_1) {
|
||||
auto x = NDArrayFactory::create<double>('c', {2, 1, 1, 1, 2}, {1, 2, 3, 4});
|
||||
auto exp = NDArrayFactory::create<double>('c', {2, 1, 2}, {1, 2, 3, 4});
|
||||
|
@ -995,22 +977,6 @@ TEST_F(DeclarableOpsTests4, Test_Squeeze_args_3) {
|
|||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests4, Test_1D_1) {
|
||||
auto x = NDArrayFactory::create<double>('c', {2, 3});
|
||||
|
||||
sd::ops::unstack op;
|
||||
auto result = op.evaluate({&x}, {}, {1});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
|
||||
ASSERT_EQ(3, result->size());
|
||||
|
||||
for (int e = 0; e < 3; e++)
|
||||
ASSERT_EQ(1, result->at(e)->rankOf());
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests4, Test_SpaceToDepth_1) {
|
||||
auto x = NDArrayFactory::create<double>('c', {1, 2, 2, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||
auto exp = NDArrayFactory::create<double>('c', {1, 1, 1, 12}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||
|
|
|
@ -637,7 +637,7 @@ TEST_F(DeclarableOpsTests9, concat_test18) {
|
|||
for (int e = 0; e < 2000; e++) {
|
||||
auto exp = NDArrayFactory::create<int>('c', {300});
|
||||
exp.assign(e);
|
||||
auto row = z.tensorAlongDimension(e, {1});
|
||||
auto row = z(e, {0});
|
||||
ASSERT_EQ(exp, row);
|
||||
}
|
||||
}
|
||||
|
@ -778,6 +778,33 @@ TEST_F(DeclarableOpsTests9, concat_test25) {
|
|||
delete result;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, concat_test26) {
|
||||
|
||||
NDArray x0('f', {1, 2, 3}, sd::DataType::INT32);
|
||||
NDArray x1('f', {1, 2, 3}, sd::DataType::INT32);
|
||||
NDArray x2('f', {1, 2, 3}, sd::DataType::INT32);
|
||||
|
||||
NDArray exp('f', {3, 2, 3}, {0, 6, 12, 3, 9, 15, 1, 7, 13, 4, 10, 16, 2, 8, 14, 5, 11, 17}, sd::DataType::INT32);
|
||||
|
||||
x0.linspace(0);
|
||||
x1.linspace(6);
|
||||
x2.linspace(12);
|
||||
|
||||
sd::ops::concat op;
|
||||
|
||||
auto result = op.evaluate({&x0, &x1, &x2}, {}, {0}, {});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
auto output = result->at(0);
|
||||
output->printLinearBuffer();
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(output));
|
||||
ASSERT_TRUE(exp.equalsTo(output));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, tile_bp_test1) {
|
||||
|
||||
|
|
|
@ -206,24 +206,17 @@ TEST_F(EmptyTests, test_empty_scatter_1) {
|
|||
}
|
||||
|
||||
TEST_F(EmptyTests, test_empty_scatter_2) {
|
||||
auto x = NDArrayFactory::create<float>('c', {5});
|
||||
auto z = NDArrayFactory::create<float>('c', {5});
|
||||
NDArray x ('c', {5}, sd::DataType::FLOAT32);
|
||||
NDArray z ('c', {5}, sd::DataType::FLOAT32);
|
||||
auto indices = NDArrayFactory::create<int>('c', {0});
|
||||
auto updates = NDArrayFactory::create<float>('c', {0});
|
||||
|
||||
x.linspace(1.0f);
|
||||
|
||||
Context ctx(1);
|
||||
ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo());
|
||||
ctx.setInputArray(1, indices.buffer(), indices.shapeInfo(), indices.specialBuffer(), indices.specialShapeInfo());
|
||||
ctx.setInputArray(2, updates.buffer(), updates.shapeInfo(), updates.specialBuffer(), updates.specialShapeInfo());
|
||||
ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo());
|
||||
bool args[] = {true};
|
||||
ctx.setBArguments(args, 1);
|
||||
|
||||
sd::ops::scatter_upd op;
|
||||
auto result = op.execute(&ctx);
|
||||
ASSERT_EQ(Status::OK(), result);
|
||||
auto status = op.execute({&x, &indices, &updates}, {&z}, {}, {}, {true});
|
||||
|
||||
ASSERT_EQ(Status::OK(), status);
|
||||
|
||||
ASSERT_EQ(x, z);
|
||||
}
|
||||
|
|
|
@ -1242,7 +1242,7 @@ TEST_F(JavaInteropTests, test_ismax_view) {
|
|||
v.assign(1.0);
|
||||
|
||||
auto e = v.like();
|
||||
auto t = e.tensorAlongDimension(0, {0, 1});
|
||||
auto t = e(0, {2});
|
||||
t.assign(1.0);
|
||||
|
||||
auto z = v.ulike();
|
||||
|
|
|
@ -674,10 +674,10 @@ TEST_F(LegacyOpsTests, test_inverse_broadcast_2) {
|
|||
auto e = NDArrayFactory::create<bool>('c', {3, 4});
|
||||
e.assign(false);
|
||||
|
||||
auto row = y.tensorAlongDimension(1, {1});
|
||||
auto row = y(1, {0});
|
||||
row.assign(2.0f);
|
||||
|
||||
auto erow = e.tensorAlongDimension(1, {1});
|
||||
auto erow = e(1, {0});
|
||||
erow.assign(true);
|
||||
|
||||
auto tadPackY = sd::ConstantTadHelper::getInstance()->tadForDimensions(y.shapeInfo(), 1);
|
||||
|
|
|
@ -197,7 +197,7 @@ TEST_F(NDArrayTest, EqualityTest1) {
|
|||
TEST_F(NDArrayTest, TestTad1) {
|
||||
auto array = NDArrayFactory::create_<float>('c', {3, 3});
|
||||
|
||||
auto row2 = array->tensorAlongDimension(1, {1});
|
||||
auto row2 = (*array)(1, {0});
|
||||
|
||||
ASSERT_TRUE(row2.isView());
|
||||
ASSERT_EQ(3, row2.lengthOf());
|
||||
|
@ -221,7 +221,7 @@ TEST_F(NDArrayTest, TestTad2) {
|
|||
TEST_F(NDArrayTest, TestTad3) {
|
||||
auto array = NDArrayFactory::create_<float>('c', {4, 3});
|
||||
|
||||
auto row2 = array->tensorAlongDimension(1, {1});
|
||||
auto row2 = (*array)(1, {0});
|
||||
|
||||
ASSERT_TRUE(row2.isView());
|
||||
ASSERT_EQ(3, row2.lengthOf());
|
||||
|
@ -1529,7 +1529,7 @@ TEST_F(NDArrayTest, TestStdDev1) {
|
|||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(NDArrayTest, TestStdDev2) {
|
||||
auto array = NDArrayFactory::create<double>('c', {5, 6});
|
||||
auto tad = array.tensorAlongDimension(0, {0});
|
||||
auto tad = array(0, {1});
|
||||
|
||||
ASSERT_EQ(5, tad.lengthOf());
|
||||
|
||||
|
|
|
@ -229,7 +229,7 @@ TEST_F(NlpTests, basic_sg_ns_test_1) {
|
|||
auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {1, 1}, {false}, {}, true);
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
|
||||
auto row0 = syn0({1,2, 0,0}, true);
|
||||
auto row0 = syn0({1,2, 0,0}, true);
|
||||
|
||||
ASSERT_EQ(exp0, row0);
|
||||
ASSERT_FALSE(syn1Neg2.equalsTo(syn1Neg, 1e-6));
|
||||
|
@ -418,10 +418,6 @@ TEST_F(NlpTests, test_sg_ns_batch_1) {
|
|||
auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {4, 5}, {false, true}, {}, true);
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
|
||||
auto row0 = syn0({0,0, 0,0}, true);
|
||||
auto row1 = syn0({5,0, 0,0}, true);
|
||||
auto row2 = syn0({2,0, 0,0}, true);
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
|
|
|
@ -328,6 +328,25 @@ TEST_F(ParityOpsTests, TestUnstack12) {
|
|||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(ParityOpsTests, TestUnstack13) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('c', {2, 3});
|
||||
|
||||
sd::ops::unstack op;
|
||||
auto result = op.evaluate({&x}, {}, {1});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
|
||||
ASSERT_EQ(3, result->size());
|
||||
|
||||
for (int e = 0; e < 3; e++)
|
||||
ASSERT_EQ(1, result->at(e)->rankOf());
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
|
||||
|
||||
TEST_F(ParityOpsTests, ExpandDimsTest1) {
|
||||
auto input = NDArrayFactory::create<float>('c', {5, 5});
|
||||
input.linspace(1);
|
||||
|
|
|
@ -216,47 +216,6 @@ TEST_F(ScalarTests, Test_Permute_1) {
|
|||
delete result;
|
||||
}
|
||||
|
||||
|
||||
TEST_F(ScalarTests, Test_Stack_1) {
|
||||
auto t = NDArrayFactory::create<float>(1.0f);
|
||||
auto u = NDArrayFactory::create<float>(2.0f);
|
||||
auto v = NDArrayFactory::create<float>(3.0f);
|
||||
auto exp = NDArrayFactory::create<float>('c', {3}, {1, 2, 3});
|
||||
|
||||
sd::ops::stack op;
|
||||
auto result = op.evaluate({&t, &u, &v}, {}, {0});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(ScalarTests, Test_Stack_2) {
|
||||
auto t = NDArrayFactory::create<float>('c', {1, 1}, {1.0f});
|
||||
auto u = NDArrayFactory::create<float>('c', {1, 1}, {2.0f});
|
||||
auto v = NDArrayFactory::create<float>('c', {1, 1}, {3.0f});
|
||||
auto w = NDArrayFactory::create<float>('c', {1, 1}, {4.0f});
|
||||
auto exp = NDArrayFactory::create<float>('c', {4, 1, 1}, {1, 2, 3, 4});
|
||||
|
||||
sd::ops::stack op;
|
||||
auto result = op.evaluate({&t, &u, &v, &w}, {}, {0});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
|
||||
// z->printShapeInfo("z shape");
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
|
||||
TEST_F(ScalarTests, Test_Concat_Scalar_1) {
|
||||
auto t = NDArrayFactory::create<float>('c', {1, 1}, {1.0f});
|
||||
auto u = NDArrayFactory::create<float>('c', {1, 1}, {2.0f});
|
||||
|
@ -268,7 +227,7 @@ TEST_F(ScalarTests, Test_Concat_Scalar_1) {
|
|||
auto result = op.evaluate({&t, &u, &v, &w}, {}, {0});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
auto z = result->at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
|
|
@ -196,7 +196,7 @@ public:
|
|||
int dimensionLength = 2;
|
||||
int dimension[2] = {2,3};
|
||||
Nd4jLong tadAssertionC[10] = {3,4,4,1,4,1,16,16384,1,99};
|
||||
Nd4jLong tadCAssertionF[10] = {3,4,4,1,1,4,16,16384,1,102};
|
||||
Nd4jLong tadCAssertionF[10] = {3,4,4,1,1,4,1,16384,1,102};
|
||||
};
|
||||
|
||||
|
||||
|
|
|
@ -130,7 +130,7 @@ TEST_F(TadTests, TadEdgeCase_1) {
|
|||
auto exp = NDArrayFactory::create<float>('c', {5, 4});
|
||||
array.linspace(1);
|
||||
|
||||
auto tad = array.tensorAlongDimension(0, {0, 1});
|
||||
auto tad = array(0, {2});
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(tad));
|
||||
}
|
||||
|
@ -140,7 +140,7 @@ TEST_F(TadTests, TestEdgeCase_2) {
|
|||
auto array = NDArrayFactory::create<float>('f', {2, 3, 1}, {1, 4, 2, 5, 3, 6});
|
||||
|
||||
for (int e = 0 ; e < array.lengthOf(); e++) {
|
||||
auto tad = array.tensorAlongDimension(e, {2});
|
||||
auto tad = array(e, {0,1});
|
||||
ASSERT_NEAR(tad.e<float>(0), array.e<float>(e), 1e-5);
|
||||
}
|
||||
}
|
||||
|
@ -148,7 +148,7 @@ TEST_F(TadTests, TestEdgeCase_2) {
|
|||
TEST_F(TadTests, TadEdgeCase_2) {
|
||||
auto array = NDArrayFactory::create<float>('c', {2, 3, 4});
|
||||
|
||||
auto tad = array.tensorAlongDimension(0, {1});
|
||||
auto tad = array(0, {0,2});
|
||||
|
||||
ASSERT_EQ(3, tad.lengthOf());
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue