Shyrma temp (#131)

* - specifying template instantiation for certain types in float16 and bloat16

Signed-off-by: Yurii <iuriish@yahoo.com>

* - polishing bfloat16 and float16 member functions template specialization

Signed-off-by: Yurii <iuriish@yahoo.com>

* - rewrite and overload array +-*/ scalar and scalar +-*/ arr in NDAray class

Signed-off-by: Yurii <iuriish@yahoo.com>

* - make corrections which have to do with and rvalue lvalue conversions

Signed-off-by: Yurii <iuriish@yahoo.com>

* - provide move semantic in NDArray operators array +-/* array

Signed-off-by: Yurii <iuriish@yahoo.com>

* float16/bfloat16 tweaks

Signed-off-by: raver119 <raver119@gmail.com>

* one more tweak

Signed-off-by: raver119 <raver119@gmail.com>

* - make float16 and bfloat16 to compile successfully on cuda

Signed-off-by: Yurii <iuriish@yahoo.com>

* - do not use resources of view-like arrays when move semantics is applied

Signed-off-by: Yurii <iuriish@yahoo.com>

* - get rid of pointers in signatures NDArray methods 1

Signed-off-by: Yurii <iuriish@yahoo.com>

* - correction of signature of NDArray::dup method

Signed-off-by: Yurii <iuriish@yahoo.com>

* - correction of signature of NDArray::reduceAlongDimension method

Signed-off-by: Yurii <iuriish@yahoo.com>

* - signature correction of NDArray::applyIndexReduce and applyTrueBroadcast methods

Signed-off-by: Yurii <iuriish@yahoo.com>

* - signature correction of NDArray::applyReduce3 and varianceAlongDimension methods

Signed-off-by: Yurii <iuriish@yahoo.com>

* - signature correction of NDArray::tensorsAlongDimension and diagonal methods

Signed-off-by: Yurii <iuriish@yahoo.com>

* - signature correction of NDArray::allTensorsAlongDimension

Signed-off-by: Yurii <iuriish@yahoo.com>

* - signature correction of NDArray::reduceAlongDimension 2

Signed-off-by: Yurii <iuriish@yahoo.com>

* - signature correction of NDArray::applyTransform 2

Signed-off-by: Yurii <iuriish@yahoo.com>

* - signature correction of NDArray::applyPairwiseTransform 2

Signed-off-by: Yurii <iuriish@yahoo.com>

* - signature correction of NDArray::applyBroadcast 2

Signed-off-by: Yurii <iuriish@yahoo.com>

* - signature correction of NDArray::applyTrueBroadcast 2

Signed-off-by: Yurii <iuriish@yahoo.com>

* - signature correction of NDArray::applyScalar and applyScalarArr

Signed-off-by: Yurii <iuriish@yahoo.com>

* - signature correction of NDArray::lambda methods

Signed-off-by: Yurii <iuriish@yahoo.com>

* - signature correction of NDArray::reduce3 methods 2

Signed-off-by: Yurii <iuriish@yahoo.com>

* - signature correction of following NDArray methods: add/sub/mul/div row/column and fillAsTriangular

Signed-off-by: Yurii <iuriish@yahoo.com>

* - signature correction of NDArray::tileToShape methods

Signed-off-by: Yurii <iuriish@yahoo.com>

* - signature correction of NDArray::isShapeSameStrict method

Signed-off-by: Yurii <iuriish@yahoo.com>

* minor corrections in tests

Signed-off-by: Yurii <iuriish@yahoo.com>

* - replace reduce op in batchnorm mkldnn

Signed-off-by: Yurii <iuriish@yahoo.com>

* - add explicit templates instantiations for operator+(NDArray&&. const scalar)

Signed-off-by: Yurii <iuriish@yahoo.com>

* - corrections of casts in float16/bfloat16

Signed-off-by: Yurii <iuriish@yahoo.com>

* - provide move semantics in following NDArray methods: transform, applyTrueBroadcast, transpose, reshape, permute

Signed-off-by: Yurii <iuriish@yahoo.com>

* - get rid of input array A duplicate in svd cuda op

Signed-off-by: Yurii <iuriish@yahoo.com>

* - avoid available bug in svd cuda API

Signed-off-by: Yurii <iuriish@yahoo.com>

* - add temporary global memory buffer in svd cuda when calcUV = false and  m != n

Signed-off-by: Yurii <iuriish@yahoo.com>

* - remove test with blfoat16 type for betainC

Signed-off-by: Yurii <iuriish@yahoo.com>

* - resolve conflicts after master has been merged in

Signed-off-by: Yurii <iuriish@yahoo.com>

* - changed type of affected input array in fused_batch_norm

Signed-off-by: Yurii <iuriish@yahoo.com>

* - add several explicit type castings

Signed-off-by: Yurii <iuriish@yahoo.com>

* - add ND4J_EXPORT to operators

Signed-off-by: Yurii <iuriish@yahoo.com>

* - add explicit template types in instantiations of template arithm operators of NDArray class

Signed-off-by: Yurii <iuriish@yahoo.com>

* - one more test fix

Signed-off-by: Yurii <iuriish@yahoo.com>

Co-authored-by: raver119 <raver119@gmail.com>
master
Yurii Shyrma 2019-12-20 21:35:39 +02:00 committed by raver119
parent 3e0afadea1
commit 5d9b2a16e5
237 changed files with 5235 additions and 6513 deletions

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -104,7 +104,7 @@ namespace graph {
if (node->id() == 13) if (node->id() == 13)
nd4j_debug("",""); nd4j_debug("","");
// if true - this is special case: Graph-in-Graph. // if true - this is special case: Graph-in-Graph.
if (node->hasGraphEmbedded()) { if (node->hasGraphEmbedded()) {
auto embedded = node->getGraph(); auto embedded = node->getGraph();
@ -128,12 +128,12 @@ namespace graph {
int cnt = 0; int cnt = 0;
for (Variable* v: *embedded->getPlaceholders()) { for (Variable* v: *embedded->getPlaceholders()) {
if (v->getName() != nullptr && v->getName()->size() > 0) { if (v->getName() != nullptr && v->getName()->size() > 0) {
// trying symbolic lookup first // trying symbolic lookup first
if (variableSpace->hasVariable(v->getName())) { if (variableSpace->hasVariable(v->getName())) {
// symbolic feeder // symbolic feeder
auto array = variableSpace->getVariable(v->getName())->getNDArray(); auto array = variableSpace->getVariable(v->getName())->getNDArray();
auto vr = array->dup(); auto vr = new NDArray(array->dup());
// deletables.push_back(vr); // deletables.push_back(vr);
v->setNDArray(vr); v->setNDArray(vr);
} else { } else {
@ -145,7 +145,7 @@ namespace graph {
// if we're not using symbolic lookup - we'll use sequential approach then // if we're not using symbolic lookup - we'll use sequential approach then
auto p = node->input()->at(cnt); auto p = node->input()->at(cnt);
auto array = variableSpace->getVariable(p)->getNDArray(); auto array = variableSpace->getVariable(p)->getNDArray();
auto vr = array->dup(); auto vr = new NDArray(array->dup());
//deletables.push_back(vr); //deletables.push_back(vr);
v->setNDArray(vr); v->setNDArray(vr);
} }
@ -501,7 +501,7 @@ Nd4jStatus GraphExecutioner::execute(Graph *graph, VariableSpace* variableSpace)
} }
/** /**
* This method is provided for IPC: * This method is provided for IPC:
* 1) it accepts pointer to FlatBuffers buffer * 1) it accepts pointer to FlatBuffers buffer
* 2) restores Graph from it * 2) restores Graph from it
* 3) Executes this Graph * 3) Executes this Graph

View File

@ -71,44 +71,41 @@ void NDArray::makeBothBuffersActual() const { }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
template <typename T> template <typename T>
void NDArray::fillAsTriangular(const float val, int lower, int upper, const char direction, NDArray* target) { void NDArray::fillAsTriangular(const float val, int lower, int upper, NDArray& target, const char direction) {
if (isS()) if (isS())
throw std::runtime_error("NDArray::fillArrayAsTriangular: you can't use this method on String array!"); throw std::runtime_error("NDArray::fillArrayAsTriangular: you can't use this method on String array!");
if(target == nullptr) if(!isSameShape(target) && !(rankOf() == 1 && target.rankOf() == 2 && sizeAt(0) == target.sizeAt(0) && sizeAt(0) == target.sizeAt(1)))
target = this;
if(!isSameShape(target) && !(rankOf() == 1 && target->rankOf() == 2 && sizeAt(0) == target->sizeAt(0) && sizeAt(0) == target->sizeAt(1)))
throw std::string("NDArray::fillArrayAsTriangular method: wrong shape of target array !"); throw std::string("NDArray::fillArrayAsTriangular method: wrong shape of target array !");
if (direction == 'u') if (direction == 'u')
lower = -target->sizeAt(-2); lower = -target.sizeAt(-2);
else if (direction == 'l') else if (direction == 'l')
upper = target->sizeAt(-1); upper = target.sizeAt(-1);
const T value = static_cast<T>(val); const T value = static_cast<T>(val);
const auto x = reinterpret_cast<const T*>(getBuffer()); const auto x = reinterpret_cast<const T*>(getBuffer());
auto z = reinterpret_cast<T*>(target->getBuffer()); auto z = reinterpret_cast<T*>(target.getBuffer());
const int xRank = rankOf(); const int xRank = rankOf();
const int zRank = target->rankOf(); const int zRank = target.rankOf();
const auto zLen = target->lengthOf(); const auto zLen = target.lengthOf();
const bool areSameOffsets = shape::haveSameShapeAndStrides(getShapeInfo(), target->getShapeInfo()); const bool areSameOffsets = shape::haveSameShapeAndStrides(getShapeInfo(), target.getShapeInfo());
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
Nd4jLong coords[MAX_RANK]; Nd4jLong coords[MAX_RANK];
for (auto i = start; i < stop; i += increment) { for (auto i = start; i < stop; i += increment) {
shape::index2coords(i, target->getShapeInfo(), coords); shape::index2coords(i, target.getShapeInfo(), coords);
const auto zOffset = shape::getOffset(target->getShapeInfo(), coords); const auto zOffset = shape::getOffset(target.getShapeInfo(), coords);
// if( (row + upper < col) || (row + lower > col) ) // if( (row + upper < col) || (row + lower > col) )
if ((coords[zRank - 2] + upper < coords[zRank - 1]) || (coords[zRank - 2] + lower > coords[zRank - 1])) if ((coords[zRank - 2] + upper < coords[zRank - 1]) || (coords[zRank - 2] + lower > coords[zRank - 1]))
z[zOffset] = value; z[zOffset] = value;
else if (this != target) { // when this and target are different arrays else if (this != &target) { // when this and target are different arrays
if (xRank != zRank) if (xRank != zRank)
coords[0] = coords[1]; coords[0] = coords[1];
@ -120,7 +117,7 @@ void NDArray::fillAsTriangular(const float val, int lower, int upper, const char
samediff::Threads::parallel_for(func, 0, zLen); samediff::Threads::parallel_for(func, 0, zLen);
} }
BUILD_SINGLE_TEMPLATE(template void NDArray::fillAsTriangular, (const float val, int lower, int upper, const char direction, NDArray* target), LIBND4J_TYPES); BUILD_SINGLE_TEMPLATE(template void NDArray::fillAsTriangular, (const float val, int lower, int upper, NDArray& target, const char direction), LIBND4J_TYPES);
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
void NDArray::setIdentity() { void NDArray::setIdentity() {
@ -405,11 +402,11 @@ static void repeat_(const NDArray& input, NDArray& output, const std::vector<int
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
// create new array by repeating it the number of times given by repeats // create new array by repeating it the number of times given by repeats
NDArray* NDArray::repeat(const int axis, const std::vector<int>& repeats) const { NDArray NDArray::repeat(const int axis, const std::vector<int>& repeats) const {
auto output = new NDArray('c', ShapeUtils::evalRepeatShape(axis, repeats, *this), dataType(), getContext()); NDArray output('c', ShapeUtils::evalRepeatShape(axis, repeats, *this), dataType(), getContext());
BUILD_SINGLE_SELECTOR_TWICE(dataType(), repeat_, (*this, *output, repeats, axis), LIBND4J_TYPES); BUILD_SINGLE_SELECTOR_TWICE(dataType(), repeat_, (*this, output, repeats, axis), LIBND4J_TYPES);
return output; return output;
} }

View File

@ -2,35 +2,24 @@
template<typename T> template<typename T>
void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<T(T, T, T)>& func, NDArray* target) { void NDArray::applyTriplewiseLambda(NDArray& second, NDArray& third, const std::function<T(T, T, T)>& func, NDArray& target) {
if (target == nullptr)
target = this;
if (second == nullptr) {
nd4j_printf("applyTriplewiseLambda requires three operands to be valid NDArrays, but Second is NULL\n","");
throw std::runtime_error("second is null");
}
if (third == nullptr) {
nd4j_printf("applyTriplewiseLambda requires three operands to be valid NDArrays, but Third is NULL\n","");
throw std::runtime_error("third is null");
}
if(dataType() != DataTypeUtils::fromT<T>()) if(dataType() != DataTypeUtils::fromT<T>())
throw std::runtime_error("NDArray::applyTriplewiseLambda<T> method: wrong template parameter T, its type should be the same as type of this array!"); throw std::runtime_error("NDArray::applyTriplewiseLambda<T> method: wrong template parameter T, its type should be the same as type of this array!");
if(dataType() != second->dataType() || dataType() != third->dataType() || dataType() != target->dataType()) if(dataType() != second.dataType() || dataType() != third.dataType() || dataType() != target.dataType())
throw std::runtime_error("NDArray::applyTriplewiseLambda<T> method: bother four arrays (this, second, third, target) should have the same type !"); throw std::runtime_error("NDArray::applyTriplewiseLambda<T> method: bother four arrays (this, second, third, target) should have the same type !");
if (this->lengthOf() != second->lengthOf() || this->lengthOf() != third->lengthOf() || !this->isSameShape(second) || !this->isSameShape(third)) { if (this->lengthOf() != second.lengthOf() || this->lengthOf() != third.lengthOf() || !this->isSameShape(second) || !this->isSameShape(third)) {
nd4j_printf("applyPairwiseLambda requires both operands to have the same shape\n",""); nd4j_printf("applyPairwiseLambda requires both operands to have the same shape\n","");
throw std::runtime_error("Shapes mismach"); throw std::runtime_error("Shapes mismach");
} }
auto f = this->bufferAsT<T>(); auto f = this->bufferAsT<T>();
auto s = second->bufferAsT<T>(); auto s = second.bufferAsT<T>();
auto t = third->bufferAsT<T>(); auto t = third.bufferAsT<T>();
auto z = target->bufferAsT<T>(); auto z = target.bufferAsT<T>();
if (this->ordering() == second->ordering() && this->ordering() == third->ordering() && this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1) && this->ews() == second->ews() && this->ews() == third->ews()) { if (this->ordering() == second.ordering() && this->ordering() == third.ordering() && this->ordering() == target.ordering() && (this->ews() == 1 && target.ews() == 1) && this->ews() == second.ews() && this->ews() == third.ews()) {
auto loop = PRAGMA_THREADS_FOR { auto loop = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e += increment) for (auto e = start; e < stop; e += increment)
@ -44,8 +33,8 @@ void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::
auto loop = PRAGMA_THREADS_FOR { auto loop = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e += increment) { for (auto e = start; e < stop; e += increment) {
auto tOffset = this->getOffset(e); auto tOffset = this->getOffset(e);
auto uOffset = second->getOffset(e); auto uOffset = second.getOffset(e);
auto vOffset = third->getOffset(e); auto vOffset = third.getOffset(e);
f[tOffset] = func(f[tOffset], s[uOffset], t[vOffset]); f[tOffset] = func(f[tOffset], s[uOffset], t[vOffset]);
} }
@ -57,9 +46,9 @@ void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::
auto loop = PRAGMA_THREADS_FOR { auto loop = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e += increment) { for (auto e = start; e < stop; e += increment) {
auto tOffset = this->getOffset(e); auto tOffset = this->getOffset(e);
auto uOffset = second->getOffset(e); auto uOffset = second.getOffset(e);
auto vOffset = third->getOffset(e); auto vOffset = third.getOffset(e);
auto zOffset = target->getOffset(e); auto zOffset = target.getOffset(e);
z[zOffset] = func(f[tOffset], s[uOffset], t[vOffset]); z[zOffset] = func(f[tOffset], s[uOffset], t[vOffset]);
} }
@ -69,46 +58,39 @@ void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::
} }
} }
} }
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<double (double, double, double)>& func, NDArray* target); template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function<double (double, double, double)>& func, NDArray& target);
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<float (float, float, float)>& func, NDArray* target); template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function<float (float, float, float)>& func, NDArray& target);
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<float16 (float16, float16, float16)>& func, NDArray* target); template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function<float16 (float16, float16, float16)>& func, NDArray& target);
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<bfloat16 (bfloat16, bfloat16, bfloat16)>& func, NDArray* target); template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function<bfloat16 (bfloat16, bfloat16, bfloat16)>& func, NDArray& target);
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<Nd4jLong (Nd4jLong, Nd4jLong, Nd4jLong)>& func, NDArray* target); template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function<Nd4jLong (Nd4jLong, Nd4jLong, Nd4jLong)>& func, NDArray& target);
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<int (int, int, int)>& func, NDArray* target); template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function<int (int, int, int)>& func, NDArray& target);
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<int16_t (int16_t, int16_t, int16_t)>& func, NDArray* target); template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function<int16_t (int16_t, int16_t, int16_t)>& func, NDArray& target);
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<uint8_t (uint8_t, uint8_t, uint8_t)>& func, NDArray* target); template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function<uint8_t (uint8_t, uint8_t, uint8_t)>& func, NDArray& target);
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<uint16_t (uint16_t, uint16_t, uint16_t)>& func, NDArray* target); template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function<uint16_t (uint16_t, uint16_t, uint16_t)>& func, NDArray& target);
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<uint32_t (uint32_t, uint32_t, uint32_t)>& func, NDArray* target); template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function<uint32_t (uint32_t, uint32_t, uint32_t)>& func, NDArray& target);
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<uint64_t (uint64_t, uint64_t, uint64_t)>& func, NDArray* target); template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function<uint64_t (uint64_t, uint64_t, uint64_t)>& func, NDArray& target);
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<int8_t (int8_t, int8_t, int8_t)>& func, NDArray* target); template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function<int8_t (int8_t, int8_t, int8_t)>& func, NDArray& target);
template void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, const std::function<bool (bool, bool, bool)>& func, NDArray* target); template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray &third, const std::function<bool (bool, bool, bool)>& func, NDArray& target);
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template<typename T> template<typename T>
void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<T(T, T)>& func, NDArray* target) { void NDArray::applyPairwiseLambda(const NDArray& other, const std::function<T(T, T)>& func, NDArray& target) {
if (target == nullptr)
target = this;
if (other == nullptr) {
nd4j_printf("applyPairwiseLambda requires both operands to be valid NDArrays, but Y is NULL\n","");
throw std::runtime_error("Other is null");
}
if(dataType() != DataTypeUtils::fromT<T>()) if(dataType() != DataTypeUtils::fromT<T>())
throw std::runtime_error("NDArray::applyPairwiseLambda<T> method: wrong template parameter T, its type should be the same as type of this array!"); throw std::runtime_error("NDArray::applyPairwiseLambda<T> method: wrong template parameter T, its type should be the same as type of this array!");
if(dataType() != other->dataType() || dataType() != target->dataType()) if(dataType() != other.dataType() || dataType() != target.dataType())
throw std::runtime_error("NDArray::applyPairwiseLambda<T> method: all three arrays (this, other, target) must have the same type !"); throw std::runtime_error("NDArray::applyPairwiseLambda<T> method: all three arrays (this, other, target) must have the same type !");
if (this->lengthOf() != other->lengthOf()) { if (this->lengthOf() != other.lengthOf()) {
nd4j_printf("applyPairwiseLambda requires both operands to have the same shape\n",""); nd4j_printf("applyPairwiseLambda requires both operands to have the same shape\n","");
throw std::runtime_error("Shapes mismach"); throw std::runtime_error("Shapes mismach");
} }
auto f = this->bufferAsT<T>(); auto f = this->bufferAsT<T>();
auto s = other->bufferAsT<T>(); auto s = other.bufferAsT<T>();
auto z = target->bufferAsT<T>(); auto z = target.bufferAsT<T>();
if (this->ordering() == other->ordering() && this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1) && this->ews() == other->ews()) { if (this->ordering() == other.ordering() && this->ordering() == target.ordering() && (this->ews() == 1 && target.ews() == 1) && this->ews() == other.ews()) {
auto loop = PRAGMA_THREADS_FOR { auto loop = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e += increment) for (auto e = start; e < stop; e += increment)
@ -122,7 +104,7 @@ void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<T(T,
auto loop = PRAGMA_THREADS_FOR { auto loop = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e += increment) { for (auto e = start; e < stop; e += increment) {
auto xOffset = this->getOffset(e); auto xOffset = this->getOffset(e);
auto yOffset = other->getOffset(e); auto yOffset = other.getOffset(e);
f[xOffset] = func(f[xOffset], s[yOffset]); f[xOffset] = func(f[xOffset], s[yOffset]);
} }
@ -134,8 +116,8 @@ void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<T(T,
auto loop = PRAGMA_THREADS_FOR { auto loop = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e += increment) { for (auto e = start; e < stop; e += increment) {
auto xOffset = this->getOffset(e); auto xOffset = this->getOffset(e);
auto yOffset = other->getOffset(e); auto yOffset = other.getOffset(e);
auto zOffset = target->getOffset(e); auto zOffset = target.getOffset(e);
z[zOffset] = func(f[xOffset], s[yOffset]); z[zOffset] = func(f[xOffset], s[yOffset]);
} }
@ -145,35 +127,33 @@ void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<T(T,
} }
} }
} }
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<double (double, double)>& func, NDArray* target); template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function<double (double, double)>& func, NDArray& target);
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<float (float, float)>& func, NDArray* target); template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function<float (float, float)>& func, NDArray& target);
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<float16 (float16, float16)>& func, NDArray* target); template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function<float16 (float16, float16)>& func, NDArray& target);
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<bfloat16 (bfloat16, bfloat16)>& func, NDArray* target); template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function<bfloat16 (bfloat16, bfloat16)>& func, NDArray& target);
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<Nd4jLong (Nd4jLong, Nd4jLong)>& func, NDArray* target); template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function<Nd4jLong (Nd4jLong, Nd4jLong)>& func, NDArray& target);
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<int (int, int)>& func, NDArray* target); template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function<int (int, int)>& func, NDArray& target);
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<int16_t (int16_t, int16_t)>& func, NDArray* target); template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function<int16_t (int16_t, int16_t)>& func, NDArray& target);
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<uint8_t (uint8_t, uint8_t)>& func, NDArray* target); template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function<uint8_t (uint8_t, uint8_t)>& func, NDArray& target);
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<uint16_t (uint16_t, uint16_t)>& func, NDArray* target); template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function<uint16_t (uint16_t, uint16_t)>& func, NDArray& target);
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<uint32_t (uint32_t, uint32_t)>& func, NDArray* target); template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function<uint32_t (uint32_t, uint32_t)>& func, NDArray& target);
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<uint64_t (uint64_t, uint64_t)>& func, NDArray* target); template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function<uint64_t (uint64_t, uint64_t)>& func, NDArray& target);
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<int8_t (int8_t, int8_t)>& func, NDArray* target); template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function<int8_t (int8_t, int8_t)>& func, NDArray& target);
template void NDArray::applyPairwiseLambda(const NDArray* other, const std::function<bool (bool, bool)>& func, NDArray* target); template void NDArray::applyPairwiseLambda(const NDArray& other, const std::function<bool (bool, bool)>& func, NDArray& target);
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template<typename T> template<typename T>
void NDArray::applyLambda(const std::function<T(T)>& func, NDArray* target) { void NDArray::applyLambda(const std::function<T(T)>& func, NDArray& target) {
if (target == nullptr)
target = this;
if(dataType() != DataTypeUtils::fromT<T>()) if(dataType() != DataTypeUtils::fromT<T>())
throw std::runtime_error("NDArray::applyLambda<T> method: wrong template parameter T, its type should be the same as type of this array!"); throw std::runtime_error("NDArray::applyLambda<T> method: wrong template parameter T, its type should be the same as type of this array!");
if(dataType() != target->dataType()) if(dataType() != target.dataType())
throw std::runtime_error("NDArray::applyLambda<T> method: types of this and target array should match !"); throw std::runtime_error("NDArray::applyLambda<T> method: types of this and target array should match !");
auto f = this->bufferAsT<T>(); auto f = this->bufferAsT<T>();
auto z = target->bufferAsT<T>(); auto z = target.bufferAsT<T>();
if (this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1)) { if (this->ordering() == target.ordering() && (this->ews() == 1 && target.ews() == 1)) {
auto loop = PRAGMA_THREADS_FOR { auto loop = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e += increment) for (auto e = start; e < stop; e += increment)
@ -198,7 +178,7 @@ void NDArray::applyLambda(const std::function<T(T)>& func, NDArray* target) {
auto loop = PRAGMA_THREADS_FOR { auto loop = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e += increment) { for (auto e = start; e < stop; e += increment) {
auto xOffset = this->getOffset(e); auto xOffset = this->getOffset(e);
auto zOffset = target->getOffset(e); auto zOffset = target.getOffset(e);
z[zOffset] = func(f[xOffset]); z[zOffset] = func(f[xOffset]);
} }
@ -208,35 +188,33 @@ void NDArray::applyLambda(const std::function<T(T)>& func, NDArray* target) {
} }
} }
} }
template void NDArray::applyLambda(const std::function<double(double)>& func, NDArray* target); template void NDArray::applyLambda(const std::function<double(double)>& func, NDArray& target);
template void NDArray::applyLambda(const std::function<float(float)>& func, NDArray* target); template void NDArray::applyLambda(const std::function<float(float)>& func, NDArray& target);
template void NDArray::applyLambda(const std::function<float16(float16)>& func, NDArray* target); template void NDArray::applyLambda(const std::function<float16(float16)>& func, NDArray& target);
template void NDArray::applyLambda(const std::function<bfloat16(bfloat16)>& func, NDArray* target); template void NDArray::applyLambda(const std::function<bfloat16(bfloat16)>& func, NDArray& target);
template void NDArray::applyLambda(const std::function<Nd4jLong(Nd4jLong)>& func, NDArray* target); template void NDArray::applyLambda(const std::function<Nd4jLong(Nd4jLong)>& func, NDArray& target);
template void NDArray::applyLambda(const std::function<int16_t(int16_t)>& func, NDArray* target); template void NDArray::applyLambda(const std::function<int16_t(int16_t)>& func, NDArray& target);
template void NDArray::applyLambda(const std::function<int32_t(int32_t)>& func, NDArray* target); template void NDArray::applyLambda(const std::function<int32_t(int32_t)>& func, NDArray& target);
template void NDArray::applyLambda(const std::function<uint8_t(uint8_t)>& func, NDArray* target); template void NDArray::applyLambda(const std::function<uint8_t(uint8_t)>& func, NDArray& target);
template void NDArray::applyLambda(const std::function<uint16_t(uint16_t)>& func, NDArray* target); template void NDArray::applyLambda(const std::function<uint16_t(uint16_t)>& func, NDArray& target);
template void NDArray::applyLambda(const std::function<uint32_t(uint32_t)>& func, NDArray* target); template void NDArray::applyLambda(const std::function<uint32_t(uint32_t)>& func, NDArray& target);
template void NDArray::applyLambda(const std::function<uint64_t(uint64_t)>& func, NDArray* target); template void NDArray::applyLambda(const std::function<uint64_t(uint64_t)>& func, NDArray& target);
template void NDArray::applyLambda(const std::function<int8_t(int8_t)>& func, NDArray* target); template void NDArray::applyLambda(const std::function<int8_t(int8_t)>& func, NDArray& target);
template void NDArray::applyLambda(const std::function<bool(bool)>& func, NDArray* target); template void NDArray::applyLambda(const std::function<bool(bool)>& func, NDArray& target);
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template<typename T> template<typename T>
void NDArray::applyIndexedLambda(const std::function<T(Nd4jLong, T)>& func, NDArray* target) { void NDArray::applyIndexedLambda(const std::function<T(Nd4jLong, T)>& func, NDArray& target) {
if (target == nullptr)
target = this;
if(dataType() != DataTypeUtils::fromT<T>()) if(dataType() != DataTypeUtils::fromT<T>())
throw std::runtime_error("NDArray::applyIndexedLambda<T> method: wrong template parameter T, its type should be the same as type of this array!"); throw std::runtime_error("NDArray::applyIndexedLambda<T> method: wrong template parameter T, its type should be the same as type of this array!");
if(dataType() != target->dataType()) if(dataType() != target.dataType())
throw std::runtime_error("NDArray::applyIndexedLambda<T> method: types of this and target array should match !"); throw std::runtime_error("NDArray::applyIndexedLambda<T> method: types of this and target array should match !");
auto f = this->bufferAsT<T>(); auto f = this->bufferAsT<T>();
auto z = target->bufferAsT<T>(); auto z = target.bufferAsT<T>();
if (this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1)) { if (this->ordering() == target.ordering() && (this->ews() == 1 && target.ews() == 1)) {
auto loop = PRAGMA_THREADS_FOR { auto loop = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e += increment) for (auto e = start; e < stop; e += increment)
@ -261,7 +239,7 @@ void NDArray::applyIndexedLambda(const std::function<T(Nd4jLong, T)>& func, NDAr
auto loop = PRAGMA_THREADS_FOR { auto loop = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e += increment) { for (auto e = start; e < stop; e += increment) {
auto xOffset = this->getOffset(e); auto xOffset = this->getOffset(e);
auto zOffset = target->getOffset(e); auto zOffset = target.getOffset(e);
z[zOffset] = func(e, f[xOffset]); z[zOffset] = func(e, f[xOffset]);
} }
@ -271,44 +249,38 @@ void NDArray::applyIndexedLambda(const std::function<T(Nd4jLong, T)>& func, NDAr
} }
} }
} }
template void NDArray::applyIndexedLambda(const std::function<double(Nd4jLong, double)>& func, NDArray* target); template void NDArray::applyIndexedLambda(const std::function<double(Nd4jLong, double)>& func, NDArray& target);
template void NDArray::applyIndexedLambda(const std::function<float(Nd4jLong, float)>& func, NDArray* target); template void NDArray::applyIndexedLambda(const std::function<float(Nd4jLong, float)>& func, NDArray& target);
template void NDArray::applyIndexedLambda(const std::function<float16(Nd4jLong, float16)>& func, NDArray* target); template void NDArray::applyIndexedLambda(const std::function<float16(Nd4jLong, float16)>& func, NDArray& target);
template void NDArray::applyIndexedLambda(const std::function<bfloat16(Nd4jLong, bfloat16)>& func, NDArray* target); template void NDArray::applyIndexedLambda(const std::function<bfloat16(Nd4jLong, bfloat16)>& func, NDArray& target);
template void NDArray::applyIndexedLambda(const std::function<Nd4jLong(Nd4jLong, Nd4jLong)>& func, NDArray* target); template void NDArray::applyIndexedLambda(const std::function<Nd4jLong(Nd4jLong, Nd4jLong)>& func, NDArray& target);
template void NDArray::applyIndexedLambda(const std::function<int(Nd4jLong, int)>& func, NDArray* target); template void NDArray::applyIndexedLambda(const std::function<int(Nd4jLong, int)>& func, NDArray& target);
template void NDArray::applyIndexedLambda(const std::function<int16_t(Nd4jLong, int16_t)>& func, NDArray* target); template void NDArray::applyIndexedLambda(const std::function<int16_t(Nd4jLong, int16_t)>& func, NDArray& target);
template void NDArray::applyIndexedLambda(const std::function<uint8_t (Nd4jLong, uint8_t)>& func, NDArray* target); template void NDArray::applyIndexedLambda(const std::function<uint8_t (Nd4jLong, uint8_t)>& func, NDArray& target);
template void NDArray::applyIndexedLambda(const std::function<uint16_t (Nd4jLong, uint16_t)>& func, NDArray* target); template void NDArray::applyIndexedLambda(const std::function<uint16_t (Nd4jLong, uint16_t)>& func, NDArray& target);
template void NDArray::applyIndexedLambda(const std::function<uint32_t (Nd4jLong, uint32_t)>& func, NDArray* target); template void NDArray::applyIndexedLambda(const std::function<uint32_t (Nd4jLong, uint32_t)>& func, NDArray& target);
template void NDArray::applyIndexedLambda(const std::function<uint64_t (Nd4jLong, uint64_t)>& func, NDArray* target); template void NDArray::applyIndexedLambda(const std::function<uint64_t (Nd4jLong, uint64_t)>& func, NDArray& target);
template void NDArray::applyIndexedLambda(const std::function<int8_t(Nd4jLong, int8_t)>& func, NDArray* target); template void NDArray::applyIndexedLambda(const std::function<int8_t(Nd4jLong, int8_t)>& func, NDArray& target);
template void NDArray::applyIndexedLambda(const std::function<bool(Nd4jLong, bool)>& func, NDArray* target); template void NDArray::applyIndexedLambda(const std::function<bool(Nd4jLong, bool)>& func, NDArray& target);
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template<typename T> template<typename T>
void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<T(Nd4jLong, T, T)>& func, NDArray* target) { void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function<T(Nd4jLong, T, T)>& func, NDArray& target) {
if (target == nullptr)
target = this;
if (other == nullptr) {
nd4j_printf("applyIndexedPairwiseLambda requires both operands to be valid NDArrays, but Y is NULL\n","");
throw std::runtime_error("Other is null");
}
if(dataType() != DataTypeUtils::fromT<T>()) if(dataType() != DataTypeUtils::fromT<T>())
throw std::runtime_error("NDArray::applyIndexedPairwiseLambda<T> method: wrong template parameter T, its type should be the same as type of this array!"); throw std::runtime_error("NDArray::applyIndexedPairwiseLambda<T> method: wrong template parameter T, its type should be the same as type of this array!");
if(dataType() != target->dataType()) if(dataType() != target.dataType())
throw std::runtime_error("NDArray::applyIndexedPairwiseLambda<T> method: types of this and target array should match !"); throw std::runtime_error("NDArray::applyIndexedPairwiseLambda<T> method: types of this and target array should match !");
if (this->lengthOf() != other->lengthOf()) { if (this->lengthOf() != other.lengthOf()) {
nd4j_printf("applyIndexedPairwiseLambda requires both operands to have the same shape\n",""); nd4j_printf("applyIndexedPairwiseLambda requires both operands to have the same shape\n","");
throw std::runtime_error("Shapes mismach"); throw std::runtime_error("Shapes mismach");
} }
auto f = this->bufferAsT<T>(); auto f = this->bufferAsT<T>();
auto s = other->bufferAsT<T>(); auto s = other.bufferAsT<T>();
auto z = target->bufferAsT<T>(); auto z = target.bufferAsT<T>();
if (this->ordering() == other->ordering() && this->ordering() == target->ordering() && (this->ews() == 1 && target->ews() == 1) && this->ews() == other->ews()) { if (this->ordering() == other.ordering() && this->ordering() == target.ordering() && (this->ews() == 1 && target.ews() == 1) && this->ews() == other.ews()) {
auto loop = PRAGMA_THREADS_FOR { auto loop = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e += increment) for (auto e = start; e < stop; e += increment)
@ -322,7 +294,7 @@ void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<T(N
auto loop = PRAGMA_THREADS_FOR { auto loop = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e += increment) { for (auto e = start; e < stop; e += increment) {
auto xOffset = this->getOffset(e); auto xOffset = this->getOffset(e);
auto yOffset = other->getOffset(e); auto yOffset = other.getOffset(e);
f[xOffset] = func((Nd4jLong) e, f[xOffset], s[yOffset]); f[xOffset] = func((Nd4jLong) e, f[xOffset], s[yOffset]);
} }
@ -334,8 +306,8 @@ void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<T(N
auto loop = PRAGMA_THREADS_FOR { auto loop = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e += increment) { for (auto e = start; e < stop; e += increment) {
auto xOffset = this->getOffset(e); auto xOffset = this->getOffset(e);
auto yOffset = other->getOffset(e); auto yOffset = other.getOffset(e);
auto zOffset = target->getOffset(e); auto zOffset = target.getOffset(e);
z[zOffset] = func((Nd4jLong) e, f[xOffset], s[yOffset]); z[zOffset] = func((Nd4jLong) e, f[xOffset], s[yOffset]);
} }
@ -345,16 +317,16 @@ void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<T(N
} }
} }
} }
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<double (Nd4jLong, double, double)>& func, NDArray* target); template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function<double (Nd4jLong, double, double)>& func, NDArray& target);
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<float (Nd4jLong, float, float)>& func, NDArray* target); template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function<float (Nd4jLong, float, float)>& func, NDArray& target);
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<float16 (Nd4jLong, float16, float16)>& func, NDArray* target); template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function<float16 (Nd4jLong, float16, float16)>& func, NDArray& target);
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<bfloat16 (Nd4jLong, bfloat16, bfloat16)>& func, NDArray* target); template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function<bfloat16 (Nd4jLong, bfloat16, bfloat16)>& func, NDArray& target);
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<Nd4jLong (Nd4jLong, Nd4jLong, Nd4jLong)>& func, NDArray* target); template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function<Nd4jLong (Nd4jLong, Nd4jLong, Nd4jLong)>& func, NDArray& target);
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<int (Nd4jLong, int, int)>& func, NDArray* target); template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function<int (Nd4jLong, int, int)>& func, NDArray& target);
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<int16_t (Nd4jLong, int16_t, int16_t)>& func, NDArray* target); template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function<int16_t (Nd4jLong, int16_t, int16_t)>& func, NDArray& target);
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<uint8_t (Nd4jLong, uint8_t, uint8_t)>& func, NDArray* target); template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function<uint8_t (Nd4jLong, uint8_t, uint8_t)>& func, NDArray& target);
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<uint16_t (Nd4jLong, uint16_t, uint16_t)>& func, NDArray* target); template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function<uint16_t (Nd4jLong, uint16_t, uint16_t)>& func, NDArray& target);
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<uint32_t (Nd4jLong, uint32_t, uint32_t)>& func, NDArray* target); template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function<uint32_t (Nd4jLong, uint32_t, uint32_t)>& func, NDArray& target);
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<uint64_t (Nd4jLong, uint64_t, uint64_t)>& func, NDArray* target); template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function<uint64_t (Nd4jLong, uint64_t, uint64_t)>& func, NDArray& target);
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<int8_t (Nd4jLong, int8_t, int8_t)>& func, NDArray* target); template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function<int8_t (Nd4jLong, int8_t, int8_t)>& func, NDArray& target);
template void NDArray::applyIndexedPairwiseLambda(NDArray* other, const std::function<bool (Nd4jLong, bool, bool)>& func, NDArray* target); template void NDArray::applyIndexedPairwiseLambda(NDArray& other, const std::function<bool (Nd4jLong, bool, bool)>& func, NDArray& target);

View File

@ -2717,25 +2717,25 @@ static void _scatterUpdate(Nd4jPointer *extraPointers, int opCode, int numOfSub
switch (opCode) { switch (opCode) {
case 0: case 0:
inSubArr.applyPairwiseTransform(pairwise::Add, &updSubArr, &inSubArr, nullptr); inSubArr.applyPairwiseTransform(pairwise::Add, updSubArr, inSubArr);
break; break;
case 1: case 1:
inSubArr.applyPairwiseTransform(pairwise::Subtract, &updSubArr, &inSubArr, nullptr); inSubArr.applyPairwiseTransform(pairwise::Subtract, updSubArr, inSubArr);
break; break;
case 2: case 2:
inSubArr.applyPairwiseTransform(pairwise::Multiply, &updSubArr, &inSubArr, nullptr); inSubArr.applyPairwiseTransform(pairwise::Multiply, updSubArr, inSubArr);
break; break;
case 3: case 3:
inSubArr.applyPairwiseTransform(pairwise::Divide, &updSubArr, &inSubArr, nullptr); inSubArr.applyPairwiseTransform(pairwise::Divide, updSubArr, inSubArr);
break; break;
case 4: case 4:
inSubArr.applyPairwiseTransform(pairwise::ReverseSubtract, &updSubArr, &inSubArr, nullptr); inSubArr.applyPairwiseTransform(pairwise::ReverseSubtract, updSubArr, inSubArr);
break; break;
case 5: case 5:
inSubArr.applyPairwiseTransform(pairwise::ReverseDivide, &updSubArr, &inSubArr, nullptr); inSubArr.applyPairwiseTransform(pairwise::ReverseDivide, updSubArr, inSubArr);
break; break;
case 6: case 6:
inSubArr.applyPairwiseTransform(pairwise::CopyPws, &updSubArr, &inSubArr, nullptr); inSubArr.applyPairwiseTransform(pairwise::CopyPws, updSubArr, inSubArr);
break; break;
default: default:
continue; continue;

View File

@ -122,35 +122,32 @@ __global__ static void fillAsTriangularCuda(const void* vx, const Nd4jLong* xSha
/////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////
template<typename T> template<typename T>
void NDArray::fillAsTriangular(const float val, int lower, int upper, const char direction, NDArray* target) { void NDArray::fillAsTriangular(const float val, int lower, int upper, NDArray& target, const char direction) {
if (isS()) if (isS())
throw std::runtime_error("NDArray::fillAsTriangular: you can't use this method on String array!"); throw std::runtime_error("NDArray::fillAsTriangular: you can't use this method on String array!");
if(target == nullptr) if(!isSameShape(target) && !(rankOf() == 1 && target.rankOf() == 2 && sizeAt(0) == target.sizeAt(0) && sizeAt(0) == target.sizeAt(1)))
target = this;
if(!isSameShape(target) && !(rankOf() == 1 && target->rankOf() == 2 && sizeAt(0) == target->sizeAt(0) && sizeAt(0) == target->sizeAt(1)))
throw std::string("NDArray::fillAsTriangular method: wrong shape of target array !"); throw std::string("NDArray::fillAsTriangular method: wrong shape of target array !");
if (direction == 'u') if (direction == 'u')
lower = -target->sizeAt(-2); lower = -target.sizeAt(-2);
else if (direction == 'l') else if (direction == 'l')
upper = target->sizeAt(-1); upper = target.sizeAt(-1);
const int threadsPerBlock = MAX_NUM_THREADS / 4; const int threadsPerBlock = MAX_NUM_THREADS / 4;
const int blocksPerGrid = (target->lengthOf() + threadsPerBlock - 1) / threadsPerBlock; const int blocksPerGrid = (target.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
const int sharedMem = threadsPerBlock * sizeof(decltype(*target->getShapeInfo())) * target->rankOf() + 128; const int sharedMem = threadsPerBlock * sizeof(decltype(*target.getShapeInfo())) * target.rankOf() + 128;
PointersManager manager(getContext(), "NDArray::fillAsTriangular"); PointersManager manager(getContext(), "NDArray::fillAsTriangular");
NDArray::prepareSpecialUse({target}, {this}); NDArray::prepareSpecialUse({&target}, {this});
fillAsTriangularCuda<T><<<blocksPerGrid, threadsPerBlock, sharedMem, *getContext()->getCudaStream()>>>(getPlatformBuffer(), getPlatformShapeInfo(), target->getPlatformBuffer(), target->getPlatformShapeInfo(), static_cast<T>(val), lower, upper); fillAsTriangularCuda<T><<<blocksPerGrid, threadsPerBlock, sharedMem, *getContext()->getCudaStream()>>>(getPlatformBuffer(), getPlatformShapeInfo(), target.getPlatformBuffer(), target.getPlatformShapeInfo(), static_cast<T>(val), lower, upper);
NDArray::registerSpecialUse({target}, {this}); NDArray::registerSpecialUse({&target}, {this});
manager.synchronize(); manager.synchronize();
} }
BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT void NDArray::fillAsTriangular, (const float val, int lower, int upper, const char direction, NDArray* target), LIBND4J_TYPES); BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT void NDArray::fillAsTriangular, (const float val, int lower, int upper, NDArray& target, const char direction), LIBND4J_TYPES);
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
template<typename T> template<typename T>
@ -457,21 +454,21 @@ BUILD_DOUBLE_TEMPLATE(template void repeatCudaLauncher, (const int blocksPerGrid
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
// create new array by repeating it the number of times given by repeats // create new array by repeating it the number of times given by repeats
NDArray* NDArray::repeat(const int axis, const std::vector<int>& repeats) const { NDArray NDArray::repeat(const int axis, const std::vector<int>& repeats) const {
auto output = new NDArray('c', ShapeUtils::evalRepeatShape(axis, repeats, *this), dataType(), getContext()); NDArray output('c', ShapeUtils::evalRepeatShape(axis, repeats, *this), dataType(), getContext());
const int threadsPerBlock = MAX_NUM_THREADS / 2; const int threadsPerBlock = MAX_NUM_THREADS / 2;
const int blocksPerGrid = (output->lengthOf() + threadsPerBlock - 1) / threadsPerBlock; const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
const int sharedMem = output->rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; const int sharedMem = output.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128;
PointersManager manager(getContext(), "NDArray::repeat(const int axis, const std::vector<int>& repeats)"); PointersManager manager(getContext(), "NDArray::repeat(const int axis, const std::vector<int>& repeats)");
const int* reps = reinterpret_cast<int*>(manager.replicatePointer(repeats.data(), repeats.size() * sizeof(int))); const int* reps = reinterpret_cast<int*>(manager.replicatePointer(repeats.data(), repeats.size() * sizeof(int)));
prepareSpecialUse({output}, {this}); prepareSpecialUse({&output}, {this});
BUILD_SINGLE_SELECTOR_TWICE(dataType(), repeatCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, getContext()->getCudaStream(), getSpecialBuffer(), getSpecialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), reps, repeats.size(), axis), LIBND4J_TYPES); BUILD_SINGLE_SELECTOR_TWICE(dataType(), repeatCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, getContext()->getCudaStream(), getSpecialBuffer(), getSpecialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), reps, repeats.size(), axis), LIBND4J_TYPES);
prepareSpecialUse({output}, {this}); prepareSpecialUse({&output}, {this});
manager.synchronize(); manager.synchronize();

View File

@ -247,73 +247,73 @@ static _CUDA_G void lambdaTriplewiseKernel(void* vw, Nd4jLong *wShapeInfo, void*
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template<typename Lambda> template<typename Lambda>
void NDArray::applyLambda(Lambda func, NDArray* target) { void NDArray::applyLambda(Lambda func, NDArray& target) {
auto result = target == nullptr ? this : target;
auto dtype = this->dataType(); auto dtype = this->dataType();
if (dtype != result->dataType()) if (dtype != target.dataType())
throw std::runtime_error("NDArray::applyLambda X/Z data types must be the same"); throw std::runtime_error("NDArray::applyLambda X/Z data types must be the same");
//throw datatype_exception::build("NDArray::applyLambda X/Z data types must be the same", dtype, result->dataType()); //throw datatype_exception::build("NDArray::applyLambda X/Z data types must be the same", dtype, target.dataType());
prepareSpecialUse({result}, {this}); prepareSpecialUse({&target}, {this});
BUILD_SINGLE_SELECTOR(dtype, LambdaHelper ,::lambdaLauncher(this->_context->getCudaStream(), this->specialBuffer(), this->specialShapeInfo(), result->specialBuffer(), result->specialShapeInfo(), func), LIBND4J_TYPES); BUILD_SINGLE_SELECTOR(dtype, LambdaHelper ,::lambdaLauncher(this->_context->getCudaStream(), this->specialBuffer(), this->specialShapeInfo(), target.specialBuffer(), target.specialShapeInfo(), func), LIBND4J_TYPES);
registerSpecialUse({result}, {this}); registerSpecialUse({&target}, {this});
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template<typename Lambda> template<typename Lambda>
void NDArray::applyPairwiseLambda(const NDArray* other, Lambda func, NDArray* target) { void NDArray::applyPairwiseLambda(const NDArray& other, Lambda func, NDArray& target) {
auto result = target == nullptr ? this : target;
auto dtype = this->dataType(); auto dtype = this->dataType();
if (dtype != result->dataType() || dtype != other->dataType()) if (dtype != target.dataType() || dtype != other.dataType())
throw std::runtime_error("NDArray::applyPairwiseLambda X/Y/Z data types must be the same"); throw std::runtime_error("NDArray::applyPairwiseLambda X/Y/Z data types must be the same");
//throw datatype_exception::build("NDArray::applyLambda X/Z data types must be the same", dtype, result->dataType()); //throw datatype_exception::build("NDArray::applyLambda X/Z data types must be the same", dtype, target.dataType());
prepareSpecialUse({result}, {this, other}); prepareSpecialUse({&target}, {this, &other});
BUILD_SINGLE_SELECTOR(dtype, LambdaHelper ,::lambdaPairwiseLauncher(this->_context->getCudaStream(), this->specialBuffer(), this->specialShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->specialBuffer(), result->specialShapeInfo(), func), LIBND4J_TYPES); BUILD_SINGLE_SELECTOR(dtype, LambdaHelper ,::lambdaPairwiseLauncher(this->_context->getCudaStream(), this->specialBuffer(), this->specialShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), target.specialBuffer(), target.specialShapeInfo(), func), LIBND4J_TYPES);
registerSpecialUse({result}, {this, other}); registerSpecialUse({&target}, {this, &other});
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template <typename Lambda> template <typename Lambda>
void NDArray::applyIndexedLambda(Lambda func, NDArray* target) { void NDArray::applyIndexedLambda(Lambda func, NDArray& target) {
auto result = target == nullptr ? this : target;
auto dtype = this->dataType(); auto dtype = this->dataType();
if (dtype != result->dataType()) if (dtype != target.dataType())
throw std::runtime_error("NDArray::applyIndexedLambda X/Z data types must be the same"); throw std::runtime_error("NDArray::applyIndexedLambda X/Z data types must be the same");
prepareSpecialUse({result}, {this}); prepareSpecialUse({&target}, {this});
BUILD_SINGLE_SELECTOR(dtype, LambdaHelper ,::lambdaIndexedLauncher(this->_context->getCudaStream(), this->specialBuffer(), this->specialShapeInfo(), result->specialBuffer(), result->specialShapeInfo(), func), LIBND4J_TYPES); BUILD_SINGLE_SELECTOR(dtype, LambdaHelper ,::lambdaIndexedLauncher(this->_context->getCudaStream(), this->specialBuffer(), this->specialShapeInfo(), target.specialBuffer(), target.specialShapeInfo(), func), LIBND4J_TYPES);
registerSpecialUse({result}, {this}); registerSpecialUse({&target}, {this});
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template <typename Lambda> template <typename Lambda>
void NDArray::applyIndexedPairwiseLambda(NDArray* other, Lambda func, NDArray* target) { void NDArray::applyIndexedPairwiseLambda(NDArray& other, Lambda func, NDArray& target) {
auto result = target == nullptr ? this : target;
auto dtype = this->dataType(); auto dtype = this->dataType();
if (dtype != result->dataType() || dtype != other->dataType()) if (dtype != target.dataType() || dtype != other.dataType())
throw std::runtime_error("NDArray::applyIndexedPairwiseLambda X/Y/Z data types must be the same"); throw std::runtime_error("NDArray::applyIndexedPairwiseLambda X/Y/Z data types must be the same");
prepareSpecialUse({result}, {this, other}); prepareSpecialUse({&target}, {this, &other});
BUILD_SINGLE_SELECTOR(dtype, LambdaHelper ,::lambdaIndexedPairwiseLauncher(this->_context->getCudaStream(), this->specialBuffer(), this->specialShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->specialBuffer(), result->specialShapeInfo(), func), LIBND4J_TYPES); BUILD_SINGLE_SELECTOR(dtype, LambdaHelper ,::lambdaIndexedPairwiseLauncher(this->_context->getCudaStream(), this->specialBuffer(), this->specialShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), target.specialBuffer(), target.specialShapeInfo(), func), LIBND4J_TYPES);
registerSpecialUse({result}, {this, other}); registerSpecialUse({&target}, {this, &other});
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template <typename Lambda> template <typename Lambda>
void NDArray::applyTriplewiseLambda(NDArray* second, NDArray *third, Lambda func, NDArray* target) { void NDArray::applyTriplewiseLambda(NDArray& second, NDArray& third, Lambda func, NDArray& target) {
auto result = target == nullptr ? this : target;
auto dtype = this->dataType(); auto dtype = this->dataType();
if (dtype != result->dataType() || dtype != second->dataType() || dtype != third->dataType()) if (dtype != target.dataType() || dtype != second.dataType() || dtype != third.dataType())
throw std::runtime_error("NDArray::applyTriplewiseLambda X/Y/Z data types must be the same"); throw std::runtime_error("NDArray::applyTriplewiseLambda X/Y/Z data types must be the same");
prepareSpecialUse({result}, {this, second, third}); prepareSpecialUse({&target}, {this, &second, &third});
BUILD_SINGLE_SELECTOR(dtype, LambdaHelper ,::lambdaTriplewiseLauncher(this->_context->getCudaStream(), this->specialBuffer(), this->specialShapeInfo(), second->specialBuffer(), second->specialShapeInfo(), third->specialBuffer(), third->specialShapeInfo(), result->specialBuffer(), result->specialShapeInfo(), func), LIBND4J_TYPES); BUILD_SINGLE_SELECTOR(dtype, LambdaHelper ,::lambdaTriplewiseLauncher(this->_context->getCudaStream(), this->specialBuffer(), this->specialShapeInfo(), second.specialBuffer(), second.specialShapeInfo(), third.specialBuffer(), third.specialShapeInfo(), target.specialBuffer(), target.specialShapeInfo(), func), LIBND4J_TYPES);
registerSpecialUse({result}, {this, second, third}); registerSpecialUse({&target}, {this, &second, &third});
} }

View File

@ -91,6 +91,10 @@ namespace nd4j {
template <typename T> template <typename T>
FORCEINLINE static bool castShapeInfo(const Nd4jLong *originalShapeInfo, T *newShapeInfo); FORCEINLINE static bool castShapeInfo(const Nd4jLong *originalShapeInfo, T *newShapeInfo);
template<typename T>
// struct scalarTypesForNDarray { static bool const value = std::is_same<double, T>::value || std::is_same<float, T>::value || std::is_same<int, T>::value || std::is_same<bfloat16, T>::value || std::is_same<float16, T>::value || std::is_same<long long, T>::value; };
struct scalarTypesForNDarray { static bool const value = std::is_same<double, T>::value || std::is_same<float, T>::value || std::is_same<int, T>::value || std::is_same<unsigned int, T>::value || std::is_same<long long, T>::value || std::is_same<unsigned long long, T>::value || std::is_same<long int, T>::value || std::is_same<long unsigned int, T>::value || std::is_same<int8_t, T>::value || std::is_same<uint8_t, T>::value || std::is_same<int16_t, T>::value || std::is_same<uint16_t, T>::value || std::is_same<bool, T>::value || std::is_same<bfloat16, T>::value || std::is_same<float16, T>::value; };
}; };

View File

@ -44,7 +44,7 @@ namespace nd4j {
} }
NDArray* NDArrayList::read(int idx) { NDArray* NDArrayList::read(int idx) {
return readRaw(idx)->dup(); return new NDArray(readRaw(idx)->dup());
} }
nd4j::DataType NDArrayList::dataType() { nd4j::DataType NDArrayList::dataType() {
@ -114,7 +114,7 @@ namespace nd4j {
} else } else
return Status::CODE(ND4J_STATUS_BAD_INPUT, "NDArrayList: all arrays must have same size along inner dimensions"); return Status::CODE(ND4J_STATUS_BAD_INPUT, "NDArrayList: all arrays must have same size along inner dimensions");
} }
//_elements++; //_elements++;
// storing reference // storing reference
@ -136,11 +136,10 @@ namespace nd4j {
std::vector<int> args({axis}); std::vector<int> args({axis});
auto newAxis = ShapeUtils::evalDimsToExclude(array->rankOf(), args); auto newAxis = ShapeUtils::evalDimsToExclude(array->rankOf(), args);
auto result = array->allTensorsAlongDimension(newAxis); auto result = array->allTensorsAlongDimension(newAxis);
for (int e = 0; e < result->size(); e++) { for (int e = 0; e < result.size(); e++) {
auto chunk = result->at(e);//->dup(array->ordering()); auto chunk = result.at(e);//->dup(array->ordering());
write(e, chunk->dup(array->ordering())); write(e, new NDArray(chunk->dup(array->ordering())));
} }
delete result;
} }
NDArray* NDArrayList::stack() { NDArray* NDArrayList::stack() {
@ -161,7 +160,7 @@ namespace nd4j {
auto result = op.execute(inputs, {}, {}, {}); auto result = op.execute(inputs, {}, {}, {});
auto array = result->at(0)->dup(); auto array = new NDArray(result->at(0)->dup());
delete result; delete result;
@ -214,13 +213,11 @@ namespace nd4j {
auto tads = array->allTensorsAlongDimension(axis); auto tads = array->allTensorsAlongDimension(axis);
int indicesSize = indices.size(); int indicesSize = indices.size();
if (tads->size() != indicesSize) if (tads.size() != indicesSize)
throw std::runtime_error("Number of TADs should match number of indices"); throw std::runtime_error("Number of TADs should match number of indices");
for (int e = 0; e < indicesSize; e++) for (int e = 0; e < indicesSize; e++)
tads->at(e)->assign(_chunks[indices[e]]); tads.at(e)->assign(_chunks[indices[e]]);
delete tads;
return array; return array;
} }
@ -234,7 +231,7 @@ namespace nd4j {
list->_elements.store(_elements.load()); list->_elements.store(_elements.load());
for (auto const& v : _chunks) { for (auto const& v : _chunks) {
list->_chunks[v.first] = v.second->dup(); list->_chunks[v.first] = new NDArray(v.second->dup());
} }
return list; return list;

View File

@ -48,7 +48,7 @@ namespace nd4j {
} else { } else {
// FIXME: in some cases it's possible to have no NDArray // FIXME: in some cases it's possible to have no NDArray
if (inputVar->hasNDArray()) if (inputVar->hasNDArray())
innerVar->setNDArray(inputVar->getNDArray()->dup()); innerVar->setNDArray(new NDArray(inputVar->getNDArray()->dup()));
} }
} }

View File

@ -56,7 +56,7 @@ namespace nd4j {
} else { } else {
// FIXME: in some cases it's possible to have no NDArray // FIXME: in some cases it's possible to have no NDArray
if (inputVar->hasNDArray()) if (inputVar->hasNDArray())
innerVar->setNDArray(inputVar->getNDArray()->dup()); innerVar->setNDArray(new NDArray(inputVar->getNDArray()->dup()));
} }
} }

View File

@ -40,7 +40,7 @@ namespace nd4j {
result->setIndex(this->_index); result->setIndex(this->_index);
if (this->_ndarray != nullptr) if (this->_ndarray != nullptr)
result->setNDArray(this->_ndarray->template asT<N>()); result->setNDArray(new NDArray(this->_ndarray->template asT<N>()));
// FIXME: add support for ArrayList // FIXME: add support for ArrayList
if (this->_list != nullptr) { if (this->_list != nullptr) {
@ -61,7 +61,7 @@ namespace nd4j {
result->_index = this->_index; result->_index = this->_index;
if (this->_ndarray != nullptr) if (this->_ndarray != nullptr)
result->_ndarray = this->_ndarray->dup(this->_ndarray->ordering()); result->_ndarray = new NDArray(this->_ndarray->dup(this->_ndarray->ordering()));
if (this->_list != nullptr) if (this->_list != nullptr)
result->_list = this->_list->clone(); result->_list = this->_list->clone();

View File

@ -93,7 +93,7 @@ namespace nd4j {
} }
OpBenchmark* clone() override { OpBenchmark* clone() override {
return new ScalarBenchmark((scalar::Ops) _opNum, _testName, _x == nullptr ? _x : _x->dup() , _y == nullptr ? _y : _y->dup(), _z == nullptr ? _z : _z->dup()); return new ScalarBenchmark((scalar::Ops) _opNum, _testName, _x == nullptr ? _x : new NDArray(_x->dup()) , _y == nullptr ? _y : new NDArray(_y->dup()), _z == nullptr ? _z : new NDArray(_z->dup()));
} }
}; };
} }

View File

@ -230,17 +230,17 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, con
bool cNcont = N == 1 || C->strideAt(1) == 1; bool cNcont = N == 1 || C->strideAt(1) == 1;
if(!aMcont && !aKcont) { if(!aMcont && !aKcont) {
pA = A->dup('f'); pA = new NDArray(A->dup('f'));
toDelete.push_back(pA); toDelete.push_back(pA);
aMcont = true; aMcont = true;
} }
if(!bKcont && !bNcont) { if(!bKcont && !bNcont) {
pB = B->dup('f'); pB = new NDArray(B->dup('f'));
toDelete.push_back(pB); toDelete.push_back(pB);
bKcont = true; bKcont = true;
} }
if(!cMcont && !cNcont) { if(!cMcont && !cNcont) {
pC = C->dup('f'); pC = new NDArray(C->dup('f'));
toDelete.push_back(pC); toDelete.push_back(pC);
cMcont = true; cMcont = true;
} }
@ -332,7 +332,7 @@ NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, nd4j::NDArray*
bool aNcont = N == 1 || A->strideAt(1) == 1; bool aNcont = N == 1 || A->strideAt(1) == 1;
if(!aMcont && !aNcont) { if(!aMcont && !aNcont) {
pA = A->dup('f'); pA = new NDArray(A->dup('f'));
aMcont = true; aMcont = true;
} }
const CBLAS_ORDER blasOrder = aMcont ? CblasColMajor : CblasRowMajor; const CBLAS_ORDER blasOrder = aMcont ? CblasColMajor : CblasRowMajor;

View File

@ -60,11 +60,10 @@ NDArray Householder<T>::evalHHmatrix(const NDArray& x) {
w.p(Nd4jLong(0), 1.f); w.p(Nd4jLong(0), 1.f);
wT.assign(&w); wT.assign(&w);
auto identity = NDArrayFactory::create(x.ordering(), {(int)x.lengthOf(), (int)x.lengthOf()}, x.dataType(), x.getContext()); NDArray identity = NDArrayFactory::create(x.ordering(), {(int)x.lengthOf(), (int)x.lengthOf()}, x.dataType(), x.getContext());
identity.setIdentity(); // identity matrix identity.setIdentity(); // identity matrix
return identity - mmul(w, wT) * coeff; return identity - mmul(w, wT) * coeff;
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
@ -95,9 +94,9 @@ void Householder<T>::evalHHmatrixData(const NDArray& x, NDArray& tail, T& coeff,
coeff = -u0 / normX; coeff = -u0 / normX;
if(x.isRowVector()) if(x.isRowVector())
tail.assign(x({0,0, 1,-1}) / u0); tail.assign(static_cast<const NDArray&>(x({0,0, 1,-1})) / u0);
else else
tail.assign(x({1,-1, 0,0,}) / u0); tail.assign(static_cast<const NDArray&>(x({1,-1, 0,0,})) / u0);
} }
} }

View File

@ -269,7 +269,7 @@ void JacobiSVD<T>::evalData(const NDArray& matrix) {
HHcolPivQR qr(matrix / scale); HHcolPivQR qr(matrix / scale);
_m.assign(qr._qr({0,_cols, 0,_cols})); _m.assign(qr._qr({0,_cols, 0,_cols}));
_m.fillAsTriangular<T>(0., 0, 0, 'l'); _m.fillAsTriangular<T>(0., 0, 0, _m, 'l');
HHsequence hhSeg(qr._qr, qr._coeffs, 'u'); HHsequence hhSeg(qr._qr, qr._coeffs, 'u');
@ -288,7 +288,7 @@ void JacobiSVD<T>::evalData(const NDArray& matrix) {
auto matrixT = matrix.transpose(); auto matrixT = matrix.transpose();
HHcolPivQR qr(matrixT / scale); HHcolPivQR qr(matrixT / scale);
_m.assign(qr._qr({0,_rows, 0,_rows})); _m.assign(qr._qr({0,_rows, 0,_rows}));
_m.fillAsTriangular<T>(0., 0, 0, 'l'); _m.fillAsTriangular<T>(0., 0, 0, _m, 'l');
_m.transposei(); _m.transposei();
HHsequence hhSeg(qr._qr, qr._coeffs, 'u'); // type = 'u' is not mistake here ! HHsequence hhSeg(qr._qr, qr._coeffs, 'u'); // type = 'u' is not mistake here !
@ -305,7 +305,7 @@ void JacobiSVD<T>::evalData(const NDArray& matrix) {
} }
else { else {
_m.assign(matrix({0,_diagSize, 0,_diagSize}) / scale); _m.assign(static_cast<const NDArray&>(matrix({0,_diagSize, 0,_diagSize})) / scale);
if(_calcU) if(_calcU)
_u.setIdentity(); _u.setIdentity();
@ -366,7 +366,7 @@ void JacobiSVD<T>::evalData(const NDArray& matrix) {
_s.p(i, math::nd4j_abs<T>(_m.e<T>(i,i))); _s.p(i, math::nd4j_abs<T>(_m.e<T>(i,i)));
if(_calcU && _m.e<T>(i,i) < (T)0.) { if(_calcU && _m.e<T>(i,i) < (T)0.) {
auto temp = _u({0,0, i,i+1}, true); auto temp = _u({0,0, i,i+1}, true);
temp.applyTransform(transform::Neg, &temp, nullptr); temp.applyTransform(transform::Neg, temp, nullptr);
} }
} }

View File

@ -223,26 +223,26 @@ void SVD<T>::deflation(int col1, int col2, int ind, int row1W, int col1W, int sh
const T almostZero = DataTypeUtils::min<T>(); const T almostZero = DataTypeUtils::min<T>();
T maxElem; T maxElem;
if(len == 1) if(len == 1)
maxElem = math::nd4j_abs<T>(diagInterval->template e<T>(0)); maxElem = math::nd4j_abs<T>(diagInterval.template e<T>(0));
else else
maxElem = (*diagInterval)({1,-1, 0,0}, true).reduceNumber(reduce::AMax).template e<T>(0); maxElem = diagInterval({1,-1, 0,0}, true).reduceNumber(reduce::AMax).template e<T>(0);
T maxElem0 = colVec0->reduceNumber(reduce::AMax).template e<T>(0); T maxElem0 = colVec0->reduceNumber(reduce::AMax).template e<T>(0);
T eps = math::nd4j_max<T>(almostZero, DataTypeUtils::eps<T>() * maxElem); T eps = math::nd4j_max<T>(almostZero, DataTypeUtils::eps<T>() * maxElem);
T epsBig = (T)8. * DataTypeUtils::eps<T>() * math::nd4j_max<T>(maxElem0, maxElem); T epsBig = (T)8. * DataTypeUtils::eps<T>() * math::nd4j_max<T>(maxElem0, maxElem);
if(diagInterval->template e<T>(0) < epsBig) if(diagInterval.template e<T>(0) < epsBig)
diagInterval->p(Nd4jLong(0), epsBig); diagInterval.p(Nd4jLong(0), epsBig);
for(int i=1; i < len; ++i) for(int i=1; i < len; ++i)
if(math::nd4j_abs<T>(colVec0->template e<T>(i)) < eps) if(math::nd4j_abs<T>(colVec0->template e<T>(i)) < eps)
colVec0->p(i, 0.f); colVec0->p(i, 0.f);
for(int i=1; i < len; i++) for(int i=1; i < len; i++)
if(diagInterval->template e<T>(i) < epsBig) { if(diagInterval.template e<T>(i) < epsBig) {
deflation1(col1, shift, i, len); deflation1(col1, shift, i, len);
for(int i = 0; i < len; ++i) for(int i = 0; i < len; ++i)
diagInterval->p(i, _m.e<T>(col1+shift+i,col1+shift+i)); diagInterval.p(i, _m.e<T>(col1+shift+i,col1+shift+i));
} }
{ {
@ -261,7 +261,7 @@ void SVD<T>::deflation(int col1, int col2, int ind, int row1W, int col1W, int sh
int p = 1; int p = 1;
for(int i=1; i<len; ++i) for(int i=1; i<len; ++i)
if(math::nd4j_abs<T>(diagInterval->template e<T>(i)) < almostZero) if(math::nd4j_abs<T>(diagInterval.template e<T>(i)) < almostZero)
permut[p++] = i; permut[p++] = i;
int k = 1, m = ind+1; int k = 1, m = ind+1;
@ -271,7 +271,7 @@ void SVD<T>::deflation(int col1, int col2, int ind, int row1W, int col1W, int sh
permut[p] = m++; permut[p] = m++;
else if(m >= len) else if(m >= len)
permut[p] = k++; permut[p] = k++;
else if(diagInterval->template e<T>(k) < diagInterval->template e<T>(m)) else if(diagInterval.template e<T>(k) < diagInterval.template e<T>(m))
permut[p] = m++; permut[p] = m++;
else else
permut[p] = k++; permut[p] = k++;
@ -281,7 +281,7 @@ void SVD<T>::deflation(int col1, int col2, int ind, int row1W, int col1W, int sh
if(totDefl) { if(totDefl) {
for(int i=1; i<len; ++i) { for(int i=1; i<len; ++i) {
int ki = permut[i]; int ki = permut[i];
if(math::nd4j_abs<T>(diagInterval->template e<T>(ki)) < almostZero || diagInterval->template e<T>(0) < diagInterval->template e<T>(ki)) if(math::nd4j_abs<T>(diagInterval.template e<T>(ki)) < almostZero || diagInterval.template e<T>(0) < diagInterval.template e<T>(ki))
permut[i-1] = permut[i]; permut[i-1] = permut[i];
else { else {
permut[i-1] = 0; permut[i-1] = 0;
@ -303,10 +303,10 @@ void SVD<T>::deflation(int col1, int col2, int ind, int row1W, int col1W, int sh
const int ki = permut[len - (totDefl ? i+1 : i)]; const int ki = permut[len - (totDefl ? i+1 : i)];
const int jac = tCol[ki]; const int jac = tCol[ki];
T _e0 = diagInterval->template e<T>(jac); T _e0 = diagInterval.template e<T>(jac);
//math::nd4j_swap<T>(diagInterval)(i), (*diagInterval)(jac)); //math::nd4j_swap<T>(diagInterval)(i), (*diagInterval)(jac));
diagInterval->p(jac, diagInterval->template e<T>(i)); diagInterval.p(jac, diagInterval.template e<T>(i));
diagInterval->p(i, _e0); diagInterval.p(i, _e0);
if(i!=0 && jac!=0) { if(i!=0 && jac!=0) {
_e0 = colVec0->template e<T>(jac); _e0 = colVec0->template e<T>(jac);
@ -315,9 +315,8 @@ void SVD<T>::deflation(int col1, int col2, int ind, int row1W, int col1W, int sh
colVec0->p(i, _e0); colVec0->p(i, _e0);
} }
NDArray* temp1 = nullptr, *temp2 = nullptr;
if (_calcU) { if (_calcU) {
auto temp1 = _u({col1,col1+len+1, col1+i, col1+i+1}, true); auto temp1 = _u({col1,col1+len+1, col1+i, col1+i+1}, true);
auto temp2 = _u({col1,col1+len+1, col1+jac,col1+jac+1}, true); auto temp2 = _u({col1,col1+len+1, col1+jac,col1+jac+1}, true);
auto temp3 = temp1; auto temp3 = temp1;
temp1.assign(temp2); temp1.assign(temp2);
@ -352,12 +351,12 @@ void SVD<T>::deflation(int col1, int col2, int ind, int row1W, int col1W, int sh
{ {
int i = len-1; int i = len-1;
while(i > 0 && (math::nd4j_abs<T>(diagInterval->template e<T>(i)) < almostZero || math::nd4j_abs<T>(colVec0->template e<T>(i)) < almostZero)) while(i > 0 && (math::nd4j_abs<T>(diagInterval.template e<T>(i)) < almostZero || math::nd4j_abs<T>(colVec0->template e<T>(i)) < almostZero))
--i; --i;
for(; i > 1; --i) { for(; i > 1; --i) {
if( (diagInterval->template e<T>(i) - diagInterval->template e<T>(i-1)) < DataTypeUtils::eps<T>()*maxElem ) { if( (diagInterval.template e<T>(i) - diagInterval.template e<T>(i-1)) < DataTypeUtils::eps<T>()*maxElem ) {
if (math::nd4j_abs<T>(diagInterval->template e<T>(i) - diagInterval->template e<T>(i-1)) >= epsBig) if (math::nd4j_abs<T>(diagInterval.template e<T>(i) - diagInterval.template e<T>(i-1)) >= epsBig)
throw std::runtime_error("ops::helpers::SVD::deflation: diagonal elements are not properly sorted !"); throw std::runtime_error("ops::helpers::SVD::deflation: diagonal elements are not properly sorted !");
deflation2(col1, col1 + shift, row1W, col1W, i-1, i, len); deflation2(col1, col1 + shift, row1W, col1W, i-1, i, len);
} }
@ -365,7 +364,6 @@ void SVD<T>::deflation(int col1, int col2, int ind, int row1W, int col1W, int sh
} }
delete colVec0; delete colVec0;
delete diagInterval;
} }
@ -609,9 +607,7 @@ void SVD<T>::calcBlockSVD(int col1, int size, NDArray& U, NDArray& singVals, NDA
const T almostZero = DataTypeUtils::min<T>(); const T almostZero = DataTypeUtils::min<T>();
auto col0 = _m({col1, col1+size, col1, col1+1}, true); auto col0 = _m({col1, col1+size, col1, col1+1}, true);
auto diagP = _m({col1, col1+size, col1, col1+size}, true).diagonal('c'); auto diag = static_cast<const NDArray&>(_m({col1, col1+size, col1, col1+size}, true).diagonal('c'));
auto diag = *diagP;
delete diagP;
diag.p(Nd4jLong(0), T(0)); diag.p(Nd4jLong(0), T(0));
singVals = NDArrayFactory::create<T>(_m.ordering(), {size, 1}, _m.getContext()); singVals = NDArrayFactory::create<T>(_m.ordering(), {size, 1}, _m.getContext());
@ -730,8 +726,7 @@ void SVD<T>::DivideAndConquer(int col1, int col2, int row1W, int col1W, int shif
auto temp = _m({col1+shift,col1+shift+n+1, col1+shift,col1+shift+n}, true); auto temp = _m({col1+shift,col1+shift+n+1, col1+shift,col1+shift+n}, true);
temp.assign(0.); temp.assign(0.);
auto diag = _m.diagonal('c'); auto diag = _m.diagonal('c');
(*diag)({col1+shift, col1+shift+n, 0,0}, true).assign(jac._s({0,n, 0,0}, true)); diag({col1+shift, col1+shift+n, 0,0}, true).assign(jac._s({0,n, 0,0}, true));
delete diag;
return; return;
} }
@ -762,11 +757,6 @@ void SVD<T>::DivideAndConquer(int col1, int col2, int row1W, int col1W, int shif
f.assign(_u({0,1, col1+k+1, col1+n}, true)); f.assign(_u({0,1, col1+k+1, col1+n}, true));
} }
// UofSVD.printIndexedBuffer();
// VofSVD.printIndexedBuffer();
// singVals.printIndexedBuffer();
// printf("!! \n");
if (_calcV) if (_calcV)
_v.p(row1W+k, col1W, 1.f); _v.p(row1W+k, col1W, 1.f);
@ -789,14 +779,10 @@ void SVD<T>::DivideAndConquer(int col1, int col2, int row1W, int col1W, int shif
temp.assign(_u({col1, col1+k+1, i, i+1}, true)); temp.assign(_u({col1, col1+k+1, i, i+1}, true));
} }
auto temp1 = _u({col1,col1+k+1, col1,col1+1}, true); _u({col1,col1+k+1, col1,col1+1}, true).assign(q1 * c0);
temp1.assign(q1 * c0); _u({col1,col1+k+1, col2+1,col2+2}, true).assign(q1 * (-s0));
auto temp2 = _u({col1,col1+k+1, col2+1,col2+2}, true); _u({col1+k+1,col1+n+1, col1, col1+1}, true).assign(static_cast<const NDArray&>(_u({col1+k+1, col1+n+1, col2+1, col2+2}, true)) * s0);
temp2.assign(q1 * (-s0)); _u({col1+k+1,col1+n+1, col2+1,col2+2}, true) *= c0;
auto temp3 = _u({col1+k+1,col1+n+1, col1, col1+1}, true);
temp3.assign(_u({col1+k+1, col1+n+1, col2+1, col2+2}, true) * s0);
auto temp4 =_u({col1+k+1,col1+n+1, col2+1,col2+2}, true);
temp4 *= c0;
} }
else { else {
@ -844,8 +830,7 @@ void SVD<T>::DivideAndConquer(int col1, int col2, int row1W, int col1W, int shif
auto blockM = _m({col1+shift,col1+shift+n, col1+shift,col1+shift+n}, true); auto blockM = _m({col1+shift,col1+shift+n, col1+shift,col1+shift+n}, true);
blockM = 0.f; blockM = 0.f;
auto diag = blockM.diagonal('c'); auto diag = blockM.diagonal('c');
diag->assign(singVals); diag.assign(singVals);
delete diag;
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////

View File

@ -285,17 +285,17 @@ NDArray* MmulHelper::mmulMxM(const NDArray* A, const NDArray* B, NDArray* C, dou
bool cNcont = N == 1 || C->strideAt(1) == 1; bool cNcont = N == 1 || C->strideAt(1) == 1;
if(!aMcont && !aKcont) { if(!aMcont && !aKcont) {
pA = A->dup('f'); pA = new NDArray(A->dup('f'));
toDelete.push_back(pA); toDelete.push_back(pA);
aMcont = true; aMcont = true;
} }
if(!bKcont && !bNcont) { if(!bKcont && !bNcont) {
pB = B->dup('f'); pB = new NDArray(B->dup('f'));
toDelete.push_back(pB); toDelete.push_back(pB);
bKcont = true; bKcont = true;
} }
if(!cMcont) { if(!cMcont) {
pC = C->dup('f'); pC = new NDArray(C->dup('f'));
toDelete.push_back(pC); toDelete.push_back(pC);
cMcont = true; cMcont = true;
} }
@ -418,7 +418,7 @@ NDArray* MmulHelper::mmulMxV(const NDArray* A, const NDArray* X, nd4j::NDArray*
bool aNcont = N == 1 || A->strideAt(1) == 1; bool aNcont = N == 1 || A->strideAt(1) == 1;
if(!aMcont && !aNcont) { if(!aMcont && !aNcont) {
pA = A->dup('f'); pA = new NDArray(A->dup('f'));
aMcont = true; aMcont = true;
} }
@ -866,12 +866,12 @@ NDArray* MmulHelper::mmulNxNold2(const NDArray* A, const NDArray* B, NDArray* C,
bool cNcont = N == 1 || C->strideAt(-1) == 1; bool cNcont = N == 1 || C->strideAt(-1) == 1;
if(!aMcont && !aKcont) { if(!aMcont && !aKcont) {
pA = A->dup('c'); pA = new NDArray(A->dup('c'));
toDelete.push_back(pA); toDelete.push_back(pA);
aKcont = true; aKcont = true;
} }
if(!bKcont && !bNcont) { if(!bKcont && !bNcont) {
pB = B->dup('c'); pB = new NDArray(B->dup('c'));
toDelete.push_back(pB); toDelete.push_back(pB);
bNcont = true; bNcont = true;
} }

View File

@ -82,7 +82,7 @@ namespace nd4j {
// now we actually apply quantization // now we actually apply quantization
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e += increment) { for (auto e = start; e < stop; e += increment) {
rz[e] = static_cast<char>(nd4j::math::nd4j_round<float, char>(1.0f * x[e] / nd4j::math::nd4j_max<float>(amax, amin) * max_byte)); rz[e] = static_cast<char>(nd4j::math::nd4j_round<float, char>( 1.0f * static_cast<float>(x[e]) / nd4j::math::nd4j_max<float>(amax, amin) * max_byte));
} }
}; };
@ -180,7 +180,7 @@ PRAGMA_OMP_ATOMIC_ARGS(write)
for (auto e = start; e < stop; e += increment) { for (auto e = start; e < stop; e += increment) {
int el = x[e]; int el = x[e];
int ael = nd4j::math::nd4j_abs<int>(el) - 1; int ael = nd4j::math::nd4j_abs<int>(el) - 1;
z[ael] += el > 0 ? threshold : -threshold; z[ael] += el > 0 ? static_cast<T>(threshold) : static_cast<T>(-threshold);
} }
}; };

View File

@ -32,21 +32,19 @@ namespace nd4j {
REQUIRE_TRUE(x->isR(), 0, "CRELU: input must be real type"); REQUIRE_TRUE(x->isR(), 0, "CRELU: input must be real type");
auto tmp = x->dup(); auto tmp = x->dup();
tmp->applyTransform(nd4j::transform::Neg, nullptr, nullptr); tmp.applyTransform(nd4j::transform::Neg, tmp);
auto z = OUTPUT_VARIABLE(0); auto z = OUTPUT_VARIABLE(0);
helpers::concat(block.launchContext(), {x, tmp}, *z, x->rankOf()-1); helpers::concat(block.launchContext(), {x, &tmp}, *z, x->rankOf()-1);
// NDArrayFactory<T>::concat({x, tmp}, -1, z); // NDArrayFactory<T>::concat({x, tmp}, -1, z);
// TODO: make this configurable? // TODO: make this configurable?
double threshold = 0.0; double threshold = 0.0;
z->applyScalar(nd4j::scalar::RELU, threshold); z->applyScalar(nd4j::scalar::RELU, threshold, *z);
STORE_RESULT(z); STORE_RESULT(z);
delete tmp;
return Status::OK(); return Status::OK();
} }
@ -61,7 +59,7 @@ namespace nd4j {
std::vector<Nd4jLong> shape; std::vector<Nd4jLong> shape;
for (int e = 0; e < shape::rank(inShape); e++) for (int e = 0; e < shape::rank(inShape); e++)
shape.emplace_back(shape::shapeOf(inShape)[e]); shape.emplace_back(shape::shapeOf(inShape)[e]);
shape[shape.size()-1] *= 2; shape[shape.size()-1] *= 2;
auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShape), shape::order(inShape), shape); auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShape), shape::order(inShape), shape);
@ -94,7 +92,7 @@ namespace nd4j {
auto pos = dec->at(0); auto pos = dec->at(0);
auto neg = dec->at(1); auto neg = dec->at(1);
pos->applyPairwiseTransform(nd4j::pairwise::Subtract, neg, epsilon, nullptr); pos->applyPairwiseTransform(nd4j::pairwise::Subtract, *neg, *epsilon);
delete tmpResult; delete tmpResult;
delete dec; delete dec;

View File

@ -31,9 +31,9 @@ namespace nd4j {
auto input = INPUT_VARIABLE(0); auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0);
input->applyTransform(nd4j::transform::Cube, output, nullptr); input->applyTransform(nd4j::transform::Cube, *output);
STORE_RESULT(output); STORE_RESULT(output);
return Status::OK(); return Status::OK();
} }

View File

@ -32,7 +32,7 @@ namespace nd4j {
const auto alpha = block.numT() > 0 ? T_ARG(0) : 1.f; const auto alpha = block.numT() > 0 ? T_ARG(0) : 1.f;
input->applyScalar(nd4j::scalar::ELU, alpha, output); input->applyScalar(nd4j::scalar::ELU, alpha, *output);
return Status::OK(); return Status::OK();
} }

View File

@ -30,9 +30,9 @@ namespace nd4j {
auto input = INPUT_VARIABLE(0); auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0);
input->applyTransform(nd4j::transform::HardSigmoid, output, nullptr); input->applyTransform(nd4j::transform::HardSigmoid, *output);
STORE_RESULT(output); STORE_RESULT(output);
return Status::OK(); return Status::OK();
} }

View File

@ -30,9 +30,9 @@ namespace nd4j {
auto input = INPUT_VARIABLE(0); auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0);
input->applyTransform(nd4j::transform::HardTanh, output, nullptr); input->applyTransform(nd4j::transform::HardTanh, *output);
STORE_RESULT(output); STORE_RESULT(output);
return Status::OK(); return Status::OK();
} }

View File

@ -30,7 +30,7 @@ namespace nd4j {
auto z = this->getZ(block); auto z = this->getZ(block);
// just for lulz // just for lulz
first->applyTransform(nd4j::transform::Identity, z, nullptr); first->applyTransform(nd4j::transform::Identity, *z);
STORE_RESULT(*z); STORE_RESULT(*z);

View File

@ -33,7 +33,7 @@ namespace nd4j {
auto x = INPUT_VARIABLE(i); auto x = INPUT_VARIABLE(i);
auto z = OUTPUT_VARIABLE(i); auto z = OUTPUT_VARIABLE(i);
x->applyTransform(transform::Identity, z, nullptr); x->applyTransform(transform::Identity, *z);
} }
} }

View File

@ -31,7 +31,7 @@ namespace nd4j {
float alpha = block.numT() > 0 ? T_ARG(0) : 0.01f; float alpha = block.numT() > 0 ? T_ARG(0) : 0.01f;
input->applyScalar(nd4j::scalar::LeakyRELU, alpha, output); input->applyScalar(nd4j::scalar::LeakyRELU, alpha, *output);
STORE_RESULT(output); STORE_RESULT(output);
return Status::OK(); return Status::OK();

View File

@ -30,9 +30,9 @@ namespace nd4j {
auto input = INPUT_VARIABLE(0); auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0);
input->applyTransform(nd4j::transform::RationalTanh, output, nullptr); input->applyTransform(nd4j::transform::RationalTanh, *output);
STORE_RESULT(output); STORE_RESULT(output);
return Status::OK(); return Status::OK();
} }

View File

@ -30,9 +30,9 @@ namespace nd4j {
auto input = INPUT_VARIABLE(0); auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0);
input->applyTransform(nd4j::transform::RectifiedTanh, output, nullptr); input->applyTransform(nd4j::transform::RectifiedTanh, *output);
STORE_RESULT(output); STORE_RESULT(output);
return Status::OK(); return Status::OK();
} }

View File

@ -32,7 +32,7 @@ namespace nd4j {
auto scalar = block.numT() > 0 ? block.getTArguments()->at(0) : 0.0; auto scalar = block.numT() > 0 ? block.getTArguments()->at(0) : 0.0;
first->applyScalar(nd4j::scalar::RELU, scalar, z); first->applyScalar(nd4j::scalar::RELU, scalar, *z);
STORE_RESULT(*z); STORE_RESULT(*z);

View File

@ -33,8 +33,8 @@ CONFIGURABLE_OP_IMPL(relu6, 1, 1, true, 1, 0) {
auto input = INPUT_VARIABLE(0); auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0);
input->applyScalar(nd4j::scalar::RELU6, T_ARG(0), output); input->applyScalar(nd4j::scalar::RELU6, T_ARG(0), *output);
return Status::OK(); return Status::OK();
} }

View File

@ -30,7 +30,7 @@ namespace nd4j {
auto first = INPUT_VARIABLE(0); auto first = INPUT_VARIABLE(0);
auto z = OUTPUT_VARIABLE(0); auto z = OUTPUT_VARIABLE(0);
first->applyTransform(nd4j::transform::SELU, z, nullptr); first->applyTransform(nd4j::transform::SELU, *z);
STORE_RESULT(*z); STORE_RESULT(*z);

View File

@ -29,7 +29,7 @@ namespace nd4j {
auto first = INPUT_VARIABLE(0); auto first = INPUT_VARIABLE(0);
auto z = OUTPUT_VARIABLE(0); auto z = OUTPUT_VARIABLE(0);
first->applyTransform(nd4j::transform::Sigmoid, z, nullptr); first->applyTransform(nd4j::transform::Sigmoid, *z);
STORE_RESULT(*z); STORE_RESULT(*z);

View File

@ -30,7 +30,7 @@ namespace nd4j {
auto first = INPUT_VARIABLE(0); auto first = INPUT_VARIABLE(0);
auto z = OUTPUT_VARIABLE(0); auto z = OUTPUT_VARIABLE(0);
first->applyTransform(nd4j::transform::SoftPlus, z, nullptr); first->applyTransform(nd4j::transform::SoftPlus, *z);
STORE_RESULT(*z); STORE_RESULT(*z);

View File

@ -30,7 +30,7 @@ namespace nd4j {
auto first = INPUT_VARIABLE(0); auto first = INPUT_VARIABLE(0);
auto z = OUTPUT_VARIABLE(0); auto z = OUTPUT_VARIABLE(0);
first->applyTransform(nd4j::transform::SoftSign, z, nullptr); first->applyTransform(nd4j::transform::SoftSign, *z);
STORE_RESULT(*z); STORE_RESULT(*z);

View File

@ -30,7 +30,7 @@ namespace nd4j {
auto first = INPUT_VARIABLE(0); auto first = INPUT_VARIABLE(0);
auto z = OUTPUT_VARIABLE(0); auto z = OUTPUT_VARIABLE(0);
first->applyTransform(nd4j::transform::Tanh, z, nullptr); first->applyTransform(nd4j::transform::Tanh, *z);
STORE_RESULT(*z); STORE_RESULT(*z);

View File

@ -34,7 +34,7 @@ namespace nd4j {
BROADCAST_CHECK_EMPTY(x,y,z); BROADCAST_CHECK_EMPTY(x,y,z);
x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::IntOps::IntAnd, pairwise::IntOps::IntAnd, broadcast::IntOps::IntAnd), y, z, false); x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::IntOps::IntAnd, pairwise::IntOps::IntAnd, broadcast::IntOps::IntAnd), *y, *z, false);
return Status::OK(); return Status::OK();
} }

View File

@ -34,7 +34,7 @@ namespace nd4j {
BROADCAST_CHECK_EMPTY(x,y,z); BROADCAST_CHECK_EMPTY(x,y,z);
x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::IntOps::IntOr, pairwise::IntOps::IntOr, broadcast::IntOps::IntOr), y, z, false); x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::IntOps::IntOr, pairwise::IntOps::IntOr, broadcast::IntOps::IntOr), *y, *z, false);
return Status::OK(); return Status::OK();
} }

View File

@ -34,7 +34,7 @@ namespace nd4j {
BROADCAST_CHECK_EMPTY(x,y,z); BROADCAST_CHECK_EMPTY(x,y,z);
x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::IntOps::IntXor, pairwise::IntOps::IntXor, broadcast::IntOps::IntXor), y, z, false); x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::IntOps::IntXor, pairwise::IntOps::IntXor, broadcast::IntOps::IntXor), *y, *z, false);
return Status::OK(); return Status::OK();
} }

View File

@ -34,7 +34,7 @@ namespace nd4j {
BROADCAST_CHECK_EMPTY(x,y,z); BROADCAST_CHECK_EMPTY(x,y,z);
x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::CyclicShiftRight, pairwise::CyclicShiftRight, broadcast::CyclicShiftRight), y, z, false); x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::CyclicShiftRight, pairwise::CyclicShiftRight, broadcast::CyclicShiftRight), *y, *z, false);
return Status::OK(); return Status::OK();
} }

View File

@ -34,7 +34,7 @@ namespace nd4j {
BROADCAST_CHECK_EMPTY(x,y,z); BROADCAST_CHECK_EMPTY(x,y,z);
x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::CyclicShiftLeft, pairwise::CyclicShiftLeft, broadcast::CyclicShiftLeft), y, z, false); x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::CyclicShiftLeft, pairwise::CyclicShiftLeft, broadcast::CyclicShiftLeft), *y, *z, false);
return Status::OK(); return Status::OK();
} }

View File

@ -34,7 +34,7 @@ namespace nd4j {
BROADCAST_CHECK_EMPTY(x,y,z); BROADCAST_CHECK_EMPTY(x,y,z);
x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::ShiftRight, pairwise::ShiftRight, broadcast::ShiftRight), y, z, false); x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::ShiftRight, pairwise::ShiftRight, broadcast::ShiftRight), *y, *z, false);
return Status::OK(); return Status::OK();
} }

View File

@ -34,7 +34,7 @@ namespace nd4j {
BROADCAST_CHECK_EMPTY(x,y,z); BROADCAST_CHECK_EMPTY(x,y,z);
x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::ShiftLeft, pairwise::ShiftLeft, broadcast::ShiftLeft), y, z, false); x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::ShiftLeft, pairwise::ShiftLeft, broadcast::ShiftLeft), *y, *z, false);
return Status::OK(); return Status::OK();
} }

View File

@ -37,14 +37,14 @@ namespace nd4j {
if (block.width() > 2) { if (block.width() > 2) {
auto alpha = INPUT_VARIABLE(2); auto alpha = INPUT_VARIABLE(2);
REQUIRE_TRUE(alpha->isScalar(), 0, "Axpy: alpha argument should be scalar or TArg"); REQUIRE_TRUE(alpha->isScalar(), 0, "Axpy: alpha argument should be scalar or TArg");
} else if (block.getTArguments()->size() > 0) { } else if (block.getTArguments()->size() > 0) {
a = T_ARG(0); a = T_ARG(0);
} }
ExtraArguments arguments({a}); ExtraArguments arguments({a});
y->applyPairwiseTransform(pairwise::Axpy, x, z, &arguments); y->applyPairwiseTransform(pairwise::Axpy, *x, *z, &arguments);
return ND4J_STATUS_OK; return ND4J_STATUS_OK;
} }

View File

@ -33,8 +33,12 @@ CUSTOM_OP_IMPL(svd, 1, 1, false, 0, 3) {
const int rank = x->rankOf(); const int rank = x->rankOf();
REQUIRE_TRUE(rank >= 2 , 0, "SVD OP: the rank of input array must be >=2, but got %i instead!", rank); REQUIRE_TRUE(rank >= 2 , 0, "SVD OP: the rank of input array must be >=2, but got %i instead!", rank);
const bool fullUV = (bool)INT_ARG(0); bool fullUV = (bool)INT_ARG(0);
const bool calcUV = (bool)INT_ARG(1); const bool calcUV = (bool)INT_ARG(1);
if(calcUV == false)
fullUV = false;
const int switchNum = INT_ARG(2); const int switchNum = INT_ARG(2);
// #ifndef __CUDABLAS__ // #ifndef __CUDABLAS__

View File

@ -29,7 +29,7 @@ namespace nd4j {
auto x = INPUT_VARIABLE(0); auto x = INPUT_VARIABLE(0);
auto z = OUTPUT_VARIABLE(0); auto z = OUTPUT_VARIABLE(0);
x->applyTransform(transform::Not, z, nullptr); x->applyTransform(transform::Not, *z);
return Status::OK(); return Status::OK();
} }

View File

@ -70,17 +70,13 @@ namespace nd4j {
auto tadsY = y->allTensorsAlongDimension(dims); auto tadsY = y->allTensorsAlongDimension(dims);
auto tadsZ = z->allTensorsAlongDimension(dims); auto tadsZ = z->allTensorsAlongDimension(dims);
for (int e = 0; e < tadsX->size(); e++) { for (int e = 0; e < tadsX.size(); e++) {
if (!cond->e<bool>(e)) { if (!cond->e<bool>(e)) {
tadsZ->at(e)->assign(tadsY->at(e)); tadsZ.at(e)->assign(tadsY.at(e));
} else { } else {
tadsZ->at(e)->assign(tadsX->at(e)); tadsZ.at(e)->assign(tadsX.at(e));
} }
} }
delete tadsX;
delete tadsY;
delete tadsZ;
} }
} }

View File

@ -59,17 +59,13 @@ namespace nd4j {
auto tadsY = y->allTensorsAlongDimension(dims); auto tadsY = y->allTensorsAlongDimension(dims);
auto tadsZ = z->allTensorsAlongDimension(dims); auto tadsZ = z->allTensorsAlongDimension(dims);
for (int e = 0; e < tadsX->size(); e++) { for (int e = 0; e < tadsX.size(); e++) {
if (!condition->e<bool>(e)) { if (!condition->e<bool>(e)) {
tadsZ->at(e)->assign(tadsY->at(e)); tadsZ.at(e)->assign(tadsY.at(e));
} else { } else {
tadsZ->at(e)->assign(tadsX->at(e)); tadsZ.at(e)->assign(tadsX.at(e));
} }
} }
delete tadsX;
delete tadsY;
delete tadsZ;
} }
} else { } else {
// in this case we return 2D matrix, which basically contains coordinates fo true // in this case we return 2D matrix, which basically contains coordinates fo true

View File

@ -89,16 +89,12 @@ namespace nd4j {
auto tadsY = y->allTensorsAlongDimension(dims); auto tadsY = y->allTensorsAlongDimension(dims);
auto tadsZ = z->allTensorsAlongDimension(dims); auto tadsZ = z->allTensorsAlongDimension(dims);
for (int e = 0; e < tadsX->size(); e++) { for (int e = 0; e < tadsX.size(); e++) {
if (!condition->e<bool>(e)) if (!condition->e<bool>(e))
tadsZ->at(e)->assign(tadsY->at(e)); tadsZ.at(e)->assign(tadsY.at(e));
else else
tadsZ->at(e)->assign(tadsX->at(e)); tadsZ.at(e)->assign(tadsX.at(e));
} }
delete tadsX;
delete tadsY;
delete tadsZ;
} }
} else { } else {
// in this case we return 2D matrix, which basically contains coordinates fo true // in this case we return 2D matrix, which basically contains coordinates fo true

View File

@ -30,7 +30,7 @@ namespace nd4j {
auto x = INPUT_VARIABLE(0); auto x = INPUT_VARIABLE(0);
auto y = INPUT_VARIABLE(1); auto y = INPUT_VARIABLE(1);
auto z = OUTPUT_VARIABLE(0); auto z = OUTPUT_VARIABLE(0);
BROADCAST_CHECK_EMPTY(x,y,z); BROADCAST_CHECK_EMPTY(x,y,z);
auto tZ = BroadcastHelper::broadcastApply(nd4j::BroadcastOpsTuple::Add(), x, y, z); auto tZ = BroadcastHelper::broadcastApply(nd4j::BroadcastOpsTuple::Add(), x, y, z);
@ -82,14 +82,12 @@ namespace nd4j {
if (axisX.size() > 0) { if (axisX.size() > 0) {
auto sum = epsNext->reduceAlongDimension(nd4j::reduce::Sum, axisX); auto sum = epsNext->reduceAlongDimension(nd4j::reduce::Sum, axisX);
gradX->assign(sum); gradX->assign(sum);
delete sum; } else
} else
gradX->assign(epsNext); gradX->assign(epsNext);
if (axisY.size() > 0) { if (axisY.size() > 0) {
auto sum = epsNext->reduceAlongDimension(nd4j::reduce::Sum, axisY); auto sum = epsNext->reduceAlongDimension(nd4j::reduce::Sum, axisY);
gradY->assign(sum); gradY->assign(sum);
delete sum;
} else } else
gradY->assign(epsNext); gradY->assign(epsNext);
} }

View File

@ -39,7 +39,7 @@ namespace nd4j {
else if (tZ != z) { else if (tZ != z) {
OVERWRITE_RESULT(tZ); OVERWRITE_RESULT(tZ);
} }
return ND4J_STATUS_OK; return ND4J_STATUS_OK;
} }
DECLARE_SYN(set, assign); DECLARE_SYN(set, assign);
@ -80,7 +80,6 @@ namespace nd4j {
if (axisY.size() > 0) { if (axisY.size() > 0) {
auto sum = epsNext->reduceAlongDimension(nd4j::reduce::Sum, axisY); auto sum = epsNext->reduceAlongDimension(nd4j::reduce::Sum, axisY);
gradY->assign(sum); gradY->assign(sum);
delete sum;
} else } else
gradY->assign(epsNext); gradY->assign(epsNext);
} }
@ -98,7 +97,7 @@ namespace nd4j {
Nd4jLong *shapeE; Nd4jLong *shapeE;
Nd4jLong *shapeG; Nd4jLong *shapeG;
COPY_SHAPE(x, shapeE); COPY_SHAPE(x, shapeE);
COPY_SHAPE(y, shapeG); COPY_SHAPE(y, shapeG);

View File

@ -28,7 +28,7 @@ namespace nd4j {
namespace ops { namespace ops {
BROADCASTABLE_OP_IMPL(tf_atan2, 0, 0) { BROADCASTABLE_OP_IMPL(tf_atan2, 0, 0) {
auto y = INPUT_VARIABLE(0); auto y = INPUT_VARIABLE(0);
auto x = INPUT_VARIABLE(1); auto x = INPUT_VARIABLE(1);
auto z = OUTPUT_VARIABLE(0); auto z = OUTPUT_VARIABLE(0);
@ -36,8 +36,8 @@ BROADCASTABLE_OP_IMPL(tf_atan2, 0, 0) {
BROADCAST_CHECK_EMPTY(x,y,z); BROADCAST_CHECK_EMPTY(x,y,z);
// auto tZ = BroadcastHelper<T>::template broadcastApply<simdOps::Atan2<T>>(y, x, z); // auto tZ = BroadcastHelper<T>::template broadcastApply<simdOps::Atan2<T>>(y, x, z);
x->applyTrueBroadcast(nd4j::BroadcastOpsTuple::custom(scalar::Atan2, pairwise::Atan2, broadcast::Atan2), y, z, true); x->applyTrueBroadcast(nd4j::BroadcastOpsTuple::custom(scalar::Atan2, pairwise::Atan2, broadcast::Atan2), *y, *z, true);
// if (tZ == nullptr) // if (tZ == nullptr)
// return ND4J_STATUS_KERNEL_FAILURE; // return ND4J_STATUS_KERNEL_FAILURE;
// else if (tZ != z) { // else if (tZ != z) {

View File

@ -81,7 +81,7 @@ namespace nd4j {
// Y gradient // Y gradient
//epsNext->applyTriplewiseLambda(x, y, lambdaY, gradY); //epsNext->applyTriplewiseLambda(x, y, lambdaY, gradY);
gradY->assign((*epsNext) * (*x) / ((*y) * (*y))); gradY->assign((*epsNext) * (*x) / ((*y) * (*y)));
gradY->applyTransform(transform::Neg, nullptr, nullptr); gradY->applyTransform(transform::Neg, *gradY);
} else if (y->isScalar()) { } else if (y->isScalar()) {
// scalar case // scalar case
@ -91,17 +91,17 @@ namespace nd4j {
//tmpX.printBuffer("SumX"); //tmpX.printBuffer("SumX");
//tmp.printBuffer("Sum Eps"); //tmp.printBuffer("Sum Eps");
gradY->assign(tmp * tmpX / ((*y) * (*y))); gradY->assign(tmp * tmpX / ((*y) * (*y)));
gradY->applyTransform(transform::Neg, nullptr, nullptr); gradY->applyTransform(transform::Neg, *gradY);
//epsNext->applyLambda(lambdaS, gradX); //epsNext->applyLambda(lambdaS, *gradX);
epsNext->applyScalarArr(scalar::Divide, y, gradX, nullptr); epsNext->applyScalarArr(scalar::Divide, *y, *gradX);
} else { } else {
// broadcast case // broadcast case
auto preX = *epsNext / *y; auto preX = *epsNext / *y;
NDArray negX(*x); NDArray negX(*x);
x->applyTransform(transform::Neg, &negX); x->applyTransform(transform::Neg, negX);
auto preY = *epsNext * negX / ((*y) * (*y)); auto preY = *epsNext * negX / ((*y) * (*y));
auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), epsNext->shapeInfo()); auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), epsNext->shapeInfo());
@ -110,14 +110,12 @@ namespace nd4j {
if (axisX.size() > 0) { if (axisX.size() > 0) {
auto sum = preX.reduceAlongDimension(reduce::Sum, axisX); auto sum = preX.reduceAlongDimension(reduce::Sum, axisX);
gradX->assign(sum); gradX->assign(sum);
delete sum; } else
} else
gradX->assign(preX); gradX->assign(preX);
if (axisY.size() > 0) { if (axisY.size() > 0) {
auto sum = preY.reduceAlongDimension(reduce::Sum, axisY); auto sum = preY.reduceAlongDimension(reduce::Sum, axisY);
gradY->assign(sum); gradY->assign(sum);
delete sum;
} else } else
gradY->assign(preY); gradY->assign(preY);
} }

View File

@ -69,7 +69,7 @@ namespace nd4j {
std::unique_ptr<ResultSet> tmpResult(op.execute({x, y}, {}, {}, {})); std::unique_ptr<ResultSet> tmpResult(op.execute({x, y}, {}, {}, {}));
if (gradY->rankOf() == gradX->rankOf()) if (gradY->rankOf() == gradX->rankOf())
epsNext->applyPairwiseTransform(pairwise::Multiply, tmpResult->at(0), gradY, nullptr); epsNext->applyPairwiseTransform(pairwise::Multiply, *tmpResult->at(0), *gradY);
else // epsNext is greater than gradY else // epsNext is greater than gradY
{ {
std::vector<Nd4jLong> dims(epsNext->rankOf() * 2); std::vector<Nd4jLong> dims(epsNext->rankOf() * 2);
@ -78,7 +78,7 @@ namespace nd4j {
dims[d * 2 + 1] = 1; dims[d * 2 + 1] = 1;
} }
auto tempIn((*tmpResult->at(0))(dims)); auto tempIn((*tmpResult->at(0))(dims));
(*epsNext)(dims).applyPairwiseTransform(pairwise::Multiply, &tempIn, gradY, nullptr); (*epsNext)(dims).applyPairwiseTransform(pairwise::Multiply, tempIn, *gradY);
} }
return Status::OK(); return Status::OK();
} }

View File

@ -79,42 +79,42 @@ CUSTOM_OP_IMPL(multiply_bp, 3, 2, false, 0, 0) {
const Nd4jLong yLen = y->lengthOf(); const Nd4jLong yLen = y->lengthOf();
if(x->isScalar() && y->isScalar()) { // both are scalars if(x->isScalar() && y->isScalar()) { // both are scalars
y->applyPairwiseTransform(pairwise::Multiply, dLdz, dLdx, nullptr); y->applyPairwiseTransform(pairwise::Multiply, *dLdz, *dLdx);
x->applyPairwiseTransform(pairwise::Multiply, dLdz, dLdy, nullptr); x->applyPairwiseTransform(pairwise::Multiply, *dLdz, *dLdy);
//dLdx->assign((*y) * (*dLdz)); //dLdx->assign((*y) * (*dLdz));
//dLdy->assign((*x) * (*dLdz)); //dLdy->assign((*x) * (*dLdz));
} }
else if(x->isScalar()) { // x is scalar and y is not else if(x->isScalar()) { // x is scalar and y is not
dLdx->assign((*y * *dLdz).reduceNumber(reduce::Sum)); dLdx->assign((*y * *dLdz).reduceNumber(reduce::Sum));
dLdz->applyScalarArr(scalar::Multiply, x, dLdy, nullptr); dLdz->applyScalarArr(scalar::Multiply, *x, *dLdy);
//dLdz->applyTrueBroadcast(broadcast::Multiply, x, dLdy, true); //dLdz->applyTrueBroadcast(broadcast::Multiply, x, dLdy, true);
} }
else if(y->isScalar()) { // y is scalar and x is not else if(y->isScalar()) { // y is scalar and x is not
dLdy->assign((*x * *dLdz).reduceNumber(reduce::Sum)); dLdy->assign((*x * *dLdz).reduceNumber(reduce::Sum));
dLdz->applyScalarArr(scalar::Multiply, y, dLdx); dLdz->applyScalarArr(scalar::Multiply, *y, *dLdx);
} }
else if(x->isSameShape(y)) { else if(x->isSameShape(y)) {
x->applyPairwiseTransform(pairwise::Multiply, dLdz, dLdy, nullptr); x->applyPairwiseTransform(pairwise::Multiply, *dLdz, *dLdy);
y->applyPairwiseTransform(pairwise::Multiply, dLdz, dLdx, nullptr); y->applyPairwiseTransform(pairwise::Multiply, *dLdz, *dLdx);
} }
else if (x->isSameShape(dLdz)) { else if (x->isSameShape(dLdz)) {
auto yTiled = NDArray(dLdz, false, block.launchContext()); auto yTiled = NDArray(dLdz, false, block.launchContext());
y->tile(yTiled); y->tile(yTiled);
std::vector<int> axesForY = ShapeUtils::evalBroadcastBackwardAxis(y->getShapeInfo(), dLdz->getShapeInfo()); std::vector<int> axesForY = ShapeUtils::evalBroadcastBackwardAxis(y->getShapeInfo(), dLdz->getShapeInfo());
dLdy->assign( (*x * *dLdz).reduceAlongDims(reduce::Sum, axesForY) ); dLdy->assign( (*x * *dLdz).reduceAlongDimension(reduce::Sum, axesForY) );
yTiled.applyPairwiseTransform(pairwise::Multiply, dLdz, dLdx, nullptr); yTiled.applyPairwiseTransform(pairwise::Multiply, *dLdz, *dLdx);
} }
else if (y->isSameShape(dLdz)) { else if (y->isSameShape(dLdz)) {
auto xTiled = NDArray(dLdz, false, block.launchContext()); auto xTiled = NDArray(dLdz, false, block.launchContext());
x->tile(xTiled); x->tile(xTiled);
std::vector<int> axesForX = ShapeUtils::evalBroadcastBackwardAxis(x->getShapeInfo(), dLdz->getShapeInfo()); std::vector<int> axesForX = ShapeUtils::evalBroadcastBackwardAxis(x->getShapeInfo(), dLdz->getShapeInfo());
dLdx->assign( (*y * *dLdz).reduceAlongDims(reduce::Sum, axesForX) ); dLdx->assign( (*y * *dLdz).reduceAlongDimension(reduce::Sum, axesForX) );
xTiled.applyPairwiseTransform(pairwise::Multiply, dLdz, dLdy, nullptr); xTiled.applyPairwiseTransform(pairwise::Multiply, *dLdz, *dLdy);
} }
else { else {
@ -124,16 +124,16 @@ CUSTOM_OP_IMPL(multiply_bp, 3, 2, false, 0, 0) {
y->tile(yTiled); y->tile(yTiled);
std::vector<int> axesForX = ShapeUtils::evalBroadcastBackwardAxis(x->getShapeInfo(), dLdz->getShapeInfo()); std::vector<int> axesForX = ShapeUtils::evalBroadcastBackwardAxis(x->getShapeInfo(), dLdz->getShapeInfo());
std::vector<int> axesForY = ShapeUtils::evalBroadcastBackwardAxis(y->getShapeInfo(), dLdz->getShapeInfo()); std::vector<int> axesForY = ShapeUtils::evalBroadcastBackwardAxis(y->getShapeInfo(), dLdz->getShapeInfo());
dLdx->assign( (*y * *dLdz).reduceAlongDims(reduce::Sum, axesForX) ); dLdx->assign( (*y * *dLdz).reduceAlongDimension(reduce::Sum, axesForX) );
dLdy->assign( (*x * *dLdz).reduceAlongDims(reduce::Sum, axesForY) ); dLdy->assign( (*x * *dLdz).reduceAlongDimension(reduce::Sum, axesForY) );
} }
return Status::OK(); return Status::OK();
} }
DECLARE_SHAPE_FN(multiply_bp) { DECLARE_SHAPE_FN(multiply_bp) {
auto xShapeInfo = inputShape->at(0); auto xShapeInfo = inputShape->at(0);
auto yShapeInfo = inputShape->at(1); auto yShapeInfo = inputShape->at(1);
@ -181,8 +181,8 @@ DECLARE_SHAPE_FN(multiply_bp) {
T tmpX = x->template reduceNumber<simdOps::Sum<T>>(); T tmpX = x->template reduceNumber<simdOps::Sum<T>>();
gradY->assign(tmpX); gradY->assign(tmpX);
epsNext->applyLambda(lambdaS, gradX); epsNext->applyLambda(lambdaS, *gradX);
} else { } else {
// broadcast case // broadcast case
@ -201,7 +201,7 @@ DECLARE_SHAPE_FN(multiply_bp) {
auto sum = preX->template reduceAlongDimension<simdOps::Sum<T>>(axisX); auto sum = preX->template reduceAlongDimension<simdOps::Sum<T>>(axisX);
gradX->assign(sum); gradX->assign(sum);
delete sum; delete sum;
} else } else
gradX->assign(preX); gradX->assign(preX);
if (axisY.size() > 0) { if (axisY.size() > 0) {

View File

@ -71,7 +71,7 @@ namespace nd4j {
// X gradient // X gradient
//epsNext->applyPairwiseLambda(y, lambdaX, gradX); //epsNext->applyPairwiseLambda(y, lambdaX, gradX);
epsNext->applyPairwiseTransform(pairwise::Divide, y, gradX, nullptr); epsNext->applyPairwiseTransform(pairwise::Divide, *y, *gradX);
// Y gradient // Y gradient
//epsNext->applyTriplewiseLambda(x, y, lambdaY, gradY); //epsNext->applyTriplewiseLambda(x, y, lambdaY, gradY);
@ -84,16 +84,16 @@ namespace nd4j {
auto tmp = epsNext->reduceNumber(reduce::Sum); auto tmp = epsNext->reduceNumber(reduce::Sum);
auto tmpX = x->reduceNumber(reduce::Sum); auto tmpX = x->reduceNumber(reduce::Sum);
gradY->assign(tmp * -tmpX / ((*y) * (*y))); gradY->assign(tmp * -tmpX / ((*y) * (*y)));
//epsNext->applyLambda(lambdaS, gradX); //epsNext->applyLambda(lambdaS, gradX);
epsNext->applyScalarArr(scalar::Divide, y, gradX, nullptr); epsNext->applyScalarArr(scalar::Divide, *y, *gradX);
} else { } else {
// broadcast case // broadcast case
auto preX = *epsNext / *y; auto preX = *epsNext / *y;
NDArray negX(*x); NDArray negX(*x);
x->applyTransform(transform::Neg, &negX); x->applyTransform(transform::Neg, negX);
auto preY = *epsNext * negX / ((*y) * (*y)); auto preY = *epsNext * negX / ((*y) * (*y));
auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), epsNext->shapeInfo()); auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), epsNext->shapeInfo());
@ -102,14 +102,12 @@ namespace nd4j {
if (axisX.size() > 0) { if (axisX.size() > 0) {
auto sum = preX.reduceAlongDimension(reduce::Sum, axisX); auto sum = preX.reduceAlongDimension(reduce::Sum, axisX);
gradX->assign(sum); gradX->assign(sum);
delete sum; } else
} else
gradX->assign(preX); gradX->assign(preX);
if (axisY.size() > 0) { if (axisY.size() > 0) {
auto sum = preY.reduceAlongDimension(reduce::Sum, axisY); auto sum = preY.reduceAlongDimension(reduce::Sum, axisY);
gradY->assign(sum); gradY->assign(sum);
delete sum;
} else } else
gradY->assign(preY); gradY->assign(preY);
} }

View File

@ -34,7 +34,7 @@ namespace nd4j {
BROADCAST_CHECK_EMPTY(x,y,z); BROADCAST_CHECK_EMPTY(x,y,z);
REQUIRE_TRUE(!x->isB(), 0, "REVERSEDIVIDE OP: you can't divide by bool array!"); REQUIRE_TRUE(!x->isB(), 0, "REVERSEDIVIDE OP: you can't divide by bool array!");
x->applyTrueBroadcast(BROADCAST(ReverseDivide), y, z, true); x->applyTrueBroadcast(BROADCAST(ReverseDivide), *y, *z, true);
return Status::OK(); return Status::OK();
} }
@ -67,7 +67,7 @@ namespace nd4j {
// X gradient // X gradient
//epsNext->applyTriplewiseLambda(x, y, lambdaX, gradX); //epsNext->applyTriplewiseLambda(x, y, lambdaX, gradX);
gradX->assign((*epsNext) * (*y) / ((*x) * (*x))); gradX->assign((*epsNext) * (*y) / ((*x) * (*x)));
gradX->applyTransform(transform::Neg, nullptr, nullptr); gradX->applyTransform(transform::Neg, *gradX);
// Y gradient // Y gradient
//epsNext->applyPairwiseLambda(x, lambdaY, gradY); //epsNext->applyPairwiseLambda(x, lambdaY, gradY);
gradY->assign((*epsNext) / (*x)); gradY->assign((*epsNext) / (*x));
@ -78,14 +78,14 @@ namespace nd4j {
gradY->assign(tmp / tmpX); gradY->assign(tmp / tmpX);
gradX->assign((*epsNext) * (*y) / ((*x) * (*x))); gradX->assign((*epsNext) * (*y) / ((*x) * (*x)));
gradX->applyTransform(transform::Neg, nullptr, nullptr); gradX->applyTransform(transform::Neg, *gradX);
} else { } else {
// broadcast case // broadcast case
auto preY = (*epsNext) / (*x); auto preY = (*epsNext) / (*x);
auto preX = *epsNext * (*y) / ((*x) * (*x)); auto preX = *epsNext * (*y) / ((*x) * (*x));
preX.applyTransform(transform::Neg, nullptr, nullptr); preX.applyTransform(transform::Neg, preX);
auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), epsNext->shapeInfo()); auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), epsNext->shapeInfo());
auto axisY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), epsNext->shapeInfo()); auto axisY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), epsNext->shapeInfo());
@ -93,14 +93,12 @@ namespace nd4j {
if (axisX.size() > 0) { if (axisX.size() > 0) {
auto sum = preX.reduceAlongDimension(reduce::Sum, axisX); auto sum = preX.reduceAlongDimension(reduce::Sum, axisX);
gradX->assign(sum); gradX->assign(sum);
delete sum; } else
} else
gradX->assign(preX); gradX->assign(preX);
if (axisY.size() > 0) { if (axisY.size() > 0) {
auto sum = preY.reduceAlongDimension(reduce::Sum, axisY); auto sum = preY.reduceAlongDimension(reduce::Sum, axisY);
gradY->assign(sum); gradY->assign(sum);
delete sum;
} else } else
gradY->assign(preY); gradY->assign(preY);
} }

View File

@ -61,13 +61,13 @@ namespace nd4j {
if (x->isSameShape(y)) { if (x->isSameShape(y)) {
// PWT case case // PWT case case
epsNext->applyTransform(transform::Neg, gradX, nullptr); epsNext->applyTransform(transform::Neg, *gradX);
gradY->assign(epsNext); gradY->assign(epsNext);
} else if (y->isScalar()) { } else if (y->isScalar()) {
// scalar case // scalar case
auto tmp = epsNext->reduceNumber(reduce::Sum); auto tmp = epsNext->reduceNumber(reduce::Sum);
gradY->assign(tmp); gradY->assign(tmp);
epsNext->applyTransform(transform::Neg, gradX, nullptr); epsNext->applyTransform(transform::Neg, *gradX);
} else { } else {
// broadcastable // broadcastable
auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), epsNext->shapeInfo()); auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), epsNext->shapeInfo());
@ -75,20 +75,18 @@ namespace nd4j {
if (axisX.size() > 0) { if (axisX.size() > 0) {
auto sum = epsNext->reduceAlongDimension(reduce::Sum, axisX); auto sum = epsNext->reduceAlongDimension(reduce::Sum, axisX);
sum->applyTransform(transform::Neg, gradX); sum.applyTransform(transform::Neg, *gradX);
delete sum;
} else { } else {
epsNext->applyTransform(transform::Neg, gradX, nullptr); epsNext->applyTransform(transform::Neg, *gradX);
} }
if (axisY.size() > 0) { if (axisY.size() > 0) {
auto sum = epsNext->reduceAlongDimension(reduce::Sum, axisY); auto sum = epsNext->reduceAlongDimension(reduce::Sum, axisY);
gradY->assign(sum); gradY->assign(sum);
delete sum;
} else { } else {
gradY->assign(epsNext); gradY->assign(epsNext);
} }
} }
return Status::OK(); return Status::OK();
} }

View File

@ -87,7 +87,7 @@ namespace nd4j {
// scalar case // scalar case
auto tmpX = x->reduceNumber(reduce::Sum); auto tmpX = x->reduceNumber(reduce::Sum);
gradY->assign(tmpX); gradY->assign(tmpX);
//epsNext->applyPairwiseLambda(x, lambdaS, gradX); //epsNext->applyPairwiseLambda(x, lambdaS, gradX);
gradX->assign((*epsNext) * ts * ((*x) - (*y))); gradX->assign((*epsNext) * ts * ((*x) - (*y)));
} else { } else {
@ -98,37 +98,31 @@ namespace nd4j {
auto targetShape = epsNext->getShapeAsVector(); auto targetShape = epsNext->getShapeAsVector();
preX->tileToShape(targetShape); preX.tileToShape(targetShape, preX);
preY->tileToShape(targetShape); preY.tileToShape(targetShape, preY);
//epsNext->applyTriplewiseLambda(x, y, lambdaX, preX); //epsNext->applyTriplewiseLambda(x, y, lambdaX, preX);
//epsNext->applyTriplewiseLambda(x, y, lambdaY, preY); //epsNext->applyTriplewiseLambda(x, y, lambdaY, preY);
auto resX = (*epsNext) * ts * ((*x) - (*y)); auto resX = (*epsNext) * ts * ((*x) - (*y));
preX->assign(resX); preX.assign(resX);
auto resY = (*epsNext) * ts * ((*y) - (*x)); auto resY = (*epsNext) * ts * ((*y) - (*x));
preY->assign(resY); preY.assign(resY);
auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), epsNext->shapeInfo()); auto axisX = ShapeUtils::evalBroadcastBackwardAxis(x->shapeInfo(), epsNext->shapeInfo());
auto axisY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), epsNext->shapeInfo()); auto axisY = ShapeUtils::evalBroadcastBackwardAxis(y->shapeInfo(), epsNext->shapeInfo());
if (axisX.size() > 0) { if (axisX.size() > 0) {
auto sum = preX->reduceAlongDimension(reduce::Sum, axisX); auto sum = preX.reduceAlongDimension(reduce::Sum, axisX);
gradX->assign(sum); gradX->assign(sum);
delete sum; } else
} else
gradX->assign(preX); gradX->assign(preX);
if (axisY.size() > 0) { if (axisY.size() > 0) {
auto sum = preY->reduceAlongDimension(reduce::Sum, axisY); auto sum = preY.reduceAlongDimension(reduce::Sum, axisY);
gradY->assign(sum); gradY->assign(sum);
delete sum;
} else } else
gradY->assign(preY); gradY->assign(preY);
delete preX;
delete preY;
} }
return Status::OK(); return Status::OK();

View File

@ -62,7 +62,7 @@ namespace nd4j {
if (x->isSameShape(y)) { if (x->isSameShape(y)) {
// PWT case case // PWT case case
epsNext->applyTransform(transform::Neg, gradY, nullptr); epsNext->applyTransform(transform::Neg, *gradY);
gradX->assign(epsNext); gradX->assign(epsNext);
} else if (y->isScalar()) { } else if (y->isScalar()) {
// scalar case // scalar case
@ -77,18 +77,16 @@ namespace nd4j {
if (axisX.size() > 0) { if (axisX.size() > 0) {
auto sum = epsNext->reduceAlongDimension(reduce::Sum, axisX); auto sum = epsNext->reduceAlongDimension(reduce::Sum, axisX);
gradX->assign(sum); gradX->assign(sum);
delete sum; } else
} else
gradX->assign(epsNext); gradX->assign(epsNext);
if (axisY.size() > 0) { if (axisY.size() > 0) {
auto sum = epsNext->reduceAlongDimension(reduce::Sum, axisY); auto sum = epsNext->reduceAlongDimension(reduce::Sum, axisY);
sum->applyTransform(transform::Neg, gradY); sum.applyTransform(transform::Neg, *gradY);
delete sum;
} else { } else {
epsNext->applyTransform(transform::Neg, gradY); epsNext->applyTransform(transform::Neg, *gradY);
} }
} }
return Status::OK(); return Status::OK();
} }

View File

@ -26,7 +26,7 @@ namespace nd4j {
namespace ops { namespace ops {
/** /**
* This operation is, basically IF statement * This operation is, basically IF statement
* *
* arg_0 is our "signal" * arg_0 is our "signal"
* arg_1 is condition that will determine transition * arg_1 is condition that will determine transition
*/ */
@ -41,10 +41,10 @@ namespace nd4j {
// but we'll ensure only one node is active, and other is disabled // but we'll ensure only one node is active, and other is disabled
if (condition->e<int>(0) == 0) { if (condition->e<int>(0) == 0) {
block.setBranch(0); block.setBranch(0);
this->storeResult(block, 0, input->dup()); this->storeResult(block, 0, new NDArray(input->dup()));
} else { } else {
block.setBranch(1); block.setBranch(1);
this->storeResult(block, 1, *input->dup()); this->storeResult(block, 1, new NDArray(input->dup()));
} }
return Status::OK(); return Status::OK();

View File

@ -30,7 +30,7 @@
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
class BroadcastHelper { class BroadcastHelper {
public: public:
static FORCEINLINE NDArray* broadcastApply(nd4j::BroadcastOpsTuple op, NDArray* x, NDArray* y, NDArray* z, ExtraArguments *extraArgs = nullptr) { static FORCEINLINE NDArray* broadcastApply(nd4j::BroadcastOpsTuple op, NDArray* x, NDArray* y, NDArray* z, ExtraArguments *extraArgs = nullptr) {
if(x->isEmpty() || y->isEmpty()) { if(x->isEmpty() || y->isEmpty()) {
@ -42,34 +42,34 @@ namespace nd4j {
std::unique_ptr<NDArray> ptr; std::unique_ptr<NDArray> ptr;
if (!Environment::getInstance()->isExperimentalBuild()) { if (!Environment::getInstance()->isExperimentalBuild()) {
if (y->dataType() != x->dataType()) { if (y->dataType() != x->dataType()) {
y = y->cast(x->dataType()); y = new NDArray(y->cast(x->dataType()));
std::unique_ptr<NDArray> ptr2(y); std::unique_ptr<NDArray> ptr2(y);
ptr.swap(ptr2); ptr.swap(ptr2);
} }
} }
if (!x->isScalar() && !y->isScalar() && x->isSameShape(y)) { if (!x->isScalar() && !y->isScalar() && x->isSameShape(y)) {
x->applyPairwiseTransform(op.p, y, z, nullptr); x->applyPairwiseTransform(op.p, *y, *z);
} else if (!x->isScalar() && y->isScalar()) { } else if (!x->isScalar() && y->isScalar()) {
x->applyScalarArr(op.s, const_cast<const NDArray*>(y), z); x->applyScalarArr(op.s, const_cast<const NDArray&>(*y), *z);
} else if (x->isScalar() && !y->isScalar()) { } else if (x->isScalar() && !y->isScalar()) {
if (z->isSameShape(y)) { if (z->isSameShape(y)) {
if (op.s == scalar::Add || op.s == scalar::Multiply ) { if (op.s == scalar::Add || op.s == scalar::Multiply ) {
y->applyScalarArr(op.s, x, z, nullptr); y->applyScalarArr(op.s, *x, *z);
} else if (op.s == scalar::SquaredSubtract) { } else if (op.s == scalar::SquaredSubtract) {
y->applyScalarArr(scalar::SquaredReverseSubtract, x, z, nullptr); y->applyScalarArr(scalar::SquaredReverseSubtract, *x, *z);
} else if (op.s == scalar::Subtract) { } else if (op.s == scalar::Subtract) {
y->applyScalarArr(scalar::ReverseSubtract, x, z, nullptr); y->applyScalarArr(scalar::ReverseSubtract, *x, *z);
} else if (op.s == scalar::Divide) { } else if (op.s == scalar::Divide) {
y->applyScalarArr(scalar::ReverseDivide, x, z, nullptr); y->applyScalarArr(scalar::ReverseDivide, *x, *z);
} else if (op.s == scalar::Pow) { } else if (op.s == scalar::Pow) {
y->applyScalarArr(scalar::ReversePow, x, z, nullptr); y->applyScalarArr(scalar::ReversePow, *x, *z);
} else if (op.s == scalar::ReverseSubtract) { } else if (op.s == scalar::ReverseSubtract) {
y->applyScalarArr(scalar::Subtract, x, z, nullptr); y->applyScalarArr(scalar::Subtract, *x, *z);
} else if (op.s == scalar::ReverseDivide) { } else if (op.s == scalar::ReverseDivide) {
y->applyScalarArr(scalar::Divide, x, z, nullptr); y->applyScalarArr(scalar::Divide, *x, *z);
} else if (op.s == scalar::MaxPairwise || op.s == scalar::MinPairwise || op.s == scalar::AMaxPairwise || op.s == scalar::AMinPairwise) { } else if (op.s == scalar::MaxPairwise || op.s == scalar::MinPairwise || op.s == scalar::AMaxPairwise || op.s == scalar::AMinPairwise) {
y->applyScalarArr(op.s, x, z, nullptr); y->applyScalarArr(op.s, *x, *z);
} else if (op.s == scalar::CopyPws) { } else if (op.s == scalar::CopyPws) {
z->assign(y); z->assign(y);
} else { } else {
@ -84,9 +84,9 @@ namespace nd4j {
return tZ; return tZ;
} }
} else if (x->isScalar() && y->isScalar()) { // x->isScalar() && y->isScalar() } else if (x->isScalar() && y->isScalar()) { // x->isScalar() && y->isScalar()
x->applyScalarArr(op.s, const_cast<const NDArray*>(y), z, nullptr); x->applyScalarArr(op.s, const_cast<const NDArray&>(*y), *z);
} else if (ShapeUtils::areShapesBroadcastable(*x, *y)) { } else if (ShapeUtils::areShapesBroadcastable(*x, *y)) {
x->applyTrueBroadcast(op, y, z, true, extraArgs); x->applyTrueBroadcast(op, *y, *z, true, extraArgs);
return z; return z;
} else { } else {
auto sx = ShapeUtils::shapeAsString(x); auto sx = ShapeUtils::shapeAsString(x);
@ -107,16 +107,16 @@ namespace nd4j {
} }
if (!x->isScalar() && !y->isScalar() && x->isSameShape(y)) { if (!x->isScalar() && !y->isScalar() && x->isSameShape(y)) {
x->applyPairwiseTransform(op.p, y, z, nullptr); x->applyPairwiseTransform(op.p, *y, *z);
} else if (ShapeUtils::areShapesBroadcastable(*x, *y)) { } else if (ShapeUtils::areShapesBroadcastable(*x, *y)) {
x->applyTrueBroadcast(op, y, z, true, extraArgs); x->applyTrueBroadcast(op, *y, *z, true, extraArgs);
return z; return z;
} else if (!x->isScalar() && y->isScalar()) { } else if (!x->isScalar() && y->isScalar()) {
x->applyScalarArr(op.s, const_cast<const NDArray*>(y), z); x->applyScalarArr(op.s, const_cast<const NDArray&>(*y), *z);
} else if (x->isScalar() && !y->isScalar()) { } else if (x->isScalar() && !y->isScalar()) {
if (z->isSameShape(y)) { if (z->isSameShape(y)) {
//z->assign(x); //z->assign(x);
x->applyPairwiseTransform(op.p, y, z, extraArgs); x->applyPairwiseTransform(op.p, *y, *z, extraArgs);
return z; return z;
} else { } else {
auto v = y->getShapeAsVector(); auto v = y->getShapeAsVector();
@ -125,9 +125,9 @@ namespace nd4j {
return tZ; return tZ;
} }
} else if (x->isScalar() && y->isScalar()) { // x->isScalar() && y->isScalar() } else if (x->isScalar() && y->isScalar()) { // x->isScalar() && y->isScalar()
x->applyScalarArr(op.s, const_cast<const NDArray*>(y), z, nullptr); x->applyScalarArr(op.s, const_cast<const NDArray&>(*y), *z);
} else if (ShapeUtils::areShapesBroadcastable(*x, *y)) { } else if (ShapeUtils::areShapesBroadcastable(*x, *y)) {
x->applyTrueBroadcast(op, y, z, true, extraArgs); x->applyTrueBroadcast(op, *y, *z, true, extraArgs);
return z; return z;
} else { } else {
auto sx = ShapeUtils::shapeAsString(x); auto sx = ShapeUtils::shapeAsString(x);

View File

@ -51,12 +51,12 @@ namespace nd4j {
std::vector<int> axis = ShapeUtils::evalDimsToExclude(array->rankOf(), {0}); std::vector<int> axis = ShapeUtils::evalDimsToExclude(array->rankOf(), {0});
auto tads = array->allTensorsAlongDimension( axis); auto tads = array->allTensorsAlongDimension( axis);
for (int e = 0; e < tads->size(); e++) { for (int e = 0; e < tads.size(); e++) {
auto idx = indices->e<int>(e); auto idx = indices->e<int>(e);
if (idx >= tads->size()) if (idx >= tads.size())
return ND4J_STATUS_BAD_ARGUMENTS; return ND4J_STATUS_BAD_ARGUMENTS;
auto arr = tads->at(e)->dup(array->ordering()); auto arr = new NDArray(tads.at(e)->dup(array->ordering()));
auto res = list->write(idx, arr); auto res = list->write(idx, arr);
if (res != ND4J_STATUS_OK) if (res != ND4J_STATUS_OK)
return res; return res;
@ -65,7 +65,6 @@ namespace nd4j {
if (!hasList) if (!hasList)
//OVERWRITE_RESULT(list); //OVERWRITE_RESULT(list);
setupResultList(list, block); setupResultList(list, block);
delete tads;
return Status::OK(); return Status::OK();
} }

View File

@ -55,7 +55,7 @@ namespace nd4j {
std::vector<Nd4jLong> indices(2 * array->rankOf(), 0); std::vector<Nd4jLong> indices(2 * array->rankOf(), 0);
for (Nd4jLong e = 0; e < sizes->lengthOf(); e++) { for (Nd4jLong e = 0; e < sizes->lengthOf(); e++) {
int c_size = sizes->e<int>(e); int c_size = sizes->e<int>(e);
REQUIRE_TRUE(c_size > 0, 0, "Slice size should have postive value, but got %i instead", c_size); REQUIRE_TRUE(c_size > 0, 0, "Slice size should have postive value, but got %i instead", c_size);
REQUIRE_TRUE(cnt < array->sizeAt(0) && cnt + c_size <= array->sizeAt(0), 0, "Slices size should NOT be higher then number of TADs of source array. Source size: [%i]; Slice start: [%i]; Slice size: [%i]", array->sizeAt(0), cnt, c_size); REQUIRE_TRUE(cnt < array->sizeAt(0) && cnt + c_size <= array->sizeAt(0), 0, "Slices size should NOT be higher then number of TADs of source array. Source size: [%i]; Slice start: [%i]; Slice size: [%i]", array->sizeAt(0), cnt, c_size);
@ -63,11 +63,11 @@ namespace nd4j {
indices[0] = cnt; indices[0] = cnt;
indices[1] = cnt + c_size; indices[1] = cnt + c_size;
cnt += c_size; cnt += c_size;
auto subarray = (*array)(indices); auto subarray = (*array)(indices);
auto status = list->write(e, subarray.dup(array->ordering())); auto status = list->write(e, new NDArray(subarray.dup(array->ordering())));
if (status != ND4J_STATUS_OK) if (status != ND4J_STATUS_OK)
return status; return status;
} }

View File

@ -39,7 +39,7 @@ namespace nd4j {
//nd4j_printf("Writing [%i]:\n", idx->e<int>(0)); //nd4j_printf("Writing [%i]:\n", idx->e<int>(0));
//input->printShapeInfo("input shape"); //input->printShapeInfo("input shape");
//input->printIndexedBuffer("input buffer"); //input->printIndexedBuffer("input buffer");
Nd4jStatus result = list->write(idx->e<int>(0), input->dup()); Nd4jStatus result = list->write(idx->e<int>(0), new NDArray(input->dup()));
auto res = NDArrayFactory::create_(list->counter(), block.launchContext()); auto res = NDArrayFactory::create_(list->counter(), block.launchContext());
//res->printShapeInfo("Write_list 2 output shape"); //res->printShapeInfo("Write_list 2 output shape");
@ -52,7 +52,7 @@ namespace nd4j {
auto input = INPUT_VARIABLE(1); auto input = INPUT_VARIABLE(1);
auto idx = INT_ARG(0); auto idx = INT_ARG(0);
Nd4jStatus result = list->write(idx, input->dup()); Nd4jStatus result = list->write(idx, new NDArray(input->dup()));
auto res = NDArrayFactory::create_(list->counter(), block.launchContext()); auto res = NDArrayFactory::create_(list->counter(), block.launchContext());
//res->printShapeInfo("Write_list 1 output shape"); //res->printShapeInfo("Write_list 1 output shape");

View File

@ -169,10 +169,10 @@ CUSTOM_OP_IMPL(absolute_difference_loss_grad, 3, 3, false, 0, 1) {
NDArray E = *predictions - *labels; NDArray E = *predictions - *labels;
// dE_i/dp_i = sign(p_i - y_i) // dE_i/dp_i = sign(p_i - y_i)
E.applyTransform(nd4j::transform::Sign, dLdp); // dE/dp E.applyTransform(nd4j::transform::Sign, *dLdp); // dE/dp
// dE_i/dy_i = -sign(p_i - y_i) // dE_i/dy_i = -sign(p_i - y_i)
E.applyTransform(nd4j::transform::Abs); E.applyTransform(nd4j::transform::Abs, E);
switch (reductionMode) { switch (reductionMode) {
@ -184,7 +184,7 @@ CUSTOM_OP_IMPL(absolute_difference_loss_grad, 3, 3, false, 0, 1) {
dLdw->assign(E.reduceNumber(reduce::Sum)); dLdw->assign(E.reduceNumber(reduce::Sum));
else if(weights != weightsBroad) { else if(weights != weightsBroad) {
std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo());
E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false);
} }
else else
dLdw->assign(E); dLdw->assign(E);
@ -210,7 +210,7 @@ CUSTOM_OP_IMPL(absolute_difference_loss_grad, 3, 3, false, 0, 1) {
*dLdw = 0.; *dLdw = 0.;
else if(weights != weightsBroad) { else if(weights != weightsBroad) {
std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo());
((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false);
} }
else else
dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)); dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum));
@ -238,7 +238,7 @@ CUSTOM_OP_IMPL(absolute_difference_loss_grad, 3, 3, false, 0, 1) {
dLdw->assign(E.reduceNumber(reduce::Sum) / double(numOfNonZeroWeights)); dLdw->assign(E.reduceNumber(reduce::Sum) / double(numOfNonZeroWeights));
else if(weights != weightsBroad) { else if(weights != weightsBroad) {
std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo());
E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false);
*dLdw /= numOfNonZeroWeightsScalar; *dLdw /= numOfNonZeroWeightsScalar;
} }
else else

View File

@ -56,7 +56,7 @@ CUSTOM_OP_IMPL(cosine_distance_loss, 3, 1, false, 0, 2) {
REQUIRE_TRUE(weights->isScalar() || ShapeUtils::areShapesBroadcastable(*weights, *output), 0, "COSINE_DISTANCE_LOSS OP: shapes of weights and output arrays should be broadcastable, but got weights = %s and output = %s instead!", ShapeUtils::shapeAsString(weights).c_str(), ShapeUtils::shapeAsString(labels).c_str()); REQUIRE_TRUE(weights->isScalar() || ShapeUtils::areShapesBroadcastable(*weights, *output), 0, "COSINE_DISTANCE_LOSS OP: shapes of weights and output arrays should be broadcastable, but got weights = %s and output = %s instead!", ShapeUtils::shapeAsString(weights).c_str(), ShapeUtils::shapeAsString(labels).c_str());
} }
NDArray E = 1. - (*predictions * *labels).reduceAlongDims(reduce::Sum, {dim}, true); NDArray E = 1. - (*predictions * *labels).reduceAlongDimension(reduce::Sum, {dim}, true);
// perform weights broadcasting/tile to E if it is necessary // perform weights broadcasting/tile to E if it is necessary
auto weightsBroad = weights; auto weightsBroad = weights;
@ -70,7 +70,7 @@ CUSTOM_OP_IMPL(cosine_distance_loss, 3, 1, false, 0, 2) {
case 0: // 0 - "none", un-reduced weighted losses with the same shape as labels. case 0: // 0 - "none", un-reduced weighted losses with the same shape as labels.
output->assign(&E); output->assign(&E);
break; break;
case 1: { // 1 - "weighted_sum", output is scalar and equal to sum of all elements of E array case 1: { // 1 - "weighted_sum", output is scalar and equal to sum of all elements of E array
output->assign(E.reduceNumber(reduce::Sum)); output->assign(E.reduceNumber(reduce::Sum));
break; break;
@ -79,12 +79,12 @@ CUSTOM_OP_IMPL(cosine_distance_loss, 3, 1, false, 0, 2) {
NDArray sum; NDArray sum;
if (weights->isScalar()) if (weights->isScalar())
sum = *weights * E.lengthOf(); sum = *weights * E.lengthOf();
else else
sum = weightsBroad->reduceNumber(reduce::Sum); sum = weightsBroad->reduceNumber(reduce::Sum);
if (sum.e<double>(0) == 0.) if (sum.e<double>(0) == 0.)
*output = 0.; *output = 0.;
else else
output->assign(E.reduceNumber(reduce::Sum) / sum); output->assign(E.reduceNumber(reduce::Sum) / sum);
break; break;
} }
@ -99,9 +99,9 @@ CUSTOM_OP_IMPL(cosine_distance_loss, 3, 1, false, 0, 2) {
if (numOfNonZeroWeights == 0) if (numOfNonZeroWeights == 0)
*output = 0.; *output = 0.;
else else
output->assign(E.reduceNumber(reduce::Sum) / double(numOfNonZeroWeights)); output->assign(E.reduceNumber(reduce::Sum) / double(numOfNonZeroWeights));
break; break;
} }
} }
@ -111,7 +111,7 @@ CUSTOM_OP_IMPL(cosine_distance_loss, 3, 1, false, 0, 2) {
if(weightsBroad != weights) if(weightsBroad != weights)
delete weightsBroad; delete weightsBroad;
return Status::OK(); return Status::OK();
} }
@ -124,7 +124,7 @@ DECLARE_TYPES(cosine_distance_loss) {
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
DECLARE_SHAPE_FN(cosine_distance_loss) { DECLARE_SHAPE_FN(cosine_distance_loss) {
// labels and predictions must have the same shapes // labels and predictions must have the same shapes
auto predictionsShapeInfo = inputShape->at(0); auto predictionsShapeInfo = inputShape->at(0);
auto weightsShapeInfo = inputShape->at(1); auto weightsShapeInfo = inputShape->at(1);
auto labelsShapeInfo = inputShape->at(2); auto labelsShapeInfo = inputShape->at(2);
@ -194,7 +194,7 @@ CUSTOM_OP_IMPL(cosine_distance_loss_grad, 3, 3, false, 0, 2) {
// input dimension can't be larger than labels/predictions/weights rank // input dimension can't be larger than labels/predictions/weights rank
REQUIRE_TRUE(dim < labels->rankOf(), 0, "COSINE_DISTANCE_LOSS_GRAD OP: input reduction dimension (got %i) must be < labels rank %i!", dim, labels->rankOf()); REQUIRE_TRUE(dim < labels->rankOf(), 0, "COSINE_DISTANCE_LOSS_GRAD OP: input reduction dimension (got %i) must be < labels rank %i!", dim, labels->rankOf());
NDArray E = 1. - (*predictions * *labels).reduceAlongDims(reduce::Sum, {dim}, true); NDArray E = 1. - (*predictions * *labels).reduceAlongDimension(reduce::Sum, {dim}, true);
// perform weights broadcasting/tile to E if it is necessary // perform weights broadcasting/tile to E if it is necessary
auto weightsBroad = weights; auto weightsBroad = weights;
@ -216,7 +216,7 @@ CUSTOM_OP_IMPL(cosine_distance_loss_grad, 3, 3, false, 0, 2) {
else { else {
if(weights != weightsBroad) { if(weights != weightsBroad) {
std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo());
E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false);
} }
else else
dLdw->assign(E); dLdw->assign(E);
@ -249,7 +249,7 @@ CUSTOM_OP_IMPL(cosine_distance_loss_grad, 3, 3, false, 0, 2) {
if(weights != weightsBroad) { if(weights != weightsBroad) {
std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo());
((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false);
} }
else else
dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)); dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum));
@ -284,7 +284,7 @@ CUSTOM_OP_IMPL(cosine_distance_loss_grad, 3, 3, false, 0, 2) {
if(weights != weightsBroad) { if(weights != weightsBroad) {
std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo());
E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false);
*dLdw /= numOfNonZeroWeights; *dLdw /= numOfNonZeroWeights;
} }
else else

View File

@ -52,7 +52,7 @@ namespace nd4j {
// We first need to convert binary labels to -1/1 labels (as floats) // We first need to convert binary labels to -1/1 labels (as floats)
NDArray E = 1.f - (*labels * 2.f - 1.f) * (*logits); NDArray E = 1.f - (*labels * 2.f - 1.f) * (*logits);
E.applyScalar(scalar::RELU, 0.0f, &E); E.applyScalar(scalar::RELU, 0.0f, E);
// multiply E on weights // multiply E on weights
E *= *weightsBroad; E *= *weightsBroad;
@ -172,11 +172,11 @@ namespace nd4j {
NDArray z = (*labels * 2.f - 1.f); NDArray z = (*labels * 2.f - 1.f);
NDArray E = 1.f - z * (*logits); NDArray E = 1.f - z * (*logits);
E.applyScalar(scalar::RELU, 0.0f, &E); E.applyScalar(scalar::RELU, 0.0f, E);
// turn E into gradient mask // turn E into gradient mask
NDArray gradientMask(E.getShapeInfo(), block.getWorkspace()); NDArray gradientMask(E.getShapeInfo(), block.getWorkspace());
E.applyTransform(nd4j::transform::Sign, &gradientMask); E.applyTransform(nd4j::transform::Sign, gradientMask);
dLdp->assign(-z * gradientMask); dLdp->assign(-z * gradientMask);
dLdl->assign(-2.f * (*logits) * gradientMask); dLdl->assign(-2.f * (*logits) * gradientMask);
@ -192,7 +192,7 @@ namespace nd4j {
dLdw->assign(E.reduceNumber(reduce::Sum)); dLdw->assign(E.reduceNumber(reduce::Sum));
else if(weights != weightsBroad) { else if(weights != weightsBroad) {
std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo());
E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false);
} }
else else
dLdw->assign(E); dLdw->assign(E);
@ -220,7 +220,7 @@ namespace nd4j {
*dLdw = 0.; *dLdw = 0.;
else if(weights != weightsBroad) { else if(weights != weightsBroad) {
std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo());
((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false);
} }
else else
dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)); dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum));
@ -249,7 +249,7 @@ namespace nd4j {
dLdw->assign(E.reduceNumber(reduce::Sum) / double(numOfNonZeroWeights)); dLdw->assign(E.reduceNumber(reduce::Sum) / double(numOfNonZeroWeights));
else if(weights != weightsBroad) { else if(weights != weightsBroad) {
std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo());
E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false);
*dLdw /= numOfNonZeroWeightsScalar; *dLdw /= numOfNonZeroWeightsScalar;
} }
else else

View File

@ -46,17 +46,17 @@ CUSTOM_OP_IMPL(huber_loss, 3, 1, false, 1, 1) {
REQUIRE_TRUE(weights->isScalar() || ShapeUtils::areShapesBroadcastable(*weights, *labels), 0, "HUBER_LOSS OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weights).c_str(), ShapeUtils::shapeAsString(labels).c_str()); REQUIRE_TRUE(weights->isScalar() || ShapeUtils::areShapesBroadcastable(*weights, *labels), 0, "HUBER_LOSS OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weights).c_str(), ShapeUtils::shapeAsString(labels).c_str());
// only 4 possible reduction modes exist // only 4 possible reduction modes exist
REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "HUBER_LOSS OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode); REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "HUBER_LOSS OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode);
// perform weights broadcasting/tile to predictions if needed // perform weights broadcasting/tile to predictions if needed
auto weightsBroad = weights; auto weightsBroad = weights;
if(!weights->isScalar() && !weights->isSameShape(predictions)) if(!weights->isScalar() && !weights->isSameShape(predictions))
weightsBroad = new NDArray(weights->tileToShape(predictions->getShapeInfo())); weightsBroad = new NDArray(weights->tileToShape(predictions->getShapeInfo()));
auto error = *predictions - *labels; auto error = *predictions - *labels;
error.applyTransform(transform::Abs); error.applyTransform(transform::Abs, error);
NDArray quadratic(error.getShapeInfo(), block.getWorkspace()); NDArray quadratic(error.getShapeInfo(), block.getWorkspace());
error.applyScalar(scalar::MinPairwise, delta, &quadratic); error.applyScalar(scalar::MinPairwise, delta, quadratic);
NDArray E = quadratic * quadratic * 0.5f + (error - quadratic)*delta; NDArray E = quadratic * quadratic * 0.5f + (error - quadratic)*delta;
// multiply E on weights // multiply E on weights
@ -75,12 +75,12 @@ CUSTOM_OP_IMPL(huber_loss, 3, 1, false, 1, 1) {
NDArray sum; NDArray sum;
if (weights->isScalar()) if (weights->isScalar())
sum = *weights * E.lengthOf(); sum = *weights * E.lengthOf();
else else
sum = weightsBroad->reduceNumber(reduce::Sum); sum = weightsBroad->reduceNumber(reduce::Sum);
if (sum.e<double>(0) == 0.) if (sum.e<double>(0) == 0.)
*output = 0.; *output = 0.;
else else
output->assign(E.reduceNumber(reduce::Sum) / sum); output->assign(E.reduceNumber(reduce::Sum) / sum);
break; break;
} }
@ -104,7 +104,7 @@ CUSTOM_OP_IMPL(huber_loss, 3, 1, false, 1, 1) {
if(weightsBroad != weights) if(weightsBroad != weights)
delete weightsBroad; delete weightsBroad;
return Status::OK(); return Status::OK();
} }
@ -173,24 +173,24 @@ DECLARE_SHAPE_FN(huber_loss) {
NDArray diff = *predictions - *labels; NDArray diff = *predictions - *labels;
NDArray absDiff(diff); NDArray absDiff(diff);
absDiff.applyTransform(transform::Abs); absDiff.applyTransform(transform::Abs, absDiff);
NDArray quadratic(absDiff); NDArray quadratic(absDiff);
absDiff.applyScalar(scalar::MinPairwise, delta, &quadratic); absDiff.applyScalar(scalar::MinPairwise, delta, quadratic);
NDArray E = quadratic * quadratic * 0.5f + (absDiff - quadratic)*delta; NDArray E = quadratic * quadratic * 0.5f + (absDiff - quadratic)*delta;
NDArray lteMask(diff.getShapeInfo(), BOOL, true, block.launchContext()); NDArray lteMask(diff.getShapeInfo(), BOOL, true, block.launchContext());
absDiff.applyScalar(scalar::LessThanOrEqual, delta, &lteMask); absDiff.applyScalar(scalar::LessThanOrEqual, delta, lteMask);
NDArray gtMask(diff.getShapeInfo(), BOOL, true, block.launchContext()); NDArray gtMask(diff.getShapeInfo(), BOOL, true, block.launchContext());
absDiff.applyScalar(scalar::GreaterThan, delta, &gtMask); absDiff.applyScalar(scalar::GreaterThan, delta, gtMask);
NDArray signDiff(diff); NDArray signDiff(diff);
diff.applyTransform(transform::Sign, &signDiff); diff.applyTransform(transform::Sign, signDiff);
auto gtMaskFloat = *gtMask.cast(diff.dataType()); auto gtMaskFloat = gtMask.cast(diff.dataType());
auto lteMaskFloat = *lteMask.cast(diff.dataType()); auto lteMaskFloat = lteMask.cast(diff.dataType());
dLdp->assign( lteMaskFloat * diff + gtMaskFloat * delta * signDiff); dLdp->assign( lteMaskFloat * diff + gtMaskFloat * delta * signDiff);
@ -207,7 +207,7 @@ DECLARE_SHAPE_FN(huber_loss) {
dLdw->assign(E.reduceNumber(reduce::Sum)); dLdw->assign(E.reduceNumber(reduce::Sum));
else if(weights != weightsBroad) { else if(weights != weightsBroad) {
std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo());
E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false);
} }
else else
dLdw->assign(E); dLdw->assign(E);
@ -235,7 +235,7 @@ DECLARE_SHAPE_FN(huber_loss) {
*dLdw = 0.; *dLdw = 0.;
else if(weights != weightsBroad) { else if(weights != weightsBroad) {
std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo());
((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false);
} }
else else
dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)); dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum));
@ -264,7 +264,7 @@ DECLARE_SHAPE_FN(huber_loss) {
dLdw->assign(E.reduceNumber(reduce::Sum) / double(numOfNonZeroWeights)); dLdw->assign(E.reduceNumber(reduce::Sum) / double(numOfNonZeroWeights));
else if(weights != weightsBroad) { else if(weights != weightsBroad) {
std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo());
E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false);
*dLdw /= numOfNonZeroWeightsScalar; *dLdw /= numOfNonZeroWeightsScalar;
} }
else else

View File

@ -29,11 +29,11 @@ namespace ops {
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(log_loss, 3, 1, false, 1, 1) { CUSTOM_OP_IMPL(log_loss, 3, 1, false, 1, 1) {
auto predictions = INPUT_VARIABLE(0); auto predictions = INPUT_VARIABLE(0);
auto weights = INPUT_VARIABLE(1); auto weights = INPUT_VARIABLE(1);
auto labels = INPUT_VARIABLE(2); auto labels = INPUT_VARIABLE(2);
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0);
int reductionMode = INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - "weighted_sum_by_nonzero_weights" int reductionMode = INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - "weighted_sum_by_nonzero_weights"
@ -48,7 +48,7 @@ CUSTOM_OP_IMPL(log_loss, 3, 1, false, 1, 1) {
REQUIRE_TRUE(weights->isScalar() || ShapeUtils::areShapesBroadcastable(*weights, *labels), 0, "LOG_LOSS OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weights).c_str(), ShapeUtils::shapeAsString(labels).c_str()); REQUIRE_TRUE(weights->isScalar() || ShapeUtils::areShapesBroadcastable(*weights, *labels), 0, "LOG_LOSS OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weights).c_str(), ShapeUtils::shapeAsString(labels).c_str());
// only 4 possible reduction modes exist // only 4 possible reduction modes exist
REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "LOG_LOSS OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode); REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "LOG_LOSS OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode);
// perform weights broadcasting/tile to predictions if needed // perform weights broadcasting/tile to predictions if needed
auto weightsBroad = weights; auto weightsBroad = weights;
if(!weights->isScalar() && !weights->isSameShape(predictions)) if(!weights->isScalar() && !weights->isSameShape(predictions))
@ -58,7 +58,7 @@ CUSTOM_OP_IMPL(log_loss, 3, 1, false, 1, 1) {
// multiply E on weights // multiply E on weights
E *= *weightsBroad; E *= *weightsBroad;
switch (reductionMode) { switch (reductionMode) {
case 0: { // 0 - "none", un-reduced weighted losses with the same shape as labels. case 0: { // 0 - "none", un-reduced weighted losses with the same shape as labels.
output->assign(E); output->assign(E);
@ -72,12 +72,12 @@ CUSTOM_OP_IMPL(log_loss, 3, 1, false, 1, 1) {
NDArray sum; NDArray sum;
if (weights->isScalar()) if (weights->isScalar())
sum = *weights * E.lengthOf(); sum = *weights * E.lengthOf();
else else
sum = weightsBroad->reduceNumber(reduce::Sum); sum = weightsBroad->reduceNumber(reduce::Sum);
if (sum.e<double>(0) == 0.) if (sum.e<double>(0) == 0.)
*output = 0.; *output = 0.;
else else
output->assign(E.reduceNumber(reduce::Sum) / sum); output->assign(E.reduceNumber(reduce::Sum) / sum);
break; break;
} }
@ -101,13 +101,13 @@ CUSTOM_OP_IMPL(log_loss, 3, 1, false, 1, 1) {
if(weightsBroad != weights) if(weightsBroad != weights)
delete weightsBroad; delete weightsBroad;
return Status::OK(); return Status::OK();
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
DECLARE_TYPES(log_loss) { DECLARE_TYPES(log_loss) {
getOpDescriptor()->setAllowedInputTypes(nd4j::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS}); getOpDescriptor()->setAllowedInputTypes(nd4j::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS});
} }
@ -118,11 +118,11 @@ DECLARE_SHAPE_FN(log_loss) {
auto weightsShapeInfo = inputShape->at(1); auto weightsShapeInfo = inputShape->at(1);
auto labelsShapeInfo = inputShape->at(2); auto labelsShapeInfo = inputShape->at(2);
// labels and predictions must have the same shapes // labels and predictions must have the same shapes
REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, "LOG_LOSS OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, "LOG_LOSS OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(predictionsShapeInfo).c_str());
// weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels
REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), 0, "LOG_LOSS OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), 0, "LOG_LOSS OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo));
// check whether broadcast operation is possible for weights array // check whether broadcast operation is possible for weights array
REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), 0, "LOG_LOSS OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), 0, "LOG_LOSS OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), ShapeUtils::shapeAsString(labelsShapeInfo).c_str());
DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(predictionsShapeInfo)); DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(predictionsShapeInfo));
@ -132,7 +132,7 @@ DECLARE_SHAPE_FN(log_loss) {
outShapeInfo = ConstantShapeHelper::getInstance()->scalarShapeInfo(outType); outShapeInfo = ConstantShapeHelper::getInstance()->scalarShapeInfo(outType);
else // in this case output has the same shape as labels and predictions else // in this case output has the same shape as labels and predictions
outShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(outType, shape::order(labelsShapeInfo), shape::shapeOf(labelsShapeInfo), shape::rank(labelsShapeInfo))); outShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(outType, shape::order(labelsShapeInfo), shape::shapeOf(labelsShapeInfo), shape::rank(labelsShapeInfo)));
return SHAPELIST(outShapeInfo); return SHAPELIST(outShapeInfo);
} }
@ -143,33 +143,33 @@ DECLARE_SHAPE_FN(log_loss) {
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(log_loss_grad, 3, 3, false, 1, 1) { CUSTOM_OP_IMPL(log_loss_grad, 3, 3, false, 1, 1) {
auto predictions = INPUT_VARIABLE(0); auto predictions = INPUT_VARIABLE(0);
auto weights = INPUT_VARIABLE(1); auto weights = INPUT_VARIABLE(1);
auto labels = INPUT_VARIABLE(2); auto labels = INPUT_VARIABLE(2);
auto dLdp = OUTPUT_VARIABLE(0); // dL/dpredictions auto dLdp = OUTPUT_VARIABLE(0); // dL/dpredictions
auto dLdw = OUTPUT_VARIABLE(1); // dL/dweights auto dLdw = OUTPUT_VARIABLE(1); // dL/dweights
auto dLdl = OUTPUT_VARIABLE(2); // dL/dlabels auto dLdl = OUTPUT_VARIABLE(2); // dL/dlabels
int reductionMode = INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - "weighted_sum_by_nonzero_weights" int reductionMode = INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - "weighted_sum_by_nonzero_weights"
// take into account Alex's proposition to treat "none" the same as "weighted_sum" mode when calculating gradients // take into account Alex's proposition to treat "none" the same as "weighted_sum" mode when calculating gradients
if(reductionMode == 0) if(reductionMode == 0)
reductionMode = 1; reductionMode = 1;
// FIXME: double? // FIXME: double?
double epsilon = T_ARG(0); double epsilon = T_ARG(0);
// input validation // input validation
REQUIRE_TRUE(labels->isSameShape(predictions), 0, "LOG_LOSS_GRAD OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(predictions).c_str()); REQUIRE_TRUE(labels->isSameShape(predictions), 0, "LOG_LOSS_GRAD OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(predictions).c_str());
// weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels
REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, "LOG_LOSS_GRAD OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", weights->rankOf(), labels->rankOf()); REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, "LOG_LOSS_GRAD OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", weights->rankOf(), labels->rankOf());
// check whether broadcast operation is possible for weights array // check whether broadcast operation is possible for weights array
REQUIRE_TRUE(weights->isScalar() || ShapeUtils::areShapesBroadcastable(*weights, *labels), 0, "LOG_LOSS_GRAD OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weights).c_str(), ShapeUtils::shapeAsString(labels).c_str()); REQUIRE_TRUE(weights->isScalar() || ShapeUtils::areShapesBroadcastable(*weights, *labels), 0, "LOG_LOSS_GRAD OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weights).c_str(), ShapeUtils::shapeAsString(labels).c_str());
// only 4 possible reduction modes exist // only 4 possible reduction modes exist
REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "LOG_LOSS_GRAD OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode); REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "LOG_LOSS_GRAD OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode);
// perform weights broadcasting/tile to labels if needed // perform weights broadcasting/tile to labels if needed
auto weightsBroad = weights; auto weightsBroad = weights;
if(!weights->isScalar() && !weights->isSameShape(predictions)) if(!weights->isScalar() && !weights->isSameShape(predictions))
weightsBroad = new NDArray(weights->tileToShape(predictions->getShapeInfo())); weightsBroad = new NDArray(weights->tileToShape(predictions->getShapeInfo()));
@ -179,24 +179,24 @@ CUSTOM_OP_IMPL(log_loss_grad, 3, 3, false, 1, 1) {
NDArray onePlusEpsMinusPredict = (1. + epsilon) - *predictions; NDArray onePlusEpsMinusPredict = (1. + epsilon) - *predictions;
// dE_i/dp_i = (1-y_i)/(1-p_i+eps) - y_i/(p_i+eps) // dE_i/dp_i = (1-y_i)/(1-p_i+eps) - y_i/(p_i+eps)
dLdp->assign(oneMinusLabels / onePlusEpsMinusPredict - *labels / predictPlusEps); // dE/dp dLdp->assign(oneMinusLabels / onePlusEpsMinusPredict - *labels / predictPlusEps); // dE/dp
// dE_i/dy_i = log((1+2eps)/(p_i+eps) - 1) // dE_i/dy_i = log((1+2eps)/(p_i+eps) - 1)
((1. + 2. * epsilon) / predictPlusEps - 1.).applyTransform(transform::Log, dLdl); // dE/dy ((1. + 2. * epsilon) / predictPlusEps - 1.).applyTransform(transform::Log, *dLdl); // dE/dy
NDArray E = -(*labels) * predictPlusEps.transform(transform::Log) - oneMinusLabels * onePlusEpsMinusPredict.transform(transform::Log); NDArray E = -(*labels) * predictPlusEps.transform(transform::Log) - oneMinusLabels * onePlusEpsMinusPredict.transform(transform::Log);
// process 3 possible reduction modes below // process 3 possible reduction modes below
switch (reductionMode) { switch (reductionMode) {
case 1: { // 1 - "none" and "weighted_sum", output is scalar and equal to sum of all elements of E array case 1: { // 1 - "none" and "weighted_sum", output is scalar and equal to sum of all elements of E array
*dLdp *= *weightsBroad; *dLdp *= *weightsBroad;
*dLdl *= *weightsBroad; *dLdl *= *weightsBroad;
if(weights->isScalar()) if(weights->isScalar())
dLdw->assign(E.reduceNumber(reduce::Sum)); dLdw->assign(E.reduceNumber(reduce::Sum));
else if(weights != weightsBroad) { else if(weights != weightsBroad) {
std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo());
E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false);
} }
else else
dLdw->assign(E); dLdw->assign(E);
@ -208,9 +208,9 @@ CUSTOM_OP_IMPL(log_loss_grad, 3, 3, false, 1, 1) {
NDArray sum; NDArray sum;
if (weights->isScalar()) if (weights->isScalar())
sum = (*weights) * E.lengthOf(); sum = (*weights) * E.lengthOf();
else else
sum = weightsBroad->reduceNumber(reduce::Sum); sum = weightsBroad->reduceNumber(reduce::Sum);
if (sum.e<double>(0) == 0.) { if (sum.e<double>(0) == 0.) {
*dLdp = 0.; *dLdp = 0.;
*dLdl = 0.; *dLdl = 0.;
@ -221,27 +221,27 @@ CUSTOM_OP_IMPL(log_loss_grad, 3, 3, false, 1, 1) {
NDArray temp = *weightsBroad / sum; NDArray temp = *weightsBroad / sum;
*dLdp *= temp; *dLdp *= temp;
*dLdl *= temp; *dLdl *= temp;
if(weights->isScalar()) if(weights->isScalar())
*dLdw = 0.; *dLdw = 0.;
else if(weights != weightsBroad) { else if(weights != weightsBroad) {
std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo());
((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false);
} }
else else
dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)); dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum));
} }
break; break;
} }
case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of E array divided by number of non-zero weights case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of E array divided by number of non-zero weights
Nd4jLong numOfNonZeroWeights = 0; Nd4jLong numOfNonZeroWeights = 0;
if(weights->isScalar()) { if(weights->isScalar()) {
if(weights->e<double>(0) != 0.) if(weights->e<double>(0) != 0.)
numOfNonZeroWeights = E.lengthOf(); numOfNonZeroWeights = E.lengthOf();
} }
else else
numOfNonZeroWeights = weightsBroad->reduceNumber(reduce::CountNonZero).e<Nd4jLong>(0); numOfNonZeroWeights = weightsBroad->reduceNumber(reduce::CountNonZero).e<Nd4jLong>(0);
if (numOfNonZeroWeights == 0) { if (numOfNonZeroWeights == 0) {
*dLdp = 0.; *dLdp = 0.;
@ -254,12 +254,12 @@ CUSTOM_OP_IMPL(log_loss_grad, 3, 3, false, 1, 1) {
dLdw->assign(E.reduceNumber(reduce::Sum) / numOfNonZeroWeights); dLdw->assign(E.reduceNumber(reduce::Sum) / numOfNonZeroWeights);
else if(weights != weightsBroad) { else if(weights != weightsBroad) {
std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo());
E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false);
*dLdw /= numOfNonZeroWeightsScalar; *dLdw /= numOfNonZeroWeightsScalar;
} }
else else
dLdw->assign(E / numOfNonZeroWeightsScalar); dLdw->assign(E / numOfNonZeroWeightsScalar);
NDArray temp = *weightsBroad / numOfNonZeroWeightsScalar; NDArray temp = *weightsBroad / numOfNonZeroWeightsScalar;
*dLdp *= temp; *dLdp *= temp;
*dLdl *= temp; *dLdl *= temp;
@ -270,13 +270,13 @@ CUSTOM_OP_IMPL(log_loss_grad, 3, 3, false, 1, 1) {
if(weightsBroad != weights) if(weightsBroad != weights)
delete weightsBroad; delete weightsBroad;
return Status::OK(); return Status::OK();
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
DECLARE_TYPES(log_loss_grad) { DECLARE_TYPES(log_loss_grad) {
getOpDescriptor()->setAllowedInputTypes(nd4j::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS}); getOpDescriptor()->setAllowedInputTypes(nd4j::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS});
} }
@ -287,19 +287,19 @@ DECLARE_SHAPE_FN(log_loss_grad) {
auto weightsShapeInfo = inputShape->at(1); auto weightsShapeInfo = inputShape->at(1);
auto labelsShapeInfo = inputShape->at(2); auto labelsShapeInfo = inputShape->at(2);
// labels and predictions must have the same shapes // labels and predictions must have the same shapes
REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, "LOG_LOSS_GRAD OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, "LOG_LOSS_GRAD OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(predictionsShapeInfo).c_str());
// weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels
REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), 0, "LOG_LOSS_GRAD OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), 0, "LOG_LOSS_GRAD OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo));
// check whether broadcast operation is possible for weights array // check whether broadcast operation is possible for weights array
REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), 0, "LOG_LOSS_GRAD OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), 0, "LOG_LOSS_GRAD OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), ShapeUtils::shapeAsString(labelsShapeInfo).c_str());
DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(predictionsShapeInfo)); DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(predictionsShapeInfo));
auto dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType(predictionsShapeInfo, outType, false, block.getWorkspace()); auto dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType(predictionsShapeInfo, outType, false, block.getWorkspace());
auto dLdwShapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, outType, false, block.getWorkspace()); auto dLdwShapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, outType, false, block.getWorkspace());
auto dLdlShapeInfo = ShapeBuilders::copyShapeInfoAndType(labelsShapeInfo, outType, false, block.getWorkspace()); auto dLdlShapeInfo = ShapeBuilders::copyShapeInfoAndType(labelsShapeInfo, outType, false, block.getWorkspace());
return SHAPELIST(CONSTANT(dLdpShapeInfo), CONSTANT(dLdwShapeInfo), CONSTANT(dLdlShapeInfo)); return SHAPELIST(CONSTANT(dLdpShapeInfo), CONSTANT(dLdwShapeInfo), CONSTANT(dLdlShapeInfo));
} }

View File

@ -55,9 +55,9 @@ namespace ops {
NDArray E(labels->getShapeInfo(), block.getWorkspace()); NDArray E(labels->getShapeInfo(), block.getWorkspace());
if (computeFullLoss) if (computeFullLoss)
labels->applyPairwiseTransform(pairwise::LogPoissonLossFull, log_predictions, &E, nullptr); labels->applyPairwiseTransform(pairwise::LogPoissonLossFull, *log_predictions, E);
else else
labels->applyPairwiseTransform(pairwise::LogPoissonLoss, log_predictions, &E, nullptr); labels->applyPairwiseTransform(pairwise::LogPoissonLoss, *log_predictions, E);
// multiply E on weights // multiply E on weights
@ -176,19 +176,19 @@ namespace ops {
NDArray E(labels->getShapeInfo(), block.getWorkspace()); NDArray E(labels->getShapeInfo(), block.getWorkspace());
if (computeFullLoss) { if (computeFullLoss) {
labels->applyPairwiseTransform(pairwise::LogPoissonLossFull, log_predictions, &E, nullptr); labels->applyPairwiseTransform(pairwise::LogPoissonLossFull, *log_predictions, E);
NDArray rDiv(labels->getShapeInfo(), block.getWorkspace()); NDArray rDiv(labels->getShapeInfo(), block.getWorkspace());
labels->applyScalar(scalar::ReverseDivide, 0.5f, &rDiv); labels->applyScalar(scalar::ReverseDivide, 0.5f, rDiv);
dLdl->assign(rDiv + labels->transform(transform::Log) + -(*log_predictions)); dLdl->assign(rDiv + labels->transform(transform::Log) + -(*log_predictions));
} else { } else {
labels->applyPairwiseTransform(pairwise::LogPoissonLoss, log_predictions, &E, nullptr); labels->applyPairwiseTransform(pairwise::LogPoissonLoss, *log_predictions, E);
dLdl->assign(-(*log_predictions)); dLdl->assign(-(*log_predictions));
} }
dLdp->assign(log_predictions->transform(transform::Exp) - (*labels)); dLdp->assign(log_predictions->transform(transform::Exp) - (*labels));
switch (reductionMode) { switch (reductionMode) {
case 1: { // 1 - "none" and "weighted_sum", output is scalar and equal to sum of all elements of E array case 1: { // 1 - "none" and "weighted_sum", output is scalar and equal to sum of all elements of E array
@ -200,7 +200,7 @@ namespace ops {
dLdw->assign(E.reduceNumber(reduce::Sum)); dLdw->assign(E.reduceNumber(reduce::Sum));
else if(weights != weightsBroad) { else if(weights != weightsBroad) {
std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo());
E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false);
} }
else else
dLdw->assign(E); dLdw->assign(E);
@ -228,7 +228,7 @@ namespace ops {
*dLdw = 0.; *dLdw = 0.;
else if(weights != weightsBroad) { else if(weights != weightsBroad) {
std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo());
((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false);
} }
else else
dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)); dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum));
@ -257,7 +257,7 @@ namespace ops {
dLdw->assign(E.reduceNumber(reduce::Sum) / double(numOfNonZeroWeights)); dLdw->assign(E.reduceNumber(reduce::Sum) / double(numOfNonZeroWeights));
else if(weights != weightsBroad) { else if(weights != weightsBroad) {
std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo());
E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false);
*dLdw /= numOfNonZeroWeightsScalar; *dLdw /= numOfNonZeroWeightsScalar;
} }
else else

View File

@ -112,10 +112,10 @@ namespace nd4j {
auto n = double(labels->sizeAt(1)); auto n = double(labels->sizeAt(1));
auto diffs = *predictions - *labels; auto diffs = *predictions - *labels;
auto sumOfSquares = (diffs * diffs).reduceAlongDims(reduce::Sum, reductionIdx, true); auto sumOfSquares = (diffs * diffs).reduceAlongDimension(reduce::Sum, reductionIdx, true);
auto squareOfSum = diffs.reduceAlongDims(reduce::Sum, reductionIdx, true); auto squareOfSum = diffs.reduceAlongDimension(reduce::Sum, reductionIdx, true);
squareOfSum.applyScalar(scalar::Pow, 2); squareOfSum.applyScalar(scalar::Pow, 2, squareOfSum);
auto E = ((sumOfSquares * n) - squareOfSum) * (4/(n*(n-1))); auto E = ((sumOfSquares * n) - squareOfSum) * (4/(n*(n-1)));
@ -240,15 +240,15 @@ namespace nd4j {
auto diffs = *predictions - *labels; auto diffs = *predictions - *labels;
std::vector<int> reductionIdx = ShapeUtils::evalDimsToExclude(labels->rankOf(), {0}); std::vector<int> reductionIdx = ShapeUtils::evalDimsToExclude(labels->rankOf(), {0});
auto sumOfSquares = (diffs * diffs).reduceAlongDims(reduce::Sum, reductionIdx, true); auto sumOfSquares = (diffs * diffs).reduceAlongDimension(reduce::Sum, reductionIdx, true);
auto squareOfSum = diffs.reduceAlongDims(reduce::Sum, reductionIdx, true); auto squareOfSum = diffs.reduceAlongDimension(reduce::Sum, reductionIdx, true);
squareOfSum.applyScalar(scalar::Pow, 2); squareOfSum.applyScalar(scalar::Pow, 2, squareOfSum);
auto E = ((sumOfSquares * n) - squareOfSum) * (4/(n*(n-1))); auto E = ((sumOfSquares * n) - squareOfSum) * (4/(n*(n-1)));
auto sumPred = predictions->reduceAlongDims(reduce::Sum, reductionIdx, true); auto sumPred = predictions->reduceAlongDimension(reduce::Sum, reductionIdx, true);
auto sumLabel = labels->reduceAlongDims(reduce::Sum, reductionIdx, true); auto sumLabel = labels->reduceAlongDimension(reduce::Sum, reductionIdx, true);
dLdp->assign(((diffs * n) - sumPred + sumLabel)*(8/(n*(n-1)))); dLdp->assign(((diffs * n) - sumPred + sumLabel)*(8/(n*(n-1))));
@ -273,7 +273,7 @@ namespace nd4j {
dLdw->assign(E.reduceNumber(reduce::Sum)); dLdw->assign(E.reduceNumber(reduce::Sum));
else if(weights != weightsBroad) { else if(weights != weightsBroad) {
std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo());
E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false);
} }
else else
dLdw->assign(E); dLdw->assign(E);
@ -299,7 +299,7 @@ namespace nd4j {
*dLdw = 0.; *dLdw = 0.;
else if(weights != weightsBroad) { else if(weights != weightsBroad) {
std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo());
((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false);
} }
else else
dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)); dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum));
@ -327,7 +327,7 @@ namespace nd4j {
dLdw->assign(E.reduceNumber(reduce::Sum) / double(numOfNonZeroWeights)); dLdw->assign(E.reduceNumber(reduce::Sum) / double(numOfNonZeroWeights));
else if(weights != weightsBroad) { else if(weights != weightsBroad) {
std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo());
E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false);
*dLdw /= numOfNonZeroWeightsScalar; *dLdw /= numOfNonZeroWeightsScalar;
} }
else else

View File

@ -35,8 +35,8 @@ CUSTOM_OP_IMPL(mean_sqerr_loss, 3, 1, false, 0, 1) {
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0);
int reductionMode = INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - "weighted_sum_by_nonzero_weights" int reductionMode = INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - "weighted_sum_by_nonzero_weights"
// inputs validation // inputs validation
REQUIRE_TRUE(labels->isSameShape(predictions), 0, "MEAN_SQERR_LOSS OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(predictions).c_str()); REQUIRE_TRUE(labels->isSameShape(predictions), 0, "MEAN_SQERR_LOSS OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(predictions).c_str());
// weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels
REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, "MEAN_SQERR_LOSS OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", weights->rankOf(), labels->rankOf()); REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, "MEAN_SQERR_LOSS OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", weights->rankOf(), labels->rankOf());
@ -45,13 +45,13 @@ CUSTOM_OP_IMPL(mean_sqerr_loss, 3, 1, false, 0, 1) {
// only 4 possible reduction modes exist // only 4 possible reduction modes exist
REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "MEAN_SQERR_LOSS OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode); REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "MEAN_SQERR_LOSS OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode);
// perform weights broadcasting/tile to labels if needed // perform weights broadcasting/tile to labels if needed
auto weightsBroad = weights; auto weightsBroad = weights;
if(!weights->isScalar() && !weights->isSameShape(predictions)) if(!weights->isScalar() && !weights->isSameShape(predictions))
weightsBroad = new NDArray(weights->tileToShape(predictions->getShapeInfo())); weightsBroad = new NDArray(weights->tileToShape(predictions->getShapeInfo()));
NDArray E(labels->getShapeInfo(), false, block.launchContext()); NDArray E(labels->getShapeInfo(), false, block.launchContext());
predictions->applyPairwiseTransform(pairwise::SquaredSubtract, labels, &E, nullptr); predictions->applyPairwiseTransform(pairwise::SquaredSubtract, *labels, E);
// multiply E on weights // multiply E on weights
E *= (*weightsBroad); E *= (*weightsBroad);
@ -60,7 +60,7 @@ CUSTOM_OP_IMPL(mean_sqerr_loss, 3, 1, false, 0, 1) {
case 0: // 0 - "none", un-reduced weighted losses with the same shape as labels. case 0: // 0 - "none", un-reduced weighted losses with the same shape as labels.
output->assign(&E); output->assign(&E);
break; break;
case 1: { // 1 - "weighted_sum", output is scalar and equal to sum of all elements of E array case 1: { // 1 - "weighted_sum", output is scalar and equal to sum of all elements of E array
E.reduceNumber(reduce::Sum, *output); E.reduceNumber(reduce::Sum, *output);
break; break;
@ -69,12 +69,12 @@ CUSTOM_OP_IMPL(mean_sqerr_loss, 3, 1, false, 0, 1) {
NDArray sum; NDArray sum;
if (weights->isScalar()) if (weights->isScalar())
sum = (*weights) * E.lengthOf(); sum = (*weights) * E.lengthOf();
else else
sum = weightsBroad->reduceNumber(reduce::Sum); sum = weightsBroad->reduceNumber(reduce::Sum);
if (sum.e<double>(0) == 0.) if (sum.e<double>(0) == 0.)
(*output) = 0.; (*output) = 0.;
else else
output->assign(E.reduceNumber(reduce::Sum) / sum); output->assign(E.reduceNumber(reduce::Sum) / sum);
break; break;
} }
@ -101,12 +101,12 @@ CUSTOM_OP_IMPL(mean_sqerr_loss, 3, 1, false, 0, 1) {
if(weightsBroad != weights) if(weightsBroad != weights)
delete weightsBroad; delete weightsBroad;
return Status::OK(); return Status::OK();
} }
DECLARE_TYPES(mean_sqerr_loss) { DECLARE_TYPES(mean_sqerr_loss) {
getOpDescriptor()->setAllowedInputTypes(nd4j::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS}); getOpDescriptor()->setAllowedInputTypes(nd4j::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS});
} }
@ -121,7 +121,7 @@ DECLARE_SHAPE_FN(mean_sqerr_loss) {
REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, "MEAN_SQERR_LOSS OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, "MEAN_SQERR_LOSS OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(predictionsShapeInfo).c_str());
// weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels
REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), 0, "MEAN_SQERR_LOSS OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), 0, "MEAN_SQERR_LOSS OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo));
// check whether broadcast operation is possible for weights array // check whether broadcast operation is possible for weights array
REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), 0, "MEAN_SQERR_LOSS OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), 0, "MEAN_SQERR_LOSS OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), ShapeUtils::shapeAsString(labelsShapeInfo).c_str());
DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(predictionsShapeInfo)); DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(predictionsShapeInfo));
@ -132,7 +132,7 @@ DECLARE_SHAPE_FN(mean_sqerr_loss) {
else // in this case output has the same shape as labels and predictions else // in this case output has the same shape as labels and predictions
outShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(outType, shape::order(labelsShapeInfo), shape::shapeOf(labelsShapeInfo), shape::rank(labelsShapeInfo))); outShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(outType, shape::order(labelsShapeInfo), shape::shapeOf(labelsShapeInfo), shape::rank(labelsShapeInfo)));
return SHAPELIST(outShapeInfo); return SHAPELIST(outShapeInfo);
} }
@ -144,11 +144,11 @@ DECLARE_SHAPE_FN(mean_sqerr_loss) {
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(mean_sqerr_loss_grad, 3, 3, false, 0, 1) { CUSTOM_OP_IMPL(mean_sqerr_loss_grad, 3, 3, false, 0, 1) {
auto predictions = INPUT_VARIABLE(0); auto predictions = INPUT_VARIABLE(0);
auto weights = INPUT_VARIABLE(1); auto weights = INPUT_VARIABLE(1);
auto labels = INPUT_VARIABLE(2); auto labels = INPUT_VARIABLE(2);
auto dLdp = OUTPUT_VARIABLE(0); // dL/dpredictions auto dLdp = OUTPUT_VARIABLE(0); // dL/dpredictions
auto dLdw = OUTPUT_VARIABLE(1); // dL/dweights auto dLdw = OUTPUT_VARIABLE(1); // dL/dweights
auto dLdl = OUTPUT_VARIABLE(2); // dL/dlabels auto dLdl = OUTPUT_VARIABLE(2); // dL/dlabels
@ -157,8 +157,8 @@ CUSTOM_OP_IMPL(mean_sqerr_loss_grad, 3, 3, false, 0, 1) {
// take into account Alex's proposition to treat "none" the same as "weighted_sum" mode when calculating gradients // take into account Alex's proposition to treat "none" the same as "weighted_sum" mode when calculating gradients
if(reductionMode == 0) if(reductionMode == 0)
reductionMode = 1; reductionMode = 1;
// inputs validation // inputs validation
REQUIRE_TRUE(labels->isSameShape(predictions), 0, "MEAN_SQERR_LOSS_GRAD OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(predictions).c_str()); REQUIRE_TRUE(labels->isSameShape(predictions), 0, "MEAN_SQERR_LOSS_GRAD OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(predictions).c_str());
// weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels
REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, "MEAN_SQERR_LOSS_GRAD OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", weights->rankOf(), labels->rankOf()); REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, "MEAN_SQERR_LOSS_GRAD OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", weights->rankOf(), labels->rankOf());
@ -167,9 +167,9 @@ CUSTOM_OP_IMPL(mean_sqerr_loss_grad, 3, 3, false, 0, 1) {
// only 4 possible reduction modes exist // only 4 possible reduction modes exist
REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "MEAN_SQERR_LOSS_GRAD OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode); REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "MEAN_SQERR_LOSS_GRAD OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode);
// perform weights broadcasting/tile to labels if needed // perform weights broadcasting/tile to labels if needed
auto weightsBroad = weights; auto weightsBroad = weights;
if(!weights->isScalar() && !weights->isSameShape(predictions)) if(!weights->isScalar() && !weights->isSameShape(predictions))
weightsBroad = new NDArray(weights->tileToShape(predictions->getShapeInfo())); weightsBroad = new NDArray(weights->tileToShape(predictions->getShapeInfo()));
NDArray diff = *predictions - *labels; NDArray diff = *predictions - *labels;
@ -178,20 +178,20 @@ CUSTOM_OP_IMPL(mean_sqerr_loss_grad, 3, 3, false, 0, 1) {
dLdp->assign(2. * diff); // dE/dp dLdp->assign(2. * diff); // dE/dp
// dE_i/dy_i = -2 * (p_i - y_i) // dE_i/dy_i = -2 * (p_i - y_i)
// dLdl->assign(-(*dLdp)); // dE/dl // dLdl->assign(-(*dLdp)); // dE/dl
NDArray E = diff * diff; NDArray E = diff * diff;
switch (reductionMode) { switch (reductionMode) {
case 1: { // 1 - "none" and "weighted_sum", output is scalar and equal to sum of all elements of E array case 1: { // 1 - "none" and "weighted_sum", output is scalar and equal to sum of all elements of E array
*dLdp *= *weightsBroad; *dLdp *= *weightsBroad;
if(weights->isScalar()) if(weights->isScalar())
dLdw->assign(E.reduceNumber(reduce::Sum)); dLdw->assign(E.reduceNumber(reduce::Sum));
else if(weights != weightsBroad) { else if(weights != weightsBroad) {
std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo());
E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false);
} }
else else
dLdw->assign(E); dLdw->assign(E);
@ -202,40 +202,40 @@ CUSTOM_OP_IMPL(mean_sqerr_loss_grad, 3, 3, false, 0, 1) {
NDArray sum; NDArray sum;
if (weights->isScalar()) if (weights->isScalar())
sum = (*weights) * E.lengthOf(); sum = (*weights) * E.lengthOf();
else else
sum = weightsBroad->reduceNumber(reduce::Sum); sum = weightsBroad->reduceNumber(reduce::Sum);
if (sum.e<double>(0) == 0.) { if (sum.e<double>(0) == 0.) {
*dLdp = 0.; *dLdp = 0.;
*dLdw = 0.; *dLdw = 0.;
} }
else { else {
*dLdp *= *weightsBroad / sum; *dLdp *= *weightsBroad / sum;
if(weights->isScalar()) if(weights->isScalar())
*dLdw = 0.; *dLdw = 0.;
else if(weights != weightsBroad) { else if(weights != weightsBroad) {
std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo());
((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false);
} }
else else
dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)); dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum));
} }
break; break;
} }
case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of E array divided by number of non-zero weights case 3: { // 3 - "weighted_sum_by_nonzero_weights", output is scalar and equal to scalar sum of all elements of E array divided by number of non-zero weights
Nd4jLong numOfNonZeroWeights = 0; Nd4jLong numOfNonZeroWeights = 0;
if(weights->isScalar()) { if(weights->isScalar()) {
if(weights->e<double>(0) != 0.) if(weights->e<double>(0) != 0.)
numOfNonZeroWeights = E.lengthOf(); numOfNonZeroWeights = E.lengthOf();
} }
else else
numOfNonZeroWeights = weightsBroad->reduceNumber(reduce::CountNonZero).e<Nd4jLong>(0); numOfNonZeroWeights = weightsBroad->reduceNumber(reduce::CountNonZero).e<Nd4jLong>(0);
if (numOfNonZeroWeights == 0) { if (numOfNonZeroWeights == 0) {
*dLdp = 0.; *dLdp = 0.;
*dLdw = 0.; *dLdw = 0.;
} }
else { else {
@ -245,14 +245,14 @@ CUSTOM_OP_IMPL(mean_sqerr_loss_grad, 3, 3, false, 0, 1) {
dLdw->assign(E.reduceNumber(reduce::Sum) / double(numOfNonZeroWeights)); dLdw->assign(E.reduceNumber(reduce::Sum) / double(numOfNonZeroWeights));
else if(weights != weightsBroad) { else if(weights != weightsBroad) {
std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo());
E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false);
*dLdw /= numOfNonZeroWeightsScalar; *dLdw /= numOfNonZeroWeightsScalar;
} }
else else
dLdw->assign(E / numOfNonZeroWeights); dLdw->assign(E / numOfNonZeroWeights);
NDArray temp = *weightsBroad / numOfNonZeroWeightsScalar; NDArray temp = *weightsBroad / numOfNonZeroWeightsScalar;
*dLdp *= temp; *dLdp *= temp;
} }
break; break;
} }
@ -262,12 +262,12 @@ CUSTOM_OP_IMPL(mean_sqerr_loss_grad, 3, 3, false, 0, 1) {
if(weightsBroad != weights) if(weightsBroad != weights)
delete weightsBroad; delete weightsBroad;
return Status::OK(); return Status::OK();
} }
DECLARE_TYPES(mean_sqerr_loss_grad) { DECLARE_TYPES(mean_sqerr_loss_grad) {
getOpDescriptor()->setAllowedInputTypes(nd4j::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS}); getOpDescriptor()->setAllowedInputTypes(nd4j::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS});
} }
@ -281,15 +281,15 @@ DECLARE_SHAPE_FN(mean_sqerr_loss_grad) {
REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, "MEAN_SQERR_LOSS_GRAD OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(predictionsShapeInfo).c_str()); REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, predictionsShapeInfo), 0, "MEAN_SQERR_LOSS_GRAD OP: labels and predictions arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(predictionsShapeInfo).c_str());
// weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels
REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), 0, "MEAN_SQERR_LOSS_GRAD OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), 0, "MEAN_SQERR_LOSS_GRAD OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo));
// check whether broadcast operation is possible for weights array // check whether broadcast operation is possible for weights array
REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), 0, "MEAN_SQERR_LOSS_GRAD OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), 0, "MEAN_SQERR_LOSS_GRAD OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), ShapeUtils::shapeAsString(labelsShapeInfo).c_str());
DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(predictionsShapeInfo)); DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(predictionsShapeInfo));
auto dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType(predictionsShapeInfo, outType, false, block.getWorkspace()); auto dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType(predictionsShapeInfo, outType, false, block.getWorkspace());
auto dLdwShapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, outType, false, block.getWorkspace()); auto dLdwShapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, outType, false, block.getWorkspace());
auto dLdlShapeInfo = ShapeBuilders::copyShapeInfoAndType(labelsShapeInfo, outType, false, block.getWorkspace()); auto dLdlShapeInfo = ShapeBuilders::copyShapeInfoAndType(labelsShapeInfo, outType, false, block.getWorkspace());
return SHAPELIST(CONSTANT(dLdpShapeInfo), CONSTANT(dLdwShapeInfo), CONSTANT(dLdlShapeInfo)); return SHAPELIST(CONSTANT(dLdpShapeInfo), CONSTANT(dLdwShapeInfo), CONSTANT(dLdlShapeInfo));
} }

View File

@ -38,27 +38,27 @@ CUSTOM_OP_IMPL(sigm_cross_entropy_loss, 3, 1, false, 1, 1) {
int reductionMode = INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - "weighted_sum_by_nonzero_weights" int reductionMode = INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - "weighted_sum_by_nonzero_weights"
auto labelsSmoothing = T_ARG(0); auto labelsSmoothing = T_ARG(0);
// input validation // input validation
REQUIRE_TRUE(labels->isSameShape(logits), 0, "SIGM_CROSS_ENTROPY_LOSS OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly!", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(logits).c_str()); REQUIRE_TRUE(labels->isSameShape(logits), 0, "SIGM_CROSS_ENTROPY_LOSS OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly!", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(logits).c_str());
// weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels
REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, "SIGM_CROSS_ENTROPY_LOSS OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", weights->rankOf(), labels->rankOf()); REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, "SIGM_CROSS_ENTROPY_LOSS OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", weights->rankOf(), labels->rankOf());
// check whether broadcast operation is possible for weights array // check whether broadcast operation is possible for weights array
REQUIRE_TRUE(weights->isScalar() || ShapeUtils::areShapesBroadcastable(*weights, *labels), 0, "SIGM_CROSS_ENTROPY_LOSS OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weights).c_str(), ShapeUtils::shapeAsString(labels).c_str()); REQUIRE_TRUE(weights->isScalar() || ShapeUtils::areShapesBroadcastable(*weights, *labels), 0, "SIGM_CROSS_ENTROPY_LOSS OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weights).c_str(), ShapeUtils::shapeAsString(labels).c_str());
// only 4 possible reduction modes exist // only 4 possible reduction modes exist
REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "SIGM_CROSS_ENTROPY_LOSS OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode); REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "SIGM_CROSS_ENTROPY_LOSS OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode);
// perform weights broadcasting/tile to labels if needed // perform weights broadcasting/tile to labels if needed
auto weightsBroad = weights; auto weightsBroad = weights;
if(!weights->isScalar() && !weights->isSameShape(logits)) if(!weights->isScalar() && !weights->isSameShape(logits))
weightsBroad = new NDArray(weights->tileToShape(logits->getShapeInfo())); weightsBroad = new NDArray(weights->tileToShape(logits->getShapeInfo()));
// If labelsSmoothing is nonzero, smooth the labels towards 1/2: // If labelsSmoothing is nonzero, smooth the labels towards 1/2:
auto newLabels = labels; auto newLabels = labels;
if(labelsSmoothing != 0.) { if(labelsSmoothing != 0.) {
newLabels = new NDArray(*labels); newLabels = new NDArray(*labels);
newLabels->applyScalar(scalar::SXELogitsSmoother, labelsSmoothing, newLabels, nullptr); newLabels->applyScalar(scalar::SXELogitsSmoother, labelsSmoothing, *newLabels);
} }
NDArray E(labels, false, block.launchContext()); NDArray E(labels, false, block.launchContext());
// logits - labels * logits + log(1 + exp(-logits)) -> take into account numerical stability at large logits // logits - labels * logits + log(1 + exp(-logits)) -> take into account numerical stability at large logits
@ -66,12 +66,12 @@ CUSTOM_OP_IMPL(sigm_cross_entropy_loss, 3, 1, false, 1, 1) {
// multiply E on weights // multiply E on weights
E *= *weightsBroad; E *= *weightsBroad;
switch (reductionMode) { switch (reductionMode) {
case 0: // 0 - "none", un-reduced weighted losses with the same shape as labels. case 0: // 0 - "none", un-reduced weighted losses with the same shape as labels.
output->assign(E); output->assign(E);
break; break;
case 1: { // 1 - "weighted_sum", output is scalar and equal to sum of all elements of E array case 1: { // 1 - "weighted_sum", output is scalar and equal to sum of all elements of E array
E.reduceNumber(reduce::Sum, *output); E.reduceNumber(reduce::Sum, *output);
break; break;
@ -80,12 +80,12 @@ CUSTOM_OP_IMPL(sigm_cross_entropy_loss, 3, 1, false, 1, 1) {
NDArray sum; NDArray sum;
if (weights->isScalar()) if (weights->isScalar())
sum = (*weights) * E.lengthOf(); sum = (*weights) * E.lengthOf();
else else
sum = weightsBroad->reduceNumber(reduce::Sum); sum = weightsBroad->reduceNumber(reduce::Sum);
if (sum.e<double>(0) == 0.) if (sum.e<double>(0) == 0.)
*output = 0.; *output = 0.;
else else
output->assign(E.reduceNumber(reduce::Sum) / sum); output->assign(E.reduceNumber(reduce::Sum) / sum);
break; break;
} }
@ -111,13 +111,13 @@ CUSTOM_OP_IMPL(sigm_cross_entropy_loss, 3, 1, false, 1, 1) {
delete weightsBroad; delete weightsBroad;
if(newLabels != labels) if(newLabels != labels)
delete newLabels; delete newLabels;
return Status::OK(); return Status::OK();
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
DECLARE_TYPES(sigm_cross_entropy_loss) { DECLARE_TYPES(sigm_cross_entropy_loss) {
getOpDescriptor()->setAllowedInputTypes(nd4j::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS}); getOpDescriptor()->setAllowedInputTypes(nd4j::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS});
} }
@ -128,11 +128,11 @@ DECLARE_SHAPE_FN(sigm_cross_entropy_loss) {
auto weightsShapeInfo = inputShape->at(1); auto weightsShapeInfo = inputShape->at(1);
auto labelsShapeInfo = inputShape->at(2); auto labelsShapeInfo = inputShape->at(2);
// labels and logits must have the same shapes // labels and logits must have the same shapes
REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, logitsShapeInfo), 0, "SIGM_CROSS_ENTROPY_LOSS OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(logitsShapeInfo).c_str()); REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, logitsShapeInfo), 0, "SIGM_CROSS_ENTROPY_LOSS OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(logitsShapeInfo).c_str());
// weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels
REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), 0, "SIGM_CROSS_ENTROPY_LOSS OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), 0, "SIGM_CROSS_ENTROPY_LOSS OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo));
// check whether broadcast operation is possible for weights array // check whether broadcast operation is possible for weights array
REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), 0, "SIGM_CROSS_ENTROPY_LOSS OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), 0, "SIGM_CROSS_ENTROPY_LOSS OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), ShapeUtils::shapeAsString(labelsShapeInfo).c_str());
DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(logitsShapeInfo)); DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(logitsShapeInfo));
@ -142,8 +142,8 @@ DECLARE_SHAPE_FN(sigm_cross_entropy_loss) {
outShapeInfo = ConstantShapeHelper::getInstance()->scalarShapeInfo(outType); outShapeInfo = ConstantShapeHelper::getInstance()->scalarShapeInfo(outType);
else // in this case output has the same shape as labels and logits else // in this case output has the same shape as labels and logits
outShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(outType, shape::order(labelsShapeInfo), shape::shapeOf(labelsShapeInfo), shape::rank(labelsShapeInfo))); outShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(outType, shape::order(labelsShapeInfo), shape::shapeOf(labelsShapeInfo), shape::rank(labelsShapeInfo)));
return SHAPELIST(outShapeInfo); return SHAPELIST(outShapeInfo);
} }
@ -155,12 +155,12 @@ CUSTOM_OP_IMPL(sigm_cross_entropy_loss_grad, 3, 3, false, 1, 1) {
auto logits = INPUT_VARIABLE(0); auto logits = INPUT_VARIABLE(0);
auto weights = INPUT_VARIABLE(1); auto weights = INPUT_VARIABLE(1);
auto labels = INPUT_VARIABLE(2); auto labels = INPUT_VARIABLE(2);
auto dLdp = OUTPUT_VARIABLE(0); // dL/dlogits auto dLdp = OUTPUT_VARIABLE(0); // dL/dlogits
auto dLdw = OUTPUT_VARIABLE(1); // dL/dweights auto dLdw = OUTPUT_VARIABLE(1); // dL/dweights
auto dLdl = OUTPUT_VARIABLE(2); // dL/dlabels auto dLdl = OUTPUT_VARIABLE(2); // dL/dlabels
NDArray labelsSmoothing = NDArrayFactory::create(logits->dataType(), T_ARG(0), block.launchContext()); NDArray labelsSmoothing = NDArrayFactory::create(logits->dataType(), T_ARG(0), block.launchContext());
int reductionMode = INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - "weighted_sum_by_nonzero_weights" int reductionMode = INT_ARG(0); // 0 - "none"; 1 - "weighted_sum"; 2 - "weighted_mean"; 3 - "weighted_sum_by_nonzero_weights"
@ -168,27 +168,27 @@ CUSTOM_OP_IMPL(sigm_cross_entropy_loss_grad, 3, 3, false, 1, 1) {
if(reductionMode == 0) if(reductionMode == 0)
reductionMode = 1; reductionMode = 1;
// input validation // input validation
REQUIRE_TRUE(labels->isSameShape(logits), 0, "SIGM_CROSS_ENTROPY_LOSS_GRAD OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly!", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(logits).c_str()); REQUIRE_TRUE(labels->isSameShape(logits), 0, "SIGM_CROSS_ENTROPY_LOSS_GRAD OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly!", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(logits).c_str());
// weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels
REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, "SIGM_CROSS_ENTROPY_LOSS_GRAD OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", weights->rankOf(), labels->rankOf()); REQUIRE_TRUE(weights->isScalar() || weights->rankOf() == labels->rankOf(), 0, "SIGM_CROSS_ENTROPY_LOSS_GRAD OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", weights->rankOf(), labels->rankOf());
// check whether broadcast operation is possible for weights array // check whether broadcast operation is possible for weights array
REQUIRE_TRUE(weights->isScalar() || ShapeUtils::areShapesBroadcastable(*weights, *labels), 0, "SIGM_CROSS_ENTROPY_LOSS_GRAD OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weights).c_str(), ShapeUtils::shapeAsString(labels).c_str()); REQUIRE_TRUE(weights->isScalar() || ShapeUtils::areShapesBroadcastable(*weights, *labels), 0, "SIGM_CROSS_ENTROPY_LOSS_GRAD OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weights).c_str(), ShapeUtils::shapeAsString(labels).c_str());
// only 4 possible reduction modes exist // only 4 possible reduction modes exist
REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "SIGM_CROSS_ENTROPY_LOSS_GRAD OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode); REQUIRE_TRUE(reductionMode==0 || reductionMode==1 || reductionMode==2 || reductionMode==3, 0, "SIGM_CROSS_ENTROPY_LOSS_GRAD OP: reduction mode value is not acceptable, possible values are 0, 1, 2, 3, but got %i instead!", reductionMode);
// perform weights broadcasting/tile to labels if needed // perform weights broadcasting/tile to labels if needed
auto weightsBroad = weights; auto weightsBroad = weights;
if(!weights->isScalar() && !weights->isSameShape(logits)) if(!weights->isScalar() && !weights->isSameShape(logits))
weightsBroad = new NDArray(weights->tileToShape(logits->getShapeInfo())); weightsBroad = new NDArray(weights->tileToShape(logits->getShapeInfo()));
// If labelsSmoothing is nonzero, smooth the labels towards 1/2: // If labelsSmoothing is nonzero, smooth the labels towards 1/2:
auto newLabels = labels; auto newLabels = labels;
if(labelsSmoothing.e<float>(0) != 0.f) { if(labelsSmoothing.e<float>(0) != 0.f) {
newLabels = new NDArray(*labels); newLabels = new NDArray(*labels);
newLabels->applyScalar(scalar::SXELogitsSmoother, labelsSmoothing.e<float>(0), newLabels, nullptr); newLabels->applyScalar(scalar::SXELogitsSmoother, labelsSmoothing.e<float>(0), *newLabels);
} }
NDArray E(labels, false, block.launchContext()); NDArray E(labels, false, block.launchContext());
// logits - labels * logits + log(1 + exp(-logits)) -> take into account numerical stability at large logits // logits - labels * logits + log(1 + exp(-logits)) -> take into account numerical stability at large logits
@ -196,24 +196,24 @@ CUSTOM_OP_IMPL(sigm_cross_entropy_loss_grad, 3, 3, false, 1, 1) {
// dLdp = 1 - labels - 1 / (1 + exp(logits)) // dLdp = 1 - labels - 1 / (1 + exp(logits))
helpers::sigmCrossEntropyGrad(block.launchContext(), logits, newLabels, dLdp); helpers::sigmCrossEntropyGrad(block.launchContext(), logits, newLabels, dLdp);
// dLdl = -logits // dLdl = -logits
labelsSmoothing -= 1.f; labelsSmoothing -= 1.f;
dLdl->assign(*logits * labelsSmoothing); dLdl->assign(*logits * labelsSmoothing);
switch (reductionMode) { switch (reductionMode) {
case 1: { // 1 - "none" and "weighted_sum", output is scalar and equal to sum of all elements of E array case 1: { // 1 - "none" and "weighted_sum", output is scalar and equal to sum of all elements of E array
*dLdp *= *weightsBroad; *dLdp *= *weightsBroad;
*dLdl *= *weightsBroad; *dLdl *= *weightsBroad;
if(weights->isScalar()) if(weights->isScalar())
dLdw->assign(E.reduceNumber(reduce::Sum)); dLdw->assign(E.reduceNumber(reduce::Sum));
else if(weights != weightsBroad) { else if(weights != weightsBroad) {
std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo());
E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false);
} }
else else
dLdw->assign(E); dLdw->assign(E);
break; break;
} }
@ -221,9 +221,9 @@ CUSTOM_OP_IMPL(sigm_cross_entropy_loss_grad, 3, 3, false, 1, 1) {
NDArray sum; NDArray sum;
if (weights->isScalar()) if (weights->isScalar())
sum = (*weights) * E.lengthOf(); sum = (*weights) * E.lengthOf();
else else
sum = weightsBroad->reduceNumber(reduce::Sum); sum = weightsBroad->reduceNumber(reduce::Sum);
if (sum.e<double>(0) == 0.) { if (sum.e<double>(0) == 0.) {
*dLdp = 0.; *dLdp = 0.;
*dLdl = 0.; *dLdl = 0.;
@ -234,14 +234,14 @@ CUSTOM_OP_IMPL(sigm_cross_entropy_loss_grad, 3, 3, false, 1, 1) {
NDArray temp = *weightsBroad / sum; NDArray temp = *weightsBroad / sum;
*dLdp *= temp; *dLdp *= temp;
*dLdl *= temp; *dLdl *= temp;
if(weights->isScalar()) if(weights->isScalar())
*dLdw = 0.; *dLdw = 0.;
else if(weights != weightsBroad) { else if(weights != weightsBroad) {
std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo());
((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false);
} }
else else
dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum * sum)); dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum * sum));
} }
break; break;
@ -252,8 +252,8 @@ CUSTOM_OP_IMPL(sigm_cross_entropy_loss_grad, 3, 3, false, 1, 1) {
if(weights->e<double>(0) != 0.) if(weights->e<double>(0) != 0.)
numOfNonZeroWeights = E.lengthOf(); numOfNonZeroWeights = E.lengthOf();
} }
else else
numOfNonZeroWeights = weightsBroad->reduceNumber(reduce::CountNonZero).e<Nd4jLong>(0); numOfNonZeroWeights = weightsBroad->reduceNumber(reduce::CountNonZero).e<Nd4jLong>(0);
if (numOfNonZeroWeights == 0) { if (numOfNonZeroWeights == 0) {
*dLdp = 0.; *dLdp = 0.;
@ -267,12 +267,12 @@ CUSTOM_OP_IMPL(sigm_cross_entropy_loss_grad, 3, 3, false, 1, 1) {
dLdw->assign(E.reduceNumber(reduce::Sum) / numOfNonZeroWeightsScalar); dLdw->assign(E.reduceNumber(reduce::Sum) / numOfNonZeroWeightsScalar);
else if(weights != weightsBroad) { else if(weights != weightsBroad) {
std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo());
E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false);
*dLdw /= numOfNonZeroWeightsScalar; *dLdw /= numOfNonZeroWeightsScalar;
} }
else else
dLdw->assign(E / numOfNonZeroWeightsScalar); dLdw->assign(E / numOfNonZeroWeightsScalar);
NDArray temp = *weightsBroad / numOfNonZeroWeightsScalar; NDArray temp = *weightsBroad / numOfNonZeroWeightsScalar;
*dLdp *= temp; *dLdp *= temp;
*dLdl *= temp; *dLdl *= temp;
@ -285,13 +285,13 @@ CUSTOM_OP_IMPL(sigm_cross_entropy_loss_grad, 3, 3, false, 1, 1) {
delete weightsBroad; delete weightsBroad;
if(newLabels != labels) if(newLabels != labels)
delete newLabels; delete newLabels;
return Status::OK(); return Status::OK();
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
DECLARE_TYPES(sigm_cross_entropy_loss_grad) { DECLARE_TYPES(sigm_cross_entropy_loss_grad) {
getOpDescriptor()->setAllowedInputTypes(nd4j::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS}); getOpDescriptor()->setAllowedInputTypes(nd4j::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS});
} }
@ -302,11 +302,11 @@ DECLARE_SHAPE_FN(sigm_cross_entropy_loss_grad) {
auto weightsShapeInfo = inputShape->at(1); auto weightsShapeInfo = inputShape->at(1);
auto labelsShapeInfo = inputShape->at(2); auto labelsShapeInfo = inputShape->at(2);
// labels and logits must have the same shapes // labels and logits must have the same shapes
REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, logitsShapeInfo), 0, "SIGM_CROSS_ENTROPY_LOSS_GRAD OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(logitsShapeInfo).c_str()); REQUIRE_TRUE(shape::shapeEquals(labelsShapeInfo, logitsShapeInfo), 0, "SIGM_CROSS_ENTROPY_LOSS_GRAD OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(logitsShapeInfo).c_str());
// weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels // weights array can be single scalar or has the same rank as labels, and must be broadcastable to labels
REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), 0, "SIGM_CROSS_ENTROPY_LOSS_GRAD OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo)); REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || shape::rank(weightsShapeInfo) == shape::rank(labelsShapeInfo), 0, "SIGM_CROSS_ENTROPY_LOSS_GRAD OP: weights array should be scalar or have the same rank as labels array, but got %i and %i correspondingly!", shape::rank(weightsShapeInfo), shape::rank(labelsShapeInfo));
// check whether broadcast operation is possible for weights array // check whether broadcast operation is possible for weights array
REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), 0, "SIGM_CROSS_ENTROPY_LOSS_GRAD OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), ShapeUtils::shapeAsString(labelsShapeInfo).c_str()); REQUIRE_TRUE(shape::isScalar(weightsShapeInfo) || ShapeUtils::areShapesBroadcastable(weightsShapeInfo, labelsShapeInfo), 0, "SIGM_CROSS_ENTROPY_LOSS_GRAD OP: shapes of weights and labels arrays should be broadcastable, but got weights = %s and labels = %s instead!", ShapeUtils::shapeAsString(weightsShapeInfo).c_str(), ShapeUtils::shapeAsString(labelsShapeInfo).c_str());
DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(logitsShapeInfo)); DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(logitsShapeInfo));
@ -314,7 +314,7 @@ DECLARE_SHAPE_FN(sigm_cross_entropy_loss_grad) {
auto dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType(logitsShapeInfo, outType, false, block.getWorkspace()); auto dLdpShapeInfo = ShapeBuilders::copyShapeInfoAndType(logitsShapeInfo, outType, false, block.getWorkspace());
auto dLdwShapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, outType, false, block.getWorkspace()); auto dLdwShapeInfo = ShapeBuilders::copyShapeInfoAndType(weightsShapeInfo, outType, false, block.getWorkspace());
auto dLdlShapeInfo = ShapeBuilders::copyShapeInfoAndType(labelsShapeInfo, outType, false, block.getWorkspace()); auto dLdlShapeInfo = ShapeBuilders::copyShapeInfoAndType(labelsShapeInfo, outType, false, block.getWorkspace());
return SHAPELIST(CONSTANT(dLdpShapeInfo), CONSTANT(dLdwShapeInfo), CONSTANT(dLdlShapeInfo)); return SHAPELIST(CONSTANT(dLdpShapeInfo), CONSTANT(dLdwShapeInfo), CONSTANT(dLdlShapeInfo));
} }

View File

@ -54,11 +54,11 @@ CUSTOM_OP_IMPL(softmax_cross_entropy_loss, 3, 1, false, 1, 1) {
// If label_smoothing is nonzero, smooth the labels towards 1/num_classes: new_onehot_labels = onehot_labels * (1 - label_smoothing) + label_smoothing / num_classes // If label_smoothing is nonzero, smooth the labels towards 1/num_classes: new_onehot_labels = onehot_labels * (1 - label_smoothing) + label_smoothing / num_classes
// num_classes = labels->sizeAt(1) // num_classes = labels->sizeAt(1)
auto cLabels = labels->cast(weights->dataType()); NDArray* cLabels = new NDArray(labels->cast(weights->dataType()));
auto newLabels = cLabels; NDArray* newLabels = cLabels;
if(labelsSmoothing != 0.) { if(labelsSmoothing != 0.) {
newLabels = new NDArray(cLabels); newLabels = new NDArray(cLabels);
*newLabels = (1.f - labelsSmoothing) * *cLabels + labelsSmoothing / cLabels->sizeAt(1); newLabels->assign((1.f - labelsSmoothing) * *cLabels + labelsSmoothing / cLabels->sizeAt(1));
} }
// main formula: result = - sum_i(lables_i * log(softmax_i)) - sum over last dimension // main formula: result = - sum_i(lables_i * log(softmax_i)) - sum over last dimension
@ -70,9 +70,9 @@ CUSTOM_OP_IMPL(softmax_cross_entropy_loss, 3, 1, false, 1, 1) {
std::vector<int> dimensions = {-1}; std::vector<int> dimensions = {-1};
NDArray shiftedLogits = *logits - logits->reduceAlongDims(reduce::Max, dimensions, true); NDArray shiftedLogits = *logits - logits->reduceAlongDimension(reduce::Max, dimensions, true);
NDArray logSumExp = shiftedLogits.transform(transform::Exp).reduceAlongDims(reduce::Sum, dimensions, true).transform(transform::Log); NDArray logSumExp = shiftedLogits.transform(transform::Exp).reduceAlongDimension(reduce::Sum, dimensions, true).transform(transform::Log);
NDArray E = (*newLabels * (logSumExp - shiftedLogits)).reduceAlongDims(reduce::Sum, dimensions); NDArray E = (*newLabels * (logSumExp - shiftedLogits)).reduceAlongDimension(reduce::Sum, dimensions);
// perform weights broadcasting/tile to E if it is necessary // perform weights broadcasting/tile to E if it is necessary
auto weightsBroad = weights; auto weightsBroad = weights;
@ -217,25 +217,25 @@ CUSTOM_OP_IMPL(softmax_cross_entropy_loss_grad, 3, 3, false, 1, 1) {
// If label_smoothing is nonzero, smooth the labels towards 1/num_classes: new_onehot_labels = onehot_labels * (1 - label_smoothing) + label_smoothing / num_classes // If label_smoothing is nonzero, smooth the labels towards 1/num_classes: new_onehot_labels = onehot_labels * (1 - label_smoothing) + label_smoothing / num_classes
// num_classes = labels->sizeAt(1) // num_classes = labels->sizeAt(1)
auto cLabels = labels->cast(weights->dataType()); NDArray* cLabels = new NDArray(labels->cast(weights->dataType()));
auto newLabels = cLabels; NDArray* newLabels = cLabels;
if(labelsSmoothing != 0.) { if(labelsSmoothing != 0.) {
newLabels = new NDArray(labels->getShapeInfo(), dLdl->dataType(), false, block.launchContext()); newLabels = new NDArray(labels->getShapeInfo(), dLdl->dataType(), false, block.launchContext());
newLabels->assign((1.f - labelsSmoothing) * *cLabels + labelsSmoothing / cLabels->sizeAt(1)); newLabels->assign((1.f - labelsSmoothing) * *cLabels + labelsSmoothing / cLabels->sizeAt(1));
} }
NDArray softmax = (*logits - logits->reduceAlongDims(reduce::Max, dimensions, true)).transform(transform::Exp); NDArray softmax = (*logits - logits->reduceAlongDimension(reduce::Max, dimensions, true)).transform(transform::Exp);
softmax /= softmax.reduceAlongDims(reduce::Sum, dimensions, true); softmax /= softmax.reduceAlongDimension(reduce::Sum, dimensions, true);
// dEdp = softmax * sum_i(lables_i) - labels // dEdp = softmax * sum_i(lables_i) - labels
dLdp->assign(softmax * newLabels->reduceAlongDims(reduce::Sum, dimensions, true) - *newLabels); dLdp->assign(softmax * newLabels->reduceAlongDimension(reduce::Sum, dimensions, true) - *newLabels);
// dEdl = -log(softmax) // dEdl = -log(softmax)
dLdl->assign(-softmax.transform(transform::Log)* (1.f - labelsSmoothing)); dLdl->assign(-softmax.transform(transform::Log)* (1.f - labelsSmoothing));
NDArray shiftedLogits = *logits - logits->reduceAlongDims(reduce::Max, dimensions, true); NDArray shiftedLogits = *logits - logits->reduceAlongDimension(reduce::Max, dimensions, true);
NDArray logSumExp = shiftedLogits.transform(transform::Exp).reduceAlongDims(reduce::Sum, dimensions, true).transform(transform::Log); NDArray logSumExp = shiftedLogits.transform(transform::Exp).reduceAlongDimension(reduce::Sum, dimensions, true).transform(transform::Log);
NDArray E = (*newLabels * (logSumExp - shiftedLogits)).reduceAlongDims(reduce::Sum, dimensions); NDArray E = (*newLabels * (logSumExp - shiftedLogits)).reduceAlongDimension(reduce::Sum, dimensions);
// perform weights broadcasting/tile to E if it is necessary // perform weights broadcasting/tile to E if it is necessary
auto weightsBroad = weights; auto weightsBroad = weights;
@ -253,12 +253,12 @@ CUSTOM_OP_IMPL(softmax_cross_entropy_loss_grad, 3, 3, false, 1, 1) {
*dLdl *= *weights; *dLdl *= *weights;
} }
else { else {
dLdp->applyBroadcast(nd4j::broadcast::Multiply, dimensions, weightsBroad); dLdp->applyBroadcast(nd4j::broadcast::Multiply, dimensions, *weightsBroad, *dLdp);
dLdl->applyBroadcast(nd4j::broadcast::Multiply, dimensions, weightsBroad); dLdl->applyBroadcast(nd4j::broadcast::Multiply, dimensions, *weightsBroad, *dLdl);
if(weights != weightsBroad) { if(weights != weightsBroad) {
std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo());
E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false);
} }
else else
dLdw->assign(E); dLdw->assign(E);
@ -289,12 +289,12 @@ CUSTOM_OP_IMPL(softmax_cross_entropy_loss_grad, 3, 3, false, 1, 1) {
else { else {
NDArray temp = *weightsBroad / sum; NDArray temp = *weightsBroad / sum;
dLdp->applyBroadcast(nd4j::broadcast::Multiply, dimensions, &temp); dLdp->applyBroadcast(nd4j::broadcast::Multiply, dimensions, temp, *dLdp);
dLdl->applyBroadcast(nd4j::broadcast::Multiply, dimensions, &temp); dLdl->applyBroadcast(nd4j::broadcast::Multiply, dimensions, temp, *dLdl);
if(weights != weightsBroad) { if(weights != weightsBroad) {
std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo());
((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); ((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)).reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false);
} }
else else
dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum)); dLdw->assign((E * sum - (E * *weightsBroad).reduceNumber(reduce::Sum)) / (sum*sum));
@ -326,12 +326,12 @@ CUSTOM_OP_IMPL(softmax_cross_entropy_loss_grad, 3, 3, false, 1, 1) {
} }
else { else {
NDArray temp = *weightsBroad / numOfNonZeroWeights; NDArray temp = *weightsBroad / numOfNonZeroWeights;
dLdp->applyBroadcast(nd4j::broadcast::Multiply, dimensions, &temp); dLdp->applyBroadcast(nd4j::broadcast::Multiply, dimensions, temp, *dLdp);
dLdl->applyBroadcast(nd4j::broadcast::Multiply, dimensions, &temp); dLdl->applyBroadcast(nd4j::broadcast::Multiply, dimensions, temp, *dLdl);
if(weights != weightsBroad) { if(weights != weightsBroad) {
std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo()); std::vector<int> axesToReduceAlong = ShapeUtils::evalBroadcastBackwardAxis(weights->getShapeInfo(), weightsBroad->getShapeInfo());
E.reduceAlongDimension(reduce::Sum, dLdw, axesToReduceAlong, true, false, false); E.reduceAlongDimension(reduce::Sum, *dLdw, axesToReduceAlong, true, false, false);
*dLdw /= numOfNonZeroWeights; *dLdw /= numOfNonZeroWeights;
} }
else else

View File

@ -34,38 +34,38 @@ CUSTOM_OP_IMPL(softmax_cross_entropy_loss_with_logits, 2, 1, false, 0, 0) {
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0);
const int classesDim = block.getIArguments()->size() > 0 ? INT_ARG(0) : logits->rankOf()-1; const int classesDim = block.getIArguments()->size() > 0 ? INT_ARG(0) : logits->rankOf()-1;
// input validation // input validation
REQUIRE_TRUE(labels->isSameShape(logits), 0, "SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(logits).c_str()); REQUIRE_TRUE(labels->isSameShape(logits), 0, "SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(logits).c_str());
REQUIRE_TRUE(classesDim < logits->rankOf(), 0, "SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS OP: class dimension must be smaller than rank of logits, but got %i and %i correspondingly !", classesDim, logits->rankOf()); REQUIRE_TRUE(classesDim < logits->rankOf(), 0, "SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS OP: class dimension must be smaller than rank of logits, but got %i and %i correspondingly !", classesDim, logits->rankOf());
std::vector<int> dimension = {classesDim};
auto maxAlongDim = logits->reduceAlongDims(reduce::Max, {classesDim}, true); std::vector<int> dimension = {classesDim};
auto maxAlongDim = logits->reduceAlongDimension(reduce::Max, {classesDim}, true);
auto logExp = (*logits - maxAlongDim).transform(transform::Exp); auto logExp = (*logits - maxAlongDim).transform(transform::Exp);
auto logSoftMax = ( logExp / logExp.reduceAlongDims(reduce::Sum, {classesDim}, true) ).transform(transform::Log); auto logSoftMax = ( logExp / logExp.reduceAlongDimension(reduce::Sum, {classesDim}, true) ).transform(transform::Log);
(-(*labels) * logSoftMax).reduceAlongDimension(reduce::Sum, output, dimension); (-(*labels) * logSoftMax).reduceAlongDimension(reduce::Sum, *output, dimension);
return Status::OK(); return Status::OK();
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
DECLARE_TYPES(softmax_cross_entropy_loss_with_logits) { DECLARE_TYPES(softmax_cross_entropy_loss_with_logits) {
getOpDescriptor()->setAllowedInputTypes(nd4j::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS}); getOpDescriptor()->setAllowedInputTypes(nd4j::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS});
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
DECLARE_SHAPE_FN(softmax_cross_entropy_loss_with_logits) { DECLARE_SHAPE_FN(softmax_cross_entropy_loss_with_logits) {
auto logitsShapeInfo = inputShape->at(0); auto logitsShapeInfo = inputShape->at(0);
auto labelsShapeInfo = inputShape->at(1); auto labelsShapeInfo = inputShape->at(1);
const int classesDim = block.getIArguments()->size() > 0 ? INT_ARG(0) : -1; const int classesDim = block.getIArguments()->size() > 0 ? INT_ARG(0) : -1;
std::vector<int> dimensions = {classesDim}; std::vector<int> dimensions = {classesDim};
// labels and logits must have the same shapes // labels and logits must have the same shapes
REQUIRE_TRUE(shape::shapeEquals(logitsShapeInfo, labelsShapeInfo), 0, "SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly!", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(logitsShapeInfo).c_str()); REQUIRE_TRUE(shape::shapeEquals(logitsShapeInfo, labelsShapeInfo), 0, "SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly!", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(logitsShapeInfo).c_str());
auto outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(logitsShapeInfo)); auto outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(logitsShapeInfo));
@ -90,46 +90,46 @@ CUSTOM_OP_IMPL(softmax_cross_entropy_loss_with_logits_grad, 2, 2, false, 0, 0) {
auto dLdl = OUTPUT_VARIABLE(1); // dL/dlabels auto dLdl = OUTPUT_VARIABLE(1); // dL/dlabels
const int classesDim = block.getIArguments()->size() > 0 ? INT_ARG(0) : logits->rankOf()-1; const int classesDim = block.getIArguments()->size() > 0 ? INT_ARG(0) : logits->rankOf()-1;
// input validation // input validation
REQUIRE_TRUE(labels->isSameShape(logits), 0, "SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS_GRAD OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(logits).c_str()); REQUIRE_TRUE(labels->isSameShape(logits), 0, "SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS_GRAD OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(labels).c_str(), ShapeUtils::shapeAsString(logits).c_str());
REQUIRE_TRUE(classesDim < logits->rankOf(), 0, "SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS_GRAD OP: class dimension must be smaller than rank of logits, but got %i and %i correspondingly !", classesDim, logits->rankOf()); REQUIRE_TRUE(classesDim < logits->rankOf(), 0, "SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS_GRAD OP: class dimension must be smaller than rank of logits, but got %i and %i correspondingly !", classesDim, logits->rankOf());
std::vector<int> dimension = {classesDim};
NDArray softmax = (*logits - logits->reduceAlongDims(reduce::Max, dimension, true)).transform(transform::Exp); std::vector<int> dimension = {classesDim};
softmax /= softmax.reduceAlongDims(reduce::Sum, dimension, true);
NDArray softmax = (*logits - logits->reduceAlongDimension(reduce::Max, dimension, true)).transform(transform::Exp);
softmax /= softmax.reduceAlongDimension(reduce::Sum, dimension, true);
// dEdp = softmax * sum_i(labels_i) - labels // dEdp = softmax * sum_i(labels_i) - labels
dLdp->assign(softmax * labels->reduceAlongDims(reduce::Sum, dimension, true) - *labels); dLdp->assign(softmax * labels->reduceAlongDimension(reduce::Sum, dimension, true) - *labels);
// dEdl = -log(softmax)
(-softmax).applyTransform(transform::Log, *dLdl);
// dEdl = -log(softmax)
(-softmax).applyTransform(transform::Log, dLdl);
return Status::OK(); return Status::OK();
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
DECLARE_TYPES(softmax_cross_entropy_loss_with_logits_grad) { DECLARE_TYPES(softmax_cross_entropy_loss_with_logits_grad) {
getOpDescriptor()->setAllowedInputTypes(nd4j::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS}); getOpDescriptor()->setAllowedInputTypes(nd4j::DataType::ANY)->setAllowedOutputTypes({ALL_FLOATS});
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
DECLARE_SHAPE_FN(softmax_cross_entropy_loss_with_logits_grad) { DECLARE_SHAPE_FN(softmax_cross_entropy_loss_with_logits_grad) {
auto logitsShapeInfo = inputShape->at(0); auto logitsShapeInfo = inputShape->at(0);
auto labelsShapeInfo = inputShape->at(1); auto labelsShapeInfo = inputShape->at(1);
// labels and logits must have the same shapes // labels and logits must have the same shapes
REQUIRE_TRUE(shape::shapeEquals(logitsShapeInfo, labelsShapeInfo), 0, "SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS_GRAD OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly!", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(logitsShapeInfo).c_str()); REQUIRE_TRUE(shape::shapeEquals(logitsShapeInfo, labelsShapeInfo), 0, "SOFTMAX_CROSS_ENTROPY_LOSS_WITH_LOGITS_GRAD OP: labels and logits arrays must have the same shapes, but got %s and %s correspondingly!", ShapeUtils::shapeAsString(labelsShapeInfo).c_str(), ShapeUtils::shapeAsString(logitsShapeInfo).c_str());
DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(logitsShapeInfo)); DataType outType = DataTypeUtils::pickFloatingType(ArrayOptions::dataType(logitsShapeInfo));
auto dLdpShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(outType, shape::order(logitsShapeInfo), shape::shapeOf(logitsShapeInfo), shape::rank(logitsShapeInfo))); auto dLdpShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(outType, shape::order(logitsShapeInfo), shape::shapeOf(logitsShapeInfo), shape::rank(logitsShapeInfo)));
auto dLdlShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(outType, shape::order(labelsShapeInfo), shape::shapeOf(labelsShapeInfo), shape::rank(labelsShapeInfo))); auto dLdlShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(outType, shape::order(labelsShapeInfo), shape::shapeOf(labelsShapeInfo), shape::rank(labelsShapeInfo)));
return SHAPELIST(dLdpShapeInfo, dLdlShapeInfo); return SHAPELIST(dLdpShapeInfo, dLdlShapeInfo);
} }

View File

@ -50,9 +50,9 @@ CUSTOM_OP_IMPL(sparse_softmax_cross_entropy_loss_with_logits, 2, 1, false, 0, 0)
std::vector<int> dimension = {-1}; std::vector<int> dimension = {-1};
auto maxAlongDim = logits->reduceAlongDims(reduce::Max, dimension, true); auto maxAlongDim = logits->reduceAlongDimension(reduce::Max, dimension, true);
auto logitsExp = (*logits - maxAlongDim).transform(transform::Exp, nullptr); auto logitsExp = (*logits - maxAlongDim).transform(transform::Exp, nullptr);
auto logSoftMax = -(( logitsExp / logitsExp.reduceAlongDims(reduce::Sum, dimension, true) ).transform(transform::Log)); auto logSoftMax = -(( logitsExp / logitsExp.reduceAlongDimension(reduce::Sum, dimension, true) ).transform(transform::Log));
helpers::scatterForLoss(block.launchContext(), *labels, logSoftMax, *output, false); helpers::scatterForLoss(block.launchContext(), *labels, logSoftMax, *output, false);
@ -117,8 +117,8 @@ CUSTOM_OP_IMPL(sparse_softmax_cross_entropy_loss_with_logits_grad, 2, 1, false,
std::vector<int> dimension = {-1}; std::vector<int> dimension = {-1};
NDArray softmax = (*logits - logits->reduceAlongDims(reduce::Max, dimension, true)).transform(transform::Exp); NDArray softmax = (*logits - logits->reduceAlongDimension(reduce::Max, dimension, true)).transform(transform::Exp);
softmax /= softmax.reduceAlongDims(reduce::Sum, dimension, true); softmax /= softmax.reduceAlongDimension(reduce::Sum, dimension, true);
// dEdp = softmax - 1 (or 0) // dEdp = softmax - 1 (or 0)
dLdp->assign(softmax); dLdp->assign(softmax);

View File

@ -229,19 +229,19 @@ CUSTOM_OP_IMPL(batchnorm_bp, 4, 3, false, 1, 2) {
// input - mean // input - mean
NDArray xMinusMean(input); // empty array with same shape as input NDArray xMinusMean(input); // empty array with same shape as input
input->applyBroadcast(nd4j::broadcast::Subtract, axes, mean, &xMinusMean); input->applyBroadcast(nd4j::broadcast::Subtract, axes, *mean, xMinusMean);
// stdInv // stdInv
NDArray stdInv = *variance + epsilon; NDArray stdInv = *variance + epsilon;
stdInv.applyTransform(transform::Reciprocal); // 1 / (variance + epsilon) stdInv.applyTransform(transform::Reciprocal, stdInv); // 1 / (variance + epsilon)
stdInv.applyTransform(transform::Sqrt); // 1 / (variance + epsilon)^0.5 stdInv.applyTransform(transform::Sqrt, stdInv); // 1 / (variance + epsilon)^0.5
// dvdm (use dLdM as storage for dvdm) // dvdm (use dLdM as storage for dvdm)
xMinusMean.reduceAlongDimension(nd4j::reduce::Sum, dLdM, excludedAxes, keepUnitiesInShape); xMinusMean.reduceAlongDimension(nd4j::reduce::Sum, *dLdM, excludedAxes, keepUnitiesInShape);
*dLdM *= -Ninv; *dLdM *= -Ninv;
// g_sum // g_sum
auto gSum = dLdO->reduceAlongDims(nd4j::reduce::Sum, excludedAxes, keepUnitiesInShape); auto gSum = dLdO->reduceAlongDimension(nd4j::reduce::Sum, excludedAxes, keepUnitiesInShape);
// dLdB // dLdB
if(applyOffset) if(applyOffset)
@ -249,11 +249,11 @@ CUSTOM_OP_IMPL(batchnorm_bp, 4, 3, false, 1, 2) {
// stdInv * (g - g_sum/N) (use dLdI as storage for this expression) // stdInv * (g - g_sum/N) (use dLdI as storage for this expression)
gSum *= Ninv; gSum *= Ninv;
dLdO->applyBroadcast(nd4j::broadcast::Subtract, axes, &gSum, dLdI); dLdO->applyBroadcast(nd4j::broadcast::Subtract, axes, gSum, *dLdI);
dLdI->applyBroadcast(nd4j::broadcast::Multiply, axes, &stdInv); dLdI->applyBroadcast(nd4j::broadcast::Multiply, axes, stdInv, *dLdI);
// dLdV <- [g*(x - m)]_sum // dLdV <- [g*(x - m)]_sum
(xMinusMean * *dLdO).reduceAlongDimension(nd4j::reduce::Sum, dLdV, excludedAxes, keepUnitiesInShape); (xMinusMean * *dLdO).reduceAlongDimension(nd4j::reduce::Sum, *dLdV, excludedAxes, keepUnitiesInShape);
// dLdG // dLdG
*dLdV *= stdInv; *dLdV *= stdInv;
@ -265,13 +265,13 @@ CUSTOM_OP_IMPL(batchnorm_bp, 4, 3, false, 1, 2) {
*dLdV *= -Ninv; // -0.5f * (2 / N); *dLdV *= -Ninv; // -0.5f * (2 / N);
// dfdv * (dvdm + (x - m)) (use xMinusMean as storage for this expression) // dfdv * (dvdm + (x - m)) (use xMinusMean as storage for this expression)
xMinusMean.applyBroadcast(nd4j::broadcast::Add, axes, dLdM); xMinusMean.applyBroadcast(nd4j::broadcast::Add, axes, *dLdM, xMinusMean);
xMinusMean.applyBroadcast(nd4j::broadcast::Multiply, axes, dLdV); xMinusMean.applyBroadcast(nd4j::broadcast::Multiply, axes, *dLdV, xMinusMean);
// dLdI // dLdI
*dLdI += xMinusMean; *dLdI += xMinusMean;
if(applyScale) if(applyScale)
dLdI->applyBroadcast(nd4j::broadcast::Multiply, axes, gamma); dLdI->applyBroadcast(nd4j::broadcast::Multiply, axes, *gamma, *dLdI);
*dLdM = 0; // put zeros so far *dLdM = 0; // put zeros so far
*dLdV = 0; // put zeros so far *dLdV = 0; // put zeros so far

View File

@ -240,7 +240,7 @@ CUSTOM_OP_IMPL(conv3dnew_bp, 3, 2, false, 0, 13) {
if(gradB) { if(gradB) {
if(gradB->rankOf() == 2) if(gradB->rankOf() == 2)
gradB = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()})); gradB = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}));
gradO->reduceAlongDimension(reduce::Sum, gradB, gradOaxesForDot); // sum over bS oD oH oW gradO->reduceAlongDimension(reduce::Sum, *gradB, gradOaxesForDot); // sum over bS oD oH oW
if(gradB != OUTPUT_VARIABLE(2)) if(gradB != OUTPUT_VARIABLE(2))
delete gradB; delete gradB;
} }

View File

@ -234,7 +234,7 @@ CUSTOM_OP_IMPL(deconv2d_bp, 3, 2, false, 0, 9) {
if(gradB) { if(gradB) {
if(gradB->rankOf() == 2) if(gradB->rankOf() == 2)
gradB = new NDArray(gradB->reshape(gradB->ordering(), {gradB->lengthOf()})); gradB = new NDArray(gradB->reshape(gradB->ordering(), {gradB->lengthOf()}));
gradO->reduceAlongDimension(reduce::Sum, gradB, {0, 2, 3}); // sum over bS, oH, oW gradO->reduceAlongDimension(reduce::Sum, *gradB, {0, 2, 3}); // sum over bS, oH, oW
if(gradB != OUTPUT_VARIABLE(2)) if(gradB != OUTPUT_VARIABLE(2))
delete gradB; delete gradB;
} }

View File

@ -244,7 +244,7 @@ CUSTOM_OP_IMPL(deconv3d_bp, 3, 2, false, 0, 13) {
if(gradB) { if(gradB) {
if(gradB->rankOf() == 2) if(gradB->rankOf() == 2)
gradB = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()})); gradB = new NDArray(gradB->reshape(gradB->ordering(), {(int)gradB->lengthOf()}));
gradO->reduceAlongDimension(reduce::Sum, gradB, {0, 2, 3, 4}); // sum over bS, oD, oH, oW gradO->reduceAlongDimension(reduce::Sum, *gradB, {0, 2, 3, 4}); // sum over bS, oD, oH, oW
if(gradB != OUTPUT_VARIABLE(2)) if(gradB != OUTPUT_VARIABLE(2))
delete gradB; delete gradB;
} }

View File

@ -42,35 +42,35 @@ CUSTOM_OP_IMPL(fused_batch_norm, 3, 3, false, 0, 2) {
auto batchVar = OUTPUT_VARIABLE(2); // [iD] auto batchVar = OUTPUT_VARIABLE(2); // [iD]
const bool dataFormat = (bool)INT_ARG(0); // 0->NHWC, 1->NCHW const bool dataFormat = (bool)INT_ARG(0); // 0->NHWC, 1->NCHW
const bool isTraining = (bool)INT_ARG(1); const bool isTraining = (bool)INT_ARG(1);
REQUIRE_TRUE(x->rankOf() == 4, 0, "CUSTOM_OP fused_batch_norm: the rank of input x array must be equal to 4, but got %i instead !", x->rankOf()); REQUIRE_TRUE(x->rankOf() == 4, 0, "CUSTOM_OP fused_batch_norm: the rank of input x array must be equal to 4, but got %i instead !", x->rankOf());
int bS = x->sizeAt(0); // batch size int bS = x->sizeAt(0); // batch size
int iH, iW, iD; // input height, input width, input depth(number of channels) int iH, iW, iD; // input height, input width, input depth(number of channels)
if(dataFormat) { if(dataFormat) {
iD = x->sizeAt(1); iD = x->sizeAt(1);
iH = x->sizeAt(2); iH = x->sizeAt(2);
iW = x->sizeAt(3); iW = x->sizeAt(3);
} }
else { else {
iD = x->sizeAt(3); iD = x->sizeAt(3);
iH = x->sizeAt(1); iH = x->sizeAt(1);
iW = x->sizeAt(2); iW = x->sizeAt(2);
} }
REQUIRE_TRUE(scale->rankOf() == 1 && scale->sizeAt(0) == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input scale array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(scale).c_str()); REQUIRE_TRUE(scale->rankOf() == 1 && scale->sizeAt(0) == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input scale array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(scale).c_str());
REQUIRE_TRUE(offset->rankOf() == 1 && offset->sizeAt(0) == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input offset array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(offset).c_str()); REQUIRE_TRUE(offset->rankOf() == 1 && offset->sizeAt(0) == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input offset array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(offset).c_str());
NDArray *mean(nullptr), *variance(nullptr); NDArray *mean(nullptr), *variance(nullptr);
if(!isTraining){ if(!isTraining){
mean = INPUT_VARIABLE(3); mean = INPUT_VARIABLE(3);
variance = INPUT_VARIABLE(4); variance = INPUT_VARIABLE(4);
REQUIRE_TRUE(mean->rankOf() == 1 && mean->sizeAt(0) == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input mean array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(mean).c_str()); REQUIRE_TRUE(mean->rankOf() == 1 && mean->sizeAt(0) == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input mean array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(mean).c_str());
REQUIRE_TRUE(variance->rankOf() == 1 && variance->sizeAt(0) == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input variance array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(variance).c_str()); REQUIRE_TRUE(variance->rankOf() == 1 && variance->sizeAt(0) == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input variance array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(variance).c_str());
} }
else { else {
//REQUIRE_TRUE(block.width() == 3, 0, "CUSTOM_OP fused_batch_norm: when isTraining=true then number of input arrays must be equal to 3, but got %i instead !", block.width()); //REQUIRE_TRUE(block.width() == 3, 0, "CUSTOM_OP fused_batch_norm: when isTraining=true then number of input arrays must be equal to 3, but got %i instead !", block.width());
std::vector<Nd4jLong> shape = {iD}; std::vector<Nd4jLong> shape = {iD};
mean = NDArrayFactory::create_(scale->ordering(), shape, scale->dataType(), block.launchContext()); mean = NDArrayFactory::create_(scale->ordering(), shape, scale->dataType(), block.launchContext());
variance = NDArrayFactory::create_(scale->ordering(), shape, scale->dataType(), block.launchContext()); variance = NDArrayFactory::create_(scale->ordering(), shape, scale->dataType(), block.launchContext());
@ -78,13 +78,13 @@ CUSTOM_OP_IMPL(fused_batch_norm, 3, 3, false, 0, 2) {
// FIXME: double? // FIXME: double?
double epsilon; double epsilon;
if(block.getTArguments()->size() > 0) if(block.getTArguments()->size() > 0)
epsilon = T_ARG(0) > 1.001e-5 ? T_ARG(0) : 1.001e-5; epsilon = T_ARG(0) > 1.001e-5 ? T_ARG(0) : 1.001e-5;
else else
epsilon = 0.001; epsilon = 0.001;
const int restSize = x->lengthOf() / iD; const int restSize = x->lengthOf() / iD;
auto xAffected = NDArrayFactory::create(x->ordering(), {restSize, iD}, x->dataType(), block.launchContext()); auto xAffected = NDArrayFactory::create(x->ordering(), {restSize, iD}, mean->dataType(), block.launchContext());
xAffected.assign(x); xAffected.assign(x);
const int restSizeMinusOne = (restSize > 1) ? (restSize - 1) : 1; const int restSizeMinusOne = (restSize > 1) ? (restSize - 1) : 1;
@ -93,28 +93,28 @@ CUSTOM_OP_IMPL(fused_batch_norm, 3, 3, false, 0, 2) {
const double restSizeAdjust = (double)restSize / restSizeMinusOne; const double restSizeAdjust = (double)restSize / restSizeMinusOne;
if(isTraining) { if(isTraining) {
auto sum = xAffected.reduceAlongDims(reduce::Sum, {0}); auto sum = xAffected.reduceAlongDimension(reduce::Sum, {0});
sum *= restSizeInv; sum *= restSizeInv;
mean->assign(sum); mean->assign(sum);
*batchMean = *mean; *batchMean = *mean;
//delete sum; //delete sum;
} }
else else
*batchMean = 0.; *batchMean = 0.;
xAffected -= *mean; xAffected -= *mean;
if(isTraining) { if(isTraining) {
int power = 2; int power = 2;
xAffected.applyScalar(scalar::Pow, power); xAffected.applyScalar(scalar::Pow, power, xAffected);
auto sum = xAffected.reduceAlongDims(reduce::Sum, {0}); auto sum = xAffected.reduceAlongDimension(reduce::Sum, {0});
sum *= restSizeInv; sum *= restSizeInv;
variance->assign(sum); variance->assign(sum);
*batchVar = (*variance) * restSizeAdjust; *batchVar = (*variance) * restSizeAdjust;
//delete sum; //delete sum;
} }
else else
*batchVar = 0.; *batchVar = 0.;
xAffected *= (*variance + epsilon).transform(transform::RSqrt) * (*scale) + (*offset); xAffected *= (*variance + epsilon).transform(transform::RSqrt) * (*scale) + (*offset);
y->assign( xAffected ); y->assign( xAffected );
@ -136,13 +136,13 @@ DECLARE_SHAPE_FN(fused_batch_norm) {
const int iD = dataFormat ? xShapeInfo[2] : xShapeInfo[4]; const int iD = dataFormat ? xShapeInfo[2] : xShapeInfo[4];
REQUIRE_TRUE(scaleShapeInfo[0] == 1 && scaleShapeInfo[1] == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input scale array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(scaleShapeInfo).c_str()); REQUIRE_TRUE(scaleShapeInfo[0] == 1 && scaleShapeInfo[1] == iD, 0, "CUSTOM_OP fused_batch_norm: wrong shape of input scale array, expected is [%i], but got %s instead", iD, ShapeUtils::shapeAsString(scaleShapeInfo).c_str());
Nd4jLong* outShapeInfo(nullptr), *batchMeanShapeInfo(nullptr), *batchVarShapeInfo(nullptr); Nd4jLong* outShapeInfo(nullptr), *batchMeanShapeInfo(nullptr), *batchVarShapeInfo(nullptr);
COPY_SHAPE(xShapeInfo, outShapeInfo); COPY_SHAPE(xShapeInfo, outShapeInfo);
COPY_SHAPE(scaleShapeInfo, batchMeanShapeInfo); COPY_SHAPE(scaleShapeInfo, batchMeanShapeInfo);
COPY_SHAPE(scaleShapeInfo, batchVarShapeInfo); COPY_SHAPE(scaleShapeInfo, batchVarShapeInfo);
return SHAPELIST(CONSTANT(outShapeInfo), CONSTANT(batchMeanShapeInfo), CONSTANT(batchVarShapeInfo)); return SHAPELIST(CONSTANT(outShapeInfo), CONSTANT(batchMeanShapeInfo), CONSTANT(batchVarShapeInfo));
} }

View File

@ -37,7 +37,7 @@ namespace ops {
CONFIGURABLE_OP_IMPL(log_softmax, 1, 1, true, 0, 0) { CONFIGURABLE_OP_IMPL(log_softmax, 1, 1, true, 0, 0) {
auto input = INPUT_VARIABLE(0); auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0);
const int rank = input->rankOf(); const int rank = input->rankOf();
const int dim = block.getIArguments()->size() > 0 ? INT_ARG(0) : rank - 1; const int dim = block.getIArguments()->size() > 0 ? INT_ARG(0) : rank - 1;
@ -67,8 +67,8 @@ CONFIGURABLE_OP_IMPL(log_softmax_bp, 2, 1, true, 0, 0) {
REQUIRE_TRUE(dim < rank, 0, "LOG_SOFTMAX_BP OP: the value of input integer parameter (dimension) must be less than input array rank %i, but got dimension = %i instead !", rank, dim); REQUIRE_TRUE(dim < rank, 0, "LOG_SOFTMAX_BP OP: the value of input integer parameter (dimension) must be less than input array rank %i, but got dimension = %i instead !", rank, dim);
helpers::softmax(block.launchContext(), *input, *gradI, dim); helpers::softmax(block.launchContext(), *input, *gradI, dim);
gradI->assign( *gradO - (*gradI * *gradO).reduceAlongDims(reduce::Sum, {dim}, true) ); gradI->assign( *gradO - (*gradI * *gradO).reduceAlongDimension(reduce::Sum, {dim}, true) );
return Status::OK(); return Status::OK();
} }

View File

@ -31,10 +31,10 @@ namespace nd4j {
REQUIRE_TRUE(w->isMatrix(), 0, "relu_layer: weights argument should be a 2D tensor, but got rank %i instead!", w->rankOf()); REQUIRE_TRUE(w->isMatrix(), 0, "relu_layer: weights argument should be a 2D tensor, but got rank %i instead!", w->rankOf());
REQUIRE_TRUE(b->isVector(), 0, "relu_layer: biases argument should be a 1D tensor, but got rank %i instead!", b->rankOf()); REQUIRE_TRUE(b->isVector(), 0, "relu_layer: biases argument should be a 1D tensor, but got rank %i instead!", b->rankOf());
REQUIRE_TRUE(b->lengthOf() == w->sizeAt(1), 0, "relu_layer: biases array length should match to columns of weights matrix, however got length = %i and columns = %i!", b->lengthOf(), w->sizeAt(1)); REQUIRE_TRUE(b->lengthOf() == w->sizeAt(1), 0, "relu_layer: biases array length should match to columns of weights matrix, however got length = %i and columns = %i!", b->lengthOf(), w->sizeAt(1));
REQUIRE_TRUE(x->sizeAt(1) == w->sizeAt(0), 0, "relu_layer: number of x columns should match to row number of weights matrix, but got x_columns = %i and weights_rows = %i!", REQUIRE_TRUE(x->sizeAt(1) == w->sizeAt(0), 0, "relu_layer: number of x columns should match to row number of weights matrix, but got x_columns = %i and weights_rows = %i!",
x->sizeAt(1), w->sizeAt(0)); x->sizeAt(1), w->sizeAt(0));
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0);
//T bound = (T)0.f; //T bound = (T)0.f;
//nd4j_printf("Matrix x(%ix%i), Matrix w(%ix%i), b(1x%i)\n", x->sizeAt(0), x->sizeAt(1), w->sizeAt(0), w->sizeAt(1), b->lengthOf()); //nd4j_printf("Matrix x(%ix%i), Matrix w(%ix%i), b(1x%i)\n", x->sizeAt(0), x->sizeAt(1), w->sizeAt(0), w->sizeAt(1), b->lengthOf());
@ -46,7 +46,7 @@ namespace nd4j {
auto scalar = block.numT() > 0 ? block.getTArguments()->at(0) : 0.0; auto scalar = block.numT() > 0 ? block.getTArguments()->at(0) : 0.0;
auto xw = result->at(0); auto xw = result->at(0);
xw->applyScalar(nd4j::scalar::RELU, scalar, output); xw->applyScalar(nd4j::scalar::RELU, scalar, *output);
return Status::OK(); return Status::OK();
} }
@ -55,7 +55,7 @@ namespace nd4j {
auto inShape = inputShape->at(0); auto inShape = inputShape->at(0);
auto weightsShape = inputShape->at(1); auto weightsShape = inputShape->at(1);
auto outputShape = ShapeUtils::matrixProductShape(inShape, weightsShape, false, false, ArrayOptions::dataType(inShape), block.getWorkspace()); auto outputShape = ShapeUtils::matrixProductShape(inShape, weightsShape, false, false, ArrayOptions::dataType(inShape), block.getWorkspace());
return SHAPELIST(CONSTANT(outputShape)); return SHAPELIST(CONSTANT(outputShape));
} }

View File

@ -38,7 +38,7 @@ namespace ops {
CONFIGURABLE_OP_IMPL(softmax, 1, 1, true, 0, 0) { CONFIGURABLE_OP_IMPL(softmax, 1, 1, true, 0, 0) {
auto input = INPUT_VARIABLE(0); auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0);
const int rank = input->rankOf(); const int rank = input->rankOf();
const int dim = block.getIArguments()->size() > 0 ? INT_ARG(0) : rank - 1; const int dim = block.getIArguments()->size() > 0 ? INT_ARG(0) : rank - 1;
@ -59,10 +59,10 @@ CONFIGURABLE_OP_IMPL(softmax_bp, 2, 1, true, 0, 0) {
const int dim = block.getIArguments()->size() > 0 ? INT_ARG(0) : rank - 1; const int dim = block.getIArguments()->size() > 0 ? INT_ARG(0) : rank - 1;
REQUIRE_TRUE(dim < rank, 0, "SOFTMAX_BP OP: the value of input integer parameter (dimension) must be less than input array rank %i, but got dimension = %i instead !", rank, dim); REQUIRE_TRUE(dim < rank, 0, "SOFTMAX_BP OP: the value of input integer parameter (dimension) must be less than input array rank %i, but got dimension = %i instead !", rank, dim);
helpers::softmax(block.launchContext(), *input, *gradI, dim); helpers::softmax(block.launchContext(), *input, *gradI, dim);
auto sumAlongDim = (*gradI * *gradO).reduceAlongDims(reduce::Sum, {dim}, true); auto sumAlongDim = (*gradI * *gradO).reduceAlongDimension(reduce::Sum, {dim}, true);
gradI->assign(*gradI * (*gradO - sumAlongDim)); gradI->assign(*gradI * (*gradO - sumAlongDim));
return Status::OK(); return Status::OK();

View File

@ -56,7 +56,7 @@ CONFIGURABLE_OP_IMPL(adjust_contrast, 1, 1, true, 0, 0) {
axes[i] = i; axes[i] = i;
// mean as reduction for last dimension set // mean as reduction for last dimension set
auto mean = input->reduceAlongDims(reduce::Mean, axes); auto mean = input->reduceAlongDimension(reduce::Mean, axes);
// this is contrast calculation // this is contrast calculation
output->assign((*input - mean) * (*factor) + mean); output->assign((*input - mean) * (*factor) + mean);
@ -104,13 +104,13 @@ CONFIGURABLE_OP_IMPL(adjust_contrast_v2, 1, 1, true, 0, 0) {
std::vector<int> axes({1}); // dim 1 of pseudoresult std::vector<int> axes({1}); // dim 1 of pseudoresult
// mean as reduction for last dimension set over size (dim 1) of result3D // mean as reduction for last dimension set over size (dim 1) of result3D
auto mean = input3D.reduceAlongDims(reduce::Mean, axes); auto mean = input3D.reduceAlongDimension(reduce::Mean, axes);
// result as (x - mean) * factor + mean // result as (x - mean) * factor + mean
auto temp = input3D.ulike(); auto temp = input3D.ulike();
input3D.applyBroadcast(broadcast::Subtract, {0, 2}, &mean, &temp, nullptr); input3D.applyBroadcast(broadcast::Subtract, {0, 2}, mean, temp);
temp.applyScalarArr(scalar::Multiply, factor); temp.applyScalarArr(scalar::Multiply, *factor, temp);
temp.applyBroadcast(broadcast::Add, {0, 2}, &mean, &output3D); temp.applyBroadcast(broadcast::Add, {0, 2}, mean, output3D);
output->assign(output3D); output->assign(output3D);
if(block.width() == 1) if(block.width() == 1)
delete factor; delete factor;

View File

@ -44,11 +44,11 @@ namespace nd4j {
auto axisVector = INPUT_VARIABLE(1); auto axisVector = INPUT_VARIABLE(1);
helpers::adjustAxis(input->rankOf(), axisVector, axis); helpers::adjustAxis(input->rankOf(), axisVector, axis);
input->applyIndexReduce(indexreduce::IndexMax, output, axis); input->applyIndexReduce(indexreduce::IndexMax, *output, axis);
} else { } else {
helpers::adjustAxis(input->rankOf(), axis); helpers::adjustAxis(input->rankOf(), axis);
input->applyIndexReduce(indexreduce::IndexMax, output, axis); input->applyIndexReduce(indexreduce::IndexMax, *output, axis);
} }
STORE_RESULT(output); STORE_RESULT(output);

View File

@ -44,11 +44,11 @@ namespace nd4j {
auto axisVector = INPUT_VARIABLE(1); auto axisVector = INPUT_VARIABLE(1);
helpers::adjustAxis(input->rankOf(), axisVector, axis); helpers::adjustAxis(input->rankOf(), axisVector, axis);
input->applyIndexReduce(indexreduce::IndexMin, output, axis); input->applyIndexReduce(indexreduce::IndexMin, *output, axis);
} else { } else {
helpers::adjustAxis(input->rankOf(), axis); helpers::adjustAxis(input->rankOf(), axis);
input->applyIndexReduce(indexreduce::IndexMin, output, axis); input->applyIndexReduce(indexreduce::IndexMin, *output, axis);
} }
STORE_RESULT(output); STORE_RESULT(output);

View File

@ -82,7 +82,7 @@ CUSTOM_OP_IMPL(biasadd_bp, 3, 2, false, 0, 0) {
gradI->assign(gradO); gradI->assign(gradO);
gradO->reduceAlongDimension(nd4j::reduce::Sum, gradB, ShapeUtils::evalDimsToExclude(gradO->rankOf(), {channelDim})); gradO->reduceAlongDimension(nd4j::reduce::Sum, *gradB, ShapeUtils::evalDimsToExclude(gradO->rankOf(), {channelDim}));
return ND4J_STATUS_OK; return ND4J_STATUS_OK;
} }

View File

@ -45,7 +45,7 @@ CUSTOM_OP_IMPL(embedding_lookup, 2, 1, false, 0, 1) {
v = i++; v = i++;
} }
std::unique_ptr<ResultSet> outputView(output->allTensorsAlongDimension(dims)); ResultSet outputView = output->allTensorsAlongDimension(dims);
REQUIRE_TRUE(block.width() > output->sizeAt(0), 0, "embedding_lookup: input list should be greater then %i, but %i given.", REQUIRE_TRUE(block.width() > output->sizeAt(0), 0, "embedding_lookup: input list should be greater then %i, but %i given.",
output->sizeAt(0), block.width() output->sizeAt(0), block.width()
); );
@ -53,7 +53,7 @@ CUSTOM_OP_IMPL(embedding_lookup, 2, 1, false, 0, 1) {
Nd4jLong thisIndex = (*indeces).e<Nd4jLong>(e); Nd4jLong thisIndex = (*indeces).e<Nd4jLong>(e);
input = INPUT_VARIABLE(thisIndex); // lookup param input = INPUT_VARIABLE(thisIndex); // lookup param
outputView->at(e)->assign(input); outputView.at(e)->assign(input);
} }
} }
else { else {
@ -87,7 +87,7 @@ DECLARE_SHAPE_FN(embedding_lookup) {
int inRank = shape::rank(inShapeInfo); int inRank = shape::rank(inShapeInfo);
if (inputShape->size() == 2u) { if (inputShape->size() == 2u) {
int outRank = inRank; int outRank = inRank;
std::vector<Nd4jLong> shapeInfo(outRank); std::vector<Nd4jLong> shapeInfo(outRank);
shapeInfo[0] = indecesShapeInfo[1]; // vector - how many elements shapeInfo[0] = indecesShapeInfo[1]; // vector - how many elements
@ -98,14 +98,14 @@ DECLARE_SHAPE_FN(embedding_lookup) {
return SHAPELIST(outShapeInfo); return SHAPELIST(outShapeInfo);
} }
int outRank = inRank + 1; int outRank = inRank + 1;
std::vector<Nd4jLong> shapeInfo(outRank); std::vector<Nd4jLong> shapeInfo(outRank);
auto indeces = INPUT_VARIABLE(block.width() - 1); auto indeces = INPUT_VARIABLE(block.width() - 1);
shapeInfo[0] = indeces->lengthOf(); // vector - how many elements shapeInfo[0] = indeces->lengthOf(); // vector - how many elements
for (int e = 1; e < outRank; e++) for (int e = 1; e < outRank; e++)
shapeInfo[e] = shape::sizeAt(inShapeInfo, e); shapeInfo[e] = shape::sizeAt(inShapeInfo, e);
auto outShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShapeInfo), shape::order(inShapeInfo), shapeInfo); auto outShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inShapeInfo), shape::order(inShapeInfo), shapeInfo);
return SHAPELIST(outShapeInfo); return SHAPELIST(outShapeInfo);
} }

View File

@ -49,8 +49,8 @@ namespace nd4j {
} }
std::vector<int>& dims = axis; std::vector<int>& dims = axis;
input->varianceAlongDimension(variance::SummaryStatsVariance, variances, false, axis); input->varianceAlongDimension(variance::SummaryStatsVariance, *variances, false, axis);
input->reduceAlongDimension(reduce::Mean, means, axis, keepDims); input->reduceAlongDimension(reduce::Mean, *means, axis, keepDims);
return Status::OK(); return Status::OK();
} }
@ -74,10 +74,10 @@ namespace nd4j {
} }
//std::vector<int> dims = ShapeUtils::evalDimsToExclude(input->rankOf(), {axis}); //std::vector<int> dims = ShapeUtils::evalDimsToExclude(input->rankOf(), {axis});
const bool keepDims = block.getTArguments()->size() > 0 ? (bool)T_ARG(0) : false; const bool keepDims = block.getTArguments()->size() > 0 ? (bool)T_ARG(0) : false;
auto meanShape = ShapeUtils::evalReduceShapeInfo('c', axis, *input, keepDims, false, block.workspace()); auto meanShape = ShapeUtils::evalReduceShapeInfo('c', axis, *input, keepDims, false, block.workspace());
auto varianceShape = ShapeUtils::evalReduceShapeInfo('c', axis, *input, keepDims, false, block.workspace()); auto varianceShape = ShapeUtils::evalReduceShapeInfo('c', axis, *input, keepDims, false, block.workspace());
return SHAPELIST(meanShape, varianceShape); return SHAPELIST(meanShape, varianceShape);
} }
DECLARE_TYPES(moments) { DECLARE_TYPES(moments) {

View File

@ -52,31 +52,31 @@ namespace nd4j {
case 0: { case 0: {
REQUIRE_TRUE(dims.size() == 2 || (input->rankOf() == 2 && dims.size() == 0), 0, "Norm: Frobenius is defined for 2D matrices or TADS only"); REQUIRE_TRUE(dims.size() == 2 || (input->rankOf() == 2 && dims.size() == 0), 0, "Norm: Frobenius is defined for 2D matrices or TADS only");
// fro // fro
input->reduceAlongDimension(reduce::NormFrobenius, output, dims, false, output->rankOf() == 2); input->reduceAlongDimension(reduce::NormFrobenius, *output, dims, false, output->rankOf() == 2);
} }
break; break;
case 1: { case 1: {
// euclidean // euclidean
if ((input->rankOf() == 2 && dims.size() == 0) || dims.size() == 2) { if ((input->rankOf() == 2 && dims.size() == 0) || dims.size() == 2) {
input->reduceAlongDimension(reduce::NormFrobenius, output, dims, false, output->rankOf() == 2); input->reduceAlongDimension(reduce::NormFrobenius, *output, dims, false, output->rankOf() == 2);
} else { } else {
input->reduceAlongDimension(reduce::Norm2, output, dims, false, output->rankOf() == 2); input->reduceAlongDimension(reduce::Norm2, *output, dims, false, output->rankOf() == 2);
} }
} }
break; break;
case 2: { case 2: {
// 1 // 1
input->reduceAlongDimension(reduce::Norm1, output, dims, false, output->rankOf() == 2); input->reduceAlongDimension(reduce::Norm1, *output, dims, false, output->rankOf() == 2);
} }
break; break;
case 3: { case 3: {
// 2 // 2
input->reduceAlongDimension(reduce::Norm2, output, dims, false, output->rankOf() == 2); input->reduceAlongDimension(reduce::Norm2, *output, dims, false, output->rankOf() == 2);
} }
break; break;
case 4: { case 4: {
// inf-norm // inf-norm
input->reduceAlongDimension(reduce::NormMax, output, dims, false, output->rankOf() == 2); input->reduceAlongDimension(reduce::NormMax, *output, dims, false, output->rankOf() == 2);
} }
break; break;
default: { default: {
@ -84,7 +84,7 @@ namespace nd4j {
REQUIRE_TRUE(block.getIArguments()->size() > 1, 0, "P-Norm reductions requires 2 TArguments, but only 1 was provided"); REQUIRE_TRUE(block.getIArguments()->size() > 1, 0, "P-Norm reductions requires 2 TArguments, but only 1 was provided");
// FIXME: p is required here // FIXME: p is required here
//T p = T_ARG(1); //T p = T_ARG(1);
input->reduceAlongDimension(reduce::NormP, output, dims, false, output->rankOf() == 2); input->reduceAlongDimension(reduce::NormP, *output, dims, false, output->rankOf() == 2);
} }
} }

View File

@ -40,23 +40,20 @@ namespace nd4j {
shift.assign(T_ARG(0)); shift.assign(T_ARG(0));
} }
means->applyScalarArr(scalar::Divide, counts, resMeans, nullptr); means->applyScalarArr(scalar::Divide, *counts, *resMeans);
NDArray* squareMeans = resMeans->dup('c'); NDArray squareMeans = resMeans->dup('c');
NDArray* tempVariances = resVariances->dup('c'); NDArray tempVariances = resVariances->dup('c');
squareMeans->applyTransform(transform::Square, squareMeans, nullptr); squareMeans.applyTransform(transform::Square, squareMeans, nullptr);
variances->applyScalarArr(scalar::Divide, counts, tempVariances, nullptr); variances->applyScalarArr(scalar::Divide, *counts, tempVariances);
// tempVariances->printIndexedBuffer("varianced divided by count"); // tempVariances.printIndexedBuffer("varianced divided by count");
tempVariances->applyPairwiseTransform(pairwise::Subtract, squareMeans, resVariances, nullptr); tempVariances.applyPairwiseTransform(pairwise::Subtract, squareMeans, *resVariances);
if (shift.e<double>(0) != 0) { if (shift.e<double>(0) != 0) {
resMeans->applyScalarArr(scalar::Add, &shift, resMeans, nullptr); resMeans->applyScalarArr(scalar::Add, shift, *resMeans);
} }
delete squareMeans;
delete tempVariances;
return Status::OK(); return Status::OK();
} }

View File

@ -47,7 +47,7 @@ CUSTOM_OP_IMPL(reduce_mean, 1, 1, false, 0, 0) {
for(const auto& item : dimensions) for(const auto& item : dimensions)
REQUIRE_TRUE(item >= -input->rankOf() && item < input->rankOf(), 0, "REDUCE_MEAN OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item); REQUIRE_TRUE(item >= -input->rankOf() && item < input->rankOf(), 0, "REDUCE_MEAN OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item);
input->reduceAlongDimension(reduce::Mean, output, dimensions, keepDims); input->reduceAlongDimension(reduce::Mean, *output, dimensions, keepDims);
return Status::OK(); return Status::OK();
} }

View File

@ -55,7 +55,7 @@ CUSTOM_OP_IMPL(reduce_stdev, 1, 1, false, 0, 0) {
for(const auto& item : dimensions) for(const auto& item : dimensions)
REQUIRE_TRUE(item >= -input->rankOf() && item < input->rankOf(), 0, "REDUCE_STDEV OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item); REQUIRE_TRUE(item >= -input->rankOf() && item < input->rankOf(), 0, "REDUCE_STDEV OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item);
input->varianceAlongDimension(variance::SummaryStatsStandardDeviation, output, biasCorrected, dimensions); input->varianceAlongDimension(variance::SummaryStatsStandardDeviation, *output, biasCorrected, dimensions);
return Status::OK(); return Status::OK();
} }
@ -130,10 +130,10 @@ CUSTOM_OP_IMPL(reduce_stdev_bp, 2, 1, false, 0, 0) {
const Nd4jLong N = input->lengthOf() / gradO->lengthOf(); const Nd4jLong N = input->lengthOf() / gradO->lengthOf();
const Nd4jLong NminusOne = biasCorrected ? N - 1 : N; const Nd4jLong NminusOne = biasCorrected ? N - 1 : N;
auto mean = input->reduceAlongDims(reduce::Mean, dimensions, true); auto mean = input->reduceAlongDimension(reduce::Mean, dimensions, true);
NDArray variance(mean.getShapeInfo(), true, block.launchContext()); // create empty array with shape matching shape of mean array NDArray variance(mean.getShapeInfo(), true, block.launchContext()); // create empty array with shape matching shape of mean array
input->varianceAlongDimension(variance::SummaryStatsStandardDeviation, &variance, biasCorrected, dimensions); input->varianceAlongDimension(variance::SummaryStatsStandardDeviation, variance, biasCorrected, dimensions);
gradI->assign( (*input - mean) / (variance * NminusOne)); // automatic broadcasting happens here gradI->assign( (*input - mean) / (variance * NminusOne)); // automatic broadcasting happens here

View File

@ -54,8 +54,8 @@ CUSTOM_OP_IMPL(reduce_variance, 1, 1, false, 0, 0) {
for(const auto& item : dimensions) for(const auto& item : dimensions)
REQUIRE_TRUE(item >= -input->rankOf() && item < input->rankOf(), 0, "REDUCE_VARIANCE OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item); REQUIRE_TRUE(item >= -input->rankOf() && item < input->rankOf(), 0, "REDUCE_VARIANCE OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , input->rankOf(), input->rankOf(), item);
input->varianceAlongDimension(variance::SummaryStatsVariance, output, biasCorrected, dimensions); input->varianceAlongDimension(variance::SummaryStatsVariance, *output, biasCorrected, dimensions);
return Status::OK(); return Status::OK();
} }
@ -77,7 +77,7 @@ DECLARE_SHAPE_FN(reduce_variance) {
} }
REQUIRE_TRUE(dimensions.size() <= INPUT_VARIABLE(0)->rankOf(), 0, "REDUCE_VARIANCE OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); REQUIRE_TRUE(dimensions.size() <= INPUT_VARIABLE(0)->rankOf(), 0, "REDUCE_VARIANCE OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size());
for(const auto& item : dimensions) for(const auto& item : dimensions)
REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], 0, "REDUCE_VARIANCE OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , inputShape->at(0)[0], inputShape->at(0)[0], item); REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], 0, "REDUCE_VARIANCE OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , inputShape->at(0)[0], inputShape->at(0)[0], item);
@ -128,9 +128,9 @@ CUSTOM_OP_IMPL(reduce_variance_bp, 2, 1, false, 0, 0) {
const Nd4jLong NminusOne = biasCorrected ? N - 1 : N; const Nd4jLong NminusOne = biasCorrected ? N - 1 : N;
const double factor1 = 2.0 / NminusOne; const double factor1 = 2.0 / NminusOne;
const double factor2 = 2.0 / (N * NminusOne); const double factor2 = 2.0 / (N * NminusOne);
auto mean = input->reduceAlongDims(reduce::Mean, dimensions, true); auto mean = input->reduceAlongDimension(reduce::Mean, dimensions, true);
gradI->assign( (*input - mean) * (2.0f / NminusOne)); // automatic broadcasting happens here gradI->assign( (*input - mean) * (2.0f / NminusOne)); // automatic broadcasting happens here
if(!keepDims) { if(!keepDims) {
@ -153,13 +153,13 @@ DECLARE_SHAPE_FN(reduce_variance_bp) {
} }
REQUIRE_TRUE(dimensions.size() <= rank, 0, "REDUCE_VARIANCE_BP OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size()); REQUIRE_TRUE(dimensions.size() <= rank, 0, "REDUCE_VARIANCE_BP OP: the number of dimensions to reduce along must be <= input array rank, but got %i instead" , dimensions.size());
for(const auto& item : dimensions) for(const auto& item : dimensions)
REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], 0, "REDUCE_VARIANCE_BP OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , inputShape->at(0)[0], inputShape->at(0)[0], item); REQUIRE_TRUE(item >= -inputShape->at(0)[0] && item < inputShape->at(0)[0], 0, "REDUCE_VARIANCE_BP OP: the input dimension to reduce along must be in range [-%i, %i), but got %i instead !" , inputShape->at(0)[0], inputShape->at(0)[0], item);
Nd4jLong* gradIshapeInfo(nullptr); Nd4jLong* gradIshapeInfo(nullptr);
COPY_SHAPE(in, gradIshapeInfo); COPY_SHAPE(in, gradIshapeInfo);
return SHAPELIST(CONSTANT(gradIshapeInfo)); return SHAPELIST(CONSTANT(gradIshapeInfo));
} }

View File

@ -45,9 +45,9 @@ namespace ops {
//void* whereMax = (void*)(); //void* whereMax = (void*)();
auto internal = (*input); auto internal = (*input);
internal -= maxVals; internal -= maxVals;
internal.applyTransform(transform::Exp, nullptr, nullptr); internal.applyTransform(transform::Exp, internal);
internal.reduceAlongDimension(reduce::Sum, output, axes, keepDims, false); //, (void*)&maxVals); internal.reduceAlongDimension(reduce::Sum, *output, axes, keepDims, false); //, (void*)&maxVals);
output->applyTransform(transform::Log, nullptr, nullptr); output->applyTransform(transform::Log, *output);
(*output) += maxVals; (*output) += maxVals;
return ND4J_STATUS_OK; return ND4J_STATUS_OK;
} }
@ -56,7 +56,7 @@ namespace ops {
-> setAllowedInputTypes({ALL_INTS, ALL_FLOATS}) -> setAllowedInputTypes({ALL_INTS, ALL_FLOATS})
-> setAllowedOutputTypes({ALL_FLOATS}); -> setAllowedOutputTypes({ALL_FLOATS});
} }
DECLARE_SHAPE_FN(reduce_logsumexp) { DECLARE_SHAPE_FN(reduce_logsumexp) {
const bool keepDims = block.getTArguments()->size() > 0 ? (bool)T_ARG(0) : false; const bool keepDims = block.getTArguments()->size() > 0 ? (bool)T_ARG(0) : false;
auto input = INPUT_VARIABLE(0); auto input = INPUT_VARIABLE(0);
@ -74,6 +74,6 @@ namespace ops {
return SHAPELIST(outShapeInfo); return SHAPELIST(outShapeInfo);
} }
#endif #endif
} }
} }

View File

@ -52,7 +52,7 @@ CUSTOM_OP_IMPL(reduce_max, 1, 1, false, 0, 0) {
else if (block.getTArguments()->size() > 0) else if (block.getTArguments()->size() > 0)
keepDims = (bool)T_ARG(0); keepDims = (bool)T_ARG(0);
input->reduceAlongDimension(reduce::Max, output, dimensions, keepDims); input->reduceAlongDimension(reduce::Max, *output, dimensions, keepDims);
return Status::OK(); return Status::OK();
} }
@ -122,8 +122,7 @@ CUSTOM_OP_IMPL(reduce_max_bp, 2, 1, false, 0, 0) {
else { else {
auto indicesArr = input->applyIndexReduce(nd4j::indexreduce::IndexMax, dimensions); auto indicesArr = input->applyIndexReduce(nd4j::indexreduce::IndexMax, dimensions);
helpers::scatterSimple(block.launchContext(), 6, *gradI, *gradO, *indicesArr, ShapeUtils::evalDimsToExclude(gradI->rankOf(), dimensions)); // 6 corresponds to copy operation helpers::scatterSimple(block.launchContext(), 6, *gradI, *gradO, indicesArr, ShapeUtils::evalDimsToExclude(gradI->rankOf(), dimensions)); // 6 corresponds to copy operation
delete indicesArr;
} }
return Status::OK(); return Status::OK();

View File

@ -52,7 +52,7 @@ CUSTOM_OP_IMPL(reduce_min, 1, 1, false, 0, 0) {
else if (block.getTArguments()->size() > 0) else if (block.getTArguments()->size() > 0)
keepDims = (bool)T_ARG(0); keepDims = (bool)T_ARG(0);
input->reduceAlongDimension(reduce::Min, output, dimensions, keepDims); input->reduceAlongDimension(reduce::Min, *output, dimensions, keepDims);
return Status::OK(); return Status::OK();
} }
@ -89,7 +89,7 @@ DECLARE_TYPES(reduce_min) {
} }
#endif #endif
#if NOT_EXCLUDED(OP_reduce_min_bp) #if NOT_EXCLUDED(OP_reduce_min_bp)
@ -125,8 +125,7 @@ CUSTOM_OP_IMPL(reduce_min_bp, 2, 1, false, 0, 0) {
else { else {
auto indicesArr = input->applyIndexReduce(nd4j::indexreduce::IndexMin, dimensions); auto indicesArr = input->applyIndexReduce(nd4j::indexreduce::IndexMin, dimensions);
helpers::scatterSimple(block.launchContext(), 6, *gradI, *gradO, *indicesArr, ShapeUtils::evalDimsToExclude(gradI->rankOf(), dimensions)); // 6 corresponds to copy operation helpers::scatterSimple(block.launchContext(), 6, *gradI, *gradO, indicesArr, ShapeUtils::evalDimsToExclude(gradI->rankOf(), dimensions)); // 6 corresponds to copy operation
delete indicesArr;
} }
return Status::OK(); return Status::OK();

Some files were not shown because too many files have changed in this diff Show More