Oleh tenzor mmul (#231)

* Libnd4j: TensorMMul backprop op #8174, raw implementation

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* Libnd4j: TensorMMul backprop op #8174 merge master and some corrections

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* Libnd4j: TensorMMul backprop op #8174 algorithm update, need testing, sync with  master

* Libnd4j: TensorMMul backprop op #8174 fixed incorrect B axes calculation

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* Libnd4j: TensorMMul backprop op #8174 optimize axes identification and fix bug of indeces overlapping, added first test. need testing with different shapes

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* Libnd4j: TensorMMul backprop op #8174 some fixes and improvements need more testing

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* Libnd4j: TensorMMul backprop op #8174 fixed order of matrix multiply

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* Libnd4j: TensorMMul backprop op #8174 fixed issue of incorrect axes definition, add tests based on TF, need additional testing for case dLdC not equal 1

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* Libnd4j: TensorMMul backprop op #8174 fixed scalar case add test

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* Libnd4j: TensorMMul backprop op #8174 fixed bp algorithm, axes definition, need some mode testing with different orders combination f,c; c,f f,f and add some checks for inputs

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* Libnd4j: TensorMMul backprop op #8174 some checks and corrections added tests, exists the problem with different input orders support A-f B-c and A-f B-f

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* Libnd4j: TensorMMul backprop op #8174 sync master

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* - correct bug in MmulHelper::tensorDot(a, b, c, axes_a, axes_b,permutForC)

Signed-off-by: Yurii <iuriish@yahoo.com>

* Libnd4j: TensorMMul backprop op #8174 code clean up and refactoring

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* - add check for linspase ordered permutations in ShapeUtils::evalShapeForTensorDot

Signed-off-by: Yurii <iuriish@yahoo.com>

* - provide additional code in shape::reshape stuff in order to reduce amount of allocation/copy operations during reshaping procedure

Signed-off-by: Yurii <iuriish@yahoo.com>

* - further work on problem of wrong shape evaluation during permute/reshape procedures

Signed-off-by: Yurii <iuriish@yahoo.com>

* - still looking for bug reason in reshape/permute stuff

Signed-off-by: Yurii <iuriish@yahoo.com>

* - correct bug in transform cuda native ops

Signed-off-by: Yurii <iuriish@yahoo.com>

* - correct bug in NDArray::assign

Signed-off-by: Yurii <iuriish@yahoo.com>

* - remove old shape::reshape stuff

Signed-off-by: Yurii <iuriish@yahoo.com>

* - add possibility to disable copy of old buffer to new buffer during reshape operation in NDArray class

Signed-off-by: Yurii <iuriish@yahoo.com>

* - correct bug in tensorDot which had to do with wrong pointers assigments

Signed-off-by: Yurii <iuriish@yahoo.com>

Co-authored-by: Oleh <oleg.semeniv@gmail.com>
master
Yurii Shyrma 2020-02-13 19:33:54 +02:00 committed by GitHub
parent 8c0e378ec3
commit fe47f52896
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
37 changed files with 1524 additions and 901 deletions

View File

@ -999,14 +999,14 @@ namespace nd4j {
* set new order and shape in case of suitable array length (in-place operation)
* order - order to set
* shape - shape to set
*
* copyToNewBuff - if true then old buffer will be copied to new buffer if last one will be allocated after reshaping
* if there was permute applied before or there are weird strides, then new buffer is allocated for array
*/
bool reshapei(const char order, const std::initializer_list<Nd4jLong>& shape);
bool reshapei(const char order, const std::vector<Nd4jLong>& shape);
bool reshapei(const char order, const std::initializer_list<Nd4jLong>& shape, const bool copyToNewBuff = true);
bool reshapei(const char order, const std::vector<Nd4jLong>& shape, const bool copyToNewBuff = true);
bool reshapei(const std::initializer_list<Nd4jLong>& shape);
bool reshapei(const std::vector<Nd4jLong>& shape);
bool reshapei(const std::initializer_list<Nd4jLong>& shape, const bool copyToNewBuff = true);
bool reshapei(const std::vector<Nd4jLong>& shape, const bool copyToNewBuff = true);
/**
* creates new array with corresponding order and shape, new array will point on _buffer of this array
@ -1015,8 +1015,8 @@ namespace nd4j {
*
* if permute have been applied before or there are weird strides, then new buffer is allocated for new array
*/
NDArray reshape(const char order, const std::vector<Nd4jLong>& shape) const &;
NDArray reshape(const char order, const std::vector<Nd4jLong>& shape) &&;
NDArray reshape(const char order, const std::vector<Nd4jLong>& shape, const bool copyToNewBuff = true) const &;
NDArray reshape(const char order, const std::vector<Nd4jLong>& shape, const bool copyToNewBuff = true) &&;
/**
* calculate strides and set given order
@ -1493,7 +1493,7 @@ namespace nd4j {
* @return
*/
bool isS() const;
template <typename T>
std::vector<T> asVectorT();

View File

@ -42,7 +42,7 @@ ND4J_EXPORT std::u32string NDArray::e(const Nd4jLong i) const;
////////////////////////////////////////////////////////////////////////
// copy constructor
NDArray::NDArray(const NDArray& other) {
_context = other._context;
_offset = 0;
@ -308,7 +308,7 @@ NDArray::NDArray(const std::u16string& u16string, nd4j::DataType dtype, nd4j::La
if (!unicode::isStringValidU16(u16string.data(), u16string.data() + u16string.size())) {
throw std::invalid_argument("NDArray::NDArray: invalid character in input string");
}
// one word that is why used 1
Nd4jLong headerLength = ShapeUtils::stringBufferHeaderRequirements(1);
@ -435,11 +435,11 @@ NDArray::NDArray(const std::string& str, nd4j::DataType dtype, nd4j::LaunchConte
_offset = 0;
setShapeInfo(ShapeDescriptor::scalarDescriptor(dtype));
memcpy(bufferAsT<int8_t>(), &offsets[0], 2 * sizeof(Nd4jLong));
auto data = reinterpret_cast<int8_t*>(bufferAsT<int8_t>() + headerLength);
if (dtype == DataType::UTF8) {
memcpy(data, str.data(), str.size());
}
@ -456,13 +456,13 @@ NDArray::NDArray(const std::string& str, nd4j::DataType dtype, nd4j::LaunchConte
/////////////////////////////////////////////////////////////////////////
// constructors for vector of strings
NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<const char*>& string, const nd4j::DataType dataType, nd4j::LaunchContext* context) {
if (!DataTypeUtils::isS(dataType))
throw std::invalid_argument("NDArray::NDArray: invalid DataType, only string dataTypes have to be used");
if (shape::prodLong(shape.data(), shape.size()) != string.size())
throw std::invalid_argument("NDArray::NDArray: Number of strings should match length of array");
for (const auto& str : string) {
if (!unicode::isStringValidU8(str, str + std::char_traits<char>::length(str)) ) {
throw std::invalid_argument("NDArray::NDArray: invalid character in input string");
@ -497,7 +497,7 @@ NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<const cha
setAttached(context->getWorkspace() != nullptr);
memcpy(bufferAsT<int8_t>(), offsets.data(), offsets.size() * sizeof(Nd4jLong));
auto data = reinterpret_cast<int8_t*>(bufferAsT<int8_t>() + headerLength);
auto func = PRAGMA_THREADS_FOR{
@ -631,9 +631,9 @@ NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<std::u16s
setAttached(context->getWorkspace() != nullptr);
memcpy(bufferAsT<int8_t>(), offsets.data(), offsets.size() * sizeof(Nd4jLong));
auto data = reinterpret_cast<int8_t*>(bufferAsT<int8_t>() + headerLength);
auto func = PRAGMA_THREADS_FOR{
for (auto e = start; e < stop; e += increment) {
auto cdata = data + offsets[e];
@ -699,7 +699,7 @@ NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<const cha
auto data = reinterpret_cast<int8_t*>(bufferAsT<int8_t>() + headerLength);
auto func = PRAGMA_THREADS_FOR{
for (auto e = start; e < stop; e += increment) {
auto cdata = data + offsets[e];
@ -715,7 +715,7 @@ NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<const cha
}
};
samediff::Threads::parallel_for(func, 0, lengthOf(), 1);
tickWriteHost();
syncToDevice();
}
@ -764,8 +764,8 @@ NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<std::u32s
memcpy(bufferAsT<int8_t>(), offsets.data(), offsets.size() * sizeof(Nd4jLong));
auto data = reinterpret_cast<int8_t*>(bufferAsT<int8_t>() + headerLength);
auto data = reinterpret_cast<int8_t*>(bufferAsT<int8_t>() + headerLength);
auto func = PRAGMA_THREADS_FOR{
for (auto e = start; e < stop; e += increment) {
auto cdata = data + offsets[e];
@ -781,7 +781,7 @@ NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<std::u32s
}
};
samediff::Threads::parallel_for(func, 0, lengthOf(), 1);
tickWriteHost();
syncToDevice();
}
@ -831,7 +831,7 @@ NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<const cha
memcpy(bufferAsT<int8_t>(), offsets.data(), offsets.size() * sizeof(Nd4jLong));
auto data = reinterpret_cast<int8_t*>(bufferAsT<int8_t>() + headerLength);
auto func = PRAGMA_THREADS_FOR{
for (auto e = start; e < stop; e += increment) {
auto cdata = data + offsets[e];
@ -847,7 +847,7 @@ NDArray::NDArray(const std::vector<Nd4jLong>& shape, const std::vector<const cha
}
};
samediff::Threads::parallel_for(func, 0, lengthOf(), 1);
tickWriteHost();
syncToDevice();
}
@ -887,8 +887,8 @@ bool NDArray::isC() const {
//////////////////////////////////////////////////////////////////////////
bool NDArray::isS() const {
return (dataType() == DataType::UTF8 ||
dataType() == DataType::UTF16 ||
return (dataType() == DataType::UTF8 ||
dataType() == DataType::UTF16 ||
dataType() == DataType::UTF32);
}
@ -1197,8 +1197,8 @@ void NDArray::assign(const NDArray& other, bool allowParallelism) {
throw std::runtime_error("NDArray::assign: lengths of arrays are mismatched");
}
// memcpy is allowed only for same order && same ews (being equal to 1)
if (ordering() == other.ordering() && dataType() == other.dataType() && ews() == 1 && other.ews() == 1)
// memcpy is allowed only for same order c && same ews (being equal to 1)
if (ordering() == other.ordering() && ordering() == 'c' && dataType() == other.dataType() && ews() == 1 && other.ews() == 1)
copyBuffersContinuouslyFrom(other, other.lengthOf() * other.sizeOfT());
else {
NDArray::prepareSpecialUse({this}, {&other});
@ -1569,20 +1569,25 @@ Nd4jLong NDArray::tensorsAlongDimension(const std::vector<int>& dimensions) cons
//////////////////////////////////////////////////////////////////////////
void NDArray::printShapeInfo(const char * msg) const {
//shape::printShapeInfo(_shapeInfo);
if (msg == nullptr)
shape::printShapeInfoLinear(_shapeInfo);
else {
int rank = shape::rank(_shapeInfo);
int lim = shape::shapeInfoLength(rank);
printf("%s: [", msg);
for (int i = 0; i < shape::shapeInfoLength(rank); i++) {
printf("%lld", (long long) _shapeInfo[i]);
if (i < lim - 1)
printf(", ");
}
printf("]\n");
int rank = shape::rank(_shapeInfo);
int lim = shape::shapeInfoLength(rank);
if(msg != nullptr)
printf("shapeInfo %s: [", msg);
else
printf("shapeInfo: [");
printf("%i, ", rank);
for (int i = 1; i < shape::shapeInfoLength(rank) - 3; i++){
if(i == rank + 1)
printf(" ");
printf("%lld,", _shapeInfo[i]);
}
printf(" %lld,", shape::type(_shapeInfo));
printf("%lld,", shape::elementWiseStride(_shapeInfo));
printf("%lld]\n", (Nd4jLong)shape::order(_shapeInfo));
fflush(stdout);
}
@ -1624,7 +1629,7 @@ void NDArray::printBuffer(const char* msg, Nd4jLong limit, const bool sync) cons
if (e < limit - 1)
printf(", ");
}
}
}
else if (this->isS()) {
// todo do we need this print offsets
/*
@ -1773,7 +1778,7 @@ void NDArray::printIndexedBuffer(const char* msg, Nd4jLong limit) const {
printf("%s\n", this->e<bool>(0)?"true":"false");
}
else if (this->isS()) {
// todo do we need this
// todo do we need this
// printf("\"%lld\"\n", this->getOffset(e));
printf("\"%s\"\n", this->e<std::string>(0).c_str());
}
@ -1855,19 +1860,19 @@ void NDArray::updateStrides(const char order) {
//////////////////////////////////////////////////////////////////////////
// set new order and shape in case of suitable array length
bool NDArray::reshapei(const char order, const std::initializer_list<Nd4jLong>& shape) {
bool NDArray::reshapei(const char order, const std::initializer_list<Nd4jLong>& shape, const bool copyToNewBuff) {
std::vector<Nd4jLong> vShape(shape);
return reshapei(order, vShape);
return reshapei(order, vShape, copyToNewBuff);
}
//////////////////////////////////////////////////////////////////////////
bool NDArray::reshapei(const std::initializer_list<Nd4jLong>& shape) {
return reshapei('c', shape);
bool NDArray::reshapei(const std::initializer_list<Nd4jLong>& shape, const bool copyToNewBuff) {
return reshapei(ordering(), shape, copyToNewBuff);
}
//////////////////////////////////////////////////////////////////////////
bool NDArray::reshapei(const std::vector<Nd4jLong>& shape) {
return reshapei('c', shape);
bool NDArray::reshapei(const std::vector<Nd4jLong>& shape, const bool copyToNewBuff) {
return reshapei(ordering(), shape, copyToNewBuff);
}
//////////////////////////////////////////////////////////////////////////
@ -1918,18 +1923,18 @@ Nd4jLong NDArray::argMax(std::initializer_list<int> dimensions) {
//////////////////////////////////////////////////////////////////////////
// create new array with corresponding order and shape, new array will point to the same _buffer as this array
NDArray NDArray::reshape(const char order, const std::vector<Nd4jLong>& shape) const & {
NDArray NDArray::reshape(const char order, const std::vector<Nd4jLong>& shape, const bool copyToNewBuff) const & {
NDArray newArr(getDataBuffer(), ShapeDescriptor(getShapeInfo()), getContext(), getBufferOffset());
newArr.reshapei(order, shape);
newArr.reshapei(order, shape, copyToNewBuff);
return newArr;
}
//////////////////////////////////////////////////////////////////////////
NDArray NDArray::reshape(const char order, const std::vector<Nd4jLong>& shape) && {
NDArray NDArray::reshape(const char order, const std::vector<Nd4jLong>& shape, const bool copyToNewBuff) && {
this->reshapei(order, shape);
this->reshapei(order, shape, copyToNewBuff);
return std::move(*this);
}
@ -2280,7 +2285,7 @@ template <typename T>
NDArray NDArray::asT() const{
auto result = isScalar() ? NDArray('c', {}, std::vector<double>{0.}, DataTypeUtils::fromT<T>(), this->getContext()) : NDArray(ordering(), getShapeAsVector(), DataTypeUtils::fromT<T>(), this->getContext());
NDArray::prepareSpecialUse({&result}, {this});
NativeOpExecutioner::execTransformAny(getContext(), transform::AnyOps::Assign, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), result.getBuffer(), result.getShapeInfo(), result.getSpecialBuffer(), result.getSpecialShapeInfo(), nullptr, nullptr, nullptr);
NDArray::registerSpecialUse({&result}, {this});
@ -2298,15 +2303,15 @@ NDArray NDArray::asS() const {
auto dtype = DataTypeUtils::fromT<T>();
if (!(DataTypeUtils::isS(dtype)))
if (!(DataTypeUtils::isS(dtype)))
throw std::invalid_argument("NDArray::asS: invalid DataType used");
if (dtype == dataType()) {
Nd4jLong offsetsLength = ShapeUtils::stringBufferHeaderRequirements(lengthOf());
const auto nInputoffsets = bufferAsT<Nd4jLong>();
std::shared_ptr<DataBuffer> pBuffer = std::make_shared<DataBuffer>(offsetsLength + nInputoffsets[lengthOf()], dtype, getContext()->getWorkspace(), true);
NDArray res(pBuffer, ShapeDescriptor(dtype, ordering(), getShapeAsVector()), getContext());
res.setAttached(getContext()->getWorkspace() != nullptr);
@ -2319,7 +2324,7 @@ NDArray NDArray::asS() const {
registerPrimaryUse({ &res }, { this });
return res;
}
Nd4jLong offsetsLength = ShapeUtils::stringBufferHeaderRequirements(lengthOf());
std::vector<Nd4jLong> offsets(lengthOf() + 1);
@ -2353,7 +2358,7 @@ NDArray NDArray::asS() const {
NDArray res(pBuffer, ShapeDescriptor(dtype, ordering(), getShapeAsVector()), getContext());
res.setAttached(getContext()->getWorkspace() != nullptr);
preparePrimaryUse({ &res }, { this });
memcpy(res.bufferAsT<int8_t>(), offsets.data(), offsets.size() * sizeof(Nd4jLong));
@ -2403,7 +2408,7 @@ BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT NDArray NDArray::asS, () const, LIBND
////////////////////////////////////////////////////////////////////////
NDArray NDArray::asT(DataType dtype) const {
if (isS() && !DataTypeUtils::isS(dtype))
throw std::runtime_error("NDArray::asT: you can't use this method on String array with not string DataType!");
@ -3221,7 +3226,7 @@ BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT std::vector, NDArray::asVectorT(), LI
//////////////////////////////////////////////////////////////////////////
// set new order and shape in case of suitable array length
bool NDArray::reshapei(const char order, const std::vector<Nd4jLong>& cshape) {
bool NDArray::reshapei(const char order, const std::vector<Nd4jLong>& cshape, const bool copyToNewBuff) {
// check firstly whether cshape is identical to shape of array, if yes then reshape is unnecessary
if(order == ordering() && shape::shapeEquals(rankOf(), shapeOf(), cshape.size(), cshape.data()))
@ -3293,19 +3298,15 @@ bool NDArray::reshapei(const char order, const std::vector<Nd4jLong>& cshape) {
Nd4jLong *shapeInfoNew;
ALLOCATE(shapeInfoNew, getContext()->getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong);
bool canReshape = shape::reshapeC(rankOf(), shapeInfo(), shape.size(), shape.data(), shapeInfoNew);
bool canReshape = shape::reshapeC(shapeInfo(), order, shape.size(), shape.data(), shapeInfoNew);
// we can do this only if there was no permute applied, or there are no weird strides
if (canReshape) {
if(ordering() == 'c' && order == 'f')
throw std::invalid_argument("NDArray::reshapei(order, shape): in case of reshapeC it doesn't make sense to reshape from c order to f order !");
shape::setEws(shapeInfoNew, arrLength);
setShapeInfo(shapeInfoNew);
}
else {
NDArray temp(order, shape, dataType(), getContext());
this->applyTransform(transform::Assign, temp, nullptr);
if(copyToNewBuff)
this->applyTransform(transform::Assign, temp, nullptr);
*this = std::move(temp);
}
@ -3463,7 +3464,7 @@ NDArray NDArray::dup(const char newOrder) const {
if (isS()) {
if (dataType() == DataType::UTF8) {
std::vector<std::string> strings(lengthOf());
auto func = PRAGMA_THREADS_FOR{
for (auto i = start; i < stop; i += increment) {
strings[i] = std::move(this->e<std::string>(i));
@ -3521,7 +3522,7 @@ bool NDArray::equalsTo(const NDArray *other, double eps) const {
if (isS()) {
// string is special case, we'll compare them one by one, considering both arrays are guaranteed to have the same length
if (dataType() == DataType::UTF8) {
for (int e = 0; e < this->lengthOf(); e++) {
auto s1 = this->e<std::string>(e);
@ -3585,7 +3586,7 @@ std::string NDArray::e(const Nd4jLong i) const {
if (i == lengthOf())
throw std::runtime_error("Can't get std::string for index out of range");
if (this->dataType() == DataType::UTF16) {
auto u16 = this->e<std::u16string>(i);
std::string s;
@ -4846,7 +4847,7 @@ NDArray NDArray::operator()(const std::vector<Nd4jLong>& idx, const bool keepUni
auto shapeOf = shape::shapeOf(newShapeInfo);
auto stridesOf = shape::stride(newShapeInfo);
Nd4jLong offset(0), subArrLen(1);
Nd4jLong offset = 0;
int n(isStrided ? 3 : 2), first, last, stride;
for (int d = rank - 1; d >= 0; --d) {
@ -4863,29 +4864,31 @@ NDArray NDArray::operator()(const std::vector<Nd4jLong>& idx, const bool keepUni
if(shapeOf[d] != 1)
stridesOf[d] *= stride;
}
}
subArrLen *= shapeOf[d];
Nd4jLong *shapeInfoNoUnities = newShapeInfo;
if(!keepUnitiesInShape) {
std::vector<int> dimsWithUnities;
for (uint d = 0; d < rank; ++d)
if(idx[n*d] != idx[n*d+1] && shapeOf[d] == 1)
dimsWithUnities.push_back(d);
if(!dimsWithUnities.empty())
shapeInfoNoUnities = ShapeBuilders::copyShapeInfoWithoutUnites(newShapeInfo, dimsWithUnities.size(), dimsWithUnities.data(), getContext()->getWorkspace());
}
// check if there is possibility to set ews = 1
shape::setEws(newShapeInfo, subArrLen);
shape::checkStridesSetEwsAndOrder(shapeInfoNoUnities);
NDArray result(_buffer, ShapeDescriptor(newShapeInfo), getContext(), offset + getBufferOffset());
NDArray result(_buffer, ShapeDescriptor(shapeInfoNoUnities), getContext(), offset + getBufferOffset());
result._isView = true;
if(!keepUnitiesInShape) {
const int coeff = isStrided ? 3 : 2;
std::vector<Nd4jLong> nonUnitDims;
for (int d = 0; d < rank; ++d)
if(!(idx[coeff*d] != idx[coeff*d+1] && newShapeInfo[d+1] == 1))
nonUnitDims.push_back(newShapeInfo[d+1]);
if(nonUnitDims.size() != rank)
result.reshapei(nonUnitDims);
}
RELEASE(newShapeInfo, getContext()->getWorkspace());
if(newShapeInfo != shapeInfoNoUnities)
RELEASE(shapeInfoNoUnities, getContext()->getWorkspace());
return result;
}

View File

@ -30,15 +30,15 @@
namespace nd4j {
class ND4J_EXPORT ShapeBuilders {
public:
public:
static Nd4jLong* createScalarShapeInfo(nd4j::DataType dataType, nd4j::memory::Workspace* workspace = nullptr);
static Nd4jLong* createVectorShapeInfo(const nd4j::DataType dataType, const Nd4jLong length, nd4j::memory::Workspace* workspace = nullptr);
/**
* create shapeInfo for given order basing on shape stored in shapeOnly vector
* memory allocation for shapeInfo is on given workspace
*/
*/
static Nd4jLong* createShapeInfo(const nd4j::DataType dataType, const char order, int rank, const Nd4jLong* shapeOnly, memory::Workspace* workspace = nullptr);
static Nd4jLong* createShapeInfo(const nd4j::DataType dataType, const char order, const std::vector<Nd4jLong>& shapeOnly, memory::Workspace* workspace = nullptr);
static Nd4jLong* createShapeInfo(const nd4j::DataType dataType, const char order, const std::initializer_list<Nd4jLong>& shapeOnly, memory::Workspace* workspace = nullptr);
@ -51,6 +51,13 @@ namespace nd4j {
static Nd4jLong* copyShapeInfoAndType(const Nd4jLong* inShapeInfo, const DataType dtype, const bool copyStrides, memory::Workspace* workspace = nullptr);
static Nd4jLong* copyShapeInfoAndType(const Nd4jLong* inShapeInfo, const Nd4jLong* shapeInfoToGetTypeFrom, const bool copyStrides, memory::Workspace* workspace = nullptr);
/**
* allocates memory for new shapeInfo and copy all information from inShapeInfo to new shapeInfo except dimensions in dimsToExclude (unit dimensions) and corresponding strides
* for example inShapeInfo is {3, 2,1,3,1,4, 12,12,4,4,1, 16384,1,99}, dimsToExclude = {2,3}, dimsSize = 2
* then outShapeInfo will contain {3, 2,3,4, 12,4,1, 16384,1,99}
*/
static Nd4jLong* copyShapeInfoWithoutUnites(const Nd4jLong* inShapeInfo, const int dimsSize, const int* dimsToExclude, memory::Workspace* workspace = nullptr);
static Nd4jLong* emptyShapeInfo(const nd4j::DataType dataType, memory::Workspace* workspace = nullptr);
static Nd4jLong* emptyShapeInfo(const nd4j::DataType dataType, const char order, const std::vector<Nd4jLong> &shape, memory::Workspace* workspace = nullptr);

View File

@ -68,7 +68,7 @@ namespace nd4j {
const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(shapeInfo, dimsToExclude);
const int subArrRank = (rank == dimsToExclude.size() || descriptor.areUnitiesinShape()) ? rank : rank - dimsToExclude.size();
auto sPtr = new Nd4jLong[shape::shapeInfoLength(subArrRank)];
auto sPtr = new Nd4jLong[shape::shapeInfoLength(subArrRank)]; // shape of sub-arrays (same for all for them)
auto oPtr = new Nd4jLong[numOfSubArrs];
if (numOfSubArrs > 0)

View File

@ -43,23 +43,30 @@ nd4j::NDArray* nd4j::MmulHelper::tensorDot(const nd4j::NDArray* a, const nd4j::N
auto outShape = ShapeUtils::evalShapeForTensorDot(a, b, axes_0, axes_1, permutAt, permutBt, shapeAt, shapeBt);
NDArray aPR = a->permute(permutAt);
NDArray bPR = b->permute(permutBt);
// check whether permutation is necessary
const NDArray* aP = permutAt.empty() ? a : new NDArray(a->permute(permutAt));
const NDArray* bP = permutBt.empty() ? b : new NDArray(b->permute(permutBt));
// check whether reshape is necessary
if(!aPR.isSameShape(shapeAt))
aPR.reshapei( shapeAt);
if(!bPR.isSameShape(shapeBt))
bPR.reshapei( shapeBt);
const NDArray* aPR = aP->isSameShape(shapeAt) ? aP : new NDArray(aP->reshape(aP->ordering(), shapeAt));
const NDArray* bPR = bP->isSameShape(shapeAt) ? bP : new NDArray(bP->reshape(bP->ordering(), shapeBt));
NDArray* c = mmul(&aPR, &bPR, nullptr, 1.0, 0.0);
NDArray* c = mmul(aPR, bPR, nullptr, 1.0, 0.0);
c->reshapei(outShape);
if(aP != aPR)
delete aPR;
if(bP != bPR)
delete bPR;
if(a != aP)
delete aP;
if(b != bP)
delete bP;
return c;
}
//////////////////////////////////////////////////////////////////////////
void nd4j::MmulHelper::tensorDot(const nd4j::NDArray* a, const nd4j::NDArray* b, nd4j::NDArray* c, const std::vector<int>& axes_a, const std::vector<int>& axes_b, const std::vector<int>& permutForC) {
@ -67,32 +74,38 @@ void nd4j::MmulHelper::tensorDot(const nd4j::NDArray* a, const nd4j::NDArray* b,
std::vector<Nd4jLong> shapeAt, shapeBt;
ShapeUtils::evalShapeForTensorDot(a, b, axes_a, axes_b, permutAt, permutBt, shapeAt, shapeBt);
NDArray *cP(c), *cPR(c);
// check whether permutation is required
if(!permutForC.empty())
cP = new NDArray(c->permute(permutForC));
NDArray* cP = permutForC.empty() ? c : new NDArray(c->permute(permutForC));
auto aPR = a->permute(permutAt);
auto bPR = b->permute(permutBt);
// check whether permutation is necessary
const NDArray* aP = permutAt.empty() ? a : new NDArray(a->permute(permutAt));
const NDArray* bP = permutBt.empty() ? b : new NDArray(b->permute(permutBt));
// check whether reshape is necessary
if(!aPR.isSameShape(shapeAt))
aPR.reshapei(shapeAt);
if(!bPR.isSameShape(shapeBt))
bPR.reshapei(shapeBt);
const NDArray* aPR = aP->isSameShape(shapeAt) ? aP : new NDArray(aP->reshape(aP->ordering(), shapeAt));
const NDArray* bPR = bP->isSameShape(shapeAt) ? bP : new NDArray(bP->reshape(bP->ordering(), shapeBt));
if(!cP->isSameShape({aPR.sizeAt(0), bPR.sizeAt(1)}))
cPR = new NDArray(cP->reshape(cP->ordering(), {aPR.sizeAt(0), bPR.sizeAt(1)}));
std::vector<Nd4jLong> requiredCshape = {aPR->sizeAt(0), bPR->sizeAt(1)};
mmul(&aPR, &bPR, cPR, 1.0, 0.0);
NDArray* cPR = cP->isSameShape(requiredCshape) ? cP : new NDArray(cP->reshape(cP->ordering(), requiredCshape, false));
mmul(aPR, bPR, cPR, 1.0, 0.0);
if(cPR->getBuffer() != cP->getBuffer() || cPR->getSpecialBuffer() != cP->getSpecialBuffer() ) // this means both permute and reshape have been performed on c, cP always points on c->getBuffer()
cP->assign(cPR);
if(cPR != c)
if(aP != aPR)
delete aPR;
if(bP != bPR)
delete bPR;
if(a != aP)
delete aP;
if(b != bP)
delete bP;
if(cP != cPR)
delete cPR;
if(cP != c)
if(c != cP)
delete cP;
}
@ -129,7 +142,7 @@ void nd4j::MmulHelper::tensorDot(const NDArray* a, const NDArray* b, NDArray* c,
if(!whatToDoWithC.empty()) {
cArrs = std::vector<NDArray*>(whatToDoWithC.size()+1, c);
for(int i = 0; i < cArrs.size()-1; ++i)
cArrs[i+1] = (whatToDoWithC[i] == 'p') ? new NDArray(cArrs[i]->permute(modifC[i])) : new NDArray(cArrs[i]->reshape(c->ordering(), modifC[i])); // since we ignore first element in cArrs (that is cArrs[0]) then it is always equal to c
cArrs[i+1] = (whatToDoWithC[i] == 'p') ? new NDArray(cArrs[i]->permute(modifC[i])) : new NDArray(cArrs[i]->reshape(c->ordering(), modifC[i], false)); // since we ignore first element in cArrs (that is cArrs[0]) then it is always equal to c
}
mmul(aPR, bPR, cArrs[cArrs.size()-1], 1.0, 0.0);
@ -208,7 +221,7 @@ nd4j::NDArray* MmulHelper::mmul(const nd4j::NDArray* A, const nd4j::NDArray* B,
// vector x matrix, A{M} x B{M,N} = C{N} -> reduce to matrix x matrix A2{1,M} x B{M,N} = C2{1,N}, since there is no corresponding blas operation sgevm
if(isAVector && bRank == 2) {
NDArray* A2 = new NDArray(A->reshape(A->ordering(), {1, A->lengthOf()})); // A{M} -> A2{1,M}
NDArray* C2 = C ? new NDArray(C->reshape(C->ordering(), {1, C->lengthOf()})) : nullptr; // C{N} -> C2{1,N}
NDArray* C2 = C ? new NDArray(C->reshape(C->ordering(), {1, C->lengthOf()}, false)) : nullptr; // C{N} -> C2{1,N}
auto result = mmulMxM(A2, B, C2, alpha, beta, outOrder); // result{1,N}
delete A2;
delete C2;

View File

@ -139,5 +139,15 @@ namespace nd4j {
return ShapeBuilders::copyShapeInfoAndType(inShapeInfo, ArrayOptions::dataType(shapeInfoToGetTypeFrom), copyStrides, workspace);
}
////////////////////////////////////////////////////////////////////////////////
Nd4jLong* ShapeBuilders::copyShapeInfoWithoutUnites(const Nd4jLong* inShapeInfo, const int dimsSize, const int* dimsToExclude, memory::Workspace* workspace) {
Nd4jLong *outShapeInfo = nullptr;
ALLOCATE(outShapeInfo, workspace, shape::shapeInfoLength(inShapeInfo[0] - dimsSize), Nd4jLong);
shape::excludeUnitiesFromShapeInfo(inShapeInfo, dimsSize, dimsToExclude, outShapeInfo);
return outShapeInfo;
}
}

View File

@ -75,10 +75,23 @@ std::vector<Nd4jLong> ShapeUtils::evalShapeForTensorDot(const Nd4jLong* aShapeIn
permutBt = axesB;
permutBt.insert(permutBt.end(), list_B.begin(), list_B.end());
// if permut contains something like {0,1,2,..rank-1}, then there is no need to make permutation and we return empty vector in this case
uint i1, i2;
for(i1 = 0; i1 < aRank; ++i1)
if(permutAt[i1] != i1)
break;
if(i1 == aRank)
permutAt = {};
for(i2 = 0; i2 < bRank; ++i2)
if(permutBt[i2] != i2)
break;
if(i2 == bRank)
permutBt = {};
Nd4jLong n2 = 1;
for (int i = 0; i < axeAsize; i++)
n2 *= aShapeInfo[axesA[i] + 1];
shapeAt = {-1, n2};
shapeAt = {shape::length(aShapeInfo) / n2, n2};
std::vector<Nd4jLong> oldShapeA;
oldShapeA.resize(list_A.size());
@ -89,7 +102,7 @@ std::vector<Nd4jLong> ShapeUtils::evalShapeForTensorDot(const Nd4jLong* aShapeIn
Nd4jLong n3 = 1;
for (int i = 0; i < axeBsize; i++)
n3 *= bShapeInfo[axesB[i] + 1];
shapeBt = {n3, -1};
shapeBt = {n3, shape::length(bShapeInfo) / n3};
std::vector<Nd4jLong> oldShapeB;
oldShapeB.resize(list_B.size());
@ -306,10 +319,10 @@ std::vector<Nd4jLong> ShapeUtils::evalRepeatShape(int axis, const std::vector<in
Nd4jLong* ShapeUtils::evalPermShapeInfo(const int* dimensions, const int rank, const NDArray& arr, nd4j::memory::Workspace* workspace) {
if (!arr.nonNull())
throw std::runtime_error("ShapeUtils::evalPermShapeInfo static method: wrong arguments in pn/termute method: either array is nullptr!");
throw std::runtime_error("ShapeUtils::evalPermShapeInfo static method: wrong arguments: array is nullptr!");
if (rank != arr.rankOf())
throw std::runtime_error("ShapeUtils::evalPermShapeInfo static method: wrong arguments in pn/termute method: rank is not suitable!");
throw std::runtime_error("ShapeUtils::evalPermShapeInfo static method: wrong arguments: rank is not suitable!");
auto shapeInfoLength = shape::shapeInfoLength(rank);
// allocate memory for new array - shapeInfo

File diff suppressed because it is too large Load Diff

View File

@ -84,7 +84,7 @@ namespace functions {
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
int totalThreads = gridDim.x * blockDim.x;
if(xEws > 0 && zEws > 0 && xOrder == zOrder) {
if(xEws > 0 && zEws > 0 && xOrder == zOrder && xOrder == 'c') {
for (int i = tid; i < length; i += totalThreads)
z[i * zEws] = OpType::op(x[i * xEws], params);

View File

@ -89,7 +89,7 @@ namespace functions {
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
int totalThreads = gridDim.x * blockDim.x;
if(xEws > 0 && zEws > 0 && xOrder == zOrder) {
if(xEws > 0 && zEws > 0 && xOrder == zOrder && xOrder == 'c') {
for (int i = tid; i < length; i += totalThreads)
z[i * zEws] = OpType::op(x[i * xEws], params);

View File

@ -97,7 +97,7 @@ namespace functions {
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
int totalThreads = gridDim.x * blockDim.x;
if(xEws > 0 && zEws > 0 && xOrder == zOrder) {
if(xEws > 0 && zEws > 0 && xOrder == zOrder && xOrder == 'c') {
for (Nd4jLong i = tid; i < length; i += totalThreads)
z[i * zEws] = OpType::op(x[i * xEws], params);

View File

@ -87,7 +87,7 @@ namespace functions {
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
int totalThreads = gridDim.x * blockDim.x;
if(xEws > 0 && zEws > 0 && xOrder == zOrder) {
if(xEws > 0 && zEws > 0 && xOrder == zOrder && xOrder == 'c') {
for (int i = tid; i < length; i += totalThreads)
z[i * zEws] = OpType::op(x[i * xEws], params);

View File

@ -89,7 +89,7 @@ namespace functions {
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
int totalThreads = gridDim.x * blockDim.x;
if(xEws > 0 && zEws > 0 && xOrder == zOrder) {
if(xEws > 0 && zEws > 0 && xOrder == zOrder && xOrder == 'c') {
for (int i = tid; i < length; i += totalThreads)
z[i * zEws] = OpType::op(x[i * xEws], params);

View File

@ -21,70 +21,174 @@
#include <op_boilerplate.h>
#if NOT_EXCLUDED(OP_tensormmul)
#include <numeric>
#include <helpers/ShapeUtils.h>
#include <ops/declarable/CustomOperations.h>
#include <MmulHelper.h>
namespace nd4j {
namespace ops {
CUSTOM_OP_IMPL(tensormmul, 2, 1, false, 0, -1) {
auto a = INPUT_VARIABLE(0);
auto b = INPUT_VARIABLE(1);
namespace ops {
auto c = OUTPUT_VARIABLE(0); //
////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(tensormmul, 2, 1, false, 0, -1) {
REQUIRE_TRUE(a->dataType() == b->dataType(), 0, "tensormmul: A, B and C data types must be the same");
auto a = INPUT_VARIABLE(0);
auto b = INPUT_VARIABLE(1);
// building axes
int axe0_size = INT_ARG(0);
int axe1_size = INT_ARG(axe0_size+1);
std::vector<int> axes_0(axe0_size), axes_1(axe1_size);
for (int e = 0; e < axe0_size; e++)
axes_0[e] = (int) INT_ARG(e+1);
auto c = OUTPUT_VARIABLE(0);
for (int e = 0; e < axe1_size; e++)
axes_1[e] = (int) INT_ARG(e + axe0_size + 2);
REQUIRE_TRUE(a->dataType() == b->dataType(), 0, "tensormmul: A, B and C data types must be the same");
nd4j_verbose("axe0: %i; axe1: %i;\n", axes_0.size(), axes_1.size());
// building axes
int axe0_size = INT_ARG(0);
int axe1_size = INT_ARG(axe0_size+1);
std::vector<int> axes_0(axe0_size), axes_1(axe1_size);
for (int e = 0; e < axe0_size; e++)
axes_0[e] = (int)INT_ARG(e + 1);
MmulHelper::tensorDot(a, b, c, axes_0, axes_1);
return Status::OK();
}
DECLARE_SYN(tensordot, tensormmul);
for (int e = 0; e < axe1_size; e++)
axes_1[e] = (int)INT_ARG(e + axe0_size + 2);
nd4j_verbose("axe0: %i; axe1: %i;\n", axes_0.size(), axes_1.size());
DECLARE_SHAPE_FN(tensormmul) {
auto aShapeInfo = inputShape->at(0);
auto bShapeInfo = inputShape->at(1);
MmulHelper::tensorDot(a, b, c, axes_0, axes_1);
return Status::OK();
}
DECLARE_SYN(tensordot, tensormmul);
REQUIRE_TRUE(ArrayOptions::dataType(aShapeInfo) == ArrayOptions::dataType(bShapeInfo), 0, "tensormmul: A and B data types must be the same");
////////////////////////////////////////////////////////////////////////
DECLARE_SHAPE_FN(tensormmul) {
// building axes
int axe0_size = INT_ARG(0);
int axe1_size = INT_ARG(axe0_size+1);
std::vector<int> axes_0(axe0_size), axes_1(axe1_size);
for (int e = 0; e < axe0_size; e++)
axes_0[e] = (int) INT_ARG(e+1);
auto aShapeInfo = inputShape->at(0);
auto bShapeInfo = inputShape->at(1);
for (int e = 0; e < axe1_size; e++)
axes_1[e] = (int) INT_ARG(e + axe0_size + 2);
REQUIRE_TRUE(ArrayOptions::dataType(aShapeInfo) == ArrayOptions::dataType(bShapeInfo), 0, "tensormmul: A and B data types must be the same");
// evaluate shapes
std::vector<int> permutAt, permutBt;
std::vector<Nd4jLong> shapeAt, shapeBt;
auto outShape = nd4j::ShapeUtils::evalShapeForTensorDot(aShapeInfo, bShapeInfo, axes_0, axes_1, permutAt, permutBt, shapeAt, shapeBt);
// building axes
int axe0_size = INT_ARG(0);
int axe1_size = INT_ARG(axe0_size+1);
std::vector<int> axes_0(axe0_size), axes_1(axe1_size);
for (int e = 0; e < axe0_size; e++)
axes_0[e] = (int) INT_ARG(e+1);
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(aShapeInfo), 'c', outShape)));
}
for (int e = 0; e < axe1_size; e++)
axes_1[e] = (int) INT_ARG(e + axe0_size + 2);
DECLARE_TYPES(tensormmul) {
getOpDescriptor()
->setAllowedInputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF})
->setAllowedInputTypes(1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF})
->setAllowedOutputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF});
}
// evaluate shapes
std::vector<int> permutAt, permutBt;
std::vector<Nd4jLong> shapeAt, shapeBt;
auto outShape = nd4j::ShapeUtils::evalShapeForTensorDot(aShapeInfo, bShapeInfo, axes_0, axes_1, permutAt, permutBt, shapeAt, shapeBt);
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(aShapeInfo), 'c', outShape)));
}
////////////////////////////////////////////////////////////////////////
DECLARE_TYPES(tensormmul) {
getOpDescriptor()
->setAllowedInputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF})
->setAllowedInputTypes(1, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF})
->setAllowedOutputTypes(0, {DataType::FLOAT32, DataType ::DOUBLE, DataType::HALF});
}
////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(tensormmul_bp, 3, 2, false, 0, -1) {
auto A = INPUT_VARIABLE(0);
auto B = INPUT_VARIABLE(1);
auto dLdC = INPUT_VARIABLE(2);
auto dLdA = OUTPUT_VARIABLE(0);
auto dLdB = OUTPUT_VARIABLE(1);
REQUIRE_TRUE( (A->dataType() == B->dataType() && (dLdC->dataType() == A->dataType())), 0, "tensormmul_bp: A, B and dLdC data types must be the same");
int axe0Size = INT_ARG(0);
int axe1Size = INT_ARG(axe0Size + 1);
auto Arank = A->rankOf();
auto Brank = B->rankOf();
auto dLdCrank = dLdC->rankOf();
REQUIRE_TRUE((Arank >= axe0Size), 0, "tensormmul_bp: A rank must be the higher or same as input axes 0");
REQUIRE_TRUE((Brank >= axe1Size), 0, "tensormmul_bp: B rank must be the higher or same as input axes 1");
// building axes
std::vector<int> axes0(axe0Size), axes1(axe1Size);
for (uint e = 0; e < axe0Size; e++)
axes0[e] = (int)INT_ARG(e + 1);
for (uint e = 0; e < axe1Size; e++)
axes1[e] = (int)INT_ARG(e + axe0Size + 2);
std::vector<int> permutAt, permutBt;
std::vector<Nd4jLong> shapeAt, shapeBt;
ShapeUtils::evalShapeForTensorDot(A, B, axes0, axes1, permutAt, permutBt, shapeAt, shapeBt);
// special case for scalar value
if (dLdC->isScalar()) {
dLdA->assign((*dLdC) * *B);
dLdB->assign((*dLdC) * *A);
return Status::OK();
}
std::vector<int> axesA = ShapeUtils::evalDimsToExclude(Arank, axes0);
std::vector<int> axesB = ShapeUtils::evalDimsToExclude(Brank, axes1);
// rank always have to be divided by 2
std::vector<int> axesAdLdC, axesBdLdC;
if (dLdCrank > 1) {
axesAdLdC.resize(dLdCrank / 2);
std::iota(axesAdLdC.begin(), axesAdLdC.end(), 0);
axesBdLdC = ShapeUtils::evalDimsToExclude(dLdCrank, axesAdLdC);
}
else {
axesAdLdC.push_back(0);
axesBdLdC.push_back(0);
}
// calculate dLdA
MmulHelper::tensorDot(dLdC, B, dLdA, axesBdLdC, axesB, permutAt);
// calculate dLdB
MmulHelper::tensorDot(A, dLdC, dLdB, axesA, axesAdLdC, permutBt);
return Status::OK();
}
////////////////////////////////////////////////////////////////////////
DECLARE_SHAPE_FN(tensormmul_bp) {
auto aShapeInfo = inputShape->at(0);
auto bShapeInfo = inputShape->at(1);
auto dLShapeInfo = inputShape->at(2);
REQUIRE_TRUE((ArrayOptions::dataType(aShapeInfo) == ArrayOptions::dataType(bShapeInfo) &&
(ArrayOptions::dataType(dLShapeInfo) == ArrayOptions::dataType(aShapeInfo))), 0, "tensormmul_bp: A, B and dLdC data types must be the same");
Nd4jLong* dLdAShapeInfo = nullptr;
Nd4jLong* dLdBShapeInfo = nullptr;
COPY_SHAPE(aShapeInfo, dLdAShapeInfo);
COPY_SHAPE(bShapeInfo, dLdBShapeInfo);
return SHAPELIST(CONSTANT(dLdAShapeInfo), CONSTANT(dLdBShapeInfo));
}
////////////////////////////////////////////////////////////////////////
DECLARE_TYPES(tensormmul_bp) {
getOpDescriptor()
->setAllowedInputTypes(0, { DataType::FLOAT32, DataType::DOUBLE, DataType::HALF }) // maybe better ALL_FLOATS
->setAllowedInputTypes(1, { DataType::FLOAT32, DataType::DOUBLE, DataType::HALF })
->setAllowedInputTypes(2, { DataType::FLOAT32, DataType::DOUBLE, DataType::HALF })
->setAllowedOutputTypes(0, { DataType::FLOAT32, DataType::DOUBLE, DataType::HALF })
->setAllowedOutputTypes(1, { DataType::FLOAT32, DataType::DOUBLE, DataType::HALF });
}
}
}
#endif

View File

@ -79,7 +79,7 @@ CUSTOM_OP_IMPL(conv1d, 2, 1, false, 0, 5) {
}
auto inputReshaped = input ->reshape(input->ordering(), reshapeForInput);
auto outputReshaped = output ->reshape(output->ordering(), reshapeForOutput);
auto outputReshaped = output ->reshape(output->ordering(), reshapeForOutput, false);
auto weightsReshaped = weights->reshape(weights->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
nd4j::ops::conv2d conv2d;
@ -216,10 +216,10 @@ CUSTOM_OP_IMPL(conv1d_bp, 3, 2, false, 0, 5) {
}
auto inputReshaped = input ->reshape(input->ordering(), reshapeForInput);
auto gradIReshaped = gradI ->reshape(gradI->ordering(), reshapeForInput);
auto gradIReshaped = gradI ->reshape(gradI->ordering(), reshapeForInput, false);
auto gradOReshaped = gradO ->reshape(gradO->ordering(), reshapeForGradO);
auto weightsReshaped = weights->reshape(weights->ordering(),{1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
auto gradWReshaped = gradW ->reshape(gradW->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
auto weightsReshaped = weights->reshape(weights->ordering(),{1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
auto gradWReshaped = gradW ->reshape(gradW->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}, false);// [kW, iC, oC] -> [1, kW, iC, oC]
nd4j::ops::conv2d_bp conv2dBP;
auto status = conv2dBP.execute({&inputReshaped, &weightsReshaped, bias, &gradOReshaped}, {&gradIReshaped, &gradWReshaped, gradB}, {}, {1,kW, 1,sW, 0,pW, 1,dW, paddingMode, !isNCW}, {});

View File

@ -239,7 +239,7 @@ CUSTOM_OP_IMPL(conv3dnew_bp, 3, 2, false, 0, 13) {
//----- calculation of gradO -----//
if(gradB) {
if(gradB->rankOf() == 2)
gradB = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}));
gradB = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}, false));
gradO->reduceAlongDimension(reduce::Sum, *gradB, gradOaxesForDot); // sum over bS oD oH oW
if(gradB != OUTPUT_VARIABLE(2))
delete gradB;

View File

@ -233,7 +233,7 @@ CUSTOM_OP_IMPL(deconv2d_bp, 3, 2, false, 0, 9) {
// ----- calculation of gradB ----- //
if(gradB) {
if(gradB->rankOf() == 2)
gradB = new NDArray(gradB->reshape(gradB->ordering(), {gradB->lengthOf()}));
gradB = new NDArray(gradB->reshape(gradB->ordering(), {gradB->lengthOf()}, false));
gradO->reduceAlongDimension(reduce::Sum, *gradB, {0, 2, 3}); // sum over bS, oH, oW
if(gradB != OUTPUT_VARIABLE(2))
delete gradB;

View File

@ -243,7 +243,7 @@ CUSTOM_OP_IMPL(deconv3d_bp, 3, 2, false, 0, 13) {
// ----- calculation of gradB ----- //
if(gradB) {
if(gradB->rankOf() == 2)
gradB = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}));
gradB = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}, false));
gradO->reduceAlongDimension(reduce::Sum, *gradB, {0, 2, 3, 4}); // sum over bS, oD, oH, oW
if(gradB != OUTPUT_VARIABLE(2))
delete gradB;

View File

@ -61,13 +61,13 @@ namespace nd4j {
}
auto source = inRank == 4?image->reshape(image->ordering(), {image->sizeAt(0), image->sizeAt(1), image->sizeAt(2), image->sizeAt(3)}):image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)});
auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}):output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)});
auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}, false) : output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)}, false);
return helpers::resizeAreaFunctor(block.launchContext(), &source, width, height, alignCorners, &target);
}
DECLARE_SHAPE_FN(resize_area) {
auto shapeList = SHAPELIST();
auto shapeList = SHAPELIST();
auto in = inputShape->at(0);
Nd4jLong* outputShape;
@ -90,7 +90,7 @@ namespace nd4j {
}
REQUIRE_TRUE(inRank == 4 || inRank == 3, 0, "resize_area: Source tensor should have rank 4, but %i given.", inRank);
ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(inRank), Nd4jLong);
outputShape[0] = inRank;
if (inRank == 4) {

View File

@ -62,13 +62,13 @@ namespace nd4j {
REQUIRE_TRUE(!halfPixelAlign || (halfPixelAlign && !alignCorners), 0, "resize_bicubic: `half_pixel_centers' should be false or true only when `align_corners' is false");
auto source = inRank == 4?image->reshape(image->ordering(), {image->sizeAt(0), image->sizeAt(1), image->sizeAt(2), image->sizeAt(3)}):image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)});
auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}):output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)});
auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}, false) : output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)}, false);
return helpers::resizeBicubicFunctorA(block.launchContext(), &source, width, height, alignCorners, halfPixelAlign, &target);
}
DECLARE_SHAPE_FN(resize_bicubic) {
auto shapeList = SHAPELIST();
auto shapeList = SHAPELIST();
auto in = inputShape->at(0);
Nd4jLong* outputShape;
@ -82,7 +82,7 @@ namespace nd4j {
height = newImageSize->e<int>(1);
REQUIRE_TRUE(inRank == 4 || inRank == 3, 0, "resize_bicubic: Source tensor should have rank 4, but %i given.", inRank);
ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(inRank), Nd4jLong);
outputShape[0] = inRank;
if (inRank == 4) {

View File

@ -43,7 +43,7 @@ namespace nd4j {
REQUIRE_TRUE(inRank == output->rankOf(), 0, "resize_bilinear: Input and output ranks should be equals, but %i and %i occured.", inRank, output->rankOf());
auto source = inRank == 4?image->reshape(image->ordering(), {image->sizeAt(0), image->sizeAt(1), image->sizeAt(2), image->sizeAt(3)}):image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)});
auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}):output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)});
auto target = inRank == 4?output->reshape(output->ordering(), {output->sizeAt(0), output->sizeAt(1), output->sizeAt(2), output->sizeAt(3)}, false) : output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)}, false);
if (block.width() > 1) {
auto newImageSize = INPUT_VARIABLE(1);
@ -71,7 +71,7 @@ namespace nd4j {
}
DECLARE_SHAPE_FN(resize_bilinear) {
auto shapeList = SHAPELIST();
auto shapeList = SHAPELIST();
auto in = inputShape->at(0);
Nd4jLong* outputShape;
@ -94,7 +94,7 @@ namespace nd4j {
width = INT_ARG(0);
height = INT_ARG(1);
}
ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(inRank), Nd4jLong);
outputShape[0] = inRank;
if (inRank == 4) {

View File

@ -63,13 +63,13 @@ namespace nd4j {
REQUIRE_TRUE(((alignCorners && height > 2) || (height > 0)) && ((alignCorners && width > 1) || (width > 0)), 0, "resize_nearest_neighbor: Wrong input or output size to resize (width = %d, height = %d)", width, height);
auto source = inRank == 4?*image:image->reshape(image->ordering(), {1, image->sizeAt(0), image->sizeAt(1), image->sizeAt(2)});
auto target = inRank == 4?*output:output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)});
auto target = inRank == 4 ? *output : output->reshape(output->ordering(), {1, output->sizeAt(0), output->sizeAt(1), output->sizeAt(2)}, false);
return helpers::resizeNeighborFunctor(block.launchContext(), inRank==4?image:&source, width, height, alignCorners, halfPixelCenter, inRank == 4 ? output : &target);
}
DECLARE_SHAPE_FN(resize_nearest_neighbor) {
auto shapeList = SHAPELIST();
auto shapeList = SHAPELIST();
auto in = inputShape->at(0);
auto inRank = shape::rank(in);
Nd4jLong* outputShape;

View File

@ -36,14 +36,14 @@ namespace nd4j {
int _a = INT_ARG(e);
if (_a < 0)
_a += input->rankOf();
axis.emplace_back(_a);
}
else if (block.width() > 1) {
auto a = INPUT_VARIABLE(1);
for (Nd4jLong e = 0; e < a->lengthOf(); e++) {
int _a = a->e<int>(e);
if (_a < 0)
_a += input->rankOf();
@ -71,7 +71,7 @@ namespace nd4j {
}
if (block.isInplace()) {
output->reshapei(input->ordering(), shape);
output->reshapei(input->ordering(), shape, false);
} else {
auto tmp = input->reshape(input->ordering(), shape);
output->assign(tmp);
@ -106,20 +106,20 @@ namespace nd4j {
int _a = INT_ARG(e);
if (_a < 0)
_a += rank;
axis.emplace_back(_a);
}
else if (block.width() > 1) {
auto a = INPUT_VARIABLE(1);
for (int e = 0; e < a->lengthOf(); e++) {
int _a = a->e<int>(e);
if (_a < 0)
_a += rank;
axis.emplace_back(_a);
}
}
auto order = shape::order(in);

View File

@ -57,7 +57,8 @@ namespace nd4j {
* IArgs[1]... axes values for second array
*/
#if NOT_EXCLUDED(OP_tensormmul)
DECLARE_CUSTOM_OP(tensormmul, 2, 1, false, 0, -1);
DECLARE_CUSTOM_OP(tensormmul, 2, 1, false, 0, -1);
DECLARE_CUSTOM_OP(tensormmul_bp, 3, 2, false, 0, -1);
#endif
/**

View File

@ -432,7 +432,7 @@ namespace nd4j {
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, input->dataType(), input->getContext());
NDArray outputReshaped = output->reshape(output->ordering(), outReShape);
NDArray outputReshaped = output->reshape(output->ordering(), outReShape, false);
helpers::im2col(*output->getContext(), *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW]
MmulHelper::tensorDot(&columns, weights, &outputReshaped, modifColumns, {{2,0,1,3},{iC,kH*kW,mC}}, modifOutput); // [iC, bS*oH*oW, kW*kH] x [iC, kH*kW, mC] = [iC, bS*oH*oW, mC]
@ -505,7 +505,7 @@ namespace nd4j {
if(gradB) {
NDArray* gradBR = gradB;
if(gradB->rankOf() == 2)
gradBR = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}));
gradBR = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}, false));
gradO->reduceAlongDimension(reduce::Sum, *gradBR, {0,indOoH,indOoH+1}); // sum over bS, oH, oW
if(gradBR != gradB)

View File

@ -30,7 +30,7 @@ namespace helpers {
void crossBatched(nd4j::LaunchContext * context, NDArray *a, NDArray *b, NDArray *o) {
auto _a = a->reshape(a->ordering(), {-1, 3});
auto _b = b->reshape(b->ordering(), {-1, 3});
auto _o = o->reshape(o->ordering(), {-1, 3});
auto _o = o->reshape(o->ordering(), {-1, 3}, false);
auto tadsA = _a.allTensorsAlongDimension({1});
auto tadsB = _b.allTensorsAlongDimension({1});

View File

@ -244,14 +244,14 @@ void spaceToBatch(nd4j::LaunchContext* context, const NDArray& input, NDArray& o
// [bS, iH, iW, iC] is rearranged/permuted to [bS*blockSize*blockSize, (iH + padBottom + padTop)/blockSize, (iW + padLeft + padRight)/blockSize, iC]
NDArray outputRearranged0 = output.reshape(output.ordering(), {blockSize, blockSize, input.sizeAt(0), output.sizeAt(1), output.sizeAt(2), output.sizeAt(3)});
NDArray outputRearranged0 = output.reshape(output.ordering(), {blockSize, blockSize, input.sizeAt(0), output.sizeAt(1), output.sizeAt(2), output.sizeAt(3)}, false);
outputRearranged0.permutei({2, 3,0, 4,1, 5});
if(input.lengthOf() == output.lengthOf()) {
outputRearranged0.assign(input);
}
else {
NDArray outputRearranged1 = outputRearranged0.reshape(output.ordering(), {input.sizeAt(0), output.sizeAt(1) * blockSize, output.sizeAt(2) * blockSize, output.sizeAt(3)});
NDArray outputRearranged1 = outputRearranged0.reshape(output.ordering(), {input.sizeAt(0), output.sizeAt(1) * blockSize, output.sizeAt(2) * blockSize, output.sizeAt(3)}, false);
BUILD_SINGLE_SELECTOR(input.dataType(), spaceToBatch_, (input, outputRearranged1, padBottom, padTop, padLeft, padRight), LIBND4J_TYPES);
if(output.getBuffer() != outputRearranged1.getBuffer())
@ -352,7 +352,7 @@ void spaceToBatchND(nd4j::LaunchContext* context, const NDArray& input, const ND
for(int j = 1; j < rank; ++i, ++j)
temp[i] = output.sizeAt(j);
NDArray outputRearranged0 = output.reshape(output.ordering(), temp);
NDArray outputRearranged0 = output.reshape(output.ordering(), temp, false);
//*** construct permuting std::vector for permutation of output array ***//
@ -382,7 +382,7 @@ void spaceToBatchND(nd4j::LaunchContext* context, const NDArray& input, const ND
for(i = 1; i < rank; ++i)
temp[i] = (i <= numOfSpatialDims) ? output.sizeAt(i) * blockShape.e<Nd4jLong>(i - 1) : output.sizeAt(i);
NDArray outputRearranged1 = outputRearranged0.reshape(output.ordering(), temp);
NDArray outputRearranged1 = outputRearranged0.reshape(output.ordering(), temp, false);
BUILD_SINGLE_SELECTOR(input.dataType(), spaceToBatchND_, (input, padding, outputRearranged1, numOfSpatialDims), LIBND4J_TYPES);

View File

@ -59,7 +59,7 @@ void FORCEINLINE cross(nd4j::LaunchContext * context, NDArray *a, NDArray *b, ND
void FORCEINLINE _crossBatched(nd4j::LaunchContext * context, NDArray *a, NDArray *b, NDArray *o) {
auto a_ = a->reshape(a->ordering(), {-1, 3});
auto b_ = b->reshape(b->ordering(), {-1, 3});
auto o_ = o->reshape(o->ordering(), {-1, 3});
auto o_ = o->reshape(o->ordering(), {-1, 3}, false);
auto tadsA = a_.allTensorsAlongDimension({1});
auto tadsB = b_.allTensorsAlongDimension({1});

View File

@ -322,7 +322,7 @@ static void depthwiseConv2d_(nd4j::graph::Context& block, const NDArray* input,
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
NDArray columns(input->ordering(), {bS, iC, kH, kW, oH, oW}, input->dataType(), input->getContext());
NDArray outputReshaped = output->reshape(output->ordering(), outReShape);
NDArray outputReshaped = output->reshape(output->ordering(), outReShape, false);
helpers::im2col(*output->getContext(), *input, columns, kH, kW, sH, sW, pH, pW, dH, dW, NDArrayFactory::create(0.f, input->getContext())); // [bS, iC, iH, iW] is convoluted to [bS, iC, kH, kW, oH, oW]
MmulHelper::tensorDot(&columns, weights, &outputReshaped, modifColumns, {{2,0,1,3},{iC,kH*kW,mC}}, modifOutput); // [iC, bS*oH*oW, kW*kH] x [iC, kH*kW, mC] = [iC, bS*oH*oW, mC]
@ -1228,7 +1228,7 @@ static void conv2dBP_(nd4j::graph::Context& block, const NDArray* input, const N
NDArray* gradBR = gradB;
if(gradB->rankOf() == 2)
gradBR = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}));
gradO->reduceAlongDimension(reduce::Sum, *gradBR, gradOaxesForDot); // sum over bS, oH, oW
gradO->reduceAlongDimension(reduce::Sum, *gradBR, gradOaxesForDot, false); // sum over bS, oH, oW
if(gradBR != gradB)
delete gradBR;
}
@ -1310,7 +1310,7 @@ static void depthwiseConv2dBP_(const NDArray* input, const NDArray* weights, con
NDArray* gradBR = gradB;
if(gradB->rankOf() == 2)
gradBR = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}));
gradO->reduceAlongDimension(reduce::Sum, *gradBR, {0,indOoH,indOoH+1}); // sum over bS, oH, oW
gradO->reduceAlongDimension(reduce::Sum, *gradBR, {0,indOoH,indOoH+1}, false); // sum over bS, oH, oW
if(gradBR != gradB)
delete gradBR;
}

View File

@ -313,7 +313,7 @@ void spaceToBatch(nd4j::LaunchContext* context, const NDArray& input, NDArray& o
// [bS, iH, iW, iC] is rearranged/permuted to [bS*blockSize*blockSize, (iH + padBottom + padTop)/blockSize, (iW + padLeft + padRight)/blockSize, iC]
NDArray outputRearranged0 = output.reshape(output.ordering(), {blockSize, blockSize, input.sizeAt(0), output.sizeAt(1), output.sizeAt(2), input.sizeAt(3)});
NDArray outputRearranged0 = output.reshape(output.ordering(), {blockSize, blockSize, input.sizeAt(0), output.sizeAt(1), output.sizeAt(2), input.sizeAt(3)}, false);
outputRearranged0.permutei({2, 3,0, 4,1, 5});
if(input.lengthOf() == output.lengthOf()) {
@ -322,7 +322,7 @@ void spaceToBatch(nd4j::LaunchContext* context, const NDArray& input, NDArray& o
}
else {
NDArray outputRearranged1 = outputRearranged0.reshape(output.ordering(), {input.sizeAt(0), output.sizeAt(1) * blockSize, output.sizeAt(2) * blockSize, input.sizeAt(3)});
NDArray outputRearranged1 = outputRearranged0.reshape(output.ordering(), {input.sizeAt(0), output.sizeAt(1) * blockSize, output.sizeAt(2) * blockSize, input.sizeAt(3)}, false);
const int threadsPerBlock = MAX_NUM_THREADS / 2;
const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
@ -439,7 +439,7 @@ void spaceToBatchND(nd4j::LaunchContext* context, const NDArray& input, const ND
for(int j = 1; j < rank; ++i, ++j)
temp[i] = output.sizeAt(j);
NDArray outputRearranged0 = output.reshape(output.ordering(), temp);
NDArray outputRearranged0 = output.reshape(output.ordering(), temp, false);
//*** construct permuting std::vector for permutation of output array ***//
@ -469,7 +469,7 @@ void spaceToBatchND(nd4j::LaunchContext* context, const NDArray& input, const ND
for(i = 1; i < rank; ++i)
temp[i] = (i <= numOfSpatialDims) ? output.sizeAt(i) * blockShape.e<Nd4jLong>(i - 1) : output.sizeAt(i);
NDArray outputRearranged1 = outputRearranged0.reshape(output.ordering(), temp);
NDArray outputRearranged1 = outputRearranged0.reshape(output.ordering(), temp, false);
const int threadsPerBlock = MAX_NUM_THREADS / 4;
const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;

View File

@ -471,9 +471,9 @@ PLATFORM_IMPL(lstmLayer, ENGINE_CPU) {
if(cI)
cIR = new NDArray(cI->reshape(cI->ordering(), {1,dirDim,bS,nOut}));
if(hL)
hLR = new NDArray(hL->reshape(hL->ordering(), {1,dirDim,bS,nOut}));
hLR = new NDArray(hL->reshape(hL->ordering(), {1,dirDim,bS,nOut}, false));
if(cL)
cLR = new NDArray(cL->reshape(cL->ordering(), {1,dirDim,bS,nOut}));
cLR = new NDArray(cL->reshape(cL->ordering(), {1,dirDim,bS,nOut}, false));
lstmLayerMKLDNN(xP, WxR, WrR, bR, hIR, cIR, params, hP, hLR, cLR);

View File

@ -321,6 +321,280 @@ TEST_F(DeclarableOpsTests1, TestTensorDot4) {
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests1, TestTensorDot5) {
auto x = NDArrayFactory::create<double>('c', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
auto y = NDArrayFactory::create<double>('c', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
auto expected = NDArrayFactory::create<double>('c', {2,4,2,4}, {44,110,160, 66,132, 38, 88,154, 68,170,224,102,204, 82,136,238, 92,230,288,138,276,126,184,322, 116,290,352,174,348,170,232,406, 76,190,160,114,228,182,152,266, 100,250,224,150,300,226,200,350, 124,310,288,186,372,270,248,434, 148,370,352,222,444,314,296,518});
nd4j::ops::tensormmul op;
auto results = op.evaluate({&x, &y}, {}, {1,1,1,2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *result = results->at(0);
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests1, TestTensorDot6) {
auto x = NDArrayFactory::create<double>('c', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
auto y = NDArrayFactory::create<double>('f', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
auto expected = NDArrayFactory::create<double>('c', {2,4,2,4}, {22, 66,110,154, 44, 88,132,176, 34,102,170,238, 68,136,204,272, 46,138,230,322, 92,184,276,368, 58,174,290,406,116,232,348,464, 38,114,190,266, 76,152,228,304, 50,150,250,350,100,200,300,400, 62,186,310,434,124,248,372,496, 74,222,370,518,148,296,444,592});
nd4j::ops::tensormmul op;
auto results = op.evaluate({&x, &y}, {}, {1,1,1,2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *result = results->at(0);
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests1, TestTensorDot7) {
auto x = NDArrayFactory::create<double>('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
auto y = NDArrayFactory::create<double>('c', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
auto expected = NDArrayFactory::create<double>('c', {2,4,2,4}, {76,166,112,106,196, 62,136,226, 60,174,208, 98,212,230,136,250, 76,214,336,122,260,174,168,306, 124,286,240,178,340,150,232,394, 100,226,176,142,268,106,184,310, 84,234,272,134,284,274,184,334, 100,274,400,158,332,218,216,390, 148,346,304,214,412,194,280,478});
nd4j::ops::tensormmul op;
auto results = op.evaluate({&x, &y}, {}, {1,1,1,2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *result = results->at(0);
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests1, TestTensorDot8) {
auto x = NDArrayFactory::create<double>('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
auto y = NDArrayFactory::create<double>('f', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
auto expected = NDArrayFactory::create<double>('c', {2,4,2,4}, {30, 90,150,210, 60,120,180,240, 38,114,190,266, 76,152,228,304, 46,138,230,322, 92,184,276,368, 54,162,270,378,108,216,324,432, 42,126,210,294, 84,168,252,336, 50,150,250,350,100,200,300,400, 58,174,290,406,116,232,348,464, 66,198,330,462,132,264,396,528});
nd4j::ops::tensormmul op;
auto results = op.evaluate({&x, &y}, {}, {1,1,1,2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *result = results->at(0);
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests1, TestTensorDot9) {
// NDArray z('f',{2,2,3}, nd4j::DataType::DOUBLE);
// z.linspace(1);
// z.printShapeInfo();
// z.printIndexedBuffer();
// z.reshapei('c', {4,3});
// z.printShapeInfo();
// z.printIndexedBuffer();
auto x = NDArrayFactory::create<double>('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
auto y = NDArrayFactory::create<double>('f', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
auto expected = NDArrayFactory::create<double>('c', {3,4,4,3}, {14, 14, 14, 30, 30, 30, 46, 46, 46, 62, 62, 62, 86, 86, 86,198,198,198,310,310,310,422,422,422, 62, 62, 62,142,142,142,222,222,222,302,302,302, 38, 38, 38, 86, 86, 86,134,134,134,182,182,182, 38, 38, 38, 86, 86, 86,134,134,134,182,182,182, 14, 14, 14, 30, 30, 30, 46, 46, 46, 62, 62, 62, 86, 86, 86,198,198,198,310,310,310,422,422,422, 62, 62, 62,142,142,142,222,222,222,302,302,302, 62, 62, 62,142,142,142,222,222,222,302,302,302, 38, 38, 38, 86, 86, 86,134,134,134,182,182,182, 14, 14, 14, 30, 30, 30, 46, 46, 46, 62, 62, 62, 86, 86, 86,198,198,198,310,310,310,422,422,422});
nd4j::ops::tensormmul op;
auto results = op.evaluate({&x, &y}, {}, {1,0,1,0});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *result = results->at(0);
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests1, TestTensorDot10) {
auto x = NDArrayFactory::create<double>('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
auto y = NDArrayFactory::create<double>('f', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
auto expected = NDArrayFactory::create<double>('c', {4,4}, {114,258,402,546, 138,314,490,666, 162,370,578,786, 186,426,666,906});
nd4j::ops::tensormmul op;
auto results = op.evaluate({&x, &y}, {}, {2,0,1, 2,0,2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *result = results->at(0);
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests1, TestTensorDot11) {
auto x = NDArrayFactory::create<double>('c', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
auto y = NDArrayFactory::create<double>('f', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
auto expected = NDArrayFactory::create<double>('c', {4,4}, {98,218,338,458, 134,302,470,638, 170,386,602,818, 206,470,734,998});
nd4j::ops::tensormmul op;
auto results = op.evaluate({&x, &y}, {}, {2,0,1, 2,0,2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *result = results->at(0);
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests1, TestTensorDot12) {
auto x = NDArrayFactory::create<double>('c', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
auto y = NDArrayFactory::create<double>('c', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
auto expected = NDArrayFactory::create<double>('c', {4,4}, {272,292,312,332, 368,396,424,452, 464,500,536,572, 560,604,648,692});
nd4j::ops::tensormmul op;
auto results = op.evaluate({&x, &y}, {}, {2,0,1, 2,0,2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *result = results->at(0);
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests1, TestTensorDot13) {
auto x = NDArrayFactory::create<double>('c', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
auto y = NDArrayFactory::create<double>('c', {4,2,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
auto expected = NDArrayFactory::create<double>('c', {3,3}, {640,560,640, 576,624,576, 640,560,640});
nd4j::ops::tensormmul op;
auto results = op.evaluate({&x, &y}, {}, {2,0,2, 2,1,0});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *result = results->at(0);
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests1, TestTensorDot14) {
auto x = NDArrayFactory::create<double>('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
auto y = NDArrayFactory::create<double>('c', {4,2,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
auto expected = NDArrayFactory::create<double>('c', {3,3}, {648,600,520, 648,536,648, 520,600,648});
nd4j::ops::tensormmul op;
auto results = op.evaluate({&x, &y}, {}, {2,0,2, 2,1,0});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *result = results->at(0);
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests1, TestTensorDot15) {
auto x = NDArrayFactory::create<double>('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
auto y = NDArrayFactory::create<double>('f', {4,2,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
auto expected = NDArrayFactory::create<double>('c', {3,3}, {624,624,624, 656,656,656, 624,624,624});
nd4j::ops::tensormmul op;
auto results = op.evaluate({&x, &y}, {}, {2,0,2, 2,1,0});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *result = results->at(0);
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests1, TestTensorDot16) {
NDArray x('c', {1}, std::vector<double>{2}, nd4j::DataType::FLOAT32);
NDArray y('c', {2,1,2}, {1,2,3,4}, nd4j::DataType::FLOAT32);
NDArray exp('c', {2,2}, {2,4,6,8}, nd4j::DataType::FLOAT32);
nd4j::ops::tensormmul op;
auto results = op.evaluate({&x, &y}, {}, {1,0, 1,1});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *result = results->at(0);
ASSERT_TRUE(exp.isSameShape(result));
ASSERT_TRUE(exp.equalsTo(result));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests1, TestTensorDot17) {
NDArray x('f', {16,16}, nd4j::DataType::FLOAT32);
NDArray y('f', {1000,16}, nd4j::DataType::FLOAT32);
NDArray z('c', {16,1000}, nd4j::DataType::FLOAT32);
nd4j::ops::tensormmul op;
auto status = op.execute({&x, &y}, {&z}, {}, {1,1, 1,1}, {});
ASSERT_EQ(ND4J_STATUS_OK, status);
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests1, DivergentCheck1) {
auto op = nd4j::ops::OpRegistrator::getInstance()->getOperation("switch");

View File

@ -708,30 +708,6 @@ TEST_F(DeclarableOpsTests12, multiUnique_2) {
ASSERT_TRUE(nd4j::ops::helpers::multiUnique(arrayList));
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests12, tensormmul_6) {
NDArray x('c', {1}, std::vector<double>{2}, nd4j::DataType::FLOAT32);
NDArray y('c', {2,1,2}, {1,2,3,4}, nd4j::DataType::FLOAT32);
NDArray exp('c', {2,2}, {2,4,6,8}, nd4j::DataType::FLOAT32);
nd4j::ops::tensormmul op;
auto results = op.evaluate({&x, &y}, {}, {1,0, 1,1});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *result = results->at(0);
// exp.printShapeInfo();
// result->printShapeInfo();
// result->printIndexedBuffer();
ASSERT_TRUE(exp.isSameShape(result));
ASSERT_TRUE(exp.equalsTo(result));
delete results;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests12, reduceMeanBp_4) {

View File

@ -1560,3 +1560,447 @@ TEST_F(DeclarableOpsTests15, Pow_BP_Test11) {
delete resultsB;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP1) {
NDArray A('c', { 1, 2, 3 }, { 2.1, 2.2, 2.3, 2.4, 2.5, 2.6 }, nd4j::DataType::FLOAT32);
NDArray B('c', { 1, 2, 4 }, { 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8 }, nd4j::DataType::FLOAT32);
NDArray dLdC('c', { 3, 4 }, { .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.1 }, nd4j::DataType::FLOAT32);
NDArray dLdA('c', { 1, 2, 3 }, { 3.3, 8.5, 13.36, 3.7, 9.54, 15. }, nd4j::DataType::FLOAT32);
NDArray dLdB('c', { 1, 2, 4 }, { 3.38, 4.04, 4.7, 5.13, 3.83, 4.58, 5.33, 5.82 }, nd4j::DataType::FLOAT32);
nd4j::ops::tensormmul_bp op_bp;
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,0,1, 2,0,1 }, {});
ASSERT_EQ(ND4J_STATUS_OK, resultsBP->status());
auto* dLdAbp = resultsBP->at(0);
auto* dLdBbp = resultsBP->at(1);
ASSERT_TRUE(dLdA.isSameShape(*dLdAbp));
ASSERT_TRUE(dLdA.equalsTo(*dLdAbp));
ASSERT_TRUE(dLdB.isSameShape(*dLdBbp));
ASSERT_TRUE(dLdB.equalsTo(*dLdBbp));
delete resultsBP;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP2) {
NDArray A('c', { 1, 2, 3 }, { 2,2,2, 2,2,2 }, nd4j::DataType::FLOAT32);
NDArray B('c', { 1, 2, 3 }, { 3,3,3,3, 3,3 }, nd4j::DataType::FLOAT32);
NDArray dLdC('c', { 1 }, { 1 }, nd4j::DataType::FLOAT32);
nd4j::ops::tensormmul_bp op_bp;
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }, {});
ASSERT_EQ(ND4J_STATUS_OK, resultsBP->status());
auto* dLdAbp = resultsBP->at(0);
auto* dLdBbp = resultsBP->at(1);
ASSERT_TRUE(B.isSameShape(*dLdAbp));
ASSERT_TRUE(B.equalsTo(*dLdAbp));
ASSERT_TRUE(A.isSameShape(*dLdBbp));
ASSERT_TRUE(A.equalsTo(*dLdBbp));
delete resultsBP;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP3) {
NDArray A('c', { 3, 2, 2 }, { 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2 }, nd4j::DataType::FLOAT32);
NDArray B('c', { 4, 2, 2 }, { 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2, 4.3, 4.4, 4.5, 4.6 }, nd4j::DataType::FLOAT32);
NDArray dLdC('c', { 3, 4 }, { .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2 }, nd4j::DataType::FLOAT32);
NDArray dA('c', { 3, 2, 2 }, { 3.9, 4., 4.1, 4.2, 9.82, 10.08, 10.34, 10.6, 15.74, 16.16, 16.58, 17. }, nd4j::DataType::FLOAT32);
NDArray dB('c', { 4, 2, 2 }, { 4.07, 4.22, 4.37, 4.52, 4.82, 5., 5.18, 5.36, 5.57, 5.78, 5.99, 6.2, 6.32, 6.56, 6.8, 7.04 }, nd4j::DataType::FLOAT32);
nd4j::ops::tensormmul_bp op_bp;
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }, {});
ASSERT_EQ(ND4J_STATUS_OK, resultsBP->status());
auto* dLdAbp = resultsBP->at(0);
auto* dLdBbp = resultsBP->at(1);
ASSERT_TRUE(dA.isSameShape(*dLdAbp));
ASSERT_TRUE(dA.equalsTo(*dLdAbp));
ASSERT_TRUE(dB.isSameShape(*dLdBbp));
ASSERT_TRUE(dB.equalsTo(*dLdBbp));
delete resultsBP;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP4) {
NDArray A('c', { 3, 4, 1 }, { 0.4, 3, 5, 9, 23, 0.12, 8, 9, 0.1, 0, 124, 3 }, nd4j::DataType::FLOAT32);
NDArray B('c', { 2, 4, 1 }, { 4, 13, .5, 19, 2.3, 1.2, 18, .9 }, nd4j::DataType::FLOAT32);
NDArray dLdC('c', { 3, 2 }, { 1.1, 1.2, 1.3, 1.4, 1.5, 1.6 }, nd4j::DataType::FLOAT32);
NDArray dLdA('c', { 3, 4, 1 }, { 7.16, 15.74, 22.15, 21.98, 8.42, 18.58, 25.85, 25.96, 9.68, 21.42, 29.55, 29.94 }, nd4j::DataType::FLOAT32);
NDArray dLdB('c', { 2, 4, 1 }, { 30.49, 3.456, 201.9, 26.1, 32.84 , 3.768, 215.6, 28.2 }, nd4j::DataType::FLOAT32);
nd4j::ops::tensormmul_bp op_bp;
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }, {});
ASSERT_EQ(ND4J_STATUS_OK, resultsBP->status());
auto* dLdAbp = resultsBP->at(0);
auto* dLdBbp = resultsBP->at(1);
ASSERT_TRUE(dLdA.isSameShape(*dLdAbp));
ASSERT_TRUE(dLdA.equalsTo(*dLdAbp));
ASSERT_TRUE(dLdB.isSameShape(*dLdBbp));
ASSERT_TRUE(dLdB.equalsTo(*dLdBbp));
delete resultsBP;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP5) {
NDArray A('c', { 3, 4, 1, 1 }, { 0.4, 3, 5, 9, 23, 0.12, 8, 9, 0.1, 0, 124, 3 }, nd4j::DataType::FLOAT32);
NDArray B('c', { 2, 4, 1, 1 }, { 4, 13, .5, 19, 2.3, 1.2, 18, .9 }, nd4j::DataType::FLOAT32);
NDArray dLdC('c', { 3, 1, 2, 1 }, { 1.1,1.2,1.3,1.4,1.5,1.6 }, nd4j::DataType::FLOAT32);
NDArray dLdA('c', { 3, 4, 1, 1 }, { 7.16, 15.74, 22.15, 21.98, 8.42, 18.58, 25.85, 25.96, 9.68, 21.42, 29.55, 29.94 }, nd4j::DataType::FLOAT32);
NDArray dLdB('c', { 2, 4, 1, 1 }, { 30.49, 3.456, 201.9, 26.1, 32.84, 3.768, 215.6, 28.2 }, nd4j::DataType::FLOAT32);
nd4j::ops::tensormmul_bp op_bp;
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }, {});
ASSERT_EQ(ND4J_STATUS_OK, resultsBP->status());
auto* dLdAbp = resultsBP->at(0);
auto* dLdBbp = resultsBP->at(1);
ASSERT_TRUE(dLdA.isSameShape(*dLdAbp));
ASSERT_TRUE(dLdA.equalsTo(*dLdAbp));
ASSERT_TRUE(dLdB.isSameShape(*dLdBbp));
ASSERT_TRUE(dLdB.equalsTo(*dLdBbp));
delete resultsBP;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP6) {
NDArray A('c', { 2, 2, 2 }, { 2,2, 2,2, 2,2, 2,2 }, nd4j::DataType::FLOAT32);
NDArray B('c', { 2, 2, 2 }, { 3,3, 3,3, 3,3, 3,3 }, nd4j::DataType::FLOAT32);
auto dLdC = NDArrayFactory::create<float>(1.f);
nd4j::ops::tensormmul_bp op_bp;
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 3,0,1,2, 3,0,1,2 }, {});
ASSERT_EQ(ND4J_STATUS_OK, resultsBP->status());
auto* dLdAbp = resultsBP->at(0);
auto* dLdBbp = resultsBP->at(1);
ASSERT_TRUE(B.isSameShape(*dLdAbp));
ASSERT_TRUE(B.equalsTo(*dLdAbp));
ASSERT_TRUE(A.isSameShape(*dLdBbp));
ASSERT_TRUE(A.equalsTo(*dLdBbp));
delete resultsBP;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP7) {
NDArray A('c', { 3, 4, 1 }, { 0.4, 3, 5, 9, 23, 0.12, 8, 9, 0.1, 0, 124, 3 }, nd4j::DataType::FLOAT32);
NDArray B('c', { 2, 4, 1 }, { 4, 13, .5, 19, 2.3, 1.2, 18, .9 }, nd4j::DataType::FLOAT32);
NDArray dLdC('c', { 3, 1, 2, 1 }, { 1.1, 1.2, 1.3, 1.4, 1.5, 1.6 }, nd4j::DataType::FLOAT32);
NDArray dLdA('c', { 3, 4, 1 }, { 7.16, 15.74, 22.15, 21.98, 8.42, 18.58, 25.85, 25.96, 9.68, 21.42, 29.55, 29.94 }, nd4j::DataType::FLOAT32);
NDArray dLdB('c', { 2, 4, 1 }, { 30.49, 3.456, 201.9, 26.1, 32.84, 3.768, 215.6, 28.2 }, nd4j::DataType::FLOAT32);
nd4j::ops::tensormmul_bp op_bp;
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 1,1, 1,1 }, {});
ASSERT_EQ(ND4J_STATUS_OK, resultsBP->status());
auto* dLdAbp = resultsBP->at(0);
auto* dLdBbp = resultsBP->at(1);
ASSERT_TRUE(dLdA.isSameShape(*dLdAbp));
ASSERT_TRUE(dLdA.equalsTo(*dLdAbp));
ASSERT_TRUE(dLdB.isSameShape(*dLdBbp));
ASSERT_TRUE(dLdB.equalsTo(*dLdBbp));
delete resultsBP;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP8) {
NDArray A('c', { 1, 1, 4, 3 }, { 0.4, 3, 5, 9, 23, 0.12, 8, 9, 0.1, 0, 124, 3 }, nd4j::DataType::FLOAT32);
NDArray B('c', { 1, 1, 4, 2 }, { 4, 13, .5, 19, 2.3, 1.2, 18, .9 }, nd4j::DataType::FLOAT32);
NDArray dLdC('c', { 3, 2 }, { 1.1,1.2,1.3,1.4,1.5,1.6 }, nd4j::DataType::FLOAT32);
NDArray dLdA('c', { 1, 1, 4, 3 }, { 20., 23.4, 26.8, 23.35, 27.25, 31.15, 3.97, 4.67, 5.37, 20.88, 24.66, 28.44 }, nd4j::DataType::FLOAT32);
NDArray dLdB('c', { 1, 1, 4, 2 }, { 11.84, 12.68, 39.98, 43.192, 20.65, 22.36, 165.7, 178.4 }, nd4j::DataType::FLOAT32);
nd4j::ops::tensormmul_bp op_bp;
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 3,0,1,2, 3,0,1,2 }, {});
ASSERT_EQ(ND4J_STATUS_OK, resultsBP->status());
auto* dLdAbp = resultsBP->at(0);
auto* dLdBbp = resultsBP->at(1);
ASSERT_TRUE(dLdA.isSameShape(*dLdAbp));
ASSERT_TRUE(dLdA.equalsTo(*dLdAbp));
ASSERT_TRUE(dLdB.isSameShape(*dLdBbp));
ASSERT_TRUE(dLdB.equalsTo(*dLdBbp));
delete resultsBP;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP9) {
NDArray A('c', { 3, 2, 2, 1 }, { 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2 }, nd4j::DataType::FLOAT32);
NDArray B('c', { 4, 2, 2 ,1 }, { 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2, 4.3, 4.4, 4.5, 4.6 }, nd4j::DataType::FLOAT32);
NDArray dLdC('c', { 3, 1, 4, 1 }, { .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2 }, nd4j::DataType::FLOAT32);
NDArray dA('c', { 3, 2, 2, 1 }, { 3.9, 4., 4.1, 4.2, 9.82, 10.08, 10.34, 10.6, 15.74, 16.16, 16.58, 17. }, nd4j::DataType::FLOAT32);
NDArray dB('c', { 4, 2, 2, 1 }, { 4.07, 4.22, 4.37, 4.52, 4.82, 5., 5.18, 5.36, 5.57, 5.78, 5.99, 6.2, 6.32, 6.56, 6.8, 7.04 }, nd4j::DataType::FLOAT32);
nd4j::ops::tensormmul_bp op_bp;
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }, {});
ASSERT_EQ(ND4J_STATUS_OK, resultsBP->status());
auto* dLdAbp = resultsBP->at(0);
auto* dLdBbp = resultsBP->at(1);
ASSERT_TRUE(dA.isSameShape(*dLdAbp));
ASSERT_TRUE(dA.equalsTo(*dLdAbp));
ASSERT_TRUE(dB.isSameShape(*dLdBbp));
ASSERT_TRUE(dB.equalsTo(*dLdBbp));
delete resultsBP;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP10) {
NDArray A('c', { 1, 2, 2, 3 }, { 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2 }, nd4j::DataType::FLOAT32);
NDArray B('c', { 1, 2, 2 ,4 }, { 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2, 4.3, 4.4, 4.5, 4.6 }, nd4j::DataType::FLOAT32);
NDArray dLdC('c', { 1, 3, 1, 4 }, { .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2 }, nd4j::DataType::FLOAT32);
NDArray dA('c', { 1, 2, 2, 3 }, { 3.3, 8.5, 13.7, 3.7, 9.54, 15.38, 4.1, 10.58, 17.06, 4.5, 11.62, 18.74 }, nd4j::DataType::FLOAT32);
NDArray dB('c', { 1, 2, 2, 4 }, { 3.38, 4.04, 4.7, 5.36, 3.83, 4.58, 5.33, 6.08, 4.28, 5.12, 5.96, 6.8, 4.73, 5.66, 6.59, 7.52 }, nd4j::DataType::FLOAT32);
nd4j::ops::tensormmul_bp op_bp;
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }, {});
ASSERT_EQ(ND4J_STATUS_OK, resultsBP->status());
auto* dLdAbp = resultsBP->at(0);
auto* dLdBbp = resultsBP->at(1);
ASSERT_TRUE(dA.isSameShape(*dLdAbp));
ASSERT_TRUE(dA.equalsTo(*dLdAbp));
ASSERT_TRUE(dB.isSameShape(*dLdBbp));
ASSERT_TRUE(dB.equalsTo(*dLdBbp));
delete resultsBP;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP11) {
NDArray A('c', { 2, 2, 3 }, { 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2 }, nd4j::DataType::FLOAT32);
NDArray B('c', { 2, 2 ,4 }, { 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2, 4.3, 4.4, 4.5, 4.6 }, nd4j::DataType::FLOAT32);
NDArray dLdC('c', { 3, 4 }, { .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2 }, nd4j::DataType::FLOAT32);
NDArray dA('c', { 2, 2, 3 }, { 3.3, 8.5, 13.7, 3.7, 9.54, 15.38, 4.1, 10.58, 17.06, 4.5, 11.62, 18.74 }, nd4j::DataType::FLOAT32);
NDArray dB('c', { 2, 2, 4 }, { 3.38, 4.04, 4.7, 5.36, 3.83, 4.58, 5.33, 6.08, 4.28, 5.12, 5.96, 6.8, 4.73, 5.66, 6.59, 7.52 }, nd4j::DataType::FLOAT32);
nd4j::ops::tensormmul_bp op_bp;
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,0,1, 2,0,1 }, {});
ASSERT_EQ(ND4J_STATUS_OK, resultsBP->status());
auto* dLdAbp = resultsBP->at(0);
auto* dLdBbp = resultsBP->at(1);
ASSERT_TRUE(dA.isSameShape(*dLdAbp));
ASSERT_TRUE(dA.equalsTo(*dLdAbp));
ASSERT_TRUE(dB.isSameShape(*dLdBbp));
ASSERT_TRUE(dB.equalsTo(*dLdBbp));
delete resultsBP;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP12) {
NDArray A('c', { 2, 2, 3 }, { 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2 }, nd4j::DataType::FLOAT32);
NDArray B('c', { 2, 2 ,3 }, { 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2 }, nd4j::DataType::FLOAT32);
NDArray dLdC('c', { 2, 3, 2, 3 }, { .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2,
1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2., 2.1, 2.2, 2.3, 2.4,
2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2, 3.3, 3.4, 3.5, 3.6 }, nd4j::DataType::FLOAT32);
NDArray dA('c', { 2, 2, 3 }, { 7.66, 20.26, 32.86, 8.29, 21.97, 35.65, 45.46, 58.06, 70.66, 49.33, 63.01, 76.69 }, nd4j::DataType::FLOAT32);
NDArray dB('c', { 2, 2, 3 }, { 25.86, 27.36, 28.86, 28.74, 30.42, 32.1, 30.36, 31.86, 33.36, 33.78, 35.46, 37.14 }, nd4j::DataType::FLOAT32);
nd4j::ops::tensormmul_bp op_bp;
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 1,1, 1,1 }, {});
ASSERT_EQ(ND4J_STATUS_OK, resultsBP->status());
auto* dLdAbp = resultsBP->at(0);
auto* dLdBbp = resultsBP->at(1);
ASSERT_TRUE(dA.isSameShape(*dLdAbp));
ASSERT_TRUE(dA.equalsTo(*dLdAbp));
ASSERT_TRUE(dB.isSameShape(*dLdBbp));
ASSERT_TRUE(dB.equalsTo(*dLdBbp));
delete resultsBP;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP13) {
NDArray A('c', { 3, 2, 2 }, { 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2 }, nd4j::DataType::DOUBLE);
NDArray B('c', { 3, 2, 2 }, { 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2 }, nd4j::DataType::DOUBLE);
NDArray dLdC('c', { 3, 2, 3, 2 }, { .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2,
1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2., 2.1, 2.2, 2.3, 2.4,
2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2, 3.3, 3.4, 3.5, 3.6 }, nd4j::DataType::DOUBLE);
NDArray dA('c', { 3, 2, 2 }, { 7.79, 20.57, 8.21, 21.71, 33.35, 46.13, 35.21, 48.71, 58.91, 71.69, 62.21, 75.71 }, nd4j::DataType::DOUBLE);
NDArray dB('c', { 3, 2, 2 }, { 26.49, 28.02, 28.41, 30.06, 29.55, 31.08, 31.71, 33.36, 32.61, 34.14, 35.01, 36.66 }, nd4j::DataType::DOUBLE);
nd4j::ops::tensormmul_bp op_bp;
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 1,1, 1,1 }, {});
ASSERT_EQ(ND4J_STATUS_OK, resultsBP->status());
auto* dLdAbp = resultsBP->at(0);
auto* dLdBbp = resultsBP->at(1);
ASSERT_TRUE(dA.isSameShape(*dLdAbp));
ASSERT_TRUE(dA.equalsTo(*dLdAbp));
ASSERT_TRUE(dB.isSameShape(*dLdBbp));
ASSERT_TRUE(dB.equalsTo(*dLdBbp));
delete resultsBP;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP14) {
NDArray A('c', { 2, 2, 2, 2 }, { 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2, 3.3, 3.4, 3.5, 3.6 }, nd4j::DataType::DOUBLE);
NDArray B('c', { 2, 2, 2, 2 }, { 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2, 4.3, 4.4, 4.5, 4.6 }, nd4j::DataType::DOUBLE);
NDArray dLdC('c', { 2, 2, 2, 2, 2, 2 }, { .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2,
1.3, 1.4, 1.5, 1.6, .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2,
1.3, 1.4, 1.5, 1.6, .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2,
1.3, 1.4, 1.5, 1.6, .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2,
1.3, 1.4, 1.5, 1.6 }, nd4j::DataType::DOUBLE);
NDArray dA('c', { 2, 2, 2, 2 }, { 13.88, 37.24, 13.88, 37.24, 15.32, 41.24, 15.32, 41.24, 13.88, 37.24, 13.88, 37.24, 15.32, 41.24, 15.32, 41.24 }, nd4j::DataType::DOUBLE);
NDArray dB('c', { 2, 2, 2, 2 }, { 10.76, 12.88, 15., 17.12, 12.36, 14.8, 17.24, 19.68, 19.24, 21.36, 23.48, 25.6, 22.12, 24.56, 27., 29.44 }, nd4j::DataType::DOUBLE);
nd4j::ops::tensormmul_bp op_bp;
auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 1,1, 1,1 }, {});
ASSERT_EQ(ND4J_STATUS_OK, resultsBP->status());
auto* dLdAbp = resultsBP->at(0);
auto* dLdBbp = resultsBP->at(1);
ASSERT_TRUE(dA.isSameShape(*dLdAbp));
ASSERT_TRUE(dA.equalsTo(*dLdAbp));
ASSERT_TRUE(dB.isSameShape(*dLdBbp));
ASSERT_TRUE(dB.equalsTo(*dLdBbp));
delete resultsBP;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP15) {
NDArray A('c', { 2, 2, 3 }, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12. }, nd4j::DataType::FLOAT32);
NDArray B('f', { 2, 2, 3 }, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12. }, nd4j::DataType::FLOAT32);
NDArray dLdC('f', { 2, 2 }, { 23.0, 24.44, 2.0, 26. }, nd4j::DataType::FLOAT32);
NDArray dA('c', { 2, 2, 3 }, { 27., 127., 227., 77., 177., 277., 76.44, 278.20001, 479.96002, 177.32, 379.08001, 580.839966 }, nd4j::DataType::FLOAT32);
NDArray dB('f', { 2, 2, 3 }, { 194.08, 184., 336.4, 268., 241.52, 212., 383.839996, 296., 288.96002, 240., 431.27999, 324. }, nd4j::DataType::FLOAT32);
nd4j::ops::tensormmul_bp op;
auto results = op.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2,2,1,2 });
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto* dLdA = results->at(0);
auto* dLdB = results->at(1);
ASSERT_TRUE(dA.isSameShape(*dLdA));
ASSERT_TRUE(dA.equalsTo(*dLdA));
ASSERT_TRUE(dB.isSameShape(*dLdB));
ASSERT_TRUE(dB.equalsTo(*dLdB));
delete results;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP16) {
NDArray A('f', { 2, 2, 3 }, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12. }, nd4j::DataType::DOUBLE);
NDArray B('c', { 2, 2, 3 }, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12. }, nd4j::DataType::DOUBLE);
NDArray dLdC('c', { 2, 2 }, nd4j::DataType::DOUBLE);
const OpArgsHolder argsHolderFF({ &A, &B }, {}, { 2,1,2, 2,1,2 });
const OpArgsHolder argsHolderBP({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 });
nd4j::ops::tensormmul op;
nd4j::ops::tensormmul_bp op_bp;
const bool isGradCorrect = GradCheck::checkGrad(op, op_bp, argsHolderFF, argsHolderBP, {1,0});
ASSERT_TRUE(isGradCorrect);
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests15, TestTensorMmul_BP17) {
NDArray A('f', { 2, 2, 3 }, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12. }, nd4j::DataType::DOUBLE);
NDArray B('f', { 2, 2, 3 }, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12. }, nd4j::DataType::DOUBLE);
NDArray dLdC('c', { 2, 2 }, nd4j::DataType::DOUBLE);
const OpArgsHolder argsHolderFF({ &A, &B }, {}, { 2,1,2, 2,1,2 });
const OpArgsHolder argsHolderBP({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 });
nd4j::ops::tensormmul op;
nd4j::ops::tensormmul_bp op_bp;
const bool isGradCorrect = GradCheck::checkGrad(op, op_bp, argsHolderFF, argsHolderBP, { 1,0 });
ASSERT_TRUE(isGradCorrect);
}

View File

@ -578,246 +578,6 @@ TEST_F(DeclarableOpsTests2, Test_Concat_BP_1) {
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests2, TestTensorDot5) {
auto x = NDArrayFactory::create<double>('c', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
auto y = NDArrayFactory::create<double>('c', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
auto expected = NDArrayFactory::create<double>('c', {2,4,2,4}, {44,110,160, 66,132, 38, 88,154, 68,170,224,102,204, 82,136,238, 92,230,288,138,276,126,184,322, 116,290,352,174,348,170,232,406, 76,190,160,114,228,182,152,266, 100,250,224,150,300,226,200,350, 124,310,288,186,372,270,248,434, 148,370,352,222,444,314,296,518});
nd4j::ops::tensormmul op;
auto results = op.evaluate({&x, &y}, {}, {1,1,1,2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *result = results->at(0);
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests2, TestTensorDot6) {
auto x = NDArrayFactory::create<double>('c', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
auto y = NDArrayFactory::create<double>('f', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
auto expected = NDArrayFactory::create<double>('c', {2,4,2,4}, {22, 66,110,154, 44, 88,132,176, 34,102,170,238, 68,136,204,272, 46,138,230,322, 92,184,276,368, 58,174,290,406,116,232,348,464, 38,114,190,266, 76,152,228,304, 50,150,250,350,100,200,300,400, 62,186,310,434,124,248,372,496, 74,222,370,518,148,296,444,592});
nd4j::ops::tensormmul op;
auto results = op.evaluate({&x, &y}, {}, {1,1,1,2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *result = results->at(0);
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests2, TestTensorDot7) {
auto x = NDArrayFactory::create<double>('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
auto y = NDArrayFactory::create<double>('c', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
auto expected = NDArrayFactory::create<double>('c', {2,4,2,4}, {76,166,112,106,196, 62,136,226, 60,174,208, 98,212,230,136,250, 76,214,336,122,260,174,168,306, 124,286,240,178,340,150,232,394, 100,226,176,142,268,106,184,310, 84,234,272,134,284,274,184,334, 100,274,400,158,332,218,216,390, 148,346,304,214,412,194,280,478});
nd4j::ops::tensormmul op;
auto results = op.evaluate({&x, &y}, {}, {1,1,1,2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *result = results->at(0);
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests2, TestTensorDot8) {
auto x = NDArrayFactory::create<double>('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
auto y = NDArrayFactory::create<double>('f', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
auto expected = NDArrayFactory::create<double>('c', {2,4,2,4}, {30, 90,150,210, 60,120,180,240, 38,114,190,266, 76,152,228,304, 46,138,230,322, 92,184,276,368, 54,162,270,378,108,216,324,432, 42,126,210,294, 84,168,252,336, 50,150,250,350,100,200,300,400, 58,174,290,406,116,232,348,464, 66,198,330,462,132,264,396,528});
nd4j::ops::tensormmul op;
auto results = op.evaluate({&x, &y}, {}, {1,1,1,2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *result = results->at(0);
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests2, TestTensorDot9) {
// NDArray z('f',{2,2,3}, nd4j::DataType::DOUBLE);
// z.linspace(1);
// z.printShapeInfo();
// z.printIndexedBuffer();
// z.reshapei('c', {4,3});
// z.printShapeInfo();
// z.printIndexedBuffer();
auto x = NDArrayFactory::create<double>('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
auto y = NDArrayFactory::create<double>('f', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
auto expected = NDArrayFactory::create<double>('c', {3,4,4,3}, {14, 14, 14, 30, 30, 30, 46, 46, 46, 62, 62, 62, 86, 86, 86,198,198,198,310,310,310,422,422,422, 62, 62, 62,142,142,142,222,222,222,302,302,302, 38, 38, 38, 86, 86, 86,134,134,134,182,182,182, 38, 38, 38, 86, 86, 86,134,134,134,182,182,182, 14, 14, 14, 30, 30, 30, 46, 46, 46, 62, 62, 62, 86, 86, 86,198,198,198,310,310,310,422,422,422, 62, 62, 62,142,142,142,222,222,222,302,302,302, 62, 62, 62,142,142,142,222,222,222,302,302,302, 38, 38, 38, 86, 86, 86,134,134,134,182,182,182, 14, 14, 14, 30, 30, 30, 46, 46, 46, 62, 62, 62, 86, 86, 86,198,198,198,310,310,310,422,422,422});
nd4j::ops::tensormmul op;
auto results = op.evaluate({&x, &y}, {}, {1,0,1,0});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *result = results->at(0);
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests2, TestTensorDot10) {
auto x = NDArrayFactory::create<double>('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
auto y = NDArrayFactory::create<double>('f', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
auto expected = NDArrayFactory::create<double>('c', {4,4}, {114,258,402,546, 138,314,490,666, 162,370,578,786, 186,426,666,906});
nd4j::ops::tensormmul op;
auto results = op.evaluate({&x, &y}, {}, {2,0,1, 2,0,2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *result = results->at(0);
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests2, TestTensorDot11) {
auto x = NDArrayFactory::create<double>('c', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
auto y = NDArrayFactory::create<double>('f', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
auto expected = NDArrayFactory::create<double>('c', {4,4}, {98,218,338,458, 134,302,470,638, 170,386,602,818, 206,470,734,998});
nd4j::ops::tensormmul op;
auto results = op.evaluate({&x, &y}, {}, {2,0,1, 2,0,2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *result = results->at(0);
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests2, TestTensorDot12) {
auto x = NDArrayFactory::create<double>('c', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
auto y = NDArrayFactory::create<double>('c', {2,4,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
auto expected = NDArrayFactory::create<double>('c', {4,4}, {272,292,312,332, 368,396,424,452, 464,500,536,572, 560,604,648,692});
nd4j::ops::tensormmul op;
auto results = op.evaluate({&x, &y}, {}, {2,0,1, 2,0,2});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *result = results->at(0);
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests2, TestTensorDot13) {
auto x = NDArrayFactory::create<double>('c', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
auto y = NDArrayFactory::create<double>('c', {4,2,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
auto expected = NDArrayFactory::create<double>('c', {3,3}, {640,560,640, 576,624,576, 640,560,640});
nd4j::ops::tensormmul op;
auto results = op.evaluate({&x, &y}, {}, {2,0,2, 2,1,0});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *result = results->at(0);
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests2, TestTensorDot14) {
auto x = NDArrayFactory::create<double>('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
auto y = NDArrayFactory::create<double>('c', {4,2,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
auto expected = NDArrayFactory::create<double>('c', {3,3}, {648,600,520, 648,536,648, 520,600,648});
nd4j::ops::tensormmul op;
auto results = op.evaluate({&x, &y}, {}, {2,0,2, 2,1,0});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *result = results->at(0);
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests2, TestTensorDot15) {
auto x = NDArrayFactory::create<double>('f', {2,3,4}, {1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15});
auto y = NDArrayFactory::create<double>('f', {4,2,3}, {2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16});
auto expected = NDArrayFactory::create<double>('c', {3,3}, {624,624,624, 656,656,656, 624,624,624});
nd4j::ops::tensormmul op;
auto results = op.evaluate({&x, &y}, {}, {2,0,2, 2,1,0});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *result = results->at(0);
ASSERT_TRUE(expected.isSameShape(result));
ASSERT_TRUE(expected.equalsTo(result));
delete results;
}
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_1) {

View File

@ -2043,34 +2043,6 @@ TEST_F(DeclarableOpsTests9, cumprod_test1) {
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {1, 1}, {1, 1},GradCheck::MEAN);
ASSERT_TRUE(isGradCorrect);
//************************************//
/* exclusive = 1; reverse = 0;
result = op.execute({&inputC, &axis}, {}, {exclusive, reverse});
ASSERT_EQ(Status::OK(), result->status());
z = result->at(0);
ASSERT_TRUE(expTF.equalsTo(z));
delete result;
*/
//************************************//
/* exclusive = 0; reverse = 1;
result = op.execute({&inputC, &axis}, {}, {exclusive, reverse});
ASSERT_EQ(Status::OK(), result->status());
z = result->at(0);
ASSERT_TRUE(expFT.equalsTo(z));
delete result;
*/
//************************************//
/* exclusive = 1; reverse = 1;
result = op.execute({&inputC, &axis}, {}, {exclusive, reverse});
ASSERT_EQ(Status::OK(), result->status());
z = result->at(0);
ASSERT_TRUE(expTT.equalsTo(z));
delete result;
*/
}
////////////////////////////////////////////////////////////////////////////////
@ -2079,11 +2051,6 @@ TEST_F(DeclarableOpsTests9, cumprod_test2) {
auto inputC = NDArrayFactory::create<double>('c', {2, 2});
auto axis = NDArrayFactory::create<double>(1.);
// auto expFF = NDArrayFactory::create<double>('c', {3, 5}, {1., 2., 6., 24., 120., 6., 42., 336., 3024., 30240.,11., 132.,1716., 24024.,360360.});
// auto expTF = NDArrayFactory::create<double>('c', {3, 5}, {1, 1, 2, 6, 24,1, 6, 42, 336, 3024,1, 11, 132, 1716, 24024});
// auto expFT = NDArrayFactory::create<double>('c', {3, 5}, {120, 120, 60, 20, 5,30240, 5040, 720, 90, 10,360360, 32760, 2730, 210, 15}); //+++
// auto expTT = NDArrayFactory::create<double>('c', {3, 5}, {120, 60, 20, 5, 1,5040, 720, 90, 10, 1,32760, 2730, 210, 15, 1});
auto gradO = NDArrayFactory::create<double>('c', {2, 2});
int exclusive, reverse;

View File

@ -61,7 +61,7 @@ public class TestPCA extends BaseNd4jTest {
assertEquals("Reconstructed matrix is very different from the original.", 0.0, Diff.getDouble(i), 1.0);
}
}
@Test
public void testFactorSVDTransposed() {
int m = 4;