Shyrma temp (#131)

* - specifying template instantiation for certain types in float16 and bloat16

Signed-off-by: Yurii <iuriish@yahoo.com>

* - polishing bfloat16 and float16 member functions template specialization

Signed-off-by: Yurii <iuriish@yahoo.com>

* - rewrite and overload array +-*/ scalar and scalar +-*/ arr in NDAray class

Signed-off-by: Yurii <iuriish@yahoo.com>

* - make corrections which have to do with and rvalue lvalue conversions

Signed-off-by: Yurii <iuriish@yahoo.com>

* - provide move semantic in NDArray operators array +-/* array

Signed-off-by: Yurii <iuriish@yahoo.com>

* float16/bfloat16 tweaks

Signed-off-by: raver119 <raver119@gmail.com>

* one more tweak

Signed-off-by: raver119 <raver119@gmail.com>

* - make float16 and bfloat16 to compile successfully on cuda

Signed-off-by: Yurii <iuriish@yahoo.com>

* - do not use resources of view-like arrays when move semantics is applied

Signed-off-by: Yurii <iuriish@yahoo.com>

* - get rid of pointers in signatures NDArray methods 1

Signed-off-by: Yurii <iuriish@yahoo.com>

* - correction of signature of NDArray::dup method

Signed-off-by: Yurii <iuriish@yahoo.com>

* - correction of signature of NDArray::reduceAlongDimension method

Signed-off-by: Yurii <iuriish@yahoo.com>

* - signature correction of NDArray::applyIndexReduce and applyTrueBroadcast methods

Signed-off-by: Yurii <iuriish@yahoo.com>

* - signature correction of NDArray::applyReduce3 and varianceAlongDimension methods

Signed-off-by: Yurii <iuriish@yahoo.com>

* - signature correction of NDArray::tensorsAlongDimension and diagonal methods

Signed-off-by: Yurii <iuriish@yahoo.com>

* - signature correction of NDArray::allTensorsAlongDimension

Signed-off-by: Yurii <iuriish@yahoo.com>

* - signature correction of NDArray::reduceAlongDimension 2

Signed-off-by: Yurii <iuriish@yahoo.com>

* - signature correction of NDArray::applyTransform 2

Signed-off-by: Yurii <iuriish@yahoo.com>

* - signature correction of NDArray::applyPairwiseTransform 2

Signed-off-by: Yurii <iuriish@yahoo.com>

* - signature correction of NDArray::applyBroadcast 2

Signed-off-by: Yurii <iuriish@yahoo.com>

* - signature correction of NDArray::applyTrueBroadcast 2

Signed-off-by: Yurii <iuriish@yahoo.com>

* - signature correction of NDArray::applyScalar and applyScalarArr

Signed-off-by: Yurii <iuriish@yahoo.com>

* - signature correction of NDArray::lambda methods

Signed-off-by: Yurii <iuriish@yahoo.com>

* - signature correction of NDArray::reduce3 methods 2

Signed-off-by: Yurii <iuriish@yahoo.com>

* - signature correction of following NDArray methods: add/sub/mul/div row/column and fillAsTriangular

Signed-off-by: Yurii <iuriish@yahoo.com>

* - signature correction of NDArray::tileToShape methods

Signed-off-by: Yurii <iuriish@yahoo.com>

* - signature correction of NDArray::isShapeSameStrict method

Signed-off-by: Yurii <iuriish@yahoo.com>

* minor corrections in tests

Signed-off-by: Yurii <iuriish@yahoo.com>

* - replace reduce op in batchnorm mkldnn

Signed-off-by: Yurii <iuriish@yahoo.com>

* - add explicit templates instantiations for operator+(NDArray&&. const scalar)

Signed-off-by: Yurii <iuriish@yahoo.com>

* - corrections of casts in float16/bfloat16

Signed-off-by: Yurii <iuriish@yahoo.com>

* - provide move semantics in following NDArray methods: transform, applyTrueBroadcast, transpose, reshape, permute

Signed-off-by: Yurii <iuriish@yahoo.com>

* - get rid of input array A duplicate in svd cuda op

Signed-off-by: Yurii <iuriish@yahoo.com>

* - avoid available bug in svd cuda API

Signed-off-by: Yurii <iuriish@yahoo.com>

* - add temporary global memory buffer in svd cuda when calcUV = false and  m != n

Signed-off-by: Yurii <iuriish@yahoo.com>

* - remove test with blfoat16 type for betainC

Signed-off-by: Yurii <iuriish@yahoo.com>

* - resolve conflicts after master has been merged in

Signed-off-by: Yurii <iuriish@yahoo.com>

* - changed type of affected input array in fused_batch_norm

Signed-off-by: Yurii <iuriish@yahoo.com>

* - add several explicit type castings

Signed-off-by: Yurii <iuriish@yahoo.com>

* - add ND4J_EXPORT to operators

Signed-off-by: Yurii <iuriish@yahoo.com>

* - add explicit template types in instantiations of template arithm operators of NDArray class

Signed-off-by: Yurii <iuriish@yahoo.com>

* - one more test fix

Signed-off-by: Yurii <iuriish@yahoo.com>

Co-authored-by: raver119 <raver119@gmail.com>
master
Yurii Shyrma 2019-12-20 21:35:39 +02:00 committed by raver119
parent 3e0afadea1
commit 5d9b2a16e5
237 changed files with 5235 additions and 6513 deletions

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -133,7 +133,7 @@ namespace graph {
if (variableSpace->hasVariable(v->getName())) { if (variableSpace->hasVariable(v->getName())) {
// symbolic feeder // symbolic feeder
auto array = variableSpace->getVariable(v->getName())->getNDArray(); auto array = variableSpace->getVariable(v->getName())->getNDArray();
auto vr = array->dup(); auto vr = new NDArray(array->dup());
// deletables.push_back(vr); // deletables.push_back(vr);
v->setNDArray(vr); v->setNDArray(vr);
} else { } else {
@ -145,7 +145,7 @@ namespace graph {
// if we're not using symbolic lookup - we'll use sequential approach then // if we're not using symbolic lookup - we'll use sequential approach then
auto p = node->input()->at(cnt); auto p = node->input()->at(cnt);
auto array = variableSpace->getVariable(p)->getNDArray(); auto array = variableSpace->getVariable(p)->getNDArray();
auto vr = array->dup(); auto vr = new NDArray(array->dup());
//deletables.push_back(vr); //deletables.push_back(vr);
v->setNDArray(vr); v->setNDArray(vr);
} }

View File

@ -71,44 +71,41 @@ void NDArray::makeBothBuffersActual() const { }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
template <typename T> template <typename T>
void NDArray::fillAsTriangular(const float val, int lower, int upper, const char direction, NDArray* target) { void NDArray::fillAsTriangular(const float val, int lower, int upper, NDArray& target, const char direction) {
if (isS()) if (isS())
throw std::runtime_error("NDArray::fillArrayAsTriangular: you can't use this method on String array!"); throw std::runtime_error("NDArray::fillArrayAsTriangular: you can't use this method on String array!");
if(target == nullptr) if(!isSameShape(target) && !(rankOf() == 1 && target.rankOf() == 2 && sizeAt(0) == target.sizeAt(0) && sizeAt(0) == target.sizeAt(1)))
target = this;
if(!isSameShape(target) && !(rankOf() == 1 && target->rankOf() == 2 && sizeAt(0) == target->sizeAt(0) && sizeAt(0) == target->sizeAt(1)))
throw std::string("NDArray::fillArrayAsTriangular method: wrong shape of target array !"); throw std::string("NDArray::fillArrayAsTriangular method: wrong shape of target array !");
if (direction == 'u') if (direction == 'u')
lower = -target->sizeAt(-2); lower = -target.sizeAt(-2);
else if (direction == 'l') else if (direction == 'l')
upper = target->sizeAt(-1); upper = target.sizeAt(-1);
const T value = static_cast<T>(val); const T value = static_cast<T>(val);
const auto x = reinterpret_cast<const T*>(getBuffer()); const auto x = reinterpret_cast<const T*>(getBuffer());
auto z = reinterpret_cast<T*>(target->getBuffer()); auto z = reinterpret_cast<T*>(target.getBuffer());
const int xRank = rankOf(); const int xRank = rankOf();
const int zRank = target->rankOf(); const int zRank = target.rankOf();
const auto zLen = target->lengthOf(); const auto zLen = target.lengthOf();
const bool areSameOffsets = shape::haveSameShapeAndStrides(getShapeInfo(), target->getShapeInfo()); const bool areSameOffsets = shape::haveSameShapeAndStrides(getShapeInfo(), target.getShapeInfo());
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
Nd4jLong coords[MAX_RANK]; Nd4jLong coords[MAX_RANK];
for (auto i = start; i < stop; i += increment) { for (auto i = start; i < stop; i += increment) {
shape::index2coords(i, target->getShapeInfo(), coords); shape::index2coords(i, target.getShapeInfo(), coords);
const auto zOffset = shape::getOffset(target->getShapeInfo(), coords); const auto zOffset = shape::getOffset(target.getShapeInfo(), coords);
// if( (row + upper < col) || (row + lower > col) ) // if( (row + upper < col) || (row + lower > col) )
if ((coords[zRank - 2] + upper < coords[zRank - 1]) || (coords[zRank - 2] + lower > coords[zRank - 1])) if ((coords[zRank - 2] + upper < coords[zRank - 1]) || (coords[zRank - 2] + lower > coords[zRank - 1]))
z[zOffset] = value; z[zOffset] = value;
else if (this != target) { // when this and target are different arrays else if (this != &target) { // when this and target are different arrays
if (xRank != zRank) if (xRank != zRank)
coords[0] = coords[1]; coords[0] = coords[1];
@ -120,7 +117,7 @@ void NDArray::fillAsTriangular(const float val, int lower, int upper, const char
samediff::Threads::parallel_for(func, 0, zLen); samediff::Threads::parallel_for(func, 0, zLen);
} }
BUILD_SINGLE_TEMPLATE(template void NDArray::fillAsTriangular, (const float val, int lower, int upper, const char direction, NDArray* target), LIBND4J_TYPES); BUILD_SINGLE_TEMPLATE(template void NDArray::fillAsTriangular, (const float val, int lower, int upper, NDArray& target, const char direction), LIBND4J_TYPES);
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void NDArray::setIdentity() { void NDArray::setIdentity() {
@ -405,11 +402,11 @@ static void repeat_(const NDArray& input, NDArray& output, const std::vector<int
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
// create new array by repeating it the number of times given by repeats // create new array by repeating it the number of times given by repeats
NDArray* NDArray::repeat(const int axis, const std::vector<int>& repeats) const { NDArray NDArray::repeat(const int axis, const std::vector<int>& repeats) const {
auto output = new NDArray('c', ShapeUtils::evalRepeatShape(axis, repeats, *this), dataType(), getContext()); NDArray output('c', ShapeUtils::evalRepeatShape(axis, repeats, *this), dataType(), getContext());
BUILD_SINGLE_SELECTOR_TWICE(dataType(), repeat_, (*this, *output, repeats, axis), LIBND4J_TYPES); BUILD_SINGLE_SELECTOR_TWICE(dataType(), repeat_, (*this, output, repeats, axis), LIBND4J_TYPES);
return output; return output;
} }

View File

@ -2,35 +2,24 @@
template<typename T> template<typename T>
void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<T(T, T, T)>& func, NDArray* target) { void NDArray::applyTriplewiseLambda(NDArray& second, NDArray& third, const std::function<T(T, T, T)>& func, NDArray& target) {
if (target == nullptr)
target = this;
if (second == nullptr) {
nd4j_printf("applyTriplewiseLambda requires three operands to be valid NDArrays, but Second is NULL\n","");
throw std::runtime_error("second is null");
}
if (third == nullptr) {
nd4j_printf("applyTriplewiseLambda requires three operands to be valid NDArrays, but Third is NULL\n","");
throw std::runtime_error("third is null");
}
if(dataType() != DataTypeUtils::fromT<T>()) if(dataType() != DataTypeUtils::fromT<T>())
throw std::runtime_error("NDArray::applyTriplewiseLambda<T> method: wrong template parameter T, its type should be the same as type of this array!"); throw std::runtime_error("NDArray::applyTriplewiseLambda<T> method: wrong template parameter T, its type should be the same as type of this array!");
if(dataType() != second->dataType() || dataType() != third->dataType() || dataType() != target->dataType()) if(dataType() != second.dataType() || dataType() != third.dataType() || dataType() != target.dataType())
throw std::runtime_error("NDArray::applyTriplewiseLambda<T> method: bother four arrays (this, second, third, target) should have the same type !"); throw std::runtime_error("NDArray::applyTriplewiseLambda<T> method: bother four arrays (this, second, third, target) should have the same type !");
if (this->lengthOf() != second->lengthOf() || this->lengthOf() != third->lengthOf() || !this->isSameShape(second) || !this->isSameShape(third)) { if (this->lengthOf() != second.lengthOf() || this->lengthOf() != third.lengthOf() || !this->isSameShape(second) || !this->isSameShape(third)) {
nd4j_printf("applyPairwiseLambda requires both operands to have the same shape\n",""); nd4j_printf("applyPairwiseLambda requires both operands to have the same shape\n","");
throw std::runtime_error("Shapes mismach"); throw std::runtime_error("Shapes mismach");
} }
auto f = this->bufferAsT<T>(); auto f = this->bufferAsT<T>();
auto s = second->bufferAsT<T>(); auto s = second.bufferAsT<T>();
auto t = third->bufferAsT<T>(); auto t = third.bufferAsT<T>();
auto z = target->bufferAsT<T>(); auto z = target.bufferAsT<T>();
if (this->ordering() == second->ordering() && this->ordering() == third->ordering() && this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1) && this->ews() == second->ews() && this->ews() == third->ews()) { if (this->ordering() == second.ordering() && this->ordering() == third.ordering() && this->ordering() == target.ordering() && (this->ews() == 1 && target.ews() == 1) && this->ews() == second.ews() && this->ews() == third.ews()) {
auto loop = PRAGMA_THREADS_FOR { auto loop = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e += increment) for (auto e = start; e < stop; e += increment)
@ -44,8 +33,8 @@ void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::
auto loop = PRAGMA_THREADS_FOR { auto loop = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e += increment) { for (auto e = start; e < stop; e += increment) {
auto tOffset = this->getOffset(e); auto tOffset = this->getOffset(e);
auto uOffset = second->getOffset(e); auto uOffset = second.getOffset(e);
auto vOffset = third->getOffset(e); auto vOffset = third.getOffset(e);
f[tOffset] = func(f[tOffset], s[uOffset], t[vOffset]); f[tOffset] = func(f[tOffset], s[uOffset], t[vOffset]);
} }
@ -57,9 +46,9 @@ void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::
auto loop = PRAGMA_THREADS_FOR { auto loop = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e += increment) { for (auto e = start; e < stop; e += increment) {
auto tOffset = this->getOffset(e); auto tOffset = this->getOffset(e);
auto uOffset = second->getOffset(e); auto uOffset = second.getOffset(e);
auto vOffset = third->getOffset(e); auto vOffset = third.getOffset(e);
auto zOffset = target->getOffset(e); auto zOffset = target.getOffset(e);
z[zOffset] = func(f[tOffset], s[uOffset], t[vOffset]); z[zOffset] = func(f[tOffset], s[uOffset], t[vOffset]);
} }
@ -69,46 +58,39 @@ void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::
} }
} }
} }
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<double (double, double, double)>& func, NDArray* target); template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function<double (double, double, double)>& func, NDArray& target);
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<float (float, float, float)>& func, NDArray* target); template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function<float (float, float, float)>& func, NDArray& target);
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<float16 (float16, float16, float16)>& func, NDArray* target); template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function<float16 (float16, float16, float16)>& func, NDArray& target);
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<bfloat16 (bfloat16, bfloat16, bfloat16)>& func, NDArray* target); template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function<bfloat16 (bfloat16, bfloat16, bfloat16)>& func, NDArray& target);
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<Nd4jLong (Nd4jLong, Nd4jLong, Nd4jLong)>& func, NDArray* target); template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function<Nd4jLong (Nd4jLong, Nd4jLong, Nd4jLong)>& func, NDArray& target);
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<int (int, int, int)>& func, NDArray* target); template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function<int (int, int, int)>& func, NDArray& target);
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<int16_t (int16_t, int16_t, int16_t)>& func, NDArray* target); template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function<int16_t (int16_t, int16_t, int16_t)>& func, NDArray& target);
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<uint8_t (uint8_t, uint8_t, uint8_t)>& func, NDArray* target); template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function<uint8_t (uint8_t, uint8_t, uint8_t)>& func, NDArray& target);
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<uint16_t (uint16_t, uint16_t, uint16_t)>& func, NDArray* target); template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function<uint16_t (uint16_t, uint16_t, uint16_t)>& func, NDArray& target);
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<uint32_t (uint32_t, uint32_t, uint32_t)>& func, NDArray* target); template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function<uint32_t (uint32_t, uint32_t, uint32_t)>& func, NDArray& target);
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<uint64_t (uint64_t, uint64_t, uint64_t)>& func, NDArray* target); template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function<uint64_t (uint64_t, uint64_t, uint64_t)>& func, NDArray& target);
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<int8_t (int8_t, int8_t, int8_t)>& func, NDArray* target); template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function<int8_t (int8_t, int8_t, int8_t)>& func, NDArray& target);
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<bool (bool, bool, bool)>& func, NDArray* target); template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function<bool (bool, bool, bool)>& func, NDArray& target);
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template<typename T> template<typename T>
void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<T(T, T)>& func, NDArray* target) { void NDArray::applyPairwiseLambda(const NDArray& other, const std::function<T(T, T)>& func, NDArray& target) {
if (target == nullptr)
target = this;
if (other == nullptr) {
nd4j_printf("applyPairwiseLambda requires both operands to be valid NDArrays, but Y is NULL\n","");
throw std::runtime_error("Other is null");
}
if(dataType() != DataTypeUtils::fromT<T>()) if(dataType() != DataTypeUtils::fromT<T>())
throw std::runtime_error("NDArray::applyPairwiseLambda<T> method: wrong template parameter T, its type should be the same as type of this array!"); throw std::runtime_error("NDArray::applyPairwiseLambda<T> method: wrong template parameter T, its type should be the same as type of this array!");
if(dataType() != other->dataType() || dataType() != target->dataType()) if(dataType() != other.dataType() || dataType() != target.dataType())
throw std::runtime_error("NDArray::applyPairwiseLambda<T> method: all three arrays (this, other, target) must have the same type !"); throw std::runtime_error("NDArray::applyPairwiseLambda<T> method: all three arrays (this, other, target) must have the same type !");
if (this->lengthOf() != other->lengthOf()) { if (this->lengthOf() != other.lengthOf()) {
nd4j_printf("applyPairwiseLambda requires both operands to have the same shape\n",""); nd4j_printf("applyPairwiseLambda requires both operands to have the same shape\n","");
throw std::runtime_error("Shapes mismach"); throw std::runtime_error("Shapes mismach");
} }
auto f = this->bufferAsT<T>(); auto f = this->bufferAsT<T>();
auto s = other->bufferAsT<T>(); auto s = other.bufferAsT<T>();
auto z = target->bufferAsT<T>(); auto z = target.bufferAsT<T>();
if (this->ordering() == other->ordering() && this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1) && this->ews() == other->ews()) { if (this->ordering() == other.ordering() && this->ordering() == target.ordering() && (this->ews() == 1 && target.ews() == 1) && this->ews() == other.ews()) {
auto loop = PRAGMA_THREADS_FOR { auto loop = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e += increment) for (auto e = start; e < stop; e += increment)
@ -122,7 +104,7 @@ void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<T(T,
auto loop = PRAGMA_THREADS_FOR { auto loop = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e += increment) { for (auto e = start; e < stop; e += increment) {
auto xOffset = this->getOffset(e); auto xOffset = this->getOffset(e);
auto yOffset = other->getOffset(e); auto yOffset = other.getOffset(e);
f[xOffset] = func(f[xOffset], s[yOffset]); f[xOffset] = func(f[xOffset], s[yOffset]);
} }
@ -134,8 +116,8 @@ void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<T(T,
auto loop = PRAGMA_THREADS_FOR { auto loop = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e += increment) { for (auto e = start; e < stop; e += increment) {
auto xOffset = this->getOffset(e); auto xOffset = this->getOffset(e);
auto yOffset = other->getOffset(e); auto yOffset = other.getOffset(e);
auto zOffset = target->getOffset(e); auto zOffset = target.getOffset(e);
z[zOffset] = func(f[xOffset], s[yOffset]); z[zOffset] = func(f[xOffset], s[yOffset]);
} }
@ -145,35 +127,33 @@ void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<T(T,
} }
} }
} }
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<double (double, double)>& func, NDArray* target); template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function<double (double, double)>& func, NDArray& target);
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<float (float, float)>& func, NDArray* target); template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function<float (float, float)>& func, NDArray& target);
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<float16 (float16, float16)>& func, NDArray* target); template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function<float16 (float16, float16)>& func, NDArray& target);
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<bfloat16 (bfloat16, bfloat16)>& func, NDArray* target); template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function<bfloat16 (bfloat16, bfloat16)>& func, NDArray& target);
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<Nd4jLong (Nd4jLong, Nd4jLong)>& func, NDArray* target); template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function<Nd4jLong (Nd4jLong, Nd4jLong)>& func, NDArray& target);
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<int (int, int)>& func, NDArray* target); template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function<int (int, int)>& func, NDArray& target);
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<int16_t (int16_t, int16_t)>& func, NDArray* target); template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function<int16_t (int16_t, int16_t)>& func, NDArray& target);
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<uint8_t (uint8_t, uint8_t)>& func, NDArray* target); template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function<uint8_t (uint8_t, uint8_t)>& func, NDArray& target);
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<uint16_t (uint16_t, uint16_t)>& func, NDArray* target); template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function<uint16_t (uint16_t, uint16_t)>& func, NDArray& target);
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<uint32_t (uint32_t, uint32_t)>& func, NDArray* target); template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function<uint32_t (uint32_t, uint32_t)>& func, NDArray& target);
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<uint64_t (uint64_t, uint64_t)>& func, NDArray* target); template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function<uint64_t (uint64_t, uint64_t)>& func, NDArray& target);
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<int8_t (int8_t, int8_t)>& func, NDArray* target); template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function<int8_t (int8_t, int8_t)>& func, NDArray& target);
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<bool (bool, bool)>& func, NDArray* target); template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function<bool (bool, bool)>& func, NDArray& target);
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template<typename T> template<typename T>
void NDArray::applyLambda(const std::function<T(T)>& func, NDArray* target) { void NDArray::applyLambda(const std::function<T(T)>& func, NDArray& target) {
if (target == nullptr)
target = this;
if(dataType() != DataTypeUtils::fromT<T>()) if(dataType() != DataTypeUtils::fromT<T>())
throw std::runtime_error("NDArray::applyLambda<T> method: wrong template parameter T, its type should be the same as type of this array!"); throw std::runtime_error("NDArray::applyLambda<T> method: wrong template parameter T, its type should be the same as type of this array!");
if(dataType() != target->dataType()) if(dataType() != target.dataType())
throw std::runtime_error("NDArray::applyLambda<T> method: types of this and target array should match !"); throw std::runtime_error("NDArray::applyLambda<T> method: types of this and target array should match !");
auto f = this->bufferAsT<T>(); auto f = this->bufferAsT<T>();
auto z = target->bufferAsT<T>(); auto z = target.bufferAsT<T>();
if (this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1)) { if (this->ordering() == target.ordering() && (this->ews() == 1 && target.ews() == 1)) {
auto loop = PRAGMA_THREADS_FOR { auto loop = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e += increment) for (auto e = start; e < stop; e += increment)
@ -198,7 +178,7 @@ void NDArray::applyLambda(const std::function<T(T)>& func, NDArray* target) {
auto loop = PRAGMA_THREADS_FOR { auto loop = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e += increment) { for (auto e = start; e < stop; e += increment) {
auto xOffset = this->getOffset(e); auto xOffset = this->getOffset(e);
auto zOffset = target->getOffset(e); auto zOffset = target.getOffset(e);
z[zOffset] = func(f[xOffset]); z[zOffset] = func(f[xOffset]);
} }
@ -208,35 +188,33 @@ void NDArray::applyLambda(const std::function<T(T)>& func, NDArray* target) {
} }
} }
} }
template void NDArray::applyLambda(const std::function<double(double)>& func, NDArray* target); template void NDArray::applyLambda(const std::function<double(double)>& func, NDArray& target);
template void NDArray::applyLambda(const std::function<float(float)>& func, NDArray* target); template void NDArray::applyLambda(const std::function<float(float)>& func, NDArray& target);
template void NDArray::applyLambda(const std::function<float16(float16)>& func, NDArray* target); template void NDArray::applyLambda(const std::function<float16(float16)>& func, NDArray& target);
template void NDArray::applyLambda(const std::function<bfloat16(bfloat16)>& func, NDArray* target); template void NDArray::applyLambda(const std::function<bfloat16(bfloat16)>& func, NDArray& target);
template void NDArray::applyLambda(const std::function<Nd4jLong(Nd4jLong)>& func, NDArray* target); template void NDArray::applyLambda(const std::function<Nd4jLong(Nd4jLong)>& func, NDArray& target);
template void NDArray::applyLambda(const std::function<int16_t(int16_t)>& func, NDArray* target); template void NDArray::applyLambda(const std::function<int16_t(int16_t)>& func, NDArray& target);
template void NDArray::applyLambda(const std::function<int32_t(int32_t)>& func, NDArray* target); template void NDArray::applyLambda(const std::function<int32_t(int32_t)>& func, NDArray& target);
template void NDArray::applyLambda(const std::function<uint8_t(uint8_t)>& func, NDArray* target); template void NDArray::applyLambda(const std::function<uint8_t(uint8_t)>& func, NDArray& target);
template void NDArray::applyLambda(const std::function<uint16_t(uint16_t)>& func, NDArray* target); template void NDArray::applyLambda(const std::function<uint16_t(uint16_t)>& func, NDArray& target);
template void NDArray::applyLambda(const std::function<uint32_t(uint32_t)>& func, NDArray* target); template void NDArray::applyLambda(const std::function<uint32_t(uint32_t)>& func, NDArray& target);
template void NDArray::applyLambda(const std::function<uint64_t(uint64_t)>& func, NDArray* target); template void NDArray::applyLambda(const std::function<uint64_t(uint64_t)>& func, NDArray& target);
template void NDArray::applyLambda(const std::function<int8_t(int8_t)>& func, NDArray* target); template void NDArray::applyLambda(const std::function<int8_t(int8_t)>& func, NDArray& target);
template void NDArray::applyLambda(const std::function<bool(bool)>& func, NDArray* target); template void NDArray::applyLambda(const std::function<bool(bool)>& func, NDArray& target);
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template<typename T> template<typename T>
void NDArray::applyIndexedLambda(const std::function<T(Nd4jLong, T)>& func, NDArray* target) { void NDArray::applyIndexedLambda(const std::function<T(Nd4jLong, T)>& func, NDArray& target) {
if (target == nullptr)
target = this;
if(dataType() != DataTypeUtils::fromT<T>()) if(dataType() != DataTypeUtils::fromT<T>())
throw std::runtime_error("NDArray::applyIndexedLambda<T> method: wrong template parameter T, its type should be the same as type of this array!"); throw std::runtime_error("NDArray::applyIndexedLambda<T> method: wrong template parameter T, its type should be the same as type of this array!");
if(dataType() != target->dataType()) if(dataType() != target.dataType())
throw std::runtime_error("NDArray::applyIndexedLambda<T> method: types of this and target array should match !"); throw std::runtime_error("NDArray::applyIndexedLambda<T> method: types of this and target array should match !");
auto f = this->bufferAsT<T>(); auto f = this->bufferAsT<T>();
auto z = target->bufferAsT<T>(); auto z = target.bufferAsT<T>();
if (this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1)) { if (this->ordering() == target.ordering() && (this->ews() == 1 && target.ews() == 1)) {
auto loop = PRAGMA_THREADS_FOR { auto loop = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e += increment) for (auto e = start; e < stop; e += increment)
@ -261,7 +239,7 @@ void NDArray::applyIndexedLambda(const std::function<T(Nd4jLong, T)>& func, NDAr
auto loop = PRAGMA_THREADS_FOR { auto loop = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e += increment) { for (auto e = start; e < stop; e += increment) {
auto xOffset = this->getOffset(e); auto xOffset = this->getOffset(e);
auto zOffset = target->getOffset(e); auto zOffset = target.getOffset(e);
z[zOffset] = func(e, f[xOffset]); z[zOffset] = func(e, f[xOffset]);
} }
@ -271,44 +249,38 @@ void NDArray::applyIndexedLambda(const std::function<T(Nd4jLong, T)>& func, NDAr
} }
} }
} }
template void NDArray::applyIndexedLambda(const std::function<double(Nd4jLong, double)>& func, NDArray* target); template void NDArray::applyIndexedLambda(const std::function<double(Nd4jLong, double)>& func, NDArray& target);
template void NDArray::applyIndexedLambda(const std::function<float(Nd4jLong, float)>& func, NDArray* target); template void NDArray::applyIndexedLambda(const std::function<float(Nd4jLong, float)>& func, NDArray& target);
template void NDArray::applyIndexedLambda(const std::function<float16(Nd4jLong, float16)>& func, NDArray* target); template void NDArray::applyIndexedLambda(const std::function<float16(Nd4jLong, float16)>& func, NDArray& target);
template void NDArray::applyIndexedLambda(const std::function<bfloat16(Nd4jLong, bfloat16)>& func, NDArray* target); template void NDArray::applyIndexedLambda(const std::function<bfloat16(Nd4jLong, bfloat16)>& func, NDArray& target);
template void NDArray::applyIndexedLambda(const std::function<Nd4jLong(Nd4jLong, Nd4jLong)>& func, NDArray* target); template void NDArray::applyIndexedLambda(const std::function<Nd4jLong(Nd4jLong, Nd4jLong)>& func, NDArray& target);
template void NDArray::applyIndexedLambda(const std::function<int(Nd4jLong, int)>& func, NDArray* target); template void NDArray::applyIndexedLambda(const std::function<int(Nd4jLong, int)>& func, NDArray& target);
template void NDArray::applyIndexedLambda(const std::function<int16_t(Nd4jLong, int16_t)>& func, NDArray* target); template void NDArray::applyIndexedLambda(const std::function<int16_t(Nd4jLong, int16_t)>& func, NDArray& target);
template void NDArray::applyIndexedLambda(const std::function<uint8_t (Nd4jLong, uint8_t)>& func, NDArray* target); template void NDArray::applyIndexedLambda(const std::function<uint8_t (Nd4jLong, uint8_t)>& func, NDArray& target);
template void NDArray::applyIndexedLambda(const std::function<uint16_t (Nd4jLong, uint16_t)>& func, NDArray* target); template void NDArray::applyIndexedLambda(const std::function<uint16_t (Nd4jLong, uint16_t)>& func, NDArray& target);
template void NDArray::applyIndexedLambda(const std::function<uint32_t (Nd4jLong, uint32_t)>& func, NDArray* target); template void NDArray::applyIndexedLambda(const std::function<uint32_t (Nd4jLong, uint32_t)>& func, NDArray& target);
template void NDArray::applyIndexedLambda(const std::function<uint64_t (Nd4jLong, uint64_t)>& func, NDArray* target); template void NDArray::applyIndexedLambda(const std::function<uint64_t (Nd4jLong, uint64_t)>& func, NDArray& target);
template void NDArray::applyIndexedLambda(const std::function<int8_t(Nd4jLong, int8_t)>& func, NDArray* target); template void NDArray::applyIndexedLambda(const std::function<int8_t(Nd4jLong, int8_t)>& func, NDArray& target);
template void NDArray::applyIndexedLambda(const std::function<bool(Nd4jLong, bool)>& func, NDArray* target); template void NDArray::applyIndexedLambda(const std::function<bool(Nd4jLong, bool)>& func, NDArray& target);
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template<typename T> template<typename T>
void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<T(Nd4jLong, T, T)>& func, NDArray* target) { void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function<T(Nd4jLong, T, T)>& func, NDArray& target) {
if (target == nullptr)
target = this;
if (other == nullptr) {
nd4j_printf("applyIndexedPairwiseLambda requires both operands to be valid NDArrays, but Y is NULL\n","");
throw std::runtime_error("Other is null");
}
if(dataType() != DataTypeUtils::fromT<T>()) if(dataType() != DataTypeUtils::fromT<T>())
throw std::runtime_error("NDArray::applyIndexedPairwiseLambda<T> method: wrong template parameter T, its type should be the same as type of this array!"); throw std::runtime_error("NDArray::applyIndexedPairwiseLambda<T> method: wrong template parameter T, its type should be the same as type of this array!");
if(dataType() != target->dataType()) if(dataType() != target.dataType())
throw std::runtime_error("NDArray::applyIndexedPairwiseLambda<T> method: types of this and target array should match !"); throw std::runtime_error("NDArray::applyIndexedPairwiseLambda<T> method: types of this and target array should match !");
if (this->lengthOf() != other->lengthOf()) { if (this->lengthOf() != other.lengthOf()) {
nd4j_printf("applyIndexedPairwiseLambda requires both operands to have the same shape\n",""); nd4j_printf("applyIndexedPairwiseLambda requires both operands to have the same shape\n","");
throw std::runtime_error("Shapes mismach"); throw std::runtime_error("Shapes mismach");
} }
auto f = this->bufferAsT<T>(); auto f = this->bufferAsT<T>();
auto s = other->bufferAsT<T>(); auto s = other.bufferAsT<T>();
auto z = target->bufferAsT<T>(); auto z = target.bufferAsT<T>();
if (this->ordering() == other->ordering() && this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1) && this->ews() == other->ews()) { if (this->ordering() == other.ordering() && this->ordering() == target.ordering() && (this->ews() == 1 && target.ews() == 1) && this->ews() == other.ews()) {
auto loop = PRAGMA_THREADS_FOR { auto loop = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e += increment) for (auto e = start; e < stop; e += increment)
@ -322,7 +294,7 @@ void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<T(N
auto loop = PRAGMA_THREADS_FOR { auto loop = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e += increment) { for (auto e = start; e < stop; e += increment) {
auto xOffset = this->getOffset(e); auto xOffset = this->getOffset(e);
auto yOffset = other->getOffset(e); auto yOffset = other.getOffset(e);
f[xOffset] = func((Nd4jLong) e, f[xOffset], s[yOffset]); f[xOffset] = func((Nd4jLong) e, f[xOffset], s[yOffset]);
} }
@ -334,8 +306,8 @@ void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<T(N
auto loop = PRAGMA_THREADS_FOR { auto loop = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e += increment) { for (auto e = start; e < stop; e += increment) {
auto xOffset = this->getOffset(e); auto xOffset = this->getOffset(e);
auto yOffset = other->getOffset(e); auto yOffset = other.getOffset(e);
auto zOffset = target->getOffset(e); auto zOffset = target.getOffset(e);
z[zOffset] = func((Nd4jLong) e, f[xOffset], s[yOffset]); z[zOffset] = func((Nd4jLong) e, f[xOffset], s[yOffset]);
} }
@ -345,16 +317,16 @@ void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<T(N
} }
} }
} }
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<double (Nd4jLong, double, double)>& func, NDArray* target); template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function<double (Nd4jLong, double, double)>& func, NDArray& target);
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<float (Nd4jLong, float, float)>& func, NDArray* target); template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function<float (Nd4jLong, float, float)>& func, NDArray& target);
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<float16 (Nd4jLong, float16, float16)>& func, NDArray* target); template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function<float16 (Nd4jLong, float16, float16)>& func, NDArray& target);
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<bfloat16 (Nd4jLong, bfloat16, bfloat16)>& func, NDArray* target); template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function<bfloat16 (Nd4jLong, bfloat16, bfloat16)>& func, NDArray& target);
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<Nd4jLong (Nd4jLong, Nd4jLong, Nd4jLong)>& func, NDArray* target); template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function<Nd4jLong (Nd4jLong, Nd4jLong, Nd4jLong)>& func, NDArray& target);
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<int (Nd4jLong, int, int)>& func, NDArray* target); template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function<int (Nd4jLong, int, int)>& func, NDArray& target);
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<int16_t (Nd4jLong, int16_t, int16_t)>& func, NDArray* target); template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function<int16_t (Nd4jLong, int16_t, int16_t)>& func, NDArray& target);
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<uint8_t (Nd4jLong, uint8_t, uint8_t)>& func, NDArray* target); template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function<uint8_t (Nd4jLong, uint8_t, uint8_t)>& func, NDArray& target);
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<uint16_t (Nd4jLong, uint16_t, uint16_t)>& func, NDArray* target); template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function<uint16_t (Nd4jLong, uint16_t, uint16_t)>& func, NDArray& target);
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<uint32_t (Nd4jLong, uint32_t, uint32_t)>& func, NDArray* target); template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function<uint32_t (Nd4jLong, uint32_t, uint32_t)>& func, NDArray& target);
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<uint64_t (Nd4jLong, uint64_t, uint64_t)>& func, NDArray* target); template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function<uint64_t (Nd4jLong, uint64_t, uint64_t)>& func, NDArray& target);
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<int8_t (Nd4jLong, int8_t, int8_t)>& func, NDArray* target); template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function<int8_t (Nd4jLong, int8_t, int8_t)>& func, NDArray& target);
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<bool (Nd4jLong, bool, bool)>& func, NDArray* target); template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function<bool (Nd4jLong, bool, bool)>& func, NDArray& target);

View File

@ -2717,25 +2717,25 @@ static void _scatterUpdate(Nd4jPointer *extraPointers, int opCode, int numOfSub
switch (opCode) { switch (opCode) {
case 0: case 0:
inSubArr.applyPairwiseTransform(pairwise::Add, &updSubArr, &inSubArr, nullptr); inSubArr.applyPairwiseTransform(pairwise::Add, updSubArr, inSubArr);
break; break;
case 1: case 1:
inSubArr.applyPairwiseTransform(pairwise::Subtract, &updSubArr, &inSubArr, nullptr); inSubArr.applyPairwiseTransform(pairwise::Subtract, updSubArr, inSubArr);
break; break;
case 2: case 2:
inSubArr.applyPairwiseTransform(pairwise::Multiply, &updSubArr, &inSubArr, nullptr); inSubArr.applyPairwiseTransform(pairwise::Multiply, updSubArr, inSubArr);
break; break;
case 3: case 3:
inSubArr.applyPairwiseTransform(pairwise::Divide, &updSubArr, &inSubArr, nullptr); inSubArr.applyPairwiseTransform(pairwise::Divide, updSubArr, inSubArr);
break; break;
case 4: case 4:
inSubArr.applyPairwiseTransform(pairwise::ReverseSubtract, &updSubArr, &inSubArr, nullptr); inSubArr.applyPairwiseTransform(pairwise::ReverseSubtract, updSubArr, inSubArr);
break; break;
case 5: case 5:
inSubArr.applyPairwiseTransform(pairwise::ReverseDivide, &updSubArr, &inSubArr, nullptr); inSubArr.applyPairwiseTransform(pairwise::ReverseDivide, updSubArr, inSubArr);
break; break;
case 6: case 6:
inSubArr.applyPairwiseTransform(pairwise::CopyPws, &updSubArr, &inSubArr, nullptr); inSubArr.applyPairwiseTransform(pairwise::CopyPws, updSubArr, inSubArr);
break; break;
default: default:
continue; continue;

View File

@ -122,35 +122,32 @@ __global__ static void fillAsTriangularCuda(const void* vx, const Nd4jLong* xSha
/////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////
template<typename T> template<typename T>
void NDArray::fillAsTriangular(const float val, int lower, int upper, const char direction, NDArray* target) { void NDArray::fillAsTriangular(const float val, int lower, int upper, NDArray& target, const char direction) {
if (isS()) if (isS())
throw std::runtime_error("NDArray::fillAsTriangular: you can't use this method on String array!"); throw std::runtime_error("NDArray::fillAsTriangular: you can't use this method on String array!");
if(target == nullptr) if(!isSameShape(target) && !(rankOf() == 1 && target.rankOf() == 2 && sizeAt(0) == target.sizeAt(0) && sizeAt(0) == target.sizeAt(1)))
target = this;
if(!isSameShape(target) && !(rankOf() == 1 && target->rankOf() == 2 && sizeAt(0) == target->sizeAt(0) && sizeAt(0) == target->sizeAt(1)))
throw std::string("NDArray::fillAsTriangular method: wrong shape of target array !"); throw std::string("NDArray::fillAsTriangular method: wrong shape of target array !");
if (direction == 'u') if (direction == 'u')
lower = -target->sizeAt(-2); lower = -target.sizeAt(-2);
else if (direction == 'l') else if (direction == 'l')
upper = target->sizeAt(-1); upper = target.sizeAt(-1);
const int threadsPerBlock = MAX_NUM_THREADS / 4; const int threadsPerBlock = MAX_NUM_THREADS / 4;
const int blocksPerGrid = (target->lengthOf() + threadsPerBlock - 1) / threadsPerBlock; const int blocksPerGrid = (target.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
const int sharedMem = threadsPerBlock * sizeof(decltype(*target->getShapeInfo())) * target->rankOf() + 128; const int sharedMem = threadsPerBlock * sizeof(decltype(*target.getShapeInfo())) * target.rankOf() + 128;
PointersManager manager(getContext(), "NDArray::fillAsTriangular"); PointersManager manager(getContext(), "NDArray::fillAsTriangular");
NDArray::prepareSpecialUse({target}, {this}); NDArray::prepareSpecialUse({&target}, {this});
fillAsTriangularCuda<T><<<blocksPerGrid, threadsPerBlock, sharedMem, *getContext()->getCudaStream()>>>(getPlatformBuffer(), getPlatformShapeInfo(), target->getPlatformBuffer(), target->getPlatformShapeInfo(), static_cast<T>(val), lower, upper); fillAsTriangularCuda<T><<<blocksPerGrid, threadsPerBlock, sharedMem, *getContext()->getCudaStream()>>>(getPlatformBuffer(), getPlatformShapeInfo(), target.getPlatformBuffer(), target.getPlatformShapeInfo(), static_cast<T>(val), lower, upper);
NDArray::registerSpecialUse({target}, {this}); NDArray::registerSpecialUse({&target}, {this});
manager.synchronize(); manager.synchronize();
} }
BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT void NDArray::fillAsTriangular, (const float val, int lower, int upper, const char direction, NDArray* target), LIBND4J_TYPES); BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT void NDArray::fillAsTriangular, (const float val, int lower, int upper, NDArray& target, const char direction), LIBND4J_TYPES);
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
template<typename T> template<typename T>
@ -457,21 +454,21 @@ BUILD_DOUBLE_TEMPLATE(template void repeatCudaLauncher, (const int blocksPerGrid
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
// create new array by repeating it the number of times given by repeats // create new array by repeating it the number of times given by repeats
NDArray* NDArray::repeat(const int axis, const std::vector<int>& repeats) const { NDArray NDArray::repeat(const int axis, const std::vector<int>& repeats) const {
auto output = new NDArray('c', ShapeUtils::evalRepeatShape(axis, repeats, *this), dataType(), getContext()); NDArray output('c', ShapeUtils::evalRepeatShape(axis, repeats, *this), dataType(), getContext());
const int threadsPerBlock = MAX_NUM_THREADS / 2; const int threadsPerBlock = MAX_NUM_THREADS / 2;
const int blocksPerGrid = (output->lengthOf() + threadsPerBlock - 1) / threadsPerBlock; const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
const int sharedMem = output->rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; const int sharedMem = output.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128;
PointersManager manager(getContext(), "NDArray::repeat(const int axis, const std::vector<int>& repeats)"); PointersManager manager(getContext(), "NDArray::repeat(const int axis, const std::vector<int>& repeats)");
const int* reps = reinterpret_cast<int*>(manager.replicatePointer(repeats.data(), repeats.size() * sizeof(int))); const int* reps = reinterpret_cast<int*>(manager.replicatePointer(repeats.data(), repeats.size() * sizeof(int)));
prepareSpecialUse({output}, {this}); prepareSpecialUse({&output}, {this});
BUILD_SINGLE_SELECTOR_TWICE(dataType(), repeatCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, getContext()->getCudaStream(), getSpecialBuffer(), getSpecialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), reps, repeats.size(), axis), LIBND4J_TYPES); BUILD_SINGLE_SELECTOR_TWICE(dataType(), repeatCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, getContext()->getCudaStream(), getSpecialBuffer(), getSpecialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), reps, repeats.size(), axis), LIBND4J_TYPES);
prepareSpecialUse({output}, {this}); prepareSpecialUse({&output}, {this});
manager.synchronize(); manager.synchronize();

View File

@ -247,73 +247,73 @@ static _CUDA_G void lambdaTriplewiseKernel(void* vw, Nd4jLong *wShapeInfo, void*
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template<typename Lambda> template<typename Lambda>
void NDArray::applyLambda(Lambda func, NDArray* target) { void NDArray::applyLambda(Lambda func, NDArray& target) {
auto result = target == nullptr ? this : target;
auto dtype = this->dataType(); auto dtype = this->dataType();
if (dtype != result->dataType()) if (dtype != target.dataType())
throw std::runtime_error("NDArray::applyLambda X/Z data types must be the same"); throw std::runtime_error("NDArray::applyLambda X/Z data types must be the same");
//throw datatype_exception::build("NDArray::applyLambda X/Z data types must be the same", dtype, result->dataType()); //throw datatype_exception::build("NDArray::applyLambda X/Z data types must be the same", dtype, target.dataType());
prepareSpecialUse({result}, {this}); prepareSpecialUse({&target}, {this});
BUILD_SINGLE_SELECTOR(dtype, LambdaHelper ,::lambdaLauncher(this->_context->getCudaStream(), this->specialBuffer(), this->specialShapeInfo(), result->specialBuffer(), result->specialShapeInfo(), func), LIBND4J_TYPES); BUILD_SINGLE_SELECTOR(dtype, LambdaHelper ,::lambdaLauncher(this->_context->getCudaStream(), this->specialBuffer(), this->specialShapeInfo(), target.specialBuffer(), target.specialShapeInfo(), func), LIBND4J_TYPES);
registerSpecialUse({result}, {this}); registerSpecialUse({&target}, {this});
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template<typename Lambda> template<typename Lambda>
void NDArray::applyPairwiseLambda(const NDArray* other, Lambda func, NDArray* target) { void NDArray::applyPairwiseLambda(const NDArray& other, Lambda func, NDArray& target) {
auto result = target == nullptr ? this : target;
auto dtype = this->dataType(); auto dtype = this->dataType();
if (dtype != result->dataType() || dtype != other->dataType()) if (dtype != target.dataType() || dtype != other.dataType())
throw std::runtime_error("NDArray::applyPairwiseLambda X/Y/Z data types must be the same"); throw std::runtime_error("NDArray::applyPairwiseLambda X/Y/Z data types must be the same");
//throw datatype_exception::build("NDArray::applyLambda X/Z data types must be the same", dtype, result->dataType()); //throw datatype_exception::build("NDArray::applyLambda X/Z data types must be the same", dtype, target.dataType());
prepareSpecialUse({result}, {this, other}); prepareSpecialUse({&target}, {this, &other});
BUILD_SINGLE_SELECTOR(dtype, LambdaHelper ,::lambdaPairwiseLauncher(this->_context->getCudaStream(), this->specialBuffer(), this->specialShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->specialBuffer(), result->specialShapeInfo(), func), LIBND4J_TYPES); BUILD_SINGLE_SELECTOR(dtype, LambdaHelper ,::lambdaPairwiseLauncher(this->_context->getCudaStream(), this->specialBuffer(), this->specialShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), target.specialBuffer(), target.specialShapeInfo(), func), LIBND4J_TYPES);
registerSpecialUse({result}, {this, other}); registerSpecialUse({&target}, {this, &other});
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template <typename Lambda> template <typename Lambda>
void NDArray::applyIndexedLambda(Lambda func, NDArray* target) { void NDArray::applyIndexedLambda(Lambda func, NDArray& target) {
auto result = target == nullptr ? this : target;
auto dtype = this->dataType(); auto dtype = this->dataType();
if (dtype != result->dataType()) if (dtype != target.dataType())
throw std::runtime_error("NDArray::applyIndexedLambda X/Z data types must be the same"); throw std::runtime_error("NDArray::applyIndexedLambda X/Z data types must be the same");
prepareSpecialUse({result}, {this}); prepareSpecialUse({&target}, {this});
BUILD_SINGLE_SELECTOR(dtype, LambdaHelper ,::lambdaIndexedLauncher(this->_context->getCudaStream(), this->specialBuffer(), this->specialShapeInfo(), result->specialBuffer(), result->specialShapeInfo(), func), LIBND4J_TYPES); BUILD_SINGLE_SELECTOR(dtype, LambdaHelper ,::lambdaIndexedLauncher(this->_context->getCudaStream(), this->specialBuffer(), this->specialShapeInfo(), target.specialBuffer(), target.specialShapeInfo(), func), LIBND4J_TYPES);
registerSpecialUse({result}, {this}); registerSpecialUse({&target}, {this});
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template <typename Lambda> template <typename Lambda>
void NDArray::applyIndexedPairwiseLambda(NDArray* other, Lambda func, NDArray* target) { void NDArray::applyIndexedPairwiseLambda(NDArray& other, Lambda func, NDArray& target) {
auto result = target == nullptr ? this : target;
auto dtype = this->dataType(); auto dtype = this->dataType();
if (dtype != result->dataType() || dtype != other->dataType()) if (dtype != target.dataType() || dtype != other.dataType())
throw std::runtime_error("NDArray::applyIndexedPairwiseLambda X/Y/Z data types must be the same"); throw std::runtime_error("NDArray::applyIndexedPairwiseLambda X/Y/Z data types must be the same");
prepareSpecialUse({result}, {this, other}); prepareSpecialUse({&target}, {this, &other});
BUILD_SINGLE_SELECTOR(dtype, LambdaHelper ,::lambdaIndexedPairwiseLauncher(this->_context->getCudaStream(), this->specialBuffer(), this->specialShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->specialBuffer(), result->specialShapeInfo(), func), LIBND4J_TYPES); BUILD_SINGLE_SELECTOR(dtype, LambdaHelper ,::lambdaIndexedPairwiseLauncher(this->_context->getCudaStream(), this->specialBuffer(), this->specialShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), target.specialBuffer(), target.specialShapeInfo(), func), LIBND4J_TYPES);
registerSpecialUse({result}, {this, other}); registerSpecialUse({&target}, {this, &other});
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template <typename Lambda> template <typename Lambda>
void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, Lambda func, NDArray* target) { void NDArray::applyTriplewiseLambda(NDArray& second, NDArray& third, Lambda func, NDArray& target) {
auto result = target == nullptr ? this : target;
auto dtype = this->dataType(); auto dtype = this->dataType();
if (dtype != result->dataType() || dtype != second->dataType() || dtype != third->dataType()) if (dtype != target.dataType() || dtype != second.dataType() || dtype != third.dataType())
throw std::runtime_error("NDArray::applyTriplewiseLambda X/Y/Z data types must be the same"); throw std::runtime_error("NDArray::applyTriplewiseLambda X/Y/Z data types must be the same");
prepareSpecialUse({result}, {this, second, third}); prepareSpecialUse({&target}, {this, &second, &third});
BUILD_SINGLE_SELECTOR(dtype, LambdaHelper ,::lambdaTriplewiseLauncher(this->_context->getCudaStream(), this->specialBuffer(), this->specialShapeInfo(), second->specialBuffer(), second->specialShapeInfo(), third->specialBuffer(), third->specialShapeInfo(), result->specialBuffer(), result->specialShapeInfo(), func), LIBND4J_TYPES); BUILD_SINGLE_SELECTOR(dtype, LambdaHelper ,::lambdaTriplewiseLauncher(this->_context->getCudaStream(), this->specialBuffer(), this->specialShapeInfo(), second.specialBuffer(), second.specialShapeInfo(), third.specialBuffer(), third.specialShapeInfo(), target.specialBuffer(), target.specialShapeInfo(), func), LIBND4J_TYPES);
registerSpecialUse({result}, {this, second, third}); registerSpecialUse({&target}, {this, &second, &third});
} }

View File

@ -91,6 +91,10 @@ namespace nd4j {
template <typename T> template <typename T>
FORCEINLINE static bool castShapeInfo(const Nd4jLong *originalShapeInfo, T *newShapeInfo); FORCEINLINE static bool castShapeInfo(const Nd4jLong *originalShapeInfo, T *newShapeInfo);
template<typename T>
// struct scalarTypesForNDarray { static bool const value = std::is_same<double, T>::value || std::is_same<float, T>::value || std::is_same<int, T>::value || std::is_same<bfloat16, T>::value || std::is_same<float16, T>::value || std::is_same<long long, T>::value; };
struct scalarTypesForNDarray { static bool const value = std::is_same<double, T>::value || std::is_same<float, T>::value || std::is_same<int, T>::value || std::is_same<unsigned int, T>::value || std::is_same<long long, T>::value || std::is_same<unsigned long long, T>::value || std::is_same<long int, T>::value || std::is_same<long unsigned int, T>::value || std::is_same<int8_t, T>::value || std::is_same<uint8_t, T>::value || std::is_same<int16_t, T>::value || std::is_same<uint16_t, T>::value || std::is_same<bool, T>::value || std::is_same<bfloat16, T>::value || std::is_same<float16, T>::value; };
}; };

View File

@ -44,7 +44,7 @@ namespace nd4j {
} }
NDArray* NDArrayList::read(int idx) { NDArray* NDArrayList::read(int idx) {
return readRaw(idx)->dup(); return new NDArray(readRaw(idx)->dup());
} }
nd4j::DataType NDArrayList::dataType() { nd4j::DataType NDArrayList::dataType() {
@ -136,11 +136,10 @@ namespace nd4j {
std::vector<int> args({axis}); std::vector<int> args({axis});
auto newAxis = ShapeUtils::evalDimsToExclude(array->rankOf(), args); auto newAxis = ShapeUtils::evalDimsToExclude(array->rankOf(), args);
auto result = array->allTensorsAlongDimension(newAxis); auto result = array->allTensorsAlongDimension(newAxis);
for (int e = 0; e < result->size(); e++) { for (int e = 0; e < result.size(); e++) {
auto chunk = result->at(e);//->dup(array->ordering()); auto chunk = result.at(e);//->dup(array->ordering());
write(e, chunk->dup(array->ordering())); write(e, new NDArray(chunk->dup(array->ordering())));
} }
delete result;
} }
NDArray* NDArrayList::stack() { NDArray* NDArrayList::stack() {
@ -161,7 +160,7 @@ namespace nd4j {
auto result = op.execute(inputs, {}, {}, {}); auto result = op.execute(inputs, {}, {}, {});
auto array = result->at(0)->dup(); auto array = new NDArray(result->at(0)->dup());
delete result; delete result;
@ -214,13 +213,11 @@ namespace nd4j {
auto tads = array->allTensorsAlongDimension(axis); auto tads = array->allTensorsAlongDimension(axis);
int indicesSize = indices.size(); int indicesSize = indices.size();
if (tads->size() != indicesSize) if (tads.size() != indicesSize)
throw std::runtime_error("Number of TADs should match number of indices"); throw std::runtime_error("Number of TADs should match number of indices");
for (int e = 0; e < indicesSize; e++) for (int e = 0; e < indicesSize; e++)
tads->at(e)->assign(_chunks[indices[e]]); tads.at(e)->assign(_chunks[indices[e]]);
delete tads;
return array; return array;
} }
@ -234,7 +231,7 @@ namespace nd4j {
list->_elements.store(_elements.load()); list->_elements.store(_elements.load());
for (auto const& v : _chunks) { for (auto const& v : _chunks) {
list->_chunks[v.first] = v.second->dup(); list->_chunks[v.first] = new NDArray(v.second->dup());
} }
return list; return list;

View File

@ -48,7 +48,7 @@ namespace nd4j {
} else { } else {
// FIXME: in some cases it's possible to have no NDArray // FIXME: in some cases it's possible to have no NDArray
if (inputVar->hasNDArray()) if (inputVar->hasNDArray())
innerVar->setNDArray(inputVar->getNDArray()->dup()); innerVar->setNDArray(new NDArray(inputVar->getNDArray()->dup()));
} }
} }

View File

@ -56,7 +56,7 @@ namespace nd4j {
} else { } else {
// FIXME: in some cases it's possible to have no NDArray // FIXME: in some cases it's possible to have no NDArray
if (inputVar->hasNDArray()) if (inputVar->hasNDArray())
innerVar->setNDArray(inputVar->getNDArray()->dup()); innerVar->setNDArray(new NDArray(inputVar->getNDArray()->dup()));
} }
} }

View File

@ -40,7 +40,7 @@ namespace nd4j {
result->setIndex(this->_index); result->setIndex(this->_index);
if (this->_ndarray != nullptr) if (this->_ndarray != nullptr)
result->setNDArray(this->_ndarray->template asT<N>()); result->setNDArray(new NDArray(this->_ndarray->template asT<N>()));
// FIXME: add support for ArrayList // FIXME: add support for ArrayList
if (this->_list != nullptr) { if (this->_list != nullptr) {
@ -61,7 +61,7 @@ namespace nd4j {
result->_index = this->_index; result->_index = this->_index;
if (this->_ndarray != nullptr) if (this->_ndarray != nullptr)
result->_ndarray = this->_ndarray->dup(this->_ndarray->ordering()); result->_ndarray = new NDArray(this->_ndarray->dup(this->_ndarray->ordering()));
if (this->_list != nullptr) if (this->_list != nullptr)
result->_list = this->_list->clone(); result->_list = this->_list->clone();

View File

@ -93,7 +93,7 @@ namespace nd4j {
} }
OpBenchmark* clone() override { OpBenchmark* clone() override {
return new ScalarBenchmark((scalar::Ops) _opNum, _testName, _x == nullptr ? _x : _x->dup() , _y == nullptr ? _y : _y->dup(), _z == nullptr ? _z : _z->dup()); return new ScalarBenchmark((scalar::Ops) _opNum, _testName, _x == nullptr ? _x : new NDArray(_x->dup()) , _y == nullptr ? _y : new NDArray(_y->dup()), _z == nullptr ? _z : new NDArray(_z->dup()));
} }
}; };
} }

View File

@ -230,17 +230,17 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, con
bool cNcont = N == 1 || C->strideAt(1) == 1; bool cNcont = N == 1 || C->strideAt(1) == 1;
if(!aMcont && !aKcont) { if(!aMcont && !aKcont) {
pA = A->dup('f'); pA = new NDArray(A->dup('f'));
toDelete.push_back(pA); toDelete.push_back(pA);
aMcont = true; aMcont = true;
} }
if(!bKcont && !bNcont) { if(!bKcont && !bNcont) {
pB = B->dup('f'); pB = new NDArray(B->dup('f'));
toDelete.push_back(pB); toDelete.push_back(pB);
bKcont = true; bKcont = true;
} }
if(!cMcont && !cNcont) { if(!cMcont && !cNcont) {
pC = C->dup('f'); pC = new NDArray(C->dup('f'));
toDelete.push_back(pC); toDelete.push_back(pC);
cMcont = true; cMcont = true;
} }
@ -332,7 +332,7 @@ NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, nd4j::NDArray*
bool aNcont = N == 1 || A->strideAt(1) == 1; bool aNcont = N == 1 || A->strideAt(1) == 1;
if(!aMcont && !aNcont) { if(!aMcont && !aNcont) {
pA = A->dup('f'); pA = new NDArray(A->dup('f'));
aMcont = true; aMcont = true;
} }
const CBLAS_ORDER blasOrder = aMcont ? CblasColMajor : CblasRowMajor; const CBLAS_ORDER blasOrder = aMcont ? CblasColMajor : CblasRowMajor;

View File

@ -60,11 +60,10 @@ NDArray Householder<T>::evalHHmatrix(const NDArray& x) {
w.p(Nd4jLong(0), 1.f); w.p(Nd4jLong(0), 1.f);
wT.assign(&w); wT.assign(&w);
auto identity = NDArrayFactory::create(x.ordering(), {(int)x.lengthOf(), (int)x.lengthOf()}, x.dataType(), x.getContext()); NDArray identity = NDArrayFactory::create(x.ordering(), {(int)x.lengthOf(), (int)x.lengthOf()}, x.dataType(), x.getContext());
identity.setIdentity(); // identity matrix identity.setIdentity(); // identity matrix
return identity - mmul(w, wT) * coeff; return identity - mmul(w, wT) * coeff;
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
@ -95,9 +94,9 @@ void Householder<T>::evalHHmatrixData(const NDArray& x, NDArray& tail, T& coeff,
coeff = -u0 / normX; coeff = -u0 / normX;
if(x.isRowVector()) if(x.isRowVector())
tail.assign(x({0,0, 1,-1}) / u0); tail.assign(static_cast<const NDArray&>(x({0,0, 1,-1})) / u0);
else else
tail.assign(x({1,-1, 0,0,}) / u0); tail.assign(static_cast<const NDArray&>(x({1,-1, 0,0,})) / u0);
} }
} }

View File

@ -269,7 +269,7 @@ void JacobiSVD<T>::evalData(const NDArray& matrix) {
HHcolPivQR qr(matrix / scale); HHcolPivQR qr(matrix / scale);
_m.assign(qr._qr({0,_cols, 0,_cols})); _m.assign(qr._qr({0,_cols, 0,_cols}));
_m.fillAsTriangular<T>(0., 0, 0, 'l'); _m.fillAsTriangular<T>(0., 0, 0, _m, 'l');
HHsequence hhSeg(qr._qr, qr._coeffs, 'u'); HHsequence hhSeg(qr._qr, qr._coeffs, 'u');
@ -288,7 +288,7 @@ void JacobiSVD<T>::evalData(const NDArray& matrix) {
auto matrixT = matrix.transpose(); auto matrixT = matrix.transpose();
HHcolPivQR qr(matrixT / scale); HHcolPivQR qr(matrixT / scale);
_m.assign(qr._qr({0,_rows, 0,_rows})); _m.assign(qr._qr({0,_rows, 0,_rows}));
_m.fillAsTriangular<T>(0., 0, 0, 'l'); _m.fillAsTriangular<T>(0., 0, 0, _m, 'l');
_m.transposei(); _m.transposei();
HHsequence hhSeg(qr._qr, qr._coeffs, 'u'); // type = 'u' is not mistake here ! HHsequence hhSeg(qr._qr, qr._coeffs, 'u'); // type = 'u' is not mistake here !
@ -305,7 +305,7 @@ void JacobiSVD<T>::evalData(const NDArray& matrix) {
} }
else { else {
_m.assign(matrix({0,_diagSize, 0,_diagSize}) / scale); _m.assign(static_cast<const NDArray&>(matrix({0,_diagSize, 0,_diagSize})) / scale);
if(_calcU) if(_calcU)
_u.setIdentity(); _u.setIdentity();
@ -366,7 +366,7 @@ void JacobiSVD<T>::evalData(const NDArray& matrix) {
_s.p(i, math::nd4j_abs<T>(_m.e<T>(i,i))); _s.p(i, math::nd4j_abs<T>(_m.e<T>(i,i)));
if(_calcU && _m.e<T>(i,i) < (T)0.) { if(_calcU && _m.e<T>(i,i) < (T)0.) {
auto temp = _u({0,0, i,i+1}, true); auto temp = _u({0,0, i,i+1}, true);
temp.applyTransform(transform::Neg, &temp, nullptr); temp.applyTransform(transform::Neg, temp, nullptr);
} }
} }

View File

@ -223,26 +223,26 @@ void SVD<T>::deflation(int col1, int col2, int ind, int row1W, int col1W, int sh
const T almostZero = DataTypeUtils::min<T>(); const T almostZero = DataTypeUtils::min<T>();
T maxElem; T maxElem;
if(len == 1) if(len == 1)
maxElem = math::nd4j_abs<T>(diagInterval->template e<T>(0)); maxElem = math::nd4j_abs<T>(diagInterval.template e<T>(0));
else else
maxElem = (*diagInterval)({1,-1, 0,0}, true).reduceNumber(reduce::AMax).template e<T>(0); maxElem = diagInterval({1,-1, 0,0}, true).reduceNumber(reduce::AMax).template e<T>(0);
T maxElem0 = colVec0->reduceNumber(reduce::AMax).template e<T>(0); T maxElem0 = colVec0->reduceNumber(reduce::AMax).template e<T>(0);
T eps = math::nd4j_max<T>(almostZero, DataTypeUtils::eps<T>() * maxElem); T eps = math::nd4j_max<T>(almostZero, DataTypeUtils::eps<T>() * maxElem);
T epsBig = (T)8. * DataTypeUtils::eps<T>() * math::nd4j_max<T>(maxElem0, maxElem); T epsBig = (T)8. * DataTypeUtils::eps<T>() * math::nd4j_max<T>(maxElem0, maxElem);
if(diagInterval->template e<T>(0) < epsBig) if(diagInterval.template e<T>(0) < epsBig)
diagInterval->p(Nd4jLong(0), epsBig); diagInterval.p(Nd4jLong(0), epsBig);
for(int i=1; i < len; ++i) for(int i=1; i < len; ++i)
if(math::nd4j_abs<T>(colVec0->template e<T>(i)) < eps) if(math::nd4j_abs<T>(colVec0->template e<T>(i)) < eps)
colVec0->p(i, 0.f); colVec0->p(i, 0.f);
for(int i=1; i < len; i++) for(int i=1; i < len; i++)
if(diagInterval->template e<T>(i) < epsBig) { if(diagInterval.template e<T>(i) < epsBig) {
deflation1(col1, shift, i, len); deflation1(col1, shift, i, len);
for(int i = 0; i < len; ++i) for(int i = 0; i < len; ++i)
diagInterval->p(i, _m.e<T>(col1+shift+i,col1+shift+i)); diagInterval.p(i, _m.e<T>(col1+shift+i,col1+shift+i));
} }
{ {
@ -261,7 +261,7 @@ void SVD<T>::deflation(int col1, int col2, int ind, int row1W, int col1W, int sh
int p = 1; int p = 1;
for(int i=1; i<len; ++i) for(int i=1; i<len; ++i)
if(math::nd4j_abs<T>(diagInterval->template e<T>(i)) < almostZero) if(math::nd4j_abs<T>(diagInterval.template e<T>(i)) < almostZero)
permut[p++] = i; permut[p++] = i;
int k = 1, m = ind+1; int k = 1, m = ind+1;
@ -271,7 +271,7 @@ void SVD<T>::deflation(int col1, int col2, int ind, int row1W, int col1W, int sh
permut[p] = m++; permut[p] = m++;
else if(m >= len) else if(m >= len)
permut[p] = k++; permut[p] = k++;
else if(diagInterval->template e<T>(k) < diagInterval->template e<T>(m)) else if(diagInterval.template e<T>(k) < diagInterval.template e<T>(m))
permut[p] = m++; permut[p] = m++;
else else
permut[p] = k++; permut[p] = k++;
@ -281,7 +281,7 @@ void SVD<T>::deflation(int col1, int col2, int ind, int row1W, int col1W, int sh
if(totDefl) { if(totDefl) {
for(int i=1; i<len; ++i) { for(int i=1; i<len; ++i) {
int ki = permut[i]; int ki = permut[i];
if(math::nd4j_abs<T>(diagInterval->template e<T>(ki)) < almostZero || diagInterval->template e<T>(0) < diagInterval->template e<T>(ki)) if(math::nd4j_abs<T>(diagInterval.template e<T>(ki)) < almostZero || diagInterval.template e<T>(0) < diagInterval.template e<T>(ki))
permut[i-1] = permut[i]; permut[i-1] = permut[i];
else { else {
permut[i-1] = 0; permut[i-1] = 0;
@ -303,10 +303,10 @@ void SVD<T>::deflation(int col1, int col2, int ind, int row1W, int col1W, int sh
const int ki = permut[len - (totDefl ? i+1 : i)]; const int ki = permut[len - (totDefl ? i+1 : i)];
const int jac = tCol[ki]; const int jac = tCol[ki];
T _e0 = diagInterval->template e<T>(jac); T _e0 = diagInterval.template e<T>(jac);
//math::nd4j_swap<T>(diagInterval)(i), (*diagInterval)(jac)); //math::nd4j_swap<T>(diagInterval)(i), (*diagInterval)(jac));
diagInterval->p(jac, diagInterval->template e<T>(i)); diagInterval.p(jac, diagInterval.template e<T>(i));
diagInterval->p(i, _e0); diagInterval.p(i, _e0);
if(i!=0 && jac!=0) { if(i!=0 && jac!=0) {
_e0 = colVec0->template e<T>(jac); _e0 = colVec0->template e<T>(jac);
@ -315,9 +315,8 @@ void SVD<T>::deflation(int col1, int col2, int ind, int row1W, int col1W, int sh
colVec0->p(i, _e0); colVec0->p(i, _e0);
} }
NDArray* temp1 = nullptr, *temp2 = nullptr;
if (_calcU) { if (_calcU) {
auto temp1 = _u({col1,col1+len+1, col1+i, col1+i+1}, true); auto temp1 = _u({col1,col1+len+1, col1+i, col1+i+1}, true);
auto temp2 = _u({col1,col1+len+1, col1+jac,col1+jac+1}, true); auto temp2 = _u({col1,col1+len+1, col1+jac,col1+jac+1}, true);
auto temp3 = temp1; auto temp3 = temp1;
temp1.assign(temp2); temp1.assign(temp2);
@ -352,12 +351,12 @@ void SVD<T>::deflation(int col1, int col2, int ind, int row1W, int col1W, int sh
{ {
int i = len-1; int i = len-1;
while(i > 0 && (math::nd4j_abs<T>(diagInterval->template e<T>(i)) < almostZero || math::nd4j_abs<T>(colVec0->template e<T>(i)) < almostZero)) while(i > 0 && (math::nd4j_abs<T>(diagInterval.template e<T>(i)) < almostZero || math::nd4j_abs<T>(colVec0->template e<T>(i)) < almostZero))
--i; --i;
for(; i > 1; --i) { for(; i > 1; --i) {
if( (diagInterval->template e<T>(i) - diagInterval->template e<T>(i-1)) < DataTypeUtils::eps<T>()*maxElem ) { if( (diagInterval.template e<T>(i) - diagInterval.template e<T>(i-1)) < DataTypeUtils::eps<T>()*maxElem ) {
if (math::nd4j_abs<T>(diagInterval->template e<T>(i) - diagInterval->template e<T>(i-1)) >= epsBig) if (math::nd4j_abs<T>(diagInterval.template e<T>(i) - diagInterval.template e<T>(i-1)) >= epsBig)
throw std::runtime_error("ops::helpers::SVD::deflation: diagonal elements are not properly sorted !"); throw std::runtime_error("ops::helpers::SVD::deflation: diagonal elements are not properly sorted !");
deflation2(col1, col1 + shift, row1W, col1W, i-1, i, len); deflation2(col1, col1 + shift, row1W, col1W, i-1, i, len);
} }
@ -365,7 +364,6 @@ void SVD<T>::deflation(int col1, int col2, int ind, int row1W, int col1W, int sh
} }
delete colVec0; delete colVec0;
delete diagInterval;
} }
@ -609,9 +607,7 @@ void SVD<T>::calcBlockSVD(int col1, int size, NDArray& U, NDArray& singVals, NDA
const T almostZero = DataTypeUtils::min<T>(); const T almostZero = DataTypeUtils::min<T>();
auto col0 = _m({col1, col1+size, col1, col1+1}, true); auto col0 = _m({col1, col1+size, col1, col1+1}, true);
auto diagP = _m({col1, col1+size, col1, col1+size}, true).diagonal('c'); auto diag = static_cast<const NDArray&>(_m({col1, col1+size, col1, col1+size}, true).diagonal('c'));
auto diag = *diagP;
delete diagP;
diag.p(Nd4jLong(0), T(0)); diag.p(Nd4jLong(0), T(0));
singVals = NDArrayFactory::create<T>(_m.ordering(), {size, 1}, _m.getContext()); singVals = NDArrayFactory::create<T>(_m.ordering(), {size, 1}, _m.getContext());
@ -730,8 +726,7 @@ void SVD<T>::DivideAndConquer(int col1, int col2, int row1W, int col1W, int shif
auto temp = _m({col1+shift,col1+shift+n+1, col1+shift,col1+shift+n}, true); auto temp = _m({col1+shift,col1+shift+n+1, col1+shift,col1+shift+n}, true);
temp.assign(0.); temp.assign(0.);
auto diag = _m.diagonal('c'); auto diag = _m.diagonal('c');
(*diag)({col1+shift, col1+shift+n, 0,0}, true).assign(jac._s({0,n, 0,0}, true)); diag({col1+shift, col1+shift+n, 0,0}, true).assign(jac._s({0,n, 0,0}, true));
delete diag;
return; return;
} }
@ -762,11 +757,6 @@ void SVD<T>::DivideAndConquer(int col1, int col2, int row1W, int col1W, int shif
f.assign(_u({0,1, col1+k+1, col1+n}, true)); f.assign(_u({0,1, col1+k+1, col1+n}, true));
} }
// UofSVD.printIndexedBuffer();
// VofSVD.printIndexedBuffer();
// singVals.printIndexedBuffer();
// printf("!! \n");
if (_calcV) if (_calcV)
_v.p(row1W+k, col1W, 1.f); _v.p(row1W+k, col1W, 1.f);
@ -789,14 +779,10 @@ void SVD<T>::DivideAndConquer(int col1, int col2, int row1W, int col1W, int shif
temp.assign(_u({col1, col1+k+1, i, i+1}, true)); temp.assign(_u({col1, col1+k+1, i, i+1}, true));
} }
auto temp1 = _u({col1,col1+k+1, col1,col1+1}, true); _u({col1,col1+k+1, col1,col1+1}, true).assign(q1 * c0);
temp1.assign(q1 * c0); _u({col1,col1+k+1, col2+1,col2+2}, true).assign(q1 * (-s0));
auto temp2 = _u({col1,col1+k+1, col2+1,col2+2}, true); _u({col1+k+1,col1+n+1, col1, col1+1}, true).assign(static_cast<const NDArray&>(_u({col1+k+1, col1+n+1, col2+1, col2+2}, true)) * s0);
temp2.assign(q1 * (-s0)); _u({col1+k+1,col1+n+1, col2+1,col2+2}, true) *= c0;
auto temp3 = _u({col1+k+1,col1+n+1, col1, col1+1}, true);
temp3.assign(_u({col1+k+1, col1+n+1, col2+1, col2+2}, true) * s0);
auto temp4 =_u({col1+k+1,col1+n+1, col2+1,col2+2}, true);
temp4 *= c0;
} }
else { else {
@ -844,8 +830,7 @@ void SVD<T>::DivideAndConquer(int col1, int col2, int row1W, int col1W, int shif
auto blockM = _m({col1+shift,col1+shift+n, col1+shift,col1+shift+n}, true); auto blockM = _m({col1+shift,col1+shift+n, col1+shift,col1+shift+n}, true);
blockM = 0.f; blockM = 0.f;
auto diag = blockM.diagonal('c'); auto diag = blockM.diagonal('c');
diag->assign(singVals); diag.assign(singVals);
delete diag;
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////

View File

@ -285,17 +285,17 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, dou
bool cNcont = N == 1 || C->strideAt(1) == 1; bool cNcont = N == 1 || C->strideAt(1) == 1;
if(!aMcont && !aKcont) { if(!aMcont && !aKcont) {
pA = A->dup('f'); pA = new NDArray(A->dup('f'));
toDelete.push_back(pA); toDelete.push_back(pA);
aMcont = true; aMcont = true;
} }
if(!bKcont && !bNcont) { if(!bKcont && !bNcont) {
pB = B->dup('f'); pB = new NDArray(B->dup('f'));
toDelete.push_back(pB); toDelete.push_back(pB);
bKcont = true; bKcont = true;
} }
if(!cMcont) { if(!cMcont) {
pC = C->dup('f'); pC = new NDArray(C->dup('f'));
toDelete.push_back(pC); toDelete.push_back(pC);
cMcont = true; cMcont = true;
} }
@ -418,7 +418,7 @@ NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, nd4j::NDArray*
bool aNcont = N == 1 || A->strideAt(1) == 1; bool aNcont = N == 1 || A->strideAt(1) == 1;
if(!aMcont && !aNcont) { if(!aMcont && !aNcont) {
pA = A->dup('f'); pA = new NDArray(A->dup('f'));
aMcont = true; aMcont = true;
} }
@ -866,12 +866,12 @@ NDArray* MmulHelper::mmulNxNold2(const NDArray* A, const NDArray* B, NDArray* C,
bool cNcont = N == 1 || C->strideAt(-1) == 1; bool cNcont = N == 1 || C->strideAt(-1) == 1;
if(!aMcont && !aKcont) { if(!aMcont && !aKcont) {
pA = A->dup('c'); pA = new NDArray(A->dup('c'));
toDelete.push_back(pA); toDelete.push_back(pA);
aKcont = true; aKcont = true;
} }
if(!bKcont && !bNcont) { if(!bKcont && !bNcont) {
pB = B->dup('c'); pB = new NDArray(B->dup('c'));
toDelete.push_back(pB); toDelete.push_back(pB);
bNcont = true; bNcont = true;
} }

View File

@ -82,7 +82,7 @@ namespace nd4j {
// now we actually apply quantization // now we actually apply quantization
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e += increment) { for (auto e = start; e < stop; e += increment) {
rz[e] = static_cast<char>(nd4j::math::nd4j_round<float, char>(1.0f * x[e] / nd4j::math::nd4j_max<float>(amax, amin) * max_byte)); rz[e] = static_cast<char>(nd4j::math::nd4j_round<float, char>( 1.0f * static_cast<float>(x[e]) / nd4j::math::nd4j_max<float>(amax, amin) * max_byte));
} }
}; };
@ -180,7 +180,7 @@ PRAGMA_OMP_ATOMIC_ARGS(write)
for (auto e = start; e < stop; e += increment) { for (auto e = start; e < stop; e += increment) {
int el = x[e]; int el = x[e];
int ael = nd4j::math::nd4j_abs<int>(el) - 1; int ael = nd4j::math::nd4j_abs<int>(el) - 1;
z[ael] += el > 0 ? threshold : -threshold; z[ael] += el > 0 ? static_cast<T>(threshold) : static_cast<T>(-threshold);
} }
}; };

View File

@ -32,21 +32,19 @@ namespace nd4j {
REQUIRE_TRUE(x->isR(), 0, "CRELU: input must be real type"); REQUIRE_TRUE(x->isR(), 0, "CRELU: input must be real type");
auto tmp = x->dup(); auto tmp = x->dup();
tmp->applyTransform(nd4j::transform::Neg, nullptr, nullptr); tmp.applyTransform(nd4j::transform::Neg, tmp);
auto z = OUTPUT_VARIABLE(0); auto z = OUTPUT_VARIABLE(0);
helpers::concat(block.launchContext(), {x, tmp}, *z, x->rankOf()-1); helpers::concat(block.launchContext(), {x, &tmp}, *z, x->rankOf()-1);
// NDArrayFactory<T>::concat({x, tmp}, -1, z); // NDArrayFactory<T>::concat({x, tmp}, -1, z);
// TODO: make this configurable? // TODO: make this configurable?
double threshold = 0.0; double threshold = 0.0;
z->applyScalar(nd4j::scalar::RELU, threshold); z->applyScalar(nd4j::scalar::RELU, threshold, *z);
STORE_RESULT(z); STORE_RESULT(z);
delete tmp;
return Status::OK(); return Status::OK();
} }
@ -94,7 +92,7 @@ namespace nd4j {
auto pos = dec->at(0); auto pos = dec->at(0);
auto neg = dec->at(1); auto neg = dec->at(1);
pos->applyPairwiseTransform(nd4j::pairwise::Subtract, neg, epsilon, nullptr); pos->applyPairwiseTransform(nd4j::pairwise::Subtract, *neg, *epsilon);
delete tmpResult; delete tmpResult;
delete dec; delete dec;

View File

@ -31,7 +31,7 @@ namespace nd4j {
auto input = INPUT_VARIABLE(0); auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0);
input->applyTransform(nd4j::transform::Cube, output, nullptr); input->applyTransform(nd4j::transform::Cube, *output);
STORE_RESULT(output); STORE_RESULT(output);
return Status::OK(); return Status::OK();

View File

@ -32,7 +32,7 @@ namespace nd4j {
const auto alpha = block.numT() > 0 ? T_ARG(0) : 1.f; const auto alpha = block.numT() > 0 ? T_ARG(0) : 1.f;
input->applyScalar(nd4j::scalar::ELU, alpha, output); input->applyScalar(nd4j::scalar::ELU, alpha, *output);
return Status::OK(); return Status::OK();
} }

View File

@ -30,7 +30,7 @@ namespace nd4j {
auto input = INPUT_VARIABLE(0); auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0);
input->applyTransform(nd4j::transform::HardSigmoid, output, nullptr); input->applyTransform(nd4j::transform::HardSigmoid, *output);
STORE_RESULT(output); STORE_RESULT(output);
return Status::OK(); return Status::OK();

View File

@ -30,7 +30,7 @@ namespace nd4j {
auto input = INPUT_VARIABLE(0); auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0);
input->applyTransform(nd4j::transform::HardTanh, output, nullptr); input->applyTransform(nd4j::transform::HardTanh, *output);
STORE_RESULT(output); STORE_RESULT(output);
return Status::OK(); return Status::OK();

View File

@ -30,7 +30,7 @@ namespace nd4j {
auto z = this->getZ(block); auto z = this->getZ(block);
// just for lulz // just for lulz
first->applyTransform(nd4j::transform::Identity, z, nullptr); first->applyTransform(nd4j::transform::Identity, *z);
STORE_RESULT(*z); STORE_RESULT(*z);

View File

@ -33,7 +33,7 @@ namespace nd4j {
auto x = INPUT_VARIABLE(i); auto x = INPUT_VARIABLE(i);
auto z = OUTPUT_VARIABLE(i); auto z = OUTPUT_VARIABLE(i);
x->applyTransform(transform::Identity, z, nullptr); x->applyTransform(transform::Identity, *z);
} }
} }

View File

@ -31,7 +31,7 @@ namespace nd4j {
float alpha = block.numT() > 0 ? T_ARG(0) : 0.01f; float alpha = block.numT() > 0 ? T_ARG(0) : 0.01f;
input->applyScalar(nd4j::scalar::LeakyRELU, alpha, output); input->applyScalar(nd4j::scalar::LeakyRELU, alpha, *output);
STORE_RESULT(output); STORE_RESULT(output);
return Status::OK(); return Status::OK();

View File

@ -30,7 +30,7 @@ namespace nd4j {
auto input = INPUT_VARIABLE(0); auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0);
input->applyTransform(nd4j::transform::RationalTanh, output, nullptr); input->applyTransform(nd4j::transform::RationalTanh, *output);
STORE_RESULT(output); STORE_RESULT(output);
return Status::OK(); return Status::OK();

View File

@ -30,7 +30,7 @@ namespace nd4j {
auto input = INPUT_VARIABLE(0); auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0);
input->applyTransform(nd4j::transform::RectifiedTanh, output, nullptr); input->applyTransform(nd4j::transform::RectifiedTanh, *output);
STORE_RESULT(output); STORE_RESULT(output);
return Status::OK(); return Status::OK();

View File

@ -32,7 +32,7 @@ namespace nd4j {
auto scalar = block.numT() > 0 ? block.getTArguments()->at(0) : 0.0; auto scalar = block.numT() > 0 ? block.getTArguments()->at(0) : 0.0;
first->applyScalar(nd4j::scalar::RELU, scalar, z); first->applyScalar(nd4j::scalar::RELU, scalar, *z);
STORE_RESULT(*z); STORE_RESULT(*z);

View File

@ -33,7 +33,7 @@ CONFIGURABLE_OP_IMPL(relu6, 1, 1, true, 1, 0) {
auto input = INPUT_VARIABLE(0); auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0);
input->applyScalar(nd4j::scalar::RELU6, T_ARG(0), output); input->applyScalar(nd4j::scalar::RELU6, T_ARG(0), *output);
return Status::OK(); return Status::OK();
} }

View File

@ -30,7 +30,7 @@ namespace nd4j {
auto first = INPUT_VARIABLE(0); auto first = INPUT_VARIABLE(0);
auto z = OUTPUT_VARIABLE(0); auto z = OUTPUT_VARIABLE(0);
first->applyTransform(nd4j::transform::SELU, z, nullptr); first->applyTransform(nd4j::transform::SELU, *z);
STORE_RESULT(*z); STORE_RESULT(*z);

View File

@ -29,7 +29,7 @@ namespace nd4j {
auto first = INPUT_VARIABLE(0); auto first = INPUT_VARIABLE(0);
auto z = OUTPUT_VARIABLE(0); auto z = OUTPUT_VARIABLE(0);
first->applyTransform(nd4j::transform::Sigmoid, z, nullptr); first->applyTransform(nd4j::transform::Sigmoid, *z);
STORE_RESULT(*z); STORE_RESULT(*z);

View File

@ -30,7 +30,7 @@ namespace nd4j {
auto first = INPUT_VARIABLE(0); auto first = INPUT_VARIABLE(0);
auto z = OUTPUT_VARIABLE(0); auto z = OUTPUT_VARIABLE(0);
first->applyTransform(nd4j::transform::SoftPlus, z, nullptr); first->applyTransform(nd4j::transform::SoftPlus, *z);
STORE_RESULT(*z); STORE_RESULT(*z);

View File

@ -30,7 +30,7 @@ namespace nd4j {
auto first = INPUT_VARIABLE(0); auto first = INPUT_VARIABLE(0);
auto z = OUTPUT_VARIABLE(0); auto z = OUTPUT_VARIABLE(0);
first->applyTransform(nd4j::transform::SoftSign, z, nullptr); first->applyTransform(nd4j::transform::SoftSign, *z);
STORE_RESULT(*z); STORE_RESULT(*z);

View File

@ -30,7 +30,7 @@ namespace nd4j {
auto first = INPUT_VARIABLE(0); auto first = INPUT_VARIABLE(0);
auto z = OUTPUT_VARIABLE(0); auto z = OUTPUT_VARIABLE(0);
first->applyTransform(nd4j::transform::Tanh, z, nullptr); first->applyTransform(nd4j::transform::Tanh, *z);
STORE_RESULT(*z); STORE_RESULT(*z);

View File

@ -34,7 +34,7 @@ namespace nd4j {
BROADCAST_CHECK_EMPTY(x,y,z); BROADCAST_CHECK_EMPTY(x,y,z);
x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::IntOps::IntAnd, pairwise::IntOps::IntAnd, broadcast::IntOps::IntAnd), y, z, false); x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::IntOps::IntAnd, pairwise::IntOps::IntAnd, broadcast::IntOps::IntAnd), *y, *z, false);
return Status::OK(); return Status::OK();
} }

View File

@ -34,7 +34,7 @@ namespace nd4j {
BROADCAST_CHECK_EMPTY(x,y,z); BROADCAST_CHECK_EMPTY(x,y,z);
x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::IntOps::IntOr, pairwise::IntOps::IntOr, broadcast::IntOps::IntOr), y, z, false); x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::IntOps::IntOr, pairwise::IntOps::IntOr, broadcast::IntOps::IntOr), *y, *z, false);
return Status::OK(); return Status::OK();
} }

View File

@ -34,7 +34,7 @@ namespace nd4j {
BROADCAST_CHECK_EMPTY(x,y,z); BROADCAST_CHECK_EMPTY(x,y,z);
x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::IntOps::IntXor, pairwise::IntOps::IntXor, broadcast::IntOps::IntXor), y, z, false); x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::IntOps::IntXor, pairwise::IntOps::IntXor, broadcast::IntOps::IntXor), *y, *z, false);
return Status::OK(); return Status::OK();
} }

View File

@ -34,7 +34,7 @@ namespace nd4j {
BROADCAST_CHECK_EMPTY(x,y,z); BROADCAST_CHECK_EMPTY(x,y,z);
x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::CyclicShiftRight, pairwise::CyclicShiftRight, broadcast::CyclicShiftRight), y, z, false); x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::CyclicShiftRight, pairwise::CyclicShiftRight, broadcast::CyclicShiftRight), *y, *z, false);
return Status::OK(); return Status::OK();
} }

View File

@ -34,7 +34,7 @@ namespace nd4j {
BROADCAST_CHECK_EMPTY(x,y,z); BROADCAST_CHECK_EMPTY(x,y,z);
x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::CyclicShiftLeft, pairwise::CyclicShiftLeft, broadcast::CyclicShiftLeft), y, z, false); x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::CyclicShiftLeft, pairwise::CyclicShiftLeft, broadcast::CyclicShiftLeft), *y, *z, false);
return Status::OK(); return Status::OK();
} }

View File

@ -34,7 +34,7 @@ namespace nd4j {
BROADCAST_CHECK_EMPTY(x,y,z); BROADCAST_CHECK_EMPTY(x,y,z);
x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::ShiftRight, pairwise::ShiftRight, broadcast::ShiftRight), y, z, false); x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::ShiftRight, pairwise::ShiftRight, broadcast::ShiftRight), *y, *z, false);
return Status::OK(); return Status::OK();
} }

View File

@ -34,7 +34,7 @@ namespace nd4j {
BROADCAST_CHECK_EMPTY(x,y,z); BROADCAST_CHECK_EMPTY(x,y,z);
x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::ShiftLeft, pairwise::ShiftLeft, broadcast::ShiftLeft), y, z, false); x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::ShiftLeft, pairwise::ShiftLeft, broadcast::ShiftLeft), *y, *z, false);
return Status::OK(); return Status::OK();
} }

View File

@ -44,7 +44,7 @@ namespace nd4j {
ExtraArguments arguments({a}); ExtraArguments arguments({a});
y->applyPairwiseTransform(pairwise::Axpy, x, z, &arguments); y->applyPairwiseTransform(pairwise::Axpy, *x, *z, &arguments);
return ND4J_STATUS_OK; return ND4J_STATUS_OK;
} }

View File

@ -33,8 +33,12 @@ CUSTOM_OP_IMPL(svd, 1, 1, false, 0, 3) {
const int rank = x->rankOf(); const int rank = x->rankOf();
REQUIRE_TRUE(rank >= 2 , 0, "SVD OP: the rank of input array must be >=2, but got %i instead!", rank); REQUIRE_TRUE(rank >= 2 , 0, "SVD OP: the rank of input array must be >=2, but got %i instead!", rank);
const bool fullUV = (bool)INT_ARG(0); bool fullUV = (bool)INT_ARG(0);
const bool calcUV = (bool)INT_ARG(1); const bool calcUV = (bool)INT_ARG(1);
if(calcUV == false)
fullUV = false;
const int switchNum = INT_ARG(2); const int switchNum = INT_ARG(2);
// #ifndef __CUDABLAS__ // #ifndef __CUDABLAS__

View File

@ -29,7 +29,7 @@ namespace nd4j {
auto x = INPUT_VARIABLE(0); auto x = INPUT_VARIABLE(0);
auto z = OUTPUT_VARIABLE(0); auto z = OUTPUT_VARIABLE(0);
x->applyTransform(transform::Not, z, nullptr); x->applyTransform(transform::Not, *z);
return Status::OK(); return Status::OK();
} }

View File

@ -70,17 +70,13 @@ namespace nd4j {
auto tadsY = y->allTensorsAlongDimension(dims); auto tadsY = y->allTensorsAlongDimension(dims);
auto tadsZ = z->allTensorsAlongDimension(dims); auto tadsZ = z->allTensorsAlongDimension(dims);
for (int e = 0; e < tadsX->size(); e++) { for (int e = 0; e < tadsX.size(); e++) {
if (!cond->e<bool>(e)) { if (!cond->e<bool>(e)) {
tadsZ->at(e)->assign(tadsY->at(e)); tadsZ.at(e)->assign(tadsY.at(e));
} else { } else {
tadsZ->at(e)->assign(tadsX->at(e)); tadsZ.at(e)->assign(tadsX.at(e));
} }
} }
delete tadsX;
delete tadsY;
delete tadsZ;
} }
} }

View File

@ -59,17 +59,13 @@ namespace nd4j {
auto tadsY = y->allTensorsAlongDimension(dims); auto tadsY = y->allTensorsAlongDimension(dims);
auto tadsZ = z->allTensorsAlongDimension(dims); auto tadsZ = z->allTensorsAlongDimension(dims);
for (int e = 0; e < tadsX->size(); e++) { for (int e = 0; e < tadsX.size(); e++) {
if (!condition->e<bool>(e)) { if (!condition->e<bool>(e)) {
tadsZ->at(e)->assign(tadsY->at(e)); tadsZ.at(e)->assign(tadsY.at(e));
} else { } else {
tadsZ->at(e)->assign(tadsX->at(e)); tadsZ.at(e)->assign(tadsX.at(e));
} }
} }
delete tadsX;
delete tadsY;
delete tadsZ;
} }
} else { } else {
// in this case we return 2D matrix, which basically contains coordinates fo true // in this case we return 2D matrix, which basically contains coordinates fo true

View File

@ -89,16 +89,12 @@ namespace nd4j {
auto tadsY = y->allTensorsAlongDimension(dims); auto tadsY = y->allTensorsAlongDimension(dims);
auto tadsZ = z->allTensorsAlongDimension(dims); auto tadsZ = z->allTensorsAlongDimension(dims);
for (int e = 0; e < tadsX->size(); e++) { for (int e = 0; e < tadsX.size(); e++) {
if (!condition->e<bool>(e)) if (!condition->e<bool>(e))
tadsZ->at(e)->assign(tadsY->at(e)); tadsZ.at(e)->assign(tadsY.at(e));
else else
tadsZ->at(e)->assign(tadsX->at(e)); tadsZ.at(e)->assign(tadsX.at(e));
} }
delete tadsX;
delete tadsY;
delete tadsZ;
} }
} else { } else {
// in this case we return 2D matrix, which basically contains coordinates fo true // in this case we return 2D matrix, which basically contains coordinates fo true

View File

@ -82,14 +82,12 @@ namespace nd4j {
if (axisX.size() > 0) { if (axisX.size() > 0) {
auto sum = epsNext->reduceAlongDimension(nd4j::reduce::Sum, axisX); auto sum = epsNext->reduceAlongDimension(nd4j::reduce::Sum, axisX);
gradX->assign(sum); gradX->assign(sum);
delete sum;
} else } else
gradX->assign(epsNext); gradX->assign(epsNext);
if (axisY.size() > 0) { if (axisY.size() > 0) {
auto sum = epsNext->reduceAlongDimension(nd4j::reduce::Sum, axisY); auto sum = epsNext->reduceAlongDimension(nd4j::reduce::Sum, axisY);
gradY->assign(sum); gradY->assign(sum);
delete sum;
} else } else
gradY->assign(epsNext); gradY->assign(epsNext);
} }

View File

@ -80,7 +80,6 @@ namespace nd4j {
if (axisY.size() > 0) { if (axisY.size() > 0) {
auto sum = epsNext->reduceAlongDimension(nd4j::reduce::Sum, axisY); auto sum = epsNext->reduceAlongDimension(nd4j::reduce::Sum, axisY);
gradY->assign(sum); gradY->assign(sum);
delete sum;
} else } else
gradY->assign(epsNext); gradY->assign(epsNext);
} }

View File

@ -36,7 +36,7 @@ BROADCASTABLE_OP_IMPL(tf_atan2, 0, 0) {
BROADCAST_CHECK_EMPTY(x,y,z); BROADCAST_CHECK_EMPTY(x,y,z);
// auto tZ = BroadcastHelper<T>::template broadcastApply<simdOps::Atan2<T>>(y, x, z); // auto tZ = BroadcastHelper<T>::template broadcastApply<simdOps::Atan2<T>>(y, x, z);
x->applyTrueBroadcast(nd4j::BroadcastOpsTuple::custom(scalar::Atan2, pairwise::Atan2, broadcast::Atan2), y, z, true); x->applyTrueBroadcast(nd4j::BroadcastOpsTuple::custom(scalar::Atan2, pairwise::Atan2, broadcast::Atan2), *y, *z, true);
// if (tZ == nullptr) // if (tZ == nullptr)
// return ND4J_STATUS_KERNEL_FAILURE; // return ND4J_STATUS_KERNEL_FAILURE;

View File

@ -81,7 +81,7 @@ namespace nd4j {
// Y gradient // Y gradient
//epsNext->applyTriplewiseLambda(x, y, lambdaY, gradY); //epsNext->applyTriplewiseLambda(x, y, lambdaY, gradY);
gradY->assign((*epsNext) * (*x) / ((*y) * (*y))); gradY->assign((*epsNext) * (*x) / ((*y) * (*y)));
gradY->applyTransform(transform::Neg, nullptr, nullptr); gradY->applyTransform(transform::Neg, *gradY);
} else if (y->isScalar()) { } else if (y->isScalar()) {
// scalar case // scalar case
@ -91,17 +91,17 @@ namespace nd4j {
//tmpX.printBuffer("SumX"); //tmpX.printBuffer("SumX");
//tmp.printBuffer("Sum Eps"); //tmp.printBuffer("Sum Eps");
gradY->assign(tmp * tmpX / ((*y) * (*y))); gradY->assign(tmp * tmpX / ((*y) * (*y)));
gradY->applyTransform(transform::Neg, nullptr, nullptr); gradY->applyTransform(transform::Neg, *gradY);
//epsNext->applyLambda(lambdaS, gradX); //epsNext->applyLambda(lambdaS, *gradX);
epsNext->applyScalarArr(scalar::Divide, y, gradX, nullptr); epsNext->applyScalarArr(scalar::Divide, *y, *gradX);
} else { } else {
// broadcast case // broadcast case
auto preX = *epsNext / *y; auto preX = *epsNext / *y;
NDArray negX(*x); NDArray negX(*x);
x->applyTransform(transform::Neg, &negX); x->applyTransform(transform::Neg, negX);
auto preY = *epsNext * negX / ((*y) * (*y)); auto preY = *epsNext * negX / ((*y) * (*y));
auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), epsNext->shapeInfo()); auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), epsNext->shapeInfo());
@ -110,14 +110,12 @@ namespace nd4j {
if (axisX.size() > 0) { if (axisX.size() > 0) {
auto sum = preX.reduceAlongDimension(reduce::Sum, axisX); auto sum = preX.reduceAlongDimension(reduce::Sum, axisX);
gradX->assign(sum); gradX->assign(sum);
delete sum;
} else } else
gradX->assign(preX); gradX->assign(preX);
if (axisY.size() > 0) { if (axisY.size() > 0) {
auto sum = preY.reduceAlongDimension(reduce::Sum, axisY); auto sum = preY.reduceAlongDimension(reduce::Sum, axisY);
gradY->assign(sum); gradY->assign(sum);
delete sum;
} else } else
gradY->assign(preY); gradY->assign(preY);
} }

View File

@ -69,7 +69,7 @@ namespace nd4j {
std::unique_ptr<ResultSet> tmpResult(op.execute({x, y}, {}, {}, {})); std::unique_ptr<ResultSet> tmpResult(op.execute({x, y}, {}, {}, {}));
if (gradY->rankOf() == gradX->rankOf()) if (gradY->rankOf() == gradX->rankOf())
epsNext->applyPairwiseTransform(pairwise::Multiply, tmpResult->at(0), gradY, nullptr); epsNext->applyPairwiseTransform(pairwise::Multiply, *tmpResult->at(0), *gradY);
else // epsNext is greater than gradY else // epsNext is greater than gradY
{ {
std::vector<Nd4jLong> dims(epsNext->rankOf() * 2); std::vector<Nd4jLong> dims(epsNext->rankOf() * 2);
@ -78,7 +78,7 @@ namespace nd4j {
dims[d * 2 + 1] = 1; dims[d * 2 + 1] = 1;
} }
auto tempIn((*tmpResult->at(0))(dims)); auto tempIn((*tmpResult->at(0))(dims));
(*epsNext)(dims).applyPairwiseTransform(pairwise::Multiply, &tempIn, gradY, nullptr); (*epsNext)(dims).applyPairwiseTransform(pairwise::Multiply, tempIn, *gradY);
} }
return Status::OK(); return Status::OK();
} }

View File

@ -79,24 +79,24 @@ CUSTOM_OP_IMPL(multiply_bp, 3, 2, false, 0, 0) {
const Nd4jLong yLen = y->lengthOf(); const Nd4jLong yLen = y->lengthOf();
if(x->isScalar() && y->isScalar()) { // both are scalars if(x->isScalar() && y->isScalar()) { // both are scalars
y->applyPairwiseTransform(pairwise::Multiply, dLdz, dLdx, nullptr); y->applyPairwiseTransform(pairwise::Multiply, *dLdz, *dLdx);
x->applyPairwiseTransform(pairwise::Multiply, dLdz, dLdy, nullptr); x->applyPairwiseTransform(pairwise::Multiply, *dLdz, *dLdy);
//dLdx->assign((*y) * (*dLdz)); //dLdx->assign((*y) * (*dLdz));
//dLdy->assign((*x) * (*dLdz)); //dLdy->assign((*x) * (*dLdz));
} }
else if(x->isScalar()) { // x is scalar and y is not else if(x->isScalar()) { // x is scalar and y is not
dLdx->assign((*y * *dLdz).reduceNumber(reduce::Sum)); dLdx->assign((*y * *dLdz).reduceNumber(reduce::Sum));
dLdz->applyScalarArr(scalar::Multiply, x, dLdy, nullptr); dLdz->applyScalarArr(scalar::Multiply, *x, *dLdy);
//dLdz->applyTrueBroadcast(broadcast::Multiply, x, dLdy, true); //dLdz->applyTrueBroadcast(broadcast::Multiply, x, dLdy, true);
} }
else if(y->isScalar()) { // y is scalar and x is not else if(y->isScalar()) { // y is scalar and x is not
dLdy->assign((*x * *dLdz).reduceNumber(reduce::Sum)); dLdy->assign((*x * *dLdz).reduceNumber(reduce::Sum));
dLdz->applyScalarArr(scalar::Multiply, y, dLdx); dLdz->applyScalarArr(scalar::Multiply, *y, *dLdx);
} }
else if(x->isSameShape(y)) { else if(x->isSameShape(y)) {
x->applyPairwiseTransform(pairwise::Multiply, dLdz, dLdy, nullptr); x->applyPairwiseTransform(pairwise::Multiply, *dLdz, *dLdy);
y->applyPairwiseTransform(pairwise::Multiply, dLdz, dLdx, nullptr); y->applyPairwiseTransform(pairwise::Multiply, *dLdz, *dLdx);
} }
else if (x->isSameShape(dLdz)) { else if (x->isSameShape(dLdz)) {
@ -104,8 +104,8 @@ CUSTOM_OP_IMPL(multiply_bp, 3, 2, false, 0, 0) {
y->tile(yTiled); y->tile(yTiled);
std::vector<int> axesForY = ShapeUtils::evalBroadcastBackwardAxis(y->getShapeInfo(), dLdz->getShapeInfo()); std::vector<int> axesForY = ShapeUtils::evalBroadcastBackwardAxis(y->getShapeInfo(), dLdz->getShapeInfo());
dLdy->assign( (*x * *dLdz).reduceAlongDims(reduce::Sum, axesForY) ); dLdy->assign( (*x * *dLdz).reduceAlongDimension(reduce::Sum, axesForY) );
yTiled.applyPairwiseTransform(pairwise::Multiply, dLdz, dLdx, nullptr); yTiled.applyPairwiseTransform(pairwise::Multiply, *dLdz, *dLdx);
} }
else if (y->isSameShape(dLdz)) { else if (y->isSameShape(dLdz)) {
@ -113,8 +113,8 @@ CUSTOM_OP_IMPL(multiply_bp, 3, 2, false, 0, 0) {
x->tile(xTiled); x->tile(xTiled);
std::vector<int> axesForX = ShapeUtils::evalBroadcastBackwardAxis(x->getShapeInfo(), dLdz->getShapeInfo()); std::vector<int> axesForX = ShapeUtils::evalBroadcastBackwardAxis(x->getShapeInfo(), dLdz->getShapeInfo());
dLdx->assign( (*y * *dLdz).reduceAlongDims(reduce::Sum, axesForX) ); dLdx->assign( (*y * *dLdz).reduceAlongDimension(reduce::Sum, axesForX) );
xTiled.applyPairwiseTransform(pairwise::Multiply, dLdz, dLdy, nullptr); xTiled.applyPairwiseTransform(pairwise::Multiply, *dLdz, *dLdy);
} }
else { else {
@ -125,8 +125,8 @@ CUSTOM_OP_IMPL(multiply_bp, 3, 2, false, 0, 0) {
std::vector<int> axesForX = ShapeUtils::evalBroadcastBackwardAxis(x->getShapeInfo(), dLdz->getShapeInfo()); std::vector<int> axesForX = ShapeUtils::evalBroadcastBackwardAxis(x->getShapeInfo(), dLdz->getShapeInfo());
std::vector<int> axesForY = ShapeUtils::evalBroadcastBackwardAxis(y->getShapeInfo(), dLdz->getShapeInfo()); std::vector<int> axesForY = ShapeUtils::evalBroadcastBackwardAxis(y->getShapeInfo(), dLdz->getShapeInfo());
dLdx->assign( (*y * *dLdz).reduceAlongDims(reduce::Sum, axesForX) ); dLdx->assign( (*y * *dLdz).reduceAlongDimension(reduce::Sum, axesForX) );
dLdy->assign( (*x * *dLdz).reduceAlongDims(reduce::Sum, axesForY) ); dLdy->assign( (*x * *dLdz).reduceAlongDimension(reduce::Sum, axesForY) );
} }
return Status::OK(); return Status::OK();
@ -182,7 +182,7 @@ DECLARE_SHAPE_FN(multiply_bp) {
T tmpX = x->template reduceNumber<simdOps::Sum<T>>(); T tmpX = x->template reduceNumber<simdOps::Sum<T>>();
gradY->assign(tmpX); gradY->assign(tmpX);
epsNext->applyLambda(lambdaS, gradX); epsNext->applyLambda(lambdaS, *gradX);
} else { } else {
// broadcast case // broadcast case

View File

@ -71,7 +71,7 @@ namespace nd4j {
// X gradient // X gradient
//epsNext->applyPairwiseLambda(y, lambdaX, gradX); //epsNext->applyPairwiseLambda(y, lambdaX, gradX);
epsNext->applyPairwiseTransform(pairwise::Divide, y, gradX, nullptr); epsNext->applyPairwiseTransform(pairwise::Divide, *y, *gradX);
// Y gradient // Y gradient
//epsNext->applyTriplewiseLambda(x, y, lambdaY, gradY); //epsNext->applyTriplewiseLambda(x, y, lambdaY, gradY);
@ -86,14 +86,14 @@ namespace nd4j {
gradY->assign(tmp * -tmpX / ((*y) * (*y))); gradY->assign(tmp * -tmpX / ((*y) * (*y)));
//epsNext->applyLambda(lambdaS, gradX); //epsNext->applyLambda(lambdaS, gradX);
epsNext->applyScalarArr(scalar::Divide, y, gradX, nullptr); epsNext->applyScalarArr(scalar::Divide, *y, *gradX);
} else { } else {
// broadcast case // broadcast case
auto preX = *epsNext / *y; auto preX = *epsNext / *y;
NDArray negX(*x); NDArray negX(*x);
x->applyTransform(transform::Neg, &negX); x->applyTransform(transform::Neg, negX);
auto preY = *epsNext * negX / ((*y) * (*y)); auto preY = *epsNext * negX / ((*y) * (*y));
auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), epsNext->shapeInfo()); auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), epsNext->shapeInfo());
@ -102,14 +102,12 @@ namespace nd4j {
if (axisX.size() > 0) { if (axisX.size() > 0) {
auto sum = preX.reduceAlongDimension(reduce::Sum, axisX); auto sum = preX.reduceAlongDimension(reduce::Sum, axisX);
gradX->assign(sum); gradX->assign(sum);
delete sum;
} else } else
gradX->assign(preX); gradX->assign(preX);
if (axisY.size() > 0) { if (axisY.size() > 0) {
auto sum = preY.reduceAlongDimension(reduce::Sum, axisY); auto sum = preY.reduceAlongDimension(reduce::Sum, axisY);
gradY->assign(sum); gradY->assign(sum);
delete sum;
} else } else
gradY->assign(preY); gradY->assign(preY);
} }

View File

@ -34,7 +34,7 @@ namespace nd4j {
BROADCAST_CHECK_EMPTY(x,y,z); BROADCAST_CHECK_EMPTY(x,y,z);
REQUIRE_TRUE(!x->isB(), 0, "REVERSEDIVIDE OP: you can't divide by bool array!"); REQUIRE_TRUE(!x->isB(), 0, "REVERSEDIVIDE OP: you can't divide by bool array!");
x->applyTrueBroadcast(BROADCAST(ReverseDivide), y, z, true); x->applyTrueBroadcast(BROADCAST(ReverseDivide), *y, *z, true);
return Status::OK(); return Status::OK();
} }
@ -67,7 +67,7 @@ namespace nd4j {
// X gradient // X gradient
//epsNext->applyTriplewiseLambda(x, y, lambdaX, gradX); //epsNext->applyTriplewiseLambda(x, y, lambdaX, gradX);
gradX->assign((*epsNext) * (*y) / ((*x) * (*x))); gradX->assign((*epsNext) * (*y) / ((*x) * (*x)));
gradX->applyTransform(transform::Neg, nullptr, nullptr); gradX->applyTransform(transform::Neg, *gradX);
// Y gradient // Y gradient
//epsNext->applyPairwiseLambda(x, lambdaY, gradY); //epsNext->applyPairwiseLambda(x, lambdaY, gradY);
gradY->assign((*epsNext) / (*x)); gradY->assign((*epsNext) / (*x));
@ -78,14 +78,14 @@ namespace nd4j {
gradY->assign(tmp / tmpX); gradY->assign(tmp / tmpX);
gradX->assign((*epsNext) * (*y) / ((*x) * (*x))); gradX->assign((*epsNext) * (*y) / ((*x) * (*x)));
gradX->applyTransform(transform::Neg, nullptr, nullptr); gradX->applyTransform(transform::Neg, *gradX);
} else { } else {
// broadcast case // broadcast case
auto preY = (*epsNext) / (*x); auto preY = (*epsNext) / (*x);
auto preX = *epsNext * (*y) / ((*x) * (*x)); auto preX = *epsNext * (*y) / ((*x) * (*x));
preX.applyTransform(transform::Neg, nullptr, nullptr); preX.applyTransform(transform::Neg, preX);
auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), epsNext->shapeInfo()); auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), epsNext->shapeInfo());
auto axisY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), epsNext->shapeInfo()); auto axisY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), epsNext->shapeInfo());
@ -93,14 +93,12 @@ namespace nd4j {
if (axisX.size() > 0) { if (axisX.size() > 0) {
auto sum = preX.reduceAlongDimension(reduce::Sum, axisX); auto sum = preX.reduceAlongDimension(reduce::Sum, axisX);
gradX->assign(sum); gradX->assign(sum);
delete sum;
} else } else
gradX->assign(preX); gradX->assign(preX);
if (axisY.size() > 0) { if (axisY.size() > 0) {
auto sum = preY.reduceAlongDimension(reduce::Sum, axisY); auto sum = preY.reduceAlongDimension(reduce::Sum, axisY);
gradY->assign(sum); gradY->assign(sum);
delete sum;
} else } else
gradY->assign(preY); gradY->assign(preY);
} }

View File

@ -61,13 +61,13 @@ namespace nd4j {
if (x->isSameShape(y)) { if (x->isSameShape(y)) {
// PWT case case // PWT case case
epsNext->applyTransform(transform::Neg, gradX, nullptr); epsNext->applyTransform(transform::Neg, *gradX);
gradY->assign(epsNext); gradY->assign(epsNext);
} else if (y->isScalar()) { } else if (y->isScalar()) {
// scalar case // scalar case
auto tmp = epsNext->reduceNumber(reduce::Sum); auto tmp = epsNext->reduceNumber(reduce::Sum);
gradY->assign(tmp); gradY->assign(tmp);
epsNext->applyTransform(transform::Neg, gradX, nullptr); epsNext->applyTransform(transform::Neg, *gradX);
} else { } else {
// broadcastable // broadcastable
auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), epsNext->shapeInfo()); auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), epsNext->shapeInfo());
@ -75,16 +75,14 @@ namespace nd4j {
if (axisX.size() > 0) { if (axisX.size() > 0) {
auto sum = epsNext->reduceAlongDimension(reduce::Sum, axisX); auto sum = epsNext->reduceAlongDimension(reduce::Sum, axisX);
sum->applyTransform(transform::Neg, gradX); sum.applyTransform(transform::Neg, *gradX);
delete sum;
} else { } else {
epsNext->applyTransform(transform::Neg, gradX, nullptr); epsNext->applyTransform(transform::Neg, *gradX);
} }
if (axisY.size() > 0) { if (axisY.size() > 0) {
auto sum = epsNext->reduceAlongDimension(reduce::Sum, axisY); auto sum = epsNext->reduceAlongDimension(reduce::Sum, axisY);
gradY->assign(sum); gradY->assign(sum);
delete sum;
} else { } else {
gradY->assign(epsNext); gradY->assign(epsNext);
} }

View File

@ -98,37 +98,31 @@ namespace nd4j {
auto targetShape = epsNext->getShapeAsVector(); auto targetShape = epsNext->getShapeAsVector();
preX->tileToShape(targetShape); preX.tileToShape(targetShape, preX);
preY->tileToShape(targetShape); preY.tileToShape(targetShape, preY);
//epsNext->applyTriplewiseLambda(x, y, lambdaX, preX); //epsNext->applyTriplewiseLambda(x, y, lambdaX, preX);
//epsNext->applyTriplewiseLambda(x, y, lambdaY, preY); //epsNext->applyTriplewiseLambda(x, y, lambdaY, preY);
auto resX = (*epsNext) * ts * ((*x) - (*y)); auto resX = (*epsNext) * ts * ((*x) - (*y));
preX->assign(resX); preX.assign(resX);
auto resY = (*epsNext) * ts * ((*y) - (*x)); auto resY = (*epsNext) * ts * ((*y) - (*x));
preY->assign(resY); preY.assign(resY);
auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), epsNext->shapeInfo()); auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), epsNext->shapeInfo());
auto axisY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), epsNext->shapeInfo()); auto axisY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), epsNext->shapeInfo());
if (axisX.size() > 0) { if (axisX.size() > 0) {
auto sum = preX->reduceAlongDimension(reduce::Sum, axisX); auto sum = preX.reduceAlongDimension(reduce::Sum, axisX);
gradX->assign(sum); gradX->assign(sum);
delete sum;
} else } else
gradX->assign(preX); gradX->assign(preX);
if (axisY.size() > 0) { if (axisY.size() > 0) {
auto sum = preY->reduceAlongDimension(reduce::Sum, axisY); auto sum = preY.reduceAlongDimension(reduce::Sum, axisY);
gradY->assign(sum); gradY->assign(sum);
delete sum;
} else } else
gradY->assign(preY); gradY->assign(preY);
delete preX;
delete preY;
} }
return Status::OK(); return Status::OK();

View File

@ -62,7 +62,7 @@ namespace nd4j {
if (x->isSameShape(y)) { if (x->isSameShape(y)) {
// PWT case case // PWT case case
epsNext->applyTransform(transform::Neg, gradY, nullptr); epsNext->applyTransform(transform::Neg, *gradY);
gradX->assign(epsNext); gradX->assign(epsNext);
} else if (y->isScalar()) { } else if (y->isScalar()) {
// scalar case // scalar case
@ -77,16 +77,14 @@ namespace nd4j {
if (axisX.size() > 0) { if (axisX.size() > 0) {
auto sum = epsNext->reduceAlongDimension(reduce::Sum, axisX); auto sum = epsNext->reduceAlongDimension(reduce::Sum, axisX);
gradX->assign(sum); gradX->assign(sum);
delete sum;
} else } else
gradX->assign(epsNext); gradX->assign(epsNext);
if (axisY.size() > 0) { if (axisY.size() > 0) {
auto sum = epsNext->reduceAlongDimension(reduce::Sum, axisY); auto sum = epsNext->reduceAlongDimension(reduce::Sum, axisY);
sum->applyTransform(transform::Neg, gradY); sum.applyTransform(transform::Neg, *gradY);
delete sum;
} else { } else {
epsNext->applyTransform(transform::Neg, gradY); epsNext->applyTransform(transform::Neg, *gradY);
} }
} }

View File

@ -41,10 +41,10 @@ namespace nd4j {
// but we'll ensure only one node is active, and other is disabled // but we'll ensure only one node is active, and other is disabled
if (condition->e<int>(0) == 0) { if (condition->e<int>(0) == 0) {
block.setBranch(0); block.setBranch(0);
this->storeResult(block, 0, input->dup()); this->storeResult(block, 0, new NDArray(input->dup()));
} else { } else {
block.setBranch(1); block.setBranch(1);
this->storeResult(block, 1, *input->dup()); this->storeResult(block, 1, new NDArray(input->dup()));
} }
return Status::OK(); return Status::OK();

View File

@ -42,34 +42,34 @@ namespace nd4j {
std::unique_ptr<NDArray> ptr; std::unique_ptr<NDArray> ptr;
if (!Environment::getInstance()->isExperimentalBuild()) { if (!Environment::getInstance()->isExperimentalBuild()) {
if (y->dataType() != x->dataType()) { if (y->dataType() != x->dataType()) {
y = y->cast(x->dataType()); y = new NDArray(y->cast(x->dataType()));
std::unique_ptr<NDArray> ptr2(y); std::unique_ptr<NDArray> ptr2(y);
ptr.swap(ptr2); ptr.swap(ptr2);
} }
} }
if (!x->isScalar() && !y->isScalar() && x->isSameShape(y)) { if (!x->isScalar() && !y->isScalar() && x->isSameShape(y)) {
x->applyPairwiseTransform(op.p, y, z, nullptr); x->applyPairwiseTransform(op.p, *y, *z);
} else if (!x->isScalar() && y->isScalar()) { } else if (!x->isScalar() && y->isScalar()) {
x->applyScalarArr(op.s, const_cast<const NDArray*>(y), z); x->applyScalarArr(op.s, const_cast<const NDArray&>(*y), *z);
} else if (x->isScalar() && !y->isScalar()) { } else if (x->isScalar() && !y->isScalar()) {
if (z->isSameShape(y)) { if (z->isSameShape(y)) {
if (op.s == scalar::Add || op.s == scalar::Multiply ) { if (op.s == scalar::Add || op.s == scalar::Multiply ) {
y->applyScalarArr(op.s, x, z, nullptr); y->applyScalarArr(op.s, *x, *z);
} else if (op.s == scalar::SquaredSubtract) { } else if (op.s == scalar::SquaredSubtract) {
y->applyScalarArr(scalar::SquaredReverseSubtract, x, z, nullptr); y->applyScalarArr(scalar::SquaredReverseSubtract, *x, *z);
} else if (op.s == scalar::Subtract) { } else if (op.s == scalar::Subtract) {
y->applyScalarArr(scalar::ReverseSubtract, x, z, nullptr); y->applyScalarArr(scalar::ReverseSubtract, *x, *z);
} else if (op.s == scalar::Divide) { } else if (op.s == scalar::Divide) {
y->applyScalarArr(scalar::ReverseDivide, x, z, nullptr); y->applyScalarArr(scalar::ReverseDivide, *x, *z);
} else if (op.s == scalar::Pow) { } else if (op.s == scalar::Pow) {
y->applyScalarArr(scalar::ReversePow, x, z, nullptr); y->applyScalarArr(scalar::ReversePow, *x, *z);
} else if (op.s == scalar::ReverseSubtract) { } else if (op.s == scalar::ReverseSubtract) {
y->applyScalarArr(scalar::Subtract, x, z, nullptr); y->applyScalarArr(scalar::Subtract, *x, *z);
} else if (op.s == scalar::ReverseDivide) { } else if (op.s == scalar::ReverseDivide) {
y->applyScalarArr(scalar::Divide, x, z, nullptr); y->applyScalarArr(scalar::Divide, *x, *z);
} else if (op.s == scalar::MaxPairwise || op.s == scalar::MinPairwise || op.s == scalar::AMaxPairwise || op.s == scalar::AMinPairwise) { } else if (op.s == scalar::MaxPairwise || op.s == scalar::MinPairwise || op.s == scalar::AMaxPairwise || op.s == scalar::AMinPairwise) {
y->applyScalarArr(op.s, x, z, nullptr); y->applyScalarArr(op.s, *x, *z);
} else if (op.s == scalar::CopyPws) { } else if (op.s == scalar::CopyPws) {
z->assign(y); z->assign(y);
} else { } else {
@ -84,9 +84,9 @@ namespace nd4j {
return tZ; return tZ;
} }
} else if (x->isScalar() && y->isScalar()) { // x->isScalar() && y->isScalar() } else if (x->isScalar() && y->isScalar()) { // x->isScalar() && y->isScalar()
x->applyScalarArr(op.s, const_cast<const NDArray*>(y), z, nullptr); x->applyScalarArr(op.s, const_cast<const NDArray&>(*y), *z);
} else if (ShapeUtils::areShapesBroadcastable(*x, *y)) { } else if (ShapeUtils::areShapesBroadcastable(*x, *y)) {
x->applyTrueBroadcast(op, y, z, true, extraArgs); x->applyTrueBroadcast(op, *y, *z, true, extraArgs);
return z; return z;
} else { } else {
auto sx = ShapeUtils::shapeAsString(x); auto sx = ShapeUtils::shapeAsString(x);
@ -107,16 +107,16 @@ namespace nd4j {
} }
if (!x->isScalar() && !y->isScalar() && x->isSameShape(y)) { if (!x->isScalar() && !y->isScalar() && x->isSameShape(y)) {
x->applyPairwiseTransform(op.p, y, z, nullptr); x->applyPairwiseTransform(op.p, *y, *z);
} else if (ShapeUtils::areShapesBroadcastable(*x, *y)) { } else if (ShapeUtils::areShapesBroadcastable(*x, *y)) {
x->applyTrueBroadcast(op, y, z, true, extraArgs); x->applyTrueBroadcast(op, *y, *z, true, extraArgs);
return z; return z;
} else if (!x->isScalar() && y->isScalar()) { } else if (!x->isScalar() && y->isScalar()) {
x->applyScalarArr(op.s, const_cast<const NDArray*>(y), z); x->applyScalarArr(op.s, const_cast<const NDArray&>(*y), *z);
} else if (x->isScalar() && !y->isScalar()) { } else if (x->isScalar() && !y->isScalar()) {
if (z->isSameShape(y)) { if (z->isSameShape(y)) {
//z->assign(x); //z->assign(x);
x->applyPairwiseTransform(op.p, y, z, extraArgs); x->applyPairwiseTransform(op.p, *y, *z, extraArgs);
return z; return z;
} else { } else {
auto v = y->getShapeAsVector(); auto v = y->getShapeAsVector();
@ -125,9 +125,9 @@ namespace nd4j {
return tZ; return tZ;
} }
} else if (x->isScalar() && y->isScalar()) { // x->isScalar() && y->isScalar() } else if (x->isScalar() && y->isScalar()) { // x->isScalar() && y->isScalar()
x->applyScalarArr(op.s, const_cast<const NDArray*>(y), z, nullptr); x->applyScalarArr(op.s, const_cast<const NDArray&>(*y), *z);
} else if (ShapeUtils::areShapesBroadcastable(*x, *y)) { } else if (ShapeUtils::areShapesBroadcastable(*x, *y)) {
x->applyTrueBroadcast(op, y, z, true, extraArgs); x->applyTrueBroadcast(op, *y, *z, true, extraArgs);
return z; return z;
} else { } else {
auto sx = ShapeUtils::shapeAsString(x); auto sx = ShapeUtils::shapeAsString(x);

View File

@ -51,12 +51,12 @@ namespace nd4j {
std::vector<int> axis = ShapeUtils::evalDimsToExclude(array->rankOf(), {0}); std::vector<int> axis = ShapeUtils::evalDimsToExclude(array->rankOf(), {0});
auto tads = array->allTensorsAlongDimension( axis); auto tads = array->allTensorsAlongDimension( axis);
for (int e = 0; e < tads->size(); e++) { for (int e = 0; e < tads.size(); e++) {
auto idx = indices->e<int>(e); auto idx = indices->e<int>(e);
if (idx >= tads->size()) if (idx >= tads.size())
return ND4J_STATUS_BAD_ARGUMENTS; return ND4J_STATUS_BAD_ARGUMENTS;
auto arr = tads->at(e)->dup(array->ordering()); auto arr = new NDArray(tads.at(e)->dup(array->ordering()));
auto res = list->write(idx, arr); auto res = list->write(idx, arr);
if (res != ND4J_STATUS_OK) if (res != ND4J_STATUS_OK)
return res; return res;
@ -65,7 +65,6 @@ namespace nd4j {
if (!hasList) if (!hasList)
//OVERWRITE_RESULT(list); //OVERWRITE_RESULT(list);
setupResultList(list, block); setupResultList(list, block);
delete tads;
return Status::OK(); return Status::OK();
} }

View File

@ -66,7 +66,7 @@ namespace nd4j {
auto subarray = (*array)(indices); auto subarray = (*array)(indices);
auto status = list->write(e, subarray.dup(array->ordering())); auto status = list->write(e, new NDArray(subarray.dup(array->ordering())));
if (status != ND4J_STATUS_OK) if (status != ND4J_STATUS_OK)
return status; return status;

View File

@ -39,7 +39,7 @@ namespace nd4j {
//nd4j_printf("Writing [%i]:\n", idx->e<int>(0)); //nd4j_printf("Writing [%i]:\n", idx->e<int>(0));
//input->printShapeInfo("input shape"); //input->printShapeInfo("input shape");
//input->printIndexedBuffer("input buffer"); //input->printIndexedBuffer("input buffer");
Nd4jStatus result = list->write(idx->e<int>(0), input->dup()); Nd4jStatus result = list->write(idx->e<int>(0), new NDArray(input->dup()));
auto res = NDArrayFactory::create_(list->counter(), block.launchContext()); auto res = NDArrayFactory::create_(list->counter(), block.launchContext());
//res->printShapeInfo("Write_list 2 output shape"); //res->printShapeInfo("Write_list 2 output shape");
@ -52,7 +52,7 @@ namespace nd4j {
auto input = INPUT_VARIABLE(1); auto input = INPUT_VARIABLE(1);
auto idx = INT_ARG(0); auto idx = INT_ARG(0);
Nd4jStatus result = list->write(idx, input->dup()); Nd4jStatus result = list->write(idx, new NDArray(input->dup()));
auto res = NDArrayFactory::create_(list->counter(), block.launchContext()); auto res = NDArrayFactory::create_(list->counter(), block.launchContext());
//res->printShapeInfo("Write_list 1 output shape"); //res->printShapeInfo("Write_list 1 output shape");

View File

@ -169,10 +169,10 @@ CUSTOM_OP_IMPL(absolute_difference_loss_grad, 3, 3, false, 0, 1) {
NDArray E = *predictions - *labels; NDArray E = *predictions - *labels;
// dE_i/dp_i = sign(p_i - y_i) // dE_i/dp_i = sign(p_i - y_i)
E.applyTransform(nd4j::transform::Sign, dLdp); // dE/dp E.applyTransform(nd4j::transform::Sign, *dLdp); // dE/dp
// dE_i/dy_i = -sign(p_i - y_i) // dE_i/dy_i = -sign(p_i - y_i)
E.applyTransform(nd4j::transform::Abs); E.applyTransform(nd4j::transform::Abs, E);
switch (reductionMode) { switch (reductionMode) {
@ -184,7 +184,7 @@ CUSTOM_OP_IMPL(absolute_difference_loss_grad, 3, 3, false, 0, 1) {
dLdw->assign(E.reduceNumber(reduce::Sum)); dLdw->assign(E.reduceNumber(reduce::Sum));
else if(weights != weightsBroad) { else if(weights != weightsBroad) {
std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo());
E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false);
} }
else else
dLdw->assign(E); dLdw->assign(E);
@ -210,7 +210,7 @@ CUSTOM_OP_IMPL(absolute_difference_loss_grad, 3, 3, false, 0, 1) {
*dLdw = 0.; *dLdw = 0.;
else if(weights != weightsBroad) { else if(weights != weightsBroad) {
std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo());
((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false);
} }
else else
dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)); dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum));
@ -238,7 +238,7 @@ CUSTOM_OP_IMPL(absolute_difference_loss_grad, 3, 3, false, 0, 1) {
dLdw->assign(E.reduceNumber(reduce::Sum) / double(numOfNonZeroWeights)); dLdw->assign(E.reduceNumber(reduce::Sum) / double(numOfNonZeroWeights));
else if(weights != weightsBroad) { else if(weights != weightsBroad) {
std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo());
E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false);
*dLdw /= numOfNonZeroWeightsScalar; *dLdw /= numOfNonZeroWeightsScalar;
} }
else else

View File

@ -56,7 +56,7 @@ CUSTOM_OP_IMPL(cosine_distance_loss, 3, 1, false, 0, 2) {
REQUIRE_TRUE(weights->isScalar() || ShapeUtils::areShapesBroadcastable(*weights, *output), 0, "COSINE_DISTANCE_LOSS OP: shapes of weights and output arrays should be broadcastable, but got weights = %s and output = %s instead!", ShapeUtils::shapeAsString(weights).c_str(), ShapeUtils::shapeAsString(labels).c_str()); REQUIRE_TRUE(weights->isScalar() || ShapeUtils::areShapesBroadcastable(*weights, *output), 0, "COSINE_DISTANCE_LOSS OP: shapes of weights and output arrays should be broadcastable, but got weights = %s and output = %s instead!", ShapeUtils::shapeAsString(weights).c_str(), ShapeUtils::shapeAsString(labels).c_str());
} }
NDArray E = 1. - (*predictions * *labels).reduceAlongDims(reduce::Sum, {dim}, true); NDArray E = 1. - (*predictions * *labels).reduceAlongDimension(reduce::Sum, {dim}, true);
// perform weights broadcasting/tile to E if it is necessary // perform weights broadcasting/tile to E if it is necessary
auto weightsBroad = weights; auto weightsBroad = weights;
@ -194,7 +194,7 @@ CUSTOM_OP_IMPL(cosine_distance_loss_grad, 3, 3, false, 0, 2) {
// input dimension can't be larger than labels/predictions/weights rank // input dimension can't be larger than labels/predictions/weights rank
REQUIRE_TRUE(dim < labels->rankOf(), 0, "COSINE_DISTANCE_LOSS_GRAD OP: input reduction dimension (got %i) must be < labels rank %i!", dim, labels->rankOf()); REQUIRE_TRUE(dim < labels->rankOf(), 0, "COSINE_DISTANCE_LOSS_GRAD OP: input reduction dimension (got %i) must be < labels rank %i!", dim, labels->rankOf());
NDArray E = 1. - (*predictions * *labels).reduceAlongDims(reduce::Sum, {dim}, true); NDArray E = 1. - (*predictions * *labels).reduceAlongDimension(reduce::Sum, {dim}, true);
// perform weights broadcasting/tile to E if it is necessary // perform weights broadcasting/tile to E if it is necessary
auto weightsBroad = weights; auto weightsBroad = weights;
@ -216,7 +216,7 @@ CUSTOM_OP_IMPL(cosine_distance_loss_grad, 3, 3, false, 0, 2) {
else { else {
if(weights != weightsBroad) { if(weights != weightsBroad) {
std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo());
E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false);
} }
else else
dLdw->assign(E); dLdw->assign(E);
@ -249,7 +249,7 @@ CUSTOM_OP_IMPL(cosine_distance_loss_grad, 3, 3, false, 0, 2) {
if(weights != weightsBroad) { if(weights != weightsBroad) {
std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo());
((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false);
} }
else else
dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)); dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum));
@ -284,7 +284,7 @@ CUSTOM_OP_IMPL(cosine_distance_loss_grad, 3, 3, false, 0, 2) {
if(weights != weightsBroad) { if(weights != weightsBroad) {
std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo());
E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false);
*dLdw /= numOfNonZeroWeights; *dLdw /= numOfNonZeroWeights;
} }
else else

View File

@ -52,7 +52,7 @@ namespace nd4j {
// We first need to convert binary labels to -1/1 labels (as floats) // We first need to convert binary labels to -1/1 labels (as floats)
NDArray E = 1.f - (*labels * 2.f - 1.f) * (*logits); NDArray E = 1.f - (*labels * 2.f - 1.f) * (*logits);
E.applyScalar(scalar::RELU, 0.0f, &E); E.applyScalar(scalar::RELU, 0.0f, E);
// multiply E on weights // multiply E on weights
E *= *weightsBroad; E *= *weightsBroad;
@ -172,11 +172,11 @@ namespace nd4j {
NDArray z = (*labels * 2.f - 1.f); NDArray z = (*labels * 2.f - 1.f);
NDArray E = 1.f - z * (*logits); NDArray E = 1.f - z * (*logits);
E.applyScalar(scalar::RELU, 0.0f, &E); E.applyScalar(scalar::RELU, 0.0f, E);
// turn E into gradient mask // turn E into gradient mask
NDArray gradientMask(E.getShapeInfo(), block.getWorkspace()); NDArray gradientMask(E.getShapeInfo(), block.getWorkspace());
E.applyTransform(nd4j::transform::Sign, &gradientMask); E.applyTransform(nd4j::transform::Sign, gradientMask);
dLdp->assign(-z * gradientMask); dLdp->assign(-z * gradientMask);
dLdl->assign(-2.f * (*logits) * gradientMask); dLdl->assign(-2.f * (*logits) * gradientMask);
@ -192,7 +192,7 @@ namespace nd4j {
dLdw->assign(E.reduceNumber(reduce::Sum)); dLdw->assign(E.reduceNumber(reduce::Sum));
else if(weights != weightsBroad) { else if(weights != weightsBroad) {
std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo());
E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false);
} }
else else
dLdw->assign(E); dLdw->assign(E);
@ -220,7 +220,7 @@ namespace nd4j {
*dLdw = 0.; *dLdw = 0.;
else if(weights != weightsBroad) { else if(weights != weightsBroad) {
std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo());
((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false);
} }
else else
dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)); dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum));
@ -249,7 +249,7 @@ namespace nd4j {
dLdw->assign(E.reduceNumber(reduce::Sum) / double(numOfNonZeroWeights)); dLdw->assign(E.reduceNumber(reduce::Sum) / double(numOfNonZeroWeights));
else if(weights != weightsBroad) { else if(weights != weightsBroad) {
std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo());
E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false);
*dLdw /= numOfNonZeroWeightsScalar; *dLdw /= numOfNonZeroWeightsScalar;
} }
else else

View File

@ -53,9 +53,9 @@ CUSTOM_OP_IMPL(huber_loss, 3, 1, false, 1, 1) {
weightsBroad = new NDArray(weights->tileToShape(predictions->getShapeInfo())); weightsBroad = new NDArray(weights->tileToShape(predictions->getShapeInfo()));
auto error = *predictions - *labels; auto error = *predictions - *labels;
error.applyTransform(transform::Abs); error.applyTransform(transform::Abs, error);
NDArray quadratic(error.getShapeInfo(), block.getWorkspace()); NDArray quadratic(error.getShapeInfo(), block.getWorkspace());
error.applyScalar(scalar::MinPairwise, delta, &quadratic); error.applyScalar(scalar::MinPairwise, delta, quadratic);
NDArray E = quadratic * quadratic * 0.5f + (error - quadratic)*delta; NDArray E = quadratic * quadratic * 0.5f + (error - quadratic)*delta;
@ -173,24 +173,24 @@ DECLARE_SHAPE_FN(huber_loss) {
NDArray diff = *predictions - *labels; NDArray diff = *predictions - *labels;
NDArray absDiff(diff); NDArray absDiff(diff);
absDiff.applyTransform(transform::Abs); absDiff.applyTransform(transform::Abs, absDiff);
NDArray quadratic(absDiff); NDArray quadratic(absDiff);
absDiff.applyScalar(scalar::MinPairwise, delta, &quadratic); absDiff.applyScalar(scalar::MinPairwise, delta, quadratic);
NDArray E = quadratic * quadratic * 0.5f + (absDiff - quadratic)*delta; NDArray E = quadratic * quadratic * 0.5f + (absDiff - quadratic)*delta;
NDArray lteMask(diff.getShapeInfo(), BOOL, true, block.launchContext()); NDArray lteMask(diff.getShapeInfo(), BOOL, true, block.launchContext());
absDiff.applyScalar(scalar::LessThanOrEqual, delta, &lteMask); absDiff.applyScalar(scalar::LessThanOrEqual, delta, lteMask);
NDArray gtMask(diff.getShapeInfo(), BOOL, true, block.launchContext()); NDArray gtMask(diff.getShapeInfo(), BOOL, true, block.launchContext());
absDiff.applyScalar(scalar::GreaterThan, delta, &gtMask); absDiff.applyScalar(scalar::GreaterThan, delta, gtMask);
NDArray signDiff(diff); NDArray signDiff(diff);
diff.applyTransform(transform::Sign, &signDiff); diff.applyTransform(transform::Sign, signDiff);
auto gtMaskFloat = *gtMask.cast(diff.dataType()); auto gtMaskFloat = gtMask.cast(diff.dataType());
auto lteMaskFloat = *lteMask.cast(diff.dataType()); auto lteMaskFloat = lteMask.cast(diff.dataType());
dLdp->assign( lteMaskFloat * diff + gtMaskFloat * delta * signDiff); dLdp->assign( lteMaskFloat * diff + gtMaskFloat * delta * signDiff);
@ -207,7 +207,7 @@ DECLARE_SHAPE_FN(huber_loss) {
dLdw->assign(E.reduceNumber(reduce::Sum)); dLdw->assign(E.reduceNumber(reduce::Sum));
else if(weights != weightsBroad) { else if(weights != weightsBroad) {
std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo());
E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false);
} }
else else
dLdw->assign(E); dLdw->assign(E);
@ -235,7 +235,7 @@ DECLARE_SHAPE_FN(huber_loss) {
*dLdw = 0.; *dLdw = 0.;
else if(weights != weightsBroad) { else if(weights != weightsBroad) {
std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo());
((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false);
} }
else else
dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)); dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum));
@ -264,7 +264,7 @@ DECLARE_SHAPE_FN(huber_loss) {
dLdw->assign(E.reduceNumber(reduce::Sum) / double(numOfNonZeroWeights)); dLdw->assign(E.reduceNumber(reduce::Sum) / double(numOfNonZeroWeights));
else if(weights != weightsBroad) { else if(weights != weightsBroad) {
std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo());
E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false);
*dLdw /= numOfNonZeroWeightsScalar; *dLdw /= numOfNonZeroWeightsScalar;
} }
else else

View File

@ -181,7 +181,7 @@ CUSTOM_OP_IMPL(log_loss_grad, 3, 3, false, 1, 1) {
// dE_i/dp_i = (1-y_i)/(1-p_i+eps) - y_i/(p_i+eps) // dE_i/dp_i = (1-y_i)/(1-p_i+eps) - y_i/(p_i+eps)
dLdp->assign(oneMinusLabels / onePlusEpsMinusPredict - *labels / predictPlusEps); // dE/dp dLdp->assign(oneMinusLabels / onePlusEpsMinusPredict - *labels / predictPlusEps); // dE/dp
// dE_i/dy_i = log((1+2eps)/(p_i+eps) - 1) // dE_i/dy_i = log((1+2eps)/(p_i+eps) - 1)
((1. + 2. * epsilon) / predictPlusEps - 1.).applyTransform(transform::Log, dLdl); // dE/dy ((1. + 2. * epsilon) / predictPlusEps - 1.).applyTransform(transform::Log, *dLdl); // dE/dy
NDArray E = -(*labels) * predictPlusEps.transform(transform::Log) - oneMinusLabels * onePlusEpsMinusPredict.transform(transform::Log); NDArray E = -(*labels) * predictPlusEps.transform(transform::Log) - oneMinusLabels * onePlusEpsMinusPredict.transform(transform::Log);
@ -196,7 +196,7 @@ CUSTOM_OP_IMPL(log_loss_grad, 3, 3, false, 1, 1) {
dLdw->assign(E.reduceNumber(reduce::Sum)); dLdw->assign(E.reduceNumber(reduce::Sum));
else if(weights != weightsBroad) { else if(weights != weightsBroad) {
std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo());
E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false);
} }
else else
dLdw->assign(E); dLdw->assign(E);
@ -226,7 +226,7 @@ CUSTOM_OP_IMPL(log_loss_grad, 3, 3, false, 1, 1) {
*dLdw = 0.; *dLdw = 0.;
else if(weights != weightsBroad) { else if(weights != weightsBroad) {
std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo());
((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false);
} }
else else
dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)); dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum));
@ -254,7 +254,7 @@ CUSTOM_OP_IMPL(log_loss_grad, 3, 3, false, 1, 1) {
dLdw->assign(E.reduceNumber(reduce::Sum) / numOfNonZeroWeights); dLdw->assign(E.reduceNumber(reduce::Sum) / numOfNonZeroWeights);
else if(weights != weightsBroad) { else if(weights != weightsBroad) {
std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo());
E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false);
*dLdw /= numOfNonZeroWeightsScalar; *dLdw /= numOfNonZeroWeightsScalar;
} }
else else

View File

@ -55,9 +55,9 @@ namespace ops {
NDArray E(labels->getShapeInfo(), block.getWorkspace()); NDArray E(labels->getShapeInfo(), block.getWorkspace());
if (computeFullLoss) if (computeFullLoss)
labels->applyPairwiseTransform(pairwise::LogPoissonLossFull, log_predictions, &E, nullptr); labels->applyPairwiseTransform(pairwise::LogPoissonLossFull, *log_predictions, E);
else else
labels->applyPairwiseTransform(pairwise::LogPoissonLoss, log_predictions, &E, nullptr); labels->applyPairwiseTransform(pairwise::LogPoissonLoss, *log_predictions, E);
// multiply E on weights // multiply E on weights
@ -176,13 +176,13 @@ namespace ops {
NDArray E(labels->getShapeInfo(), block.getWorkspace()); NDArray E(labels->getShapeInfo(), block.getWorkspace());
if (computeFullLoss) { if (computeFullLoss) {
labels->applyPairwiseTransform(pairwise::LogPoissonLossFull, log_predictions, &E, nullptr); labels->applyPairwiseTransform(pairwise::LogPoissonLossFull, *log_predictions, E);
NDArray rDiv(labels->getShapeInfo(), block.getWorkspace()); NDArray rDiv(labels->getShapeInfo(), block.getWorkspace());
labels->applyScalar(scalar::ReverseDivide, 0.5f, &rDiv); labels->applyScalar(scalar::ReverseDivide, 0.5f, rDiv);
dLdl->assign(rDiv + labels->transform(transform::Log) + -(*log_predictions)); dLdl->assign(rDiv + labels->transform(transform::Log) + -(*log_predictions));
} else { } else {
labels->applyPairwiseTransform(pairwise::LogPoissonLoss, log_predictions, &E, nullptr); labels->applyPairwiseTransform(pairwise::LogPoissonLoss, *log_predictions, E);
dLdl->assign(-(*log_predictions)); dLdl->assign(-(*log_predictions));
} }
@ -200,7 +200,7 @@ namespace ops {
dLdw->assign(E.reduceNumber(reduce::Sum)); dLdw->assign(E.reduceNumber(reduce::Sum));
else if(weights != weightsBroad) { else if(weights != weightsBroad) {
std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo());
E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false);
} }
else else
dLdw->assign(E); dLdw->assign(E);
@ -228,7 +228,7 @@ namespace ops {
*dLdw = 0.; *dLdw = 0.;
else if(weights != weightsBroad) { else if(weights != weightsBroad) {
std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo());
((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false);
} }
else else
dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)); dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum));
@ -257,7 +257,7 @@ namespace ops {
dLdw->assign(E.reduceNumber(reduce::Sum) / double(numOfNonZeroWeights)); dLdw->assign(E.reduceNumber(reduce::Sum) / double(numOfNonZeroWeights));
else if(weights != weightsBroad) { else if(weights != weightsBroad) {
std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo());
E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false);
*dLdw /= numOfNonZeroWeightsScalar; *dLdw /= numOfNonZeroWeightsScalar;
} }
else else

View File

@ -112,10 +112,10 @@ namespace nd4j {
auto n = double(labels->sizeAt(1)); auto n = double(labels->sizeAt(1));
auto diffs = *predictions - *labels; auto diffs = *predictions - *labels;
auto sumOfSquares = (diffs * diffs).reduceAlongDims(reduce::Sum, reductionIdx, true); auto sumOfSquares = (diffs * diffs).reduceAlongDimension(reduce::Sum, reductionIdx, true);
auto squareOfSum = diffs.reduceAlongDims(reduce::Sum, reductionIdx, true); auto squareOfSum = diffs.reduceAlongDimension(reduce::Sum, reductionIdx, true);
squareOfSum.applyScalar(scalar::Pow, 2); squareOfSum.applyScalar(scalar::Pow, 2, squareOfSum);
auto E = ((sumOfSquares * n) - squareOfSum) * (4/(n*(n-1))); auto E = ((sumOfSquares * n) - squareOfSum) * (4/(n*(n-1)));
@ -240,15 +240,15 @@ namespace nd4j {
auto diffs = *predictions - *labels; auto diffs = *predictions - *labels;
std::vector<int> reductionIdx = ShapeUtils::evalDimsToExclude(labels->rankOf(), {0}); std::vector<int> reductionIdx = ShapeUtils::evalDimsToExclude(labels->rankOf(), {0});
auto sumOfSquares = (diffs * diffs).reduceAlongDims(reduce::Sum, reductionIdx, true); auto sumOfSquares = (diffs * diffs).reduceAlongDimension(reduce::Sum, reductionIdx, true);
auto squareOfSum = diffs.reduceAlongDims(reduce::Sum, reductionIdx, true); auto squareOfSum = diffs.reduceAlongDimension(reduce::Sum, reductionIdx, true);
squareOfSum.applyScalar(scalar::Pow, 2); squareOfSum.applyScalar(scalar::Pow, 2, squareOfSum);
auto E = ((sumOfSquares * n) - squareOfSum) * (4/(n*(n-1))); auto E = ((sumOfSquares * n) - squareOfSum) * (4/(n*(n-1)));
auto sumPred = predictions->reduceAlongDims(reduce::Sum, reductionIdx, true); auto sumPred = predictions->reduceAlongDimension(reduce::Sum, reductionIdx, true);
auto sumLabel = labels->reduceAlongDims(reduce::Sum, reductionIdx, true); auto sumLabel = labels->reduceAlongDimension(reduce::Sum, reductionIdx, true);
dLdp->assign(((diffs * n) - sumPred + sumLabel)*(8/(n*(n-1)))); dLdp->assign(((diffs * n) - sumPred + sumLabel)*(8/(n*(n-1))));
@ -273,7 +273,7 @@ namespace nd4j {
dLdw->assign(E.reduceNumber(reduce::Sum)); dLdw->assign(E.reduceNumber(reduce::Sum));
else if(weights != weightsBroad) { else if(weights != weightsBroad) {
std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo());
E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false);
} }
else else
dLdw->assign(E); dLdw->assign(E);
@ -299,7 +299,7 @@ namespace nd4j {
*dLdw = 0.; *dLdw = 0.;
else if(weights != weightsBroad) { else if(weights != weightsBroad) {
std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo());
((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false);
} }
else else
dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)); dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum));
@ -327,7 +327,7 @@ namespace nd4j {
dLdw->assign(E.reduceNumber(reduce::Sum) / double(numOfNonZeroWeights)); dLdw->assign(E.reduceNumber(reduce::Sum) / double(numOfNonZeroWeights));
else if(weights != weightsBroad) { else if(weights != weightsBroad) {
std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo());
E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false);
*dLdw /= numOfNonZeroWeightsScalar; *dLdw /= numOfNonZeroWeightsScalar;
} }
else else

View File

@ -51,7 +51,7 @@ CUSTOM_OP_IMPL(mean_sqerr_loss, 3, 1, false, 0, 1) {
weightsBroad = new NDArray(weights->tileToShape(predictions->getShapeInfo())); weightsBroad = new NDArray(weights->tileToShape(predictions->getShapeInfo()));
NDArray E(labels->getShapeInfo(), false, block.launchContext()); NDArray E(labels->getShapeInfo(), false, block.launchContext());
predictions->applyPairwiseTransform(pairwise::SquaredSubtract, labels, &E, nullptr); predictions->applyPairwiseTransform(pairwise::SquaredSubtract, *labels, E);
// multiply E on weights // multiply E on weights
E *= (*weightsBroad); E *= (*weightsBroad);
@ -191,7 +191,7 @@ CUSTOM_OP_IMPL(mean_sqerr_loss_grad, 3, 3, false, 0, 1) {
dLdw->assign(E.reduceNumber(reduce::Sum)); dLdw->assign(E.reduceNumber(reduce::Sum));
else if(weights != weightsBroad) { else if(weights != weightsBroad) {
std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo());
E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false);
} }
else else
dLdw->assign(E); dLdw->assign(E);
@ -217,7 +217,7 @@ CUSTOM_OP_IMPL(mean_sqerr_loss_grad, 3, 3, false, 0, 1) {
*dLdw = 0.; *dLdw = 0.;
else if(weights != weightsBroad) { else if(weights != weightsBroad) {
std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo());
((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false);
} }
else else
dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)); dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum));
@ -245,7 +245,7 @@ CUSTOM_OP_IMPL(mean_sqerr_loss_grad, 3, 3, false, 0, 1) {
dLdw->assign(E.reduceNumber(reduce::Sum) / double(numOfNonZeroWeights)); dLdw->assign(E.reduceNumber(reduce::Sum) / double(numOfNonZeroWeights));
else if(weights != weightsBroad) { else if(weights != weightsBroad) {
std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo());
E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false);
*dLdw /= numOfNonZeroWeightsScalar; *dLdw /= numOfNonZeroWeightsScalar;
} }
else else

View File

@ -56,7 +56,7 @@ CUSTOM_OP_IMPL(sigm_cross_entropy_loss, 3, 1, false, 1, 1) {
auto newLabels = labels; auto newLabels = labels;
if(labelsSmoothing != 0.) { if(labelsSmoothing != 0.) {
newLabels = new NDArray(*labels); newLabels = new NDArray(*labels);
newLabels->applyScalar(scalar::SXELogitsSmoother, labelsSmoothing, newLabels, nullptr); newLabels->applyScalar(scalar::SXELogitsSmoother, labelsSmoothing, *newLabels);
} }
NDArray E(labels, false, block.launchContext()); NDArray E(labels, false, block.launchContext());
@ -186,7 +186,7 @@ CUSTOM_OP_IMPL(sigm_cross_entropy_loss_grad, 3, 3, false, 1, 1) {
auto newLabels = labels; auto newLabels = labels;
if(labelsSmoothing.e<float>(0) != 0.f) { if(labelsSmoothing.e<float>(0) != 0.f) {
newLabels = new NDArray(*labels); newLabels = new NDArray(*labels);
newLabels->applyScalar(scalar::SXELogitsSmoother, labelsSmoothing.e<float>(0), newLabels, nullptr); newLabels->applyScalar(scalar::SXELogitsSmoother, labelsSmoothing.e<float>(0), *newLabels);
} }
NDArray E(labels, false, block.launchContext()); NDArray E(labels, false, block.launchContext());
@ -211,7 +211,7 @@ CUSTOM_OP_IMPL(sigm_cross_entropy_loss_grad, 3, 3, false, 1, 1) {
dLdw->assign(E.reduceNumber(reduce::Sum)); dLdw->assign(E.reduceNumber(reduce::Sum));
else if(weights != weightsBroad) { else if(weights != weightsBroad) {
std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo());
E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false);
} }
else else
dLdw->assign(E); dLdw->assign(E);
@ -239,7 +239,7 @@ CUSTOM_OP_IMPL(sigm_cross_entropy_loss_grad, 3, 3, false, 1, 1) {
*dLdw = 0.; *dLdw = 0.;
else if(weights != weightsBroad) { else if(weights != weightsBroad) {
std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo());
((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false);
} }
else else
dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum * sum)); dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum * sum));
@ -267,7 +267,7 @@ CUSTOM_OP_IMPL(sigm_cross_entropy_loss_grad, 3, 3, false, 1, 1) {
dLdw->assign(E.reduceNumber(reduce::Sum) / numOfNonZeroWeightsScalar); dLdw->assign(E.reduceNumber(reduce::Sum) / numOfNonZeroWeightsScalar);
else if(weights != weightsBroad) { else if(weights != weightsBroad) {
std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo());
E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false);
*dLdw /= numOfNonZeroWeightsScalar; *dLdw /= numOfNonZeroWeightsScalar;
} }
else else

View File

@ -54,11 +54,11 @@ CUSTOM_OP_IMPL(softmax_cross_entropy_loss, 3, 1, false, 1, 1) {
// If label_smoothing is nonzero, smooth the labels towards 1/num_classes: new_onehot_labels = onehot_labels * (1 - label_smoothing) + label_smoothing / num_classes // If label_smoothing is nonzero, smooth the labels towards 1/num_classes: new_onehot_labels = onehot_labels * (1 - label_smoothing) + label_smoothing / num_classes
// num_classes = labels->sizeAt(1) // num_classes = labels->sizeAt(1)
auto cLabels = labels->cast(weights->dataType()); NDArray* cLabels = new NDArray(labels->cast(weights->dataType()));
auto newLabels = cLabels; NDArray* newLabels = cLabels;
if(labelsSmoothing != 0.) { if(labelsSmoothing != 0.) {
newLabels = new NDArray(cLabels); newLabels = new NDArray(cLabels);
*newLabels = (1.f - labelsSmoothing) * *cLabels + labelsSmoothing / cLabels->sizeAt(1); newLabels->assign((1.f - labelsSmoothing) * *cLabels + labelsSmoothing / cLabels->sizeAt(1));
} }
// main formula: result = - sum_i(lables_i * log(softmax_i)) - sum over last dimension // main formula: result = - sum_i(lables_i * log(softmax_i)) - sum over last dimension
@ -70,9 +70,9 @@ CUSTOM_OP_IMPL(softmax_cross_entropy_loss, 3, 1, false, 1, 1) {
std::vector<int> dimensions = {-1}; std::vector<int> dimensions = {-1};
NDArray shiftedLogits = *logits - logits->reduceAlongDims(reduce::Max, dimensions, true); NDArray shiftedLogits = *logits - logits->reduceAlongDimension(reduce::Max, dimensions, true);
NDArray logSumExp = shiftedLogits.transform(transform::Exp).reduceAlongDims(reduce::Sum, dimensions, true).transform(transform::Log); NDArray logSumExp = shiftedLogits.transform(transform::Exp).reduceAlongDimension(reduce::Sum, dimensions, true).transform(transform::Log);
NDArray E = (*newLabels * (logSumExp - shiftedLogits)).reduceAlongDims(reduce::Sum, dimensions); NDArray E = (*newLabels * (logSumExp - shiftedLogits)).reduceAlongDimension(reduce::Sum, dimensions);
// perform weights broadcasting/tile to E if it is necessary // perform weights broadcasting/tile to E if it is necessary
auto weightsBroad = weights; auto weightsBroad = weights;
@ -217,25 +217,25 @@ CUSTOM_OP_IMPL(softmax_cross_entropy_loss_grad, 3, 3, false, 1, 1) {
// If label_smoothing is nonzero, smooth the labels towards 1/num_classes: new_onehot_labels = onehot_labels * (1 - label_smoothing) + label_smoothing / num_classes // If label_smoothing is nonzero, smooth the labels towards 1/num_classes: new_onehot_labels = onehot_labels * (1 - label_smoothing) + label_smoothing / num_classes
// num_classes = labels->sizeAt(1) // num_classes = labels->sizeAt(1)
auto cLabels = labels->cast(weights->dataType()); NDArray* cLabels = new NDArray(labels->cast(weights->dataType()));
auto newLabels = cLabels; NDArray* newLabels = cLabels;
if(labelsSmoothing != 0.) { if(labelsSmoothing != 0.) {
newLabels = new NDArray(labels->getShapeInfo(), dLdl->dataType(), false, block.launchContext()); newLabels = new NDArray(labels->getShapeInfo(), dLdl->dataType(), false, block.launchContext());
newLabels->assign((1.f - labelsSmoothing) * *cLabels + labelsSmoothing / cLabels->sizeAt(1)); newLabels->assign((1.f - labelsSmoothing) * *cLabels + labelsSmoothing / cLabels->sizeAt(1));
} }
NDArray softmax = (*logits - logits->reduceAlongDims(reduce::Max, dimensions, true)).transform(transform::Exp); NDArray softmax = (*logits - logits->reduceAlongDimension(reduce::Max, dimensions, true)).transform(transform::Exp);
softmax /= softmax.reduceAlongDims(reduce::Sum, dimensions, true); softmax /= softmax.reduceAlongDimension(reduce::Sum, dimensions, true);
// dEdp = softmax * sum_i(lables_i) - labels // dEdp = softmax * sum_i(lables_i) - labels
dLdp->assign(softmax * newLabels->reduceAlongDims(reduce::Sum, dimensions, true) - *newLabels); dLdp->assign(softmax * newLabels->reduceAlongDimension(reduce::Sum, dimensions, true) - *newLabels);
// dEdl = -log(softmax) // dEdl = -log(softmax)
dLdl->assign(-softmax.transform(transform::Log)* (1.f - labelsSmoothing)); dLdl->assign(-softmax.transform(transform::Log)* (1.f - labelsSmoothing));
NDArray shiftedLogits = *logits - logits->reduceAlongDims(reduce::Max, dimensions, true); NDArray shiftedLogits = *logits - logits->reduceAlongDimension(reduce::Max, dimensions, true);
NDArray logSumExp = shiftedLogits.transform(transform::Exp).reduceAlongDims(reduce::Sum, dimensions, true).transform(transform::Log); NDArray logSumExp = shiftedLogits.transform(transform::Exp).reduceAlongDimension(reduce::Sum, dimensions, true).transform(transform::Log);
NDArray E = (*newLabels * (logSumExp - shiftedLogits)).reduceAlongDims(reduce::Sum, dimensions); NDArray E = (*newLabels * (logSumExp - shiftedLogits)).reduceAlongDimension(reduce::Sum, dimensions);
// perform weights broadcasting/tile to E if it is necessary // perform weights broadcasting/tile to E if it is necessary
auto weightsBroad = weights; auto weightsBroad = weights;
@ -253,12 +253,12 @@ CUSTOM_OP_IMPL(softmax_cross_entropy_loss_grad, 3, 3, false, 1, 1) {
*dLdl *= *weights; *dLdl *= *weights;
} }
else { else {
dLdp->applyBroadcast(nd4j::broadcast::Multiply, dimensions, weightsBroad); dLdp->applyBroadcast(nd4j::broadcast::Multiply, dimensions, *weightsBroad, *dLdp);
dLdl->applyBroadcast(nd4j::broadcast::Multiply, dimensions, weightsBroad); dLdl->applyBroadcast(nd4j::broadcast::Multiply, dimensions, *weightsBroad, *dLdl);
if(weights != weightsBroad) { if(weights != weightsBroad) {
std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo());
E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false);
} }
else else
dLdw->assign(E); dLdw->assign(E);
@ -289,12 +289,12 @@ CUSTOM_OP_IMPL(softmax_cross_entropy_loss_grad, 3, 3, false, 1, 1) {
else { else {
NDArray temp = *weightsBroad / sum; NDArray temp = *weightsBroad / sum;
dLdp->applyBroadcast(nd4j::broadcast::Multiply, dimensions, &temp); dLdp->applyBroadcast(nd4j::broadcast::Multiply, dimensions, temp, *dLdp);
dLdl->applyBroadcast(nd4j::broadcast::Multiply, dimensions, &temp); dLdl->applyBroadcast(nd4j::broadcast::Multiply, dimensions, temp, *dLdl);
if(weights != weightsBroad) { if(weights != weightsBroad) {
std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo());
((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false);
} }
else else
dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)); dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum));
@ -326,12 +326,12 @@ CUSTOM_OP_IMPL(softmax_cross_entropy_loss_grad, 3, 3, false, 1, 1) {
} }
else { else {
NDArray temp = *weightsBroad / numOfNonZeroWeights; NDArray temp = *weightsBroad / numOfNonZeroWeights;
dLdp->applyBroadcast(nd4j::broadcast::Multiply, dimensions, &temp); dLdp->applyBroadcast(nd4j::broadcast::Multiply, dimensions, temp, *dLdp);
dLdl->applyBroadcast(nd4j::broadcast::Multiply, dimensions, &temp); dLdl->applyBroadcast(nd4j::broadcast::Multiply, dimensions, temp, *dLdl);
if(weights != weightsBroad) { if(weights != weightsBroad) {
std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo());
E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false);
*dLdw /= numOfNonZeroWeights; *dLdw /= numOfNonZeroWeights;
} }
else else

View File

@ -41,11 +41,11 @@ CUSTOM_OP_IMPL(softmax_cross_entropy_loss_with_logits, 2, 1, false, 0, 0) {
std::vector<int> dimension = {classesDim}; std::vector<int> dimension = {classesDim};
auto maxAlongDim = logits->reduceAlongDims(reduce::Max, {classesDim}, true); auto maxAlongDim = logits->reduceAlongDimension(reduce::Max, {classesDim}, true);
auto logExp = (*logits - maxAlongDim).transform(transform::Exp); auto logExp = (*logits - maxAlongDim).transform(transform::Exp);
auto logSoftMax = ( logExp / logExp.reduceAlongDims(reduce::Sum, {classesDim}, true) ).transform(transform::Log); auto logSoftMax = ( logExp / logExp.reduceAlongDimension(reduce::Sum, {classesDim}, true) ).transform(transform::Log);
(-(*labels) * logSoftMax).reduceAlongDimension(reduce::Sum, output, dimension); (-(*labels) * logSoftMax).reduceAlongDimension(reduce::Sum, *output, dimension);
return Status::OK(); return Status::OK();
} }
@ -97,14 +97,14 @@ CUSTOM_OP_IMPL(softmax_cross_entropy_loss_with_logits_grad, 2, 2, false, 0, 0) {
std::vector<int> dimension = {classesDim}; std::vector<int> dimension = {classesDim};
NDArray softmax = (*logits - logits->reduceAlongDims(reduce::Max, dimension, true)).transform(transform::Exp); NDArray softmax = (*logits - logits->reduceAlongDimension(reduce::Max, dimension, true)).transform(transform::Exp);
softmax /= softmax.reduceAlongDims(reduce::Sum, dimension, true); softmax /= softmax.reduceAlongDimension(reduce::Sum, dimension, true);
// dEdp = softmax * sum_i(labels_i) - labels // dEdp = softmax * sum_i(labels_i) - labels
dLdp->assign(softmax * labels->reduceAlongDims(reduce::Sum, dimension, true) - *labels); dLdp->assign(softmax * labels->reduceAlongDimension(reduce::Sum, dimension, true) - *labels);
// dEdl = -log(softmax) // dEdl = -log(softmax)
(-softmax).applyTransform(transform::Log, dLdl); (-softmax).applyTransform(transform::Log, *dLdl);
return Status::OK(); return Status::OK();
} }

View File

@ -50,9 +50,9 @@ CUSTOM_OP_IMPL(sparse_softmax_cross_entropy_loss_with_logits, 2, 1, false, 0, 0)
std::vector<int> dimension = {-1}; std::vector<int> dimension = {-1};
auto maxAlongDim = logits->reduceAlongDims(reduce::Max, dimension, true); auto maxAlongDim = logits->reduceAlongDimension(reduce::Max, dimension, true);
auto logitsExp = (*logits - maxAlongDim).transform(transform::Exp, nullptr); auto logitsExp = (*logits - maxAlongDim).transform(transform::Exp, nullptr);
auto logSoftMax = -(( logitsExp / logitsExp.reduceAlongDims(reduce::Sum, dimension, true) ).transform(transform::Log)); auto logSoftMax = -(( logitsExp / logitsExp.reduceAlongDimension(reduce::Sum, dimension, true) ).transform(transform::Log));
helpers::scatterForLoss(block.launchContext(), *labels, logSoftMax, *output, false); helpers::scatterForLoss(block.launchContext(), *labels, logSoftMax, *output, false);
@ -117,8 +117,8 @@ CUSTOM_OP_IMPL(sparse_softmax_cross_entropy_loss_with_logits_grad, 2, 1, false,
std::vector<int> dimension = {-1}; std::vector<int> dimension = {-1};
NDArray softmax = (*logits - logits->reduceAlongDims(reduce::Max, dimension, true)).transform(transform::Exp); NDArray softmax = (*logits - logits->reduceAlongDimension(reduce::Max, dimension, true)).transform(transform::Exp);
softmax /= softmax.reduceAlongDims(reduce::Sum, dimension, true); softmax /= softmax.reduceAlongDimension(reduce::Sum, dimension, true);
// dEdp = softmax - 1 (or 0) // dEdp = softmax - 1 (or 0)
dLdp->assign(softmax); dLdp->assign(softmax);

View File

@ -229,19 +229,19 @@ CUSTOM_OP_IMPL(batchnorm_bp, 4, 3, false, 1, 2) {
// input - mean // input - mean
NDArray xMinusMean(input); // empty array with same shape as input NDArray xMinusMean(input); // empty array with same shape as input
input->applyBroadcast(nd4j::broadcast::Subtract, axes, mean, &xMinusMean); input->applyBroadcast(nd4j::broadcast::Subtract, axes, *mean, xMinusMean);
// stdInv // stdInv
NDArray stdInv = *variance + epsilon; NDArray stdInv = *variance + epsilon;
stdInv.applyTransform(transform::Reciprocal); // 1 / (variance + epsilon) stdInv.applyTransform(transform::Reciprocal, stdInv); // 1 / (variance + epsilon)
stdInv.applyTransform(transform::Sqrt); // 1 / (variance + epsilon)^0.5 stdInv.applyTransform(transform::Sqrt, stdInv); // 1 / (variance + epsilon)^0.5
// dvdm (use dLdM as storage for dvdm) // dvdm (use dLdM as storage for dvdm)
xMinusMean.reduceAlongDimension(nd4j::reduce::Sum, dLdM, excludedAxes, keepUnitiesInShape); xMinusMean.reduceAlongDimension(nd4j::reduce::Sum, *dLdM, excludedAxes, keepUnitiesInShape);
*dLdM *= -Ninv; *dLdM *= -Ninv;
// g_sum // g_sum
auto gSum = dLdO->reduceAlongDims(nd4j::reduce::Sum, excludedAxes, keepUnitiesInShape); auto gSum = dLdO->reduceAlongDimension(nd4j::reduce::Sum, excludedAxes, keepUnitiesInShape);
// dLdB // dLdB
if(applyOffset) if(applyOffset)
@ -249,11 +249,11 @@ CUSTOM_OP_IMPL(batchnorm_bp, 4, 3, false, 1, 2) {
// stdInv * (g - g_sum/N) (use dLdI as storage for this expression) // stdInv * (g - g_sum/N) (use dLdI as storage for this expression)
gSum *= Ninv; gSum *= Ninv;
dLdO->applyBroadcast(nd4j::broadcast::Subtract, axes, &gSum, dLdI); dLdO->applyBroadcast(nd4j::broadcast::Subtract, axes, gSum, *dLdI);
dLdI->applyBroadcast(nd4j::broadcast::Multiply, axes, &stdInv); dLdI->applyBroadcast(nd4j::broadcast::Multiply, axes, stdInv, *dLdI);
// dLdV <- [g*(x - m)]_sum // dLdV <- [g*(x - m)]_sum
(xMinusMean * *dLdO).reduceAlongDimension(nd4j::reduce::Sum, dLdV, excludedAxes, keepUnitiesInShape); (xMinusMean * *dLdO).reduceAlongDimension(nd4j::reduce::Sum, *dLdV, excludedAxes, keepUnitiesInShape);
// dLdG // dLdG
*dLdV *= stdInv; *dLdV *= stdInv;
@ -265,13 +265,13 @@ CUSTOM_OP_IMPL(batchnorm_bp, 4, 3, false, 1, 2) {
*dLdV *= -Ninv; // -0.5f * (2 / N); *dLdV *= -Ninv; // -0.5f * (2 / N);
// dfdv * (dvdm + (x - m)) (use xMinusMean as storage for this expression) // dfdv * (dvdm + (x - m)) (use xMinusMean as storage for this expression)
xMinusMean.applyBroadcast(nd4j::broadcast::Add, axes, dLdM); xMinusMean.applyBroadcast(nd4j::broadcast::Add, axes, *dLdM, xMinusMean);
xMinusMean.applyBroadcast(nd4j::broadcast::Multiply, axes, dLdV); xMinusMean.applyBroadcast(nd4j::broadcast::Multiply, axes, *dLdV, xMinusMean);
// dLdI // dLdI
*dLdI += xMinusMean; *dLdI += xMinusMean;
if(applyScale) if(applyScale)
dLdI->applyBroadcast(nd4j::broadcast::Multiply, axes, gamma); dLdI->applyBroadcast(nd4j::broadcast::Multiply, axes, *gamma, *dLdI);
*dLdM = 0; // put zeros so far *dLdM = 0; // put zeros so far
*dLdV = 0; // put zeros so far *dLdV = 0; // put zeros so far

View File

@ -240,7 +240,7 @@ CUSTOM_OP_IMPL(conv3dnew_bp, 3, 2, false, 0, 13) {
if(gradB) { if(gradB) {
if(gradB->rankOf() == 2) if(gradB->rankOf() == 2)
gradB = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()})); gradB = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}));
gradO->reduceAlongDimension(reduce::Sum, gradB, gradOaxesForDot); // sum over bS oD oH oW gradO->reduceAlongDimension(reduce::Sum, *gradB, gradOaxesForDot); // sum over bS oD oH oW
if(gradB != OUTPUT_VARIABLE(2)) if(gradB != OUTPUT_VARIABLE(2))
delete gradB; delete gradB;
} }

View File

@ -234,7 +234,7 @@ CUSTOM_OP_IMPL(deconv2d_bp, 3, 2, false, 0, 9) {
if(gradB) { if(gradB) {
if(gradB->rankOf() == 2) if(gradB->rankOf() == 2)
gradB = new NDArray(gradB->reshape(gradB->ordering(), {gradB->lengthOf()})); gradB = new NDArray(gradB->reshape(gradB->ordering(), {gradB->lengthOf()}));
gradO->reduceAlongDimension(reduce::Sum, gradB, {0, 2, 3}); // sum over bS, oH, oW gradO->reduceAlongDimension(reduce::Sum, *gradB, {0, 2, 3}); // sum over bS, oH, oW
if(gradB != OUTPUT_VARIABLE(2)) if(gradB != OUTPUT_VARIABLE(2))
delete gradB; delete gradB;
} }

View File

@ -244,7 +244,7 @@ CUSTOM_OP_IMPL(deconv3d_bp, 3, 2, false, 0, 13) {
if(gradB) { if(gradB) {
if(gradB->rankOf() == 2) if(gradB->rankOf() == 2)
gradB = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()})); gradB = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}));
gradO->reduceAlongDimension(reduce::Sum, gradB, {0, 2, 3, 4}); // sum over bS, oD, oH, oW gradO->reduceAlongDimension(reduce::Sum, *gradB, {0, 2, 3, 4}); // sum over bS, oD, oH, oW
if(gradB != OUTPUT_VARIABLE(2)) if(gradB != OUTPUT_VARIABLE(2))
delete gradB; delete gradB;
} }

View File

@ -84,7 +84,7 @@ CUSTOM_OP_IMPL(fused_batch_norm, 3, 3, false, 0, 2) {
epsilon = 0.001; epsilon = 0.001;
const int restSize = x->lengthOf() / iD; const int restSize = x->lengthOf() / iD;
auto xAffected = NDArrayFactory::create(x->ordering(), {restSize, iD}, x->dataType(), block.launchContext()); auto xAffected = NDArrayFactory::create(x->ordering(), {restSize, iD}, mean->dataType(), block.launchContext());
xAffected.assign(x); xAffected.assign(x);
const int restSizeMinusOne = (restSize > 1) ? (restSize - 1) : 1; const int restSizeMinusOne = (restSize > 1) ? (restSize - 1) : 1;
@ -93,7 +93,7 @@ CUSTOM_OP_IMPL(fused_batch_norm, 3, 3, false, 0, 2) {
const double restSizeAdjust = (double)restSize / restSizeMinusOne; const double restSizeAdjust = (double)restSize / restSizeMinusOne;
if(isTraining) { if(isTraining) {
auto sum = xAffected.reduceAlongDims(reduce::Sum, {0}); auto sum = xAffected.reduceAlongDimension(reduce::Sum, {0});
sum *= restSizeInv; sum *= restSizeInv;
mean->assign(sum); mean->assign(sum);
*batchMean = *mean; *batchMean = *mean;
@ -106,8 +106,8 @@ CUSTOM_OP_IMPL(fused_batch_norm, 3, 3, false, 0, 2) {
if(isTraining) { if(isTraining) {
int power = 2; int power = 2;
xAffected.applyScalar(scalar::Pow, power); xAffected.applyScalar(scalar::Pow, power, xAffected);
auto sum = xAffected.reduceAlongDims(reduce::Sum, {0}); auto sum = xAffected.reduceAlongDimension(reduce::Sum, {0});
sum *= restSizeInv; sum *= restSizeInv;
variance->assign(sum); variance->assign(sum);
*batchVar = (*variance) * restSizeAdjust; *batchVar = (*variance) * restSizeAdjust;

View File

@ -68,7 +68,7 @@ CONFIGURABLE_OP_IMPL(log_softmax_bp, 2, 1, true, 0, 0) {
helpers::softmax(block.launchContext(), *input, *gradI, dim); helpers::softmax(block.launchContext(), *input, *gradI, dim);
gradI->assign( *gradO - (*gradI * *gradO).reduceAlongDims(reduce::Sum, {dim}, true) ); gradI->assign( *gradO - (*gradI * *gradO).reduceAlongDimension(reduce::Sum, {dim}, true) );
return Status::OK(); return Status::OK();
} }

View File

@ -46,7 +46,7 @@ namespace nd4j {
auto scalar = block.numT() > 0 ? block.getTArguments()->at(0) : 0.0; auto scalar = block.numT() > 0 ? block.getTArguments()->at(0) : 0.0;
auto xw = result->at(0); auto xw = result->at(0);
xw->applyScalar(nd4j::scalar::RELU, scalar, output); xw->applyScalar(nd4j::scalar::RELU, scalar, *output);
return Status::OK(); return Status::OK();
} }

View File

@ -62,7 +62,7 @@ CONFIGURABLE_OP_IMPL(softmax_bp, 2, 1, true, 0, 0) {
helpers::softmax(block.launchContext(), *input, *gradI, dim); helpers::softmax(block.launchContext(), *input, *gradI, dim);
auto sumAlongDim = (*gradI * *gradO).reduceAlongDims(reduce::Sum, {dim}, true); auto sumAlongDim = (*gradI * *gradO).reduceAlongDimension(reduce::Sum, {dim}, true);
gradI->assign(*gradI * (*gradO - sumAlongDim)); gradI->assign(*gradI * (*gradO - sumAlongDim));
return Status::OK(); return Status::OK();

View File

@ -56,7 +56,7 @@ CONFIGURABLE_OP_IMPL(adjust_contrast, 1, 1, true, 0, 0) {
axes[i] = i; axes[i] = i;
// mean as reduction for last dimension set // mean as reduction for last dimension set
auto mean = input->reduceAlongDims(reduce::Mean, axes); auto mean = input->reduceAlongDimension(reduce::Mean, axes);
// this is contrast calculation // this is contrast calculation
output->assign((*input - mean) * (*factor) + mean); output->assign((*input - mean) * (*factor) + mean);
@ -104,13 +104,13 @@ CONFIGURABLE_OP_IMPL(adjust_contrast_v2, 1, 1, true, 0, 0) {
std::vector<int> axes({1}); // dim 1 of pseudoresult std::vector<int> axes({1}); // dim 1 of pseudoresult
// mean as reduction for last dimension set over size (dim 1) of result3D // mean as reduction for last dimension set over size (dim 1) of result3D
auto mean = input3D.reduceAlongDims(reduce::Mean, axes); auto mean = input3D.reduceAlongDimension(reduce::Mean, axes);
// result as (x - mean) * factor + mean // result as (x - mean) * factor + mean
auto temp = input3D.ulike(); auto temp = input3D.ulike();
input3D.applyBroadcast(broadcast::Subtract, {0, 2}, &mean, &temp, nullptr); input3D.applyBroadcast(broadcast::Subtract, {0, 2}, mean, temp);
temp.applyScalarArr(scalar::Multiply, factor); temp.applyScalarArr(scalar::Multiply, *factor, temp);
temp.applyBroadcast(broadcast::Add, {0, 2}, &mean, &output3D); temp.applyBroadcast(broadcast::Add, {0, 2}, mean, output3D);
output->assign(output3D); output->assign(output3D);
if(block.width() == 1) if(block.width() == 1)
delete factor; delete factor;

View File

@ -44,11 +44,11 @@ namespace nd4j {
auto axisVector = INPUT_VARIABLE(1); auto axisVector = INPUT_VARIABLE(1);
helpers::adjustAxis(input->rankOf(), axisVector, axis); helpers::adjustAxis(input->rankOf(), axisVector, axis);
input->applyIndexReduce(indexreduce::IndexMax, output, axis); input->applyIndexReduce(indexreduce::IndexMax, *output, axis);
} else { } else {
helpers::adjustAxis(input->rankOf(), axis); helpers::adjustAxis(input->rankOf(), axis);
input->applyIndexReduce(indexreduce::IndexMax, output, axis); input->applyIndexReduce(indexreduce::IndexMax, *output, axis);
} }
STORE_RESULT(output); STORE_RESULT(output);

View File

@ -44,11 +44,11 @@ namespace nd4j {
auto axisVector = INPUT_VARIABLE(1); auto axisVector = INPUT_VARIABLE(1);
helpers::adjustAxis(input->rankOf(), axisVector, axis); helpers::adjustAxis(input->rankOf(), axisVector, axis);
input->applyIndexReduce(indexreduce::IndexMin, output, axis); input->applyIndexReduce(indexreduce::IndexMin, *output, axis);
} else { } else {
helpers::adjustAxis(input->rankOf(), axis); helpers::adjustAxis(input->rankOf(), axis);
input->applyIndexReduce(indexreduce::IndexMin, output, axis); input->applyIndexReduce(indexreduce::IndexMin, *output, axis);
} }
STORE_RESULT(output); STORE_RESULT(output);

View File

@ -82,7 +82,7 @@ CUSTOM_OP_IMPL(biasadd_bp, 3, 2, false, 0, 0) {
gradI->assign(gradO); gradI->assign(gradO);
gradO->reduceAlongDimension(nd4j::reduce::Sum, gradB, ShapeUtils::evalDimsToExclude(gradO->rankOf(), {channelDim})); gradO->reduceAlongDimension(nd4j::reduce::Sum, *gradB, ShapeUtils::evalDimsToExclude(gradO->rankOf(), {channelDim}));
return ND4J_STATUS_OK; return ND4J_STATUS_OK;
} }

View File

@ -45,7 +45,7 @@ CUSTOM_OP_IMPL(embedding_lookup, 2, 1, false, 0, 1) {
v = i++; v = i++;
} }
std::unique_ptr<ResultSet> outputView(output->allTensorsAlongDimension(dims)); ResultSet outputView = output->allTensorsAlongDimension(dims);
REQUIRE_TRUE(block.width() > output->sizeAt(0), 0, "embedding_lookup: input list should be greater then %i, but %i given.", REQUIRE_TRUE(block.width() > output->sizeAt(0), 0, "embedding_lookup: input list should be greater then %i, but %i given.",
output->sizeAt(0), block.width() output->sizeAt(0), block.width()
); );
@ -53,7 +53,7 @@ CUSTOM_OP_IMPL(embedding_lookup, 2, 1, false, 0, 1) {
Nd4jLong thisIndex = (*indeces).e<Nd4jLong>(e); Nd4jLong thisIndex = (*indeces).e<Nd4jLong>(e);
input = INPUT_VARIABLE(thisIndex); // lookup param input = INPUT_VARIABLE(thisIndex); // lookup param
outputView->at(e)->assign(input); outputView.at(e)->assign(input);
} }
} }
else { else {

View File

@ -49,8 +49,8 @@ namespace nd4j {
} }
std::vector<int>& dims = axis; std::vector<int>& dims = axis;
input->varianceAlongDimension(variance::SummaryStatsVariance, variances, false, axis); input->varianceAlongDimension(variance::SummaryStatsVariance, *variances, false, axis);
input->reduceAlongDimension(reduce::Mean, means, axis, keepDims); input->reduceAlongDimension(reduce::Mean, *means, axis, keepDims);
return Status::OK(); return Status::OK();
} }

View File

@ -52,31 +52,31 @@ namespace nd4j {
case 0: { case 0: {
REQUIRE_TRUE(dims.size() == 2 || (input->rankOf() == 2 && dims.size() == 0), 0, "Norm: Frobenius is defined for 2D matrices or TADS only"); REQUIRE_TRUE(dims.size() == 2 || (input->rankOf() == 2 && dims.size() == 0), 0, "Norm: Frobenius is defined for 2D matrices or TADS only");
// fro // fro
input->reduceAlongDimension(reduce::NormFrobenius, output, dims, false, output->rankOf() == 2); input->reduceAlongDimension(reduce::NormFrobenius, *output, dims, false, output->rankOf() == 2);
} }
break; break;
case 1: { case 1: {
// euclidean // euclidean
if ((input->rankOf() == 2 && dims.size() == 0) || dims.size() == 2) { if ((input->rankOf() == 2 && dims.size() == 0) || dims.size() == 2) {
input->reduceAlongDimension(reduce::NormFrobenius, output, dims, false, output->rankOf() == 2); input->reduceAlongDimension(reduce::NormFrobenius, *output, dims, false, output->rankOf() == 2);
} else { } else {
input->reduceAlongDimension(reduce::Norm2, output, dims, false, output->rankOf() == 2); input->reduceAlongDimension(reduce::Norm2, *output, dims, false, output->rankOf() == 2);
} }
} }
break; break;
case 2: { case 2: {
// 1 // 1
input->reduceAlongDimension(reduce::Norm1, output, dims, false, output->rankOf() == 2); input->reduceAlongDimension(reduce::Norm1, *output, dims, false, output->rankOf() == 2);
} }
break; break;
case 3: { case 3: {
// 2 // 2
input->reduceAlongDimension(reduce::Norm2, output, dims, false, output->rankOf() == 2); input->reduceAlongDimension(reduce::Norm2, *output, dims, false, output->rankOf() == 2);
} }
break; break;
case 4: { case 4: {
// inf-norm // inf-norm
input->reduceAlongDimension(reduce::NormMax, output, dims, false, output->rankOf() == 2); input->reduceAlongDimension(reduce::NormMax, *output, dims, false, output->rankOf() == 2);
} }
break; break;
default: { default: {
@ -84,7 +84,7 @@ namespace nd4j {
REQUIRE_TRUE(block.getIArguments()->size() > 1, 0, "P-Norm reductions requires 2 TArguments, but only 1 was provided"); REQUIRE_TRUE(block.getIArguments()->size() > 1, 0, "P-Norm reductions requires 2 TArguments, but only 1 was provided");
// FIXME: p is required here // FIXME: p is required here
//T p = T_ARG(1); //T p = T_ARG(1);
input->reduceAlongDimension(reduce::NormP, output, dims, false, output->rankOf() == 2); input->reduceAlongDimension(reduce::NormP, *output, dims, false, output->rankOf() == 2);
} }
} }

View File

@ -40,23 +40,20 @@ namespace nd4j {
shift.assign(T_ARG(0)); shift.assign(T_ARG(0));
} }
means->applyScalarArr(scalar::Divide, counts, resMeans, nullptr); means->applyScalarArr(scalar::Divide, *counts, *resMeans);
NDArray* squareMeans = resMeans->dup('c'); NDArray squareMeans = resMeans->dup('c');
NDArray* tempVariances = resVariances->dup('c'); NDArray tempVariances = resVariances->dup('c');
squareMeans->applyTransform(transform::Square, squareMeans, nullptr); squareMeans.applyTransform(transform::Square, squareMeans, nullptr);
variances->applyScalarArr(scalar::Divide, counts, tempVariances, nullptr); variances->applyScalarArr(scalar::Divide, *counts, tempVariances);
// tempVariances->printIndexedBuffer("varianced divided by count"); // tempVariances.printIndexedBuffer("varianced divided by count");
tempVariances->applyPairwiseTransform(pairwise::Subtract, squareMeans, resVariances, nullptr); tempVariances.applyPairwiseTransform(pairwise::Subtract, squareMeans, *resVariances);
if (shift.e<double>(0) != 0) { if (shift.e<double>(0) != 0) {
resMeans->applyScalarArr(scalar::Add, &shift, resMeans, nullptr); resMeans->applyScalarArr(scalar::Add, shift, *resMeans);
} }
delete squareMeans;
delete tempVariances;
return Status::OK(); return Status::OK();
} }

View File

@ -47,7 +47,7 @@ CUSTOM_OP_IMPL(reduce_mean, 1, 1, false, 0, 0) {
for(const auto& item : dimensions) for(const auto& item : dimensions)
REQUIRE_TRUE(item >= -input->rankOf() && item < input->rankOf(), 0, "REDUCE_MEAN OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item); REQUIRE_TRUE(item >= -input->rankOf() && item < input->rankOf(), 0, "REDUCE_MEAN OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item);
input->reduceAlongDimension(reduce::Mean, output, dimensions, keepDims); input->reduceAlongDimension(reduce::Mean, *output, dimensions, keepDims);
return Status::OK(); return Status::OK();
} }

View File

@ -55,7 +55,7 @@ CUSTOM_OP_IMPL(reduce_stdev, 1, 1, false, 0, 0) {
for(const auto& item : dimensions) for(const auto& item : dimensions)
REQUIRE_TRUE(item >= -input->rankOf() && item < input->rankOf(), 0, "REDUCE_STDEV OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item); REQUIRE_TRUE(item >= -input->rankOf() && item < input->rankOf(), 0, "REDUCE_STDEV OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item);
input->varianceAlongDimension(variance::SummaryStatsStandardDeviation, output, biasCorrected, dimensions); input->varianceAlongDimension(variance::SummaryStatsStandardDeviation, *output, biasCorrected, dimensions);
return Status::OK(); return Status::OK();
} }
@ -130,10 +130,10 @@ CUSTOM_OP_IMPL(reduce_stdev_bp, 2, 1, false, 0, 0) {
const Nd4jLong N = input->lengthOf() / gradO->lengthOf(); const Nd4jLong N = input->lengthOf() / gradO->lengthOf();
const Nd4jLong NminusOne = biasCorrected ? N - 1 : N; const Nd4jLong NminusOne = biasCorrected ? N - 1 : N;
auto mean = input->reduceAlongDims(reduce::Mean, dimensions, true); auto mean = input->reduceAlongDimension(reduce::Mean, dimensions, true);
NDArray variance(mean.getShapeInfo(), true, block.launchContext()); // create empty array with shape matching shape of mean array NDArray variance(mean.getShapeInfo(), true, block.launchContext()); // create empty array with shape matching shape of mean array
input->varianceAlongDimension(variance::SummaryStatsStandardDeviation, &variance, biasCorrected, dimensions); input->varianceAlongDimension(variance::SummaryStatsStandardDeviation, variance, biasCorrected, dimensions);
gradI->assign( (*input - mean) / (variance * NminusOne)); // automatic broadcasting happens here gradI->assign( (*input - mean) / (variance * NminusOne)); // automatic broadcasting happens here

View File

@ -55,7 +55,7 @@ CUSTOM_OP_IMPL(reduce_variance, 1, 1, false, 0, 0) {
for(const auto& item : dimensions) for(const auto& item : dimensions)
REQUIRE_TRUE(item >= -input->rankOf() && item < input->rankOf(), 0, "REDUCE_VARIANCE OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item); REQUIRE_TRUE(item >= -input->rankOf() && item < input->rankOf(), 0, "REDUCE_VARIANCE OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item);
input->varianceAlongDimension(variance::SummaryStatsVariance, output, biasCorrected, dimensions); input->varianceAlongDimension(variance::SummaryStatsVariance, *output, biasCorrected, dimensions);
return Status::OK(); return Status::OK();
} }
@ -129,7 +129,7 @@ CUSTOM_OP_IMPL(reduce_variance_bp, 2, 1, false, 0, 0) {
const double factor1 = 2.0 / NminusOne; const double factor1 = 2.0 / NminusOne;
const double factor2 = 2.0 / (N * NminusOne); const double factor2 = 2.0 / (N * NminusOne);
auto mean = input->reduceAlongDims(reduce::Mean, dimensions, true); auto mean = input->reduceAlongDimension(reduce::Mean, dimensions, true);
gradI->assign( (*input - mean) * (2.0f / NminusOne)); // automatic broadcasting happens here gradI->assign( (*input - mean) * (2.0f / NminusOne)); // automatic broadcasting happens here

View File

@ -45,9 +45,9 @@ namespace ops {
//void* whereMax = (void*)(); //void* whereMax = (void*)();
auto internal = (*input); auto internal = (*input);
internal -= maxVals; internal -= maxVals;
internal.applyTransform(transform::Exp, nullptr, nullptr); internal.applyTransform(transform::Exp, internal);
internal.reduceAlongDimension(reduce::Sum, output, axes, keepDims, false); //, (void*)&maxVals); internal.reduceAlongDimension(reduce::Sum, *output, axes, keepDims, false); //, (void*)&maxVals);
output->applyTransform(transform::Log, nullptr, nullptr); output->applyTransform(transform::Log, *output);
(*output) += maxVals; (*output) += maxVals;
return ND4J_STATUS_OK; return ND4J_STATUS_OK;
} }

View File

@ -52,7 +52,7 @@ CUSTOM_OP_IMPL(reduce_max, 1, 1, false, 0, 0) {
else if (block.getTArguments()->size() > 0) else if (block.getTArguments()->size() > 0)
keepDims = (bool)T_ARG(0); keepDims = (bool)T_ARG(0);
input->reduceAlongDimension(reduce::Max, output, dimensions, keepDims); input->reduceAlongDimension(reduce::Max, *output, dimensions, keepDims);
return Status::OK(); return Status::OK();
} }
@ -122,8 +122,7 @@ CUSTOM_OP_IMPL(reduce_max_bp, 2, 1, false, 0, 0) {
else { else {
auto indicesArr = input->applyIndexReduce(nd4j::indexreduce::IndexMax, dimensions); auto indicesArr = input->applyIndexReduce(nd4j::indexreduce::IndexMax, dimensions);
helpers::scatterSimple(block.launchContext(), 6, *gradI, *gradO, *indicesArr, ShapeUtils::evalDimsToExclude(gradI->rankOf(), dimensions)); // 6 corresponds to copy operation helpers::scatterSimple(block.launchContext(), 6, *gradI, *gradO, indicesArr, ShapeUtils::evalDimsToExclude(gradI->rankOf(), dimensions)); // 6 corresponds to copy operation
delete indicesArr;
} }
return Status::OK(); return Status::OK();

View File

@ -52,7 +52,7 @@ CUSTOM_OP_IMPL(reduce_min, 1, 1, false, 0, 0) {
else if (block.getTArguments()->size() > 0) else if (block.getTArguments()->size() > 0)
keepDims = (bool)T_ARG(0); keepDims = (bool)T_ARG(0);
input->reduceAlongDimension(reduce::Min, output, dimensions, keepDims); input->reduceAlongDimension(reduce::Min, *output, dimensions, keepDims);
return Status::OK(); return Status::OK();
} }
@ -125,8 +125,7 @@ CUSTOM_OP_IMPL(reduce_min_bp, 2, 1, false, 0, 0) {
else { else {
auto indicesArr = input->applyIndexReduce(nd4j::indexreduce::IndexMin, dimensions); auto indicesArr = input->applyIndexReduce(nd4j::indexreduce::IndexMin, dimensions);
helpers::scatterSimple(block.launchContext(), 6, *gradI, *gradO, *indicesArr, ShapeUtils::evalDimsToExclude(gradI->rankOf(), dimensions)); // 6 corresponds to copy operation helpers::scatterSimple(block.launchContext(), 6, *gradI, *gradO, indicesArr, ShapeUtils::evalDimsToExclude(gradI->rankOf(), dimensions)); // 6 corresponds to copy operation
delete indicesArr;
} }
return Status::OK(); return Status::OK();

Some files were not shown because too many files have changed in this diff Show More