[WIP] Broadcast changes (#8257)

* - provide correct call NDArray::applyBroadcast inside of NDArray::applyTrueBroadcast

Signed-off-by: Yurii <yurii@skymind.io>

* - provide new trueBroadcast helper

Signed-off-by: Yurii <yurii@skymind.io>

* example for yurii

Signed-off-by: raver119 <raver119@gmail.com>

* - provide new trueBroadcast helper for cpu

Signed-off-by: Yurii <yurii@skymind.io>

* - start working on new trueBroadcat helper for cuda

Signed-off-by: Yurii <yurii@skymind.io>

* - further work on trueBroadcast for cuda

Signed-off-by: Yurii <yurii@skymind.io>

* - fix bugs in cuda helper trueBroadcast

Signed-off-by: Yurii <yurii@skymind.io>
master
raver119 2019-10-01 09:10:19 +03:00 committed by GitHub
parent 5959ff4795
commit 44a8d19ac6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 806 additions and 248 deletions

View File

@ -25,6 +25,7 @@
#include <ConstantTadHelper.h> #include <ConstantTadHelper.h>
#include <BroadcastPairwiseConverter.h> #include <BroadcastPairwiseConverter.h>
#include <helpers/PointersManager.h> #include <helpers/PointersManager.h>
#include <TrueBroadcastHelper.h>
namespace nd4j { namespace nd4j {
@ -2519,8 +2520,6 @@ void NDArray::applyTrueBroadcast(nd4j::BroadcastOpsTuple op, const NDArray* othe
if (isEmpty() || other->isEmpty()) if (isEmpty() || other->isEmpty())
return; return;
NDArray::prepareSpecialUse({target}, {this, other});
if (isScalar()) { if (isScalar()) {
target->assign(this); target->assign(this);
target->applyPairwiseTransform(op.p, *other, extraArgs); target->applyPairwiseTransform(op.p, *other, extraArgs);
@ -2531,57 +2530,24 @@ void NDArray::applyTrueBroadcast(nd4j::BroadcastOpsTuple op, const NDArray* othe
return; return;
} }
const NDArray* min(other);
const NDArray* max(this);
if(this->rankOf() < other->rankOf()) {
max = other;
min = this;
}
if(checkTargetShape) { if(checkTargetShape) {
Nd4jLong* newShapeInfo = nullptr; Nd4jLong* newShapeInfo = nullptr;
if(!ShapeUtils::evalBroadcastShapeInfo(*max, *min, false, newShapeInfo, getContext()->getWorkspace())) // the rank of target array must be equal to max->rankOf)() if(!ShapeUtils::evalBroadcastShapeInfo(*this, *other, true, newShapeInfo, getContext()->getWorkspace())) // the rank of target array must be equal to max->rankOf)()
throw std::runtime_error("NDArray::applyTrueBroadcast method: the shapes of this and other arrays are not suitable for broadcast operation !"); throw std::runtime_error("NDArray::applyTrueBroadcast method: the shapes of this and other arrays are not suitable for broadcast operation !");
if(!shape::equalsTypesAndShapesSoft(target->getShapeInfo(), newShapeInfo)) if(!shape::equalsTypesAndShapesSoft(target->getShapeInfo(), newShapeInfo))
throw std::runtime_error("NDArray::applyTrueBroadcast method: the shape or type of target array is wrong !"); throw std::runtime_error("NDArray::applyTrueBroadcast method: the shape or type of target array is wrong !");
} }
NDArray* pTarget = (max->dataType() == target->dataType()) ? target : new NDArray(target->ordering(), target->getShapeAsVector(), max->dataType(), target->getContext()); if(target->isSameShape(this) || target->isSameShape(other)) {
const_cast<NDArray*>(this)->applyBroadcast(op.b, ShapeUtils::getDimsWithSameShape(*this, *other), other, target, extraArgs);
// check whether max array has to be tiled return;
if(!max->isSameShape(target)) {
// evaluate repeating dimensions for tile operation
std::vector<Nd4jLong> repeatMax(max->rankOf());
for(int i = 1; i <= max->rankOf(); ++i)
repeatMax[i - 1] = (target->_shapeInfo[i] / max->_shapeInfo[i]);
max->tile(repeatMax, *pTarget);
}
else
pTarget->assign(max);
// check whether min array has to be tiled
std::vector<Nd4jLong> repeatMin(min->rankOf());
int product = 1;
for(int i = min->rankOf(); i >=1 ; --i) {
repeatMin[i-1] = (target->_shapeInfo[target->rankOf() - min->rankOf() + i] / min->_shapeInfo[i]);
product *= repeatMin[i-1];
} }
auto pMin = const_cast<NDArray *>(min); #ifdef __ND4J_EXPERIMENTAL__
if(product != 1 ) BUILD_PAIRWISE_SELECTOR(dataType(), other->dataType(), target->dataType(), helpers::TrueBroadcastHelper, ::exec(op.b, *this, *other, *target), LIBND4J_TYPES, LIBND4J_TYPES);
pMin = new NDArray(min->tile(repeatMin)); #else
BUILD_SINGLE_SELECTOR_THRICE(dataType(), helpers::TrueBroadcastHelper, ::exec(op.b, *this, *other, *target), LIBND4J_TYPES);
std::vector<int> sameDims = ShapeUtils::getDimsWithSameShape(*target, *pMin); #endif
if(max == this)
pTarget->applyBroadcast(op.b, sameDims, pMin, target, extraArgs);
else
pMin->applyBroadcast(op.b, sameDims, pTarget, target, extraArgs);
if(pMin != min)
delete pMin;
if(pTarget != target)
delete pTarget;
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
@ -2594,8 +2560,6 @@ void NDArray::applyTrueBroadcast(nd4j::BroadcastBoolOpsTuple op, const NDArray*
if (isEmpty() || other->isEmpty()) if (isEmpty() || other->isEmpty())
return; return;
NDArray::prepareSpecialUse({target}, {this, other});
if (isScalar()) { if (isScalar()) {
NDArray temp(target->_shapeInfo, dataType(), false, getContext()); NDArray temp(target->_shapeInfo, dataType(), false, getContext());
temp.assign(this); temp.assign(this);
@ -2607,17 +2571,9 @@ void NDArray::applyTrueBroadcast(nd4j::BroadcastBoolOpsTuple op, const NDArray*
return; return;
} }
const NDArray* min(other);
const NDArray* max(this);
if(this->rankOf() < other->rankOf()) {
max = other;
min = this;
}
if(checkTargetShape) { if(checkTargetShape) {
Nd4jLong* newShapeInfo = nullptr; Nd4jLong* newShapeInfo = nullptr;
if(!ShapeUtils::evalBroadcastShapeInfo(*max, *min, false, newShapeInfo, getContext()->getWorkspace())) // the rank of target array must be equal to max->rankOf)() if(!ShapeUtils::evalBroadcastShapeInfo(*this, *other, true, newShapeInfo, getContext()->getWorkspace())) // the rank of target array must be equal to max->rankOf)()
throw std::runtime_error("NDArray::applyTrueBroadcast method: the shapes of this and other arrays are not suitable for broadcast operation !"); throw std::runtime_error("NDArray::applyTrueBroadcast method: the shapes of this and other arrays are not suitable for broadcast operation !");
if(!shape::equalsSoft(target->_shapeInfo, newShapeInfo) || target->dataType() != DataType::BOOL) if(!shape::equalsSoft(target->_shapeInfo, newShapeInfo) || target->dataType() != DataType::BOOL)
throw std::runtime_error("NDArray::applyTrueBroadcast bool method: the shape or type of target array is wrong !"); throw std::runtime_error("NDArray::applyTrueBroadcast bool method: the shape or type of target array is wrong !");
@ -2625,47 +2581,16 @@ void NDArray::applyTrueBroadcast(nd4j::BroadcastBoolOpsTuple op, const NDArray*
throw std::invalid_argument("NDArray::applyTrueBroadcast bool method: this and other arrays must have the same type !"); throw std::invalid_argument("NDArray::applyTrueBroadcast bool method: this and other arrays must have the same type !");
} }
NDArray* pTarget = (max->dataType() == target->dataType()) ? target : new NDArray(target->ordering(), target->getShapeAsVector(), max->dataType(), target->getContext()); if(target->isSameShape(this) || target->isSameShape(other)) {
// check whether max array has to be tiled const_cast<NDArray*>(this)->applyBroadcast(op.b, ShapeUtils::getDimsWithSameShape(*this, *other), other, target, extraArgs);
if(!max->isSameShape(target)) { return;
// evaluate repeating dimensions for tile operation
std::vector<Nd4jLong> repeatMax(max->rankOf());
for(int i = 1; i <= max->rankOf(); ++i)
repeatMax[i-1] = (target->_shapeInfo[i] / max->_shapeInfo[i]);
max->tile(repeatMax, *pTarget);
}
else
pTarget->assign(max);
// check whether min array has to be tiled
std::vector<Nd4jLong> repeatMin(min->rankOf());
int product = 1;
for(int i = min->rankOf(); i >=1 ; --i) {
repeatMin[i-1] = (target->_shapeInfo[target->rankOf() - min->rankOf() + i] / min->_shapeInfo[i]);
product *= repeatMin[i-1];
} }
auto pMin = const_cast<NDArray *>(min); BUILD_DOUBLE_SELECTOR(dataType(), target->dataType(), helpers::TrueBroadcastBoolHelper, ::exec(op.b, *this, *other, *target), LIBND4J_TYPES, BOOL_TYPES);
if(product != 1 )
pMin = new NDArray(min->tile(repeatMin));
std::vector<int> sameDims = ShapeUtils::getDimsWithSameShape(*target, *pMin);
if(max == this)
pTarget->applyBroadcast(op.b, sameDims, pMin, target, extraArgs);
else
pMin->applyBroadcast(op.b, sameDims, pTarget, target, extraArgs);
if(pMin != min)
delete pMin;
if(pTarget != target)
delete pTarget;
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
void NDArray::applyTrueBroadcast(nd4j::BroadcastIntOpsTuple op, const NDArray* other, NDArray* target, const bool checkTargetShape, ExtraArguments *extraArgs) const { void NDArray::applyTrueBroadcast(nd4j::BroadcastIntOpsTuple op, const NDArray* other, NDArray* target, const bool checkTargetShape, ExtraArguments *extraArgs) const {
if (isS()) if (isS())
throw std::runtime_error("NDArray::applyTrueBroadcast bool: you can't use this method on String array!"); throw std::runtime_error("NDArray::applyTrueBroadcast bool: you can't use this method on String array!");
if(target == nullptr || other == nullptr) if(target == nullptr || other == nullptr)
@ -2674,8 +2599,6 @@ void NDArray::applyTrueBroadcast(nd4j::BroadcastBoolOpsTuple op, const NDArray*
if (isEmpty() || other->isEmpty()) if (isEmpty() || other->isEmpty())
return; return;
NDArray::prepareSpecialUse({target}, {this, other});
if (isScalar()) { if (isScalar()) {
NDArray temp(target->_shapeInfo, dataType(), false, getContext()); NDArray temp(target->_shapeInfo, dataType(), false, getContext());
temp.assign(this); temp.assign(this);
@ -2687,17 +2610,9 @@ void NDArray::applyTrueBroadcast(nd4j::BroadcastBoolOpsTuple op, const NDArray*
return; return;
} }
const NDArray* min(other);
const NDArray* max(this);
if(this->rankOf() < other->rankOf()) {
max = other;
min = this;
}
if(checkTargetShape) { if(checkTargetShape) {
Nd4jLong* newShapeInfo = nullptr; Nd4jLong* newShapeInfo = nullptr;
if(!ShapeUtils::evalBroadcastShapeInfo(*max, *min, false, newShapeInfo, getContext()->getWorkspace())) // the rank of target array must be equal to max->rankOf)() if(!ShapeUtils::evalBroadcastShapeInfo(*this, *other, false, newShapeInfo, getContext()->getWorkspace())) // the rank of target array must be equal to max->rankOf)()
throw std::runtime_error("NDArray::applyTrueBroadcast method: the shapes of this and other arrays are not suitable for broadcast operation !"); throw std::runtime_error("NDArray::applyTrueBroadcast method: the shapes of this and other arrays are not suitable for broadcast operation !");
if(!shape::equalsSoft(target->_shapeInfo, newShapeInfo) || target->dataType() != this->dataType()) if(!shape::equalsSoft(target->_shapeInfo, newShapeInfo) || target->dataType() != this->dataType())
throw std::runtime_error("NDArray::applyTrueBroadcast int method: the shape or type of target array is wrong !"); throw std::runtime_error("NDArray::applyTrueBroadcast int method: the shape or type of target array is wrong !");
@ -2705,44 +2620,13 @@ void NDArray::applyTrueBroadcast(nd4j::BroadcastBoolOpsTuple op, const NDArray*
throw std::invalid_argument("NDArray::applyTrueBroadcast int method: this and other arrays must have the same type !"); throw std::invalid_argument("NDArray::applyTrueBroadcast int method: this and other arrays must have the same type !");
} }
NDArray* pTarget = (max->dataType() == target->dataType()) ? target : new NDArray(target->ordering(), target->getShapeAsVector(), max->dataType(), target->getContext()); if(target->isSameShape(this) || target->isSameShape(other)) {
// check whether max array has to be tiled const_cast<NDArray*>(this)->applyBroadcast(op.b, ShapeUtils::getDimsWithSameShape(*this, *other), other, target, extraArgs);
if(!max->isSameShape(target)) { return;
// evaluate repeating dimensions for tile operation
std::vector<Nd4jLong> repeatMax(max->rankOf());
for(int i = 1; i <= max->rankOf(); ++i)
repeatMax[i-1] = (target->_shapeInfo[i] / max->_shapeInfo[i]);
max->tile(repeatMax, *pTarget);
}
else
pTarget->assign(max);
// check whether min array has to be tiled
std::vector<Nd4jLong> repeatMin(min->rankOf());
int product = 1;
for(int i = min->rankOf(); i >=1 ; --i) {
repeatMin[i-1] = (target->_shapeInfo[target->rankOf() - min->rankOf() + i] / min->_shapeInfo[i]);
product *= repeatMin[i-1];
} }
auto pMin = const_cast<NDArray *>(min); BUILD_SINGLE_SELECTOR(dataType(), helpers::TrueBroadcastIntHelper, ::exec(op.b, *this, *other, *target), INTEGER_TYPES);
if(product != 1 ) }
pMin = new NDArray(min->tile(repeatMin));
std::vector<int> sameDims = ShapeUtils::getDimsWithSameShape(*target, *pMin);
if(max == this)
pTarget->applyBroadcast(op.b, sameDims, pMin, target, extraArgs);
else
pMin->applyBroadcast(op.b, sameDims, pTarget, target, extraArgs);
if(pMin != min)
delete pMin;
if(pTarget != target)
delete pTarget;
}
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
NDArray NDArray::applyTrueBroadcast(nd4j::BroadcastOpsTuple op, const NDArray& other, ExtraArguments *extraArgs) const { NDArray NDArray::applyTrueBroadcast(nd4j::BroadcastOpsTuple op, const NDArray& other, ExtraArguments *extraArgs) const {
@ -2884,7 +2768,7 @@ void NDArray::applyBroadcast(nd4j::broadcast::BoolOps op, const std::vector<int>
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
void NDArray::applyBroadcast(nd4j::broadcast::IntOps op, const std::vector<int>& dimensions, const NDArray* other, NDArray* target, ExtraArguments* extraArgs) { void NDArray::applyBroadcast(nd4j::broadcast::IntOps op, const std::vector<int>& dimensions, const NDArray* other, NDArray* target, ExtraArguments* extraArgs) {
if (!isZ()) if (!isZ())
throw std::runtime_error("NDArray::applyBroadcast IntOps: you can't use this method on non-Integer array!"); throw std::runtime_error("NDArray::applyBroadcast IntOps: you can't use this method on non-Integer array!");
if(isEmpty() || other->isEmpty()) { if(isEmpty() || other->isEmpty()) {
@ -2941,7 +2825,7 @@ void NDArray::applyBroadcast(nd4j::broadcast::BoolOps op, const std::vector<int>
else else
NativeOpExecutioner::execInverseBroadcastInt(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), copy.data(), (int)copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets()); NativeOpExecutioner::execInverseBroadcastInt(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), copy.data(), (int)copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets());
registerSpecialUse({result}, {this, other}); registerSpecialUse({result}, {this, other});
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
void NDArray::applyBroadcast(nd4j::broadcast::Ops op, const std::initializer_list<int> dimensions, const NDArray* tadArray, NDArray* target, ExtraArguments* extraArgs) { void NDArray::applyBroadcast(nd4j::broadcast::Ops op, const std::initializer_list<int> dimensions, const NDArray* tadArray, NDArray* target, ExtraArguments* extraArgs) {

View File

@ -78,7 +78,6 @@ bool NDArray::isActualOnHostSide() const { return _buffer->isPrimaryActual();
bool NDArray::isActualOnDeviceSide() const { return _buffer->isSpecialActual(); } bool NDArray::isActualOnDeviceSide() const { return _buffer->isSpecialActual(); }
void NDArray::makeBothBuffersActual() const { if(!isActualOnHostSide()) syncToHost(); if(!isActualOnDeviceSide()) syncToDevice(); } void NDArray::makeBothBuffersActual() const { if(!isActualOnHostSide()) syncToHost(); if(!isActualOnDeviceSide()) syncToDevice(); }
/////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////
template<typename T> template<typename T>
__global__ static void fillAsTriangularCuda(const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const T val, const int lower, const int upper) { __global__ static void fillAsTriangularCuda(const void* vx, const Nd4jLong* xShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const T val, const int lower, const int upper) {

View File

@ -81,7 +81,8 @@ namespace nd4j {
// check the possibility of broadcast operation for set of arrays, if true then return resulting broadcasted shapeInfo // check the possibility of broadcast operation for set of arrays, if true then return resulting broadcasted shapeInfo
static bool evalCommonBroadcastShapeInfo(const std::vector<const NDArray*>& arrays, Nd4jLong*& resultShapeInfo, memory::Workspace* workspace = nullptr); static bool evalCommonBroadcastShapeInfo(const std::vector<const NDArray*>& arrays, Nd4jLong*& resultShapeInfo, memory::Workspace* workspace = nullptr);
// return sorted vector of dimensions of array with larger dimensions along which two input arrays have same shape // return sorted vector of dimensions common (same) for two arrays, dimensions values corresponds to array with bigger rank
// for example if arr1{2,7}, arr2{2,5,4,7} then vector = {0,3}
static std::vector<int> getDimsWithSameShape(const NDArray& max, const NDArray& min); static std::vector<int> getDimsWithSameShape(const NDArray& max, const NDArray& min);
// evaluate shapeInfo for resulting array of tile operation // evaluate shapeInfo for resulting array of tile operation
@ -169,6 +170,18 @@ namespace nd4j {
* @return * @return
*/ */
static Nd4jLong stringBufferHeaderRequirements(Nd4jLong numStrings); static Nd4jLong stringBufferHeaderRequirements(Nd4jLong numStrings);
/*
* check whether arr1/arr2 is sub-array of arr2/arr1,
* this method do not evaluate what array is sub-array, it returns true if arr1 is sub-array of arr2 or arr2 is sub-array of arr1
* sameDims is filled (and sorted) with dimensions values that match both in arr1 and arr2 shapes (unities are ignored)
* for example:
* if arr1{2,3} and arr2{2,4,3,7} then return true and sameDims contains {0,2}
* if arr1{1,1,3,1,3,1,1} and arr2{1,2,3,1,3} then return true and sameDims contains {2,4}
* if arr1{2,1,4,1,7,5} and arr2{1,1,4,5} then return true and sameDims contains {2,5}
static bool isSubArrayCase(const NDArray& arr1, const NDArray& arr2, std::vector<int>& sameDims);
*/
}; };

View File

@ -0,0 +1,84 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Yurii Shyrma (iuriish@yahoo.com)
//
#ifndef LIBND4J_TRUEBROADCASTHELPER_H
#define LIBND4J_TRUEBROADCASTHELPER_H
#include <NDArray.h>
namespace nd4j {
namespace helpers {
////////////////////////////////////////////////////////////////////////
template <typename X, typename Y, typename Z>
class TrueBroadcastHelper {
#ifdef __CUDACC__
template <typename OpType>
static __host__ void execLauncher(dim3 launchDims, cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo);
#else
template <typename OpType>
static void exec(const NDArray& xArr, const NDArray& yArr, NDArray& zArr);
#endif
public:
static void exec(const nd4j::broadcast::Ops opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr);
};
template <typename X, typename Y>
class TrueBroadcastBoolHelper {
#ifdef __CUDACC__
template <typename OpType>
static __host__ void execLauncher(dim3 launchDims, cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo);
#else
template <typename OpType>
static void exec(const NDArray& xArr, const NDArray& yArr, NDArray& zArr);
#endif
public:
static void exec(const nd4j::broadcast::BoolOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr);
};
////////////////////////////////////////////////////////////////////////
template <typename X>
class TrueBroadcastIntHelper {
#ifdef __CUDACC__
template <typename OpType>
static __host__ void execLauncher(dim3 launchDims, cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo);
#else
template <typename OpType>
static void exec(const NDArray& xArr, const NDArray& yArr, NDArray& zArr);
#endif
public:
static void exec(const nd4j::broadcast::IntOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr);
};
}
}
#endif //LIBND4J_BIDIAGONALUP_H

View File

@ -0,0 +1,218 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Yurii Shyrma (iuriish@yahoo.com)
//
#include <TrueBroadcastHelper.h>
using namespace simdOps;
namespace nd4j {
namespace helpers {
////////////////////////////////////////////////////////////////////////
template <typename X, typename Y, typename Z>
template<typename OpType>
void TrueBroadcastHelper<X, Y, Z>::exec(const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
const X* x = reinterpret_cast<X*>(xArr.getBuffer());
const Y* y = reinterpret_cast<Y*>(yArr.getBuffer());
Z* z = reinterpret_cast<Z*>(zArr.getBuffer());
const auto xShapeInfo = xArr.getShapeInfo();
const auto yShapeInfo = yArr.getShapeInfo();
const auto zShapeInfo = zArr.getShapeInfo();
const int xRank = xArr.rankOf();
const int yRank = yArr.rankOf();
const int zRank = zArr.rankOf();
const Nd4jLong zLen = zArr.lengthOf();
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(zLen > Environment::getInstance()->elementwiseThreshold()) firstprivate(xCoords, yCoords, zCoords))
for (Nd4jLong i = 0; i < zLen; ++i) {
shape::index2coords(i, zShapeInfo, zCoords.data());
for(int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) {
if(ix >= 0) {
if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) {
xCoords[ix--] = zCoords[iz];
} else {
xCoords[ix--] = 0;
}
}
if(iy >= 0) {
if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) {
yCoords[iy--] = zCoords[iz];
} else {
yCoords[iy--] = 0;
}
}
}
const auto xOffset = shape::getOffset(xShapeInfo, xCoords.data());
const auto yOffset = shape::getOffset(yShapeInfo, yCoords.data());
const auto zOffset = shape::getOffset(zShapeInfo, zCoords.data());
z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
}
}
template <typename X, typename Y, typename Z>
void TrueBroadcastHelper<X, Y, Z>::exec(const nd4j::broadcast::Ops opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
DISPATCH_BY_OPNUM_TTT(exec, PARAMS(xArr, yArr, zArr), BROADCAST_OPS);
}
////////////////////////////////////////////////////////////////////////
template <typename X, typename Z>
template<typename OpType>
void TrueBroadcastBoolHelper<X, Z>::exec(const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
const X* x = reinterpret_cast<X*>(xArr.getBuffer());
const X* y = reinterpret_cast<X*>(yArr.getBuffer());
Z* z = reinterpret_cast<Z*>(zArr.getBuffer());
const auto xShapeInfo = xArr.getShapeInfo();
const auto yShapeInfo = yArr.getShapeInfo();
const auto zShapeInfo = zArr.getShapeInfo();
const int xRank = xArr.rankOf();
const int yRank = yArr.rankOf();
const int zRank = zArr.rankOf();
const Nd4jLong zLen = zArr.lengthOf();
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(zLen > Environment::getInstance()->elementwiseThreshold()) firstprivate(xCoords, yCoords, zCoords))
for (Nd4jLong i = 0; i < zLen; ++i) {
shape::index2coords(i, zShapeInfo, zCoords.data());
for(int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) {
if(ix >= 0) {
if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) {
xCoords[ix--] = zCoords[iz];
} else {
xCoords[ix--] = 0;
}
}
if(iy >= 0) {
if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) {
yCoords[iy--] = zCoords[iz];
} else {
yCoords[iy--] = 0;
}
}
}
const auto xOffset = shape::getOffset(xShapeInfo, xCoords.data());
const auto yOffset = shape::getOffset(yShapeInfo, yCoords.data());
const auto zOffset = shape::getOffset(zShapeInfo, zCoords.data());
z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
}
}
template <typename X, typename Y>
void TrueBroadcastBoolHelper<X, Y>::exec(const nd4j::broadcast::BoolOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
DISPATCH_BY_OPNUM_TT(exec, PARAMS(xArr, yArr, zArr), BROADCAST_BOOL_OPS);
}
////////////////////////////////////////////////////////////////////////
template <typename X>
template<typename OpType>
void TrueBroadcastIntHelper<X>::exec(const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
const X* x = reinterpret_cast<X*>(xArr.getBuffer());
const X* y = reinterpret_cast<X*>(yArr.getBuffer());
X* z = reinterpret_cast<X*>(zArr.getBuffer());
const auto xShapeInfo = xArr.getShapeInfo();
const auto yShapeInfo = yArr.getShapeInfo();
const auto zShapeInfo = zArr.getShapeInfo();
const int xRank = xArr.rankOf();
const int yRank = yArr.rankOf();
const int zRank = zArr.rankOf();
const Nd4jLong zLen = zArr.lengthOf();
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(zLen > Environment::getInstance()->elementwiseThreshold()) firstprivate(xCoords, yCoords, zCoords))
for (Nd4jLong i = 0; i < zLen; ++i) {
shape::index2coords(i, zShapeInfo, zCoords.data());
for(int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) {
if(ix >= 0) {
if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) {
xCoords[ix--] = zCoords[iz];
} else {
xCoords[ix--] = 0;
}
}
if(iy >= 0) {
if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) {
yCoords[iy--] = zCoords[iz];
} else {
yCoords[iy--] = 0;
}
}
}
const auto xOffset = shape::getOffset(xShapeInfo, xCoords.data());
const auto yOffset = shape::getOffset(yShapeInfo, yCoords.data());
const auto zOffset = shape::getOffset(zShapeInfo, zCoords.data());
z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
}
}
template <typename X>
void TrueBroadcastIntHelper<X>::exec(const nd4j::broadcast::IntOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
DISPATCH_BY_OPNUM_T(exec, PARAMS(xArr, yArr, zArr), BROADCAST_INT_OPS);
}
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_0);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_1);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_2);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_3);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_4);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_5);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_6);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_7);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_8);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_9);
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastBoolHelper, , LIBND4J_TYPES, BOOL_TYPES);
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastIntHelper, , INTEGER_TYPES);
}
}

View File

@ -0,0 +1,309 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Yurii Shyrma (iuriish@yahoo.com)
//
// #include <exceptions/cuda_exception.h>
#include <TrueBroadcastHelper.h>
#include <PointersManager.h>
#include <execution/LaunchContext.h>
#include <specials.h>
#include <logger.h>
// #include <cuda_runtime.h>
// #include <cuda.h>
using namespace simdOps;
namespace nd4j {
namespace helpers {
////////////////////////////////////////////////////////////////////////
template <typename X, typename Y, typename Z, typename OpType>
__global__ static void trueBroadcastCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo) {
const auto x = reinterpret_cast<const X*>(vx);
const auto y = reinterpret_cast<const Y*>(vy);
auto z = reinterpret_cast<Z*>(vz);
__shared__ int xRank, yRank, zRank;
__shared__ Nd4jLong zLen, totalThreads, *sharedMem; // xLen == zLen, except when xRank = 1, in this case zLen = 2*xLen
if (threadIdx.x == 0) {
extern __shared__ unsigned char shmem[];
sharedMem = reinterpret_cast<Nd4jLong*>(shmem);
xRank = shape::rank(xShapeInfo);
yRank = shape::rank(yShapeInfo);
zRank = shape::rank(zShapeInfo);
zLen = shape::length(zShapeInfo);
totalThreads = gridDim.x * blockDim.x;
}
__syncthreads();
auto xCoords = sharedMem + threadIdx.x * (xRank + yRank + zRank);
auto yCoords = xCoords + xRank;
auto zCoords = yCoords + yRank;
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
for (Nd4jLong i = tid; i < zLen; i += totalThreads) {
shape::index2coords(i, zShapeInfo, zCoords);
for(int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) {
if(ix >= 0)
if(xShapeInfo[ix + 1] == zShapeInfo[iz + 1])
xCoords[ix--] = zCoords[iz];
else
xCoords[ix--] = 0;
if(iy >= 0)
if(yShapeInfo[iy + 1] == zShapeInfo[iz + 1])
yCoords[iy--] = zCoords[iz];
else
yCoords[iy--] = 0;
}
const auto xOffset = shape::getOffset(xShapeInfo, xCoords);
const auto zOffset = shape::getOffset(zShapeInfo, zCoords);
const auto yOffset = shape::getOffset(yShapeInfo, yCoords);
z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
}
}
////////////////////////////////////////////////////////////////////////
template<typename X, typename Y, typename Z>
template <typename OpType>
void TrueBroadcastHelper<X,Y,Z>::execLauncher(dim3 launchDims, cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo) {
trueBroadcastCuda<X, Y, Z, OpType><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo);
}
//////////////////////////////////////////////////////////////////////////
template<typename X, typename Y, typename Z>
void TrueBroadcastHelper<X,Y,Z>::exec(const nd4j::broadcast::Ops opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
dim3 launchDims;
launchDims.x = MAX_NUM_THREADS / 8; // threadsPerBlock
launchDims.y = (zArr.lengthOf() + launchDims.x - 1) / launchDims.x; // blocksPerGrid
launchDims.z = sizeof(Nd4jLong) * launchDims.x * (xArr.rankOf() + yArr.rankOf() + zArr.rankOf()) + 128; // sharedMem
PointersManager manager(xArr.getContext(), "TrueBroadcastHelper<X,Y,Z>::exec");
NDArray::prepareSpecialUse({&zArr}, {&xArr, &yArr});
DISPATCH_BY_OPNUM_TTT(execLauncher, PARAMS(launchDims, xArr.getContext()->getCudaStream(), xArr.getSpecialBuffer(), xArr.getSpecialShapeInfo(), yArr.getSpecialBuffer(), yArr.getSpecialShapeInfo(), zArr.specialBuffer(), zArr.specialShapeInfo()), OPS_A(BROADCAST_OPS));
NDArray::registerSpecialUse({&zArr}, {&xArr, &yArr});
manager.synchronize();
}
////////////////////////////////////////////////////////////////////////
template <typename X, typename Z, typename OpType>
__global__ static void trueBroadcastBoolCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo) {
const auto x = reinterpret_cast<const X*>(vx);
const auto y = reinterpret_cast<const X*>(vy);
auto z = reinterpret_cast<Z*>(vz);
__shared__ int xRank, yRank, zRank;
__shared__ Nd4jLong zLen, totalThreads, *sharedMem; // xLen == zLen, except when xRank = 1, in this case zLen = 2*xLen
if (threadIdx.x == 0) {
extern __shared__ unsigned char shmem[];
sharedMem = reinterpret_cast<Nd4jLong*>(shmem);
xRank = shape::rank(xShapeInfo);
yRank = shape::rank(yShapeInfo);
zRank = shape::rank(zShapeInfo);
zLen = shape::length(zShapeInfo);
totalThreads = gridDim.x * blockDim.x;
}
__syncthreads();
auto xCoords = sharedMem + threadIdx.x * (xRank + yRank + zRank);
auto yCoords = xCoords + xRank;
auto zCoords = yCoords + yRank;
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
for (Nd4jLong i = tid; i < zLen; i += totalThreads) {
shape::index2coords(i, zShapeInfo, zCoords);
for(int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) {
if(ix >= 0)
if(xShapeInfo[ix + 1] == zShapeInfo[iz + 1])
xCoords[ix--] = zCoords[iz];
else
xCoords[ix--] = 0;
if(iy >= 0)
if(yShapeInfo[iy + 1] == zShapeInfo[iz + 1])
yCoords[iy--] = zCoords[iz];
else
yCoords[iy--] = 0;
}
const auto xOffset = shape::getOffset(xShapeInfo, xCoords);
const auto zOffset = shape::getOffset(zShapeInfo, zCoords);
const auto yOffset = shape::getOffset(yShapeInfo, yCoords);
z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
}
}
////////////////////////////////////////////////////////////////////////
template<typename X, typename Z>
template <typename OpType>
void TrueBroadcastBoolHelper<X,Z>::execLauncher(dim3 launchDims, cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo) {
trueBroadcastBoolCuda<X,Z,OpType><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo);
}
//////////////////////////////////////////////////////////////////////////
template<typename X, typename Y>
void TrueBroadcastBoolHelper<X,Y>::exec(const nd4j::broadcast::BoolOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
dim3 launchDims;
launchDims.x = MAX_NUM_THREADS / 8; // threadsPerBlock
launchDims.y = (zArr.lengthOf() + launchDims.x - 1) / launchDims.x; // blocksPerGrid
launchDims.z = sizeof(Nd4jLong) * launchDims.x * (xArr.rankOf() + yArr.rankOf() + zArr.rankOf()) + 128; // sharedMem
PointersManager manager(xArr.getContext(), "TrueBroadcastBoolHelper<X,Y>::exec");
NDArray::prepareSpecialUse({&zArr}, {&xArr, &yArr});
DISPATCH_BY_OPNUM_TT(execLauncher, PARAMS(launchDims, xArr.getContext()->getCudaStream(), xArr.getSpecialBuffer(), xArr.getSpecialShapeInfo(), yArr.getSpecialBuffer(), yArr.getSpecialShapeInfo(), zArr.specialBuffer(), zArr.specialShapeInfo()), OPS_A(BROADCAST_BOOL_OPS));
NDArray::registerSpecialUse({&zArr}, {&xArr, &yArr});
manager.synchronize();
}
////////////////////////////////////////////////////////////////////////
template <typename X, typename OpType>
__global__ static void trueBroadcastIntCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo) {
const auto x = reinterpret_cast<const X*>(vx);
const auto y = reinterpret_cast<const X*>(vy);
auto z = reinterpret_cast<X*>(vz);
__shared__ int xRank, yRank, zRank;
__shared__ Nd4jLong zLen, totalThreads, *sharedMem; // xLen == zLen, except when xRank = 1, in this case zLen = 2*xLen
if (threadIdx.x == 0) {
extern __shared__ unsigned char shmem[];
sharedMem = reinterpret_cast<Nd4jLong*>(shmem);
xRank = shape::rank(xShapeInfo);
yRank = shape::rank(yShapeInfo);
zRank = shape::rank(zShapeInfo);
zLen = shape::length(zShapeInfo);
totalThreads = gridDim.x * blockDim.x;
}
__syncthreads();
auto xCoords = sharedMem + threadIdx.x * (xRank + yRank + zRank);
auto yCoords = xCoords + xRank;
auto zCoords = yCoords + yRank;
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
for (Nd4jLong i = tid; i < zLen; i += totalThreads) {
shape::index2coords(i, zShapeInfo, zCoords);
for(int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) {
if(ix >= 0)
if(xShapeInfo[ix + 1] == zShapeInfo[iz + 1])
xCoords[ix--] = zCoords[iz];
else
xCoords[ix--] = 0;
if(iy >= 0)
if(yShapeInfo[iy + 1] == zShapeInfo[iz + 1])
yCoords[iy--] = zCoords[iz];
else
yCoords[iy--] = 0;
}
const auto xOffset = shape::getOffset(xShapeInfo, xCoords);
const auto zOffset = shape::getOffset(zShapeInfo, zCoords);
const auto yOffset = shape::getOffset(yShapeInfo, yCoords);
z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
}
}
////////////////////////////////////////////////////////////////////////
template<typename X>
template <typename OpType>
void TrueBroadcastIntHelper<X>::execLauncher(dim3 launchDims, cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo) {
trueBroadcastIntCuda<X,OpType><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo);
}
//////////////////////////////////////////////////////////////////////////
template<typename X>
void TrueBroadcastIntHelper<X>::exec(const nd4j::broadcast::IntOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
dim3 launchDims;
launchDims.x = MAX_NUM_THREADS / 8; // threadsPerBlock
launchDims.y = (zArr.lengthOf() + launchDims.x - 1) / launchDims.x; // blocksPerGrid
launchDims.z = sizeof(Nd4jLong) * launchDims.x * (xArr.rankOf() + yArr.rankOf() + zArr.rankOf()) + 128; // sharedMem
PointersManager manager(xArr.getContext(), "TrueBroadcastIntHelper<X>::exec");
NDArray::prepareSpecialUse({&zArr}, {&xArr, &yArr});
DISPATCH_BY_OPNUM_T(execLauncher, PARAMS(launchDims, xArr.getContext()->getCudaStream(), xArr.getSpecialBuffer(), xArr.getSpecialShapeInfo(), yArr.getSpecialBuffer(), yArr.getSpecialShapeInfo(), zArr.specialBuffer(), zArr.specialShapeInfo()), OPS_A(BROADCAST_INT_OPS));
NDArray::registerSpecialUse({&zArr}, {&xArr, &yArr});
manager.synchronize();
}
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_0);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_1);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_2);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_3);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_4);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_5);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_6);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_7);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_8);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_9);
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastBoolHelper, , LIBND4J_TYPES, BOOL_TYPES);
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastIntHelper, , INTEGER_TYPES);
}
}

View File

@ -515,21 +515,30 @@ bool ShapeUtils::evalCommonBroadcastShapeInfo(const std::vector<const NDArray*>&
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
// return sorted vector of dimensions of array with larger dimensions number along which two input arrays have same shape // return sorted vector of dimensions common (same) for two arrays, dimensions values corresponds to array with bigger rank
// the array with larger dimensions number has to be passed as first argument // for example if arr1{2,7}, arr2{2,5,4,7} then vector = {0,3}
std::vector<int> ShapeUtils::getDimsWithSameShape(const NDArray& max, const NDArray& min) { std::vector<int> ShapeUtils::getDimsWithSameShape(const NDArray& arr1, const NDArray& arr2) {
std::vector<int> result; const NDArray *min, *max;
auto maxShapeInfo = max.getShapeInfo();
auto minShapeInfo = min.getShapeInfo();
int maxRank = maxShapeInfo[0];
int minRank = minShapeInfo[0];
for (int i = 1; i <= minRank; ++i) if(arr1.rankOf() >= arr2.rankOf()) {
if (minShapeInfo[i] == maxShapeInfo[maxRank - minRank + i]) max = &arr1;
result.emplace_back(maxRank - minRank + i - 1); min = &arr2;
}
else {
max = &arr2;
min = &arr1;
}
return result; const int rankDiff = max->rankOf() - min->rankOf();
std::vector<int> dims;
for (int i = 0; i < min->rankOf(); ++i)
if (min->sizeAt(i) == max->sizeAt(rankDiff + i))
dims.emplace_back(rankDiff + i);
return dims;
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
@ -997,14 +1006,56 @@ std::vector<int> ShapeUtils::tadAxesForSimpleBroadcast(const NDArray& max, const
} }
Nd4jLong ShapeUtils::stringBufferHeaderRequirements(Nd4jLong numStrings) { Nd4jLong ShapeUtils::stringBufferHeaderRequirements(Nd4jLong numStrings) {
// we store +1 offset // we store +1 offset
auto base = numStrings + 1; auto base = numStrings + 1;
// since we return number of bytes... // since we return number of bytes...
return base * sizeof(Nd4jLong); return base * sizeof(Nd4jLong);
}
////////////////////////////////////////////////////////////////////////////////
/*
bool ShapeUtils::isSubArrayCase(const NDArray& arr1, const NDArray& arr2, std::vector<int>& sameDims) {
if(!sameDims.empty())
sameDims.clear();
const NDArray* max = &arr1;
const NDArray* min = &arr2;
if(arr1.lengthOf() < arr2.lengthOf()) {
max = &arr2;
min = &arr1;
} }
int numUnitiesInMin = 0;
for (int iMax = -1, iMin = -1; iMax >= -max->rankOf() && iMin >= -min->rankOf(); ) {
if(max->sizeAt(iMax) == 1) { // ignore unities in shape
--iMax;
continue;
}
if(min->sizeAt(iMin) == 1) { // ignore unities in shape
++numUnitiesInMin;
--iMin;
continue;
}
if(max->sizeAt(iMax) == min->sizeAt(iMin)) {
sameDims.insert(sameDims.begin(), iMax + max->rankOf());
--iMin;
}
--iMax;
}
return sameDims.size() + numUnitiesInMin == min->rankOf();
}
*/
} }