Shyrma broadcast (#302)

* - profiling TrueBroadcastHelper

Signed-off-by: Yurii <iuriish@yahoo.com>

* - further improving of TrueBroadcastHelper

Signed-off-by: Yurii <iuriish@yahoo.com>

* - further profiling of broadcast op

Signed-off-by: Yurii <iuriish@yahoo.com>

* - implementation of broadcastShapeHelper which inserts unities in shapes of arrays to be broadcasted

Signed-off-by: Yurii <iuriish@yahoo.com>

* - provide additional method in ConstantShapeHelper class for deducing broadcast shapes with unities

Signed-off-by: Yurii <iuriish@yahoo.com>

* - provide new NativeOps helpers for usual and true broadcast methods

Signed-off-by: Yurii <iuriish@yahoo.com>

* enable bert profiler

Signed-off-by: raver119 <raver119@gmail.com>

* - delete unnessesary tests

Signed-off-by: Yurii <iuriish@yahoo.com>

Co-authored-by: raver119 <raver119@gmail.com>
master
Yurii Shyrma 2020-03-10 15:29:09 +02:00 committed by GitHub
parent c3223dbc7a
commit 6aaca58506
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
35 changed files with 1487 additions and 1247 deletions

View File

@ -26,7 +26,6 @@
#include <helpers/ConstantTadHelper.h>
#include <loops/BroadcastPairwiseConverter.h>
#include <helpers/PointersManager.h>
#include <loops/TrueBroadcastHelper.h>
namespace sd {
@ -2801,15 +2800,15 @@ void NDArray::applyTrueBroadcast(sd::BroadcastOpsTuple op, const NDArray& other,
if (isEmpty() || other.isEmpty())
return;
if (lengthOf() == 1) {
target.assign(this);
target.applyPairwiseTransform(op.p, other, extraArgs);
return;
}
if (other.lengthOf() == 1) {
const_cast<NDArray*>(this)->applyScalarArr(op.s, other, target, extraArgs);
return;
}
// if (lengthOf() == 1) {
// target.assign(this);
// target.applyPairwiseTransform(op.p, other, extraArgs);
// return;
// }
// if (other.lengthOf() == 1) {
// const_cast<NDArray*>(this)->applyScalarArr(op.s, other, target, extraArgs);
// return;
// }
if(checkTargetShape) {
Nd4jLong* newShapeInfo = nullptr;
@ -2819,36 +2818,46 @@ void NDArray::applyTrueBroadcast(sd::BroadcastOpsTuple op, const NDArray& other,
throw std::runtime_error("NDArray::applyTrueBroadcast method: the shape or type of target array is wrong !");
}
if(target.isSameShape(this) || target.isSameShape(other)) {
const_cast<NDArray*>(this)->applyBroadcast(op.b, ShapeUtils::getDimsWithSameShape(*this, other), other, target, extraArgs);
return;
Nd4jLong* xShapeInfoH = getShapeInfo();
Nd4jLong* yShapeInfoH = other.getShapeInfo();
Nd4jLong* xShapeInfoD = getSpecialShapeInfo();
Nd4jLong* yShapeInfoD = other.getSpecialShapeInfo();
if(!isSameShape(target)) {
auto xPack = ConstantShapeHelper::getInstance()->createShapeInfoWithUnitiesForBroadcast(target.getShapeInfo(), getShapeInfo(), getContext()->getWorkspace());
xShapeInfoH = reinterpret_cast<Nd4jLong*>(xPack.primary());
xShapeInfoD = reinterpret_cast<Nd4jLong*>(xPack.special());
}
if(!other.isSameShape(target)) {
auto yPack = ConstantShapeHelper::getInstance()->createShapeInfoWithUnitiesForBroadcast(target.getShapeInfo(), other.getShapeInfo(), other.getContext()->getWorkspace());
yShapeInfoH = reinterpret_cast<Nd4jLong*>(yPack.primary());
yShapeInfoD = reinterpret_cast<Nd4jLong*>(yPack.special());
}
#ifdef __ND4J_EXPERIMENTAL__
BUILD_PAIRWISE_SELECTOR(dataType(), other.dataType(), target.dataType(), helpers::TrueBroadcastHelper, ::exec(op.b, *this, other, target), LIBND4J_TYPES, LIBND4J_TYPES);
#else
BUILD_SINGLE_SELECTOR_THRICE(dataType(), helpers::TrueBroadcastHelper, ::exec(op.b, *this, other, target), LIBND4J_TYPES);
#endif
NDArray::prepareSpecialUse({&target}, {this, &other});
NativeOpExecutioner::execBroadcast(getContext(), op.b, getBuffer(), xShapeInfoH, getSpecialBuffer(), xShapeInfoD, other.getBuffer(), yShapeInfoH, other.getSpecialBuffer(), yShapeInfoD, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo());
registerSpecialUse({&target}, {this, &other});
}
//////////////////////////////////////////////////////////////////////////
void NDArray::applyTrueBroadcast(sd::BroadcastBoolOpsTuple op, const NDArray& other, NDArray& target, const bool checkTargetShape, ExtraArguments *extraArgs) const {
if (isS())
throw std::runtime_error("NDArray::applyTrueBroadcast bool: you can't use this method on String array!");
if (isEmpty() || other.isEmpty())
return;
if (lengthOf() == 1) {
NDArray temp(target._shapeInfo, dataType(), false, getContext());
temp.assign(this);
temp.applyPairwiseTransform(op.p, other, target, extraArgs);
return;
}
if (other.lengthOf() == 1) {
this->applyScalarArr(op.s, other, target, extraArgs);
return;
}
// if (lengthOf() == 1) {
// NDArray temp(target._shapeInfo, dataType(), false, getContext());
// temp.assign(this);
// temp.applyPairwiseTransform(op.p, other, target, extraArgs);
// return;
// }
// if (other.lengthOf() == 1) {
// this->applyScalarArr(op.s, other, target, extraArgs);
// return;
// }
if(checkTargetShape) {
Nd4jLong* newShapeInfo = nullptr;
@ -2860,12 +2869,25 @@ void NDArray::applyTrueBroadcast(sd::BroadcastBoolOpsTuple op, const NDArray& ot
throw std::invalid_argument("NDArray::applyTrueBroadcast bool method: this and other arrays must have the same type !");
}
if(target.isSameShape(this) || target.isSameShape(other)) {
const_cast<NDArray*>(this)->applyBroadcast(op.b, ShapeUtils::getDimsWithSameShape(*this, other), other, target, extraArgs);
return;
Nd4jLong* xShapeInfoH = getShapeInfo();
Nd4jLong* yShapeInfoH = other.getShapeInfo();
Nd4jLong* xShapeInfoD = getSpecialShapeInfo();
Nd4jLong* yShapeInfoD = other.getSpecialShapeInfo();
if(!isSameShape(target)) {
auto xPack = ConstantShapeHelper::getInstance()->createShapeInfoWithUnitiesForBroadcast(target.getShapeInfo(), getShapeInfo(), getContext()->getWorkspace());
xShapeInfoH = reinterpret_cast<Nd4jLong*>(xPack.primary());
xShapeInfoD = reinterpret_cast<Nd4jLong*>(xPack.special());
}
if(!other.isSameShape(target)) {
auto yPack = ConstantShapeHelper::getInstance()->createShapeInfoWithUnitiesForBroadcast(target.getShapeInfo(), other.getShapeInfo(), other.getContext()->getWorkspace());
yShapeInfoH = reinterpret_cast<Nd4jLong*>(yPack.primary());
yShapeInfoD = reinterpret_cast<Nd4jLong*>(yPack.special());
}
BUILD_DOUBLE_SELECTOR(dataType(), target.dataType(), helpers::TrueBroadcastBoolHelper, ::exec(op.b, *this, other, target), LIBND4J_TYPES, BOOL_TYPES);
NDArray::prepareSpecialUse({&target}, {this, &other});
NativeOpExecutioner::execBroadcastBool(getContext(), op.b, getBuffer(), xShapeInfoH, getSpecialBuffer(), xShapeInfoD, other.getBuffer(), yShapeInfoH, other.getSpecialBuffer(), yShapeInfoD, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), nullptr);
registerSpecialUse({&target}, {this, &other});
}
//////////////////////////////////////////////////////////////////////////
@ -2877,16 +2899,16 @@ void NDArray::applyTrueBroadcast(sd::BroadcastIntOpsTuple op, const NDArray& oth
if (isEmpty() || other.isEmpty())
return;
if (lengthOf() == 1) {
NDArray temp(target._shapeInfo, dataType(), false, getContext());
temp.assign(this);
temp.applyPairwiseTransform(op.p, other, target, extraArgs);
return;
}
if (other.lengthOf() == 1) {
this->applyScalarArr(op.s, other, target, extraArgs);
return;
}
// if (lengthOf() == 1) {
// NDArray temp(target._shapeInfo, dataType(), false, getContext());
// temp.assign(this);
// temp.applyPairwiseTransform(op.p, other, target, extraArgs);
// return;
// }
// if (other.lengthOf() == 1) {
// this->applyScalarArr(op.s, other, target, extraArgs);
// return;
// }
if(checkTargetShape) {
Nd4jLong* newShapeInfo = nullptr;
@ -2898,12 +2920,25 @@ void NDArray::applyTrueBroadcast(sd::BroadcastIntOpsTuple op, const NDArray& oth
throw std::invalid_argument("NDArray::applyTrueBroadcast int method: this and other arrays must have the same type !");
}
if(target.isSameShape(this) || target.isSameShape(other)) {
const_cast<NDArray*>(this)->applyBroadcast(op.b, ShapeUtils::getDimsWithSameShape(*this, other), other, target, extraArgs);
return;
Nd4jLong* xShapeInfoH = getShapeInfo();
Nd4jLong* yShapeInfoH = other.getShapeInfo();
Nd4jLong* xShapeInfoD = getSpecialShapeInfo();
Nd4jLong* yShapeInfoD = other.getSpecialShapeInfo();
if(!isSameShape(target)) {
auto xPack = ConstantShapeHelper::getInstance()->createShapeInfoWithUnitiesForBroadcast(target.getShapeInfo(), getShapeInfo(), getContext()->getWorkspace());
xShapeInfoH = reinterpret_cast<Nd4jLong*>(xPack.primary());
xShapeInfoD = reinterpret_cast<Nd4jLong*>(xPack.special());
}
if(!other.isSameShape(target)) {
auto yPack = ConstantShapeHelper::getInstance()->createShapeInfoWithUnitiesForBroadcast(target.getShapeInfo(), other.getShapeInfo(), other.getContext()->getWorkspace());
yShapeInfoH = reinterpret_cast<Nd4jLong*>(yPack.primary());
yShapeInfoD = reinterpret_cast<Nd4jLong*>(yPack.special());
}
BUILD_SINGLE_SELECTOR(dataType(), helpers::TrueBroadcastIntHelper, ::exec(op.b, *this, other, target), INTEGER_TYPES);
NDArray::prepareSpecialUse({&target}, {this, &other});
NativeOpExecutioner::execBroadcastInt(getContext(), op.b, getBuffer(), xShapeInfoH, getSpecialBuffer(), xShapeInfoD, other.getBuffer(), yShapeInfoH, other.getSpecialBuffer(), yShapeInfoD, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo());
registerSpecialUse({&target}, {this, &other});
}
//////////////////////////////////////////////////////////////////////////
@ -3008,6 +3043,10 @@ NDArray NDArray::applyTrueBroadcast(sd::BroadcastOpsTuple op, NDArray&& other, E
//////////////////////////////////////////////////////////////////////////
void NDArray::applyBroadcast(sd::broadcast::Ops op, const std::vector<int>& dimensions, const NDArray& other, NDArray& target, ExtraArguments* extraArgs) {
if (dimensions.size() == 0)
return;
if (isS())
throw std::runtime_error("NDArray::applyBroadcast: you can't use this method on String array!");
if(((op == broadcast::Divide || op == broadcast::FloorDiv || op == broadcast::FloorMod) && other.isB()) || (op == broadcast::ReverseDivide && this->isB()))
@ -3018,53 +3057,50 @@ void NDArray::applyBroadcast(sd::broadcast::Ops op, const std::vector<int>& dime
return;
}
if (dimensions.size() == 0)
return;
if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) {
NDArray::prepareSpecialUse({&target}, {this, &other});
NativeOpExecutioner::execPairwiseTransform(getContext(), fromBroadcastToPairwise(op), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), nullptr);
NDArray::registerSpecialUse({&target}, {this, &other});
return;
}
NDArray *min(nullptr), *max(nullptr);
if((lengthOf() > other.lengthOf()) || (lengthOf() == other.lengthOf() && rankOf() >= other.rankOf())) {
max = this;
min = const_cast<NDArray*>(&other);
}
else {
max = const_cast<NDArray*>(&other);
min = this;
}
// if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) {
// NDArray::prepareSpecialUse({&target}, {this, &other});
// NativeOpExecutioner::execPairwiseTransform(getContext(), fromBroadcastToPairwise(op), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), nullptr);
// NDArray::registerSpecialUse({&target}, {this, &other});
// return;
// }
if(target.dataType() != DataTypeUtils::pickPairwiseResultType(shapeInfo(), other.getShapeInfo()))
throw std::invalid_argument("NDArray::applyBroadcast method: wrong type of target array !");
if(!target.isSameShape(max))
throw std::invalid_argument("NDArray::applyBroadcast method: max and target arrays must have the same shape !");
if(!target.isSameShape(this) && !target.isSameShape(other))
throw std::invalid_argument("NDArray::applyBroadcast method: one of of two input arrays (this or other) should has the same shape as target array!");
std::vector<int> copy(dimensions);
if (dimensions.size() > 1)
std::sort(copy.begin(), copy.end());
Nd4jLong tadLength = shape::tadLength(max->shapeInfo(), copy.data(), (int) copy.size());
if (tadLength != min->lengthOf())
throw std::runtime_error("NDArray::applyBroadcast method: tad length mismatch !");
Nd4jLong* xShapeInfoH = getShapeInfo();
Nd4jLong* yShapeInfoH = other.getShapeInfo();
Nd4jLong* xShapeInfoD = getSpecialShapeInfo();
Nd4jLong* yShapeInfoD = other.getSpecialShapeInfo();
auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(max->shapeInfo(), copy);
auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(target.shapeInfo(), copy);
if(!isSameShape(target)) {
auto xPack = ConstantShapeHelper::getInstance()->createShapeInfoWithUnitiesForBroadcast(target.getShapeInfo(), getShapeInfo(), getContext()->getWorkspace(), copy);
xShapeInfoH = reinterpret_cast<Nd4jLong*>(xPack.primary());
xShapeInfoD = reinterpret_cast<Nd4jLong*>(xPack.special());
}
if(!other.isSameShape(target)) {
auto yPack = ConstantShapeHelper::getInstance()->createShapeInfoWithUnitiesForBroadcast(target.getShapeInfo(), other.getShapeInfo(), other.getContext()->getWorkspace(), copy);
yShapeInfoH = reinterpret_cast<Nd4jLong*>(yPack.primary());
yShapeInfoD = reinterpret_cast<Nd4jLong*>(yPack.special());
}
NDArray::prepareSpecialUse({&target}, {this, &other});
if(max == this)
NativeOpExecutioner::execBroadcast( getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), copy.data(), (int)copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets());
else
NativeOpExecutioner::execInverseBroadcast(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), copy.data(), (int)copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets());
NativeOpExecutioner::execBroadcast(getContext(), op, buffer(), xShapeInfoH, specialBuffer(), xShapeInfoD, other.getBuffer(), yShapeInfoH, other.getSpecialBuffer(), yShapeInfoD, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo());
registerSpecialUse({&target}, {this, &other});
}
//////////////////////////////////////////////////////////////////////////
void NDArray::applyBroadcast(sd::broadcast::BoolOps op, const std::vector<int>& dimensions, const NDArray& other, NDArray& target, ExtraArguments* extraArgs) {
if (dimensions.size() == 0)
return;
if (isS())
throw std::runtime_error("NDArray::applyBroadcast BoolOps: you can't use this method on String array!");
if(isEmpty() || other.isEmpty()) {
@ -3073,30 +3109,17 @@ void NDArray::applyBroadcast(sd::broadcast::BoolOps op, const std::vector<int>&
return;
}
if (dimensions.size() == 0)
return;
if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) {
NDArray::prepareSpecialUse({&target}, {this, &other});
NativeOpExecutioner::execPairwiseBoolTransform(getContext(), fromBroadcastToPairwiseBool(op), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), nullptr);
NDArray::registerSpecialUse({&target}, {this, &other});
return;
}
NDArray *min(nullptr), *max(nullptr);
if((lengthOf() > other.lengthOf()) || (lengthOf() == other.lengthOf() && rankOf() >= other.rankOf())) {
max = this;
min = const_cast<NDArray*>(&other);
}
else {
max = const_cast<NDArray*>(&other);
min = this;
}
// if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) {
// NDArray::prepareSpecialUse({&target}, {this, &other});
// NativeOpExecutioner::execPairwiseBoolTransform(getContext(), fromBroadcastToPairwiseBool(op), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), nullptr);
// NDArray::registerSpecialUse({&target}, {this, &other});
// return;
// }
if(target.dataType() != DataType::BOOL)
throw std::invalid_argument("NDArray::applyBroadcast bool method: type of target array must be BOOL!");
if(!target.isSameShape(max))
throw std::invalid_argument("NDArray::applyBroadcast bool method: max and target arrays must have the same shape !");
if(!target.isSameShape(this) && !target.isSameShape(other))
throw std::invalid_argument("NDArray::applyBroadcast bool method: one of of two input arrays (this or other) should has the same shape as target array!");
if(_dataType != other._dataType)
throw std::invalid_argument("NDArray::applyBroadcast bool method: this and other arrays must have the same type !");
@ -3105,25 +3128,34 @@ void NDArray::applyBroadcast(sd::broadcast::BoolOps op, const std::vector<int>&
if (dimensions.size() > 1)
std::sort(copy.begin(), copy.end());
Nd4jLong tadLength = shape::tadLength(max->shapeInfo(), copy.data(), (int) copy.size());
if (tadLength != min->lengthOf())
throw std::runtime_error("Tad length mismatch");
Nd4jLong* xShapeInfoH = getShapeInfo();
Nd4jLong* yShapeInfoH = other.getShapeInfo();
Nd4jLong* xShapeInfoD = getSpecialShapeInfo();
Nd4jLong* yShapeInfoD = other.getSpecialShapeInfo();
auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(max->shapeInfo(), copy);
auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(target.shapeInfo(), copy);
if(!isSameShape(target)) {
auto xPack = ConstantShapeHelper::getInstance()->createShapeInfoWithUnitiesForBroadcast(target.getShapeInfo(), getShapeInfo(), getContext()->getWorkspace(), copy);
xShapeInfoH = reinterpret_cast<Nd4jLong*>(xPack.primary());
xShapeInfoD = reinterpret_cast<Nd4jLong*>(xPack.special());
}
if(!other.isSameShape(target)) {
auto yPack = ConstantShapeHelper::getInstance()->createShapeInfoWithUnitiesForBroadcast(target.getShapeInfo(), other.getShapeInfo(), other.getContext()->getWorkspace(), copy);
yShapeInfoH = reinterpret_cast<Nd4jLong*>(yPack.primary());
yShapeInfoD = reinterpret_cast<Nd4jLong*>(yPack.special());
}
// TODO: eventually we want separate tads here
NDArray::prepareSpecialUse({&target}, {this, &other});
if(max == this)
NativeOpExecutioner::execBroadcastBool( getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), nullptr, copy.data(), (int)copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets());
else
NativeOpExecutioner::execInverseBroadcastBool(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), nullptr, copy.data(), (int)copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets());
NativeOpExecutioner::execBroadcastBool(getContext(), op, buffer(), xShapeInfoH, specialBuffer(), xShapeInfoD, other.getBuffer(), yShapeInfoH, other.getSpecialBuffer(), yShapeInfoD, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), nullptr);
registerSpecialUse({&target}, {this, &other});
}
//////////////////////////////////////////////////////////////////////////
void NDArray::applyBroadcast(sd::broadcast::IntOps op, const std::vector<int>& dimensions, const NDArray& other, NDArray& target, ExtraArguments* extraArgs) {
if (dimensions.empty())
return;
if (!isZ())
throw std::runtime_error("NDArray::applyBroadcast IntOps: you can't use this method on non-Integer array!");
if(isEmpty() || other.isEmpty()) {
@ -3132,30 +3164,17 @@ void NDArray::applyBroadcast(sd::broadcast::IntOps op, const std::vector<int>& d
return;
}
if (dimensions.empty())
return;
if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) {
NDArray::prepareSpecialUse({&target}, {this, &other});
NativeOpExecutioner::execPairwiseIntTransform(getContext(), fromBroadcastToPairwiseInt(op), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), nullptr);
NDArray::registerSpecialUse({&target}, {this, &other});
return;
}
NDArray *min(nullptr), *max(nullptr);
if((lengthOf() > other.lengthOf()) || (lengthOf() == other.lengthOf() && rankOf() >= other.rankOf())) {
max = this;
min = const_cast<NDArray*>(&other);
}
else {
max = const_cast<NDArray*>(&other);
min = this;
}
// if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) {
// NDArray::prepareSpecialUse({&target}, {this, &other});
// NativeOpExecutioner::execPairwiseIntTransform(getContext(), fromBroadcastToPairwiseInt(op), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), nullptr);
// NDArray::registerSpecialUse({&target}, {this, &other});
// return;
// }
if(target.dataType() != dataType())
throw std::invalid_argument("NDArray::applyBroadcast int method: type of target array must be the same as input!");
if(!target.isSameShape(max))
throw std::invalid_argument("NDArray::applyBroadcast int method: max and target arrays must have the same shape !");
if(!target.isSameShape(this) && !target.isSameShape(other))
throw std::invalid_argument("NDArray::applyBroadcast int method: one of of two input arrays (this or other) should has the same shape as target array!");
if(_dataType != other._dataType)
throw std::invalid_argument("NDArray::applyBroadcast int method: this and other arrays must have the same type !");
@ -3164,19 +3183,24 @@ void NDArray::applyBroadcast(sd::broadcast::IntOps op, const std::vector<int>& d
if (dimensions.size() > 1)
std::sort(copy.begin(), copy.end());
Nd4jLong tadLength = shape::tadLength(max->shapeInfo(), copy.data(), (int) copy.size());
if (tadLength != min->lengthOf())
throw std::runtime_error("Tad length mismatch");
Nd4jLong* xShapeInfoH = getShapeInfo();
Nd4jLong* yShapeInfoH = other.getShapeInfo();
Nd4jLong* xShapeInfoD = getSpecialShapeInfo();
Nd4jLong* yShapeInfoD = other.getSpecialShapeInfo();
auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(max->shapeInfo(), copy);
auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(target.shapeInfo(), copy);
if(!isSameShape(target)) {
auto xPack = ConstantShapeHelper::getInstance()->createShapeInfoWithUnitiesForBroadcast(target.getShapeInfo(), getShapeInfo(), getContext()->getWorkspace(), copy);
xShapeInfoH = reinterpret_cast<Nd4jLong*>(xPack.primary());
xShapeInfoD = reinterpret_cast<Nd4jLong*>(xPack.special());
}
if(!other.isSameShape(target)) {
auto yPack = ConstantShapeHelper::getInstance()->createShapeInfoWithUnitiesForBroadcast(target.getShapeInfo(), other.getShapeInfo(), other.getContext()->getWorkspace(), copy);
yShapeInfoH = reinterpret_cast<Nd4jLong*>(yPack.primary());
yShapeInfoD = reinterpret_cast<Nd4jLong*>(yPack.special());
}
// TODO: eventually we want separate tads here
NDArray::prepareSpecialUse({&target}, {this, &other});
if(max == this)
NativeOpExecutioner::execBroadcastInt( getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), copy.data(), (int)copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets());
else
NativeOpExecutioner::execInverseBroadcastInt(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), copy.data(), (int)copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets());
NativeOpExecutioner::execBroadcastInt(getContext(), op, buffer(), xShapeInfoH, specialBuffer(), xShapeInfoD, other.getBuffer(), yShapeInfoH, other.getSpecialBuffer(), yShapeInfoD, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo());
registerSpecialUse({&target}, {this, &other});
}

View File

@ -23,11 +23,11 @@
#include <cuda.h>
#include <cuda_runtime.h>
static Nd4jLong __device__ __noinline__ __getIndexOffset(Nd4jLong index, Nd4jLong *shapeInfo) {
static Nd4jLong __device__ __noinline__ getIndexOffset(Nd4jLong index, Nd4jLong *shapeInfo) {
return shape::getIndexOffset(index, shapeInfo);
}
static Nd4jLong __device__ __noinline__ __length(Nd4jLong *shapeInfo) {
static Nd4jLong __device__ __noinline__ length(Nd4jLong *shapeInfo) {
return shape::length(shapeInfo);
}
@ -94,7 +94,7 @@ static _CUDA_G void lambdaKernel(void* vx, Nd4jLong *xShapeInfo, void *vz, Nd4jL
auto xOrder = shape::order(xShapeInfo);
auto zOrder = shape::order(zShapeInfo);
auto zLength = __length(zShapeInfo);
auto zLength = length(zShapeInfo);
auto tid = threadIdx.x + blockIdx.x * blockDim.x;
@ -103,8 +103,8 @@ static _CUDA_G void lambdaKernel(void* vx, Nd4jLong *xShapeInfo, void *vz, Nd4jL
z[e * zEws] = lambda(x[e * xEws]);
} else {
for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) {
auto xOffset = __getIndexOffset(e, xShapeInfo);
auto zOffset = __getIndexOffset(e, zShapeInfo);
auto xOffset = getIndexOffset(e, xShapeInfo);
auto zOffset = getIndexOffset(e, zShapeInfo);
z[zOffset] = lambda(x[xOffset]);
}
@ -123,7 +123,7 @@ static _CUDA_G void lambdaIndexedKernel(void* vx, Nd4jLong *xShapeInfo, void *vz
auto xOrder = shape::order(xShapeInfo);
auto zOrder = shape::order(zShapeInfo);
auto zLength = __length(zShapeInfo);
auto zLength = length(zShapeInfo);
auto tid = threadIdx.x + blockIdx.x * blockDim.x;
@ -132,8 +132,8 @@ static _CUDA_G void lambdaIndexedKernel(void* vx, Nd4jLong *xShapeInfo, void *vz
z[e * zEws] = lambda(e, x[e * xEws]);
} else {
for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) {
auto xOffset = __getIndexOffset(e, xShapeInfo);
auto zOffset = __getIndexOffset(e, zShapeInfo);
auto xOffset = getIndexOffset(e, xShapeInfo);
auto zOffset = getIndexOffset(e, zShapeInfo);
z[zOffset] = lambda(e, x[xOffset]);
}
@ -155,7 +155,7 @@ static _CUDA_G void lambdaIndexedPairwiseKernel(void* vx, Nd4jLong *xShapeInfo,
auto yOrder = shape::order(yShapeInfo);
auto zOrder = shape::order(zShapeInfo);
auto zLength = __length(zShapeInfo);
auto zLength = length(zShapeInfo);
auto tid = threadIdx.x + blockIdx.x * blockDim.x;
@ -164,9 +164,9 @@ static _CUDA_G void lambdaIndexedPairwiseKernel(void* vx, Nd4jLong *xShapeInfo,
z[e * zEws] = lambda(e, x[e * xEws], y[e * yEws]);
} else {
for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) {
auto xOffset = __getIndexOffset(e, xShapeInfo);
auto yOffset = __getIndexOffset(e, yShapeInfo);
auto zOffset = __getIndexOffset(e, zShapeInfo);
auto xOffset = getIndexOffset(e, xShapeInfo);
auto yOffset = getIndexOffset(e, yShapeInfo);
auto zOffset = getIndexOffset(e, zShapeInfo);
z[zOffset] = lambda(e, x[xOffset], y[yOffset]);
}
@ -188,7 +188,7 @@ static _CUDA_G void lambdaPairwiseKernel(void* vx, Nd4jLong *xShapeInfo, void* v
auto yOrder = shape::order(yShapeInfo);
auto zOrder = shape::order(zShapeInfo);
auto zLength = __length(zShapeInfo);
auto zLength = length(zShapeInfo);
auto tid = threadIdx.x + blockIdx.x * blockDim.x;
@ -197,9 +197,9 @@ static _CUDA_G void lambdaPairwiseKernel(void* vx, Nd4jLong *xShapeInfo, void* v
z[e * zEws] = lambda(x[e * xEws], y[e * yEws]);
} else {
for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) {
auto xOffset = __getIndexOffset(e, xShapeInfo);
auto yOffset = __getIndexOffset(e, yShapeInfo);
auto zOffset = __getIndexOffset(e, zShapeInfo);
auto xOffset = getIndexOffset(e, xShapeInfo);
auto yOffset = getIndexOffset(e, yShapeInfo);
auto zOffset = getIndexOffset(e, zShapeInfo);
z[zOffset] = lambda(x[xOffset], y[yOffset]);
}
@ -224,7 +224,7 @@ static _CUDA_G void lambdaTriplewiseKernel(void* vw, Nd4jLong *wShapeInfo, void*
auto yOrder = shape::order(yShapeInfo);
auto zOrder = shape::order(zShapeInfo);
auto zLength = __length(zShapeInfo);
auto zLength = length(zShapeInfo);
auto tid = threadIdx.x + blockIdx.x * blockDim.x;
@ -233,10 +233,10 @@ static _CUDA_G void lambdaTriplewiseKernel(void* vw, Nd4jLong *wShapeInfo, void*
z[e * zEws] = lambda(w[e * wEws], x[e * xEws], y[e * yEws]);
} else {
for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) {
auto wOffset = __getIndexOffset(e, wShapeInfo);
auto xOffset = __getIndexOffset(e, xShapeInfo);
auto yOffset = __getIndexOffset(e, yShapeInfo);
auto zOffset = __getIndexOffset(e, zShapeInfo);
auto wOffset = getIndexOffset(e, wShapeInfo);
auto xOffset = getIndexOffset(e, xShapeInfo);
auto yOffset = getIndexOffset(e, yShapeInfo);
auto zOffset = getIndexOffset(e, zShapeInfo);
z[zOffset] = lambda(w[wOffset], x[xOffset], y[yOffset]);
}

View File

@ -52,6 +52,7 @@ namespace sd {
ConstantDataBuffer bufferForShapeInfo(const ShapeDescriptor &descriptor);
ConstantDataBuffer bufferForShapeInfo(const Nd4jLong *shapeInfo);
ConstantDataBuffer bufferForShapeInfo(const sd::DataType dataType, const char order, const int rank, const Nd4jLong* shape);
ConstantDataBuffer createShapeInfoWithUnitiesForBroadcast(const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, sd::memory::Workspace* workspace = nullptr, const std::vector<int> dimensions = {});
Nd4jLong* emptyShapeInfo(const sd::DataType dataType);

View File

@ -146,7 +146,62 @@ namespace sd {
return result;
}
sd::ConstantShapeHelper* sd::ConstantShapeHelper::_INSTANCE = 0;
////////////////////////////////////////////////////////////////////////
ConstantDataBuffer ConstantShapeHelper::createShapeInfoWithUnitiesForBroadcast(const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, sd::memory::Workspace* workspace, const std::vector<int> dimensions) {
Nd4jLong* newShapeInfo = nullptr;
ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(shape::rank(maxShapeInfo)), Nd4jLong);
newShapeInfo[0] = shape::rank(maxShapeInfo);
sd::ArrayOptions::copyDataType(newShapeInfo, minShapeInfo); // type
newShapeInfo[2 * newShapeInfo[0] + 2] = shape::elementWiseStride(minShapeInfo); // ews
newShapeInfo[2 * newShapeInfo[0] + 3] = shape::order(minShapeInfo); // order
if(!dimensions.empty()) {
for(uint k = 0, j = 0, i = 0; i < shape::rank(maxShapeInfo); ++i) {
if(j < dimensions.size() && dimensions[j] == i) {
shape::shapeOf(newShapeInfo)[i] = shape::shapeOf(minShapeInfo)[k];
shape::stride(newShapeInfo)[i] = shape::stride(minShapeInfo)[k++];
++j;
}
else{
shape::shapeOf(newShapeInfo)[i] = 1;
shape::stride(newShapeInfo)[i] = 0;
if(shape::sizeAt(minShapeInfo, k) == 1 && dimensions.size() != shape::rank(minShapeInfo))
++k;
}
}
}
else{
for(int j = shape::rank(minShapeInfo) - 1, i = shape::rank(maxShapeInfo) - 1; i >=0 ; --i) {
if(j >= 0) {
shape::shapeOf(newShapeInfo)[i] = shape::shapeOf(minShapeInfo)[j];
shape::stride(newShapeInfo)[i] = shape::shapeOf(minShapeInfo)[j] == 1 ? 0 : shape::stride(minShapeInfo)[j];
--j;
}
else {
shape::shapeOf(newShapeInfo)[i] = 1;
shape::stride(newShapeInfo)[i] = 0;
}
}
}
ShapeDescriptor descriptor(newShapeInfo);
RELEASE(newShapeInfo, workspace);
return bufferForShapeInfo(descriptor);
}
sd::ConstantShapeHelper* sd::ConstantShapeHelper::_INSTANCE = 0;
}
#endif

View File

@ -135,5 +135,59 @@ namespace sd {
return result;
}
sd::ConstantShapeHelper* sd::ConstantShapeHelper::_INSTANCE = 0;
////////////////////////////////////////////////////////////////////////
ConstantDataBuffer ConstantShapeHelper::createShapeInfoWithUnitiesForBroadcast(const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, sd::memory::Workspace* workspace, const std::vector<int> dimensions) {
Nd4jLong* newShapeInfo = nullptr;
ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(shape::rank(maxShapeInfo)), Nd4jLong);
newShapeInfo[0] = shape::rank(maxShapeInfo);
sd::ArrayOptions::copyDataType(newShapeInfo, minShapeInfo); // type
newShapeInfo[2 * newShapeInfo[0] + 2] = shape::elementWiseStride(minShapeInfo); // ews
newShapeInfo[2 * newShapeInfo[0] + 3] = shape::order(minShapeInfo); // order
if(!dimensions.empty()) {
for(uint k = 0, j = 0, i = 0; i < shape::rank(maxShapeInfo); ++i) {
if(j < dimensions.size() && dimensions[j] == i) {
shape::shapeOf(newShapeInfo)[i] = shape::shapeOf(minShapeInfo)[k];
shape::stride(newShapeInfo)[i] = shape::stride(minShapeInfo)[k++];
++j;
}
else{
shape::shapeOf(newShapeInfo)[i] = 1;
shape::stride(newShapeInfo)[i] = 0;
if(shape::sizeAt(minShapeInfo, k) == 1 && dimensions.size() != shape::rank(minShapeInfo))
++k;
}
}
}
else{
for(int j = shape::rank(minShapeInfo) - 1, i = shape::rank(maxShapeInfo) - 1; i >=0 ; --i) {
if(j >= 0) {
shape::shapeOf(newShapeInfo)[i] = shape::shapeOf(minShapeInfo)[j];
shape::stride(newShapeInfo)[i] = shape::shapeOf(minShapeInfo)[j] == 1 ? 0 : shape::stride(minShapeInfo)[j];
--j;
}
else {
shape::shapeOf(newShapeInfo)[i] = 1;
shape::stride(newShapeInfo)[i] = 0;
}
}
}
ShapeDescriptor descriptor(newShapeInfo);
RELEASE(newShapeInfo, workspace);
return bufferForShapeInfo(descriptor);
}
sd::ConstantShapeHelper* sd::ConstantShapeHelper::_INSTANCE = 0;
}

View File

@ -118,6 +118,7 @@ namespace shape {
ND4J_EXPORT _CUDA_HD bool haveSameShapeAndStrides(const Nd4jLong *shapeInfo1, const Nd4jLong *shapeInfo2, const Nd4jLong *shapeInfo3);
ND4J_EXPORT _CUDA_HD int sizeAt(const Nd4jLong *shapeInfo, const int dim);
ND4J_EXPORT _CUDA_HD Nd4jLong strideAt(const Nd4jLong *shapeInfo, const int dim);
template <typename T>
ND4J_EXPORT _CUDA_HD void fill(T* buffer, T value, Nd4jLong length);
@ -2989,6 +2990,15 @@ INLINEDEF _CUDA_HD bool haveSameShapeAndStrides(const Nd4jLong *shapeInfo1, cons
return shapeInfo[1+(rank(shapeInfo) + dim)];
}
INLINEDEF _CUDA_HD Nd4jLong strideAt(const Nd4jLong *shapeInfo, const int dim) {
if (0 == rank(shapeInfo))
return 1;
if (dim >= 0)
return shapeInfo[1 + rank(shapeInfo) + dim];
else
return shapeInfo[1 + 2*rank(shapeInfo) + dim];
}
/**
* This method does SOFT comparison for two shape buffers, we compare only rank & shapes
*
@ -4117,7 +4127,7 @@ INLINEDEF _CUDA_HD bool reshapeC(const Nd4jLong* oldShapeInfo, Nd4jLong* newShap
*shape::ews(newShapeInfo) = oldEws; // ews
}
newShapeInfo[2 * newRank + 1] = shape::type(oldShapeInfo); // type
sd::ArrayOptions::copyDataType(newShapeInfo, oldShapeInfo); // type
return true;
}
@ -4742,7 +4752,7 @@ INLINEDEF _CUDA_HD void calcSubArrsShapeInfoAndOffsets(const Nd4jLong* wholeShap
const int subArrRank = keepUnitiesInShape ? rank : rank - dimsSize;
subArrShapeInfo[0] = subArrRank; // rank
subArrShapeInfo[2 * subArrRank + 1] = shape::type(wholeShapeInfo); // type
sd::ArrayOptions::copyDataType(subArrShapeInfo, wholeShapeInfo); // type
subArrShapeInfo[2 * subArrRank + 3] = shape::order(wholeShapeInfo); // order
Nd4jLong* shape = new Nd4jLong[dimsSize];
@ -4820,7 +4830,7 @@ INLINEDEF void calcSubArrShapeInfoAndOffset(const Nd4jLong* idx, const Nd4jLong*
}
minShapeInfo[2 * shape::rank(minShapeInfo) + 3] = shape::order(maxShapeInfo); // order
minShapeInfo[2 * shape::rank(minShapeInfo) + 1] = shape::type(maxShapeInfo); // type
sd::ArrayOptions::copyDataType(minShapeInfo, maxShapeInfo); // type
shape::checkStridesEwsAndOrder(minShapeInfo);
}
@ -5114,9 +5124,9 @@ INLINEDEF _CUDA_HD void excludeUnitiesFromShapeInfo(const Nd4jLong* inShapeInfo,
shape::stride(outShapeInfo)[k++] = shape::stride(inShapeInfo)[i];
}
outShapeInfo[2 * outShapeInfo[0] + 1] = shape::type(inShapeInfo); // type
*shape::ews(outShapeInfo) = shape::elementWiseStride(inShapeInfo); // ews
outShapeInfo[2 * outShapeInfo[0] + 3] = shape::order(inShapeInfo); // order
sd::ArrayOptions::copyDataType(outShapeInfo, inShapeInfo); // type
*shape::ews(outShapeInfo) = shape::elementWiseStride(inShapeInfo); // ews
outShapeInfo[2 * outShapeInfo[0] + 3] = shape::order(inShapeInfo); // order
}

View File

@ -46,7 +46,7 @@ public:
* @param resultShapeInfo
*/
static void execIndexReduceScalar(sd::LaunchContext *lc,
int opNum,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *extraParams,
@ -75,7 +75,7 @@ public:
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo);
/**
*
@ -263,6 +263,15 @@ static void execScalarInt(sd::LaunchContext *lc,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ);
static void execBroadcast(sd::LaunchContext* lc,
const int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
const void *hY, const Nd4jLong *hYShapeInfo,
const void *dY, const Nd4jLong *dYShapeInfo,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo);
static void execInverseBroadcast(sd::LaunchContext *lc,
int opNum,
void *x, Nd4jLong *xShapeInfo,
@ -289,6 +298,15 @@ static void execScalarInt(sd::LaunchContext *lc,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ);
static void execBroadcastBool(sd::LaunchContext* lc, const int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
const void *hY, const Nd4jLong *hYShapeInfo,
const void *dY, const Nd4jLong *dYShapeInfo,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
void *extraParams);
static void execInverseBroadcastBool(sd::LaunchContext *lc,
int opNum,
void *x, Nd4jLong *xShapeInfo,
@ -314,6 +332,14 @@ static void execScalarInt(sd::LaunchContext *lc,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ);
static void execBroadcastInt(sd::LaunchContext* lc, const int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
const void *hY, const Nd4jLong *hYShapeInfo,
const void *dY, const Nd4jLong *dYShapeInfo,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo);
static void execInverseBroadcastInt(sd::LaunchContext *lc,
int opNum,
void *x, Nd4jLong *xShapeInfo,
@ -325,7 +351,7 @@ static void execScalarInt(sd::LaunchContext *lc,
int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
/**
*
* @param opNum
@ -421,7 +447,7 @@ static void execTransformBool(sd::LaunchContext *lc,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *extraParams,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets);
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets);
/**
*
* @param opNum
@ -509,7 +535,7 @@ static void execTransformBool(sd::LaunchContext *lc,
void *dX, Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo);
void *dZ, Nd4jLong *dZShapeInfo);
static void execReduce3TAD(sd::LaunchContext *lc,
int opNum,
@ -520,7 +546,7 @@ static void execTransformBool(sd::LaunchContext *lc,
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength,
int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *yTadShapeInfo, Nd4jLong *yTadOffsets);
@ -544,7 +570,7 @@ static void execTransformBool(sd::LaunchContext *lc,
void *dZ, Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
bool biasCorrected);
bool biasCorrected);
/**
*
@ -580,7 +606,7 @@ static void execTransformBool(sd::LaunchContext *lc,
void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
bool biasCorrected);
bool biasCorrected);
static void execRandom(sd::LaunchContext *lc,
@ -627,7 +653,7 @@ static void execTransformBool(sd::LaunchContext *lc,
int numRealArguments) {
}
inline static void execSort(void *x, Nd4jLong *xShapeInfo, bool descending) {
auto xType = sd::ArrayOptions::dataType(xShapeInfo);

View File

@ -204,6 +204,30 @@ void NativeOpExecutioner::execBroadcast(sd::LaunchContext *lc,
#endif
}
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execBroadcast(sd::LaunchContext* lc, const int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
const void *hY, const Nd4jLong *hYShapeInfo,
const void *dY, const Nd4jLong *dYShapeInfo,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo) {
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo))
return;
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
auto yType = sd::ArrayOptions::dataType(hYShapeInfo);
auto zType = sd::ArrayOptions::dataType(hZShapeInfo);
#ifdef __ND4J_EXPERIMENTAL__
BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::broadcast::Broadcast, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo), LIBND4J_TYPES, LIBND4J_TYPES);
#else
BUILD_SINGLE_SELECTOR_THRICE(xType, functions::broadcast::Broadcast, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo), LIBND4J_TYPES);
#endif
}
void NativeOpExecutioner::execInverseBroadcast(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
@ -258,13 +282,12 @@ void NativeOpExecutioner::execBroadcastBool(sd::LaunchContext *lc,
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
auto yType = sd::ArrayOptions::dataType(hYShapeInfo);
auto zType = sd::ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo))
return;
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
auto zType = sd::ArrayOptions::dataType(hZShapeInfo);
auto func = PRAGMA_THREADS_FOR {
BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ, start, stop), LIBND4J_TYPES, BOOL_TYPES);
};
@ -276,6 +299,26 @@ void NativeOpExecutioner::execBroadcastBool(sd::LaunchContext *lc,
samediff::Threads::parallel_tad(func, 0, numTads);
}
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execBroadcastBool(sd::LaunchContext* lc, const int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
const void *hY, const Nd4jLong *hYShapeInfo,
const void *dY, const Nd4jLong *dYShapeInfo,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
void *extraParams) {
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo))
return;
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
auto zType = sd::ArrayOptions::dataType(hZShapeInfo);
BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, extraParams), LIBND4J_TYPES, BOOL_TYPES);
}
void NativeOpExecutioner::execInverseBroadcastBool(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
@ -351,6 +394,33 @@ void NativeOpExecutioner::execBroadcastInt(sd::LaunchContext *lc,
samediff::Threads::parallel_tad(func, 0, numTads);
}
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execBroadcastInt(sd::LaunchContext *lc, const int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
const void *hY, const Nd4jLong *hYShapeInfo,
const void *dY, const Nd4jLong *dYShapeInfo,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo) {
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
auto yType = sd::ArrayOptions::dataType(hYShapeInfo);
auto zType = sd::ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo))
return;
if (xType != yType || xType != zType)
throw sd::datatype_exception::build("NativeOpExecutioner::execBroadcastInt", zType, xType, yType);
if (!sd::DataTypeUtils::isZ(zType))
throw sd::datatype_exception::build("NativeOpExecutioner::execBroadcastInt requires integer data type", zType);
BUILD_SINGLE_SELECTOR(xType, functions::broadcast::BroadcastInt, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo), INTEGER_TYPES);
}
void NativeOpExecutioner::execInverseBroadcastInt(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,

View File

@ -265,6 +265,39 @@ void NativeOpExecutioner::execBroadcastBool(sd::LaunchContext *lc,
throw cuda_exception::build("execBroadcastBool failed", res);
}
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execBroadcastBool(sd::LaunchContext* lc, const int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
const void *hY, const Nd4jLong *hYShapeInfo,
const void *dY, const Nd4jLong *dYShapeInfo,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo,
void *extraParams) {
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo))
return;
auto stream = lc->getCudaStream();
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
auto zType = sd::ArrayOptions::dataType(hZShapeInfo);
dim3 launchDims;
launchDims.y = MAX_NUM_THREADS / 4; // threadsPerBlock
launchDims.x = (shape::length(hZShapeInfo) + launchDims.y - 1) / launchDims.y; // blocksPerGrid
launchDims.z = 1024; // shared memory
BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::execBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, extraParams), LIBND4J_TYPES, BOOL_TYPES);
// TODO: remove after the release
auto res = cudaStreamSynchronize(*stream);
if (res != 0)
throw cuda_exception::build("execBroadcastBool failed", res);
}
void NativeOpExecutioner::execInverseBroadcastBool(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
@ -302,7 +335,6 @@ void NativeOpExecutioner::execInverseBroadcastBool(sd::LaunchContext *lc,
throw cuda_exception::build("execInverseBroadcastBool failed", res);
}
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execBroadcastInt(sd::LaunchContext *lc,
int opNum,
@ -341,6 +373,44 @@ void NativeOpExecutioner::execBroadcastInt(sd::LaunchContext *lc,
throw cuda_exception::build("execBroadcastBool failed", res);
}
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execBroadcastInt(sd::LaunchContext* lc, const int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
const void *hY, const Nd4jLong *hYShapeInfo,
const void *dY, const Nd4jLong *dYShapeInfo,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo) {
auto stream = lc->getCudaStream();
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
auto yType = sd::ArrayOptions::dataType(hYShapeInfo);
auto zType = sd::ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo))
return;
if (!DataTypeUtils::isZ(zType))
throw std::runtime_error("NativeOpExecutioner::execBroadcastInt requires Z operand to have INT type");
if (yType != xType || zType != xType)
throw std::runtime_error("NativeOpExecutioner::execBroadcastInt requires both X & Y operands to have same type");
dim3 launchDims;
launchDims.y = MAX_NUM_THREADS / 4; // threadsPerBlock
launchDims.x = (shape::length(hZShapeInfo) + launchDims.y - 1) / launchDims.y; // blocksPerGrid
launchDims.z = 1024; // shared memory
BUILD_SINGLE_SELECTOR(xType, functions::broadcast::BroadcastInt, ::execBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo), INTEGER_TYPES)
// TODO: remove after the release
auto res = cudaStreamSynchronize(*stream);
if (res != 0)
throw cuda_exception::build("execBroadcastBool failed", res);
}
void NativeOpExecutioner::execInverseBroadcastInt(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
@ -428,6 +498,42 @@ void NativeOpExecutioner::execBroadcast(sd::LaunchContext *lc,
throw cuda_exception::build("execBroadcast failed", res);
}
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execBroadcast(sd::LaunchContext *lc, const int opNum,
const void *hX, const Nd4jLong *hXShapeInfo,
const void *dX, const Nd4jLong *dXShapeInfo,
const void *hY, const Nd4jLong *hYShapeInfo,
const void *dY, const Nd4jLong *dYShapeInfo,
void *hZ, const Nd4jLong *hZShapeInfo,
void *dZ, const Nd4jLong *dZShapeInfo) {
auto stream = lc->getCudaStream();
auto xType = sd::ArrayOptions::dataType(hXShapeInfo);
auto yType = sd::ArrayOptions::dataType(hYShapeInfo);
auto zType = sd::ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo))
return;
dim3 launchDims;
launchDims.y = MAX_NUM_THREADS / 4; // threadsPerBlock
launchDims.x = (shape::length(hZShapeInfo) + launchDims.y - 1) / launchDims.y; // blocksPerGrid
launchDims.z = 1024; // shared memory
#ifdef __ND4J_EXPERIMENTAL__
BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::broadcast::Broadcast, ::execBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo), LIBND4J_TYPES, LIBND4J_TYPES);
#else
BUILD_SINGLE_SELECTOR_THRICE(xType, functions::broadcast::Broadcast, ::execBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo), LIBND4J_TYPES);
#endif
// TODO: remove after the release
auto res = cudaStreamSynchronize(*stream);
if (res != 0)
throw cuda_exception::build("execBroadcast failed", res);
}
void NativeOpExecutioner::execInverseBroadcast(sd::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,

View File

@ -1,84 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Yurii Shyrma (iuriish@yahoo.com)
//
#ifndef LIBND4J_TRUEBROADCASTHELPER_H
#define LIBND4J_TRUEBROADCASTHELPER_H
#include <array/NDArray.h>
namespace sd {
namespace helpers {
////////////////////////////////////////////////////////////////////////
template <typename X, typename Y, typename Z>
class TrueBroadcastHelper {
#ifdef __CUDACC__
template <typename OpType>
static __host__ void execLauncher(dim3 launchDims, cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo);
#else
template <typename OpType>
static void exec(const NDArray& xArr, const NDArray& yArr, NDArray& zArr);
#endif
public:
static void exec(const sd::broadcast::Ops opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr);
};
template <typename X, typename Y>
class TrueBroadcastBoolHelper {
#ifdef __CUDACC__
template <typename OpType>
static __host__ void execLauncher(dim3 launchDims, cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo);
#else
template <typename OpType>
static void exec(const NDArray& xArr, const NDArray& yArr, NDArray& zArr);
#endif
public:
static void exec(const sd::broadcast::BoolOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr);
};
////////////////////////////////////////////////////////////////////////
template <typename X>
class TrueBroadcastIntHelper {
#ifdef __CUDACC__
template <typename OpType>
static __host__ void execLauncher(dim3 launchDims, cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo);
#else
template <typename OpType>
static void exec(const NDArray& xArr, const NDArray& yArr, NDArray& zArr);
#endif
public:
static void exec(const sd::broadcast::IntOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr);
};
}
}
#endif //LIBND4J_BIDIAGONALUP_H

View File

@ -69,11 +69,21 @@ namespace functions {
int *dimension,
int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
template<typename OpType>
static __device__ void transformCuda(const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *z, const Nd4jLong *zShapeInfo);
template <typename OpClass>
static __host__ void intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
template <typename OpClass>
static __host__ void intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, const void *x, const Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo, void *z, const Nd4jLong *zShapeInfo);
static __host__ void execBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
static __host__ void execBroadcast(dim3 launchDims, cudaStream_t *stream, const int opNum, const void *x, const Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo, void *z, const Nd4jLong *zShapeInfo);
template<typename OpType>
static __device__ void transformInverseCuda(
@ -170,6 +180,17 @@ namespace functions {
Nd4jLong *tadOffsetZ,
uint64_t start,
uint64_t stop);
static void exec(const int opNum,
const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *z, const Nd4jLong *zShapeInfo);
template<typename OpType>
static void exec(const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *z, const Nd4jLong *zShapeInfo);
#endif
};
}

View File

@ -69,11 +69,30 @@ namespace functions {
int *dimension,
int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
template<typename OpType>
static __device__ void transformCuda(const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *z, const Nd4jLong *zShapeInfo,
void *extraParams);
template <typename OpClass>
static __host__ void intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
template <typename OpClass>
static __host__ void intermediateBroadcast(dim3 launchDims, cudaStream_t *stream,
const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *z, const Nd4jLong *zShapeInfo,
void *extraParams);
static __host__ void execBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
static __host__ void execBroadcast(dim3 launchDims, cudaStream_t *stream, const int opNum,
const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *z, const Nd4jLong *zShapeInfo,
void *extraParams);
template<typename OpType>
static __device__ void transformInverseCuda(
void *x,
@ -110,6 +129,12 @@ namespace functions {
uint64_t start,
uint64_t stop);
static void exec(const int opNum,
const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *z, const Nd4jLong *zShapeInfo,
void *extraParams);
static void execInverse(int opNum,
void *x,
Nd4jLong *xShapeInfo,
@ -155,6 +180,12 @@ namespace functions {
uint64_t start,
uint64_t stop);
template<typename OpType>
static void exec(const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *z, const Nd4jLong *zShapeInfo,
void *extraParams);
template<typename OpType>
static void execInverse(void *x,
Nd4jLong *xShapeInfo,

View File

@ -68,11 +68,27 @@ namespace functions {
int *dimension,
int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
template<typename OpType>
static __device__ void transformCuda(const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *z, const Nd4jLong *zShapeInfo);
template <typename OpClass>
static __host__ void intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
template <typename OpClass>
static __host__ void intermediateBroadcast(dim3 launchDims, cudaStream_t *stream,
const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *z, const Nd4jLong *zShapeInfo);
static __host__ void execBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
static __host__ void execBroadcast(dim3 launchDims, cudaStream_t *stream, const int opNum,
const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *z, const Nd4jLong *zShapeInfo);
template<typename OpType>
static __device__ void transformInverseCuda(
void *x,
@ -107,6 +123,11 @@ namespace functions {
uint64_t start,
uint64_t stop);
static void exec(const int opNum,
const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *z, const Nd4jLong *zShapeInfo);
static void execInverse(int opNum,
void *x,
Nd4jLong *xShapeInfo,
@ -150,6 +171,11 @@ namespace functions {
uint64_t start,
uint64_t stop);
template<typename OpType>
static void exec(const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *z, const Nd4jLong *zShapeInfo);
template<typename OpType>
static void execInverse(void *x,
Nd4jLong *xShapeInfo,

View File

@ -1,291 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Yurii Shyrma (iuriish@yahoo.com)
//
#include <loops/TrueBroadcastHelper.h>
#include <ops/ops.h>
#include <execution/Threads.h>
using namespace simdOps;
namespace sd {
namespace helpers {
////////////////////////////////////////////////////////////////////////
template <typename X, typename Y, typename Z>
template<typename OpType>
void TrueBroadcastHelper<X, Y, Z>::exec(const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
const X* x = reinterpret_cast<X*>(xArr.getBuffer());
const Y* y = reinterpret_cast<Y*>(yArr.getBuffer());
Z* z = reinterpret_cast<Z*>(zArr.getBuffer());
const auto xShapeInfo = xArr.getShapeInfo();
const auto yShapeInfo = yArr.getShapeInfo();
const auto zShapeInfo = zArr.getShapeInfo();
const int xRank = xArr.rankOf();
const int yRank = yArr.rankOf();
const int zRank = zArr.rankOf();
bool bSpecialCase = (1 == xArr.ews() && 'c' == xArr.ordering() &&
1 == yArr.ews() && 'c' == yArr.ordering() &&
1 == zArr.ews() && 'c' == zArr.ordering());
if (bSpecialCase && yArr.isColumnVector() && 1 == xArr.sizeAt(-1) ) {
auto yLen = yArr.lengthOf();
auto func = PRAGMA_THREADS_FOR{
for (auto i = start; i < stop; i++) {
auto rZ = z + (i * yLen);
auto v = x[i];
for (Nd4jLong j = 0; j < yLen; j++) {
rZ[j] = OpType::op(v, y[j]);
}
}
};
samediff::Threads::parallel_tad(func, 0, xArr.lengthOf());
return;
}
auto yShapeInt = yArr.getShapeAsVectorInt();
auto xShapeInt = xArr.getShapeAsVectorInt();
auto nCountY = std::count_if(yShapeInt.cbegin(), yShapeInt.cend(), [](int i) { return i == 1; });
auto nCountX = std::count_if(xShapeInt.cbegin(), xShapeInt.cend(), [](int i) { return i == 1; });
bool bSpecialCase2 = (xRank == zRank && yRank == zRank && 1 == xArr.sizeAt(-1) && 1 == yArr.sizeAt(-2) && 1 == nCountY && 1 == nCountX);
if (bSpecialCase && bSpecialCase2) {
uint32_t zDim1 = zArr.sizeAt(-2);
uint32_t zDim2 = zArr.sizeAt(-1);
uint32_t nLen = zArr.lengthOf() / yArr.sizeAt(-1);
auto func = PRAGMA_THREADS_FOR{
for (auto total = start; total < stop; total++) {
uint32_t i = total / zDim1;
uint32_t j = total % zDim1;
uint32_t index = (i * zDim1) + j;
auto rZ = z + (index * zDim2);
auto rY = y + (i * zDim2);
auto rX = x[index];
for (uint32_t n = 0; n < zDim2; n++) {
rZ[n] = OpType::op(rX, rY[n]);
}
}
};
samediff::Threads::parallel_tad(func, 0, nLen, 1);
return;
}
const Nd4jLong zLen = zArr.lengthOf();
auto func = PRAGMA_THREADS_FOR{
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
for (auto i = start; i < stop; ++i) {
shape::index2coords(i, zShapeInfo, zCoords.data());
for (int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) {
if (ix >= 0) {
if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) {
xCoords[ix--] = zCoords[iz];
}
else {
xCoords[ix--] = 0;
}
}
if (iy >= 0) {
if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) {
yCoords[iy--] = zCoords[iz];
}
else {
yCoords[iy--] = 0;
}
}
}
const auto xOffset = shape::getOffset(xShapeInfo, xCoords.data());
const auto yOffset = shape::getOffset(yShapeInfo, yCoords.data());
const auto zOffset = shape::getOffset(zShapeInfo, zCoords.data());
z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
}
};
samediff::Threads::parallel_for(func, 0, zLen);
}
template <typename X, typename Y, typename Z>
void TrueBroadcastHelper<X, Y, Z>::exec(const sd::broadcast::Ops opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
DISPATCH_BY_OPNUM_TTT(exec, PARAMS(xArr, yArr, zArr), BROADCAST_OPS);
}
////////////////////////////////////////////////////////////////////////
template <typename X, typename Z>
template<typename OpType>
void TrueBroadcastBoolHelper<X, Z>::exec(const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
const X* x = reinterpret_cast<X*>(xArr.getBuffer());
const X* y = reinterpret_cast<X*>(yArr.getBuffer());
Z* z = reinterpret_cast<Z*>(zArr.getBuffer());
const auto xShapeInfo = xArr.getShapeInfo();
const auto yShapeInfo = yArr.getShapeInfo();
const auto zShapeInfo = zArr.getShapeInfo();
const int xRank = xArr.rankOf();
const int yRank = yArr.rankOf();
const int zRank = zArr.rankOf();
const Nd4jLong zLen = zArr.lengthOf();
auto func = PRAGMA_THREADS_FOR{
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
for (auto i = start; i < stop; ++i) {
shape::index2coords(i, zShapeInfo, zCoords.data());
for (int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) {
if (ix >= 0) {
if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) {
xCoords[ix--] = zCoords[iz];
}
else {
xCoords[ix--] = 0;
}
}
if (iy >= 0) {
if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) {
yCoords[iy--] = zCoords[iz];
}
else {
yCoords[iy--] = 0;
}
}
}
const auto xOffset = shape::getOffset(xShapeInfo, xCoords.data());
const auto yOffset = shape::getOffset(yShapeInfo, yCoords.data());
const auto zOffset = shape::getOffset(zShapeInfo, zCoords.data());
z[zOffset] = OpType::op(x[xOffset], y[yOffset], nullptr);
}
};
samediff::Threads::parallel_for(func, 0, zLen);
}
template <typename X, typename Y>
void TrueBroadcastBoolHelper<X, Y>::exec(const sd::broadcast::BoolOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
DISPATCH_BY_OPNUM_TT(exec, PARAMS(xArr, yArr, zArr), BROADCAST_BOOL_OPS);
}
////////////////////////////////////////////////////////////////////////
template <typename X>
template<typename OpType>
void TrueBroadcastIntHelper<X>::exec(const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
const X* x = reinterpret_cast<X*>(xArr.getBuffer());
const X* y = reinterpret_cast<X*>(yArr.getBuffer());
X* z = reinterpret_cast<X*>(zArr.getBuffer());
const auto xShapeInfo = xArr.getShapeInfo();
const auto yShapeInfo = yArr.getShapeInfo();
const auto zShapeInfo = zArr.getShapeInfo();
const int xRank = xArr.rankOf();
const int yRank = yArr.rankOf();
const int zRank = zArr.rankOf();
const Nd4jLong zLen = zArr.lengthOf();
auto func = PRAGMA_THREADS_FOR{
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
for (auto i = start; i < stop; ++i) {
shape::index2coords(i, zShapeInfo, zCoords.data());
for (int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) {
if (ix >= 0) {
if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) {
xCoords[ix--] = zCoords[iz];
}
else {
xCoords[ix--] = 0;
}
}
if (iy >= 0) {
if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) {
yCoords[iy--] = zCoords[iz];
}
else {
yCoords[iy--] = 0;
}
}
}
const auto xOffset = shape::getOffset(xShapeInfo, xCoords.data());
const auto yOffset = shape::getOffset(yShapeInfo, yCoords.data());
const auto zOffset = shape::getOffset(zShapeInfo, zCoords.data());
z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
}
};
samediff::Threads::parallel_for(func, 0, zLen);
}
template <typename X>
void TrueBroadcastIntHelper<X>::exec(const sd::broadcast::IntOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
DISPATCH_BY_OPNUM_T(exec, PARAMS(xArr, yArr, zArr), BROADCAST_INT_OPS);
}
/*
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_0);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_1);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_2);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_3);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_4);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_5);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_6);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_7);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_8);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_9);
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastBoolHelper, , LIBND4J_TYPES, BOOL_TYPES);
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastIntHelper, , INTEGER_TYPES);
*/
}
}

View File

@ -30,7 +30,7 @@
using namespace simdOps;
namespace functions {
namespace broadcast {
namespace broadcast {
template <typename X, typename Y, typename Z>
void Broadcast<X, Y, Z>::execInverse(const int opNum,
@ -93,6 +93,7 @@ namespace functions {
zTadOffset, loopKind, start, stop), BROADCAST_OPS);
}
template <typename X, typename Y, typename Z>
template<typename OpType>
void Broadcast<X, Y, Z>::exec(void *vx,
@ -562,5 +563,205 @@ namespace functions {
};
}
}
////////////////////////////////////////////////////////////////////////
template <typename X, typename Y, typename Z>
void Broadcast<X, Y, Z>::exec(const int opNum, const void *x, const Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo, void *z, const Nd4jLong *zShapeInfo) {
DISPATCH_BY_OPNUM_TTT(exec, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo), BROADCAST_OPS);
}
////////////////////////////////////////////////////////////////////////
template <typename X, typename Y, typename Z>
template<typename OpType>
void Broadcast<X, Y, Z>::exec(const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo) {
const X* x = reinterpret_cast<const X*>(vx);
const Y* y = reinterpret_cast<const Y*>(vy);
Z* z = reinterpret_cast<Z*>(vz);
const int rank = shape::rank(zShapeInfo); // xRank = yRank = zRank
const char zOrder = shape::order(zShapeInfo);
uint xAxis0 = shape::sizeAt(xShapeInfo, zOrder == 'c' ? 0 : rank-1);
uint xAxis1 = shape::sizeAt(xShapeInfo, zOrder == 'c' ? 1 : rank-2);
uint xAxis2 = rank > 2 ? shape::sizeAt(xShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
uint xAxis3 = rank > 3 ? shape::sizeAt(xShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
uint xAxis4 = rank > 4 ? shape::sizeAt(xShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, zOrder == 'c' ? 0 : rank-1);
Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, zOrder == 'c' ? 1 : rank-2);
Nd4jLong xStrd2 = rank > 2 ? shape::strideAt(xShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
Nd4jLong xStrd3 = rank > 3 ? shape::strideAt(xShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
Nd4jLong xStrd4 = rank > 4 ? shape::strideAt(xShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
uint yAxis0 = shape::sizeAt(yShapeInfo, zOrder == 'c' ? 0 : rank-1);
uint yAxis1 = shape::sizeAt(yShapeInfo, zOrder == 'c' ? 1 : rank-2);
uint yAxis2 = rank > 2 ? shape::sizeAt(yShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
uint yAxis3 = rank > 3 ? shape::sizeAt(yShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
uint yAxis4 = rank > 4 ? shape::sizeAt(yShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, zOrder == 'c' ? 0 : rank-1);
Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, zOrder == 'c' ? 1 : rank-2);
Nd4jLong yStrd2 = rank > 2 ? shape::strideAt(yShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
Nd4jLong yStrd3 = rank > 3 ? shape::strideAt(yShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
Nd4jLong yStrd4 = rank > 4 ? shape::strideAt(yShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
uint zAxis0 = shape::sizeAt(zShapeInfo, zOrder == 'c' ? 0 : rank-1);
uint zAxis1 = shape::sizeAt(zShapeInfo, zOrder == 'c' ? 1 : rank-2);
uint zAxis2 = rank > 2 ? shape::sizeAt(zShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
uint zAxis3 = rank > 3 ? shape::sizeAt(zShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
uint zAxis4 = rank > 4 ? shape::sizeAt(zShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, zOrder == 'c' ? 0 : rank-1);
Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, zOrder == 'c' ? 1 : rank-2);
Nd4jLong zStrd2 = rank > 2 ? shape::strideAt(zShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
Nd4jLong zStrd3 = rank > 3 ? shape::strideAt(zShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
Nd4jLong zStrd4 = rank > 4 ? shape::strideAt(zShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
switch (rank) {
case 1: {
auto func = PRAGMA_THREADS_FOR{
if(zStrd0 == 1 && xStrd0 <= 1 && yStrd0 <= 1)
for (auto i0 = start; i0 < stop; ++i0)
z[i0] = OpType::op(x[xStrd0 ? i0 : 0], y[yStrd0 ? i0 : 0]);
else
for (auto i0 = start; i0 < stop; ++i0)
z[i0 * zStrd0] = OpType::op(x[i0 * xStrd0], y[i0 * yStrd0]);
};
samediff::Threads::parallel_tad(func, 0, zAxis0);
}
break;
case 2: {
auto func = PRAGMA_THREADS_FOR{
for (auto i0 = start; i0 < stop; ++i0) {
auto x0 = x + i0 * xStrd0;
auto y0 = y + i0 * yStrd0;
auto z0 = z + i0 * zStrd0;
if(zStrd1 == 1 && xStrd1 <= 1 && yStrd1 <= 1)
for (uint i1 = 0; i1 < zAxis1; ++i1)
z0[i1] = OpType::op(x0[xStrd1 ? i1 : 0], y0[yStrd1 ? i1 : 0]);
else
for (uint i1 = 0; i1 < zAxis1; ++i1)
z0[i1 * zStrd1] = OpType::op(x0[i1 * xStrd1], y0[i1 * yStrd1]);
}
};
samediff::Threads::parallel_tad(func, 0, zAxis0);
}
break;
case 3: {
auto func = PRAGMA_THREADS_FOR_2D {
for (auto i0 = start_x; i0 < stop_x; ++i0) {
for (auto i1 = start_y; i1 < stop_y; ++i1) {
auto x1 = x + i0 * xStrd0 + i1 * xStrd1;
auto y1 = y + i0 * yStrd0 + i1 * yStrd1;
auto z1 = z + i0 * zStrd0 + i1 * zStrd1;
if(zStrd2 == 1 && xStrd2 <= 1 && yStrd2 <= 1)
for (uint i2 = 0; i2 < zAxis2; ++i2)
z1[i2] = OpType::op(x1[xStrd2 ? i2 : 0], y1[yStrd2 ? i2 : 0]);
else
for (uint i2 = 0; i2 < zAxis2; ++i2)
z1[i2 * zStrd2] = OpType::op(x1[i2 * xStrd2], y1[i2 * yStrd2]);
}
}
};
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1);
}
break;
case 4: {
auto func = PRAGMA_THREADS_FOR_3D {
for (auto i0 = start_x; i0 < stop_x; ++i0) {
for (auto i1 = start_y; i1 < stop_y; ++i1) {
for (auto i2 = start_z; i2 < stop_z; ++i2) {
auto x2 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2;
auto y2 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2;
auto z2 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2;
if(zStrd3 == 1 && xStrd3 <= 1 && yStrd3 <= 1)
for (uint i3 = 0; i3 < zAxis3; ++i3)
z2[i3] = OpType::op(x2[xStrd3 ? i3 : 0], y2[yStrd3 ? i3 : 0]);
else
for (uint i3 = 0; i3 < zAxis3; ++i3)
z2[i3 * zStrd3] = OpType::op(x2[i3 * xStrd3], y2[i3 * yStrd3]);
}
}
}
};
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1);
}
break;
case 5: {
auto func = PRAGMA_THREADS_FOR_3D {
for (auto i0 = start_x; i0 < stop_x; ++i0) {
for (auto i1 = start_y; i1 < stop_y; ++i1) {
for (auto i2 = start_z; i2 < stop_z; ++i2) {
for (uint i3 = 0; i3 < zAxis3; ++i3) {
auto x3 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2 + i3 * xStrd3;
auto y3 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2 + i3 * yStrd3;
auto z3 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2 + i3 * zStrd3;
if(zStrd4 == 1 && xStrd4 <= 1 && yStrd4 <= 1)
for (uint i4 = 0; i4 < zAxis4; ++i4)
z3[i4] = OpType::op(x3[xStrd4 ? i4 : 0], y3[yStrd4 ? i4 : 0]);
else
for (uint i4 = 0; i4 < zAxis4; ++i4)
z3[i4 * zStrd4] = OpType::op(x3[i4 * xStrd4], y3[i4 * yStrd4]);
}
}
}
}
};
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1);
}
break;
default: {
auto func = PRAGMA_THREADS_FOR{
Nd4jLong xCoords[MAX_RANK], yCoords[MAX_RANK], zCoords[MAX_RANK];
for (auto i = start; i < stop; ++i) {
shape::index2coords(i, zShapeInfo, zCoords);
for (uint j = 0; j < rank; ++j) {
xCoords[j] = shape::sizeAt(xShapeInfo, j) == 1 ? 0 : zCoords[j];
yCoords[j] = shape::sizeAt(yShapeInfo, j) == 1 ? 0 : zCoords[j];
}
const auto xOffset = shape::getOffset(xShapeInfo, xCoords);
const auto yOffset = shape::getOffset(yShapeInfo, yCoords);
const auto zOffset = shape::getOffset(zShapeInfo, zCoords);
z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
}
};
samediff::Threads::parallel_for(func, 0, shape::length(zShapeInfo));
}
}
}
}
}

View File

@ -63,6 +63,16 @@ namespace functions {
zTadOffset, start, stop), BROADCAST_BOOL_OPS);
}
template <typename X, typename Y>
void BroadcastBool<X, Y>::exec(const int opNum,
const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *z, const Nd4jLong *zShapeInfo,
void* extraParams) {
DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams), BROADCAST_BOOL_OPS);
}
template <typename X, typename Y>
void BroadcastBool<X, Y>::execInverse(const int opNum,
void *x,
@ -267,8 +277,203 @@ namespace functions {
}
}
////////////////////////////////////////////////////////////////////////
template <typename X, typename Z>
template<typename OpType>
void BroadcastBool<X, Z>::exec(const void *vx, const Nd4jLong *xShapeInfo,
const void *vy, const Nd4jLong *yShapeInfo,
void *vz, const Nd4jLong *zShapeInfo,
void *vextraParams) {
template <typename X, typename Z>
const X* x = reinterpret_cast<const X*>(vx);
const X* y = reinterpret_cast<const X*>(vy);
Z* z = reinterpret_cast<Z*>(vz);
X* extraParams = reinterpret_cast<X*>(vextraParams);
const int rank = shape::rank(zShapeInfo); // xRank = yRank = zRank
const char zOrder = shape::order(zShapeInfo);
uint xAxis0 = shape::sizeAt(xShapeInfo, zOrder == 'c' ? 0 : rank-1);
uint xAxis1 = shape::sizeAt(xShapeInfo, zOrder == 'c' ? 1 : rank-2);
uint xAxis2 = rank > 2 ? shape::sizeAt(xShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
uint xAxis3 = rank > 3 ? shape::sizeAt(xShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
uint xAxis4 = rank > 4 ? shape::sizeAt(xShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, zOrder == 'c' ? 0 : rank-1);
Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, zOrder == 'c' ? 1 : rank-2);
Nd4jLong xStrd2 = rank > 2 ? shape::strideAt(xShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
Nd4jLong xStrd3 = rank > 3 ? shape::strideAt(xShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
Nd4jLong xStrd4 = rank > 4 ? shape::strideAt(xShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
uint yAxis0 = shape::sizeAt(yShapeInfo, zOrder == 'c' ? 0 : rank-1);
uint yAxis1 = shape::sizeAt(yShapeInfo, zOrder == 'c' ? 1 : rank-2);
uint yAxis2 = rank > 2 ? shape::sizeAt(yShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
uint yAxis3 = rank > 3 ? shape::sizeAt(yShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
uint yAxis4 = rank > 4 ? shape::sizeAt(yShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, zOrder == 'c' ? 0 : rank-1);
Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, zOrder == 'c' ? 1 : rank-2);
Nd4jLong yStrd2 = rank > 2 ? shape::strideAt(yShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
Nd4jLong yStrd3 = rank > 3 ? shape::strideAt(yShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
Nd4jLong yStrd4 = rank > 4 ? shape::strideAt(yShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
uint zAxis0 = shape::sizeAt(zShapeInfo, zOrder == 'c' ? 0 : rank-1);
uint zAxis1 = shape::sizeAt(zShapeInfo, zOrder == 'c' ? 1 : rank-2);
uint zAxis2 = rank > 2 ? shape::sizeAt(zShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
uint zAxis3 = rank > 3 ? shape::sizeAt(zShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
uint zAxis4 = rank > 4 ? shape::sizeAt(zShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, zOrder == 'c' ? 0 : rank-1);
Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, zOrder == 'c' ? 1 : rank-2);
Nd4jLong zStrd2 = rank > 2 ? shape::strideAt(zShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
Nd4jLong zStrd3 = rank > 3 ? shape::strideAt(zShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
Nd4jLong zStrd4 = rank > 4 ? shape::strideAt(zShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
switch (rank) {
case 1: {
auto func = PRAGMA_THREADS_FOR{
if(zStrd0 == 1 && xStrd0 <= 1 && yStrd0 <= 1)
for (auto i0 = start; i0 < stop; ++i0)
z[i0] = OpType::op(x[xStrd0 ? i0 : 0], y[yStrd0 ? i0 : 0], extraParams);
else
for (auto i0 = start; i0 < stop; ++i0)
z[i0 * zStrd0] = OpType::op(x[i0 * xStrd0], y[i0 * yStrd0], extraParams);
};
samediff::Threads::parallel_tad(func, 0, zAxis0);
}
break;
case 2: {
auto func = PRAGMA_THREADS_FOR{
for (auto i0 = start; i0 < stop; ++i0) {
auto x0 = x + i0 * xStrd0;
auto y0 = y + i0 * yStrd0;
auto z0 = z + i0 * zStrd0;
if(zStrd1 == 1 && xStrd1 <= 1 && yStrd1 <= 1)
for (uint i1 = 0; i1 < zAxis1; ++i1)
z0[i1] = OpType::op(x0[xStrd1 ? i1 : 0], y0[yStrd1 ? i1 : 0], extraParams);
else
for (uint i1 = 0; i1 < zAxis1; ++i1)
z0[i1 * zStrd1] = OpType::op(x0[i1 * xStrd1], y0[i1 * yStrd1], extraParams);
}
};
samediff::Threads::parallel_tad(func, 0, zAxis0);
}
break;
case 3: {
auto func = PRAGMA_THREADS_FOR_2D {
for (auto i0 = start_x; i0 < stop_x; ++i0) {
for (auto i1 = start_y; i1 < stop_y; ++i1) {
auto x1 = x + i0 * xStrd0 + i1 * xStrd1;
auto y1 = y + i0 * yStrd0 + i1 * yStrd1;
auto z1 = z + i0 * zStrd0 + i1 * zStrd1;
if(zStrd2 == 1 && xStrd2 <= 1 && yStrd2 <= 1)
for (uint i2 = 0; i2 < zAxis2; ++i2)
z1[i2] = OpType::op(x1[xStrd2 ? i2 : 0], y1[yStrd2 ? i2 : 0], extraParams);
else
for (uint i2 = 0; i2 < zAxis2; ++i2)
z1[i2 * zStrd2] = OpType::op(x1[i2 * xStrd2], y1[i2 * yStrd2], extraParams);
}
}
};
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1);
}
break;
case 4: {
auto func = PRAGMA_THREADS_FOR_3D {
for (auto i0 = start_x; i0 < stop_x; ++i0) {
for (auto i1 = start_y; i1 < stop_y; ++i1) {
for (auto i2 = start_z; i2 < stop_z; ++i2) {
auto x2 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2;
auto y2 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2;
auto z2 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2;
if(zStrd3 == 1 && xStrd3 <= 1 && yStrd3 <= 1)
for (uint i3 = 0; i3 < zAxis3; ++i3)
z2[i3] = OpType::op(x2[xStrd3 ? i3 : 0], y2[yStrd3 ? i3 : 0], extraParams);
else
for (uint i3 = 0; i3 < zAxis3; ++i3)
z2[i3 * zStrd3] = OpType::op(x2[i3 * xStrd3], y2[i3 * yStrd3], extraParams);
}
}
}
};
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1);
}
break;
case 5: {
auto func = PRAGMA_THREADS_FOR_3D {
for (auto i0 = start_x; i0 < stop_x; ++i0) {
for (auto i1 = start_y; i1 < stop_y; ++i1) {
for (auto i2 = start_z; i2 < stop_z; ++i2) {
for (uint i3 = 0; i3 < zAxis3; ++i3) {
auto x3 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2 + i3 * xStrd3;
auto y3 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2 + i3 * yStrd3;
auto z3 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2 + i3 * zStrd3;
if(zStrd4 == 1 && xStrd4 <= 1 && yStrd4 <= 1)
for (uint i4 = 0; i4 < zAxis4; ++i4)
z3[i4] = OpType::op(x3[xStrd4 ? i4 : 0], y3[yStrd4 ? i4 : 0], extraParams);
else
for (uint i4 = 0; i4 < zAxis4; ++i4)
z3[i4 * zStrd4] = OpType::op(x3[i4 * xStrd4], y3[i4 * yStrd4], extraParams);
}
}
}
}
};
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1);
}
break;
default: {
auto func = PRAGMA_THREADS_FOR{
Nd4jLong xCoords[MAX_RANK], yCoords[MAX_RANK], zCoords[MAX_RANK];
for (auto i = start; i < stop; ++i) {
shape::index2coords(i, zShapeInfo, zCoords);
for (uint j = 0; j < rank; ++j) {
xCoords[j] = shape::sizeAt(xShapeInfo, j) == 1 ? 0 : zCoords[j];
yCoords[j] = shape::sizeAt(yShapeInfo, j) == 1 ? 0 : zCoords[j];
}
const auto xOffset = shape::getOffset(xShapeInfo, xCoords);
const auto yOffset = shape::getOffset(yShapeInfo, yCoords);
const auto zOffset = shape::getOffset(zShapeInfo, zCoords);
z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams);
}
};
samediff::Threads::parallel_for(func, 0, shape::length(zShapeInfo));
}
}
}
template <typename X, typename Z>
template<typename OpType>
void BroadcastBool<X, Z>::execInverse(void *vx,
Nd4jLong *xShapeInfo,

View File

@ -61,6 +61,15 @@ namespace functions {
zTadOffset, start, stop), BROADCAST_INT_OPS);
}
template <typename X>
void BroadcastInt<X>::exec(const int opNum,
const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *z, const Nd4jLong *zShapeInfo) {
DISPATCH_BY_OPNUM_T(exec, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo), BROADCAST_INT_OPS);
}
template <typename X>
void BroadcastInt<X>::execInverse(const int opNum,
void *x,
@ -430,6 +439,200 @@ namespace functions {
}
}
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT BroadcastInt, , INTEGER_TYPES);
////////////////////////////////////////////////////////////////////////
template <typename X>
template<typename OpType>
void BroadcastInt<X>::exec(const void *vx, const Nd4jLong *xShapeInfo,
const void *vy, const Nd4jLong *yShapeInfo,
void *vz, const Nd4jLong *zShapeInfo) {
const X* x = reinterpret_cast<const X*>(vx);
const X* y = reinterpret_cast<const X*>(vy);
X* z = reinterpret_cast<X*>(vz);
const int rank = shape::rank(zShapeInfo); // xRank = yRank = zRank
const char zOrder = shape::order(zShapeInfo);
uint xAxis0 = shape::sizeAt(xShapeInfo, zOrder == 'c' ? 0 : rank-1);
uint xAxis1 = shape::sizeAt(xShapeInfo, zOrder == 'c' ? 1 : rank-2);
uint xAxis2 = rank > 2 ? shape::sizeAt(xShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
uint xAxis3 = rank > 3 ? shape::sizeAt(xShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
uint xAxis4 = rank > 4 ? shape::sizeAt(xShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, zOrder == 'c' ? 0 : rank-1);
Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, zOrder == 'c' ? 1 : rank-2);
Nd4jLong xStrd2 = rank > 2 ? shape::strideAt(xShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
Nd4jLong xStrd3 = rank > 3 ? shape::strideAt(xShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
Nd4jLong xStrd4 = rank > 4 ? shape::strideAt(xShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
uint yAxis0 = shape::sizeAt(yShapeInfo, zOrder == 'c' ? 0 : rank-1);
uint yAxis1 = shape::sizeAt(yShapeInfo, zOrder == 'c' ? 1 : rank-2);
uint yAxis2 = rank > 2 ? shape::sizeAt(yShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
uint yAxis3 = rank > 3 ? shape::sizeAt(yShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
uint yAxis4 = rank > 4 ? shape::sizeAt(yShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, zOrder == 'c' ? 0 : rank-1);
Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, zOrder == 'c' ? 1 : rank-2);
Nd4jLong yStrd2 = rank > 2 ? shape::strideAt(yShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
Nd4jLong yStrd3 = rank > 3 ? shape::strideAt(yShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
Nd4jLong yStrd4 = rank > 4 ? shape::strideAt(yShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
uint zAxis0 = shape::sizeAt(zShapeInfo, zOrder == 'c' ? 0 : rank-1);
uint zAxis1 = shape::sizeAt(zShapeInfo, zOrder == 'c' ? 1 : rank-2);
uint zAxis2 = rank > 2 ? shape::sizeAt(zShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
uint zAxis3 = rank > 3 ? shape::sizeAt(zShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
uint zAxis4 = rank > 4 ? shape::sizeAt(zShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, zOrder == 'c' ? 0 : rank-1);
Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, zOrder == 'c' ? 1 : rank-2);
Nd4jLong zStrd2 = rank > 2 ? shape::strideAt(zShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0;
Nd4jLong zStrd3 = rank > 3 ? shape::strideAt(zShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0;
Nd4jLong zStrd4 = rank > 4 ? shape::strideAt(zShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0;
switch (rank) {
case 1: {
auto func = PRAGMA_THREADS_FOR{
if(zStrd0 == 1 && xStrd0 <= 1 && yStrd0 <= 1)
for (auto i0 = start; i0 < stop; ++i0)
z[i0] = OpType::op(x[xStrd0 ? i0 : 0], y[yStrd0 ? i0 : 0]);
else
for (auto i0 = start; i0 < stop; ++i0)
z[i0 * zStrd0] = OpType::op(x[i0 * xStrd0], y[i0 * yStrd0]);
};
samediff::Threads::parallel_tad(func, 0, zAxis0);
}
break;
case 2: {
auto func = PRAGMA_THREADS_FOR{
for (auto i0 = start; i0 < stop; ++i0) {
auto x0 = x + i0 * xStrd0;
auto y0 = y + i0 * yStrd0;
auto z0 = z + i0 * zStrd0;
if(zStrd1 == 1 && xStrd1 <= 1 && yStrd1 <= 1)
for (uint i1 = 0; i1 < zAxis1; ++i1)
z0[i1] = OpType::op(x0[xStrd1 ? i1 : 0], y0[yStrd1 ? i1 : 0]);
else
for (uint i1 = 0; i1 < zAxis1; ++i1)
z0[i1 * zStrd1] = OpType::op(x0[i1 * xStrd1], y0[i1 * yStrd1]);
}
};
samediff::Threads::parallel_tad(func, 0, zAxis0);
}
break;
case 3: {
auto func = PRAGMA_THREADS_FOR_2D {
for (auto i0 = start_x; i0 < stop_x; ++i0) {
for (auto i1 = start_y; i1 < stop_y; ++i1) {
auto x1 = x + i0 * xStrd0 + i1 * xStrd1;
auto y1 = y + i0 * yStrd0 + i1 * yStrd1;
auto z1 = z + i0 * zStrd0 + i1 * zStrd1;
if(zStrd2 == 1 && xStrd2 <= 1 && yStrd2 <= 1)
for (uint i2 = 0; i2 < zAxis2; ++i2)
z1[i2] = OpType::op(x1[xStrd2 ? i2 : 0], y1[yStrd2 ? i2 : 0]);
else
for (uint i2 = 0; i2 < zAxis2; ++i2)
z1[i2 * zStrd2] = OpType::op(x1[i2 * xStrd2], y1[i2 * yStrd2]);
}
}
};
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1);
}
break;
case 4: {
auto func = PRAGMA_THREADS_FOR_3D {
for (auto i0 = start_x; i0 < stop_x; ++i0) {
for (auto i1 = start_y; i1 < stop_y; ++i1) {
for (auto i2 = start_z; i2 < stop_z; ++i2) {
auto x2 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2;
auto y2 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2;
auto z2 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2;
if(zStrd3 == 1 && xStrd3 <= 1 && yStrd3 <= 1)
for (uint i3 = 0; i3 < zAxis3; ++i3)
z2[i3] = OpType::op(x2[xStrd3 ? i3 : 0], y2[yStrd3 ? i3 : 0]);
else
for (uint i3 = 0; i3 < zAxis3; ++i3)
z2[i3 * zStrd3] = OpType::op(x2[i3 * xStrd3], y2[i3 * yStrd3]);
}
}
}
};
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1);
}
break;
case 5: {
auto func = PRAGMA_THREADS_FOR_3D {
for (auto i0 = start_x; i0 < stop_x; ++i0) {
for (auto i1 = start_y; i1 < stop_y; ++i1) {
for (auto i2 = start_z; i2 < stop_z; ++i2) {
for (uint i3 = 0; i3 < zAxis3; ++i3) {
auto x3 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2 + i3 * xStrd3;
auto y3 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2 + i3 * yStrd3;
auto z3 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2 + i3 * zStrd3;
if(zStrd4 == 1 && xStrd4 <= 1 && yStrd4 <= 1)
for (uint i4 = 0; i4 < zAxis4; ++i4)
z3[i4] = OpType::op(x3[xStrd4 ? i4 : 0], y3[yStrd4 ? i4 : 0]);
else
for (uint i4 = 0; i4 < zAxis4; ++i4)
z3[i4 * zStrd4] = OpType::op(x3[i4 * xStrd4], y3[i4 * yStrd4]);
}
}
}
}
};
samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1);
}
break;
default: {
auto func = PRAGMA_THREADS_FOR{
Nd4jLong xCoords[MAX_RANK], yCoords[MAX_RANK], zCoords[MAX_RANK];
for (auto i = start; i < stop; ++i) {
shape::index2coords(i, zShapeInfo, zCoords);
for (uint j = 0; j < rank; ++j) {
xCoords[j] = shape::sizeAt(xShapeInfo, j) == 1 ? 0 : zCoords[j];
yCoords[j] = shape::sizeAt(yShapeInfo, j) == 1 ? 0 : zCoords[j];
}
const auto xOffset = shape::getOffset(xShapeInfo, xCoords);
const auto yOffset = shape::getOffset(yShapeInfo, yCoords);
const auto zOffset = shape::getOffset(zShapeInfo, zCoords);
z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
}
};
samediff::Threads::parallel_for(func, 0, shape::length(zShapeInfo));
}
}
}
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT BroadcastInt, , INTEGER_TYPES);
}
}

View File

@ -1,27 +0,0 @@
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include "../TrueBroadcastHelper.hpp"
namespace sd {
namespace helpers {
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_9);
}
}

View File

@ -1,27 +0,0 @@
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include "../TrueBroadcastHelper.hpp"
namespace sd {
namespace helpers {
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_0);
}
}

View File

@ -1,27 +0,0 @@
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include "../TrueBroadcastHelper.hpp"
namespace sd {
namespace helpers {
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_1);
}
}

View File

@ -1,27 +0,0 @@
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include "../TrueBroadcastHelper.hpp"
namespace sd {
namespace helpers {
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_2);
}
}

View File

@ -1,27 +0,0 @@
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include "../TrueBroadcastHelper.hpp"
namespace sd {
namespace helpers {
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_3);
}
}

View File

@ -1,27 +0,0 @@
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include "../TrueBroadcastHelper.hpp"
namespace sd {
namespace helpers {
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_4);
}
}

View File

@ -1,27 +0,0 @@
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include "../TrueBroadcastHelper.hpp"
namespace sd {
namespace helpers {
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_5);
}
}

View File

@ -1,27 +0,0 @@
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include "../TrueBroadcastHelper.hpp"
namespace sd {
namespace helpers {
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_6);
}
}

View File

@ -1,27 +0,0 @@
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include "../TrueBroadcastHelper.hpp"
namespace sd {
namespace helpers {
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_7);
}
}

View File

@ -1,27 +0,0 @@
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include "../TrueBroadcastHelper.hpp"
namespace sd {
namespace helpers {
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_8);
}
}

View File

@ -1,27 +0,0 @@
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include "../TrueBroadcastHelper.hpp"
namespace sd {
namespace helpers {
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastBoolHelper, , LIBND4J_TYPES, BOOL_TYPES);
}
}

View File

@ -1,27 +0,0 @@
/*******************************************************************************
* Copyright (c) 2019-2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include "../TrueBroadcastHelper.hpp"
namespace sd {
namespace helpers {
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastIntHelper, , INTEGER_TYPES);
}
}

View File

@ -1,312 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Yurii Shyrma (iuriish@yahoo.com)
//
// #include <exceptions/cuda_exception.h>
#include <loops/TrueBroadcastHelper.h>
#include <helpers/PointersManager.h>
#include <execution/LaunchContext.h>
#include <ops/specials.h>
#include <helpers/logger.h>
#include <ops/ops.h>
// #include <cuda_runtime.h>
// #include <cuda.h>
using namespace simdOps;
namespace sd {
namespace helpers {
////////////////////////////////////////////////////////////////////////
template <typename X, typename Y, typename Z, typename OpType>
__global__ static void trueBroadcastCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo) {
const auto x = reinterpret_cast<const X*>(vx);
const auto y = reinterpret_cast<const Y*>(vy);
auto z = reinterpret_cast<Z*>(vz);
__shared__ int xRank, yRank, zRank;
__shared__ Nd4jLong zLen, totalThreads, *sharedMem; // xLen == zLen, except when xRank = 1, in this case zLen = 2*xLen
if (threadIdx.x == 0) {
extern __shared__ unsigned char shmem[];
sharedMem = reinterpret_cast<Nd4jLong*>(shmem);
xRank = shape::rank(xShapeInfo);
yRank = shape::rank(yShapeInfo);
zRank = shape::rank(zShapeInfo);
zLen = shape::length(zShapeInfo);
totalThreads = gridDim.x * blockDim.x;
}
__syncthreads();
auto xCoords = sharedMem + threadIdx.x * (xRank + yRank + zRank);
auto yCoords = xCoords + xRank;
auto zCoords = yCoords + yRank;
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
for (Nd4jLong i = tid; i < zLen; i += totalThreads) {
shape::index2coords(i, zShapeInfo, zCoords);
for(int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) {
if(ix >= 0)
if(xShapeInfo[ix + 1] == zShapeInfo[iz + 1])
xCoords[ix--] = zCoords[iz];
else
xCoords[ix--] = 0;
if(iy >= 0)
if(yShapeInfo[iy + 1] == zShapeInfo[iz + 1])
yCoords[iy--] = zCoords[iz];
else
yCoords[iy--] = 0;
}
const auto xOffset = shape::getOffset(xShapeInfo, xCoords);
const auto zOffset = shape::getOffset(zShapeInfo, zCoords);
const auto yOffset = shape::getOffset(yShapeInfo, yCoords);
z[zOffset] = OpType::op(x[xOffset], y[yOffset], nullptr);
}
}
////////////////////////////////////////////////////////////////////////
template<typename X, typename Y, typename Z>
template <typename OpType>
void TrueBroadcastHelper<X,Y,Z>::execLauncher(dim3 launchDims, cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo) {
trueBroadcastCuda<X, Y, Z, OpType><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo);
}
//////////////////////////////////////////////////////////////////////////
template<typename X, typename Y, typename Z>
void TrueBroadcastHelper<X,Y,Z>::exec(const sd::broadcast::Ops opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
dim3 launchDims;
launchDims.y = MAX_NUM_THREADS / 8; // threadsPerBlock
launchDims.x = (zArr.lengthOf() + launchDims.y - 1) / launchDims.y; // blocksPerGrid
launchDims.z = sizeof(Nd4jLong) * launchDims.y * (xArr.rankOf() + yArr.rankOf() + zArr.rankOf()) + 128; // sharedMe
PointersManager manager(xArr.getContext(), "TrueBroadcastHelper<X,Y,Z>::exec");
NDArray::prepareSpecialUse({&zArr}, {&xArr, &yArr});
DISPATCH_BY_OPNUM_TTT(execLauncher, PARAMS(launchDims, xArr.getContext()->getCudaStream(), xArr.getSpecialBuffer(), xArr.getSpecialShapeInfo(), yArr.getSpecialBuffer(), yArr.getSpecialShapeInfo(), zArr.specialBuffer(), zArr.specialShapeInfo()), OPS_A(BROADCAST_OPS));
NDArray::registerSpecialUse({&zArr}, {&xArr, &yArr});
manager.synchronize();
}
////////////////////////////////////////////////////////////////////////
template <typename X, typename Z, typename OpType>
__global__ static void trueBroadcastBoolCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo) {
const auto x = reinterpret_cast<const X*>(vx);
const auto y = reinterpret_cast<const X*>(vy);
auto z = reinterpret_cast<Z*>(vz);
__shared__ int xRank, yRank, zRank;
__shared__ Nd4jLong zLen, totalThreads, *sharedMem; // xLen == zLen, except when xRank = 1, in this case zLen = 2*xLen
if (threadIdx.x == 0) {
extern __shared__ unsigned char shmem[];
sharedMem = reinterpret_cast<Nd4jLong*>(shmem);
xRank = shape::rank(xShapeInfo);
yRank = shape::rank(yShapeInfo);
zRank = shape::rank(zShapeInfo);
zLen = shape::length(zShapeInfo);
totalThreads = gridDim.x * blockDim.x;
}
__syncthreads();
auto xCoords = sharedMem + threadIdx.x * (xRank + yRank + zRank);
auto yCoords = xCoords + xRank;
auto zCoords = yCoords + yRank;
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
for (Nd4jLong i = tid; i < zLen; i += totalThreads) {
shape::index2coords(i, zShapeInfo, zCoords);
for(int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) {
if(ix >= 0)
if(xShapeInfo[ix + 1] == zShapeInfo[iz + 1])
xCoords[ix--] = zCoords[iz];
else
xCoords[ix--] = 0;
if(iy >= 0)
if(yShapeInfo[iy + 1] == zShapeInfo[iz + 1])
yCoords[iy--] = zCoords[iz];
else
yCoords[iy--] = 0;
}
const auto xOffset = shape::getOffset(xShapeInfo, xCoords);
const auto zOffset = shape::getOffset(zShapeInfo, zCoords);
const auto yOffset = shape::getOffset(yShapeInfo, yCoords);
z[zOffset] = OpType::op(x[xOffset], y[yOffset], nullptr);
}
}
////////////////////////////////////////////////////////////////////////
template<typename X, typename Z>
template <typename OpType>
void TrueBroadcastBoolHelper<X,Z>::execLauncher(dim3 launchDims, cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo) {
trueBroadcastBoolCuda<X,Z,OpType><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo);
}
//////////////////////////////////////////////////////////////////////////
template<typename X, typename Y>
void TrueBroadcastBoolHelper<X,Y>::exec(const sd::broadcast::BoolOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
dim3 launchDims;
launchDims.y = MAX_NUM_THREADS / 8; // threadsPerBlock
launchDims.x = (zArr.lengthOf() + launchDims.y - 1) / launchDims.y; // blocksPerGrid
launchDims.z = sizeof(Nd4jLong) * launchDims.y * (xArr.rankOf() + yArr.rankOf() + zArr.rankOf()) + 128; // sharedMe
PointersManager manager(xArr.getContext(), "TrueBroadcastBoolHelper<X,Y>::exec");
NDArray::prepareSpecialUse({&zArr}, {&xArr, &yArr});
DISPATCH_BY_OPNUM_TT(execLauncher, PARAMS(launchDims, xArr.getContext()->getCudaStream(), xArr.getSpecialBuffer(), xArr.getSpecialShapeInfo(), yArr.getSpecialBuffer(), yArr.getSpecialShapeInfo(), zArr.specialBuffer(), zArr.specialShapeInfo()), OPS_A(BROADCAST_BOOL_OPS));
NDArray::registerSpecialUse({&zArr}, {&xArr, &yArr});
manager.synchronize();
}
////////////////////////////////////////////////////////////////////////
template <typename X, typename OpType>
__global__ static void trueBroadcastIntCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo) {
const auto x = reinterpret_cast<const X*>(vx);
const auto y = reinterpret_cast<const X*>(vy);
auto z = reinterpret_cast<X*>(vz);
__shared__ int xRank, yRank, zRank;
__shared__ Nd4jLong zLen, totalThreads, *sharedMem; // xLen == zLen, except when xRank = 1, in this case zLen = 2*xLen
if (threadIdx.x == 0) {
extern __shared__ unsigned char shmem[];
sharedMem = reinterpret_cast<Nd4jLong*>(shmem);
xRank = shape::rank(xShapeInfo);
yRank = shape::rank(yShapeInfo);
zRank = shape::rank(zShapeInfo);
zLen = shape::length(zShapeInfo);
totalThreads = gridDim.x * blockDim.x;
}
__syncthreads();
auto xCoords = sharedMem + threadIdx.x * (xRank + yRank + zRank);
auto yCoords = xCoords + xRank;
auto zCoords = yCoords + yRank;
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
for (Nd4jLong i = tid; i < zLen; i += totalThreads) {
shape::index2coords(i, zShapeInfo, zCoords);
for(int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) {
if(ix >= 0)
if(xShapeInfo[ix + 1] == zShapeInfo[iz + 1])
xCoords[ix--] = zCoords[iz];
else
xCoords[ix--] = 0;
if(iy >= 0)
if(yShapeInfo[iy + 1] == zShapeInfo[iz + 1])
yCoords[iy--] = zCoords[iz];
else
yCoords[iy--] = 0;
}
const auto xOffset = shape::getOffset(xShapeInfo, xCoords);
const auto zOffset = shape::getOffset(zShapeInfo, zCoords);
const auto yOffset = shape::getOffset(yShapeInfo, yCoords);
z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
}
}
////////////////////////////////////////////////////////////////////////
template<typename X>
template <typename OpType>
void TrueBroadcastIntHelper<X>::execLauncher(dim3 launchDims, cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo) {
trueBroadcastIntCuda<X,OpType><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo);
}
//////////////////////////////////////////////////////////////////////////
template<typename X>
void TrueBroadcastIntHelper<X>::exec(const sd::broadcast::IntOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
dim3 launchDims;
launchDims.y = MAX_NUM_THREADS / 8; // threadsPerBlock
launchDims.x = (zArr.lengthOf() + launchDims.y - 1) / launchDims.y; // blocksPerGrid
launchDims.z = sizeof(Nd4jLong) * launchDims.y * (xArr.rankOf() + yArr.rankOf() + zArr.rankOf()) + 128; // sharedMe
PointersManager manager(xArr.getContext(), "TrueBroadcastIntHelper<X>::exec");
NDArray::prepareSpecialUse({&zArr}, {&xArr, &yArr});
DISPATCH_BY_OPNUM_T(execLauncher, PARAMS(launchDims, xArr.getContext()->getCudaStream(), xArr.getSpecialBuffer(), xArr.getSpecialShapeInfo(), yArr.getSpecialBuffer(), yArr.getSpecialShapeInfo(), zArr.specialBuffer(), zArr.specialShapeInfo()), OPS_A(BROADCAST_INT_OPS));
NDArray::registerSpecialUse({&zArr}, {&xArr, &yArr});
manager.synchronize();
}
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_0);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_1);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_2);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_3);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_4);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_5);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_6);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_7);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_8);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_9);
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastBoolHelper, , LIBND4J_TYPES, BOOL_TYPES);
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastIntHelper, , INTEGER_TYPES);
}
}

View File

@ -46,6 +46,15 @@ static __global__ void broadcastSimple(
functions::broadcast::Broadcast<X,Y,Z>::template transformCuda<OpClass>(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo,dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ);
}
template<typename X, typename Y, typename Z, typename OpClass>
static __global__ void broadcastSimple(const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *z, const Nd4jLong *zShapeInfo ) {
functions::broadcast::Broadcast<X,Y,Z>::template transformCuda<OpClass>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo);
}
template<typename X, typename Y, typename Z, typename OpClass>
static __global__ void broadcastInverseSimple(
void *x,
@ -64,11 +73,11 @@ static __global__ void broadcastInverseSimple(
namespace functions {
namespace broadcast {
static Nd4jLong __device__ __noinline__ _getIndexOffset(Nd4jLong index, Nd4jLong *shapeInfo) {
static Nd4jLong __device__ __noinline__ getIndexOffset(Nd4jLong index, Nd4jLong *shapeInfo) {
return shape::getIndexOffset(index, shapeInfo);
}
static Nd4jLong __device__ __noinline__ _length(Nd4jLong *shapeInfo) {
static Nd4jLong __device__ __noinline__ length(Nd4jLong *shapeInfo) {
return shape::length(shapeInfo);
}
@ -78,6 +87,12 @@ namespace functions {
broadcastSimple<X, Y, Z, OpClass><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ);
}
template<typename X, typename Y, typename Z>
template <typename OpClass>
__host__ void Broadcast<X,Y,Z>::intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, const void *x, const Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo, void *z, const Nd4jLong *zShapeInfo) {
broadcastSimple<X, Y, Z, OpClass><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo);
}
template<typename X, typename Y, typename Z>
__host__ void Broadcast<X,Y,Z>::execBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
DISPATCH_BY_OPNUM_TTT(intermediateBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), OPS_A(BROADCAST_OPS))
@ -85,6 +100,13 @@ namespace functions {
DEBUG_KERNEL(stream, opNum);
}
template<typename X, typename Y, typename Z>
__host__ void Broadcast<X,Y,Z>::execBroadcast(dim3 launchDims, cudaStream_t *stream, const int opNum, const void *x, const Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo, void *z, const Nd4jLong *zShapeInfo) {
DISPATCH_BY_OPNUM_TTT(intermediateBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo), OPS_A(BROADCAST_OPS))
DEBUG_KERNEL(stream, opNum);
}
template<typename X, typename Y, typename Z>
template <typename OpClass>
__host__ void Broadcast<X,Y,Z>::intermediateInverseBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
@ -128,9 +150,9 @@ namespace functions {
if (threadIdx.x == 0) {
tadLength = _length(tadOnlyShapeInfo);
tadLength = length(tadOnlyShapeInfo);
tadEWS = shape::elementWiseStride(tadOnlyShapeInfo);
numTads = _length(yShapeInfo) / tadLength;
numTads = length(yShapeInfo) / tadLength;
xEWS = shape::elementWiseStride(xShapeInfo);
zEWS = shape::elementWiseStride(tadOnlyShapeInfoZ);
}
@ -154,9 +176,9 @@ namespace functions {
else {
// it is expected that x and z tads and y array all have the same length
for (Nd4jLong i = threadIdx.x; i < tadLength; i+= blockDim.x) {
auto xOffset = _getIndexOffset(i, xShapeInfo);
auto yOffset = _getIndexOffset(i, tadOnlyShapeInfo);
auto zOffset = _getIndexOffset(i, tadOnlyShapeInfoZ);
auto xOffset = getIndexOffset(i, xShapeInfo);
auto yOffset = getIndexOffset(i, tadOnlyShapeInfo);
auto zOffset = getIndexOffset(i, tadOnlyShapeInfoZ);
rZ[zOffset] = OpType::op(x[xOffset], rY[yOffset]);
}
}
@ -193,9 +215,9 @@ namespace functions {
__shared__ Nd4jLong zEWS;
if (threadIdx.x == 0) {
tadLength = _length(tadOnlyShapeInfo);
tadLength = length(tadOnlyShapeInfo);
tadEWS = shape::elementWiseStride(tadOnlyShapeInfo);
numTads = _length(xShapeInfo) / tadLength;
numTads = length(xShapeInfo) / tadLength;
yEWS = shape::elementWiseStride(yShapeInfo);
zEWS = shape::elementWiseStride(tadOnlyShapeInfoZ);
}
@ -219,15 +241,59 @@ namespace functions {
// it is expected that x and z tads and y array all have the same length
for (Nd4jLong i = threadIdx.x; i < tadLength; i+= blockDim.x) {
auto xOffset = _getIndexOffset(i, tadOnlyShapeInfo);
auto yOffset = _getIndexOffset(i, yShapeInfo);
auto zOffset = _getIndexOffset(i, tadOnlyShapeInfoZ);
auto xOffset = getIndexOffset(i, tadOnlyShapeInfo);
auto yOffset = getIndexOffset(i, yShapeInfo);
auto zOffset = getIndexOffset(i, tadOnlyShapeInfoZ);
rZ[zOffset] = OpType::op(rX[xOffset], y[yOffset]);
}
}
}
}
////////////////////////////////////////////////////////////////////////
template<typename X, typename Y, typename Z>
template <typename OpType>
__device__ void Broadcast<X,Y,Z>::transformCuda(
const void *vx, const Nd4jLong *xShapeInfo,
const void *vy, const Nd4jLong *yShapeInfo,
void *vz, const Nd4jLong *zShapeInfo) {
const X* x = reinterpret_cast<const X*>(vx);
const Y* y = reinterpret_cast<const Y*>(vy);
Z* z = reinterpret_cast<Z*>(vz);
__shared__ Nd4jLong zLen;
__shared__ int rank;
if (threadIdx.x == 0) {
zLen = shape::length(zShapeInfo);
rank = shape::rank(zShapeInfo);
}
__syncthreads();
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
Nd4jLong xCoords[MAX_RANK], yCoords[MAX_RANK], zCoords[MAX_RANK];
for (int i = tid; i < zLen; i += blockDim.x * gridDim.x) {
shape::index2coords(i, zShapeInfo, zCoords);
for (uint j = 0; j < rank; ++j) {
xCoords[j] = shape::sizeAt(xShapeInfo, j) == 1 ? 0 : zCoords[j];
yCoords[j] = shape::sizeAt(yShapeInfo, j) == 1 ? 0 : zCoords[j];
}
const auto xOffset = shape::getOffset(xShapeInfo, xCoords);
const auto yOffset = shape::getOffset(yShapeInfo, yCoords);
const auto zOffset = shape::getOffset(zShapeInfo, zCoords);
z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
}
}
/*
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_0);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_1);

View File

@ -47,6 +47,15 @@ static __global__ void broadcastBoolSimple(
functions::broadcast::BroadcastBool<X, Z>::template transformCuda<OpClass>(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo, extraParams, dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ);
}
//////////////////////////////////////////////////////////////////////////
template<typename X, typename Z, typename OpClass>
static __global__ void broadcastBoolSimple(const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *z, const Nd4jLong *zShapeInfo,
void *extraParams) {
functions::broadcast::BroadcastBool<X, Z>::template transformCuda<OpClass>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams);
}
//////////////////////////////////////////////////////////////////////////
template<typename X, typename Z, typename OpClass>
static __global__ void broadcastBoolInverseSimple(
@ -64,22 +73,48 @@ static __global__ void broadcastBoolInverseSimple(
}
namespace functions {
namespace broadcast {
//////////////////////////////////////////////////////////////////////////
template<typename X, typename Z>
template <typename OpClass>
__host__ void BroadcastBool<X,Z>::intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
broadcastBoolSimple<X, Z, OpClass><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ);
sd::DebugHelper::checkErrorCode(stream, "intermediateBroadcastBool(...) failed");
}
namespace broadcast {
//////////////////////////////////////////////////////////////////////////
template<typename X, typename Y>
__host__ void BroadcastBool<X,Y>::execBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
DISPATCH_BY_OPNUM_TT(intermediateBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), OPS_A(BROADCAST_BOOL_OPS))
template<typename X, typename Z>
template <typename OpClass>
__host__ void BroadcastBool<X,Z>::intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
broadcastBoolSimple<X, Z, OpClass><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ);
sd::DebugHelper::checkErrorCode(stream, "intermediateBroadcastBool(...) failed");
}
DEBUG_KERNEL(stream, opNum);
}
//////////////////////////////////////////////////////////////////////////
template<typename X, typename Z>
template <typename OpClass>
__host__ void BroadcastBool<X,Z>::intermediateBroadcast(dim3 launchDims, cudaStream_t *stream,
const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *z, const Nd4jLong *zShapeInfo,
void *extraParams) {
broadcastBoolSimple<X, Z, OpClass><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams);
sd::DebugHelper::checkErrorCode(stream, "intermediateBroadcastBool(...) failed");
}
//////////////////////////////////////////////////////////////////////////
template<typename X, typename Y>
__host__ void BroadcastBool<X,Y>::execBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
DISPATCH_BY_OPNUM_TT(intermediateBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), OPS_A(BROADCAST_BOOL_OPS))
DEBUG_KERNEL(stream, opNum);
}
//////////////////////////////////////////////////////////////////////////
template<typename X, typename Y>
__host__ void BroadcastBool<X,Y>::execBroadcast(dim3 launchDims, cudaStream_t *stream, const int opNum,
const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *z, const Nd4jLong *zShapeInfo,
void *extraParams) {
DISPATCH_BY_OPNUM_TT(intermediateBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams), OPS_A(BROADCAST_BOOL_OPS))
DEBUG_KERNEL(stream, opNum);
}
//////////////////////////////////////////////////////////////////////////
template<typename X, typename Z>
@ -229,7 +264,53 @@ namespace functions {
}
}
//////////////////////////////////////////////////////////////////////////
template<typename X, typename Z>
template <typename OpType>
__device__ void BroadcastBool<X,Z>::transformCuda(const void *vx, const Nd4jLong *xShapeInfo,
const void *vy, const Nd4jLong *yShapeInfo,
void *vz, const Nd4jLong *zShapeInfo,
void *vextraParams) {
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT BroadcastBool, , LIBND4J_TYPES, BOOL_TYPES);
const X* x = reinterpret_cast<const X*>(vx);
const X* y = reinterpret_cast<const X*>(vy);
Z* z = reinterpret_cast<Z*>(vz);
auto extraParams = reinterpret_cast<X*>(vextraParams);
__shared__ Nd4jLong zLen;
__shared__ int rank;
if (threadIdx.x == 0) {
zLen = shape::length(zShapeInfo);
rank = shape::rank(zShapeInfo);
}
__syncthreads();
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
Nd4jLong xCoords[MAX_RANK], yCoords[MAX_RANK], zCoords[MAX_RANK];
for (int i = tid; i < zLen; i += blockDim.x * gridDim.x) {
shape::index2coords(i, zShapeInfo, zCoords);
for (uint j = 0; j < rank; ++j) {
xCoords[j] = shape::sizeAt(xShapeInfo, j) == 1 ? 0 : zCoords[j];
yCoords[j] = shape::sizeAt(yShapeInfo, j) == 1 ? 0 : zCoords[j];
}
const auto xOffset = shape::getOffset(xShapeInfo, xCoords);
const auto yOffset = shape::getOffset(yShapeInfo, yCoords);
const auto zOffset = shape::getOffset(zShapeInfo, zCoords);
z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams);
}
}
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT BroadcastBool, , LIBND4J_TYPES, BOOL_TYPES);
}
}

View File

@ -46,6 +46,15 @@ static __global__ void broadcastIntSimple(
functions::broadcast::BroadcastInt<X>::template transformCuda<OpClass>(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo,dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ);
}
//////////////////////////////////////////////////////////////////////////
template<typename X, typename OpClass>
static __global__ void broadcastIntSimple(const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *z, const Nd4jLong *zShapeInfo) {
functions::broadcast::BroadcastInt<X>::template transformCuda<OpClass>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo);
}
//////////////////////////////////////////////////////////////////////////
template<typename X, typename OpClass>
static __global__ void broadcastBoolInverseSimple(
@ -62,19 +71,40 @@ static __global__ void broadcastBoolInverseSimple(
}
namespace functions {
namespace broadcast {
namespace broadcast {
//////////////////////////////////////////////////////////////////////////
template<typename X>
template <typename OpClass>
__host__ void BroadcastInt<X>::intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
broadcastIntSimple<X, OpClass><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ);
}
template<typename X>
template <typename OpClass>
__host__ void BroadcastInt<X>::intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
broadcastIntSimple<X, OpClass><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ);
}
//////////////////////////////////////////////////////////////////////////
template<typename X>
__host__ void BroadcastInt<X>::execBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
DISPATCH_BY_OPNUM_T(intermediateBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), OPS_A(BROADCAST_INT_OPS))
}
template<typename X>
template <typename OpClass>
__host__ void BroadcastInt<X>::intermediateBroadcast(dim3 launchDims, cudaStream_t *stream,
const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *z, const Nd4jLong *zShapeInfo) {
broadcastIntSimple<X, OpClass><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo);
}
//////////////////////////////////////////////////////////////////////////
template<typename X>
__host__ void BroadcastInt<X>::execBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
DISPATCH_BY_OPNUM_T(intermediateBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), OPS_A(BROADCAST_INT_OPS))
}
//////////////////////////////////////////////////////////////////////////
template<typename X>
__host__ void BroadcastInt<X>::execBroadcast(dim3 launchDims, cudaStream_t *stream, const int opNum,
const void *x, const Nd4jLong *xShapeInfo,
const void *y, const Nd4jLong *yShapeInfo,
void *z, const Nd4jLong *zShapeInfo) {
DISPATCH_BY_OPNUM_T(intermediateBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo), OPS_A(BROADCAST_INT_OPS))
}
//////////////////////////////////////////////////////////////////////////
template<typename X>
@ -217,6 +247,50 @@ namespace functions {
}
}
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT BroadcastInt, , INTEGER_TYPES);
//////////////////////////////////////////////////////////////////////////
template<typename X>
template <typename OpType>
__device__ void BroadcastInt<X>::transformCuda(const void *vx, const Nd4jLong *xShapeInfo,
const void *vy, const Nd4jLong *yShapeInfo,
void *vz, const Nd4jLong *zShapeInfo) {
const X* x = reinterpret_cast<const X*>(vx);
const X* y = reinterpret_cast<const X*>(vy);
X* z = reinterpret_cast<X*>(vz);
__shared__ Nd4jLong zLen;
__shared__ int rank;
if (threadIdx.x == 0) {
zLen = shape::length(zShapeInfo);
rank = shape::rank(zShapeInfo);
}
__syncthreads();
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
Nd4jLong xCoords[MAX_RANK], yCoords[MAX_RANK], zCoords[MAX_RANK];
for (int i = tid; i < zLen; i += blockDim.x * gridDim.x) {
shape::index2coords(i, zShapeInfo, zCoords);
for (uint j = 0; j < rank; ++j) {
xCoords[j] = shape::sizeAt(xShapeInfo, j) == 1 ? 0 : zCoords[j];
yCoords[j] = shape::sizeAt(yShapeInfo, j) == 1 ? 0 : zCoords[j];
}
const auto xOffset = shape::getOffset(xShapeInfo, xCoords);
const auto yOffset = shape::getOffset(yShapeInfo, yCoords);
const auto zOffset = shape::getOffset(zShapeInfo, zCoords);
z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
}
}
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT BroadcastInt, , INTEGER_TYPES);
}
}

View File

@ -21,17 +21,14 @@
#include <loops/special_kernels.h>
namespace sd {
static Nd4jLong __device__ __noinline__ _getIndexOffset(Nd4jLong index, Nd4jLong *shapeInfo) {
static Nd4jLong __device__ __noinline__ getIndexOffset_(Nd4jLong index, Nd4jLong *shapeInfo) {
return shape::getIndexOffset(index, shapeInfo);
}
static Nd4jLong __device__ __noinline__ _subArrayOffset(Nd4jLong index, Nd4jLong *shapeInfoA, Nd4jLong *shapeInfoB) {
static Nd4jLong __device__ __noinline__ subArrayOffset(Nd4jLong index, Nd4jLong *shapeInfoA, Nd4jLong *shapeInfoB) {
return shape::subArrayOffset(index, shapeInfoA, shapeInfoB);
}
static Nd4jLong __device__ __noinline__ _length(Nd4jLong *shapeInfo) {
return shape::length(shapeInfo);
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// tileKernel:
@ -48,13 +45,13 @@ namespace sd {
int totalThreads = gridDim.x * blockDim.x;
if (shape::order(outputShape) == 'c') { // ews == 1 always here
for (int i = tid; i < resultLength; i += totalThreads) {
auto yOffset = _subArrayOffset(i, outputShape, inputShape);
auto yOffset = subArrayOffset(i, outputShape, inputShape);
*(reinterpret_cast<T *>(outputBuffer) + i) = *(reinterpret_cast<T const *>(inputBuffer) + yOffset);
}
} else {
for (int i = tid; i < resultLength; i += totalThreads) {
auto xOffset = _getIndexOffset(i, outputShape);
auto yOffset = _subArrayOffset(i, outputShape, inputShape);
auto xOffset = getIndexOffset_(i, outputShape);
auto yOffset = subArrayOffset(i, outputShape, inputShape);
*(reinterpret_cast<T *>(outputBuffer) + xOffset) = *(reinterpret_cast<T const *>(inputBuffer) + yOffset);
}
}
@ -83,20 +80,20 @@ namespace sd {
if (ordering == 'c' && ews == 1) { // ews == 1 always here
for (int i = tid; i < resultLength; i += totalThreads) {
auto yOffset = _subArrayOffset(i, outputShape, inputShape);
auto yOffset = subArrayOffset(i, outputShape, inputShape);
*(reinterpret_cast<X *>(outputBuffer) + i) = static_cast<X>(*(reinterpret_cast<Y const *>(inputBuffer) + yOffset));
}
} else if (ordering == 'c' && ews > 1) {
for (int i = tid; i < resultLength; i += totalThreads) {
auto yOffset = _subArrayOffset(i, outputShape, inputShape);
auto yOffset = subArrayOffset(i, outputShape, inputShape);
*(reinterpret_cast<X *>(outputBuffer) + i * ews) = static_cast<X>(*(reinterpret_cast<Y const *>(inputBuffer) + yOffset));
}
} else {
for (int i = tid; i < resultLength; i += totalThreads) {
auto xOffset = _getIndexOffset(i, outputShape);
auto yOffset = _subArrayOffset(i, outputShape, inputShape);
auto xOffset = getIndexOffset_(i, outputShape);
auto yOffset = subArrayOffset(i, outputShape, inputShape);
*(reinterpret_cast<X *>(outputBuffer) + xOffset) = static_cast<X>(*(reinterpret_cast<Y const *>(inputBuffer) + yOffset));
}
}

View File

@ -170,7 +170,7 @@ TEST_F(PlaygroundTests, test_bert_1) {
graph->getVariableSpace()->putVariable(86,0, u);
graph->getVariableSpace()->putVariable(87,0, v);
/*
// validating graph now
auto status = GraphExecutioner::execute(graph);
ASSERT_EQ(Status::OK(), status);
@ -179,7 +179,7 @@ TEST_F(PlaygroundTests, test_bert_1) {
auto array = graph->getVariableSpace()->getVariable(198)->getNDArray();
ASSERT_EQ(z, *array);
/*
*/
sd::Environment::getInstance()->setProfiling(true);
auto profile = GraphProfilingHelper::profile(graph, 1);
@ -187,7 +187,7 @@ TEST_F(PlaygroundTests, test_bert_1) {
sd::Environment::getInstance()->setProfiling(false);
delete profile;
*/
/*
std::vector<Nd4jLong> values;