Oleh tenzor mmul (#231)
* Libnd4j: TensorMMul backprop op #8174, raw implementation Signed-off-by: Oleg <oleg.semeniv@gmail.com> * Libnd4j: TensorMMul backprop op #8174 merge master and some corrections Signed-off-by: Oleg <oleg.semeniv@gmail.com> * Libnd4j: TensorMMul backprop op #8174 algorithm update, need testing, sync with master * Libnd4j: TensorMMul backprop op #8174 fixed incorrect B axes calculation Signed-off-by: Oleg <oleg.semeniv@gmail.com> * Libnd4j: TensorMMul backprop op #8174 optimize axes identification and fix bug of indeces overlapping, added first test. need testing with different shapes Signed-off-by: Oleg <oleg.semeniv@gmail.com> * Libnd4j: TensorMMul backprop op #8174 some fixes and improvements need more testing Signed-off-by: Oleg <oleg.semeniv@gmail.com> * Libnd4j: TensorMMul backprop op #8174 fixed order of matrix multiply Signed-off-by: Oleg <oleg.semeniv@gmail.com> * Libnd4j: TensorMMul backprop op #8174 fixed issue of incorrect axes definition, add tests based on TF, need additional testing for case dLdC not equal 1 Signed-off-by: Oleg <oleg.semeniv@gmail.com> * Libnd4j: TensorMMul backprop op #8174 fixed scalar case add test Signed-off-by: Oleg <oleg.semeniv@gmail.com> * Libnd4j: TensorMMul backprop op #8174 fixed bp algorithm, axes definition, need some mode testing with different orders combination f,c; c,f f,f and add some checks for inputs Signed-off-by: Oleg <oleg.semeniv@gmail.com> * Libnd4j: TensorMMul backprop op #8174 some checks and corrections added tests, exists the problem with different input orders support A-f B-c and A-f B-f Signed-off-by: Oleg <oleg.semeniv@gmail.com> * Libnd4j: TensorMMul backprop op #8174 sync master Signed-off-by: Oleg <oleg.semeniv@gmail.com> * - correct bug in MmulHelper::tensorDot(a, b, c, axes_a, axes_b,permutForC) Signed-off-by: Yurii <iuriish@yahoo.com> * Libnd4j: TensorMMul backprop op #8174 code clean up and refactoring Signed-off-by: Oleg <oleg.semeniv@gmail.com> * - add check for linspase ordered permutations in ShapeUtils::evalShapeForTensorDot Signed-off-by: Yurii <iuriish@yahoo.com> * - provide additional code in shape::reshape stuff in order to reduce amount of allocation/copy operations during reshaping procedure Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on problem of wrong shape evaluation during permute/reshape procedures Signed-off-by: Yurii <iuriish@yahoo.com> * - still looking for bug reason in reshape/permute stuff Signed-off-by: Yurii <iuriish@yahoo.com> * - correct bug in transform cuda native ops Signed-off-by: Yurii <iuriish@yahoo.com> * - correct bug in NDArray::assign Signed-off-by: Yurii <iuriish@yahoo.com> * - remove old shape::reshape stuff Signed-off-by: Yurii <iuriish@yahoo.com> * - add possibility to disable copy of old buffer to new buffer during reshape operation in NDArray class Signed-off-by: Yurii <iuriish@yahoo.com> * - correct bug in tensorDot which had to do with wrong pointers assigments Signed-off-by: Yurii <iuriish@yahoo.com> Co-authored-by: Oleh <oleg.semeniv@gmail.com>master
parent
8c0e378ec3
commit
fe47f52896
|
@ -999,14 +999,14 @@ namespace nd4j {
|
|||
* set new order and shape in case of suitable array length (in-place operation)
|
||||
* order - order to set
|
||||
* shape - shape to set
|
||||
*
|
||||
* copyToNewBuff - if true then old buffer will be copied to new buffer if last one will be allocated after reshaping
|
||||
* if there was permute applied before or there are weird strides, then new buffer is allocated for array
|
||||
*/
|
||||
bool reshapei(const char order, const std::initializer_list<Nd4jLong>& shape);
|
||||
bool reshapei(const char order, const std::vector<Nd4jLong>& shape);
|
||||
bool reshapei(const char order, const std::initializer_list<Nd4jLong>& shape, const bool copyToNewBuff = true);
|
||||
bool reshapei(const char order, const std::vector<Nd4jLong>& shape, const bool copyToNewBuff = true);
|
||||
|
||||
bool reshapei(const std::initializer_list<Nd4jLong>& shape);
|
||||
bool reshapei(const std::vector<Nd4jLong>& shape);
|
||||
bool reshapei(const std::initializer_list<Nd4jLong>& shape, const bool copyToNewBuff = true);
|
||||
bool reshapei(const std::vector<Nd4jLong>& shape, const bool copyToNewBuff = true);
|
||||
|
||||
/**
|
||||
* creates new array with corresponding order and shape, new array will point on _buffer of this array
|
||||
|
@ -1015,8 +1015,8 @@ namespace nd4j {
|
|||
*
|
||||
* if permute have been applied before or there are weird strides, then new buffer is allocated for new array
|
||||
*/
|
||||
NDArray reshape(const char order, const std::vector<Nd4jLong>& shape) const &;
|
||||
NDArray reshape(const char order, const std::vector<Nd4jLong>& shape) &&;
|
||||
NDArray reshape(const char order, const std::vector<Nd4jLong>& shape, const bool copyToNewBuff = true) const &;
|
||||
NDArray reshape(const char order, const std::vector<Nd4jLong>& shape, const bool copyToNewBuff = true) &&;
|
||||
|
||||
/**
|
||||
* calculate strides and set given order
|
||||
|
@ -1493,7 +1493,7 @@ namespace nd4j {
|
|||
* @return
|
||||
*/
|
||||
bool isS() const;
|
||||
|
||||
|
||||
template <typename T>
|
||||
std::vector<T> asVectorT();
|
||||
|
||||
|
|
|
@ -42,7 +42,7 @@ ND4J_EXPORT std::u32string NDArray::e(const Nd4jLong i) const;
|
|||
////////////////////////////////////////////////////////////////////////
|
||||
// copy constructor
|
||||
NDArray::NDArray(const NDArray& other) {
|
||||
|
||||
|
||||
_context = other._context;
|
||||
_offset = 0;
|
||||
|
||||
|
@ -308,7 +308,7 @@ NDArray::NDArray(const std::u16string& u16string, nd4j::DataType dtype, nd4j::La
|
|||
if (!unicode::isStringValidU16(u16string.data(), u16string.data() + u16string.size())) {
|
||||
throw std::invalid_argument("NDArray::NDArray: invalid character in input string");
|
||||
}
|
||||
|
||||
|
||||
// one word that is why used 1
|
||||
Nd4jLong headerLength = ShapeUtils::stringBufferHeaderRequirements(1);
|
||||
|
||||
|
@ -435,11 +435,11 @@ NDArray::NDArray(const std::string& str, nd4j::DataType dtype, nd4j::LaunchConte
|
|||
_offset = 0;
|
||||
|
||||
setShapeInfo(ShapeDescriptor::scalarDescriptor(dtype));
|
||||
|
||||
|
||||
memcpy(bufferAsT<int8_t>(), &offsets[0], 2 * sizeof(Nd4jLong));
|
||||
|
||||
|
||||
auto data = reinterpret_cast<int8_t*>(bufferAsT<int8_t>() + headerLength);
|
||||
|
||||
|
||||
if (dtype == DataType::UTF8) {
|
||||
memcpy(data, str.data(), str.size());
|
||||
}
|
||||
|
@ -456,13 +456,13 @@ NDArray::NDArray(const std::string& str, nd4j::DataType dtype, nd4j::LaunchConte
|
|||
/////////////////////////////////////////////////////////////////////////
|
||||
// constructors for vector of strings
|
||||
NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<const char*>& string, const nd4j::DataType dataType, nd4j::LaunchContext* context) {
|
||||
|
||||
|
||||
if (!DataTypeUtils::isS(dataType))
|
||||
throw std::invalid_argument("NDArray::NDArray: invalid DataType, only string dataTypes have to be used");
|
||||
|
||||
if (shape::prodLong(shape.data(), shape.size()) != string.size())
|
||||
throw std::invalid_argument("NDArray::NDArray: Number of strings should match length of array");
|
||||
|
||||
|
||||
for (const auto& str : string) {
|
||||
if (!unicode::isStringValidU8(str, str + std::char_traits<char>::length(str)) ) {
|
||||
throw std::invalid_argument("NDArray::NDArray: invalid character in input string");
|
||||
|
@ -497,7 +497,7 @@ NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<const cha
|
|||
setAttached(context->getWorkspace() != nullptr);
|
||||
|
||||
memcpy(bufferAsT<int8_t>(), offsets.data(), offsets.size() * sizeof(Nd4jLong));
|
||||
|
||||
|
||||
auto data = reinterpret_cast<int8_t*>(bufferAsT<int8_t>() + headerLength);
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR{
|
||||
|
@ -631,9 +631,9 @@ NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<std::u16s
|
|||
setAttached(context->getWorkspace() != nullptr);
|
||||
|
||||
memcpy(bufferAsT<int8_t>(), offsets.data(), offsets.size() * sizeof(Nd4jLong));
|
||||
|
||||
|
||||
auto data = reinterpret_cast<int8_t*>(bufferAsT<int8_t>() + headerLength);
|
||||
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR{
|
||||
for (auto e = start; e < stop; e += increment) {
|
||||
auto cdata = data + offsets[e];
|
||||
|
@ -699,7 +699,7 @@ NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<const cha
|
|||
|
||||
auto data = reinterpret_cast<int8_t*>(bufferAsT<int8_t>() + headerLength);
|
||||
|
||||
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR{
|
||||
for (auto e = start; e < stop; e += increment) {
|
||||
auto cdata = data + offsets[e];
|
||||
|
@ -715,7 +715,7 @@ NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<const cha
|
|||
}
|
||||
};
|
||||
samediff::Threads::parallel_for(func, 0, lengthOf(), 1);
|
||||
|
||||
|
||||
tickWriteHost();
|
||||
syncToDevice();
|
||||
}
|
||||
|
@ -764,8 +764,8 @@ NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<std::u32s
|
|||
|
||||
memcpy(bufferAsT<int8_t>(), offsets.data(), offsets.size() * sizeof(Nd4jLong));
|
||||
|
||||
auto data = reinterpret_cast<int8_t*>(bufferAsT<int8_t>() + headerLength);
|
||||
|
||||
auto data = reinterpret_cast<int8_t*>(bufferAsT<int8_t>() + headerLength);
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR{
|
||||
for (auto e = start; e < stop; e += increment) {
|
||||
auto cdata = data + offsets[e];
|
||||
|
@ -781,7 +781,7 @@ NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<std::u32s
|
|||
}
|
||||
};
|
||||
samediff::Threads::parallel_for(func, 0, lengthOf(), 1);
|
||||
|
||||
|
||||
tickWriteHost();
|
||||
syncToDevice();
|
||||
}
|
||||
|
@ -831,7 +831,7 @@ NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<const cha
|
|||
memcpy(bufferAsT<int8_t>(), offsets.data(), offsets.size() * sizeof(Nd4jLong));
|
||||
|
||||
auto data = reinterpret_cast<int8_t*>(bufferAsT<int8_t>() + headerLength);
|
||||
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR{
|
||||
for (auto e = start; e < stop; e += increment) {
|
||||
auto cdata = data + offsets[e];
|
||||
|
@ -847,7 +847,7 @@ NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<const cha
|
|||
}
|
||||
};
|
||||
samediff::Threads::parallel_for(func, 0, lengthOf(), 1);
|
||||
|
||||
|
||||
tickWriteHost();
|
||||
syncToDevice();
|
||||
}
|
||||
|
@ -887,8 +887,8 @@ bool NDArray::isC() const {
|
|||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
bool NDArray::isS() const {
|
||||
return (dataType() == DataType::UTF8 ||
|
||||
dataType() == DataType::UTF16 ||
|
||||
return (dataType() == DataType::UTF8 ||
|
||||
dataType() == DataType::UTF16 ||
|
||||
dataType() == DataType::UTF32);
|
||||
}
|
||||
|
||||
|
@ -1197,8 +1197,8 @@ void NDArray::assign(const NDArray& other, bool allowParallelism) {
|
|||
throw std::runtime_error("NDArray::assign: lengths of arrays are mismatched");
|
||||
}
|
||||
|
||||
// memcpy is allowed only for same order && same ews (being equal to 1)
|
||||
if (ordering() == other.ordering() && dataType() == other.dataType() && ews() == 1 && other.ews() == 1)
|
||||
// memcpy is allowed only for same order c && same ews (being equal to 1)
|
||||
if (ordering() == other.ordering() && ordering() == 'c' && dataType() == other.dataType() && ews() == 1 && other.ews() == 1)
|
||||
copyBuffersContinuouslyFrom(other, other.lengthOf() * other.sizeOfT());
|
||||
else {
|
||||
NDArray::prepareSpecialUse({this}, {&other});
|
||||
|
@ -1569,20 +1569,25 @@ Nd4jLong NDArray::tensorsAlongDimension(const std::vector<int>& dimensions) cons
|
|||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void NDArray::printShapeInfo(const char * msg) const {
|
||||
//shape::printShapeInfo(_shapeInfo);
|
||||
if (msg == nullptr)
|
||||
shape::printShapeInfoLinear(_shapeInfo);
|
||||
else {
|
||||
int rank = shape::rank(_shapeInfo);
|
||||
int lim = shape::shapeInfoLength(rank);
|
||||
printf("%s: [", msg);
|
||||
for (int i = 0; i < shape::shapeInfoLength(rank); i++) {
|
||||
printf("%lld", (long long) _shapeInfo[i]);
|
||||
if (i < lim - 1)
|
||||
printf(", ");
|
||||
}
|
||||
printf("]\n");
|
||||
|
||||
int rank = shape::rank(_shapeInfo);
|
||||
int lim = shape::shapeInfoLength(rank);
|
||||
|
||||
if(msg != nullptr)
|
||||
printf("shapeInfo %s: [", msg);
|
||||
else
|
||||
printf("shapeInfo: [");
|
||||
|
||||
printf("%i, ", rank);
|
||||
for (int i = 1; i < shape::shapeInfoLength(rank) - 3; i++){
|
||||
if(i == rank + 1)
|
||||
printf(" ");
|
||||
printf("%lld,", _shapeInfo[i]);
|
||||
}
|
||||
printf(" %lld,", shape::type(_shapeInfo));
|
||||
printf("%lld,", shape::elementWiseStride(_shapeInfo));
|
||||
printf("%lld]\n", (Nd4jLong)shape::order(_shapeInfo));
|
||||
|
||||
fflush(stdout);
|
||||
}
|
||||
|
||||
|
@ -1624,7 +1629,7 @@ void NDArray::printBuffer(const char* msg, Nd4jLong limit, const bool sync) cons
|
|||
if (e < limit - 1)
|
||||
printf(", ");
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (this->isS()) {
|
||||
// todo do we need this print offsets
|
||||
/*
|
||||
|
@ -1773,7 +1778,7 @@ void NDArray::printIndexedBuffer(const char* msg, Nd4jLong limit) const {
|
|||
printf("%s\n", this->e<bool>(0)?"true":"false");
|
||||
}
|
||||
else if (this->isS()) {
|
||||
// todo do we need this
|
||||
// todo do we need this
|
||||
// printf("\"%lld\"\n", this->getOffset(e));
|
||||
printf("\"%s\"\n", this->e<std::string>(0).c_str());
|
||||
}
|
||||
|
@ -1855,19 +1860,19 @@ void NDArray::updateStrides(const char order) {
|
|||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
// set new order and shape in case of suitable array length
|
||||
bool NDArray::reshapei(const char order, const std::initializer_list<Nd4jLong>& shape) {
|
||||
bool NDArray::reshapei(const char order, const std::initializer_list<Nd4jLong>& shape, const bool copyToNewBuff) {
|
||||
std::vector<Nd4jLong> vShape(shape);
|
||||
return reshapei(order, vShape);
|
||||
return reshapei(order, vShape, copyToNewBuff);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
bool NDArray::reshapei(const std::initializer_list<Nd4jLong>& shape) {
|
||||
return reshapei('c', shape);
|
||||
bool NDArray::reshapei(const std::initializer_list<Nd4jLong>& shape, const bool copyToNewBuff) {
|
||||
return reshapei(ordering(), shape, copyToNewBuff);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
bool NDArray::reshapei(const std::vector<Nd4jLong>& shape) {
|
||||
return reshapei('c', shape);
|
||||
bool NDArray::reshapei(const std::vector<Nd4jLong>& shape, const bool copyToNewBuff) {
|
||||
return reshapei(ordering(), shape, copyToNewBuff);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
|
@ -1918,18 +1923,18 @@ Nd4jLong NDArray::argMax(std::initializer_list<int> dimensions) {
|
|||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
// create new array with corresponding order and shape, new array will point to the same _buffer as this array
|
||||
NDArray NDArray::reshape(const char order, const std::vector<Nd4jLong>& shape) const & {
|
||||
NDArray NDArray::reshape(const char order, const std::vector<Nd4jLong>& shape, const bool copyToNewBuff) const & {
|
||||
|
||||
NDArray newArr(getDataBuffer(), ShapeDescriptor(getShapeInfo()), getContext(), getBufferOffset());
|
||||
newArr.reshapei(order, shape);
|
||||
newArr.reshapei(order, shape, copyToNewBuff);
|
||||
|
||||
return newArr;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
NDArray NDArray::reshape(const char order, const std::vector<Nd4jLong>& shape) && {
|
||||
NDArray NDArray::reshape(const char order, const std::vector<Nd4jLong>& shape, const bool copyToNewBuff) && {
|
||||
|
||||
this->reshapei(order, shape);
|
||||
this->reshapei(order, shape, copyToNewBuff);
|
||||
return std::move(*this);
|
||||
}
|
||||
|
||||
|
@ -2280,7 +2285,7 @@ template <typename T>
|
|||
NDArray NDArray::asT() const{
|
||||
|
||||
auto result = isScalar() ? NDArray('c', {}, std::vector<double>{0.}, DataTypeUtils::fromT<T>(), this->getContext()) : NDArray(ordering(), getShapeAsVector(), DataTypeUtils::fromT<T>(), this->getContext());
|
||||
|
||||
|
||||
NDArray::prepareSpecialUse({&result}, {this});
|
||||
NativeOpExecutioner::execTransformAny(getContext(), transform::AnyOps::Assign, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), result.getBuffer(), result.getShapeInfo(), result.getSpecialBuffer(), result.getSpecialShapeInfo(), nullptr, nullptr, nullptr);
|
||||
NDArray::registerSpecialUse({&result}, {this});
|
||||
|
@ -2298,15 +2303,15 @@ NDArray NDArray::asS() const {
|
|||
|
||||
auto dtype = DataTypeUtils::fromT<T>();
|
||||
|
||||
if (!(DataTypeUtils::isS(dtype)))
|
||||
if (!(DataTypeUtils::isS(dtype)))
|
||||
throw std::invalid_argument("NDArray::asS: invalid DataType used");
|
||||
|
||||
|
||||
if (dtype == dataType()) {
|
||||
|
||||
|
||||
Nd4jLong offsetsLength = ShapeUtils::stringBufferHeaderRequirements(lengthOf());
|
||||
const auto nInputoffsets = bufferAsT<Nd4jLong>();
|
||||
std::shared_ptr<DataBuffer> pBuffer = std::make_shared<DataBuffer>(offsetsLength + nInputoffsets[lengthOf()], dtype, getContext()->getWorkspace(), true);
|
||||
|
||||
|
||||
NDArray res(pBuffer, ShapeDescriptor(dtype, ordering(), getShapeAsVector()), getContext());
|
||||
res.setAttached(getContext()->getWorkspace() != nullptr);
|
||||
|
||||
|
@ -2319,7 +2324,7 @@ NDArray NDArray::asS() const {
|
|||
registerPrimaryUse({ &res }, { this });
|
||||
return res;
|
||||
}
|
||||
|
||||
|
||||
Nd4jLong offsetsLength = ShapeUtils::stringBufferHeaderRequirements(lengthOf());
|
||||
|
||||
std::vector<Nd4jLong> offsets(lengthOf() + 1);
|
||||
|
@ -2353,7 +2358,7 @@ NDArray NDArray::asS() const {
|
|||
|
||||
NDArray res(pBuffer, ShapeDescriptor(dtype, ordering(), getShapeAsVector()), getContext());
|
||||
res.setAttached(getContext()->getWorkspace() != nullptr);
|
||||
|
||||
|
||||
preparePrimaryUse({ &res }, { this });
|
||||
|
||||
memcpy(res.bufferAsT<int8_t>(), offsets.data(), offsets.size() * sizeof(Nd4jLong));
|
||||
|
@ -2403,7 +2408,7 @@ BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT NDArray NDArray::asS, () const, LIBND
|
|||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
NDArray NDArray::asT(DataType dtype) const {
|
||||
|
||||
|
||||
if (isS() && !DataTypeUtils::isS(dtype))
|
||||
throw std::runtime_error("NDArray::asT: you can't use this method on String array with not string DataType!");
|
||||
|
||||
|
@ -3221,7 +3226,7 @@ BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT std::vector, NDArray::asVectorT(), LI
|
|||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
// set new order and shape in case of suitable array length
|
||||
bool NDArray::reshapei(const char order, const std::vector<Nd4jLong>& cshape) {
|
||||
bool NDArray::reshapei(const char order, const std::vector<Nd4jLong>& cshape, const bool copyToNewBuff) {
|
||||
|
||||
// check firstly whether cshape is identical to shape of array, if yes then reshape is unnecessary
|
||||
if(order == ordering() && shape::shapeEquals(rankOf(), shapeOf(), cshape.size(), cshape.data()))
|
||||
|
@ -3293,19 +3298,15 @@ bool NDArray::reshapei(const char order, const std::vector<Nd4jLong>& cshape) {
|
|||
Nd4jLong *shapeInfoNew;
|
||||
ALLOCATE(shapeInfoNew, getContext()->getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong);
|
||||
|
||||
bool canReshape = shape::reshapeC(rankOf(), shapeInfo(), shape.size(), shape.data(), shapeInfoNew);
|
||||
bool canReshape = shape::reshapeC(shapeInfo(), order, shape.size(), shape.data(), shapeInfoNew);
|
||||
|
||||
// we can do this only if there was no permute applied, or there are no weird strides
|
||||
if (canReshape) {
|
||||
if(ordering() == 'c' && order == 'f')
|
||||
throw std::invalid_argument("NDArray::reshapei(order, shape): in case of reshapeC it doesn't make sense to reshape from c order to f order !");
|
||||
|
||||
shape::setEws(shapeInfoNew, arrLength);
|
||||
setShapeInfo(shapeInfoNew);
|
||||
}
|
||||
else {
|
||||
NDArray temp(order, shape, dataType(), getContext());
|
||||
this->applyTransform(transform::Assign, temp, nullptr);
|
||||
if(copyToNewBuff)
|
||||
this->applyTransform(transform::Assign, temp, nullptr);
|
||||
*this = std::move(temp);
|
||||
}
|
||||
|
||||
|
@ -3463,7 +3464,7 @@ NDArray NDArray::dup(const char newOrder) const {
|
|||
if (isS()) {
|
||||
if (dataType() == DataType::UTF8) {
|
||||
std::vector<std::string> strings(lengthOf());
|
||||
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR{
|
||||
for (auto i = start; i < stop; i += increment) {
|
||||
strings[i] = std::move(this->e<std::string>(i));
|
||||
|
@ -3521,7 +3522,7 @@ bool NDArray::equalsTo(const NDArray *other, double eps) const {
|
|||
|
||||
if (isS()) {
|
||||
// string is special case, we'll compare them one by one, considering both arrays are guaranteed to have the same length
|
||||
|
||||
|
||||
if (dataType() == DataType::UTF8) {
|
||||
for (int e = 0; e < this->lengthOf(); e++) {
|
||||
auto s1 = this->e<std::string>(e);
|
||||
|
@ -3585,7 +3586,7 @@ std::string NDArray::e(const Nd4jLong i) const {
|
|||
if (i == lengthOf())
|
||||
throw std::runtime_error("Can't get std::string for index out of range");
|
||||
|
||||
|
||||
|
||||
if (this->dataType() == DataType::UTF16) {
|
||||
auto u16 = this->e<std::u16string>(i);
|
||||
std::string s;
|
||||
|
@ -4846,7 +4847,7 @@ NDArray NDArray::operator()(const std::vector<Nd4jLong>& idx, const bool keepUni
|
|||
auto shapeOf = shape::shapeOf(newShapeInfo);
|
||||
auto stridesOf = shape::stride(newShapeInfo);
|
||||
|
||||
Nd4jLong offset(0), subArrLen(1);
|
||||
Nd4jLong offset = 0;
|
||||
int n(isStrided ? 3 : 2), first, last, stride;
|
||||
|
||||
for (int d = rank - 1; d >= 0; --d) {
|
||||
|
@ -4863,29 +4864,31 @@ NDArray NDArray::operator()(const std::vector<Nd4jLong>& idx, const bool keepUni
|
|||
if(shapeOf[d] != 1)
|
||||
stridesOf[d] *= stride;
|
||||
}
|
||||
}
|
||||
|
||||
subArrLen *= shapeOf[d];
|
||||
Nd4jLong *shapeInfoNoUnities = newShapeInfo;
|
||||
|
||||
if(!keepUnitiesInShape) {
|
||||
|
||||
std::vector<int> dimsWithUnities;
|
||||
|
||||
for (uint d = 0; d < rank; ++d)
|
||||
if(idx[n*d] != idx[n*d+1] && shapeOf[d] == 1)
|
||||
dimsWithUnities.push_back(d);
|
||||
|
||||
if(!dimsWithUnities.empty())
|
||||
shapeInfoNoUnities = ShapeBuilders::copyShapeInfoWithoutUnites(newShapeInfo, dimsWithUnities.size(), dimsWithUnities.data(), getContext()->getWorkspace());
|
||||
}
|
||||
|
||||
// check if there is possibility to set ews = 1
|
||||
shape::setEws(newShapeInfo, subArrLen);
|
||||
shape::checkStridesSetEwsAndOrder(shapeInfoNoUnities);
|
||||
|
||||
NDArray result(_buffer, ShapeDescriptor(newShapeInfo), getContext(), offset + getBufferOffset());
|
||||
NDArray result(_buffer, ShapeDescriptor(shapeInfoNoUnities), getContext(), offset + getBufferOffset());
|
||||
result._isView = true;
|
||||
|
||||
if(!keepUnitiesInShape) {
|
||||
const int coeff = isStrided ? 3 : 2;
|
||||
std::vector<Nd4jLong> nonUnitDims;
|
||||
|
||||
for (int d = 0; d < rank; ++d)
|
||||
if(!(idx[coeff*d] != idx[coeff*d+1] && newShapeInfo[d+1] == 1))
|
||||
nonUnitDims.push_back(newShapeInfo[d+1]);
|
||||
|
||||
if(nonUnitDims.size() != rank)
|
||||
result.reshapei(nonUnitDims);
|
||||
}
|
||||
|
||||
RELEASE(newShapeInfo, getContext()->getWorkspace());
|
||||
if(newShapeInfo != shapeInfoNoUnities)
|
||||
RELEASE(shapeInfoNoUnities, getContext()->getWorkspace());
|
||||
|
||||
return result;
|
||||
}
|
||||
|
|
|
@ -30,15 +30,15 @@
|
|||
|
||||
namespace nd4j {
|
||||
class ND4J_EXPORT ShapeBuilders {
|
||||
public:
|
||||
public:
|
||||
static Nd4jLong* createScalarShapeInfo(nd4j::DataType dataType, nd4j::memory::Workspace* workspace = nullptr);
|
||||
|
||||
|
||||
static Nd4jLong* createVectorShapeInfo(const nd4j::DataType dataType, const Nd4jLong length, nd4j::memory::Workspace* workspace = nullptr);
|
||||
|
||||
/**
|
||||
* create shapeInfo for given order basing on shape stored in shapeOnly vector
|
||||
* memory allocation for shapeInfo is on given workspace
|
||||
*/
|
||||
*/
|
||||
static Nd4jLong* createShapeInfo(const nd4j::DataType dataType, const char order, int rank, const Nd4jLong* shapeOnly, memory::Workspace* workspace = nullptr);
|
||||
static Nd4jLong* createShapeInfo(const nd4j::DataType dataType, const char order, const std::vector<Nd4jLong>& shapeOnly, memory::Workspace* workspace = nullptr);
|
||||
static Nd4jLong* createShapeInfo(const nd4j::DataType dataType, const char order, const std::initializer_list<Nd4jLong>& shapeOnly, memory::Workspace* workspace = nullptr);
|
||||
|
@ -51,6 +51,13 @@ namespace nd4j {
|
|||
static Nd4jLong* copyShapeInfoAndType(const Nd4jLong* inShapeInfo, const DataType dtype, const bool copyStrides, memory::Workspace* workspace = nullptr);
|
||||
static Nd4jLong* copyShapeInfoAndType(const Nd4jLong* inShapeInfo, const Nd4jLong* shapeInfoToGetTypeFrom, const bool copyStrides, memory::Workspace* workspace = nullptr);
|
||||
|
||||
/**
|
||||
* allocates memory for new shapeInfo and copy all information from inShapeInfo to new shapeInfo except dimensions in dimsToExclude (unit dimensions) and corresponding strides
|
||||
* for example inShapeInfo is {3, 2,1,3,1,4, 12,12,4,4,1, 16384,1,99}, dimsToExclude = {2,3}, dimsSize = 2
|
||||
* then outShapeInfo will contain {3, 2,3,4, 12,4,1, 16384,1,99}
|
||||
*/
|
||||
static Nd4jLong* copyShapeInfoWithoutUnites(const Nd4jLong* inShapeInfo, const int dimsSize, const int* dimsToExclude, memory::Workspace* workspace = nullptr);
|
||||
|
||||
static Nd4jLong* emptyShapeInfo(const nd4j::DataType dataType, memory::Workspace* workspace = nullptr);
|
||||
|
||||
static Nd4jLong* emptyShapeInfo(const nd4j::DataType dataType, const char order, const std::vector<Nd4jLong> &shape, memory::Workspace* workspace = nullptr);
|
||||
|
|
|
@ -68,7 +68,7 @@ namespace nd4j {
|
|||
const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(shapeInfo, dimsToExclude);
|
||||
const int subArrRank = (rank == dimsToExclude.size() || descriptor.areUnitiesinShape()) ? rank : rank - dimsToExclude.size();
|
||||
|
||||
auto sPtr = new Nd4jLong[shape::shapeInfoLength(subArrRank)];
|
||||
auto sPtr = new Nd4jLong[shape::shapeInfoLength(subArrRank)]; // shape of sub-arrays (same for all for them)
|
||||
auto oPtr = new Nd4jLong[numOfSubArrs];
|
||||
|
||||
if (numOfSubArrs > 0)
|
||||
|
|
|
@ -43,23 +43,30 @@ nd4j::NDArray* nd4j::MmulHelper::tensorDot(const nd4j::NDArray* a, const nd4j::N
|
|||
|
||||
auto outShape = ShapeUtils::evalShapeForTensorDot(a, b, axes_0, axes_1, permutAt, permutBt, shapeAt, shapeBt);
|
||||
|
||||
NDArray aPR = a->permute(permutAt);
|
||||
NDArray bPR = b->permute(permutBt);
|
||||
// check whether permutation is necessary
|
||||
const NDArray* aP = permutAt.empty() ? a : new NDArray(a->permute(permutAt));
|
||||
const NDArray* bP = permutBt.empty() ? b : new NDArray(b->permute(permutBt));
|
||||
|
||||
// check whether reshape is necessary
|
||||
if(!aPR.isSameShape(shapeAt))
|
||||
aPR.reshapei( shapeAt);
|
||||
if(!bPR.isSameShape(shapeBt))
|
||||
bPR.reshapei( shapeBt);
|
||||
const NDArray* aPR = aP->isSameShape(shapeAt) ? aP : new NDArray(aP->reshape(aP->ordering(), shapeAt));
|
||||
const NDArray* bPR = bP->isSameShape(shapeAt) ? bP : new NDArray(bP->reshape(bP->ordering(), shapeBt));
|
||||
|
||||
NDArray* c = mmul(&aPR, &bPR, nullptr, 1.0, 0.0);
|
||||
NDArray* c = mmul(aPR, bPR, nullptr, 1.0, 0.0);
|
||||
|
||||
c->reshapei(outShape);
|
||||
|
||||
if(aP != aPR)
|
||||
delete aPR;
|
||||
if(bP != bPR)
|
||||
delete bPR;
|
||||
if(a != aP)
|
||||
delete aP;
|
||||
if(b != bP)
|
||||
delete bP;
|
||||
|
||||
return c;
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void nd4j::MmulHelper::tensorDot(const nd4j::NDArray* a, const nd4j::NDArray* b, nd4j::NDArray* c, const std::vector<int>& axes_a, const std::vector<int>& axes_b, const std::vector<int>& permutForC) {
|
||||
|
||||
|
@ -67,32 +74,38 @@ void nd4j::MmulHelper::tensorDot(const nd4j::NDArray* a, const nd4j::NDArray* b,
|
|||
std::vector<Nd4jLong> shapeAt, shapeBt;
|
||||
ShapeUtils::evalShapeForTensorDot(a, b, axes_a, axes_b, permutAt, permutBt, shapeAt, shapeBt);
|
||||
|
||||
NDArray *cP(c), *cPR(c);
|
||||
|
||||
// check whether permutation is required
|
||||
if(!permutForC.empty())
|
||||
cP = new NDArray(c->permute(permutForC));
|
||||
NDArray* cP = permutForC.empty() ? c : new NDArray(c->permute(permutForC));
|
||||
|
||||
auto aPR = a->permute(permutAt);
|
||||
auto bPR = b->permute(permutBt);
|
||||
// check whether permutation is necessary
|
||||
const NDArray* aP = permutAt.empty() ? a : new NDArray(a->permute(permutAt));
|
||||
const NDArray* bP = permutBt.empty() ? b : new NDArray(b->permute(permutBt));
|
||||
|
||||
// check whether reshape is necessary
|
||||
if(!aPR.isSameShape(shapeAt))
|
||||
aPR.reshapei(shapeAt);
|
||||
if(!bPR.isSameShape(shapeBt))
|
||||
bPR.reshapei(shapeBt);
|
||||
const NDArray* aPR = aP->isSameShape(shapeAt) ? aP : new NDArray(aP->reshape(aP->ordering(), shapeAt));
|
||||
const NDArray* bPR = bP->isSameShape(shapeAt) ? bP : new NDArray(bP->reshape(bP->ordering(), shapeBt));
|
||||
|
||||
if(!cP->isSameShape({aPR.sizeAt(0), bPR.sizeAt(1)}))
|
||||
cPR = new NDArray(cP->reshape(cP->ordering(), {aPR.sizeAt(0), bPR.sizeAt(1)}));
|
||||
std::vector<Nd4jLong> requiredCshape = {aPR->sizeAt(0), bPR->sizeAt(1)};
|
||||
|
||||
mmul(&aPR, &bPR, cPR, 1.0, 0.0);
|
||||
NDArray* cPR = cP->isSameShape(requiredCshape) ? cP : new NDArray(cP->reshape(cP->ordering(), requiredCshape, false));
|
||||
|
||||
mmul(aPR, bPR, cPR, 1.0, 0.0);
|
||||
|
||||
if(cPR->getBuffer() != cP->getBuffer() || cPR->getSpecialBuffer() != cP->getSpecialBuffer() ) // this means both permute and reshape have been performed on c, cP always points on c->getBuffer()
|
||||
cP->assign(cPR);
|
||||
|
||||
if(cPR != c)
|
||||
if(aP != aPR)
|
||||
delete aPR;
|
||||
if(bP != bPR)
|
||||
delete bPR;
|
||||
if(a != aP)
|
||||
delete aP;
|
||||
if(b != bP)
|
||||
delete bP;
|
||||
|
||||
if(cP != cPR)
|
||||
delete cPR;
|
||||
if(cP != c)
|
||||
if(c != cP)
|
||||
delete cP;
|
||||
}
|
||||
|
||||
|
@ -129,7 +142,7 @@ void nd4j::MmulHelper::tensorDot(const NDArray* a, const NDArray* b, NDArray* c,
|
|||
if(!whatToDoWithC.empty()) {
|
||||
cArrs = std::vector<NDArray*>(whatToDoWithC.size()+1, c);
|
||||
for(int i = 0; i < cArrs.size()-1; ++i)
|
||||
cArrs[i+1] = (whatToDoWithC[i] == 'p') ? new NDArray(cArrs[i]->permute(modifC[i])) : new NDArray(cArrs[i]->reshape(c->ordering(), modifC[i])); // since we ignore first element in cArrs (that is cArrs[0]) then it is always equal to c
|
||||
cArrs[i+1] = (whatToDoWithC[i] == 'p') ? new NDArray(cArrs[i]->permute(modifC[i])) : new NDArray(cArrs[i]->reshape(c->ordering(), modifC[i], false)); // since we ignore first element in cArrs (that is cArrs[0]) then it is always equal to c
|
||||
}
|
||||
|
||||
mmul(aPR, bPR, cArrs[cArrs.size()-1], 1.0, 0.0);
|
||||
|
@ -208,7 +221,7 @@ nd4j::NDArray* MmulHelper::mmul(const nd4j::NDArray* A, const nd4j::NDArray* B,
|
|||
// vector x matrix, A{M} x B{M,N} = C{N} -> reduce to matrix x matrix A2{1,M} x B{M,N} = C2{1,N}, since there is no corresponding blas operation sgevm
|
||||
if(isAVector && bRank == 2) {
|
||||
NDArray* A2 = new NDArray(A->reshape(A->ordering(), {1, A->lengthOf()})); // A{M} -> A2{1,M}
|
||||
NDArray* C2 = C ? new NDArray(C->reshape(C->ordering(), {1, C->lengthOf()})) : nullptr; // C{N} -> C2{1,N}
|
||||
NDArray* C2 = C ? new NDArray(C->reshape(C->ordering(), {1, C->lengthOf()}, false)) : nullptr; // C{N} -> C2{1,N}
|
||||
auto result = mmulMxM(A2, B, C2, alpha, beta, outOrder); // result{1,N}
|
||||
delete A2;
|
||||
delete C2;
|
||||
|
|
|
@ -139,5 +139,15 @@ namespace nd4j {
|
|||
return ShapeBuilders::copyShapeInfoAndType(inShapeInfo, ArrayOptions::dataType(shapeInfoToGetTypeFrom), copyStrides, workspace);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
Nd4jLong* ShapeBuilders::copyShapeInfoWithoutUnites(const Nd4jLong* inShapeInfo, const int dimsSize, const int* dimsToExclude, memory::Workspace* workspace) {
|
||||
|
||||
Nd4jLong *outShapeInfo = nullptr;
|
||||
ALLOCATE(outShapeInfo, workspace, shape::shapeInfoLength(inShapeInfo[0] - dimsSize), Nd4jLong);
|
||||
|
||||
shape::excludeUnitiesFromShapeInfo(inShapeInfo, dimsSize, dimsToExclude, outShapeInfo);
|
||||
|
||||
return outShapeInfo;
|
||||
}
|
||||
|
||||
}
|
|
@ -75,10 +75,23 @@ std::vector<Nd4jLong> ShapeUtils::evalShapeForTensorDot(const Nd4jLong* aShapeIn
|
|||
permutBt = axesB;
|
||||
permutBt.insert(permutBt.end(), list_B.begin(), list_B.end());
|
||||
|
||||
// if permut contains something like {0,1,2,..rank-1}, then there is no need to make permutation and we return empty vector in this case
|
||||
uint i1, i2;
|
||||
for(i1 = 0; i1 < aRank; ++i1)
|
||||
if(permutAt[i1] != i1)
|
||||
break;
|
||||
if(i1 == aRank)
|
||||
permutAt = {};
|
||||
for(i2 = 0; i2 < bRank; ++i2)
|
||||
if(permutBt[i2] != i2)
|
||||
break;
|
||||
if(i2 == bRank)
|
||||
permutBt = {};
|
||||
|
||||
Nd4jLong n2 = 1;
|
||||
for (int i = 0; i < axeAsize; i++)
|
||||
n2 *= aShapeInfo[axesA[i] + 1];
|
||||
shapeAt = {-1, n2};
|
||||
shapeAt = {shape::length(aShapeInfo) / n2, n2};
|
||||
|
||||
std::vector<Nd4jLong> oldShapeA;
|
||||
oldShapeA.resize(list_A.size());
|
||||
|
@ -89,7 +102,7 @@ std::vector<Nd4jLong> ShapeUtils::evalShapeForTensorDot(const Nd4jLong* aShapeIn
|
|||
Nd4jLong n3 = 1;
|
||||
for (int i = 0; i < axeBsize; i++)
|
||||
n3 *= bShapeInfo[axesB[i] + 1];
|
||||
shapeBt = {n3, -1};
|
||||
shapeBt = {n3, shape::length(bShapeInfo) / n3};
|
||||
|
||||
std::vector<Nd4jLong> oldShapeB;
|
||||
oldShapeB.resize(list_B.size());
|
||||
|
@ -306,10 +319,10 @@ std::vector<Nd4jLong> ShapeUtils::evalRepeatShape(int axis, const std::vector<in
|
|||
Nd4jLong* ShapeUtils::evalPermShapeInfo(const int* dimensions, const int rank, const NDArray& arr, nd4j::memory::Workspace* workspace) {
|
||||
|
||||
if (!arr.nonNull())
|
||||
throw std::runtime_error("ShapeUtils::evalPermShapeInfo static method: wrong arguments in pn/termute method: either array is nullptr!");
|
||||
throw std::runtime_error("ShapeUtils::evalPermShapeInfo static method: wrong arguments: array is nullptr!");
|
||||
|
||||
if (rank != arr.rankOf())
|
||||
throw std::runtime_error("ShapeUtils::evalPermShapeInfo static method: wrong arguments in pn/termute method: rank is not suitable!");
|
||||
throw std::runtime_error("ShapeUtils::evalPermShapeInfo static method: wrong arguments: rank is not suitable!");
|
||||
|
||||
auto shapeInfoLength = shape::shapeInfoLength(rank);
|
||||
// allocate memory for new array - shapeInfo
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -84,7 +84,7 @@ namespace functions {
|
|||
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int totalThreads = gridDim.x * blockDim.x;
|
||||
|
||||
if(xEws > 0 && zEws > 0 && xOrder == zOrder) {
|
||||
if(xEws > 0 && zEws > 0 && xOrder == zOrder && xOrder == 'c') {
|
||||
|
||||
for (int i = tid; i < length; i += totalThreads)
|
||||
z[i * zEws] = OpType::op(x[i * xEws], params);
|
||||
|
|
|
@ -89,7 +89,7 @@ namespace functions {
|
|||
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int totalThreads = gridDim.x * blockDim.x;
|
||||
|
||||
if(xEws > 0 && zEws > 0 && xOrder == zOrder) {
|
||||
if(xEws > 0 && zEws > 0 && xOrder == zOrder && xOrder == 'c') {
|
||||
|
||||
for (int i = tid; i < length; i += totalThreads)
|
||||
z[i * zEws] = OpType::op(x[i * xEws], params);
|
||||
|
|
|
@ -97,7 +97,7 @@ namespace functions {
|
|||
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int totalThreads = gridDim.x * blockDim.x;
|
||||
|
||||
if(xEws > 0 && zEws > 0 && xOrder == zOrder) {
|
||||
if(xEws > 0 && zEws > 0 && xOrder == zOrder && xOrder == 'c') {
|
||||
|
||||
for (Nd4jLong i = tid; i < length; i += totalThreads)
|
||||
z[i * zEws] = OpType::op(x[i * xEws], params);
|
||||
|
|
|
@ -87,7 +87,7 @@ namespace functions {
|
|||
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int totalThreads = gridDim.x * blockDim.x;
|
||||
|
||||
if(xEws > 0 && zEws > 0 && xOrder == zOrder) {
|
||||
if(xEws > 0 && zEws > 0 && xOrder == zOrder && xOrder == 'c') {
|
||||
|
||||
for (int i = tid; i < length; i += totalThreads)
|
||||
z[i * zEws] = OpType::op(x[i * xEws], params);
|
||||
|
|
|
@ -89,7 +89,7 @@ namespace functions {
|
|||
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int totalThreads = gridDim.x * blockDim.x;
|
||||
|
||||
if(xEws > 0 && zEws > 0 && xOrder == zOrder) {
|
||||
if(xEws > 0 && zEws > 0 && xOrder == zOrder && xOrder == 'c') {
|
||||
|
||||
for (int i = tid; i < length; i += totalThreads)
|
||||
z[i * zEws] = OpType::op(x[i * xEws], params);
|
||||
|
|
|
@ -21,70 +21,174 @@
|
|||
#include <op_boilerplate.h>
|
||||
#if NOT_EXCLUDED(OP_tensormmul)
|
||||
|
||||
#include <numeric>
|
||||
#include <helpers/ShapeUtils.h>
|
||||
#include <ops/declarable/CustomOperations.h>
|
||||
#include <MmulHelper.h>
|
||||
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
CUSTOM_OP_IMPL(tensormmul, 2, 1, false, 0, -1) {
|
||||
auto a = INPUT_VARIABLE(0);
|
||||
auto b = INPUT_VARIABLE(1);
|
||||
namespace ops {
|
||||
|
||||
auto c = OUTPUT_VARIABLE(0); //
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
CUSTOM_OP_IMPL(tensormmul, 2, 1, false, 0, -1) {
|
||||
|
||||
REQUIRE_TRUE(a->dataType() == b->dataType(), 0, "tensormmul: A, B and C data types must be the same");
|
||||
auto a = INPUT_VARIABLE(0);
|
||||
auto b = INPUT_VARIABLE(1);
|
||||
|
||||
// building axes
|
||||
int axe0_size = INT_ARG(0);
|
||||
int axe1_size = INT_ARG(axe0_size+1);
|
||||
std::vector<int> axes_0(axe0_size), axes_1(axe1_size);
|
||||
for (int e = 0; e < axe0_size; e++)
|
||||
axes_0[e] = (int) INT_ARG(e+1);
|
||||
auto c = OUTPUT_VARIABLE(0);
|
||||
|
||||
for (int e = 0; e < axe1_size; e++)
|
||||
axes_1[e] = (int) INT_ARG(e + axe0_size + 2);
|
||||
REQUIRE_TRUE(a->dataType() == b->dataType(), 0, "tensormmul: A, B and C data types must be the same");
|
||||
|
||||
nd4j_verbose("axe0: %i; axe1: %i;\n", axes_0.size(), axes_1.size());
|
||||
// building axes
|
||||
int axe0_size = INT_ARG(0);
|
||||
int axe1_size = INT_ARG(axe0_size+1);
|
||||
std::vector<int> axes_0(axe0_size), axes_1(axe1_size);
|
||||
for (int e = 0; e < axe0_size; e++)
|
||||
axes_0[e] = (int)INT_ARG(e + 1);
|
||||
|
||||
MmulHelper::tensorDot(a, b, c, axes_0, axes_1);
|
||||
return Status::OK();
|
||||
}
|
||||
DECLARE_SYN(tensordot, tensormmul);
|
||||
for (int e = 0; e < axe1_size; e++)
|
||||
axes_1[e] = (int)INT_ARG(e + axe0_size + 2);
|
||||
|
||||
nd4j_verbose("axe0: %i; axe1: %i;\n", axes_0.size(), axes_1.size());
|
||||
|
||||
DECLARE_SHAPE_FN(tensormmul) {
|
||||
|
||||
auto aShapeInfo = inputShape->at(0);
|
||||
auto bShapeInfo = inputShape->at(1);
|
||||
MmulHelper::tensorDot(a, b, c, axes_0, axes_1);
|
||||
return Status::OK();
|
||||
}
|
||||
DECLARE_SYN(tensordot, tensormmul);
|
||||
|
||||
REQUIRE_TRUE(ArrayOptions::dataType(aShapeInfo) == ArrayOptions::dataType(bShapeInfo), 0, "tensormmul: A and B data types must be the same");
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
DECLARE_SHAPE_FN(tensormmul) {
|
||||
|
||||
// building axes
|
||||
int axe0_size = INT_ARG(0);
|
||||
int axe1_size = INT_ARG(axe0_size+1);
|
||||
std::vector<int> axes_0(axe0_size), axes_1(axe1_size);
|
||||
for (int e = 0; e < axe0_size; e++)
|
||||
axes_0[e] = (int) INT_ARG(e+1);
|
||||
auto aShapeInfo = inputShape->at(0);
|
||||
auto bShapeInfo = inputShape->at(1);
|
||||
|
||||
for (int e = 0; e < axe1_size; e++)
|
||||
axes_1[e] = (int) INT_ARG(e + axe0_size + 2);
|
||||
REQUIRE_TRUE(ArrayOptions::dataType(aShapeInfo) == ArrayOptions::dataType(bShapeInfo), 0, "tensormmul: A and B data types must be the same");
|
||||
|
||||
// evaluate shapes
|
||||
std::vector<int> permutAt, permutBt;
|
||||
std::vector<Nd4jLong> shapeAt, shapeBt;
|
||||
auto outShape = nd4j::ShapeUtils::evalShapeForTensorDot(aShapeInfo, bShapeInfo, axes_0, axes_1, permutAt, permutBt, shapeAt, shapeBt);
|
||||
// building axes
|
||||
int axe0_size = INT_ARG(0);
|
||||
int axe1_size = INT_ARG(axe0_size+1);
|
||||
std::vector<int> axes_0(axe0_size), axes_1(axe1_size);
|
||||
for (int e = 0; e < axe0_size; e++)
|
||||
axes_0[e] = (int) INT_ARG(e+1);
|
||||
|
||||
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(aShapeInfo), 'c', outShape)));
|
||||
}
|
||||
for (int e = 0; e < axe1_size; e++)
|
||||
axes_1[e] = (int) INT_ARG(e + axe0_size + 2);
|
||||
|
||||
DECLARE_TYPES(tensormmul) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF})
|
||||
->setAllowedInputTypes(1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF})
|
||||
->setAllowedOutputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF});
|
||||
}
|
||||
// evaluate shapes
|
||||
std::vector<int> permutAt, permutBt;
|
||||
std::vector<Nd4jLong> shapeAt, shapeBt;
|
||||
auto outShape = nd4j::ShapeUtils::evalShapeForTensorDot(aShapeInfo, bShapeInfo, axes_0, axes_1, permutAt, permutBt, shapeAt, shapeBt);
|
||||
|
||||
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(aShapeInfo), 'c', outShape)));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
DECLARE_TYPES(tensormmul) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF})
|
||||
->setAllowedInputTypes(1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF})
|
||||
->setAllowedOutputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF});
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
CUSTOM_OP_IMPL(tensormmul_bp, 3, 2, false, 0, -1) {
|
||||
|
||||
auto A = INPUT_VARIABLE(0);
|
||||
auto B = INPUT_VARIABLE(1);
|
||||
|
||||
auto dLdC = INPUT_VARIABLE(2);
|
||||
|
||||
auto dLdA = OUTPUT_VARIABLE(0);
|
||||
auto dLdB = OUTPUT_VARIABLE(1);
|
||||
|
||||
REQUIRE_TRUE( (A->dataType() == B->dataType() && (dLdC->dataType() == A->dataType())), 0, "tensormmul_bp: A, B and dLdC data types must be the same");
|
||||
|
||||
int axe0Size = INT_ARG(0);
|
||||
int axe1Size = INT_ARG(axe0Size + 1);
|
||||
|
||||
auto Arank = A->rankOf();
|
||||
auto Brank = B->rankOf();
|
||||
auto dLdCrank = dLdC->rankOf();
|
||||
|
||||
REQUIRE_TRUE((Arank >= axe0Size), 0, "tensormmul_bp: A rank must be the higher or same as input axes 0");
|
||||
|
||||
REQUIRE_TRUE((Brank >= axe1Size), 0, "tensormmul_bp: B rank must be the higher or same as input axes 1");
|
||||
|
||||
// building axes
|
||||
std::vector<int> axes0(axe0Size), axes1(axe1Size);
|
||||
for (uint e = 0; e < axe0Size; e++)
|
||||
axes0[e] = (int)INT_ARG(e + 1);
|
||||
for (uint e = 0; e < axe1Size; e++)
|
||||
axes1[e] = (int)INT_ARG(e + axe0Size + 2);
|
||||
|
||||
std::vector<int> permutAt, permutBt;
|
||||
std::vector<Nd4jLong> shapeAt, shapeBt;
|
||||
|
||||
ShapeUtils::evalShapeForTensorDot(A, B, axes0, axes1, permutAt, permutBt, shapeAt, shapeBt);
|
||||
|
||||
// special case for scalar value
|
||||
if (dLdC->isScalar()) {
|
||||
|
||||
dLdA->assign((*dLdC) * *B);
|
||||
dLdB->assign((*dLdC) * *A);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::vector<int> axesA = ShapeUtils::evalDimsToExclude(Arank, axes0);
|
||||
std::vector<int> axesB = ShapeUtils::evalDimsToExclude(Brank, axes1);
|
||||
|
||||
// rank always have to be divided by 2
|
||||
std::vector<int> axesAdLdC, axesBdLdC;
|
||||
if (dLdCrank > 1) {
|
||||
axesAdLdC.resize(dLdCrank / 2);
|
||||
std::iota(axesAdLdC.begin(), axesAdLdC.end(), 0);
|
||||
axesBdLdC = ShapeUtils::evalDimsToExclude(dLdCrank, axesAdLdC);
|
||||
}
|
||||
else {
|
||||
axesAdLdC.push_back(0);
|
||||
axesBdLdC.push_back(0);
|
||||
}
|
||||
|
||||
// calculate dLdA
|
||||
MmulHelper::tensorDot(dLdC, B, dLdA, axesBdLdC, axesB, permutAt);
|
||||
|
||||
// calculate dLdB
|
||||
MmulHelper::tensorDot(A, dLdC, dLdB, axesA, axesAdLdC, permutBt);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
DECLARE_SHAPE_FN(tensormmul_bp) {
|
||||
|
||||
auto aShapeInfo = inputShape->at(0);
|
||||
auto bShapeInfo = inputShape->at(1);
|
||||
auto dLShapeInfo = inputShape->at(2);
|
||||
|
||||
REQUIRE_TRUE((ArrayOptions::dataType(aShapeInfo) == ArrayOptions::dataType(bShapeInfo) &&
|
||||
(ArrayOptions::dataType(dLShapeInfo) == ArrayOptions::dataType(aShapeInfo))), 0, "tensormmul_bp: A, B and dLdC data types must be the same");
|
||||
|
||||
Nd4jLong* dLdAShapeInfo = nullptr;
|
||||
Nd4jLong* dLdBShapeInfo = nullptr;
|
||||
|
||||
COPY_SHAPE(aShapeInfo, dLdAShapeInfo);
|
||||
COPY_SHAPE(bShapeInfo, dLdBShapeInfo);
|
||||
|
||||
return SHAPELIST(CONSTANT(dLdAShapeInfo), CONSTANT(dLdBShapeInfo));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
DECLARE_TYPES(tensormmul_bp) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes(0, { DataType::FLOAT32, DataType::DOUBLE, DataType::HALF }) // maybe better ALL_FLOATS
|
||||
->setAllowedInputTypes(1, { DataType::FLOAT32, DataType::DOUBLE, DataType::HALF })
|
||||
->setAllowedInputTypes(2, { DataType::FLOAT32, DataType::DOUBLE, DataType::HALF })
|
||||
->setAllowedOutputTypes(0, { DataType::FLOAT32, DataType::DOUBLE, DataType::HALF })
|
||||
->setAllowedOutputTypes(1, { DataType::FLOAT32, DataType::DOUBLE, DataType::HALF });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
|
@ -79,7 +79,7 @@ CUSTOM_OP_IMPL(conv1d, 2, 1, false, 0, 5) {
|
|||
}
|
||||
|
||||
auto inputReshaped = input ->reshape(input->ordering(), reshapeForInput);
|
||||
auto outputReshaped = output ->reshape(output->ordering(), reshapeForOutput);
|
||||
auto outputReshaped = output ->reshape(output->ordering(), reshapeForOutput, false);
|
||||
auto weightsReshaped = weights->reshape(weights->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
|
||||
|
||||
nd4j::ops::conv2d conv2d;
|
||||
|
@ -216,10 +216,10 @@ CUSTOM_OP_IMPL(conv1d_bp, 3, 2, false, 0, 5) {
|
|||
}
|
||||
|
||||
auto inputReshaped = input ->reshape(input->ordering(), reshapeForInput);
|
||||
auto gradIReshaped = gradI ->reshape(gradI->ordering(), reshapeForInput);
|
||||
auto gradIReshaped = gradI ->reshape(gradI->ordering(), reshapeForInput, false);
|
||||
auto gradOReshaped = gradO ->reshape(gradO->ordering(), reshapeForGradO);
|
||||
auto weightsReshaped = weights->reshape(weights->ordering(),{1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
|
||||
auto gradWReshaped = gradW ->reshape(gradW->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
|
||||
auto weightsReshaped = weights->reshape(weights->ordering(),{1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
|
||||
auto gradWReshaped = gradW ->reshape(gradW->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}, false);// [kW, iC, oC] -> [1, kW, iC, oC]
|
||||
|
||||
nd4j::ops::conv2d_bp conv2dBP;
|
||||
auto status = conv2dBP.execute({&inputReshaped, &weightsReshaped, bias, &gradOReshaped}, {&gradIReshaped, &gradWReshaped, gradB}, {}, {1,kW, 1,sW, 0,pW, 1,dW, paddingMode, !isNCW}, {});
|
||||
|
|
|
@ -239,7 +239,7 @@ CUSTOM_OP_IMPL(conv3dnew_bp, 3, 2, false, 0, 13) {
|
|||
//----- calculation of gradO -----//
|
||||
if(gradB) {
|
||||
if(gradB->rankOf() == 2)
|
||||
gradB = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}));
|
||||
gradB = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}, false));
|
||||
gradO->reduceAlongDimension(reduce::Sum, *gradB, gradOaxesForDot); // sum over bS oD oH oW
|
||||
if(gradB != OUTPUT_VARIABLE(2))
|
||||
delete gradB;
|
||||
|
|
|
@ -233,7 +233,7 @@ CUSTOM_OP_IMPL(deconv2d_bp, 3, 2, false, 0, 9) {
|
|||
// ----- calculation of gradB ----- //
|
||||
if(gradB) {
|
||||
if(gradB->rankOf() == 2)
|
||||
gradB = new NDArray(gradB->reshape(gradB->ordering(), {gradB->lengthOf()}));
|
||||
gradB = new NDArray(gradB->reshape(gradB->ordering(), {gradB->lengthOf()}, false));
|
||||
gradO->reduceAlongDimension(reduce::Sum, *gradB, {0, 2, 3}); // sum over bS, oH, oW
|
||||
if(gradB != OUTPUT_VARIABLE(2))
|
||||
delete gradB;
|
||||
|
|
|
@ -243,7 +243,7 @@ CUSTOM_OP_IMPL(deconv3d_bp, 3, 2, false, 0, 13) {
|
|||
// ----- calculation of gradB ----- //
|
||||
if(gradB) {
|
||||
if(gradB->rankOf() == 2)
|
||||
gradB = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}));
|
||||
gradB = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}, false));
|
||||
gradO->reduceAlongDimension(reduce::Sum, *gradB, {0, 2, 3, 4}); // sum over bS, oD, oH, oW
|
||||
if(gradB != OUTPUT_VARIABLE(2))
|
||||
delete gradB;
|
||||
|
|
|
@ -61,13 +61,13 @@ namespace nd4j {
|
|||
}
|
||||
|
||||
auto source = inRank == 4?image->reshape(image->ordering(), {image->sizeAt(0), image->sizeAt(1), image->sizeAt(2), image->sizeAt(3)}):image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)});
|
||||
auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}):output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)});
|
||||
auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}, false) : output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)}, false);
|
||||
|
||||
return helpers::resizeAreaFunctor(block.launchContext(), &source, width, height, alignCorners, &target);
|
||||
}
|
||||
|
||||
DECLARE_SHAPE_FN(resize_area) {
|
||||
auto shapeList = SHAPELIST();
|
||||
auto shapeList = SHAPELIST();
|
||||
auto in = inputShape->at(0);
|
||||
|
||||
Nd4jLong* outputShape;
|
||||
|
@ -90,7 +90,7 @@ namespace nd4j {
|
|||
}
|
||||
|
||||
REQUIRE_TRUE(inRank == 4 || inRank == 3, 0, "resize_area: Source tensor should have rank 4, but %i given.", inRank);
|
||||
|
||||
|
||||
ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(inRank), Nd4jLong);
|
||||
outputShape[0] = inRank;
|
||||
if (inRank == 4) {
|
||||
|
|
|
@ -62,13 +62,13 @@ namespace nd4j {
|
|||
REQUIRE_TRUE(!halfPixelAlign || (halfPixelAlign && !alignCorners), 0, "resize_bicubic: `half_pixel_centers' should be false or true only when `align_corners' is false");
|
||||
|
||||
auto source = inRank == 4?image->reshape(image->ordering(), {image->sizeAt(0), image->sizeAt(1), image->sizeAt(2), image->sizeAt(3)}):image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)});
|
||||
auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}):output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)});
|
||||
auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}, false) : output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)}, false);
|
||||
|
||||
return helpers::resizeBicubicFunctorA(block.launchContext(), &source, width, height, alignCorners, halfPixelAlign, &target);
|
||||
}
|
||||
|
||||
DECLARE_SHAPE_FN(resize_bicubic) {
|
||||
auto shapeList = SHAPELIST();
|
||||
auto shapeList = SHAPELIST();
|
||||
auto in = inputShape->at(0);
|
||||
|
||||
Nd4jLong* outputShape;
|
||||
|
@ -82,7 +82,7 @@ namespace nd4j {
|
|||
height = newImageSize->e<int>(1);
|
||||
|
||||
REQUIRE_TRUE(inRank == 4 || inRank == 3, 0, "resize_bicubic: Source tensor should have rank 4, but %i given.", inRank);
|
||||
|
||||
|
||||
ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(inRank), Nd4jLong);
|
||||
outputShape[0] = inRank;
|
||||
if (inRank == 4) {
|
||||
|
|
|
@ -43,7 +43,7 @@ namespace nd4j {
|
|||
REQUIRE_TRUE(inRank == output->rankOf(), 0, "resize_bilinear: Input and output ranks should be equals, but %i and %i occured.", inRank, output->rankOf());
|
||||
|
||||
auto source = inRank == 4?image->reshape(image->ordering(), {image->sizeAt(0), image->sizeAt(1), image->sizeAt(2), image->sizeAt(3)}):image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)});
|
||||
auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}):output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)});
|
||||
auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}, false) : output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)}, false);
|
||||
|
||||
if (block.width() > 1) {
|
||||
auto newImageSize = INPUT_VARIABLE(1);
|
||||
|
@ -71,7 +71,7 @@ namespace nd4j {
|
|||
}
|
||||
|
||||
DECLARE_SHAPE_FN(resize_bilinear) {
|
||||
auto shapeList = SHAPELIST();
|
||||
auto shapeList = SHAPELIST();
|
||||
auto in = inputShape->at(0);
|
||||
|
||||
Nd4jLong* outputShape;
|
||||
|
@ -94,7 +94,7 @@ namespace nd4j {
|
|||
width = INT_ARG(0);
|
||||
height = INT_ARG(1);
|
||||
}
|
||||
|
||||
|
||||
ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(inRank), Nd4jLong);
|
||||
outputShape[0] = inRank;
|
||||
if (inRank == 4) {
|
||||
|
|
|
@ -63,13 +63,13 @@ namespace nd4j {
|
|||
REQUIRE_TRUE(((alignCorners && height > 2) || (height > 0)) && ((alignCorners && width > 1) || (width > 0)), 0, "resize_nearest_neighbor: Wrong input or output size to resize (width = %d, height = %d)", width, height);
|
||||
|
||||
auto source = inRank == 4?*image:image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)});
|
||||
auto target = inRank == 4?*output:output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)});
|
||||
auto target = inRank == 4 ? *output : output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)}, false);
|
||||
|
||||
return helpers::resizeNeighborFunctor(block.launchContext(), inRank==4?image:&source, width, height, alignCorners, halfPixelCenter, inRank == 4 ? output : &target);
|
||||
}
|
||||
|
||||
DECLARE_SHAPE_FN(resize_nearest_neighbor) {
|
||||
auto shapeList = SHAPELIST();
|
||||
auto shapeList = SHAPELIST();
|
||||
auto in = inputShape->at(0);
|
||||
auto inRank = shape::rank(in);
|
||||
Nd4jLong* outputShape;
|
||||
|
|
|
@ -36,14 +36,14 @@ namespace nd4j {
|
|||
int _a = INT_ARG(e);
|
||||
if (_a < 0)
|
||||
_a += input->rankOf();
|
||||
|
||||
|
||||
axis.emplace_back(_a);
|
||||
}
|
||||
else if (block.width() > 1) {
|
||||
auto a = INPUT_VARIABLE(1);
|
||||
for (Nd4jLong e = 0; e < a->lengthOf(); e++) {
|
||||
int _a = a->e<int>(e);
|
||||
|
||||
|
||||
if (_a < 0)
|
||||
_a += input->rankOf();
|
||||
|
||||
|
@ -71,7 +71,7 @@ namespace nd4j {
|
|||
}
|
||||
|
||||
if (block.isInplace()) {
|
||||
output->reshapei(input->ordering(), shape);
|
||||
output->reshapei(input->ordering(), shape, false);
|
||||
} else {
|
||||
auto tmp = input->reshape(input->ordering(), shape);
|
||||
output->assign(tmp);
|
||||
|
@ -106,20 +106,20 @@ namespace nd4j {
|
|||
int _a = INT_ARG(e);
|
||||
if (_a < 0)
|
||||
_a += rank;
|
||||
|
||||
|
||||
axis.emplace_back(_a);
|
||||
}
|
||||
else if (block.width() > 1) {
|
||||
auto a = INPUT_VARIABLE(1);
|
||||
for (int e = 0; e < a->lengthOf(); e++) {
|
||||
int _a = a->e<int>(e);
|
||||
|
||||
|
||||
if (_a < 0)
|
||||
_a += rank;
|
||||
|
||||
axis.emplace_back(_a);
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
auto order = shape::order(in);
|
||||
|
|
|
@ -57,7 +57,8 @@ namespace nd4j {
|
|||
* IArgs[1]... axes values for second array
|
||||
*/
|
||||
#if NOT_EXCLUDED(OP_tensormmul)
|
||||
DECLARE_CUSTOM_OP(tensormmul, 2, 1, false, 0, -1);
|
||||
DECLARE_CUSTOM_OP(tensormmul, 2, 1, false, 0, -1);
|
||||
DECLARE_CUSTOM_OP(tensormmul_bp, 3, 2, false, 0, -1);
|
||||
#endif
|
||||
|
||||
/**
|
||||
|
|
|
@ -432,7 +432,7 @@ namespace nd4j {
|
|||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
||||
|
||||
NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, input->dataType(), input->getContext());
|
||||
NDArray outputReshaped = output->reshape(output->ordering(), outReShape);
|
||||
NDArray outputReshaped = output->reshape(output->ordering(), outReShape, false);
|
||||
|
||||
helpers::im2col(*output->getContext(), *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW]
|
||||
MmulHelper::tensorDot(&columns, weights, &outputReshaped, modifColumns, {{2,0,1,3},{iC,kH*kW,mC}}, modifOutput); // [iC, bS*oH*oW, kW*kH] x [iC, kH*kW, mC] = [iC, bS*oH*oW, mC]
|
||||
|
@ -505,7 +505,7 @@ namespace nd4j {
|
|||
if(gradB) {
|
||||
NDArray* gradBR = gradB;
|
||||
if(gradB->rankOf() == 2)
|
||||
gradBR = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}));
|
||||
gradBR = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}, false));
|
||||
gradO->reduceAlongDimension(reduce::Sum, *gradBR, {0,indOoH,indOoH+1}); // sum over bS, oH, oW
|
||||
|
||||
if(gradBR != gradB)
|
||||
|
|
|
@ -30,7 +30,7 @@ namespace helpers {
|
|||
void crossBatched(nd4j::LaunchContext * context, NDArray *a, NDArray *b, NDArray *o) {
|
||||
auto _a = a->reshape(a->ordering(), {-1, 3});
|
||||
auto _b = b->reshape(b->ordering(), {-1, 3});
|
||||
auto _o = o->reshape(o->ordering(), {-1, 3});
|
||||
auto _o = o->reshape(o->ordering(), {-1, 3}, false);
|
||||
|
||||
auto tadsA = _a.allTensorsAlongDimension({1});
|
||||
auto tadsB = _b.allTensorsAlongDimension({1});
|
||||
|
|
|
@ -244,14 +244,14 @@ void spaceToBatch(nd4j::LaunchContext* context, const NDArray& input, NDArray& o
|
|||
|
||||
// [bS, iH, iW, iC] is rearranged/permuted to [bS*blockSize*blockSize, (iH + padBottom + padTop)/blockSize, (iW + padLeft + padRight)/blockSize, iC]
|
||||
|
||||
NDArray outputRearranged0 = output.reshape(output.ordering(), {blockSize, blockSize, input.sizeAt(0), output.sizeAt(1), output.sizeAt(2), output.sizeAt(3)});
|
||||
NDArray outputRearranged0 = output.reshape(output.ordering(), {blockSize, blockSize, input.sizeAt(0), output.sizeAt(1), output.sizeAt(2), output.sizeAt(3)}, false);
|
||||
outputRearranged0.permutei({2, 3,0, 4,1, 5});
|
||||
|
||||
if(input.lengthOf() == output.lengthOf()) {
|
||||
outputRearranged0.assign(input);
|
||||
}
|
||||
else {
|
||||
NDArray outputRearranged1 = outputRearranged0.reshape(output.ordering(), {input.sizeAt(0), output.sizeAt(1) * blockSize, output.sizeAt(2) * blockSize, output.sizeAt(3)});
|
||||
NDArray outputRearranged1 = outputRearranged0.reshape(output.ordering(), {input.sizeAt(0), output.sizeAt(1) * blockSize, output.sizeAt(2) * blockSize, output.sizeAt(3)}, false);
|
||||
BUILD_SINGLE_SELECTOR(input.dataType(), spaceToBatch_, (input, outputRearranged1, padBottom, padTop, padLeft, padRight), LIBND4J_TYPES);
|
||||
|
||||
if(output.getBuffer() != outputRearranged1.getBuffer())
|
||||
|
@ -352,7 +352,7 @@ void spaceToBatchND(nd4j::LaunchContext* context, const NDArray& input, const ND
|
|||
for(int j = 1; j < rank; ++i, ++j)
|
||||
temp[i] = output.sizeAt(j);
|
||||
|
||||
NDArray outputRearranged0 = output.reshape(output.ordering(), temp);
|
||||
NDArray outputRearranged0 = output.reshape(output.ordering(), temp, false);
|
||||
|
||||
//*** construct permuting std::vector for permutation of output array ***//
|
||||
|
||||
|
@ -382,7 +382,7 @@ void spaceToBatchND(nd4j::LaunchContext* context, const NDArray& input, const ND
|
|||
for(i = 1; i < rank; ++i)
|
||||
temp[i] = (i <= numOfSpatialDims) ? output.sizeAt(i) * blockShape.e<Nd4jLong>(i - 1) : output.sizeAt(i);
|
||||
|
||||
NDArray outputRearranged1 = outputRearranged0.reshape(output.ordering(), temp);
|
||||
NDArray outputRearranged1 = outputRearranged0.reshape(output.ordering(), temp, false);
|
||||
|
||||
BUILD_SINGLE_SELECTOR(input.dataType(), spaceToBatchND_, (input, padding, outputRearranged1, numOfSpatialDims), LIBND4J_TYPES);
|
||||
|
||||
|
|
|
@ -59,7 +59,7 @@ void FORCEINLINE cross(nd4j::LaunchContext * context, NDArray *a, NDArray *b, ND
|
|||
void FORCEINLINE _crossBatched(nd4j::LaunchContext * context, NDArray *a, NDArray *b, NDArray *o) {
|
||||
auto a_ = a->reshape(a->ordering(), {-1, 3});
|
||||
auto b_ = b->reshape(b->ordering(), {-1, 3});
|
||||
auto o_ = o->reshape(o->ordering(), {-1, 3});
|
||||
auto o_ = o->reshape(o->ordering(), {-1, 3}, false);
|
||||
|
||||
auto tadsA = a_.allTensorsAlongDimension({1});
|
||||
auto tadsB = b_.allTensorsAlongDimension({1});
|
||||
|
|
|
@ -322,7 +322,7 @@ static void depthwiseConv2d_(nd4j::graph::Context& block, const NDArray* input,
|
|||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
||||
|
||||
NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, input->dataType(), input->getContext());
|
||||
NDArray outputReshaped = output->reshape(output->ordering(), outReShape);
|
||||
NDArray outputReshaped = output->reshape(output->ordering(), outReShape, false);
|
||||
|
||||
helpers::im2col(*output->getContext(), *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW]
|
||||
MmulHelper::tensorDot(&columns, weights, &outputReshaped, modifColumns, {{2,0,1,3},{iC,kH*kW,mC}}, modifOutput); // [iC, bS*oH*oW, kW*kH] x [iC, kH*kW, mC] = [iC, bS*oH*oW, mC]
|
||||
|
@ -1228,7 +1228,7 @@ static void conv2dBP_(nd4j::graph::Context& block, const NDArray* input, const N
|
|||
NDArray* gradBR = gradB;
|
||||
if(gradB->rankOf() == 2)
|
||||
gradBR = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}));
|
||||
gradO->reduceAlongDimension(reduce::Sum, *gradBR, gradOaxesForDot); // sum over bS, oH, oW
|
||||
gradO->reduceAlongDimension(reduce::Sum, *gradBR, gradOaxesForDot, false); // sum over bS, oH, oW
|
||||
if(gradBR != gradB)
|
||||
delete gradBR;
|
||||
}
|
||||
|
@ -1310,7 +1310,7 @@ static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, con
|
|||
NDArray* gradBR = gradB;
|
||||
if(gradB->rankOf() == 2)
|
||||
gradBR = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}));
|
||||
gradO->reduceAlongDimension(reduce::Sum, *gradBR, {0,indOoH,indOoH+1}); // sum over bS, oH, oW
|
||||
gradO->reduceAlongDimension(reduce::Sum, *gradBR, {0,indOoH,indOoH+1}, false); // sum over bS, oH, oW
|
||||
if(gradBR != gradB)
|
||||
delete gradBR;
|
||||
}
|
||||
|
|
|
@ -313,7 +313,7 @@ void spaceToBatch(nd4j::LaunchContext* context, const NDArray& input, NDArray& o
|
|||
|
||||
// [bS, iH, iW, iC] is rearranged/permuted to [bS*blockSize*blockSize, (iH + padBottom + padTop)/blockSize, (iW + padLeft + padRight)/blockSize, iC]
|
||||
|
||||
NDArray outputRearranged0 = output.reshape(output.ordering(), {blockSize, blockSize, input.sizeAt(0), output.sizeAt(1), output.sizeAt(2), input.sizeAt(3)});
|
||||
NDArray outputRearranged0 = output.reshape(output.ordering(), {blockSize, blockSize, input.sizeAt(0), output.sizeAt(1), output.sizeAt(2), input.sizeAt(3)}, false);
|
||||
outputRearranged0.permutei({2, 3,0, 4,1, 5});
|
||||
|
||||
if(input.lengthOf() == output.lengthOf()) {
|
||||
|
@ -322,7 +322,7 @@ void spaceToBatch(nd4j::LaunchContext* context, const NDArray& input, NDArray& o
|
|||
}
|
||||
else {
|
||||
|
||||
NDArray outputRearranged1 = outputRearranged0.reshape(output.ordering(), {input.sizeAt(0), output.sizeAt(1) * blockSize, output.sizeAt(2) * blockSize, input.sizeAt(3)});
|
||||
NDArray outputRearranged1 = outputRearranged0.reshape(output.ordering(), {input.sizeAt(0), output.sizeAt(1) * blockSize, output.sizeAt(2) * blockSize, input.sizeAt(3)}, false);
|
||||
|
||||
const int threadsPerBlock = MAX_NUM_THREADS / 2;
|
||||
const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
||||
|
@ -439,7 +439,7 @@ void spaceToBatchND(nd4j::LaunchContext* context, const NDArray& input, const ND
|
|||
for(int j = 1; j < rank; ++i, ++j)
|
||||
temp[i] = output.sizeAt(j);
|
||||
|
||||
NDArray outputRearranged0 = output.reshape(output.ordering(), temp);
|
||||
NDArray outputRearranged0 = output.reshape(output.ordering(), temp, false);
|
||||
|
||||
//*** construct permuting std::vector for permutation of output array ***//
|
||||
|
||||
|
@ -469,7 +469,7 @@ void spaceToBatchND(nd4j::LaunchContext* context, const NDArray& input, const ND
|
|||
for(i = 1; i < rank; ++i)
|
||||
temp[i] = (i <= numOfSpatialDims) ? output.sizeAt(i) * blockShape.e<Nd4jLong>(i - 1) : output.sizeAt(i);
|
||||
|
||||
NDArray outputRearranged1 = outputRearranged0.reshape(output.ordering(), temp);
|
||||
NDArray outputRearranged1 = outputRearranged0.reshape(output.ordering(), temp, false);
|
||||
|
||||
const int threadsPerBlock = MAX_NUM_THREADS / 4;
|
||||
const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
||||
|
|
|
@ -471,9 +471,9 @@ PLATFORM_IMPL(lstmLayer, ENGINE_CPU) {
|
|||
if(cI)
|
||||
cIR = new NDArray(cI->reshape(cI->ordering(), {1,dirDim,bS,nOut}));
|
||||
if(hL)
|
||||
hLR = new NDArray(hL->reshape(hL->ordering(), {1,dirDim,bS,nOut}));
|
||||
hLR = new NDArray(hL->reshape(hL->ordering(), {1,dirDim,bS,nOut}, false));
|
||||
if(cL)
|
||||
cLR = new NDArray(cL->reshape(cL->ordering(), {1,dirDim,bS,nOut}));
|
||||
cLR = new NDArray(cL->reshape(cL->ordering(), {1,dirDim,bS,nOut}, false));
|
||||
|
||||
lstmLayerMKLDNN(xP, WxR, WrR, bR, hIR, cIR, params, hP, hLR, cLR);
|
||||
|
||||
|
|
|
@ -321,6 +321,280 @@ TEST_F(DeclarableOpsTests1, TestTensorDot4) {
|
|||
delete results;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests1, TestTensorDot5) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('c', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
|
||||
auto y = NDArrayFactory::create<double>('c', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
|
||||
auto expected = NDArrayFactory::create<double>('c', {2,4,2,4}, {44,110,160, 66,132, 38, 88,154, 68,170,224,102,204, 82,136,238, 92,230,288,138,276,126,184,322, 116,290,352,174,348,170,232,406, 76,190,160,114,228,182,152,266, 100,250,224,150,300,226,200,350, 124,310,288,186,372,270,248,434, 148,370,352,222,444,314,296,518});
|
||||
|
||||
nd4j::ops::tensormmul op;
|
||||
auto results = op.evaluate({&x, &y}, {}, {1,1,1,2});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto *result = results->at(0);
|
||||
|
||||
ASSERT_TRUE(expected.isSameShape(result));
|
||||
ASSERT_TRUE(expected.equalsTo(result));
|
||||
|
||||
delete results;
|
||||
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests1, TestTensorDot6) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('c', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
|
||||
auto y = NDArrayFactory::create<double>('f', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
|
||||
auto expected = NDArrayFactory::create<double>('c', {2,4,2,4}, {22, 66,110,154, 44, 88,132,176, 34,102,170,238, 68,136,204,272, 46,138,230,322, 92,184,276,368, 58,174,290,406,116,232,348,464, 38,114,190,266, 76,152,228,304, 50,150,250,350,100,200,300,400, 62,186,310,434,124,248,372,496, 74,222,370,518,148,296,444,592});
|
||||
|
||||
nd4j::ops::tensormmul op;
|
||||
auto results = op.evaluate({&x, &y}, {}, {1,1,1,2});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto *result = results->at(0);
|
||||
|
||||
ASSERT_TRUE(expected.isSameShape(result));
|
||||
ASSERT_TRUE(expected.equalsTo(result));
|
||||
|
||||
delete results;
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests1, TestTensorDot7) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
|
||||
auto y = NDArrayFactory::create<double>('c', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
|
||||
auto expected = NDArrayFactory::create<double>('c', {2,4,2,4}, {76,166,112,106,196, 62,136,226, 60,174,208, 98,212,230,136,250, 76,214,336,122,260,174,168,306, 124,286,240,178,340,150,232,394, 100,226,176,142,268,106,184,310, 84,234,272,134,284,274,184,334, 100,274,400,158,332,218,216,390, 148,346,304,214,412,194,280,478});
|
||||
|
||||
nd4j::ops::tensormmul op;
|
||||
auto results = op.evaluate({&x, &y}, {}, {1,1,1,2});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto *result = results->at(0);
|
||||
|
||||
ASSERT_TRUE(expected.isSameShape(result));
|
||||
ASSERT_TRUE(expected.equalsTo(result));
|
||||
|
||||
delete results;
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests1, TestTensorDot8) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
|
||||
auto y = NDArrayFactory::create<double>('f', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
|
||||
auto expected = NDArrayFactory::create<double>('c', {2,4,2,4}, {30, 90,150,210, 60,120,180,240, 38,114,190,266, 76,152,228,304, 46,138,230,322, 92,184,276,368, 54,162,270,378,108,216,324,432, 42,126,210,294, 84,168,252,336, 50,150,250,350,100,200,300,400, 58,174,290,406,116,232,348,464, 66,198,330,462,132,264,396,528});
|
||||
|
||||
nd4j::ops::tensormmul op;
|
||||
auto results = op.evaluate({&x, &y}, {}, {1,1,1,2});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto *result = results->at(0);
|
||||
|
||||
ASSERT_TRUE(expected.isSameShape(result));
|
||||
ASSERT_TRUE(expected.equalsTo(result));
|
||||
|
||||
delete results;
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests1, TestTensorDot9) {
|
||||
|
||||
// NDArray z('f',{2,2,3}, nd4j::DataType::DOUBLE);
|
||||
// z.linspace(1);
|
||||
// z.printShapeInfo();
|
||||
// z.printIndexedBuffer();
|
||||
// z.reshapei('c', {4,3});
|
||||
// z.printShapeInfo();
|
||||
// z.printIndexedBuffer();
|
||||
|
||||
auto x = NDArrayFactory::create<double>('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
|
||||
auto y = NDArrayFactory::create<double>('f', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
|
||||
auto expected = NDArrayFactory::create<double>('c', {3,4,4,3}, {14, 14, 14, 30, 30, 30, 46, 46, 46, 62, 62, 62, 86, 86, 86,198,198,198,310,310,310,422,422,422, 62, 62, 62,142,142,142,222,222,222,302,302,302, 38, 38, 38, 86, 86, 86,134,134,134,182,182,182, 38, 38, 38, 86, 86, 86,134,134,134,182,182,182, 14, 14, 14, 30, 30, 30, 46, 46, 46, 62, 62, 62, 86, 86, 86,198,198,198,310,310,310,422,422,422, 62, 62, 62,142,142,142,222,222,222,302,302,302, 62, 62, 62,142,142,142,222,222,222,302,302,302, 38, 38, 38, 86, 86, 86,134,134,134,182,182,182, 14, 14, 14, 30, 30, 30, 46, 46, 46, 62, 62, 62, 86, 86, 86,198,198,198,310,310,310,422,422,422});
|
||||
|
||||
nd4j::ops::tensormmul op;
|
||||
auto results = op.evaluate({&x, &y}, {}, {1,0,1,0});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto *result = results->at(0);
|
||||
|
||||
ASSERT_TRUE(expected.isSameShape(result));
|
||||
ASSERT_TRUE(expected.equalsTo(result));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests1, TestTensorDot10) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
|
||||
auto y = NDArrayFactory::create<double>('f', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
|
||||
auto expected = NDArrayFactory::create<double>('c', {4,4}, {114,258,402,546, 138,314,490,666, 162,370,578,786, 186,426,666,906});
|
||||
|
||||
nd4j::ops::tensormmul op;
|
||||
auto results = op.evaluate({&x, &y}, {}, {2,0,1, 2,0,2});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto *result = results->at(0);
|
||||
|
||||
ASSERT_TRUE(expected.isSameShape(result));
|
||||
ASSERT_TRUE(expected.equalsTo(result));
|
||||
|
||||
delete results;
|
||||
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests1, TestTensorDot11) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('c', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
|
||||
auto y = NDArrayFactory::create<double>('f', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
|
||||
auto expected = NDArrayFactory::create<double>('c', {4,4}, {98,218,338,458, 134,302,470,638, 170,386,602,818, 206,470,734,998});
|
||||
|
||||
nd4j::ops::tensormmul op;
|
||||
auto results = op.evaluate({&x, &y}, {}, {2,0,1, 2,0,2});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto *result = results->at(0);
|
||||
|
||||
ASSERT_TRUE(expected.isSameShape(result));
|
||||
ASSERT_TRUE(expected.equalsTo(result));
|
||||
|
||||
delete results;
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests1, TestTensorDot12) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('c', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
|
||||
auto y = NDArrayFactory::create<double>('c', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
|
||||
auto expected = NDArrayFactory::create<double>('c', {4,4}, {272,292,312,332, 368,396,424,452, 464,500,536,572, 560,604,648,692});
|
||||
|
||||
nd4j::ops::tensormmul op;
|
||||
auto results = op.evaluate({&x, &y}, {}, {2,0,1, 2,0,2});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto *result = results->at(0);
|
||||
|
||||
ASSERT_TRUE(expected.isSameShape(result));
|
||||
ASSERT_TRUE(expected.equalsTo(result));
|
||||
|
||||
delete results;
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests1, TestTensorDot13) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('c', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
|
||||
auto y = NDArrayFactory::create<double>('c', {4,2,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
|
||||
auto expected = NDArrayFactory::create<double>('c', {3,3}, {640,560,640, 576,624,576, 640,560,640});
|
||||
|
||||
nd4j::ops::tensormmul op;
|
||||
auto results = op.evaluate({&x, &y}, {}, {2,0,2, 2,1,0});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto *result = results->at(0);
|
||||
|
||||
ASSERT_TRUE(expected.isSameShape(result));
|
||||
ASSERT_TRUE(expected.equalsTo(result));
|
||||
|
||||
delete results;
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests1, TestTensorDot14) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
|
||||
auto y = NDArrayFactory::create<double>('c', {4,2,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
|
||||
auto expected = NDArrayFactory::create<double>('c', {3,3}, {648,600,520, 648,536,648, 520,600,648});
|
||||
|
||||
nd4j::ops::tensormmul op;
|
||||
auto results = op.evaluate({&x, &y}, {}, {2,0,2, 2,1,0});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto *result = results->at(0);
|
||||
|
||||
ASSERT_TRUE(expected.isSameShape(result));
|
||||
ASSERT_TRUE(expected.equalsTo(result));
|
||||
|
||||
delete results;
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests1, TestTensorDot15) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
|
||||
auto y = NDArrayFactory::create<double>('f', {4,2,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
|
||||
auto expected = NDArrayFactory::create<double>('c', {3,3}, {624,624,624, 656,656,656, 624,624,624});
|
||||
|
||||
nd4j::ops::tensormmul op;
|
||||
auto results = op.evaluate({&x, &y}, {}, {2,0,2, 2,1,0});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto *result = results->at(0);
|
||||
|
||||
ASSERT_TRUE(expected.isSameShape(result));
|
||||
ASSERT_TRUE(expected.equalsTo(result));
|
||||
|
||||
delete results;
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests1, TestTensorDot16) {
|
||||
|
||||
NDArray x('c', {1}, std::vector<double>{2}, nd4j::DataType::FLOAT32);
|
||||
NDArray y('c', {2,1,2}, {1,2,3,4}, nd4j::DataType::FLOAT32);
|
||||
NDArray exp('c', {2,2}, {2,4,6,8}, nd4j::DataType::FLOAT32);
|
||||
|
||||
nd4j::ops::tensormmul op;
|
||||
auto results = op.evaluate({&x, &y}, {}, {1,0, 1,1});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto *result = results->at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(result));
|
||||
ASSERT_TRUE(exp.equalsTo(result));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests1, TestTensorDot17) {
|
||||
|
||||
NDArray x('f', {16,16}, nd4j::DataType::FLOAT32);
|
||||
NDArray y('f', {1000,16}, nd4j::DataType::FLOAT32);
|
||||
NDArray z('c', {16,1000}, nd4j::DataType::FLOAT32);
|
||||
|
||||
nd4j::ops::tensormmul op;
|
||||
auto status = op.execute({&x, &y}, {&z}, {}, {1,1, 1,1}, {});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests1, DivergentCheck1) {
|
||||
auto op = nd4j::ops::OpRegistrator::getInstance()->getOperation("switch");
|
||||
|
|
|
@ -708,30 +708,6 @@ TEST_F(DeclarableOpsTests12, multiUnique_2) {
|
|||
ASSERT_TRUE(nd4j::ops::helpers::multiUnique(arrayList));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests12, tensormmul_6) {
|
||||
|
||||
NDArray x('c', {1}, std::vector<double>{2}, nd4j::DataType::FLOAT32);
|
||||
NDArray y('c', {2,1,2}, {1,2,3,4}, nd4j::DataType::FLOAT32);
|
||||
NDArray exp('c', {2,2}, {2,4,6,8}, nd4j::DataType::FLOAT32);
|
||||
|
||||
nd4j::ops::tensormmul op;
|
||||
auto results = op.evaluate({&x, &y}, {}, {1,0, 1,1});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto *result = results->at(0);
|
||||
// exp.printShapeInfo();
|
||||
// result->printShapeInfo();
|
||||
// result->printIndexedBuffer();
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(result));
|
||||
ASSERT_TRUE(exp.equalsTo(result));
|
||||
|
||||
delete results;
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests12, reduceMeanBp_4) {
|
||||
|
||||
|
|
|
@ -1560,3 +1560,447 @@ TEST_F(DeclarableOpsTests15, Pow_BP_Test11) {
|
|||
|
||||
delete resultsB;
|
||||
}
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP1) {
|
||||
|
||||
NDArray A('c', { 1, 2, 3 }, { 2.1, 2.2, 2.3, 2.4, 2.5, 2.6 }, nd4j::DataType::FLOAT32);
|
||||
NDArray B('c', { 1, 2, 4 }, { 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8 }, nd4j::DataType::FLOAT32);
|
||||
NDArray dLdC('c', { 3, 4 }, { .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.1 }, nd4j::DataType::FLOAT32);
|
||||
|
||||
NDArray dLdA('c', { 1, 2, 3 }, { 3.3, 8.5, 13.36, 3.7, 9.54, 15. }, nd4j::DataType::FLOAT32);
|
||||
NDArray dLdB('c', { 1, 2, 4 }, { 3.38, 4.04, 4.7, 5.13, 3.83, 4.58, 5.33, 5.82 }, nd4j::DataType::FLOAT32);
|
||||
|
||||
nd4j::ops::tensormmul_bp op_bp;
|
||||
|
||||
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,0,1, 2,0,1 }, {});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, resultsBP->status());
|
||||
|
||||
auto* dLdAbp = resultsBP->at(0);
|
||||
auto* dLdBbp = resultsBP->at(1);
|
||||
|
||||
ASSERT_TRUE(dLdA.isSameShape(*dLdAbp));
|
||||
ASSERT_TRUE(dLdA.equalsTo(*dLdAbp));
|
||||
|
||||
ASSERT_TRUE(dLdB.isSameShape(*dLdBbp));
|
||||
ASSERT_TRUE(dLdB.equalsTo(*dLdBbp));
|
||||
|
||||
delete resultsBP;
|
||||
}
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP2) {
|
||||
|
||||
NDArray A('c', { 1, 2, 3 }, { 2,2,2, 2,2,2 }, nd4j::DataType::FLOAT32);
|
||||
NDArray B('c', { 1, 2, 3 }, { 3,3,3,3, 3,3 }, nd4j::DataType::FLOAT32);
|
||||
NDArray dLdC('c', { 1 }, { 1 }, nd4j::DataType::FLOAT32);
|
||||
|
||||
nd4j::ops::tensormmul_bp op_bp;
|
||||
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }, {});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, resultsBP->status());
|
||||
|
||||
auto* dLdAbp = resultsBP->at(0);
|
||||
auto* dLdBbp = resultsBP->at(1);
|
||||
|
||||
ASSERT_TRUE(B.isSameShape(*dLdAbp));
|
||||
ASSERT_TRUE(B.equalsTo(*dLdAbp));
|
||||
|
||||
ASSERT_TRUE(A.isSameShape(*dLdBbp));
|
||||
ASSERT_TRUE(A.equalsTo(*dLdBbp));
|
||||
|
||||
delete resultsBP;
|
||||
}
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP3) {
|
||||
|
||||
NDArray A('c', { 3, 2, 2 }, { 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2 }, nd4j::DataType::FLOAT32);
|
||||
NDArray B('c', { 4, 2, 2 }, { 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2, 4.3, 4.4, 4.5, 4.6 }, nd4j::DataType::FLOAT32);
|
||||
NDArray dLdC('c', { 3, 4 }, { .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2 }, nd4j::DataType::FLOAT32);
|
||||
|
||||
NDArray dA('c', { 3, 2, 2 }, { 3.9, 4., 4.1, 4.2, 9.82, 10.08, 10.34, 10.6, 15.74, 16.16, 16.58, 17. }, nd4j::DataType::FLOAT32);
|
||||
NDArray dB('c', { 4, 2, 2 }, { 4.07, 4.22, 4.37, 4.52, 4.82, 5., 5.18, 5.36, 5.57, 5.78, 5.99, 6.2, 6.32, 6.56, 6.8, 7.04 }, nd4j::DataType::FLOAT32);
|
||||
|
||||
nd4j::ops::tensormmul_bp op_bp;
|
||||
|
||||
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }, {});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, resultsBP->status());
|
||||
|
||||
auto* dLdAbp = resultsBP->at(0);
|
||||
auto* dLdBbp = resultsBP->at(1);
|
||||
|
||||
ASSERT_TRUE(dA.isSameShape(*dLdAbp));
|
||||
ASSERT_TRUE(dA.equalsTo(*dLdAbp));
|
||||
|
||||
ASSERT_TRUE(dB.isSameShape(*dLdBbp));
|
||||
ASSERT_TRUE(dB.equalsTo(*dLdBbp));
|
||||
|
||||
delete resultsBP;
|
||||
}
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP4) {
|
||||
|
||||
NDArray A('c', { 3, 4, 1 }, { 0.4, 3, 5, 9, 23, 0.12, 8, 9, 0.1, 0, 124, 3 }, nd4j::DataType::FLOAT32);
|
||||
NDArray B('c', { 2, 4, 1 }, { 4, 13, .5, 19, 2.3, 1.2, 18, .9 }, nd4j::DataType::FLOAT32);
|
||||
NDArray dLdC('c', { 3, 2 }, { 1.1, 1.2, 1.3, 1.4, 1.5, 1.6 }, nd4j::DataType::FLOAT32);
|
||||
|
||||
NDArray dLdA('c', { 3, 4, 1 }, { 7.16, 15.74, 22.15, 21.98, 8.42, 18.58, 25.85, 25.96, 9.68, 21.42, 29.55, 29.94 }, nd4j::DataType::FLOAT32);
|
||||
NDArray dLdB('c', { 2, 4, 1 }, { 30.49, 3.456, 201.9, 26.1, 32.84 , 3.768, 215.6, 28.2 }, nd4j::DataType::FLOAT32);
|
||||
|
||||
nd4j::ops::tensormmul_bp op_bp;
|
||||
|
||||
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }, {});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, resultsBP->status());
|
||||
|
||||
auto* dLdAbp = resultsBP->at(0);
|
||||
auto* dLdBbp = resultsBP->at(1);
|
||||
|
||||
ASSERT_TRUE(dLdA.isSameShape(*dLdAbp));
|
||||
ASSERT_TRUE(dLdA.equalsTo(*dLdAbp));
|
||||
|
||||
ASSERT_TRUE(dLdB.isSameShape(*dLdBbp));
|
||||
ASSERT_TRUE(dLdB.equalsTo(*dLdBbp));
|
||||
|
||||
delete resultsBP;
|
||||
}
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP5) {
|
||||
|
||||
NDArray A('c', { 3, 4, 1, 1 }, { 0.4, 3, 5, 9, 23, 0.12, 8, 9, 0.1, 0, 124, 3 }, nd4j::DataType::FLOAT32);
|
||||
NDArray B('c', { 2, 4, 1, 1 }, { 4, 13, .5, 19, 2.3, 1.2, 18, .9 }, nd4j::DataType::FLOAT32);
|
||||
NDArray dLdC('c', { 3, 1, 2, 1 }, { 1.1,1.2,1.3,1.4,1.5,1.6 }, nd4j::DataType::FLOAT32);
|
||||
|
||||
NDArray dLdA('c', { 3, 4, 1, 1 }, { 7.16, 15.74, 22.15, 21.98, 8.42, 18.58, 25.85, 25.96, 9.68, 21.42, 29.55, 29.94 }, nd4j::DataType::FLOAT32);
|
||||
NDArray dLdB('c', { 2, 4, 1, 1 }, { 30.49, 3.456, 201.9, 26.1, 32.84, 3.768, 215.6, 28.2 }, nd4j::DataType::FLOAT32);
|
||||
|
||||
nd4j::ops::tensormmul_bp op_bp;
|
||||
|
||||
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }, {});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, resultsBP->status());
|
||||
|
||||
auto* dLdAbp = resultsBP->at(0);
|
||||
auto* dLdBbp = resultsBP->at(1);
|
||||
|
||||
ASSERT_TRUE(dLdA.isSameShape(*dLdAbp));
|
||||
ASSERT_TRUE(dLdA.equalsTo(*dLdAbp));
|
||||
|
||||
ASSERT_TRUE(dLdB.isSameShape(*dLdBbp));
|
||||
ASSERT_TRUE(dLdB.equalsTo(*dLdBbp));
|
||||
|
||||
delete resultsBP;
|
||||
}
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP6) {
|
||||
|
||||
NDArray A('c', { 2, 2, 2 }, { 2,2, 2,2, 2,2, 2,2 }, nd4j::DataType::FLOAT32);
|
||||
NDArray B('c', { 2, 2, 2 }, { 3,3, 3,3, 3,3, 3,3 }, nd4j::DataType::FLOAT32);
|
||||
|
||||
auto dLdC = NDArrayFactory::create<float>(1.f);
|
||||
|
||||
nd4j::ops::tensormmul_bp op_bp;
|
||||
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 3,0,1,2, 3,0,1,2 }, {});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, resultsBP->status());
|
||||
|
||||
auto* dLdAbp = resultsBP->at(0);
|
||||
auto* dLdBbp = resultsBP->at(1);
|
||||
|
||||
ASSERT_TRUE(B.isSameShape(*dLdAbp));
|
||||
ASSERT_TRUE(B.equalsTo(*dLdAbp));
|
||||
|
||||
ASSERT_TRUE(A.isSameShape(*dLdBbp));
|
||||
ASSERT_TRUE(A.equalsTo(*dLdBbp));
|
||||
|
||||
delete resultsBP;
|
||||
}
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP7) {
|
||||
|
||||
NDArray A('c', { 3, 4, 1 }, { 0.4, 3, 5, 9, 23, 0.12, 8, 9, 0.1, 0, 124, 3 }, nd4j::DataType::FLOAT32);
|
||||
NDArray B('c', { 2, 4, 1 }, { 4, 13, .5, 19, 2.3, 1.2, 18, .9 }, nd4j::DataType::FLOAT32);
|
||||
NDArray dLdC('c', { 3, 1, 2, 1 }, { 1.1, 1.2, 1.3, 1.4, 1.5, 1.6 }, nd4j::DataType::FLOAT32);
|
||||
|
||||
NDArray dLdA('c', { 3, 4, 1 }, { 7.16, 15.74, 22.15, 21.98, 8.42, 18.58, 25.85, 25.96, 9.68, 21.42, 29.55, 29.94 }, nd4j::DataType::FLOAT32);
|
||||
NDArray dLdB('c', { 2, 4, 1 }, { 30.49, 3.456, 201.9, 26.1, 32.84, 3.768, 215.6, 28.2 }, nd4j::DataType::FLOAT32);
|
||||
|
||||
nd4j::ops::tensormmul_bp op_bp;
|
||||
|
||||
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 1,1, 1,1 }, {});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, resultsBP->status());
|
||||
auto* dLdAbp = resultsBP->at(0);
|
||||
auto* dLdBbp = resultsBP->at(1);
|
||||
|
||||
ASSERT_TRUE(dLdA.isSameShape(*dLdAbp));
|
||||
ASSERT_TRUE(dLdA.equalsTo(*dLdAbp));
|
||||
|
||||
ASSERT_TRUE(dLdB.isSameShape(*dLdBbp));
|
||||
ASSERT_TRUE(dLdB.equalsTo(*dLdBbp));
|
||||
|
||||
delete resultsBP;
|
||||
}
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP8) {
|
||||
|
||||
NDArray A('c', { 1, 1, 4, 3 }, { 0.4, 3, 5, 9, 23, 0.12, 8, 9, 0.1, 0, 124, 3 }, nd4j::DataType::FLOAT32);
|
||||
NDArray B('c', { 1, 1, 4, 2 }, { 4, 13, .5, 19, 2.3, 1.2, 18, .9 }, nd4j::DataType::FLOAT32);
|
||||
NDArray dLdC('c', { 3, 2 }, { 1.1,1.2,1.3,1.4,1.5,1.6 }, nd4j::DataType::FLOAT32);
|
||||
|
||||
NDArray dLdA('c', { 1, 1, 4, 3 }, { 20., 23.4, 26.8, 23.35, 27.25, 31.15, 3.97, 4.67, 5.37, 20.88, 24.66, 28.44 }, nd4j::DataType::FLOAT32);
|
||||
NDArray dLdB('c', { 1, 1, 4, 2 }, { 11.84, 12.68, 39.98, 43.192, 20.65, 22.36, 165.7, 178.4 }, nd4j::DataType::FLOAT32);
|
||||
|
||||
nd4j::ops::tensormmul_bp op_bp;
|
||||
|
||||
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 3,0,1,2, 3,0,1,2 }, {});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, resultsBP->status());
|
||||
|
||||
auto* dLdAbp = resultsBP->at(0);
|
||||
auto* dLdBbp = resultsBP->at(1);
|
||||
|
||||
ASSERT_TRUE(dLdA.isSameShape(*dLdAbp));
|
||||
ASSERT_TRUE(dLdA.equalsTo(*dLdAbp));
|
||||
|
||||
ASSERT_TRUE(dLdB.isSameShape(*dLdBbp));
|
||||
ASSERT_TRUE(dLdB.equalsTo(*dLdBbp));
|
||||
|
||||
delete resultsBP;
|
||||
}
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP9) {
|
||||
|
||||
NDArray A('c', { 3, 2, 2, 1 }, { 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2 }, nd4j::DataType::FLOAT32);
|
||||
NDArray B('c', { 4, 2, 2 ,1 }, { 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2, 4.3, 4.4, 4.5, 4.6 }, nd4j::DataType::FLOAT32);
|
||||
NDArray dLdC('c', { 3, 1, 4, 1 }, { .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2 }, nd4j::DataType::FLOAT32);
|
||||
|
||||
NDArray dA('c', { 3, 2, 2, 1 }, { 3.9, 4., 4.1, 4.2, 9.82, 10.08, 10.34, 10.6, 15.74, 16.16, 16.58, 17. }, nd4j::DataType::FLOAT32);
|
||||
NDArray dB('c', { 4, 2, 2, 1 }, { 4.07, 4.22, 4.37, 4.52, 4.82, 5., 5.18, 5.36, 5.57, 5.78, 5.99, 6.2, 6.32, 6.56, 6.8, 7.04 }, nd4j::DataType::FLOAT32);
|
||||
|
||||
nd4j::ops::tensormmul_bp op_bp;
|
||||
|
||||
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }, {});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, resultsBP->status());
|
||||
|
||||
auto* dLdAbp = resultsBP->at(0);
|
||||
auto* dLdBbp = resultsBP->at(1);
|
||||
|
||||
ASSERT_TRUE(dA.isSameShape(*dLdAbp));
|
||||
ASSERT_TRUE(dA.equalsTo(*dLdAbp));
|
||||
|
||||
ASSERT_TRUE(dB.isSameShape(*dLdBbp));
|
||||
ASSERT_TRUE(dB.equalsTo(*dLdBbp));
|
||||
|
||||
delete resultsBP;
|
||||
}
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP10) {
|
||||
|
||||
NDArray A('c', { 1, 2, 2, 3 }, { 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2 }, nd4j::DataType::FLOAT32);
|
||||
NDArray B('c', { 1, 2, 2 ,4 }, { 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2, 4.3, 4.4, 4.5, 4.6 }, nd4j::DataType::FLOAT32);
|
||||
NDArray dLdC('c', { 1, 3, 1, 4 }, { .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2 }, nd4j::DataType::FLOAT32);
|
||||
|
||||
|
||||
NDArray dA('c', { 1, 2, 2, 3 }, { 3.3, 8.5, 13.7, 3.7, 9.54, 15.38, 4.1, 10.58, 17.06, 4.5, 11.62, 18.74 }, nd4j::DataType::FLOAT32);
|
||||
NDArray dB('c', { 1, 2, 2, 4 }, { 3.38, 4.04, 4.7, 5.36, 3.83, 4.58, 5.33, 6.08, 4.28, 5.12, 5.96, 6.8, 4.73, 5.66, 6.59, 7.52 }, nd4j::DataType::FLOAT32);
|
||||
|
||||
nd4j::ops::tensormmul_bp op_bp;
|
||||
|
||||
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }, {});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, resultsBP->status());
|
||||
|
||||
auto* dLdAbp = resultsBP->at(0);
|
||||
auto* dLdBbp = resultsBP->at(1);
|
||||
|
||||
ASSERT_TRUE(dA.isSameShape(*dLdAbp));
|
||||
ASSERT_TRUE(dA.equalsTo(*dLdAbp));
|
||||
|
||||
ASSERT_TRUE(dB.isSameShape(*dLdBbp));
|
||||
ASSERT_TRUE(dB.equalsTo(*dLdBbp));
|
||||
|
||||
delete resultsBP;
|
||||
}
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP11) {
|
||||
|
||||
NDArray A('c', { 2, 2, 3 }, { 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2 }, nd4j::DataType::FLOAT32);
|
||||
NDArray B('c', { 2, 2 ,4 }, { 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2, 4.3, 4.4, 4.5, 4.6 }, nd4j::DataType::FLOAT32);
|
||||
NDArray dLdC('c', { 3, 4 }, { .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2 }, nd4j::DataType::FLOAT32);
|
||||
|
||||
|
||||
NDArray dA('c', { 2, 2, 3 }, { 3.3, 8.5, 13.7, 3.7, 9.54, 15.38, 4.1, 10.58, 17.06, 4.5, 11.62, 18.74 }, nd4j::DataType::FLOAT32);
|
||||
NDArray dB('c', { 2, 2, 4 }, { 3.38, 4.04, 4.7, 5.36, 3.83, 4.58, 5.33, 6.08, 4.28, 5.12, 5.96, 6.8, 4.73, 5.66, 6.59, 7.52 }, nd4j::DataType::FLOAT32);
|
||||
|
||||
nd4j::ops::tensormmul_bp op_bp;
|
||||
|
||||
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,0,1, 2,0,1 }, {});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, resultsBP->status());
|
||||
|
||||
auto* dLdAbp = resultsBP->at(0);
|
||||
auto* dLdBbp = resultsBP->at(1);
|
||||
|
||||
ASSERT_TRUE(dA.isSameShape(*dLdAbp));
|
||||
ASSERT_TRUE(dA.equalsTo(*dLdAbp));
|
||||
|
||||
ASSERT_TRUE(dB.isSameShape(*dLdBbp));
|
||||
ASSERT_TRUE(dB.equalsTo(*dLdBbp));
|
||||
|
||||
delete resultsBP;
|
||||
}
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP12) {
|
||||
|
||||
NDArray A('c', { 2, 2, 3 }, { 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2 }, nd4j::DataType::FLOAT32);
|
||||
NDArray B('c', { 2, 2 ,3 }, { 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2 }, nd4j::DataType::FLOAT32);
|
||||
NDArray dLdC('c', { 2, 3, 2, 3 }, { .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2,
|
||||
1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2., 2.1, 2.2, 2.3, 2.4,
|
||||
2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2, 3.3, 3.4, 3.5, 3.6 }, nd4j::DataType::FLOAT32);
|
||||
|
||||
NDArray dA('c', { 2, 2, 3 }, { 7.66, 20.26, 32.86, 8.29, 21.97, 35.65, 45.46, 58.06, 70.66, 49.33, 63.01, 76.69 }, nd4j::DataType::FLOAT32);
|
||||
NDArray dB('c', { 2, 2, 3 }, { 25.86, 27.36, 28.86, 28.74, 30.42, 32.1, 30.36, 31.86, 33.36, 33.78, 35.46, 37.14 }, nd4j::DataType::FLOAT32);
|
||||
|
||||
nd4j::ops::tensormmul_bp op_bp;
|
||||
|
||||
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 1,1, 1,1 }, {});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, resultsBP->status());
|
||||
|
||||
auto* dLdAbp = resultsBP->at(0);
|
||||
auto* dLdBbp = resultsBP->at(1);
|
||||
|
||||
ASSERT_TRUE(dA.isSameShape(*dLdAbp));
|
||||
ASSERT_TRUE(dA.equalsTo(*dLdAbp));
|
||||
|
||||
ASSERT_TRUE(dB.isSameShape(*dLdBbp));
|
||||
ASSERT_TRUE(dB.equalsTo(*dLdBbp));
|
||||
|
||||
delete resultsBP;
|
||||
}
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP13) {
|
||||
|
||||
NDArray A('c', { 3, 2, 2 }, { 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2 }, nd4j::DataType::DOUBLE);
|
||||
NDArray B('c', { 3, 2, 2 }, { 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2 }, nd4j::DataType::DOUBLE);
|
||||
NDArray dLdC('c', { 3, 2, 3, 2 }, { .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2,
|
||||
1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2., 2.1, 2.2, 2.3, 2.4,
|
||||
2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2, 3.3, 3.4, 3.5, 3.6 }, nd4j::DataType::DOUBLE);
|
||||
|
||||
NDArray dA('c', { 3, 2, 2 }, { 7.79, 20.57, 8.21, 21.71, 33.35, 46.13, 35.21, 48.71, 58.91, 71.69, 62.21, 75.71 }, nd4j::DataType::DOUBLE);
|
||||
NDArray dB('c', { 3, 2, 2 }, { 26.49, 28.02, 28.41, 30.06, 29.55, 31.08, 31.71, 33.36, 32.61, 34.14, 35.01, 36.66 }, nd4j::DataType::DOUBLE);
|
||||
|
||||
nd4j::ops::tensormmul_bp op_bp;
|
||||
|
||||
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 1,1, 1,1 }, {});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, resultsBP->status());
|
||||
|
||||
auto* dLdAbp = resultsBP->at(0);
|
||||
auto* dLdBbp = resultsBP->at(1);
|
||||
|
||||
ASSERT_TRUE(dA.isSameShape(*dLdAbp));
|
||||
ASSERT_TRUE(dA.equalsTo(*dLdAbp));
|
||||
|
||||
ASSERT_TRUE(dB.isSameShape(*dLdBbp));
|
||||
ASSERT_TRUE(dB.equalsTo(*dLdBbp));
|
||||
|
||||
delete resultsBP;
|
||||
}
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP14) {
|
||||
|
||||
NDArray A('c', { 2, 2, 2, 2 }, { 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2, 3.3, 3.4, 3.5, 3.6 }, nd4j::DataType::DOUBLE);
|
||||
|
||||
NDArray B('c', { 2, 2, 2, 2 }, { 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2, 4.3, 4.4, 4.5, 4.6 }, nd4j::DataType::DOUBLE);
|
||||
|
||||
NDArray dLdC('c', { 2, 2, 2, 2, 2, 2 }, { .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2,
|
||||
1.3, 1.4, 1.5, 1.6, .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2,
|
||||
1.3, 1.4, 1.5, 1.6, .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2,
|
||||
1.3, 1.4, 1.5, 1.6, .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2,
|
||||
1.3, 1.4, 1.5, 1.6 }, nd4j::DataType::DOUBLE);
|
||||
|
||||
NDArray dA('c', { 2, 2, 2, 2 }, { 13.88, 37.24, 13.88, 37.24, 15.32, 41.24, 15.32, 41.24, 13.88, 37.24, 13.88, 37.24, 15.32, 41.24, 15.32, 41.24 }, nd4j::DataType::DOUBLE);
|
||||
NDArray dB('c', { 2, 2, 2, 2 }, { 10.76, 12.88, 15., 17.12, 12.36, 14.8, 17.24, 19.68, 19.24, 21.36, 23.48, 25.6, 22.12, 24.56, 27., 29.44 }, nd4j::DataType::DOUBLE);
|
||||
|
||||
nd4j::ops::tensormmul_bp op_bp;
|
||||
|
||||
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 1,1, 1,1 }, {});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, resultsBP->status());
|
||||
|
||||
auto* dLdAbp = resultsBP->at(0);
|
||||
auto* dLdBbp = resultsBP->at(1);
|
||||
|
||||
ASSERT_TRUE(dA.isSameShape(*dLdAbp));
|
||||
ASSERT_TRUE(dA.equalsTo(*dLdAbp));
|
||||
|
||||
ASSERT_TRUE(dB.isSameShape(*dLdBbp));
|
||||
ASSERT_TRUE(dB.equalsTo(*dLdBbp));
|
||||
|
||||
delete resultsBP;
|
||||
}
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP15) {
|
||||
|
||||
NDArray A('c', { 2, 2, 3 }, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12. }, nd4j::DataType::FLOAT32);
|
||||
NDArray B('f', { 2, 2, 3 }, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12. }, nd4j::DataType::FLOAT32);
|
||||
|
||||
NDArray dLdC('f', { 2, 2 }, { 23.0, 24.44, 2.0, 26. }, nd4j::DataType::FLOAT32);
|
||||
|
||||
NDArray dA('c', { 2, 2, 3 }, { 27., 127., 227., 77., 177., 277., 76.44, 278.20001, 479.96002, 177.32, 379.08001, 580.839966 }, nd4j::DataType::FLOAT32);
|
||||
NDArray dB('f', { 2, 2, 3 }, { 194.08, 184., 336.4, 268., 241.52, 212., 383.839996, 296., 288.96002, 240., 431.27999, 324. }, nd4j::DataType::FLOAT32);
|
||||
|
||||
nd4j::ops::tensormmul_bp op;
|
||||
auto results = op.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2,2,1,2 });
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto* dLdA = results->at(0);
|
||||
auto* dLdB = results->at(1);
|
||||
|
||||
ASSERT_TRUE(dA.isSameShape(*dLdA));
|
||||
ASSERT_TRUE(dA.equalsTo(*dLdA));
|
||||
|
||||
ASSERT_TRUE(dB.isSameShape(*dLdB));
|
||||
ASSERT_TRUE(dB.equalsTo(*dLdB));
|
||||
|
||||
delete results;
|
||||
}
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP16) {
|
||||
|
||||
NDArray A('f', { 2, 2, 3 }, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12. }, nd4j::DataType::DOUBLE);
|
||||
NDArray B('c', { 2, 2, 3 }, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12. }, nd4j::DataType::DOUBLE);
|
||||
|
||||
NDArray dLdC('c', { 2, 2 }, nd4j::DataType::DOUBLE);
|
||||
|
||||
const OpArgsHolder argsHolderFF({ &A, &B }, {}, { 2,1,2, 2,1,2 });
|
||||
const OpArgsHolder argsHolderBP({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 });
|
||||
|
||||
nd4j::ops::tensormmul op;
|
||||
nd4j::ops::tensormmul_bp op_bp;
|
||||
|
||||
const bool isGradCorrect = GradCheck::checkGrad(op, op_bp, argsHolderFF, argsHolderBP, {1,0});
|
||||
ASSERT_TRUE(isGradCorrect);
|
||||
}
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP17) {
|
||||
|
||||
NDArray A('f', { 2, 2, 3 }, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12. }, nd4j::DataType::DOUBLE);
|
||||
NDArray B('f', { 2, 2, 3 }, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12. }, nd4j::DataType::DOUBLE);
|
||||
|
||||
NDArray dLdC('c', { 2, 2 }, nd4j::DataType::DOUBLE);
|
||||
|
||||
const OpArgsHolder argsHolderFF({ &A, &B }, {}, { 2,1,2, 2,1,2 });
|
||||
const OpArgsHolder argsHolderBP({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 });
|
||||
|
||||
nd4j::ops::tensormmul op;
|
||||
nd4j::ops::tensormmul_bp op_bp;
|
||||
|
||||
const bool isGradCorrect = GradCheck::checkGrad(op, op_bp, argsHolderFF, argsHolderBP, { 1,0 });
|
||||
ASSERT_TRUE(isGradCorrect);
|
||||
}
|
||||
|
||||
|
|
|
@ -578,246 +578,6 @@ TEST_F(DeclarableOpsTests2, Test_Concat_BP_1) {
|
|||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests2, TestTensorDot5) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('c', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
|
||||
auto y = NDArrayFactory::create<double>('c', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
|
||||
auto expected = NDArrayFactory::create<double>('c', {2,4,2,4}, {44,110,160, 66,132, 38, 88,154, 68,170,224,102,204, 82,136,238, 92,230,288,138,276,126,184,322, 116,290,352,174,348,170,232,406, 76,190,160,114,228,182,152,266, 100,250,224,150,300,226,200,350, 124,310,288,186,372,270,248,434, 148,370,352,222,444,314,296,518});
|
||||
|
||||
nd4j::ops::tensormmul op;
|
||||
auto results = op.evaluate({&x, &y}, {}, {1,1,1,2});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto *result = results->at(0);
|
||||
|
||||
ASSERT_TRUE(expected.isSameShape(result));
|
||||
ASSERT_TRUE(expected.equalsTo(result));
|
||||
|
||||
delete results;
|
||||
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests2, TestTensorDot6) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('c', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
|
||||
auto y = NDArrayFactory::create<double>('f', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
|
||||
auto expected = NDArrayFactory::create<double>('c', {2,4,2,4}, {22, 66,110,154, 44, 88,132,176, 34,102,170,238, 68,136,204,272, 46,138,230,322, 92,184,276,368, 58,174,290,406,116,232,348,464, 38,114,190,266, 76,152,228,304, 50,150,250,350,100,200,300,400, 62,186,310,434,124,248,372,496, 74,222,370,518,148,296,444,592});
|
||||
|
||||
nd4j::ops::tensormmul op;
|
||||
auto results = op.evaluate({&x, &y}, {}, {1,1,1,2});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto *result = results->at(0);
|
||||
|
||||
ASSERT_TRUE(expected.isSameShape(result));
|
||||
ASSERT_TRUE(expected.equalsTo(result));
|
||||
|
||||
delete results;
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests2, TestTensorDot7) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
|
||||
auto y = NDArrayFactory::create<double>('c', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
|
||||
auto expected = NDArrayFactory::create<double>('c', {2,4,2,4}, {76,166,112,106,196, 62,136,226, 60,174,208, 98,212,230,136,250, 76,214,336,122,260,174,168,306, 124,286,240,178,340,150,232,394, 100,226,176,142,268,106,184,310, 84,234,272,134,284,274,184,334, 100,274,400,158,332,218,216,390, 148,346,304,214,412,194,280,478});
|
||||
|
||||
nd4j::ops::tensormmul op;
|
||||
auto results = op.evaluate({&x, &y}, {}, {1,1,1,2});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto *result = results->at(0);
|
||||
|
||||
ASSERT_TRUE(expected.isSameShape(result));
|
||||
ASSERT_TRUE(expected.equalsTo(result));
|
||||
|
||||
delete results;
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests2, TestTensorDot8) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
|
||||
auto y = NDArrayFactory::create<double>('f', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
|
||||
auto expected = NDArrayFactory::create<double>('c', {2,4,2,4}, {30, 90,150,210, 60,120,180,240, 38,114,190,266, 76,152,228,304, 46,138,230,322, 92,184,276,368, 54,162,270,378,108,216,324,432, 42,126,210,294, 84,168,252,336, 50,150,250,350,100,200,300,400, 58,174,290,406,116,232,348,464, 66,198,330,462,132,264,396,528});
|
||||
|
||||
nd4j::ops::tensormmul op;
|
||||
auto results = op.evaluate({&x, &y}, {}, {1,1,1,2});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto *result = results->at(0);
|
||||
|
||||
ASSERT_TRUE(expected.isSameShape(result));
|
||||
ASSERT_TRUE(expected.equalsTo(result));
|
||||
|
||||
delete results;
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests2, TestTensorDot9) {
|
||||
|
||||
// NDArray z('f',{2,2,3}, nd4j::DataType::DOUBLE);
|
||||
// z.linspace(1);
|
||||
// z.printShapeInfo();
|
||||
// z.printIndexedBuffer();
|
||||
// z.reshapei('c', {4,3});
|
||||
// z.printShapeInfo();
|
||||
// z.printIndexedBuffer();
|
||||
|
||||
auto x = NDArrayFactory::create<double>('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
|
||||
auto y = NDArrayFactory::create<double>('f', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
|
||||
auto expected = NDArrayFactory::create<double>('c', {3,4,4,3}, {14, 14, 14, 30, 30, 30, 46, 46, 46, 62, 62, 62, 86, 86, 86,198,198,198,310,310,310,422,422,422, 62, 62, 62,142,142,142,222,222,222,302,302,302, 38, 38, 38, 86, 86, 86,134,134,134,182,182,182, 38, 38, 38, 86, 86, 86,134,134,134,182,182,182, 14, 14, 14, 30, 30, 30, 46, 46, 46, 62, 62, 62, 86, 86, 86,198,198,198,310,310,310,422,422,422, 62, 62, 62,142,142,142,222,222,222,302,302,302, 62, 62, 62,142,142,142,222,222,222,302,302,302, 38, 38, 38, 86, 86, 86,134,134,134,182,182,182, 14, 14, 14, 30, 30, 30, 46, 46, 46, 62, 62, 62, 86, 86, 86,198,198,198,310,310,310,422,422,422});
|
||||
|
||||
nd4j::ops::tensormmul op;
|
||||
auto results = op.evaluate({&x, &y}, {}, {1,0,1,0});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto *result = results->at(0);
|
||||
|
||||
ASSERT_TRUE(expected.isSameShape(result));
|
||||
ASSERT_TRUE(expected.equalsTo(result));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests2, TestTensorDot10) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
|
||||
auto y = NDArrayFactory::create<double>('f', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
|
||||
auto expected = NDArrayFactory::create<double>('c', {4,4}, {114,258,402,546, 138,314,490,666, 162,370,578,786, 186,426,666,906});
|
||||
|
||||
nd4j::ops::tensormmul op;
|
||||
auto results = op.evaluate({&x, &y}, {}, {2,0,1, 2,0,2});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto *result = results->at(0);
|
||||
|
||||
ASSERT_TRUE(expected.isSameShape(result));
|
||||
ASSERT_TRUE(expected.equalsTo(result));
|
||||
|
||||
delete results;
|
||||
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests2, TestTensorDot11) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('c', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
|
||||
auto y = NDArrayFactory::create<double>('f', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
|
||||
auto expected = NDArrayFactory::create<double>('c', {4,4}, {98,218,338,458, 134,302,470,638, 170,386,602,818, 206,470,734,998});
|
||||
|
||||
nd4j::ops::tensormmul op;
|
||||
auto results = op.evaluate({&x, &y}, {}, {2,0,1, 2,0,2});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto *result = results->at(0);
|
||||
|
||||
ASSERT_TRUE(expected.isSameShape(result));
|
||||
ASSERT_TRUE(expected.equalsTo(result));
|
||||
|
||||
delete results;
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests2, TestTensorDot12) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('c', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
|
||||
auto y = NDArrayFactory::create<double>('c', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
|
||||
auto expected = NDArrayFactory::create<double>('c', {4,4}, {272,292,312,332, 368,396,424,452, 464,500,536,572, 560,604,648,692});
|
||||
|
||||
nd4j::ops::tensormmul op;
|
||||
auto results = op.evaluate({&x, &y}, {}, {2,0,1, 2,0,2});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto *result = results->at(0);
|
||||
|
||||
ASSERT_TRUE(expected.isSameShape(result));
|
||||
ASSERT_TRUE(expected.equalsTo(result));
|
||||
|
||||
delete results;
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests2, TestTensorDot13) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('c', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
|
||||
auto y = NDArrayFactory::create<double>('c', {4,2,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
|
||||
auto expected = NDArrayFactory::create<double>('c', {3,3}, {640,560,640, 576,624,576, 640,560,640});
|
||||
|
||||
nd4j::ops::tensormmul op;
|
||||
auto results = op.evaluate({&x, &y}, {}, {2,0,2, 2,1,0});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto *result = results->at(0);
|
||||
|
||||
ASSERT_TRUE(expected.isSameShape(result));
|
||||
ASSERT_TRUE(expected.equalsTo(result));
|
||||
|
||||
delete results;
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests2, TestTensorDot14) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
|
||||
auto y = NDArrayFactory::create<double>('c', {4,2,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
|
||||
auto expected = NDArrayFactory::create<double>('c', {3,3}, {648,600,520, 648,536,648, 520,600,648});
|
||||
|
||||
nd4j::ops::tensormmul op;
|
||||
auto results = op.evaluate({&x, &y}, {}, {2,0,2, 2,1,0});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto *result = results->at(0);
|
||||
|
||||
ASSERT_TRUE(expected.isSameShape(result));
|
||||
ASSERT_TRUE(expected.equalsTo(result));
|
||||
|
||||
delete results;
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests2, TestTensorDot15) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
|
||||
auto y = NDArrayFactory::create<double>('f', {4,2,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
|
||||
auto expected = NDArrayFactory::create<double>('c', {3,3}, {624,624,624, 656,656,656, 624,624,624});
|
||||
|
||||
nd4j::ops::tensormmul op;
|
||||
auto results = op.evaluate({&x, &y}, {}, {2,0,2, 2,1,0});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto *result = results->at(0);
|
||||
|
||||
ASSERT_TRUE(expected.isSameShape(result));
|
||||
ASSERT_TRUE(expected.equalsTo(result));
|
||||
|
||||
delete results;
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_1) {
|
||||
|
|
|
@ -2043,34 +2043,6 @@ TEST_F(DeclarableOpsTests9, cumprod_test1) {
|
|||
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {1, 1}, {1, 1},GradCheck::MEAN);
|
||||
|
||||
ASSERT_TRUE(isGradCorrect);
|
||||
|
||||
//************************************//
|
||||
/* exclusive = 1; reverse = 0;
|
||||
|
||||
result = op.execute({&inputC, &axis}, {}, {exclusive, reverse});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
z = result->at(0);
|
||||
ASSERT_TRUE(expTF.equalsTo(z));
|
||||
delete result;
|
||||
*/
|
||||
//************************************//
|
||||
/* exclusive = 0; reverse = 1;
|
||||
|
||||
result = op.execute({&inputC, &axis}, {}, {exclusive, reverse});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
z = result->at(0);
|
||||
ASSERT_TRUE(expFT.equalsTo(z));
|
||||
delete result;
|
||||
*/
|
||||
//************************************//
|
||||
/* exclusive = 1; reverse = 1;
|
||||
|
||||
result = op.execute({&inputC, &axis}, {}, {exclusive, reverse});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
z = result->at(0);
|
||||
ASSERT_TRUE(expTT.equalsTo(z));
|
||||
delete result;
|
||||
*/
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -2079,11 +2051,6 @@ TEST_F(DeclarableOpsTests9, cumprod_test2) {
|
|||
auto inputC = NDArrayFactory::create<double>('c', {2, 2});
|
||||
auto axis = NDArrayFactory::create<double>(1.);
|
||||
|
||||
// auto expFF = NDArrayFactory::create<double>('c', {3, 5}, {1., 2., 6., 24., 120., 6., 42., 336., 3024., 30240.,11., 132.,1716., 24024.,360360.});
|
||||
// auto expTF = NDArrayFactory::create<double>('c', {3, 5}, {1, 1, 2, 6, 24,1, 6, 42, 336, 3024,1, 11, 132, 1716, 24024});
|
||||
|
||||
// auto expFT = NDArrayFactory::create<double>('c', {3, 5}, {120, 120, 60, 20, 5,30240, 5040, 720, 90, 10,360360, 32760, 2730, 210, 15}); //+++
|
||||
// auto expTT = NDArrayFactory::create<double>('c', {3, 5}, {120, 60, 20, 5, 1,5040, 720, 90, 10, 1,32760, 2730, 210, 15, 1});
|
||||
auto gradO = NDArrayFactory::create<double>('c', {2, 2});
|
||||
|
||||
int exclusive, reverse;
|
||||
|
|
|
@ -61,7 +61,7 @@ public class TestPCA extends BaseNd4jTest {
|
|||
assertEquals("Reconstructed matrix is very different from the original.", 0.0, Diff.getDouble(i), 1.0);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testFactorSVDTransposed() {
|
||||
int m = 4;
|
||||
|
|
Loading…
Reference in New Issue