[WIP] Int broadcastables (#195)
* Removed invalid resource and fixed tests Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * legacy scalar/pairwise/broadcast int ops Signed-off-by: raver119 <raver119@gmail.com> * NDArray int broadcastables Signed-off-by: raver119 <raver119@gmail.com> * few more bitwise tests Signed-off-by: raver119 <raver119@gmail.com> * java side update Signed-off-by: raver119 <raver119@gmail.com> * Argument type changed for shift ops Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com> * legacy scalar/pairwise/broadcast int ops Signed-off-by: raver119 <raver119@gmail.com> * NDArray int broadcastables Signed-off-by: raver119 <raver119@gmail.com> * few more bitwise tests Signed-off-by: raver119 <raver119@gmail.com> * java side update Signed-off-by: raver119 <raver119@gmail.com> * Argument type changed for shift ops Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>master
parent
130c9aa536
commit
1003428a18
|
@ -103,7 +103,7 @@ public class FastTextTest extends BaseDL4JTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testPredict() throws IOException {
|
||||
public void testPredict() {
|
||||
String text = "I like soccer";
|
||||
|
||||
FastText fastText = new FastText(supModelFile);
|
||||
|
@ -118,7 +118,7 @@ public class FastTextTest extends BaseDL4JTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testPredictProbability() throws IOException {
|
||||
public void testPredictProbability() {
|
||||
String text = "I like soccer";
|
||||
|
||||
FastText fastText = new FastText(supModelFile);
|
||||
|
|
|
@ -34,6 +34,7 @@
|
|||
#include <op_enums.h>
|
||||
#include <ops/BroadcastOpsTuple.h>
|
||||
#include <ops/BroadcastBoolOpsTuple.h>
|
||||
#include <ops/BroadcastIntOpsTuple.h>
|
||||
#include <array/ExtraArguments.h>
|
||||
#include <Status.h>
|
||||
#include <ShapeDescriptor.h>
|
||||
|
@ -659,6 +660,8 @@ namespace nd4j {
|
|||
|
||||
void applyPairwiseTransform(nd4j::pairwise::BoolOps op, const NDArray *other, NDArray *target, ExtraArguments *extraParams = nullptr) const;
|
||||
|
||||
void applyPairwiseTransform(nd4j::pairwise::IntOps op, const NDArray *other, NDArray *target, ExtraArguments *extraParams = nullptr) const;
|
||||
|
||||
/**
|
||||
* apply operation which requires broadcasting, broadcast a smaller array (tad) along bigger one (this)
|
||||
* tad - array to broadcast
|
||||
|
@ -672,6 +675,9 @@ namespace nd4j {
|
|||
|
||||
void applyBroadcast(nd4j::broadcast::BoolOps op, const std::vector<int> &dimensions, const NDArray *tad, NDArray *target = nullptr, ExtraArguments *extraArgs = nullptr);
|
||||
|
||||
void applyBroadcast(nd4j::broadcast::IntOps op, const std::vector<int> &dimensions, const NDArray *tad, NDArray *target = nullptr, ExtraArguments *extraArgs = nullptr);
|
||||
|
||||
|
||||
/**
|
||||
* apply operation which requires broadcasting, broadcast one tensor along another, also this method checks the possibility of broadcasting
|
||||
* other - input array
|
||||
|
@ -692,6 +698,9 @@ namespace nd4j {
|
|||
|
||||
void applyTrueBroadcast(nd4j::BroadcastBoolOpsTuple op, const NDArray* other, NDArray* target, const bool checkTargetShape = true, ExtraArguments *extraArgs = nullptr) const;
|
||||
|
||||
void applyTrueBroadcast(nd4j::BroadcastIntOpsTuple op, const NDArray* other, NDArray* target, const bool checkTargetShape = true, ExtraArguments *extraArgs = nullptr) const;
|
||||
|
||||
|
||||
/**
|
||||
* apply a scalar operation to an array
|
||||
* scalar - input scalar
|
||||
|
@ -704,6 +713,9 @@ namespace nd4j {
|
|||
template <typename T>
|
||||
void applyScalar(nd4j::scalar::BoolOps op, const T scalar, NDArray* target, ExtraArguments *extraParams = nullptr) const;
|
||||
|
||||
template <typename T>
|
||||
void applyScalar(nd4j::scalar::IntOps op, const T scalar, NDArray* target, ExtraArguments *extraParams = nullptr) const;
|
||||
|
||||
/**
|
||||
* apply a scalar operation to an array
|
||||
* scalar - input array which is simple scalar
|
||||
|
@ -714,6 +726,7 @@ namespace nd4j {
|
|||
|
||||
void applyScalarArr(nd4j::scalar::BoolOps op, const NDArray* scalar, NDArray* target, ExtraArguments *extraParams = nullptr) const;
|
||||
|
||||
void applyScalarArr(nd4j::scalar::IntOps op, const NDArray* scalar, NDArray* target, ExtraArguments *extraParams = nullptr) const;
|
||||
|
||||
#if defined(__CUDABLAS__) //&& defined(BUILD_TESTS)
|
||||
template <typename Lambda>
|
||||
|
|
|
@ -2663,6 +2663,88 @@ void NDArray::applyTrueBroadcast(nd4j::BroadcastBoolOpsTuple op, const NDArray*
|
|||
delete pTarget;
|
||||
}
|
||||
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void NDArray::applyTrueBroadcast(nd4j::BroadcastIntOpsTuple op, const NDArray* other, NDArray* target, const bool checkTargetShape, ExtraArguments *extraArgs) const {
|
||||
if (isS())
|
||||
throw std::runtime_error("NDArray::applyTrueBroadcast bool: you can't use this method on String array!");
|
||||
if(target == nullptr || other == nullptr)
|
||||
throw std::runtime_error("NDArray::applyTrueBroadcast int method: target or other = nullptr !");
|
||||
|
||||
if (isEmpty() || other->isEmpty())
|
||||
return;
|
||||
|
||||
NDArray::prepareSpecialUse({target}, {this, other});
|
||||
|
||||
if (isScalar()) {
|
||||
NDArray temp(target->_shapeInfo, dataType(), false, getContext());
|
||||
temp.assign(this);
|
||||
temp.applyPairwiseTransform(op.p, other, target, extraArgs);
|
||||
return;
|
||||
}
|
||||
if (other->isScalar()) {
|
||||
this->applyScalarArr(op.s, other, target, extraArgs);
|
||||
return;
|
||||
}
|
||||
|
||||
const NDArray* min(other);
|
||||
const NDArray* max(this);
|
||||
|
||||
if(this->rankOf() < other->rankOf()) {
|
||||
max = other;
|
||||
min = this;
|
||||
}
|
||||
|
||||
if(checkTargetShape) {
|
||||
Nd4jLong* newShapeInfo = nullptr;
|
||||
if(!ShapeUtils::evalBroadcastShapeInfo(*max, *min, false, newShapeInfo, getContext()->getWorkspace())) // the rank of target array must be equal to max->rankOf)()
|
||||
throw std::runtime_error("NDArray::applyTrueBroadcast method: the shapes of this and other arrays are not suitable for broadcast operation !");
|
||||
if(!shape::equalsSoft(target->_shapeInfo, newShapeInfo) || target->dataType() != this->dataType())
|
||||
throw std::runtime_error("NDArray::applyTrueBroadcast int method: the shape or type of target array is wrong !");
|
||||
if(dataType() != other->dataType())
|
||||
throw std::invalid_argument("NDArray::applyTrueBroadcast int method: this and other arrays must have the same type !");
|
||||
}
|
||||
|
||||
NDArray* pTarget = (max->dataType() == target->dataType()) ? target : new NDArray(target->ordering(), target->getShapeAsVector(), max->dataType(), target->getContext());
|
||||
// check whether max array has to be tiled
|
||||
if(!max->isSameShape(target)) {
|
||||
// evaluate repeating dimensions for tile operation
|
||||
std::vector<Nd4jLong> repeatMax(max->rankOf());
|
||||
for(int i = 1; i <= max->rankOf(); ++i)
|
||||
repeatMax[i-1] = (target->_shapeInfo[i] / max->_shapeInfo[i]);
|
||||
max->tile(repeatMax, *pTarget);
|
||||
}
|
||||
else
|
||||
pTarget->assign(max);
|
||||
|
||||
// check whether min array has to be tiled
|
||||
std::vector<Nd4jLong> repeatMin(min->rankOf());
|
||||
int product = 1;
|
||||
for(int i = min->rankOf(); i >=1 ; --i) {
|
||||
repeatMin[i-1] = (target->_shapeInfo[target->rankOf() - min->rankOf() + i] / min->_shapeInfo[i]);
|
||||
product *= repeatMin[i-1];
|
||||
}
|
||||
|
||||
auto pMin = const_cast<NDArray *>(min);
|
||||
if(product != 1 )
|
||||
pMin = new NDArray(min->tile(repeatMin));
|
||||
|
||||
std::vector<int> sameDims = ShapeUtils::getDimsWithSameShape(*target, *pMin);
|
||||
|
||||
if(max == this)
|
||||
pTarget->applyBroadcast(op.b, sameDims, pMin, target, extraArgs);
|
||||
else
|
||||
pMin->applyBroadcast(op.b, sameDims, pTarget, target, extraArgs);
|
||||
|
||||
if(pMin != min)
|
||||
delete pMin;
|
||||
if(pTarget != target)
|
||||
delete pTarget;
|
||||
}
|
||||
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
NDArray NDArray::applyTrueBroadcast(nd4j::BroadcastOpsTuple op, const NDArray& other, ExtraArguments *extraArgs) const {
|
||||
if (isEmpty() || other.isEmpty()) {
|
||||
|
@ -2801,6 +2883,67 @@ void NDArray::applyBroadcast(nd4j::broadcast::BoolOps op, const std::vector<int>
|
|||
registerSpecialUse({result}, {this, other});
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void NDArray::applyBroadcast(nd4j::broadcast::IntOps op, const std::vector<int>& dimensions, const NDArray* other, NDArray* target, ExtraArguments* extraArgs) {
|
||||
if (!isZ())
|
||||
throw std::runtime_error("NDArray::applyBroadcast IntOps: you can't use this method on non-Integer array!");
|
||||
if(isEmpty() || other->isEmpty()) {
|
||||
if(!target->isEmpty())
|
||||
throw std::runtime_error("NDArray::applyBroadcast IntOps: when some of input arrays (or both) is empty, target array must be empty as well !");
|
||||
return;
|
||||
}
|
||||
|
||||
if (dimensions.empty())
|
||||
return;
|
||||
|
||||
auto result = target == nullptr ? this : target;
|
||||
|
||||
if (other->lengthOf() == lengthOf() && this->rankOf() == other->rankOf()) {
|
||||
NDArray::prepareSpecialUse({result}, {this, other});
|
||||
NativeOpExecutioner::execPairwiseIntTransform(getContext(), fromBroadcastToPairwiseInt(op), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), nullptr);
|
||||
NDArray::registerSpecialUse({result}, {this, other});
|
||||
return;
|
||||
}
|
||||
|
||||
NDArray *min(nullptr), *max(nullptr);
|
||||
if((lengthOf() > other->lengthOf()) || (lengthOf() == other->lengthOf() && rankOf() >= other->rankOf())) {
|
||||
max = this;
|
||||
min = const_cast<NDArray*>(other);
|
||||
}
|
||||
else {
|
||||
max = const_cast<NDArray*>(other);
|
||||
min = this;
|
||||
}
|
||||
|
||||
if(result->dataType() != dataType())
|
||||
throw std::invalid_argument("NDArray::applyBroadcast int method: type of target array must be the same as input!");
|
||||
if(!result->isSameShape(max))
|
||||
throw std::invalid_argument("NDArray::applyBroadcast int method: max and target arrays must have the same shape !");
|
||||
if(_dataType != other->_dataType)
|
||||
throw std::invalid_argument("NDArray::applyBroadcast int method: this and other arrays must have the same type !");
|
||||
|
||||
std::vector<int> copy(dimensions);
|
||||
|
||||
if (dimensions.size() > 1)
|
||||
std::sort(copy.begin(), copy.end());
|
||||
|
||||
Nd4jLong tadLength = shape::tadLength(max->shapeInfo(), copy.data(), (int) copy.size());
|
||||
if (tadLength != min->lengthOf())
|
||||
throw std::runtime_error("Tad length mismatch");
|
||||
|
||||
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(max->shapeInfo(), copy);
|
||||
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(result->shapeInfo(), copy);
|
||||
|
||||
// TODO: eventually we want separate tads here
|
||||
NDArray::prepareSpecialUse({result}, {this, other});
|
||||
if(max == this)
|
||||
NativeOpExecutioner::execBroadcastInt( getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), copy.data(), (int)copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets());
|
||||
else
|
||||
NativeOpExecutioner::execInverseBroadcastInt(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), copy.data(), (int)copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets());
|
||||
registerSpecialUse({result}, {this, other});
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void NDArray::applyBroadcast(nd4j::broadcast::Ops op, const std::initializer_list<int> dimensions, const NDArray* tadArray, NDArray* target, ExtraArguments* extraArgs) {
|
||||
std::vector<int> vec(dimensions);
|
||||
|
@ -3043,6 +3186,22 @@ void NDArray::applyPairwiseTransform(nd4j::pairwise::BoolOps op, const NDArray *
|
|||
NDArray::registerSpecialUse({target}, {this, other});
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
void NDArray::applyPairwiseTransform(nd4j::pairwise::IntOps op, const NDArray *other, NDArray *target, ExtraArguments *extraParams) const{
|
||||
if (isS())
|
||||
throw std::runtime_error("NDArray::applyPairwiseTransform IntOps: you can't use this method on String array!");
|
||||
if (other->lengthOf() != target->lengthOf())
|
||||
throw std::invalid_argument("NDArray::applyPairwiseTransform IntOps method - lengths of arrays are mismatched");
|
||||
if (!target->isZ())
|
||||
throw std::invalid_argument("NDArray::applyPairwiseTransform IntOps method - result must have bool type");
|
||||
if (dataType() != other->dataType())
|
||||
throw std::invalid_argument("NDArray::applyPairwiseTransform IntOps method - this and other arrays must have the same type !");
|
||||
|
||||
NDArray::prepareSpecialUse({target}, {this, other});
|
||||
NativeOpExecutioner::execPairwiseIntTransform(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()) : nullptr);
|
||||
NDArray::registerSpecialUse({target}, {this, other});
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void NDArray::applyPairwiseTransform(nd4j::pairwise::Ops op, const NDArray& other, ExtraArguments *extraParams) {
|
||||
applyPairwiseTransform(op, &other, this, extraParams);
|
||||
|
@ -3585,6 +3744,45 @@ template void NDArray::applyScalar<int8_t>(nd4j::scalar::BoolOps op, const int8_
|
|||
template void NDArray::applyScalar<uint8_t>(nd4j::scalar::BoolOps op, const uint8_t scalar, NDArray *target, ExtraArguments *extraParams) const;
|
||||
template void NDArray::applyScalar<bool>(nd4j::scalar::BoolOps op, const bool scalar, NDArray *target, ExtraArguments *extraParams) const;
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void NDArray::applyScalarArr(nd4j::scalar::IntOps op, const NDArray* scalar, NDArray *target, ExtraArguments *extraParams) const {
|
||||
if (isS())
|
||||
throw std::runtime_error("NDArray::applyScalarArr IntOps: you can't use this method on String array!");
|
||||
|
||||
if (target == nullptr || target->dataType() != this->dataType())
|
||||
throw std::invalid_argument("NDArray::applyScalarArr int method: target is nullptr or has not bool type!");
|
||||
if (dataType() != scalar->dataType()) {
|
||||
nd4j_printf("NDArray::applyScalarArr IntOps: this dtype: [%i]; scalar dtype: [%i]\n", this->dataType(), scalar->dataType());
|
||||
throw std::invalid_argument("NDArray::applyScalarArr int method: this and scalar arrays must have the same type!");
|
||||
}
|
||||
|
||||
NDArray::prepareSpecialUse({target}, {this, scalar});
|
||||
NativeOpExecutioner::execScalarInt(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), scalar->getBuffer(), scalar->getShapeInfo(), scalar->getSpecialBuffer(), scalar->getSpecialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()): nullptr);
|
||||
NDArray::registerSpecialUse({target}, {this, scalar});
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
void NDArray::applyScalar(nd4j::scalar::IntOps op, const T scalar, NDArray *target, ExtraArguments *extraParams) const {
|
||||
|
||||
NDArray scalarArr = NDArrayFactory::create(this->dataType(), scalar, getContext());
|
||||
applyScalarArr(op, &scalarArr, target, extraParams);
|
||||
}
|
||||
|
||||
template <> void NDArray::applyScalar(nd4j::scalar::IntOps op, const NDArray* scalar, NDArray *target, ExtraArguments *extraParams) const { throw std::runtime_error("NDArray::applyScalar<NDArray*> method: do not use me!");}
|
||||
template void NDArray::applyScalar<double>(nd4j::scalar::IntOps op, const double scalar, NDArray *target, ExtraArguments *extraParams) const;
|
||||
template void NDArray::applyScalar<float>(nd4j::scalar::IntOps op, const float scalar, NDArray *target, ExtraArguments *extraParams) const;
|
||||
template void NDArray::applyScalar<float16>(nd4j::scalar::IntOps op, const float16 scalar, NDArray *target, ExtraArguments *extraParams) const;
|
||||
template void NDArray::applyScalar<bfloat16>(nd4j::scalar::IntOps op, const bfloat16 scalar, NDArray *target, ExtraArguments *extraParams) const;
|
||||
template void NDArray::applyScalar<Nd4jLong>(nd4j::scalar::IntOps op, const Nd4jLong scalar, NDArray *target, ExtraArguments *extraParams) const;
|
||||
template void NDArray::applyScalar<int>(nd4j::scalar::IntOps op, const int scalar, NDArray *target, ExtraArguments *extraParams) const;
|
||||
template void NDArray::applyScalar<int16_t>(nd4j::scalar::IntOps op, const int16_t scalar, NDArray *target, ExtraArguments *extraParams) const;
|
||||
template void NDArray::applyScalar<int8_t>(nd4j::scalar::IntOps op, const int8_t scalar, NDArray *target, ExtraArguments *extraParams) const;
|
||||
template void NDArray::applyScalar<uint8_t>(nd4j::scalar::IntOps op, const uint8_t scalar, NDArray *target, ExtraArguments *extraParams) const;
|
||||
template void NDArray::applyScalar<bool>(nd4j::scalar::IntOps op, const bool scalar, NDArray *target, ExtraArguments *extraParams) const;
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
void NDArray::applyIndexReduce(nd4j::indexreduce::Ops op, NDArray* target, const std::vector<int>& dimensions, const ExtraArguments *extraParams) const {
|
||||
if (isS())
|
||||
|
|
|
@ -189,6 +189,16 @@ static void execScalarBool(nd4j::LaunchContext *lc,
|
|||
void *dScalar, Nd4jLong *dSscalarShapeInfo,
|
||||
void *extraParams, bool allowParallelism = true);
|
||||
|
||||
static void execScalarInt(nd4j::LaunchContext *lc,
|
||||
int opNum,
|
||||
void *hX, Nd4jLong *hXShapeInfo,
|
||||
void *dX, Nd4jLong *dXShapeInfo,
|
||||
void *hZ, Nd4jLong *hZShapeInfo,
|
||||
void *dZ, Nd4jLong *dZShapeInfo,
|
||||
void *hScalar, Nd4jLong *hSscalarShapeInfo,
|
||||
void *dScalar, Nd4jLong *dSscalarShapeInfo,
|
||||
void *extraParams, bool allowParallelism = true);
|
||||
|
||||
static void execScalar(nd4j::LaunchContext *lc,
|
||||
int opNum,
|
||||
void *hX, Nd4jLong *hXShapeInfo,
|
||||
|
@ -215,6 +225,20 @@ static void execScalarBool(nd4j::LaunchContext *lc,
|
|||
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
|
||||
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ);
|
||||
|
||||
static void execScalarInt(nd4j::LaunchContext *lc,
|
||||
int opNum,
|
||||
void *hX, Nd4jLong *hXShapeInfo,
|
||||
void *dX, Nd4jLong *dXShapeInfo,
|
||||
void *extraParams,
|
||||
void *hZ, Nd4jLong *hZShapeInfo,
|
||||
void *dZ, Nd4jLong *dZShapeInfo,
|
||||
void *hScalars, Nd4jLong *hScalarShapeInfo,
|
||||
void *dScalars, Nd4jLong *dScalarShapeInfo,
|
||||
int *dimension, int dimensionLength,
|
||||
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
|
||||
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ);
|
||||
|
||||
|
||||
/**
|
||||
*
|
||||
* @param opNum
|
||||
|
@ -275,6 +299,30 @@ static void execScalarBool(nd4j::LaunchContext *lc,
|
|||
int *dimension, int dimensionLength,
|
||||
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
|
||||
Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
|
||||
|
||||
static void execBroadcastInt(nd4j::LaunchContext *lc,
|
||||
int opNum,
|
||||
void *hX, Nd4jLong *hXShapeInfo,
|
||||
void *dX, Nd4jLong *dXShapeInfo,
|
||||
void *hY, Nd4jLong *hYShapeInfo,
|
||||
void *dY, Nd4jLong *dYShapeInfo,
|
||||
void *hZ, Nd4jLong *hZShapeInfo,
|
||||
void *dZ, Nd4jLong *dZShapeInfo,
|
||||
int *dimension, int dimensionLength,
|
||||
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
|
||||
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ);
|
||||
|
||||
static void execInverseBroadcastInt(nd4j::LaunchContext *lc,
|
||||
int opNum,
|
||||
void *x, Nd4jLong *xShapeInfo,
|
||||
void *dX, Nd4jLong *dXShapeInfo,
|
||||
void *y, Nd4jLong *yShapeInfo,
|
||||
void *dY, Nd4jLong *dYShapeInfo,
|
||||
void *result, Nd4jLong *resultShapeInfo,
|
||||
void *dZ, Nd4jLong *dZShapeInfo,
|
||||
int *dimension, int dimensionLength,
|
||||
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
|
||||
Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
|
||||
|
||||
/**
|
||||
*
|
||||
|
@ -308,6 +356,16 @@ static void execScalarBool(nd4j::LaunchContext *lc,
|
|||
void *dZ, Nd4jLong *dZShapeInfo,
|
||||
void *extraParams);
|
||||
|
||||
static void execPairwiseIntTransform(nd4j::LaunchContext *lc,
|
||||
int opNum,
|
||||
void *hX, Nd4jLong *hXShapeInfo,
|
||||
void *dX, Nd4jLong *dXShapeInfo,
|
||||
void *hY, Nd4jLong *hYShapeInfo,
|
||||
void *dY, Nd4jLong *dYShapeInfo,
|
||||
void *hZ, Nd4jLong *hZShapeInfo,
|
||||
void *dZ, Nd4jLong *dZShapeInfo,
|
||||
void *extraParams);
|
||||
|
||||
/**
|
||||
*
|
||||
* @param opNum
|
||||
|
|
|
@ -24,6 +24,10 @@
|
|||
#include <broadcasting_bool.h>
|
||||
#include <scalar_bool.h>
|
||||
|
||||
#include <pairwise_int.h>
|
||||
#include <broadcasting_int.h>
|
||||
#include <scalar_int.h>
|
||||
|
||||
#include <loops/transform_float.h>
|
||||
#include <loops/transform_bool.h>
|
||||
#include <loops/transform_any.h>
|
||||
|
@ -242,6 +246,66 @@ void NativeOpExecutioner::execInverseBroadcastBool(nd4j::LaunchContext *lc,
|
|||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::execInverse(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, BOOL_TYPES);
|
||||
}
|
||||
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
void NativeOpExecutioner::execBroadcastInt(nd4j::LaunchContext *lc,
|
||||
int opNum,
|
||||
void *hX, Nd4jLong *hXShapeInfo,
|
||||
void *dX, Nd4jLong *dXShapeInfo,
|
||||
void *hY, Nd4jLong *hYShapeInfo,
|
||||
void *dY, Nd4jLong *dYShapeInfo,
|
||||
void *hZ, Nd4jLong *hZShapeInfo,
|
||||
void *dZ, Nd4jLong *dZShapeInfo,
|
||||
int *dimension, int dimensionLength,
|
||||
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
|
||||
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) {
|
||||
#ifdef _OPENMP
|
||||
omp_set_nested(1);
|
||||
#endif
|
||||
|
||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
|
||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||
|
||||
if (xType != yType || xType != zType)
|
||||
throw nd4j::datatype_exception::build("NativeOpExecutioner::execBroadcastInt", zType, xType, yType);
|
||||
|
||||
if (!nd4j::DataTypeUtils::isZ(zType))
|
||||
throw nd4j::datatype_exception::build("NativeOpExecutioner::execBroadcastInt requires integer data type", zType);
|
||||
|
||||
BUILD_SINGLE_SELECTOR(xType, functions::broadcast::BroadcastInt, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), INTEGER_TYPES);
|
||||
}
|
||||
|
||||
void NativeOpExecutioner::execInverseBroadcastInt(nd4j::LaunchContext *lc,
|
||||
int opNum,
|
||||
void *hX, Nd4jLong *hXShapeInfo,
|
||||
void *dX, Nd4jLong *dXShapeInfo,
|
||||
void *hY, Nd4jLong *hYShapeInfo,
|
||||
void *dY, Nd4jLong *dYShapeInfo,
|
||||
void *hZ, Nd4jLong *hZShapeInfo,
|
||||
void *dZ, Nd4jLong *dZShapeInfo,
|
||||
int *dimension, int dimensionLength,
|
||||
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
|
||||
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) {
|
||||
|
||||
#ifdef _OPENMP
|
||||
omp_set_nested(1);
|
||||
#endif
|
||||
|
||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
|
||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||
|
||||
if (xType != yType || xType != zType)
|
||||
throw nd4j::datatype_exception::build("NativeOpExecutioner::execPairwiseIntTransform", zType, xType, yType);
|
||||
|
||||
if (!nd4j::DataTypeUtils::isZ(zType))
|
||||
throw nd4j::datatype_exception::build("NativeOpExecutioner::execBroadcastInt requires integer data type", zType);
|
||||
|
||||
BUILD_SINGLE_SELECTOR(xType, functions::broadcast::BroadcastInt, ::execInverse(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), INTEGER_TYPES);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
/**
|
||||
*
|
||||
|
@ -297,9 +361,42 @@ void NativeOpExecutioner::execPairwiseBoolTransform(nd4j::LaunchContext *lc,
|
|||
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
|
||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||
|
||||
if (xType != yType)
|
||||
throw nd4j::datatype_exception::build("NativeOpExecutioner::execPairwiseBoolTransform", xType, yType);
|
||||
|
||||
if (zType != nd4j::DataType::BOOL)
|
||||
throw nd4j::datatype_exception::build("NativeOpExecutioner::execPairwiseBoolTransform", nd4j::DataType::BOOL, zType);
|
||||
|
||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::pairwise_transforms::PairWiseBoolTransform, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, extraParams), LIBND4J_TYPES, BOOL_TYPES);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
void NativeOpExecutioner::execPairwiseIntTransform(nd4j::LaunchContext *lc,
|
||||
int opNum,
|
||||
void *hX, Nd4jLong *hXShapeInfo,
|
||||
void *dX, Nd4jLong *dXShapeInfo,
|
||||
void *hY, Nd4jLong *hYShapeInfo,
|
||||
void *dY, Nd4jLong *dYShapeInfo,
|
||||
void *hZ, Nd4jLong *hZShapeInfo,
|
||||
void *dZ, Nd4jLong *dZShapeInfo,
|
||||
void *extraParams) {
|
||||
#ifdef _OPENMP
|
||||
omp_set_nested(1);
|
||||
#endif
|
||||
|
||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
|
||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||
|
||||
if (xType != yType || xType != zType)
|
||||
throw nd4j::datatype_exception::build("NativeOpExecutioner::execPairwiseIntTransform", zType, xType, yType);
|
||||
|
||||
if (!nd4j::DataTypeUtils::isZ(zType))
|
||||
throw nd4j::datatype_exception::build("NativeOpExecutioner::execSPairwiseInt requires integer data type", zType);
|
||||
|
||||
BUILD_SINGLE_SELECTOR(xType, functions::pairwise_transforms::PairWiseIntTransform, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, extraParams), INTEGER_TYPES);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
/**
|
||||
*
|
||||
|
@ -739,6 +836,64 @@ void NativeOpExecutioner::execScalarBool(nd4j::LaunchContext *lc,
|
|||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::scalar::ScalarBoolTransform, ::transform(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, hScalars, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, BOOL_TYPES);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
void NativeOpExecutioner::execScalarInt(nd4j::LaunchContext *lc,
|
||||
int opNum,
|
||||
void *hX, Nd4jLong *hXShapeInfo,
|
||||
void *dX, Nd4jLong *dXShapeInfo,
|
||||
void *hZ, Nd4jLong *hZShapeInfo,
|
||||
void *dZ, Nd4jLong *dZShapeInfo,
|
||||
void *hScalar, Nd4jLong *hSscalarShapeInfo,
|
||||
void *dScalar, Nd4jLong *dSscalarShapeInfo,
|
||||
void *extraParams, bool allowParallelism) {
|
||||
|
||||
#ifdef _OPENMP
|
||||
omp_set_nested(1);
|
||||
#endif
|
||||
|
||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||
auto yType = nd4j::ArrayOptions::dataType(hSscalarShapeInfo);
|
||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||
|
||||
if (xType != yType || xType != zType)
|
||||
throw nd4j::datatype_exception::build("NativeOpExecutioner::execScalarInt", xType, yType);
|
||||
|
||||
if (!nd4j::DataTypeUtils::isZ(zType))
|
||||
throw nd4j::datatype_exception::build("NativeOpExecutioner::execScalarInt", nd4j::DataType::INT32, zType);
|
||||
|
||||
BUILD_SINGLE_SELECTOR(xType, functions::scalar::ScalarIntTransform, ::transform(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, hScalar, extraParams), INTEGER_TYPES);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
void NativeOpExecutioner::execScalarInt(nd4j::LaunchContext *lc,
|
||||
int opNum,
|
||||
void *hX, Nd4jLong *hXShapeInfo,
|
||||
void *dX, Nd4jLong *dXShapeInfo,
|
||||
void *extraParams,
|
||||
void *hZ, Nd4jLong *hZShapeInfo,
|
||||
void *dZ, Nd4jLong *dZShapeInfo,
|
||||
void *hScalars, Nd4jLong *hScalarShapeInfo,
|
||||
void *dScalars, Nd4jLong *dScalarShapeInfo,
|
||||
int *dimension, int dimensionLength,
|
||||
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
|
||||
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) {
|
||||
#ifdef _OPENMP
|
||||
omp_set_nested(1);
|
||||
#endif
|
||||
|
||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||
auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo);
|
||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||
|
||||
if (xType != yType || xType != zType)
|
||||
throw nd4j::datatype_exception::build("NativeOpExecutioner::execScalarInt", xType, yType);
|
||||
|
||||
if (!nd4j::DataTypeUtils::isZ(zType))
|
||||
throw nd4j::datatype_exception::build("NativeOpExecutioner::execScalarInt requires integer data type", zType);
|
||||
|
||||
BUILD_SINGLE_SELECTOR(xType, functions::scalar::ScalarIntTransform, ::transform(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, hScalars, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), INTEGER_TYPES);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
/**
|
||||
*
|
||||
|
|
|
@ -38,19 +38,22 @@
|
|||
#include <loops/reduce_same.h>
|
||||
#include <loops/reduce_bool.h>
|
||||
#include <loops/reduce_long.h>
|
||||
#include <loops/broadcasting.h>
|
||||
#include <loops/indexreduce.h>
|
||||
#include <loops/pairwise_transform.h>
|
||||
#include <loops/pairwise_bool.h>
|
||||
#include <loops/pairwise_int.h>
|
||||
#include <loops/broadcasting_bool.h>
|
||||
#include <loops/broadcasting_int.h>
|
||||
#include <loops/broadcasting.h>
|
||||
#include <loops/reduce_float.h>
|
||||
#include <loops/reduce3.h>
|
||||
#include <loops/summarystatsreduce.h>
|
||||
#include <loops/transform_same.h>
|
||||
#include <loops/scalar.h>
|
||||
#include <loops/random.h>
|
||||
#include <loops/special_kernels.h>
|
||||
#include <loops/scalar.h>
|
||||
#include <loops/scalar_bool.h>
|
||||
#include <loops/scalar_int.h>
|
||||
|
||||
using namespace nd4j;
|
||||
|
||||
|
@ -152,6 +155,39 @@ void NativeOpExecutioner::execPairwiseBoolTransform( nd4j::LaunchContext *lc,
|
|||
throw cuda_exception::build("execPairwiseBoolTransform failed", res);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
void NativeOpExecutioner::execPairwiseIntTransform( nd4j::LaunchContext *lc,
|
||||
int opNum,
|
||||
void *hX, Nd4jLong *hXShapeInfo,
|
||||
void *dX, Nd4jLong *dXShapeInfo,
|
||||
void *hY, Nd4jLong *hYShapeInfo,
|
||||
void *dY, Nd4jLong *dYShapeInfo,
|
||||
void *hZ, Nd4jLong *hZShapeInfo,
|
||||
void *dZ, Nd4jLong *dZShapeInfo,
|
||||
void *extraParams) {
|
||||
|
||||
auto stream = lc->getCudaStream();
|
||||
|
||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
|
||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||
|
||||
if (!DataTypeUtils::isZ(zType))
|
||||
throw nd4j::datatype_exception::build("NativeOpExecutioner::execPairwiseIntTransform wrong Z operand data type", nd4j::DataType::BOOL, zType);
|
||||
|
||||
if (yType != xType || zType != xType)
|
||||
throw nd4j::datatype_exception::build("NativeOpExecutioner::execPairwiseIntTransform both operands must have same data type", xType, yType);
|
||||
|
||||
dim3 launchDims(256, 1024, 16384);
|
||||
|
||||
BUILD_SINGLE_SELECTOR(xType, functions::pairwise_transforms::PairWiseIntTransform, ::executeCudaShaped(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, extraParams), INTEGER_TYPES)
|
||||
|
||||
// TODO: remove after the release
|
||||
auto res = cudaStreamSynchronize(*stream);
|
||||
if (res != 0)
|
||||
throw cuda_exception::build("execPairwiseIntTransform failed", res);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
void NativeOpExecutioner::execSummaryStatsScalar(nd4j::LaunchContext *lc,
|
||||
int opNum,
|
||||
|
@ -252,6 +288,81 @@ void NativeOpExecutioner::execInverseBroadcastBool(nd4j::LaunchContext *lc,
|
|||
throw cuda_exception::build("execInverseBroadcastBool failed", res);
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
void NativeOpExecutioner::execBroadcastInt(nd4j::LaunchContext *lc,
|
||||
int opNum,
|
||||
void *hX, Nd4jLong *hXShapeInfo,
|
||||
void *dX, Nd4jLong *dXShapeInfo,
|
||||
void *hY, Nd4jLong *hYShapeInfo,
|
||||
void *dY, Nd4jLong *dYShapeInfo,
|
||||
void *hZ, Nd4jLong *hZShapeInfo,
|
||||
void *dZ, Nd4jLong *dZShapeInfo,
|
||||
int *dimension, int dimensionLength,
|
||||
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
|
||||
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) {
|
||||
|
||||
auto stream = lc->getCudaStream();
|
||||
|
||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
|
||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||
|
||||
if (!DataTypeUtils::isZ(zType))
|
||||
throw std::runtime_error("NativeOpExecutioner::execBroadcastInt requires Z operand to have INT type");
|
||||
|
||||
if (yType != xType || zType != xType)
|
||||
throw std::runtime_error("NativeOpExecutioner::execBroadcastInt requires both X & Y operands to have same type");
|
||||
|
||||
if (nd4j::Environment::getInstance()->isDebugAndVerbose())
|
||||
printf("F3B opNum:[%i]\n", opNum);
|
||||
|
||||
dim3 launchDims(256, 256, 1024);
|
||||
|
||||
BUILD_SINGLE_SELECTOR(xType, functions::broadcast::BroadcastInt, ::execBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), INTEGER_TYPES)
|
||||
|
||||
// TODO: remove after the release
|
||||
auto res = cudaStreamSynchronize(*stream);
|
||||
if (res != 0)
|
||||
throw cuda_exception::build("execBroadcastBool failed", res);
|
||||
}
|
||||
|
||||
void NativeOpExecutioner::execInverseBroadcastInt(nd4j::LaunchContext *lc,
|
||||
int opNum,
|
||||
void *hX, Nd4jLong *hXShapeInfo,
|
||||
void *dX, Nd4jLong *dXShapeInfo,
|
||||
void *hY, Nd4jLong *hYShapeInfo,
|
||||
void *dY, Nd4jLong *dYShapeInfo,
|
||||
void *hZ, Nd4jLong *hZShapeInfo,
|
||||
void *dZ, Nd4jLong *dZShapeInfo,
|
||||
int *dimension, int dimensionLength,
|
||||
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
|
||||
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) {
|
||||
auto stream = lc->getCudaStream();
|
||||
|
||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
|
||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||
|
||||
if (!DataTypeUtils::isZ(zType))
|
||||
throw std::runtime_error("NativeOpExecutioner::execBroadcastInt requires Z operand to have INT type");
|
||||
|
||||
if (yType != xType || zType != xType)
|
||||
throw std::runtime_error("NativeOpExecutioner::execBroadcastInt requires both X & Y operands to have same type");
|
||||
|
||||
if (nd4j::Environment::getInstance()->isDebugAndVerbose())
|
||||
printf("F3BI opNum:[%i]\n", opNum);
|
||||
|
||||
dim3 launchDims(256, 256, 1024);
|
||||
|
||||
BUILD_SINGLE_SELECTOR(xType, functions::broadcast::BroadcastInt, ::execInverseBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), INTEGER_TYPES)
|
||||
|
||||
// TODO: remove after the release
|
||||
auto res = cudaStreamSynchronize(*stream);
|
||||
if (res != 0)
|
||||
throw cuda_exception::build("execInverseBroadcastInt failed", res);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
/**
|
||||
*
|
||||
|
@ -1114,6 +1225,75 @@ void NativeOpExecutioner::execScalarBool(nd4j::LaunchContext *lc,
|
|||
throw cuda_exception::build("execScalarBool B failed", res);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
void NativeOpExecutioner::execScalarInt(nd4j::LaunchContext *lc,
|
||||
int opNum,
|
||||
void *hX, Nd4jLong *hXShapeInfo,
|
||||
void *dX, Nd4jLong *dXShapeInfo,
|
||||
void *hZ, Nd4jLong *hZShapeInfo,
|
||||
void *dZ, Nd4jLong *dZShapeInfo,
|
||||
void *hScalar, Nd4jLong *hScalarShapeInfo,
|
||||
void *dScalar, Nd4jLong *dScalarShapeInfo,
|
||||
void *extraParams, bool allowParallelism) {
|
||||
|
||||
auto stream = lc->getCudaStream();
|
||||
|
||||
dim3 launchDims = dim3(256, 512, 8192);
|
||||
|
||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||
auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo);
|
||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||
|
||||
if (xType != yType || zType != xType)
|
||||
throw std::runtime_error("NativeOpExecutioner::execScalarInt requires X & Y to have same type");
|
||||
|
||||
if (!DataTypeUtils::isZ(zType) )
|
||||
throw std::runtime_error("NativeOpExecutioner::execScalarInt requires Z operand to have INT type");
|
||||
|
||||
BUILD_SINGLE_SELECTOR(xType, functions::scalar::ScalarIntTransform, ::executeCudaShaped(launchDims, stream, opNum, dX, dXShapeInfo, dZ, dZShapeInfo, dScalar, extraParams), INTEGER_TYPES);
|
||||
|
||||
// TODO: remove after the release
|
||||
auto res = cudaStreamSynchronize(*stream);
|
||||
if (res != 0)
|
||||
throw cuda_exception::build("execScalarInt failed", res);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
void NativeOpExecutioner::execScalarInt(nd4j::LaunchContext *lc,
|
||||
int opNum,
|
||||
void *hX, Nd4jLong *hXShapeInfo,
|
||||
void *dX, Nd4jLong *dXShapeInfo,
|
||||
void *extraParams,
|
||||
void *hZ, Nd4jLong *hZShapeInfo,
|
||||
void *dZ, Nd4jLong *dZShapeInfo,
|
||||
void *hScalars, Nd4jLong *hScalarShapeInfo,
|
||||
void *dScalars, Nd4jLong *dScalarShapeInfo,
|
||||
int *dimension, int dimensionLength,
|
||||
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
|
||||
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) {
|
||||
|
||||
auto stream = lc->getCudaStream();
|
||||
|
||||
dim3 launchDims(256, 512, 8192);
|
||||
|
||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||
auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo);
|
||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||
|
||||
if (xType != yType || zType != xType)
|
||||
throw std::runtime_error("NativeOpExecutioner::execScalarInt requires X & Y to have same type");
|
||||
|
||||
if (!DataTypeUtils::isZ(zType) )
|
||||
throw std::runtime_error("NativeOpExecutioner::execScalarInt requires Z operand to have INT type");
|
||||
|
||||
BUILD_SINGLE_SELECTOR(xType, functions::scalar::ScalarIntTransform, ::executeCudaAlongDimension(launchDims, stream, opNum, dX, dXShapeInfo, dZ, dZShapeInfo, dScalars, extraParams, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), INTEGER_TYPES);
|
||||
|
||||
// TODO: remove after the release
|
||||
auto res = cudaStreamSynchronize(*stream);
|
||||
if (res != 0)
|
||||
throw cuda_exception::build("execScalarInt B failed", res);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
void NativeOpExecutioner::execScalar(nd4j::LaunchContext *lc,
|
||||
int opNum,
|
||||
|
|
|
@ -71,12 +71,25 @@ inline pairwise::BoolOps fromBroadcastToPairwiseBool(broadcast::BoolOps op) {
|
|||
case broadcast::And: return pairwise::And;
|
||||
case broadcast::Or: return pairwise::Or;
|
||||
case broadcast::Xor: return pairwise::Xor;
|
||||
case broadcast::Not: return pairwise::Not;
|
||||
case broadcast::Not: return pairwise::Not;
|
||||
default:
|
||||
throw std::runtime_error("fromBroadcastToPairwiseBool: Not convertible operation");
|
||||
}
|
||||
}
|
||||
|
||||
inline pairwise::IntOps fromBroadcastToPairwiseInt(broadcast::IntOps op) {
|
||||
switch (op) {
|
||||
case broadcast::IntOps::IntAnd: return pairwise::IntOps::IntAnd;
|
||||
case broadcast::IntOps::IntOr: return pairwise::IntOps::IntOr;
|
||||
case broadcast::IntOps::IntXor: return pairwise::IntOps::IntXor;
|
||||
case broadcast::IntOps::ShiftLeft: return pairwise::IntOps::ShiftLeft;
|
||||
case broadcast::IntOps::ShiftRight: return pairwise::IntOps::ShiftRight;
|
||||
case broadcast::IntOps::CyclicShiftLeft: return pairwise::IntOps::CyclicShiftLeft;
|
||||
case broadcast::IntOps::CyclicShiftRight: return pairwise::IntOps::CyclicShiftRight;
|
||||
default:
|
||||
throw std::runtime_error("fromBroadcastToPairwiseInt: Not convertible operation");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif //DEV_TESTS_BROADCASTPAIRWISECONVERTER_H
|
|
@ -0,0 +1,164 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
/*
|
||||
* broadcasting.h
|
||||
*
|
||||
* Created on: Dec 28, 2015
|
||||
* Author: agibsonccc
|
||||
*/
|
||||
|
||||
#ifndef BROADCASTING_INT_H_
|
||||
#define BROADCASTING_INT_H_
|
||||
#include <dll.h>
|
||||
#include <helpers/shape.h>
|
||||
#include <templatemath.h>
|
||||
#include <pairwise_util.h>
|
||||
#include <ops/ops.h>
|
||||
#include <op_boilerplate.h>
|
||||
#include <helpers/DebugHelper.h>
|
||||
|
||||
#ifdef __CUDACC__
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#endif
|
||||
#ifdef __JNI__
|
||||
#include <jni.h>
|
||||
#endif
|
||||
|
||||
#include <helpers/TAD.h>
|
||||
|
||||
#include "legacy_ops.h"
|
||||
|
||||
namespace functions {
|
||||
namespace broadcast {
|
||||
|
||||
/**
|
||||
* Broadcast operation
|
||||
* for broadcasting a smaller tensor
|
||||
* along long a bigger one.
|
||||
*/
|
||||
template<typename X>
|
||||
class BroadcastInt {
|
||||
public:
|
||||
|
||||
#ifdef __CUDACC__
|
||||
|
||||
template<typename OpType>
|
||||
static __device__ void transformCuda(
|
||||
void *x,
|
||||
Nd4jLong *xShapeInfo,
|
||||
void *y,
|
||||
Nd4jLong *yShapeInfo,
|
||||
void *result,
|
||||
Nd4jLong *resultShapeInfo,
|
||||
int *dimension,
|
||||
int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
|
||||
|
||||
template <typename OpClass>
|
||||
static __host__ void intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
|
||||
|
||||
static __host__ void execBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
|
||||
|
||||
template<typename OpType>
|
||||
static __device__ void transformInverseCuda(
|
||||
void *x,
|
||||
Nd4jLong *xShapeInfo,
|
||||
void *y,
|
||||
Nd4jLong *yShapeInfo,
|
||||
void *result,
|
||||
Nd4jLong *resultShapeInfo,
|
||||
int *dimension,
|
||||
int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
|
||||
|
||||
template <typename OpClass>
|
||||
static __host__ void intermediateInverseBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
|
||||
|
||||
static __host__ void execInverseBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
|
||||
|
||||
#endif
|
||||
|
||||
static void exec(int opNum,
|
||||
void *x,
|
||||
Nd4jLong *xShapeInfo,
|
||||
void *y,
|
||||
Nd4jLong *yShapeInfo,
|
||||
void *result,
|
||||
Nd4jLong *resultShapeInfo,
|
||||
int *dimension,
|
||||
int dimensionLength,
|
||||
Nd4jLong *tadShapeInfo,
|
||||
Nd4jLong *tadOffset,
|
||||
Nd4jLong *tadShapeInfoZ,
|
||||
Nd4jLong *tadOffsetZ);
|
||||
|
||||
static void execInverse(int opNum,
|
||||
void *x,
|
||||
Nd4jLong *xShapeInfo,
|
||||
void *y,
|
||||
Nd4jLong *yShapeInfo,
|
||||
void *result,
|
||||
Nd4jLong *resultShapeInfo,
|
||||
int *dimension,
|
||||
int dimensionLength,
|
||||
Nd4jLong *tadShapeInfo,
|
||||
Nd4jLong *tadOffset,
|
||||
Nd4jLong *tadShapeInfoZ,
|
||||
Nd4jLong *tadOffsetZ);
|
||||
|
||||
/**
|
||||
* CPU execution
|
||||
* @param x the input
|
||||
* @param xShapeInfo the x shape information
|
||||
* @param y the y data
|
||||
* @param yShapeInfo the y shape information
|
||||
* @param result the result
|
||||
* @param resultShapeInfo the result shape information
|
||||
* @param dimension the dimension to broadcast along long
|
||||
* @param dimensionLength the length of the dimension buffer
|
||||
*/
|
||||
template<typename OpType>
|
||||
static void exec(void *x,
|
||||
Nd4jLong *xShapeInfo,
|
||||
void *y,
|
||||
Nd4jLong *yShapeInfo,
|
||||
void *result,
|
||||
Nd4jLong *resultShapeInfo,
|
||||
int *dimension,
|
||||
int dimensionLength,
|
||||
Nd4jLong *tadShapeInfo,
|
||||
Nd4jLong *tadOffset,
|
||||
Nd4jLong *tadShapeInfoZ,
|
||||
Nd4jLong *tadOffsetZ);
|
||||
|
||||
template<typename OpType>
|
||||
static void execInverse(void *x,
|
||||
Nd4jLong *xShapeInfo,
|
||||
void *y,
|
||||
Nd4jLong *yShapeInfo,
|
||||
void *result,
|
||||
Nd4jLong *resultShapeInfo,
|
||||
int *dimension,
|
||||
int dimensionLength,
|
||||
Nd4jLong *tadShapeInfo,
|
||||
Nd4jLong *tadOffset,
|
||||
Nd4jLong *tadShapeInfoZ,
|
||||
Nd4jLong *tadOffsetZ);
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
#endif /* BROADCASTING_H_ */
|
|
@ -0,0 +1,464 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#include <op_boilerplate.h>
|
||||
#include <loops/broadcasting_int.h>
|
||||
#include <loops/legacy_ops.h>
|
||||
#include <types/types.h>
|
||||
#include <LoopKind.h>
|
||||
#include <helpers/ConstantTadHelper.h>
|
||||
|
||||
using namespace simdOps;
|
||||
|
||||
namespace functions {
|
||||
namespace broadcast {
|
||||
|
||||
template <typename X>
|
||||
void BroadcastInt<X>::exec(const int opNum,
|
||||
void *x,
|
||||
Nd4jLong *xShapeInfo,
|
||||
void *y,
|
||||
Nd4jLong *yShapeInfo,
|
||||
void *z,
|
||||
Nd4jLong *zShapeInfo,
|
||||
int *dimension,
|
||||
int dimensionLength,
|
||||
Nd4jLong *xTadShapeInfo,
|
||||
Nd4jLong *xTadOffset,
|
||||
Nd4jLong *zTadShapeInfo,
|
||||
Nd4jLong *zTadOffset) {
|
||||
DISPATCH_BY_OPNUM_T(exec, PARAMS(x,
|
||||
xShapeInfo,
|
||||
y,
|
||||
yShapeInfo,
|
||||
z,
|
||||
zShapeInfo,
|
||||
dimension,
|
||||
dimensionLength,
|
||||
xTadShapeInfo,
|
||||
xTadOffset,
|
||||
zTadShapeInfo,
|
||||
zTadOffset), BROADCAST_INT_OPS);
|
||||
}
|
||||
|
||||
template <typename X>
|
||||
void BroadcastInt<X>::execInverse(const int opNum,
|
||||
void *x,
|
||||
Nd4jLong *xShapeInfo,
|
||||
void *y,
|
||||
Nd4jLong *yShapeInfo,
|
||||
void *z,
|
||||
Nd4jLong *zShapeInfo,
|
||||
int *dimension,
|
||||
int dimensionLength,
|
||||
Nd4jLong *xTadShapeInfo,
|
||||
Nd4jLong *xTadOffset,
|
||||
Nd4jLong *zTadShapeInfo,
|
||||
Nd4jLong *zTadOffset) {
|
||||
DISPATCH_BY_OPNUM_T(execInverse, PARAMS(x,
|
||||
xShapeInfo,
|
||||
y,
|
||||
yShapeInfo,
|
||||
z,
|
||||
zShapeInfo,
|
||||
dimension,
|
||||
dimensionLength,
|
||||
xTadShapeInfo,
|
||||
xTadOffset,
|
||||
zTadShapeInfo,
|
||||
zTadOffset), BROADCAST_INT_OPS);
|
||||
}
|
||||
|
||||
template <typename X>
|
||||
template<typename OpType>
|
||||
void BroadcastInt<X>::exec(void *vx,
|
||||
Nd4jLong *xShapeInfo,
|
||||
void *vy,
|
||||
Nd4jLong *yShapeInfo,
|
||||
void *vz,
|
||||
Nd4jLong *zShapeInfo,
|
||||
int *dimension,
|
||||
int dimensionLength,
|
||||
Nd4jLong *xTadShapeInfo,
|
||||
Nd4jLong *xTadOffset,
|
||||
Nd4jLong *zTadShapeInfo,
|
||||
Nd4jLong *zTadOffset) {
|
||||
|
||||
auto x = reinterpret_cast<X *>(vx);
|
||||
auto y = reinterpret_cast<X *>(vy);
|
||||
auto z = reinterpret_cast<X *>(vz);
|
||||
|
||||
//decompose in to several sub tads after
|
||||
//moving all dimensions (in sorted order)
|
||||
//to the back.
|
||||
//permuted version of the x shape info for setting up the tad problem
|
||||
auto xTadShapeShapeInfo = xTadShapeInfo;
|
||||
auto tadOffsets = xTadOffset;
|
||||
|
||||
if (xTadShapeInfo == nullptr || tadOffsets == nullptr) {
|
||||
auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(xShapeInfo, dimension, dimensionLength);
|
||||
|
||||
xTadShapeShapeInfo = tadPack.primaryShapeInfo();
|
||||
tadOffsets = tadPack.primaryOffsets();
|
||||
}
|
||||
|
||||
//int *resultStride = shape::stride(xTadShapeShapeInfo);
|
||||
unsigned int tadLength = shape::length(xTadShapeShapeInfo);//shape::length(xTadShapeShapeInfo);
|
||||
unsigned int tads = shape::length(xShapeInfo) / tadLength;
|
||||
|
||||
if (zTadShapeInfo == nullptr) {
|
||||
zTadShapeInfo = xTadShapeShapeInfo;
|
||||
zTadOffset = tadOffsets;
|
||||
}
|
||||
|
||||
auto lenZ = shape::length(zTadShapeInfo);
|
||||
auto lenY = shape::length(yShapeInfo);
|
||||
|
||||
int tadsPerThread = tads / TAD_THRESHOLD;
|
||||
int threads = nd4j::math::nd4j_max<int>(1, tadsPerThread);
|
||||
threads = nd4j::math::nd4j_min<int>(threads, omp_get_max_threads());
|
||||
|
||||
auto xEws = shape::elementWiseStride(xTadShapeShapeInfo);
|
||||
auto yEws = shape::elementWiseStride(yShapeInfo);
|
||||
auto zEws = shape::elementWiseStride(zTadShapeInfo);
|
||||
|
||||
const nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXYZ(xTadShapeShapeInfo, yShapeInfo, zTadShapeInfo);
|
||||
|
||||
if (kindOfLoop == nd4j::LoopKind::EWS1) {
|
||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
|
||||
for (int i = 0; i < tads; i++) {
|
||||
auto oX = x + tadOffsets[i];
|
||||
auto oZ = z + zTadOffset[i];
|
||||
|
||||
PRAGMA_OMP_SIMD
|
||||
for (unsigned int f = 0; f < tadLength; f++)
|
||||
oZ[f] = OpType::op(oX[f], y[f]);
|
||||
}
|
||||
}
|
||||
else if(kindOfLoop == nd4j::LoopKind::EWSNONZERO) {
|
||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
|
||||
for (int i = 0; i < tads; i++) {
|
||||
auto oX = x + tadOffsets[i];
|
||||
auto oZ = z + zTadOffset[i];
|
||||
|
||||
PRAGMA_OMP_SIMD
|
||||
for (unsigned int f = 0; f < tadLength; f++)
|
||||
oZ[f * zEws] = OpType::op(oX[f * xEws], y[f * yEws]);
|
||||
}
|
||||
}
|
||||
else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xTadShapeShapeInfo, zTadShapeInfo)) {
|
||||
|
||||
uint tadShapeShapeInfoCast[MAX_RANK];
|
||||
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast);
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
|
||||
for (int i = 0; i < tads; i++) {
|
||||
|
||||
auto oZ = z + zTadOffset[i];
|
||||
auto oX = x + tadOffsets[i];
|
||||
|
||||
// TODO: cover this codebranch with tests
|
||||
// all this stuff already happens within thread
|
||||
PRAGMA_OMP_SIMD
|
||||
for (int f = 0; f < tadLength; f++) {
|
||||
auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, tadLength, canCastX);
|
||||
oZ[offset] = OpType::op(oX[offset], y[offset]);
|
||||
}
|
||||
}
|
||||
}
|
||||
else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo)) {
|
||||
|
||||
uint tadShapeShapeInfoCast[MAX_RANK];
|
||||
uint tadShapeInfoZCast[MAX_RANK];
|
||||
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast);
|
||||
bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast);
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
|
||||
for (int i = 0; i < tads; i++) {
|
||||
|
||||
auto oZ = z + zTadOffset[i];
|
||||
auto oX = x + tadOffsets[i];
|
||||
|
||||
PRAGMA_OMP_SIMD
|
||||
for (int f = 0; f < tadLength; f++) {
|
||||
auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, tadLength, canCastX);
|
||||
auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, lenZ, canCastZ);
|
||||
oZ[zOffset] = OpType::op(oX[offset], y[offset]);
|
||||
}
|
||||
}
|
||||
}
|
||||
else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, zTadShapeInfo)) {
|
||||
|
||||
uint tadShapeShapeInfoCast[MAX_RANK];
|
||||
uint yShapeInfoCast[MAX_RANK];
|
||||
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast);
|
||||
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
|
||||
for (int i = 0; i < tads; i++) {
|
||||
|
||||
auto oZ = z + zTadOffset[i];
|
||||
auto oX = x + tadOffsets[i];
|
||||
|
||||
PRAGMA_OMP_SIMD
|
||||
for (int f = 0; f < tadLength; f++) {
|
||||
auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, tadLength, canCastX);
|
||||
auto yOffset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, lenY, canCastY);
|
||||
oZ[offset] = OpType::op(oX[offset], y[yOffset]);
|
||||
}
|
||||
}
|
||||
}
|
||||
else if(shape::haveSameShapeAndStrides(yShapeInfo, zTadShapeInfo)) {
|
||||
|
||||
uint tadShapeShapeInfoCast[MAX_RANK];
|
||||
uint yShapeInfoCast[MAX_RANK];
|
||||
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast);
|
||||
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
|
||||
for (int i = 0; i < tads; i++) {
|
||||
|
||||
auto oZ = z + zTadOffset[i];
|
||||
auto oX = x + tadOffsets[i];
|
||||
|
||||
PRAGMA_OMP_SIMD
|
||||
for (int f = 0; f < tadLength; f++) {
|
||||
auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, tadLength, canCastX);
|
||||
auto offset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, lenY, canCastY);
|
||||
oZ[offset] = OpType::op(oX[xOffset], y[offset]);
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
|
||||
uint tadShapeShapeInfoCast[MAX_RANK];
|
||||
uint tadShapeInfoZCast[MAX_RANK];
|
||||
uint yShapeInfoCast[MAX_RANK];
|
||||
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast);
|
||||
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
|
||||
bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast);
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
|
||||
for (int i = 0; i < tads; i++) {
|
||||
|
||||
auto oZ = z + zTadOffset[i];
|
||||
auto oX = x + tadOffsets[i];
|
||||
|
||||
PRAGMA_OMP_SIMD
|
||||
for (int f = 0; f < tadLength; f++) {
|
||||
auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, tadLength, canCastX);
|
||||
auto yOffset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, lenY, canCastY);
|
||||
auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, lenZ, canCastZ);
|
||||
oZ[zOffset] = OpType::op(oX[xOffset], y[yOffset]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename X>
|
||||
template<typename OpType>
|
||||
void BroadcastInt<X>::execInverse(void *vx,
|
||||
Nd4jLong *xShapeInfo,
|
||||
void *vy,
|
||||
Nd4jLong *yShapeInfo,
|
||||
void *vz,
|
||||
Nd4jLong *zShapeInfo,
|
||||
int *dimension,
|
||||
int dimensionLength,
|
||||
Nd4jLong *yTadShapeInfo,
|
||||
Nd4jLong *yTadOffset,
|
||||
Nd4jLong *zTadShapeInfo,
|
||||
Nd4jLong *zTadOffset) {
|
||||
|
||||
auto x = reinterpret_cast<X *>(vx);
|
||||
auto y = reinterpret_cast<X *>(vy);
|
||||
auto z = reinterpret_cast<X *>(vz);
|
||||
|
||||
//decompose in to several sub tads after
|
||||
//moving all dimensions (in sorted order)
|
||||
//to the back.
|
||||
//permuted version of the x shape info for setting up the tad problem
|
||||
auto yTadShapeShapeInfo = yTadShapeInfo;
|
||||
auto tadOffsets = yTadOffset;
|
||||
|
||||
if (yTadShapeInfo == nullptr || tadOffsets == nullptr) {
|
||||
auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(yShapeInfo, dimension, dimensionLength);
|
||||
|
||||
yTadShapeShapeInfo = tadPack.primaryShapeInfo();
|
||||
tadOffsets = tadPack.primaryOffsets();
|
||||
}
|
||||
|
||||
//int *resultStride = shape::stride(yTadShapeShapeInfo);
|
||||
unsigned int tadLength = shape::length(yTadShapeShapeInfo);
|
||||
unsigned int tads = shape::length(yShapeInfo) / tadLength;
|
||||
|
||||
if (zTadShapeInfo == nullptr) {
|
||||
zTadShapeInfo = yTadShapeShapeInfo;
|
||||
zTadOffset = tadOffsets;
|
||||
}
|
||||
|
||||
auto lenZ = shape::length(zTadShapeInfo);
|
||||
auto lenX = shape::length(xShapeInfo);
|
||||
|
||||
int tadsPerThread = tads / TAD_THRESHOLD;
|
||||
int threads = nd4j::math::nd4j_max<int>(1, tadsPerThread);
|
||||
threads = nd4j::math::nd4j_min<int>(threads, omp_get_max_threads());
|
||||
|
||||
auto yEws = shape::elementWiseStride(yTadShapeShapeInfo);
|
||||
auto xEws = shape::elementWiseStride(xShapeInfo);
|
||||
auto zEws = shape::elementWiseStride(zTadShapeInfo);
|
||||
|
||||
const nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXYZ(yTadShapeShapeInfo, xShapeInfo, zTadShapeInfo);
|
||||
|
||||
if (kindOfLoop == nd4j::LoopKind::EWS1) {
|
||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
|
||||
for (int i = 0; i < tads; i++) {
|
||||
auto oY = y + tadOffsets[i];
|
||||
auto oZ = z + zTadOffset[i];
|
||||
|
||||
PRAGMA_OMP_SIMD
|
||||
for (unsigned int f = 0; f < tadLength; f++)
|
||||
oZ[f] = OpType::op(x[f], oY[f]);
|
||||
}
|
||||
}
|
||||
else if(kindOfLoop == nd4j::LoopKind::EWSNONZERO) {
|
||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
|
||||
for (int i = 0; i < tads; i++) {
|
||||
auto oY = y + tadOffsets[i];
|
||||
auto oZ = z + zTadOffset[i];
|
||||
|
||||
PRAGMA_OMP_SIMD
|
||||
for (uint f = 0; f < tadLength; f++)
|
||||
oZ[f * zEws] = OpType::op(x[f * xEws], oY[f * yEws]);
|
||||
}
|
||||
}
|
||||
else if(shape::haveSameShapeAndStrides(yTadShapeShapeInfo, xShapeInfo) && shape::haveSameShapeAndStrides(yTadShapeShapeInfo, zTadShapeInfo)) {
|
||||
|
||||
uint tadShapeShapeInfoCast[MAX_RANK];
|
||||
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast);
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
|
||||
for (int i = 0; i < tads; i++) {
|
||||
|
||||
auto oY = y + tadOffsets[i];
|
||||
auto oZ = z + zTadOffset[i];
|
||||
|
||||
// TODO: cover this codebranch with tests
|
||||
// all this stuff already happens within thread
|
||||
PRAGMA_OMP_SIMD
|
||||
for (int f = 0; f < tadLength; f++) {
|
||||
auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, tadLength, canCastY);
|
||||
oZ[offset] = OpType::op(x[offset], oY[offset]);
|
||||
}
|
||||
}
|
||||
}
|
||||
else if(shape::haveSameShapeAndStrides(yTadShapeShapeInfo, xShapeInfo)) {
|
||||
|
||||
uint tadShapeShapeInfoCast[MAX_RANK];
|
||||
uint tadShapeInfoZCast[MAX_RANK];
|
||||
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast);
|
||||
bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast);
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
|
||||
for (int i = 0; i < tads; i++) {
|
||||
|
||||
auto oZ = z + zTadOffset[i];
|
||||
auto oY = y + tadOffsets[i];
|
||||
|
||||
PRAGMA_OMP_SIMD
|
||||
for (int f = 0; f < tadLength; f++) {
|
||||
auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, tadLength, canCastY);
|
||||
auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, lenZ, canCastZ);
|
||||
oZ[zOffset] = OpType::op(x[offset], oY[offset]);
|
||||
}
|
||||
}
|
||||
}
|
||||
else if(shape::haveSameShapeAndStrides(yTadShapeShapeInfo, zTadShapeInfo)) {
|
||||
|
||||
uint tadShapeShapeInfoCast[MAX_RANK];
|
||||
uint xShapeInfoCast[MAX_RANK];
|
||||
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
||||
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast);
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
|
||||
for (int i = 0; i < tads; i++) {
|
||||
|
||||
auto oZ = z + zTadOffset[i];
|
||||
auto oY = y + tadOffsets[i];
|
||||
|
||||
PRAGMA_OMP_SIMD
|
||||
for (int f = 0; f < tadLength; f++) {
|
||||
auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, tadLength, canCastY);
|
||||
auto xOffset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, lenX, canCastX);
|
||||
oZ[offset] = OpType::op(x[xOffset], oY[offset]);
|
||||
}
|
||||
}
|
||||
}
|
||||
else if(shape::haveSameShapeAndStrides(xShapeInfo, zTadShapeInfo)) {
|
||||
|
||||
uint tadShapeShapeInfoCast[MAX_RANK];
|
||||
uint xShapeInfoCast[MAX_RANK];
|
||||
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
||||
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast);
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
|
||||
for (int i = 0; i < tads; i++) {
|
||||
|
||||
auto oZ = z + zTadOffset[i];
|
||||
auto oY = y + tadOffsets[i];
|
||||
|
||||
PRAGMA_OMP_SIMD
|
||||
for (int f = 0; f < tadLength; f++) {
|
||||
auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, tadLength, canCastY);
|
||||
auto offset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, lenX, canCastX);
|
||||
oZ[offset] = OpType::op(x[offset], oY[yOffset]);
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
|
||||
uint xShapeInfoCast[MAX_RANK];
|
||||
uint tadShapeShapeInfoCast[MAX_RANK];
|
||||
uint tadShapeInfoZCast[MAX_RANK];
|
||||
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
||||
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast);
|
||||
bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast);
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
|
||||
for (int i = 0; i < tads; i++) {
|
||||
|
||||
auto oZ = z + zTadOffset[i];
|
||||
auto oY = y + tadOffsets[i];
|
||||
|
||||
PRAGMA_OMP_SIMD
|
||||
for (int f = 0; f < tadLength; f++) {
|
||||
auto xOffset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, lenX, canCastX);
|
||||
auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, tadLength, canCastY);
|
||||
auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, lenZ, canCastZ);
|
||||
oZ[zOffset] = OpType::op(x[xOffset], oY[yOffset]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT BroadcastInt, , INTEGER_TYPES);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,309 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#include <loops/pairwise_int.h>
|
||||
#include <types/types.h>
|
||||
#include <LoopKind.h>
|
||||
#include <OmpLaunchHelper.h>
|
||||
|
||||
using namespace simdOps;
|
||||
|
||||
namespace functions {
|
||||
namespace pairwise_transforms {
|
||||
|
||||
template <typename X>
|
||||
void PairWiseIntTransform<X>::exec(
|
||||
const int opNum,
|
||||
void *x,
|
||||
Nd4jLong xEws,
|
||||
void *y,
|
||||
Nd4jLong yEws,
|
||||
void *z,
|
||||
Nd4jLong zEws,
|
||||
void *extraParams,
|
||||
Nd4jLong n) {
|
||||
DISPATCH_BY_OPNUM_T(exec, PARAMS(x,
|
||||
xEws,
|
||||
y,
|
||||
yEws,
|
||||
z,
|
||||
zEws,
|
||||
extraParams,
|
||||
n), PAIRWISE_INT_OPS);
|
||||
};
|
||||
|
||||
|
||||
|
||||
template <typename X>
|
||||
template <typename OpType>
|
||||
void PairWiseIntTransform<X>::exec(void *vx,
|
||||
Nd4jLong xEws,
|
||||
void *vy,
|
||||
Nd4jLong yEws,
|
||||
void *vz,
|
||||
Nd4jLong zEws,
|
||||
void *vextraParams,
|
||||
const Nd4jLong n) {
|
||||
|
||||
auto x = reinterpret_cast<X *>(vx);
|
||||
auto y = reinterpret_cast<X *>(vy);
|
||||
auto z = reinterpret_cast<X *>(vz);
|
||||
auto extraParams = reinterpret_cast<X *>(vextraParams);
|
||||
|
||||
nd4j::OmpLaunchHelper info(n);
|
||||
|
||||
if (xEws == 1 && yEws == 1 && zEws == 1) {
|
||||
|
||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
||||
{
|
||||
auto threadNum = omp_get_thread_num();
|
||||
Nd4jLong threadOffset = info.getThreadOffset(threadNum);
|
||||
auto xi = x + threadOffset;
|
||||
auto yi = y + threadOffset;
|
||||
auto zi = z + threadOffset;
|
||||
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
|
||||
|
||||
PRAGMA_OMP_SIMD
|
||||
for (Nd4jLong i = 0; i < ulen; i++)
|
||||
zi[i] = OpType::op(xi[i], yi[i], extraParams);
|
||||
}
|
||||
}
|
||||
else {
|
||||
|
||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
||||
{
|
||||
auto threadNum = omp_get_thread_num();
|
||||
Nd4jLong threadOffset = info.getThreadOffset(threadNum);
|
||||
auto xi = x + xEws*threadOffset;
|
||||
auto yi = y + yEws*threadOffset;
|
||||
auto zi = z + zEws*threadOffset;
|
||||
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
|
||||
|
||||
PRAGMA_OMP_SIMD
|
||||
for (Nd4jLong i = 0; i < ulen; i++)
|
||||
zi[i*zEws] = OpType::op(xi[i*xEws], yi[i*yEws], extraParams);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename X>
|
||||
void PairWiseIntTransform<X>::exec(
|
||||
const int opNum,
|
||||
void *x,
|
||||
Nd4jLong *xShapeInfo,
|
||||
void *y,
|
||||
Nd4jLong *yShapeInfo,
|
||||
void *z,
|
||||
Nd4jLong *zShapeInfo,
|
||||
void *extraParams) {
|
||||
DISPATCH_BY_OPNUM_T(exec, PARAMS(x,
|
||||
xShapeInfo,
|
||||
y,
|
||||
yShapeInfo,
|
||||
z,
|
||||
zShapeInfo,
|
||||
extraParams),
|
||||
PAIRWISE_INT_OPS);
|
||||
};
|
||||
|
||||
|
||||
template <typename X>
|
||||
template <typename OpType>
|
||||
void PairWiseIntTransform<X>::exec(void *vx, Nd4jLong* xShapeInfo,
|
||||
void *vy, Nd4jLong* yShapeInfo,
|
||||
void *vz, Nd4jLong* zShapeInfo,
|
||||
void *vextraParams) {
|
||||
|
||||
auto x = reinterpret_cast<X *>(vx);
|
||||
auto y = reinterpret_cast<X *>(vy);
|
||||
auto z = reinterpret_cast<X *>(vz);
|
||||
auto extraParams = reinterpret_cast<X *>(vextraParams);
|
||||
|
||||
auto n = shape::length(xShapeInfo);
|
||||
auto xEws = shape::elementWiseStride(xShapeInfo);
|
||||
auto yEws = shape::elementWiseStride(yShapeInfo);
|
||||
auto zEws = shape::elementWiseStride(zShapeInfo);
|
||||
|
||||
nd4j::OmpLaunchHelper info(n);
|
||||
|
||||
if (shape::isScalar(yShapeInfo)) {
|
||||
|
||||
uint xShapeInfoCast[MAX_RANK];
|
||||
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
||||
|
||||
if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
|
||||
|
||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
||||
{
|
||||
auto threadNum = omp_get_thread_num();
|
||||
auto threadOffset = info.getThreadOffset(threadNum);
|
||||
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
|
||||
|
||||
PRAGMA_OMP_SIMD
|
||||
for(Nd4jLong i = 0; i < ulen; i++) {
|
||||
auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, n, canCastX);
|
||||
z[offset] = OpType::op(x[offset], y[0], extraParams);
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
|
||||
uint zShapeInfoCast[MAX_RANK];
|
||||
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast);
|
||||
|
||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
||||
{
|
||||
auto threadNum = omp_get_thread_num();
|
||||
auto threadOffset = info.getThreadOffset(threadNum);
|
||||
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
|
||||
|
||||
PRAGMA_OMP_SIMD
|
||||
for(Nd4jLong i = 0; i < ulen; i++) {
|
||||
auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, n, canCastX);
|
||||
auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, n, canCastZ);
|
||||
z[zOffset] = OpType::op(x[xOffset], y[0], extraParams);
|
||||
}
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
const nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXYZ(xShapeInfo, yShapeInfo, zShapeInfo);
|
||||
const bool sameShapesXY = shape::shapeEquals(xShapeInfo, yShapeInfo);
|
||||
|
||||
if ((kindOfLoop == nd4j::LoopKind::EWS1 || kindOfLoop == nd4j::LoopKind::EWSNONZERO) && sameShapesXY) {
|
||||
exec<OpType>(x, xEws, y, yEws, z, zEws, extraParams, n);
|
||||
}
|
||||
else if ((kindOfLoop == nd4j::LoopKind::EWS1 || kindOfLoop == nd4j::LoopKind::EWSNONZERO) && !sameShapesXY) { //not same shape
|
||||
exec<OpType>(x, xEws, y, yEws, z, zEws, extraParams, shape::length(yShapeInfo));
|
||||
}
|
||||
else {
|
||||
|
||||
if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
|
||||
|
||||
uint xShapeInfoCast[MAX_RANK];
|
||||
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
||||
|
||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
||||
{
|
||||
auto threadNum = omp_get_thread_num();
|
||||
auto threadOffset = info.getThreadOffset(threadNum);
|
||||
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
|
||||
|
||||
PRAGMA_OMP_SIMD
|
||||
for (Nd4jLong i = 0; i < ulen; i++) {
|
||||
auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, n, canCastX);
|
||||
z[offset] = OpType::op(x[offset], y[offset], extraParams);
|
||||
}
|
||||
}
|
||||
}
|
||||
else if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) {
|
||||
|
||||
uint xShapeInfoCast[MAX_RANK];
|
||||
uint zShapeInfoCast[MAX_RANK];
|
||||
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
||||
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast);
|
||||
|
||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
||||
{
|
||||
auto threadNum = omp_get_thread_num();
|
||||
auto threadOffset = info.getThreadOffset(threadNum);
|
||||
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
|
||||
|
||||
PRAGMA_OMP_SIMD
|
||||
for (Nd4jLong i = 0; i < ulen; i++) {
|
||||
auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, n, canCastX);
|
||||
auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, n, canCastZ);
|
||||
z[zOffset] = OpType::op(x[offset], y[offset], extraParams);
|
||||
}
|
||||
}
|
||||
}
|
||||
else if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
|
||||
|
||||
uint xShapeInfoCast[MAX_RANK];
|
||||
uint yShapeInfoCast[MAX_RANK];
|
||||
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
||||
const bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
|
||||
|
||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
||||
{
|
||||
auto threadNum = omp_get_thread_num();
|
||||
auto threadOffset = info.getThreadOffset(threadNum);
|
||||
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
|
||||
|
||||
PRAGMA_OMP_SIMD
|
||||
for (Nd4jLong i = 0; i < ulen; i++) {
|
||||
auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, n, canCastX);
|
||||
auto yOffset = shape::indexOffset(i + threadOffset, yShapeInfo, yShapeInfoCast, n, canCastY);
|
||||
z[offset] = OpType::op(x[offset], y[yOffset], extraParams);
|
||||
}
|
||||
}
|
||||
}
|
||||
else if(shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo)) {
|
||||
|
||||
uint xShapeInfoCast[MAX_RANK];
|
||||
uint yShapeInfoCast[MAX_RANK];
|
||||
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
||||
const bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
|
||||
|
||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
||||
{
|
||||
auto threadNum = omp_get_thread_num();
|
||||
auto threadOffset = info.getThreadOffset(threadNum);
|
||||
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
|
||||
|
||||
PRAGMA_OMP_SIMD
|
||||
for (Nd4jLong i = 0; i < ulen; i++) {
|
||||
auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, n, canCastX);
|
||||
auto offset = shape::indexOffset(i + threadOffset, yShapeInfo, yShapeInfoCast, n, canCastY);
|
||||
z[offset] = OpType::op(x[xOffset], y[offset], extraParams);
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
|
||||
uint xShapeInfoCast[MAX_RANK];
|
||||
uint yShapeInfoCast[MAX_RANK];
|
||||
uint zShapeInfoCast[MAX_RANK];
|
||||
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
|
||||
const bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
|
||||
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast);
|
||||
|
||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
||||
{
|
||||
auto threadNum = omp_get_thread_num();
|
||||
auto threadOffset = info.getThreadOffset(threadNum);
|
||||
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
|
||||
|
||||
PRAGMA_OMP_SIMD
|
||||
for (Nd4jLong i = 0; i < ulen; i++) {
|
||||
auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, n, canCastX);
|
||||
auto yOffset = shape::indexOffset(i + threadOffset, yShapeInfo, yShapeInfoCast, n, canCastY);
|
||||
auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, n, canCastZ);
|
||||
z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT PairWiseIntTransform, , INTEGER_TYPES);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,255 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#include "../scalar_int.h"
|
||||
#include <op_boilerplate.h>
|
||||
#include <types/types.h>
|
||||
#include <LoopKind.h>
|
||||
|
||||
#include "../legacy_ops.h"
|
||||
|
||||
using namespace simdOps;
|
||||
|
||||
namespace functions {
|
||||
namespace scalar {
|
||||
|
||||
|
||||
template<typename X>
|
||||
template<typename OpType>
|
||||
void ScalarIntTransform<X>::transform(void *vx, Nd4jLong *xShapeInfo,
|
||||
void *vextraParams,
|
||||
void *vz, Nd4jLong *zShapeInfo,
|
||||
void *vscalars,
|
||||
int *dimension, int dimensionLength,
|
||||
Nd4jLong *xTadShapeInfo, Nd4jLong *xTadOffsets,
|
||||
Nd4jLong *zTadShapeInfo, Nd4jLong *zTadOffsets) {
|
||||
|
||||
auto x = reinterpret_cast<X *>(vx);
|
||||
auto z = reinterpret_cast<X *>(vz);
|
||||
auto scalars = reinterpret_cast<X *>(vscalars);
|
||||
auto extraParams = reinterpret_cast<X *>(vextraParams);
|
||||
|
||||
if (zTadShapeInfo == nullptr) {
|
||||
zTadShapeInfo = xTadShapeInfo;
|
||||
zTadOffsets = xTadOffsets;
|
||||
}
|
||||
|
||||
// tad preparation
|
||||
const int xTadEws = shape::elementWiseStride(xTadShapeInfo);
|
||||
const int zTadEws = shape::elementWiseStride(zTadShapeInfo);
|
||||
const int tadLength = shape::tadLength(xShapeInfo, dimension, dimensionLength);
|
||||
const int numTads = shape::length(xShapeInfo) / tadLength;
|
||||
|
||||
nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXZ(xTadShapeInfo, zTadShapeInfo);
|
||||
|
||||
if (kindOfLoop != nd4j::LoopKind::EWS1 && kindOfLoop != nd4j::LoopKind::EWSNONZERO) {
|
||||
printf("ScalarIntTransform<X>::transform: super-bad loop visited. Shouldn't ever happen\n");
|
||||
return;
|
||||
}
|
||||
|
||||
int num_threads = nd4j::math::nd4j_min<int>(numTads, omp_get_max_threads());
|
||||
|
||||
if (kindOfLoop == nd4j::LoopKind::EWS1) {
|
||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(num_threads)
|
||||
for (unsigned int r = 0; r < numTads; r++) {
|
||||
auto oZ = z + zTadOffsets[r];
|
||||
auto oX = x + xTadOffsets[r];
|
||||
|
||||
PRAGMA_OMP_SIMD
|
||||
for (unsigned int f = 0; f < tadLength; f++)
|
||||
oZ[f] = OpType::op(oX[f], scalars[r], extraParams);
|
||||
}
|
||||
}
|
||||
else { // kindOfLoop != nd4j::LoopKind::EWSNONZERO
|
||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(num_threads)
|
||||
for (unsigned int r = 0; r < numTads; r++) {
|
||||
auto oZ = z + zTadOffsets[r];
|
||||
auto oX = x + xTadOffsets[r];
|
||||
|
||||
PRAGMA_OMP_SIMD
|
||||
for (unsigned int f = 0; f < tadLength; f++)
|
||||
oZ[f * zTadEws] = OpType::op(oX[f * xTadEws], scalars[r], extraParams);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename X>
|
||||
void ScalarIntTransform<X>::transform(int opNum,
|
||||
void *x,
|
||||
Nd4jLong *xShapeInfo,
|
||||
void *extraParams,
|
||||
void *z,
|
||||
Nd4jLong *zShapeInfo,
|
||||
void *scalars,
|
||||
int *dimension,
|
||||
int dimensionLength,
|
||||
Nd4jLong *xTadShapeInfo,
|
||||
Nd4jLong *xTadOffsets,
|
||||
Nd4jLong *zTadShapeInfo,
|
||||
Nd4jLong *zTadOffsets) {
|
||||
DISPATCH_BY_OPNUM_T(transform, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, scalars, dimension, dimensionLength, xTadShapeInfo, xTadOffsets, zTadShapeInfo, zTadOffsets), SCALAR_INT_OPS);
|
||||
}
|
||||
|
||||
|
||||
template<typename X>
|
||||
void ScalarIntTransform<X>::transform(const int opNum,
|
||||
void *x,
|
||||
Nd4jLong xEws,
|
||||
void *z,
|
||||
Nd4jLong zEws,
|
||||
void *scalar,
|
||||
void *extraParams,
|
||||
const Nd4jLong n) {
|
||||
DISPATCH_BY_OPNUM_T(transform, PARAMS(x, xEws, z, zEws, scalar, extraParams, n), SCALAR_INT_OPS);
|
||||
}
|
||||
|
||||
template<typename X>
|
||||
void ScalarIntTransform<X>::transform(const int opNum,
|
||||
void *x,
|
||||
Nd4jLong *xShapeInfo,
|
||||
void *z,
|
||||
Nd4jLong *zShapeInfo,
|
||||
void *scalar,
|
||||
void *extraParams) {
|
||||
DISPATCH_BY_OPNUM_T(transform, PARAMS(x, xShapeInfo, z, zShapeInfo, scalar, extraParams), SCALAR_INT_OPS);
|
||||
}
|
||||
|
||||
template<typename X>
|
||||
template<typename OpType>
|
||||
void ScalarIntTransform<X>::transform(void *vx,
|
||||
Nd4jLong *xShapeInfo,
|
||||
void *vz,
|
||||
Nd4jLong *zShapeInfo,
|
||||
void *vscalar,
|
||||
void *vextraParams) {
|
||||
|
||||
auto x = reinterpret_cast<X *>(vx);
|
||||
auto z = reinterpret_cast<X *>(vz);
|
||||
auto scalar = reinterpret_cast<X *>(vscalar)[0];
|
||||
auto extraParams = reinterpret_cast<X *>(vextraParams);
|
||||
|
||||
auto xEws = shape::elementWiseStride(xShapeInfo);
|
||||
auto zEws = shape::elementWiseStride(zShapeInfo);
|
||||
auto len = shape::length(xShapeInfo);
|
||||
|
||||
// nd4j_logger("Launching scalar: xOrder: %i; zOrder: %i; xEWS: %i\n", xOrder, zOrder, xEws);
|
||||
|
||||
nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXZ(xShapeInfo, zShapeInfo);
|
||||
|
||||
if (kindOfLoop == nd4j::LoopKind::EWS1 || kindOfLoop == nd4j::LoopKind::EWSNONZERO) {
|
||||
transform<OpType>(x, xEws, z, zEws, vscalar, extraParams, len);
|
||||
return;
|
||||
}
|
||||
|
||||
uint xShapeInfoCast[MAX_RANK];
|
||||
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo<uint>(xShapeInfo, xShapeInfoCast);
|
||||
|
||||
nd4j::OmpLaunchHelper info(len);
|
||||
|
||||
if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
|
||||
|
||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
||||
{
|
||||
auto threadNum = omp_get_thread_num();
|
||||
auto threadOffset = info.getThreadOffset(threadNum);
|
||||
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
|
||||
|
||||
PRAGMA_OMP_SIMD
|
||||
for (unsigned int i = 0; i < ulen; i++) {
|
||||
auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, len, canCastX);
|
||||
z[offset] = OpType::op(x[offset], scalar, extraParams);
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
|
||||
uint zShapeInfoCast[MAX_RANK];
|
||||
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo<uint>(zShapeInfo, zShapeInfoCast);
|
||||
|
||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
||||
{
|
||||
auto threadNum = omp_get_thread_num();
|
||||
auto threadOffset = info.getThreadOffset(threadNum);
|
||||
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
|
||||
|
||||
PRAGMA_OMP_SIMD
|
||||
for (unsigned int i = 0; i < ulen; i++) {
|
||||
auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, len, canCastX);
|
||||
auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, len, canCastZ);
|
||||
z[zOffset] = OpType::op(x[xOffset], scalar, extraParams);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template<typename X>
|
||||
template<typename OpType>
|
||||
void ScalarIntTransform<X>::transform(void *vx,
|
||||
Nd4jLong xEws,
|
||||
void *vz,
|
||||
Nd4jLong zEws,
|
||||
void *vscalar,
|
||||
void *vextraParams,
|
||||
const Nd4jLong len) {
|
||||
|
||||
auto x = reinterpret_cast<X *>(vx);
|
||||
auto z = reinterpret_cast<X *>(vz);
|
||||
auto scalar = reinterpret_cast<X *>(vscalar)[0];
|
||||
auto extraParams = reinterpret_cast<X *>(vextraParams);
|
||||
|
||||
nd4j::OmpLaunchHelper info(len);
|
||||
|
||||
if (xEws == 1 && zEws == 1) {
|
||||
|
||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
||||
{
|
||||
auto threadNum = omp_get_thread_num();
|
||||
auto threadOffset = info.getThreadOffset(threadNum);
|
||||
auto xi = x + threadOffset;
|
||||
auto zi = z + threadOffset;
|
||||
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
|
||||
|
||||
PRAGMA_OMP_SIMD
|
||||
for (unsigned int i = 0; i < ulen; i++)
|
||||
zi[i] = OpType::op(xi[i], scalar, extraParams);
|
||||
}
|
||||
}
|
||||
else {
|
||||
|
||||
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
|
||||
{
|
||||
auto threadNum = omp_get_thread_num();
|
||||
auto threadOffset = info.getThreadOffset(threadNum);
|
||||
auto xi = x + xEws * threadOffset;
|
||||
auto zi = z + zEws * threadOffset;
|
||||
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
|
||||
|
||||
PRAGMA_OMP_SIMD
|
||||
for (unsigned int i = 0; i < ulen; i++)
|
||||
zi[i * zEws] = OpType::op(xi[i * xEws], scalar, extraParams);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT ScalarIntTransform, , INTEGER_TYPES);
|
||||
|
||||
}
|
||||
}
|
|
@ -0,0 +1,291 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#include <op_boilerplate.h>
|
||||
#include <loops/broadcasting_int.h>
|
||||
#include <loops/legacy_ops.h>
|
||||
#include <types/types.h>
|
||||
#include <Environment.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <string>
|
||||
#include <stdexcept>
|
||||
#include <StringUtils.h>
|
||||
|
||||
using namespace simdOps;
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template<typename X, typename OpClass>
|
||||
static __global__ void broadcastIntSimple(
|
||||
void *x,
|
||||
Nd4jLong *xShapeInfo,
|
||||
void *y,
|
||||
Nd4jLong *yShapeInfo,
|
||||
void *z,
|
||||
Nd4jLong *zShapeInfo,
|
||||
int *dimension,
|
||||
int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
|
||||
|
||||
functions::broadcast::BroadcastInt<X>::template transformCuda<OpClass>(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo,dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template<typename X, typename OpClass>
|
||||
static __global__ void broadcastBoolInverseSimple(
|
||||
void *x,
|
||||
Nd4jLong *xShapeInfo,
|
||||
void *y,
|
||||
Nd4jLong *yShapeInfo,
|
||||
void *z,
|
||||
Nd4jLong *zShapeInfo,
|
||||
int *dimension,
|
||||
int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
|
||||
|
||||
functions::broadcast::BroadcastInt<X>::template transformInverseCuda<OpClass>(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo,dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ);
|
||||
}
|
||||
|
||||
namespace functions {
|
||||
namespace broadcast {
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template<typename X>
|
||||
template <typename OpClass>
|
||||
__host__ void BroadcastInt<X>::intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
|
||||
broadcastIntSimple<X, OpClass><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template<typename X>
|
||||
__host__ void BroadcastInt<X>::execBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
|
||||
DISPATCH_BY_OPNUM_T(intermediateBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), OPS_A(BROADCAST_INT_OPS))
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template<typename X>
|
||||
template <typename OpClass>
|
||||
__host__ void BroadcastInt<X>::intermediateInverseBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
|
||||
broadcastBoolInverseSimple<X, OpClass><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template<typename X>
|
||||
__host__ void BroadcastInt<X>::execInverseBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
|
||||
DISPATCH_BY_OPNUM_T(intermediateInverseBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), OPS_A(BROADCAST_INT_OPS))
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template<typename X>
|
||||
template <typename OpType>
|
||||
__device__ void BroadcastInt<X>::transformInverseCuda(
|
||||
void *vx, Nd4jLong *xShapeInfo,
|
||||
void *vy, Nd4jLong *yShapeInfo,
|
||||
void *vz, Nd4jLong *zShapeInfo,
|
||||
int *dimension, int dimensionLength,
|
||||
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
|
||||
|
||||
if (tadOnlyShapeInfoZ == nullptr) {
|
||||
tadOnlyShapeInfoZ = tadOnlyShapeInfo;
|
||||
tadOffsetsZ = tadOffsets;
|
||||
}
|
||||
|
||||
auto x = reinterpret_cast<X*>(vx);
|
||||
auto y = reinterpret_cast<X*>(vy);
|
||||
auto z = reinterpret_cast<X*>(vz);
|
||||
|
||||
//decompose in to several sub tads after
|
||||
//moving all dimensions (in sorted order)
|
||||
//to the back.
|
||||
//permuted version of the x shape info for setting up the tad problem
|
||||
__shared__ Nd4jLong tadLength;
|
||||
__shared__ Nd4jLong tadEWS;
|
||||
__shared__ int numTads;
|
||||
__shared__ Nd4jLong xEWS;
|
||||
__shared__ Nd4jLong zEWS;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
tadLength = shape::length(tadOnlyShapeInfo);//shape::tadLength(xShapeInfo, dimension, dimensionLength);
|
||||
tadEWS = shape::elementWiseStride(tadOnlyShapeInfo);
|
||||
numTads = shape::length(yShapeInfo) / tadLength;
|
||||
xEWS = shape::elementWiseStride(xShapeInfo);
|
||||
zEWS = shape::elementWiseStride(tadOnlyShapeInfoZ);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int r = blockIdx.x; r < numTads; r += gridDim.x) {
|
||||
auto rZ = z + tadOffsetsZ[r];
|
||||
auto rY = y + tadOffsets[r];
|
||||
|
||||
if(tadEWS > 0 && zEWS > 0 && xEWS > 0 && dimensionLength == 1) {
|
||||
|
||||
for (int i = threadIdx.x; i < tadLength; i+= blockDim.x)
|
||||
rZ[i * zEWS] = OpType::op(x[i * xEWS], rY[i * tadEWS]);
|
||||
}
|
||||
else {
|
||||
// it is expected that x and z tads and y array all have the same length
|
||||
for (Nd4jLong i = threadIdx.x; i < tadLength; i+= blockDim.x) {
|
||||
auto xOffset = shape::getIndexOffset(i, xShapeInfo, tadLength);
|
||||
auto yOffset = shape::getIndexOffset(i, tadOnlyShapeInfo, tadLength);
|
||||
auto zOffset = shape::getIndexOffset(i, tadOnlyShapeInfoZ, tadLength);
|
||||
|
||||
rZ[zOffset] = OpType::op(x[xOffset], rY[yOffset]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template<typename X>
|
||||
template <typename OpType>
|
||||
__device__ void BroadcastInt<X>::transformCuda(
|
||||
void *vx, Nd4jLong *xShapeInfo,
|
||||
void *vy, Nd4jLong *yShapeInfo,
|
||||
void *vz, Nd4jLong *zShapeInfo,
|
||||
int *dimension, int dimensionLength,
|
||||
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
|
||||
|
||||
if (tadOnlyShapeInfoZ == nullptr) {
|
||||
tadOnlyShapeInfoZ = tadOnlyShapeInfo;
|
||||
tadOffsetsZ = tadOffsets;
|
||||
}
|
||||
|
||||
auto x = reinterpret_cast<X*>(vx);
|
||||
auto y = reinterpret_cast<X*>(vy);
|
||||
auto z = reinterpret_cast<X*>(vz);
|
||||
|
||||
//decompose in to several sub tads after
|
||||
//moving all dimensions (in sorted order)
|
||||
//to the back.
|
||||
//permuted version of the x shape info for setting up the tad problem
|
||||
__shared__ Nd4jLong tadLength;
|
||||
__shared__ Nd4jLong tadEWS;
|
||||
__shared__ int numTads;
|
||||
__shared__ Nd4jLong yEWS;
|
||||
__shared__ Nd4jLong zEWS;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
tadLength = shape::length(tadOnlyShapeInfo);//shape::tadLength(xShapeInfo, dimension, dimensionLength);
|
||||
tadEWS = shape::elementWiseStride(tadOnlyShapeInfo);
|
||||
numTads = shape::length(xShapeInfo) / tadLength;
|
||||
yEWS = shape::elementWiseStride(yShapeInfo);
|
||||
zEWS = shape::elementWiseStride(tadOnlyShapeInfoZ);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
__shared__ X *rZ;
|
||||
__shared__ X *rX;
|
||||
|
||||
for (int r = blockIdx.x; r < numTads; r += gridDim.x) {
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
rZ = z + tadOffsetsZ[r];
|
||||
rX = x + tadOffsets[r];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
|
||||
if(tadEWS > 0 && zEWS > 0 && yEWS > 0 && dimensionLength == 1) {
|
||||
|
||||
for (int i = threadIdx.x; i < tadLength; i+= blockDim.x)
|
||||
rZ[i * zEWS] = OpType::op(rX[i * tadEWS], y[i * yEWS]);
|
||||
}
|
||||
else {
|
||||
// it is expected that x and z tads and y array all have the same length
|
||||
for (Nd4jLong i = threadIdx.x; i < tadLength; i+= blockDim.x) {
|
||||
auto xOffset = shape::getIndexOffset(i, tadOnlyShapeInfo, tadLength);
|
||||
auto yOffset = shape::getIndexOffset(i, yShapeInfo, tadLength);
|
||||
auto zOffset = shape::getIndexOffset(i, tadOnlyShapeInfoZ, tadLength);
|
||||
|
||||
rZ[zOffset] = OpType::op(rX[xOffset], y[yOffset]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template<typename X>
|
||||
void BroadcastInt<X>::exec(int opNum,
|
||||
void *x,
|
||||
Nd4jLong *xShapeInfo,
|
||||
void *y,
|
||||
Nd4jLong *yShapeInfo,
|
||||
void *result,
|
||||
Nd4jLong *resultShapeInfo,
|
||||
int *dimension,
|
||||
int dimensionLength,
|
||||
Nd4jLong *tadShapeInfo,
|
||||
Nd4jLong *tadOffset,
|
||||
Nd4jLong *tadShapeInfoZ,
|
||||
Nd4jLong *tadOffsetZ) {
|
||||
|
||||
}
|
||||
|
||||
template<typename X>
|
||||
void BroadcastInt<X>::execInverse(int opNum,
|
||||
void *x,
|
||||
Nd4jLong *xShapeInfo,
|
||||
void *y,
|
||||
Nd4jLong *yShapeInfo,
|
||||
void *result,
|
||||
Nd4jLong *resultShapeInfo,
|
||||
int *dimension,
|
||||
int dimensionLength,
|
||||
Nd4jLong *tadShapeInfo,
|
||||
Nd4jLong *tadOffset,
|
||||
Nd4jLong *tadShapeInfoZ,
|
||||
Nd4jLong *tadOffsetZ) {
|
||||
|
||||
}
|
||||
|
||||
template<typename X>
|
||||
template<typename OpType>
|
||||
void BroadcastInt<X>::exec(void *x,
|
||||
Nd4jLong *xShapeInfo,
|
||||
void *y,
|
||||
Nd4jLong *yShapeInfo,
|
||||
void *result,
|
||||
Nd4jLong *resultShapeInfo,
|
||||
int *dimension,
|
||||
int dimensionLength,
|
||||
Nd4jLong *tadShapeInfo,
|
||||
Nd4jLong *tadOffset,
|
||||
Nd4jLong *tadShapeInfoZ,
|
||||
Nd4jLong *tadOffsetZ) {
|
||||
|
||||
}
|
||||
|
||||
template<typename X>
|
||||
template<typename OpType>
|
||||
void BroadcastInt<X>::execInverse(void *x,
|
||||
Nd4jLong *xShapeInfo,
|
||||
void *y,
|
||||
Nd4jLong *yShapeInfo,
|
||||
void *result,
|
||||
Nd4jLong *resultShapeInfo,
|
||||
int *dimension,
|
||||
int dimensionLength,
|
||||
Nd4jLong *tadShapeInfo,
|
||||
Nd4jLong *tadOffset,
|
||||
Nd4jLong *tadShapeInfoZ,
|
||||
Nd4jLong *tadOffsetZ) {
|
||||
|
||||
}
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT BroadcastInt, , INTEGER_TYPES);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,173 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
// @author raver119@gmail.com
|
||||
// @author Yurii Shyrma (iuriish@yahoo.com), created on 08.11.2018
|
||||
|
||||
#ifndef PAIRWISE_INT_CU
|
||||
#define PAIRWISE_INT_CU
|
||||
|
||||
|
||||
#include "../pairwise_int.h"
|
||||
|
||||
|
||||
using namespace simdOps;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
template <typename X, typename OpType>
|
||||
__global__ static void pairwiseSimpleShaped(void* vx, Nd4jLong *xShapeInfo,
|
||||
void *vy, Nd4jLong *yShapeInfo,
|
||||
void *vz, Nd4jLong *zShapeInfo,
|
||||
void *vextraParams) {
|
||||
|
||||
auto x = reinterpret_cast<X*>(vx);
|
||||
auto y = reinterpret_cast<X*>(vy);
|
||||
auto z = reinterpret_cast<X*>(vz);
|
||||
auto extraParams = reinterpret_cast<X*>(vextraParams);
|
||||
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
__shared__ int xEws;
|
||||
__shared__ int yEws;
|
||||
__shared__ int zEws;
|
||||
__shared__ char xOrder;
|
||||
__shared__ char yOrder;
|
||||
__shared__ char zOrder;
|
||||
__shared__ Nd4jLong len;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
xEws = shape::elementWiseStride(xShapeInfo);
|
||||
yEws = shape::elementWiseStride(yShapeInfo);
|
||||
zEws = shape::elementWiseStride(zShapeInfo);
|
||||
xOrder = shape::order(xShapeInfo);
|
||||
yOrder = shape::order(yShapeInfo);
|
||||
zOrder = shape::order(zShapeInfo);
|
||||
len = shape::length(xShapeInfo);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
|
||||
if (xEws >= 1 && yEws >= 1 && zEws >= 1 && xOrder == yOrder && xOrder == zOrder) {
|
||||
for (Nd4jLong i = tid; i < len; i += gridDim.x * blockDim.x) {
|
||||
z[i * zEws] = OpType::op(x[i * xEws], y[i * yEws], extraParams);
|
||||
}
|
||||
}
|
||||
else if (vx == vz) {
|
||||
for (Nd4jLong i = tid; i < len; i += gridDim.x * blockDim.x) {
|
||||
auto xOffset = shape::getIndexOffset(i, xShapeInfo, len);
|
||||
auto yOffset = shape::getIndexOffset(i, yShapeInfo, len);
|
||||
|
||||
z[xOffset] = OpType::op(x[xOffset], y[yOffset], extraParams);
|
||||
}
|
||||
}
|
||||
else {
|
||||
for (Nd4jLong i = tid; i < len; i += gridDim.x * blockDim.x) {
|
||||
auto xOffset = shape::getIndexOffset(i, xShapeInfo, len);
|
||||
auto yOffset = shape::getIndexOffset(i, yShapeInfo, len);
|
||||
auto zOffset = shape::getIndexOffset(i, zShapeInfo, len);
|
||||
|
||||
z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
namespace functions {
|
||||
namespace pairwise_transforms {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
template<typename X>
|
||||
template<typename OpType>
|
||||
void _CUDA_H PairWiseIntTransform<X>::intermediateShaped(dim3& launchDims, cudaStream_t *stream,
|
||||
void *vx, Nd4jLong *xShapeInfo,
|
||||
void *vy, Nd4jLong *yShapeInfo,
|
||||
void *vz, Nd4jLong *zShapeInfo,
|
||||
void *vextraParams){
|
||||
|
||||
pairwiseSimpleShaped<X, OpType><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vextraParams);
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
template<typename X>
|
||||
void PairWiseIntTransform<X>::executeCudaShaped(dim3& launchDims, cudaStream_t *stream, int opNum, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, void *vextraParams) {
|
||||
auto xType = nd4j::DataTypeUtils::fromT<X>();
|
||||
|
||||
DISPATCH_BY_OPNUM_T(intermediateShaped, PARAMS(launchDims, stream, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vextraParams), PAIRWISE_INT_OPS);
|
||||
}
|
||||
|
||||
|
||||
template<typename X>
|
||||
void PairWiseIntTransform<X>::exec(
|
||||
const int opNum,
|
||||
void *dx,
|
||||
Nd4jLong *xShapeBuffer,
|
||||
void *y,
|
||||
Nd4jLong *yShapeBuffer,
|
||||
void *result,
|
||||
Nd4jLong *resultShapeBuffer,
|
||||
void *extraParams) {
|
||||
|
||||
}
|
||||
|
||||
template<typename X>
|
||||
void PairWiseIntTransform<X>::exec(
|
||||
const int opNum,
|
||||
void *dx,
|
||||
Nd4jLong xStride,
|
||||
void *y,
|
||||
Nd4jLong yStride,
|
||||
void *result,
|
||||
Nd4jLong resultStride,
|
||||
void *extraParams,
|
||||
Nd4jLong n) {
|
||||
|
||||
}
|
||||
|
||||
|
||||
template<typename X>
|
||||
template<typename OpType>
|
||||
void PairWiseIntTransform<X>::exec(
|
||||
void *vx,
|
||||
Nd4jLong* xShapeBuffer,
|
||||
void *vy,
|
||||
Nd4jLong* yShapeBuffer,
|
||||
void *vresult,
|
||||
Nd4jLong* resultShapeBuffer,
|
||||
void *vextraParams) {
|
||||
|
||||
}
|
||||
|
||||
template<typename X>
|
||||
template<typename OpType>
|
||||
void PairWiseIntTransform<X>::exec(void *vx,
|
||||
Nd4jLong xStride,
|
||||
void *vy,
|
||||
Nd4jLong yStride,
|
||||
void *vresult,
|
||||
Nd4jLong resultStride,
|
||||
void *vextraParams,
|
||||
const Nd4jLong n) {
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT PairWiseIntTransform, , INTEGER_TYPES);
|
||||
}
|
||||
}
|
||||
|
||||
#endif // PAIRWISE_INT_CU
|
|
@ -0,0 +1,269 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author Yurii Shyrma (iuriish@yahoo.com), created on 08.11.2018
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#include "../scalar_int.h"
|
||||
#include <op_boilerplate.h>
|
||||
#include <types/types.h>
|
||||
|
||||
#include "../legacy_ops.h"
|
||||
|
||||
using namespace simdOps;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
template <typename X, typename OpType>
|
||||
__global__ void scalarAlongDimension(void *x, Nd4jLong *xShapeInfo,
|
||||
void *extraParams,
|
||||
void *z, Nd4jLong *zShapeInfo,
|
||||
void *scalars,
|
||||
int *dimension, int dimensionLength,
|
||||
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
|
||||
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) {
|
||||
|
||||
functions::scalar::ScalarIntTransform<X>::template transformCuda<OpType>(x, xShapeInfo, extraParams, z, zShapeInfo, scalars, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ);
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
template <typename X, typename OpType>
|
||||
__global__ void scalarSimpleShaped(void* x, void *y, Nd4jLong *xShapeInfo, void *params, void *z, Nd4jLong *zShapeInfo, int *allocationBuffer) {
|
||||
|
||||
functions::scalar::ScalarIntTransform<X>::template transformCuda<OpType>(y, x, xShapeInfo, params, z, zShapeInfo, allocationBuffer);
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
// *********************************************************************//
|
||||
// *********************************************************************//
|
||||
namespace functions {
|
||||
namespace scalar {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
template<typename X>
|
||||
template<typename OpType>
|
||||
__device__ void ScalarIntTransform<X>::transformCuda(void* vscalar,
|
||||
void *vy, Nd4jLong *yShapeInfo,
|
||||
void *vparams,
|
||||
void *vz, Nd4jLong *zShapeInfo,
|
||||
int *allocationBuffer) {
|
||||
auto scalar = reinterpret_cast<X*>(vscalar)[0];
|
||||
auto y = reinterpret_cast<X*>(vy);
|
||||
auto params = reinterpret_cast<X*>(vparams);
|
||||
auto z = reinterpret_cast<X*>(vz);
|
||||
|
||||
auto yRank = shape::rank(yShapeInfo);
|
||||
auto yEWS = shape::elementWiseStride(yShapeInfo);
|
||||
auto yShape = shape::shapeOf(yShapeInfo);
|
||||
auto yStride = shape::stride(yShapeInfo);
|
||||
|
||||
auto zRank = shape::rank(zShapeInfo);
|
||||
auto zEWS = shape::elementWiseStride(zShapeInfo);
|
||||
auto zShape = shape::shapeOf(zShapeInfo);
|
||||
auto zStride = shape::stride(zShapeInfo);
|
||||
|
||||
int totalThreads = gridDim.x * blockDim.x;
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
__shared__ int len;
|
||||
if(threadIdx.x == 0)
|
||||
len = shape::length(yShapeInfo);
|
||||
__syncthreads();
|
||||
|
||||
if(yEWS >= 1 && zEWS >= 1 && shape::order(yShapeInfo) == shape::order(zShapeInfo)) {
|
||||
transformCuda<OpType>(len, vscalar, vy, yEWS, vparams, vz, zEWS, allocationBuffer);
|
||||
}
|
||||
else {
|
||||
for (Nd4jLong i = tid; i < len; i+= totalThreads)
|
||||
z[shape::getIndexOffset(i, zShapeInfo, len)] = OpType::op(y[shape::getIndexOffset(i, yShapeInfo, len)], scalar, params);
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
template<typename X>
|
||||
template<typename OpType>
|
||||
__device__ void ScalarIntTransform<X>::transformCuda(Nd4jLong len,
|
||||
void* vx,
|
||||
void *vy, Nd4jLong yEWS,
|
||||
void *vparams,
|
||||
void *vz, Nd4jLong zEWS,
|
||||
int *allocationBuffer) {
|
||||
|
||||
auto x = reinterpret_cast<X*>(vx)[0];
|
||||
auto y = reinterpret_cast<X*>(vy);
|
||||
auto z = reinterpret_cast<X*>(vz);
|
||||
auto params = reinterpret_cast<X*>(vparams);
|
||||
|
||||
int totalThreads = gridDim.x * blockDim.x;
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
Nd4jLong i = tid;
|
||||
if(yEWS == 1 && zEWS == 1) {
|
||||
for (; i < len; i += totalThreads)
|
||||
z[i] = OpType::op(y[i], x, params);
|
||||
}
|
||||
else {
|
||||
for (; i < len; i += totalThreads)
|
||||
z[i * zEWS] = OpType::op(y[i * yEWS], x, params);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
template<typename X>
|
||||
template<typename OpType>
|
||||
__device__ void ScalarIntTransform<X>::transformCuda(void *vx, Nd4jLong *xShapeInfo,
|
||||
void *vextraParams,
|
||||
void *vz, Nd4jLong *zShapeInfo,
|
||||
void *vscalars,
|
||||
int *dimension, int dimensionLength,
|
||||
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
|
||||
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) {
|
||||
auto x = reinterpret_cast<X*>(vx);
|
||||
auto scalars = reinterpret_cast<X*>(vscalars);
|
||||
auto z = reinterpret_cast<X*>(vz);
|
||||
auto extraParams = reinterpret_cast<X*>(vextraParams);
|
||||
|
||||
if (tadShapeInfoZ == nullptr) {
|
||||
tadShapeInfoZ = tadShapeInfo;
|
||||
tadOffsetsZ = tadOffsets;
|
||||
}
|
||||
|
||||
// tad preparation
|
||||
auto tadEws = shape::elementWiseStride(tadShapeInfo);
|
||||
auto zEws = shape::elementWiseStride(tadShapeInfoZ);
|
||||
auto tadLength = shape::length(tadShapeInfo);//shape::tadLength(xShapeInfo, dimension, dimensionLength);
|
||||
auto numTads =shape::length(xShapeInfo) / tadLength;
|
||||
|
||||
if (tadEws > 0 && zEws > 0 && shape::order(tadShapeInfo) == shape::order(zShapeInfo)) {
|
||||
|
||||
// main loop, rolling over tads
|
||||
for (int r = blockIdx.x; r < numTads; r += gridDim.x) {
|
||||
X *oZ = z + tadOffsetsZ[r];
|
||||
X *oX = x + tadOffsets[r];
|
||||
|
||||
auto s = scalars[r];
|
||||
|
||||
for (int f = threadIdx.x; f < tadLength; f += blockDim.x)
|
||||
oZ[f * zEws] = OpType::op(oX[f * tadEws], s, extraParams);
|
||||
}
|
||||
} else {
|
||||
// main loop, rolling over tads
|
||||
for (int r = blockIdx.x; r < numTads; r += gridDim.x) {
|
||||
X *oZ = z + tadOffsetsZ[r];
|
||||
X *oX = x + tadOffsets[r];
|
||||
|
||||
auto s = scalars[r];
|
||||
|
||||
for (int f = threadIdx.x; f < tadLength; f += blockDim.x)
|
||||
oZ[shape::getIndexOffset(f, tadShapeInfoZ, tadLength)] = OpType::op(oX[shape::getIndexOffset(f, tadShapeInfo, tadLength)], s, extraParams);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
template<typename X>
|
||||
template <typename OpType>
|
||||
_CUDA_H void ScalarIntTransform<X>::intermediateAlongDimension(dim3& launchDims, cudaStream_t *stream,
|
||||
void *x, Nd4jLong *xShapeInfo,
|
||||
void *z, Nd4jLong *zShapeInfo,
|
||||
void *scalars,
|
||||
void *extraParams,
|
||||
int *dimension, int dimensionLength,
|
||||
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
|
||||
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) {
|
||||
|
||||
scalarAlongDimension<X, OpType><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShapeInfo, extraParams, z, zShapeInfo, scalars, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
template<typename X>
|
||||
template<typename OpType>
|
||||
void _CUDA_H ScalarIntTransform<X>::intermediateShaped(dim3& launchDims, cudaStream_t *stream,
|
||||
void *vx, Nd4jLong *xShapeInfo,
|
||||
void *vz, Nd4jLong *zShapeInfo,
|
||||
void* vscalar,
|
||||
void *vextraParams, int *allocPointer){
|
||||
|
||||
scalarSimpleShaped<X, OpType><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, vscalar, xShapeInfo, vextraParams, vz, zShapeInfo, allocPointer);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
template<typename X>
|
||||
void ScalarIntTransform<X>::executeCudaShaped(dim3& launchDims, cudaStream_t *stream,
|
||||
int opNum,
|
||||
void *vx, Nd4jLong *xShapeInfo,
|
||||
void *vz, Nd4jLong *zShapeInfo,
|
||||
void* vscalar,
|
||||
void *vextraParams) {
|
||||
|
||||
if (nd4j::Environment::getInstance()->isDebugAndVerbose())
|
||||
printf("H14 opNum:[%i]\n", opNum);
|
||||
|
||||
DISPATCH_BY_OPNUM_T(intermediateShaped, PARAMS(launchDims, stream, vx, xShapeInfo, vz, zShapeInfo, vscalar, vextraParams, nullptr), SCALAR_INT_OPS);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
template<typename X>
|
||||
void ScalarIntTransform<X>::executeCudaAlongDimension(dim3& launchDims, cudaStream_t *stream, int opNum, void *vx, Nd4jLong *xShapeInfo, void *vz, Nd4jLong *zShapeInfo, void *vscalars, void *vextraParams, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) {
|
||||
DISPATCH_BY_OPNUM_T(intermediateAlongDimension, PARAMS(launchDims, stream, vx, xShapeInfo, vz, zShapeInfo, vscalars, vextraParams, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), SCALAR_INT_OPS);
|
||||
}
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT ScalarIntTransform, , INTEGER_TYPES);
|
||||
|
||||
|
||||
template<typename X>
|
||||
template <typename OpType>
|
||||
void ScalarIntTransform<X,>::transform(void *x, Nd4jLong *xShapeInfo, void *extraParams, void *z, Nd4jLong *zShapeInfo, void *scalars, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) {
|
||||
|
||||
}
|
||||
|
||||
template<typename X>
|
||||
void ScalarIntTransform<X>::transform(int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams, void *z, Nd4jLong *zShapeInfo, void *scalars, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) {
|
||||
|
||||
}
|
||||
|
||||
template<typename X>
|
||||
void ScalarIntTransform<X>::transform(const int opNum, void *x, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *scalar, void *extraParams) {
|
||||
|
||||
}
|
||||
|
||||
template<typename X>
|
||||
void ScalarIntTransform<X>::transform(const int opNum, void *x, Nd4jLong xStride, void *result, Nd4jLong resultStride, void *scalar, void *extraParams, const Nd4jLong n) {
|
||||
|
||||
}
|
||||
|
||||
template<typename X>
|
||||
template<typename OpType>
|
||||
void ScalarIntTransform<X>::transform(void *x, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *scalar, void *extraParams) {
|
||||
|
||||
}
|
||||
|
||||
|
||||
template<typename X>
|
||||
template<typename OpType>
|
||||
void ScalarIntTransform<X>::transform(void *x, Nd4jLong xStride, void *result, Nd4jLong resultStride, void *scalar, void *extraParams, const Nd4jLong n) {
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -29,6 +29,16 @@
|
|||
(4, aggregateOps::CBOW) ,\
|
||||
(5, aggregateOps::GEMM)
|
||||
|
||||
#define BROADCAST_INT_OPS \
|
||||
(0, ShiftLeft), \
|
||||
(1, ShiftRight), \
|
||||
(2, CyclicShiftLeft), \
|
||||
(3, CyclicShiftRight), \
|
||||
(4, IntAnd), \
|
||||
(5, IntOr), \
|
||||
(6, IntXor)
|
||||
|
||||
|
||||
#define BROADCAST_BOOL_OPS \
|
||||
(0, EqualTo),\
|
||||
(1, GreaterThan),\
|
||||
|
@ -171,6 +181,14 @@
|
|||
(0, SummaryStatsVariance), \
|
||||
(1, SummaryStatsStandardDeviation)
|
||||
|
||||
#define SCALAR_INT_OPS \
|
||||
(0, ShiftLeft), \
|
||||
(1, ShiftRight), \
|
||||
(2, CyclicShiftLeft), \
|
||||
(3, CyclicShiftRight), \
|
||||
(4, IntAnd), \
|
||||
(5, IntOr), \
|
||||
(6, IntXor)
|
||||
|
||||
#define SCALAR_BOOL_OPS \
|
||||
(0, EqualTo),\
|
||||
|
@ -300,6 +318,15 @@
|
|||
(13, ExponentialDistribution),\
|
||||
(14, ExponentialDistributionInv)
|
||||
|
||||
#define PAIRWISE_INT_OPS \
|
||||
(0, ShiftLeft), \
|
||||
(1, ShiftRight), \
|
||||
(2, CyclicShiftLeft), \
|
||||
(3, CyclicShiftRight), \
|
||||
(4, IntAnd), \
|
||||
(5, IntOr), \
|
||||
(6, IntXor)
|
||||
|
||||
#define PAIRWISE_BOOL_OPS \
|
||||
(0, EqualTo),\
|
||||
(1, GreaterThan),\
|
||||
|
|
|
@ -0,0 +1,119 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
/*
|
||||
* pairwise_transform.h
|
||||
*
|
||||
* Created on: Dec 28, 2015
|
||||
* Author: agibsonccc
|
||||
*/
|
||||
|
||||
#ifndef PAIRWISE_INT_H_
|
||||
#define PAIRWISE_INT_H_
|
||||
#ifdef _OPENMP
|
||||
#include <omp.h>
|
||||
#endif
|
||||
#include <templatemath.h>
|
||||
#include <helpers/shape.h>
|
||||
#include <pairwise_util.h>
|
||||
#include <dll.h>
|
||||
#include <stdio.h>
|
||||
#include <ops/ops.h>
|
||||
#include <op_boilerplate.h>
|
||||
#include <helpers/DebugHelper.h>
|
||||
|
||||
#ifdef __CUDACC__
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#endif
|
||||
|
||||
#ifndef _OPENMP
|
||||
#define omp_get_thread_num() 0
|
||||
#define omp_get_max_threads() 1
|
||||
#endif
|
||||
|
||||
|
||||
#include "legacy_ops.h"
|
||||
|
||||
using namespace simdOps;
|
||||
|
||||
namespace functions {
|
||||
namespace pairwise_transforms {
|
||||
|
||||
/**
|
||||
* Transforms involving 2 arrays
|
||||
*/
|
||||
template<typename X>
|
||||
class PairWiseIntTransform {
|
||||
public:
|
||||
|
||||
#ifdef __CUDACC__
|
||||
|
||||
template <typename OpType>
|
||||
static __host__ void intermediateShaped(dim3& launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, void *vextraParams);
|
||||
|
||||
static __host__ void executeCudaShaped(dim3& launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, void *extraParams);
|
||||
|
||||
|
||||
#endif
|
||||
public:
|
||||
|
||||
static void exec(
|
||||
const int opNum,
|
||||
void *dx,
|
||||
Nd4jLong *xShapeBuffer,
|
||||
void *y,
|
||||
Nd4jLong *yShapeBuffer,
|
||||
void *result,
|
||||
Nd4jLong *resultShapeBuffer,
|
||||
void *extraParams);
|
||||
|
||||
static void exec(
|
||||
const int opNum,
|
||||
void *dx,
|
||||
Nd4jLong xStride,
|
||||
void *y,
|
||||
Nd4jLong yStride,
|
||||
void *result,
|
||||
Nd4jLong resultStride,
|
||||
void *extraParams,
|
||||
Nd4jLong n);
|
||||
|
||||
|
||||
template<typename OpType>
|
||||
static void exec(
|
||||
void *vx,
|
||||
Nd4jLong* xShapeBuffer,
|
||||
void *vy,
|
||||
Nd4jLong* yShapeBuffer,
|
||||
void *vresult,
|
||||
Nd4jLong* resultShapeBuffer,
|
||||
void *vextraParams);
|
||||
|
||||
template<typename OpType>
|
||||
static void exec(void *vx,
|
||||
Nd4jLong xStride,
|
||||
void *vy,
|
||||
Nd4jLong yStride,
|
||||
void *vresult,
|
||||
Nd4jLong resultStride,
|
||||
void *vextraParams,
|
||||
const Nd4jLong n);
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
#endif /* PAIRWISE_TRANSFORM_H_ */
|
|
@ -0,0 +1,142 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
/*
|
||||
* scalar.h
|
||||
*
|
||||
* Created on: Dec 28, 2015
|
||||
* Author: agibsonccc
|
||||
*/
|
||||
|
||||
#ifndef SCALAR_INT_H_
|
||||
#define SCALAR_INT_H_
|
||||
#include <dll.h>
|
||||
|
||||
#ifdef __JNI__
|
||||
#include <jni.h>
|
||||
#endif
|
||||
#include <templatemath.h>
|
||||
#include <ops/ops.h>
|
||||
#include <op_boilerplate.h>
|
||||
#include "helpers/logger.h"
|
||||
#include <OmpLaunchHelper.h>
|
||||
#include <helpers/DebugHelper.h>
|
||||
|
||||
#ifdef __CUDACC__
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <types/float16.h>
|
||||
#endif
|
||||
|
||||
#include "legacy_ops.h"
|
||||
|
||||
namespace functions {
|
||||
namespace scalar {
|
||||
/**
|
||||
* Apply a scalar
|
||||
* operation to an array
|
||||
*/
|
||||
template<typename X>
|
||||
class ScalarIntTransform {
|
||||
|
||||
public:
|
||||
|
||||
#ifdef __CUDACC__
|
||||
|
||||
template<typename OpType>
|
||||
__device__
|
||||
static void transformCuda(void* scalar, void *vy, Nd4jLong *shapeInfo, void *vparams, void *vresult, Nd4jLong *resultShapeInfo, int *allocationBuffer);
|
||||
|
||||
template<typename OpType>
|
||||
__device__
|
||||
static void transformCuda(Nd4jLong n, void* vx, void *vy, Nd4jLong yEWS, void *vparams, void *vz, Nd4jLong zEWS, int *allocationBuffer);
|
||||
|
||||
template<typename OpType>
|
||||
__device__
|
||||
static void transformCuda(void *vx, Nd4jLong *xShapeInfo, void *vextraParams, void *vz, Nd4jLong *zShapeInfo, void *vscalars, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ);
|
||||
|
||||
template <typename OpType>
|
||||
__host__
|
||||
static void intermediateAlongDimension(dim3& launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *z, Nd4jLong *zShapeInfo, void *scalars, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ);
|
||||
|
||||
template <typename OpType>
|
||||
__host__
|
||||
static void intermediateShaped(dim3& launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vz, Nd4jLong *zShapeInfo, void* vscalar, void *vextraParams, int *allocPointer);
|
||||
|
||||
__host__
|
||||
static void executeCudaShaped(dim3& launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void* scalar, void *extraParams);
|
||||
|
||||
__host__
|
||||
static void executeCudaAlongDimension(dim3& launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *z, Nd4jLong *zShapeInfo, void *scalars, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ);
|
||||
|
||||
|
||||
/*
|
||||
#include "cuda/scalar_temp.cu"
|
||||
*/
|
||||
#endif
|
||||
template <typename OpType>
|
||||
static void transform(void *x, Nd4jLong *xShapeInfo, void *extraParams, void *z, Nd4jLong *zShapeInfo, void *scalars, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ);
|
||||
|
||||
static void transform(int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams, void *z, Nd4jLong *zShapeInfo, void *scalars, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ);
|
||||
|
||||
static void transform(const int opNum, void *x, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *scalar, void *extraParams);
|
||||
|
||||
static void transform(const int opNum, void *x, Nd4jLong xStride, void *result, Nd4jLong resultStride, void *scalar, void *extraParams, const Nd4jLong n);
|
||||
|
||||
|
||||
|
||||
|
||||
/*
|
||||
* ScalarOp along dimension
|
||||
*/
|
||||
|
||||
|
||||
/**
|
||||
* CPU implementation of scalar operation
|
||||
* @param x the input
|
||||
* @param xStride the stride for the input
|
||||
* @param result the result buffer
|
||||
* @param resultStride the stride for the result
|
||||
* @param scalar the scalar to apply
|
||||
* @param extraParams the extra parameters where
|
||||
* neccssary
|
||||
* @param n the number of elements to loop over
|
||||
*/
|
||||
|
||||
template<typename OpType>
|
||||
static void transform(void *x, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *scalar, void *extraParams);
|
||||
|
||||
|
||||
/**
|
||||
* CPU implementation of scalar operation
|
||||
* @param x the input
|
||||
* @param xStride the stride for the input
|
||||
* @param result the result buffer
|
||||
* @param resultStride the stride for the result
|
||||
* @param scalar the scalar to apply
|
||||
* @param extraParams the extra parameters where
|
||||
* neccssary
|
||||
* @param n the number of elements to loop over
|
||||
*/
|
||||
|
||||
template<typename OpType>
|
||||
static void transform(void *x, Nd4jLong xStride, void *result, Nd4jLong resultStride, void *scalar, void *extraParams, const Nd4jLong n);
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#endif /* SCALAR_H_ */
|
|
@ -63,6 +63,10 @@ namespace nd4j {
|
|||
enum BoolOps {
|
||||
BUILD_ENUMERATION(PAIRWISE_BOOL_OPS)
|
||||
};
|
||||
|
||||
enum IntOps {
|
||||
BUILD_ENUMERATION(PAIRWISE_INT_OPS)
|
||||
};
|
||||
}
|
||||
|
||||
namespace scalar {
|
||||
|
@ -73,6 +77,10 @@ namespace nd4j {
|
|||
enum BoolOps {
|
||||
BUILD_ENUMERATION(SCALAR_BOOL_OPS)
|
||||
};
|
||||
|
||||
enum IntOps {
|
||||
BUILD_ENUMERATION(SCALAR_INT_OPS)
|
||||
};
|
||||
}
|
||||
|
||||
namespace reduce {
|
||||
|
@ -113,6 +121,10 @@ namespace nd4j {
|
|||
enum BoolOps {
|
||||
BUILD_ENUMERATION(BROADCAST_BOOL_OPS)
|
||||
};
|
||||
|
||||
enum IntOps {
|
||||
BUILD_ENUMERATION(BROADCAST_INT_OPS)
|
||||
};
|
||||
}
|
||||
|
||||
namespace variance {
|
||||
|
|
|
@ -0,0 +1,49 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#ifndef DEV_TESTS_BROADCASTINTOPSTUPLE_H
|
||||
#define DEV_TESTS_BROADCASTINTOPSTUPLE_H
|
||||
|
||||
#include <op_enums.h>
|
||||
|
||||
namespace nd4j {
|
||||
class BroadcastIntOpsTuple {
|
||||
private:
|
||||
|
||||
public:
|
||||
nd4j::scalar::IntOps s;
|
||||
nd4j::pairwise::IntOps p;
|
||||
nd4j::broadcast::IntOps b;
|
||||
|
||||
BroadcastIntOpsTuple() = default;
|
||||
~BroadcastIntOpsTuple() = default;
|
||||
|
||||
BroadcastIntOpsTuple(nd4j::scalar::IntOps scalar, nd4j::pairwise::IntOps pairwise, nd4j::broadcast::IntOps broadcast) {
|
||||
s = scalar;
|
||||
p = pairwise;
|
||||
b = broadcast;
|
||||
}
|
||||
|
||||
static BroadcastIntOpsTuple custom(nd4j::scalar::IntOps scalar, nd4j::pairwise::IntOps pairwise, nd4j::broadcast::IntOps broadcast);
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
#endif //DEV_TESTS_BROADCASTOPSTUPLE_H
|
|
@ -27,22 +27,14 @@
|
|||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
CONFIGURABLE_OP_IMPL(cyclic_rshift_bits, 1, 1, true, 0, -2) {
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
BROADCASTABLE_OP_IMPL(cyclic_rshift_bits, 0, 0) {
|
||||
auto x = INPUT_VARIABLE(0);
|
||||
auto y = INPUT_VARIABLE(1);
|
||||
auto z = OUTPUT_VARIABLE(0);
|
||||
|
||||
REQUIRE_TRUE(block.numI() > 0 || block.width() > 1, 0, "cyclic_rshift_bits: actual shift value is missing");
|
||||
BROADCAST_CHECK_EMPTY(x,y,z);
|
||||
|
||||
uint32_t shift = 0;
|
||||
if (block.width() > 1) {
|
||||
shift = INPUT_VARIABLE(1)->e<uint32_t>(0);
|
||||
} else if (block.numI() > 0) {
|
||||
shift = INT_ARG(0);
|
||||
};
|
||||
|
||||
helpers::cyclic_rshift_bits(block.launchContext(), *input, *output, shift);
|
||||
|
||||
REQUIRE_TRUE(shift > 0 && shift < input->sizeOfT() * 8, 0, "cyclic_rshift_bits: can't shift beyond size of data type")
|
||||
x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::CyclicShiftRight, pairwise::CyclicShiftRight, broadcast::CyclicShiftRight), y, z, false);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -27,22 +27,14 @@
|
|||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
CONFIGURABLE_OP_IMPL(cyclic_shift_bits, 1, 1, true, 0, -2) {
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
BROADCASTABLE_OP_IMPL(cyclic_shift_bits, 0, 0) {
|
||||
auto x = INPUT_VARIABLE(0);
|
||||
auto y = INPUT_VARIABLE(1);
|
||||
auto z = OUTPUT_VARIABLE(0);
|
||||
|
||||
REQUIRE_TRUE(block.numI() > 0 || block.width() > 1, 0, "cyclic_shift_bits: actual shift value is missing");
|
||||
BROADCAST_CHECK_EMPTY(x,y,z);
|
||||
|
||||
uint32_t shift = 0;
|
||||
if (block.width() > 1) {
|
||||
shift = INPUT_VARIABLE(1)->e<uint32_t>(0);
|
||||
} else if (block.numI() > 0) {
|
||||
shift = INT_ARG(0);
|
||||
};
|
||||
|
||||
helpers::cyclic_shift_bits(block.launchContext(), *input, *output, shift);
|
||||
|
||||
REQUIRE_TRUE(shift > 0 && shift < input->sizeOfT() * 8, 0, "cyclic_shift_bits: can't shift beyond size of data type")
|
||||
x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::CyclicShiftLeft, pairwise::CyclicShiftLeft, broadcast::CyclicShiftLeft), y, z, false);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -27,22 +27,14 @@
|
|||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
CONFIGURABLE_OP_IMPL(rshift_bits, 1, 1, true, 0, -2) {
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
BROADCASTABLE_OP_IMPL(rshift_bits, 0, 0) {
|
||||
auto x = INPUT_VARIABLE(0);
|
||||
auto y = INPUT_VARIABLE(1);
|
||||
auto z = OUTPUT_VARIABLE(0);
|
||||
|
||||
REQUIRE_TRUE(block.numI() > 0 || block.width() > 1, 0, "rshift_bits: actual shift value is missing");
|
||||
BROADCAST_CHECK_EMPTY(x,y,z);
|
||||
|
||||
uint32_t shift = 0;
|
||||
if (block.width() > 1) {
|
||||
shift = INPUT_VARIABLE(1)->e<uint32_t>(0);
|
||||
} else if (block.numI() > 0) {
|
||||
shift = INT_ARG(0);
|
||||
};
|
||||
|
||||
REQUIRE_TRUE(shift > 0 && shift < input->sizeOfT() * 8, 0, "rshift_bits: can't shift beyond size of data type")
|
||||
|
||||
helpers::rshift_bits(block.launchContext(), *input, *output, shift);
|
||||
x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::ShiftRight, pairwise::ShiftRight, broadcast::ShiftRight), y, z, false);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -27,22 +27,14 @@
|
|||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
CONFIGURABLE_OP_IMPL(shift_bits, 1, 1, true, 0, -2) {
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
BROADCASTABLE_OP_IMPL(shift_bits, 0, 0) {
|
||||
auto x = INPUT_VARIABLE(0);
|
||||
auto y = INPUT_VARIABLE(1);
|
||||
auto z = OUTPUT_VARIABLE(0);
|
||||
|
||||
REQUIRE_TRUE(block.numI() > 0 || block.width() > 1, 0, "shift_bits: actual shift value is missing");
|
||||
BROADCAST_CHECK_EMPTY(x,y,z);
|
||||
|
||||
uint32_t shift = 0;
|
||||
if (block.width() > 1) {
|
||||
shift = INPUT_VARIABLE(1)->e<uint32_t>(0);
|
||||
} else if (block.numI() > 0) {
|
||||
shift = INT_ARG(0);
|
||||
};
|
||||
|
||||
REQUIRE_TRUE(shift > 0 && shift < input->sizeOfT() * 8, 0, "shift_bits: can't shift beyond size of data type")
|
||||
|
||||
helpers::shift_bits(block.launchContext(), *input, *output, shift);
|
||||
x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::ShiftLeft, pairwise::ShiftLeft, broadcast::ShiftLeft), y, z, false);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -45,7 +45,7 @@ namespace nd4j {
|
|||
* @tparam T
|
||||
*/
|
||||
#if NOT_EXCLUDED(OP_shift_bits)
|
||||
DECLARE_CONFIGURABLE_OP(shift_bits, 1, 1, true, 0, -2);
|
||||
DECLARE_BROADCASTABLE_OP(shift_bits, 0, 0);
|
||||
#endif
|
||||
|
||||
/**
|
||||
|
@ -56,7 +56,7 @@ namespace nd4j {
|
|||
* @tparam T
|
||||
*/
|
||||
#if NOT_EXCLUDED(OP_rshift_bits)
|
||||
DECLARE_CONFIGURABLE_OP(rshift_bits, 1, 1, true, 0, -2);
|
||||
DECLARE_BROADCASTABLE_OP(rshift_bits, 0, 0);
|
||||
#endif
|
||||
|
||||
/**
|
||||
|
@ -67,7 +67,7 @@ namespace nd4j {
|
|||
* @tparam T
|
||||
*/
|
||||
#if NOT_EXCLUDED(OP_cyclic_shift_bits)
|
||||
DECLARE_CONFIGURABLE_OP(cyclic_shift_bits, 1, 1, true, 0, -2);
|
||||
DECLARE_BROADCASTABLE_OP(cyclic_shift_bits, 0, 0);
|
||||
#endif
|
||||
|
||||
/**
|
||||
|
@ -78,7 +78,7 @@ namespace nd4j {
|
|||
* @tparam T
|
||||
*/
|
||||
#if NOT_EXCLUDED(OP_cyclic_rshift_bits)
|
||||
DECLARE_CONFIGURABLE_OP(cyclic_rshift_bits, 1, 1, true, 0, -2);
|
||||
DECLARE_BROADCASTABLE_OP(cyclic_rshift_bits, 0, 0);
|
||||
#endif
|
||||
|
||||
/**
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
#include <ops/BroadcastIntOpsTuple.h>
|
||||
|
||||
namespace nd4j {
|
||||
BroadcastIntOpsTuple BroadcastIntOpsTuple::custom(nd4j::scalar::IntOps scalar, nd4j::pairwise::IntOps pairwise, nd4j::broadcast::IntOps broadcast) {
|
||||
BroadcastIntOpsTuple t(scalar, pairwise, broadcast);
|
||||
return t;
|
||||
}
|
||||
}
|
|
@ -660,6 +660,98 @@ namespace simdOps {
|
|||
}
|
||||
};
|
||||
|
||||
template <typename X>
|
||||
class IntOr {
|
||||
public:
|
||||
|
||||
op_def static X op(X d1, X d2) {
|
||||
return d2 | d1;
|
||||
}
|
||||
|
||||
op_def static X op(X d1, X d2, X *params) {
|
||||
return op(d1, d2);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename X>
|
||||
class IntAnd {
|
||||
public:
|
||||
|
||||
op_def static X op(X d1, X d2) {
|
||||
return d2 & d1;
|
||||
}
|
||||
|
||||
op_def static X op(X d1, X d2, X *params) {
|
||||
return op(d1, d2);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename X>
|
||||
class IntXor {
|
||||
public:
|
||||
|
||||
op_def static X op(X d1, X d2) {
|
||||
return d2 ^ d1;
|
||||
}
|
||||
|
||||
op_def static X op(X d1, X d2, X *params) {
|
||||
return op(d1, d2);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename X>
|
||||
class ShiftLeft {
|
||||
public:
|
||||
|
||||
op_def static X op(X d1, X d2) {
|
||||
return d1 << d2;
|
||||
}
|
||||
|
||||
op_def static X op(X d1, X d2, X *params) {
|
||||
return op(d1, d2);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename X>
|
||||
class ShiftRight {
|
||||
public:
|
||||
|
||||
op_def static X op(X d1, X d2) {
|
||||
return d1 >> d2;
|
||||
}
|
||||
|
||||
op_def static X op(X d1, X d2, X *params) {
|
||||
return op(d1, d2);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename X>
|
||||
class CyclicShiftLeft {
|
||||
public:
|
||||
|
||||
op_def static X op(X d1, X d2) {
|
||||
return d1 << d2 | d1 >> ((sizeof(X) * 8) - d2);
|
||||
}
|
||||
|
||||
op_def static X op(X d1, X d2, X *params) {
|
||||
return op(d1, d2);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename X>
|
||||
class CyclicShiftRight {
|
||||
public:
|
||||
|
||||
op_def static X op(X d1, X d2) {
|
||||
return d1 >> d2 | d1 << ((sizeof(X) * 8) - d2);
|
||||
}
|
||||
|
||||
op_def static X op(X d1, X d2, X *params) {
|
||||
return op(d1, d2);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template <typename X, typename Z>
|
||||
class Or {
|
||||
public:
|
||||
|
|
|
@ -623,12 +623,13 @@ TEST_F(DeclarableOpsTests13, adjustSaturation_5) {
|
|||
|
||||
TEST_F(DeclarableOpsTests13, shift_bits_1) {
|
||||
auto x = NDArrayFactory::create<int>('c', {5});
|
||||
auto y = NDArrayFactory::create<int>(4);
|
||||
auto e = x.ulike();
|
||||
x.assign(32);
|
||||
e.assign(512);
|
||||
|
||||
nd4j::ops::shift_bits op;
|
||||
auto result = op.execute({&x}, {}, {4});
|
||||
auto result = op.execute({&x, &y}, {}, {});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
|
@ -640,12 +641,13 @@ TEST_F(DeclarableOpsTests13, shift_bits_1) {
|
|||
|
||||
TEST_F(DeclarableOpsTests13, rshift_bits_1) {
|
||||
auto x = NDArrayFactory::create<int>('c', {5});
|
||||
auto y = NDArrayFactory::create<int>(4);
|
||||
auto e = x.ulike();
|
||||
x.assign(512);
|
||||
e.assign(32);
|
||||
|
||||
nd4j::ops::rshift_bits op;
|
||||
auto result = op.execute({&x}, {}, {4});
|
||||
auto result = op.execute({&x, &y}, {}, {});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
|
@ -657,12 +659,13 @@ TEST_F(DeclarableOpsTests13, rshift_bits_1) {
|
|||
|
||||
TEST_F(DeclarableOpsTests13, cyclic_shift_bits_1) {
|
||||
auto x = NDArrayFactory::create<int>('c', {5});
|
||||
auto y = NDArrayFactory::create<int>(4);
|
||||
auto e = x.ulike();
|
||||
x.assign(32);
|
||||
e.assign(512);
|
||||
|
||||
nd4j::ops::cyclic_shift_bits op;
|
||||
auto result = op.execute({&x}, {}, {4});
|
||||
auto result = op.execute({&x, &y}, {}, {});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
|
@ -674,12 +677,107 @@ TEST_F(DeclarableOpsTests13, cyclic_shift_bits_1) {
|
|||
|
||||
TEST_F(DeclarableOpsTests13, cyclic_rshift_bits_1) {
|
||||
auto x = NDArrayFactory::create<int>('c', {5});
|
||||
auto y = NDArrayFactory::create<int>(4);
|
||||
auto e = x.ulike();
|
||||
x.assign(512);
|
||||
e.assign(32);
|
||||
|
||||
nd4j::ops::cyclic_rshift_bits op;
|
||||
auto result = op.execute({&x}, {}, {4});
|
||||
auto result = op.execute({&x, &y}, {}, {});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
|
||||
ASSERT_EQ(e, *z);
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests13, shift_bits_2) {
|
||||
auto x = NDArrayFactory::create<int>('c', {5});
|
||||
auto y = NDArrayFactory::create<int>('c', {5});
|
||||
auto e = x.ulike();
|
||||
x.assign(32);
|
||||
y.assign(4);
|
||||
e.assign(512);
|
||||
|
||||
nd4j::ops::shift_bits op;
|
||||
auto result = op.execute({&x, &y}, {}, {});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
|
||||
ASSERT_EQ(e, *z);
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests13, rshift_bits_2) {
|
||||
auto x = NDArrayFactory::create<int>('c', {5});
|
||||
auto y = NDArrayFactory::create<int>('c', {5});
|
||||
auto e = x.ulike();
|
||||
x.assign(512);
|
||||
y.assign(4);
|
||||
e.assign(32);
|
||||
|
||||
nd4j::ops::rshift_bits op;
|
||||
auto result = op.execute({&x, &y}, {}, {});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
|
||||
ASSERT_EQ(e, *z);
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests13, cyclic_shift_bits_2) {
|
||||
auto x = NDArrayFactory::create<int>('c', {5});
|
||||
auto y = NDArrayFactory::create<int>('c', {5});
|
||||
auto e = x.ulike();
|
||||
x.assign(32);
|
||||
y.assign(4);
|
||||
e.assign(512);
|
||||
|
||||
nd4j::ops::cyclic_shift_bits op;
|
||||
auto result = op.execute({&x, &y}, {}, {});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
|
||||
ASSERT_EQ(e, *z);
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests13, cyclic_rshift_bits_2) {
|
||||
auto x = NDArrayFactory::create<int>('c', {5});
|
||||
auto y = NDArrayFactory::create<int>('c', {5});
|
||||
auto e = x.ulike();
|
||||
x.assign(512);
|
||||
y.assign(4);
|
||||
e.assign(32);
|
||||
|
||||
nd4j::ops::cyclic_rshift_bits op;
|
||||
auto result = op.execute({&x, &y}, {}, {});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
|
||||
ASSERT_EQ(e, *z);
|
||||
|
||||
delete result;
|
||||
}
|
||||
TEST_F(DeclarableOpsTests13, shift_bits_3) {
|
||||
auto x = NDArrayFactory::create<int>('c', {5, 5});
|
||||
auto y = NDArrayFactory::create<int>('c', {1, 5});
|
||||
auto e = x.ulike();
|
||||
x.assign(32);
|
||||
y.assign(4);
|
||||
e.assign(512);
|
||||
|
||||
nd4j::ops::shift_bits op;
|
||||
auto result = op.execute({&x, &y}, {}, {});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
|
|
|
@ -1234,19 +1234,19 @@ public class DifferentialFunctionFactory {
|
|||
return new Xor(sameDiff(), ix, iy).outputVariable();
|
||||
}
|
||||
|
||||
public SDVariable shift(SDVariable ix, int shift) {
|
||||
public SDVariable shift(SDVariable ix, SDVariable shift) {
|
||||
return new ShiftBits(sameDiff(), ix, shift).outputVariable();
|
||||
}
|
||||
|
||||
public SDVariable rshift(SDVariable ix, int shift) {
|
||||
public SDVariable rshift(SDVariable ix, SDVariable shift) {
|
||||
return new RShiftBits(sameDiff(), ix, shift).outputVariable();
|
||||
}
|
||||
|
||||
public SDVariable rotl(SDVariable ix, int shift) {
|
||||
public SDVariable rotl(SDVariable ix, SDVariable shift) {
|
||||
return new CyclicShiftBits(sameDiff(), ix, shift).outputVariable();
|
||||
}
|
||||
|
||||
public SDVariable rotr(SDVariable ix, int shift) {
|
||||
public SDVariable rotr(SDVariable ix, SDVariable shift) {
|
||||
return new CyclicRShiftBits(sameDiff(), ix, shift).outputVariable();
|
||||
}
|
||||
|
||||
|
|
|
@ -2428,7 +2428,7 @@ public class SDMath extends SDOps {
|
|||
* @param x Input 1
|
||||
* @return Output SDVariable with shifted bits
|
||||
*/
|
||||
public SDVariable bitShift(String name, SDVariable x, int shift) {
|
||||
public SDVariable bitShift(String name, SDVariable x, SDVariable shift) {
|
||||
validateInteger("shift_bits", x);
|
||||
SDVariable result = f().shift(x, shift);
|
||||
return updateVariableNameAndReference(result, name);
|
||||
|
@ -2441,7 +2441,7 @@ public class SDMath extends SDOps {
|
|||
* @param x Input 1
|
||||
* @return Output SDVariable with shifted bits
|
||||
*/
|
||||
public SDVariable bitShiftRight(String name, SDVariable x, int shift) {
|
||||
public SDVariable bitShiftRight(String name, SDVariable x, SDVariable shift) {
|
||||
validateInteger("rshift_bits", x);
|
||||
SDVariable result = f().rshift(x, shift);
|
||||
return updateVariableNameAndReference(result, name);
|
||||
|
@ -2454,7 +2454,7 @@ public class SDMath extends SDOps {
|
|||
* @param x Input 1
|
||||
* @return Output SDVariable with shifted bits
|
||||
*/
|
||||
public SDVariable bitRotl(String name, SDVariable x, int shift) {
|
||||
public SDVariable bitRotl(String name, SDVariable x, SDVariable shift) {
|
||||
validateInteger("cyclic_shift_bits", x);
|
||||
SDVariable result = f().rotl(x, shift);
|
||||
return updateVariableNameAndReference(result, name);
|
||||
|
@ -2467,7 +2467,7 @@ public class SDMath extends SDOps {
|
|||
* @param x Input 1
|
||||
* @return Output SDVariable with shifted bits
|
||||
*/
|
||||
public SDVariable bitRotr(String name, SDVariable x, int shift) {
|
||||
public SDVariable bitRotr(String name, SDVariable x, SDVariable shift) {
|
||||
validateInteger("cyclic_rshift_bits", x);
|
||||
SDVariable result = f().rotr(x, shift);
|
||||
return updateVariableNameAndReference(result, name);
|
||||
|
|
|
@ -34,18 +34,16 @@ import java.util.List;
|
|||
*/
|
||||
public class CyclicRShiftBits extends BaseDynamicTransformOp {
|
||||
|
||||
public CyclicRShiftBits(SameDiff sameDiff, SDVariable x, int shift) {
|
||||
super(sameDiff, new SDVariable[] {x} ,false);
|
||||
this.addIArgument(shift);
|
||||
public CyclicRShiftBits(SameDiff sameDiff, SDVariable x, SDVariable shift) {
|
||||
super(sameDiff, new SDVariable[] {x, shift} ,false);
|
||||
}
|
||||
|
||||
public CyclicRShiftBits(INDArray input, int shift, INDArray output) {
|
||||
super(new INDArray[]{input}, new INDArray[]{output});
|
||||
this.addIArgument(shift);
|
||||
public CyclicRShiftBits(INDArray input, INDArray shift, INDArray output) {
|
||||
super(new INDArray[]{input, shift}, new INDArray[]{output});
|
||||
}
|
||||
|
||||
public CyclicRShiftBits(INDArray input, int shift) {
|
||||
this(input, shift,null);
|
||||
public CyclicRShiftBits(INDArray input, INDArray shift) {
|
||||
this(input, shift,input.ulike());
|
||||
}
|
||||
|
||||
public CyclicRShiftBits() {}
|
||||
|
|
|
@ -34,18 +34,16 @@ import java.util.List;
|
|||
*/
|
||||
public class CyclicShiftBits extends BaseDynamicTransformOp {
|
||||
|
||||
public CyclicShiftBits(SameDiff sameDiff, SDVariable x, int shift) {
|
||||
super(sameDiff, new SDVariable[] {x} ,false);
|
||||
this.addIArgument(shift);
|
||||
public CyclicShiftBits(SameDiff sameDiff, SDVariable x, SDVariable shift) {
|
||||
super(sameDiff, new SDVariable[] {x, shift} ,false);
|
||||
}
|
||||
|
||||
public CyclicShiftBits(INDArray input, int shift, INDArray output) {
|
||||
super(new INDArray[]{input}, new INDArray[]{output});
|
||||
this.addIArgument(shift);
|
||||
public CyclicShiftBits(INDArray input, INDArray shift, INDArray output) {
|
||||
super(new INDArray[]{input, shift}, new INDArray[]{output});
|
||||
}
|
||||
|
||||
public CyclicShiftBits(INDArray input, int shift) {
|
||||
this(input, shift,null);
|
||||
public CyclicShiftBits(INDArray input, INDArray shift) {
|
||||
this(input, shift,input.ulike());
|
||||
}
|
||||
|
||||
public CyclicShiftBits() {}
|
||||
|
|
|
@ -34,18 +34,16 @@ import java.util.List;
|
|||
*/
|
||||
public class RShiftBits extends BaseDynamicTransformOp {
|
||||
|
||||
public RShiftBits(SameDiff sameDiff, SDVariable x, int shift) {
|
||||
super(sameDiff, new SDVariable[] {x} ,false);
|
||||
this.addIArgument(shift);
|
||||
public RShiftBits(SameDiff sameDiff, SDVariable x, SDVariable y) {
|
||||
super(sameDiff, new SDVariable[] {x, y} ,false);
|
||||
}
|
||||
|
||||
public RShiftBits(INDArray input, int shift, INDArray output) {
|
||||
super(new INDArray[]{input}, new INDArray[]{output});
|
||||
this.addIArgument(shift);
|
||||
public RShiftBits(INDArray input, INDArray shift, INDArray output) {
|
||||
super(new INDArray[]{input, shift}, new INDArray[]{output});
|
||||
}
|
||||
|
||||
public RShiftBits(INDArray input, int shift) {
|
||||
this(input, shift,null);
|
||||
public RShiftBits(INDArray input, INDArray shift) {
|
||||
this(input, shift,input.ulike());
|
||||
}
|
||||
|
||||
public RShiftBits() {}
|
||||
|
|
|
@ -34,18 +34,16 @@ import java.util.List;
|
|||
*/
|
||||
public class ShiftBits extends BaseDynamicTransformOp {
|
||||
|
||||
public ShiftBits(SameDiff sameDiff, SDVariable x, int shift) {
|
||||
super(sameDiff, new SDVariable[] {x} ,false);
|
||||
this.addIArgument(shift);
|
||||
public ShiftBits(SameDiff sameDiff, SDVariable x, SDVariable y) {
|
||||
super(sameDiff, new SDVariable[] {x, y} ,false);
|
||||
}
|
||||
|
||||
public ShiftBits(INDArray input, int shift, INDArray output) {
|
||||
super(new INDArray[]{input}, new INDArray[]{output});
|
||||
this.addIArgument(shift);
|
||||
public ShiftBits(INDArray x, INDArray y, INDArray output) {
|
||||
super(new INDArray[]{x, y}, new INDArray[]{output});
|
||||
}
|
||||
|
||||
public ShiftBits(INDArray input, int shift) {
|
||||
this(input, shift,null);
|
||||
public ShiftBits(INDArray x, INDArray y) {
|
||||
this(x, y,x.ulike());
|
||||
}
|
||||
|
||||
public ShiftBits() {}
|
||||
|
|
|
@ -3594,6 +3594,7 @@ public native @Cast("Nd4jPointer") Pointer lcSolverHandle(OpaqueLaunchContext lc
|
|||
// #include <op_enums.h>
|
||||
// #include <ops/BroadcastOpsTuple.h>
|
||||
// #include <ops/BroadcastBoolOpsTuple.h>
|
||||
// #include <ops/BroadcastIntOpsTuple.h>
|
||||
// #include <array/ExtraArguments.h>
|
||||
// #include <Status.h>
|
||||
// #include <ShapeDescriptor.h>
|
||||
|
|
|
@ -3594,6 +3594,7 @@ public native @Cast("Nd4jPointer") Pointer lcSolverHandle(OpaqueLaunchContext lc
|
|||
// #include <op_enums.h>
|
||||
// #include <ops/BroadcastOpsTuple.h>
|
||||
// #include <ops/BroadcastBoolOpsTuple.h>
|
||||
// #include <ops/BroadcastIntOpsTuple.h>
|
||||
// #include <array/ExtraArguments.h>
|
||||
// #include <Status.h>
|
||||
// #include <ShapeDescriptor.h>
|
||||
|
@ -16496,15 +16497,18 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
|
||||
/**
|
||||
* creates identity 2D matrix or batch of identical 2D identity matrices
|
||||
*
|
||||
*
|
||||
* Input array:
|
||||
* provide some array - in any case operation simply neglects it
|
||||
*
|
||||
*
|
||||
* Input float argument (if passed):
|
||||
* TArgs[0] - type of elements of output array, default value is 5 (float)
|
||||
*
|
||||
* Input integer arguments:
|
||||
* IArgs[0] - order of output identity matrix, 99 -> 'c'-order, 102 -> 'f'-order
|
||||
* IArgs[1] - the number of rows in output inner-most 2D identity matrix
|
||||
* IArgs[2] - optional, the number of columns in output inner-most 2D identity matrix, if this argument is not provided then it is taken to be equal to number of rows
|
||||
* IArgs[3,4,...] - optional, shape of batch, output matrix will have leading batch dimensions of this shape
|
||||
* IArgs[3,4,...] - optional, shape of batch, output matrix will have leading batch dimensions of this shape
|
||||
*/
|
||||
// #if NOT_EXCLUDED(OP_eye)
|
||||
@Namespace("nd4j::ops") public static class eye extends DeclarableCustomOp {
|
||||
|
@ -16598,10 +16602,10 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
|
||||
/**
|
||||
* clip a list of given tensors with given average norm when needed
|
||||
*
|
||||
*
|
||||
* Input:
|
||||
* a list of tensors (at least one)
|
||||
*
|
||||
*
|
||||
* Input floating point argument:
|
||||
* clip_norm - a value that used as threshold value and norm to be used
|
||||
*
|
||||
|
@ -16749,12 +16753,12 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
|
||||
/**
|
||||
* returns histogram (as 1D array) with fixed bins width
|
||||
*
|
||||
*
|
||||
* Input arrays:
|
||||
* - input array with elements to be binned into output histogram
|
||||
* - input array with elements to be binned into output histogram
|
||||
* - range array with first element being bottom limit and second element being top limit of histogram,
|
||||
please note that input_value <= range[0] will be mapped to histogram[0], input_value >= range[1] will be mapped to histogram[-1]
|
||||
*
|
||||
*
|
||||
* Input integer arguments:
|
||||
* nbins (optional) - number of histogram bins, default value is 100
|
||||
*/
|
||||
|
@ -21822,7 +21826,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
* \tparam T
|
||||
*/
|
||||
// #if NOT_EXCLUDED(OP_shift_bits)
|
||||
@Namespace("nd4j::ops") public static class shift_bits extends DeclarableOp {
|
||||
@Namespace("nd4j::ops") public static class shift_bits extends BroadcastableOp {
|
||||
static { Loader.load(); }
|
||||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||
public shift_bits(Pointer p) { super(p); }
|
||||
|
@ -21835,7 +21839,6 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
|
||||
public shift_bits() { super((Pointer)null); allocate(); }
|
||||
private native void allocate();
|
||||
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
||||
}
|
||||
// #endif
|
||||
|
||||
|
@ -21847,7 +21850,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
* \tparam T
|
||||
*/
|
||||
// #if NOT_EXCLUDED(OP_rshift_bits)
|
||||
@Namespace("nd4j::ops") public static class rshift_bits extends DeclarableOp {
|
||||
@Namespace("nd4j::ops") public static class rshift_bits extends BroadcastableOp {
|
||||
static { Loader.load(); }
|
||||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||
public rshift_bits(Pointer p) { super(p); }
|
||||
|
@ -21860,7 +21863,6 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
|
||||
public rshift_bits() { super((Pointer)null); allocate(); }
|
||||
private native void allocate();
|
||||
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
||||
}
|
||||
// #endif
|
||||
|
||||
|
@ -21872,7 +21874,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
* \tparam T
|
||||
*/
|
||||
// #if NOT_EXCLUDED(OP_cyclic_shift_bits)
|
||||
@Namespace("nd4j::ops") public static class cyclic_shift_bits extends DeclarableOp {
|
||||
@Namespace("nd4j::ops") public static class cyclic_shift_bits extends BroadcastableOp {
|
||||
static { Loader.load(); }
|
||||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||
public cyclic_shift_bits(Pointer p) { super(p); }
|
||||
|
@ -21885,7 +21887,6 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
|
||||
public cyclic_shift_bits() { super((Pointer)null); allocate(); }
|
||||
private native void allocate();
|
||||
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
||||
}
|
||||
// #endif
|
||||
|
||||
|
@ -21897,7 +21898,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
* \tparam T
|
||||
*/
|
||||
// #if NOT_EXCLUDED(OP_cyclic_rshift_bits)
|
||||
@Namespace("nd4j::ops") public static class cyclic_rshift_bits extends DeclarableOp {
|
||||
@Namespace("nd4j::ops") public static class cyclic_rshift_bits extends BroadcastableOp {
|
||||
static { Loader.load(); }
|
||||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||
public cyclic_rshift_bits(Pointer p) { super(p); }
|
||||
|
@ -21910,6 +21911,30 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
|
||||
public cyclic_rshift_bits() { super((Pointer)null); allocate(); }
|
||||
private native void allocate();
|
||||
}
|
||||
// #endif
|
||||
|
||||
/**
|
||||
* This operation returns hamming distance based on bits
|
||||
*
|
||||
* PLEASE NOTE: This operation is applicable only to integer data types
|
||||
*
|
||||
* \tparam T
|
||||
*/
|
||||
// #if NOT_EXCLUDED(OP_bits_hamming_distance)
|
||||
@Namespace("nd4j::ops") public static class bits_hamming_distance extends DeclarableCustomOp {
|
||||
static { Loader.load(); }
|
||||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||
public bits_hamming_distance(Pointer p) { super(p); }
|
||||
/** Native array allocator. Access with {@link Pointer#position(long)}. */
|
||||
public bits_hamming_distance(long size) { super((Pointer)null); allocateArray(size); }
|
||||
private native void allocateArray(long size);
|
||||
@Override public bits_hamming_distance position(long position) {
|
||||
return (bits_hamming_distance)super.position(position);
|
||||
}
|
||||
|
||||
public bits_hamming_distance() { super((Pointer)null); allocate(); }
|
||||
private native void allocate();
|
||||
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
||||
}
|
||||
// #endif
|
||||
|
|
|
@ -90,8 +90,13 @@ class SDVariableWrapper {
|
|||
def |(other: SDVariable)(implicit sameDiff: SameDiff): SDVariable = sameDiff.math.or(thisVariable, other)
|
||||
def &(other: SDVariable)(implicit sameDiff: SameDiff): SDVariable = sameDiff.math.and(thisVariable, other)
|
||||
|
||||
def <<(x: Int)(implicit sameDiff: SameDiff): SDVariable = sameDiff.math.bitShift(null, thisVariable, x)
|
||||
def >>(x: Int)(implicit sameDiff: SameDiff): SDVariable = sameDiff.math.bitShiftRight(null, thisVariable, x)
|
||||
def <<(other: SDVariable)(implicit sameDiff: SameDiff): SDVariable = sameDiff.math.bitShift(null, thisVariable, other)
|
||||
def >>(other: SDVariable)(implicit sameDiff: SameDiff): SDVariable =
|
||||
sameDiff.math.bitShiftRight(null, thisVariable, other)
|
||||
def <<(x: Int)(implicit sameDiff: SameDiff): SDVariable =
|
||||
sameDiff.math.bitShift(null, thisVariable, sameDiff.constant(x))
|
||||
def >>(x: Int)(implicit sameDiff: SameDiff): SDVariable =
|
||||
sameDiff.math.bitShiftRight(null, thisVariable, sameDiff.constant(x))
|
||||
|
||||
// Overloads for numeric arguments
|
||||
// Float
|
||||
|
|
|
@ -188,4 +188,16 @@ class MathTest extends FlatSpec with Matchers {
|
|||
val w3 = w1 >> 2
|
||||
w3.eval.toIntVector.head shouldBe 4
|
||||
}
|
||||
|
||||
"SameDiff" should "provide shifting operations with SDVariable argument" in {
|
||||
implicit val sd = SameDiff.create()
|
||||
val w1 = sd.constant(16)
|
||||
val two = sd.constant(2)
|
||||
|
||||
val w2 = w1 << two
|
||||
w2.eval.toIntVector.head shouldBe 64
|
||||
|
||||
val w3 = w1 >> two
|
||||
w3.eval.toIntVector.head shouldBe 4
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue