[WIP] bunch of improvements (#257)

* - profiling bias_add op
- add some docementation

Signed-off-by: Yurii <yurii@skymind.io>

* - minor change

Signed-off-by: Yurii <yurii@skymind.io>

* - provide addBias cuda kernel

Signed-off-by: Yurii <yurii@skymind.io>

* - improve shape::getIndexOfffset and change its signature

Signed-off-by: Yurii <yurii@skymind.io>

* - same as previous

Signed-off-by: Yurii <yurii@skymind.io>

* - improve and change signature in some shape:: stuff which has to do with calculation of offsets for array elements

Signed-off-by: Yurii <yurii@skymind.io>

* - minor changes in flatten

Signed-off-by: Yurii <shyrma@skymind.io>

* - add function shape::getIndexOffsetOrdered

Signed-off-by: Yurii <shyrma@skymind.io>

* - correct shape::getIndexOffsetOrdered()

Signed-off-by: Yurii <shyrma@skymind.io>

* - move getIndexOffsetOrdered to flatten.h header in order to isolate this function

Signed-off-by: Yurii <shyrma@skymind.io>
master
raver119 2019-09-11 20:12:09 +03:00 committed by GitHub
parent 3e73e9b56e
commit 589401477d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
168 changed files with 2428 additions and 2384 deletions

View File

@ -1770,7 +1770,7 @@ NDArray NDArray::operator()(const Nd4jLong i) const {
} else { } else {
Nd4jLong idx[MAX_RANK]; Nd4jLong idx[MAX_RANK];
shape::ind2subC(rankOf(), shapeOf(), i, idx); shape::ind2subC(rankOf(), shapeOf(), i, idx);
auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), idx, rankOf()); auto xOffset = shape::getOffset(getShapeInfo(), idx);
auto cast = reinterpret_cast<int8_t *>(_buffer) + (xOffset * this->sizeOfT()); auto cast = reinterpret_cast<int8_t *>(_buffer) + (xOffset * this->sizeOfT());
NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace()));
@ -1801,7 +1801,7 @@ NDArray& NDArray::operator()(const Nd4jLong i) {
} else { } else {
Nd4jLong idx[MAX_RANK]; Nd4jLong idx[MAX_RANK];
shape::ind2subC(rankOf(), shapeOf(), i, idx); shape::ind2subC(rankOf(), shapeOf(), i, idx);
auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), idx, rankOf()); auto xOffset = shape::getOffset(getShapeInfo(), idx);
auto cast = reinterpret_cast<int8_t *>(_buffer) + (xOffset * this->sizeOfT()); auto cast = reinterpret_cast<int8_t *>(_buffer) + (xOffset * this->sizeOfT());
NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace()));
@ -1818,7 +1818,7 @@ NDArray NDArray::operator()(const Nd4jLong i, const Nd4jLong j) const {
throw std::invalid_argument("NDArray::operator(i,j): one of input indexes is out of array length or rank!=2 !"); throw std::invalid_argument("NDArray::operator(i,j): one of input indexes is out of array length or rank!=2 !");
Nd4jLong coords[2] = {i, j}; Nd4jLong coords[2] = {i, j};
auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf()); auto xOffset = shape::getOffset(getShapeInfo(), coords);
// TODO: do we really want a view here? // TODO: do we really want a view here?
auto cast = reinterpret_cast<int8_t *>(_buffer) + (xOffset * this->sizeOfT()); auto cast = reinterpret_cast<int8_t *>(_buffer) + (xOffset * this->sizeOfT());
@ -1834,7 +1834,7 @@ NDArray& NDArray::operator()(const Nd4jLong i, const Nd4jLong j) {
throw std::invalid_argument("NDArray::operator(i,j): one of input indexes is out of array length or rank!=2 !"); throw std::invalid_argument("NDArray::operator(i,j): one of input indexes is out of array length or rank!=2 !");
Nd4jLong coords[2] = {i, j}; Nd4jLong coords[2] = {i, j};
auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf()); auto xOffset = shape::getOffset(getShapeInfo(), coords);
auto cast = reinterpret_cast<int8_t *>(_buffer) + (xOffset * this->sizeOfT()); auto cast = reinterpret_cast<int8_t *>(_buffer) + (xOffset * this->sizeOfT());
NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace()));
@ -1853,7 +1853,7 @@ NDArray NDArray::operator()(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k
throw std::invalid_argument("NDArray::operator(i,j,k): one of input indexes is out of array length or rank!=3 !"); throw std::invalid_argument("NDArray::operator(i,j,k): one of input indexes is out of array length or rank!=3 !");
Nd4jLong coords[3] = {i, j, k}; Nd4jLong coords[3] = {i, j, k};
auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf()); auto xOffset = shape::getOffset(getShapeInfo(), coords);
auto cast = reinterpret_cast<int8_t *>(_buffer) + (xOffset * this->sizeOfT()); auto cast = reinterpret_cast<int8_t *>(_buffer) + (xOffset * this->sizeOfT());
NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace()));
@ -1870,7 +1870,7 @@ NDArray& NDArray::operator()(const Nd4jLong i, const Nd4jLong j, const Nd4jLong
throw std::invalid_argument("NDArray::operator(i,j,k): one of input indexes is out of array length or rank!=3 !"); throw std::invalid_argument("NDArray::operator(i,j,k): one of input indexes is out of array length or rank!=3 !");
Nd4jLong coords[3] = {i, j, k}; Nd4jLong coords[3] = {i, j, k};
auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf()); auto xOffset = shape::getOffset(getShapeInfo(), coords);
auto cast = reinterpret_cast<int8_t *>(_buffer) + (xOffset * this->sizeOfT()); auto cast = reinterpret_cast<int8_t *>(_buffer) + (xOffset * this->sizeOfT());
NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace()));
@ -1886,7 +1886,7 @@ NDArray NDArray::operator()(const Nd4jLong t, const Nd4jLong u, const Nd4jLong v
throw std::invalid_argument("NDArray::operator(t,u,v,w): one of input indexes is out of array length or rank!=4 !"); throw std::invalid_argument("NDArray::operator(t,u,v,w): one of input indexes is out of array length or rank!=4 !");
Nd4jLong coords[4] = {t, u, v, w}; Nd4jLong coords[4] = {t, u, v, w};
auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf()); auto xOffset = shape::getOffset(getShapeInfo(), coords);
auto cast = reinterpret_cast<int8_t *>(_buffer) + (xOffset * this->sizeOfT()); auto cast = reinterpret_cast<int8_t *>(_buffer) + (xOffset * this->sizeOfT());
NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace()));
@ -1900,7 +1900,7 @@ NDArray& NDArray::operator()(const Nd4jLong t, const Nd4jLong u, const Nd4jLong
throw std::invalid_argument("NDArray::operator(t,u,v,w): one of input indexes is out of array length or rank!=4 !"); throw std::invalid_argument("NDArray::operator(t,u,v,w): one of input indexes is out of array length or rank!=4 !");
Nd4jLong coords[4] = {t, u, v, w}; Nd4jLong coords[4] = {t, u, v, w};
auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf()); auto xOffset = shape::getOffset(getShapeInfo(), coords);
// FIXME // FIXME
auto cast = reinterpret_cast<int8_t *>(_buffer) + (xOffset * this->sizeOfT()); auto cast = reinterpret_cast<int8_t *>(_buffer) + (xOffset * this->sizeOfT());
@ -1916,7 +1916,7 @@ NDArray NDArray::operator()(const Nd4jLong* idx) const {
if (idx[i] >= sizeAt(i)) if (idx[i] >= sizeAt(i))
throw std::invalid_argument("NDArray::operator(const Nd4jLong* idx): input index is out of dimension length !"); throw std::invalid_argument("NDArray::operator(const Nd4jLong* idx): input index is out of dimension length !");
auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), idx, rankOf()); auto xOffset = shape::getOffset(getShapeInfo(), idx);
auto cast = reinterpret_cast<int8_t *>(_buffer) + (xOffset * this->sizeOfT()); auto cast = reinterpret_cast<int8_t *>(_buffer) + (xOffset * this->sizeOfT());
NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace()));
@ -1931,7 +1931,7 @@ NDArray& NDArray::operator()(const Nd4jLong* idx) {
if (idx[i] >= sizeAt(i)) if (idx[i] >= sizeAt(i))
throw std::invalid_argument("NDArray::operator(const Nd4jLong* idx): input index is out of dimension length !"); throw std::invalid_argument("NDArray::operator(const Nd4jLong* idx): input index is out of dimension length !");
auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), idx, rankOf()); auto xOffset = shape::getOffset(getShapeInfo(), idx);
auto cast = reinterpret_cast<int8_t *>(_buffer) + (xOffset * this->sizeOfT()); auto cast = reinterpret_cast<int8_t *>(_buffer) + (xOffset * this->sizeOfT());
NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace())); NDArray result(cast, nd4j::ShapeBuilders::createScalarShapeInfo(this->dataType(), this->getWorkspace()));
@ -2067,7 +2067,7 @@ T& NDArray::t(const Nd4jLong i, const Nd4jLong j) {
syncToHost(); syncToHost();
Nd4jLong coords[2] = {i, j}; Nd4jLong coords[2] = {i, j};
auto offset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf()); auto offset = shape::getOffset(getShapeInfo(), coords);
tickWriteHost(); tickWriteHost();
return *(reinterpret_cast<T*>(bufferWithOffset(offset))); return *(reinterpret_cast<T*>(bufferWithOffset(offset)));
} }
@ -2084,7 +2084,7 @@ T& NDArray::t(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) {
syncToHost(); syncToHost();
Nd4jLong coords[3] = {i, j, k}; Nd4jLong coords[3] = {i, j, k};
auto offset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf()); auto offset = shape::getOffset(getShapeInfo(), coords);
tickWriteHost(); tickWriteHost();
return *(reinterpret_cast<T*>(bufferWithOffset(offset))); return *(reinterpret_cast<T*>(bufferWithOffset(offset)));
} }
@ -2118,7 +2118,7 @@ T NDArray::t(const Nd4jLong i, const Nd4jLong j) const {
syncToHost(); syncToHost();
Nd4jLong coords[2] = {i, j}; Nd4jLong coords[2] = {i, j};
auto offset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf()); auto offset = shape::getOffset(getShapeInfo(), coords);
tickReadHost(); tickReadHost();
return *(reinterpret_cast<T*>(bufferWithOffset(offset))); return *(reinterpret_cast<T*>(bufferWithOffset(offset)));
} }
@ -2135,7 +2135,7 @@ T NDArray::t(const Nd4jLong i, const Nd4jLong j) const {
syncToHost(); syncToHost();
Nd4jLong coords[3] = {i, j, k}; Nd4jLong coords[3] = {i, j, k};
auto offset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf()); auto offset = shape::getOffset(getShapeInfo(), coords);
tickReadHost(); tickReadHost();
return *(reinterpret_cast<T*>(bufferWithOffset(offset))); return *(reinterpret_cast<T*>(bufferWithOffset(offset)));
} }

View File

@ -808,7 +808,7 @@ void NDArray::templatedSet(void *buffer, const Nd4jLong *indices, const void *va
auto t = reinterpret_cast<T *>(buffer); auto t = reinterpret_cast<T *>(buffer);
const auto y = *(reinterpret_cast<const Y *>(value)); const auto y = *(reinterpret_cast<const Y *>(value));
auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), indices, rankOf()); auto xOffset = shape::getOffset(getShapeInfo(), indices);
t[xOffset] = static_cast<T>(y); t[xOffset] = static_cast<T>(y);
} }
BUILD_DOUBLE_TEMPLATE(template void NDArray::templatedSet, (void *buffer, const Nd4jLong *indices, const void *value), LIBND4J_TYPES, LIBND4J_TYPES); BUILD_DOUBLE_TEMPLATE(template void NDArray::templatedSet, (void *buffer, const Nd4jLong *indices, const void *value), LIBND4J_TYPES, LIBND4J_TYPES);
@ -2462,14 +2462,13 @@ double NDArray::getTrace() const {
int rank = rankOf(); int rank = rankOf();
auto shape = shapeOf(); auto shape = shapeOf();
auto strides = stridesOf();
int minDim = 100000000; int minDim = 100000000;
Nd4jLong indices[MAX_RANK]; Nd4jLong indices[MAX_RANK];
for(int j = 0; j < rank; ++j) for(int j = 0; j < rank; ++j)
indices[j] = 1; indices[j] = 1;
auto offset = shape::getOffset(0, shape, strides, indices, rank); auto offset = shape::getOffset(getShapeInfo(), indices);
for(int i = 0; i < rank; ++i) for(int i = 0; i < rank; ++i)
if(minDim > shape[i]) if(minDim > shape[i])
@ -3472,7 +3471,7 @@ T NDArray::e(const Nd4jLong i, const Nd4jLong j) const {
throw std::invalid_argument("NDArray::e(i,j): one of input indexes is out of array length or rank!=2 !"); throw std::invalid_argument("NDArray::e(i,j): one of input indexes is out of array length or rank!=2 !");
const Nd4jLong coords[2] = {i, j}; const Nd4jLong coords[2] = {i, j};
const auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf()); const auto xOffset = shape::getOffset(getShapeInfo(), coords);
NDArray::preparePrimaryUse({}, {this}); NDArray::preparePrimaryUse({}, {this});
NDArray::registerPrimaryUse({}, {this}); NDArray::registerPrimaryUse({}, {this});
@ -3492,7 +3491,7 @@ T NDArray::e(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) const {
throw std::invalid_argument("NDArray::e(i,j,k): one of input indexes is out of array length or rank!=3 !"); throw std::invalid_argument("NDArray::e(i,j,k): one of input indexes is out of array length or rank!=3 !");
const Nd4jLong coords[3] = {i, j, k}; const Nd4jLong coords[3] = {i, j, k};
const auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf()); const auto xOffset = shape::getOffset(getShapeInfo(), coords);
NDArray::preparePrimaryUse({}, {this}); NDArray::preparePrimaryUse({}, {this});
NDArray::registerPrimaryUse({}, {this}); NDArray::registerPrimaryUse({}, {this});
@ -3512,7 +3511,7 @@ T NDArray::e(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLon
throw std::invalid_argument("NDArray::e(i,j,k,l): one of input indexes is out of array length or rank!=4 !"); throw std::invalid_argument("NDArray::e(i,j,k,l): one of input indexes is out of array length or rank!=4 !");
const Nd4jLong coords[4] = {i, j, k, l}; const Nd4jLong coords[4] = {i, j, k, l};
const auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf()); const auto xOffset = shape::getOffset(getShapeInfo(), coords);
NDArray::preparePrimaryUse({}, {this}); NDArray::preparePrimaryUse({}, {this});
NDArray::registerPrimaryUse({}, {this}); NDArray::registerPrimaryUse({}, {this});
@ -4095,7 +4094,7 @@ void NDArray::p(const Nd4jLong i, const Nd4jLong j, const T value) {
void *p = reinterpret_cast<void *>(const_cast<T *>(&value)); void *p = reinterpret_cast<void *>(const_cast<T *>(&value));
Nd4jLong coords[2] = {i, j}; Nd4jLong coords[2] = {i, j};
auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf()); auto xOffset = shape::getOffset(getShapeInfo(), coords);
NDArray::preparePrimaryUse({this}, {}, true); NDArray::preparePrimaryUse({this}, {}, true);
BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), templatedSet<, T>(this->getBuffer(), xOffset, p), LIBND4J_TYPES); BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), templatedSet<, T>(this->getBuffer(), xOffset, p), LIBND4J_TYPES);
@ -4127,7 +4126,7 @@ void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const T va
void *p = reinterpret_cast<void *>(const_cast<T *>(&value)); void *p = reinterpret_cast<void *>(const_cast<T *>(&value));
Nd4jLong coords[3] = {i, j, k}; Nd4jLong coords[3] = {i, j, k};
auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf()); auto xOffset = shape::getOffset(getShapeInfo(), coords);
BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), templatedSet<, T>(this->getBuffer(), xOffset, p), LIBND4J_TYPES); BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), templatedSet<, T>(this->getBuffer(), xOffset, p), LIBND4J_TYPES);
NDArray::registerPrimaryUse({this}, {}); NDArray::registerPrimaryUse({this}, {});
} }
@ -4154,7 +4153,7 @@ void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4j
void *p = reinterpret_cast<void *>(const_cast<T *>(&value)); void *p = reinterpret_cast<void *>(const_cast<T *>(&value));
Nd4jLong coords[4] = {i, j, k, l}; Nd4jLong coords[4] = {i, j, k, l};
auto xOffset = shape::getOffset(0, shapeOf(), stridesOf(), coords, rankOf()); auto xOffset = shape::getOffset(getShapeInfo(), coords);
NDArray::preparePrimaryUse({this}, {}, true); NDArray::preparePrimaryUse({this}, {}, true);
BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), templatedSet<, T>(this->getBuffer(), xOffset, p), LIBND4J_TYPES); BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), templatedSet<, T>(this->getBuffer(), xOffset, p), LIBND4J_TYPES);
@ -4409,7 +4408,7 @@ Nd4jLong NDArray::getOffset(const Nd4jLong i) const {
if (i >= lengthOf()) if (i >= lengthOf())
throw std::invalid_argument("NDArray::getOffset: input index is out of array length !"); throw std::invalid_argument("NDArray::getOffset: input index is out of array length !");
return shape::getIndexOffset(i, _shapeInfo, lengthOf()); return shape::getIndexOffset(i, _shapeInfo);
} }
NDArray NDArray::like() { NDArray NDArray::like() {
@ -4455,7 +4454,7 @@ NDArray* NDArray::diagonal(const char type) const {
indices[i] = 1; indices[i] = 1;
} }
auto step = shape::getOffset(0, shapeOf(), stridesOf(), indices, rank); auto step = shape::getOffset(getShapeInfo(), indices);
if(type == 'c') { if(type == 'c') {
outShapeInfo[1] = diagSize; outShapeInfo[1] = diagSize;

View File

@ -103,8 +103,8 @@ void NDArray::fillAsTriangular(const float val, int lower, int upper, const char
PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(zLen > Environment::getInstance()->elementwiseThreshold()) firstprivate(coords)) PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(zLen > Environment::getInstance()->elementwiseThreshold()) firstprivate(coords))
for (Nd4jLong i = 0; i < zLen; ++i) { for (Nd4jLong i = 0; i < zLen; ++i) {
shape::index2coords(zRank, target->shapeOf(), i, zLen, coords.data()); shape::index2coords(i, target->getShapeInfo(), coords.data());
const auto zOffset = shape::getOffset(0, target->shapeOf(), target->stridesOf(), coords.data(), zRank); const auto zOffset = shape::getOffset(target->getShapeInfo(), coords.data());
// if( (row + upper < col) || (row + lower > col) ) // if( (row + upper < col) || (row + lower > col) )
if((coords[zRank - 2] + upper < coords[zRank - 1]) || (coords[zRank - 2] + lower > coords[zRank - 1])) if((coords[zRank - 2] + upper < coords[zRank - 1]) || (coords[zRank - 2] + lower > coords[zRank - 1]))
@ -112,7 +112,7 @@ void NDArray::fillAsTriangular(const float val, int lower, int upper, const char
else if(this != target) { // when this and target are different arrays else if(this != target) { // when this and target are different arrays
if(xRank != zRank) if(xRank != zRank)
coords[0] = coords[1]; coords[0] = coords[1];
const auto xOffset = areSameOffsets ? zOffset : shape::getOffset(0, shapeOf(), stridesOf(), coords.data(), xRank); const auto xOffset = areSameOffsets ? zOffset : shape::getOffset(getShapeInfo(), coords.data());
z[zOffset] = x[xOffset]; z[zOffset] = x[xOffset];
} }
} }
@ -128,13 +128,12 @@ void NDArray::setIdentity() {
int rank = rankOf(); int rank = rankOf();
auto shape = shapeOf(); auto shape = shapeOf();
auto strides = stridesOf();
int minDim = MAX_INT; int minDim = MAX_INT;
Nd4jLong indices[MAX_RANK]; Nd4jLong indices[MAX_RANK];
for(int j = 0; j < rank; ++j) for(int j = 0; j < rank; ++j)
indices[j] = 1; indices[j] = 1;
Nd4jLong offset = shape::getOffset(0, shape, strides, indices, rank); Nd4jLong offset = shape::getOffset(getShapeInfo(), indices);
for(int i = 0; i < rank; ++i) for(int i = 0; i < rank; ++i)
if(minDim > shape[i]) if(minDim > shape[i])
@ -380,9 +379,9 @@ static void repeat_(const NDArray& input, NDArray& output, const std::vector<int
PRAGMA_OMP_PARALLEL_FOR_ARGS(schedule(guided) firstprivate(coords)) PRAGMA_OMP_PARALLEL_FOR_ARGS(schedule(guided) firstprivate(coords))
for (Nd4jLong i = 0; i < zLen; ++i) { for (Nd4jLong i = 0; i < zLen; ++i) {
shape::index2coords(rank, output.shapeOf(), i, zLen, coords.data()); shape::index2coords(i, output.getShapeInfo(), coords.data());
const auto zOffset = shape::getOffset(0, output.shapeOf(), output.stridesOf(), coords.data(), rank); const auto zOffset = shape::getOffset(output.getShapeInfo(), coords.data());
if(repSize > 1) { if(repSize > 1) {
for (uint j = 0; j < repSize; ++j) { for (uint j = 0; j < repSize; ++j) {
@ -396,7 +395,7 @@ static void repeat_(const NDArray& input, NDArray& output, const std::vector<int
else else
coords[axis] /= repeats[0]; coords[axis] /= repeats[0];
z[zOffset] = x[shape::getOffset(0, input.shapeOf(), input.stridesOf(), coords.data(), rank)]; z[zOffset] = x[shape::getOffset(input.getShapeInfo(), coords.data())];
} }
} }

View File

@ -1385,8 +1385,8 @@ void pullRowsGeneric(void *vx,
} }
else { else {
for (int i = 0; i < tadLength; i++) { for (int i = 0; i < tadLength; i++) {
auto xOffset = xTadOffsetForBlock + shape::getIndexOffset(i, tadShapeInfo, tadLength); auto xOffset = xTadOffsetForBlock + shape::getIndexOffset(i, tadShapeInfo);
auto zOffset = zTadOffsetForBlock + shape::getIndexOffset(i, zTadShapeInfo, tadLength); auto zOffset = zTadOffsetForBlock + shape::getIndexOffset(i, zTadShapeInfo);
hZ[zOffset] = hX[xOffset]; hZ[zOffset] = hX[xOffset];
} }
} }
@ -1450,7 +1450,7 @@ void tearGeneric(void *vx,
else { else {
for (Nd4jLong j = 0; j < tadLength; j++) for (Nd4jLong j = 0; j < tadLength; j++)
hZ[shape::getIndexOffset(j, hZShapeInfo, tadLength)] = s[shape::getIndexOffset(j, tadShapeInfo, tadLength)]; hZ[shape::getIndexOffset(j, hZShapeInfo)] = s[shape::getIndexOffset(j, tadShapeInfo)];
} }
} }
} }
@ -1597,7 +1597,7 @@ void shuffleGeneric(void **hX, Nd4jLong **hXShapeInfo, void **dz, Nd4jLong **hZS
} }
} else { } else {
for (Nd4jLong i = 0; i < tadLength; i++) { for (Nd4jLong i = 0; i < tadLength; i++) {
auto offset = shape::getIndexOffset(i, tadOnlyShapeInfo[f], tadLength); auto offset = shape::getIndexOffset(i, tadOnlyShapeInfo[f]);
nd4j::math::nd4j_swap<T>(hX[offset + oldOffset], hX[offset + newOffset]); nd4j::math::nd4j_swap<T>(hX[offset + oldOffset], hX[offset + newOffset]);
} }
} }

View File

@ -106,8 +106,8 @@ __global__ static void fillAsTriangularCuda(const void* vx, const Nd4jLong* xSha
for (Nd4jLong i = tid; i < zLen; i += totalThreads) { for (Nd4jLong i = tid; i < zLen; i += totalThreads) {
shape::index2coords(zRank, shape::shapeOf(const_cast<Nd4jLong*>(zShapeInfo)), i, zLen, coords); shape::index2coords(i, zShapeInfo, coords);
const auto zOffset = shape::getOffset(0, shape::shapeOf(const_cast<Nd4jLong*>(zShapeInfo)), shape::stride(const_cast<Nd4jLong*>(zShapeInfo)), coords, zRank); const auto zOffset = shape::getOffset(zShapeInfo, coords);
// if( (row + upper < col) || (row + lower > col) ) // if( (row + upper < col) || (row + lower > col) )
if((coords[zRank - 2] + upper < coords[zRank - 1]) || (coords[zRank - 2] + lower > coords[zRank - 1])) if((coords[zRank - 2] + upper < coords[zRank - 1]) || (coords[zRank - 2] + lower > coords[zRank - 1]))
@ -115,7 +115,7 @@ __global__ static void fillAsTriangularCuda(const void* vx, const Nd4jLong* xSha
else if(vx != vz) { // when x and z are different arrays else if(vx != vz) { // when x and z are different arrays
if(xRank != zRank) if(xRank != zRank)
coords[0] = coords[1]; coords[0] = coords[1];
const auto xOffset = areSameOffsets ? zOffset : shape::getOffset(0, shape::shapeOf(const_cast<Nd4jLong*>(xShapeInfo)), shape::stride(const_cast<Nd4jLong*>(xShapeInfo)), coords, xRank); const auto xOffset = areSameOffsets ? zOffset : shape::getOffset(xShapeInfo, coords);
z[zOffset] = x[xOffset]; z[zOffset] = x[xOffset];
} }
} }
@ -177,8 +177,8 @@ __global__ static void identityMatrixCuda(void* vx, const Nd4jLong* xShapeInfo,
for (Nd4jLong i = tid; i < len; i += totalThreads) { for (Nd4jLong i = tid; i < len; i += totalThreads) {
shape::index2coords(rank, shape::shapeOf(const_cast<Nd4jLong*>(xShapeInfo)), i, len, coords); shape::index2coords(i, xShapeInfo, coords);
const auto offset = shape::getOffset(0, shape::shapeOf(const_cast<Nd4jLong*>(xShapeInfo)), shape::stride(const_cast<Nd4jLong*>(xShapeInfo)), coords, rank); const auto offset = shape::getOffset(xShapeInfo, coords);
if(coords[rank - 2] == coords[rank - 1]) // row == col -> on diagonal if(coords[rank - 2] == coords[rank - 1]) // row == col -> on diagonal
x[offset] = val; x[offset] = val;
@ -424,9 +424,9 @@ __global__ static void repeatCuda(const void* vx, const Nd4jLong* xShapeInfo,
for (Nd4jLong i = tid; i < zLen; i += totalThreads) { for (Nd4jLong i = tid; i < zLen; i += totalThreads) {
shape::index2coords(rank, zShapeInfo + 1, i, zLen, coords); shape::index2coords(i, zShapeInfo, coords);
const auto zOffset = shape::getOffset(0, zShapeInfo + 1, zShapeInfo + rank + 1, coords, rank); const auto zOffset = shape::getOffset(zShapeInfo, coords);
if(repSize > 1) { if(repSize > 1) {
for (uint j = 0; j < repSize; ++j) { for (uint j = 0; j < repSize; ++j) {
@ -440,7 +440,7 @@ __global__ static void repeatCuda(const void* vx, const Nd4jLong* xShapeInfo,
else else
coords[axis] /= repeats[0]; coords[axis] /= repeats[0];
z[zOffset] = x[shape::getOffset(0, xShapeInfo + 1, xShapeInfo + rank + 1, coords, rank)]; z[zOffset] = x[shape::getOffset(xShapeInfo, coords)];
} }
} }

View File

@ -23,8 +23,8 @@
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
static Nd4jLong __device__ __noinline__ __getIndexOffset(Nd4jLong index, Nd4jLong *shapeInfo, Nd4jLong length) { static Nd4jLong __device__ __noinline__ __getIndexOffset(Nd4jLong index, Nd4jLong *shapeInfo) {
return shape::getIndexOffset(index, shapeInfo, length); return shape::getIndexOffset(index, shapeInfo);
} }
static Nd4jLong __device__ __noinline__ __length(Nd4jLong *shapeInfo) { static Nd4jLong __device__ __noinline__ __length(Nd4jLong *shapeInfo) {
@ -103,8 +103,8 @@ static _CUDA_G void lambdaKernel(void* vx, Nd4jLong *xShapeInfo, void *vz, Nd4jL
z[e * zEws] = lambda(x[e * xEws]); z[e * zEws] = lambda(x[e * xEws]);
} else { } else {
for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) { for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) {
auto xOffset = __getIndexOffset(e, xShapeInfo, zLength); auto xOffset = __getIndexOffset(e, xShapeInfo);
auto zOffset = __getIndexOffset(e, zShapeInfo, zLength); auto zOffset = __getIndexOffset(e, zShapeInfo);
z[zOffset] = lambda(x[xOffset]); z[zOffset] = lambda(x[xOffset]);
} }
@ -132,8 +132,8 @@ static _CUDA_G void lambdaIndexedKernel(void* vx, Nd4jLong *xShapeInfo, void *vz
z[e * zEws] = lambda(e, x[e * xEws]); z[e * zEws] = lambda(e, x[e * xEws]);
} else { } else {
for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) { for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) {
auto xOffset = __getIndexOffset(e, xShapeInfo, zLength); auto xOffset = __getIndexOffset(e, xShapeInfo);
auto zOffset = __getIndexOffset(e, zShapeInfo, zLength); auto zOffset = __getIndexOffset(e, zShapeInfo);
z[zOffset] = lambda(e, x[xOffset]); z[zOffset] = lambda(e, x[xOffset]);
} }
@ -164,9 +164,9 @@ static _CUDA_G void lambdaIndexedPairwiseKernel(void* vx, Nd4jLong *xShapeInfo,
z[e * zEws] = lambda(e, x[e * xEws], y[e * yEws]); z[e * zEws] = lambda(e, x[e * xEws], y[e * yEws]);
} else { } else {
for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) { for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) {
auto xOffset = __getIndexOffset(e, xShapeInfo, zLength); auto xOffset = __getIndexOffset(e, xShapeInfo);
auto yOffset = __getIndexOffset(e, yShapeInfo, zLength); auto yOffset = __getIndexOffset(e, yShapeInfo);
auto zOffset = __getIndexOffset(e, zShapeInfo, zLength); auto zOffset = __getIndexOffset(e, zShapeInfo);
z[zOffset] = lambda(e, x[xOffset], y[yOffset]); z[zOffset] = lambda(e, x[xOffset], y[yOffset]);
} }
@ -197,9 +197,9 @@ static _CUDA_G void lambdaPairwiseKernel(void* vx, Nd4jLong *xShapeInfo, void* v
z[e * zEws] = lambda(x[e * xEws], y[e * yEws]); z[e * zEws] = lambda(x[e * xEws], y[e * yEws]);
} else { } else {
for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) { for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) {
auto xOffset = __getIndexOffset(e, xShapeInfo, zLength); auto xOffset = __getIndexOffset(e, xShapeInfo);
auto yOffset = __getIndexOffset(e, yShapeInfo, zLength); auto yOffset = __getIndexOffset(e, yShapeInfo);
auto zOffset = __getIndexOffset(e, zShapeInfo, zLength); auto zOffset = __getIndexOffset(e, zShapeInfo);
z[zOffset] = lambda(x[xOffset], y[yOffset]); z[zOffset] = lambda(x[xOffset], y[yOffset]);
} }
@ -233,10 +233,10 @@ static _CUDA_G void lambdaTriplewiseKernel(void* vw, Nd4jLong *wShapeInfo, void*
z[e * zEws] = lambda(w[e * wEws], x[e * xEws], y[e * yEws]); z[e * zEws] = lambda(w[e * wEws], x[e * xEws], y[e * yEws]);
} else { } else {
for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) { for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) {
auto wOffset = __getIndexOffset(e, wShapeInfo, zLength); auto wOffset = __getIndexOffset(e, wShapeInfo);
auto xOffset = __getIndexOffset(e, xShapeInfo, zLength); auto xOffset = __getIndexOffset(e, xShapeInfo);
auto yOffset = __getIndexOffset(e, yShapeInfo, zLength); auto yOffset = __getIndexOffset(e, yShapeInfo);
auto zOffset = __getIndexOffset(e, zShapeInfo, zLength); auto zOffset = __getIndexOffset(e, zShapeInfo);
z[zOffset] = lambda(w[wOffset], x[xOffset], y[yOffset]); z[zOffset] = lambda(w[wOffset], x[xOffset], y[yOffset]);
} }

View File

@ -3228,8 +3228,8 @@ __global__ static void scatterUpdateCuda(const int opCode, const int numOfSubArr
for (Nd4jLong i = threadIdx.x; i < arrLenX; i += blockDim.x) { for (Nd4jLong i = threadIdx.x; i < arrLenX; i += blockDim.x) {
const auto xOffset = shape::getIndexOffset(i, xShapeInfo, arrLenX); const auto xOffset = shape::getIndexOffset(i, xShapeInfo);
const auto yOffset = shape::getIndexOffset(i, yShapeInfo, arrLenY); const auto yOffset = shape::getIndexOffset(i, yShapeInfo);
switch (opCode) { switch (opCode) {
case 0: case 0:

View File

@ -246,9 +246,9 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
auto lenPerThread = static_cast<uint>(threadsInfo.getItersPerThread(threadNum)); auto lenPerThread = static_cast<uint>(threadsInfo.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (uint i = 0; i < lenPerThread; i++) { for (uint i = 0; i < lenPerThread; i++) {
auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, len, canCastX); auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
auto yOffset = shape::indexOffset(i + threadOffset, yShapeInfo, yShapeInfoCast, len, canCastY); auto yOffset = shape::indexOffset(i + threadOffset, yShapeInfo, yShapeInfoCast, canCastY);
auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, len, canCastZ); auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, canCastZ);
z[zOffset] = op(x[xOffset], y[yOffset], extraParams); z[zOffset] = op(x[xOffset], y[yOffset], extraParams);
} }
} }
@ -452,7 +452,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
for (uint j = 0; j < tadLen; j++) for (uint j = 0; j < tadLen; j++)
start = OpType::update(start, OpType::op(tad[j * tadEws], extraParams), extraParams); start = OpType::update(start, OpType::op(tad[j * tadEws], extraParams), extraParams);
auto zOffset = shape::indexOffset(i, zShapeInfo, castZShapeInfo, zLen, canCastZ); auto zOffset = shape::indexOffset(i, zShapeInfo, castZShapeInfo, canCastZ);
z[zOffset] = OpType::postProcess(start, tadLen, extraParams); z[zOffset] = OpType::postProcess(start, tadLen, extraParams);
} }
} }
@ -469,7 +469,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
auto start = OpType::startingValue(tad); auto start = OpType::startingValue(tad);
for (uint j = 0; j < tadLen; j++) { for (uint j = 0; j < tadLen; j++) {
auto tadOffset = shape::indexOffset(j, tadShapeInfo, castTadShapeInfo, tadLen, canCastTad); auto tadOffset = shape::indexOffset(j, tadShapeInfo, castTadShapeInfo, canCastTad);
start = OpType::update(start, OpType::op(tad[tadOffset], extraParams), extraParams); start = OpType::update(start, OpType::op(tad[tadOffset], extraParams), extraParams);
} }
@ -491,11 +491,11 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
// auto start = OpType::startingValue(tad); // auto start = OpType::startingValue(tad);
// for (uint j = 0; j < tadLen; j++) { // for (uint j = 0; j < tadLen; j++) {
// auto tadOffset = shape::indexOffset(j, tadShapeInfo, castTadShapeInfo, tadLen, canCastTad); // auto tadOffset = shape::indexOffset(j, tadShapeInfo, castTadShapeInfo, canCastTad);
// start = OpType::update(start, OpType::op(tad[tadOffset], extraParams), extraParams); // start = OpType::update(start, OpType::op(tad[tadOffset], extraParams), extraParams);
// } // }
// auto zOffset = shape::indexOffset(i, zShapeInfo, castZShapeInfo, zLen, canCastZ); // auto zOffset = shape::indexOffset(i, zShapeInfo, castZShapeInfo, canCastZ);
// z[zOffset] = OpType::postProcess(start, tadLen, extraParams); // z[zOffset] = OpType::postProcess(start, tadLen, extraParams);
// } // }
// } // }
@ -517,7 +517,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
for (uint j = 0; j < tadLen; j++) for (uint j = 0; j < tadLen; j++)
start = OpType::update(start, OpType::op(tad[innertadOffsets[j]], extraParams), extraParams); start = OpType::update(start, OpType::op(tad[innertadOffsets[j]], extraParams), extraParams);
auto zOffset = shape::indexOffset(i, zShapeInfo, castZShapeInfo, zLen, canCastZ); auto zOffset = shape::indexOffset(i, zShapeInfo, castZShapeInfo, canCastZ);
z[zOffset] = OpType::postProcess(start, tadLen, extraParams); z[zOffset] = OpType::postProcess(start, tadLen, extraParams);
} }
@ -658,13 +658,13 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (uint i = 0; i < lenPerThread; i++) { for (uint i = 0; i < lenPerThread; i++) {
const auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, castXShapeInfo, len, canCastX); const auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, castXShapeInfo, canCastX);
zi[i * zEws] = OpType::op(x[xOffset], extraParams); zi[i * zEws] = OpType::op(x[xOffset], extraParams);
} }
} else { } else {
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (uint i = 0; i < lenPerThread; i++) { for (uint i = 0; i < lenPerThread; i++) {
const auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, castXShapeInfo, len, canCastX); const auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, castXShapeInfo, canCastX);
zi[i] = OpType::op(x[xOffset], extraParams); zi[i] = OpType::op(x[xOffset], extraParams);
} }
} }
@ -782,8 +782,8 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (uint i = 0; i < lenPerThread; i++) { for (uint i = 0; i < lenPerThread; i++) {
auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, len, canCastX); auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, len, canCastZ); auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, canCastZ);
z[zOffset] = OpType::op(x[xOffset], extraParams); z[zOffset] = OpType::op(x[xOffset], extraParams);
} }
} }
@ -1123,7 +1123,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
auto start = OpType::startingValue(xTad); auto start = OpType::startingValue(xTad);
for (uint j = 0; j < tadLen; ++j) { for (uint j = 0; j < tadLen; ++j) {
const auto tadOffset = shape::indexOffset(j, xTadShapeInfo, castXTadShapeInfo, tadLen, canCastXTad); const auto tadOffset = shape::indexOffset(j, xTadShapeInfo, castXTadShapeInfo, canCastXTad);
start = OpType::update(start, OpType::op(xTad[tadOffset], yTad[tadOffset], extraParams), extraParams); start = OpType::update(start, OpType::op(xTad[tadOffset], yTad[tadOffset], extraParams), extraParams);
} }
@ -1147,8 +1147,8 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
auto start = OpType::startingValue(xTad); auto start = OpType::startingValue(xTad);
for (uint j = 0; j < tadLen; ++j) { for (uint j = 0; j < tadLen; ++j) {
const auto xTadOffset = shape::indexOffset(j, xTadShapeInfo, castXTadShapeInfo, tadLen, canCastXTad); const auto xTadOffset = shape::indexOffset(j, xTadShapeInfo, castXTadShapeInfo, canCastXTad);
const auto yTadOffset = shape::indexOffset(j, yTadShapeInfo, castYTadShapeInfo, tadLen, canCastYTad); const auto yTadOffset = shape::indexOffset(j, yTadShapeInfo, castYTadShapeInfo, canCastYTad);
start = OpType::update(start, OpType::op(xTad[xTadOffset], yTad[yTadOffset], extraParams), extraParams); start = OpType::update(start, OpType::op(xTad[xTadOffset], yTad[yTadOffset], extraParams), extraParams);
} }
@ -1423,7 +1423,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
auto start = startVal; auto start = startVal;
for (uint j = 0; j < tadLen; ++j) { for (uint j = 0; j < tadLen; ++j) {
const auto tadOffset = shape::indexOffset(j, xTadShapeInfo, castXTadShapeInfo, tadLen, canCastXTad); const auto tadOffset = shape::indexOffset(j, xTadShapeInfo, castXTadShapeInfo, canCastXTad);
start = OpType::update(start, OpType::op(xTad[tadOffset], yTad[tadOffset], extraParams), extraParams); start = OpType::update(start, OpType::op(xTad[tadOffset], yTad[tadOffset], extraParams), extraParams);
} }
z[zInd * zEws] = OpType::postProcess(start, tadLen, extraParams); z[zInd * zEws] = OpType::postProcess(start, tadLen, extraParams);
@ -1449,8 +1449,8 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo,
auto start = startVal; auto start = startVal;
for (uint j = 0; j < tadLen; ++j) { for (uint j = 0; j < tadLen; ++j) {
const auto xTadOffset = shape::indexOffset(j, xTadShapeInfo, castXTadShapeInfo, tadLen, canCastXTad); const auto xTadOffset = shape::indexOffset(j, xTadShapeInfo, castXTadShapeInfo, canCastXTad);
const auto yTadOffset = shape::indexOffset(j, yTadShapeInfo, castYTadShapeInfo, tadLen, canCastYTad); const auto yTadOffset = shape::indexOffset(j, yTadShapeInfo, castYTadShapeInfo, canCastYTad);
start = OpType::update(start, OpType::op(xTad[xTadOffset], yTad[yTadOffset], extraParams), extraParams); start = OpType::update(start, OpType::op(xTad[xTadOffset], yTad[yTadOffset], extraParams), extraParams);
} }

View File

@ -15,7 +15,7 @@
******************************************************************************/ ******************************************************************************/
// //
// @author iuriish@yahoo.com // @author Yurii Shyrma (iuriish@yahoo.com)
// //
#ifndef LIBND4J_SHAPEUTILS_H #ifndef LIBND4J_SHAPEUTILS_H

View File

@ -526,7 +526,7 @@ namespace shape {
/* int *sub = new int[leftOverIndexLen]; /* int *sub = new int[leftOverIndexLen];
shape::ind2subOrder(tadShape,index,len,sub); shape::ind2subOrder(tadShape,index,len,sub);
*/ */
shape::index2coords(leftOverIndexLen,tadShape, index,len, sub); shape::index2coords(index, leftOverIndexLen,tadShape, sub);
for(int i = 0; i < leftOverIndexLen; i++) { for(int i = 0; i < leftOverIndexLen; i++) {
@ -609,7 +609,7 @@ namespace shape {
if(dimensionLength > 1) { if(dimensionLength > 1) {
Nd4jLong *tad2Sub = this->tad2Sub(index, ptrManager); Nd4jLong *tad2Sub = this->tad2Sub(index, ptrManager);
Nd4jLong ret = shape::getOffset(0,shape::shapeOf(shapeInfo),shape::stride(shapeInfo),tad2Sub,shape::rank(shapeInfo)); Nd4jLong ret = shape::getOffset(shapeInfo, tad2Sub);
if(ret < 0) { if(ret < 0) {
if (ptrManager == nullptr) if (ptrManager == nullptr)
@ -625,7 +625,7 @@ namespace shape {
else { else {
Nd4jLong *tad2Sub = this->tad2Sub(index, ptrManager); Nd4jLong *tad2Sub = this->tad2Sub(index, ptrManager);
Nd4jLong ret = shape::getOffset(0,shape::shapeOf(shapeInfo),shape::stride(shapeInfo),tad2Sub,shape::rank(shapeInfo)); Nd4jLong ret = shape::getOffset(shapeInfo, tad2Sub);
if (ptrManager == nullptr) if (ptrManager == nullptr)
delete[] tad2Sub; delete[] tad2Sub;
@ -703,7 +703,7 @@ namespace shape {
/* int *sub = new int[leftOverIndexLen]; /* int *sub = new int[leftOverIndexLen];
shape::ind2subOrder(tadShape,index,len,sub); shape::ind2subOrder(tadShape,index,len,sub);
*/ */
shape::index2coords(leftOverIndexLen,tadShape,index,len, sub); shape::index2coords(index, leftOverIndexLen,tadShape, sub);
for(int i = 0; i < leftOverIndexLen; i++) { for(int i = 0; i < leftOverIndexLen; i++) {
ret[leftOverIndexes[i]] = sub[i]; ret[leftOverIndexes[i]] = sub[i];
@ -732,7 +732,7 @@ namespace shape {
// return shape::createScalarShapeInfo(); // return shape::createScalarShapeInfo();
//ensure tad shapes get setup right for vectors //ensure tad shapes get setup right for vectors
if(dimensionLength > 1 && shape::isVector(shapeInfo)) if(dimensionLength > 1 && shape::isVector(shapeInfo))
return shape::copyOf(shape::shapeInfoLength(shape::rank(shapeInfo)),shapeInfo); return shape::copyOf(shape::shapeInfoLength(shape::rank(shapeInfo)),shapeInfo);
// case when tad coincides with whole array // case when tad coincides with whole array

View File

@ -64,7 +64,7 @@ namespace nd4j {
for (int i = 0; i < totalIterations; i++) { for (int i = 0; i < totalIterations; i++) {
shape::index2coords(xRank, xShape, i, totalIterations, xCoords); shape::index2coords(i, xRank, xShape, xCoords);
Parameters params; Parameters params;
for (int j = 0; j < xRank; j++) { for (int j = 0; j < xRank; j++) {

View File

@ -226,7 +226,7 @@ void nd4j::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
indexValue = OpType::update(indexValue, comp, extraParams); indexValue = OpType::update(indexValue, comp, extraParams);
} }
auto zOffset = shape::indexOffset(i, zShapeInfo, castZShapeInfo, zLen, canCastZ); auto zOffset = shape::indexOffset(i, zShapeInfo, castZShapeInfo, canCastZ);
z[zOffset] = (Z) indexValue.index; z[zOffset] = (Z) indexValue.index;
} }
} }
@ -243,7 +243,7 @@ void nd4j::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
auto indexValue = OpType::startingIndexValue(tad); auto indexValue = OpType::startingIndexValue(tad);
for (uint j = 0; j < tadLen; j++) { for (uint j = 0; j < tadLen; j++) {
auto tadOffset = shape::indexOffset(j, tadShapeInfo, castTadShapeInfo, tadLen, canCastTad); auto tadOffset = shape::indexOffset(j, tadShapeInfo, castTadShapeInfo, canCastTad);
functions::indexreduce::IndexValue<X> comp(tad[tadOffset], j); functions::indexreduce::IndexValue<X> comp(tad[tadOffset], j);
indexValue = OpType::update(indexValue, comp, extraParams); indexValue = OpType::update(indexValue, comp, extraParams);
} }
@ -266,12 +266,12 @@ void nd4j::IndexReductionLoops<X,Z>::loopIndexReduce(X* x, Nd4jLong* xShapeInfo,
auto indexValue = OpType::startingIndexValue(tad); auto indexValue = OpType::startingIndexValue(tad);
for (uint j = 0; j < tadLen; j++) { for (uint j = 0; j < tadLen; j++) {
auto tadOffset = shape::indexOffset(j, tadShapeInfo, castTadShapeInfo, tadLen, canCastTad); auto tadOffset = shape::indexOffset(j, tadShapeInfo, castTadShapeInfo, canCastTad);
functions::indexreduce::IndexValue<X> comp(tad[tadOffset], j); functions::indexreduce::IndexValue<X> comp(tad[tadOffset], j);
indexValue = OpType::update(indexValue, comp, extraParams); indexValue = OpType::update(indexValue, comp, extraParams);
} }
auto zOffset = shape::indexOffset(i, zShapeInfo, castZShapeInfo, zLen, canCastZ); auto zOffset = shape::indexOffset(i, zShapeInfo, castZShapeInfo, canCastZ);
z[zOffset] = (Z) indexValue.index; z[zOffset] = (Z) indexValue.index;
} }
} }

View File

@ -15,7 +15,7 @@
******************************************************************************/ ******************************************************************************/
// //
// @author Yurii Shyrma // @author Yurii Shyrma (iuriish@yahoo.com)
// //
#include <algorithm> #include <algorithm>
@ -931,7 +931,7 @@ void ShapeUtils::evalIdxRangesForSubArr(const Nd4jLong subArrIdx, const Nd4jLon
for(int i = 0; i < subArrRank; ++i) for(int i = 0; i < subArrRank; ++i)
shapeOfSubArr[i] = shapeInfo[dimsToExclude[i] + 1]; shapeOfSubArr[i] = shapeInfo[dimsToExclude[i] + 1];
shape::index2coords(subArrRank, shapeOfSubArr.data(), subArrIdx, indexes.data()); shape::index2coords(subArrIdx, subArrRank, shapeOfSubArr.data(), indexes.data());
memset(idxRanges, 0, 2 * rank * sizeof(Nd4jLong)); memset(idxRanges, 0, 2 * rank * sizeof(Nd4jLong));

View File

@ -887,7 +887,7 @@ namespace shape {
* @param indices the indices to iterate over * @param indices the indices to iterate over
* @return the double at the specified index * @return the double at the specified index
*/ */
ND4J_EXPORT _CUDA_HD Nd4jLong getOffset(Nd4jLong baseOffset, const Nd4jLong *shape, const Nd4jLong *stride, const Nd4jLong *indices, const int rank);
ND4J_EXPORT _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const Nd4jLong *indices, Nd4jLong baseOffset = 0); ND4J_EXPORT _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const Nd4jLong *indices, Nd4jLong baseOffset = 0);
ND4J_EXPORT Nd4jLong getOffset(const Nd4jLong *shapeInfo, const std::vector<uint>& indices); ND4J_EXPORT Nd4jLong getOffset(const Nd4jLong *shapeInfo, const std::vector<uint>& indices);
@ -897,20 +897,19 @@ namespace shape {
/** /**
* Convert a linear index to the corresponding coordinates * Convert a linear index to the corresponding coordinates
* for example if shape is {2, 4}, then index 5 corresponds to following coordinates * for example if shape is {2, 4}, then index 5 corresponds to coordinates [1, 1]
* -> [1, 1] in case of c order
* -> [1, 2] in case of f order
*/ */
ND4J_EXPORT _CUDA_HD void index2coords(const int rank, const Nd4jLong *shape, Nd4jLong index, Nd4jLong arrLen, Nd4jLong *coords, const char order = 'c'); ND4J_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, Nd4jLong *coords);
ND4J_EXPORT _CUDA_HD void index2coords(const int rank, const Nd4jLong *shape, Nd4jLong index, Nd4jLong *coords, const char order = 'c'); ND4J_EXPORT _CUDA_HD void index2coords(Nd4jLong index, const int rank, const Nd4jLong *shape, Nd4jLong *coords);
/** /**
* Convert coordinates to the corresponding linear index (sequence number in other words) * Convert coordinates to the corresponding linear index (sequence number in other words)
* for example if shape is {2, 4}, then: * for example if shape is {2, 4} and coordinates [1, 1] then index 5 is returned
* in case of c order and coordinates [1, 1] index 5 is returned
* in case of f order and coordinates [1, 2] index 5 is returned
*/ */
ND4J_EXPORT _CUDA_HD Nd4jLong coords2index(const int rank, const Nd4jLong *shape, const Nd4jLong *coords, const char order = 'c'); ND4J_EXPORT _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, const Nd4jLong *coords);
ND4J_EXPORT _CUDA_HD Nd4jLong coords2index(const int rank, const Nd4jLong *shape, const Nd4jLong *coords);
/** /**
* increment n-dimensional array by one iteration by changing coord appropriately * increment n-dimensional array by one iteration by changing coord appropriately
@ -921,24 +920,10 @@ namespace shape {
*/ */
/* calculates an array buffer offset for given "index" using following formula: offset = coord_0*stride_0 + coord_1*stride_1 + ... + coord_{rank-1}*stride_{rank-1} /* calculates an array buffer offset for given "index" using following formula: offset = coord_0*stride_0 + coord_1*stride_1 + ... + coord_{rank-1}*stride_{rank-1}
* arrLen - array length
*/ */
ND4J_EXPORT _CUDA_HD uint getIndexOffset(uint index, const uint *shapeInfo, uint arrLen); ND4J_EXPORT _CUDA_HD uint getIndexOffset(uint index, const uint *shapeInfo);
ND4J_EXPORT _CUDA_HD Nd4jLong getIndexOffset(Nd4jLong index, const Nd4jLong *shapeInfo, Nd4jLong arrLen); ND4J_EXPORT _CUDA_HD Nd4jLong getIndexOffset(Nd4jLong index, const Nd4jLong *shapeInfo);
ND4J_EXPORT _CUDA_HD Nd4jLong getIndexOrderOffset(Nd4jLong index, const Nd4jLong *shapeInfo, Nd4jLong arrLen, const char order); ND4J_EXPORT _CUDA_HD Nd4jLong indexOffset(Nd4jLong index, const Nd4jLong* lShapeInfo, const uint* uShapeInfo, const bool useUnsigned);
ND4J_EXPORT _CUDA_HD Nd4jLong indexOffset(Nd4jLong index, const Nd4jLong* lShapeInfo, const uint* uShapeInfo, Nd4jLong arrLen, const bool useUnsigned);
/**
* Compute the real linear indices for the given shape and stride
*/
ND4J_EXPORT _CUDA_HD Nd4jLong *computeIndices(int rank, Nd4jLong *shape, Nd4jLong *stride);
/**
* Compute the real linear indices for the
* given shape buffer. Shape,stride and rank are derived
* from the buffer
*/
ND4J_EXPORT _CUDA_HD Nd4jLong *computeIndices( Nd4jLong *shapeBuffer);
ND4J_EXPORT _CUDA_HD void printShapeInfo(Nd4jLong *shapeInfo); ND4J_EXPORT _CUDA_HD void printShapeInfo(Nd4jLong *shapeInfo);
@ -1749,57 +1734,34 @@ __device__ INLINEDEF Nd4jLong *cuMalloc(Nd4jLong *buffer, long size) {
return output; return output;
} }
/** //////////////////////////////////////////////////////////////////////
* Compute the real linear indices for the given shape and stride INLINEDEF _CUDA_HD Nd4jLong coords2index(const Nd4jLong *shapeInfo, const Nd4jLong *indices) {
*/
INLINEDEF _CUDA_HD Nd4jLong *computeIndices(int rank, Nd4jLong *shape, Nd4jLong *stride) {
Nd4jLong length = shape::prodLong(shape,rank);
traceNew(13); Nd4jLong index, shift = 1;;
Nd4jLong *ret = new Nd4jLong[length]; index = indices[shapeInfo[0] - 1];
for(int i = 0; i < length; i++) { for(uint i = shapeInfo[0]; i > 1; --i) {
Nd4jLong *idx = new Nd4jLong[rank]; shift *= shapeInfo[i];
shape::index2coords(rank, shape, i, idx, 'f'); index += shift * indices[i - 2];
ret[i] = shape::getOffset(0, shape, stride, idx, rank);
delete[] idx;
}
return ret;
}
/**
* Compute the real linear indices for the given shape and stride
*/
INLINEDEF _CUDA_HD Nd4jLong *computeIndices(Nd4jLong *shapeBuffer) {
return computeIndices(shape::rank(shapeBuffer),shape::shapeOf(shapeBuffer),shape::stride(shapeBuffer));
} }
return index;
}
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
INLINEDEF _CUDA_HD Nd4jLong coords2index(const int rank, const Nd4jLong *shape, const Nd4jLong *indices, const char order) { INLINEDEF _CUDA_HD Nd4jLong coords2index(const int rank, const Nd4jLong *shape, const Nd4jLong *indices) {
Nd4jLong index, shift = 1;; Nd4jLong index, shift = 1;;
if(order == 'c') { index = indices[rank - 1];
for(uint i = rank - 1; i >= 1; --i) {
index = indices[rank - 1]; shift *= shape[i];
for(int i = rank - 2; i >= 0; --i) { index += shift * indices[i - 1];
shift *= shape[i + 1];
index += shift * indices[i];
}
}
else {
index = indices[0];
for(int i = 1; i < rank; ++i) {
shift *= shape[i - 1];
index += shift * indices[i];
}
}
return index;
} }
return index;
}
template <typename T> template <typename T>
INLINEDEF _CUDA_HD void fill(T* buffer, T value, Nd4jLong length) { INLINEDEF _CUDA_HD void fill(T* buffer, T value, Nd4jLong length) {
@ -1809,85 +1771,110 @@ template <typename T>
} }
////////////////////////////////////////////////////////////////////// // //////////////////////////////////////////////////////////////////////
INLINEDEF _CUDA_HD Nd4jLong getIndexOffset(Nd4jLong index, const Nd4jLong *shapeInfo, Nd4jLong arrLen) { // INLINEDEF _CUDA_HD Nd4jLong getIndexOffset(Nd4jLong index, const Nd4jLong *shapeInfo, Nd4jLong arrLen) {
const Nd4jLong ews = shapeInfo[shapeInfo[0] + shapeInfo[0] + 2]; // const Nd4jLong ews = shapeInfo[shapeInfo[0] + shapeInfo[0] + 2];
if(ews > 0 && order(shapeInfo) == 'c') // if(ews > 0 && order(shapeInfo) == 'c')
if (ews == 1) // if (ews == 1)
return index; // return index;
else // else
return ews * index; // return ews * index;
Nd4jLong offset = 0; // Nd4jLong offset = 0;
Nd4jLong rank = shapeInfo[0]; // Nd4jLong rank = shapeInfo[0];
for(int i = 1; i <= shapeInfo[0]; ++i) { // for(int i = 1; i <= shapeInfo[0]; ++i) {
arrLen /= shapeInfo[i]; // arrLen /= shapeInfo[i];
if(arrLen > 0 && shapeInfo[i] > 1) { // if(arrLen > 0 && shapeInfo[i] > 1) {
offset += (index / arrLen) * shapeInfo[i + rank]; // offset += (index / arrLen) * shapeInfo[i + rank];
index %= arrLen; // index %= arrLen;
} // }
} // }
return offset; // return offset;
} // }
INLINEDEF _CUDA_HD uint getIndexOffset(uint index, const uint *shapeInfo, uint arrLen) { // INLINEDEF _CUDA_HD uint getIndexOffset(uint index, const uint *shapeInfo, uint arrLen) {
const uint rank = shapeInfo[0]; // const uint rank = shapeInfo[0];
const uint ews = shapeInfo[rank + rank + 2]; // const uint ews = shapeInfo[rank + rank + 2];
if(ews > 0 && shapeInfo[rank + rank + 3] == 99) // if(ews > 0 && shapeInfo[rank + rank + 3] == 99)
if (ews == 1) // if (ews == 1)
return index; // return index;
else // else
return ews * index; // return ews * index;
uint offset = 0; // uint offset = 0;
for(uint i = 1; i <= rank; ++i) { // for(uint i = 1; i <= rank; ++i) {
arrLen /= shapeInfo[i]; // arrLen /= shapeInfo[i];
if(arrLen > 0 && shapeInfo[i] > 1) { // if(arrLen > 0 && shapeInfo[i] > 1) {
offset += (index / arrLen) * shapeInfo[i + rank]; // offset += (index / arrLen) * shapeInfo[i + rank];
index %= arrLen; // index %= arrLen;
} // }
} // }
return offset; // return offset;
} // }
INLINEDEF _CUDA_HD Nd4jLong indexOffset(Nd4jLong index, const Nd4jLong* lShapeInfo, const uint* uShapeInfo, Nd4jLong arrLen, const bool useUnsigned) {
if(useUnsigned)
return getIndexOffset(static_cast<uint>(index), uShapeInfo, static_cast<uint>(arrLen));
return getIndexOffset(index, lShapeInfo, arrLen);
}
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
INLINEDEF _CUDA_HD Nd4jLong getIndexOrderOffset(Nd4jLong index, const Nd4jLong *shapeInfo, Nd4jLong arrLen, const char order) { INLINEDEF _CUDA_HD Nd4jLong getIndexOffset(Nd4jLong index, const Nd4jLong *shapeInfo) {
Nd4jLong offset = 0; if (shapeInfo[2 * shapeInfo[0] + 3] == 99) {
if(order == 'c') {
for(int i = 1; i <= *shapeInfo; ++i) { const Nd4jLong ews = shapeInfo[2 * shapeInfo[0] + 2];
arrLen /= shapeInfo[i]; if (ews == 1)
if(arrLen > 0 && shapeInfo[i] > 1) { return index;
offset += (index / arrLen) * shapeInfo[i + *shapeInfo]; else if(ews > 1)
index %= arrLen; return ews * index;
}
}
}
else {
for(int i = *shapeInfo; i >= 1 ; --i) {
arrLen /= shapeInfo[i];
if(arrLen > 0 && shapeInfo[i] > 1) {
offset += (index / arrLen) * shapeInfo[i + *shapeInfo];
index %= arrLen;
}
}
}
return offset;
} }
Nd4jLong offset = 0;
for(uint i = shapeInfo[0]; i > 1; --i) {
offset += (index % shapeInfo[i]) * shapeInfo[i + shapeInfo[0]];
index /= shapeInfo[i];
}
offset += index * shapeInfo[1 + shapeInfo[0]]; // last iteration
return offset;
}
//////////////////////////////////////////////////////////////////////
INLINEDEF _CUDA_HD uint getIndexOffset(uint index, const uint *shapeInfo) {
if (shapeInfo[2 * shapeInfo[0] + 3] == 99) {
const Nd4jLong ews = shapeInfo[2 * shapeInfo[0] + 2];
if (ews == 1)
return index;
else if(ews > 1)
return ews * index;
}
uint offset = 0;
for(uint i = shapeInfo[0]; i > 1; --i) {
offset += (index % shapeInfo[i]) * shapeInfo[i + shapeInfo[0]];
index /= shapeInfo[i];
}
offset += index * shapeInfo[1 + shapeInfo[0]]; // last iteration
return offset;
}
//////////////////////////////////////////////////////////////////////
INLINEDEF _CUDA_HD Nd4jLong indexOffset(Nd4jLong index, const Nd4jLong* lShapeInfo, const uint* uShapeInfo, const bool useUnsigned) {
if(useUnsigned)
return getIndexOffset(static_cast<uint>(index), uShapeInfo);
return getIndexOffset(index, lShapeInfo);
}
/** /**
* *
* @param length * @param length
@ -2394,7 +2381,7 @@ template <typename T>
auto indices = new Nd4jLong[rank]; auto indices = new Nd4jLong[rank];
memset((void *) indices,0,rank * sizeof(Nd4jLong)); memset((void *) indices,0,rank * sizeof(Nd4jLong));
indices[0] = sliceIdx; indices[0] = sliceIdx;
Nd4jLong offset = shape::getOffset(0,newShape,newStride,indices,rank); Nd4jLong offset = shape::getOffset(newShapeBuffer, indices);
newShapeBuffer[shape::shapeInfoLength(newRank) - 3] = offset; newShapeBuffer[shape::shapeInfoLength(newRank) - 3] = offset;
// set current order and ews // set current order and ews
@ -3201,30 +3188,30 @@ INLINEDEF _CUDA_HD bool haveSameShapeAndStrides(const Nd4jLong *shapeInfo1, cons
* @param indices the indices to iterate over * @param indices the indices to iterate over
* @return the double at the specified index * @return the double at the specified index
*/ */
INLINEDEF _CUDA_HD Nd4jLong getOffset(Nd4jLong baseOffset, const Nd4jLong *shape, const Nd4jLong *stride, const Nd4jLong *indices, const int rank) {
Nd4jLong offset = baseOffset;
for(int i = 0; i < rank; i++) {
if(shape[i] != 1)
offset += indices[i] * stride[i];
}
return offset; //////////////////////////////////////////////////////////////////////////
} INLINEDEF _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const Nd4jLong *indices, Nd4jLong baseOffset) {
INLINEDEF _CUDA_HD Nd4jLong getOffset(const Nd4jLong *shapeInfo, const Nd4jLong *indices, Nd4jLong baseOffset) { Nd4jLong offset = baseOffset;
return shape::getOffset(baseOffset, shape::shapeOf(const_cast<Nd4jLong*>(shapeInfo)), shape::stride(const_cast<Nd4jLong*>(shapeInfo)), indices, shapeInfo[0]);
}
INLINEDEF Nd4jLong getOffset(const Nd4jLong *shapeInfo, const std::vector<uint>& indices) { for(uint i = 1; i <= shapeInfo[0]; ++i)
if(shapeInfo[i] != 1)
offset += indices[i - 1] * shapeInfo[shapeInfo[0] + i];
Nd4jLong offset = 0; return offset;
}
for(uint i = 0; i < shapeInfo[0]; ++i) //////////////////////////////////////////////////////////////////////////
if(shapeInfo[i + 1] != 1) INLINEDEF Nd4jLong getOffset(const Nd4jLong *shapeInfo, const std::vector<uint>& indices) {
offset += indices[i] * shapeInfo[shapeInfo[0] + i + 1];
return offset; Nd4jLong offset = 0;
}
for(uint i = 1; i <= shapeInfo[0]; ++i)
if(shapeInfo[i] != 1)
offset += indices[i - 1] * shapeInfo[shapeInfo[0] + i];
return offset;
}
@ -4209,24 +4196,24 @@ INLINEDEF _CUDA_HD void maxIndToMinInd(Nd4jLong* maxIdxs, Nd4jLong* minIdxs, con
INLINEDEF _CUDA_HD Nd4jLong subArrayIndex(const Nd4jLong maxIdx, const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, const int* dimsToExclude, const int dimsLen) { INLINEDEF _CUDA_HD Nd4jLong subArrayIndex(const Nd4jLong maxIdx, const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, const int* dimsToExclude, const int dimsLen) {
Nd4jLong maxIdxs[MAX_RANK]; Nd4jLong maxIdxs[MAX_RANK];
shape::index2coords(shape::rank(maxShapeInfo), const_cast<Nd4jLong *>(maxShapeInfo)+1, const_cast<Nd4jLong&>(maxIdx), maxIdxs, shape::order(maxShapeInfo)); shape::index2coords(const_cast<Nd4jLong&>(maxIdx), maxShapeInfo, maxIdxs);
Nd4jLong minIdxs[MAX_RANK]; Nd4jLong minIdxs[MAX_RANK];
maxIndToMinInd(maxIdxs, minIdxs, maxShapeInfo, minShapeInfo, dimsToExclude, dimsLen); maxIndToMinInd(maxIdxs, minIdxs, maxShapeInfo, minShapeInfo, dimsToExclude, dimsLen);
return coords2index(shape::rank(minShapeInfo), minShapeInfo + 1, minIdxs); return shape::coords2index(minShapeInfo, minIdxs);
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
INLINEDEF _CUDA_HD Nd4jLong subArrayOffset(const Nd4jLong maxIdx, const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, const int* dimsToExclude, const int dimsLen) { INLINEDEF _CUDA_HD Nd4jLong subArrayOffset(const Nd4jLong maxIdx, const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, const int* dimsToExclude, const int dimsLen) {
Nd4jLong maxIdxs[MAX_RANK]; Nd4jLong maxIdxs[MAX_RANK];
shape::index2coords(shape::rank(maxShapeInfo), const_cast<Nd4jLong *>(maxShapeInfo)+1, const_cast<Nd4jLong&>(maxIdx), maxIdxs, shape::order(maxShapeInfo)); shape::index2coords(const_cast<Nd4jLong&>(maxIdx), maxShapeInfo, maxIdxs);
Nd4jLong minIdxs[MAX_RANK]; Nd4jLong minIdxs[MAX_RANK];
maxIndToMinInd(maxIdxs, minIdxs, maxShapeInfo, minShapeInfo, dimsToExclude, dimsLen); maxIndToMinInd(maxIdxs, minIdxs, maxShapeInfo, minShapeInfo, dimsToExclude, dimsLen);
return getOffset(0, minShapeInfo + 1, minShapeInfo + shape::rank(minShapeInfo) + 1, minIdxs, shape::rank(minShapeInfo)); return getOffset(minShapeInfo, minIdxs);
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -4246,7 +4233,7 @@ INLINEDEF _CUDA_HD void maxIndToMinInd(Nd4jLong* maxIdxs, Nd4jLong* minIdxs, con
int N, minI, maxI; int N, minI, maxI;
// calculate min per-dim-indices which corresponds to absolute minIdx index // calculate min per-dim-indices which corresponds to absolute minIdx index
shape::index2coords(rankMin, minShapeInfo + 1, minIdx, indices, order(minShapeInfo)); shape::index2coords(minIdx, minShapeInfo, indices);
// transform storage indices to contain per-dim max indices, purpose - memory saving // transform storage indices to contain per-dim max indices, purpose - memory saving
// fill increment array as well // fill increment array as well
@ -4277,7 +4264,7 @@ INLINEDEF _CUDA_HD void maxIndToMinInd(Nd4jLong* maxIdxs, Nd4jLong* minIdxs, con
maxI = rankMax-1; maxI = rankMax-1;
N = 0; N = 0;
int step; int step;
maxOffsets[N++] = shape::getOffset(0, maxShapeInfo + 1, maxShapeInfo + rankMax + 1, indices, rankMax); maxOffsets[N++] = shape::getOffset(maxShapeInfo, indices);
// nested loops - producing of absolute indices for max array // nested loops - producing of absolute indices for max array
while(maxI >= 0) { while(maxI >= 0) {
@ -4290,7 +4277,7 @@ INLINEDEF _CUDA_HD void maxIndToMinInd(Nd4jLong* maxIdxs, Nd4jLong* minIdxs, con
step = -1; step = -1;
} }
else { else {
maxOffsets[N++] = shape::getOffset(0, maxShapeInfo + 1, maxShapeInfo + rankMax + 1, indices, rankMax); maxOffsets[N++] = shape::getOffset(maxShapeInfo, indices);
step = rankMax - 1 - maxI; step = rankMax - 1 - maxI;
} }
} }
@ -4322,7 +4309,7 @@ INLINEDEF _CUDA_HD void maxIndToMinInd(Nd4jLong* maxIdxs, Nd4jLong* minIdxs, con
int N, minI, maxI; int N, minI, maxI;
// calculate min per-dim-indices which corresponds to absolute minIdx index // calculate min per-dim-indices which corresponds to absolute minIdx index
shape::index2coords(rankMin, minShapeInfo + 1, minIdx, indices, order(minShapeInfo)); shape::index2coords(minIdx, minShapeInfo, indices);
// transform storage indices to contain per-dim max indices, purpose - memory saving // transform storage indices to contain per-dim max indices, purpose - memory saving
// fill increment array as well // fill increment array as well
@ -4353,7 +4340,7 @@ INLINEDEF _CUDA_HD void maxIndToMinInd(Nd4jLong* maxIdxs, Nd4jLong* minIdxs, con
maxI = rankMax-1; maxI = rankMax-1;
N = 0; N = 0;
int step; int step;
maxIdxs[N++] = coords2index(rankMax, maxShapeInfo + 1, indices); maxIdxs[N++] = shape::coords2index(maxShapeInfo, indices);
// nested loops - producing of absolute indices for max array // nested loops - producing of absolute indices for max array
while(maxI >= 0) { while(maxI >= 0) {
@ -4366,7 +4353,7 @@ INLINEDEF _CUDA_HD void maxIndToMinInd(Nd4jLong* maxIdxs, Nd4jLong* minIdxs, con
step = -1; step = -1;
} }
else { else {
maxIdxs[N++] = coords2index(rankMax, maxShapeInfo + 1, indices); maxIdxs[N++] = shape::coords2index(maxShapeInfo, indices);
step = rankMax - 1 - maxI; step = rankMax - 1 - maxI;
} }
} }
@ -4693,37 +4680,23 @@ INLINEDEF _CUDA_HD void calcSubArrShapeAndOffsets(const Nd4jLong* wholeShapeInfo
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
INLINEDEF void _CUDA_HD index2coords(const int rank, const Nd4jLong *shape, Nd4jLong index, Nd4jLong *coords, const char order) { INLINEDEF void _CUDA_HD index2coords(Nd4jLong index, const Nd4jLong *shapeInfo, Nd4jLong *coords) {
Nd4jLong arrLen = shape::prodLong(shape, rank);
shape::index2coords(rank, shape, index, arrLen, coords, order); for(uint i = shapeInfo[0]; i > 1; --i) {
coords[i - 1] = index % shapeInfo[i];
index /= shapeInfo[i];
}
coords[0] = index; // last iteration
} }
INLINEDEF void _CUDA_HD index2coords(const int rank, const Nd4jLong *shape, Nd4jLong index, Nd4jLong arrLen, Nd4jLong *coords, const char order) { //////////////////////////////////////////////////////////////////////
INLINEDEF void _CUDA_HD index2coords(Nd4jLong index, const int rank, const Nd4jLong *shape, Nd4jLong *coords) {
if(order == 'c') { for(uint i = rank - 1; i > 0; --i) {
coords[i] = index % shape[i];
for(int i = 0; i < rank; i++) { index /= shape[i];
arrLen /= shape[i];
if(arrLen > 0 && shape[i] > 1) {
coords[i] = index / arrLen;
index %= arrLen;
}
else
coords[i] = 0;
}
}
else {
for(int i = rank - 1; i >= 0; i--) {
arrLen /= shape[i];
if(arrLen > 0 && shape[i] > 1) {
coords[i] = index / arrLen;
index %= arrLen;
}
else
coords[i] = 0;
}
} }
coords[0] = index; // last iteration
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////

View File

@ -170,13 +170,13 @@ namespace functions {
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
for (int i = 0; i < tads; i++) { for (int i = 0; i < tads; i++) {
auto oX = x + tadOffsets[i]; auto oX = x + tadOffsets[i];
auto oZ = z + zTadOffset[i]; auto oZ = z + zTadOffset[i];
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (unsigned int f = 0; f < tadLength; f++) { for (unsigned int f = 0; f < tadLength; f++) {
auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, tadLength, canCastX); auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
oZ[offset] = OpType::op(oX[offset], y[offset]); oZ[offset] = OpType::op(oX[offset], y[offset]);
} }
} }
@ -190,14 +190,14 @@ namespace functions {
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
for (int i = 0; i < tads; i++) { for (int i = 0; i < tads; i++) {
auto oZ = z + zTadOffset[i]; auto oZ = z + zTadOffset[i];
auto oX = x + tadOffsets[i]; auto oX = x + tadOffsets[i];
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (int f = 0; f < tadLength; f++) { for (int f = 0; f < tadLength; f++) {
auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, tadLength, canCastX); auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, lenZ, canCastZ); auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ);
oZ[zOffset] = OpType::op(oX[offset], y[offset]); oZ[zOffset] = OpType::op(oX[offset], y[offset]);
} }
} }
@ -211,14 +211,14 @@ namespace functions {
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
for (int i = 0; i < tads; i++) { for (int i = 0; i < tads; i++) {
auto oZ = z + zTadOffset[i]; auto oZ = z + zTadOffset[i];
auto oX = x + tadOffsets[i]; auto oX = x + tadOffsets[i];
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (int f = 0; f < tadLength; f++) { for (int f = 0; f < tadLength; f++) {
auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, tadLength, canCastX); auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
auto yOffset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, lenY, canCastY); auto yOffset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY);
oZ[offset] = OpType::op(oX[offset], y[yOffset]); oZ[offset] = OpType::op(oX[offset], y[yOffset]);
} }
} }
@ -232,14 +232,14 @@ namespace functions {
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
for (int i = 0; i < tads; i++) { for (int i = 0; i < tads; i++) {
auto oZ = z + zTadOffset[i]; auto oZ = z + zTadOffset[i];
auto oX = x + tadOffsets[i]; auto oX = x + tadOffsets[i];
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (int f = 0; f < tadLength; f++) { for (int f = 0; f < tadLength; f++) {
auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, tadLength, canCastX); auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
auto offset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, lenY, canCastY); auto offset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY);
oZ[offset] = OpType::op(oX[xOffset], y[offset]); oZ[offset] = OpType::op(oX[xOffset], y[offset]);
} }
} }
@ -255,15 +255,15 @@ namespace functions {
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
for (int i = 0; i < tads; i++) { for (int i = 0; i < tads; i++) {
auto oZ = z + zTadOffset[i]; auto oZ = z + zTadOffset[i];
auto oX = x + tadOffsets[i]; auto oX = x + tadOffsets[i];
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (int f = 0; f < tadLength; f++) { for (int f = 0; f < tadLength; f++) {
auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, tadLength, canCastX); auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
auto yOffset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, lenY, canCastY); auto yOffset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY);
auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, lenZ, canCastZ); auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ);
oZ[zOffset] = OpType::op(oX[xOffset], y[yOffset]); oZ[zOffset] = OpType::op(oX[xOffset], y[yOffset]);
} }
} }
@ -362,7 +362,7 @@ namespace functions {
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (unsigned int f = 0; f < tadLength; f++) { for (unsigned int f = 0; f < tadLength; f++) {
auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, tadLength, canCastY); auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
oZ[offset] = OpType::op(x[offset], oY[offset]); oZ[offset] = OpType::op(x[offset], oY[offset]);
} }
} }
@ -382,8 +382,8 @@ namespace functions {
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (int f = 0; f < tadLength; f++) { for (int f = 0; f < tadLength; f++) {
auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, tadLength, canCastY); auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, lenZ, canCastZ); auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ);
oZ[zOffset] = OpType::op(x[offset], oY[offset]); oZ[zOffset] = OpType::op(x[offset], oY[offset]);
} }
} }
@ -403,8 +403,8 @@ namespace functions {
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (int f = 0; f < tadLength; f++) { for (int f = 0; f < tadLength; f++) {
auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, tadLength, canCastY); auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
auto xOffset = shape::indexOffset(f, yShapeInfo, xShapeInfoCast, lenX, canCastX); auto xOffset = shape::indexOffset(f, yShapeInfo, xShapeInfoCast, canCastX);
oZ[offset] = OpType::op(x[xOffset], oY[offset]); oZ[offset] = OpType::op(x[xOffset], oY[offset]);
} }
} }
@ -424,8 +424,8 @@ namespace functions {
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (int f = 0; f < tadLength; f++) { for (int f = 0; f < tadLength; f++) {
auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, tadLength, canCastY); auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
auto offset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, lenX, canCastX); auto offset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX);
oZ[offset] = OpType::op(x[offset], oY[yOffset]); oZ[offset] = OpType::op(x[offset], oY[yOffset]);
} }
} }
@ -447,9 +447,9 @@ namespace functions {
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (int f = 0; f < tadLength; f++) { for (int f = 0; f < tadLength; f++) {
auto xOffset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, lenX, canCastX); auto xOffset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX);
auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, tadLength, canCastY); auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, lenZ, canCastZ); auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ);
oZ[zOffset] = OpType::op(x[xOffset], oY[yOffset]); oZ[zOffset] = OpType::op(x[xOffset], oY[yOffset]);
} }
} }

View File

@ -126,7 +126,7 @@ namespace functions {
if (zTadShapeInfo == nullptr) { if (zTadShapeInfo == nullptr) {
zTadShapeInfo = xTadShapeShapeInfo; zTadShapeInfo = xTadShapeShapeInfo;
zTadOffset = tadOffsets; zTadOffset = tadOffsets;
} }
auto lenZ = shape::length(zTadShapeInfo); auto lenZ = shape::length(zTadShapeInfo);
auto lenY = shape::length(yShapeInfo); auto lenY = shape::length(yShapeInfo);
@ -140,7 +140,7 @@ namespace functions {
auto zEws = shape::elementWiseStride(zTadShapeInfo); auto zEws = shape::elementWiseStride(zTadShapeInfo);
const nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXYZ(xTadShapeShapeInfo, yShapeInfo, zTadShapeInfo); const nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXYZ(xTadShapeShapeInfo, yShapeInfo, zTadShapeInfo);
if (kindOfLoop == nd4j::LoopKind::EWS1) { if (kindOfLoop == nd4j::LoopKind::EWS1) {
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
for (int i = 0; i < tads; i++) { for (int i = 0; i < tads; i++) {
@ -170,15 +170,15 @@ namespace functions {
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
for (int i = 0; i < tads; i++) { for (int i = 0; i < tads; i++) {
auto oZ = z + zTadOffset[i]; auto oZ = z + zTadOffset[i];
auto oX = x + tadOffsets[i]; auto oX = x + tadOffsets[i];
// TODO: cover this codebranch with tests // TODO: cover this codebranch with tests
// all this stuff already happens within thread // all this stuff already happens within thread
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (int f = 0; f < tadLength; f++) { for (int f = 0; f < tadLength; f++) {
auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, tadLength, canCastX); auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
oZ[offset] = OpType::op(oX[offset], y[offset]); oZ[offset] = OpType::op(oX[offset], y[offset]);
} }
} }
@ -192,14 +192,14 @@ namespace functions {
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
for (int i = 0; i < tads; i++) { for (int i = 0; i < tads; i++) {
auto oZ = z + zTadOffset[i]; auto oZ = z + zTadOffset[i];
auto oX = x + tadOffsets[i]; auto oX = x + tadOffsets[i];
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (int f = 0; f < tadLength; f++) { for (int f = 0; f < tadLength; f++) {
auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, tadLength, canCastX); auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, lenZ, canCastZ); auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ);
oZ[zOffset] = OpType::op(oX[offset], y[offset]); oZ[zOffset] = OpType::op(oX[offset], y[offset]);
} }
} }
@ -213,14 +213,14 @@ namespace functions {
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
for (int i = 0; i < tads; i++) { for (int i = 0; i < tads; i++) {
auto oZ = z + zTadOffset[i]; auto oZ = z + zTadOffset[i];
auto oX = x + tadOffsets[i]; auto oX = x + tadOffsets[i];
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (int f = 0; f < tadLength; f++) { for (int f = 0; f < tadLength; f++) {
auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, tadLength, canCastX); auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
auto yOffset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, lenY, canCastY); auto yOffset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY);
oZ[offset] = OpType::op(oX[offset], y[yOffset]); oZ[offset] = OpType::op(oX[offset], y[yOffset]);
} }
} }
@ -234,14 +234,14 @@ namespace functions {
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
for (int i = 0; i < tads; i++) { for (int i = 0; i < tads; i++) {
auto oZ = z + zTadOffset[i]; auto oZ = z + zTadOffset[i];
auto oX = x + tadOffsets[i]; auto oX = x + tadOffsets[i];
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (int f = 0; f < tadLength; f++) { for (int f = 0; f < tadLength; f++) {
auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, tadLength, canCastX); auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
auto offset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, lenY, canCastY); auto offset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY);
oZ[offset] = OpType::op(oX[xOffset], y[offset]); oZ[offset] = OpType::op(oX[xOffset], y[offset]);
} }
} }
@ -257,15 +257,15 @@ namespace functions {
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
for (int i = 0; i < tads; i++) { for (int i = 0; i < tads; i++) {
auto oZ = z + zTadOffset[i]; auto oZ = z + zTadOffset[i];
auto oX = x + tadOffsets[i]; auto oX = x + tadOffsets[i];
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (int f = 0; f < tadLength; f++) { for (int f = 0; f < tadLength; f++) {
auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, tadLength, canCastX); auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
auto yOffset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, lenY, canCastY); auto yOffset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY);
auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, lenZ, canCastZ); auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ);
oZ[zOffset] = OpType::op(oX[xOffset], y[yOffset]); oZ[zOffset] = OpType::op(oX[xOffset], y[yOffset]);
} }
} }
@ -365,7 +365,7 @@ namespace functions {
// all this stuff already happens within thread // all this stuff already happens within thread
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (int f = 0; f < tadLength; f++) { for (int f = 0; f < tadLength; f++) {
auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, tadLength, canCastY); auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
oZ[offset] = OpType::op(x[offset], oY[offset]); oZ[offset] = OpType::op(x[offset], oY[offset]);
} }
} }
@ -385,8 +385,8 @@ namespace functions {
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (int f = 0; f < tadLength; f++) { for (int f = 0; f < tadLength; f++) {
auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, tadLength, canCastY); auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, lenZ, canCastZ); auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ);
oZ[zOffset] = OpType::op(x[offset], oY[offset]); oZ[zOffset] = OpType::op(x[offset], oY[offset]);
} }
} }
@ -406,8 +406,8 @@ namespace functions {
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (int f = 0; f < tadLength; f++) { for (int f = 0; f < tadLength; f++) {
auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, tadLength, canCastY); auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
auto xOffset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, lenX, canCastX); auto xOffset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX);
oZ[offset] = OpType::op(x[xOffset], oY[offset]); oZ[offset] = OpType::op(x[xOffset], oY[offset]);
} }
} }
@ -427,8 +427,8 @@ namespace functions {
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (int f = 0; f < tadLength; f++) { for (int f = 0; f < tadLength; f++) {
auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, tadLength, canCastY); auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
auto offset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, lenX, canCastX); auto offset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX);
oZ[offset] = OpType::op(x[offset], oY[yOffset]); oZ[offset] = OpType::op(x[offset], oY[yOffset]);
} }
} }
@ -450,9 +450,9 @@ namespace functions {
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (int f = 0; f < tadLength; f++) { for (int f = 0; f < tadLength; f++) {
auto xOffset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, lenX, canCastX); auto xOffset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX);
auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, tadLength, canCastY); auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, lenZ, canCastZ); auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ);
oZ[zOffset] = OpType::op(x[xOffset], oY[yOffset]); oZ[zOffset] = OpType::op(x[xOffset], oY[yOffset]);
} }
} }

View File

@ -126,7 +126,7 @@ namespace functions {
if (zTadShapeInfo == nullptr) { if (zTadShapeInfo == nullptr) {
zTadShapeInfo = xTadShapeShapeInfo; zTadShapeInfo = xTadShapeShapeInfo;
zTadOffset = tadOffsets; zTadOffset = tadOffsets;
} }
auto lenZ = shape::length(zTadShapeInfo); auto lenZ = shape::length(zTadShapeInfo);
auto lenY = shape::length(yShapeInfo); auto lenY = shape::length(yShapeInfo);
@ -140,7 +140,7 @@ namespace functions {
auto zEws = shape::elementWiseStride(zTadShapeInfo); auto zEws = shape::elementWiseStride(zTadShapeInfo);
const nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXYZ(xTadShapeShapeInfo, yShapeInfo, zTadShapeInfo); const nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXYZ(xTadShapeShapeInfo, yShapeInfo, zTadShapeInfo);
if (kindOfLoop == nd4j::LoopKind::EWS1) { if (kindOfLoop == nd4j::LoopKind::EWS1) {
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
for (int i = 0; i < tads; i++) { for (int i = 0; i < tads; i++) {
@ -170,15 +170,15 @@ namespace functions {
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
for (int i = 0; i < tads; i++) { for (int i = 0; i < tads; i++) {
auto oZ = z + zTadOffset[i]; auto oZ = z + zTadOffset[i];
auto oX = x + tadOffsets[i]; auto oX = x + tadOffsets[i];
// TODO: cover this codebranch with tests // TODO: cover this codebranch with tests
// all this stuff already happens within thread // all this stuff already happens within thread
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (int f = 0; f < tadLength; f++) { for (int f = 0; f < tadLength; f++) {
auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, tadLength, canCastX); auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
oZ[offset] = OpType::op(oX[offset], y[offset]); oZ[offset] = OpType::op(oX[offset], y[offset]);
} }
} }
@ -192,14 +192,14 @@ namespace functions {
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
for (int i = 0; i < tads; i++) { for (int i = 0; i < tads; i++) {
auto oZ = z + zTadOffset[i]; auto oZ = z + zTadOffset[i];
auto oX = x + tadOffsets[i]; auto oX = x + tadOffsets[i];
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (int f = 0; f < tadLength; f++) { for (int f = 0; f < tadLength; f++) {
auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, tadLength, canCastX); auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, lenZ, canCastZ); auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ);
oZ[zOffset] = OpType::op(oX[offset], y[offset]); oZ[zOffset] = OpType::op(oX[offset], y[offset]);
} }
} }
@ -213,14 +213,14 @@ namespace functions {
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
for (int i = 0; i < tads; i++) { for (int i = 0; i < tads; i++) {
auto oZ = z + zTadOffset[i]; auto oZ = z + zTadOffset[i];
auto oX = x + tadOffsets[i]; auto oX = x + tadOffsets[i];
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (int f = 0; f < tadLength; f++) { for (int f = 0; f < tadLength; f++) {
auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, tadLength, canCastX); auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
auto yOffset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, lenY, canCastY); auto yOffset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY);
oZ[offset] = OpType::op(oX[offset], y[yOffset]); oZ[offset] = OpType::op(oX[offset], y[yOffset]);
} }
} }
@ -234,14 +234,14 @@ namespace functions {
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
for (int i = 0; i < tads; i++) { for (int i = 0; i < tads; i++) {
auto oZ = z + zTadOffset[i]; auto oZ = z + zTadOffset[i];
auto oX = x + tadOffsets[i]; auto oX = x + tadOffsets[i];
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (int f = 0; f < tadLength; f++) { for (int f = 0; f < tadLength; f++) {
auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, tadLength, canCastX); auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
auto offset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, lenY, canCastY); auto offset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY);
oZ[offset] = OpType::op(oX[xOffset], y[offset]); oZ[offset] = OpType::op(oX[xOffset], y[offset]);
} }
} }
@ -257,15 +257,15 @@ namespace functions {
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
for (int i = 0; i < tads; i++) { for (int i = 0; i < tads; i++) {
auto oZ = z + zTadOffset[i]; auto oZ = z + zTadOffset[i];
auto oX = x + tadOffsets[i]; auto oX = x + tadOffsets[i];
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (int f = 0; f < tadLength; f++) { for (int f = 0; f < tadLength; f++) {
auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, tadLength, canCastX); auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
auto yOffset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, lenY, canCastY); auto yOffset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY);
auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, lenZ, canCastZ); auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ);
oZ[zOffset] = OpType::op(oX[xOffset], y[yOffset]); oZ[zOffset] = OpType::op(oX[xOffset], y[yOffset]);
} }
} }
@ -365,7 +365,7 @@ namespace functions {
// all this stuff already happens within thread // all this stuff already happens within thread
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (int f = 0; f < tadLength; f++) { for (int f = 0; f < tadLength; f++) {
auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, tadLength, canCastY); auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
oZ[offset] = OpType::op(x[offset], oY[offset]); oZ[offset] = OpType::op(x[offset], oY[offset]);
} }
} }
@ -385,8 +385,8 @@ namespace functions {
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (int f = 0; f < tadLength; f++) { for (int f = 0; f < tadLength; f++) {
auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, tadLength, canCastY); auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, lenZ, canCastZ); auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ);
oZ[zOffset] = OpType::op(x[offset], oY[offset]); oZ[zOffset] = OpType::op(x[offset], oY[offset]);
} }
} }
@ -406,8 +406,8 @@ namespace functions {
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (int f = 0; f < tadLength; f++) { for (int f = 0; f < tadLength; f++) {
auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, tadLength, canCastY); auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
auto xOffset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, lenX, canCastX); auto xOffset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX);
oZ[offset] = OpType::op(x[xOffset], oY[offset]); oZ[offset] = OpType::op(x[xOffset], oY[offset]);
} }
} }
@ -427,8 +427,8 @@ namespace functions {
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (int f = 0; f < tadLength; f++) { for (int f = 0; f < tadLength; f++) {
auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, tadLength, canCastY); auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
auto offset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, lenX, canCastX); auto offset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX);
oZ[offset] = OpType::op(x[offset], oY[yOffset]); oZ[offset] = OpType::op(x[offset], oY[yOffset]);
} }
} }
@ -450,9 +450,9 @@ namespace functions {
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (int f = 0; f < tadLength; f++) { for (int f = 0; f < tadLength; f++) {
auto xOffset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, lenX, canCastX); auto xOffset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX);
auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, tadLength, canCastY); auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, lenZ, canCastZ); auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ);
oZ[zOffset] = OpType::op(x[xOffset], oY[yOffset]); oZ[zOffset] = OpType::op(x[xOffset], oY[yOffset]);
} }
} }

View File

@ -92,7 +92,7 @@ Nd4jLong IndexReduce<X, Y>::execScalar(void *vx, Nd4jLong *xShapeInfo, void *vex
auto ulen = info.getItersPerThread(threadNum); auto ulen = info.getItersPerThread(threadNum);
for (Nd4jLong i = 0; i < ulen; i++) { for (Nd4jLong i = 0; i < ulen; i++) {
auto offset = shape::indexOffset(threadOffset + i, xShapeInfo, xShapeInfoCast, len, canCastX); auto offset = shape::indexOffset(threadOffset + i, xShapeInfo, xShapeInfoCast, canCastX);
IndexValue<X> curr(x[offset], threadOffset + i); IndexValue<X> curr(x[offset], threadOffset + i);
local = OpType::update(local, curr, extraParams); local = OpType::update(local, curr, extraParams);
} }

View File

@ -137,7 +137,7 @@ namespace functions {
void *vz, void *vz,
Nd4jLong* zShapeInfo, Nd4jLong* zShapeInfo,
void *vextraParams) { void *vextraParams) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<X *>(vx);
auto y = reinterpret_cast<Y *>(vy); auto y = reinterpret_cast<Y *>(vy);
auto z = reinterpret_cast<Z *>(vz); auto z = reinterpret_cast<Z *>(vz);
@ -152,13 +152,13 @@ namespace functions {
if (shape::isScalar(yShapeInfo)) { if (shape::isScalar(yShapeInfo)) {
uint xShapeInfoCast[MAX_RANK]; uint xShapeInfoCast[MAX_RANK];
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
{ {
auto threadNum = omp_get_thread_num(); auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum); auto threadOffset = info.getThreadOffset(threadNum);
@ -166,25 +166,25 @@ namespace functions {
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for(unsigned int i = 0; i < ulen; i++) { for(unsigned int i = 0; i < ulen; i++) {
auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, n, canCastX); auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
z[offset] = OpType::op(x[offset], y[0], extraParams); z[offset] = OpType::op(x[offset], y[0], extraParams);
} }
} }
} }
else { else {
uint zShapeInfoCast[MAX_RANK]; uint zShapeInfoCast[MAX_RANK];
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast);
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
{ {
auto threadNum = omp_get_thread_num(); auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum); auto threadOffset = info.getThreadOffset(threadNum);
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum)); auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for(unsigned int i = 0; i < ulen; i++) { for(unsigned int i = 0; i < ulen; i++) {
auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, n, canCastX); auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, n, canCastZ); auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, canCastZ);
z[zOffset] = OpType::op(x[xOffset], y[0], extraParams); z[zOffset] = OpType::op(x[xOffset], y[0], extraParams);
} }
} }
@ -192,18 +192,18 @@ namespace functions {
return; return;
} }
const nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXYZ(xShapeInfo, yShapeInfo, zShapeInfo); const nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXYZ(xShapeInfo, yShapeInfo, zShapeInfo);
const bool sameShapesXY = shape::shapeEquals(xShapeInfo, yShapeInfo); const bool sameShapesXY = shape::shapeEquals(xShapeInfo, yShapeInfo);
if ((kindOfLoop == nd4j::LoopKind::EWS1 || kindOfLoop == nd4j::LoopKind::EWSNONZERO) && sameShapesXY) { if ((kindOfLoop == nd4j::LoopKind::EWS1 || kindOfLoop == nd4j::LoopKind::EWSNONZERO) && sameShapesXY) {
exec<OpType>(x, xEws, y, yEws, z, zEws, extraParams, n); exec<OpType>(x, xEws, y, yEws, z, zEws, extraParams, n);
} }
else if ((kindOfLoop == nd4j::LoopKind::EWS1 || kindOfLoop == nd4j::LoopKind::EWSNONZERO) && !sameShapesXY) { //not same shape else if ((kindOfLoop == nd4j::LoopKind::EWS1 || kindOfLoop == nd4j::LoopKind::EWSNONZERO) && !sameShapesXY) { //not same shape
exec<OpType>(x, xEws, y, yEws, z, zEws, extraParams, shape::length(yShapeInfo)); exec<OpType>(x, xEws, y, yEws, z, zEws, extraParams, shape::length(yShapeInfo));
} }
else { else {
if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
@ -211,14 +211,14 @@ namespace functions {
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
{ {
auto threadNum = omp_get_thread_num(); auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum); auto threadOffset = info.getThreadOffset(threadNum);
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum)); auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (unsigned int i = 0; i < ulen; i++) { for (unsigned int i = 0; i < ulen; i++) {
auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, n, canCastX); auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
z[offset] = OpType::op(x[offset], y[offset], extraParams); z[offset] = OpType::op(x[offset], y[offset], extraParams);
} }
} }
@ -231,15 +231,15 @@ namespace functions {
bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast);
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
{ {
auto threadNum = omp_get_thread_num(); auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum); auto threadOffset = info.getThreadOffset(threadNum);
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum)); auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (unsigned int i = 0; i < ulen; i++) { for (unsigned int i = 0; i < ulen; i++) {
auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, n, canCastX); auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, n, canCastZ); auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, canCastZ);
z[zOffset] = OpType::op(x[offset], y[offset], extraParams); z[zOffset] = OpType::op(x[offset], y[offset], extraParams);
} }
} }
@ -252,15 +252,15 @@ namespace functions {
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
{ {
auto threadNum = omp_get_thread_num(); auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum); auto threadOffset = info.getThreadOffset(threadNum);
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum)); auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (unsigned int i = 0; i < ulen; i++) { for (unsigned int i = 0; i < ulen; i++) {
auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, n, canCastX); auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
auto yOffset = shape::indexOffset(i + threadOffset, yShapeInfo, yShapeInfoCast, n, canCastY); auto yOffset = shape::indexOffset(i + threadOffset, yShapeInfo, yShapeInfoCast, canCastY);
z[offset] = OpType::op(x[offset], y[yOffset], extraParams); z[offset] = OpType::op(x[offset], y[yOffset], extraParams);
} }
} }
@ -273,15 +273,15 @@ namespace functions {
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
{ {
auto threadNum = omp_get_thread_num(); auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum); auto threadOffset = info.getThreadOffset(threadNum);
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum)); auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (unsigned int i = 0; i < ulen; i++) { for (unsigned int i = 0; i < ulen; i++) {
auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, n, canCastX); auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
auto offset = shape::indexOffset(i + threadOffset, yShapeInfo, yShapeInfoCast, n, canCastY); auto offset = shape::indexOffset(i + threadOffset, yShapeInfo, yShapeInfoCast, canCastY);
z[offset] = OpType::op(x[xOffset], y[offset], extraParams); z[offset] = OpType::op(x[xOffset], y[offset], extraParams);
} }
} }
@ -296,16 +296,16 @@ namespace functions {
bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast);
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
{ {
auto threadNum = omp_get_thread_num(); auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum); auto threadOffset = info.getThreadOffset(threadNum);
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum)); auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (unsigned int i = 0; i < ulen; i++) { for (unsigned int i = 0; i < ulen; i++) {
auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, n, canCastX); auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
auto yOffset = shape::indexOffset(i + threadOffset, yShapeInfo, yShapeInfoCast, n, canCastY); auto yOffset = shape::indexOffset(i + threadOffset, yShapeInfo, yShapeInfoCast, canCastY);
auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, n, canCastZ); auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, canCastZ);
z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams); z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams);
} }
} }

View File

@ -61,7 +61,7 @@ namespace functions {
Nd4jLong zEws, Nd4jLong zEws,
void *vextraParams, void *vextraParams,
const Nd4jLong n) { const Nd4jLong n) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<X *>(vx);
auto y = reinterpret_cast<X *>(vy); auto y = reinterpret_cast<X *>(vy);
auto z = reinterpret_cast<Z *>(vz); auto z = reinterpret_cast<Z *>(vz);
@ -72,9 +72,9 @@ namespace functions {
if (xEws == 1 && yEws == 1 && zEws == 1) { if (xEws == 1 && yEws == 1 && zEws == 1) {
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
{ {
auto threadNum = omp_get_thread_num(); auto threadNum = omp_get_thread_num();
Nd4jLong threadOffset = info.getThreadOffset(threadNum); Nd4jLong threadOffset = info.getThreadOffset(threadNum);
auto xi = x + threadOffset; auto xi = x + threadOffset;
auto yi = y + threadOffset; auto yi = y + threadOffset;
auto zi = z + threadOffset; auto zi = z + threadOffset;
@ -88,9 +88,9 @@ namespace functions {
else { else {
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
{ {
auto threadNum = omp_get_thread_num(); auto threadNum = omp_get_thread_num();
Nd4jLong threadOffset = info.getThreadOffset(threadNum); Nd4jLong threadOffset = info.getThreadOffset(threadNum);
auto xi = x + xEws*threadOffset; auto xi = x + xEws*threadOffset;
auto yi = y + yEws*threadOffset; auto yi = y + yEws*threadOffset;
auto zi = z + zEws*threadOffset; auto zi = z + zEws*threadOffset;
@ -151,33 +151,33 @@ namespace functions {
if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
{ {
auto threadNum = omp_get_thread_num(); auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum); auto threadOffset = info.getThreadOffset(threadNum);
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum)); auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for(Nd4jLong i = 0; i < ulen; i++) { for(Nd4jLong i = 0; i < ulen; i++) {
auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, n, canCastX); auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
z[offset] = OpType::op(x[offset], y[0], extraParams); z[offset] = OpType::op(x[offset], y[0], extraParams);
} }
} }
} }
else { else {
uint zShapeInfoCast[MAX_RANK]; uint zShapeInfoCast[MAX_RANK];
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast);
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
{ {
auto threadNum = omp_get_thread_num(); auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum); auto threadOffset = info.getThreadOffset(threadNum);
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum)); auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for(Nd4jLong i = 0; i < ulen; i++) { for(Nd4jLong i = 0; i < ulen; i++) {
auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, n, canCastX); auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, n, canCastZ); auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, canCastZ);
z[zOffset] = OpType::op(x[xOffset], y[0], extraParams); z[zOffset] = OpType::op(x[xOffset], y[0], extraParams);
} }
} }
@ -190,11 +190,11 @@ namespace functions {
if ((kindOfLoop == nd4j::LoopKind::EWS1 || kindOfLoop == nd4j::LoopKind::EWSNONZERO) && sameShapesXY) { if ((kindOfLoop == nd4j::LoopKind::EWS1 || kindOfLoop == nd4j::LoopKind::EWSNONZERO) && sameShapesXY) {
exec<OpType>(x, xEws, y, yEws, z, zEws, extraParams, n); exec<OpType>(x, xEws, y, yEws, z, zEws, extraParams, n);
} }
else if ((kindOfLoop == nd4j::LoopKind::EWS1 || kindOfLoop == nd4j::LoopKind::EWSNONZERO) && !sameShapesXY) { //not same shape else if ((kindOfLoop == nd4j::LoopKind::EWS1 || kindOfLoop == nd4j::LoopKind::EWSNONZERO) && !sameShapesXY) { //not same shape
exec<OpType>(x, xEws, y, yEws, z, zEws, extraParams, shape::length(yShapeInfo)); exec<OpType>(x, xEws, y, yEws, z, zEws, extraParams, shape::length(yShapeInfo));
} }
else { else {
if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
@ -202,83 +202,83 @@ namespace functions {
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
{ {
auto threadNum = omp_get_thread_num(); auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum); auto threadOffset = info.getThreadOffset(threadNum);
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum)); auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (Nd4jLong i = 0; i < ulen; i++) { for (Nd4jLong i = 0; i < ulen; i++) {
auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, n, canCastX); auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
z[offset] = OpType::op(x[offset], y[offset], extraParams); z[offset] = OpType::op(x[offset], y[offset], extraParams);
} }
} }
} }
else if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) { else if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) {
uint xShapeInfoCast[MAX_RANK]; uint xShapeInfoCast[MAX_RANK];
uint zShapeInfoCast[MAX_RANK]; uint zShapeInfoCast[MAX_RANK];
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast);
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
{ {
auto threadNum = omp_get_thread_num(); auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum); auto threadOffset = info.getThreadOffset(threadNum);
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum)); auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (Nd4jLong i = 0; i < ulen; i++) { for (Nd4jLong i = 0; i < ulen; i++) {
auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, n, canCastX); auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, n, canCastZ); auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, canCastZ);
z[zOffset] = OpType::op(x[offset], y[offset], extraParams); z[zOffset] = OpType::op(x[offset], y[offset], extraParams);
} }
} }
} }
else if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { else if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
uint xShapeInfoCast[MAX_RANK]; uint xShapeInfoCast[MAX_RANK];
uint yShapeInfoCast[MAX_RANK]; uint yShapeInfoCast[MAX_RANK];
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
const bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); const bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
{ {
auto threadNum = omp_get_thread_num(); auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum); auto threadOffset = info.getThreadOffset(threadNum);
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum)); auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (Nd4jLong i = 0; i < ulen; i++) { for (Nd4jLong i = 0; i < ulen; i++) {
auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, n, canCastX); auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
auto yOffset = shape::indexOffset(i + threadOffset, yShapeInfo, yShapeInfoCast, n, canCastY); auto yOffset = shape::indexOffset(i + threadOffset, yShapeInfo, yShapeInfoCast, canCastY);
z[offset] = OpType::op(x[offset], y[yOffset], extraParams); z[offset] = OpType::op(x[offset], y[yOffset], extraParams);
} }
} }
} }
else if(shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo)) { else if(shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo)) {
uint xShapeInfoCast[MAX_RANK]; uint xShapeInfoCast[MAX_RANK];
uint yShapeInfoCast[MAX_RANK]; uint yShapeInfoCast[MAX_RANK];
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
const bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); const bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
{ {
auto threadNum = omp_get_thread_num(); auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum); auto threadOffset = info.getThreadOffset(threadNum);
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum)); auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (Nd4jLong i = 0; i < ulen; i++) { for (Nd4jLong i = 0; i < ulen; i++) {
auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, n, canCastX); auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
auto offset = shape::indexOffset(i + threadOffset, yShapeInfo, yShapeInfoCast, n, canCastY); auto offset = shape::indexOffset(i + threadOffset, yShapeInfo, yShapeInfoCast, canCastY);
z[offset] = OpType::op(x[xOffset], y[offset], extraParams); z[offset] = OpType::op(x[xOffset], y[offset], extraParams);
} }
} }
} }
else { else {
uint xShapeInfoCast[MAX_RANK]; uint xShapeInfoCast[MAX_RANK];
uint yShapeInfoCast[MAX_RANK]; uint yShapeInfoCast[MAX_RANK];
uint zShapeInfoCast[MAX_RANK]; uint zShapeInfoCast[MAX_RANK];
@ -287,16 +287,16 @@ namespace functions {
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast);
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
{ {
auto threadNum = omp_get_thread_num(); auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum); auto threadOffset = info.getThreadOffset(threadNum);
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum)); auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (Nd4jLong i = 0; i < ulen; i++) { for (Nd4jLong i = 0; i < ulen; i++) {
auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, n, canCastX); auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
auto yOffset = shape::indexOffset(i + threadOffset, yShapeInfo, yShapeInfoCast, n, canCastY); auto yOffset = shape::indexOffset(i + threadOffset, yShapeInfo, yShapeInfoCast, canCastY);
auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, n, canCastZ); auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, canCastZ);
z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams); z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams);
} }
} }

View File

@ -61,7 +61,7 @@ namespace functions {
Nd4jLong zEws, Nd4jLong zEws,
void *vextraParams, void *vextraParams,
const Nd4jLong n) { const Nd4jLong n) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<X *>(vx);
auto y = reinterpret_cast<X *>(vy); auto y = reinterpret_cast<X *>(vy);
auto z = reinterpret_cast<X *>(vz); auto z = reinterpret_cast<X *>(vz);
@ -72,9 +72,9 @@ namespace functions {
if (xEws == 1 && yEws == 1 && zEws == 1) { if (xEws == 1 && yEws == 1 && zEws == 1) {
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
{ {
auto threadNum = omp_get_thread_num(); auto threadNum = omp_get_thread_num();
Nd4jLong threadOffset = info.getThreadOffset(threadNum); Nd4jLong threadOffset = info.getThreadOffset(threadNum);
auto xi = x + threadOffset; auto xi = x + threadOffset;
auto yi = y + threadOffset; auto yi = y + threadOffset;
auto zi = z + threadOffset; auto zi = z + threadOffset;
@ -88,9 +88,9 @@ namespace functions {
else { else {
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
{ {
auto threadNum = omp_get_thread_num(); auto threadNum = omp_get_thread_num();
Nd4jLong threadOffset = info.getThreadOffset(threadNum); Nd4jLong threadOffset = info.getThreadOffset(threadNum);
auto xi = x + xEws*threadOffset; auto xi = x + xEws*threadOffset;
auto yi = y + yEws*threadOffset; auto yi = y + yEws*threadOffset;
auto zi = z + zEws*threadOffset; auto zi = z + zEws*threadOffset;
@ -151,33 +151,33 @@ namespace functions {
if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
{ {
auto threadNum = omp_get_thread_num(); auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum); auto threadOffset = info.getThreadOffset(threadNum);
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum)); auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for(Nd4jLong i = 0; i < ulen; i++) { for(Nd4jLong i = 0; i < ulen; i++) {
auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, n, canCastX); auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
z[offset] = OpType::op(x[offset], y[0], extraParams); z[offset] = OpType::op(x[offset], y[0], extraParams);
} }
} }
} }
else { else {
uint zShapeInfoCast[MAX_RANK]; uint zShapeInfoCast[MAX_RANK];
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast);
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
{ {
auto threadNum = omp_get_thread_num(); auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum); auto threadOffset = info.getThreadOffset(threadNum);
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum)); auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for(Nd4jLong i = 0; i < ulen; i++) { for(Nd4jLong i = 0; i < ulen; i++) {
auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, n, canCastX); auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, n, canCastZ); auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, canCastZ);
z[zOffset] = OpType::op(x[xOffset], y[0], extraParams); z[zOffset] = OpType::op(x[xOffset], y[0], extraParams);
} }
} }
@ -190,11 +190,11 @@ namespace functions {
if ((kindOfLoop == nd4j::LoopKind::EWS1 || kindOfLoop == nd4j::LoopKind::EWSNONZERO) && sameShapesXY) { if ((kindOfLoop == nd4j::LoopKind::EWS1 || kindOfLoop == nd4j::LoopKind::EWSNONZERO) && sameShapesXY) {
exec<OpType>(x, xEws, y, yEws, z, zEws, extraParams, n); exec<OpType>(x, xEws, y, yEws, z, zEws, extraParams, n);
} }
else if ((kindOfLoop == nd4j::LoopKind::EWS1 || kindOfLoop == nd4j::LoopKind::EWSNONZERO) && !sameShapesXY) { //not same shape else if ((kindOfLoop == nd4j::LoopKind::EWS1 || kindOfLoop == nd4j::LoopKind::EWSNONZERO) && !sameShapesXY) { //not same shape
exec<OpType>(x, xEws, y, yEws, z, zEws, extraParams, shape::length(yShapeInfo)); exec<OpType>(x, xEws, y, yEws, z, zEws, extraParams, shape::length(yShapeInfo));
} }
else { else {
if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
@ -202,83 +202,83 @@ namespace functions {
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
{ {
auto threadNum = omp_get_thread_num(); auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum); auto threadOffset = info.getThreadOffset(threadNum);
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum)); auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (Nd4jLong i = 0; i < ulen; i++) { for (Nd4jLong i = 0; i < ulen; i++) {
auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, n, canCastX); auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
z[offset] = OpType::op(x[offset], y[offset], extraParams); z[offset] = OpType::op(x[offset], y[offset], extraParams);
} }
} }
} }
else if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) { else if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) {
uint xShapeInfoCast[MAX_RANK]; uint xShapeInfoCast[MAX_RANK];
uint zShapeInfoCast[MAX_RANK]; uint zShapeInfoCast[MAX_RANK];
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast);
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
{ {
auto threadNum = omp_get_thread_num(); auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum); auto threadOffset = info.getThreadOffset(threadNum);
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum)); auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (Nd4jLong i = 0; i < ulen; i++) { for (Nd4jLong i = 0; i < ulen; i++) {
auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, n, canCastX); auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, n, canCastZ); auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, canCastZ);
z[zOffset] = OpType::op(x[offset], y[offset], extraParams); z[zOffset] = OpType::op(x[offset], y[offset], extraParams);
} }
} }
} }
else if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { else if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
uint xShapeInfoCast[MAX_RANK]; uint xShapeInfoCast[MAX_RANK];
uint yShapeInfoCast[MAX_RANK]; uint yShapeInfoCast[MAX_RANK];
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
const bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); const bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
{ {
auto threadNum = omp_get_thread_num(); auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum); auto threadOffset = info.getThreadOffset(threadNum);
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum)); auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (Nd4jLong i = 0; i < ulen; i++) { for (Nd4jLong i = 0; i < ulen; i++) {
auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, n, canCastX); auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
auto yOffset = shape::indexOffset(i + threadOffset, yShapeInfo, yShapeInfoCast, n, canCastY); auto yOffset = shape::indexOffset(i + threadOffset, yShapeInfo, yShapeInfoCast, canCastY);
z[offset] = OpType::op(x[offset], y[yOffset], extraParams); z[offset] = OpType::op(x[offset], y[yOffset], extraParams);
} }
} }
} }
else if(shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo)) { else if(shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo)) {
uint xShapeInfoCast[MAX_RANK]; uint xShapeInfoCast[MAX_RANK];
uint yShapeInfoCast[MAX_RANK]; uint yShapeInfoCast[MAX_RANK];
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
const bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); const bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
{ {
auto threadNum = omp_get_thread_num(); auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum); auto threadOffset = info.getThreadOffset(threadNum);
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum)); auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (Nd4jLong i = 0; i < ulen; i++) { for (Nd4jLong i = 0; i < ulen; i++) {
auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, n, canCastX); auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
auto offset = shape::indexOffset(i + threadOffset, yShapeInfo, yShapeInfoCast, n, canCastY); auto offset = shape::indexOffset(i + threadOffset, yShapeInfo, yShapeInfoCast, canCastY);
z[offset] = OpType::op(x[xOffset], y[offset], extraParams); z[offset] = OpType::op(x[xOffset], y[offset], extraParams);
} }
} }
} }
else { else {
uint xShapeInfoCast[MAX_RANK]; uint xShapeInfoCast[MAX_RANK];
uint yShapeInfoCast[MAX_RANK]; uint yShapeInfoCast[MAX_RANK];
uint zShapeInfoCast[MAX_RANK]; uint zShapeInfoCast[MAX_RANK];
@ -287,16 +287,16 @@ namespace functions {
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast);
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
{ {
auto threadNum = omp_get_thread_num(); auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum); auto threadOffset = info.getThreadOffset(threadNum);
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum)); auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (Nd4jLong i = 0; i < ulen; i++) { for (Nd4jLong i = 0; i < ulen; i++) {
auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, n, canCastX); auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
auto yOffset = shape::indexOffset(i + threadOffset, yShapeInfo, yShapeInfoCast, n, canCastY); auto yOffset = shape::indexOffset(i + threadOffset, yShapeInfo, yShapeInfoCast, canCastY);
auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, n, canCastZ); auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, canCastZ);
z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams); z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams);
} }
} }

View File

@ -50,27 +50,27 @@ namespace functions {
return; return;
} }
auto length = shape::length(zShapeInfo); auto length = shape::length(zShapeInfo);
// nd4j::random::RandomBuffer *buffer = reinterpret_cast<nd4j::random::RandomBuffer *> (state); // nd4j::random::RandomBuffer *buffer = reinterpret_cast<nd4j::random::RandomBuffer *> (state);
nd4j::graph::RandomGenerator* rng = reinterpret_cast<nd4j::graph::RandomGenerator*>(state); nd4j::graph::RandomGenerator* rng = reinterpret_cast<nd4j::graph::RandomGenerator*>(state);
nd4j::OmpLaunchHelper info(length); nd4j::OmpLaunchHelper info(length);
if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
uint xShapeInfoCast[MAX_RANK]; uint xShapeInfoCast[MAX_RANK];
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
{ {
auto threadNum = omp_get_thread_num(); auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum); auto threadOffset = info.getThreadOffset(threadNum);
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum)); auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (Nd4jLong i = 0; i < ulen; i++) { for (Nd4jLong i = 0; i < ulen; i++) {
auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, length, canCastX); auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
z[offset] = OpClass::op(x[offset], y[offset], i, length, rng, extraArguments); z[offset] = OpClass::op(x[offset], y[offset], i, length, rng, extraArguments);
} }
} }
@ -79,19 +79,19 @@ namespace functions {
uint xShapeInfoCast[MAX_RANK]; uint xShapeInfoCast[MAX_RANK];
uint zShapeInfoCast[MAX_RANK]; uint zShapeInfoCast[MAX_RANK];
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast);
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
{ {
auto threadNum = omp_get_thread_num(); auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum); auto threadOffset = info.getThreadOffset(threadNum);
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum)); auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (Nd4jLong i = 0; i < ulen; i++) { for (Nd4jLong i = 0; i < ulen; i++) {
auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, length, canCastX); auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, length, canCastZ); auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, canCastZ);
z[zOffset] = OpClass::op(x[offset], y[offset], i, length, rng, extraArguments); z[zOffset] = OpClass::op(x[offset], y[offset], i, length, rng, extraArguments);
} }
} }
@ -100,19 +100,19 @@ namespace functions {
uint xShapeInfoCast[MAX_RANK]; uint xShapeInfoCast[MAX_RANK];
uint yShapeInfoCast[MAX_RANK]; uint yShapeInfoCast[MAX_RANK];
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
const bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); const bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
{ {
auto threadNum = omp_get_thread_num(); auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum); auto threadOffset = info.getThreadOffset(threadNum);
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum)); auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (Nd4jLong i = 0; i < ulen; i++) { for (Nd4jLong i = 0; i < ulen; i++) {
auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, length, canCastX); auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
auto yOffset = shape::indexOffset(i + threadOffset, yShapeInfo, yShapeInfoCast, length, canCastY); auto yOffset = shape::indexOffset(i + threadOffset, yShapeInfo, yShapeInfoCast, canCastY);
z[offset] = OpClass::op(x[offset], y[yOffset], i, length, rng, extraArguments); z[offset] = OpClass::op(x[offset], y[yOffset], i, length, rng, extraArguments);
} }
} }
@ -121,19 +121,19 @@ namespace functions {
uint xShapeInfoCast[MAX_RANK]; uint xShapeInfoCast[MAX_RANK];
uint yShapeInfoCast[MAX_RANK]; uint yShapeInfoCast[MAX_RANK];
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
const bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); const bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
{ {
auto threadNum = omp_get_thread_num(); auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum); auto threadOffset = info.getThreadOffset(threadNum);
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum)); auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (Nd4jLong i = 0; i < info.getItersPerThread(threadNum); i++) { for (Nd4jLong i = 0; i < info.getItersPerThread(threadNum); i++) {
auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, length, canCastX); auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
auto offset = shape::indexOffset(i + threadOffset, yShapeInfo, yShapeInfoCast, length, canCastY); auto offset = shape::indexOffset(i + threadOffset, yShapeInfo, yShapeInfoCast, canCastY);
z[offset] = OpClass::op(x[xOffset], y[offset], i, length, rng, extraArguments); z[offset] = OpClass::op(x[xOffset], y[offset], i, length, rng, extraArguments);
} }
} }
@ -143,21 +143,21 @@ namespace functions {
uint xShapeInfoCast[MAX_RANK]; uint xShapeInfoCast[MAX_RANK];
uint yShapeInfoCast[MAX_RANK]; uint yShapeInfoCast[MAX_RANK];
uint zShapeInfoCast[MAX_RANK]; uint zShapeInfoCast[MAX_RANK];
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
const bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); const bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast);
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
{ {
auto threadNum = omp_get_thread_num(); auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum); auto threadOffset = info.getThreadOffset(threadNum);
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum)); auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (Nd4jLong i = 0; i < ulen; i++) { for (Nd4jLong i = 0; i < ulen; i++) {
auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, length, canCastX); auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
auto yOffset = shape::indexOffset(i + threadOffset, yShapeInfo, yShapeInfoCast, length, canCastY); auto yOffset = shape::indexOffset(i + threadOffset, yShapeInfo, yShapeInfoCast, canCastY);
auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, length, canCastZ); auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, canCastZ);
z[zOffset] = OpClass::op(x[xOffset], y[yOffset], i, length, rng, extraArguments); z[zOffset] = OpClass::op(x[xOffset], y[yOffset], i, length, rng, extraArguments);
} }
} }
@ -185,18 +185,18 @@ namespace functions {
nd4j::graph::RandomGenerator* rng = reinterpret_cast<nd4j::graph::RandomGenerator*>(state); nd4j::graph::RandomGenerator* rng = reinterpret_cast<nd4j::graph::RandomGenerator*>(state);
nd4j::OmpLaunchHelper info(length); nd4j::OmpLaunchHelper info(length);
if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
{ {
auto threadNum = omp_get_thread_num(); auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum); auto threadOffset = info.getThreadOffset(threadNum);
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum)); auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (Nd4jLong i = 0; i < ulen; i++) { for (Nd4jLong i = 0; i < ulen; i++) {
auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, length, canCastX); auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
z[offset] = OpClass::op(x[offset], i, length, rng, extraArguments); z[offset] = OpClass::op(x[offset], i, length, rng, extraArguments);
} }
} }
@ -207,15 +207,15 @@ namespace functions {
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast);
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
{ {
auto threadNum = omp_get_thread_num(); auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum); auto threadOffset = info.getThreadOffset(threadNum);
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum)); auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (Nd4jLong i = 0; i < ulen; i++) { for (Nd4jLong i = 0; i < ulen; i++) {
auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, length, canCastX); auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, length, canCastZ); auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, canCastZ);
z[zOffset] = OpClass::op(x[xOffset], i, length, rng, extraArguments); z[zOffset] = OpClass::op(x[xOffset], i, length, rng, extraArguments);
} }
} }
@ -231,7 +231,7 @@ namespace functions {
auto extraArguments = reinterpret_cast<X *>(vextraArguments); auto extraArguments = reinterpret_cast<X *>(vextraArguments);
auto length = shape::length(zShapeInfo); auto length = shape::length(zShapeInfo);
//nd4j::random::RandomBuffer *buffer = reinterpret_cast<nd4j::random::RandomBuffer *> (state); //nd4j::random::RandomBuffer *buffer = reinterpret_cast<nd4j::random::RandomBuffer *> (state);
nd4j::graph::RandomGenerator* rng = reinterpret_cast<nd4j::graph::RandomGenerator*>(state); nd4j::graph::RandomGenerator* rng = reinterpret_cast<nd4j::graph::RandomGenerator*>(state);
nd4j::OmpLaunchHelper info(length); nd4j::OmpLaunchHelper info(length);
@ -240,14 +240,14 @@ namespace functions {
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast);
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
{ {
auto threadNum = omp_get_thread_num(); auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum); auto threadOffset = info.getThreadOffset(threadNum);
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum)); auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (Nd4jLong i = 0; i < ulen; i++) { for (Nd4jLong i = 0; i < ulen; i++) {
auto offset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, length, canCastZ); auto offset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, canCastZ);
z[offset] = OpClass::op(i+threadOffset, length, rng, extraArguments); z[offset] = OpClass::op(i+threadOffset, length, rng, extraArguments);
} }
} }

View File

@ -77,7 +77,7 @@ namespace functions {
PRAGMA_OMP_PARALLEL_FOR_SIMD_THREADS(maxThreads) PRAGMA_OMP_PARALLEL_FOR_SIMD_THREADS(maxThreads)
for(Nd4jLong i = 0; i < length; ++i) for(Nd4jLong i = 0; i < length; ++i)
intermediate[omp_get_thread_num()] = OpType::update(intermediate[omp_get_thread_num()], OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, length, canCastX)], extraParams), extraParams); intermediate[omp_get_thread_num()] = OpType::update(intermediate[omp_get_thread_num()], OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], extraParams), extraParams);
for (int e = 0; e < maxThreads; e++) for (int e = 0; e < maxThreads; e++)
@ -112,7 +112,7 @@ namespace functions {
PRAGMA_OMP_PARALLEL_FOR_SIMD PRAGMA_OMP_PARALLEL_FOR_SIMD
for(Nd4jLong i = 0; i < length; ++i) for(Nd4jLong i = 0; i < length; ++i)
intermediate[omp_get_thread_num()] = OpType::update(intermediate[omp_get_thread_num()], OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, length, canCastX)], extraParams), extraParams); intermediate[omp_get_thread_num()] = OpType::update(intermediate[omp_get_thread_num()], OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], extraParams), extraParams);
for (int e = 0; e < omp_get_max_threads(); e++) for (int e = 0; e < omp_get_max_threads(); e++)
start = OpType::update(start, intermediate[e], extraParams); start = OpType::update(start, intermediate[e], extraParams);

View File

@ -81,7 +81,7 @@ namespace functions {
PRAGMA_OMP_PARALLEL_FOR_SIMD_THREADS(maxThreads) PRAGMA_OMP_PARALLEL_FOR_SIMD_THREADS(maxThreads)
for(Nd4jLong i = 0; i < length; ++i) for(Nd4jLong i = 0; i < length; ++i)
intermediate[omp_get_thread_num()] = OpType::update(intermediate[omp_get_thread_num()], OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, length, canCastX)], extraParams), extraParams); intermediate[omp_get_thread_num()] = OpType::update(intermediate[omp_get_thread_num()], OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], extraParams), extraParams);
for (int e = 0; e < maxThreads; e++) for (int e = 0; e < maxThreads; e++)
@ -115,7 +115,7 @@ namespace functions {
PRAGMA_OMP_PARALLEL_FOR_SIMD PRAGMA_OMP_PARALLEL_FOR_SIMD
for(Nd4jLong i = 0; i < length; ++i) for(Nd4jLong i = 0; i < length; ++i)
intermediate[omp_get_thread_num()] = OpType::update(intermediate[omp_get_thread_num()], OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, length, canCastX)], extraParams), extraParams); intermediate[omp_get_thread_num()] = OpType::update(intermediate[omp_get_thread_num()], OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], extraParams), extraParams);
for (int e = 0; e < omp_get_max_threads(); e++) for (int e = 0; e < omp_get_max_threads(); e++)
start = OpType::update(start, intermediate[e], extraParams); start = OpType::update(start, intermediate[e], extraParams);

View File

@ -77,7 +77,7 @@ namespace functions {
PRAGMA_OMP_PARALLEL_FOR_SIMD_THREADS(maxThreads) PRAGMA_OMP_PARALLEL_FOR_SIMD_THREADS(maxThreads)
for(Nd4jLong i = 0; i < length; ++i) for(Nd4jLong i = 0; i < length; ++i)
intermediate[omp_get_thread_num()] = OpType::update(intermediate[omp_get_thread_num()], OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, length, canCastX)], extraParams), extraParams); intermediate[omp_get_thread_num()] = OpType::update(intermediate[omp_get_thread_num()], OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], extraParams), extraParams);
for (int e = 0; e < maxThreads; e++) for (int e = 0; e < maxThreads; e++)
@ -113,7 +113,7 @@ namespace functions {
PRAGMA_OMP_PARALLEL_FOR_SIMD PRAGMA_OMP_PARALLEL_FOR_SIMD
for(Nd4jLong i = 0; i < length; ++i) for(Nd4jLong i = 0; i < length; ++i)
intermediate[omp_get_thread_num()] = OpType::update(intermediate[omp_get_thread_num()], OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, length, canCastX)], extraParams), extraParams); intermediate[omp_get_thread_num()] = OpType::update(intermediate[omp_get_thread_num()], OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], extraParams), extraParams);
for (int e = 0; e < omp_get_max_threads(); e++) for (int e = 0; e < omp_get_max_threads(); e++)
start = OpType::update(start, intermediate[e], extraParams); start = OpType::update(start, intermediate[e], extraParams);

View File

@ -79,7 +79,7 @@ namespace functions {
PRAGMA_OMP_PARALLEL_FOR_SIMD_THREADS(maxThreads) PRAGMA_OMP_PARALLEL_FOR_SIMD_THREADS(maxThreads)
for(Nd4jLong i = 0; i < length; ++i) for(Nd4jLong i = 0; i < length; ++i)
intermediate[omp_get_thread_num()] = OpType::update(intermediate[omp_get_thread_num()], OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, length, canCastX)], extraParams), extraParams); intermediate[omp_get_thread_num()] = OpType::update(intermediate[omp_get_thread_num()], OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], extraParams), extraParams);
for (int e = 0; e < maxThreads; e++) for (int e = 0; e < maxThreads; e++)
@ -117,7 +117,7 @@ namespace functions {
PRAGMA_OMP_PARALLEL_FOR_SIMD_THREADS(maxThreads) PRAGMA_OMP_PARALLEL_FOR_SIMD_THREADS(maxThreads)
for(Nd4jLong i = 0; i < length; ++i) for(Nd4jLong i = 0; i < length; ++i)
intermediate[omp_get_thread_num()] = OpType::update(intermediate[omp_get_thread_num()], OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, length, canCastX)], extraParams), extraParams); intermediate[omp_get_thread_num()] = OpType::update(intermediate[omp_get_thread_num()], OpType::op(x[shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX)], extraParams), extraParams);
for (int e = 0; e < maxThreads; e++) for (int e = 0; e < maxThreads; e++)
start = OpType::update(start, intermediate[e], extraParams); start = OpType::update(start, intermediate[e], extraParams);

View File

@ -95,7 +95,7 @@ void Reduce3<X,Z>::execScalar(void *vx, Nd4jLong *xShapeInfo,
PRAGMA_OMP_PARALLEL_FOR_SIMD_THREADS(t._numThreads) PRAGMA_OMP_PARALLEL_FOR_SIMD_THREADS(t._numThreads)
for(unsigned int i = 0; i < length; i++) { for(unsigned int i = 0; i < length; i++) {
const auto threadNum = omp_get_thread_num(); const auto threadNum = omp_get_thread_num();
auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, length, canCastX); auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
intermediate[threadNum] = OpType::update(intermediate[threadNum], OpType::op(x[offset], y[offset], extraParamsLocal + 3 * threadNum), extraParamsLocal + 3 * threadNum); intermediate[threadNum] = OpType::update(intermediate[threadNum], OpType::op(x[offset], y[offset], extraParamsLocal + 3 * threadNum), extraParamsLocal + 3 * threadNum);
} }
} else { } else {
@ -105,8 +105,8 @@ void Reduce3<X,Z>::execScalar(void *vx, Nd4jLong *xShapeInfo,
PRAGMA_OMP_PARALLEL_FOR_SIMD_THREADS(t._numThreads) PRAGMA_OMP_PARALLEL_FOR_SIMD_THREADS(t._numThreads)
for(unsigned int i = 0; i < length; i++) { for(unsigned int i = 0; i < length; i++) {
const auto threadNum = omp_get_thread_num(); const auto threadNum = omp_get_thread_num();
auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, length, canCastX); auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX);
auto yOffset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, length, canCastY); auto yOffset = shape::indexOffset(i, yShapeInfo, yShapeInfoCast, canCastY);
intermediate[threadNum] = OpType::update(intermediate[threadNum], OpType::op(x[xOffset], y[yOffset], extraParamsLocal + 3 * threadNum), extraParamsLocal + 3 * threadNum); intermediate[threadNum] = OpType::update(intermediate[threadNum], OpType::op(x[xOffset], y[yOffset], extraParamsLocal + 3 * threadNum), extraParamsLocal + 3 * threadNum);
} }
} }

View File

@ -33,14 +33,14 @@ namespace scalar {
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
template<typename X, typename Y, typename Z> template<typename X, typename Y, typename Z>
template<typename OpType> template<typename OpType>
void ScalarTransform<X, Y, Z>::transform(void *vx, Nd4jLong *xShapeInfo, void ScalarTransform<X, Y, Z>::transform(void *vx, Nd4jLong *xShapeInfo,
void *vextraParams, void *vextraParams,
void *vz, Nd4jLong *zShapeInfo, void *vz, Nd4jLong *zShapeInfo,
void *vscalars, void *vscalars,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *xTadShapeInfo, Nd4jLong *xTadOffsets, Nd4jLong *xTadShapeInfo, Nd4jLong *xTadOffsets,
Nd4jLong *zTadShapeInfo, Nd4jLong *zTadOffsets) { Nd4jLong *zTadShapeInfo, Nd4jLong *zTadOffsets) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<X *>(vx);
auto z = reinterpret_cast<Z *>(vz); auto z = reinterpret_cast<Z *>(vz);
auto scalars = reinterpret_cast<Y *>(vscalars); auto scalars = reinterpret_cast<Y *>(vscalars);
@ -159,37 +159,37 @@ void ScalarTransform<X, Y, Z>::transform(void *vx, Nd4jLong *xShapeInfo,
PRAGMA_OMP_PARALLEL_THREADS_IF(info._numThreads, allowParallelism) PRAGMA_OMP_PARALLEL_THREADS_IF(info._numThreads, allowParallelism)
{ {
auto threadNum = omp_get_thread_num(); auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum); auto threadOffset = info.getThreadOffset(threadNum);
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum)); auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (unsigned int i = 0; i < ulen; i++) { for (unsigned int i = 0; i < ulen; i++) {
auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, len, canCastX); auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
z[offset] = OpType::op(x[offset], scalar, extraParams); z[offset] = OpType::op(x[offset], scalar, extraParams);
} }
} }
} }
else { else {
uint zShapeInfoCast[MAX_RANK]; uint zShapeInfoCast[MAX_RANK];
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo<uint>(zShapeInfo, zShapeInfoCast); const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo<uint>(zShapeInfo, zShapeInfoCast);
PRAGMA_OMP_PARALLEL_THREADS_IF(info._numThreads, allowParallelism) PRAGMA_OMP_PARALLEL_THREADS_IF(info._numThreads, allowParallelism)
{ {
auto threadNum = omp_get_thread_num(); auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum); auto threadOffset = info.getThreadOffset(threadNum);
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum)); auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (unsigned int i = 0; i < ulen; i++) { for (unsigned int i = 0; i < ulen; i++) {
auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, len, canCastX); auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, len, canCastZ); auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, canCastZ);
z[zOffset] = OpType::op(x[xOffset], scalar, extraParams); z[zOffset] = OpType::op(x[xOffset], scalar, extraParams);
} }
} }
} }
} }
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
@ -200,7 +200,7 @@ void ScalarTransform<X, Y, Z>::transform(void *vx, Nd4jLong xEws,
void *vscalar, void *vscalar,
void *vextraParams, void *vextraParams,
const Nd4jLong len, bool allowParallelism) { const Nd4jLong len, bool allowParallelism) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<X *>(vx);
auto z = reinterpret_cast<Z *>(vz); auto z = reinterpret_cast<Z *>(vz);
auto scalar = reinterpret_cast<Y *>(vscalar)[0]; auto scalar = reinterpret_cast<Y *>(vscalar)[0];

View File

@ -33,14 +33,14 @@ namespace functions {
template<typename X, typename Z> template<typename X, typename Z>
template<typename OpType> template<typename OpType>
void ScalarBoolTransform<X, Z>::transform(void *vx, Nd4jLong *xShapeInfo, void ScalarBoolTransform<X, Z>::transform(void *vx, Nd4jLong *xShapeInfo,
void *vextraParams, void *vextraParams,
void *vz, Nd4jLong *zShapeInfo, void *vz, Nd4jLong *zShapeInfo,
void *vscalars, void *vscalars,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *xTadShapeInfo, Nd4jLong *xTadOffsets, Nd4jLong *xTadShapeInfo, Nd4jLong *xTadOffsets,
Nd4jLong *zTadShapeInfo, Nd4jLong *zTadOffsets) { Nd4jLong *zTadShapeInfo, Nd4jLong *zTadOffsets) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<X *>(vx);
auto z = reinterpret_cast<Z *>(vz); auto z = reinterpret_cast<Z *>(vz);
auto scalars = reinterpret_cast<X *>(vscalars); auto scalars = reinterpret_cast<X *>(vscalars);
@ -63,7 +63,7 @@ namespace functions {
printf("ScalarBoolTransform<X, Z>::transform: super-bad loop visited. Shouldn't ever happen\n"); printf("ScalarBoolTransform<X, Z>::transform: super-bad loop visited. Shouldn't ever happen\n");
return; return;
} }
int num_threads = nd4j::math::nd4j_min<int>(numTads, omp_get_max_threads()); int num_threads = nd4j::math::nd4j_min<int>(numTads, omp_get_max_threads());
if (kindOfLoop == nd4j::LoopKind::EWS1) { if (kindOfLoop == nd4j::LoopKind::EWS1) {
@ -76,7 +76,7 @@ namespace functions {
for (unsigned int f = 0; f < tadLength; f++) for (unsigned int f = 0; f < tadLength; f++)
oZ[f] = OpType::op(oX[f], scalars[r], extraParams); oZ[f] = OpType::op(oX[f], scalars[r], extraParams);
} }
} }
else { // kindOfLoop != nd4j::LoopKind::EWSNONZERO else { // kindOfLoop != nd4j::LoopKind::EWSNONZERO
PRAGMA_OMP_PARALLEL_FOR_THREADS(num_threads) PRAGMA_OMP_PARALLEL_FOR_THREADS(num_threads)
for (unsigned int r = 0; r < numTads; r++) { for (unsigned int r = 0; r < numTads; r++) {
@ -87,7 +87,7 @@ namespace functions {
for (unsigned int f = 0; f < tadLength; f++) for (unsigned int f = 0; f < tadLength; f++)
oZ[f * zTadEws] = OpType::op(oX[f * xTadEws], scalars[r], extraParams); oZ[f * zTadEws] = OpType::op(oX[f * xTadEws], scalars[r], extraParams);
} }
} }
} }
template<typename X, typename Y> template<typename X, typename Y>
@ -139,7 +139,7 @@ namespace functions {
Nd4jLong *zShapeInfo, Nd4jLong *zShapeInfo,
void *vscalar, void *vscalar,
void *vextraParams) { void *vextraParams) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<X *>(vx);
auto z = reinterpret_cast<Z *>(vz); auto z = reinterpret_cast<Z *>(vz);
auto scalar = reinterpret_cast<X *>(vscalar)[0]; auto scalar = reinterpret_cast<X *>(vscalar)[0];
@ -162,41 +162,41 @@ namespace functions {
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo<uint>(xShapeInfo, xShapeInfoCast); const bool canCastX = nd4j::DataTypeUtils::castShapeInfo<uint>(xShapeInfo, xShapeInfoCast);
nd4j::OmpLaunchHelper info(len); nd4j::OmpLaunchHelper info(len);
if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
{ {
auto threadNum = omp_get_thread_num(); auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum); auto threadOffset = info.getThreadOffset(threadNum);
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum)); auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (unsigned int i = 0; i < ulen; i++) { for (unsigned int i = 0; i < ulen; i++) {
auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, len, canCastX); auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
z[offset] = OpType::op(x[offset], scalar, extraParams); z[offset] = OpType::op(x[offset], scalar, extraParams);
} }
} }
} }
else { else {
uint zShapeInfoCast[MAX_RANK]; uint zShapeInfoCast[MAX_RANK];
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo<uint>(zShapeInfo, zShapeInfoCast); const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo<uint>(zShapeInfo, zShapeInfoCast);
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
{ {
auto threadNum = omp_get_thread_num(); auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum); auto threadOffset = info.getThreadOffset(threadNum);
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum)); auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (unsigned int i = 0; i < ulen; i++) { for (unsigned int i = 0; i < ulen; i++) {
auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, len, canCastX); auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, len, canCastZ); auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, canCastZ);
z[zOffset] = OpType::op(x[xOffset], scalar, extraParams); z[zOffset] = OpType::op(x[xOffset], scalar, extraParams);
} }
} }
} }
} }
@ -213,7 +213,7 @@ namespace functions {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<X *>(vx);
auto z = reinterpret_cast<Z *>(vz); auto z = reinterpret_cast<Z *>(vz);
auto scalar = reinterpret_cast<X *>(vscalar)[0]; auto scalar = reinterpret_cast<X *>(vscalar)[0];
auto extraParams = reinterpret_cast<X *>(vextraParams); auto extraParams = reinterpret_cast<X *>(vextraParams);
nd4j::OmpLaunchHelper info(len); nd4j::OmpLaunchHelper info(len);
@ -231,7 +231,7 @@ namespace functions {
for (unsigned int i = 0; i < ulen; i++) for (unsigned int i = 0; i < ulen; i++)
zi[i] = OpType::op(xi[i], scalar, extraParams); zi[i] = OpType::op(xi[i], scalar, extraParams);
} }
} }
else { else {
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)

View File

@ -34,13 +34,13 @@ namespace functions {
template<typename X> template<typename X>
template<typename OpType> template<typename OpType>
void ScalarIntTransform<X>::transform(void *vx, Nd4jLong *xShapeInfo, void ScalarIntTransform<X>::transform(void *vx, Nd4jLong *xShapeInfo,
void *vextraParams, void *vextraParams,
void *vz, Nd4jLong *zShapeInfo, void *vz, Nd4jLong *zShapeInfo,
void *vscalars, void *vscalars,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *xTadShapeInfo, Nd4jLong *xTadOffsets, Nd4jLong *xTadShapeInfo, Nd4jLong *xTadOffsets,
Nd4jLong *zTadShapeInfo, Nd4jLong *zTadOffsets) { Nd4jLong *zTadShapeInfo, Nd4jLong *zTadOffsets) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<X *>(vx);
auto z = reinterpret_cast<X *>(vz); auto z = reinterpret_cast<X *>(vz);
auto scalars = reinterpret_cast<X *>(vscalars); auto scalars = reinterpret_cast<X *>(vscalars);
@ -63,7 +63,7 @@ namespace functions {
printf("ScalarIntTransform<X>::transform: super-bad loop visited. Shouldn't ever happen\n"); printf("ScalarIntTransform<X>::transform: super-bad loop visited. Shouldn't ever happen\n");
return; return;
} }
int num_threads = nd4j::math::nd4j_min<int>(numTads, omp_get_max_threads()); int num_threads = nd4j::math::nd4j_min<int>(numTads, omp_get_max_threads());
if (kindOfLoop == nd4j::LoopKind::EWS1) { if (kindOfLoop == nd4j::LoopKind::EWS1) {
@ -76,7 +76,7 @@ namespace functions {
for (unsigned int f = 0; f < tadLength; f++) for (unsigned int f = 0; f < tadLength; f++)
oZ[f] = OpType::op(oX[f], scalars[r], extraParams); oZ[f] = OpType::op(oX[f], scalars[r], extraParams);
} }
} }
else { // kindOfLoop != nd4j::LoopKind::EWSNONZERO else { // kindOfLoop != nd4j::LoopKind::EWSNONZERO
PRAGMA_OMP_PARALLEL_FOR_THREADS(num_threads) PRAGMA_OMP_PARALLEL_FOR_THREADS(num_threads)
for (unsigned int r = 0; r < numTads; r++) { for (unsigned int r = 0; r < numTads; r++) {
@ -87,7 +87,7 @@ namespace functions {
for (unsigned int f = 0; f < tadLength; f++) for (unsigned int f = 0; f < tadLength; f++)
oZ[f * zTadEws] = OpType::op(oX[f * xTadEws], scalars[r], extraParams); oZ[f * zTadEws] = OpType::op(oX[f * xTadEws], scalars[r], extraParams);
} }
} }
} }
template<typename X> template<typename X>
@ -139,7 +139,7 @@ namespace functions {
Nd4jLong *zShapeInfo, Nd4jLong *zShapeInfo,
void *vscalar, void *vscalar,
void *vextraParams) { void *vextraParams) {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<X *>(vx);
auto z = reinterpret_cast<X *>(vz); auto z = reinterpret_cast<X *>(vz);
auto scalar = reinterpret_cast<X *>(vscalar)[0]; auto scalar = reinterpret_cast<X *>(vscalar)[0];
@ -162,41 +162,41 @@ namespace functions {
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo<uint>(xShapeInfo, xShapeInfoCast); const bool canCastX = nd4j::DataTypeUtils::castShapeInfo<uint>(xShapeInfo, xShapeInfoCast);
nd4j::OmpLaunchHelper info(len); nd4j::OmpLaunchHelper info(len);
if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
{ {
auto threadNum = omp_get_thread_num(); auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum); auto threadOffset = info.getThreadOffset(threadNum);
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum)); auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (unsigned int i = 0; i < ulen; i++) { for (unsigned int i = 0; i < ulen; i++) {
auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, len, canCastX); auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
z[offset] = OpType::op(x[offset], scalar, extraParams); z[offset] = OpType::op(x[offset], scalar, extraParams);
} }
} }
} }
else { else {
uint zShapeInfoCast[MAX_RANK]; uint zShapeInfoCast[MAX_RANK];
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo<uint>(zShapeInfo, zShapeInfoCast); const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo<uint>(zShapeInfo, zShapeInfoCast);
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
{ {
auto threadNum = omp_get_thread_num(); auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum); auto threadOffset = info.getThreadOffset(threadNum);
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum)); auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (unsigned int i = 0; i < ulen; i++) { for (unsigned int i = 0; i < ulen; i++) {
auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, len, canCastX); auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, canCastX);
auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, len, canCastZ); auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, canCastZ);
z[zOffset] = OpType::op(x[xOffset], scalar, extraParams); z[zOffset] = OpType::op(x[xOffset], scalar, extraParams);
} }
} }
} }
} }
@ -213,7 +213,7 @@ namespace functions {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<X *>(vx);
auto z = reinterpret_cast<X *>(vz); auto z = reinterpret_cast<X *>(vz);
auto scalar = reinterpret_cast<X *>(vscalar)[0]; auto scalar = reinterpret_cast<X *>(vscalar)[0];
auto extraParams = reinterpret_cast<X *>(vextraParams); auto extraParams = reinterpret_cast<X *>(vextraParams);
nd4j::OmpLaunchHelper info(len); nd4j::OmpLaunchHelper info(len);
@ -231,7 +231,7 @@ namespace functions {
for (unsigned int i = 0; i < ulen; i++) for (unsigned int i = 0; i < ulen; i++)
zi[i] = OpType::op(xi[i], scalar, extraParams); zi[i] = OpType::op(xi[i], scalar, extraParams);
} }
} }
else { else {
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)

View File

@ -92,7 +92,7 @@ namespace functions {
for (Nd4jLong i = 0; i < length; i++) { for (Nd4jLong i = 0; i < length; i++) {
auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, length, canCast); auto xOffset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCast);
SummaryStatsData<X> curr; SummaryStatsData<X> curr;
curr.initWithValue(x[xOffset]); curr.initWithValue(x[xOffset]);
@ -175,7 +175,7 @@ namespace functions {
} }
else { else {
for (int i = 1; i < tadLength; i ++) { for (int i = 1; i < tadLength; i ++) {
auto xOffset = shape::indexOffset(i, tadShapeShapeInfo, tadShapeShapeInfoCast, tadLength, canCast); auto xOffset = shape::indexOffset(i, tadShapeShapeInfo, tadShapeShapeInfoCast, canCast);
SummaryStatsData <X> indexVal2; SummaryStatsData <X> indexVal2;
indexVal2.initWithValue(tx[xOffset]); indexVal2.initWithValue(tx[xOffset]);

View File

@ -42,7 +42,7 @@ static __global__ void broadcastSimple(
Nd4jLong *zShapeInfo, Nd4jLong *zShapeInfo,
int *dimension, int *dimension,
int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
functions::broadcast::Broadcast<X,Y,Z>::template transformCuda<OpClass>(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo,dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ); functions::broadcast::Broadcast<X,Y,Z>::template transformCuda<OpClass>(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo,dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ);
} }
@ -64,8 +64,8 @@ static __global__ void broadcastInverseSimple(
namespace functions { namespace functions {
namespace broadcast { namespace broadcast {
static Nd4jLong __device__ __noinline__ _getIndexOffset(Nd4jLong index, Nd4jLong *shapeInfo, Nd4jLong length) { static Nd4jLong __device__ __noinline__ _getIndexOffset(Nd4jLong index, Nd4jLong *shapeInfo) {
return shape::getIndexOffset(index, shapeInfo, length); return shape::getIndexOffset(index, shapeInfo);
} }
static Nd4jLong __device__ __noinline__ _length(Nd4jLong *shapeInfo) { static Nd4jLong __device__ __noinline__ _length(Nd4jLong *shapeInfo) {
@ -154,9 +154,9 @@ namespace functions {
else { else {
// it is expected that x and z tads and y array all have the same length // it is expected that x and z tads and y array all have the same length
for (Nd4jLong i = threadIdx.x; i < tadLength; i+= blockDim.x) { for (Nd4jLong i = threadIdx.x; i < tadLength; i+= blockDim.x) {
auto xOffset = _getIndexOffset(i, xShapeInfo, tadLength); auto xOffset = _getIndexOffset(i, xShapeInfo);
auto yOffset = _getIndexOffset(i, tadOnlyShapeInfo, tadLength); auto yOffset = _getIndexOffset(i, tadOnlyShapeInfo);
auto zOffset = _getIndexOffset(i, tadOnlyShapeInfoZ, tadLength); auto zOffset = _getIndexOffset(i, tadOnlyShapeInfoZ);
rZ[zOffset] = OpType::op(x[xOffset], rY[yOffset]); rZ[zOffset] = OpType::op(x[xOffset], rY[yOffset]);
} }
} }
@ -170,14 +170,14 @@ namespace functions {
void *vx, Nd4jLong *xShapeInfo, void *vx, Nd4jLong *xShapeInfo,
void *vy, Nd4jLong *yShapeInfo, void *vy, Nd4jLong *yShapeInfo,
void *vz, Nd4jLong *zShapeInfo, void *vz, Nd4jLong *zShapeInfo,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
if (tadOnlyShapeInfoZ == nullptr) { if (tadOnlyShapeInfoZ == nullptr) {
tadOnlyShapeInfoZ = tadOnlyShapeInfo; tadOnlyShapeInfoZ = tadOnlyShapeInfo;
tadOffsetsZ = tadOffsets; tadOffsetsZ = tadOffsets;
} }
auto x = reinterpret_cast<X*>(vx); auto x = reinterpret_cast<X*>(vx);
auto y = reinterpret_cast<Y*>(vy); auto y = reinterpret_cast<Y*>(vy);
auto z = reinterpret_cast<Z*>(vz); auto z = reinterpret_cast<Z*>(vz);
@ -212,16 +212,16 @@ namespace functions {
if(tadEWS > 0 && zEWS > 0 && yEWS > 0 && xOrder == yOrder && xOrder == zOrder) { if(tadEWS > 0 && zEWS > 0 && yEWS > 0 && xOrder == yOrder && xOrder == zOrder) {
for (int i = threadIdx.x; i < tadLength; i+= blockDim.x) for (int i = threadIdx.x; i < tadLength; i+= blockDim.x)
rZ[i * zEWS] = OpType::op(rX[i * tadEWS], y[i * yEWS]); rZ[i * zEWS] = OpType::op(rX[i * tadEWS], y[i * yEWS]);
} }
else { else {
// it is expected that x and z tads and y array all have the same length // it is expected that x and z tads and y array all have the same length
for (Nd4jLong i = threadIdx.x; i < tadLength; i+= blockDim.x) { for (Nd4jLong i = threadIdx.x; i < tadLength; i+= blockDim.x) {
auto xOffset = _getIndexOffset(i, tadOnlyShapeInfo, tadLength); auto xOffset = _getIndexOffset(i, tadOnlyShapeInfo);
auto yOffset = _getIndexOffset(i, yShapeInfo, tadLength); auto yOffset = _getIndexOffset(i, yShapeInfo);
auto zOffset = _getIndexOffset(i, tadOnlyShapeInfoZ, tadLength); auto zOffset = _getIndexOffset(i, tadOnlyShapeInfoZ);
rZ[zOffset] = OpType::op(rX[xOffset], y[yOffset]); rZ[zOffset] = OpType::op(rX[xOffset], y[yOffset]);
} }
} }

View File

@ -42,7 +42,7 @@ static __global__ void broadcastBoolSimple(
Nd4jLong *zShapeInfo, Nd4jLong *zShapeInfo,
int *dimension, int *dimension,
int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
functions::broadcast::BroadcastBool<X, Z>::template transformCuda<OpClass>(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo,dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ); functions::broadcast::BroadcastBool<X, Z>::template transformCuda<OpClass>(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo,dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ);
} }
@ -145,9 +145,9 @@ namespace functions {
else { else {
// it is expected that x and z tads and y array all have the same length // it is expected that x and z tads and y array all have the same length
for (Nd4jLong i = threadIdx.x; i < tadLength; i+= blockDim.x) { for (Nd4jLong i = threadIdx.x; i < tadLength; i+= blockDim.x) {
auto xOffset = shape::getIndexOffset(i, xShapeInfo, tadLength); auto xOffset = shape::getIndexOffset(i, xShapeInfo);
auto yOffset = shape::getIndexOffset(i, tadOnlyShapeInfo, tadLength); auto yOffset = shape::getIndexOffset(i, tadOnlyShapeInfo);
auto zOffset = shape::getIndexOffset(i, tadOnlyShapeInfoZ, tadLength); auto zOffset = shape::getIndexOffset(i, tadOnlyShapeInfoZ);
rZ[zOffset] = OpType::op(x[xOffset], rY[yOffset]); rZ[zOffset] = OpType::op(x[xOffset], rY[yOffset]);
} }
@ -183,13 +183,13 @@ namespace functions {
__shared__ int numTads; __shared__ int numTads;
__shared__ Nd4jLong yEWS; __shared__ Nd4jLong yEWS;
__shared__ Nd4jLong zEWS; __shared__ Nd4jLong zEWS;
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
tadLength = shape::length(tadOnlyShapeInfo);//shape::tadLength(xShapeInfo, dimension, dimensionLength); tadLength = shape::length(tadOnlyShapeInfo);//shape::tadLength(xShapeInfo, dimension, dimensionLength);
tadEWS = shape::elementWiseStride(tadOnlyShapeInfo); tadEWS = shape::elementWiseStride(tadOnlyShapeInfo);
numTads = shape::length(xShapeInfo) / tadLength; numTads = shape::length(xShapeInfo) / tadLength;
yEWS = shape::elementWiseStride(yShapeInfo); yEWS = shape::elementWiseStride(yShapeInfo);
zEWS = shape::elementWiseStride(tadOnlyShapeInfoZ); zEWS = shape::elementWiseStride(tadOnlyShapeInfoZ);
} }
__syncthreads(); __syncthreads();
@ -213,9 +213,9 @@ namespace functions {
else { else {
// it is expected that x and z tads and y array all have the same length // it is expected that x and z tads and y array all have the same length
for (Nd4jLong i = threadIdx.x; i < tadLength; i+= blockDim.x) { for (Nd4jLong i = threadIdx.x; i < tadLength; i+= blockDim.x) {
auto xOffset = shape::getIndexOffset(i, tadOnlyShapeInfo, tadLength); auto xOffset = shape::getIndexOffset(i, tadOnlyShapeInfo);
auto yOffset = shape::getIndexOffset(i, yShapeInfo, tadLength); auto yOffset = shape::getIndexOffset(i, yShapeInfo);
auto zOffset = shape::getIndexOffset(i, tadOnlyShapeInfoZ, tadLength); auto zOffset = shape::getIndexOffset(i, tadOnlyShapeInfoZ);
rZ[zOffset] = OpType::op(rX[xOffset], y[yOffset]); rZ[zOffset] = OpType::op(rX[xOffset], y[yOffset]);
} }

View File

@ -42,7 +42,7 @@ static __global__ void broadcastIntSimple(
Nd4jLong *zShapeInfo, Nd4jLong *zShapeInfo,
int *dimension, int *dimension,
int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
functions::broadcast::BroadcastInt<X>::template transformCuda<OpClass>(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo,dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ); functions::broadcast::BroadcastInt<X>::template transformCuda<OpClass>(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo,dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ);
} }
@ -139,9 +139,9 @@ namespace functions {
else { else {
// it is expected that x and z tads and y array all have the same length // it is expected that x and z tads and y array all have the same length
for (Nd4jLong i = threadIdx.x; i < tadLength; i+= blockDim.x) { for (Nd4jLong i = threadIdx.x; i < tadLength; i+= blockDim.x) {
auto xOffset = shape::getIndexOffset(i, xShapeInfo, tadLength); auto xOffset = shape::getIndexOffset(i, xShapeInfo);
auto yOffset = shape::getIndexOffset(i, tadOnlyShapeInfo, tadLength); auto yOffset = shape::getIndexOffset(i, tadOnlyShapeInfo);
auto zOffset = shape::getIndexOffset(i, tadOnlyShapeInfoZ, tadLength); auto zOffset = shape::getIndexOffset(i, tadOnlyShapeInfoZ);
rZ[zOffset] = OpType::op(x[xOffset], rY[yOffset]); rZ[zOffset] = OpType::op(x[xOffset], rY[yOffset]);
} }
@ -177,13 +177,13 @@ namespace functions {
__shared__ int numTads; __shared__ int numTads;
__shared__ Nd4jLong yEWS; __shared__ Nd4jLong yEWS;
__shared__ Nd4jLong zEWS; __shared__ Nd4jLong zEWS;
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
tadLength = shape::length(tadOnlyShapeInfo);//shape::tadLength(xShapeInfo, dimension, dimensionLength); tadLength = shape::length(tadOnlyShapeInfo);//shape::tadLength(xShapeInfo, dimension, dimensionLength);
tadEWS = shape::elementWiseStride(tadOnlyShapeInfo); tadEWS = shape::elementWiseStride(tadOnlyShapeInfo);
numTads = shape::length(xShapeInfo) / tadLength; numTads = shape::length(xShapeInfo) / tadLength;
yEWS = shape::elementWiseStride(yShapeInfo); yEWS = shape::elementWiseStride(yShapeInfo);
zEWS = shape::elementWiseStride(tadOnlyShapeInfoZ); zEWS = shape::elementWiseStride(tadOnlyShapeInfoZ);
} }
__syncthreads(); __syncthreads();
@ -207,9 +207,9 @@ namespace functions {
else { else {
// it is expected that x and z tads and y array all have the same length // it is expected that x and z tads and y array all have the same length
for (Nd4jLong i = threadIdx.x; i < tadLength; i+= blockDim.x) { for (Nd4jLong i = threadIdx.x; i < tadLength; i+= blockDim.x) {
auto xOffset = shape::getIndexOffset(i, tadOnlyShapeInfo, tadLength); auto xOffset = shape::getIndexOffset(i, tadOnlyShapeInfo);
auto yOffset = shape::getIndexOffset(i, yShapeInfo, tadLength); auto yOffset = shape::getIndexOffset(i, yShapeInfo);
auto zOffset = shape::getIndexOffset(i, tadOnlyShapeInfoZ, tadLength); auto zOffset = shape::getIndexOffset(i, tadOnlyShapeInfoZ);
rZ[zOffset] = OpType::op(rX[xOffset], y[yOffset]); rZ[zOffset] = OpType::op(rX[xOffset], y[yOffset]);
} }

View File

@ -246,12 +246,12 @@ namespace functions {
if (dimensionLength > 1 || tadEWS < 1) { if (dimensionLength > 1 || tadEWS < 1) {
for (int r = blockIdx.x; r < numTads; r += gridDim.x) { for (int r = blockIdx.x; r < numTads; r += gridDim.x) {
auto tadOffsetForBlock = tadOffsets[r]; auto tadOffsetForBlock = tadOffsets[r];
sPartials[threadIdx.x] = OpType::startingIndexValue(dx); sPartials[threadIdx.x] = OpType::startingIndexValue(dx);
for(int i = threadIdx.x;i < tadLength; i += blockDim.x) { for(int i = threadIdx.x;i < tadLength; i += blockDim.x) {
auto xOffset = tadOffsetForBlock + shape::getIndexOffset(i, tadOnlyShapeInfo, tadLength); auto xOffset = tadOffsetForBlock + shape::getIndexOffset(i, tadOnlyShapeInfo);
IndexValue<X> comp {dx[xOffset], i}; IndexValue<X> comp {dx[xOffset], i};
sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], comp, extraParams); sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], comp, extraParams);
} }
@ -297,9 +297,9 @@ namespace functions {
reduction = OpType::update(reduction, indexVal, extraParams); reduction = OpType::update(reduction, indexVal, extraParams);
} }
} else { } else {
for(Nd4jLong i = tid;i < n; i += blockDim.x * gridDim.x) { for(Nd4jLong i = tid;i < n; i += blockDim.x * gridDim.x) {
auto offset = shape::getIndexOffset(i, xShapeInfo, n); auto offset = shape::getIndexOffset(i, xShapeInfo);
IndexValue<X> indexVal = {dx[offset], i}; IndexValue<X> indexVal = {dx[offset], i};
reduction = OpType::update(reduction, indexVal, extraParams); reduction = OpType::update(reduction, indexVal, extraParams);
} }

View File

@ -115,7 +115,7 @@ namespace functions {
sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::op(x[i * xEws], extraParams), extraParams); sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::op(x[i * xEws], extraParams), extraParams);
else else
for (int i = tid; i < len; i += blockDim.x * gridDim.x) for (int i = tid; i < len; i += blockDim.x * gridDim.x)
sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::op(x[shape::getIndexOffset(i, xShapeInfo, len)], extraParams), extraParams); sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::op(x[shape::getIndexOffset(i, xShapeInfo)], extraParams), extraParams);
__syncthreads(); __syncthreads();
aggregatePartials<OpType>(sPartials, threadIdx.x, nd4j::math::nd4j_min<int>(blockDim.x, len), extraParams); aggregatePartials<OpType>(sPartials, threadIdx.x, nd4j::math::nd4j_min<int>(blockDim.x, len), extraParams);

View File

@ -73,7 +73,7 @@ namespace functions {
for (Nd4jLong i = tid; i < length; i+= totalThreads) { for (Nd4jLong i = tid; i < length; i+= totalThreads) {
z[shape::getIndexOffset(i, zShapeInfo, length)] = OpType::op(y[shape::getIndexOffset(i, yShapeInfo, length)], scalar, params); z[shape::getIndexOffset(i, zShapeInfo)] = OpType::op(y[shape::getIndexOffset(i, yShapeInfo)], scalar, params);
} }
} }
} }

View File

@ -72,8 +72,8 @@ namespace functions {
for (Nd4jLong i = tid; i < length; i+= gridDim.x * blockDim.x) { for (Nd4jLong i = tid; i < length; i+= gridDim.x * blockDim.x) {
auto xOffset2 = shape::getIndexOffset(i, shapeInfo, length); auto xOffset2 = shape::getIndexOffset(i, shapeInfo);
auto zOffset2 = shape::getIndexOffset(i, zShapeInfo, length); auto zOffset2 = shape::getIndexOffset(i, zShapeInfo);
result[zOffset2] = OpType::op(dy[xOffset2], params); result[zOffset2] = OpType::op(dy[xOffset2], params);
} }
} }

View File

@ -169,7 +169,7 @@ namespace functions {
template <> template <>
_CUDA_H void ReduceFunction<float>::execReduceScalar(dim3 launchDims, cudaStream_t *stream, int opNum, float *x, Nd4jLong *xShapeInfo, float *extraParams, float *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, float *reductionBuffer, Nd4jLong *tadOnlyShapeInfo) { _CUDA_H void ReduceFunction<float>::execReduceScalar(dim3 launchDims, cudaStream_t *stream, int opNum, float *x, Nd4jLong *xShapeInfo, float *extraParams, float *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, float *reductionBuffer, Nd4jLong *tadOnlyShapeInfo) {
DISPATCH_SIMPLE(reduceScalarSimple, float, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, nullptr, 1, reductionBuffer, tadOnlyShapeInfo), OPS_A(REDUCE_OPS)) DISPATCH_SIMPLE(reduceScalarSimple, float, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, nullptr, 1, reductionBuffer, tadOnlyShapeInfo), OPS_A(REDUCE_OPS))
nd4j::DebugHelper::checkErrorCode(stream, "execReduceScalarFloat(...) failed"); nd4j::DebugHelper::checkErrorCode(stream, "execReduceScalarFloat(...) failed");
@ -177,7 +177,7 @@ namespace functions {
template <> template <>
_CUDA_H void ReduceFunction<float16>::execReduceScalar(dim3 launchDims, cudaStream_t *stream, int opNum, float16 *x, Nd4jLong *xShapeInfo, float16 *extraParams, float16 *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, float16 *reductionBuffer, Nd4jLong *tadOnlyShapeInfo) { _CUDA_H void ReduceFunction<float16>::execReduceScalar(dim3 launchDims, cudaStream_t *stream, int opNum, float16 *x, Nd4jLong *xShapeInfo, float16 *extraParams, float16 *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, float16 *reductionBuffer, Nd4jLong *tadOnlyShapeInfo) {
DISPATCH_SIMPLE(reduceScalarSimple, float16, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, nullptr, 1, reductionBuffer, tadOnlyShapeInfo), OPS_A(REDUCE_OPS)) DISPATCH_SIMPLE(reduceScalarSimple, float16, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, nullptr, 1, reductionBuffer, tadOnlyShapeInfo), OPS_A(REDUCE_OPS))
nd4j::DebugHelper::checkErrorCode(stream, "execReduceScalarHalf(...) failed"); nd4j::DebugHelper::checkErrorCode(stream, "execReduceScalarHalf(...) failed");
@ -185,7 +185,7 @@ namespace functions {
template <> template <>
_CUDA_H void ReduceFunction<double>::execReduceScalar(dim3 launchDims, cudaStream_t *stream, int opNum, double *x, Nd4jLong *xShapeInfo, double *extraParams, double *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, double *reductionBuffer, Nd4jLong *tadOnlyShapeInfo) { _CUDA_H void ReduceFunction<double>::execReduceScalar(dim3 launchDims, cudaStream_t *stream, int opNum, double *x, Nd4jLong *xShapeInfo, double *extraParams, double *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, double *reductionBuffer, Nd4jLong *tadOnlyShapeInfo) {
DISPATCH_SIMPLE(reduceScalarSimple, double, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, nullptr, 1, reductionBuffer, tadOnlyShapeInfo), OPS_A(REDUCE_OPS)) DISPATCH_SIMPLE(reduceScalarSimple, double, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, nullptr, 1, reductionBuffer, tadOnlyShapeInfo), OPS_A(REDUCE_OPS))
nd4j::DebugHelper::checkErrorCode(stream, "execReduceScalarDouble(...) failed"); nd4j::DebugHelper::checkErrorCode(stream, "execReduceScalarDouble(...) failed");
@ -294,7 +294,7 @@ namespace functions {
for (int i = threadIdx.x; i < tadLength; i += blockDim.x) { for (int i = threadIdx.x; i < tadLength; i += blockDim.x) {
shape::ind2subC(tadRank, tadShape, i, tadLength, xCoord); shape::ind2subC(tadRank, tadShape, i, tadLength, xCoord);
auto xOffset = shape::getOffset(tadOffsetForBlock, tadShape, tadStride, xCoord, tadRank); auto xOffset = shape::getOffset(tadOnlyShapeInfo, xCoord);
sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::op(dx[xOffset], extraParams), extraParams); sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::op(dx[xOffset], extraParams), extraParams);
} }
@ -358,7 +358,7 @@ namespace functions {
for (int i = tid; i < n; i += blockDim.x * gridDim.x) { for (int i = tid; i < n; i += blockDim.x * gridDim.x) {
shape::ind2subC(rank, xShape, i, n, ind2sub); shape::ind2subC(rank, xShape, i, n, ind2sub);
auto offset = shape::getOffset(0, xShape, xStride, ind2sub, rank); auto offset = shape::getOffset(xShapeInfo, ind2sub);
sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::op(dx[offset], extraParams), extraParams); sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::op(dx[offset], extraParams), extraParams);
} }
} }
@ -461,7 +461,7 @@ namespace functions {
for (int i = threadIdx.x; i < tadLength; i += blockDim.x) { for (int i = threadIdx.x; i < tadLength; i += blockDim.x) {
shape::ind2subC(tadRank, tadShape, i, tadLength, xCoord); shape::ind2subC(tadRank, tadShape, i, tadLength, xCoord);
auto xOffset = shape::getOffset(tadOffsetForBlock, tadShape, tadStride, xCoord, tadRank); auto xOffset = shape::getOffset(tadOnlyShapeInfo, xCoord);
sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::op(dx[xOffset], extraParams), extraParams); sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::op(dx[xOffset], extraParams), extraParams);
} }
@ -526,7 +526,7 @@ namespace functions {
for (int i = threadIdx.x; i < tadLength; i += blockDim.x) { for (int i = threadIdx.x; i < tadLength; i += blockDim.x) {
shape::ind2subC(tadRank, tadShape, i, tadLength, xCoord); shape::ind2subC(tadRank, tadShape, i, tadLength, xCoord);
auto xOffset = shape::getOffset(tadOffsetForBlock, tadShape, tadStride, xCoord, tadRank); auto xOffset = shape::getOffset(tadOnlyShapeInfo, xCoord);
sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::op(dx[xOffset], extraParams), extraParams); sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::op(dx[xOffset], extraParams), extraParams);
} }

View File

@ -88,8 +88,8 @@ static inline __device__ void transformCuda(T scalar, T *dy, int *shapeInfo, T *
for (Nd4jLong i = tid; i < length; i+= totalThreads) { for (Nd4jLong i = tid; i < length; i+= totalThreads) {
shape::ind2sub(xRank, xShape, i, length, xIdx); shape::ind2sub(xRank, xShape, i, length, xIdx);
int xOffset2 = shape::getOffset(0, xShape, xStride, xIdx, xRank); int xOffset2 = shape::getOffset(shapeInfo, xIdx);
int resultOffset = shape::getOffset(0, zShape, zStride, xIdx, zRank); int resultOffset = shape::getOffset(0resultShapeInfo, xIdx);
result[resultOffset] = OpType::op(dy[xOffset2],scalar, params); result[resultOffset] = OpType::op(dy[xOffset2],scalar, params);
} }
} }

View File

@ -111,7 +111,7 @@ __device__ void transformSimpleGeneric(
manager->init(sizeof(UnifiedSharedMemory), 0, sizeof(functions::transform::Transform<T>), sizeof(shape::TAD), xRank); manager->init(sizeof(UnifiedSharedMemory), 0, sizeof(functions::transform::Transform<T>), sizeof(shape::TAD), xRank);
} }
__syncthreads(); __syncthreads();
functions::transform::Transform<T>::template transformCuda<OpClass>( functions::transform::Transform<T>::template transformCuda<OpClass>(
dy, dy,
xShapeInfo, xShapeInfo,
@ -161,7 +161,7 @@ namespace functions {
template <> template <>
_CUDA_H void Transform<float>::executeTransformShaped(dim3 launchDims, cudaStream_t *stream, int opNum, float *x, Nd4jLong *xShape, int xRank, float *extraParams, float *z, Nd4jLong *zShape, int zRank, int *allocationPointer, float *reductionPointer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { _CUDA_H void Transform<float>::executeTransformShaped(dim3 launchDims, cudaStream_t *stream, int opNum, float *x, Nd4jLong *xShape, int xRank, float *extraParams, float *z, Nd4jLong *zShape, int zRank, int *allocationPointer, float *reductionPointer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
DISPATCH_SIMPLE(transformShaped, float, PARAMS(x, xShape, xRank, extraParams, z, zShape, zRank, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets), OPS_A(TRANSFORM_OPS)) DISPATCH_SIMPLE(transformShaped, float, PARAMS(x, xShape, xRank, extraParams, z, zShape, zRank, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets), OPS_A(TRANSFORM_OPS))
@ -170,16 +170,16 @@ namespace functions {
template <> template <>
_CUDA_H void Transform<float16>::executeTransformShaped(dim3 launchDims, cudaStream_t *stream, int opNum, float16 *x, Nd4jLong *xShape, int xRank, float16 *extraParams, float16 *z, Nd4jLong *zShape, int zRank, int *allocationPointer, float16 *reductionPointer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { _CUDA_H void Transform<float16>::executeTransformShaped(dim3 launchDims, cudaStream_t *stream, int opNum, float16 *x, Nd4jLong *xShape, int xRank, float16 *extraParams, float16 *z, Nd4jLong *zShape, int zRank, int *allocationPointer, float16 *reductionPointer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
DISPATCH_SIMPLE(transformShaped, float16, PARAMS(x, xShape, xRank, extraParams, z, zShape, zRank, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets), OPS_A(TRANSFORM_OPS)) DISPATCH_SIMPLE(transformShaped, float16, PARAMS(x, xShape, xRank, extraParams, z, zShape, zRank, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets), OPS_A(TRANSFORM_OPS))
if (nd4j::Environment::getInstance()->isDebug()) if (nd4j::Environment::getInstance()->isDebug())
checkCudaErrors(cudaStreamSynchronize(*stream)); checkCudaErrors(cudaStreamSynchronize(*stream));
} }
template <> template <>
_CUDA_H void Transform<double>::executeTransformShaped(dim3 launchDims, cudaStream_t *stream, int opNum, double *x, Nd4jLong *xShape, int xRank, double *extraParams, double *z, Nd4jLong *zShape, int zRank, int *allocationPointer, double *reductionPointer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { _CUDA_H void Transform<double>::executeTransformShaped(dim3 launchDims, cudaStream_t *stream, int opNum, double *x, Nd4jLong *xShape, int xRank, double *extraParams, double *z, Nd4jLong *zShape, int zRank, int *allocationPointer, double *reductionPointer, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
DISPATCH_SIMPLE(transformShaped, double, PARAMS(x, xShape, xRank, extraParams, z, zShape, zRank, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets), OPS_A(TRANSFORM_OPS)) DISPATCH_SIMPLE(transformShaped, double, PARAMS(x, xShape, xRank, extraParams, z, zShape, zRank, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets), OPS_A(TRANSFORM_OPS))
DEBUG_KERNEL(stream, opNum); DEBUG_KERNEL(stream, opNum);
@ -226,13 +226,13 @@ namespace functions {
} }
else { else {
Nd4jLong xCoord[MAX_RANK]; Nd4jLong xCoord[MAX_RANK];
for (Nd4jLong i = tid; i < length; i+= gridDim.x * blockDim.x) { for (Nd4jLong i = tid; i < length; i+= gridDim.x * blockDim.x) {
shape::ind2sub(xRank,shape::shapeOf(shapeInfo),i, length, xCoord); shape::ind2sub(xRank,shape::shapeOf(shapeInfo),i, length, xCoord);
auto xOffset2 = shape::getOffset(0, xShape, xStride, xCoord, xRank); auto xOffset2 = shape::getOffset(shapeInfo, xCoord);
auto resultOffset2 = shape::getOffset(0,xShape,shape::stride(resultShapeInfo),xCoord,xRank); auto resultOffset2 = shape::getOffset(resultShapeInfo, xCoord);
result[resultOffset2] = OpType::op(dy[xOffset2], params); result[resultOffset2] = OpType::op(dy[xOffset2], params);
} }
} }
@ -249,7 +249,7 @@ namespace functions {
T *result, T *result,
Nd4jLong resultStride, Nd4jLong resultStride,
int *allocationPointer, T *reductionPointer, UnifiedSharedMemory *manager) { int *allocationPointer, T *reductionPointer, UnifiedSharedMemory *manager) {
int totalThreads = gridDim.x * blockDim.x; int totalThreads = gridDim.x * blockDim.x;
Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x;

View File

@ -28,11 +28,11 @@ using namespace simdOps;
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
template <typename X, typename Y, typename Z, typename OpType> template <typename X, typename Y, typename Z, typename OpType>
__global__ static void pairwiseSimpleShaped(void* vx, Nd4jLong *xShapeInfo, __global__ static void pairwiseSimpleShaped(void* vx, Nd4jLong *xShapeInfo,
void *vy, Nd4jLong *yShapeInfo, void *vy, Nd4jLong *yShapeInfo,
void *vz, Nd4jLong *zShapeInfo, void *vz, Nd4jLong *zShapeInfo,
void *vextraParams) { void *vextraParams) {
auto x = reinterpret_cast<X*>(vx); auto x = reinterpret_cast<X*>(vx);
auto y = reinterpret_cast<Y*>(vy); auto y = reinterpret_cast<Y*>(vy);
auto z = reinterpret_cast<Z*>(vz); auto z = reinterpret_cast<Z*>(vz);
@ -67,17 +67,17 @@ __global__ static void pairwiseSimpleShaped(void* vx, Nd4jLong *xShapeInfo,
} }
else if (vx == vz) { else if (vx == vz) {
for (Nd4jLong i = tid; i < len; i += gridDim.x * blockDim.x) { for (Nd4jLong i = tid; i < len; i += gridDim.x * blockDim.x) {
auto xOffset = shape::getIndexOffset(i, xShapeInfo, len); auto xOffset = shape::getIndexOffset(i, xShapeInfo);
auto yOffset = shape::getIndexOffset(i, yShapeInfo, len); auto yOffset = shape::getIndexOffset(i, yShapeInfo);
z[xOffset] = OpType::op(x[xOffset], y[yOffset], extraParams); z[xOffset] = OpType::op(x[xOffset], y[yOffset], extraParams);
} }
} }
else { else {
for (Nd4jLong i = tid; i < len; i += gridDim.x * blockDim.x) { for (Nd4jLong i = tid; i < len; i += gridDim.x * blockDim.x) {
auto xOffset = shape::getIndexOffset(i, xShapeInfo, len); auto xOffset = shape::getIndexOffset(i, xShapeInfo);
auto yOffset = shape::getIndexOffset(i, yShapeInfo, len); auto yOffset = shape::getIndexOffset(i, yShapeInfo);
auto zOffset = shape::getIndexOffset(i, zShapeInfo, len); auto zOffset = shape::getIndexOffset(i, zShapeInfo);
z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams); z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams);
} }

View File

@ -67,17 +67,17 @@ __global__ static void pairwiseSimpleShaped(void* vx, Nd4jLong *xShapeInfo,
} }
else if (vx == vz) { else if (vx == vz) {
for (Nd4jLong i = tid; i < len; i += gridDim.x * blockDim.x) { for (Nd4jLong i = tid; i < len; i += gridDim.x * blockDim.x) {
auto xOffset = shape::getIndexOffset(i, xShapeInfo, len); auto xOffset = shape::getIndexOffset(i, xShapeInfo);
auto yOffset = shape::getIndexOffset(i, yShapeInfo, len); auto yOffset = shape::getIndexOffset(i, yShapeInfo);
z[xOffset] = OpType::op(x[xOffset], y[yOffset], extraParams); z[xOffset] = OpType::op(x[xOffset], y[yOffset], extraParams);
} }
} }
else { else {
for (Nd4jLong i = tid; i < len; i += gridDim.x * blockDim.x) { for (Nd4jLong i = tid; i < len; i += gridDim.x * blockDim.x) {
auto xOffset = shape::getIndexOffset(i, xShapeInfo, len); auto xOffset = shape::getIndexOffset(i, xShapeInfo);
auto yOffset = shape::getIndexOffset(i, yShapeInfo, len); auto yOffset = shape::getIndexOffset(i, yShapeInfo);
auto zOffset = shape::getIndexOffset(i, zShapeInfo, len); auto zOffset = shape::getIndexOffset(i, zShapeInfo);
z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams); z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams);
} }
@ -105,7 +105,7 @@ void _CUDA_H PairWiseBoolTransform<X,Z>::intermediateShaped(dim3& launchDims, cu
template<typename X, typename Y> template<typename X, typename Y>
void PairWiseBoolTransform<X,Y>::executeCudaShaped(dim3& launchDims, cudaStream_t *stream, int opNum, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, void *vextraParams) { void PairWiseBoolTransform<X,Y>::executeCudaShaped(dim3& launchDims, cudaStream_t *stream, int opNum, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, void *vextraParams) {
auto xType = nd4j::DataTypeUtils::fromT<X>(); auto xType = nd4j::DataTypeUtils::fromT<X>();
auto yType = nd4j::DataTypeUtils::fromT<Y>(); auto yType = nd4j::DataTypeUtils::fromT<Y>();
DISPATCH_BY_OPNUM_TT(intermediateShaped, PARAMS(launchDims, stream, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vextraParams), PAIRWISE_BOOL_OPS); DISPATCH_BY_OPNUM_TT(intermediateShaped, PARAMS(launchDims, stream, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vextraParams), PAIRWISE_BOOL_OPS);
} }
@ -166,7 +166,7 @@ void PairWiseBoolTransform<X,Y>::executeCudaShaped(dim3& launchDims, cudaStream_
} }
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT PairWiseBoolTransform, , LIBND4J_TYPES, BOOL_TYPES); BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT PairWiseBoolTransform, , LIBND4J_TYPES, BOOL_TYPES);
} }
} }

View File

@ -67,17 +67,17 @@ __global__ static void pairwiseSimpleShaped(void* vx, Nd4jLong *xShapeInfo,
} }
else if (vx == vz) { else if (vx == vz) {
for (Nd4jLong i = tid; i < len; i += gridDim.x * blockDim.x) { for (Nd4jLong i = tid; i < len; i += gridDim.x * blockDim.x) {
auto xOffset = shape::getIndexOffset(i, xShapeInfo, len); auto xOffset = shape::getIndexOffset(i, xShapeInfo);
auto yOffset = shape::getIndexOffset(i, yShapeInfo, len); auto yOffset = shape::getIndexOffset(i, yShapeInfo);
z[xOffset] = OpType::op(x[xOffset], y[yOffset], extraParams); z[xOffset] = OpType::op(x[xOffset], y[yOffset], extraParams);
} }
} }
else { else {
for (Nd4jLong i = tid; i < len; i += gridDim.x * blockDim.x) { for (Nd4jLong i = tid; i < len; i += gridDim.x * blockDim.x) {
auto xOffset = shape::getIndexOffset(i, xShapeInfo, len); auto xOffset = shape::getIndexOffset(i, xShapeInfo);
auto yOffset = shape::getIndexOffset(i, yShapeInfo, len); auto yOffset = shape::getIndexOffset(i, yShapeInfo);
auto zOffset = shape::getIndexOffset(i, zShapeInfo, len); auto zOffset = shape::getIndexOffset(i, zShapeInfo);
z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams); z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams);
} }
@ -165,7 +165,7 @@ void PairWiseIntTransform<X>::executeCudaShaped(dim3& launchDims, cudaStream_t *
} }
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT PairWiseIntTransform, , INTEGER_TYPES); BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT PairWiseIntTransform, , INTEGER_TYPES);
} }
} }

View File

@ -116,7 +116,7 @@ namespace functions {
auto y = reinterpret_cast<T*>(vy); auto y = reinterpret_cast<T*>(vy);
auto z = reinterpret_cast<T*>(vz); auto z = reinterpret_cast<T*>(vz);
auto extraArguments = reinterpret_cast<T*>(vextraArguments); auto extraArguments = reinterpret_cast<T*>(vextraArguments);
if (OpClass::requiresSpecial) { if (OpClass::requiresSpecial) {
OpClass::specialOpCuda(state, x, xShapeBuffer, y, yShapeBuffer, z, zShapeBuffer, extraArguments); OpClass::specialOpCuda(state, x, xShapeBuffer, y, yShapeBuffer, z, zShapeBuffer, extraArguments);
return; return;
@ -166,10 +166,10 @@ namespace functions {
} }
} else { } else {
for (Nd4jLong i = tid; i < length; i += blockDim.x * gridDim.x) { for (Nd4jLong i = tid; i < length; i += blockDim.x * gridDim.x) {
auto xOffset2 = shape::getIndexOffset(i, xShapeBuffer, length); auto xOffset2 = shape::getIndexOffset(i, xShapeBuffer);
auto yOffset2 = shape::getIndexOffset(i, yShapeBuffer, length); auto yOffset2 = shape::getIndexOffset(i, yShapeBuffer);
auto zOffset2 = shape::getIndexOffset(i, zShapeBuffer, length); auto zOffset2 = shape::getIndexOffset(i, zShapeBuffer);
z[zOffset2] = OpClass::op(x[xOffset2], y[yOffset2], i, length, buffer, extraArguments); z[zOffset2] = OpClass::op(x[xOffset2], y[yOffset2], i, length, buffer, extraArguments);
} }
@ -224,11 +224,11 @@ namespace functions {
z[e * zEWS] = OpClass::op(x[e * xEWS], e, length, buffer, extraArguments); z[e * zEWS] = OpClass::op(x[e * xEWS], e, length, buffer, extraArguments);
} }
} else { } else {
for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < length; i += blockDim.x * gridDim.x) { for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < length; i += blockDim.x * gridDim.x) {
auto xOffset2 = shape::getIndexOffset(i, xShapeBuffer, length); auto xOffset2 = shape::getIndexOffset(i, xShapeBuffer);
auto zOffset2 = shape::getIndexOffset(i, zShapeBuffer, length); auto zOffset2 = shape::getIndexOffset(i, zShapeBuffer);
z[zOffset2] = OpClass::op(x[xOffset2], i, length, buffer, extraArguments); z[zOffset2] = OpClass::op(x[xOffset2], i, length, buffer, extraArguments);
} }
@ -274,9 +274,9 @@ namespace functions {
z[i * ews] = OpClass::op(i, length, buffer, extraArguments); z[i * ews] = OpClass::op(i, length, buffer, extraArguments);
} }
} else { } else {
for (Nd4jLong i = tid; i < length; i += blockDim.x * gridDim.x) { for (Nd4jLong i = tid; i < length; i += blockDim.x * gridDim.x) {
auto zOffset2 = shape::getIndexOffset(i, zShapeBuffer, length); auto zOffset2 = shape::getIndexOffset(i, zShapeBuffer);
z[zOffset2] = OpClass::op(i, length, buffer, extraArguments); z[zOffset2] = OpClass::op(i, length, buffer, extraArguments);
} }
} }
@ -296,7 +296,7 @@ namespace functions {
template <> template <>
_CUDA_H void RandomFunction<float16>::executeCudaSingle(dim3& launchDims, cudaStream_t* stream, int opNum, Nd4jPointer stateHost, void *vz, Nd4jLong *zShapeBuffer, void *vextraArguments) { _CUDA_H void RandomFunction<float16>::executeCudaSingle(dim3& launchDims, cudaStream_t* stream, int opNum, Nd4jPointer stateHost, void *vz, Nd4jLong *zShapeBuffer, void *vextraArguments) {
auto z = reinterpret_cast<float16*>(vz); auto z = reinterpret_cast<float16*>(vz);
auto extraArguments = reinterpret_cast<float16*>(vextraArguments); auto extraArguments = reinterpret_cast<float16*>(vextraArguments);
@ -320,7 +320,7 @@ namespace functions {
template <> template <>
_CUDA_H void RandomFunction<double>::executeCudaSingle(dim3& launchDims, cudaStream_t *stream, int opNum, Nd4jPointer stateHost, void *vz, Nd4jLong *zShapeBuffer, void *vextraArguments) { _CUDA_H void RandomFunction<double>::executeCudaSingle(dim3& launchDims, cudaStream_t *stream, int opNum, Nd4jPointer stateHost, void *vz, Nd4jLong *zShapeBuffer, void *vextraArguments) {
auto z = reinterpret_cast<double*>(vz); auto z = reinterpret_cast<double*>(vz);
auto extraArguments = reinterpret_cast<double*>(vextraArguments); auto extraArguments = reinterpret_cast<double*>(vextraArguments);
@ -332,7 +332,7 @@ namespace functions {
template <> template <>
_CUDA_H void RandomFunction<float>::executeCudaDouble(dim3& launchDims, cudaStream_t* stream, int opNum, Nd4jPointer stateHost, void *vx, Nd4jLong *xShapeBuffer, void *vz, Nd4jLong *zShapeBuffer, void *vextraArguments) { _CUDA_H void RandomFunction<float>::executeCudaDouble(dim3& launchDims, cudaStream_t* stream, int opNum, Nd4jPointer stateHost, void *vx, Nd4jLong *xShapeBuffer, void *vz, Nd4jLong *zShapeBuffer, void *vextraArguments) {
auto x = reinterpret_cast<float*>(vx); auto x = reinterpret_cast<float*>(vx);
auto z = reinterpret_cast<float*>(vz); auto z = reinterpret_cast<float*>(vz);
auto extraArguments = reinterpret_cast<float*>(vextraArguments); auto extraArguments = reinterpret_cast<float*>(vextraArguments);
@ -346,7 +346,7 @@ namespace functions {
template <> template <>
_CUDA_H void RandomFunction<float16>::executeCudaDouble(dim3& launchDims, cudaStream_t* stream, int opNum, Nd4jPointer stateHost, void *vx, Nd4jLong *xShapeBuffer, void *vz, Nd4jLong *zShapeBuffer, void *vextraArguments) { _CUDA_H void RandomFunction<float16>::executeCudaDouble(dim3& launchDims, cudaStream_t* stream, int opNum, Nd4jPointer stateHost, void *vx, Nd4jLong *xShapeBuffer, void *vz, Nd4jLong *zShapeBuffer, void *vextraArguments) {
auto x = reinterpret_cast<float16*>(vx); auto x = reinterpret_cast<float16*>(vx);
auto z = reinterpret_cast<float16*>(vz); auto z = reinterpret_cast<float16*>(vz);
auto extraArguments = reinterpret_cast<float16*>(vextraArguments); auto extraArguments = reinterpret_cast<float16*>(vextraArguments);
@ -372,7 +372,7 @@ namespace functions {
template <> template <>
_CUDA_H void RandomFunction<double>::executeCudaDouble(dim3& launchDims, cudaStream_t* stream, int opNum, Nd4jPointer stateHost, void *vx, Nd4jLong *xShapeBuffer, void *vz, Nd4jLong *zShapeBuffer, void *vextraArguments) { _CUDA_H void RandomFunction<double>::executeCudaDouble(dim3& launchDims, cudaStream_t* stream, int opNum, Nd4jPointer stateHost, void *vx, Nd4jLong *xShapeBuffer, void *vz, Nd4jLong *zShapeBuffer, void *vextraArguments) {
auto x = reinterpret_cast<double*>(vx); auto x = reinterpret_cast<double*>(vx);
auto z = reinterpret_cast<double*>(vz); auto z = reinterpret_cast<double*>(vz);
auto extraArguments = reinterpret_cast<double*>(vextraArguments); auto extraArguments = reinterpret_cast<double*>(vextraArguments);
@ -385,7 +385,7 @@ namespace functions {
template <> template <>
_CUDA_H void RandomFunction<float>::executeCudaTriple(dim3& launchDims, cudaStream_t* stream, int opNum, Nd4jPointer stateHost, void *vx, Nd4jLong *xShapeBuffer, void *vy, Nd4jLong *yShapeBuffer, void *vz, Nd4jLong *zShapeBuffer, void *vextraArguments) { _CUDA_H void RandomFunction<float>::executeCudaTriple(dim3& launchDims, cudaStream_t* stream, int opNum, Nd4jPointer stateHost, void *vx, Nd4jLong *xShapeBuffer, void *vy, Nd4jLong *yShapeBuffer, void *vz, Nd4jLong *zShapeBuffer, void *vextraArguments) {
auto x = reinterpret_cast<float*>(vx); auto x = reinterpret_cast<float*>(vx);
auto y = reinterpret_cast<float*>(vy); auto y = reinterpret_cast<float*>(vy);
@ -400,7 +400,7 @@ namespace functions {
template <> template <>
_CUDA_H void RandomFunction<float16>::executeCudaTriple(dim3& launchDims, cudaStream_t* stream, int opNum, Nd4jPointer stateHost, void *vx, Nd4jLong *xShapeBuffer, void *vy, Nd4jLong *yShapeBuffer, void *vz, Nd4jLong *zShapeBuffer, void *vextraArguments) { _CUDA_H void RandomFunction<float16>::executeCudaTriple(dim3& launchDims, cudaStream_t* stream, int opNum, Nd4jPointer stateHost, void *vx, Nd4jLong *xShapeBuffer, void *vy, Nd4jLong *yShapeBuffer, void *vz, Nd4jLong *zShapeBuffer, void *vextraArguments) {
auto x = reinterpret_cast<float16*>(vx); auto x = reinterpret_cast<float16*>(vx);
auto y = reinterpret_cast<float16*>(vy); auto y = reinterpret_cast<float16*>(vy);
auto z = reinterpret_cast<float16*>(vz); auto z = reinterpret_cast<float16*>(vz);

View File

@ -129,7 +129,7 @@ __device__ void ReduceBoolFunction<X,Z>::transformCudaXD( void *vx, Nd4jLong *xS
for (int i = threadIdx.x; i < tadLength; i += blockDim.x) { for (int i = threadIdx.x; i < tadLength; i += blockDim.x) {
auto xOffset = tadOffsetForBlock + shape::getIndexOffset(i, tadOnlyShapeInfo, tadLength); auto xOffset = tadOffsetForBlock + shape::getIndexOffset(i, tadOnlyShapeInfo);
sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::op(x[xOffset], extraParams), extraParams); sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::op(x[xOffset], extraParams), extraParams);
} }
__syncthreads(); __syncthreads();
@ -140,7 +140,7 @@ __device__ void ReduceBoolFunction<X,Z>::transformCudaXD( void *vx, Nd4jLong *xS
__syncthreads(); __syncthreads();
if (threadIdx.x == 0) if (threadIdx.x == 0)
z[isPlainOutput ? r : shape::getIndexOffset(r, zShapeInfo, numTads)] = OpType::postProcess(sPartials[threadIdx.x], tadLength, extraParams); z[isPlainOutput ? r : shape::getIndexOffset(r, zShapeInfo)] = OpType::postProcess(sPartials[threadIdx.x], tadLength, extraParams);
} }
} }
@ -180,7 +180,7 @@ __device__ void ReduceBoolFunction<X,Z>::execScalarCuda(void *vx, Nd4jLong *xSha
sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::op(x[i * xEws], extraParams), extraParams); sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::op(x[i * xEws], extraParams), extraParams);
else else
for (int i = tid; i < len; i += blockDim.x * gridDim.x) for (int i = tid; i < len; i += blockDim.x * gridDim.x)
sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::op(x[shape::getIndexOffset(i, xShapeInfo, len)], extraParams), extraParams); sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::op(x[shape::getIndexOffset(i, xShapeInfo)], extraParams), extraParams);
__syncthreads(); __syncthreads();
aggregatePartials<OpType>(sPartials, threadIdx.x, nd4j::math::nd4j_min<int>(blockDim.x, len), extraParams); aggregatePartials<OpType>(sPartials, threadIdx.x, nd4j::math::nd4j_min<int>(blockDim.x, len), extraParams);

View File

@ -129,7 +129,7 @@ __device__ void ReduceFloatFunction<X,Z>::transformCudaXD( void *vx, Nd4jLong *x
sPartials[threadIdx.x] = OpType::startingValue(x + tadOffsetForBlock); sPartials[threadIdx.x] = OpType::startingValue(x + tadOffsetForBlock);
for (int i = threadIdx.x; i < tadLength; i += blockDim.x) { for (int i = threadIdx.x; i < tadLength; i += blockDim.x) {
auto xOffset = tadOffsetForBlock + shape::getIndexOffset(i, tadOnlyShapeInfo, tadLength); auto xOffset = tadOffsetForBlock + shape::getIndexOffset(i, tadOnlyShapeInfo);
sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::op(x[xOffset], extraParams), extraParams); sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::op(x[xOffset], extraParams), extraParams);
} }
__syncthreads(); __syncthreads();
@ -139,7 +139,7 @@ __device__ void ReduceFloatFunction<X,Z>::transformCudaXD( void *vx, Nd4jLong *x
__syncthreads(); __syncthreads();
if (threadIdx.x == 0) if (threadIdx.x == 0)
z[isPlainOutput ? r : shape::getIndexOffset(r, zShapeInfo, numTads)] = OpType::postProcess(sPartials[threadIdx.x], tadLength, extraParams); z[isPlainOutput ? r : shape::getIndexOffset(r, zShapeInfo)] = OpType::postProcess(sPartials[threadIdx.x], tadLength, extraParams);
} }
} }
@ -179,7 +179,7 @@ __device__ void ReduceFloatFunction<X,Z>::execScalarCuda(void *vx, Nd4jLong *xSh
sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::op(x[i * xEws], extraParams), extraParams); sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::op(x[i * xEws], extraParams), extraParams);
else else
for (int i = tid; i < len; i += blockDim.x * gridDim.x) for (int i = tid; i < len; i += blockDim.x * gridDim.x)
sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::op(x[shape::getIndexOffset(i, xShapeInfo, len)], extraParams), extraParams); sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::op(x[shape::getIndexOffset(i, xShapeInfo)], extraParams), extraParams);
__syncthreads(); __syncthreads();
aggregatePartials<OpType>(sPartials, threadIdx.x, nd4j::math::nd4j_min<int>(blockDim.x, len), extraParams); aggregatePartials<OpType>(sPartials, threadIdx.x, nd4j::math::nd4j_min<int>(blockDim.x, len), extraParams);

View File

@ -150,7 +150,7 @@ __device__ void ReduceLongFunction<X,Z>::transformCudaXD( void *vx, Nd4jLong *xS
sPartials[threadIdx.x] = OpType::startingValue(x + tadOffsetForBlock); sPartials[threadIdx.x] = OpType::startingValue(x + tadOffsetForBlock);
for (int i = threadIdx.x; i < tadLength; i += blockDim.x) { for (int i = threadIdx.x; i < tadLength; i += blockDim.x) {
auto xOffset = tadOffsetForBlock + shape::getIndexOffset(i, tadOnlyShapeInfo, tadLength); auto xOffset = tadOffsetForBlock + shape::getIndexOffset(i, tadOnlyShapeInfo);
sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::op(x[xOffset], extraParams), extraParams); sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::op(x[xOffset], extraParams), extraParams);
} }
__syncthreads(); __syncthreads();
@ -160,7 +160,7 @@ __device__ void ReduceLongFunction<X,Z>::transformCudaXD( void *vx, Nd4jLong *xS
__syncthreads(); __syncthreads();
if (threadIdx.x == 0) if (threadIdx.x == 0)
z[isPlainOutput ? r : shape::getIndexOffset(r, zShapeInfo, numTads)] = OpType::postProcess(sPartials[threadIdx.x], tadLength, extraParams); z[isPlainOutput ? r : shape::getIndexOffset(r, zShapeInfo)] = OpType::postProcess(sPartials[threadIdx.x], tadLength, extraParams);
} }
} }
@ -200,7 +200,7 @@ __device__ void ReduceLongFunction<X,Z>::execScalarCuda(void *vx, Nd4jLong *xSha
sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::op(x[i * xEws], extraParams), extraParams); sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::op(x[i * xEws], extraParams), extraParams);
else else
for (int i = tid; i < len; i += blockDim.x * gridDim.x) for (int i = tid; i < len; i += blockDim.x * gridDim.x)
sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::op(x[shape::getIndexOffset(i, xShapeInfo, len)], extraParams), extraParams); sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::op(x[shape::getIndexOffset(i, xShapeInfo)], extraParams), extraParams);
__syncthreads(); __syncthreads();
aggregatePartials<OpType>(sPartials, threadIdx.x, nd4j::math::nd4j_min<int>(blockDim.x, len), extraParams); aggregatePartials<OpType>(sPartials, threadIdx.x, nd4j::math::nd4j_min<int>(blockDim.x, len), extraParams);

View File

@ -139,7 +139,7 @@ __device__ void ReduceSameFunction<X>::transformCudaXD( void *vx, Nd4jLong *xSha
sPartials[threadIdx.x] = OpType::startingValue(x + tadOffsetForBlock); sPartials[threadIdx.x] = OpType::startingValue(x + tadOffsetForBlock);
for (int i = threadIdx.x; i < tadLength; i += blockDim.x) { for (int i = threadIdx.x; i < tadLength; i += blockDim.x) {
auto xOffset = tadOffsetForBlock + shape::getIndexOffset(i, tadOnlyShapeInfo, tadLength); auto xOffset = tadOffsetForBlock + shape::getIndexOffset(i, tadOnlyShapeInfo);
sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::op(x[xOffset], extraParams), extraParams); sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::op(x[xOffset], extraParams), extraParams);
} }
__syncthreads(); __syncthreads();
@ -149,7 +149,7 @@ __device__ void ReduceSameFunction<X>::transformCudaXD( void *vx, Nd4jLong *xSha
__syncthreads(); __syncthreads();
if (threadIdx.x == 0) if (threadIdx.x == 0)
z[isPlainOutput ? r : shape::getIndexOffset(r, zShapeInfo, numTads)] = OpType::postProcess(sPartials[threadIdx.x], tadLength, extraParams); z[isPlainOutput ? r : shape::getIndexOffset(r, zShapeInfo)] = OpType::postProcess(sPartials[threadIdx.x], tadLength, extraParams);
} }
} }
@ -197,7 +197,7 @@ __device__ void ReduceSameFunction<X>::execScalarCuda(void *vx, Nd4jLong *xShape
sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::op(x[i * xEws], extraParams), extraParams); sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::op(x[i * xEws], extraParams), extraParams);
else else
for (int i = tid; i < len; i += blockDim.x * gridDim.x) for (int i = tid; i < len; i += blockDim.x * gridDim.x)
sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::op(x[shape::getIndexOffset(i, xShapeInfo, len)], extraParams), extraParams); sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::op(x[shape::getIndexOffset(i, xShapeInfo)], extraParams), extraParams);
__syncthreads(); __syncthreads();
aggregatePartials<OpType>(sPartials, threadIdx.x, nd4j::math::nd4j_min<int>(blockDim.x, len), extraParams); aggregatePartials<OpType>(sPartials, threadIdx.x, nd4j::math::nd4j_min<int>(blockDim.x, len), extraParams);

View File

@ -161,8 +161,8 @@ __device__ void Reduce3<X,Z>::execScalarCuda( void *vx, Nd4jLong *xShapeInfo,
sPartials[threadIdx.x] = OpType::startingValue(x); sPartials[threadIdx.x] = OpType::startingValue(x);
auto threadCount = gridDim.x * blockDim.x; auto threadCount = gridDim.x * blockDim.x;
for(Nd4jLong i = tid; i < length; i += threadCount) { for(Nd4jLong i = tid; i < length; i += threadCount) {
auto xOffset = shape::getIndexOffset(i, xShapeInfo, length); auto xOffset = shape::getIndexOffset(i, xShapeInfo);
auto yOffset = shape::getIndexOffset(i, yShapeInfo, length); auto yOffset = shape::getIndexOffset(i, yShapeInfo);
sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::opAtomic(x[xOffset], y[yOffset], extraZ), extraZ); sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::opAtomic(x[xOffset], y[yOffset], extraZ), extraZ);
} }
} }
@ -290,7 +290,7 @@ __device__ void Reduce3<X,Z>::transformAll( void *vx, Nd4jLong *xShapeInfo,
X *x = dx + xOffsets[r]; X *x = dx + xOffsets[r];
if (threadIdx.x < xTadLength && threadIdx.x < maxBlock) { if (threadIdx.x < xTadLength && threadIdx.x < maxBlock) {
auto x0 = shape::getIndexOffset(threadIdx.x, xTadShapeInfo, shape::length(xTadShapeInfo)); auto x0 = shape::getIndexOffset(threadIdx.x, xTadShapeInfo);
tempX[threadIdx.x] = x[x0]; tempX[threadIdx.x] = x[x0];
} }
__syncthreads(); __syncthreads();
@ -311,12 +311,12 @@ __device__ void Reduce3<X,Z>::transformAll( void *vx, Nd4jLong *xShapeInfo,
// we reset tempX IF we have >1 tiles // we reset tempX IF we have >1 tiles
if (t >= 1 || (limit > 1 && g > 0)) if (t >= 1 || (limit > 1 && g > 0))
if (threadIdx.x + (t * maxBlock) < xTadLength) { if (threadIdx.x + (t * maxBlock) < xTadLength) {
auto x0 = shape::getIndexOffset(threadIdx.x + (t * maxBlock), xTadShapeInfo, xTadLength); auto x0 = shape::getIndexOffset(threadIdx.x + (t * maxBlock), xTadShapeInfo);
tempX[threadIdx.x] = x[x0]; tempX[threadIdx.x] = x[x0];
} }
for (int f = threadIdx.x + (t * maxBlock); f < xTadLength && f < threadIdx.x + ((t + 1) * maxBlock); f += blockDim.x * gridDim.x) { for (int f = threadIdx.x + (t * maxBlock); f < xTadLength && f < threadIdx.x + ((t + 1) * maxBlock); f += blockDim.x * gridDim.x) {
auto y0 = shape::getIndexOffset(f, yTadShapeInfo, yTadLength); auto y0 = shape::getIndexOffset(f, yTadShapeInfo);
sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::opAtomic(tempX[threadIdx.x], y[y0], extraZ), extraZ); sPartials[threadIdx.x] = OpType::update(sPartials[threadIdx.x], OpType::opAtomic(tempX[threadIdx.x], y[y0], extraZ), extraZ);
} }
@ -433,8 +433,8 @@ __device__ void Reduce3<X,Z>::transform(void *vx, Nd4jLong *xShapeInfo,
for (int j = threadIdx.x; j < tadLen; j += blockDim.x) { for (int j = threadIdx.x; j < tadLen; j += blockDim.x) {
Nd4jLong xOffset2 = xOffset + shape::getIndexOffset(j, tadOnlyShapeInfo, tadLen); Nd4jLong xOffset2 = xOffset + shape::getIndexOffset(j, tadOnlyShapeInfo);
Nd4jLong yOffset2 = yOffset + shape::getIndexOffset(j, yTadOnlyShapeInfo, tadLen); Nd4jLong yOffset2 = yOffset + shape::getIndexOffset(j, yTadOnlyShapeInfo);
sPartials[threadIdx.x] = j < blockDim.x ? OpType::opAtomic(x[xOffset2], y[yOffset2], extraZ) : OpType::update(sPartials[threadIdx.x], OpType::opAtomic(x[xOffset2], y[yOffset2], extraZ), extraZ); sPartials[threadIdx.x] = j < blockDim.x ? OpType::opAtomic(x[xOffset2], y[yOffset2], extraZ) : OpType::update(sPartials[threadIdx.x], OpType::opAtomic(x[xOffset2], y[yOffset2], extraZ), extraZ);
} }

View File

@ -33,7 +33,7 @@ using namespace simdOps;
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
template <typename X, typename Y, typename Z, typename OpType> template <typename X, typename Y, typename Z, typename OpType>
__global__ static void scalarSimpleShaped(void* vx, void *vscalar, Nd4jLong *xShapeInfo, void *vparams, void *vz, Nd4jLong *zShapeInfo, int *allocationBuffer) { __global__ static void scalarSimpleShaped(void* vx, void *vscalar, Nd4jLong *xShapeInfo, void *vparams, void *vz, Nd4jLong *zShapeInfo, int *allocationBuffer) {
auto scalar = reinterpret_cast<Y*>(vscalar)[0]; auto scalar = reinterpret_cast<Y*>(vscalar)[0];
auto x = reinterpret_cast<X*>(vx); auto x = reinterpret_cast<X*>(vx);
auto params = reinterpret_cast<Z*>(vparams); auto params = reinterpret_cast<Z*>(vparams);
@ -61,10 +61,10 @@ __global__ static void scalarSimpleShaped(void* vx, void *vscalar, Nd4jLong *xSh
} }
} else { } else {
for (Nd4jLong i = tid; i < length; i += totalThreads) { for (Nd4jLong i = tid; i < length; i += totalThreads) {
z[shape::getIndexOffset(i, zShapeInfo, length)] = OpType::op(x[shape::getIndexOffset(i, xShapeInfo, length)], scalar, params); z[shape::getIndexOffset(i, zShapeInfo)] = OpType::op(x[shape::getIndexOffset(i, xShapeInfo)], scalar, params);
} }
} }
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
@ -76,7 +76,7 @@ __global__ static void scalarAlongDimension(void *vx, Nd4jLong *xShapeInfo,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) { Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) {
auto x = reinterpret_cast<X*>(vx); auto x = reinterpret_cast<X*>(vx);
auto extraParams = reinterpret_cast<Z*>(vextraParams); auto extraParams = reinterpret_cast<Z*>(vextraParams);
auto z = reinterpret_cast<Z*>(vz); auto z = reinterpret_cast<Z*>(vz);
@ -114,7 +114,7 @@ __global__ static void scalarAlongDimension(void *vx, Nd4jLong *xShapeInfo,
auto s = scalars[r]; auto s = scalars[r];
for (int f = threadIdx.x; f < tadLength; f += blockDim.x) for (int f = threadIdx.x; f < tadLength; f += blockDim.x)
oZ[shape::getIndexOffset(f, tadShapeInfoZ, tadLength)] = OpType::op(oX[shape::getIndexOffset(f, tadShapeInfo, tadLength)], s, extraParams); oZ[shape::getIndexOffset(f, tadShapeInfoZ)] = OpType::op(oX[shape::getIndexOffset(f, tadShapeInfo)], s, extraParams);
} }
} }
} }
@ -127,7 +127,7 @@ namespace scalar {
template<typename X, typename Y, typename Z> template<typename X, typename Y, typename Z>
template<typename OpType> template<typename OpType>
void _CUDA_H ScalarTransform<X,Y,Z>::intermediateShaped(dim3& launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, Nd4jLong *hxShapeInfo, void *vz, Nd4jLong *zShapeInfo, Nd4jLong *hzShapeInfo, void* vscalar, void *vextraParams, int *allocPointer){ void _CUDA_H ScalarTransform<X,Y,Z>::intermediateShaped(dim3& launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, Nd4jLong *hxShapeInfo, void *vz, Nd4jLong *zShapeInfo, Nd4jLong *hzShapeInfo, void* vscalar, void *vextraParams, int *allocPointer){
auto xEws = shape::elementWiseStride(hxShapeInfo); auto xEws = shape::elementWiseStride(hxShapeInfo);
auto xOrder = shape::order(hxShapeInfo); auto xOrder = shape::order(hxShapeInfo);

View File

@ -36,7 +36,7 @@ __global__ void scalarAlongDimension(void *x, Nd4jLong *xShapeInfo,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) { Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) {
functions::scalar::ScalarBoolTransform<X,Z>::template transformCuda<OpType>(x, xShapeInfo, extraParams, z, zShapeInfo, scalars, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ); functions::scalar::ScalarBoolTransform<X,Z>::template transformCuda<OpType>(x, xShapeInfo, extraParams, z, zShapeInfo, scalars, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ);
} }
@ -60,10 +60,10 @@ namespace scalar {
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
template<typename X, typename Z> template<typename X, typename Z>
template<typename OpType> template<typename OpType>
__device__ void ScalarBoolTransform<X, Z>::transformCuda(void* vscalar, __device__ void ScalarBoolTransform<X, Z>::transformCuda(void* vscalar,
void *vy, Nd4jLong *yShapeInfo, void *vy, Nd4jLong *yShapeInfo,
void *vparams, void *vparams,
void *vz, Nd4jLong *zShapeInfo, void *vz, Nd4jLong *zShapeInfo,
int *allocationBuffer) { int *allocationBuffer) {
auto scalar = reinterpret_cast<X*>(vscalar)[0]; auto scalar = reinterpret_cast<X*>(vscalar)[0];
auto y = reinterpret_cast<X*>(vy); auto y = reinterpret_cast<X*>(vy);
@ -73,8 +73,8 @@ __device__ void ScalarBoolTransform<X, Z>::transformCuda(void* vscalar,
auto yRank = shape::rank(yShapeInfo); auto yRank = shape::rank(yShapeInfo);
auto yEWS = shape::elementWiseStride(yShapeInfo); auto yEWS = shape::elementWiseStride(yShapeInfo);
auto yShape = shape::shapeOf(yShapeInfo); auto yShape = shape::shapeOf(yShapeInfo);
auto yStride = shape::stride(yShapeInfo); auto yStride = shape::stride(yShapeInfo);
auto zRank = shape::rank(zShapeInfo); auto zRank = shape::rank(zShapeInfo);
auto zEWS = shape::elementWiseStride(zShapeInfo); auto zEWS = shape::elementWiseStride(zShapeInfo);
auto zShape = shape::shapeOf(zShapeInfo); auto zShape = shape::shapeOf(zShapeInfo);
@ -89,22 +89,22 @@ __device__ void ScalarBoolTransform<X, Z>::transformCuda(void* vscalar,
__syncthreads(); __syncthreads();
if(yEWS >= 1 && zEWS >= 1 && shape::order(yShapeInfo) == shape::order(zShapeInfo)) { if(yEWS >= 1 && zEWS >= 1 && shape::order(yShapeInfo) == shape::order(zShapeInfo)) {
transformCuda<OpType>(len, vscalar, vy, yEWS, vparams, vz, zEWS, allocationBuffer); transformCuda<OpType>(len, vscalar, vy, yEWS, vparams, vz, zEWS, allocationBuffer);
} }
else { else {
for (Nd4jLong i = tid; i < len; i+= totalThreads) for (Nd4jLong i = tid; i < len; i+= totalThreads)
z[shape::getIndexOffset(i, zShapeInfo, len)] = OpType::op(y[shape::getIndexOffset(i, yShapeInfo, len)], scalar, params); z[shape::getIndexOffset(i, zShapeInfo)] = OpType::op(y[shape::getIndexOffset(i, yShapeInfo)], scalar, params);
} }
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
template<typename X, typename Z> template<typename X, typename Z>
template<typename OpType> template<typename OpType>
__device__ void ScalarBoolTransform<X, Z>::transformCuda(Nd4jLong len, __device__ void ScalarBoolTransform<X, Z>::transformCuda(Nd4jLong len,
void* vx, void* vx,
void *vy, Nd4jLong yEWS, void *vy, Nd4jLong yEWS,
void *vparams, void *vparams,
void *vz, Nd4jLong zEWS, void *vz, Nd4jLong zEWS,
int *allocationBuffer) { int *allocationBuffer) {
auto x = reinterpret_cast<X*>(vx)[0]; auto x = reinterpret_cast<X*>(vx)[0];
@ -130,18 +130,18 @@ __device__ void ScalarBoolTransform<X, Z>::transformCuda(Nd4jLong len,
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
template<typename X, typename Z> template<typename X, typename Z>
template<typename OpType> template<typename OpType>
__device__ void ScalarBoolTransform<X, Z>::transformCuda(void *vx, Nd4jLong *xShapeInfo, __device__ void ScalarBoolTransform<X, Z>::transformCuda(void *vx, Nd4jLong *xShapeInfo,
void *vextraParams, void *vextraParams,
void *vz, Nd4jLong *zShapeInfo, void *vz, Nd4jLong *zShapeInfo,
void *vscalars, void *vscalars,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) { Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) {
auto x = reinterpret_cast<X*>(vx); auto x = reinterpret_cast<X*>(vx);
auto scalars = reinterpret_cast<X*>(vscalars); auto scalars = reinterpret_cast<X*>(vscalars);
auto z = reinterpret_cast<Z*>(vz); auto z = reinterpret_cast<Z*>(vz);
auto extraParams = reinterpret_cast<X*>(vextraParams); auto extraParams = reinterpret_cast<X*>(vextraParams);
if (tadShapeInfoZ == nullptr) { if (tadShapeInfoZ == nullptr) {
tadShapeInfoZ = tadShapeInfo; tadShapeInfoZ = tadShapeInfo;
tadOffsetsZ = tadOffsets; tadOffsetsZ = tadOffsets;
@ -174,7 +174,7 @@ __device__ void ScalarBoolTransform<X, Z>::transformCuda(void *vx, Nd4jLong *xS
auto s = scalars[r]; auto s = scalars[r];
for (int f = threadIdx.x; f < tadLength; f += blockDim.x) for (int f = threadIdx.x; f < tadLength; f += blockDim.x)
oZ[shape::getIndexOffset(f, tadShapeInfoZ, tadLength)] = OpType::op(oX[shape::getIndexOffset(f, tadShapeInfo, tadLength)], s, extraParams); oZ[shape::getIndexOffset(f, tadShapeInfoZ)] = OpType::op(oX[shape::getIndexOffset(f, tadShapeInfo)], s, extraParams);
} }
} }
} }
@ -184,12 +184,12 @@ __device__ void ScalarBoolTransform<X, Z>::transformCuda(void *vx, Nd4jLong *xS
template<typename X, typename Z> template<typename X, typename Z>
template <typename OpType> template <typename OpType>
_CUDA_H void ScalarBoolTransform<X, Z>::intermediateAlongDimension(dim3& launchDims, cudaStream_t *stream, _CUDA_H void ScalarBoolTransform<X, Z>::intermediateAlongDimension(dim3& launchDims, cudaStream_t *stream,
void *x, Nd4jLong *xShapeInfo, void *x, Nd4jLong *xShapeInfo,
void *z, Nd4jLong *zShapeInfo, void *z, Nd4jLong *zShapeInfo,
void *scalars, void *scalars,
void *extraParams, void *extraParams,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) { Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) {
scalarAlongDimension<X, Z, OpType><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShapeInfo, extraParams, z, zShapeInfo, scalars, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ); scalarAlongDimension<X, Z, OpType><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShapeInfo, extraParams, z, zShapeInfo, scalars, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ);
@ -200,11 +200,11 @@ _CUDA_H void ScalarBoolTransform<X, Z>::intermediateAlongDimension(dim3& launchD
template<typename X, typename Z> template<typename X, typename Z>
template<typename OpType> template<typename OpType>
void _CUDA_H ScalarBoolTransform<X,Z>::intermediateShaped(dim3& launchDims, cudaStream_t *stream, void _CUDA_H ScalarBoolTransform<X,Z>::intermediateShaped(dim3& launchDims, cudaStream_t *stream,
void *vx, Nd4jLong *xShapeInfo, void *vx, Nd4jLong *xShapeInfo,
void *vz, Nd4jLong *zShapeInfo, void *vz, Nd4jLong *zShapeInfo,
void* vscalar, void* vscalar,
void *vextraParams, int *allocPointer){ void *vextraParams, int *allocPointer){
scalarSimpleShaped<X, Z, OpType><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, vscalar, xShapeInfo, vextraParams, vz, zShapeInfo, allocPointer); scalarSimpleShaped<X, Z, OpType><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, vscalar, xShapeInfo, vextraParams, vz, zShapeInfo, allocPointer);
nd4j::DebugHelper::checkErrorCode(stream, "scalarSimpleShaped(...) failed"); nd4j::DebugHelper::checkErrorCode(stream, "scalarSimpleShaped(...) failed");
} }
@ -212,10 +212,10 @@ void _CUDA_H ScalarBoolTransform<X,Z>::intermediateShaped(dim3& launchDims, cuda
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
template<typename X, typename Y> template<typename X, typename Y>
void ScalarBoolTransform<X,Y>::executeCudaShaped(dim3& launchDims, cudaStream_t *stream, void ScalarBoolTransform<X,Y>::executeCudaShaped(dim3& launchDims, cudaStream_t *stream,
int opNum, int opNum,
void *vx, Nd4jLong *xShapeInfo, void *vx, Nd4jLong *xShapeInfo,
void *vz, Nd4jLong *zShapeInfo, void *vz, Nd4jLong *zShapeInfo,
void* vscalar, void* vscalar,
void *vextraParams) { void *vextraParams) {
if (nd4j::Environment::getInstance()->isDebugAndVerbose()) if (nd4j::Environment::getInstance()->isDebugAndVerbose())

View File

@ -36,7 +36,7 @@ __global__ void scalarAlongDimension(void *x, Nd4jLong *xShapeInfo,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) { Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) {
functions::scalar::ScalarIntTransform<X>::template transformCuda<OpType>(x, xShapeInfo, extraParams, z, zShapeInfo, scalars, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ); functions::scalar::ScalarIntTransform<X>::template transformCuda<OpType>(x, xShapeInfo, extraParams, z, zShapeInfo, scalars, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ);
} }
@ -60,10 +60,10 @@ namespace scalar {
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
template<typename X> template<typename X>
template<typename OpType> template<typename OpType>
__device__ void ScalarIntTransform<X>::transformCuda(void* vscalar, __device__ void ScalarIntTransform<X>::transformCuda(void* vscalar,
void *vy, Nd4jLong *yShapeInfo, void *vy, Nd4jLong *yShapeInfo,
void *vparams, void *vparams,
void *vz, Nd4jLong *zShapeInfo, void *vz, Nd4jLong *zShapeInfo,
int *allocationBuffer) { int *allocationBuffer) {
auto scalar = reinterpret_cast<X*>(vscalar)[0]; auto scalar = reinterpret_cast<X*>(vscalar)[0];
auto y = reinterpret_cast<X*>(vy); auto y = reinterpret_cast<X*>(vy);
@ -73,8 +73,8 @@ __device__ void ScalarIntTransform<X>::transformCuda(void* vscalar,
auto yRank = shape::rank(yShapeInfo); auto yRank = shape::rank(yShapeInfo);
auto yEWS = shape::elementWiseStride(yShapeInfo); auto yEWS = shape::elementWiseStride(yShapeInfo);
auto yShape = shape::shapeOf(yShapeInfo); auto yShape = shape::shapeOf(yShapeInfo);
auto yStride = shape::stride(yShapeInfo); auto yStride = shape::stride(yShapeInfo);
auto zRank = shape::rank(zShapeInfo); auto zRank = shape::rank(zShapeInfo);
auto zEWS = shape::elementWiseStride(zShapeInfo); auto zEWS = shape::elementWiseStride(zShapeInfo);
auto zShape = shape::shapeOf(zShapeInfo); auto zShape = shape::shapeOf(zShapeInfo);
@ -89,11 +89,11 @@ __device__ void ScalarIntTransform<X>::transformCuda(void* vscalar,
__syncthreads(); __syncthreads();
if(yEWS >= 1 && zEWS >= 1 && shape::order(yShapeInfo) == shape::order(zShapeInfo)) { if(yEWS >= 1 && zEWS >= 1 && shape::order(yShapeInfo) == shape::order(zShapeInfo)) {
transformCuda<OpType>(len, vscalar, vy, yEWS, vparams, vz, zEWS, allocationBuffer); transformCuda<OpType>(len, vscalar, vy, yEWS, vparams, vz, zEWS, allocationBuffer);
} }
else { else {
for (Nd4jLong i = tid; i < len; i+= totalThreads) for (Nd4jLong i = tid; i < len; i+= totalThreads)
z[shape::getIndexOffset(i, zShapeInfo, len)] = OpType::op(y[shape::getIndexOffset(i, yShapeInfo, len)], scalar, params); z[shape::getIndexOffset(i, zShapeInfo)] = OpType::op(y[shape::getIndexOffset(i, yShapeInfo)], scalar, params);
} }
} }
@ -101,10 +101,10 @@ __device__ void ScalarIntTransform<X>::transformCuda(void* vscalar,
template<typename X> template<typename X>
template<typename OpType> template<typename OpType>
__device__ void ScalarIntTransform<X>::transformCuda(Nd4jLong len, __device__ void ScalarIntTransform<X>::transformCuda(Nd4jLong len,
void* vx, void* vx,
void *vy, Nd4jLong yEWS, void *vy, Nd4jLong yEWS,
void *vparams, void *vparams,
void *vz, Nd4jLong zEWS, void *vz, Nd4jLong zEWS,
int *allocationBuffer) { int *allocationBuffer) {
auto x = reinterpret_cast<X*>(vx)[0]; auto x = reinterpret_cast<X*>(vx)[0];
@ -131,17 +131,17 @@ __device__ void ScalarIntTransform<X>::transformCuda(Nd4jLong len,
template<typename X> template<typename X>
template<typename OpType> template<typename OpType>
__device__ void ScalarIntTransform<X>::transformCuda(void *vx, Nd4jLong *xShapeInfo, __device__ void ScalarIntTransform<X>::transformCuda(void *vx, Nd4jLong *xShapeInfo,
void *vextraParams, void *vextraParams,
void *vz, Nd4jLong *zShapeInfo, void *vz, Nd4jLong *zShapeInfo,
void *vscalars, void *vscalars,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) { Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) {
auto x = reinterpret_cast<X*>(vx); auto x = reinterpret_cast<X*>(vx);
auto scalars = reinterpret_cast<X*>(vscalars); auto scalars = reinterpret_cast<X*>(vscalars);
auto z = reinterpret_cast<X*>(vz); auto z = reinterpret_cast<X*>(vz);
auto extraParams = reinterpret_cast<X*>(vextraParams); auto extraParams = reinterpret_cast<X*>(vextraParams);
if (tadShapeInfoZ == nullptr) { if (tadShapeInfoZ == nullptr) {
tadShapeInfoZ = tadShapeInfo; tadShapeInfoZ = tadShapeInfo;
tadOffsetsZ = tadOffsets; tadOffsetsZ = tadOffsets;
@ -174,7 +174,7 @@ __device__ void ScalarIntTransform<X>::transformCuda(void *vx, Nd4jLong *xShape
auto s = scalars[r]; auto s = scalars[r];
for (int f = threadIdx.x; f < tadLength; f += blockDim.x) for (int f = threadIdx.x; f < tadLength; f += blockDim.x)
oZ[shape::getIndexOffset(f, tadShapeInfoZ, tadLength)] = OpType::op(oX[shape::getIndexOffset(f, tadShapeInfo, tadLength)], s, extraParams); oZ[shape::getIndexOffset(f, tadShapeInfoZ)] = OpType::op(oX[shape::getIndexOffset(f, tadShapeInfo)], s, extraParams);
} }
} }
} }
@ -184,12 +184,12 @@ __device__ void ScalarIntTransform<X>::transformCuda(void *vx, Nd4jLong *xShape
template<typename X> template<typename X>
template <typename OpType> template <typename OpType>
_CUDA_H void ScalarIntTransform<X>::intermediateAlongDimension(dim3& launchDims, cudaStream_t *stream, _CUDA_H void ScalarIntTransform<X>::intermediateAlongDimension(dim3& launchDims, cudaStream_t *stream,
void *x, Nd4jLong *xShapeInfo, void *x, Nd4jLong *xShapeInfo,
void *z, Nd4jLong *zShapeInfo, void *z, Nd4jLong *zShapeInfo,
void *scalars, void *scalars,
void *extraParams, void *extraParams,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) { Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) {
scalarAlongDimension<X, OpType><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShapeInfo, extraParams, z, zShapeInfo, scalars, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ); scalarAlongDimension<X, OpType><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShapeInfo, extraParams, z, zShapeInfo, scalars, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ);
@ -199,21 +199,21 @@ _CUDA_H void ScalarIntTransform<X>::intermediateAlongDimension(dim3& launchDims,
template<typename X> template<typename X>
template<typename OpType> template<typename OpType>
void _CUDA_H ScalarIntTransform<X>::intermediateShaped(dim3& launchDims, cudaStream_t *stream, void _CUDA_H ScalarIntTransform<X>::intermediateShaped(dim3& launchDims, cudaStream_t *stream,
void *vx, Nd4jLong *xShapeInfo, void *vx, Nd4jLong *xShapeInfo,
void *vz, Nd4jLong *zShapeInfo, void *vz, Nd4jLong *zShapeInfo,
void* vscalar, void* vscalar,
void *vextraParams, int *allocPointer){ void *vextraParams, int *allocPointer){
scalarSimpleShaped<X, OpType><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, vscalar, xShapeInfo, vextraParams, vz, zShapeInfo, allocPointer); scalarSimpleShaped<X, OpType><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, vscalar, xShapeInfo, vextraParams, vz, zShapeInfo, allocPointer);
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
template<typename X> template<typename X>
void ScalarIntTransform<X>::executeCudaShaped(dim3& launchDims, cudaStream_t *stream, void ScalarIntTransform<X>::executeCudaShaped(dim3& launchDims, cudaStream_t *stream,
int opNum, int opNum,
void *vx, Nd4jLong *xShapeInfo, void *vx, Nd4jLong *xShapeInfo,
void *vz, Nd4jLong *zShapeInfo, void *vz, Nd4jLong *zShapeInfo,
void* vscalar, void* vscalar,
void *vextraParams) { void *vextraParams) {
if (nd4j::Environment::getInstance()->isDebugAndVerbose()) if (nd4j::Environment::getInstance()->isDebugAndVerbose())

View File

@ -80,8 +80,8 @@ __global__ void bitonicArbitraryStepKernelKey(void *vx, Nd4jLong *xShapeInfo, vo
int it = (reverse) ? i + j + half : i + window - j - 1; int it = (reverse) ? i + j + half : i + window - j - 1;
int ij = i+j; int ij = i+j;
if (it < length && ij < length ) { if (it < length && ij < length ) {
int posIT = shape::getIndexOffset(it, xShapeInfo, xLength); int posIT = shape::getIndexOffset(it, xShapeInfo);
int posIJ = shape::getIndexOffset(ij, xShapeInfo, xLength); int posIJ = shape::getIndexOffset(ij, xShapeInfo);
X v0 = x[posIJ]; X v0 = x[posIJ];
X v1 = x[posIT]; X v1 = x[posIT];
@ -160,8 +160,8 @@ __global__ void execBitonicArbitraryStepKernel(void *vx, Nd4jLong *xShapeInfo, i
int it = (reverse) ? i + j + half : i + window - j - 1; int it = (reverse) ? i + j + half : i + window - j - 1;
int ij = i+j; int ij = i+j;
if (it < length && ij < length ) { if (it < length && ij < length ) {
int posIT = shape::getIndexOffset(it, xShapeInfo, xLength); int posIT = shape::getIndexOffset(it, xShapeInfo);
int posIJ = shape::getIndexOffset(ij, xShapeInfo, xLength); int posIJ = shape::getIndexOffset(ij, xShapeInfo);
shmem[threadIdx.x] = x[posIJ]; shmem[threadIdx.x] = x[posIJ];
shmem[threadIdx.x + blockDim.x] = x[posIT]; shmem[threadIdx.x + blockDim.x] = x[posIT];

View File

@ -46,8 +46,8 @@ __global__ void bitonicSortStepKernelKey(void *vx, Nd4jLong *xShapeInfo, void *v
/* The threads with the lowest ids sort the array. */ /* The threads with the lowest ids sort the array. */
if ((ixj)>i) { if ((ixj)>i) {
int posI = shape::getIndexOffset(i, xShapeInfo, xLength); int posI = shape::getIndexOffset(i, xShapeInfo);
int posIXJ = shape::getIndexOffset(ixj, xShapeInfo, xLength); int posIXJ = shape::getIndexOffset(ixj, xShapeInfo);
if ((i&k)==0) { if ((i&k)==0) {
/* Sort ascending */ /* Sort ascending */
@ -100,8 +100,8 @@ __global__ void bitonicSortStepKernel(void *vx, Nd4jLong *xShapeInfo, int j, int
/* The threads with the lowest ids sort the array. */ /* The threads with the lowest ids sort the array. */
if ((ixj)>i) { if ((ixj)>i) {
int posI = shape::getIndexOffset(i, xShapeInfo, xLength); int posI = shape::getIndexOffset(i, xShapeInfo);
int posIXJ = shape::getIndexOffset(ixj, xShapeInfo, xLength); int posIXJ = shape::getIndexOffset(ixj, xShapeInfo);
if ((i&k)==0) { if ((i&k)==0) {
/* Sort ascending */ /* Sort ascending */

View File

@ -139,19 +139,19 @@ namespace nd4j {
Nd4jLong sub[MAX_RANK]; Nd4jLong sub[MAX_RANK];
shape::index2coords(shape::rank(zTadShape),shape::shapeOf(zTadShape), arrOffset, sub, shape::order(zTadShape)); shape::index2coords(arrOffset, zTadShape, sub);
Nd4jLong baseOffset = shape::getOffset(0,shape::shapeOf(zTadShape),shape::stride(zTadShape), sub, shape::rank(zTadShape)); Nd4jLong baseOffset = shape::getOffset(zTadShape, sub);
resultTAD += baseOffset; resultTAD += baseOffset;
auto yRank = shape::rank(currentTad); auto yRank = shape::rank(currentTad);
auto tadRank = shape::rank(zTadShape); auto tadRank = shape::rank(zTadShape);
shape::index2coords(yRank, shape::shapeOf(currentTad), 0, sub); shape::index2coords(0, currentTad, sub);
auto yOffset = shape::getOffset(0, shape::shapeOf(currentTad), shape::stride(currentTad), sub, yRank); auto yOffset = shape::getOffset(currentTad, sub);
resultOffset = shape::getOffset(0, shape::shapeOf(zTadShape), shape::stride(zTadShape), sub, tadRank); resultOffset = shape::getOffset(zTadShape, sub);
resultTAD[resultOffset] = dataTAD[yOffset]; resultTAD[resultOffset] = dataTAD[yOffset];
} }
@ -168,8 +168,8 @@ namespace nd4j {
Nd4jLong sub[MAX_RANK]; Nd4jLong sub[MAX_RANK];
shape::index2coords(shape::rank(zTadShape),shape::shapeOf(zTadShape), arrOffset, sub); shape::index2coords(arrOffset, zTadShape, sub);
Nd4jLong baseOffset = shape::getOffset(0,shape::shapeOf(zTadShape),shape::stride(zTadShape), sub, shape::rank(zTadShape)); Nd4jLong baseOffset = shape::getOffset(zTadShape, sub);
resultTAD += baseOffset; resultTAD += baseOffset;
@ -203,8 +203,8 @@ namespace nd4j {
auto yRank = shape::rank(currentTad); auto yRank = shape::rank(currentTad);
for (int i = threadIdx.x; i < yLength; i+= blockDim.x) { for (int i = threadIdx.x; i < yLength; i+= blockDim.x) {
shape::index2coords(yRank, shape::shapeOf(currentTad), i, yIdx); shape::index2coords(i, currentTad, yIdx);
auto yOffset = shape::getOffset(0, shape::shapeOf(currentTad), shape::stride(currentTad), yIdx, yRank); auto yOffset = shape::getOffset(currentTad, yIdx);
resultTAD[baseIdx + i * tadEWS] = dataTAD[yOffset]; resultTAD[baseIdx + i * tadEWS] = dataTAD[yOffset];
} }
@ -220,11 +220,11 @@ namespace nd4j {
auto tadRank = shape::rank(zTadShape); auto tadRank = shape::rank(zTadShape);
for (int i = threadIdx.x; i < yLength; i+= blockDim.x) { for (int i = threadIdx.x; i < yLength; i+= blockDim.x) {
shape::index2coords(yRank, shape::shapeOf(currentTad), i, yIdx); shape::index2coords(i, currentTad, yIdx);
shape::index2coords(tadRank, shape::shapeOf(zTadShape), i, zIdx); shape::index2coords(i, zTadShape, zIdx);
auto yOffset = shape::getOffset(0, shape::shapeOf(currentTad), shape::stride(currentTad), yIdx, yRank); auto yOffset = shape::getOffset(currentTad, yIdx);
auto resultOffset = shape::getOffset(0, shape::shapeOf(zTadShape), shape::stride(zTadShape), zIdx, tadRank); auto resultOffset = shape::getOffset(zTadShape, zIdx);
resultTAD[resultOffset] = dataTAD[yOffset]; resultTAD[resultOffset] = dataTAD[yOffset];
} }

View File

@ -53,7 +53,7 @@ namespace nd4j {
if (dimensionLength > 1 || tadEWS < 1) { if (dimensionLength > 1 || tadEWS < 1) {
for (Nd4jLong e = threadIdx.x; e < tadLength; e += blockDim.x) { for (Nd4jLong e = threadIdx.x; e < tadLength; e += blockDim.x) {
auto xOffset = tadOffsetForBlock + shape::getIndexOffset(e, tadOnlyShapeInfo, tadLength); auto xOffset = tadOffsetForBlock + shape::getIndexOffset(e, tadOnlyShapeInfo);
dZ[xOffset] = (e == highestElement ? (T) 1 : (T) 0); dZ[xOffset] = (e == highestElement ? (T) 1 : (T) 0);
} }
} else { } else {

View File

@ -30,7 +30,7 @@ namespace nd4j {
int tid = blockIdx.x * blockDim.x + threadIdx.x; int tid = blockIdx.x * blockDim.x + threadIdx.x;
for (Nd4jLong i = tid; i < length; i += blockDim.x * gridDim.x) for (Nd4jLong i = tid; i < length; i += blockDim.x * gridDim.x)
dz[shape::getIndexOffset(i, xShapeInfo, length)] = (i == idx ? (T) 1 : (T) 0); dz[shape::getIndexOffset(i, xShapeInfo)] = (i == idx ? (T) 1 : (T) 0);
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////

View File

@ -20,6 +20,7 @@
// //
#include <loops/special_kernels.h> #include <loops/special_kernels.h>
#include <ops/declarable/helpers/flatten.h>
namespace nd4j { namespace nd4j {
@ -34,34 +35,26 @@ __global__ void flattenKernel(
auto z = reinterpret_cast<T*>(vz); auto z = reinterpret_cast<T*>(vz);
auto y = reinterpret_cast<T*>(vy); auto y = reinterpret_cast<T*>(vy);
__shared__ Nd4jLong lenY, yOrder, zEWS, yEWS; __shared__ Nd4jLong lenY, yOrder, zEWS, yEWS;
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
yEWS = shape::elementWiseStride(yShapeInfo); yEWS = shape::elementWiseStride(yShapeInfo);
zEWS = shape::elementWiseStride(zShapeInfo); zEWS = shape::elementWiseStride(zShapeInfo);
lenY = shape::length(yShapeInfo); lenY = shape::length(yShapeInfo);
} }
__syncthreads(); __syncthreads();
Nd4jLong tid = blockIdx.x * blockDim.x + threadIdx.x; Nd4jLong tid = blockIdx.x * blockDim.x + threadIdx.x;
if (zEWS >= 1 && yEWS >= 1 && yOrder == order) { for(auto i = tid; i < lenY; i += gridDim.x * blockDim.x)
z[i * zEWS + dOffset] = y[ops::helpers::getIndexOffsetOrdered(i, yShapeInfo, order)];
for (int i = tid; i < lenY; i += gridDim.x * blockDim.x)
z[i * zEWS + dOffset] = y[i * yEWS];
}
else {
for(auto i = tid; i < lenY; i += gridDim.x * blockDim.x)
z[i * zEWS + dOffset] = y[shape::getIndexOrderOffset(i, yShapeInfo, lenY, order)];
}
} }
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
template <typename T> template <typename T>
__host__ void flattenKernelGeneric(dim3& launchDims, cudaStream_t *stream, __host__ void flattenKernelGeneric(dim3& launchDims, cudaStream_t *stream,
Nd4jPointer *extraPointers, Nd4jPointer *extraPointers,
int dOffset, int dOffset,
char order, char order,

View File

@ -54,8 +54,8 @@ __global__ void execOesTadKernelKey(void *vx, Nd4jLong *xShapeInfo,
for (int tid = threadIdx.x; tid < xTadLength; tid += blockDim.x) { for (int tid = threadIdx.x; tid < xTadLength; tid += blockDim.x) {
auto top = 2 * tid + 1; auto top = 2 * tid + 1;
if (top < xTadLength) { if (top < xTadLength) {
auto t0 = shape::getIndexOffset(top - 1, tadShapeInfo, xTadLength); auto t0 = shape::getIndexOffset(top - 1, tadShapeInfo);
auto t1 = shape::getIndexOffset(top, tadShapeInfo, xTadLength); auto t1 = shape::getIndexOffset(top, tadShapeInfo);
if (!descending == (dx[t0] > dx[t1])) { if (!descending == (dx[t0] > dx[t1])) {
X dt0 = dx[t0]; X dt0 = dx[t0];
@ -72,8 +72,8 @@ __global__ void execOesTadKernelKey(void *vx, Nd4jLong *xShapeInfo,
for (int tid = threadIdx.x; tid < xTadLength; tid += blockDim.x) { for (int tid = threadIdx.x; tid < xTadLength; tid += blockDim.x) {
auto top = 2 * tid + 2; auto top = 2 * tid + 2;
if (top < xTadLength) { if (top < xTadLength) {
auto t0 = shape::getIndexOffset(top - 1, tadShapeInfo, xTadLength); auto t0 = shape::getIndexOffset(top - 1, tadShapeInfo);
auto t1 = shape::getIndexOffset(top, tadShapeInfo, xTadLength); auto t1 = shape::getIndexOffset(top, tadShapeInfo);
if (!descending == (dx[t0] > dx[t1])) { if (!descending == (dx[t0] > dx[t1])) {
X dt0 = dx[t0]; X dt0 = dx[t0];
@ -126,7 +126,7 @@ __global__ void execOesTadKernel(void *vx, Nd4jLong *xShapeInfo,
int iterations = xTadLength; int iterations = xTadLength;
if (cached) { if (cached) {
for (int tid = threadIdx.x; tid < xTadLength; tid += blockDim.x) { for (int tid = threadIdx.x; tid < xTadLength; tid += blockDim.x) {
auto t0 = shape::getIndexOffset(tid, tadShapeInfo, xTadLength); auto t0 = shape::getIndexOffset(tid, tadShapeInfo);
shmem[tid] = dx[t0]; shmem[tid] = dx[t0];
} }
@ -140,8 +140,8 @@ __global__ void execOesTadKernel(void *vx, Nd4jLong *xShapeInfo,
for (int tid = threadIdx.x; tid < xTadLength; tid += blockDim.x) { for (int tid = threadIdx.x; tid < xTadLength; tid += blockDim.x) {
auto top = 2 * tid + 1; auto top = 2 * tid + 1;
if (top < xTadLength) { if (top < xTadLength) {
auto t0 = cached ? top - 1 : shape::getIndexOffset(top - 1, tadShapeInfo, xTadLength); auto t0 = cached ? top - 1 : shape::getIndexOffset(top - 1, tadShapeInfo);
auto t1 = cached ? top : shape::getIndexOffset(top, tadShapeInfo, xTadLength); auto t1 = cached ? top : shape::getIndexOffset(top, tadShapeInfo);
if (!descending == (dx[t0] > dx[t1])) { if (!descending == (dx[t0] > dx[t1])) {
T dt0 = dx[t0]; T dt0 = dx[t0];
@ -154,8 +154,8 @@ __global__ void execOesTadKernel(void *vx, Nd4jLong *xShapeInfo,
for (int tid = threadIdx.x; tid < xTadLength; tid += blockDim.x) { for (int tid = threadIdx.x; tid < xTadLength; tid += blockDim.x) {
auto top = 2 * tid + 2; auto top = 2 * tid + 2;
if (top < xTadLength) { if (top < xTadLength) {
auto t0 = cached ? top - 1 : shape::getIndexOffset(top - 1, tadShapeInfo, xTadLength); auto t0 = cached ? top - 1 : shape::getIndexOffset(top - 1, tadShapeInfo);
auto t1 = cached ? top : shape::getIndexOffset(top, tadShapeInfo, xTadLength); auto t1 = cached ? top : shape::getIndexOffset(top, tadShapeInfo);
if (!descending == (dx[t0] > dx[t1])) { if (!descending == (dx[t0] > dx[t1])) {
T dt0 = dx[t0]; T dt0 = dx[t0];
@ -172,7 +172,7 @@ __global__ void execOesTadKernel(void *vx, Nd4jLong *xShapeInfo,
if (cached) { if (cached) {
dx = x + tadOffsets[r]; dx = x + tadOffsets[r];
for (int tid = threadIdx.x; tid < xTadLength; tid += blockDim.x) { for (int tid = threadIdx.x; tid < xTadLength; tid += blockDim.x) {
auto t0 = shape::getIndexOffset(tid, tadShapeInfo, xTadLength); auto t0 = shape::getIndexOffset(tid, tadShapeInfo);
dx[t0] = shmem[tid]; dx[t0] = shmem[tid];
} }
} }

View File

@ -53,8 +53,8 @@ namespace nd4j {
T *rZ = z + zTadOffsets[idx]; T *rZ = z + zTadOffsets[idx];
for (int i = threadIdx.x; i < tadLength; i += blockDim.x) { for (int i = threadIdx.x; i < tadLength; i += blockDim.x) {
auto xOffset = shape::getIndexOffset(i, tadShapeInfo, tadLength); auto xOffset = shape::getIndexOffset(i, tadShapeInfo);
auto zOffset = shape::getIndexOffset(i, zTadShapeInfo, tadLength); auto zOffset = shape::getIndexOffset(i, zTadShapeInfo);
rZ[zOffset] = rX[xOffset]; rZ[zOffset] = rX[xOffset];
} }
} }

View File

@ -33,7 +33,7 @@ namespace nd4j {
for (Nd4jLong i = blockIdx.x; i < rows; i += gridDim.x) { for (Nd4jLong i = blockIdx.x; i < rows; i += gridDim.x) {
for (int j = threadIdx.x; j < cols; j += totalThreads) { for (int j = threadIdx.x; j < cols; j += totalThreads) {
Nd4jLong coords[2] = {i, j}; Nd4jLong coords[2] = {i, j};
Nd4jLong xOffset = shape::getOffset(0, shape::shapeOf(shape), shape::stride(shape), coords, rank); Nd4jLong xOffset = shape::getOffset(shape, coords);
if (i + diagonal <= j) if (i + diagonal <= j)
array[xOffset] = value; array[xOffset] = value;
} }
@ -48,7 +48,7 @@ namespace nd4j {
for (Nd4jLong i = blockIdx.x; i < rows; i += gridDim.x) { for (Nd4jLong i = blockIdx.x; i < rows; i += gridDim.x) {
for (int j = threadIdx.x; j < cols; j += totalThreads) { for (int j = threadIdx.x; j < cols; j += totalThreads) {
Nd4jLong coords[2] = {i, j}; Nd4jLong coords[2] = {i, j};
auto xOffset = shape::getOffset(0, shape::shapeOf(shape), shape::stride(shape), coords, rank); auto xOffset = shape::getOffset(shape, coords);
if (i + diagonal >= j) if (i + diagonal >= j)
*(reinterpret_cast<T*>(buffer) + xOffset) = value; *(reinterpret_cast<T*>(buffer) + xOffset) = value;
} }

View File

@ -92,7 +92,7 @@ namespace nd4j {
} else { } else {
for (Nd4jLong i = threadIdx.x; i < tadLength; i += blockDim.x) { for (Nd4jLong i = threadIdx.x; i < tadLength; i += blockDim.x) {
auto xOffset = shape::getIndexOffset(i, tadOnlyShapeInfo[f], tadLength); auto xOffset = shape::getIndexOffset(i, tadOnlyShapeInfo[f]);
auto yOffset = newOffset + xOffset; auto yOffset = newOffset + xOffset;
xOffset += oldOffset; xOffset += oldOffset;

View File

@ -34,8 +34,8 @@ namespace nd4j {
auto xEws = shape::order(theFirstShape) == 'c'? shape::elementWiseStride(theFirstShape) :1; auto xEws = shape::order(theFirstShape) == 'c'? shape::elementWiseStride(theFirstShape) :1;
auto yEws = shape::order(theSecondShape) == 'c'? shape::elementWiseStride(theSecondShape):1; auto yEws = shape::order(theSecondShape) == 'c'? shape::elementWiseStride(theSecondShape):1;
//if (shape::order(theFirstShape) ==) //if (shape::order(theFirstShape) ==)
auto xOffset = shape::getIndexOffset(i * xEws, theFirstShape, resultLength); auto xOffset = shape::getIndexOffset(i * xEws, theFirstShape);
auto yOffset = shape::getIndexOffset(i * yEws, theSecondShape, resultLength); auto yOffset = shape::getIndexOffset(i * yEws, theSecondShape);
T temp = *(reinterpret_cast<T*>(theFirstBuffer) + xOffset); T temp = *(reinterpret_cast<T*>(theFirstBuffer) + xOffset);
*(reinterpret_cast<T*>(theFirstBuffer) + xOffset) = *(reinterpret_cast<T*>(theSecondBuffer) + yOffset); *(reinterpret_cast<T*>(theFirstBuffer) + xOffset) = *(reinterpret_cast<T*>(theSecondBuffer) + yOffset);
*(reinterpret_cast<T*>(theSecondBuffer) + yOffset) = temp; *(reinterpret_cast<T*>(theSecondBuffer) + yOffset) = temp;

View File

@ -61,8 +61,8 @@ namespace nd4j {
} else { } else {
for (Nd4jLong j = threadIdx.x; j < tadLength; j += blockDim.x) { for (Nd4jLong j = threadIdx.x; j < tadLength; j += blockDim.x) {
auto xOffset = shape::getIndexOffset(j, tadShapeInfo, tadLength); auto xOffset = shape::getIndexOffset(j, tadShapeInfo);
auto zOffset = shape::getIndexOffset(j, zShapeInfo, tadLength); auto zOffset = shape::getIndexOffset(j, zShapeInfo);
z[zOffset] = s[xOffset]; z[zOffset] = s[xOffset];
} }

View File

@ -21,8 +21,8 @@
#include <loops/special_kernels.h> #include <loops/special_kernels.h>
namespace nd4j { namespace nd4j {
static Nd4jLong __device__ __noinline__ _getIndexOffset(Nd4jLong index, Nd4jLong *shapeInfo, Nd4jLong length) { static Nd4jLong __device__ __noinline__ _getIndexOffset(Nd4jLong index, Nd4jLong *shapeInfo) {
return shape::getIndexOffset(index, shapeInfo, length); return shape::getIndexOffset(index, shapeInfo);
} }
static Nd4jLong __device__ __noinline__ _subArrayOffset(Nd4jLong index, Nd4jLong *shapeInfoA, Nd4jLong *shapeInfoB) { static Nd4jLong __device__ __noinline__ _subArrayOffset(Nd4jLong index, Nd4jLong *shapeInfoA, Nd4jLong *shapeInfoB) {
@ -50,7 +50,7 @@ namespace nd4j {
} }
} else { } else {
for (int i = tid; i < resultLength; i += totalThreads) { for (int i = tid; i < resultLength; i += totalThreads) {
auto xOffset = _getIndexOffset(i, outputShape, resultLength); auto xOffset = _getIndexOffset(i, outputShape);
auto yOffset = _subArrayOffset(i, outputShape, inputShape); auto yOffset = _subArrayOffset(i, outputShape, inputShape);
*(reinterpret_cast<T *>(outputBuffer) + xOffset) = *(reinterpret_cast<T const *>(inputBuffer) + yOffset); *(reinterpret_cast<T *>(outputBuffer) + xOffset) = *(reinterpret_cast<T const *>(inputBuffer) + yOffset);
} }
@ -89,7 +89,7 @@ namespace nd4j {
for (int i = tid; i < resultLength; i += totalThreads) { for (int i = tid; i < resultLength; i += totalThreads) {
auto xOffset = _getIndexOffset(i, outputShape, resultLength); auto xOffset = _getIndexOffset(i, outputShape);
auto yOffset = _subArrayOffset(i, outputShape, inputShape); auto yOffset = _subArrayOffset(i, outputShape, inputShape);
*(reinterpret_cast<X *>(outputBuffer) + xOffset) = static_cast<X>(*(reinterpret_cast<Y const *>(inputBuffer) + yOffset)); *(reinterpret_cast<X *>(outputBuffer) + xOffset) = static_cast<X>(*(reinterpret_cast<Y const *>(inputBuffer) + yOffset));
} }

View File

@ -40,7 +40,7 @@ namespace functions {
template <typename X, typename Z> template <typename X, typename Z>
void _CUDA_G summaryStatsReduceT(int op, void *dx, Nd4jLong *xShapeInfo, int xRank, void *extraParams, void *z, Nd4jLong *zShapeInfo, int zRank, int *dimension, int dimensionLength, int postProcessOrNot,bool biasCorrected,int *allocationBuffer, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets) { void _CUDA_G summaryStatsReduceT(int op, void *dx, Nd4jLong *xShapeInfo, int xRank, void *extraParams, void *z, Nd4jLong *zShapeInfo, int zRank, int *dimension, int dimensionLength, int postProcessOrNot,bool biasCorrected,int *allocationBuffer, void *reductionBuffer, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets) {
functions::summarystats::SummaryStatsReduce<X,Z>::transform(op,dx,xShapeInfo,extraParams,z,zShapeInfo,dimension,dimensionLength,biasCorrected,allocationBuffer,reductionBuffer,tadOnlyShapeInfo,tadOffsets); functions::summarystats::SummaryStatsReduce<X,Z>::transform(op,dx,xShapeInfo,extraParams,z,zShapeInfo,dimension,dimensionLength,biasCorrected,allocationBuffer,reductionBuffer,tadOnlyShapeInfo,tadOffsets);
} }
@ -103,12 +103,12 @@ void _CUDA_G summaryStatsReduceT(int op, void *dx, Nd4jLong *xShapeInfo, int xRa
*/ */
template<typename X, typename Z> template<typename X, typename Z>
template<typename OpType> template<typename OpType>
_CUDA_D void SummaryStatsReduce<X,Z>::transform(void *vx, Nd4jLong *xShapeInfo, _CUDA_D void SummaryStatsReduce<X,Z>::transform(void *vx, Nd4jLong *xShapeInfo,
void *vextraParams, void *vextraParams,
void *vz, Nd4jLong *zShapeInfo, void *vz, Nd4jLong *zShapeInfo,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
int postProcessOrNot, int postProcessOrNot,
int *allocationBuffer, void *vreductionBuffer, int *allocationBuffer, void *vreductionBuffer,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets) { Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets) {
auto dx = static_cast<X*>(vx); auto dx = static_cast<X*>(vx);
@ -204,7 +204,7 @@ void _CUDA_G summaryStatsReduceT(int op, void *dx, Nd4jLong *xShapeInfo, int xRa
sPartials[threadIdx.x] = val; sPartials[threadIdx.x] = val;
for (int i = threadIdx.x; i < tadLength; i += blockDim.x) { for (int i = threadIdx.x; i < tadLength; i += blockDim.x) {
auto xOffset = tadOffsetForBlock + shape::getIndexOffset(i, tadOnlyShapeInfo, tadLength); auto xOffset = tadOffsetForBlock + shape::getIndexOffset(i, tadOnlyShapeInfo);
SummaryStatsData<X> indexVal2; SummaryStatsData<X> indexVal2;
indexVal2.initWithValue(dx[xOffset]); indexVal2.initWithValue(dx[xOffset]);
@ -264,8 +264,8 @@ void _CUDA_G summaryStatsReduceT(int op, void *dx, Nd4jLong *xShapeInfo, int xRa
else { else {
for (Nd4jLong i = tid; i < n; i += blockDim.x * gridDim.x) { for (Nd4jLong i = tid; i < n; i += blockDim.x * gridDim.x) {
auto offset = shape::getIndexOffset(i, xShapeInfo, n); auto offset = shape::getIndexOffset(i, xShapeInfo);
SummaryStatsData<X> indexVal2; SummaryStatsData<X> indexVal2;
indexVal2.initWithValue(dx[offset]); indexVal2.initWithValue(dx[offset]);
reduction = update(reduction, indexVal2, extraParams); reduction = update(reduction, indexVal2, extraParams);
@ -279,7 +279,7 @@ void _CUDA_G summaryStatsReduceT(int op, void *dx, Nd4jLong *xShapeInfo, int xRa
if (gridDim.x > 1) { if (gridDim.x > 1) {
__shared__ bool amLast; __shared__ bool amLast;
unsigned int *tc = (unsigned int *)reductionBuffer; unsigned int *tc = (unsigned int *)reductionBuffer;
tid = threadIdx.x; tid = threadIdx.x;
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
SummaryStatsData<X> *pBuffer = (SummaryStatsData<X>*) reductionBuffer; SummaryStatsData<X> *pBuffer = (SummaryStatsData<X>*) reductionBuffer;
@ -338,9 +338,9 @@ void _CUDA_G summaryStatsReduceT(int op, void *dx, Nd4jLong *xShapeInfo, int xRa
template <typename X, typename Z> template <typename X, typename Z>
_CUDA_H void SummaryStatsReduce<X,Z>::execSummaryStatsReduceScalar(dim3& launchDims, cudaStream_t *stream, int opNum, void *vx, Nd4jLong *xShapeInfo, Nd4jLong *hxShapeInfo, void *vextraParams, void *vz, Nd4jLong *zShapeInfo, Nd4jLong *hzShapeInfo, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, bool biasCorrected, void *reductionBuffer) { _CUDA_H void SummaryStatsReduce<X,Z>::execSummaryStatsReduceScalar(dim3& launchDims, cudaStream_t *stream, int opNum, void *vx, Nd4jLong *xShapeInfo, Nd4jLong *hxShapeInfo, void *vextraParams, void *vz, Nd4jLong *zShapeInfo, Nd4jLong *hzShapeInfo, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, bool biasCorrected, void *reductionBuffer) {
auto x = static_cast<X*>(vx); auto x = static_cast<X*>(vx);
auto extraParams = static_cast<Z*>(vextraParams); auto extraParams = static_cast<Z*>(vextraParams);
auto z = reinterpret_cast<Z*>(vz); auto z = reinterpret_cast<Z*>(vz);
auto reductionPointerA = reinterpret_cast<Z*>(reductionBuffer); auto reductionPointerA = reinterpret_cast<Z*>(reductionBuffer);

View File

@ -36,7 +36,7 @@ __global__ void transformAnySimple(void *x, Nd4jLong *xShapeInfo, int xRank,
int *allocationPointer, int *allocationPointer,
void *reductionPointer, void *reductionPointer,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
functions::transform::TransformAny<X,Z>::template transformCuda<OpType>(x,xShapeInfo,params,z,zShapeInfo,allocationPointer,reductionPointer,tadShapeInfo, tadOffsets); functions::transform::TransformAny<X,Z>::template transformCuda<OpType>(x,xShapeInfo,params,z,zShapeInfo,allocationPointer,reductionPointer,tadShapeInfo, tadOffsets);
} }
@ -57,7 +57,7 @@ namespace functions {
__device__ void TransformAny<X,Z>::transformCuda(void *vx, Nd4jLong *xShapeInfo, __device__ void TransformAny<X,Z>::transformCuda(void *vx, Nd4jLong *xShapeInfo,
void *vparams, void *vparams,
void *vz, Nd4jLong *zShapeInfo, void *vz, Nd4jLong *zShapeInfo,
int *allocationPointer, void *vreductionPointer, int *allocationPointer, void *vreductionPointer,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
auto x = reinterpret_cast<X*>(vx); auto x = reinterpret_cast<X*>(vx);
@ -70,9 +70,9 @@ namespace functions {
__shared__ char xOrder; __shared__ char xOrder;
__shared__ char zOrder; __shared__ char zOrder;
__shared__ Nd4jLong length; __shared__ Nd4jLong length;
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
xEws = shape::elementWiseStride(xShapeInfo); xEws = shape::elementWiseStride(xShapeInfo);
zEws = shape::elementWiseStride(zShapeInfo); zEws = shape::elementWiseStride(zShapeInfo);
xOrder = shape::order(xShapeInfo); xOrder = shape::order(xShapeInfo);
@ -84,26 +84,26 @@ namespace functions {
auto tid = blockIdx.x * blockDim.x + threadIdx.x; auto tid = blockIdx.x * blockDim.x + threadIdx.x;
int totalThreads = gridDim.x * blockDim.x; int totalThreads = gridDim.x * blockDim.x;
if(xEws > 0 && zEws > 0 && xOrder == zOrder) { if(xEws > 0 && zEws > 0 && xOrder == zOrder) {
for (int i = tid; i < length; i += totalThreads) for (int i = tid; i < length; i += totalThreads)
z[i * zEws] = OpType::op(x[i * xEws], params); z[i * zEws] = OpType::op(x[i * xEws], params);
} }
else { else {
if(vx == vz) { if(vx == vz) {
for (Nd4jLong i = tid; i < length; i+= totalThreads) { for (Nd4jLong i = tid; i < length; i+= totalThreads) {
auto xOffset = shape::getIndexOffset(i, xShapeInfo, length); auto xOffset = shape::getIndexOffset(i, xShapeInfo);
z[xOffset] = OpType::op(x[xOffset], params); z[xOffset] = OpType::op(x[xOffset], params);
} }
} }
else { else {
for (Nd4jLong i = tid; i < length; i+= totalThreads) { for (Nd4jLong i = tid; i < length; i+= totalThreads) {
auto xOffset = shape::getIndexOffset(i, xShapeInfo, length); auto xOffset = shape::getIndexOffset(i, xShapeInfo);
auto zOffset = shape::getIndexOffset(i, zShapeInfo, length); auto zOffset = shape::getIndexOffset(i, zShapeInfo);
z[zOffset] = OpType::op(x[xOffset], params); z[zOffset] = OpType::op(x[xOffset], params);
} }
} }
} }
}; };

View File

@ -68,16 +68,16 @@ namespace functions {
if(OpType::requiresSpecial) { if(OpType::requiresSpecial) {
OpType::execSpecialCuda(x,xShapeInfo,z,zShapeInfo,params, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets); OpType::execSpecialCuda(x,xShapeInfo,z,zShapeInfo,params, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets);
return; return;
} }
else { else {
__shared__ Nd4jLong xEws; __shared__ Nd4jLong xEws;
__shared__ Nd4jLong zEws; __shared__ Nd4jLong zEws;
__shared__ char xOrder; __shared__ char xOrder;
__shared__ char zOrder; __shared__ char zOrder;
__shared__ Nd4jLong length; __shared__ Nd4jLong length;
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
xEws = shape::elementWiseStride(xShapeInfo); xEws = shape::elementWiseStride(xShapeInfo);
zEws = shape::elementWiseStride(zShapeInfo); zEws = shape::elementWiseStride(zShapeInfo);
xOrder = shape::order(xShapeInfo); xOrder = shape::order(xShapeInfo);
@ -87,28 +87,28 @@ namespace functions {
__syncthreads(); __syncthreads();
auto tid = blockIdx.x * blockDim.x + threadIdx.x; auto tid = blockIdx.x * blockDim.x + threadIdx.x;
int totalThreads = gridDim.x * blockDim.x; int totalThreads = gridDim.x * blockDim.x;
if(xEws > 0 && zEws > 0 && xOrder == zOrder) {
if(xEws > 0 && zEws > 0 && xOrder == zOrder) {
for (int i = tid; i < length; i += totalThreads) for (int i = tid; i < length; i += totalThreads)
z[i * zEws] = OpType::op(x[i * xEws], params); z[i * zEws] = OpType::op(x[i * xEws], params);
} }
else { else {
if(vx == vz) { if(vx == vz) {
for (Nd4jLong i = tid; i < length; i+= totalThreads) { for (Nd4jLong i = tid; i < length; i+= totalThreads) {
auto xOffset = shape::getIndexOffset(i, xShapeInfo, length); auto xOffset = shape::getIndexOffset(i, xShapeInfo);
z[xOffset] = OpType::op(x[xOffset], params); z[xOffset] = OpType::op(x[xOffset], params);
} }
} }
else { else {
for (Nd4jLong i = tid; i < length; i+= totalThreads) { for (Nd4jLong i = tid; i < length; i+= totalThreads) {
auto xOffset = shape::getIndexOffset(i, xShapeInfo, length); auto xOffset = shape::getIndexOffset(i, xShapeInfo);
auto zOffset = shape::getIndexOffset(i, zShapeInfo, length); auto zOffset = shape::getIndexOffset(i, zShapeInfo);
z[zOffset] = OpType::op(x[xOffset], params); z[zOffset] = OpType::op(x[xOffset], params);
} }
} }
} }
} }
}; };

View File

@ -35,7 +35,7 @@ __global__ void transformFloatSimple(void *x, Nd4jLong *xShapeInfo, int xRank,
int *allocationPointer, int *allocationPointer,
void *reductionPointer, void *reductionPointer,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
functions::transform::TransformFloat<X,Z>::template transformCuda<OpType>( functions::transform::TransformFloat<X,Z>::template transformCuda<OpType>(
x, xShapeInfo, x, xShapeInfo,
params, params,
@ -64,7 +64,7 @@ namespace functions {
void *vparams, void *vparams,
void *vz, void *vz,
Nd4jLong *zShapeInfo, Nd4jLong *zShapeInfo,
int *allocationPointer, void *vreductionPointer, int *allocationPointer, void *vreductionPointer,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
auto x = reinterpret_cast<X*>(vx); auto x = reinterpret_cast<X*>(vx);
@ -75,7 +75,7 @@ namespace functions {
if(OpType::requiresSpecial) { if(OpType::requiresSpecial) {
OpType::execSpecialCuda(x,xShapeInfo,z,zShapeInfo,params, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets); OpType::execSpecialCuda(x,xShapeInfo,z,zShapeInfo,params, allocationPointer, reductionPointer, tadShapeInfo, tadOffsets);
return; return;
} }
else { else {
__shared__ Nd4jLong xEws; __shared__ Nd4jLong xEws;
@ -83,9 +83,9 @@ namespace functions {
__shared__ char xOrder; __shared__ char xOrder;
__shared__ char zOrder; __shared__ char zOrder;
__shared__ Nd4jLong length; __shared__ Nd4jLong length;
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
xEws = shape::elementWiseStride(xShapeInfo); xEws = shape::elementWiseStride(xShapeInfo);
zEws = shape::elementWiseStride(zShapeInfo); zEws = shape::elementWiseStride(zShapeInfo);
xOrder = shape::order(xShapeInfo); xOrder = shape::order(xShapeInfo);
@ -95,24 +95,24 @@ namespace functions {
__syncthreads(); __syncthreads();
auto tid = blockIdx.x * blockDim.x + threadIdx.x; auto tid = blockIdx.x * blockDim.x + threadIdx.x;
int totalThreads = gridDim.x * blockDim.x; int totalThreads = gridDim.x * blockDim.x;
if(xEws > 0 && zEws > 0 && xOrder == zOrder) {
if(xEws > 0 && zEws > 0 && xOrder == zOrder) {
for (Nd4jLong i = tid; i < length; i += totalThreads) for (Nd4jLong i = tid; i < length; i += totalThreads)
z[i * zEws] = OpType::op(x[i * xEws], params); z[i * zEws] = OpType::op(x[i * xEws], params);
} }
else { else {
if(vx == vz) { if(vx == vz) {
for (Nd4jLong i = tid; i < length; i+= totalThreads) { for (Nd4jLong i = tid; i < length; i+= totalThreads) {
auto xOffset = shape::getIndexOffset(i, xShapeInfo, length); auto xOffset = shape::getIndexOffset(i, xShapeInfo);
z[xOffset] = OpType::op(x[xOffset], params); z[xOffset] = OpType::op(x[xOffset], params);
} }
} }
else { else {
for (Nd4jLong i = tid; i < length; i+= totalThreads) { for (Nd4jLong i = tid; i < length; i+= totalThreads) {
auto xOffset = shape::getIndexOffset(i, xShapeInfo, length); auto xOffset = shape::getIndexOffset(i, xShapeInfo);
auto zOffset = shape::getIndexOffset(i, zShapeInfo, length); auto zOffset = shape::getIndexOffset(i, zShapeInfo);
z[zOffset] = OpType::op(x[xOffset], params); z[zOffset] = OpType::op(x[xOffset], params);
} }
} }

View File

@ -95,14 +95,14 @@ namespace functions {
else { else {
if(vx == vz) { if(vx == vz) {
for (Nd4jLong i = tid; i < length; i+= totalThreads) { for (Nd4jLong i = tid; i < length; i+= totalThreads) {
auto xOffset = shape::getIndexOffset(i, xShapeInfo, length); auto xOffset = shape::getIndexOffset(i, xShapeInfo);
z[xOffset] = OpType::op(x[xOffset], params); z[xOffset] = OpType::op(x[xOffset], params);
} }
} }
else { else {
for (Nd4jLong i = tid; i < length; i+= totalThreads) { for (Nd4jLong i = tid; i < length; i+= totalThreads) {
auto xOffset = shape::getIndexOffset(i, xShapeInfo, length); auto xOffset = shape::getIndexOffset(i, xShapeInfo);
auto zOffset = shape::getIndexOffset(i, zShapeInfo, length); auto zOffset = shape::getIndexOffset(i, zShapeInfo);
z[zOffset] = OpType::op(x[xOffset], params); z[zOffset] = OpType::op(x[xOffset], params);
} }
} }

View File

@ -35,7 +35,7 @@ __global__ void transformStrictSimple(void *x, Nd4jLong *xShapeInfo, int xRank,
int *allocationPointer, int *allocationPointer,
void *reductionPointer, void *reductionPointer,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
functions::transform::TransformStrict<X>::template transformCuda<OpType>(x,xShapeInfo,params,z,zShapeInfo,allocationPointer,reductionPointer,tadShapeInfo, tadOffsets); functions::transform::TransformStrict<X>::template transformCuda<OpType>(x,xShapeInfo,params,z,zShapeInfo,allocationPointer,reductionPointer,tadShapeInfo, tadOffsets);
} }
@ -97,14 +97,14 @@ namespace functions {
else { else {
if(vx == vz) { if(vx == vz) {
for (Nd4jLong i = tid; i < length; i+= totalThreads) { for (Nd4jLong i = tid; i < length; i+= totalThreads) {
auto xOffset = shape::getIndexOffset(i, xShapeInfo, length); auto xOffset = shape::getIndexOffset(i, xShapeInfo);
z[xOffset] = OpType::op(x[xOffset], params); z[xOffset] = OpType::op(x[xOffset], params);
} }
} }
else { else {
for (Nd4jLong i = tid; i < length; i+= totalThreads) { for (Nd4jLong i = tid; i < length; i+= totalThreads) {
auto xOffset = shape::getIndexOffset(i, xShapeInfo, length); auto xOffset = shape::getIndexOffset(i, xShapeInfo);
auto zOffset = shape::getIndexOffset(i, zShapeInfo, length); auto zOffset = shape::getIndexOffset(i, zShapeInfo);
z[zOffset] = OpType::op(x[xOffset], params); z[zOffset] = OpType::op(x[xOffset], params);
} }
} }

View File

@ -24,6 +24,7 @@
#include <ops/declarable/CustomOperations.h> #include <ops/declarable/CustomOperations.h>
#include <ops/declarable/helpers/convolutions.h> #include <ops/declarable/helpers/convolutions.h>
#include <ops/declarable/helpers/addBias.h>
#include <MmulHelper.h> #include <MmulHelper.h>
namespace nd4j { namespace nd4j {
@ -162,7 +163,8 @@ CUSTOM_OP_IMPL(conv3dnew, 2, 1, false, 0, 13) {
MmulHelper::tensorDot(&columns, weights, output, {1,2,3,4}, {3,0,1,2}, permutForOutput); MmulHelper::tensorDot(&columns, weights, output, {1,2,3,4}, {3,0,1,2}, permutForOutput);
if(bias) if(bias)
output->applyBroadcast(broadcast::Add, {indIOioC}, bias); // output->applyBroadcast(broadcast::Add, {indIOioC}, bias);
helpers::addBias(block, *output, *bias, *output, isNCDHW);
if(!isNCDHW) if(!isNCDHW)
delete input; delete input;

View File

@ -27,7 +27,7 @@
#include <declarable/helpers/convolutions.h> #include <declarable/helpers/convolutions.h>
#include <ops/declarable/helpers/im2col.h> #include <ops/declarable/helpers/im2col.h>
#include <ops/declarable/helpers/col2im.h> #include <ops/declarable/helpers/col2im.h>
#include <ops/declarable/helpers/addBias.h>
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
@ -80,7 +80,8 @@ CUSTOM_OP_IMPL(deconv2d, 2, 1, false, 0, 9) {
//----- add biases if required -----// //----- add biases if required -----//
if(bias) if(bias)
output->applyBroadcast(broadcast::Add, {1}, bias); // output->applyBroadcast(broadcast::Add, {1}, bias);
helpers::addBias(block, *output, *bias, *output, true);
if(!isNCHW) if(!isNCHW)
delete output; delete output;

View File

@ -23,6 +23,7 @@
#include <ops/declarable/CustomOperations.h> #include <ops/declarable/CustomOperations.h>
#include <ops/declarable/helpers/convolutions.h> #include <ops/declarable/helpers/convolutions.h>
#include <ops/declarable/helpers/addBias.h>
#include <MmulHelper.h> #include <MmulHelper.h>
namespace nd4j { namespace nd4j {
@ -79,7 +80,8 @@ CUSTOM_OP_IMPL(deconv3d, 2, 1, false, 0, 13) {
//----- add biases if required -----// //----- add biases if required -----//
if(bias) if(bias)
output->applyBroadcast(broadcast::Add,{1}, bias); // output->applyBroadcast(broadcast::Add,{1}, bias);
helpers::addBias(block, *output, *bias, *output, true);
if(!isNCDHW) if(!isNCDHW)
delete output; delete output;

View File

@ -15,107 +15,111 @@
******************************************************************************/ ******************************************************************************/
// //
// @author raver119@gmail.com // @author raver119@gmail.com
// @author Yurii Shyrma (iuriish@yahoo.com)
// //
#include <op_boilerplate.h> #include <op_boilerplate.h>
#if NOT_EXCLUDED(OP_biasadd) #if NOT_EXCLUDED(OP_biasadd)
#include <ops/declarable/CustomOperations.h> #include <ops/declarable/CustomOperations.h>
#include<ops/declarable/helpers/addBias.h>
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
DECLARE_TYPES(biasadd) {
getOpDescriptor()
->setAllowedInputTypes(nd4j::DataType::ANY)
->setAllowedOutputTypes({ALL_FLOATS});
}
CUSTOM_OP_IMPL(biasadd, 2, 1, true, 0, 0) { ////////////////////////////////////////////////////////////////////
//REQUIRE_OK(this->validateInput2D(block)); CUSTOM_OP_IMPL(biasadd, 2, 1, true, 0, 0) {
auto input = INPUT_VARIABLE(0);
auto bias = INPUT_VARIABLE(1);
REQUIRE_TRUE(bias->isRowVector(), 0, "Bias array should be a vector"); auto input = INPUT_VARIABLE(0);
auto bias = INPUT_VARIABLE(1);
auto z = OUTPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0);
if (input->isMatrix()) const bool isNCHW = !block.getBArguments()->empty() ? B_ARG(0) : false;
input->addRowVector(bias, z); const int channelDim = isNCHW ? 1 : input->rankOf() - 1; // second or last
else {
// TODO: we might want to use NDArray::applyTrueBroadcast here, like AddOp does
std::vector<Nd4jLong> shape({-1, bias->lengthOf()});
//nd4j_debug("Reshaping to: [%i, %i]\n", -1, (int) bias->lengthOf());
auto tArr = input->reshape(input->ordering(), shape);
auto zArr = z->reshape(z->ordering(), shape);
tArr.addRowVector(bias, &zArr);
}
STORE_RESULT(*z); REQUIRE_TRUE(bias->rankOf() == 1, 0, "BIASADD CUSTOM_OP: bias array should have rank = 1, but got %i instead !", bias->rankOf());
return Status::OK(); REQUIRE_TRUE(bias->sizeAt(0) == input->sizeAt(channelDim), 0, "BIASADD CUSTOM_OP: shapes of bias %s and input %s arrays are not suitable for broadcast operation along channel dimension %i !", ShapeUtils::shapeAsString(bias).c_str(), ShapeUtils::shapeAsString(input).c_str(), channelDim);
}
DECLARE_SYN(bias_add, biasadd);
DECLARE_SHAPE_FN(biasadd) { REQUIRE_TRUE(output->isSameShape(input), 0, "BIASADD CUSTOM_OP: wrong shape of output array, expected is %s but got %s instead !", ShapeUtils::shapeAsString(input).c_str(), ShapeUtils::shapeAsString(output).c_str());
auto xShape = inputShape->at(0);
auto yShape = inputShape->at(1);
auto dtype = ArrayOptions::dataType(yShape); helpers::addBias(block, *input, *bias, *output, isNCHW);
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(xShape, dtype))); // input->applyBroadcast(nd4j::broadcast::Add, {channelDim}, bias, output);
}
DECLARE_TYPES(biasadd_bp) { return Status::OK();
getOpDescriptor() }
->setAllowedInputTypes(nd4j::DataType::ANY) DECLARE_SYN(bias_add, biasadd);
->setAllowedOutputTypes({ALL_FLOATS});
}
CUSTOM_OP_IMPL(biasadd_bp, 3, 2, false, 0, 0) { ////////////////////////////////////////////////////////////////////
auto input = INPUT_VARIABLE(0); DECLARE_SHAPE_FN(biasadd) {
auto bias = INPUT_VARIABLE(1); auto xShape = inputShape->at(0);
auto epsilonNext = INPUT_VARIABLE(2); auto yShape = inputShape->at(1);
auto epsilon = OUTPUT_VARIABLE(0); auto dtype = ArrayOptions::dataType(yShape);
auto gradB = OUTPUT_VARIABLE(1); return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(xShape, dtype)));
}
epsilon->assign(epsilonNext); DECLARE_TYPES(biasadd) {
getOpDescriptor()
->setAllowedInputTypes(nd4j::DataType::ANY)
->setAllowedOutputTypes({ALL_FLOATS});
}
// cnn case ////////////////////////////////////////////////////////////////////
if (input->rankOf() == 4) { CUSTOM_OP_IMPL(biasadd_bp, 3, 2, false, 0, 0) {
auto epsilonNext2d = epsilonNext->permute({1, 0, 2, 3}); auto input = INPUT_VARIABLE(0);
epsilonNext2d.reshapei('c', {(int) bias->lengthOf(), -1}); auto bias = INPUT_VARIABLE(1);
auto epsilonNext = INPUT_VARIABLE(2);
auto sum = epsilonNext2d.reduceAlongDimension(reduce::Sum, {1}); auto epsilon = OUTPUT_VARIABLE(0);
gradB->assign(sum); auto gradB = OUTPUT_VARIABLE(1);
delete sum; epsilon->assign(epsilonNext);
} else if (input->rankOf() == 2) {
// regular fully-connected case
auto sum = epsilonNext->reduceAlongDimension(reduce::Sum, {0});
gradB->assign(sum);
delete sum;
}
return ND4J_STATUS_OK; // cnn case
} if (input->rankOf() == 4) {
DECLARE_SYN(BiasAddGrad, biasadd_bp); auto epsilonNext2d = epsilonNext->permute({1, 0, 2, 3});
epsilonNext2d.reshapei('c', {(int) bias->lengthOf(), -1});
DECLARE_SHAPE_FN(biasadd_bp) { auto sum = epsilonNext2d.reduceAlongDimension(reduce::Sum, {1});
auto input = inputShape->at(0); gradB->assign(sum);
auto bias = inputShape->at(1);
Nd4jLong* epsShape; delete sum;
Nd4jLong* gradShape; } else if (input->rankOf() == 2) {
// regular fully-connected case
auto sum = epsilonNext->reduceAlongDimension(reduce::Sum, {0});
gradB->assign(sum);
COPY_SHAPE(input, epsShape); delete sum;
COPY_SHAPE(bias, gradShape);
return SHAPELIST(CONSTANT(epsShape), CONSTANT(gradShape));
}
} }
return ND4J_STATUS_OK;
}
DECLARE_SYN(BiasAddGrad, biasadd_bp);
DECLARE_SHAPE_FN(biasadd_bp) {
auto input = inputShape->at(0);
auto bias = inputShape->at(1);
Nd4jLong* epsShape;
Nd4jLong* gradShape;
COPY_SHAPE(input, epsShape);
COPY_SHAPE(bias, gradShape);
return SHAPELIST(CONSTANT(epsShape), CONSTANT(gradShape));
}
DECLARE_TYPES(biasadd_bp) {
getOpDescriptor()
->setAllowedInputTypes(nd4j::DataType::ANY)
->setAllowedOutputTypes({ALL_FLOATS});
}
}
} }
#endif #endif

View File

@ -43,14 +43,15 @@ DECLARE_SHAPE_FN(matrix_diag) {
auto in = inputShape->at(0); auto in = inputShape->at(0);
int inRank = shape::rank(in); int inRank = shape::rank(in);
// if for example diagonal array has shape [A,B,C] then output array has shape [A,B,C,C]
int outRank = inRank + 1; int outRank = inRank + 1;
auto lastDimension = shape::sizeAt(in, -1);
ALLOCATE(outShapeInfo, block.getWorkspace(), shape::shapeInfoLength(outRank), Nd4jLong); ALLOCATE(outShapeInfo, block.getWorkspace(), shape::shapeInfoLength(outRank), Nd4jLong);
outShapeInfo[0] = outRank; outShapeInfo[0] = outRank;
for(int i = 0; i < inRank; ++i) for(int i = 0; i < inRank; ++i)
outShapeInfo[i + 1] = shape::sizeAt(in, i); outShapeInfo[i + 1] = shape::sizeAt(in, i);
outShapeInfo[outRank] = lastDimension; outShapeInfo[outRank] = shape::sizeAt(in, -1);
ShapeUtils::updateStridesAndType(outShapeInfo, in, shape::order(in)); ShapeUtils::updateStridesAndType(outShapeInfo, in, shape::order(in));

View File

@ -23,7 +23,7 @@
#include <ops/declarable/CustomOperations.h> #include <ops/declarable/CustomOperations.h>
#include <ops/declarable/helpers/reverse.h> #include <ops/declarable/helpers/reverse.h>
#include <ops/declarable/helpers/addBias.h>
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
@ -59,7 +59,8 @@ namespace ops {
output->applyBroadcast(nd4j::broadcast::Multiply, {dimC}, gain); output->applyBroadcast(nd4j::broadcast::Multiply, {dimC}, gain);
if(bias != nullptr) { if(bias != nullptr) {
// output->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Add(), bias, output); // output->applyTrueBroadcast(nd4j::BroadcastOpsTuple::Add(), bias, output);
output->applyBroadcast(nd4j::broadcast::Add, {dimC}, bias); // output->applyBroadcast(nd4j::broadcast::Add, {dimC}, bias);
helpers::addBias(block, *output, *bias, *output, isNCHW);
} }
return Status::OK(); return Status::OK();

View File

@ -79,36 +79,44 @@ namespace nd4j {
* Inserts elements provided by diagonal array into the main diagonal of innermost matrices of input array * Inserts elements provided by diagonal array into the main diagonal of innermost matrices of input array
* *
* Input arrays: * Input arrays:
* input: input array, considered as batch of matrices * 0: input array, considered as batch of matrices
* diagonal: array containing elements to be inserted into input array, * 1: diagonal array containing elements to be inserted into input array,
* following rank condition should be satisfied: diagonal_rank = input_rank - 1, * following rank condition should be satisfied: diagonal_rank = input_rank - 1,
* the shapes of diagonal and input arrays must be equal except last dimension of input array, * the shapes of diagonal and input arrays must be equal except last dimension of input array,
* for example if input_shape = [A,B,C,D] then diagonal_shape = [A,B,C], * for example if input_shape = [A,B,C,D] then diagonal_shape = [A,B,C],
* also last dimension of diagonal array should be equal to smaller of last and last but one input dimensions * also last dimension of diagonal array should be equal to smaller of last and last but one input dimensions
* that is: diagonal_shape[-1] = min(input_shape[-1], input_shape[-2]) * that is: diagonal_shape[-1] = min(input_shape[-1], input_shape[-2])
* *
* Output array: * Output array:
* has the same shape as input, corresponding diagonal elements are substituted * 0: has the same shape as input, corresponding diagonal elements are substituted
*/ */
#if NOT_EXCLUDED(OP_matrix_set_diag) #if NOT_EXCLUDED(OP_matrix_set_diag)
DECLARE_CONFIGURABLE_OP(matrix_set_diag, 2, 1, false, 0, 0); DECLARE_CONFIGURABLE_OP(matrix_set_diag, 2, 1, false, 0, 0);
#endif #endif
/** /**
* Returns a batched matrix tensor with diagonal values given (as TF.matrix_diag). * Inserts elements provided by diagonal array into the main diagonal of innermost matrices of output array,
*/ * rest output elements are set to zeros
*
* Input array:
* diagonal: array containing elements to be inserted into output array,
* following rank condition is present: diagonal_rank = ouput_rank - 1
*
* Output array:
* 0: is considered as batch of matrices, if for example diagonal array has shape [A,B,C] then output array has shape [A,B,C,C]
*/
DECLARE_CUSTOM_OP(matrix_diag, 1, 1, false, 0, 0); DECLARE_CUSTOM_OP(matrix_diag, 1, 1, false, 0, 0);
/** /**
* This op calculates regularized incomplete beta integral Ix(a, b). * This op calculates regularized incomplete beta integral Ix(a, b).
* Implementation is based on two algorithms depending on input values of a and b: * Implementation is based on two algorithms depending on input values of a and b:
* - when a and b are both > maxValue (3000.), then apply Gauss-Legendre quadrature method * - when a and b are both > maxValue (3000.), then Gauss-Legendre quadrature method is applied
* - when a and b are both <= maxValue (3000.), then apply modified Lentzs algorithm for continued fractions * - when a and b are both <= maxValue (3000.), then modified Lentzs algorithm for continued fractions is applied
* *
* Input arrays: * Input arrays:
* a: define power t^{a-1}, must be > 0, type float. * a: defines power t^{a-1}, must be > 0, type float.
* b: define power (1-t)^{b-1}, must be > 0, type float. * b: defines power (1-t)^{b-1}, must be > 0, type float.
* x: define upper limit of integration, must be within (0 <= x <= 1) range, type float. * x: defines upper limit of integration, must be within (0 <= x <= 1) range, type float.
* *
* Output array: * Output array:
* 0: values of regularized incomplete beta integral that corresponds to variable upper limit x, type float * 0: values of regularized incomplete beta integral that corresponds to variable upper limit x, type float

View File

@ -22,14 +22,15 @@
#define LIBND4J_ADDBIAS_H #define LIBND4J_ADDBIAS_H
#include <ops/declarable/helpers/helpers.h> #include <ops/declarable/helpers/helpers.h>
#include <graph/Context.h>
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
namespace helpers { namespace helpers {
void addBias(NDArray& input, const NDArray& bias, const bool isNCHW); void addBias(graph::Context& block, const NDArray& input, const NDArray& bias, NDArray& output, const bool isNCHW);
} }
} }

View File

@ -91,19 +91,19 @@ static void softMaxForVector_(void *input, Nd4jLong *inShapeInfo, void *output,
PRAGMA_OMP_SIMD_ARGS(reduction(OMP_MAXT:max)) PRAGMA_OMP_SIMD_ARGS(reduction(OMP_MAXT:max))
for (int i = 0; i < length; i++) { for (int i = 0; i < length; i++) {
const Nd4jLong offset = shape::getIndexOffset(i, inShapeInfo, length); const Nd4jLong offset = shape::getIndexOffset(i, inShapeInfo);
max = nd4j::math::nd4j_max<T>(max, inBuff[offset]); max = nd4j::math::nd4j_max<T>(max, inBuff[offset]);
} }
PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(reduction(OMP_SUMT:sum)) PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(reduction(OMP_SUMT:sum))
for (int i = 0; i < length; i++) { for (int i = 0; i < length; i++) {
const Nd4jLong offset = shape::getIndexOffset(i, inShapeInfo, length); const Nd4jLong offset = shape::getIndexOffset(i, inShapeInfo);
outBuff[offset] = nd4j::math::nd4j_exp<T, T>(inBuff[offset] - max); outBuff[offset] = nd4j::math::nd4j_exp<T, T>(inBuff[offset] - max);
sum += outBuff[offset]; sum += outBuff[offset];
} }
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (int i = 0; i < length; i++) { for (int i = 0; i < length; i++) {
const Nd4jLong offset = shape::getIndexOffset(i, inShapeInfo, length); const Nd4jLong offset = shape::getIndexOffset(i, inShapeInfo);
outBuff[offset] /= sum; outBuff[offset] /= sum;
outBuff[offset] *= (1.f - outBuff[offset]); // derivative outBuff[offset] *= (1.f - outBuff[offset]); // derivative
} }

View File

@ -28,70 +28,116 @@ namespace helpers {
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template <typename X, typename Y> template <typename X, typename Y>
static void addBias_(NDArray& input, const NDArray& bias, const bool isNCHW) { static void addBias_(const NDArray& input, const NDArray& bias, NDArray &output, const bool isNCHW) {
// input [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) // bias [oC]
// bias [oC]
X* inBuff = input.bufferAsT<X>(); // if(input_rank == 4)
const Y* biasBuff = bias.bufferAsT<Y>(); // input and output have same shapes: [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
// if(input_rank == 5)
// input and output have same shapes: [bS, oD, oH, oW, oC] (NHWC) or [bS, oD, oC, oH, oW] (NCHW)
// else
// apply applyBroadCast
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
bS = input.sizeAt(0);
const Nd4jLong stride0 = input.stridesOf()[0];
const Nd4jLong stride1 = input.stridesOf()[1];
const Nd4jLong stride2 = input.stridesOf()[2];
uint biasShapeInfoCast[MAX_RANK]; const X* x = input.bufferAsT<X>();
bool canCastBias = nd4j::DataTypeUtils::castShapeInfo(bias.getShapeInfo(), biasShapeInfoCast); const Y* y = bias.bufferAsT<Y>();
X* z = output.bufferAsT<X>();
if(isNCHW) {
oC = input.sizeAt(1);
oH = input.sizeAt(2);
oW = input.sizeAt(3);
const int oHoW = oH*oW; const bool inOutAreSame = x == z;
PRAGMA_OMP_PARALLEL_FOR_COLLAPSE(2) const uint bS = output.sizeAt(0); // batch size
for (int i = 0; i < bS; ++i) { const Nd4jLong yStrideC = bias.stridesOf()[0];
for (int c = 0; c < oC; ++c) { const Nd4jLong zStrideB = output.stridesOf()[0];
auto biasOffset = shape::indexOffset(c, bias.getShapeInfo(), biasShapeInfoCast, oC, canCastBias);
auto inOffset = i * stride0 + c * stride1;
PRAGMA_OMP_SIMD if(output.rankOf() == 4) {
for (uint k = 0; k < oHoW; ++k)
inBuff[inOffset + k] += static_cast<X>(biasBuff[biasOffset]); const uint C = isNCHW ? output.sizeAt(1) : output.sizeAt(3); // channels
} const uint oH = isNCHW ? output.sizeAt(2) : output.sizeAt(1); // height
const uint oW = isNCHW ? output.sizeAt(3) : output.sizeAt(2); // width
const Nd4jLong zStrideC = isNCHW ? output.stridesOf()[1] : output.stridesOf()[3];
const Nd4jLong zStrideH = isNCHW ? output.stridesOf()[2] : output.stridesOf()[1];
const Nd4jLong zStrideW = isNCHW ? output.stridesOf()[3] : output.stridesOf()[2];
if(inOutAreSame) {
PRAGMA_OMP_PARALLEL_FOR_ARGS(collapse(4))
for(uint b = 0; b < bS; ++b)
for(uint c = 0; c < C; ++c)
for(uint h = 0; h < oH ; ++h)
for(uint w = 0; w < oW ; ++w)
z[b*zStrideB + c*zStrideC + h*zStrideH + w*zStrideW] += static_cast<X>(y[c*yStrideC]);
}
else {
const Nd4jLong xStrideB = input.stridesOf()[0];
const Nd4jLong xStrideC = isNCHW ? input.stridesOf()[1] : input.stridesOf()[3];
const Nd4jLong xStrideH = isNCHW ? input.stridesOf()[2] : input.stridesOf()[1];
const Nd4jLong xStrideW = isNCHW ? input.stridesOf()[3] : input.stridesOf()[2];
PRAGMA_OMP_PARALLEL_FOR_ARGS(collapse(4))
for(uint b = 0; b < bS; ++b)
for(uint c = 0; c < C; ++c)
for(uint h = 0; h < oH ; ++h)
for(uint w = 0; w < oW ; ++w)
z[b*zStrideB + c*zStrideC + h*zStrideH + w*zStrideW] = x[b*xStrideB + c*xStrideC + h*xStrideH + w*xStrideW] + static_cast<X>(y[c*yStrideC]);
}
}
else if(output.rankOf() == 5) {
const uint C = isNCHW ? output.sizeAt(1) : output.sizeAt(4); // channels
const uint oD = isNCHW ? output.sizeAt(2) : output.sizeAt(1); // depth
const uint oH = isNCHW ? output.sizeAt(3) : output.sizeAt(2); // height
const uint oW = isNCHW ? output.sizeAt(4) : output.sizeAt(3); // width
const Nd4jLong zStrideC = isNCHW ? output.stridesOf()[1] : output.stridesOf()[4];
const Nd4jLong zStrideD = isNCHW ? output.stridesOf()[2] : output.stridesOf()[1];
const Nd4jLong zStrideH = isNCHW ? output.stridesOf()[3] : output.stridesOf()[2];
const Nd4jLong zStrideW = isNCHW ? output.stridesOf()[4] : output.stridesOf()[3];
if(inOutAreSame) {
PRAGMA_OMP_PARALLEL_FOR_ARGS(collapse(5))
for(uint b = 0; b < bS; ++b)
for(uint c = 0; c < C; ++c)
for(uint d = 0; d < oD ; ++d)
for(uint h = 0; h < oH ; ++h)
for(uint w = 0; w < oW ; ++w)
z[b*zStrideB + c*zStrideC + d*zStrideD + h*zStrideH + w*zStrideW] += static_cast<X>(y[c*yStrideC]);
}
else {
const Nd4jLong xStrideB = input.stridesOf()[0];
const Nd4jLong xStrideC = isNCHW ? input.stridesOf()[1] : input.stridesOf()[4];
const Nd4jLong xStrideD = isNCHW ? input.stridesOf()[2] : input.stridesOf()[1];
const Nd4jLong xStrideH = isNCHW ? input.stridesOf()[3] : input.stridesOf()[2];
const Nd4jLong xStrideW = isNCHW ? input.stridesOf()[4] : input.stridesOf()[3];
PRAGMA_OMP_PARALLEL_FOR_ARGS(collapse(5))
for(uint b = 0; b < bS; ++b)
for(uint c = 0; c < C; ++c)
for(uint d = 0; d < oD ; ++d)
for(uint h = 0; h < oH ; ++h)
for(uint w = 0; w < oW ; ++w)
z[b*zStrideB + c*zStrideC + d*zStrideD + h*zStrideH + w*zStrideW] = x[b*xStrideB + c*xStrideC + d*xStrideD + h*xStrideH + w*xStrideW] + static_cast<X>(y[c*yStrideC]);
} }
} }
else { else {
const int channelDim = isNCHW ? 1 : input.rankOf() - 1; // second or last
oC = input.sizeAt(3); const_cast<NDArray&>(input).applyBroadcast(nd4j::broadcast::Add, {channelDim}, &bias, &output);
oH = input.sizeAt(1); }
oW = input.sizeAt(2);
PRAGMA_OMP_PARALLEL_FOR
for (int i = 0; i < bS*oH*oW; ++i) {
PRAGMA_OMP_SIMD
for (int c = 0; c < oC; ++c) {
auto biasOffset = shape::indexOffset(c, bias.getShapeInfo(), biasShapeInfoCast, oC, canCastBias);
inBuff[i * oC + c] += static_cast<X>(biasBuff[biasOffset]);
}
}
}
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
void addBias(NDArray& input, const NDArray& bias, const bool isNCHW) { void addBias(nd4j::graph::Context& block, const NDArray& input, const NDArray& bias, NDArray& output, const bool isNCHW) {
BUILD_DOUBLE_SELECTOR(input.dataType(), bias.dataType(), addBias_, (input, bias, isNCHW), FLOAT_TYPES, FLOAT_TYPES); // bias.rankOf() == 1 ? bias : bias.reshape(bias.ordering(), {bias.lengthOf()})
BUILD_DOUBLE_SELECTOR(input.dataType(), bias.dataType(), addBias_, (input, bias, output, isNCHW), FLOAT_TYPES, FLOAT_TYPES);
} }
BUILD_DOUBLE_TEMPLATE(template void addBias_, (NDArray& input, const NDArray& bias, const bool isNCHW), FLOAT_TYPES, FLOAT_TYPES); BUILD_DOUBLE_TEMPLATE(template void addBias_, (const NDArray& input, const NDArray& bias, NDArray& output, const bool isNCHW), FLOAT_TYPES, FLOAT_TYPES);
} }
} }

View File

@ -84,7 +84,7 @@ static void batchnorm_(const NDArray* input, const NDArray* mean, const NDArray*
const Nd4jLong end = start + step; const Nd4jLong end = start + step;
// calculate offset for mean, variance, gamma, beta (all of them have the same shape) // calculate offset for mean, variance, gamma, beta (all of them have the same shape)
auto offsetSmall = shape::indexOffset(j, meanShapeInfo, meanShapeInfoCast, lenSmall, canCastMean); auto offsetSmall = shape::indexOffset(j, meanShapeInfo, meanShapeInfoCast, canCastMean);
// calculate offset for input and output (all of them have the same shape) // calculate offset for input and output (all of them have the same shape)
shape::outerArrayOffsets(inOffsets, j, inShapeInfo, meanShapeInfo, memBuff, dimsToExclude.data()); shape::outerArrayOffsets(inOffsets, j, inShapeInfo, meanShapeInfo, memBuff, dimsToExclude.data());
@ -114,7 +114,7 @@ static void batchnorm_(const NDArray* input, const NDArray* mean, const NDArray*
const Nd4jLong end = start + step; const Nd4jLong end = start + step;
// calculate offset for mean, variance, gamma, beta (all of them have the same shape) // calculate offset for mean, variance, gamma, beta (all of them have the same shape)
auto offsetSmall = shape::indexOffset(j, meanShapeInfo, meanShapeInfoCast, lenSmall, canCastMean); auto offsetSmall = shape::indexOffset(j, meanShapeInfo, meanShapeInfoCast, canCastMean);
// calculate offset for input and output (all of them have the same shape) // calculate offset for input and output (all of them have the same shape)
shape::outerArrayOffsets(inOffsets, j, inShapeInfo, meanShapeInfo, memBuff, dimsToExclude.data()); shape::outerArrayOffsets(inOffsets, j, inShapeInfo, meanShapeInfo, memBuff, dimsToExclude.data());

View File

@ -29,7 +29,7 @@ namespace helpers {
/////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////
// modified Lentzs algorithm for continued fractions, // modified Lentzs algorithm for continued fractions,
// reference: Lentz, W.J. 1976, “Generating Bessel Functions in Mie Scattering Calculations Using Continued Fractions, // reference: Lentz, W.J. 1976, “Generating Bessel Functions in Mie Scattering Calculations Using Continued Fractions
template <typename T> template <typename T>
static T continuedFraction(const T a, const T b, const T x) { static T continuedFraction(const T a, const T b, const T x) {
@ -122,9 +122,8 @@ static void betaIncForArray(nd4j::LaunchContext * context, const NDArray& a, con
int xLen = x.lengthOf(); int xLen = x.lengthOf();
PRAGMA_OMP_PARALLEL_FOR_IF(xLen > Environment::getInstance()->elementwiseThreshold()) PRAGMA_OMP_PARALLEL_FOR_IF(xLen > Environment::getInstance()->elementwiseThreshold())
for(int i = 0; i < xLen; ++i) { for(int i = 0; i < xLen; ++i)
output.p(i, betaIncCore<T>(a.e<T>(i), b.e<T>(i), x.e<T>(i))); output.t<T>(i) = betaIncCore<T>(a.t<T>(i), b.t<T>(i), x.t<T>(i));
}
} }
/////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////

View File

@ -648,7 +648,7 @@ void ConvolutionUtils::getMKLDNNMemoryDescConv3d(
//----- add biases if required -----// //----- add biases if required -----//
if(bias) if(bias)
// output->applyBroadcast(broadcast::Add, {indIOioC}, bias); // output->applyBroadcast(broadcast::Add, {indIOioC}, bias);
helpers::addBias(*output, *bias, isNCHW); helpers::addBias(block, *output, *bias, *output, isNCHW);
if(!isNCHW) if(!isNCHW)
delete input; delete input;
@ -875,7 +875,7 @@ void ConvolutionUtils::getMKLDNNMemoryDescConv3d(
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template <typename X, typename Y> template <typename X, typename Y>
static void depthwiseConv2d_(const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) { static void depthwiseConv2d_(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) {
// input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) // input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
// weights [kH, kW, iC, mC] always // weights [kH, kW, iC, mC] always
@ -922,7 +922,8 @@ void ConvolutionUtils::getMKLDNNMemoryDescConv3d(
MmulHelper::tensorDot(&columns, weights, &outputReshaped, modifColumns, {{2,0,1,3},{iC,kH*kW,mC}}, modifOutput); // [iC, bS*oH*oW, kW*kH] x [iC, kH*kW, mC] = [iC, bS*oH*oW, mC] MmulHelper::tensorDot(&columns, weights, &outputReshaped, modifColumns, {{2,0,1,3},{iC,kH*kW,mC}}, modifOutput); // [iC, bS*oH*oW, kW*kH] x [iC, kH*kW, mC] = [iC, bS*oH*oW, mC]
if(bias) if(bias)
output->applyBroadcast(broadcast::Add, {indIOioC}, bias); // output->applyBroadcast(broadcast::Add, {indIOioC}, bias);
helpers::addBias(block, *output, *bias, *output, isNCHW);
if(!isNCHW) if(!isNCHW)
delete input; delete input;
@ -2451,7 +2452,7 @@ void ConvolutionUtils::getMKLDNNMemoryDescConv3d(
BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2dBP_, (block, input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), FLOAT_TYPES); BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), conv2dBP_, (block, input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), FLOAT_TYPES);
} }
void ConvolutionUtils::depthwiseConv2d(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) { void ConvolutionUtils::depthwiseConv2d(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) {
BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2d_, (input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), FLOAT_TYPES); BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2d_, (block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), FLOAT_TYPES);
} }
void ConvolutionUtils::depthwiseConv2dBP(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) { void ConvolutionUtils::depthwiseConv2dBP(nd4j::graph::Context& block, const NDArray* input, const NDArray* weights, const NDArray* bias, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int isNCHW) {
BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2dBP_, (input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), FLOAT_TYPES); BUILD_SINGLE_SELECTOR_TWICE(input->dataType(), depthwiseConv2dBP_, (input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW), FLOAT_TYPES);

View File

@ -37,24 +37,16 @@ namespace nd4j {
cOffset += inputs[e]->lengthOf(); cOffset += inputs[e]->lengthOf();
} }
Nd4jLong xCoord[MAX_RANK];
// actually transferring data // actually transferring data
for (int e = 0; e < numArrays; e++) { for (int e = 0; e < numArrays; e++) {
auto z = reinterpret_cast<T *>(output->bufferWithOffset(offsets[e])); auto z = reinterpret_cast<T *>(output->bufferWithOffset(offsets[e]));
auto xBuffer = inputs[e]->bufferAsT<T>(); auto xBuffer = inputs[e]->bufferAsT<T>();
auto xShapeInfo = inputs[e]->shapeInfo(); auto xShapeInfo = inputs[e]->shapeInfo();
auto xShape = shape::shapeOf(xShapeInfo);
auto xStride = shape::stride(xShapeInfo);
auto xRank = shape::rank(xShapeInfo);
auto xLength = inputs[e]->lengthOf(); auto xLength = inputs[e]->lengthOf();
for (uint i = 0; i < xLength; i++) { for (uint i = 0; i < xLength; i++)
shape::index2coords(xRank, xShape, i, xLength, xCoord, order); z[i] = xBuffer[getIndexOffsetOrdered(i, xShapeInfo, order)];
auto xOffset = shape::getOffset(0, xShape, xStride, xCoord, xRank);
z[i] = xBuffer[xOffset];
}
} }
} }

View File

@ -30,7 +30,7 @@ namespace helpers {
template <typename X, typename Z> template <typename X, typename Z>
static void ismax_(const NDArray* input, NDArray* output, const std::vector<int>& dimensions) { static void ismax_(const NDArray* input, NDArray* output, const std::vector<int>& dimensions) {
if (input->isVector()) { if (input->isVector()) {
int dimensionsLength = dimensions.size(); int dimensionsLength = dimensions.size();
int length = input->lengthOf(); int length = input->lengthOf();
@ -169,7 +169,7 @@ static void ismax_(const NDArray* input, NDArray* output, const std::vector<int>
for (int i = 0; i < tadLength; i++) { for (int i = 0; i < tadLength; i++) {
rZ[i] = maxIdx == i ? (Z) 1 : (Z) 0; rZ[i] = maxIdx == i ? (Z) 1 : (Z) 0;
} }
} }
else if (tadEWS > 1 && zEWS > 1) { else if (tadEWS > 1 && zEWS > 1) {
for (int i = 0; i < tadLength; i++) { for (int i = 0; i < tadLength; i++) {
if (rX[i * tadEWS] > maxValue) { if (rX[i * tadEWS] > maxValue) {
@ -184,7 +184,7 @@ static void ismax_(const NDArray* input, NDArray* output, const std::vector<int>
} }
} else { } else {
for (int i = 0; i < tadLength; i++) { for (int i = 0; i < tadLength; i++) {
auto xOffset = shape::getIndexOffset(i, tadShapeShapeInfo, tadLength); auto xOffset = shape::getIndexOffset(i, tadShapeShapeInfo);
if (rX[xOffset] > maxValue) { if (rX[xOffset] > maxValue) {
maxIdx = i; maxIdx = i;
maxValue = rX[xOffset]; maxValue = rX[xOffset];
@ -193,7 +193,7 @@ static void ismax_(const NDArray* input, NDArray* output, const std::vector<int>
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (int i = 0; i < tadLength; i++) { for (int i = 0; i < tadLength; i++) {
auto zOffset = shape::getIndexOffset(i, tadPackZ.primaryShapeInfo(), tadLength); auto zOffset = shape::getIndexOffset(i, tadPackZ.primaryShapeInfo());
rZ[zOffset] = maxIdx == i ? (Z) 1 : (Z) 0; rZ[zOffset] = maxIdx == i ? (Z) 1 : (Z) 0;
} }
} }

View File

@ -52,14 +52,14 @@ void matrixSetDiag_(const NDArray& input, const NDArray& diagonal, NDArray& outp
PRAGMA_OMP_PARALLEL_FOR_ARGS(firstprivate(coords)) PRAGMA_OMP_PARALLEL_FOR_ARGS(firstprivate(coords))
for (Nd4jLong i = 0; i < xLen; ++i) { for (Nd4jLong i = 0; i < xLen; ++i) {
shape::index2coords(xRank, xShapeInfo + 1, i, xLen, coords.data()); shape::index2coords(i, xShapeInfo, coords.data());
const auto xOffset = shape::getOffset(0, xShapeInfo + 1, xShapeInfo + xRank + 1, coords.data(), xRank); const auto xOffset = shape::getOffset(xShapeInfo, coords.data());
const auto zOffset = areSameOffsets ? xOffset : shape::getOffset(0, zShapeInfo + 1, zShapeInfo + xRank + 1, coords.data(), xRank); const auto zOffset = areSameOffsets ? xOffset : shape::getOffset(zShapeInfo, coords.data());
// condition to be on diagonal of innermost matrix // condition to be on diagonal of innermost matrix
if(coords[xRank - 2] == coords[xRank - 1]) if(coords[xRank - 2] == coords[xRank - 1])
z[zOffset] = y[shape::getOffset(0, yShapeInfo + 1, yShapeInfo + xRank, coords.data(), xRank - 1)]; z[zOffset] = y[shape::getOffset(yShapeInfo, coords.data())];
else else
z[zOffset] = zeroPad ? static_cast<T>(0) : x[xOffset]; z[zOffset] = zeroPad ? static_cast<T>(0) : x[xOffset];
} }

View File

@ -73,12 +73,12 @@ namespace nd4j {
if (idx < 0 || idx >= tLen) { if (idx < 0 || idx >= tLen) {
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (unsigned int t = 0; t < tLen; t++) { for (unsigned int t = 0; t < tLen; t++) {
cO[shape::getIndexOffset(t, tadPack.primaryShapeInfo(), tLen)] = zero; cO[shape::getIndexOffset(t, tadPack.primaryShapeInfo())] = zero;
} }
} else { } else {
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (unsigned int t = 0; t < tLen; t++) { for (unsigned int t = 0; t < tLen; t++) {
cO[shape::getIndexOffset(t, tadPack.primaryShapeInfo(), tLen)] = idx == t ? one : zero; cO[shape::getIndexOffset(t, tadPack.primaryShapeInfo())] = idx == t ? one : zero;
} }
} }
} }

View File

@ -53,8 +53,8 @@ namespace nd4j {
for (Nd4jLong e = length - 1; e >= 0; --e) { for (Nd4jLong e = length - 1; e >= 0; --e) {
auto xOffset = shape::getIndexOffset(e, xShapeInfo, length); auto xOffset = shape::getIndexOffset(e, xShapeInfo);
auto zOffset = shape::getIndexOffset(e, zShapeInfo, length); auto zOffset = shape::getIndexOffset(e, zShapeInfo);
sum = op == scalar::Add ? simdOps::Add<T, T, T>::op(sum, x[xOffset]) : simdOps::Multiply<T, T, T>::op(sum, x[xOffset]); sum = op == scalar::Add ? simdOps::Add<T, T, T>::op(sum, x[xOffset]) : simdOps::Multiply<T, T, T>::op(sum, x[xOffset]);
if (!exclusive) if (!exclusive)
@ -83,8 +83,8 @@ namespace nd4j {
for (int e = 0; e < length; e++) { for (int e = 0; e < length; e++) {
auto xOffset = shape::getIndexOffset(e, xShapeInfo, length); auto xOffset = shape::getIndexOffset(e, xShapeInfo);
auto zOffset = shape::getIndexOffset(e, zShapeInfo, length); auto zOffset = shape::getIndexOffset(e, zShapeInfo);
sum = op == scalar::Add ? simdOps::Add<T, T, T>::op(sum, x[xOffset]) : simdOps::Multiply<T, T, T>::op(sum, x[xOffset]); sum = op == scalar::Add ? simdOps::Add<T, T, T>::op(sum, x[xOffset]) : simdOps::Multiply<T, T, T>::op(sum, x[xOffset]);
if (!exclusive) if (!exclusive)

View File

@ -60,7 +60,7 @@ static void reverseArray(nd4j::LaunchContext * context, void *vinArr, Nd4jLong *
// inArr[e] = inArr[idx]; // inArr[e] = inArr[idx];
// inArr[idx] = tmp; // inArr[idx] = tmp;
} }
} }
else if (inEWS > 1) { else if (inEWS > 1) {
PRAGMA_OMP_PARALLEL_FOR PRAGMA_OMP_PARALLEL_FOR
for (Nd4jLong e = 0; e < numOfElemsToReverse / 2; e++) { for (Nd4jLong e = 0; e < numOfElemsToReverse / 2; e++) {
@ -71,19 +71,19 @@ static void reverseArray(nd4j::LaunchContext * context, void *vinArr, Nd4jLong *
// inArr[idx1] = tmp; // inArr[idx1] = tmp;
swap(inArr, idx1, idx2); swap(inArr, idx1, idx2);
} }
} }
else { else {
PRAGMA_OMP_PARALLEL_FOR PRAGMA_OMP_PARALLEL_FOR
for (Nd4jLong e = 0; e < numOfElemsToReverse / 2; e++) { for (Nd4jLong e = 0; e < numOfElemsToReverse / 2; e++) {
auto inOffset = shape::getIndexOffset(e, inShapeBuffer, inLength); auto inOffset = shape::getIndexOffset(e, inShapeBuffer);
auto outOffset = shape::getIndexOffset(sLength - e, inShapeBuffer, inLength); auto outOffset = shape::getIndexOffset(sLength - e, inShapeBuffer);
//outArr[outOffset] = inArr[inOffset]; //outArr[outOffset] = inArr[inOffset];
swap(outArr, inOffset, outOffset); swap(outArr, inOffset, outOffset);
} }
} }
} }
else { else {
// single step phase here // single step phase here
auto outEWS = shape::elementWiseStride(outShapeBuffer); auto outEWS = shape::elementWiseStride(outShapeBuffer);
@ -92,15 +92,15 @@ static void reverseArray(nd4j::LaunchContext * context, void *vinArr, Nd4jLong *
if (inEWS == 1 && outEWS == 1 && inOrder == outOrder) { if (inEWS == 1 && outEWS == 1 && inOrder == outOrder) {
PRAGMA_OMP_PARALLEL_FOR PRAGMA_OMP_PARALLEL_FOR
for (Nd4jLong e = 0; e < numOfElemsToReverse; e++) for (Nd4jLong e = 0; e < numOfElemsToReverse; e++)
outArr[sLength - e] = inArr[e]; outArr[sLength - e] = inArr[e];
if(inLength != numOfElemsToReverse) { if(inLength != numOfElemsToReverse) {
PRAGMA_OMP_PARALLEL_FOR PRAGMA_OMP_PARALLEL_FOR
for (Nd4jLong e = numOfElemsToReverse; e < inLength; e++) for (Nd4jLong e = numOfElemsToReverse; e < inLength; e++)
outArr[e] = inArr[e]; outArr[e] = inArr[e];
} }
} }
else if (inEWS >= 1 && outEWS >= 1 && inOrder == outOrder) { else if (inEWS >= 1 && outEWS >= 1 && inOrder == outOrder) {
PRAGMA_OMP_PARALLEL_FOR PRAGMA_OMP_PARALLEL_FOR
@ -112,14 +112,14 @@ static void reverseArray(nd4j::LaunchContext * context, void *vinArr, Nd4jLong *
for (Nd4jLong e = numOfElemsToReverse; e < inLength; e++) for (Nd4jLong e = numOfElemsToReverse; e < inLength; e++)
outArr[e * outEWS] = inArr[e * inEWS]; outArr[e * outEWS] = inArr[e * inEWS];
} }
} }
else { else {
PRAGMA_OMP_PARALLEL_FOR PRAGMA_OMP_PARALLEL_FOR
for (Nd4jLong e = 0; e < numOfElemsToReverse; e++) { for (Nd4jLong e = 0; e < numOfElemsToReverse; e++) {
auto inOffset = shape::getIndexOffset(e, inShapeBuffer, inLength); auto inOffset = shape::getIndexOffset(e, inShapeBuffer);
auto outOffset = shape::getIndexOffset(sLength - e, outShapeBuffer, outLength); auto outOffset = shape::getIndexOffset(sLength - e, outShapeBuffer);
outArr[outOffset] = inArr[inOffset]; outArr[outOffset] = inArr[inOffset];
} }
@ -128,9 +128,9 @@ static void reverseArray(nd4j::LaunchContext * context, void *vinArr, Nd4jLong *
PRAGMA_OMP_PARALLEL_FOR PRAGMA_OMP_PARALLEL_FOR
for (Nd4jLong e = numOfElemsToReverse; e < inLength; e++) { for (Nd4jLong e = numOfElemsToReverse; e < inLength; e++) {
auto inOffset = shape::getIndexOffset(e, inShapeBuffer, inLength); auto inOffset = shape::getIndexOffset(e, inShapeBuffer);
auto outOffset = shape::getIndexOffset(e, outShapeBuffer, outLength); auto outOffset = shape::getIndexOffset(e, outShapeBuffer);
outArr[outOffset] = inArr[inOffset]; outArr[outOffset] = inArr[inOffset];
} }
} }
} }
@ -151,7 +151,7 @@ static void _reverseSequence(nd4j::LaunchContext * context, const NDArray* input
helpers::reverseArray<T>(context, const_cast<NDArray*>(input)->getBuffer(), const_cast<NDArray*>(input)->getShapeInfo(), output->getBuffer(), output->getShapeInfo(), seqLengths->e<int>(0)); helpers::reverseArray<T>(context, const_cast<NDArray*>(input)->getBuffer(), const_cast<NDArray*>(input)->getShapeInfo(), output->getBuffer(), output->getShapeInfo(), seqLengths->e<int>(0));
} }
else { else {
if(seqDim > batchDim) if(seqDim > batchDim)
--seqDim; --seqDim;
@ -163,7 +163,7 @@ static void _reverseSequence(nd4j::LaunchContext * context, const NDArray* input
for(int i = 0; i < inSubArrsSet->size(); ++i) { for(int i = 0; i < inSubArrsSet->size(); ++i) {
Nd4jLong numOfElemsToReverse = seqLengths->e<Nd4jLong>(i); Nd4jLong numOfElemsToReverse = seqLengths->e<Nd4jLong>(i);
if(numOfElemsToReverse == 0 || numOfElemsToReverse == 1) { if(numOfElemsToReverse == 0 || numOfElemsToReverse == 1) {
outSubArrsSet->at(i)->assign(inSubArrsSet->at(i)); outSubArrsSet->at(i)->assign(inSubArrsSet->at(i));
} }
@ -172,7 +172,7 @@ static void _reverseSequence(nd4j::LaunchContext * context, const NDArray* input
auto outInnerSet = outSubArrsSet->at(i)->allTensorsAlongDimension({seqDim}); auto outInnerSet = outSubArrsSet->at(i)->allTensorsAlongDimension({seqDim});
for(int j = 0; j < inInnerSet->size(); ++j) for(int j = 0; j < inInnerSet->size(); ++j)
helpers::reverseArray<T>(context, inInnerSet->at(j)->getBuffer(), inInnerSet->at(j)->getShapeInfo(), outInnerSet->at(j)->getBuffer(), outInnerSet->at(j)->getShapeInfo(), numOfElemsToReverse); helpers::reverseArray<T>(context, inInnerSet->at(j)->getBuffer(), inInnerSet->at(j)->getShapeInfo(), outInnerSet->at(j)->getBuffer(), outInnerSet->at(j)->getShapeInfo(), numOfElemsToReverse);
delete inInnerSet; delete inInnerSet;
delete outInnerSet; delete outInnerSet;
} }
@ -195,12 +195,12 @@ void reverse(nd4j::LaunchContext * context, const NDArray* input, NDArray* outpu
auto listOut = output->allTensorsAlongDimension(dimensions); auto listOut = output->allTensorsAlongDimension(dimensions);
auto listIn = input->allTensorsAlongDimension(dimensions); auto listIn = input->allTensorsAlongDimension(dimensions);
NDArray *subArrIn, *subArrOut; NDArray *subArrIn, *subArrOut;
for(int i = 0; i < listIn->size(); ++i) { // listIn->size() = listOut->size() for(int i = 0; i < listIn->size(); ++i) { // listIn->size() = listOut->size()
subArrIn = listIn->at(i); subArrIn = listIn->at(i);
subArrOut = listOut->at(i); subArrOut = listOut->at(i);
BUILD_SINGLE_SELECTOR(input->dataType(), helpers::reverseArray, (context, subArrIn->getBuffer(), subArrIn->getShapeInfo(), subArrOut->getBuffer(), subArrOut->getShapeInfo()), LIBND4J_TYPES); BUILD_SINGLE_SELECTOR(input->dataType(), helpers::reverseArray, (context, subArrIn->getBuffer(), subArrIn->getShapeInfo(), subArrOut->getBuffer(), subArrOut->getShapeInfo()), LIBND4J_TYPES);
} }

View File

@ -116,15 +116,15 @@ static void batchToSpaceND_(const NDArray& input, const NDArray& crop, NDArray&
for (Nd4jLong i = 0; i < zLen; ++i) { for (Nd4jLong i = 0; i < zLen; ++i) {
shape::index2coords(rank, output.shapeOf(), i, zLen, coords.data()); shape::index2coords(i, output.getShapeInfo(), coords.data());
const auto zOffset = shape::getOffset(0, output.shapeOf(), output.stridesOf(), coords.data(), rank); const auto zOffset = shape::getOffset(output.getShapeInfo(), coords.data());
// evaluate spatial coordinates for x // evaluate spatial coordinates for x
for(uint j = 1; j <= numOfSpatialDims; ++j) for(uint j = 1; j <= numOfSpatialDims; ++j)
coords[j] += crop.e<uint>(j - 1, 0); // add crop left coords[j] += crop.e<uint>(j - 1, 0); // add crop left
z[zOffset] = x[shape::getOffset(0, input.shapeOf(), input.stridesOf(), coords.data(), rank)]; z[zOffset] = x[shape::getOffset(input.getShapeInfo(), coords.data())];
} }
} }
@ -298,9 +298,9 @@ static void spaceToBatchND_(const NDArray& input, const NDArray& padding, NDArra
PRAGMA_OMP_PARALLEL_FOR_ARGS(schedule(guided) firstprivate(coords)) PRAGMA_OMP_PARALLEL_FOR_ARGS(schedule(guided) firstprivate(coords))
for (Nd4jLong i = 0; i < zLen; ++i) { for (Nd4jLong i = 0; i < zLen; ++i) {
shape::index2coords(rank, output.shapeOf(), i, zLen, coords.data()); shape::index2coords(i, output.getShapeInfo(), coords.data());
const auto zOffset = shape::getOffset(0, output.shapeOf(), output.stridesOf(), coords.data(), rank); const auto zOffset = shape::getOffset(output.getShapeInfo(), coords.data());
bool within = true; bool within = true;
@ -318,7 +318,7 @@ static void spaceToBatchND_(const NDArray& input, const NDArray& padding, NDArra
} }
if(within) if(within)
z[zOffset] = x[shape::getOffset(0, input.shapeOf(), input.stridesOf(), coords.data(), rank)]; z[zOffset] = x[shape::getOffset(input.getShapeInfo(), coords.data())];
else else
z[zOffset] = 0.f; z[zOffset] = 0.f;
} }

View File

@ -178,8 +178,6 @@ void pad_(const int mode, const NDArray& input, const NDArray& paddings, NDArray
const Nd4jLong* xShape = input.shapeOf(); const Nd4jLong* xShape = input.shapeOf();
const Nd4jLong* zShape = output.shapeOf(); const Nd4jLong* zShape = output.shapeOf();
const Nd4jLong* xStride = input.stridesOf();
const Nd4jLong* zStride = output.stridesOf();
const int rank = input.rankOf(); // both input and output have the same rank const int rank = input.rankOf(); // both input and output have the same rank
const int rankMinusOne = rank - 1; const int rankMinusOne = rank - 1;
@ -195,8 +193,8 @@ void pad_(const int mode, const NDArray& input, const NDArray& paddings, NDArray
PRAGMA_OMP_PARALLEL_FOR_ARGS(firstprivate(coords)) PRAGMA_OMP_PARALLEL_FOR_ARGS(firstprivate(coords))
for(uint i = 0; i < zLen; ++i) { for(uint i = 0; i < zLen; ++i) {
shape::index2coords(rank, zShape, i, zLen, coords.data()); shape::index2coords(i, output.getShapeInfo(), coords.data());
const auto zOffset = shape::getOffset(0, zShape, zStride, coords.data(), rank); const auto zOffset = shape::getOffset(output.getShapeInfo(), coords.data());
bool within = true; bool within = true;
for(int j = rankMinusOne; j >= 0; --j) { for(int j = rankMinusOne; j >= 0; --j) {
@ -207,7 +205,7 @@ void pad_(const int mode, const NDArray& input, const NDArray& paddings, NDArray
} }
if(within) if(within)
z[zOffset] = x[shape::getOffset(0, xShape, xStride, coords.data(), rank)]; z[zOffset] = x[shape::getOffset(input.getShapeInfo(), coords.data())];
else else
z[zOffset] = padVal; z[zOffset] = padVal;
} }
@ -220,8 +218,8 @@ void pad_(const int mode, const NDArray& input, const NDArray& paddings, NDArray
PRAGMA_OMP_PARALLEL_FOR_ARGS(firstprivate(coords)) PRAGMA_OMP_PARALLEL_FOR_ARGS(firstprivate(coords))
for(uint i = 0; i < zLen; ++i) { for(uint i = 0; i < zLen; ++i) {
shape::index2coords(rank, zShape, i, zLen, coords.data()); shape::index2coords(i, output.getShapeInfo(), coords.data());
const auto zOffset = shape::getOffset(0, zShape, zStride, coords.data(), rank); const auto zOffset = shape::getOffset(output.getShapeInfo(), coords.data());
for(int j = rankMinusOne; j >= 0; --j) { for(int j = rankMinusOne; j >= 0; --j) {
@ -231,7 +229,7 @@ void pad_(const int mode, const NDArray& input, const NDArray& paddings, NDArray
else if(coords[j] >= xShape[j]) coords[j] = 2 * xShape[j] - coords[j] - shift2; // means fill from right else if(coords[j] >= xShape[j]) coords[j] = 2 * xShape[j] - coords[j] - shift2; // means fill from right
} }
const auto xOffset = shape::getOffset(0, xShape, xStride, coords.data(), rank); const auto xOffset = shape::getOffset(input.getShapeInfo(), coords.data());
z[zOffset] = x[xOffset]; z[zOffset] = x[xOffset];
} }
} }
@ -580,9 +578,9 @@ static void gatherND_(NDArray& input, NDArray& indices, NDArray& output) {
xCoordStart = coords.data(); xCoordStart = coords.data();
} }
shape::index2coords(zRank, output.shapeOf(), i, zLen, zCoordStart); shape::index2coords(i, output.getShapeInfo(), zCoordStart);
const auto zOffset = shape::getOffset(0, output.shapeOf(), output.stridesOf(), zCoordStart, zRank); const auto zOffset = shape::getOffset(output.getShapeInfo(), zCoordStart);
// last y coordinate // last y coordinate
uint coordToRestore; uint coordToRestore;
@ -590,7 +588,7 @@ static void gatherND_(NDArray& input, NDArray& indices, NDArray& output) {
coordToRestore = static_cast<uint>(zCoordStart[yRank - 1]); coordToRestore = static_cast<uint>(zCoordStart[yRank - 1]);
zCoordStart[yRank - 1] = 0; zCoordStart[yRank - 1] = 0;
const auto yOffset = shape::getOffset(0, indices.shapeOf(), indices.stridesOf(), zCoordStart, yRank); const auto yOffset = shape::getOffset(indices.getShapeInfo(), zCoordStart);
//restore z coordinate //restore z coordinate
if(yLastDim != xRank) if(yLastDim != xRank)
@ -600,7 +598,7 @@ static void gatherND_(NDArray& input, NDArray& indices, NDArray& output) {
for(uint j = 0; j < yLastDim; ++j) for(uint j = 0; j < yLastDim; ++j)
xCoordStart[j] = y[yOffset + j * indices.stridesOf()[yRank - 1]]; // last stride xCoordStart[j] = y[yOffset + j * indices.stridesOf()[yRank - 1]]; // last stride
const auto xOffset = shape::getOffset(0, input.shapeOf(), input.stridesOf(), xCoordStart, xRank); const auto xOffset = shape::getOffset(input.getShapeInfo(), xCoordStart);
z[zOffset] = x[xOffset]; z[zOffset] = x[xOffset];
} }
@ -1172,7 +1170,7 @@ static void mirrorPad_(const NDArray& input, const NDArray& paddings, NDArray& o
PRAGMA_OMP_PARALLEL_FOR_ARGS(firstprivate(inIdx, outIdx)) PRAGMA_OMP_PARALLEL_FOR_ARGS(firstprivate(inIdx, outIdx))
for(int i = 0; i < outLen; ++i) { for(int i = 0; i < outLen; ++i) {
shape::index2coords(rank, output.shapeOf(), i, outIdx.data()); shape::index2coords(i, output.getShapeInfo(), outIdx.data());
for(int j = 0; j < rank; ++j) { for(int j = 0; j < rank; ++j) {
@ -1191,8 +1189,8 @@ static void mirrorPad_(const NDArray& input, const NDArray& paddings, NDArray& o
inIdx[j] = len - outIdx[j]; inIdx[j] = len - outIdx[j];
} }
auto outOffset = shape::getOffset(0, output.shapeOf(), output.stridesOf(), outIdx.data(), rank); auto outOffset = shape::getOffset(output.getShapeInfo(), outIdx.data());
auto inOffset = shape::getOffset(0, input.shapeOf(), input.stridesOf(), inIdx.data(), rank); auto inOffset = shape::getOffset(input.getShapeInfo(), inIdx.data());
reinterpret_cast<T*>(output.buffer())[outOffset] = reinterpret_cast<T*>(input.getBuffer())[inOffset]; reinterpret_cast<T*>(output.buffer())[outOffset] = reinterpret_cast<T*>(input.getBuffer())[inOffset];
} }
} }
@ -1259,7 +1257,7 @@ static void tileBP_(const NDArray& gradO /*input*/, NDArray& gradI /*output*/, c
for(Nd4jLong i=0; i<gradOLen; ++i) { for(Nd4jLong i=0; i<gradOLen; ++i) {
auto fidx = shape::subArrayIndex(i, gradO.getShapeInfo(), gradI.getShapeInfo()); auto fidx = shape::subArrayIndex(i, gradO.getShapeInfo(), gradI.getShapeInfo());
gradI.p(fidx, gradI.e<T>(fidx) + gradOBuff[shape::getIndexOffset(i, gradO.getShapeInfo(), gradOLen)]); gradI.p(fidx, gradI.e<T>(fidx) + gradOBuff[shape::getIndexOffset(i, gradO.getShapeInfo())]);
} }
} }
} }

View File

@ -60,9 +60,9 @@ __global__ void preluCuda(const void *vx, const Nd4jLong *xShapeInfo,
for (int i = tid; i < xzLen; i += totalThreads) { for (int i = tid; i < xzLen; i += totalThreads) {
shape::index2coords(xzRank, xShapeInfo + 1, i, xzLen, coords); shape::index2coords(i, xShapeInfo, coords);
const auto xzOffset = shape::getOffset(0, xShapeInfo + 1, xShapeInfo + xzRank + 1, coords, xzRank); const auto xzOffset = shape::getOffset(xShapeInfo, coords);
const auto xVal = x[xzOffset]; const auto xVal = x[xzOffset];
@ -72,7 +72,7 @@ __global__ void preluCuda(const void *vx, const Nd4jLong *xShapeInfo,
if(yShapeInfo[j + 1] == 1) if(yShapeInfo[j + 1] == 1)
coords[j + 1] = 0; coords[j + 1] = 0;
z[xzOffset] = xVal * y[shape::getOffset(0, yShapeInfo + 1, yShapeInfo + yRank + 1, coords + 1, yRank)]; z[xzOffset] = xVal * y[shape::getOffset(yShapeInfo, coords + 1)];
} }
else else
z[xzOffset] = xVal; z[xzOffset] = xVal;
@ -139,11 +139,11 @@ __global__ linkage void preluBPCuda(const void *vIn, const Nd4jLong *inShapeI
for (int i = tid; i < inLen; i += totalThreads) { for (int i = tid; i < inLen; i += totalThreads) {
shape::index2coords(inRank, inShapeInfo + 1, i, inLen, coords); shape::index2coords(i, inShapeInfo, coords);
const auto inOffset = shape::getOffset(0, inShapeInfo + 1, inShapeInfo + inRank + 1, coords, inRank); const auto inOffset = shape::getOffset(inShapeInfo, coords);
const auto dLdOOffset = shape::getOffset(0, dLdOShapeInfo + 1, dLdOShapeInfo + inRank + 1, coords, inRank); const auto dLdOOffset = shape::getOffset(dLdOShapeInfo, coords);
const auto dLdIOffset = shape::getOffset(0, dLdIShapeInfo + 1, dLdIShapeInfo + inRank + 1, coords, inRank); const auto dLdIOffset = shape::getOffset(dLdIShapeInfo, coords);
const auto xVal = in[inOffset]; const auto xVal = in[inOffset];
const auto grO = dLdO[dLdOOffset]; const auto grO = dLdO[dLdOOffset];
@ -154,8 +154,8 @@ __global__ linkage void preluBPCuda(const void *vIn, const Nd4jLong *inShapeI
if(alphaShapeInfo[j + 1] == 1) if(alphaShapeInfo[j + 1] == 1)
coords[j + 1] = 0; coords[j + 1] = 0;
const auto alphaOffset = shape::getOffset(0, alphaShapeInfo + 1, alphaShapeInfo + alphaRank + 1, coords + 1, alphaRank); const auto alphaOffset = shape::getOffset(alphaShapeInfo, coords + 1);
const auto dLdAOffset = shape::getOffset(0, dLdAShapeInfo + 1, dLdAShapeInfo + alphaRank + 1, coords + 1, alphaRank); const auto dLdAOffset = shape::getOffset(dLdAShapeInfo, coords + 1);
dLdI[dLdIOffset] = grO * alpha[alphaOffset]; dLdI[dLdIOffset] = grO * alpha[alphaOffset];
@ -223,7 +223,7 @@ __device__ void softMaxForVectorCuda(const void *vx, const Nd4jLong *xShapeInfo,
const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x; const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x;
if(elemIdx < len) { if(elemIdx < len) {
const Nd4jLong xOffset = shape::getIndexOffset(elemIdx, xShapeInfo, len); const Nd4jLong xOffset = shape::getIndexOffset(elemIdx, xShapeInfo);
shmem[threadIdx.x] = (threadIdx.x != 0) ? x[xOffset] : nd4j::math::nd4j_max<T>(x[xOffset], temp); // take into account max element evaluated on previous iteration and stored in temp shmem[threadIdx.x] = (threadIdx.x != 0) ? x[xOffset] : nd4j::math::nd4j_max<T>(x[xOffset], temp); // take into account max element evaluated on previous iteration and stored in temp
} }
else else
@ -249,8 +249,8 @@ __device__ void softMaxForVectorCuda(const void *vx, const Nd4jLong *xShapeInfo,
const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x; const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x;
if(elemIdx < len) { if(elemIdx < len) {
const Nd4jLong xOffset = shape::getIndexOffset(elemIdx, xShapeInfo, len); const Nd4jLong xOffset = shape::getIndexOffset(elemIdx, xShapeInfo);
const Nd4jLong zOffset = shape::getIndexOffset(elemIdx, zShapeInfo, len); const Nd4jLong zOffset = shape::getIndexOffset(elemIdx, zShapeInfo);
z[zOffset] = nd4j::math::nd4j_exp<T, T>(x[xOffset] - max); z[zOffset] = nd4j::math::nd4j_exp<T, T>(x[xOffset] - max);
shmem[threadIdx.x] = (threadIdx.x != 0) ? z[zOffset] : (z[zOffset] + temp); // take into account sum element evaluated on previous iteration and stored in temp shmem[threadIdx.x] = (threadIdx.x != 0) ? z[zOffset] : (z[zOffset] + temp); // take into account sum element evaluated on previous iteration and stored in temp
} }
@ -272,7 +272,7 @@ __device__ void softMaxForVectorCuda(const void *vx, const Nd4jLong *xShapeInfo,
for (int i = 0; i < numOfIters; ++i) { for (int i = 0; i < numOfIters; ++i) {
const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x; const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x;
if(elemIdx >= len) continue; if(elemIdx >= len) continue;
const Nd4jLong zOffset = shape::getIndexOffset(elemIdx, zShapeInfo, len); const Nd4jLong zOffset = shape::getIndexOffset(elemIdx, zShapeInfo);
z[zOffset] /= shmem[0]; z[zOffset] /= shmem[0];
} }
} }
@ -386,7 +386,7 @@ __global__ void logSoftMaxForVectorCuda(const void *vx, const Nd4jLong *xzShape
const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x; const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x;
if(elemIdx < len) { if(elemIdx < len) {
const Nd4jLong offset = shape::getIndexOffset(elemIdx, xzShapeInfo, len); const Nd4jLong offset = shape::getIndexOffset(elemIdx, xzShapeInfo);
shmem[threadIdx.x] = (threadIdx.x != 0) ? x[offset] : nd4j::math::nd4j_max<T>(x[offset], temp); // take into account max element evaluated on previous iteration and stored in temp shmem[threadIdx.x] = (threadIdx.x != 0) ? x[offset] : nd4j::math::nd4j_max<T>(x[offset], temp); // take into account max element evaluated on previous iteration and stored in temp
} }
else else
@ -412,7 +412,7 @@ __global__ void logSoftMaxForVectorCuda(const void *vx, const Nd4jLong *xzShape
const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x; const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x;
if(elemIdx < len) { if(elemIdx < len) {
const Nd4jLong offset = shape::getIndexOffset(elemIdx, xzShapeInfo, len); const Nd4jLong offset = shape::getIndexOffset(elemIdx, xzShapeInfo);
z[offset] = nd4j::math::nd4j_exp<T, T>(x[offset] - max); z[offset] = nd4j::math::nd4j_exp<T, T>(x[offset] - max);
shmem[threadIdx.x] = (threadIdx.x != 0) ? z[offset] : (z[offset] + temp); // take into account sum element evaluated on previous iteration and stored in temp shmem[threadIdx.x] = (threadIdx.x != 0) ? z[offset] : (z[offset] + temp); // take into account sum element evaluated on previous iteration and stored in temp
} }
@ -434,7 +434,7 @@ __global__ void logSoftMaxForVectorCuda(const void *vx, const Nd4jLong *xzShape
for (int i = 0; i < numOfIters; ++i) { for (int i = 0; i < numOfIters; ++i) {
const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x; const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x;
if(elemIdx >= len) continue; if(elemIdx >= len) continue;
const Nd4jLong offset = shape::getIndexOffset(elemIdx, xzShapeInfo, len); const Nd4jLong offset = shape::getIndexOffset(elemIdx, xzShapeInfo);
z[offset] = nd4j::math::nd4j_log<T,T>(z[offset] / shmem[0]); z[offset] = nd4j::math::nd4j_log<T,T>(z[offset] / shmem[0]);
} }
} }
@ -505,7 +505,7 @@ __global__ linkage void softMaxDerivForVectorCuda(const void *vx, const Nd4jLong
const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x; const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x;
if(elemIdx < len) { if(elemIdx < len) {
const Nd4jLong offset = shape::getIndexOffset(elemIdx, xzShapeInfo, len); const Nd4jLong offset = shape::getIndexOffset(elemIdx, xzShapeInfo);
shmem[threadIdx.x] = (threadIdx.x != 0) ? x[offset] : nd4j::math::nd4j_max<T>(x[offset], temp); // take into account max element evaluated on previous iteration and stored in temp shmem[threadIdx.x] = (threadIdx.x != 0) ? x[offset] : nd4j::math::nd4j_max<T>(x[offset], temp); // take into account max element evaluated on previous iteration and stored in temp
} }
else else
@ -531,7 +531,7 @@ __global__ linkage void softMaxDerivForVectorCuda(const void *vx, const Nd4jLong
const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x; const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x;
if(elemIdx < len) { if(elemIdx < len) {
const Nd4jLong offset = shape::getIndexOffset(elemIdx, xzShapeInfo, len); const Nd4jLong offset = shape::getIndexOffset(elemIdx, xzShapeInfo);
z[offset] = nd4j::math::nd4j_exp<T, T>(x[offset] - max); z[offset] = nd4j::math::nd4j_exp<T, T>(x[offset] - max);
shmem[threadIdx.x] = (threadIdx.x != 0) ? z[offset] : (z[offset] + temp); // take into account sum element evaluated on previous iteration and stored in temp shmem[threadIdx.x] = (threadIdx.x != 0) ? z[offset] : (z[offset] + temp); // take into account sum element evaluated on previous iteration and stored in temp
} }
@ -553,7 +553,7 @@ __global__ linkage void softMaxDerivForVectorCuda(const void *vx, const Nd4jLong
for (int i = 0; i < numOfIters; ++i) { for (int i = 0; i < numOfIters; ++i) {
const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x; const Nd4jLong elemIdx = i * blockDim.x + threadIdx.x;
if(elemIdx >= len) continue; if(elemIdx >= len) continue;
const Nd4jLong offset = shape::getIndexOffset(elemIdx, xzShapeInfo, len); const Nd4jLong offset = shape::getIndexOffset(elemIdx, xzShapeInfo);
z[offset] /= shmem[0]; z[offset] /= shmem[0];
z[offset] *= (1.f - z[offset]); // derivative z[offset] *= (1.f - z[offset]); // derivative
} }

View File

@ -0,0 +1,110 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Yurii Shyrma (iuriish@yahoo.com)
//
#include<ops/declarable/helpers/addBias.h>
#include <PointersManager.h>
namespace nd4j {
namespace ops {
namespace helpers {
//////////////////////////////////////////////////////////////////////
template<typename X, typename Y>
__global__ static void addBiasCuda( const void* vx, const Nd4jLong* xShapeInfo,
const void* vy, const Nd4jLong* yShapeInfo,
void* vz, const Nd4jLong* zShapeInfo,
const bool isNCHW) {
// bias [oC]
// if(input_rank == 4)
// input and output have same shapes: [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
// if(input_rank == 5)
// input and output have same shapes: [bS, oD, oH, oW, oC] (NHWC) or [bS, oD, oC, oH, oW] (NCHW)
const X* x = reinterpret_cast<const X*>(vx);
const Y* y = reinterpret_cast<const Y*>(vy);
X* z = reinterpret_cast<X*>(vz);
__shared__ int rank, channelPosition;
__shared__ Nd4jLong *sharedMem, len;
__shared__ bool xzSameOffsets, xzAreSame;
if (threadIdx.x == 0) {
extern __shared__ unsigned char shmem[];
sharedMem = reinterpret_cast<Nd4jLong*>(shmem);
rank = shape::rank(xShapeInfo); // xRank == zRank
xzSameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo);
len = shape::length(xShapeInfo);
channelPosition = isNCHW ? 1 : rank - 1; // second or last
xzAreSame = x == z;
}
__syncthreads();
auto coords = sharedMem + threadIdx.x * rank;
for (Nd4jLong i = blockIdx.x * blockDim.x + threadIdx.x; i < len; i += blockDim.x * gridDim.x) {
shape::index2coords(i, xShapeInfo, coords);
const auto xOffsets = shape::getOffset(xShapeInfo, coords);
const auto zOffsets = xzSameOffsets ? xOffsets : shape::getOffset(zShapeInfo, coords);
const auto yOffsets = shape::getOffset(yShapeInfo, coords + channelPosition);
if(xzAreSame)
z[zOffsets] += static_cast<X>(y[yOffsets]);
else
z[zOffsets] = x[xOffsets] + static_cast<X>(y[yOffsets]);
}
}
//////////////////////////////////////////////////////////////////////////
template<typename X, typename Y>
static void addBiasCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream,
const void* vx, const Nd4jLong* xShapeInfo,
const void* vy, const Nd4jLong* yShapeInfo,
void* vz, const Nd4jLong* zShapeInfo,
const bool isNCHW) {
addBiasCuda<X,Y><<<blocksPerGrid, threadsPerBlock, sharedMem, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, isNCHW);
}
//////////////////////////////////////////////////////////////////////////
void addBias(nd4j::graph::Context& block, const NDArray& input, const NDArray& bias, NDArray& output, const bool isNCHW) {
PointersManager manager(block.launchContext(), "addBias");
const int threadsPerBlock = MAX_NUM_THREADS;
const int blocksPerGrid = (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
const int sharedMem = input.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128;
NDArray::prepareSpecialUse({&output}, {&input, &bias});
BUILD_DOUBLE_SELECTOR(input.dataType(), bias.dataType(), addBiasCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, block.launchContext()->getCudaStream(), input.getSpecialBuffer(), input.getSpecialShapeInfo(), bias.getSpecialBuffer(), bias.getSpecialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), isNCHW), FLOAT_TYPES, FLOAT_TYPES);
NDArray::registerSpecialUse({&output}, {&input, &bias});
manager.synchronize();
}
}
}
}

View File

@ -143,13 +143,13 @@ static void _CUDA_G adjustHueSingleNCHWKernel(void *xBuffer, Nd4jLong *xTadShape
for (Nd4jLong e = tid; e < tuples; e += blockDim.x * gridDim.x) { for (Nd4jLong e = tid; e < tuples; e += blockDim.x * gridDim.x) {
auto _ri = bufferR + shape::getIndexOffset(e, xTadShapeInfo, tadLength);; auto _ri = bufferR + shape::getIndexOffset(e, xTadShapeInfo);
auto _gi = bufferG + shape::getIndexOffset(e, xTadShapeInfo, tadLength);; auto _gi = bufferG + shape::getIndexOffset(e, xTadShapeInfo);
auto _bi = bufferB + shape::getIndexOffset(e, xTadShapeInfo, tadLength);; auto _bi = bufferB + shape::getIndexOffset(e, xTadShapeInfo);
auto _ro = outputR + shape::getIndexOffset(e, xTadShapeInfo, tadLength);; auto _ro = outputR + shape::getIndexOffset(e, xTadShapeInfo);
auto _go = outputG + shape::getIndexOffset(e, xTadShapeInfo, tadLength);; auto _go = outputG + shape::getIndexOffset(e, xTadShapeInfo);
auto _bo = outputB + shape::getIndexOffset(e, xTadShapeInfo, tadLength);; auto _bo = outputB + shape::getIndexOffset(e, xTadShapeInfo);
T h, v_min, v_max; T h, v_min, v_max;
helpers::rgb_to_hv(_ri[0], _gi[0], _bi[0], &h, &v_min, &v_max); helpers::rgb_to_hv(_ri[0], _gi[0], _bi[0], &h, &v_min, &v_max);

View File

@ -139,13 +139,13 @@ static void _CUDA_G adjustSaturationSingleNCHWKernel(void *xBuffer, Nd4jLong *xT
auto outputB = reinterpret_cast<T *>(zBuffer) + zOffsets[2]; auto outputB = reinterpret_cast<T *>(zBuffer) + zOffsets[2];
for (Nd4jLong e = tid; e < tuples; e += blockDim.x * gridDim.x) { for (Nd4jLong e = tid; e < tuples; e += blockDim.x * gridDim.x) {
auto _ri = bufferR + shape::getIndexOffset(e, xTadShapeInfo, tadLength); auto _ri = bufferR + shape::getIndexOffset(e, xTadShapeInfo);
auto _gi = bufferG + shape::getIndexOffset(e, xTadShapeInfo, tadLength); auto _gi = bufferG + shape::getIndexOffset(e, xTadShapeInfo);
auto _bi = bufferB + shape::getIndexOffset(e, xTadShapeInfo, tadLength); auto _bi = bufferB + shape::getIndexOffset(e, xTadShapeInfo);
auto _ro = outputR + shape::getIndexOffset(e, xTadShapeInfo, tadLength); auto _ro = outputR + shape::getIndexOffset(e, xTadShapeInfo);
auto _go = outputG + shape::getIndexOffset(e, xTadShapeInfo, tadLength); auto _go = outputG + shape::getIndexOffset(e, xTadShapeInfo);
auto _bo = outputB + shape::getIndexOffset(e, xTadShapeInfo, tadLength); auto _bo = outputB + shape::getIndexOffset(e, xTadShapeInfo);
T h, s, v; T h, s, v;
// Convert the RGB color to Hue/V-range. // Convert the RGB color to Hue/V-range.

View File

@ -64,25 +64,25 @@ __global__ static void batchnormCuda(const void* vx, const Nd4jLong* xShapeInfo,
for (uint i = tid; i < minLen; i += totalThreads) { for (uint i = tid; i < minLen; i += totalThreads) {
const auto meanOffset = shape::getIndexOffset(i, meanShapeInfo, minLen); const auto meanOffset = shape::getIndexOffset(i, meanShapeInfo);
const auto varianceOffset = shape::getIndexOffset(i, varianceShapeInfo, minLen); const auto varianceOffset = shape::getIndexOffset(i, varianceShapeInfo);
T sigmaInvGam = 1. / nd4j::math::nd4j_sqrt<T, T>(variance[varianceOffset] + epsilon); T sigmaInvGam = 1. / nd4j::math::nd4j_sqrt<T, T>(variance[varianceOffset] + epsilon);
if(gamma != nullptr) if(gamma != nullptr)
sigmaInvGam *= gamma[shape::getIndexOffset(i, gammaShapeInfo, minLen)]; sigmaInvGam *= gamma[shape::getIndexOffset(i, gammaShapeInfo)];
auto betaOffset = 0; auto betaOffset = 0;
if(beta != nullptr) if(beta != nullptr)
betaOffset = shape::getIndexOffset(i, betaShapeInfo, minLen); betaOffset = shape::getIndexOffset(i, betaShapeInfo);
const auto xTad = x + xTadOffsets[i]; const auto xTad = x + xTadOffsets[i];
auto zTad = z + zTadOffsets[i]; auto zTad = z + zTadOffsets[i];
for (uint j = 0; j < tadLen; ++j) { for (uint j = 0; j < tadLen; ++j) {
const auto xTadOffset = shape::getIndexOffset(j, xTadShapeInfo, tadLen); const auto xTadOffset = shape::getIndexOffset(j, xTadShapeInfo);
const auto zTadOffset = shape::getIndexOffset(j, zTadShapeInfo, tadLen); const auto zTadOffset = shape::getIndexOffset(j, zTadShapeInfo);
zTad[zTadOffset] = (xTad[xTadOffset] - mean[meanOffset]) * sigmaInvGam; zTad[zTadOffset] = (xTad[xTadOffset] - mean[meanOffset]) * sigmaInvGam;
@ -130,10 +130,10 @@ __global__ static void batchnormCuda2(const void* vx, const Nd4jLong* xShapeInfo
for (uint i = tid; i < xLen; i += totalThreads) { for (uint i = tid; i < xLen; i += totalThreads) {
shape::index2coords(xRank, shape::shapeOf(const_cast<Nd4jLong*>(xShapeInfo)), i, xLen, coords); shape::index2coords(i, xShapeInfo, coords);
const auto xOffset = shape::getOffset(0, shape::shapeOf(const_cast<Nd4jLong*>(xShapeInfo)), shape::stride(const_cast<Nd4jLong*>(xShapeInfo)), coords, xRank); const auto xOffset = shape::getOffset(xShapeInfo, coords);
const auto zOffset = shape::getOffset(0, shape::shapeOf(const_cast<Nd4jLong*>(zShapeInfo)), shape::stride(const_cast<Nd4jLong*>(zShapeInfo)), coords, xRank); const auto zOffset = shape::getOffset(zShapeInfo, coords);
if(minRank == xRank) { if(minRank == xRank) {
for (uint i = 0, j = 0; i < xRank; ++i) { for (uint i = 0, j = 0; i < xRank; ++i) {
@ -146,20 +146,20 @@ __global__ static void batchnormCuda2(const void* vx, const Nd4jLong* xShapeInfo
else // minRank = numDims = 1 in this case else // minRank = numDims = 1 in this case
coords[0] = coords[dims[0]]; coords[0] = coords[dims[0]];
const auto meanOffset = shape::getOffset(0, shape::shapeOf(const_cast<Nd4jLong*>(meanShapeInfo)), shape::stride(const_cast<Nd4jLong*>(meanShapeInfo)), coords, minRank); const auto meanOffset = shape::getOffset(meanShapeInfo, coords);
const auto varianceOffset = shape::getOffset(0, shape::shapeOf(const_cast<Nd4jLong*>(varianceShapeInfo)), shape::stride(const_cast<Nd4jLong*>(varianceShapeInfo)), coords, minRank); const auto varianceOffset = shape::getOffset(varianceShapeInfo, coords);
T sigmaInvGam = 1. / nd4j::math::nd4j_sqrt<T, T>(variance[varianceOffset] + epsilon); T sigmaInvGam = 1. / nd4j::math::nd4j_sqrt<T, T>(variance[varianceOffset] + epsilon);
if(gamma != nullptr) { if(gamma != nullptr) {
const auto gammaOffset = shape::getOffset(0, shape::shapeOf(const_cast<Nd4jLong*>(gammaShapeInfo)), shape::stride(const_cast<Nd4jLong*>(gammaShapeInfo)), coords, minRank); const auto gammaOffset = shape::getOffset(gammaShapeInfo, coords);
sigmaInvGam *= gamma[gammaOffset]; sigmaInvGam *= gamma[gammaOffset];
} }
z[zOffset] = (x[xOffset] - mean[meanOffset]) * sigmaInvGam; z[zOffset] = (x[xOffset] - mean[meanOffset]) * sigmaInvGam;
if(beta != nullptr) { if(beta != nullptr) {
const auto betaOffset = shape::getOffset(0, shape::shapeOf(const_cast<Nd4jLong*>(betaShapeInfo)), shape::stride(const_cast<Nd4jLong*>(betaShapeInfo)), coords, minRank); const auto betaOffset = shape::getOffset(betaShapeInfo, coords);
z[zOffset] += beta[betaOffset]; z[zOffset] += beta[betaOffset];
} }
} }

View File

@ -15,7 +15,7 @@
******************************************************************************/ ******************************************************************************/
// //
// Created by Yurii Shyrma on 11.12.2017 // @author Yurii Shyrma (iuriish@yahoo.com)
// //
#include<cmath> #include<cmath>
@ -117,10 +117,10 @@ __global__ void betaIncForArrayCuda(const void* va, const Nd4jLong* aShapeInfo,
Nd4jLong len = shape::length(xShapeInfo); Nd4jLong len = shape::length(xShapeInfo);
const T a = *(reinterpret_cast<const T*>(va) + shape::getIndexOffset(j, aShapeInfo, len)); const T a = *(reinterpret_cast<const T*>(va) + shape::getIndexOffset(j, aShapeInfo));
const T b = *(reinterpret_cast<const T*>(vb) + shape::getIndexOffset(j, bShapeInfo, len)); const T b = *(reinterpret_cast<const T*>(vb) + shape::getIndexOffset(j, bShapeInfo));
const T x = *(reinterpret_cast<const T*>(vx) + shape::getIndexOffset(j, xShapeInfo, len)); const T x = *(reinterpret_cast<const T*>(vx) + shape::getIndexOffset(j, xShapeInfo));
T& z = *(reinterpret_cast<T*>(vz) + shape::getIndexOffset(j, zShapeInfo, len)); T& z = *(reinterpret_cast<T*>(vz) + shape::getIndexOffset(j, zShapeInfo));
// t^{n-1} * (1 - t)^{n-1} is symmetric function with respect to x = 0.5 // t^{n-1} * (1 - t)^{n-1} is symmetric function with respect to x = 0.5
if(a == b && x == static_cast<T>(0.5)) { if(a == b && x == static_cast<T>(0.5)) {

View File

@ -35,12 +35,12 @@ void col2im_(nd4j::LaunchContext & context, const NDArray& input, NDArray& outp
auto colShape = shape::shapeOf(colShapeBuffer); auto colShape = shape::shapeOf(colShapeBuffer);
auto colStride = shape::stride(colShapeBuffer); auto colStride = shape::stride(colShapeBuffer);
auto imShape = shape::shapeOf(imShapeBuffer); auto imShape = shape::shapeOf(imShapeBuffer);
auto imStride = shape::stride(imShapeBuffer); auto imStride = shape::stride(imShapeBuffer);
const int bS = imShape[0]; const int bS = imShape[0];
const int iC = imShape[1]; const int iC = imShape[1];
const int kH = colShape[2]; const int kH = colShape[2];
const int kW = colShape[3]; const int kW = colShape[3];
const int oH = colShape[4]; const int oH = colShape[4];
const int oW = colShape[5]; const int oW = colShape[5];
const Nd4jLong colStride0 = colStride[0]; const Nd4jLong colStride0 = colStride[0];
@ -58,31 +58,31 @@ void col2im_(nd4j::LaunchContext & context, const NDArray& input, NDArray& outp
const auto imEWS = shape::elementWiseStride(imShapeBuffer); const auto imEWS = shape::elementWiseStride(imShapeBuffer);
if(imEWS == 1) { if(imEWS == 1) {
memset(imBuff, 0, shape::length(imShapeBuffer) * sizeof(T)); memset(imBuff, 0, shape::length(imShapeBuffer) * sizeof(T));
} }
else if (imEWS > 1) { else if (imEWS > 1) {
PRAGMA_OMP_PARALLEL_FOR_ARGS(schedule(static) proc_bind(close)) PRAGMA_OMP_PARALLEL_FOR_ARGS(schedule(static) proc_bind(close))
for (int i = 0; i < shape::length(imShapeBuffer) * imEWS; i += imEWS) for (int i = 0; i < shape::length(imShapeBuffer) * imEWS; i += imEWS)
imBuff[i] = static_cast<T>(0.f); imBuff[i] = static_cast<T>(0.f);
} }
else { else {
const auto len = shape::length(imShapeBuffer); const auto len = shape::length(imShapeBuffer);
PRAGMA_OMP_PARALLEL_FOR_ARGS(schedule(static) proc_bind(close)) PRAGMA_OMP_PARALLEL_FOR_ARGS(schedule(static) proc_bind(close))
for (int i = 0; i < len; i++) for (int i = 0; i < len; i++)
imBuff[shape::getIndexOffset(i, imShapeBuffer, len)] = static_cast<T>(0.f); imBuff[shape::getIndexOffset(i, imShapeBuffer)] = static_cast<T>(0.f);
} }
T *col, *im; T *col, *im;
int imRow, imCol; int imRow, imCol;
if (shape::order(colShapeBuffer) == 'c' && shape::order(imShapeBuffer) == 'c' && shape::strideDescendingCAscendingF(colShapeBuffer) && shape::strideDescendingCAscendingF(imShapeBuffer)) { if (shape::order(colShapeBuffer) == 'c' && shape::order(imShapeBuffer) == 'c' && shape::strideDescendingCAscendingF(colShapeBuffer) && shape::strideDescendingCAscendingF(imShapeBuffer)) {
PRAGMA_OMP_PARALLEL_FOR_ARGS(schedule(static) proc_bind(close) private(col, im, imRow, imCol)) PRAGMA_OMP_PARALLEL_FOR_ARGS(schedule(static) proc_bind(close) private(col, im, imRow, imCol))
for (int b = 0; b < bS; b++) { for (int b = 0; b < bS; b++) {
for (int c = 0; c < iC; ++c) { for (int c = 0; c < iC; ++c) {
for (int kRow = 0; kRow < kH; ++kRow) { for (int kRow = 0; kRow < kH; ++kRow) {
for (int kCol = 0; kCol < kW; ++kCol) { for (int kCol = 0; kCol < kW; ++kCol) {
for (int colH = 0; colH < oH; ++colH) { for (int colH = 0; colH < oH; ++colH) {
for (int colW = 0; colW < oW; ++colW) { for (int colW = 0; colW < oW; ++colW) {
imRow = (-pH + kRow * dH) + colH*sH; imRow = (-pH + kRow * dH) + colH*sH;
imCol = (-pW + kCol * dW) + colW*sW; imCol = (-pW + kCol * dW) + colW*sW;
@ -97,21 +97,21 @@ PRAGMA_OMP_PARALLEL_FOR_ARGS(schedule(static) proc_bind(close) private(col, im,
} }
} }
} }
} }
} }
else { else {
PRAGMA_OMP_PARALLEL_FOR_ARGS(schedule(static) proc_bind(close) private(im, col, imRow, imCol)) PRAGMA_OMP_PARALLEL_FOR_ARGS(schedule(static) proc_bind(close) private(im, col, imRow, imCol))
for (int b = 0; b < bS; b++) { for (int b = 0; b < bS; b++) {
for (int colH = 0; colH < oH; ++colH) { for (int colH = 0; colH < oH; ++colH) {
for (int colW = 0; colW < oW; ++colW) { for (int colW = 0; colW < oW; ++colW) {
for (int c = 0; c < iC; ++c) { for (int c = 0; c < iC; ++c) {
for (int kRow = 0; kRow < kH; ++kRow) { for (int kRow = 0; kRow < kH; ++kRow) {
for (int kCol = 0; kCol < kW; ++kCol) { for (int kCol = 0; kCol < kW; ++kCol) {
imRow = (-pH + kRow * dH) + colH*sH; imRow = (-pH + kRow * dH) + colH*sH;
imCol = (-pW + kCol * dW) + colW*sW; imCol = (-pW + kCol * dW) + colW*sW;
col = colBuff + b*colStride0 + c*colStride1 + kRow*colStride2 + kCol*colStride3 + colH*colStride4 + colW*colStride5; col = colBuff + b*colStride0 + c*colStride1 + kRow*colStride2 + kCol*colStride3 + colH*colStride4 + colW*colStride5;
im = imBuff + b*imStride0 + c*imStride1 + imRow*imStride2 + imCol*imStride3; im = imBuff + b*imStride0 + c*imStride1 + imRow*imStride2 + imCol*imStride3;
@ -120,9 +120,9 @@ PRAGMA_OMP_PARALLEL_FOR_ARGS(schedule(static) proc_bind(close) private(im, col,
} }
} }
} }
} }
} }
} }
} }
} }

Some files were not shown because too many files have changed in this diff Show More