diff --git a/libnd4j/include/array/NDArray.hXX b/libnd4j/include/array/NDArray.hXX index 38ea087ba..fa39a00f6 100644 --- a/libnd4j/include/array/NDArray.hXX +++ b/libnd4j/include/array/NDArray.hXX @@ -26,7 +26,6 @@ #include #include #include -#include namespace sd { @@ -2801,15 +2800,15 @@ void NDArray::applyTrueBroadcast(sd::BroadcastOpsTuple op, const NDArray& other, if (isEmpty() || other.isEmpty()) return; - if (lengthOf() == 1) { - target.assign(this); - target.applyPairwiseTransform(op.p, other, extraArgs); - return; - } - if (other.lengthOf() == 1) { - const_cast(this)->applyScalarArr(op.s, other, target, extraArgs); - return; - } + // if (lengthOf() == 1) { + // target.assign(this); + // target.applyPairwiseTransform(op.p, other, extraArgs); + // return; + // } + // if (other.lengthOf() == 1) { + // const_cast(this)->applyScalarArr(op.s, other, target, extraArgs); + // return; + // } if(checkTargetShape) { Nd4jLong* newShapeInfo = nullptr; @@ -2819,36 +2818,46 @@ void NDArray::applyTrueBroadcast(sd::BroadcastOpsTuple op, const NDArray& other, throw std::runtime_error("NDArray::applyTrueBroadcast method: the shape or type of target array is wrong !"); } - if(target.isSameShape(this) || target.isSameShape(other)) { - const_cast(this)->applyBroadcast(op.b, ShapeUtils::getDimsWithSameShape(*this, other), other, target, extraArgs); - return; + Nd4jLong* xShapeInfoH = getShapeInfo(); + Nd4jLong* yShapeInfoH = other.getShapeInfo(); + Nd4jLong* xShapeInfoD = getSpecialShapeInfo(); + Nd4jLong* yShapeInfoD = other.getSpecialShapeInfo(); + + if(!isSameShape(target)) { + auto xPack = ConstantShapeHelper::getInstance()->createShapeInfoWithUnitiesForBroadcast(target.getShapeInfo(), getShapeInfo(), getContext()->getWorkspace()); + xShapeInfoH = reinterpret_cast(xPack.primary()); + xShapeInfoD = reinterpret_cast(xPack.special()); + } + if(!other.isSameShape(target)) { + auto yPack = ConstantShapeHelper::getInstance()->createShapeInfoWithUnitiesForBroadcast(target.getShapeInfo(), other.getShapeInfo(), other.getContext()->getWorkspace()); + yShapeInfoH = reinterpret_cast(yPack.primary()); + yShapeInfoD = reinterpret_cast(yPack.special()); } - #ifdef __ND4J_EXPERIMENTAL__ - BUILD_PAIRWISE_SELECTOR(dataType(), other.dataType(), target.dataType(), helpers::TrueBroadcastHelper, ::exec(op.b, *this, other, target), LIBND4J_TYPES, LIBND4J_TYPES); - #else - BUILD_SINGLE_SELECTOR_THRICE(dataType(), helpers::TrueBroadcastHelper, ::exec(op.b, *this, other, target), LIBND4J_TYPES); - #endif + NDArray::prepareSpecialUse({&target}, {this, &other}); + NativeOpExecutioner::execBroadcast(getContext(), op.b, getBuffer(), xShapeInfoH, getSpecialBuffer(), xShapeInfoD, other.getBuffer(), yShapeInfoH, other.getSpecialBuffer(), yShapeInfoD, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo()); + registerSpecialUse({&target}, {this, &other}); } ////////////////////////////////////////////////////////////////////////// void NDArray::applyTrueBroadcast(sd::BroadcastBoolOpsTuple op, const NDArray& other, NDArray& target, const bool checkTargetShape, ExtraArguments *extraArgs) const { + if (isS()) throw std::runtime_error("NDArray::applyTrueBroadcast bool: you can't use this method on String array!"); if (isEmpty() || other.isEmpty()) return; - if (lengthOf() == 1) { - NDArray temp(target._shapeInfo, dataType(), false, getContext()); - temp.assign(this); - temp.applyPairwiseTransform(op.p, other, target, extraArgs); - return; - } - if (other.lengthOf() == 1) { - this->applyScalarArr(op.s, other, target, extraArgs); - return; - } + // if (lengthOf() == 1) { + // NDArray temp(target._shapeInfo, dataType(), false, getContext()); + // temp.assign(this); + // temp.applyPairwiseTransform(op.p, other, target, extraArgs); + // return; + // } + // if (other.lengthOf() == 1) { + // this->applyScalarArr(op.s, other, target, extraArgs); + // return; + // } if(checkTargetShape) { Nd4jLong* newShapeInfo = nullptr; @@ -2860,12 +2869,25 @@ void NDArray::applyTrueBroadcast(sd::BroadcastBoolOpsTuple op, const NDArray& ot throw std::invalid_argument("NDArray::applyTrueBroadcast bool method: this and other arrays must have the same type !"); } - if(target.isSameShape(this) || target.isSameShape(other)) { - const_cast(this)->applyBroadcast(op.b, ShapeUtils::getDimsWithSameShape(*this, other), other, target, extraArgs); - return; + Nd4jLong* xShapeInfoH = getShapeInfo(); + Nd4jLong* yShapeInfoH = other.getShapeInfo(); + Nd4jLong* xShapeInfoD = getSpecialShapeInfo(); + Nd4jLong* yShapeInfoD = other.getSpecialShapeInfo(); + + if(!isSameShape(target)) { + auto xPack = ConstantShapeHelper::getInstance()->createShapeInfoWithUnitiesForBroadcast(target.getShapeInfo(), getShapeInfo(), getContext()->getWorkspace()); + xShapeInfoH = reinterpret_cast(xPack.primary()); + xShapeInfoD = reinterpret_cast(xPack.special()); + } + if(!other.isSameShape(target)) { + auto yPack = ConstantShapeHelper::getInstance()->createShapeInfoWithUnitiesForBroadcast(target.getShapeInfo(), other.getShapeInfo(), other.getContext()->getWorkspace()); + yShapeInfoH = reinterpret_cast(yPack.primary()); + yShapeInfoD = reinterpret_cast(yPack.special()); } - BUILD_DOUBLE_SELECTOR(dataType(), target.dataType(), helpers::TrueBroadcastBoolHelper, ::exec(op.b, *this, other, target), LIBND4J_TYPES, BOOL_TYPES); + NDArray::prepareSpecialUse({&target}, {this, &other}); + NativeOpExecutioner::execBroadcastBool(getContext(), op.b, getBuffer(), xShapeInfoH, getSpecialBuffer(), xShapeInfoD, other.getBuffer(), yShapeInfoH, other.getSpecialBuffer(), yShapeInfoD, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), nullptr); + registerSpecialUse({&target}, {this, &other}); } ////////////////////////////////////////////////////////////////////////// @@ -2877,16 +2899,16 @@ void NDArray::applyTrueBroadcast(sd::BroadcastIntOpsTuple op, const NDArray& oth if (isEmpty() || other.isEmpty()) return; - if (lengthOf() == 1) { - NDArray temp(target._shapeInfo, dataType(), false, getContext()); - temp.assign(this); - temp.applyPairwiseTransform(op.p, other, target, extraArgs); - return; - } - if (other.lengthOf() == 1) { - this->applyScalarArr(op.s, other, target, extraArgs); - return; - } + // if (lengthOf() == 1) { + // NDArray temp(target._shapeInfo, dataType(), false, getContext()); + // temp.assign(this); + // temp.applyPairwiseTransform(op.p, other, target, extraArgs); + // return; + // } + // if (other.lengthOf() == 1) { + // this->applyScalarArr(op.s, other, target, extraArgs); + // return; + // } if(checkTargetShape) { Nd4jLong* newShapeInfo = nullptr; @@ -2898,12 +2920,25 @@ void NDArray::applyTrueBroadcast(sd::BroadcastIntOpsTuple op, const NDArray& oth throw std::invalid_argument("NDArray::applyTrueBroadcast int method: this and other arrays must have the same type !"); } - if(target.isSameShape(this) || target.isSameShape(other)) { - const_cast(this)->applyBroadcast(op.b, ShapeUtils::getDimsWithSameShape(*this, other), other, target, extraArgs); - return; + Nd4jLong* xShapeInfoH = getShapeInfo(); + Nd4jLong* yShapeInfoH = other.getShapeInfo(); + Nd4jLong* xShapeInfoD = getSpecialShapeInfo(); + Nd4jLong* yShapeInfoD = other.getSpecialShapeInfo(); + + if(!isSameShape(target)) { + auto xPack = ConstantShapeHelper::getInstance()->createShapeInfoWithUnitiesForBroadcast(target.getShapeInfo(), getShapeInfo(), getContext()->getWorkspace()); + xShapeInfoH = reinterpret_cast(xPack.primary()); + xShapeInfoD = reinterpret_cast(xPack.special()); + } + if(!other.isSameShape(target)) { + auto yPack = ConstantShapeHelper::getInstance()->createShapeInfoWithUnitiesForBroadcast(target.getShapeInfo(), other.getShapeInfo(), other.getContext()->getWorkspace()); + yShapeInfoH = reinterpret_cast(yPack.primary()); + yShapeInfoD = reinterpret_cast(yPack.special()); } - BUILD_SINGLE_SELECTOR(dataType(), helpers::TrueBroadcastIntHelper, ::exec(op.b, *this, other, target), INTEGER_TYPES); + NDArray::prepareSpecialUse({&target}, {this, &other}); + NativeOpExecutioner::execBroadcastInt(getContext(), op.b, getBuffer(), xShapeInfoH, getSpecialBuffer(), xShapeInfoD, other.getBuffer(), yShapeInfoH, other.getSpecialBuffer(), yShapeInfoD, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo()); + registerSpecialUse({&target}, {this, &other}); } ////////////////////////////////////////////////////////////////////////// @@ -3008,6 +3043,10 @@ NDArray NDArray::applyTrueBroadcast(sd::BroadcastOpsTuple op, NDArray&& other, E ////////////////////////////////////////////////////////////////////////// void NDArray::applyBroadcast(sd::broadcast::Ops op, const std::vector& dimensions, const NDArray& other, NDArray& target, ExtraArguments* extraArgs) { + + if (dimensions.size() == 0) + return; + if (isS()) throw std::runtime_error("NDArray::applyBroadcast: you can't use this method on String array!"); if(((op == broadcast::Divide || op == broadcast::FloorDiv || op == broadcast::FloorMod) && other.isB()) || (op == broadcast::ReverseDivide && this->isB())) @@ -3018,53 +3057,50 @@ void NDArray::applyBroadcast(sd::broadcast::Ops op, const std::vector& dime return; } - if (dimensions.size() == 0) - return; - - if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) { - NDArray::prepareSpecialUse({&target}, {this, &other}); - NativeOpExecutioner::execPairwiseTransform(getContext(), fromBroadcastToPairwise(op), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&target}, {this, &other}); - return; - } - - NDArray *min(nullptr), *max(nullptr); - if((lengthOf() > other.lengthOf()) || (lengthOf() == other.lengthOf() && rankOf() >= other.rankOf())) { - max = this; - min = const_cast(&other); - } - else { - max = const_cast(&other); - min = this; - } + // if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) { + // NDArray::prepareSpecialUse({&target}, {this, &other}); + // NativeOpExecutioner::execPairwiseTransform(getContext(), fromBroadcastToPairwise(op), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), nullptr); + // NDArray::registerSpecialUse({&target}, {this, &other}); + // return; + // } if(target.dataType() != DataTypeUtils::pickPairwiseResultType(shapeInfo(), other.getShapeInfo())) throw std::invalid_argument("NDArray::applyBroadcast method: wrong type of target array !"); - if(!target.isSameShape(max)) - throw std::invalid_argument("NDArray::applyBroadcast method: max and target arrays must have the same shape !"); + if(!target.isSameShape(this) && !target.isSameShape(other)) + throw std::invalid_argument("NDArray::applyBroadcast method: one of of two input arrays (this or other) should has the same shape as target array!"); std::vector copy(dimensions); if (dimensions.size() > 1) std::sort(copy.begin(), copy.end()); - Nd4jLong tadLength = shape::tadLength(max->shapeInfo(), copy.data(), (int) copy.size()); - if (tadLength != min->lengthOf()) - throw std::runtime_error("NDArray::applyBroadcast method: tad length mismatch !"); + Nd4jLong* xShapeInfoH = getShapeInfo(); + Nd4jLong* yShapeInfoH = other.getShapeInfo(); + Nd4jLong* xShapeInfoD = getSpecialShapeInfo(); + Nd4jLong* yShapeInfoD = other.getSpecialShapeInfo(); - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(max->shapeInfo(), copy); - auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(target.shapeInfo(), copy); + if(!isSameShape(target)) { + auto xPack = ConstantShapeHelper::getInstance()->createShapeInfoWithUnitiesForBroadcast(target.getShapeInfo(), getShapeInfo(), getContext()->getWorkspace(), copy); + xShapeInfoH = reinterpret_cast(xPack.primary()); + xShapeInfoD = reinterpret_cast(xPack.special()); + } + if(!other.isSameShape(target)) { + auto yPack = ConstantShapeHelper::getInstance()->createShapeInfoWithUnitiesForBroadcast(target.getShapeInfo(), other.getShapeInfo(), other.getContext()->getWorkspace(), copy); + yShapeInfoH = reinterpret_cast(yPack.primary()); + yShapeInfoD = reinterpret_cast(yPack.special()); + } NDArray::prepareSpecialUse({&target}, {this, &other}); - if(max == this) - NativeOpExecutioner::execBroadcast( getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), copy.data(), (int)copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets()); - else - NativeOpExecutioner::execInverseBroadcast(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), copy.data(), (int)copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets()); + NativeOpExecutioner::execBroadcast(getContext(), op, buffer(), xShapeInfoH, specialBuffer(), xShapeInfoD, other.getBuffer(), yShapeInfoH, other.getSpecialBuffer(), yShapeInfoD, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo()); registerSpecialUse({&target}, {this, &other}); } ////////////////////////////////////////////////////////////////////////// void NDArray::applyBroadcast(sd::broadcast::BoolOps op, const std::vector& dimensions, const NDArray& other, NDArray& target, ExtraArguments* extraArgs) { + + if (dimensions.size() == 0) + return; + if (isS()) throw std::runtime_error("NDArray::applyBroadcast BoolOps: you can't use this method on String array!"); if(isEmpty() || other.isEmpty()) { @@ -3073,30 +3109,17 @@ void NDArray::applyBroadcast(sd::broadcast::BoolOps op, const std::vector& return; } - if (dimensions.size() == 0) - return; - - if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) { - NDArray::prepareSpecialUse({&target}, {this, &other}); - NativeOpExecutioner::execPairwiseBoolTransform(getContext(), fromBroadcastToPairwiseBool(op), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&target}, {this, &other}); - return; - } - - NDArray *min(nullptr), *max(nullptr); - if((lengthOf() > other.lengthOf()) || (lengthOf() == other.lengthOf() && rankOf() >= other.rankOf())) { - max = this; - min = const_cast(&other); - } - else { - max = const_cast(&other); - min = this; - } + // if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) { + // NDArray::prepareSpecialUse({&target}, {this, &other}); + // NativeOpExecutioner::execPairwiseBoolTransform(getContext(), fromBroadcastToPairwiseBool(op), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), nullptr); + // NDArray::registerSpecialUse({&target}, {this, &other}); + // return; + // } if(target.dataType() != DataType::BOOL) throw std::invalid_argument("NDArray::applyBroadcast bool method: type of target array must be BOOL!"); - if(!target.isSameShape(max)) - throw std::invalid_argument("NDArray::applyBroadcast bool method: max and target arrays must have the same shape !"); + if(!target.isSameShape(this) && !target.isSameShape(other)) + throw std::invalid_argument("NDArray::applyBroadcast bool method: one of of two input arrays (this or other) should has the same shape as target array!"); if(_dataType != other._dataType) throw std::invalid_argument("NDArray::applyBroadcast bool method: this and other arrays must have the same type !"); @@ -3105,25 +3128,34 @@ void NDArray::applyBroadcast(sd::broadcast::BoolOps op, const std::vector& if (dimensions.size() > 1) std::sort(copy.begin(), copy.end()); - Nd4jLong tadLength = shape::tadLength(max->shapeInfo(), copy.data(), (int) copy.size()); - if (tadLength != min->lengthOf()) - throw std::runtime_error("Tad length mismatch"); + Nd4jLong* xShapeInfoH = getShapeInfo(); + Nd4jLong* yShapeInfoH = other.getShapeInfo(); + Nd4jLong* xShapeInfoD = getSpecialShapeInfo(); + Nd4jLong* yShapeInfoD = other.getSpecialShapeInfo(); - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(max->shapeInfo(), copy); - auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(target.shapeInfo(), copy); + if(!isSameShape(target)) { + auto xPack = ConstantShapeHelper::getInstance()->createShapeInfoWithUnitiesForBroadcast(target.getShapeInfo(), getShapeInfo(), getContext()->getWorkspace(), copy); + xShapeInfoH = reinterpret_cast(xPack.primary()); + xShapeInfoD = reinterpret_cast(xPack.special()); + } + if(!other.isSameShape(target)) { + auto yPack = ConstantShapeHelper::getInstance()->createShapeInfoWithUnitiesForBroadcast(target.getShapeInfo(), other.getShapeInfo(), other.getContext()->getWorkspace(), copy); + yShapeInfoH = reinterpret_cast(yPack.primary()); + yShapeInfoD = reinterpret_cast(yPack.special()); + } - // TODO: eventually we want separate tads here NDArray::prepareSpecialUse({&target}, {this, &other}); - if(max == this) - NativeOpExecutioner::execBroadcastBool( getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), nullptr, copy.data(), (int)copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets()); - else - NativeOpExecutioner::execInverseBroadcastBool(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), nullptr, copy.data(), (int)copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets()); + NativeOpExecutioner::execBroadcastBool(getContext(), op, buffer(), xShapeInfoH, specialBuffer(), xShapeInfoD, other.getBuffer(), yShapeInfoH, other.getSpecialBuffer(), yShapeInfoD, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), nullptr); registerSpecialUse({&target}, {this, &other}); } ////////////////////////////////////////////////////////////////////////// void NDArray::applyBroadcast(sd::broadcast::IntOps op, const std::vector& dimensions, const NDArray& other, NDArray& target, ExtraArguments* extraArgs) { + + if (dimensions.empty()) + return; + if (!isZ()) throw std::runtime_error("NDArray::applyBroadcast IntOps: you can't use this method on non-Integer array!"); if(isEmpty() || other.isEmpty()) { @@ -3132,30 +3164,17 @@ void NDArray::applyBroadcast(sd::broadcast::IntOps op, const std::vector& d return; } - if (dimensions.empty()) - return; - - if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) { - NDArray::prepareSpecialUse({&target}, {this, &other}); - NativeOpExecutioner::execPairwiseIntTransform(getContext(), fromBroadcastToPairwiseInt(op), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&target}, {this, &other}); - return; - } - - NDArray *min(nullptr), *max(nullptr); - if((lengthOf() > other.lengthOf()) || (lengthOf() == other.lengthOf() && rankOf() >= other.rankOf())) { - max = this; - min = const_cast(&other); - } - else { - max = const_cast(&other); - min = this; - } + // if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) { + // NDArray::prepareSpecialUse({&target}, {this, &other}); + // NativeOpExecutioner::execPairwiseIntTransform(getContext(), fromBroadcastToPairwiseInt(op), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), nullptr); + // NDArray::registerSpecialUse({&target}, {this, &other}); + // return; + // } if(target.dataType() != dataType()) throw std::invalid_argument("NDArray::applyBroadcast int method: type of target array must be the same as input!"); - if(!target.isSameShape(max)) - throw std::invalid_argument("NDArray::applyBroadcast int method: max and target arrays must have the same shape !"); + if(!target.isSameShape(this) && !target.isSameShape(other)) + throw std::invalid_argument("NDArray::applyBroadcast int method: one of of two input arrays (this or other) should has the same shape as target array!"); if(_dataType != other._dataType) throw std::invalid_argument("NDArray::applyBroadcast int method: this and other arrays must have the same type !"); @@ -3164,19 +3183,24 @@ void NDArray::applyBroadcast(sd::broadcast::IntOps op, const std::vector& d if (dimensions.size() > 1) std::sort(copy.begin(), copy.end()); - Nd4jLong tadLength = shape::tadLength(max->shapeInfo(), copy.data(), (int) copy.size()); - if (tadLength != min->lengthOf()) - throw std::runtime_error("Tad length mismatch"); + Nd4jLong* xShapeInfoH = getShapeInfo(); + Nd4jLong* yShapeInfoH = other.getShapeInfo(); + Nd4jLong* xShapeInfoD = getSpecialShapeInfo(); + Nd4jLong* yShapeInfoD = other.getSpecialShapeInfo(); - auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(max->shapeInfo(), copy); - auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(target.shapeInfo(), copy); + if(!isSameShape(target)) { + auto xPack = ConstantShapeHelper::getInstance()->createShapeInfoWithUnitiesForBroadcast(target.getShapeInfo(), getShapeInfo(), getContext()->getWorkspace(), copy); + xShapeInfoH = reinterpret_cast(xPack.primary()); + xShapeInfoD = reinterpret_cast(xPack.special()); + } + if(!other.isSameShape(target)) { + auto yPack = ConstantShapeHelper::getInstance()->createShapeInfoWithUnitiesForBroadcast(target.getShapeInfo(), other.getShapeInfo(), other.getContext()->getWorkspace(), copy); + yShapeInfoH = reinterpret_cast(yPack.primary()); + yShapeInfoD = reinterpret_cast(yPack.special()); + } - // TODO: eventually we want separate tads here NDArray::prepareSpecialUse({&target}, {this, &other}); - if(max == this) - NativeOpExecutioner::execBroadcastInt( getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), copy.data(), (int)copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets()); - else - NativeOpExecutioner::execInverseBroadcastInt(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), copy.data(), (int)copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets()); + NativeOpExecutioner::execBroadcastInt(getContext(), op, buffer(), xShapeInfoH, specialBuffer(), xShapeInfoD, other.getBuffer(), yShapeInfoH, other.getSpecialBuffer(), yShapeInfoD, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo()); registerSpecialUse({&target}, {this, &other}); } diff --git a/libnd4j/include/array/NDArrayLambda.hXX b/libnd4j/include/array/NDArrayLambda.hXX index 718a35527..50d9bc8d6 100644 --- a/libnd4j/include/array/NDArrayLambda.hXX +++ b/libnd4j/include/array/NDArrayLambda.hXX @@ -23,11 +23,11 @@ #include #include -static Nd4jLong __device__ __noinline__ __getIndexOffset(Nd4jLong index, Nd4jLong *shapeInfo) { +static Nd4jLong __device__ __noinline__ getIndexOffset(Nd4jLong index, Nd4jLong *shapeInfo) { return shape::getIndexOffset(index, shapeInfo); } -static Nd4jLong __device__ __noinline__ __length(Nd4jLong *shapeInfo) { +static Nd4jLong __device__ __noinline__ length(Nd4jLong *shapeInfo) { return shape::length(shapeInfo); } @@ -94,7 +94,7 @@ static _CUDA_G void lambdaKernel(void* vx, Nd4jLong *xShapeInfo, void *vz, Nd4jL auto xOrder = shape::order(xShapeInfo); auto zOrder = shape::order(zShapeInfo); - auto zLength = __length(zShapeInfo); + auto zLength = length(zShapeInfo); auto tid = threadIdx.x + blockIdx.x * blockDim.x; @@ -103,8 +103,8 @@ static _CUDA_G void lambdaKernel(void* vx, Nd4jLong *xShapeInfo, void *vz, Nd4jL z[e * zEws] = lambda(x[e * xEws]); } else { for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) { - auto xOffset = __getIndexOffset(e, xShapeInfo); - auto zOffset = __getIndexOffset(e, zShapeInfo); + auto xOffset = getIndexOffset(e, xShapeInfo); + auto zOffset = getIndexOffset(e, zShapeInfo); z[zOffset] = lambda(x[xOffset]); } @@ -123,7 +123,7 @@ static _CUDA_G void lambdaIndexedKernel(void* vx, Nd4jLong *xShapeInfo, void *vz auto xOrder = shape::order(xShapeInfo); auto zOrder = shape::order(zShapeInfo); - auto zLength = __length(zShapeInfo); + auto zLength = length(zShapeInfo); auto tid = threadIdx.x + blockIdx.x * blockDim.x; @@ -132,8 +132,8 @@ static _CUDA_G void lambdaIndexedKernel(void* vx, Nd4jLong *xShapeInfo, void *vz z[e * zEws] = lambda(e, x[e * xEws]); } else { for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) { - auto xOffset = __getIndexOffset(e, xShapeInfo); - auto zOffset = __getIndexOffset(e, zShapeInfo); + auto xOffset = getIndexOffset(e, xShapeInfo); + auto zOffset = getIndexOffset(e, zShapeInfo); z[zOffset] = lambda(e, x[xOffset]); } @@ -155,7 +155,7 @@ static _CUDA_G void lambdaIndexedPairwiseKernel(void* vx, Nd4jLong *xShapeInfo, auto yOrder = shape::order(yShapeInfo); auto zOrder = shape::order(zShapeInfo); - auto zLength = __length(zShapeInfo); + auto zLength = length(zShapeInfo); auto tid = threadIdx.x + blockIdx.x * blockDim.x; @@ -164,9 +164,9 @@ static _CUDA_G void lambdaIndexedPairwiseKernel(void* vx, Nd4jLong *xShapeInfo, z[e * zEws] = lambda(e, x[e * xEws], y[e * yEws]); } else { for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) { - auto xOffset = __getIndexOffset(e, xShapeInfo); - auto yOffset = __getIndexOffset(e, yShapeInfo); - auto zOffset = __getIndexOffset(e, zShapeInfo); + auto xOffset = getIndexOffset(e, xShapeInfo); + auto yOffset = getIndexOffset(e, yShapeInfo); + auto zOffset = getIndexOffset(e, zShapeInfo); z[zOffset] = lambda(e, x[xOffset], y[yOffset]); } @@ -188,7 +188,7 @@ static _CUDA_G void lambdaPairwiseKernel(void* vx, Nd4jLong *xShapeInfo, void* v auto yOrder = shape::order(yShapeInfo); auto zOrder = shape::order(zShapeInfo); - auto zLength = __length(zShapeInfo); + auto zLength = length(zShapeInfo); auto tid = threadIdx.x + blockIdx.x * blockDim.x; @@ -197,9 +197,9 @@ static _CUDA_G void lambdaPairwiseKernel(void* vx, Nd4jLong *xShapeInfo, void* v z[e * zEws] = lambda(x[e * xEws], y[e * yEws]); } else { for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) { - auto xOffset = __getIndexOffset(e, xShapeInfo); - auto yOffset = __getIndexOffset(e, yShapeInfo); - auto zOffset = __getIndexOffset(e, zShapeInfo); + auto xOffset = getIndexOffset(e, xShapeInfo); + auto yOffset = getIndexOffset(e, yShapeInfo); + auto zOffset = getIndexOffset(e, zShapeInfo); z[zOffset] = lambda(x[xOffset], y[yOffset]); } @@ -224,7 +224,7 @@ static _CUDA_G void lambdaTriplewiseKernel(void* vw, Nd4jLong *wShapeInfo, void* auto yOrder = shape::order(yShapeInfo); auto zOrder = shape::order(zShapeInfo); - auto zLength = __length(zShapeInfo); + auto zLength = length(zShapeInfo); auto tid = threadIdx.x + blockIdx.x * blockDim.x; @@ -233,10 +233,10 @@ static _CUDA_G void lambdaTriplewiseKernel(void* vw, Nd4jLong *wShapeInfo, void* z[e * zEws] = lambda(w[e * wEws], x[e * xEws], y[e * yEws]); } else { for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) { - auto wOffset = __getIndexOffset(e, wShapeInfo); - auto xOffset = __getIndexOffset(e, xShapeInfo); - auto yOffset = __getIndexOffset(e, yShapeInfo); - auto zOffset = __getIndexOffset(e, zShapeInfo); + auto wOffset = getIndexOffset(e, wShapeInfo); + auto xOffset = getIndexOffset(e, xShapeInfo); + auto yOffset = getIndexOffset(e, yShapeInfo); + auto zOffset = getIndexOffset(e, zShapeInfo); z[zOffset] = lambda(w[wOffset], x[xOffset], y[yOffset]); } diff --git a/libnd4j/include/helpers/ConstantShapeHelper.h b/libnd4j/include/helpers/ConstantShapeHelper.h index 56f9d6aeb..4454776a4 100644 --- a/libnd4j/include/helpers/ConstantShapeHelper.h +++ b/libnd4j/include/helpers/ConstantShapeHelper.h @@ -52,6 +52,7 @@ namespace sd { ConstantDataBuffer bufferForShapeInfo(const ShapeDescriptor &descriptor); ConstantDataBuffer bufferForShapeInfo(const Nd4jLong *shapeInfo); ConstantDataBuffer bufferForShapeInfo(const sd::DataType dataType, const char order, const int rank, const Nd4jLong* shape); + ConstantDataBuffer createShapeInfoWithUnitiesForBroadcast(const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, sd::memory::Workspace* workspace = nullptr, const std::vector dimensions = {}); Nd4jLong* emptyShapeInfo(const sd::DataType dataType); diff --git a/libnd4j/include/helpers/cpu/ConstantShapeHelper.cpp b/libnd4j/include/helpers/cpu/ConstantShapeHelper.cpp index b7ffa15f5..a69614906 100644 --- a/libnd4j/include/helpers/cpu/ConstantShapeHelper.cpp +++ b/libnd4j/include/helpers/cpu/ConstantShapeHelper.cpp @@ -146,7 +146,62 @@ namespace sd { return result; } - sd::ConstantShapeHelper* sd::ConstantShapeHelper::_INSTANCE = 0; + +//////////////////////////////////////////////////////////////////////// +ConstantDataBuffer ConstantShapeHelper::createShapeInfoWithUnitiesForBroadcast(const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, sd::memory::Workspace* workspace, const std::vector dimensions) { + + Nd4jLong* newShapeInfo = nullptr; + ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(shape::rank(maxShapeInfo)), Nd4jLong); + + newShapeInfo[0] = shape::rank(maxShapeInfo); + + sd::ArrayOptions::copyDataType(newShapeInfo, minShapeInfo); // type + newShapeInfo[2 * newShapeInfo[0] + 2] = shape::elementWiseStride(minShapeInfo); // ews + newShapeInfo[2 * newShapeInfo[0] + 3] = shape::order(minShapeInfo); // order + + if(!dimensions.empty()) { + + for(uint k = 0, j = 0, i = 0; i < shape::rank(maxShapeInfo); ++i) { + + if(j < dimensions.size() && dimensions[j] == i) { + shape::shapeOf(newShapeInfo)[i] = shape::shapeOf(minShapeInfo)[k]; + shape::stride(newShapeInfo)[i] = shape::stride(minShapeInfo)[k++]; + ++j; + } + else{ + shape::shapeOf(newShapeInfo)[i] = 1; + shape::stride(newShapeInfo)[i] = 0; + if(shape::sizeAt(minShapeInfo, k) == 1 && dimensions.size() != shape::rank(minShapeInfo)) + ++k; + } + } + } + else{ + + for(int j = shape::rank(minShapeInfo) - 1, i = shape::rank(maxShapeInfo) - 1; i >=0 ; --i) { + + if(j >= 0) { + shape::shapeOf(newShapeInfo)[i] = shape::shapeOf(minShapeInfo)[j]; + shape::stride(newShapeInfo)[i] = shape::shapeOf(minShapeInfo)[j] == 1 ? 0 : shape::stride(minShapeInfo)[j]; + --j; + } + else { + shape::shapeOf(newShapeInfo)[i] = 1; + shape::stride(newShapeInfo)[i] = 0; + } + } + } + + ShapeDescriptor descriptor(newShapeInfo); + + RELEASE(newShapeInfo, workspace); + + return bufferForShapeInfo(descriptor); +} + + +sd::ConstantShapeHelper* sd::ConstantShapeHelper::_INSTANCE = 0; + } #endif \ No newline at end of file diff --git a/libnd4j/include/helpers/cuda/ConstantShapeHelper.cu b/libnd4j/include/helpers/cuda/ConstantShapeHelper.cu index e4719bd74..ebce6aac5 100644 --- a/libnd4j/include/helpers/cuda/ConstantShapeHelper.cu +++ b/libnd4j/include/helpers/cuda/ConstantShapeHelper.cu @@ -135,5 +135,59 @@ namespace sd { return result; } - sd::ConstantShapeHelper* sd::ConstantShapeHelper::_INSTANCE = 0; +//////////////////////////////////////////////////////////////////////// +ConstantDataBuffer ConstantShapeHelper::createShapeInfoWithUnitiesForBroadcast(const Nd4jLong* maxShapeInfo, const Nd4jLong* minShapeInfo, sd::memory::Workspace* workspace, const std::vector dimensions) { + + Nd4jLong* newShapeInfo = nullptr; + ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(shape::rank(maxShapeInfo)), Nd4jLong); + + newShapeInfo[0] = shape::rank(maxShapeInfo); + + sd::ArrayOptions::copyDataType(newShapeInfo, minShapeInfo); // type + newShapeInfo[2 * newShapeInfo[0] + 2] = shape::elementWiseStride(minShapeInfo); // ews + newShapeInfo[2 * newShapeInfo[0] + 3] = shape::order(minShapeInfo); // order + + if(!dimensions.empty()) { + + for(uint k = 0, j = 0, i = 0; i < shape::rank(maxShapeInfo); ++i) { + + if(j < dimensions.size() && dimensions[j] == i) { + shape::shapeOf(newShapeInfo)[i] = shape::shapeOf(minShapeInfo)[k]; + shape::stride(newShapeInfo)[i] = shape::stride(minShapeInfo)[k++]; + ++j; + } + else{ + shape::shapeOf(newShapeInfo)[i] = 1; + shape::stride(newShapeInfo)[i] = 0; + if(shape::sizeAt(minShapeInfo, k) == 1 && dimensions.size() != shape::rank(minShapeInfo)) + ++k; + } + } + } + else{ + + for(int j = shape::rank(minShapeInfo) - 1, i = shape::rank(maxShapeInfo) - 1; i >=0 ; --i) { + + if(j >= 0) { + shape::shapeOf(newShapeInfo)[i] = shape::shapeOf(minShapeInfo)[j]; + shape::stride(newShapeInfo)[i] = shape::shapeOf(minShapeInfo)[j] == 1 ? 0 : shape::stride(minShapeInfo)[j]; + --j; + } + else { + shape::shapeOf(newShapeInfo)[i] = 1; + shape::stride(newShapeInfo)[i] = 0; + } + } + } + + ShapeDescriptor descriptor(newShapeInfo); + + RELEASE(newShapeInfo, workspace); + + return bufferForShapeInfo(descriptor); +} + + +sd::ConstantShapeHelper* sd::ConstantShapeHelper::_INSTANCE = 0; + } \ No newline at end of file diff --git a/libnd4j/include/helpers/shape.h b/libnd4j/include/helpers/shape.h index 431734799..4149996f7 100644 --- a/libnd4j/include/helpers/shape.h +++ b/libnd4j/include/helpers/shape.h @@ -118,6 +118,7 @@ namespace shape { ND4J_EXPORT _CUDA_HD bool haveSameShapeAndStrides(const Nd4jLong *shapeInfo1, const Nd4jLong *shapeInfo2, const Nd4jLong *shapeInfo3); ND4J_EXPORT _CUDA_HD int sizeAt(const Nd4jLong *shapeInfo, const int dim); + ND4J_EXPORT _CUDA_HD Nd4jLong strideAt(const Nd4jLong *shapeInfo, const int dim); template ND4J_EXPORT _CUDA_HD void fill(T* buffer, T value, Nd4jLong length); @@ -2989,6 +2990,15 @@ INLINEDEF _CUDA_HD bool haveSameShapeAndStrides(const Nd4jLong *shapeInfo1, cons return shapeInfo[1+(rank(shapeInfo) + dim)]; } + INLINEDEF _CUDA_HD Nd4jLong strideAt(const Nd4jLong *shapeInfo, const int dim) { + if (0 == rank(shapeInfo)) + return 1; + if (dim >= 0) + return shapeInfo[1 + rank(shapeInfo) + dim]; + else + return shapeInfo[1 + 2*rank(shapeInfo) + dim]; + } + /** * This method does SOFT comparison for two shape buffers, we compare only rank & shapes * @@ -4117,7 +4127,7 @@ INLINEDEF _CUDA_HD bool reshapeC(const Nd4jLong* oldShapeInfo, Nd4jLong* newShap *shape::ews(newShapeInfo) = oldEws; // ews } - newShapeInfo[2 * newRank + 1] = shape::type(oldShapeInfo); // type + sd::ArrayOptions::copyDataType(newShapeInfo, oldShapeInfo); // type return true; } @@ -4742,7 +4752,7 @@ INLINEDEF _CUDA_HD void calcSubArrsShapeInfoAndOffsets(const Nd4jLong* wholeShap const int subArrRank = keepUnitiesInShape ? rank : rank - dimsSize; subArrShapeInfo[0] = subArrRank; // rank - subArrShapeInfo[2 * subArrRank + 1] = shape::type(wholeShapeInfo); // type + sd::ArrayOptions::copyDataType(subArrShapeInfo, wholeShapeInfo); // type subArrShapeInfo[2 * subArrRank + 3] = shape::order(wholeShapeInfo); // order Nd4jLong* shape = new Nd4jLong[dimsSize]; @@ -4820,7 +4830,7 @@ INLINEDEF void calcSubArrShapeInfoAndOffset(const Nd4jLong* idx, const Nd4jLong* } minShapeInfo[2 * shape::rank(minShapeInfo) + 3] = shape::order(maxShapeInfo); // order - minShapeInfo[2 * shape::rank(minShapeInfo) + 1] = shape::type(maxShapeInfo); // type + sd::ArrayOptions::copyDataType(minShapeInfo, maxShapeInfo); // type shape::checkStridesEwsAndOrder(minShapeInfo); } @@ -5114,9 +5124,9 @@ INLINEDEF _CUDA_HD void excludeUnitiesFromShapeInfo(const Nd4jLong* inShapeInfo, shape::stride(outShapeInfo)[k++] = shape::stride(inShapeInfo)[i]; } - outShapeInfo[2 * outShapeInfo[0] + 1] = shape::type(inShapeInfo); // type - *shape::ews(outShapeInfo) = shape::elementWiseStride(inShapeInfo); // ews - outShapeInfo[2 * outShapeInfo[0] + 3] = shape::order(inShapeInfo); // order + sd::ArrayOptions::copyDataType(outShapeInfo, inShapeInfo); // type + *shape::ews(outShapeInfo) = shape::elementWiseStride(inShapeInfo); // ews + outShapeInfo[2 * outShapeInfo[0] + 3] = shape::order(inShapeInfo); // order } diff --git a/libnd4j/include/legacy/NativeOpExecutioner.h b/libnd4j/include/legacy/NativeOpExecutioner.h index 56fb500db..4d55a3357 100644 --- a/libnd4j/include/legacy/NativeOpExecutioner.h +++ b/libnd4j/include/legacy/NativeOpExecutioner.h @@ -46,7 +46,7 @@ public: * @param resultShapeInfo */ static void execIndexReduceScalar(sd::LaunchContext *lc, - int opNum, + int opNum, void *hX, Nd4jLong *hXShapeInfo, void *dX, Nd4jLong *dXShapeInfo, void *extraParams, @@ -75,7 +75,7 @@ public: void *dY, Nd4jLong *dYShapeInfo, void *hZ, Nd4jLong *hZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo); - + /** * @@ -263,6 +263,15 @@ static void execScalarInt(sd::LaunchContext *lc, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ); + static void execBroadcast(sd::LaunchContext* lc, + const int opNum, + const void *hX, const Nd4jLong *hXShapeInfo, + const void *dX, const Nd4jLong *dXShapeInfo, + const void *hY, const Nd4jLong *hYShapeInfo, + const void *dY, const Nd4jLong *dYShapeInfo, + void *hZ, const Nd4jLong *hZShapeInfo, + void *dZ, const Nd4jLong *dZShapeInfo); + static void execInverseBroadcast(sd::LaunchContext *lc, int opNum, void *x, Nd4jLong *xShapeInfo, @@ -289,6 +298,15 @@ static void execScalarInt(sd::LaunchContext *lc, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ); + static void execBroadcastBool(sd::LaunchContext* lc, const int opNum, + const void *hX, const Nd4jLong *hXShapeInfo, + const void *dX, const Nd4jLong *dXShapeInfo, + const void *hY, const Nd4jLong *hYShapeInfo, + const void *dY, const Nd4jLong *dYShapeInfo, + void *hZ, const Nd4jLong *hZShapeInfo, + void *dZ, const Nd4jLong *dZShapeInfo, + void *extraParams); + static void execInverseBroadcastBool(sd::LaunchContext *lc, int opNum, void *x, Nd4jLong *xShapeInfo, @@ -314,6 +332,14 @@ static void execScalarInt(sd::LaunchContext *lc, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ); + static void execBroadcastInt(sd::LaunchContext* lc, const int opNum, + const void *hX, const Nd4jLong *hXShapeInfo, + const void *dX, const Nd4jLong *dXShapeInfo, + const void *hY, const Nd4jLong *hYShapeInfo, + const void *dY, const Nd4jLong *dYShapeInfo, + void *hZ, const Nd4jLong *hZShapeInfo, + void *dZ, const Nd4jLong *dZShapeInfo); + static void execInverseBroadcastInt(sd::LaunchContext *lc, int opNum, void *x, Nd4jLong *xShapeInfo, @@ -325,7 +351,7 @@ static void execScalarInt(sd::LaunchContext *lc, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ); - + /** * * @param opNum @@ -421,7 +447,7 @@ static void execTransformBool(sd::LaunchContext *lc, void *hZ, Nd4jLong *hZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo, void *extraParams, - Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets); + Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets); /** * * @param opNum @@ -509,7 +535,7 @@ static void execTransformBool(sd::LaunchContext *lc, void *dX, Nd4jLong *dXShapeInfo, void *extraParams, void *hZ, Nd4jLong *hZShapeInfo, - void *dZ, Nd4jLong *dZShapeInfo); + void *dZ, Nd4jLong *dZShapeInfo); static void execReduce3TAD(sd::LaunchContext *lc, int opNum, @@ -520,7 +546,7 @@ static void execTransformBool(sd::LaunchContext *lc, void *dY, Nd4jLong *dYShapeInfo, void *hZ, Nd4jLong *hZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo, - int *dimension, int dimensionLength, + int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *yTadShapeInfo, Nd4jLong *yTadOffsets); @@ -544,7 +570,7 @@ static void execTransformBool(sd::LaunchContext *lc, void *dZ, Nd4jLong *dZShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, - bool biasCorrected); + bool biasCorrected); /** * @@ -580,7 +606,7 @@ static void execTransformBool(sd::LaunchContext *lc, void *extraParams, void *hZ, Nd4jLong *hZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo, - bool biasCorrected); + bool biasCorrected); static void execRandom(sd::LaunchContext *lc, @@ -627,7 +653,7 @@ static void execTransformBool(sd::LaunchContext *lc, int numRealArguments) { } - + inline static void execSort(void *x, Nd4jLong *xShapeInfo, bool descending) { auto xType = sd::ArrayOptions::dataType(xShapeInfo); diff --git a/libnd4j/include/legacy/cpu/NativeOpExecutioner.cpp b/libnd4j/include/legacy/cpu/NativeOpExecutioner.cpp index 574978a7d..b3f15e345 100644 --- a/libnd4j/include/legacy/cpu/NativeOpExecutioner.cpp +++ b/libnd4j/include/legacy/cpu/NativeOpExecutioner.cpp @@ -204,6 +204,30 @@ void NativeOpExecutioner::execBroadcast(sd::LaunchContext *lc, #endif } +//////////////////////////////////////////////////////////////////////// +void NativeOpExecutioner::execBroadcast(sd::LaunchContext* lc, const int opNum, + const void *hX, const Nd4jLong *hXShapeInfo, + const void *dX, const Nd4jLong *dXShapeInfo, + const void *hY, const Nd4jLong *hYShapeInfo, + const void *dY, const Nd4jLong *dYShapeInfo, + void *hZ, const Nd4jLong *hZShapeInfo, + void *dZ, const Nd4jLong *dZShapeInfo) { + + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) + return; + + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hYShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + #ifdef __ND4J_EXPERIMENTAL__ + BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::broadcast::Broadcast, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo), LIBND4J_TYPES, LIBND4J_TYPES); + #else + BUILD_SINGLE_SELECTOR_THRICE(xType, functions::broadcast::Broadcast, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo), LIBND4J_TYPES); + #endif +} + void NativeOpExecutioner::execInverseBroadcast(sd::LaunchContext *lc, int opNum, void *hX, Nd4jLong *hXShapeInfo, @@ -258,13 +282,12 @@ void NativeOpExecutioner::execBroadcastBool(sd::LaunchContext *lc, Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) { - auto xType = sd::ArrayOptions::dataType(hXShapeInfo); - auto yType = sd::ArrayOptions::dataType(hYShapeInfo); - auto zType = sd::ArrayOptions::dataType(hZShapeInfo); - if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) return; + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + auto func = PRAGMA_THREADS_FOR { BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ, start, stop), LIBND4J_TYPES, BOOL_TYPES); }; @@ -276,6 +299,26 @@ void NativeOpExecutioner::execBroadcastBool(sd::LaunchContext *lc, samediff::Threads::parallel_tad(func, 0, numTads); } +//////////////////////////////////////////////////////////////////////// +void NativeOpExecutioner::execBroadcastBool(sd::LaunchContext* lc, const int opNum, + const void *hX, const Nd4jLong *hXShapeInfo, + const void *dX, const Nd4jLong *dXShapeInfo, + const void *hY, const Nd4jLong *hYShapeInfo, + const void *dY, const Nd4jLong *dYShapeInfo, + void *hZ, const Nd4jLong *hZShapeInfo, + void *dZ, const Nd4jLong *dZShapeInfo, + void *extraParams) { + + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) + return; + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, extraParams), LIBND4J_TYPES, BOOL_TYPES); +} + + void NativeOpExecutioner::execInverseBroadcastBool(sd::LaunchContext *lc, int opNum, void *hX, Nd4jLong *hXShapeInfo, @@ -351,6 +394,33 @@ void NativeOpExecutioner::execBroadcastInt(sd::LaunchContext *lc, samediff::Threads::parallel_tad(func, 0, numTads); } +//////////////////////////////////////////////////////////////////////// +void NativeOpExecutioner::execBroadcastInt(sd::LaunchContext *lc, const int opNum, + const void *hX, const Nd4jLong *hXShapeInfo, + const void *dX, const Nd4jLong *dXShapeInfo, + const void *hY, const Nd4jLong *hYShapeInfo, + const void *dY, const Nd4jLong *dYShapeInfo, + void *hZ, const Nd4jLong *hZShapeInfo, + void *dZ, const Nd4jLong *dZShapeInfo) { + + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hYShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) + return; + + if (xType != yType || xType != zType) + throw sd::datatype_exception::build("NativeOpExecutioner::execBroadcastInt", zType, xType, yType); + + if (!sd::DataTypeUtils::isZ(zType)) + throw sd::datatype_exception::build("NativeOpExecutioner::execBroadcastInt requires integer data type", zType); + + BUILD_SINGLE_SELECTOR(xType, functions::broadcast::BroadcastInt, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo), INTEGER_TYPES); + +} + void NativeOpExecutioner::execInverseBroadcastInt(sd::LaunchContext *lc, int opNum, void *hX, Nd4jLong *hXShapeInfo, diff --git a/libnd4j/include/legacy/cuda/NativeOpExecutioner.cu b/libnd4j/include/legacy/cuda/NativeOpExecutioner.cu index 2a16efbda..00a9ea03f 100644 --- a/libnd4j/include/legacy/cuda/NativeOpExecutioner.cu +++ b/libnd4j/include/legacy/cuda/NativeOpExecutioner.cu @@ -265,6 +265,39 @@ void NativeOpExecutioner::execBroadcastBool(sd::LaunchContext *lc, throw cuda_exception::build("execBroadcastBool failed", res); } +//////////////////////////////////////////////////////////////////////// +void NativeOpExecutioner::execBroadcastBool(sd::LaunchContext* lc, const int opNum, + const void *hX, const Nd4jLong *hXShapeInfo, + const void *dX, const Nd4jLong *dXShapeInfo, + const void *hY, const Nd4jLong *hYShapeInfo, + const void *dY, const Nd4jLong *dYShapeInfo, + void *hZ, const Nd4jLong *hZShapeInfo, + void *dZ, const Nd4jLong *dZShapeInfo, + void *extraParams) { + + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) + return; + + auto stream = lc->getCudaStream(); + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + dim3 launchDims; + + launchDims.y = MAX_NUM_THREADS / 4; // threadsPerBlock + launchDims.x = (shape::length(hZShapeInfo) + launchDims.y - 1) / launchDims.y; // blocksPerGrid + launchDims.z = 1024; // shared memory + + BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::execBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, extraParams), LIBND4J_TYPES, BOOL_TYPES); + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) + throw cuda_exception::build("execBroadcastBool failed", res); +} + + void NativeOpExecutioner::execInverseBroadcastBool(sd::LaunchContext *lc, int opNum, void *hX, Nd4jLong *hXShapeInfo, @@ -302,7 +335,6 @@ void NativeOpExecutioner::execInverseBroadcastBool(sd::LaunchContext *lc, throw cuda_exception::build("execInverseBroadcastBool failed", res); } - //////////////////////////////////////////////////////////////////////// void NativeOpExecutioner::execBroadcastInt(sd::LaunchContext *lc, int opNum, @@ -341,6 +373,44 @@ void NativeOpExecutioner::execBroadcastInt(sd::LaunchContext *lc, throw cuda_exception::build("execBroadcastBool failed", res); } +//////////////////////////////////////////////////////////////////////// +void NativeOpExecutioner::execBroadcastInt(sd::LaunchContext* lc, const int opNum, + const void *hX, const Nd4jLong *hXShapeInfo, + const void *dX, const Nd4jLong *dXShapeInfo, + const void *hY, const Nd4jLong *hYShapeInfo, + const void *dY, const Nd4jLong *dYShapeInfo, + void *hZ, const Nd4jLong *hZShapeInfo, + void *dZ, const Nd4jLong *dZShapeInfo) { + + auto stream = lc->getCudaStream(); + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hYShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) + return; + + if (!DataTypeUtils::isZ(zType)) + throw std::runtime_error("NativeOpExecutioner::execBroadcastInt requires Z operand to have INT type"); + + if (yType != xType || zType != xType) + throw std::runtime_error("NativeOpExecutioner::execBroadcastInt requires both X & Y operands to have same type"); + + dim3 launchDims; + + launchDims.y = MAX_NUM_THREADS / 4; // threadsPerBlock + launchDims.x = (shape::length(hZShapeInfo) + launchDims.y - 1) / launchDims.y; // blocksPerGrid + launchDims.z = 1024; // shared memory + + BUILD_SINGLE_SELECTOR(xType, functions::broadcast::BroadcastInt, ::execBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo), INTEGER_TYPES) + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) + throw cuda_exception::build("execBroadcastBool failed", res); +} + void NativeOpExecutioner::execInverseBroadcastInt(sd::LaunchContext *lc, int opNum, void *hX, Nd4jLong *hXShapeInfo, @@ -428,6 +498,42 @@ void NativeOpExecutioner::execBroadcast(sd::LaunchContext *lc, throw cuda_exception::build("execBroadcast failed", res); } +//////////////////////////////////////////////////////////////////////// +void NativeOpExecutioner::execBroadcast(sd::LaunchContext *lc, const int opNum, + const void *hX, const Nd4jLong *hXShapeInfo, + const void *dX, const Nd4jLong *dXShapeInfo, + const void *hY, const Nd4jLong *hYShapeInfo, + const void *dY, const Nd4jLong *dYShapeInfo, + void *hZ, const Nd4jLong *hZShapeInfo, + void *dZ, const Nd4jLong *dZShapeInfo) { + + auto stream = lc->getCudaStream(); + + auto xType = sd::ArrayOptions::dataType(hXShapeInfo); + auto yType = sd::ArrayOptions::dataType(hYShapeInfo); + auto zType = sd::ArrayOptions::dataType(hZShapeInfo); + + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) + return; + + dim3 launchDims; + + launchDims.y = MAX_NUM_THREADS / 4; // threadsPerBlock + launchDims.x = (shape::length(hZShapeInfo) + launchDims.y - 1) / launchDims.y; // blocksPerGrid + launchDims.z = 1024; // shared memory + +#ifdef __ND4J_EXPERIMENTAL__ + BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::broadcast::Broadcast, ::execBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo), LIBND4J_TYPES, LIBND4J_TYPES); +#else + BUILD_SINGLE_SELECTOR_THRICE(xType, functions::broadcast::Broadcast, ::execBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo), LIBND4J_TYPES); +#endif + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) + throw cuda_exception::build("execBroadcast failed", res); +} + void NativeOpExecutioner::execInverseBroadcast(sd::LaunchContext *lc, int opNum, void *hX, Nd4jLong *hXShapeInfo, diff --git a/libnd4j/include/loops/TrueBroadcastHelper.h b/libnd4j/include/loops/TrueBroadcastHelper.h deleted file mode 100644 index 71934b674..000000000 --- a/libnd4j/include/loops/TrueBroadcastHelper.h +++ /dev/null @@ -1,84 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// @author Yurii Shyrma (iuriish@yahoo.com) -// - -#ifndef LIBND4J_TRUEBROADCASTHELPER_H -#define LIBND4J_TRUEBROADCASTHELPER_H - -#include - -namespace sd { -namespace helpers { - -//////////////////////////////////////////////////////////////////////// -template -class TrueBroadcastHelper { - - #ifdef __CUDACC__ - template - static __host__ void execLauncher(dim3 launchDims, cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo); - #else - template - static void exec(const NDArray& xArr, const NDArray& yArr, NDArray& zArr); - #endif - - public: - static void exec(const sd::broadcast::Ops opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr); -}; - -template -class TrueBroadcastBoolHelper { - - #ifdef __CUDACC__ - template - static __host__ void execLauncher(dim3 launchDims, cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo); - #else - template - static void exec(const NDArray& xArr, const NDArray& yArr, NDArray& zArr); - #endif - - public: - - static void exec(const sd::broadcast::BoolOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr); -}; - -//////////////////////////////////////////////////////////////////////// -template -class TrueBroadcastIntHelper { - - #ifdef __CUDACC__ - template - static __host__ void execLauncher(dim3 launchDims, cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo); - #else - template - static void exec(const NDArray& xArr, const NDArray& yArr, NDArray& zArr); - #endif - - public: - - static void exec(const sd::broadcast::IntOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr); -}; - - -} -} - - - -#endif //LIBND4J_BIDIAGONALUP_H diff --git a/libnd4j/include/loops/broadcasting.h b/libnd4j/include/loops/broadcasting.h index 4d53e4c73..20c95588c 100755 --- a/libnd4j/include/loops/broadcasting.h +++ b/libnd4j/include/loops/broadcasting.h @@ -69,11 +69,21 @@ namespace functions { int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ); + template + static __device__ void transformCuda(const void *x, const Nd4jLong *xShapeInfo, + const void *y, const Nd4jLong *yShapeInfo, + void *z, const Nd4jLong *zShapeInfo); + template static __host__ void intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ); + template + static __host__ void intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, const void *x, const Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo, void *z, const Nd4jLong *zShapeInfo); + static __host__ void execBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ); + static __host__ void execBroadcast(dim3 launchDims, cudaStream_t *stream, const int opNum, const void *x, const Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo, void *z, const Nd4jLong *zShapeInfo); + template static __device__ void transformInverseCuda( @@ -170,6 +180,17 @@ namespace functions { Nd4jLong *tadOffsetZ, uint64_t start, uint64_t stop); + + static void exec(const int opNum, + const void *x, const Nd4jLong *xShapeInfo, + const void *y, const Nd4jLong *yShapeInfo, + void *z, const Nd4jLong *zShapeInfo); + + template + static void exec(const void *x, const Nd4jLong *xShapeInfo, + const void *y, const Nd4jLong *yShapeInfo, + void *z, const Nd4jLong *zShapeInfo); + #endif }; } diff --git a/libnd4j/include/loops/broadcasting_bool.h b/libnd4j/include/loops/broadcasting_bool.h index 23ed95a21..9bab82c81 100644 --- a/libnd4j/include/loops/broadcasting_bool.h +++ b/libnd4j/include/loops/broadcasting_bool.h @@ -69,11 +69,30 @@ namespace functions { int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ); + template + static __device__ void transformCuda(const void *x, const Nd4jLong *xShapeInfo, + const void *y, const Nd4jLong *yShapeInfo, + void *z, const Nd4jLong *zShapeInfo, + void *extraParams); + template static __host__ void intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ); + template + static __host__ void intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, + const void *x, const Nd4jLong *xShapeInfo, + const void *y, const Nd4jLong *yShapeInfo, + void *z, const Nd4jLong *zShapeInfo, + void *extraParams); + static __host__ void execBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ); + static __host__ void execBroadcast(dim3 launchDims, cudaStream_t *stream, const int opNum, + const void *x, const Nd4jLong *xShapeInfo, + const void *y, const Nd4jLong *yShapeInfo, + void *z, const Nd4jLong *zShapeInfo, + void *extraParams); + template static __device__ void transformInverseCuda( void *x, @@ -110,6 +129,12 @@ namespace functions { uint64_t start, uint64_t stop); + static void exec(const int opNum, + const void *x, const Nd4jLong *xShapeInfo, + const void *y, const Nd4jLong *yShapeInfo, + void *z, const Nd4jLong *zShapeInfo, + void *extraParams); + static void execInverse(int opNum, void *x, Nd4jLong *xShapeInfo, @@ -155,6 +180,12 @@ namespace functions { uint64_t start, uint64_t stop); + template + static void exec(const void *x, const Nd4jLong *xShapeInfo, + const void *y, const Nd4jLong *yShapeInfo, + void *z, const Nd4jLong *zShapeInfo, + void *extraParams); + template static void execInverse(void *x, Nd4jLong *xShapeInfo, diff --git a/libnd4j/include/loops/broadcasting_int.h b/libnd4j/include/loops/broadcasting_int.h index 2c33491b2..81149ad8a 100644 --- a/libnd4j/include/loops/broadcasting_int.h +++ b/libnd4j/include/loops/broadcasting_int.h @@ -68,11 +68,27 @@ namespace functions { int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ); + template + static __device__ void transformCuda(const void *x, const Nd4jLong *xShapeInfo, + const void *y, const Nd4jLong *yShapeInfo, + void *z, const Nd4jLong *zShapeInfo); + template static __host__ void intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ); + template + static __host__ void intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, + const void *x, const Nd4jLong *xShapeInfo, + const void *y, const Nd4jLong *yShapeInfo, + void *z, const Nd4jLong *zShapeInfo); + static __host__ void execBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ); + static __host__ void execBroadcast(dim3 launchDims, cudaStream_t *stream, const int opNum, + const void *x, const Nd4jLong *xShapeInfo, + const void *y, const Nd4jLong *yShapeInfo, + void *z, const Nd4jLong *zShapeInfo); + template static __device__ void transformInverseCuda( void *x, @@ -107,6 +123,11 @@ namespace functions { uint64_t start, uint64_t stop); + static void exec(const int opNum, + const void *x, const Nd4jLong *xShapeInfo, + const void *y, const Nd4jLong *yShapeInfo, + void *z, const Nd4jLong *zShapeInfo); + static void execInverse(int opNum, void *x, Nd4jLong *xShapeInfo, @@ -150,6 +171,11 @@ namespace functions { uint64_t start, uint64_t stop); + template + static void exec(const void *x, const Nd4jLong *xShapeInfo, + const void *y, const Nd4jLong *yShapeInfo, + void *z, const Nd4jLong *zShapeInfo); + template static void execInverse(void *x, Nd4jLong *xShapeInfo, diff --git a/libnd4j/include/loops/cpu/TrueBroadcastHelper.hpp b/libnd4j/include/loops/cpu/TrueBroadcastHelper.hpp deleted file mode 100644 index 7edb9d90d..000000000 --- a/libnd4j/include/loops/cpu/TrueBroadcastHelper.hpp +++ /dev/null @@ -1,291 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - - // - // @author Yurii Shyrma (iuriish@yahoo.com) - // - -#include -#include -#include - -using namespace simdOps; - -namespace sd { - namespace helpers { - - //////////////////////////////////////////////////////////////////////// - template - template - void TrueBroadcastHelper::exec(const NDArray& xArr, const NDArray& yArr, NDArray& zArr) { - - - const X* x = reinterpret_cast(xArr.getBuffer()); - const Y* y = reinterpret_cast(yArr.getBuffer()); - Z* z = reinterpret_cast(zArr.getBuffer()); - - const auto xShapeInfo = xArr.getShapeInfo(); - const auto yShapeInfo = yArr.getShapeInfo(); - const auto zShapeInfo = zArr.getShapeInfo(); - - const int xRank = xArr.rankOf(); - const int yRank = yArr.rankOf(); - const int zRank = zArr.rankOf(); - - bool bSpecialCase = (1 == xArr.ews() && 'c' == xArr.ordering() && - 1 == yArr.ews() && 'c' == yArr.ordering() && - 1 == zArr.ews() && 'c' == zArr.ordering()); - - if (bSpecialCase && yArr.isColumnVector() && 1 == xArr.sizeAt(-1) ) { - auto yLen = yArr.lengthOf(); - auto func = PRAGMA_THREADS_FOR{ - for (auto i = start; i < stop; i++) { - auto rZ = z + (i * yLen); - auto v = x[i]; - for (Nd4jLong j = 0; j < yLen; j++) { - rZ[j] = OpType::op(v, y[j]); - } - } - }; - samediff::Threads::parallel_tad(func, 0, xArr.lengthOf()); - return; - } - - - auto yShapeInt = yArr.getShapeAsVectorInt(); - auto xShapeInt = xArr.getShapeAsVectorInt(); - auto nCountY = std::count_if(yShapeInt.cbegin(), yShapeInt.cend(), [](int i) { return i == 1; }); - auto nCountX = std::count_if(xShapeInt.cbegin(), xShapeInt.cend(), [](int i) { return i == 1; }); - - bool bSpecialCase2 = (xRank == zRank && yRank == zRank && 1 == xArr.sizeAt(-1) && 1 == yArr.sizeAt(-2) && 1 == nCountY && 1 == nCountX); - - if (bSpecialCase && bSpecialCase2) { - - uint32_t zDim1 = zArr.sizeAt(-2); - uint32_t zDim2 = zArr.sizeAt(-1); - - uint32_t nLen = zArr.lengthOf() / yArr.sizeAt(-1); - - auto func = PRAGMA_THREADS_FOR{ - for (auto total = start; total < stop; total++) { - - uint32_t i = total / zDim1; - uint32_t j = total % zDim1; - - uint32_t index = (i * zDim1) + j; - auto rZ = z + (index * zDim2); - auto rY = y + (i * zDim2); - auto rX = x[index]; - - for (uint32_t n = 0; n < zDim2; n++) { - rZ[n] = OpType::op(rX, rY[n]); - } - } - }; - samediff::Threads::parallel_tad(func, 0, nLen, 1); - return; - } - - - const Nd4jLong zLen = zArr.lengthOf(); - auto func = PRAGMA_THREADS_FOR{ - std::vector xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf()); - - for (auto i = start; i < stop; ++i) { - - shape::index2coords(i, zShapeInfo, zCoords.data()); - - for (int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) { - - if (ix >= 0) { - if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) { - xCoords[ix--] = zCoords[iz]; - } - else { - xCoords[ix--] = 0; - } - } - - if (iy >= 0) { - if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) { - yCoords[iy--] = zCoords[iz]; - } - else { - yCoords[iy--] = 0; - } - } - } - - const auto xOffset = shape::getOffset(xShapeInfo, xCoords.data()); - const auto yOffset = shape::getOffset(yShapeInfo, yCoords.data()); - const auto zOffset = shape::getOffset(zShapeInfo, zCoords.data()); - - z[zOffset] = OpType::op(x[xOffset], y[yOffset]); - } - }; - - samediff::Threads::parallel_for(func, 0, zLen); - } - - template - void TrueBroadcastHelper::exec(const sd::broadcast::Ops opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) { - DISPATCH_BY_OPNUM_TTT(exec, PARAMS(xArr, yArr, zArr), BROADCAST_OPS); - } - - //////////////////////////////////////////////////////////////////////// - template - template - void TrueBroadcastBoolHelper::exec(const NDArray& xArr, const NDArray& yArr, NDArray& zArr) { - - const X* x = reinterpret_cast(xArr.getBuffer()); - const X* y = reinterpret_cast(yArr.getBuffer()); - Z* z = reinterpret_cast(zArr.getBuffer()); - - const auto xShapeInfo = xArr.getShapeInfo(); - const auto yShapeInfo = yArr.getShapeInfo(); - const auto zShapeInfo = zArr.getShapeInfo(); - - const int xRank = xArr.rankOf(); - const int yRank = yArr.rankOf(); - const int zRank = zArr.rankOf(); - - const Nd4jLong zLen = zArr.lengthOf(); - - auto func = PRAGMA_THREADS_FOR{ - std::vector xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf()); - - for (auto i = start; i < stop; ++i) { - - shape::index2coords(i, zShapeInfo, zCoords.data()); - - for (int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) { - - if (ix >= 0) { - if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) { - xCoords[ix--] = zCoords[iz]; - } - else { - xCoords[ix--] = 0; - } - } - - if (iy >= 0) { - if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) { - yCoords[iy--] = zCoords[iz]; - } - else { - yCoords[iy--] = 0; - } - } - } - - const auto xOffset = shape::getOffset(xShapeInfo, xCoords.data()); - const auto yOffset = shape::getOffset(yShapeInfo, yCoords.data()); - const auto zOffset = shape::getOffset(zShapeInfo, zCoords.data()); - - z[zOffset] = OpType::op(x[xOffset], y[yOffset], nullptr); - } - }; - - samediff::Threads::parallel_for(func, 0, zLen); - } - - template - void TrueBroadcastBoolHelper::exec(const sd::broadcast::BoolOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) { - DISPATCH_BY_OPNUM_TT(exec, PARAMS(xArr, yArr, zArr), BROADCAST_BOOL_OPS); - } - - //////////////////////////////////////////////////////////////////////// - template - template - void TrueBroadcastIntHelper::exec(const NDArray& xArr, const NDArray& yArr, NDArray& zArr) { - - const X* x = reinterpret_cast(xArr.getBuffer()); - const X* y = reinterpret_cast(yArr.getBuffer()); - X* z = reinterpret_cast(zArr.getBuffer()); - - const auto xShapeInfo = xArr.getShapeInfo(); - const auto yShapeInfo = yArr.getShapeInfo(); - const auto zShapeInfo = zArr.getShapeInfo(); - - const int xRank = xArr.rankOf(); - const int yRank = yArr.rankOf(); - const int zRank = zArr.rankOf(); - - const Nd4jLong zLen = zArr.lengthOf(); - - auto func = PRAGMA_THREADS_FOR{ - std::vector xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf()); - - for (auto i = start; i < stop; ++i) { - - shape::index2coords(i, zShapeInfo, zCoords.data()); - - for (int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) { - - if (ix >= 0) { - if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) { - xCoords[ix--] = zCoords[iz]; - } - else { - xCoords[ix--] = 0; - } - } - - if (iy >= 0) { - if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) { - yCoords[iy--] = zCoords[iz]; - } - else { - yCoords[iy--] = 0; - } - } - } - - const auto xOffset = shape::getOffset(xShapeInfo, xCoords.data()); - const auto yOffset = shape::getOffset(yShapeInfo, yCoords.data()); - const auto zOffset = shape::getOffset(zShapeInfo, zCoords.data()); - - z[zOffset] = OpType::op(x[xOffset], y[yOffset]); - } - }; - - samediff::Threads::parallel_for(func, 0, zLen); - } - - template - void TrueBroadcastIntHelper::exec(const sd::broadcast::IntOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) { - DISPATCH_BY_OPNUM_T(exec, PARAMS(xArr, yArr, zArr), BROADCAST_INT_OPS); - } - - /* - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_0); - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_1); - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_2); - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_3); - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_4); - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_5); - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_6); - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_7); - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_8); - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_9); - - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastBoolHelper, , LIBND4J_TYPES, BOOL_TYPES); - - BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastIntHelper, , INTEGER_TYPES); - */ - } -} diff --git a/libnd4j/include/loops/cpu/broadcasting.hpp b/libnd4j/include/loops/cpu/broadcasting.hpp index 11dcb56f2..d69396e70 100644 --- a/libnd4j/include/loops/cpu/broadcasting.hpp +++ b/libnd4j/include/loops/cpu/broadcasting.hpp @@ -30,7 +30,7 @@ using namespace simdOps; namespace functions { - namespace broadcast { +namespace broadcast { template void Broadcast::execInverse(const int opNum, @@ -93,6 +93,7 @@ namespace functions { zTadOffset, loopKind, start, stop), BROADCAST_OPS); } + template template void Broadcast::exec(void *vx, @@ -562,5 +563,205 @@ namespace functions { }; } } + + +//////////////////////////////////////////////////////////////////////// +template + void Broadcast::exec(const int opNum, const void *x, const Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo, void *z, const Nd4jLong *zShapeInfo) { + + DISPATCH_BY_OPNUM_TTT(exec, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo), BROADCAST_OPS); +} + +//////////////////////////////////////////////////////////////////////// +template +template +void Broadcast::exec(const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo) { + + const X* x = reinterpret_cast(vx); + const Y* y = reinterpret_cast(vy); + Z* z = reinterpret_cast(vz); + + const int rank = shape::rank(zShapeInfo); // xRank = yRank = zRank + const char zOrder = shape::order(zShapeInfo); + + uint xAxis0 = shape::sizeAt(xShapeInfo, zOrder == 'c' ? 0 : rank-1); + uint xAxis1 = shape::sizeAt(xShapeInfo, zOrder == 'c' ? 1 : rank-2); + uint xAxis2 = rank > 2 ? shape::sizeAt(xShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0; + uint xAxis3 = rank > 3 ? shape::sizeAt(xShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0; + uint xAxis4 = rank > 4 ? shape::sizeAt(xShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0; + Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, zOrder == 'c' ? 0 : rank-1); + Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, zOrder == 'c' ? 1 : rank-2); + Nd4jLong xStrd2 = rank > 2 ? shape::strideAt(xShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0; + Nd4jLong xStrd3 = rank > 3 ? shape::strideAt(xShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0; + Nd4jLong xStrd4 = rank > 4 ? shape::strideAt(xShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0; + + uint yAxis0 = shape::sizeAt(yShapeInfo, zOrder == 'c' ? 0 : rank-1); + uint yAxis1 = shape::sizeAt(yShapeInfo, zOrder == 'c' ? 1 : rank-2); + uint yAxis2 = rank > 2 ? shape::sizeAt(yShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0; + uint yAxis3 = rank > 3 ? shape::sizeAt(yShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0; + uint yAxis4 = rank > 4 ? shape::sizeAt(yShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0; + Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, zOrder == 'c' ? 0 : rank-1); + Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, zOrder == 'c' ? 1 : rank-2); + Nd4jLong yStrd2 = rank > 2 ? shape::strideAt(yShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0; + Nd4jLong yStrd3 = rank > 3 ? shape::strideAt(yShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0; + Nd4jLong yStrd4 = rank > 4 ? shape::strideAt(yShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0; + + uint zAxis0 = shape::sizeAt(zShapeInfo, zOrder == 'c' ? 0 : rank-1); + uint zAxis1 = shape::sizeAt(zShapeInfo, zOrder == 'c' ? 1 : rank-2); + uint zAxis2 = rank > 2 ? shape::sizeAt(zShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0; + uint zAxis3 = rank > 3 ? shape::sizeAt(zShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0; + uint zAxis4 = rank > 4 ? shape::sizeAt(zShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0; + Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, zOrder == 'c' ? 0 : rank-1); + Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, zOrder == 'c' ? 1 : rank-2); + Nd4jLong zStrd2 = rank > 2 ? shape::strideAt(zShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0; + Nd4jLong zStrd3 = rank > 3 ? shape::strideAt(zShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0; + Nd4jLong zStrd4 = rank > 4 ? shape::strideAt(zShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0; + + switch (rank) { + + case 1: { + + auto func = PRAGMA_THREADS_FOR{ + + if(zStrd0 == 1 && xStrd0 <= 1 && yStrd0 <= 1) + for (auto i0 = start; i0 < stop; ++i0) + z[i0] = OpType::op(x[xStrd0 ? i0 : 0], y[yStrd0 ? i0 : 0]); + else + for (auto i0 = start; i0 < stop; ++i0) + z[i0 * zStrd0] = OpType::op(x[i0 * xStrd0], y[i0 * yStrd0]); + }; + samediff::Threads::parallel_tad(func, 0, zAxis0); + } + break; + + case 2: { + + auto func = PRAGMA_THREADS_FOR{ + + for (auto i0 = start; i0 < stop; ++i0) { + + auto x0 = x + i0 * xStrd0; + auto y0 = y + i0 * yStrd0; + auto z0 = z + i0 * zStrd0; + + if(zStrd1 == 1 && xStrd1 <= 1 && yStrd1 <= 1) + for (uint i1 = 0; i1 < zAxis1; ++i1) + z0[i1] = OpType::op(x0[xStrd1 ? i1 : 0], y0[yStrd1 ? i1 : 0]); + else + for (uint i1 = 0; i1 < zAxis1; ++i1) + z0[i1 * zStrd1] = OpType::op(x0[i1 * xStrd1], y0[i1 * yStrd1]); + } + }; + samediff::Threads::parallel_tad(func, 0, zAxis0); + } + break; + + case 3: { + + + auto func = PRAGMA_THREADS_FOR_2D { + + for (auto i0 = start_x; i0 < stop_x; ++i0) { + for (auto i1 = start_y; i1 < stop_y; ++i1) { + + auto x1 = x + i0 * xStrd0 + i1 * xStrd1; + auto y1 = y + i0 * yStrd0 + i1 * yStrd1; + auto z1 = z + i0 * zStrd0 + i1 * zStrd1; + + if(zStrd2 == 1 && xStrd2 <= 1 && yStrd2 <= 1) + for (uint i2 = 0; i2 < zAxis2; ++i2) + z1[i2] = OpType::op(x1[xStrd2 ? i2 : 0], y1[yStrd2 ? i2 : 0]); + else + for (uint i2 = 0; i2 < zAxis2; ++i2) + z1[i2 * zStrd2] = OpType::op(x1[i2 * xStrd2], y1[i2 * yStrd2]); + } + } + }; + samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1); + } + break; + + case 4: { + + auto func = PRAGMA_THREADS_FOR_3D { + + for (auto i0 = start_x; i0 < stop_x; ++i0) { + for (auto i1 = start_y; i1 < stop_y; ++i1) { + for (auto i2 = start_z; i2 < stop_z; ++i2) { + + auto x2 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2; + auto y2 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2; + auto z2 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2; + + if(zStrd3 == 1 && xStrd3 <= 1 && yStrd3 <= 1) + for (uint i3 = 0; i3 < zAxis3; ++i3) + z2[i3] = OpType::op(x2[xStrd3 ? i3 : 0], y2[yStrd3 ? i3 : 0]); + else + for (uint i3 = 0; i3 < zAxis3; ++i3) + z2[i3 * zStrd3] = OpType::op(x2[i3 * xStrd3], y2[i3 * yStrd3]); + } + } + } + }; + samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1); + } + break; + + case 5: { + + auto func = PRAGMA_THREADS_FOR_3D { + + for (auto i0 = start_x; i0 < stop_x; ++i0) { + for (auto i1 = start_y; i1 < stop_y; ++i1) { + for (auto i2 = start_z; i2 < stop_z; ++i2) { + for (uint i3 = 0; i3 < zAxis3; ++i3) { + + auto x3 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2 + i3 * xStrd3; + auto y3 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2 + i3 * yStrd3; + auto z3 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2 + i3 * zStrd3; + + if(zStrd4 == 1 && xStrd4 <= 1 && yStrd4 <= 1) + for (uint i4 = 0; i4 < zAxis4; ++i4) + z3[i4] = OpType::op(x3[xStrd4 ? i4 : 0], y3[yStrd4 ? i4 : 0]); + else + for (uint i4 = 0; i4 < zAxis4; ++i4) + z3[i4 * zStrd4] = OpType::op(x3[i4 * xStrd4], y3[i4 * yStrd4]); + } + } + } + } + }; + samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1); + } + break; + + default: { + + auto func = PRAGMA_THREADS_FOR{ + + Nd4jLong xCoords[MAX_RANK], yCoords[MAX_RANK], zCoords[MAX_RANK]; + + for (auto i = start; i < stop; ++i) { + + shape::index2coords(i, zShapeInfo, zCoords); + + for (uint j = 0; j < rank; ++j) { + xCoords[j] = shape::sizeAt(xShapeInfo, j) == 1 ? 0 : zCoords[j]; + yCoords[j] = shape::sizeAt(yShapeInfo, j) == 1 ? 0 : zCoords[j]; + } + + const auto xOffset = shape::getOffset(xShapeInfo, xCoords); + const auto yOffset = shape::getOffset(yShapeInfo, yCoords); + const auto zOffset = shape::getOffset(zShapeInfo, zCoords); + + z[zOffset] = OpType::op(x[xOffset], y[yOffset]); + } + }; + + samediff::Threads::parallel_for(func, 0, shape::length(zShapeInfo)); + } } +} + +} } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/broadcasting_bool.cpp b/libnd4j/include/loops/cpu/broadcasting_bool.cpp index 55f681355..418006794 100644 --- a/libnd4j/include/loops/cpu/broadcasting_bool.cpp +++ b/libnd4j/include/loops/cpu/broadcasting_bool.cpp @@ -63,6 +63,16 @@ namespace functions { zTadOffset, start, stop), BROADCAST_BOOL_OPS); } + template + void BroadcastBool::exec(const int opNum, + const void *x, const Nd4jLong *xShapeInfo, + const void *y, const Nd4jLong *yShapeInfo, + void *z, const Nd4jLong *zShapeInfo, + void* extraParams) { + + DISPATCH_BY_OPNUM_TT(exec, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams), BROADCAST_BOOL_OPS); + } + template void BroadcastBool::execInverse(const int opNum, void *x, @@ -267,8 +277,203 @@ namespace functions { } } +//////////////////////////////////////////////////////////////////////// +template +template +void BroadcastBool::exec(const void *vx, const Nd4jLong *xShapeInfo, + const void *vy, const Nd4jLong *yShapeInfo, + void *vz, const Nd4jLong *zShapeInfo, + void *vextraParams) { - template + const X* x = reinterpret_cast(vx); + const X* y = reinterpret_cast(vy); + Z* z = reinterpret_cast(vz); + + X* extraParams = reinterpret_cast(vextraParams); + + const int rank = shape::rank(zShapeInfo); // xRank = yRank = zRank + const char zOrder = shape::order(zShapeInfo); + + uint xAxis0 = shape::sizeAt(xShapeInfo, zOrder == 'c' ? 0 : rank-1); + uint xAxis1 = shape::sizeAt(xShapeInfo, zOrder == 'c' ? 1 : rank-2); + uint xAxis2 = rank > 2 ? shape::sizeAt(xShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0; + uint xAxis3 = rank > 3 ? shape::sizeAt(xShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0; + uint xAxis4 = rank > 4 ? shape::sizeAt(xShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0; + Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, zOrder == 'c' ? 0 : rank-1); + Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, zOrder == 'c' ? 1 : rank-2); + Nd4jLong xStrd2 = rank > 2 ? shape::strideAt(xShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0; + Nd4jLong xStrd3 = rank > 3 ? shape::strideAt(xShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0; + Nd4jLong xStrd4 = rank > 4 ? shape::strideAt(xShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0; + + uint yAxis0 = shape::sizeAt(yShapeInfo, zOrder == 'c' ? 0 : rank-1); + uint yAxis1 = shape::sizeAt(yShapeInfo, zOrder == 'c' ? 1 : rank-2); + uint yAxis2 = rank > 2 ? shape::sizeAt(yShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0; + uint yAxis3 = rank > 3 ? shape::sizeAt(yShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0; + uint yAxis4 = rank > 4 ? shape::sizeAt(yShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0; + Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, zOrder == 'c' ? 0 : rank-1); + Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, zOrder == 'c' ? 1 : rank-2); + Nd4jLong yStrd2 = rank > 2 ? shape::strideAt(yShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0; + Nd4jLong yStrd3 = rank > 3 ? shape::strideAt(yShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0; + Nd4jLong yStrd4 = rank > 4 ? shape::strideAt(yShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0; + + uint zAxis0 = shape::sizeAt(zShapeInfo, zOrder == 'c' ? 0 : rank-1); + uint zAxis1 = shape::sizeAt(zShapeInfo, zOrder == 'c' ? 1 : rank-2); + uint zAxis2 = rank > 2 ? shape::sizeAt(zShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0; + uint zAxis3 = rank > 3 ? shape::sizeAt(zShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0; + uint zAxis4 = rank > 4 ? shape::sizeAt(zShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0; + Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, zOrder == 'c' ? 0 : rank-1); + Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, zOrder == 'c' ? 1 : rank-2); + Nd4jLong zStrd2 = rank > 2 ? shape::strideAt(zShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0; + Nd4jLong zStrd3 = rank > 3 ? shape::strideAt(zShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0; + Nd4jLong zStrd4 = rank > 4 ? shape::strideAt(zShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0; + + switch (rank) { + + case 1: { + + auto func = PRAGMA_THREADS_FOR{ + + if(zStrd0 == 1 && xStrd0 <= 1 && yStrd0 <= 1) + for (auto i0 = start; i0 < stop; ++i0) + z[i0] = OpType::op(x[xStrd0 ? i0 : 0], y[yStrd0 ? i0 : 0], extraParams); + else + for (auto i0 = start; i0 < stop; ++i0) + z[i0 * zStrd0] = OpType::op(x[i0 * xStrd0], y[i0 * yStrd0], extraParams); + }; + samediff::Threads::parallel_tad(func, 0, zAxis0); + } + break; + + case 2: { + + auto func = PRAGMA_THREADS_FOR{ + + for (auto i0 = start; i0 < stop; ++i0) { + + auto x0 = x + i0 * xStrd0; + auto y0 = y + i0 * yStrd0; + auto z0 = z + i0 * zStrd0; + + if(zStrd1 == 1 && xStrd1 <= 1 && yStrd1 <= 1) + for (uint i1 = 0; i1 < zAxis1; ++i1) + z0[i1] = OpType::op(x0[xStrd1 ? i1 : 0], y0[yStrd1 ? i1 : 0], extraParams); + else + for (uint i1 = 0; i1 < zAxis1; ++i1) + z0[i1 * zStrd1] = OpType::op(x0[i1 * xStrd1], y0[i1 * yStrd1], extraParams); + } + }; + samediff::Threads::parallel_tad(func, 0, zAxis0); + } + break; + + case 3: { + + + auto func = PRAGMA_THREADS_FOR_2D { + + for (auto i0 = start_x; i0 < stop_x; ++i0) { + for (auto i1 = start_y; i1 < stop_y; ++i1) { + + auto x1 = x + i0 * xStrd0 + i1 * xStrd1; + auto y1 = y + i0 * yStrd0 + i1 * yStrd1; + auto z1 = z + i0 * zStrd0 + i1 * zStrd1; + + if(zStrd2 == 1 && xStrd2 <= 1 && yStrd2 <= 1) + for (uint i2 = 0; i2 < zAxis2; ++i2) + z1[i2] = OpType::op(x1[xStrd2 ? i2 : 0], y1[yStrd2 ? i2 : 0], extraParams); + else + for (uint i2 = 0; i2 < zAxis2; ++i2) + z1[i2 * zStrd2] = OpType::op(x1[i2 * xStrd2], y1[i2 * yStrd2], extraParams); + } + } + }; + samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1); + } + break; + + case 4: { + + auto func = PRAGMA_THREADS_FOR_3D { + + for (auto i0 = start_x; i0 < stop_x; ++i0) { + for (auto i1 = start_y; i1 < stop_y; ++i1) { + for (auto i2 = start_z; i2 < stop_z; ++i2) { + + auto x2 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2; + auto y2 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2; + auto z2 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2; + + if(zStrd3 == 1 && xStrd3 <= 1 && yStrd3 <= 1) + for (uint i3 = 0; i3 < zAxis3; ++i3) + z2[i3] = OpType::op(x2[xStrd3 ? i3 : 0], y2[yStrd3 ? i3 : 0], extraParams); + else + for (uint i3 = 0; i3 < zAxis3; ++i3) + z2[i3 * zStrd3] = OpType::op(x2[i3 * xStrd3], y2[i3 * yStrd3], extraParams); + } + } + } + }; + samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1); + } + break; + + case 5: { + + auto func = PRAGMA_THREADS_FOR_3D { + + for (auto i0 = start_x; i0 < stop_x; ++i0) { + for (auto i1 = start_y; i1 < stop_y; ++i1) { + for (auto i2 = start_z; i2 < stop_z; ++i2) { + for (uint i3 = 0; i3 < zAxis3; ++i3) { + + auto x3 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2 + i3 * xStrd3; + auto y3 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2 + i3 * yStrd3; + auto z3 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2 + i3 * zStrd3; + + if(zStrd4 == 1 && xStrd4 <= 1 && yStrd4 <= 1) + for (uint i4 = 0; i4 < zAxis4; ++i4) + z3[i4] = OpType::op(x3[xStrd4 ? i4 : 0], y3[yStrd4 ? i4 : 0], extraParams); + else + for (uint i4 = 0; i4 < zAxis4; ++i4) + z3[i4 * zStrd4] = OpType::op(x3[i4 * xStrd4], y3[i4 * yStrd4], extraParams); + } + } + } + } + }; + samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1); + } + break; + + default: { + + auto func = PRAGMA_THREADS_FOR{ + + Nd4jLong xCoords[MAX_RANK], yCoords[MAX_RANK], zCoords[MAX_RANK]; + + for (auto i = start; i < stop; ++i) { + + shape::index2coords(i, zShapeInfo, zCoords); + + for (uint j = 0; j < rank; ++j) { + xCoords[j] = shape::sizeAt(xShapeInfo, j) == 1 ? 0 : zCoords[j]; + yCoords[j] = shape::sizeAt(yShapeInfo, j) == 1 ? 0 : zCoords[j]; + } + + const auto xOffset = shape::getOffset(xShapeInfo, xCoords); + const auto yOffset = shape::getOffset(yShapeInfo, yCoords); + const auto zOffset = shape::getOffset(zShapeInfo, zCoords); + + z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams); + } + }; + + samediff::Threads::parallel_for(func, 0, shape::length(zShapeInfo)); + } + } +} + + template template void BroadcastBool::execInverse(void *vx, Nd4jLong *xShapeInfo, diff --git a/libnd4j/include/loops/cpu/broadcasting_int.cpp b/libnd4j/include/loops/cpu/broadcasting_int.cpp index 231c56946..c3ccd024e 100644 --- a/libnd4j/include/loops/cpu/broadcasting_int.cpp +++ b/libnd4j/include/loops/cpu/broadcasting_int.cpp @@ -61,6 +61,15 @@ namespace functions { zTadOffset, start, stop), BROADCAST_INT_OPS); } + template + void BroadcastInt::exec(const int opNum, + const void *x, const Nd4jLong *xShapeInfo, + const void *y, const Nd4jLong *yShapeInfo, + void *z, const Nd4jLong *zShapeInfo) { + + DISPATCH_BY_OPNUM_T(exec, PARAMS(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo), BROADCAST_INT_OPS); + } + template void BroadcastInt::execInverse(const int opNum, void *x, @@ -430,6 +439,200 @@ namespace functions { } } - BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT BroadcastInt, , INTEGER_TYPES); + +//////////////////////////////////////////////////////////////////////// +template +template +void BroadcastInt::exec(const void *vx, const Nd4jLong *xShapeInfo, + const void *vy, const Nd4jLong *yShapeInfo, + void *vz, const Nd4jLong *zShapeInfo) { + + const X* x = reinterpret_cast(vx); + const X* y = reinterpret_cast(vy); + X* z = reinterpret_cast(vz); + + const int rank = shape::rank(zShapeInfo); // xRank = yRank = zRank + const char zOrder = shape::order(zShapeInfo); + + uint xAxis0 = shape::sizeAt(xShapeInfo, zOrder == 'c' ? 0 : rank-1); + uint xAxis1 = shape::sizeAt(xShapeInfo, zOrder == 'c' ? 1 : rank-2); + uint xAxis2 = rank > 2 ? shape::sizeAt(xShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0; + uint xAxis3 = rank > 3 ? shape::sizeAt(xShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0; + uint xAxis4 = rank > 4 ? shape::sizeAt(xShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0; + Nd4jLong xStrd0 = shape::strideAt(xShapeInfo, zOrder == 'c' ? 0 : rank-1); + Nd4jLong xStrd1 = shape::strideAt(xShapeInfo, zOrder == 'c' ? 1 : rank-2); + Nd4jLong xStrd2 = rank > 2 ? shape::strideAt(xShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0; + Nd4jLong xStrd3 = rank > 3 ? shape::strideAt(xShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0; + Nd4jLong xStrd4 = rank > 4 ? shape::strideAt(xShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0; + + uint yAxis0 = shape::sizeAt(yShapeInfo, zOrder == 'c' ? 0 : rank-1); + uint yAxis1 = shape::sizeAt(yShapeInfo, zOrder == 'c' ? 1 : rank-2); + uint yAxis2 = rank > 2 ? shape::sizeAt(yShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0; + uint yAxis3 = rank > 3 ? shape::sizeAt(yShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0; + uint yAxis4 = rank > 4 ? shape::sizeAt(yShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0; + Nd4jLong yStrd0 = shape::strideAt(yShapeInfo, zOrder == 'c' ? 0 : rank-1); + Nd4jLong yStrd1 = shape::strideAt(yShapeInfo, zOrder == 'c' ? 1 : rank-2); + Nd4jLong yStrd2 = rank > 2 ? shape::strideAt(yShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0; + Nd4jLong yStrd3 = rank > 3 ? shape::strideAt(yShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0; + Nd4jLong yStrd4 = rank > 4 ? shape::strideAt(yShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0; + + uint zAxis0 = shape::sizeAt(zShapeInfo, zOrder == 'c' ? 0 : rank-1); + uint zAxis1 = shape::sizeAt(zShapeInfo, zOrder == 'c' ? 1 : rank-2); + uint zAxis2 = rank > 2 ? shape::sizeAt(zShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0; + uint zAxis3 = rank > 3 ? shape::sizeAt(zShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0; + uint zAxis4 = rank > 4 ? shape::sizeAt(zShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0; + Nd4jLong zStrd0 = shape::strideAt(zShapeInfo, zOrder == 'c' ? 0 : rank-1); + Nd4jLong zStrd1 = shape::strideAt(zShapeInfo, zOrder == 'c' ? 1 : rank-2); + Nd4jLong zStrd2 = rank > 2 ? shape::strideAt(zShapeInfo, zOrder == 'c' ? 2 : rank - 3) : 0; + Nd4jLong zStrd3 = rank > 3 ? shape::strideAt(zShapeInfo, zOrder == 'c' ? 3 : rank - 4) : 0; + Nd4jLong zStrd4 = rank > 4 ? shape::strideAt(zShapeInfo, zOrder == 'c' ? 4 : rank - 5) : 0; + + switch (rank) { + + case 1: { + + auto func = PRAGMA_THREADS_FOR{ + + if(zStrd0 == 1 && xStrd0 <= 1 && yStrd0 <= 1) + for (auto i0 = start; i0 < stop; ++i0) + z[i0] = OpType::op(x[xStrd0 ? i0 : 0], y[yStrd0 ? i0 : 0]); + else + for (auto i0 = start; i0 < stop; ++i0) + z[i0 * zStrd0] = OpType::op(x[i0 * xStrd0], y[i0 * yStrd0]); + }; + samediff::Threads::parallel_tad(func, 0, zAxis0); + } + break; + + case 2: { + + auto func = PRAGMA_THREADS_FOR{ + + for (auto i0 = start; i0 < stop; ++i0) { + + auto x0 = x + i0 * xStrd0; + auto y0 = y + i0 * yStrd0; + auto z0 = z + i0 * zStrd0; + + if(zStrd1 == 1 && xStrd1 <= 1 && yStrd1 <= 1) + for (uint i1 = 0; i1 < zAxis1; ++i1) + z0[i1] = OpType::op(x0[xStrd1 ? i1 : 0], y0[yStrd1 ? i1 : 0]); + else + for (uint i1 = 0; i1 < zAxis1; ++i1) + z0[i1 * zStrd1] = OpType::op(x0[i1 * xStrd1], y0[i1 * yStrd1]); + } + }; + samediff::Threads::parallel_tad(func, 0, zAxis0); + } + break; + + case 3: { + + + auto func = PRAGMA_THREADS_FOR_2D { + + for (auto i0 = start_x; i0 < stop_x; ++i0) { + for (auto i1 = start_y; i1 < stop_y; ++i1) { + + auto x1 = x + i0 * xStrd0 + i1 * xStrd1; + auto y1 = y + i0 * yStrd0 + i1 * yStrd1; + auto z1 = z + i0 * zStrd0 + i1 * zStrd1; + + if(zStrd2 == 1 && xStrd2 <= 1 && yStrd2 <= 1) + for (uint i2 = 0; i2 < zAxis2; ++i2) + z1[i2] = OpType::op(x1[xStrd2 ? i2 : 0], y1[yStrd2 ? i2 : 0]); + else + for (uint i2 = 0; i2 < zAxis2; ++i2) + z1[i2 * zStrd2] = OpType::op(x1[i2 * xStrd2], y1[i2 * yStrd2]); + } + } + }; + samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1); + } + break; + + case 4: { + + auto func = PRAGMA_THREADS_FOR_3D { + + for (auto i0 = start_x; i0 < stop_x; ++i0) { + for (auto i1 = start_y; i1 < stop_y; ++i1) { + for (auto i2 = start_z; i2 < stop_z; ++i2) { + + auto x2 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2; + auto y2 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2; + auto z2 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2; + + if(zStrd3 == 1 && xStrd3 <= 1 && yStrd3 <= 1) + for (uint i3 = 0; i3 < zAxis3; ++i3) + z2[i3] = OpType::op(x2[xStrd3 ? i3 : 0], y2[yStrd3 ? i3 : 0]); + else + for (uint i3 = 0; i3 < zAxis3; ++i3) + z2[i3 * zStrd3] = OpType::op(x2[i3 * xStrd3], y2[i3 * yStrd3]); + } + } + } + }; + samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1); + } + break; + + case 5: { + + auto func = PRAGMA_THREADS_FOR_3D { + + for (auto i0 = start_x; i0 < stop_x; ++i0) { + for (auto i1 = start_y; i1 < stop_y; ++i1) { + for (auto i2 = start_z; i2 < stop_z; ++i2) { + for (uint i3 = 0; i3 < zAxis3; ++i3) { + + auto x3 = x + i0 * xStrd0 + i1 * xStrd1 + i2 * xStrd2 + i3 * xStrd3; + auto y3 = y + i0 * yStrd0 + i1 * yStrd1 + i2 * yStrd2 + i3 * yStrd3; + auto z3 = z + i0 * zStrd0 + i1 * zStrd1 + i2 * zStrd2 + i3 * zStrd3; + + if(zStrd4 == 1 && xStrd4 <= 1 && yStrd4 <= 1) + for (uint i4 = 0; i4 < zAxis4; ++i4) + z3[i4] = OpType::op(x3[xStrd4 ? i4 : 0], y3[yStrd4 ? i4 : 0]); + else + for (uint i4 = 0; i4 < zAxis4; ++i4) + z3[i4 * zStrd4] = OpType::op(x3[i4 * xStrd4], y3[i4 * yStrd4]); + } + } + } + } + }; + samediff::Threads::parallel_for(func, 0,zAxis0,1, 0,zAxis1,1, 0,zAxis2,1); + } + break; + + default: { + + auto func = PRAGMA_THREADS_FOR{ + + Nd4jLong xCoords[MAX_RANK], yCoords[MAX_RANK], zCoords[MAX_RANK]; + + for (auto i = start; i < stop; ++i) { + + shape::index2coords(i, zShapeInfo, zCoords); + + for (uint j = 0; j < rank; ++j) { + xCoords[j] = shape::sizeAt(xShapeInfo, j) == 1 ? 0 : zCoords[j]; + yCoords[j] = shape::sizeAt(yShapeInfo, j) == 1 ? 0 : zCoords[j]; + } + + const auto xOffset = shape::getOffset(xShapeInfo, xCoords); + const auto yOffset = shape::getOffset(yShapeInfo, yCoords); + const auto zOffset = shape::getOffset(zShapeInfo, zCoords); + + z[zOffset] = OpType::op(x[xOffset], y[yOffset]); + } + }; + + samediff::Threads::parallel_for(func, 0, shape::length(zShapeInfo)); + } } +} + +BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT BroadcastInt, , INTEGER_TYPES); +} } \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/compilation_units/TrueBroadcastHelper_0.cpp b/libnd4j/include/loops/cpu/compilation_units/TrueBroadcastHelper_0.cpp deleted file mode 100644 index 7bb2e8d81..000000000 --- a/libnd4j/include/loops/cpu/compilation_units/TrueBroadcastHelper_0.cpp +++ /dev/null @@ -1,27 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2019-2020 Konduit K.K. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// @author raver119@gmail.com -// - -#include "../TrueBroadcastHelper.hpp" - -namespace sd { - namespace helpers { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_9); - } -} diff --git a/libnd4j/include/loops/cpu/compilation_units/TrueBroadcastHelper_1.cpp b/libnd4j/include/loops/cpu/compilation_units/TrueBroadcastHelper_1.cpp deleted file mode 100644 index 7cacbe035..000000000 --- a/libnd4j/include/loops/cpu/compilation_units/TrueBroadcastHelper_1.cpp +++ /dev/null @@ -1,27 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2019-2020 Konduit K.K. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// @author raver119@gmail.com -// - -#include "../TrueBroadcastHelper.hpp" - -namespace sd { - namespace helpers { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_0); - } -} diff --git a/libnd4j/include/loops/cpu/compilation_units/TrueBroadcastHelper_2.cpp b/libnd4j/include/loops/cpu/compilation_units/TrueBroadcastHelper_2.cpp deleted file mode 100644 index c98abca4a..000000000 --- a/libnd4j/include/loops/cpu/compilation_units/TrueBroadcastHelper_2.cpp +++ /dev/null @@ -1,27 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2019-2020 Konduit K.K. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// @author raver119@gmail.com -// - -#include "../TrueBroadcastHelper.hpp" - -namespace sd { - namespace helpers { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_1); - } -} diff --git a/libnd4j/include/loops/cpu/compilation_units/TrueBroadcastHelper_3.cpp b/libnd4j/include/loops/cpu/compilation_units/TrueBroadcastHelper_3.cpp deleted file mode 100644 index 10d29053c..000000000 --- a/libnd4j/include/loops/cpu/compilation_units/TrueBroadcastHelper_3.cpp +++ /dev/null @@ -1,27 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2019-2020 Konduit K.K. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// @author raver119@gmail.com -// - -#include "../TrueBroadcastHelper.hpp" - -namespace sd { - namespace helpers { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_2); - } -} diff --git a/libnd4j/include/loops/cpu/compilation_units/TrueBroadcastHelper_4.cpp b/libnd4j/include/loops/cpu/compilation_units/TrueBroadcastHelper_4.cpp deleted file mode 100644 index ef72d78f8..000000000 --- a/libnd4j/include/loops/cpu/compilation_units/TrueBroadcastHelper_4.cpp +++ /dev/null @@ -1,27 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2019-2020 Konduit K.K. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// @author raver119@gmail.com -// - -#include "../TrueBroadcastHelper.hpp" - -namespace sd { - namespace helpers { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_3); - } -} diff --git a/libnd4j/include/loops/cpu/compilation_units/TrueBroadcastHelper_5.cpp b/libnd4j/include/loops/cpu/compilation_units/TrueBroadcastHelper_5.cpp deleted file mode 100644 index c90b23f8e..000000000 --- a/libnd4j/include/loops/cpu/compilation_units/TrueBroadcastHelper_5.cpp +++ /dev/null @@ -1,27 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2019-2020 Konduit K.K. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// @author raver119@gmail.com -// - -#include "../TrueBroadcastHelper.hpp" - -namespace sd { - namespace helpers { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_4); - } -} diff --git a/libnd4j/include/loops/cpu/compilation_units/TrueBroadcastHelper_6.cpp b/libnd4j/include/loops/cpu/compilation_units/TrueBroadcastHelper_6.cpp deleted file mode 100644 index ba6bcb5a3..000000000 --- a/libnd4j/include/loops/cpu/compilation_units/TrueBroadcastHelper_6.cpp +++ /dev/null @@ -1,27 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2019-2020 Konduit K.K. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// @author raver119@gmail.com -// - -#include "../TrueBroadcastHelper.hpp" - -namespace sd { - namespace helpers { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_5); - } -} diff --git a/libnd4j/include/loops/cpu/compilation_units/TrueBroadcastHelper_7.cpp b/libnd4j/include/loops/cpu/compilation_units/TrueBroadcastHelper_7.cpp deleted file mode 100644 index e8c62f3b7..000000000 --- a/libnd4j/include/loops/cpu/compilation_units/TrueBroadcastHelper_7.cpp +++ /dev/null @@ -1,27 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2019-2020 Konduit K.K. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// @author raver119@gmail.com -// - -#include "../TrueBroadcastHelper.hpp" - -namespace sd { - namespace helpers { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_6); - } -} diff --git a/libnd4j/include/loops/cpu/compilation_units/TrueBroadcastHelper_8.cpp b/libnd4j/include/loops/cpu/compilation_units/TrueBroadcastHelper_8.cpp deleted file mode 100644 index 7771d89b8..000000000 --- a/libnd4j/include/loops/cpu/compilation_units/TrueBroadcastHelper_8.cpp +++ /dev/null @@ -1,27 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2019-2020 Konduit K.K. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// @author raver119@gmail.com -// - -#include "../TrueBroadcastHelper.hpp" - -namespace sd { - namespace helpers { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_7); - } -} diff --git a/libnd4j/include/loops/cpu/compilation_units/TrueBroadcastHelper_9.cpp b/libnd4j/include/loops/cpu/compilation_units/TrueBroadcastHelper_9.cpp deleted file mode 100644 index 1c1cf1cb0..000000000 --- a/libnd4j/include/loops/cpu/compilation_units/TrueBroadcastHelper_9.cpp +++ /dev/null @@ -1,27 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2019-2020 Konduit K.K. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// @author raver119@gmail.com -// - -#include "../TrueBroadcastHelper.hpp" - -namespace sd { - namespace helpers { - BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_8); - } -} diff --git a/libnd4j/include/loops/cpu/compilation_units/TrueBroadcastHelper_bool.cpp b/libnd4j/include/loops/cpu/compilation_units/TrueBroadcastHelper_bool.cpp deleted file mode 100644 index 23013f8a0..000000000 --- a/libnd4j/include/loops/cpu/compilation_units/TrueBroadcastHelper_bool.cpp +++ /dev/null @@ -1,27 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2019-2020 Konduit K.K. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// @author raver119@gmail.com -// - -#include "../TrueBroadcastHelper.hpp" - -namespace sd { - namespace helpers { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastBoolHelper, , LIBND4J_TYPES, BOOL_TYPES); - } -} diff --git a/libnd4j/include/loops/cpu/compilation_units/TrueBroadcastHelper_int.cpp b/libnd4j/include/loops/cpu/compilation_units/TrueBroadcastHelper_int.cpp deleted file mode 100644 index 00d4b6a93..000000000 --- a/libnd4j/include/loops/cpu/compilation_units/TrueBroadcastHelper_int.cpp +++ /dev/null @@ -1,27 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2019-2020 Konduit K.K. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// @author raver119@gmail.com -// - -#include "../TrueBroadcastHelper.hpp" - -namespace sd { - namespace helpers { - BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastIntHelper, , INTEGER_TYPES); - } -} diff --git a/libnd4j/include/loops/cuda/TrueBroadcastHelper.cu b/libnd4j/include/loops/cuda/TrueBroadcastHelper.cu deleted file mode 100644 index 9a775b2e7..000000000 --- a/libnd4j/include/loops/cuda/TrueBroadcastHelper.cu +++ /dev/null @@ -1,312 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// @author Yurii Shyrma (iuriish@yahoo.com) -// - -// #include -#include -#include -#include -#include -#include -#include -// #include -// #include - -using namespace simdOps; - -namespace sd { -namespace helpers { - -//////////////////////////////////////////////////////////////////////// -template -__global__ static void trueBroadcastCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo) { - - const auto x = reinterpret_cast(vx); - const auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - - __shared__ int xRank, yRank, zRank; - __shared__ Nd4jLong zLen, totalThreads, *sharedMem; // xLen == zLen, except when xRank = 1, in this case zLen = 2*xLen - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); - - xRank = shape::rank(xShapeInfo); - yRank = shape::rank(yShapeInfo); - zRank = shape::rank(zShapeInfo); - - zLen = shape::length(zShapeInfo); - totalThreads = gridDim.x * blockDim.x; - } - __syncthreads(); - - auto xCoords = sharedMem + threadIdx.x * (xRank + yRank + zRank); - auto yCoords = xCoords + xRank; - auto zCoords = yCoords + yRank; - - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - - for (Nd4jLong i = tid; i < zLen; i += totalThreads) { - - shape::index2coords(i, zShapeInfo, zCoords); - - for(int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) { - - if(ix >= 0) - if(xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) - xCoords[ix--] = zCoords[iz]; - else - xCoords[ix--] = 0; - - if(iy >= 0) - if(yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) - yCoords[iy--] = zCoords[iz]; - else - yCoords[iy--] = 0; - } - - const auto xOffset = shape::getOffset(xShapeInfo, xCoords); - const auto zOffset = shape::getOffset(zShapeInfo, zCoords); - const auto yOffset = shape::getOffset(yShapeInfo, yCoords); - - z[zOffset] = OpType::op(x[xOffset], y[yOffset], nullptr); - } -} - -//////////////////////////////////////////////////////////////////////// -template -template -void TrueBroadcastHelper::execLauncher(dim3 launchDims, cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo) { - - trueBroadcastCuda<<>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo); -} - -////////////////////////////////////////////////////////////////////////// -template -void TrueBroadcastHelper::exec(const sd::broadcast::Ops opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) { - - dim3 launchDims; - - launchDims.y = MAX_NUM_THREADS / 8; // threadsPerBlock - launchDims.x = (zArr.lengthOf() + launchDims.y - 1) / launchDims.y; // blocksPerGrid - launchDims.z = sizeof(Nd4jLong) * launchDims.y * (xArr.rankOf() + yArr.rankOf() + zArr.rankOf()) + 128; // sharedMe - - PointersManager manager(xArr.getContext(), "TrueBroadcastHelper::exec"); - - NDArray::prepareSpecialUse({&zArr}, {&xArr, &yArr}); - - DISPATCH_BY_OPNUM_TTT(execLauncher, PARAMS(launchDims, xArr.getContext()->getCudaStream(), xArr.getSpecialBuffer(), xArr.getSpecialShapeInfo(), yArr.getSpecialBuffer(), yArr.getSpecialShapeInfo(), zArr.specialBuffer(), zArr.specialShapeInfo()), OPS_A(BROADCAST_OPS)); - - NDArray::registerSpecialUse({&zArr}, {&xArr, &yArr}); - - manager.synchronize(); -} - -//////////////////////////////////////////////////////////////////////// -template -__global__ static void trueBroadcastBoolCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo) { - - const auto x = reinterpret_cast(vx); - const auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - - __shared__ int xRank, yRank, zRank; - __shared__ Nd4jLong zLen, totalThreads, *sharedMem; // xLen == zLen, except when xRank = 1, in this case zLen = 2*xLen - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); - - xRank = shape::rank(xShapeInfo); - yRank = shape::rank(yShapeInfo); - zRank = shape::rank(zShapeInfo); - - zLen = shape::length(zShapeInfo); - totalThreads = gridDim.x * blockDim.x; - } - __syncthreads(); - - auto xCoords = sharedMem + threadIdx.x * (xRank + yRank + zRank); - auto yCoords = xCoords + xRank; - auto zCoords = yCoords + yRank; - - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - - for (Nd4jLong i = tid; i < zLen; i += totalThreads) { - - shape::index2coords(i, zShapeInfo, zCoords); - - for(int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) { - - if(ix >= 0) - if(xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) - xCoords[ix--] = zCoords[iz]; - else - xCoords[ix--] = 0; - - if(iy >= 0) - if(yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) - yCoords[iy--] = zCoords[iz]; - else - yCoords[iy--] = 0; - } - - const auto xOffset = shape::getOffset(xShapeInfo, xCoords); - const auto zOffset = shape::getOffset(zShapeInfo, zCoords); - const auto yOffset = shape::getOffset(yShapeInfo, yCoords); - - z[zOffset] = OpType::op(x[xOffset], y[yOffset], nullptr); - } -} - -//////////////////////////////////////////////////////////////////////// -template -template -void TrueBroadcastBoolHelper::execLauncher(dim3 launchDims, cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo) { - - trueBroadcastBoolCuda<<>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo); -} - -////////////////////////////////////////////////////////////////////////// -template -void TrueBroadcastBoolHelper::exec(const sd::broadcast::BoolOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) { - - dim3 launchDims; - - launchDims.y = MAX_NUM_THREADS / 8; // threadsPerBlock - launchDims.x = (zArr.lengthOf() + launchDims.y - 1) / launchDims.y; // blocksPerGrid - launchDims.z = sizeof(Nd4jLong) * launchDims.y * (xArr.rankOf() + yArr.rankOf() + zArr.rankOf()) + 128; // sharedMe - - PointersManager manager(xArr.getContext(), "TrueBroadcastBoolHelper::exec"); - - NDArray::prepareSpecialUse({&zArr}, {&xArr, &yArr}); - - DISPATCH_BY_OPNUM_TT(execLauncher, PARAMS(launchDims, xArr.getContext()->getCudaStream(), xArr.getSpecialBuffer(), xArr.getSpecialShapeInfo(), yArr.getSpecialBuffer(), yArr.getSpecialShapeInfo(), zArr.specialBuffer(), zArr.specialShapeInfo()), OPS_A(BROADCAST_BOOL_OPS)); - - NDArray::registerSpecialUse({&zArr}, {&xArr, &yArr}); - - manager.synchronize(); -} - -//////////////////////////////////////////////////////////////////////// -template -__global__ static void trueBroadcastIntCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo) { - - const auto x = reinterpret_cast(vx); - const auto y = reinterpret_cast(vy); - auto z = reinterpret_cast(vz); - - __shared__ int xRank, yRank, zRank; - __shared__ Nd4jLong zLen, totalThreads, *sharedMem; // xLen == zLen, except when xRank = 1, in this case zLen = 2*xLen - - if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); - - xRank = shape::rank(xShapeInfo); - yRank = shape::rank(yShapeInfo); - zRank = shape::rank(zShapeInfo); - - zLen = shape::length(zShapeInfo); - totalThreads = gridDim.x * blockDim.x; - } - __syncthreads(); - - auto xCoords = sharedMem + threadIdx.x * (xRank + yRank + zRank); - auto yCoords = xCoords + xRank; - auto zCoords = yCoords + yRank; - - const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - - for (Nd4jLong i = tid; i < zLen; i += totalThreads) { - - shape::index2coords(i, zShapeInfo, zCoords); - - for(int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) { - - if(ix >= 0) - if(xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) - xCoords[ix--] = zCoords[iz]; - else - xCoords[ix--] = 0; - - if(iy >= 0) - if(yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) - yCoords[iy--] = zCoords[iz]; - else - yCoords[iy--] = 0; - } - - const auto xOffset = shape::getOffset(xShapeInfo, xCoords); - const auto zOffset = shape::getOffset(zShapeInfo, zCoords); - const auto yOffset = shape::getOffset(yShapeInfo, yCoords); - - z[zOffset] = OpType::op(x[xOffset], y[yOffset]); - } -} - -//////////////////////////////////////////////////////////////////////// -template -template -void TrueBroadcastIntHelper::execLauncher(dim3 launchDims, cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo) { - - trueBroadcastIntCuda<<>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo); -} - -////////////////////////////////////////////////////////////////////////// -template -void TrueBroadcastIntHelper::exec(const sd::broadcast::IntOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) { - - dim3 launchDims; - - launchDims.y = MAX_NUM_THREADS / 8; // threadsPerBlock - launchDims.x = (zArr.lengthOf() + launchDims.y - 1) / launchDims.y; // blocksPerGrid - launchDims.z = sizeof(Nd4jLong) * launchDims.y * (xArr.rankOf() + yArr.rankOf() + zArr.rankOf()) + 128; // sharedMe - - PointersManager manager(xArr.getContext(), "TrueBroadcastIntHelper::exec"); - - NDArray::prepareSpecialUse({&zArr}, {&xArr, &yArr}); - - DISPATCH_BY_OPNUM_T(execLauncher, PARAMS(launchDims, xArr.getContext()->getCudaStream(), xArr.getSpecialBuffer(), xArr.getSpecialShapeInfo(), yArr.getSpecialBuffer(), yArr.getSpecialShapeInfo(), zArr.specialBuffer(), zArr.specialShapeInfo()), OPS_A(BROADCAST_INT_OPS)); - - NDArray::registerSpecialUse({&zArr}, {&xArr, &yArr}); - - manager.synchronize(); -} - - - -BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_0); -BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_1); -BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_2); -BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_3); -BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_4); -BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_5); -BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_6); -BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_7); -BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_8); -BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_9); - -BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastBoolHelper, , LIBND4J_TYPES, BOOL_TYPES); - -BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastIntHelper, , INTEGER_TYPES); - -} -} \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/broadcasting.chpp b/libnd4j/include/loops/cuda/broadcasting.chpp index f8d232349..f54386975 100644 --- a/libnd4j/include/loops/cuda/broadcasting.chpp +++ b/libnd4j/include/loops/cuda/broadcasting.chpp @@ -46,6 +46,15 @@ static __global__ void broadcastSimple( functions::broadcast::Broadcast::template transformCuda(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo,dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ); } +template +static __global__ void broadcastSimple(const void *x, const Nd4jLong *xShapeInfo, + const void *y, const Nd4jLong *yShapeInfo, + void *z, const Nd4jLong *zShapeInfo ) { + + functions::broadcast::Broadcast::template transformCuda(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); +} + + template static __global__ void broadcastInverseSimple( void *x, @@ -64,11 +73,11 @@ static __global__ void broadcastInverseSimple( namespace functions { namespace broadcast { - static Nd4jLong __device__ __noinline__ _getIndexOffset(Nd4jLong index, Nd4jLong *shapeInfo) { + static Nd4jLong __device__ __noinline__ getIndexOffset(Nd4jLong index, Nd4jLong *shapeInfo) { return shape::getIndexOffset(index, shapeInfo); } - static Nd4jLong __device__ __noinline__ _length(Nd4jLong *shapeInfo) { + static Nd4jLong __device__ __noinline__ length(Nd4jLong *shapeInfo) { return shape::length(shapeInfo); } @@ -78,6 +87,12 @@ namespace functions { broadcastSimple<<>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ); } + template + template + __host__ void Broadcast::intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, const void *x, const Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo, void *z, const Nd4jLong *zShapeInfo) { + broadcastSimple<<>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); + } + template __host__ void Broadcast::execBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { DISPATCH_BY_OPNUM_TTT(intermediateBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), OPS_A(BROADCAST_OPS)) @@ -85,6 +100,13 @@ namespace functions { DEBUG_KERNEL(stream, opNum); } + template + __host__ void Broadcast::execBroadcast(dim3 launchDims, cudaStream_t *stream, const int opNum, const void *x, const Nd4jLong *xShapeInfo, const void *y, const Nd4jLong *yShapeInfo, void *z, const Nd4jLong *zShapeInfo) { + DISPATCH_BY_OPNUM_TTT(intermediateBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo), OPS_A(BROADCAST_OPS)) + + DEBUG_KERNEL(stream, opNum); + } + template template __host__ void Broadcast::intermediateInverseBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { @@ -128,9 +150,9 @@ namespace functions { if (threadIdx.x == 0) { - tadLength = _length(tadOnlyShapeInfo); + tadLength = length(tadOnlyShapeInfo); tadEWS = shape::elementWiseStride(tadOnlyShapeInfo); - numTads = _length(yShapeInfo) / tadLength; + numTads = length(yShapeInfo) / tadLength; xEWS = shape::elementWiseStride(xShapeInfo); zEWS = shape::elementWiseStride(tadOnlyShapeInfoZ); } @@ -154,9 +176,9 @@ namespace functions { else { // it is expected that x and z tads and y array all have the same length for (Nd4jLong i = threadIdx.x; i < tadLength; i+= blockDim.x) { - auto xOffset = _getIndexOffset(i, xShapeInfo); - auto yOffset = _getIndexOffset(i, tadOnlyShapeInfo); - auto zOffset = _getIndexOffset(i, tadOnlyShapeInfoZ); + auto xOffset = getIndexOffset(i, xShapeInfo); + auto yOffset = getIndexOffset(i, tadOnlyShapeInfo); + auto zOffset = getIndexOffset(i, tadOnlyShapeInfoZ); rZ[zOffset] = OpType::op(x[xOffset], rY[yOffset]); } } @@ -193,9 +215,9 @@ namespace functions { __shared__ Nd4jLong zEWS; if (threadIdx.x == 0) { - tadLength = _length(tadOnlyShapeInfo); + tadLength = length(tadOnlyShapeInfo); tadEWS = shape::elementWiseStride(tadOnlyShapeInfo); - numTads = _length(xShapeInfo) / tadLength; + numTads = length(xShapeInfo) / tadLength; yEWS = shape::elementWiseStride(yShapeInfo); zEWS = shape::elementWiseStride(tadOnlyShapeInfoZ); } @@ -219,15 +241,59 @@ namespace functions { // it is expected that x and z tads and y array all have the same length for (Nd4jLong i = threadIdx.x; i < tadLength; i+= blockDim.x) { - auto xOffset = _getIndexOffset(i, tadOnlyShapeInfo); - auto yOffset = _getIndexOffset(i, yShapeInfo); - auto zOffset = _getIndexOffset(i, tadOnlyShapeInfoZ); + auto xOffset = getIndexOffset(i, tadOnlyShapeInfo); + auto yOffset = getIndexOffset(i, yShapeInfo); + auto zOffset = getIndexOffset(i, tadOnlyShapeInfoZ); rZ[zOffset] = OpType::op(rX[xOffset], y[yOffset]); } } } } +//////////////////////////////////////////////////////////////////////// +template +template +__device__ void Broadcast::transformCuda( + const void *vx, const Nd4jLong *xShapeInfo, + const void *vy, const Nd4jLong *yShapeInfo, + void *vz, const Nd4jLong *zShapeInfo) { + + const X* x = reinterpret_cast(vx); + const Y* y = reinterpret_cast(vy); + Z* z = reinterpret_cast(vz); + + __shared__ Nd4jLong zLen; + __shared__ int rank; + + if (threadIdx.x == 0) { + + zLen = shape::length(zShapeInfo); + rank = shape::rank(zShapeInfo); + } + __syncthreads(); + + + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + + Nd4jLong xCoords[MAX_RANK], yCoords[MAX_RANK], zCoords[MAX_RANK]; + + for (int i = tid; i < zLen; i += blockDim.x * gridDim.x) { + + shape::index2coords(i, zShapeInfo, zCoords); + + for (uint j = 0; j < rank; ++j) { + xCoords[j] = shape::sizeAt(xShapeInfo, j) == 1 ? 0 : zCoords[j]; + yCoords[j] = shape::sizeAt(yShapeInfo, j) == 1 ? 0 : zCoords[j]; + } + + const auto xOffset = shape::getOffset(xShapeInfo, xCoords); + const auto yOffset = shape::getOffset(yShapeInfo, yCoords); + const auto zOffset = shape::getOffset(zShapeInfo, zCoords); + + z[zOffset] = OpType::op(x[xOffset], y[yOffset]); + } +} + /* BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_0); BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT Broadcast, , PAIRWISE_TYPES_1); diff --git a/libnd4j/include/loops/cuda/broadcasting_bool.cu b/libnd4j/include/loops/cuda/broadcasting_bool.cu index 47df0fd2b..513db1b7c 100644 --- a/libnd4j/include/loops/cuda/broadcasting_bool.cu +++ b/libnd4j/include/loops/cuda/broadcasting_bool.cu @@ -47,6 +47,15 @@ static __global__ void broadcastBoolSimple( functions::broadcast::BroadcastBool::template transformCuda(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo, extraParams, dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ); } +////////////////////////////////////////////////////////////////////////// +template +static __global__ void broadcastBoolSimple(const void *x, const Nd4jLong *xShapeInfo, + const void *y, const Nd4jLong *yShapeInfo, + void *z, const Nd4jLong *zShapeInfo, + void *extraParams) { + + functions::broadcast::BroadcastBool::template transformCuda(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams); +} ////////////////////////////////////////////////////////////////////////// template static __global__ void broadcastBoolInverseSimple( @@ -64,22 +73,48 @@ static __global__ void broadcastBoolInverseSimple( } namespace functions { - namespace broadcast { -////////////////////////////////////////////////////////////////////////// - template - template - __host__ void BroadcastBool::intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { - broadcastBoolSimple<<>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ); - sd::DebugHelper::checkErrorCode(stream, "intermediateBroadcastBool(...) failed"); - } +namespace broadcast { ////////////////////////////////////////////////////////////////////////// - template - __host__ void BroadcastBool::execBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { - DISPATCH_BY_OPNUM_TT(intermediateBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), OPS_A(BROADCAST_BOOL_OPS)) +template +template +__host__ void BroadcastBool::intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { + broadcastBoolSimple<<>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ); + sd::DebugHelper::checkErrorCode(stream, "intermediateBroadcastBool(...) failed"); +} - DEBUG_KERNEL(stream, opNum); - } +////////////////////////////////////////////////////////////////////////// +template +template +__host__ void BroadcastBool::intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, + const void *x, const Nd4jLong *xShapeInfo, + const void *y, const Nd4jLong *yShapeInfo, + void *z, const Nd4jLong *zShapeInfo, + void *extraParams) { + + broadcastBoolSimple<<>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams); + sd::DebugHelper::checkErrorCode(stream, "intermediateBroadcastBool(...) failed"); +} + +////////////////////////////////////////////////////////////////////////// +template +__host__ void BroadcastBool::execBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { + + DISPATCH_BY_OPNUM_TT(intermediateBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), OPS_A(BROADCAST_BOOL_OPS)) + DEBUG_KERNEL(stream, opNum); +} + +////////////////////////////////////////////////////////////////////////// +template +__host__ void BroadcastBool::execBroadcast(dim3 launchDims, cudaStream_t *stream, const int opNum, + const void *x, const Nd4jLong *xShapeInfo, + const void *y, const Nd4jLong *yShapeInfo, + void *z, const Nd4jLong *zShapeInfo, + void *extraParams) { + + DISPATCH_BY_OPNUM_TT(intermediateBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams), OPS_A(BROADCAST_BOOL_OPS)) + DEBUG_KERNEL(stream, opNum); +} ////////////////////////////////////////////////////////////////////////// template @@ -229,7 +264,53 @@ namespace functions { } } +////////////////////////////////////////////////////////////////////////// +template +template +__device__ void BroadcastBool::transformCuda(const void *vx, const Nd4jLong *xShapeInfo, + const void *vy, const Nd4jLong *yShapeInfo, + void *vz, const Nd4jLong *zShapeInfo, + void *vextraParams) { - BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT BroadcastBool, , LIBND4J_TYPES, BOOL_TYPES); + const X* x = reinterpret_cast(vx); + const X* y = reinterpret_cast(vy); + Z* z = reinterpret_cast(vz); + + auto extraParams = reinterpret_cast(vextraParams); + + __shared__ Nd4jLong zLen; + __shared__ int rank; + + if (threadIdx.x == 0) { + + zLen = shape::length(zShapeInfo); + rank = shape::rank(zShapeInfo); } + __syncthreads(); + + + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + + Nd4jLong xCoords[MAX_RANK], yCoords[MAX_RANK], zCoords[MAX_RANK]; + + for (int i = tid; i < zLen; i += blockDim.x * gridDim.x) { + + shape::index2coords(i, zShapeInfo, zCoords); + + for (uint j = 0; j < rank; ++j) { + xCoords[j] = shape::sizeAt(xShapeInfo, j) == 1 ? 0 : zCoords[j]; + yCoords[j] = shape::sizeAt(yShapeInfo, j) == 1 ? 0 : zCoords[j]; + } + + const auto xOffset = shape::getOffset(xShapeInfo, xCoords); + const auto yOffset = shape::getOffset(yShapeInfo, yCoords); + const auto zOffset = shape::getOffset(zShapeInfo, zCoords); + + z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams); + } +} + + +BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT BroadcastBool, , LIBND4J_TYPES, BOOL_TYPES); +} } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/broadcasting_int.cu b/libnd4j/include/loops/cuda/broadcasting_int.cu index 9a725a886..651aaecc5 100644 --- a/libnd4j/include/loops/cuda/broadcasting_int.cu +++ b/libnd4j/include/loops/cuda/broadcasting_int.cu @@ -46,6 +46,15 @@ static __global__ void broadcastIntSimple( functions::broadcast::BroadcastInt::template transformCuda(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo,dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ); } +////////////////////////////////////////////////////////////////////////// +template +static __global__ void broadcastIntSimple(const void *x, const Nd4jLong *xShapeInfo, + const void *y, const Nd4jLong *yShapeInfo, + void *z, const Nd4jLong *zShapeInfo) { + + functions::broadcast::BroadcastInt::template transformCuda(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); +} + ////////////////////////////////////////////////////////////////////////// template static __global__ void broadcastBoolInverseSimple( @@ -62,19 +71,40 @@ static __global__ void broadcastBoolInverseSimple( } namespace functions { - namespace broadcast { +namespace broadcast { ////////////////////////////////////////////////////////////////////////// - template - template - __host__ void BroadcastInt::intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { - broadcastIntSimple<<>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ); - } +template +template +__host__ void BroadcastInt::intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { + broadcastIntSimple<<>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ); +} ////////////////////////////////////////////////////////////////////////// - template - __host__ void BroadcastInt::execBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { - DISPATCH_BY_OPNUM_T(intermediateBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), OPS_A(BROADCAST_INT_OPS)) - } +template +template +__host__ void BroadcastInt::intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, + const void *x, const Nd4jLong *xShapeInfo, + const void *y, const Nd4jLong *yShapeInfo, + void *z, const Nd4jLong *zShapeInfo) { + + broadcastIntSimple<<>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo); +} + +////////////////////////////////////////////////////////////////////////// +template +__host__ void BroadcastInt::execBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { + DISPATCH_BY_OPNUM_T(intermediateBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), OPS_A(BROADCAST_INT_OPS)) +} + +////////////////////////////////////////////////////////////////////////// +template +__host__ void BroadcastInt::execBroadcast(dim3 launchDims, cudaStream_t *stream, const int opNum, + const void *x, const Nd4jLong *xShapeInfo, + const void *y, const Nd4jLong *yShapeInfo, + void *z, const Nd4jLong *zShapeInfo) { + + DISPATCH_BY_OPNUM_T(intermediateBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo), OPS_A(BROADCAST_INT_OPS)) +} ////////////////////////////////////////////////////////////////////////// template @@ -217,6 +247,50 @@ namespace functions { } } - BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT BroadcastInt, , INTEGER_TYPES); +////////////////////////////////////////////////////////////////////////// +template +template +__device__ void BroadcastInt::transformCuda(const void *vx, const Nd4jLong *xShapeInfo, + const void *vy, const Nd4jLong *yShapeInfo, + void *vz, const Nd4jLong *zShapeInfo) { + + const X* x = reinterpret_cast(vx); + const X* y = reinterpret_cast(vy); + X* z = reinterpret_cast(vz); + + __shared__ Nd4jLong zLen; + __shared__ int rank; + + if (threadIdx.x == 0) { + + zLen = shape::length(zShapeInfo); + rank = shape::rank(zShapeInfo); } + __syncthreads(); + + + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + + Nd4jLong xCoords[MAX_RANK], yCoords[MAX_RANK], zCoords[MAX_RANK]; + + for (int i = tid; i < zLen; i += blockDim.x * gridDim.x) { + + shape::index2coords(i, zShapeInfo, zCoords); + + for (uint j = 0; j < rank; ++j) { + xCoords[j] = shape::sizeAt(xShapeInfo, j) == 1 ? 0 : zCoords[j]; + yCoords[j] = shape::sizeAt(yShapeInfo, j) == 1 ? 0 : zCoords[j]; + } + + const auto xOffset = shape::getOffset(xShapeInfo, xCoords); + const auto yOffset = shape::getOffset(yShapeInfo, yCoords); + const auto zOffset = shape::getOffset(zShapeInfo, zCoords); + + z[zOffset] = OpType::op(x[xOffset], y[yOffset]); + } +} + + +BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT BroadcastInt, , INTEGER_TYPES); +} } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/specials/tileKernel.cu b/libnd4j/include/loops/cuda/specials/tileKernel.cu index 885978fef..d6076d6cb 100644 --- a/libnd4j/include/loops/cuda/specials/tileKernel.cu +++ b/libnd4j/include/loops/cuda/specials/tileKernel.cu @@ -21,17 +21,14 @@ #include namespace sd { - static Nd4jLong __device__ __noinline__ _getIndexOffset(Nd4jLong index, Nd4jLong *shapeInfo) { + static Nd4jLong __device__ __noinline__ getIndexOffset_(Nd4jLong index, Nd4jLong *shapeInfo) { return shape::getIndexOffset(index, shapeInfo); } - static Nd4jLong __device__ __noinline__ _subArrayOffset(Nd4jLong index, Nd4jLong *shapeInfoA, Nd4jLong *shapeInfoB) { + static Nd4jLong __device__ __noinline__ subArrayOffset(Nd4jLong index, Nd4jLong *shapeInfoA, Nd4jLong *shapeInfoB) { return shape::subArrayOffset(index, shapeInfoA, shapeInfoB); } - static Nd4jLong __device__ __noinline__ _length(Nd4jLong *shapeInfo) { - return shape::length(shapeInfo); - } //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // tileKernel: @@ -48,13 +45,13 @@ namespace sd { int totalThreads = gridDim.x * blockDim.x; if (shape::order(outputShape) == 'c') { // ews == 1 always here for (int i = tid; i < resultLength; i += totalThreads) { - auto yOffset = _subArrayOffset(i, outputShape, inputShape); + auto yOffset = subArrayOffset(i, outputShape, inputShape); *(reinterpret_cast(outputBuffer) + i) = *(reinterpret_cast(inputBuffer) + yOffset); } } else { for (int i = tid; i < resultLength; i += totalThreads) { - auto xOffset = _getIndexOffset(i, outputShape); - auto yOffset = _subArrayOffset(i, outputShape, inputShape); + auto xOffset = getIndexOffset_(i, outputShape); + auto yOffset = subArrayOffset(i, outputShape, inputShape); *(reinterpret_cast(outputBuffer) + xOffset) = *(reinterpret_cast(inputBuffer) + yOffset); } } @@ -83,20 +80,20 @@ namespace sd { if (ordering == 'c' && ews == 1) { // ews == 1 always here for (int i = tid; i < resultLength; i += totalThreads) { - auto yOffset = _subArrayOffset(i, outputShape, inputShape); + auto yOffset = subArrayOffset(i, outputShape, inputShape); *(reinterpret_cast(outputBuffer) + i) = static_cast(*(reinterpret_cast(inputBuffer) + yOffset)); } } else if (ordering == 'c' && ews > 1) { for (int i = tid; i < resultLength; i += totalThreads) { - auto yOffset = _subArrayOffset(i, outputShape, inputShape); + auto yOffset = subArrayOffset(i, outputShape, inputShape); *(reinterpret_cast(outputBuffer) + i * ews) = static_cast(*(reinterpret_cast(inputBuffer) + yOffset)); } } else { for (int i = tid; i < resultLength; i += totalThreads) { - auto xOffset = _getIndexOffset(i, outputShape); - auto yOffset = _subArrayOffset(i, outputShape, inputShape); + auto xOffset = getIndexOffset_(i, outputShape); + auto yOffset = subArrayOffset(i, outputShape, inputShape); *(reinterpret_cast(outputBuffer) + xOffset) = static_cast(*(reinterpret_cast(inputBuffer) + yOffset)); } } diff --git a/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp b/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp index ac969ebbd..0769a9aef 100644 --- a/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp @@ -170,7 +170,7 @@ TEST_F(PlaygroundTests, test_bert_1) { graph->getVariableSpace()->putVariable(86,0, u); graph->getVariableSpace()->putVariable(87,0, v); - +/* // validating graph now auto status = GraphExecutioner::execute(graph); ASSERT_EQ(Status::OK(), status); @@ -179,7 +179,7 @@ TEST_F(PlaygroundTests, test_bert_1) { auto array = graph->getVariableSpace()->getVariable(198)->getNDArray(); ASSERT_EQ(z, *array); -/* +*/ sd::Environment::getInstance()->setProfiling(true); auto profile = GraphProfilingHelper::profile(graph, 1); @@ -187,7 +187,7 @@ TEST_F(PlaygroundTests, test_bert_1) { sd::Environment::getInstance()->setProfiling(false); delete profile; -*/ + /* std::vector values;