[WIP] Broadcast changes (#8257)
* - provide correct call NDArray::applyBroadcast inside of NDArray::applyTrueBroadcast Signed-off-by: Yurii <yurii@skymind.io> * - provide new trueBroadcast helper Signed-off-by: Yurii <yurii@skymind.io> * example for yurii Signed-off-by: raver119 <raver119@gmail.com> * - provide new trueBroadcast helper for cpu Signed-off-by: Yurii <yurii@skymind.io> * - start working on new trueBroadcat helper for cuda Signed-off-by: Yurii <yurii@skymind.io> * - further work on trueBroadcast for cuda Signed-off-by: Yurii <yurii@skymind.io> * - fix bugs in cuda helper trueBroadcast Signed-off-by: Yurii <yurii@skymind.io>master
parent
5959ff4795
commit
44a8d19ac6
|
@ -25,6 +25,7 @@
|
|||
#include <ConstantTadHelper.h>
|
||||
#include <BroadcastPairwiseConverter.h>
|
||||
#include <helpers/PointersManager.h>
|
||||
#include <TrueBroadcastHelper.h>
|
||||
|
||||
namespace nd4j {
|
||||
|
||||
|
@ -2519,8 +2520,6 @@ void NDArray::applyTrueBroadcast(nd4j::BroadcastOpsTuple op, const NDArray* othe
|
|||
if (isEmpty() || other->isEmpty())
|
||||
return;
|
||||
|
||||
NDArray::prepareSpecialUse({target}, {this, other});
|
||||
|
||||
if (isScalar()) {
|
||||
target->assign(this);
|
||||
target->applyPairwiseTransform(op.p, *other, extraArgs);
|
||||
|
@ -2531,57 +2530,24 @@ void NDArray::applyTrueBroadcast(nd4j::BroadcastOpsTuple op, const NDArray* othe
|
|||
return;
|
||||
}
|
||||
|
||||
const NDArray* min(other);
|
||||
const NDArray* max(this);
|
||||
|
||||
if(this->rankOf() < other->rankOf()) {
|
||||
max = other;
|
||||
min = this;
|
||||
}
|
||||
if(checkTargetShape) {
|
||||
Nd4jLong* newShapeInfo = nullptr;
|
||||
if(!ShapeUtils::evalBroadcastShapeInfo(*max, *min, false, newShapeInfo, getContext()->getWorkspace())) // the rank of target array must be equal to max->rankOf)()
|
||||
if(!ShapeUtils::evalBroadcastShapeInfo(*this, *other, true, newShapeInfo, getContext()->getWorkspace())) // the rank of target array must be equal to max->rankOf)()
|
||||
throw std::runtime_error("NDArray::applyTrueBroadcast method: the shapes of this and other arrays are not suitable for broadcast operation !");
|
||||
if(!shape::equalsTypesAndShapesSoft(target->getShapeInfo(), newShapeInfo))
|
||||
throw std::runtime_error("NDArray::applyTrueBroadcast method: the shape or type of target array is wrong !");
|
||||
}
|
||||
|
||||
NDArray* pTarget = (max->dataType() == target->dataType()) ? target : new NDArray(target->ordering(), target->getShapeAsVector(), max->dataType(), target->getContext());
|
||||
|
||||
// check whether max array has to be tiled
|
||||
if(!max->isSameShape(target)) {
|
||||
// evaluate repeating dimensions for tile operation
|
||||
std::vector<Nd4jLong> repeatMax(max->rankOf());
|
||||
for(int i = 1; i <= max->rankOf(); ++i)
|
||||
repeatMax[i - 1] = (target->_shapeInfo[i] / max->_shapeInfo[i]);
|
||||
max->tile(repeatMax, *pTarget);
|
||||
}
|
||||
else
|
||||
pTarget->assign(max);
|
||||
|
||||
// check whether min array has to be tiled
|
||||
std::vector<Nd4jLong> repeatMin(min->rankOf());
|
||||
int product = 1;
|
||||
for(int i = min->rankOf(); i >=1 ; --i) {
|
||||
repeatMin[i-1] = (target->_shapeInfo[target->rankOf() - min->rankOf() + i] / min->_shapeInfo[i]);
|
||||
product *= repeatMin[i-1];
|
||||
if(target->isSameShape(this) || target->isSameShape(other)) {
|
||||
const_cast<NDArray*>(this)->applyBroadcast(op.b, ShapeUtils::getDimsWithSameShape(*this, *other), other, target, extraArgs);
|
||||
return;
|
||||
}
|
||||
|
||||
auto pMin = const_cast<NDArray *>(min);
|
||||
if(product != 1 )
|
||||
pMin = new NDArray(min->tile(repeatMin));
|
||||
|
||||
std::vector<int> sameDims = ShapeUtils::getDimsWithSameShape(*target, *pMin);
|
||||
|
||||
if(max == this)
|
||||
pTarget->applyBroadcast(op.b, sameDims, pMin, target, extraArgs);
|
||||
else
|
||||
pMin->applyBroadcast(op.b, sameDims, pTarget, target, extraArgs);
|
||||
|
||||
if(pMin != min)
|
||||
delete pMin;
|
||||
if(pTarget != target)
|
||||
delete pTarget;
|
||||
#ifdef __ND4J_EXPERIMENTAL__
|
||||
BUILD_PAIRWISE_SELECTOR(dataType(), other->dataType(), target->dataType(), helpers::TrueBroadcastHelper, ::exec(op.b, *this, *other, *target), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||
#else
|
||||
BUILD_SINGLE_SELECTOR_THRICE(dataType(), helpers::TrueBroadcastHelper, ::exec(op.b, *this, *other, *target), LIBND4J_TYPES);
|
||||
#endif
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
|
@ -2594,8 +2560,6 @@ void NDArray::applyTrueBroadcast(nd4j::BroadcastBoolOpsTuple op, const NDArray*
|
|||
if (isEmpty() || other->isEmpty())
|
||||
return;
|
||||
|
||||
NDArray::prepareSpecialUse({target}, {this, other});
|
||||
|
||||
if (isScalar()) {
|
||||
NDArray temp(target->_shapeInfo, dataType(), false, getContext());
|
||||
temp.assign(this);
|
||||
|
@ -2607,17 +2571,9 @@ void NDArray::applyTrueBroadcast(nd4j::BroadcastBoolOpsTuple op, const NDArray*
|
|||
return;
|
||||
}
|
||||
|
||||
const NDArray* min(other);
|
||||
const NDArray* max(this);
|
||||
|
||||
if(this->rankOf() < other->rankOf()) {
|
||||
max = other;
|
||||
min = this;
|
||||
}
|
||||
|
||||
if(checkTargetShape) {
|
||||
Nd4jLong* newShapeInfo = nullptr;
|
||||
if(!ShapeUtils::evalBroadcastShapeInfo(*max, *min, false, newShapeInfo, getContext()->getWorkspace())) // the rank of target array must be equal to max->rankOf)()
|
||||
if(!ShapeUtils::evalBroadcastShapeInfo(*this, *other, true, newShapeInfo, getContext()->getWorkspace())) // the rank of target array must be equal to max->rankOf)()
|
||||
throw std::runtime_error("NDArray::applyTrueBroadcast method: the shapes of this and other arrays are not suitable for broadcast operation !");
|
||||
if(!shape::equalsSoft(target->_shapeInfo, newShapeInfo) || target->dataType() != DataType::BOOL)
|
||||
throw std::runtime_error("NDArray::applyTrueBroadcast bool method: the shape or type of target array is wrong !");
|
||||
|
@ -2625,47 +2581,16 @@ void NDArray::applyTrueBroadcast(nd4j::BroadcastBoolOpsTuple op, const NDArray*
|
|||
throw std::invalid_argument("NDArray::applyTrueBroadcast bool method: this and other arrays must have the same type !");
|
||||
}
|
||||
|
||||
NDArray* pTarget = (max->dataType() == target->dataType()) ? target : new NDArray(target->ordering(), target->getShapeAsVector(), max->dataType(), target->getContext());
|
||||
// check whether max array has to be tiled
|
||||
if(!max->isSameShape(target)) {
|
||||
// evaluate repeating dimensions for tile operation
|
||||
std::vector<Nd4jLong> repeatMax(max->rankOf());
|
||||
for(int i = 1; i <= max->rankOf(); ++i)
|
||||
repeatMax[i-1] = (target->_shapeInfo[i] / max->_shapeInfo[i]);
|
||||
max->tile(repeatMax, *pTarget);
|
||||
}
|
||||
else
|
||||
pTarget->assign(max);
|
||||
|
||||
// check whether min array has to be tiled
|
||||
std::vector<Nd4jLong> repeatMin(min->rankOf());
|
||||
int product = 1;
|
||||
for(int i = min->rankOf(); i >=1 ; --i) {
|
||||
repeatMin[i-1] = (target->_shapeInfo[target->rankOf() - min->rankOf() + i] / min->_shapeInfo[i]);
|
||||
product *= repeatMin[i-1];
|
||||
if(target->isSameShape(this) || target->isSameShape(other)) {
|
||||
const_cast<NDArray*>(this)->applyBroadcast(op.b, ShapeUtils::getDimsWithSameShape(*this, *other), other, target, extraArgs);
|
||||
return;
|
||||
}
|
||||
|
||||
auto pMin = const_cast<NDArray *>(min);
|
||||
if(product != 1 )
|
||||
pMin = new NDArray(min->tile(repeatMin));
|
||||
|
||||
std::vector<int> sameDims = ShapeUtils::getDimsWithSameShape(*target, *pMin);
|
||||
|
||||
if(max == this)
|
||||
pTarget->applyBroadcast(op.b, sameDims, pMin, target, extraArgs);
|
||||
else
|
||||
pMin->applyBroadcast(op.b, sameDims, pTarget, target, extraArgs);
|
||||
|
||||
if(pMin != min)
|
||||
delete pMin;
|
||||
if(pTarget != target)
|
||||
delete pTarget;
|
||||
BUILD_DOUBLE_SELECTOR(dataType(), target->dataType(), helpers::TrueBroadcastBoolHelper, ::exec(op.b, *this, *other, *target), LIBND4J_TYPES, BOOL_TYPES);
|
||||
}
|
||||
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void NDArray::applyTrueBroadcast(nd4j::BroadcastIntOpsTuple op, const NDArray* other, NDArray* target, const bool checkTargetShape, ExtraArguments *extraArgs) const {
|
||||
void NDArray::applyTrueBroadcast(nd4j::BroadcastIntOpsTuple op, const NDArray* other, NDArray* target, const bool checkTargetShape, ExtraArguments *extraArgs) const {
|
||||
if (isS())
|
||||
throw std::runtime_error("NDArray::applyTrueBroadcast bool: you can't use this method on String array!");
|
||||
if(target == nullptr || other == nullptr)
|
||||
|
@ -2674,8 +2599,6 @@ void NDArray::applyTrueBroadcast(nd4j::BroadcastBoolOpsTuple op, const NDArray*
|
|||
if (isEmpty() || other->isEmpty())
|
||||
return;
|
||||
|
||||
NDArray::prepareSpecialUse({target}, {this, other});
|
||||
|
||||
if (isScalar()) {
|
||||
NDArray temp(target->_shapeInfo, dataType(), false, getContext());
|
||||
temp.assign(this);
|
||||
|
@ -2687,17 +2610,9 @@ void NDArray::applyTrueBroadcast(nd4j::BroadcastBoolOpsTuple op, const NDArray*
|
|||
return;
|
||||
}
|
||||
|
||||
const NDArray* min(other);
|
||||
const NDArray* max(this);
|
||||
|
||||
if(this->rankOf() < other->rankOf()) {
|
||||
max = other;
|
||||
min = this;
|
||||
}
|
||||
|
||||
if(checkTargetShape) {
|
||||
Nd4jLong* newShapeInfo = nullptr;
|
||||
if(!ShapeUtils::evalBroadcastShapeInfo(*max, *min, false, newShapeInfo, getContext()->getWorkspace())) // the rank of target array must be equal to max->rankOf)()
|
||||
if(!ShapeUtils::evalBroadcastShapeInfo(*this, *other, false, newShapeInfo, getContext()->getWorkspace())) // the rank of target array must be equal to max->rankOf)()
|
||||
throw std::runtime_error("NDArray::applyTrueBroadcast method: the shapes of this and other arrays are not suitable for broadcast operation !");
|
||||
if(!shape::equalsSoft(target->_shapeInfo, newShapeInfo) || target->dataType() != this->dataType())
|
||||
throw std::runtime_error("NDArray::applyTrueBroadcast int method: the shape or type of target array is wrong !");
|
||||
|
@ -2705,44 +2620,13 @@ void NDArray::applyTrueBroadcast(nd4j::BroadcastBoolOpsTuple op, const NDArray*
|
|||
throw std::invalid_argument("NDArray::applyTrueBroadcast int method: this and other arrays must have the same type !");
|
||||
}
|
||||
|
||||
NDArray* pTarget = (max->dataType() == target->dataType()) ? target : new NDArray(target->ordering(), target->getShapeAsVector(), max->dataType(), target->getContext());
|
||||
// check whether max array has to be tiled
|
||||
if(!max->isSameShape(target)) {
|
||||
// evaluate repeating dimensions for tile operation
|
||||
std::vector<Nd4jLong> repeatMax(max->rankOf());
|
||||
for(int i = 1; i <= max->rankOf(); ++i)
|
||||
repeatMax[i-1] = (target->_shapeInfo[i] / max->_shapeInfo[i]);
|
||||
max->tile(repeatMax, *pTarget);
|
||||
}
|
||||
else
|
||||
pTarget->assign(max);
|
||||
|
||||
// check whether min array has to be tiled
|
||||
std::vector<Nd4jLong> repeatMin(min->rankOf());
|
||||
int product = 1;
|
||||
for(int i = min->rankOf(); i >=1 ; --i) {
|
||||
repeatMin[i-1] = (target->_shapeInfo[target->rankOf() - min->rankOf() + i] / min->_shapeInfo[i]);
|
||||
product *= repeatMin[i-1];
|
||||
if(target->isSameShape(this) || target->isSameShape(other)) {
|
||||
const_cast<NDArray*>(this)->applyBroadcast(op.b, ShapeUtils::getDimsWithSameShape(*this, *other), other, target, extraArgs);
|
||||
return;
|
||||
}
|
||||
|
||||
auto pMin = const_cast<NDArray *>(min);
|
||||
if(product != 1 )
|
||||
pMin = new NDArray(min->tile(repeatMin));
|
||||
|
||||
std::vector<int> sameDims = ShapeUtils::getDimsWithSameShape(*target, *pMin);
|
||||
|
||||
if(max == this)
|
||||
pTarget->applyBroadcast(op.b, sameDims, pMin, target, extraArgs);
|
||||
else
|
||||
pMin->applyBroadcast(op.b, sameDims, pTarget, target, extraArgs);
|
||||
|
||||
if(pMin != min)
|
||||
delete pMin;
|
||||
if(pTarget != target)
|
||||
delete pTarget;
|
||||
}
|
||||
|
||||
|
||||
BUILD_SINGLE_SELECTOR(dataType(), helpers::TrueBroadcastIntHelper, ::exec(op.b, *this, *other, *target), INTEGER_TYPES);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
NDArray NDArray::applyTrueBroadcast(nd4j::BroadcastOpsTuple op, const NDArray& other, ExtraArguments *extraArgs) const {
|
||||
|
@ -2884,7 +2768,7 @@ void NDArray::applyBroadcast(nd4j::broadcast::BoolOps op, const std::vector<int>
|
|||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void NDArray::applyBroadcast(nd4j::broadcast::IntOps op, const std::vector<int>& dimensions, const NDArray* other, NDArray* target, ExtraArguments* extraArgs) {
|
||||
void NDArray::applyBroadcast(nd4j::broadcast::IntOps op, const std::vector<int>& dimensions, const NDArray* other, NDArray* target, ExtraArguments* extraArgs) {
|
||||
if (!isZ())
|
||||
throw std::runtime_error("NDArray::applyBroadcast IntOps: you can't use this method on non-Integer array!");
|
||||
if(isEmpty() || other->isEmpty()) {
|
||||
|
@ -2941,7 +2825,7 @@ void NDArray::applyBroadcast(nd4j::broadcast::BoolOps op, const std::vector<int>
|
|||
else
|
||||
NativeOpExecutioner::execInverseBroadcastInt(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), copy.data(), (int)copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets());
|
||||
registerSpecialUse({result}, {this, other});
|
||||
}
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void NDArray::applyBroadcast(nd4j::broadcast::Ops op, const std::initializer_list<int> dimensions, const NDArray* tadArray, NDArray* target, ExtraArguments* extraArgs) {
|
||||
|
|
|
@ -78,7 +78,6 @@ bool NDArray::isActualOnHostSide() const { return _buffer->isPrimaryActual();
|
|||
bool NDArray::isActualOnDeviceSide() const { return _buffer->isSpecialActual(); }
|
||||
void NDArray::makeBothBuffersActual() const { if(!isActualOnHostSide()) syncToHost(); if(!isActualOnDeviceSide()) syncToDevice(); }
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
template<typename T>
|
||||
__global__ static void fillAsTriangularCuda(const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const T val, const int lower, const int upper) {
|
||||
|
|
|
@ -81,7 +81,8 @@ namespace nd4j {
|
|||
// check the possibility of broadcast operation for set of arrays, if true then return resulting broadcasted shapeInfo
|
||||
static bool evalCommonBroadcastShapeInfo(const std::vector<const NDArray*>& arrays, Nd4jLong*& resultShapeInfo, memory::Workspace* workspace = nullptr);
|
||||
|
||||
// return sorted vector of dimensions of array with larger dimensions along which two input arrays have same shape
|
||||
// return sorted vector of dimensions common (same) for two arrays, dimensions values corresponds to array with bigger rank
|
||||
// for example if arr1{2,7}, arr2{2,5,4,7} then vector = {0,3}
|
||||
static std::vector<int> getDimsWithSameShape(const NDArray& max, const NDArray& min);
|
||||
|
||||
// evaluate shapeInfo for resulting array of tile operation
|
||||
|
@ -169,6 +170,18 @@ namespace nd4j {
|
|||
* @return
|
||||
*/
|
||||
static Nd4jLong stringBufferHeaderRequirements(Nd4jLong numStrings);
|
||||
|
||||
/*
|
||||
* check whether arr1/arr2 is sub-array of arr2/arr1,
|
||||
* this method do not evaluate what array is sub-array, it returns true if arr1 is sub-array of arr2 or arr2 is sub-array of arr1
|
||||
* sameDims is filled (and sorted) with dimensions values that match both in arr1 and arr2 shapes (unities are ignored)
|
||||
* for example:
|
||||
* if arr1{2,3} and arr2{2,4,3,7} then return true and sameDims contains {0,2}
|
||||
* if arr1{1,1,3,1,3,1,1} and arr2{1,2,3,1,3} then return true and sameDims contains {2,4}
|
||||
* if arr1{2,1,4,1,7,5} and arr2{1,1,4,5} then return true and sameDims contains {2,5}
|
||||
|
||||
static bool isSubArrayCase(const NDArray& arr1, const NDArray& arr2, std::vector<int>& sameDims);
|
||||
*/
|
||||
};
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,84 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||
//
|
||||
|
||||
#ifndef LIBND4J_TRUEBROADCASTHELPER_H
|
||||
#define LIBND4J_TRUEBROADCASTHELPER_H
|
||||
|
||||
#include <NDArray.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace helpers {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
template <typename X, typename Y, typename Z>
|
||||
class TrueBroadcastHelper {
|
||||
|
||||
#ifdef __CUDACC__
|
||||
template <typename OpType>
|
||||
static __host__ void execLauncher(dim3 launchDims, cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo);
|
||||
#else
|
||||
template <typename OpType>
|
||||
static void exec(const NDArray& xArr, const NDArray& yArr, NDArray& zArr);
|
||||
#endif
|
||||
|
||||
public:
|
||||
static void exec(const nd4j::broadcast::Ops opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr);
|
||||
};
|
||||
|
||||
template <typename X, typename Y>
|
||||
class TrueBroadcastBoolHelper {
|
||||
|
||||
#ifdef __CUDACC__
|
||||
template <typename OpType>
|
||||
static __host__ void execLauncher(dim3 launchDims, cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo);
|
||||
#else
|
||||
template <typename OpType>
|
||||
static void exec(const NDArray& xArr, const NDArray& yArr, NDArray& zArr);
|
||||
#endif
|
||||
|
||||
public:
|
||||
|
||||
static void exec(const nd4j::broadcast::BoolOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr);
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
template <typename X>
|
||||
class TrueBroadcastIntHelper {
|
||||
|
||||
#ifdef __CUDACC__
|
||||
template <typename OpType>
|
||||
static __host__ void execLauncher(dim3 launchDims, cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo);
|
||||
#else
|
||||
template <typename OpType>
|
||||
static void exec(const NDArray& xArr, const NDArray& yArr, NDArray& zArr);
|
||||
#endif
|
||||
|
||||
public:
|
||||
|
||||
static void exec(const nd4j::broadcast::IntOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr);
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
#endif //LIBND4J_BIDIAGONALUP_H
|
|
@ -0,0 +1,218 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||
//
|
||||
|
||||
#include <TrueBroadcastHelper.h>
|
||||
|
||||
using namespace simdOps;
|
||||
|
||||
namespace nd4j {
|
||||
namespace helpers {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
template <typename X, typename Y, typename Z>
|
||||
template<typename OpType>
|
||||
void TrueBroadcastHelper<X, Y, Z>::exec(const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
|
||||
|
||||
const X* x = reinterpret_cast<X*>(xArr.getBuffer());
|
||||
const Y* y = reinterpret_cast<Y*>(yArr.getBuffer());
|
||||
Z* z = reinterpret_cast<Z*>(zArr.getBuffer());
|
||||
|
||||
const auto xShapeInfo = xArr.getShapeInfo();
|
||||
const auto yShapeInfo = yArr.getShapeInfo();
|
||||
const auto zShapeInfo = zArr.getShapeInfo();
|
||||
|
||||
const int xRank = xArr.rankOf();
|
||||
const int yRank = yArr.rankOf();
|
||||
const int zRank = zArr.rankOf();
|
||||
|
||||
const Nd4jLong zLen = zArr.lengthOf();
|
||||
|
||||
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(zLen > Environment::getInstance()->elementwiseThreshold()) firstprivate(xCoords, yCoords, zCoords))
|
||||
for (Nd4jLong i = 0; i < zLen; ++i) {
|
||||
|
||||
shape::index2coords(i, zShapeInfo, zCoords.data());
|
||||
|
||||
for(int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) {
|
||||
|
||||
if(ix >= 0) {
|
||||
if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) {
|
||||
xCoords[ix--] = zCoords[iz];
|
||||
} else {
|
||||
xCoords[ix--] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
if(iy >= 0) {
|
||||
if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) {
|
||||
yCoords[iy--] = zCoords[iz];
|
||||
} else {
|
||||
yCoords[iy--] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const auto xOffset = shape::getOffset(xShapeInfo, xCoords.data());
|
||||
const auto yOffset = shape::getOffset(yShapeInfo, yCoords.data());
|
||||
const auto zOffset = shape::getOffset(zShapeInfo, zCoords.data());
|
||||
|
||||
z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename X, typename Y, typename Z>
|
||||
void TrueBroadcastHelper<X, Y, Z>::exec(const nd4j::broadcast::Ops opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
|
||||
DISPATCH_BY_OPNUM_TTT(exec, PARAMS(xArr, yArr, zArr), BROADCAST_OPS);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
template <typename X, typename Z>
|
||||
template<typename OpType>
|
||||
void TrueBroadcastBoolHelper<X, Z>::exec(const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
|
||||
|
||||
const X* x = reinterpret_cast<X*>(xArr.getBuffer());
|
||||
const X* y = reinterpret_cast<X*>(yArr.getBuffer());
|
||||
Z* z = reinterpret_cast<Z*>(zArr.getBuffer());
|
||||
|
||||
const auto xShapeInfo = xArr.getShapeInfo();
|
||||
const auto yShapeInfo = yArr.getShapeInfo();
|
||||
const auto zShapeInfo = zArr.getShapeInfo();
|
||||
|
||||
const int xRank = xArr.rankOf();
|
||||
const int yRank = yArr.rankOf();
|
||||
const int zRank = zArr.rankOf();
|
||||
|
||||
const Nd4jLong zLen = zArr.lengthOf();
|
||||
|
||||
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(zLen > Environment::getInstance()->elementwiseThreshold()) firstprivate(xCoords, yCoords, zCoords))
|
||||
for (Nd4jLong i = 0; i < zLen; ++i) {
|
||||
|
||||
shape::index2coords(i, zShapeInfo, zCoords.data());
|
||||
|
||||
for(int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) {
|
||||
|
||||
if(ix >= 0) {
|
||||
if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) {
|
||||
xCoords[ix--] = zCoords[iz];
|
||||
} else {
|
||||
xCoords[ix--] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
if(iy >= 0) {
|
||||
if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) {
|
||||
yCoords[iy--] = zCoords[iz];
|
||||
} else {
|
||||
yCoords[iy--] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const auto xOffset = shape::getOffset(xShapeInfo, xCoords.data());
|
||||
const auto yOffset = shape::getOffset(yShapeInfo, yCoords.data());
|
||||
const auto zOffset = shape::getOffset(zShapeInfo, zCoords.data());
|
||||
|
||||
z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename X, typename Y>
|
||||
void TrueBroadcastBoolHelper<X, Y>::exec(const nd4j::broadcast::BoolOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
|
||||
DISPATCH_BY_OPNUM_TT(exec, PARAMS(xArr, yArr, zArr), BROADCAST_BOOL_OPS);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
template <typename X>
|
||||
template<typename OpType>
|
||||
void TrueBroadcastIntHelper<X>::exec(const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
|
||||
|
||||
const X* x = reinterpret_cast<X*>(xArr.getBuffer());
|
||||
const X* y = reinterpret_cast<X*>(yArr.getBuffer());
|
||||
X* z = reinterpret_cast<X*>(zArr.getBuffer());
|
||||
|
||||
const auto xShapeInfo = xArr.getShapeInfo();
|
||||
const auto yShapeInfo = yArr.getShapeInfo();
|
||||
const auto zShapeInfo = zArr.getShapeInfo();
|
||||
|
||||
const int xRank = xArr.rankOf();
|
||||
const int yRank = yArr.rankOf();
|
||||
const int zRank = zArr.rankOf();
|
||||
|
||||
const Nd4jLong zLen = zArr.lengthOf();
|
||||
|
||||
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(zLen > Environment::getInstance()->elementwiseThreshold()) firstprivate(xCoords, yCoords, zCoords))
|
||||
for (Nd4jLong i = 0; i < zLen; ++i) {
|
||||
|
||||
shape::index2coords(i, zShapeInfo, zCoords.data());
|
||||
|
||||
for(int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) {
|
||||
|
||||
if(ix >= 0) {
|
||||
if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) {
|
||||
xCoords[ix--] = zCoords[iz];
|
||||
} else {
|
||||
xCoords[ix--] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
if(iy >= 0) {
|
||||
if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) {
|
||||
yCoords[iy--] = zCoords[iz];
|
||||
} else {
|
||||
yCoords[iy--] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const auto xOffset = shape::getOffset(xShapeInfo, xCoords.data());
|
||||
const auto yOffset = shape::getOffset(yShapeInfo, yCoords.data());
|
||||
const auto zOffset = shape::getOffset(zShapeInfo, zCoords.data());
|
||||
|
||||
z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename X>
|
||||
void TrueBroadcastIntHelper<X>::exec(const nd4j::broadcast::IntOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
|
||||
DISPATCH_BY_OPNUM_T(exec, PARAMS(xArr, yArr, zArr), BROADCAST_INT_OPS);
|
||||
}
|
||||
|
||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_0);
|
||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_1);
|
||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_2);
|
||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_3);
|
||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_4);
|
||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_5);
|
||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_6);
|
||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_7);
|
||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_8);
|
||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_9);
|
||||
|
||||
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastBoolHelper, , LIBND4J_TYPES, BOOL_TYPES);
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastIntHelper, , INTEGER_TYPES);
|
||||
|
||||
}
|
||||
}
|
|
@ -0,0 +1,309 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||
//
|
||||
|
||||
// #include <exceptions/cuda_exception.h>
|
||||
#include <TrueBroadcastHelper.h>
|
||||
#include <PointersManager.h>
|
||||
#include <execution/LaunchContext.h>
|
||||
#include <specials.h>
|
||||
#include <logger.h>
|
||||
// #include <cuda_runtime.h>
|
||||
// #include <cuda.h>
|
||||
|
||||
using namespace simdOps;
|
||||
|
||||
namespace nd4j {
|
||||
namespace helpers {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
template <typename X, typename Y, typename Z, typename OpType>
|
||||
__global__ static void trueBroadcastCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo) {
|
||||
|
||||
const auto x = reinterpret_cast<const X*>(vx);
|
||||
const auto y = reinterpret_cast<const Y*>(vy);
|
||||
auto z = reinterpret_cast<Z*>(vz);
|
||||
|
||||
__shared__ int xRank, yRank, zRank;
|
||||
__shared__ Nd4jLong zLen, totalThreads, *sharedMem; // xLen == zLen, except when xRank = 1, in this case zLen = 2*xLen
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
extern __shared__ unsigned char shmem[];
|
||||
sharedMem = reinterpret_cast<Nd4jLong*>(shmem);
|
||||
|
||||
xRank = shape::rank(xShapeInfo);
|
||||
yRank = shape::rank(yShapeInfo);
|
||||
zRank = shape::rank(zShapeInfo);
|
||||
|
||||
zLen = shape::length(zShapeInfo);
|
||||
totalThreads = gridDim.x * blockDim.x;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
auto xCoords = sharedMem + threadIdx.x * (xRank + yRank + zRank);
|
||||
auto yCoords = xCoords + xRank;
|
||||
auto zCoords = yCoords + yRank;
|
||||
|
||||
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
for (Nd4jLong i = tid; i < zLen; i += totalThreads) {
|
||||
|
||||
shape::index2coords(i, zShapeInfo, zCoords);
|
||||
|
||||
for(int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) {
|
||||
|
||||
if(ix >= 0)
|
||||
if(xShapeInfo[ix + 1] == zShapeInfo[iz + 1])
|
||||
xCoords[ix--] = zCoords[iz];
|
||||
else
|
||||
xCoords[ix--] = 0;
|
||||
|
||||
if(iy >= 0)
|
||||
if(yShapeInfo[iy + 1] == zShapeInfo[iz + 1])
|
||||
yCoords[iy--] = zCoords[iz];
|
||||
else
|
||||
yCoords[iy--] = 0;
|
||||
}
|
||||
|
||||
const auto xOffset = shape::getOffset(xShapeInfo, xCoords);
|
||||
const auto zOffset = shape::getOffset(zShapeInfo, zCoords);
|
||||
const auto yOffset = shape::getOffset(yShapeInfo, yCoords);
|
||||
|
||||
z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
template<typename X, typename Y, typename Z>
|
||||
template <typename OpType>
|
||||
void TrueBroadcastHelper<X,Y,Z>::execLauncher(dim3 launchDims, cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo) {
|
||||
|
||||
trueBroadcastCuda<X, Y, Z, OpType><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template<typename X, typename Y, typename Z>
|
||||
void TrueBroadcastHelper<X,Y,Z>::exec(const nd4j::broadcast::Ops opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
|
||||
|
||||
dim3 launchDims;
|
||||
|
||||
launchDims.x = MAX_NUM_THREADS / 8; // threadsPerBlock
|
||||
launchDims.y = (zArr.lengthOf() + launchDims.x - 1) / launchDims.x; // blocksPerGrid
|
||||
launchDims.z = sizeof(Nd4jLong) * launchDims.x * (xArr.rankOf() + yArr.rankOf() + zArr.rankOf()) + 128; // sharedMem
|
||||
|
||||
PointersManager manager(xArr.getContext(), "TrueBroadcastHelper<X,Y,Z>::exec");
|
||||
|
||||
NDArray::prepareSpecialUse({&zArr}, {&xArr, &yArr});
|
||||
|
||||
DISPATCH_BY_OPNUM_TTT(execLauncher, PARAMS(launchDims, xArr.getContext()->getCudaStream(), xArr.getSpecialBuffer(), xArr.getSpecialShapeInfo(), yArr.getSpecialBuffer(), yArr.getSpecialShapeInfo(), zArr.specialBuffer(), zArr.specialShapeInfo()), OPS_A(BROADCAST_OPS));
|
||||
|
||||
NDArray::registerSpecialUse({&zArr}, {&xArr, &yArr});
|
||||
|
||||
manager.synchronize();
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
template <typename X, typename Z, typename OpType>
|
||||
__global__ static void trueBroadcastBoolCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo) {
|
||||
|
||||
const auto x = reinterpret_cast<const X*>(vx);
|
||||
const auto y = reinterpret_cast<const X*>(vy);
|
||||
auto z = reinterpret_cast<Z*>(vz);
|
||||
|
||||
__shared__ int xRank, yRank, zRank;
|
||||
__shared__ Nd4jLong zLen, totalThreads, *sharedMem; // xLen == zLen, except when xRank = 1, in this case zLen = 2*xLen
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
extern __shared__ unsigned char shmem[];
|
||||
sharedMem = reinterpret_cast<Nd4jLong*>(shmem);
|
||||
|
||||
xRank = shape::rank(xShapeInfo);
|
||||
yRank = shape::rank(yShapeInfo);
|
||||
zRank = shape::rank(zShapeInfo);
|
||||
|
||||
zLen = shape::length(zShapeInfo);
|
||||
totalThreads = gridDim.x * blockDim.x;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
auto xCoords = sharedMem + threadIdx.x * (xRank + yRank + zRank);
|
||||
auto yCoords = xCoords + xRank;
|
||||
auto zCoords = yCoords + yRank;
|
||||
|
||||
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
for (Nd4jLong i = tid; i < zLen; i += totalThreads) {
|
||||
|
||||
shape::index2coords(i, zShapeInfo, zCoords);
|
||||
|
||||
for(int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) {
|
||||
|
||||
if(ix >= 0)
|
||||
if(xShapeInfo[ix + 1] == zShapeInfo[iz + 1])
|
||||
xCoords[ix--] = zCoords[iz];
|
||||
else
|
||||
xCoords[ix--] = 0;
|
||||
|
||||
if(iy >= 0)
|
||||
if(yShapeInfo[iy + 1] == zShapeInfo[iz + 1])
|
||||
yCoords[iy--] = zCoords[iz];
|
||||
else
|
||||
yCoords[iy--] = 0;
|
||||
}
|
||||
|
||||
const auto xOffset = shape::getOffset(xShapeInfo, xCoords);
|
||||
const auto zOffset = shape::getOffset(zShapeInfo, zCoords);
|
||||
const auto yOffset = shape::getOffset(yShapeInfo, yCoords);
|
||||
|
||||
z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
template<typename X, typename Z>
|
||||
template <typename OpType>
|
||||
void TrueBroadcastBoolHelper<X,Z>::execLauncher(dim3 launchDims, cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo) {
|
||||
|
||||
trueBroadcastBoolCuda<X,Z,OpType><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template<typename X, typename Y>
|
||||
void TrueBroadcastBoolHelper<X,Y>::exec(const nd4j::broadcast::BoolOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
|
||||
|
||||
dim3 launchDims;
|
||||
launchDims.x = MAX_NUM_THREADS / 8; // threadsPerBlock
|
||||
launchDims.y = (zArr.lengthOf() + launchDims.x - 1) / launchDims.x; // blocksPerGrid
|
||||
launchDims.z = sizeof(Nd4jLong) * launchDims.x * (xArr.rankOf() + yArr.rankOf() + zArr.rankOf()) + 128; // sharedMem
|
||||
|
||||
PointersManager manager(xArr.getContext(), "TrueBroadcastBoolHelper<X,Y>::exec");
|
||||
|
||||
NDArray::prepareSpecialUse({&zArr}, {&xArr, &yArr});
|
||||
|
||||
DISPATCH_BY_OPNUM_TT(execLauncher, PARAMS(launchDims, xArr.getContext()->getCudaStream(), xArr.getSpecialBuffer(), xArr.getSpecialShapeInfo(), yArr.getSpecialBuffer(), yArr.getSpecialShapeInfo(), zArr.specialBuffer(), zArr.specialShapeInfo()), OPS_A(BROADCAST_BOOL_OPS));
|
||||
|
||||
NDArray::registerSpecialUse({&zArr}, {&xArr, &yArr});
|
||||
|
||||
manager.synchronize();
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
template <typename X, typename OpType>
|
||||
__global__ static void trueBroadcastIntCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo) {
|
||||
|
||||
const auto x = reinterpret_cast<const X*>(vx);
|
||||
const auto y = reinterpret_cast<const X*>(vy);
|
||||
auto z = reinterpret_cast<X*>(vz);
|
||||
|
||||
__shared__ int xRank, yRank, zRank;
|
||||
__shared__ Nd4jLong zLen, totalThreads, *sharedMem; // xLen == zLen, except when xRank = 1, in this case zLen = 2*xLen
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
extern __shared__ unsigned char shmem[];
|
||||
sharedMem = reinterpret_cast<Nd4jLong*>(shmem);
|
||||
|
||||
xRank = shape::rank(xShapeInfo);
|
||||
yRank = shape::rank(yShapeInfo);
|
||||
zRank = shape::rank(zShapeInfo);
|
||||
|
||||
zLen = shape::length(zShapeInfo);
|
||||
totalThreads = gridDim.x * blockDim.x;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
auto xCoords = sharedMem + threadIdx.x * (xRank + yRank + zRank);
|
||||
auto yCoords = xCoords + xRank;
|
||||
auto zCoords = yCoords + yRank;
|
||||
|
||||
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
for (Nd4jLong i = tid; i < zLen; i += totalThreads) {
|
||||
|
||||
shape::index2coords(i, zShapeInfo, zCoords);
|
||||
|
||||
for(int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) {
|
||||
|
||||
if(ix >= 0)
|
||||
if(xShapeInfo[ix + 1] == zShapeInfo[iz + 1])
|
||||
xCoords[ix--] = zCoords[iz];
|
||||
else
|
||||
xCoords[ix--] = 0;
|
||||
|
||||
if(iy >= 0)
|
||||
if(yShapeInfo[iy + 1] == zShapeInfo[iz + 1])
|
||||
yCoords[iy--] = zCoords[iz];
|
||||
else
|
||||
yCoords[iy--] = 0;
|
||||
}
|
||||
|
||||
const auto xOffset = shape::getOffset(xShapeInfo, xCoords);
|
||||
const auto zOffset = shape::getOffset(zShapeInfo, zCoords);
|
||||
const auto yOffset = shape::getOffset(yShapeInfo, yCoords);
|
||||
|
||||
z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
template<typename X>
|
||||
template <typename OpType>
|
||||
void TrueBroadcastIntHelper<X>::execLauncher(dim3 launchDims, cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo) {
|
||||
|
||||
trueBroadcastIntCuda<X,OpType><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template<typename X>
|
||||
void TrueBroadcastIntHelper<X>::exec(const nd4j::broadcast::IntOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
|
||||
|
||||
dim3 launchDims;
|
||||
launchDims.x = MAX_NUM_THREADS / 8; // threadsPerBlock
|
||||
launchDims.y = (zArr.lengthOf() + launchDims.x - 1) / launchDims.x; // blocksPerGrid
|
||||
launchDims.z = sizeof(Nd4jLong) * launchDims.x * (xArr.rankOf() + yArr.rankOf() + zArr.rankOf()) + 128; // sharedMem
|
||||
|
||||
PointersManager manager(xArr.getContext(), "TrueBroadcastIntHelper<X>::exec");
|
||||
|
||||
NDArray::prepareSpecialUse({&zArr}, {&xArr, &yArr});
|
||||
|
||||
DISPATCH_BY_OPNUM_T(execLauncher, PARAMS(launchDims, xArr.getContext()->getCudaStream(), xArr.getSpecialBuffer(), xArr.getSpecialShapeInfo(), yArr.getSpecialBuffer(), yArr.getSpecialShapeInfo(), zArr.specialBuffer(), zArr.specialShapeInfo()), OPS_A(BROADCAST_INT_OPS));
|
||||
|
||||
NDArray::registerSpecialUse({&zArr}, {&xArr, &yArr});
|
||||
|
||||
manager.synchronize();
|
||||
}
|
||||
|
||||
|
||||
|
||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_0);
|
||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_1);
|
||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_2);
|
||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_3);
|
||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_4);
|
||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_5);
|
||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_6);
|
||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_7);
|
||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_8);
|
||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_9);
|
||||
|
||||
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastBoolHelper, , LIBND4J_TYPES, BOOL_TYPES);
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastIntHelper, , INTEGER_TYPES);
|
||||
|
||||
}
|
||||
}
|
|
@ -515,21 +515,30 @@ bool ShapeUtils::evalCommonBroadcastShapeInfo(const std::vector<const NDArray*>&
|
|||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
// return sorted vector of dimensions of array with larger dimensions number along which two input arrays have same shape
|
||||
// the array with larger dimensions number has to be passed as first argument
|
||||
std::vector<int> ShapeUtils::getDimsWithSameShape(const NDArray& max, const NDArray& min) {
|
||||
// return sorted vector of dimensions common (same) for two arrays, dimensions values corresponds to array with bigger rank
|
||||
// for example if arr1{2,7}, arr2{2,5,4,7} then vector = {0,3}
|
||||
std::vector<int> ShapeUtils::getDimsWithSameShape(const NDArray& arr1, const NDArray& arr2) {
|
||||
|
||||
std::vector<int> result;
|
||||
auto maxShapeInfo = max.getShapeInfo();
|
||||
auto minShapeInfo = min.getShapeInfo();
|
||||
int maxRank = maxShapeInfo[0];
|
||||
int minRank = minShapeInfo[0];
|
||||
const NDArray *min, *max;
|
||||
|
||||
for (int i = 1; i <= minRank; ++i)
|
||||
if (minShapeInfo[i] == maxShapeInfo[maxRank - minRank + i])
|
||||
result.emplace_back(maxRank - minRank + i - 1);
|
||||
if(arr1.rankOf() >= arr2.rankOf()) {
|
||||
max = &arr1;
|
||||
min = &arr2;
|
||||
}
|
||||
else {
|
||||
max = &arr2;
|
||||
min = &arr1;
|
||||
}
|
||||
|
||||
return result;
|
||||
const int rankDiff = max->rankOf() - min->rankOf();
|
||||
|
||||
std::vector<int> dims;
|
||||
|
||||
for (int i = 0; i < min->rankOf(); ++i)
|
||||
if (min->sizeAt(i) == max->sizeAt(rankDiff + i))
|
||||
dims.emplace_back(rankDiff + i);
|
||||
|
||||
return dims;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
|
@ -997,14 +1006,56 @@ std::vector<int> ShapeUtils::tadAxesForSimpleBroadcast(const NDArray& max, const
|
|||
}
|
||||
|
||||
|
||||
Nd4jLong ShapeUtils::stringBufferHeaderRequirements(Nd4jLong numStrings) {
|
||||
Nd4jLong ShapeUtils::stringBufferHeaderRequirements(Nd4jLong numStrings) {
|
||||
// we store +1 offset
|
||||
auto base = numStrings + 1;
|
||||
|
||||
// since we return number of bytes...
|
||||
return base * sizeof(Nd4jLong);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/*
|
||||
bool ShapeUtils::isSubArrayCase(const NDArray& arr1, const NDArray& arr2, std::vector<int>& sameDims) {
|
||||
|
||||
if(!sameDims.empty())
|
||||
sameDims.clear();
|
||||
|
||||
const NDArray* max = &arr1;
|
||||
const NDArray* min = &arr2;
|
||||
|
||||
if(arr1.lengthOf() < arr2.lengthOf()) {
|
||||
max = &arr2;
|
||||
min = &arr1;
|
||||
}
|
||||
|
||||
int numUnitiesInMin = 0;
|
||||
|
||||
for (int iMax = -1, iMin = -1; iMax >= -max->rankOf() && iMin >= -min->rankOf(); ) {
|
||||
|
||||
if(max->sizeAt(iMax) == 1) { // ignore unities in shape
|
||||
--iMax;
|
||||
continue;
|
||||
}
|
||||
|
||||
if(min->sizeAt(iMin) == 1) { // ignore unities in shape
|
||||
++numUnitiesInMin;
|
||||
--iMin;
|
||||
continue;
|
||||
}
|
||||
|
||||
if(max->sizeAt(iMax) == min->sizeAt(iMin)) {
|
||||
sameDims.insert(sameDims.begin(), iMax + max->rankOf());
|
||||
--iMin;
|
||||
}
|
||||
|
||||
--iMax;
|
||||
}
|
||||
|
||||
return sameDims.size() + numUnitiesInMin == min->rankOf();
|
||||
}
|
||||
*/
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue