[WIP] Int broadcastables (#195)

* Removed invalid resource and fixed tests

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* legacy scalar/pairwise/broadcast int ops

Signed-off-by: raver119 <raver119@gmail.com>

* NDArray int broadcastables

Signed-off-by: raver119 <raver119@gmail.com>

* few more bitwise tests

Signed-off-by: raver119 <raver119@gmail.com>

* java side update

Signed-off-by: raver119 <raver119@gmail.com>

* Argument type changed for shift ops

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>

* legacy scalar/pairwise/broadcast int ops

Signed-off-by: raver119 <raver119@gmail.com>

* NDArray int broadcastables

Signed-off-by: raver119 <raver119@gmail.com>

* few more bitwise tests

Signed-off-by: raver119 <raver119@gmail.com>

* java side update

Signed-off-by: raver119 <raver119@gmail.com>

* Argument type changed for shift ops

Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>
master
raver119 2019-08-30 10:12:40 +03:00 committed by GitHub
parent 130c9aa536
commit 1003428a18
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
37 changed files with 3237 additions and 126 deletions

View File

@ -103,7 +103,7 @@ public class FastTextTest extends BaseDL4JTest {
}
@Test
public void testPredict() throws IOException {
public void testPredict() {
String text = "I like soccer";
FastText fastText = new FastText(supModelFile);
@ -118,7 +118,7 @@ public class FastTextTest extends BaseDL4JTest {
}
@Test
public void testPredictProbability() throws IOException {
public void testPredictProbability() {
String text = "I like soccer";
FastText fastText = new FastText(supModelFile);

View File

@ -34,6 +34,7 @@
#include <op_enums.h>
#include <ops/BroadcastOpsTuple.h>
#include <ops/BroadcastBoolOpsTuple.h>
#include <ops/BroadcastIntOpsTuple.h>
#include <array/ExtraArguments.h>
#include <Status.h>
#include <ShapeDescriptor.h>
@ -659,6 +660,8 @@ namespace nd4j {
void applyPairwiseTransform(nd4j::pairwise::BoolOps op, const NDArray *other, NDArray *target, ExtraArguments *extraParams = nullptr) const;
void applyPairwiseTransform(nd4j::pairwise::IntOps op, const NDArray *other, NDArray *target, ExtraArguments *extraParams = nullptr) const;
/**
* apply operation which requires broadcasting, broadcast a smaller array (tad) along bigger one (this)
* tad - array to broadcast
@ -672,6 +675,9 @@ namespace nd4j {
void applyBroadcast(nd4j::broadcast::BoolOps op, const std::vector<int> &dimensions, const NDArray *tad, NDArray *target = nullptr, ExtraArguments *extraArgs = nullptr);
void applyBroadcast(nd4j::broadcast::IntOps op, const std::vector<int> &dimensions, const NDArray *tad, NDArray *target = nullptr, ExtraArguments *extraArgs = nullptr);
/**
* apply operation which requires broadcasting, broadcast one tensor along another, also this method checks the possibility of broadcasting
* other - input array
@ -692,6 +698,9 @@ namespace nd4j {
void applyTrueBroadcast(nd4j::BroadcastBoolOpsTuple op, const NDArray* other, NDArray* target, const bool checkTargetShape = true, ExtraArguments *extraArgs = nullptr) const;
void applyTrueBroadcast(nd4j::BroadcastIntOpsTuple op, const NDArray* other, NDArray* target, const bool checkTargetShape = true, ExtraArguments *extraArgs = nullptr) const;
/**
* apply a scalar operation to an array
* scalar - input scalar
@ -704,6 +713,9 @@ namespace nd4j {
template <typename T>
void applyScalar(nd4j::scalar::BoolOps op, const T scalar, NDArray* target, ExtraArguments *extraParams = nullptr) const;
template <typename T>
void applyScalar(nd4j::scalar::IntOps op, const T scalar, NDArray* target, ExtraArguments *extraParams = nullptr) const;
/**
* apply a scalar operation to an array
* scalar - input array which is simple scalar
@ -714,6 +726,7 @@ namespace nd4j {
void applyScalarArr(nd4j::scalar::BoolOps op, const NDArray* scalar, NDArray* target, ExtraArguments *extraParams = nullptr) const;
void applyScalarArr(nd4j::scalar::IntOps op, const NDArray* scalar, NDArray* target, ExtraArguments *extraParams = nullptr) const;
#if defined(__CUDABLAS__) //&& defined(BUILD_TESTS)
template <typename Lambda>

View File

@ -2663,6 +2663,88 @@ void NDArray::applyTrueBroadcast(nd4j::BroadcastBoolOpsTuple op, const NDArray*
delete pTarget;
}
//////////////////////////////////////////////////////////////////////////
void NDArray::applyTrueBroadcast(nd4j::BroadcastIntOpsTuple op, const NDArray* other, NDArray* target, const bool checkTargetShape, ExtraArguments *extraArgs) const {
if (isS())
throw std::runtime_error("NDArray::applyTrueBroadcast bool: you can't use this method on String array!");
if(target == nullptr || other == nullptr)
throw std::runtime_error("NDArray::applyTrueBroadcast int method: target or other = nullptr !");
if (isEmpty() || other->isEmpty())
return;
NDArray::prepareSpecialUse({target}, {this, other});
if (isScalar()) {
NDArray temp(target->_shapeInfo, dataType(), false, getContext());
temp.assign(this);
temp.applyPairwiseTransform(op.p, other, target, extraArgs);
return;
}
if (other->isScalar()) {
this->applyScalarArr(op.s, other, target, extraArgs);
return;
}
const NDArray* min(other);
const NDArray* max(this);
if(this->rankOf() < other->rankOf()) {
max = other;
min = this;
}
if(checkTargetShape) {
Nd4jLong* newShapeInfo = nullptr;
if(!ShapeUtils::evalBroadcastShapeInfo(*max, *min, false, newShapeInfo, getContext()->getWorkspace())) // the rank of target array must be equal to max->rankOf)()
throw std::runtime_error("NDArray::applyTrueBroadcast method: the shapes of this and other arrays are not suitable for broadcast operation !");
if(!shape::equalsSoft(target->_shapeInfo, newShapeInfo) || target->dataType() != this->dataType())
throw std::runtime_error("NDArray::applyTrueBroadcast int method: the shape or type of target array is wrong !");
if(dataType() != other->dataType())
throw std::invalid_argument("NDArray::applyTrueBroadcast int method: this and other arrays must have the same type !");
}
NDArray* pTarget = (max->dataType() == target->dataType()) ? target : new NDArray(target->ordering(), target->getShapeAsVector(), max->dataType(), target->getContext());
// check whether max array has to be tiled
if(!max->isSameShape(target)) {
// evaluate repeating dimensions for tile operation
std::vector<Nd4jLong> repeatMax(max->rankOf());
for(int i = 1; i <= max->rankOf(); ++i)
repeatMax[i-1] = (target->_shapeInfo[i] / max->_shapeInfo[i]);
max->tile(repeatMax, *pTarget);
}
else
pTarget->assign(max);
// check whether min array has to be tiled
std::vector<Nd4jLong> repeatMin(min->rankOf());
int product = 1;
for(int i = min->rankOf(); i >=1 ; --i) {
repeatMin[i-1] = (target->_shapeInfo[target->rankOf() - min->rankOf() + i] / min->_shapeInfo[i]);
product *= repeatMin[i-1];
}
auto pMin = const_cast<NDArray *>(min);
if(product != 1 )
pMin = new NDArray(min->tile(repeatMin));
std::vector<int> sameDims = ShapeUtils::getDimsWithSameShape(*target, *pMin);
if(max == this)
pTarget->applyBroadcast(op.b, sameDims, pMin, target, extraArgs);
else
pMin->applyBroadcast(op.b, sameDims, pTarget, target, extraArgs);
if(pMin != min)
delete pMin;
if(pTarget != target)
delete pTarget;
}
//////////////////////////////////////////////////////////////////////////
NDArray NDArray::applyTrueBroadcast(nd4j::BroadcastOpsTuple op, const NDArray& other, ExtraArguments *extraArgs) const {
if (isEmpty() || other.isEmpty()) {
@ -2801,6 +2883,67 @@ void NDArray::applyBroadcast(nd4j::broadcast::BoolOps op, const std::vector<int>
registerSpecialUse({result}, {this, other});
}
//////////////////////////////////////////////////////////////////////////
void NDArray::applyBroadcast(nd4j::broadcast::IntOps op, const std::vector<int>& dimensions, const NDArray* other, NDArray* target, ExtraArguments* extraArgs) {
if (!isZ())
throw std::runtime_error("NDArray::applyBroadcast IntOps: you can't use this method on non-Integer array!");
if(isEmpty() || other->isEmpty()) {
if(!target->isEmpty())
throw std::runtime_error("NDArray::applyBroadcast IntOps: when some of input arrays (or both) is empty, target array must be empty as well !");
return;
}
if (dimensions.empty())
return;
auto result = target == nullptr ? this : target;
if (other->lengthOf() == lengthOf() && this->rankOf() == other->rankOf()) {
NDArray::prepareSpecialUse({result}, {this, other});
NativeOpExecutioner::execPairwiseIntTransform(getContext(), fromBroadcastToPairwiseInt(op), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), nullptr);
NDArray::registerSpecialUse({result}, {this, other});
return;
}
NDArray *min(nullptr), *max(nullptr);
if((lengthOf() > other->lengthOf()) || (lengthOf() == other->lengthOf() && rankOf() >= other->rankOf())) {
max = this;
min = const_cast<NDArray*>(other);
}
else {
max = const_cast<NDArray*>(other);
min = this;
}
if(result->dataType() != dataType())
throw std::invalid_argument("NDArray::applyBroadcast int method: type of target array must be the same as input!");
if(!result->isSameShape(max))
throw std::invalid_argument("NDArray::applyBroadcast int method: max and target arrays must have the same shape !");
if(_dataType != other->_dataType)
throw std::invalid_argument("NDArray::applyBroadcast int method: this and other arrays must have the same type !");
std::vector<int> copy(dimensions);
if (dimensions.size() > 1)
std::sort(copy.begin(), copy.end());
Nd4jLong tadLength = shape::tadLength(max->shapeInfo(), copy.data(), (int) copy.size());
if (tadLength != min->lengthOf())
throw std::runtime_error("Tad length mismatch");
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(max->shapeInfo(), copy);
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(result->shapeInfo(), copy);
// TODO: eventually we want separate tads here
NDArray::prepareSpecialUse({result}, {this, other});
if(max == this)
NativeOpExecutioner::execBroadcastInt( getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), copy.data(), (int)copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets());
else
NativeOpExecutioner::execInverseBroadcastInt(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), copy.data(), (int)copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets());
registerSpecialUse({result}, {this, other});
}
//////////////////////////////////////////////////////////////////////////
void NDArray::applyBroadcast(nd4j::broadcast::Ops op, const std::initializer_list<int> dimensions, const NDArray* tadArray, NDArray* target, ExtraArguments* extraArgs) {
std::vector<int> vec(dimensions);
@ -3043,6 +3186,22 @@ void NDArray::applyPairwiseTransform(nd4j::pairwise::BoolOps op, const NDArray *
NDArray::registerSpecialUse({target}, {this, other});
}
////////////////////////////////////////////////////////////////////////
void NDArray::applyPairwiseTransform(nd4j::pairwise::IntOps op, const NDArray *other, NDArray *target, ExtraArguments *extraParams) const{
if (isS())
throw std::runtime_error("NDArray::applyPairwiseTransform IntOps: you can't use this method on String array!");
if (other->lengthOf() != target->lengthOf())
throw std::invalid_argument("NDArray::applyPairwiseTransform IntOps method - lengths of arrays are mismatched");
if (!target->isZ())
throw std::invalid_argument("NDArray::applyPairwiseTransform IntOps method - result must have bool type");
if (dataType() != other->dataType())
throw std::invalid_argument("NDArray::applyPairwiseTransform IntOps method - this and other arrays must have the same type !");
NDArray::prepareSpecialUse({target}, {this, other});
NativeOpExecutioner::execPairwiseIntTransform(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()) : nullptr);
NDArray::registerSpecialUse({target}, {this, other});
}
//////////////////////////////////////////////////////////////////////////
void NDArray::applyPairwiseTransform(nd4j::pairwise::Ops op, const NDArray& other, ExtraArguments *extraParams) {
applyPairwiseTransform(op, &other, this, extraParams);
@ -3585,6 +3744,45 @@ template void NDArray::applyScalar<int8_t>(nd4j::scalar::BoolOps op, const int8_
template void NDArray::applyScalar<uint8_t>(nd4j::scalar::BoolOps op, const uint8_t scalar, NDArray *target, ExtraArguments *extraParams) const;
template void NDArray::applyScalar<bool>(nd4j::scalar::BoolOps op, const bool scalar, NDArray *target, ExtraArguments *extraParams) const;
//////////////////////////////////////////////////////////////////////////
void NDArray::applyScalarArr(nd4j::scalar::IntOps op, const NDArray* scalar, NDArray *target, ExtraArguments *extraParams) const {
if (isS())
throw std::runtime_error("NDArray::applyScalarArr IntOps: you can't use this method on String array!");
if (target == nullptr || target->dataType() != this->dataType())
throw std::invalid_argument("NDArray::applyScalarArr int method: target is nullptr or has not bool type!");
if (dataType() != scalar->dataType()) {
nd4j_printf("NDArray::applyScalarArr IntOps: this dtype: [%i]; scalar dtype: [%i]\n", this->dataType(), scalar->dataType());
throw std::invalid_argument("NDArray::applyScalarArr int method: this and scalar arrays must have the same type!");
}
NDArray::prepareSpecialUse({target}, {this, scalar});
NativeOpExecutioner::execScalarInt(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), scalar->getBuffer(), scalar->getShapeInfo(), scalar->getSpecialBuffer(), scalar->getSpecialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()): nullptr);
NDArray::registerSpecialUse({target}, {this, scalar});
}
////////////////////////////////////////////////////////////////////////
template <typename T>
void NDArray::applyScalar(nd4j::scalar::IntOps op, const T scalar, NDArray *target, ExtraArguments *extraParams) const {
NDArray scalarArr = NDArrayFactory::create(this->dataType(), scalar, getContext());
applyScalarArr(op, &scalarArr, target, extraParams);
}
template <> void NDArray::applyScalar(nd4j::scalar::IntOps op, const NDArray* scalar, NDArray *target, ExtraArguments *extraParams) const { throw std::runtime_error("NDArray::applyScalar<NDArray*> method: do not use me!");}
template void NDArray::applyScalar<double>(nd4j::scalar::IntOps op, const double scalar, NDArray *target, ExtraArguments *extraParams) const;
template void NDArray::applyScalar<float>(nd4j::scalar::IntOps op, const float scalar, NDArray *target, ExtraArguments *extraParams) const;
template void NDArray::applyScalar<float16>(nd4j::scalar::IntOps op, const float16 scalar, NDArray *target, ExtraArguments *extraParams) const;
template void NDArray::applyScalar<bfloat16>(nd4j::scalar::IntOps op, const bfloat16 scalar, NDArray *target, ExtraArguments *extraParams) const;
template void NDArray::applyScalar<Nd4jLong>(nd4j::scalar::IntOps op, const Nd4jLong scalar, NDArray *target, ExtraArguments *extraParams) const;
template void NDArray::applyScalar<int>(nd4j::scalar::IntOps op, const int scalar, NDArray *target, ExtraArguments *extraParams) const;
template void NDArray::applyScalar<int16_t>(nd4j::scalar::IntOps op, const int16_t scalar, NDArray *target, ExtraArguments *extraParams) const;
template void NDArray::applyScalar<int8_t>(nd4j::scalar::IntOps op, const int8_t scalar, NDArray *target, ExtraArguments *extraParams) const;
template void NDArray::applyScalar<uint8_t>(nd4j::scalar::IntOps op, const uint8_t scalar, NDArray *target, ExtraArguments *extraParams) const;
template void NDArray::applyScalar<bool>(nd4j::scalar::IntOps op, const bool scalar, NDArray *target, ExtraArguments *extraParams) const;
////////////////////////////////////////////////////////////////////////
void NDArray::applyIndexReduce(nd4j::indexreduce::Ops op, NDArray* target, const std::vector<int>& dimensions, const ExtraArguments *extraParams) const {
if (isS())

View File

@ -189,6 +189,16 @@ static void execScalarBool(nd4j::LaunchContext *lc,
void *dScalar, Nd4jLong *dSscalarShapeInfo,
void *extraParams, bool allowParallelism = true);
static void execScalarInt(nd4j::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *hScalar, Nd4jLong *hSscalarShapeInfo,
void *dScalar, Nd4jLong *dSscalarShapeInfo,
void *extraParams, bool allowParallelism = true);
static void execScalar(nd4j::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
@ -215,6 +225,20 @@ static void execScalarBool(nd4j::LaunchContext *lc,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ);
static void execScalarInt(nd4j::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *hScalars, Nd4jLong *hScalarShapeInfo,
void *dScalars, Nd4jLong *dScalarShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ);
/**
*
* @param opNum
@ -275,6 +299,30 @@ static void execScalarBool(nd4j::LaunchContext *lc,
int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
static void execBroadcastInt(nd4j::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hY, Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ);
static void execInverseBroadcastInt(nd4j::LaunchContext *lc,
int opNum,
void *x, Nd4jLong *xShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *y, Nd4jLong *yShapeInfo,
void *dY, Nd4jLong *dYShapeInfo,
void *result, Nd4jLong *resultShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
/**
*
@ -308,6 +356,16 @@ static void execScalarBool(nd4j::LaunchContext *lc,
void *dZ, Nd4jLong *dZShapeInfo,
void *extraParams);
static void execPairwiseIntTransform(nd4j::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hY, Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *extraParams);
/**
*
* @param opNum

View File

@ -24,6 +24,10 @@
#include <broadcasting_bool.h>
#include <scalar_bool.h>
#include <pairwise_int.h>
#include <broadcasting_int.h>
#include <scalar_int.h>
#include <loops/transform_float.h>
#include <loops/transform_bool.h>
#include <loops/transform_any.h>
@ -242,6 +246,66 @@ void NativeOpExecutioner::execInverseBroadcastBool(nd4j::LaunchContext *lc,
BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::execInverse(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, BOOL_TYPES);
}
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execBroadcastInt(nd4j::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hY, Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) {
#ifdef _OPENMP
omp_set_nested(1);
#endif
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (xType != yType || xType != zType)
throw nd4j::datatype_exception::build("NativeOpExecutioner::execBroadcastInt", zType, xType, yType);
if (!nd4j::DataTypeUtils::isZ(zType))
throw nd4j::datatype_exception::build("NativeOpExecutioner::execBroadcastInt requires integer data type", zType);
BUILD_SINGLE_SELECTOR(xType, functions::broadcast::BroadcastInt, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), INTEGER_TYPES);
}
void NativeOpExecutioner::execInverseBroadcastInt(nd4j::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hY, Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) {
#ifdef _OPENMP
omp_set_nested(1);
#endif
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (xType != yType || xType != zType)
throw nd4j::datatype_exception::build("NativeOpExecutioner::execPairwiseIntTransform", zType, xType, yType);
if (!nd4j::DataTypeUtils::isZ(zType))
throw nd4j::datatype_exception::build("NativeOpExecutioner::execBroadcastInt requires integer data type", zType);
BUILD_SINGLE_SELECTOR(xType, functions::broadcast::BroadcastInt, ::execInverse(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), INTEGER_TYPES);
}
////////////////////////////////////////////////////////////////////////
/**
*
@ -297,9 +361,42 @@ void NativeOpExecutioner::execPairwiseBoolTransform(nd4j::LaunchContext *lc,
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (xType != yType)
throw nd4j::datatype_exception::build("NativeOpExecutioner::execPairwiseBoolTransform", xType, yType);
if (zType != nd4j::DataType::BOOL)
throw nd4j::datatype_exception::build("NativeOpExecutioner::execPairwiseBoolTransform", nd4j::DataType::BOOL, zType);
BUILD_DOUBLE_SELECTOR(xType, zType, functions::pairwise_transforms::PairWiseBoolTransform, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, extraParams), LIBND4J_TYPES, BOOL_TYPES);
}
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execPairwiseIntTransform(nd4j::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hY, Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *extraParams) {
#ifdef _OPENMP
omp_set_nested(1);
#endif
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (xType != yType || xType != zType)
throw nd4j::datatype_exception::build("NativeOpExecutioner::execPairwiseIntTransform", zType, xType, yType);
if (!nd4j::DataTypeUtils::isZ(zType))
throw nd4j::datatype_exception::build("NativeOpExecutioner::execSPairwiseInt requires integer data type", zType);
BUILD_SINGLE_SELECTOR(xType, functions::pairwise_transforms::PairWiseIntTransform, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, extraParams), INTEGER_TYPES);
}
////////////////////////////////////////////////////////////////////////
/**
*
@ -739,6 +836,64 @@ void NativeOpExecutioner::execScalarBool(nd4j::LaunchContext *lc,
BUILD_DOUBLE_SELECTOR(xType, zType, functions::scalar::ScalarBoolTransform, ::transform(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, hScalars, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, BOOL_TYPES);
}
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execScalarInt(nd4j::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *hScalar, Nd4jLong *hSscalarShapeInfo,
void *dScalar, Nd4jLong *dSscalarShapeInfo,
void *extraParams, bool allowParallelism) {
#ifdef _OPENMP
omp_set_nested(1);
#endif
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto yType = nd4j::ArrayOptions::dataType(hSscalarShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (xType != yType || xType != zType)
throw nd4j::datatype_exception::build("NativeOpExecutioner::execScalarInt", xType, yType);
if (!nd4j::DataTypeUtils::isZ(zType))
throw nd4j::datatype_exception::build("NativeOpExecutioner::execScalarInt", nd4j::DataType::INT32, zType);
BUILD_SINGLE_SELECTOR(xType, functions::scalar::ScalarIntTransform, ::transform(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, hScalar, extraParams), INTEGER_TYPES);
}
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execScalarInt(nd4j::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *hScalars, Nd4jLong *hScalarShapeInfo,
void *dScalars, Nd4jLong *dScalarShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) {
#ifdef _OPENMP
omp_set_nested(1);
#endif
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (xType != yType || xType != zType)
throw nd4j::datatype_exception::build("NativeOpExecutioner::execScalarInt", xType, yType);
if (!nd4j::DataTypeUtils::isZ(zType))
throw nd4j::datatype_exception::build("NativeOpExecutioner::execScalarInt requires integer data type", zType);
BUILD_SINGLE_SELECTOR(xType, functions::scalar::ScalarIntTransform, ::transform(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, hScalars, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), INTEGER_TYPES);
}
////////////////////////////////////////////////////////////////////////
/**
*

View File

@ -38,19 +38,22 @@
#include <loops/reduce_same.h>
#include <loops/reduce_bool.h>
#include <loops/reduce_long.h>
#include <loops/broadcasting.h>
#include <loops/indexreduce.h>
#include <loops/pairwise_transform.h>
#include <loops/pairwise_bool.h>
#include <loops/pairwise_int.h>
#include <loops/broadcasting_bool.h>
#include <loops/broadcasting_int.h>
#include <loops/broadcasting.h>
#include <loops/reduce_float.h>
#include <loops/reduce3.h>
#include <loops/summarystatsreduce.h>
#include <loops/transform_same.h>
#include <loops/scalar.h>
#include <loops/random.h>
#include <loops/special_kernels.h>
#include <loops/scalar.h>
#include <loops/scalar_bool.h>
#include <loops/scalar_int.h>
using namespace nd4j;
@ -152,6 +155,39 @@ void NativeOpExecutioner::execPairwiseBoolTransform( nd4j::LaunchContext *lc,
throw cuda_exception::build("execPairwiseBoolTransform failed", res);
}
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execPairwiseIntTransform( nd4j::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hY, Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *extraParams) {
auto stream = lc->getCudaStream();
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (!DataTypeUtils::isZ(zType))
throw nd4j::datatype_exception::build("NativeOpExecutioner::execPairwiseIntTransform wrong Z operand data type", nd4j::DataType::BOOL, zType);
if (yType != xType || zType != xType)
throw nd4j::datatype_exception::build("NativeOpExecutioner::execPairwiseIntTransform both operands must have same data type", xType, yType);
dim3 launchDims(256, 1024, 16384);
BUILD_SINGLE_SELECTOR(xType, functions::pairwise_transforms::PairWiseIntTransform, ::executeCudaShaped(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, extraParams), INTEGER_TYPES)
// TODO: remove after the release
auto res = cudaStreamSynchronize(*stream);
if (res != 0)
throw cuda_exception::build("execPairwiseIntTransform failed", res);
}
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execSummaryStatsScalar(nd4j::LaunchContext *lc,
int opNum,
@ -252,6 +288,81 @@ void NativeOpExecutioner::execInverseBroadcastBool(nd4j::LaunchContext *lc,
throw cuda_exception::build("execInverseBroadcastBool failed", res);
}
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execBroadcastInt(nd4j::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hY, Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) {
auto stream = lc->getCudaStream();
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (!DataTypeUtils::isZ(zType))
throw std::runtime_error("NativeOpExecutioner::execBroadcastInt requires Z operand to have INT type");
if (yType != xType || zType != xType)
throw std::runtime_error("NativeOpExecutioner::execBroadcastInt requires both X & Y operands to have same type");
if (nd4j::Environment::getInstance()->isDebugAndVerbose())
printf("F3B opNum:[%i]\n", opNum);
dim3 launchDims(256, 256, 1024);
BUILD_SINGLE_SELECTOR(xType, functions::broadcast::BroadcastInt, ::execBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), INTEGER_TYPES)
// TODO: remove after the release
auto res = cudaStreamSynchronize(*stream);
if (res != 0)
throw cuda_exception::build("execBroadcastBool failed", res);
}
void NativeOpExecutioner::execInverseBroadcastInt(nd4j::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hY, Nd4jLong *hYShapeInfo,
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) {
auto stream = lc->getCudaStream();
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (!DataTypeUtils::isZ(zType))
throw std::runtime_error("NativeOpExecutioner::execBroadcastInt requires Z operand to have INT type");
if (yType != xType || zType != xType)
throw std::runtime_error("NativeOpExecutioner::execBroadcastInt requires both X & Y operands to have same type");
if (nd4j::Environment::getInstance()->isDebugAndVerbose())
printf("F3BI opNum:[%i]\n", opNum);
dim3 launchDims(256, 256, 1024);
BUILD_SINGLE_SELECTOR(xType, functions::broadcast::BroadcastInt, ::execInverseBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), INTEGER_TYPES)
// TODO: remove after the release
auto res = cudaStreamSynchronize(*stream);
if (res != 0)
throw cuda_exception::build("execInverseBroadcastInt failed", res);
}
////////////////////////////////////////////////////////////////////////
/**
*
@ -1114,6 +1225,75 @@ void NativeOpExecutioner::execScalarBool(nd4j::LaunchContext *lc,
throw cuda_exception::build("execScalarBool B failed", res);
}
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execScalarInt(nd4j::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *hScalar, Nd4jLong *hScalarShapeInfo,
void *dScalar, Nd4jLong *dScalarShapeInfo,
void *extraParams, bool allowParallelism) {
auto stream = lc->getCudaStream();
dim3 launchDims = dim3(256, 512, 8192);
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (xType != yType || zType != xType)
throw std::runtime_error("NativeOpExecutioner::execScalarInt requires X & Y to have same type");
if (!DataTypeUtils::isZ(zType) )
throw std::runtime_error("NativeOpExecutioner::execScalarInt requires Z operand to have INT type");
BUILD_SINGLE_SELECTOR(xType, functions::scalar::ScalarIntTransform, ::executeCudaShaped(launchDims, stream, opNum, dX, dXShapeInfo, dZ, dZShapeInfo, dScalar, extraParams), INTEGER_TYPES);
// TODO: remove after the release
auto res = cudaStreamSynchronize(*stream);
if (res != 0)
throw cuda_exception::build("execScalarInt failed", res);
}
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execScalarInt(nd4j::LaunchContext *lc,
int opNum,
void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo,
void *extraParams,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *hScalars, Nd4jLong *hScalarShapeInfo,
void *dScalars, Nd4jLong *dScalarShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) {
auto stream = lc->getCudaStream();
dim3 launchDims(256, 512, 8192);
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (xType != yType || zType != xType)
throw std::runtime_error("NativeOpExecutioner::execScalarInt requires X & Y to have same type");
if (!DataTypeUtils::isZ(zType) )
throw std::runtime_error("NativeOpExecutioner::execScalarInt requires Z operand to have INT type");
BUILD_SINGLE_SELECTOR(xType, functions::scalar::ScalarIntTransform, ::executeCudaAlongDimension(launchDims, stream, opNum, dX, dXShapeInfo, dZ, dZShapeInfo, dScalars, extraParams, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), INTEGER_TYPES);
// TODO: remove after the release
auto res = cudaStreamSynchronize(*stream);
if (res != 0)
throw cuda_exception::build("execScalarInt B failed", res);
}
////////////////////////////////////////////////////////////////////////
void NativeOpExecutioner::execScalar(nd4j::LaunchContext *lc,
int opNum,

View File

@ -71,12 +71,25 @@ inline pairwise::BoolOps fromBroadcastToPairwiseBool(broadcast::BoolOps op) {
case broadcast::And: return pairwise::And;
case broadcast::Or: return pairwise::Or;
case broadcast::Xor: return pairwise::Xor;
case broadcast::Not: return pairwise::Not;
case broadcast::Not: return pairwise::Not;
default:
throw std::runtime_error("fromBroadcastToPairwiseBool: Not convertible operation");
}
}
inline pairwise::IntOps fromBroadcastToPairwiseInt(broadcast::IntOps op) {
switch (op) {
case broadcast::IntOps::IntAnd: return pairwise::IntOps::IntAnd;
case broadcast::IntOps::IntOr: return pairwise::IntOps::IntOr;
case broadcast::IntOps::IntXor: return pairwise::IntOps::IntXor;
case broadcast::IntOps::ShiftLeft: return pairwise::IntOps::ShiftLeft;
case broadcast::IntOps::ShiftRight: return pairwise::IntOps::ShiftRight;
case broadcast::IntOps::CyclicShiftLeft: return pairwise::IntOps::CyclicShiftLeft;
case broadcast::IntOps::CyclicShiftRight: return pairwise::IntOps::CyclicShiftRight;
default:
throw std::runtime_error("fromBroadcastToPairwiseInt: Not convertible operation");
}
}
}
#endif //DEV_TESTS_BROADCASTPAIRWISECONVERTER_H

View File

@ -0,0 +1,164 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
/*
* broadcasting.h
*
* Created on: Dec 28, 2015
* Author: agibsonccc
*/
#ifndef BROADCASTING_INT_H_
#define BROADCASTING_INT_H_
#include <dll.h>
#include <helpers/shape.h>
#include <templatemath.h>
#include <pairwise_util.h>
#include <ops/ops.h>
#include <op_boilerplate.h>
#include <helpers/DebugHelper.h>
#ifdef __CUDACC__
#include <cuda.h>
#include <cuda_runtime.h>
#endif
#ifdef __JNI__
#include <jni.h>
#endif
#include <helpers/TAD.h>
#include "legacy_ops.h"
namespace functions {
namespace broadcast {
/**
* Broadcast operation
* for broadcasting a smaller tensor
* along long a bigger one.
*/
template<typename X>
class BroadcastInt {
public:
#ifdef __CUDACC__
template<typename OpType>
static __device__ void transformCuda(
void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void *result,
Nd4jLong *resultShapeInfo,
int *dimension,
int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
template <typename OpClass>
static __host__ void intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
static __host__ void execBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
template<typename OpType>
static __device__ void transformInverseCuda(
void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void *result,
Nd4jLong *resultShapeInfo,
int *dimension,
int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
template <typename OpClass>
static __host__ void intermediateInverseBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
static __host__ void execInverseBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
#endif
static void exec(int opNum,
void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void *result,
Nd4jLong *resultShapeInfo,
int *dimension,
int dimensionLength,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset,
Nd4jLong *tadShapeInfoZ,
Nd4jLong *tadOffsetZ);
static void execInverse(int opNum,
void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void *result,
Nd4jLong *resultShapeInfo,
int *dimension,
int dimensionLength,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset,
Nd4jLong *tadShapeInfoZ,
Nd4jLong *tadOffsetZ);
/**
* CPU execution
* @param x the input
* @param xShapeInfo the x shape information
* @param y the y data
* @param yShapeInfo the y shape information
* @param result the result
* @param resultShapeInfo the result shape information
* @param dimension the dimension to broadcast along long
* @param dimensionLength the length of the dimension buffer
*/
template<typename OpType>
static void exec(void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void *result,
Nd4jLong *resultShapeInfo,
int *dimension,
int dimensionLength,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset,
Nd4jLong *tadShapeInfoZ,
Nd4jLong *tadOffsetZ);
template<typename OpType>
static void execInverse(void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void *result,
Nd4jLong *resultShapeInfo,
int *dimension,
int dimensionLength,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset,
Nd4jLong *tadShapeInfoZ,
Nd4jLong *tadOffsetZ);
};
}
}
#endif /* BROADCASTING_H_ */

View File

@ -0,0 +1,464 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <op_boilerplate.h>
#include <loops/broadcasting_int.h>
#include <loops/legacy_ops.h>
#include <types/types.h>
#include <LoopKind.h>
#include <helpers/ConstantTadHelper.h>
using namespace simdOps;
namespace functions {
namespace broadcast {
template <typename X>
void BroadcastInt<X>::exec(const int opNum,
void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void *z,
Nd4jLong *zShapeInfo,
int *dimension,
int dimensionLength,
Nd4jLong *xTadShapeInfo,
Nd4jLong *xTadOffset,
Nd4jLong *zTadShapeInfo,
Nd4jLong *zTadOffset) {
DISPATCH_BY_OPNUM_T(exec, PARAMS(x,
xShapeInfo,
y,
yShapeInfo,
z,
zShapeInfo,
dimension,
dimensionLength,
xTadShapeInfo,
xTadOffset,
zTadShapeInfo,
zTadOffset), BROADCAST_INT_OPS);
}
template <typename X>
void BroadcastInt<X>::execInverse(const int opNum,
void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void *z,
Nd4jLong *zShapeInfo,
int *dimension,
int dimensionLength,
Nd4jLong *xTadShapeInfo,
Nd4jLong *xTadOffset,
Nd4jLong *zTadShapeInfo,
Nd4jLong *zTadOffset) {
DISPATCH_BY_OPNUM_T(execInverse, PARAMS(x,
xShapeInfo,
y,
yShapeInfo,
z,
zShapeInfo,
dimension,
dimensionLength,
xTadShapeInfo,
xTadOffset,
zTadShapeInfo,
zTadOffset), BROADCAST_INT_OPS);
}
template <typename X>
template<typename OpType>
void BroadcastInt<X>::exec(void *vx,
Nd4jLong *xShapeInfo,
void *vy,
Nd4jLong *yShapeInfo,
void *vz,
Nd4jLong *zShapeInfo,
int *dimension,
int dimensionLength,
Nd4jLong *xTadShapeInfo,
Nd4jLong *xTadOffset,
Nd4jLong *zTadShapeInfo,
Nd4jLong *zTadOffset) {
auto x = reinterpret_cast<X *>(vx);
auto y = reinterpret_cast<X *>(vy);
auto z = reinterpret_cast<X *>(vz);
//decompose in to several sub tads after
//moving all dimensions (in sorted order)
//to the back.
//permuted version of the x shape info for setting up the tad problem
auto xTadShapeShapeInfo = xTadShapeInfo;
auto tadOffsets = xTadOffset;
if (xTadShapeInfo == nullptr || tadOffsets == nullptr) {
auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(xShapeInfo, dimension, dimensionLength);
xTadShapeShapeInfo = tadPack.primaryShapeInfo();
tadOffsets = tadPack.primaryOffsets();
}
//int *resultStride = shape::stride(xTadShapeShapeInfo);
unsigned int tadLength = shape::length(xTadShapeShapeInfo);//shape::length(xTadShapeShapeInfo);
unsigned int tads = shape::length(xShapeInfo) / tadLength;
if (zTadShapeInfo == nullptr) {
zTadShapeInfo = xTadShapeShapeInfo;
zTadOffset = tadOffsets;
}
auto lenZ = shape::length(zTadShapeInfo);
auto lenY = shape::length(yShapeInfo);
int tadsPerThread = tads / TAD_THRESHOLD;
int threads = nd4j::math::nd4j_max<int>(1, tadsPerThread);
threads = nd4j::math::nd4j_min<int>(threads, omp_get_max_threads());
auto xEws = shape::elementWiseStride(xTadShapeShapeInfo);
auto yEws = shape::elementWiseStride(yShapeInfo);
auto zEws = shape::elementWiseStride(zTadShapeInfo);
const nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXYZ(xTadShapeShapeInfo, yShapeInfo, zTadShapeInfo);
if (kindOfLoop == nd4j::LoopKind::EWS1) {
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
for (int i = 0; i < tads; i++) {
auto oX = x + tadOffsets[i];
auto oZ = z + zTadOffset[i];
PRAGMA_OMP_SIMD
for (unsigned int f = 0; f < tadLength; f++)
oZ[f] = OpType::op(oX[f], y[f]);
}
}
else if(kindOfLoop == nd4j::LoopKind::EWSNONZERO) {
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
for (int i = 0; i < tads; i++) {
auto oX = x + tadOffsets[i];
auto oZ = z + zTadOffset[i];
PRAGMA_OMP_SIMD
for (unsigned int f = 0; f < tadLength; f++)
oZ[f * zEws] = OpType::op(oX[f * xEws], y[f * yEws]);
}
}
else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xTadShapeShapeInfo, zTadShapeInfo)) {
uint tadShapeShapeInfoCast[MAX_RANK];
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast);
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
for (int i = 0; i < tads; i++) {
auto oZ = z + zTadOffset[i];
auto oX = x + tadOffsets[i];
// TODO: cover this codebranch with tests
// all this stuff already happens within thread
PRAGMA_OMP_SIMD
for (int f = 0; f < tadLength; f++) {
auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, tadLength, canCastX);
oZ[offset] = OpType::op(oX[offset], y[offset]);
}
}
}
else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo)) {
uint tadShapeShapeInfoCast[MAX_RANK];
uint tadShapeInfoZCast[MAX_RANK];
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast);
bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast);
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
for (int i = 0; i < tads; i++) {
auto oZ = z + zTadOffset[i];
auto oX = x + tadOffsets[i];
PRAGMA_OMP_SIMD
for (int f = 0; f < tadLength; f++) {
auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, tadLength, canCastX);
auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, lenZ, canCastZ);
oZ[zOffset] = OpType::op(oX[offset], y[offset]);
}
}
}
else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, zTadShapeInfo)) {
uint tadShapeShapeInfoCast[MAX_RANK];
uint yShapeInfoCast[MAX_RANK];
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast);
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
for (int i = 0; i < tads; i++) {
auto oZ = z + zTadOffset[i];
auto oX = x + tadOffsets[i];
PRAGMA_OMP_SIMD
for (int f = 0; f < tadLength; f++) {
auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, tadLength, canCastX);
auto yOffset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, lenY, canCastY);
oZ[offset] = OpType::op(oX[offset], y[yOffset]);
}
}
}
else if(shape::haveSameShapeAndStrides(yShapeInfo, zTadShapeInfo)) {
uint tadShapeShapeInfoCast[MAX_RANK];
uint yShapeInfoCast[MAX_RANK];
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast);
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
for (int i = 0; i < tads; i++) {
auto oZ = z + zTadOffset[i];
auto oX = x + tadOffsets[i];
PRAGMA_OMP_SIMD
for (int f = 0; f < tadLength; f++) {
auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, tadLength, canCastX);
auto offset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, lenY, canCastY);
oZ[offset] = OpType::op(oX[xOffset], y[offset]);
}
}
}
else {
uint tadShapeShapeInfoCast[MAX_RANK];
uint tadShapeInfoZCast[MAX_RANK];
uint yShapeInfoCast[MAX_RANK];
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast);
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast);
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
for (int i = 0; i < tads; i++) {
auto oZ = z + zTadOffset[i];
auto oX = x + tadOffsets[i];
PRAGMA_OMP_SIMD
for (int f = 0; f < tadLength; f++) {
auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, tadLength, canCastX);
auto yOffset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, lenY, canCastY);
auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, lenZ, canCastZ);
oZ[zOffset] = OpType::op(oX[xOffset], y[yOffset]);
}
}
}
}
template <typename X>
template<typename OpType>
void BroadcastInt<X>::execInverse(void *vx,
Nd4jLong *xShapeInfo,
void *vy,
Nd4jLong *yShapeInfo,
void *vz,
Nd4jLong *zShapeInfo,
int *dimension,
int dimensionLength,
Nd4jLong *yTadShapeInfo,
Nd4jLong *yTadOffset,
Nd4jLong *zTadShapeInfo,
Nd4jLong *zTadOffset) {
auto x = reinterpret_cast<X *>(vx);
auto y = reinterpret_cast<X *>(vy);
auto z = reinterpret_cast<X *>(vz);
//decompose in to several sub tads after
//moving all dimensions (in sorted order)
//to the back.
//permuted version of the x shape info for setting up the tad problem
auto yTadShapeShapeInfo = yTadShapeInfo;
auto tadOffsets = yTadOffset;
if (yTadShapeInfo == nullptr || tadOffsets == nullptr) {
auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(yShapeInfo, dimension, dimensionLength);
yTadShapeShapeInfo = tadPack.primaryShapeInfo();
tadOffsets = tadPack.primaryOffsets();
}
//int *resultStride = shape::stride(yTadShapeShapeInfo);
unsigned int tadLength = shape::length(yTadShapeShapeInfo);
unsigned int tads = shape::length(yShapeInfo) / tadLength;
if (zTadShapeInfo == nullptr) {
zTadShapeInfo = yTadShapeShapeInfo;
zTadOffset = tadOffsets;
}
auto lenZ = shape::length(zTadShapeInfo);
auto lenX = shape::length(xShapeInfo);
int tadsPerThread = tads / TAD_THRESHOLD;
int threads = nd4j::math::nd4j_max<int>(1, tadsPerThread);
threads = nd4j::math::nd4j_min<int>(threads, omp_get_max_threads());
auto yEws = shape::elementWiseStride(yTadShapeShapeInfo);
auto xEws = shape::elementWiseStride(xShapeInfo);
auto zEws = shape::elementWiseStride(zTadShapeInfo);
const nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXYZ(yTadShapeShapeInfo, xShapeInfo, zTadShapeInfo);
if (kindOfLoop == nd4j::LoopKind::EWS1) {
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
for (int i = 0; i < tads; i++) {
auto oY = y + tadOffsets[i];
auto oZ = z + zTadOffset[i];
PRAGMA_OMP_SIMD
for (unsigned int f = 0; f < tadLength; f++)
oZ[f] = OpType::op(x[f], oY[f]);
}
}
else if(kindOfLoop == nd4j::LoopKind::EWSNONZERO) {
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
for (int i = 0; i < tads; i++) {
auto oY = y + tadOffsets[i];
auto oZ = z + zTadOffset[i];
PRAGMA_OMP_SIMD
for (uint f = 0; f < tadLength; f++)
oZ[f * zEws] = OpType::op(x[f * xEws], oY[f * yEws]);
}
}
else if(shape::haveSameShapeAndStrides(yTadShapeShapeInfo, xShapeInfo) && shape::haveSameShapeAndStrides(yTadShapeShapeInfo, zTadShapeInfo)) {
uint tadShapeShapeInfoCast[MAX_RANK];
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast);
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
for (int i = 0; i < tads; i++) {
auto oY = y + tadOffsets[i];
auto oZ = z + zTadOffset[i];
// TODO: cover this codebranch with tests
// all this stuff already happens within thread
PRAGMA_OMP_SIMD
for (int f = 0; f < tadLength; f++) {
auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, tadLength, canCastY);
oZ[offset] = OpType::op(x[offset], oY[offset]);
}
}
}
else if(shape::haveSameShapeAndStrides(yTadShapeShapeInfo, xShapeInfo)) {
uint tadShapeShapeInfoCast[MAX_RANK];
uint tadShapeInfoZCast[MAX_RANK];
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast);
bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast);
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
for (int i = 0; i < tads; i++) {
auto oZ = z + zTadOffset[i];
auto oY = y + tadOffsets[i];
PRAGMA_OMP_SIMD
for (int f = 0; f < tadLength; f++) {
auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, tadLength, canCastY);
auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, lenZ, canCastZ);
oZ[zOffset] = OpType::op(x[offset], oY[offset]);
}
}
}
else if(shape::haveSameShapeAndStrides(yTadShapeShapeInfo, zTadShapeInfo)) {
uint tadShapeShapeInfoCast[MAX_RANK];
uint xShapeInfoCast[MAX_RANK];
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast);
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
for (int i = 0; i < tads; i++) {
auto oZ = z + zTadOffset[i];
auto oY = y + tadOffsets[i];
PRAGMA_OMP_SIMD
for (int f = 0; f < tadLength; f++) {
auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, tadLength, canCastY);
auto xOffset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, lenX, canCastX);
oZ[offset] = OpType::op(x[xOffset], oY[offset]);
}
}
}
else if(shape::haveSameShapeAndStrides(xShapeInfo, zTadShapeInfo)) {
uint tadShapeShapeInfoCast[MAX_RANK];
uint xShapeInfoCast[MAX_RANK];
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast);
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
for (int i = 0; i < tads; i++) {
auto oZ = z + zTadOffset[i];
auto oY = y + tadOffsets[i];
PRAGMA_OMP_SIMD
for (int f = 0; f < tadLength; f++) {
auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, tadLength, canCastY);
auto offset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, lenX, canCastX);
oZ[offset] = OpType::op(x[offset], oY[yOffset]);
}
}
}
else {
uint xShapeInfoCast[MAX_RANK];
uint tadShapeShapeInfoCast[MAX_RANK];
uint tadShapeInfoZCast[MAX_RANK];
bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast);
bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast);
PRAGMA_OMP_PARALLEL_FOR_THREADS(threads)
for (int i = 0; i < tads; i++) {
auto oZ = z + zTadOffset[i];
auto oY = y + tadOffsets[i];
PRAGMA_OMP_SIMD
for (int f = 0; f < tadLength; f++) {
auto xOffset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, lenX, canCastX);
auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, tadLength, canCastY);
auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, lenZ, canCastZ);
oZ[zOffset] = OpType::op(x[xOffset], oY[yOffset]);
}
}
}
}
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT BroadcastInt, , INTEGER_TYPES);
}
}

View File

@ -0,0 +1,309 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <loops/pairwise_int.h>
#include <types/types.h>
#include <LoopKind.h>
#include <OmpLaunchHelper.h>
using namespace simdOps;
namespace functions {
namespace pairwise_transforms {
template <typename X>
void PairWiseIntTransform<X>::exec(
const int opNum,
void *x,
Nd4jLong xEws,
void *y,
Nd4jLong yEws,
void *z,
Nd4jLong zEws,
void *extraParams,
Nd4jLong n) {
DISPATCH_BY_OPNUM_T(exec, PARAMS(x,
xEws,
y,
yEws,
z,
zEws,
extraParams,
n), PAIRWISE_INT_OPS);
};
template <typename X>
template <typename OpType>
void PairWiseIntTransform<X>::exec(void *vx,
Nd4jLong xEws,
void *vy,
Nd4jLong yEws,
void *vz,
Nd4jLong zEws,
void *vextraParams,
const Nd4jLong n) {
auto x = reinterpret_cast<X *>(vx);
auto y = reinterpret_cast<X *>(vy);
auto z = reinterpret_cast<X *>(vz);
auto extraParams = reinterpret_cast<X *>(vextraParams);
nd4j::OmpLaunchHelper info(n);
if (xEws == 1 && yEws == 1 && zEws == 1) {
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
{
auto threadNum = omp_get_thread_num();
Nd4jLong threadOffset = info.getThreadOffset(threadNum);
auto xi = x + threadOffset;
auto yi = y + threadOffset;
auto zi = z + threadOffset;
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD
for (Nd4jLong i = 0; i < ulen; i++)
zi[i] = OpType::op(xi[i], yi[i], extraParams);
}
}
else {
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
{
auto threadNum = omp_get_thread_num();
Nd4jLong threadOffset = info.getThreadOffset(threadNum);
auto xi = x + xEws*threadOffset;
auto yi = y + yEws*threadOffset;
auto zi = z + zEws*threadOffset;
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD
for (Nd4jLong i = 0; i < ulen; i++)
zi[i*zEws] = OpType::op(xi[i*xEws], yi[i*yEws], extraParams);
}
}
}
template <typename X>
void PairWiseIntTransform<X>::exec(
const int opNum,
void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void *z,
Nd4jLong *zShapeInfo,
void *extraParams) {
DISPATCH_BY_OPNUM_T(exec, PARAMS(x,
xShapeInfo,
y,
yShapeInfo,
z,
zShapeInfo,
extraParams),
PAIRWISE_INT_OPS);
};
template <typename X>
template <typename OpType>
void PairWiseIntTransform<X>::exec(void *vx, Nd4jLong* xShapeInfo,
void *vy, Nd4jLong* yShapeInfo,
void *vz, Nd4jLong* zShapeInfo,
void *vextraParams) {
auto x = reinterpret_cast<X *>(vx);
auto y = reinterpret_cast<X *>(vy);
auto z = reinterpret_cast<X *>(vz);
auto extraParams = reinterpret_cast<X *>(vextraParams);
auto n = shape::length(xShapeInfo);
auto xEws = shape::elementWiseStride(xShapeInfo);
auto yEws = shape::elementWiseStride(yShapeInfo);
auto zEws = shape::elementWiseStride(zShapeInfo);
nd4j::OmpLaunchHelper info(n);
if (shape::isScalar(yShapeInfo)) {
uint xShapeInfoCast[MAX_RANK];
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
{
auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum);
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD
for(Nd4jLong i = 0; i < ulen; i++) {
auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, n, canCastX);
z[offset] = OpType::op(x[offset], y[0], extraParams);
}
}
}
else {
uint zShapeInfoCast[MAX_RANK];
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast);
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
{
auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum);
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD
for(Nd4jLong i = 0; i < ulen; i++) {
auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, n, canCastX);
auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, n, canCastZ);
z[zOffset] = OpType::op(x[xOffset], y[0], extraParams);
}
}
}
return;
}
const nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXYZ(xShapeInfo, yShapeInfo, zShapeInfo);
const bool sameShapesXY = shape::shapeEquals(xShapeInfo, yShapeInfo);
if ((kindOfLoop == nd4j::LoopKind::EWS1 || kindOfLoop == nd4j::LoopKind::EWSNONZERO) && sameShapesXY) {
exec<OpType>(x, xEws, y, yEws, z, zEws, extraParams, n);
}
else if ((kindOfLoop == nd4j::LoopKind::EWS1 || kindOfLoop == nd4j::LoopKind::EWSNONZERO) && !sameShapesXY) { //not same shape
exec<OpType>(x, xEws, y, yEws, z, zEws, extraParams, shape::length(yShapeInfo));
}
else {
if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
uint xShapeInfoCast[MAX_RANK];
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
{
auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum);
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD
for (Nd4jLong i = 0; i < ulen; i++) {
auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, n, canCastX);
z[offset] = OpType::op(x[offset], y[offset], extraParams);
}
}
}
else if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) {
uint xShapeInfoCast[MAX_RANK];
uint zShapeInfoCast[MAX_RANK];
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast);
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
{
auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum);
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD
for (Nd4jLong i = 0; i < ulen; i++) {
auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, n, canCastX);
auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, n, canCastZ);
z[zOffset] = OpType::op(x[offset], y[offset], extraParams);
}
}
}
else if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
uint xShapeInfoCast[MAX_RANK];
uint yShapeInfoCast[MAX_RANK];
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
const bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
{
auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum);
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD
for (Nd4jLong i = 0; i < ulen; i++) {
auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, n, canCastX);
auto yOffset = shape::indexOffset(i + threadOffset, yShapeInfo, yShapeInfoCast, n, canCastY);
z[offset] = OpType::op(x[offset], y[yOffset], extraParams);
}
}
}
else if(shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo)) {
uint xShapeInfoCast[MAX_RANK];
uint yShapeInfoCast[MAX_RANK];
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
const bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
{
auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum);
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD
for (Nd4jLong i = 0; i < ulen; i++) {
auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, n, canCastX);
auto offset = shape::indexOffset(i + threadOffset, yShapeInfo, yShapeInfoCast, n, canCastY);
z[offset] = OpType::op(x[xOffset], y[offset], extraParams);
}
}
}
else {
uint xShapeInfoCast[MAX_RANK];
uint yShapeInfoCast[MAX_RANK];
uint zShapeInfoCast[MAX_RANK];
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast);
const bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast);
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast);
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
{
auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum);
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD
for (Nd4jLong i = 0; i < ulen; i++) {
auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, n, canCastX);
auto yOffset = shape::indexOffset(i + threadOffset, yShapeInfo, yShapeInfoCast, n, canCastY);
auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, n, canCastZ);
z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams);
}
}
}
}
}
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT PairWiseIntTransform, , INTEGER_TYPES);
}
}

View File

@ -0,0 +1,255 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include "../scalar_int.h"
#include <op_boilerplate.h>
#include <types/types.h>
#include <LoopKind.h>
#include "../legacy_ops.h"
using namespace simdOps;
namespace functions {
namespace scalar {
template<typename X>
template<typename OpType>
void ScalarIntTransform<X>::transform(void *vx, Nd4jLong *xShapeInfo,
void *vextraParams,
void *vz, Nd4jLong *zShapeInfo,
void *vscalars,
int *dimension, int dimensionLength,
Nd4jLong *xTadShapeInfo, Nd4jLong *xTadOffsets,
Nd4jLong *zTadShapeInfo, Nd4jLong *zTadOffsets) {
auto x = reinterpret_cast<X *>(vx);
auto z = reinterpret_cast<X *>(vz);
auto scalars = reinterpret_cast<X *>(vscalars);
auto extraParams = reinterpret_cast<X *>(vextraParams);
if (zTadShapeInfo == nullptr) {
zTadShapeInfo = xTadShapeInfo;
zTadOffsets = xTadOffsets;
}
// tad preparation
const int xTadEws = shape::elementWiseStride(xTadShapeInfo);
const int zTadEws = shape::elementWiseStride(zTadShapeInfo);
const int tadLength = shape::tadLength(xShapeInfo, dimension, dimensionLength);
const int numTads = shape::length(xShapeInfo) / tadLength;
nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXZ(xTadShapeInfo, zTadShapeInfo);
if (kindOfLoop != nd4j::LoopKind::EWS1 && kindOfLoop != nd4j::LoopKind::EWSNONZERO) {
printf("ScalarIntTransform<X>::transform: super-bad loop visited. Shouldn't ever happen\n");
return;
}
int num_threads = nd4j::math::nd4j_min<int>(numTads, omp_get_max_threads());
if (kindOfLoop == nd4j::LoopKind::EWS1) {
PRAGMA_OMP_PARALLEL_FOR_THREADS(num_threads)
for (unsigned int r = 0; r < numTads; r++) {
auto oZ = z + zTadOffsets[r];
auto oX = x + xTadOffsets[r];
PRAGMA_OMP_SIMD
for (unsigned int f = 0; f < tadLength; f++)
oZ[f] = OpType::op(oX[f], scalars[r], extraParams);
}
}
else { // kindOfLoop != nd4j::LoopKind::EWSNONZERO
PRAGMA_OMP_PARALLEL_FOR_THREADS(num_threads)
for (unsigned int r = 0; r < numTads; r++) {
auto oZ = z + zTadOffsets[r];
auto oX = x + xTadOffsets[r];
PRAGMA_OMP_SIMD
for (unsigned int f = 0; f < tadLength; f++)
oZ[f * zTadEws] = OpType::op(oX[f * xTadEws], scalars[r], extraParams);
}
}
}
template<typename X>
void ScalarIntTransform<X>::transform(int opNum,
void *x,
Nd4jLong *xShapeInfo,
void *extraParams,
void *z,
Nd4jLong *zShapeInfo,
void *scalars,
int *dimension,
int dimensionLength,
Nd4jLong *xTadShapeInfo,
Nd4jLong *xTadOffsets,
Nd4jLong *zTadShapeInfo,
Nd4jLong *zTadOffsets) {
DISPATCH_BY_OPNUM_T(transform, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, scalars, dimension, dimensionLength, xTadShapeInfo, xTadOffsets, zTadShapeInfo, zTadOffsets), SCALAR_INT_OPS);
}
template<typename X>
void ScalarIntTransform<X>::transform(const int opNum,
void *x,
Nd4jLong xEws,
void *z,
Nd4jLong zEws,
void *scalar,
void *extraParams,
const Nd4jLong n) {
DISPATCH_BY_OPNUM_T(transform, PARAMS(x, xEws, z, zEws, scalar, extraParams, n), SCALAR_INT_OPS);
}
template<typename X>
void ScalarIntTransform<X>::transform(const int opNum,
void *x,
Nd4jLong *xShapeInfo,
void *z,
Nd4jLong *zShapeInfo,
void *scalar,
void *extraParams) {
DISPATCH_BY_OPNUM_T(transform, PARAMS(x, xShapeInfo, z, zShapeInfo, scalar, extraParams), SCALAR_INT_OPS);
}
template<typename X>
template<typename OpType>
void ScalarIntTransform<X>::transform(void *vx,
Nd4jLong *xShapeInfo,
void *vz,
Nd4jLong *zShapeInfo,
void *vscalar,
void *vextraParams) {
auto x = reinterpret_cast<X *>(vx);
auto z = reinterpret_cast<X *>(vz);
auto scalar = reinterpret_cast<X *>(vscalar)[0];
auto extraParams = reinterpret_cast<X *>(vextraParams);
auto xEws = shape::elementWiseStride(xShapeInfo);
auto zEws = shape::elementWiseStride(zShapeInfo);
auto len = shape::length(xShapeInfo);
// nd4j_logger("Launching scalar: xOrder: %i; zOrder: %i; xEWS: %i\n", xOrder, zOrder, xEws);
nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXZ(xShapeInfo, zShapeInfo);
if (kindOfLoop == nd4j::LoopKind::EWS1 || kindOfLoop == nd4j::LoopKind::EWSNONZERO) {
transform<OpType>(x, xEws, z, zEws, vscalar, extraParams, len);
return;
}
uint xShapeInfoCast[MAX_RANK];
const bool canCastX = nd4j::DataTypeUtils::castShapeInfo<uint>(xShapeInfo, xShapeInfoCast);
nd4j::OmpLaunchHelper info(len);
if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) {
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
{
auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum);
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD
for (unsigned int i = 0; i < ulen; i++) {
auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, len, canCastX);
z[offset] = OpType::op(x[offset], scalar, extraParams);
}
}
}
else {
uint zShapeInfoCast[MAX_RANK];
const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo<uint>(zShapeInfo, zShapeInfoCast);
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
{
auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum);
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD
for (unsigned int i = 0; i < ulen; i++) {
auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, len, canCastX);
auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, len, canCastZ);
z[zOffset] = OpType::op(x[xOffset], scalar, extraParams);
}
}
}
}
template<typename X>
template<typename OpType>
void ScalarIntTransform<X>::transform(void *vx,
Nd4jLong xEws,
void *vz,
Nd4jLong zEws,
void *vscalar,
void *vextraParams,
const Nd4jLong len) {
auto x = reinterpret_cast<X *>(vx);
auto z = reinterpret_cast<X *>(vz);
auto scalar = reinterpret_cast<X *>(vscalar)[0];
auto extraParams = reinterpret_cast<X *>(vextraParams);
nd4j::OmpLaunchHelper info(len);
if (xEws == 1 && zEws == 1) {
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
{
auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum);
auto xi = x + threadOffset;
auto zi = z + threadOffset;
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD
for (unsigned int i = 0; i < ulen; i++)
zi[i] = OpType::op(xi[i], scalar, extraParams);
}
}
else {
PRAGMA_OMP_PARALLEL_THREADS(info._numThreads)
{
auto threadNum = omp_get_thread_num();
auto threadOffset = info.getThreadOffset(threadNum);
auto xi = x + xEws * threadOffset;
auto zi = z + zEws * threadOffset;
auto ulen = static_cast<unsigned int>(info.getItersPerThread(threadNum));
PRAGMA_OMP_SIMD
for (unsigned int i = 0; i < ulen; i++)
zi[i * zEws] = OpType::op(xi[i * xEws], scalar, extraParams);
}
}
}
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT ScalarIntTransform, , INTEGER_TYPES);
}
}

View File

@ -0,0 +1,291 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <op_boilerplate.h>
#include <loops/broadcasting_int.h>
#include <loops/legacy_ops.h>
#include <types/types.h>
#include <Environment.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <string>
#include <stdexcept>
#include <StringUtils.h>
using namespace simdOps;
//////////////////////////////////////////////////////////////////////////
template<typename X, typename OpClass>
static __global__ void broadcastIntSimple(
void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void *z,
Nd4jLong *zShapeInfo,
int *dimension,
int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
functions::broadcast::BroadcastInt<X>::template transformCuda<OpClass>(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo,dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ);
}
//////////////////////////////////////////////////////////////////////////
template<typename X, typename OpClass>
static __global__ void broadcastBoolInverseSimple(
void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void *z,
Nd4jLong *zShapeInfo,
int *dimension,
int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
functions::broadcast::BroadcastInt<X>::template transformInverseCuda<OpClass>(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo,dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ);
}
namespace functions {
namespace broadcast {
//////////////////////////////////////////////////////////////////////////
template<typename X>
template <typename OpClass>
__host__ void BroadcastInt<X>::intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
broadcastIntSimple<X, OpClass><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ);
}
//////////////////////////////////////////////////////////////////////////
template<typename X>
__host__ void BroadcastInt<X>::execBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
DISPATCH_BY_OPNUM_T(intermediateBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), OPS_A(BROADCAST_INT_OPS))
}
//////////////////////////////////////////////////////////////////////////
template<typename X>
template <typename OpClass>
__host__ void BroadcastInt<X>::intermediateInverseBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
broadcastBoolInverseSimple<X, OpClass><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ);
}
//////////////////////////////////////////////////////////////////////////
template<typename X>
__host__ void BroadcastInt<X>::execInverseBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
DISPATCH_BY_OPNUM_T(intermediateInverseBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), OPS_A(BROADCAST_INT_OPS))
}
//////////////////////////////////////////////////////////////////////////
template<typename X>
template <typename OpType>
__device__ void BroadcastInt<X>::transformInverseCuda(
void *vx, Nd4jLong *xShapeInfo,
void *vy, Nd4jLong *yShapeInfo,
void *vz, Nd4jLong *zShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
if (tadOnlyShapeInfoZ == nullptr) {
tadOnlyShapeInfoZ = tadOnlyShapeInfo;
tadOffsetsZ = tadOffsets;
}
auto x = reinterpret_cast<X*>(vx);
auto y = reinterpret_cast<X*>(vy);
auto z = reinterpret_cast<X*>(vz);
//decompose in to several sub tads after
//moving all dimensions (in sorted order)
//to the back.
//permuted version of the x shape info for setting up the tad problem
__shared__ Nd4jLong tadLength;
__shared__ Nd4jLong tadEWS;
__shared__ int numTads;
__shared__ Nd4jLong xEWS;
__shared__ Nd4jLong zEWS;
if (threadIdx.x == 0) {
tadLength = shape::length(tadOnlyShapeInfo);//shape::tadLength(xShapeInfo, dimension, dimensionLength);
tadEWS = shape::elementWiseStride(tadOnlyShapeInfo);
numTads = shape::length(yShapeInfo) / tadLength;
xEWS = shape::elementWiseStride(xShapeInfo);
zEWS = shape::elementWiseStride(tadOnlyShapeInfoZ);
}
__syncthreads();
for (int r = blockIdx.x; r < numTads; r += gridDim.x) {
auto rZ = z + tadOffsetsZ[r];
auto rY = y + tadOffsets[r];
if(tadEWS > 0 && zEWS > 0 && xEWS > 0 && dimensionLength == 1) {
for (int i = threadIdx.x; i < tadLength; i+= blockDim.x)
rZ[i * zEWS] = OpType::op(x[i * xEWS], rY[i * tadEWS]);
}
else {
// it is expected that x and z tads and y array all have the same length
for (Nd4jLong i = threadIdx.x; i < tadLength; i+= blockDim.x) {
auto xOffset = shape::getIndexOffset(i, xShapeInfo, tadLength);
auto yOffset = shape::getIndexOffset(i, tadOnlyShapeInfo, tadLength);
auto zOffset = shape::getIndexOffset(i, tadOnlyShapeInfoZ, tadLength);
rZ[zOffset] = OpType::op(x[xOffset], rY[yOffset]);
}
}
}
}
//////////////////////////////////////////////////////////////////////////
template<typename X>
template <typename OpType>
__device__ void BroadcastInt<X>::transformCuda(
void *vx, Nd4jLong *xShapeInfo,
void *vy, Nd4jLong *yShapeInfo,
void *vz, Nd4jLong *zShapeInfo,
int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
if (tadOnlyShapeInfoZ == nullptr) {
tadOnlyShapeInfoZ = tadOnlyShapeInfo;
tadOffsetsZ = tadOffsets;
}
auto x = reinterpret_cast<X*>(vx);
auto y = reinterpret_cast<X*>(vy);
auto z = reinterpret_cast<X*>(vz);
//decompose in to several sub tads after
//moving all dimensions (in sorted order)
//to the back.
//permuted version of the x shape info for setting up the tad problem
__shared__ Nd4jLong tadLength;
__shared__ Nd4jLong tadEWS;
__shared__ int numTads;
__shared__ Nd4jLong yEWS;
__shared__ Nd4jLong zEWS;
if (threadIdx.x == 0) {
tadLength = shape::length(tadOnlyShapeInfo);//shape::tadLength(xShapeInfo, dimension, dimensionLength);
tadEWS = shape::elementWiseStride(tadOnlyShapeInfo);
numTads = shape::length(xShapeInfo) / tadLength;
yEWS = shape::elementWiseStride(yShapeInfo);
zEWS = shape::elementWiseStride(tadOnlyShapeInfoZ);
}
__syncthreads();
__shared__ X *rZ;
__shared__ X *rX;
for (int r = blockIdx.x; r < numTads; r += gridDim.x) {
if (threadIdx.x == 0) {
rZ = z + tadOffsetsZ[r];
rX = x + tadOffsets[r];
}
__syncthreads();
if(tadEWS > 0 && zEWS > 0 && yEWS > 0 && dimensionLength == 1) {
for (int i = threadIdx.x; i < tadLength; i+= blockDim.x)
rZ[i * zEWS] = OpType::op(rX[i * tadEWS], y[i * yEWS]);
}
else {
// it is expected that x and z tads and y array all have the same length
for (Nd4jLong i = threadIdx.x; i < tadLength; i+= blockDim.x) {
auto xOffset = shape::getIndexOffset(i, tadOnlyShapeInfo, tadLength);
auto yOffset = shape::getIndexOffset(i, yShapeInfo, tadLength);
auto zOffset = shape::getIndexOffset(i, tadOnlyShapeInfoZ, tadLength);
rZ[zOffset] = OpType::op(rX[xOffset], y[yOffset]);
}
}
}
}
template<typename X>
void BroadcastInt<X>::exec(int opNum,
void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void *result,
Nd4jLong *resultShapeInfo,
int *dimension,
int dimensionLength,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset,
Nd4jLong *tadShapeInfoZ,
Nd4jLong *tadOffsetZ) {
}
template<typename X>
void BroadcastInt<X>::execInverse(int opNum,
void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void *result,
Nd4jLong *resultShapeInfo,
int *dimension,
int dimensionLength,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset,
Nd4jLong *tadShapeInfoZ,
Nd4jLong *tadOffsetZ) {
}
template<typename X>
template<typename OpType>
void BroadcastInt<X>::exec(void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void *result,
Nd4jLong *resultShapeInfo,
int *dimension,
int dimensionLength,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset,
Nd4jLong *tadShapeInfoZ,
Nd4jLong *tadOffsetZ) {
}
template<typename X>
template<typename OpType>
void BroadcastInt<X>::execInverse(void *x,
Nd4jLong *xShapeInfo,
void *y,
Nd4jLong *yShapeInfo,
void *result,
Nd4jLong *resultShapeInfo,
int *dimension,
int dimensionLength,
Nd4jLong *tadShapeInfo,
Nd4jLong *tadOffset,
Nd4jLong *tadShapeInfoZ,
Nd4jLong *tadOffsetZ) {
}
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT BroadcastInt, , INTEGER_TYPES);
}
}

View File

@ -0,0 +1,173 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
// @author raver119@gmail.com
// @author Yurii Shyrma (iuriish@yahoo.com), created on 08.11.2018
#ifndef PAIRWISE_INT_CU
#define PAIRWISE_INT_CU
#include "../pairwise_int.h"
using namespace simdOps;
////////////////////////////////////////////////////////////////////////////////
template <typename X, typename OpType>
__global__ static void pairwiseSimpleShaped(void* vx, Nd4jLong *xShapeInfo,
void *vy, Nd4jLong *yShapeInfo,
void *vz, Nd4jLong *zShapeInfo,
void *vextraParams) {
auto x = reinterpret_cast<X*>(vx);
auto y = reinterpret_cast<X*>(vy);
auto z = reinterpret_cast<X*>(vz);
auto extraParams = reinterpret_cast<X*>(vextraParams);
int tid = blockIdx.x * blockDim.x + threadIdx.x;
__shared__ int xEws;
__shared__ int yEws;
__shared__ int zEws;
__shared__ char xOrder;
__shared__ char yOrder;
__shared__ char zOrder;
__shared__ Nd4jLong len;
if (threadIdx.x == 0) {
xEws = shape::elementWiseStride(xShapeInfo);
yEws = shape::elementWiseStride(yShapeInfo);
zEws = shape::elementWiseStride(zShapeInfo);
xOrder = shape::order(xShapeInfo);
yOrder = shape::order(yShapeInfo);
zOrder = shape::order(zShapeInfo);
len = shape::length(xShapeInfo);
}
__syncthreads();
if (xEws >= 1 && yEws >= 1 && zEws >= 1 && xOrder == yOrder && xOrder == zOrder) {
for (Nd4jLong i = tid; i < len; i += gridDim.x * blockDim.x) {
z[i * zEws] = OpType::op(x[i * xEws], y[i * yEws], extraParams);
}
}
else if (vx == vz) {
for (Nd4jLong i = tid; i < len; i += gridDim.x * blockDim.x) {
auto xOffset = shape::getIndexOffset(i, xShapeInfo, len);
auto yOffset = shape::getIndexOffset(i, yShapeInfo, len);
z[xOffset] = OpType::op(x[xOffset], y[yOffset], extraParams);
}
}
else {
for (Nd4jLong i = tid; i < len; i += gridDim.x * blockDim.x) {
auto xOffset = shape::getIndexOffset(i, xShapeInfo, len);
auto yOffset = shape::getIndexOffset(i, yShapeInfo, len);
auto zOffset = shape::getIndexOffset(i, zShapeInfo, len);
z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams);
}
}
}
namespace functions {
namespace pairwise_transforms {
////////////////////////////////////////////////////////////////////////////////
template<typename X>
template<typename OpType>
void _CUDA_H PairWiseIntTransform<X>::intermediateShaped(dim3& launchDims, cudaStream_t *stream,
void *vx, Nd4jLong *xShapeInfo,
void *vy, Nd4jLong *yShapeInfo,
void *vz, Nd4jLong *zShapeInfo,
void *vextraParams){
pairwiseSimpleShaped<X, OpType><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vextraParams);
}
////////////////////////////////////////////////////////////////////////////////
template<typename X>
void PairWiseIntTransform<X>::executeCudaShaped(dim3& launchDims, cudaStream_t *stream, int opNum, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, void *vextraParams) {
auto xType = nd4j::DataTypeUtils::fromT<X>();
DISPATCH_BY_OPNUM_T(intermediateShaped, PARAMS(launchDims, stream, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vextraParams), PAIRWISE_INT_OPS);
}
template<typename X>
void PairWiseIntTransform<X>::exec(
const int opNum,
void *dx,
Nd4jLong *xShapeBuffer,
void *y,
Nd4jLong *yShapeBuffer,
void *result,
Nd4jLong *resultShapeBuffer,
void *extraParams) {
}
template<typename X>
void PairWiseIntTransform<X>::exec(
const int opNum,
void *dx,
Nd4jLong xStride,
void *y,
Nd4jLong yStride,
void *result,
Nd4jLong resultStride,
void *extraParams,
Nd4jLong n) {
}
template<typename X>
template<typename OpType>
void PairWiseIntTransform<X>::exec(
void *vx,
Nd4jLong* xShapeBuffer,
void *vy,
Nd4jLong* yShapeBuffer,
void *vresult,
Nd4jLong* resultShapeBuffer,
void *vextraParams) {
}
template<typename X>
template<typename OpType>
void PairWiseIntTransform<X>::exec(void *vx,
Nd4jLong xStride,
void *vy,
Nd4jLong yStride,
void *vresult,
Nd4jLong resultStride,
void *vextraParams,
const Nd4jLong n) {
}
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT PairWiseIntTransform, , INTEGER_TYPES);
}
}
#endif // PAIRWISE_INT_CU

View File

@ -0,0 +1,269 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Yurii Shyrma (iuriish@yahoo.com), created on 08.11.2018
// @author raver119@gmail.com
//
#include "../scalar_int.h"
#include <op_boilerplate.h>
#include <types/types.h>
#include "../legacy_ops.h"
using namespace simdOps;
////////////////////////////////////////////////////////////////////////
template <typename X, typename OpType>
__global__ void scalarAlongDimension(void *x, Nd4jLong *xShapeInfo,
void *extraParams,
void *z, Nd4jLong *zShapeInfo,
void *scalars,
int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) {
functions::scalar::ScalarIntTransform<X>::template transformCuda<OpType>(x, xShapeInfo, extraParams, z, zShapeInfo, scalars, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ);
}
////////////////////////////////////////////////////////////////////////
template <typename X, typename OpType>
__global__ void scalarSimpleShaped(void* x, void *y, Nd4jLong *xShapeInfo, void *params, void *z, Nd4jLong *zShapeInfo, int *allocationBuffer) {
functions::scalar::ScalarIntTransform<X>::template transformCuda<OpType>(y, x, xShapeInfo, params, z, zShapeInfo, allocationBuffer);
}
// *********************************************************************//
// *********************************************************************//
namespace functions {
namespace scalar {
////////////////////////////////////////////////////////////////////////
template<typename X>
template<typename OpType>
__device__ void ScalarIntTransform<X>::transformCuda(void* vscalar,
void *vy, Nd4jLong *yShapeInfo,
void *vparams,
void *vz, Nd4jLong *zShapeInfo,
int *allocationBuffer) {
auto scalar = reinterpret_cast<X*>(vscalar)[0];
auto y = reinterpret_cast<X*>(vy);
auto params = reinterpret_cast<X*>(vparams);
auto z = reinterpret_cast<X*>(vz);
auto yRank = shape::rank(yShapeInfo);
auto yEWS = shape::elementWiseStride(yShapeInfo);
auto yShape = shape::shapeOf(yShapeInfo);
auto yStride = shape::stride(yShapeInfo);
auto zRank = shape::rank(zShapeInfo);
auto zEWS = shape::elementWiseStride(zShapeInfo);
auto zShape = shape::shapeOf(zShapeInfo);
auto zStride = shape::stride(zShapeInfo);
int totalThreads = gridDim.x * blockDim.x;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
__shared__ int len;
if(threadIdx.x == 0)
len = shape::length(yShapeInfo);
__syncthreads();
if(yEWS >= 1 && zEWS >= 1 && shape::order(yShapeInfo) == shape::order(zShapeInfo)) {
transformCuda<OpType>(len, vscalar, vy, yEWS, vparams, vz, zEWS, allocationBuffer);
}
else {
for (Nd4jLong i = tid; i < len; i+= totalThreads)
z[shape::getIndexOffset(i, zShapeInfo, len)] = OpType::op(y[shape::getIndexOffset(i, yShapeInfo, len)], scalar, params);
}
}
////////////////////////////////////////////////////////////////////////
template<typename X>
template<typename OpType>
__device__ void ScalarIntTransform<X>::transformCuda(Nd4jLong len,
void* vx,
void *vy, Nd4jLong yEWS,
void *vparams,
void *vz, Nd4jLong zEWS,
int *allocationBuffer) {
auto x = reinterpret_cast<X*>(vx)[0];
auto y = reinterpret_cast<X*>(vy);
auto z = reinterpret_cast<X*>(vz);
auto params = reinterpret_cast<X*>(vparams);
int totalThreads = gridDim.x * blockDim.x;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
Nd4jLong i = tid;
if(yEWS == 1 && zEWS == 1) {
for (; i < len; i += totalThreads)
z[i] = OpType::op(y[i], x, params);
}
else {
for (; i < len; i += totalThreads)
z[i * zEWS] = OpType::op(y[i * yEWS], x, params);
}
}
////////////////////////////////////////////////////////////////////////
template<typename X>
template<typename OpType>
__device__ void ScalarIntTransform<X>::transformCuda(void *vx, Nd4jLong *xShapeInfo,
void *vextraParams,
void *vz, Nd4jLong *zShapeInfo,
void *vscalars,
int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) {
auto x = reinterpret_cast<X*>(vx);
auto scalars = reinterpret_cast<X*>(vscalars);
auto z = reinterpret_cast<X*>(vz);
auto extraParams = reinterpret_cast<X*>(vextraParams);
if (tadShapeInfoZ == nullptr) {
tadShapeInfoZ = tadShapeInfo;
tadOffsetsZ = tadOffsets;
}
// tad preparation
auto tadEws = shape::elementWiseStride(tadShapeInfo);
auto zEws = shape::elementWiseStride(tadShapeInfoZ);
auto tadLength = shape::length(tadShapeInfo);//shape::tadLength(xShapeInfo, dimension, dimensionLength);
auto numTads =shape::length(xShapeInfo) / tadLength;
if (tadEws > 0 && zEws > 0 && shape::order(tadShapeInfo) == shape::order(zShapeInfo)) {
// main loop, rolling over tads
for (int r = blockIdx.x; r < numTads; r += gridDim.x) {
X *oZ = z + tadOffsetsZ[r];
X *oX = x + tadOffsets[r];
auto s = scalars[r];
for (int f = threadIdx.x; f < tadLength; f += blockDim.x)
oZ[f * zEws] = OpType::op(oX[f * tadEws], s, extraParams);
}
} else {
// main loop, rolling over tads
for (int r = blockIdx.x; r < numTads; r += gridDim.x) {
X *oZ = z + tadOffsetsZ[r];
X *oX = x + tadOffsets[r];
auto s = scalars[r];
for (int f = threadIdx.x; f < tadLength; f += blockDim.x)
oZ[shape::getIndexOffset(f, tadShapeInfoZ, tadLength)] = OpType::op(oX[shape::getIndexOffset(f, tadShapeInfo, tadLength)], s, extraParams);
}
}
}
////////////////////////////////////////////////////////////////////////
template<typename X>
template <typename OpType>
_CUDA_H void ScalarIntTransform<X>::intermediateAlongDimension(dim3& launchDims, cudaStream_t *stream,
void *x, Nd4jLong *xShapeInfo,
void *z, Nd4jLong *zShapeInfo,
void *scalars,
void *extraParams,
int *dimension, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) {
scalarAlongDimension<X, OpType><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShapeInfo, extraParams, z, zShapeInfo, scalars, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ);
}
////////////////////////////////////////////////////////////////////////
template<typename X>
template<typename OpType>
void _CUDA_H ScalarIntTransform<X>::intermediateShaped(dim3& launchDims, cudaStream_t *stream,
void *vx, Nd4jLong *xShapeInfo,
void *vz, Nd4jLong *zShapeInfo,
void* vscalar,
void *vextraParams, int *allocPointer){
scalarSimpleShaped<X, OpType><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, vscalar, xShapeInfo, vextraParams, vz, zShapeInfo, allocPointer);
}
////////////////////////////////////////////////////////////////////////
template<typename X>
void ScalarIntTransform<X>::executeCudaShaped(dim3& launchDims, cudaStream_t *stream,
int opNum,
void *vx, Nd4jLong *xShapeInfo,
void *vz, Nd4jLong *zShapeInfo,
void* vscalar,
void *vextraParams) {
if (nd4j::Environment::getInstance()->isDebugAndVerbose())
printf("H14 opNum:[%i]\n", opNum);
DISPATCH_BY_OPNUM_T(intermediateShaped, PARAMS(launchDims, stream, vx, xShapeInfo, vz, zShapeInfo, vscalar, vextraParams, nullptr), SCALAR_INT_OPS);
}
////////////////////////////////////////////////////////////////////////
template<typename X>
void ScalarIntTransform<X>::executeCudaAlongDimension(dim3& launchDims, cudaStream_t *stream, int opNum, void *vx, Nd4jLong *xShapeInfo, void *vz, Nd4jLong *zShapeInfo, void *vscalars, void *vextraParams, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) {
DISPATCH_BY_OPNUM_T(intermediateAlongDimension, PARAMS(launchDims, stream, vx, xShapeInfo, vz, zShapeInfo, vscalars, vextraParams, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), SCALAR_INT_OPS);
}
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT ScalarIntTransform, , INTEGER_TYPES);
template<typename X>
template <typename OpType>
void ScalarIntTransform<X,>::transform(void *x, Nd4jLong *xShapeInfo, void *extraParams, void *z, Nd4jLong *zShapeInfo, void *scalars, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) {
}
template<typename X>
void ScalarIntTransform<X>::transform(int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams, void *z, Nd4jLong *zShapeInfo, void *scalars, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) {
}
template<typename X>
void ScalarIntTransform<X>::transform(const int opNum, void *x, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *scalar, void *extraParams) {
}
template<typename X>
void ScalarIntTransform<X>::transform(const int opNum, void *x, Nd4jLong xStride, void *result, Nd4jLong resultStride, void *scalar, void *extraParams, const Nd4jLong n) {
}
template<typename X>
template<typename OpType>
void ScalarIntTransform<X>::transform(void *x, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *scalar, void *extraParams) {
}
template<typename X>
template<typename OpType>
void ScalarIntTransform<X>::transform(void *x, Nd4jLong xStride, void *result, Nd4jLong resultStride, void *scalar, void *extraParams, const Nd4jLong n) {
}
}
}

View File

@ -29,6 +29,16 @@
(4, aggregateOps::CBOW) ,\
(5, aggregateOps::GEMM)
#define BROADCAST_INT_OPS \
(0, ShiftLeft), \
(1, ShiftRight), \
(2, CyclicShiftLeft), \
(3, CyclicShiftRight), \
(4, IntAnd), \
(5, IntOr), \
(6, IntXor)
#define BROADCAST_BOOL_OPS \
(0, EqualTo),\
(1, GreaterThan),\
@ -171,6 +181,14 @@
(0, SummaryStatsVariance), \
(1, SummaryStatsStandardDeviation)
#define SCALAR_INT_OPS \
(0, ShiftLeft), \
(1, ShiftRight), \
(2, CyclicShiftLeft), \
(3, CyclicShiftRight), \
(4, IntAnd), \
(5, IntOr), \
(6, IntXor)
#define SCALAR_BOOL_OPS \
(0, EqualTo),\
@ -300,6 +318,15 @@
(13, ExponentialDistribution),\
(14, ExponentialDistributionInv)
#define PAIRWISE_INT_OPS \
(0, ShiftLeft), \
(1, ShiftRight), \
(2, CyclicShiftLeft), \
(3, CyclicShiftRight), \
(4, IntAnd), \
(5, IntOr), \
(6, IntXor)
#define PAIRWISE_BOOL_OPS \
(0, EqualTo),\
(1, GreaterThan),\

View File

@ -0,0 +1,119 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
/*
* pairwise_transform.h
*
* Created on: Dec 28, 2015
* Author: agibsonccc
*/
#ifndef PAIRWISE_INT_H_
#define PAIRWISE_INT_H_
#ifdef _OPENMP
#include <omp.h>
#endif
#include <templatemath.h>
#include <helpers/shape.h>
#include <pairwise_util.h>
#include <dll.h>
#include <stdio.h>
#include <ops/ops.h>
#include <op_boilerplate.h>
#include <helpers/DebugHelper.h>
#ifdef __CUDACC__
#include <cuda.h>
#include <cuda_runtime.h>
#endif
#ifndef _OPENMP
#define omp_get_thread_num() 0
#define omp_get_max_threads() 1
#endif
#include "legacy_ops.h"
using namespace simdOps;
namespace functions {
namespace pairwise_transforms {
/**
* Transforms involving 2 arrays
*/
template<typename X>
class PairWiseIntTransform {
public:
#ifdef __CUDACC__
template <typename OpType>
static __host__ void intermediateShaped(dim3& launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, void *vextraParams);
static __host__ void executeCudaShaped(dim3& launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, void *extraParams);
#endif
public:
static void exec(
const int opNum,
void *dx,
Nd4jLong *xShapeBuffer,
void *y,
Nd4jLong *yShapeBuffer,
void *result,
Nd4jLong *resultShapeBuffer,
void *extraParams);
static void exec(
const int opNum,
void *dx,
Nd4jLong xStride,
void *y,
Nd4jLong yStride,
void *result,
Nd4jLong resultStride,
void *extraParams,
Nd4jLong n);
template<typename OpType>
static void exec(
void *vx,
Nd4jLong* xShapeBuffer,
void *vy,
Nd4jLong* yShapeBuffer,
void *vresult,
Nd4jLong* resultShapeBuffer,
void *vextraParams);
template<typename OpType>
static void exec(void *vx,
Nd4jLong xStride,
void *vy,
Nd4jLong yStride,
void *vresult,
Nd4jLong resultStride,
void *vextraParams,
const Nd4jLong n);
};
}
}
#endif /* PAIRWISE_TRANSFORM_H_ */

View File

@ -0,0 +1,142 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
/*
* scalar.h
*
* Created on: Dec 28, 2015
* Author: agibsonccc
*/
#ifndef SCALAR_INT_H_
#define SCALAR_INT_H_
#include <dll.h>
#ifdef __JNI__
#include <jni.h>
#endif
#include <templatemath.h>
#include <ops/ops.h>
#include <op_boilerplate.h>
#include "helpers/logger.h"
#include <OmpLaunchHelper.h>
#include <helpers/DebugHelper.h>
#ifdef __CUDACC__
#include <cuda.h>
#include <cuda_runtime.h>
#include <types/float16.h>
#endif
#include "legacy_ops.h"
namespace functions {
namespace scalar {
/**
* Apply a scalar
* operation to an array
*/
template<typename X>
class ScalarIntTransform {
public:
#ifdef __CUDACC__
template<typename OpType>
__device__
static void transformCuda(void* scalar, void *vy, Nd4jLong *shapeInfo, void *vparams, void *vresult, Nd4jLong *resultShapeInfo, int *allocationBuffer);
template<typename OpType>
__device__
static void transformCuda(Nd4jLong n, void* vx, void *vy, Nd4jLong yEWS, void *vparams, void *vz, Nd4jLong zEWS, int *allocationBuffer);
template<typename OpType>
__device__
static void transformCuda(void *vx, Nd4jLong *xShapeInfo, void *vextraParams, void *vz, Nd4jLong *zShapeInfo, void *vscalars, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ);
template <typename OpType>
__host__
static void intermediateAlongDimension(dim3& launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *z, Nd4jLong *zShapeInfo, void *scalars, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ);
template <typename OpType>
__host__
static void intermediateShaped(dim3& launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vz, Nd4jLong *zShapeInfo, void* vscalar, void *vextraParams, int *allocPointer);
__host__
static void executeCudaShaped(dim3& launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void* scalar, void *extraParams);
__host__
static void executeCudaAlongDimension(dim3& launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *z, Nd4jLong *zShapeInfo, void *scalars, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ);
/*
#include "cuda/scalar_temp.cu"
*/
#endif
template <typename OpType>
static void transform(void *x, Nd4jLong *xShapeInfo, void *extraParams, void *z, Nd4jLong *zShapeInfo, void *scalars, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ);
static void transform(int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams, void *z, Nd4jLong *zShapeInfo, void *scalars, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ);
static void transform(const int opNum, void *x, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *scalar, void *extraParams);
static void transform(const int opNum, void *x, Nd4jLong xStride, void *result, Nd4jLong resultStride, void *scalar, void *extraParams, const Nd4jLong n);
/*
* ScalarOp along dimension
*/
/**
* CPU implementation of scalar operation
* @param x the input
* @param xStride the stride for the input
* @param result the result buffer
* @param resultStride the stride for the result
* @param scalar the scalar to apply
* @param extraParams the extra parameters where
* neccssary
* @param n the number of elements to loop over
*/
template<typename OpType>
static void transform(void *x, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *scalar, void *extraParams);
/**
* CPU implementation of scalar operation
* @param x the input
* @param xStride the stride for the input
* @param result the result buffer
* @param resultStride the stride for the result
* @param scalar the scalar to apply
* @param extraParams the extra parameters where
* neccssary
* @param n the number of elements to loop over
*/
template<typename OpType>
static void transform(void *x, Nd4jLong xStride, void *result, Nd4jLong resultStride, void *scalar, void *extraParams, const Nd4jLong n);
};
}
}
#endif /* SCALAR_H_ */

View File

@ -63,6 +63,10 @@ namespace nd4j {
enum BoolOps {
BUILD_ENUMERATION(PAIRWISE_BOOL_OPS)
};
enum IntOps {
BUILD_ENUMERATION(PAIRWISE_INT_OPS)
};
}
namespace scalar {
@ -73,6 +77,10 @@ namespace nd4j {
enum BoolOps {
BUILD_ENUMERATION(SCALAR_BOOL_OPS)
};
enum IntOps {
BUILD_ENUMERATION(SCALAR_INT_OPS)
};
}
namespace reduce {
@ -113,6 +121,10 @@ namespace nd4j {
enum BoolOps {
BUILD_ENUMERATION(BROADCAST_BOOL_OPS)
};
enum IntOps {
BUILD_ENUMERATION(BROADCAST_INT_OPS)
};
}
namespace variance {

View File

@ -0,0 +1,49 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#ifndef DEV_TESTS_BROADCASTINTOPSTUPLE_H
#define DEV_TESTS_BROADCASTINTOPSTUPLE_H
#include <op_enums.h>
namespace nd4j {
class BroadcastIntOpsTuple {
private:
public:
nd4j::scalar::IntOps s;
nd4j::pairwise::IntOps p;
nd4j::broadcast::IntOps b;
BroadcastIntOpsTuple() = default;
~BroadcastIntOpsTuple() = default;
BroadcastIntOpsTuple(nd4j::scalar::IntOps scalar, nd4j::pairwise::IntOps pairwise, nd4j::broadcast::IntOps broadcast) {
s = scalar;
p = pairwise;
b = broadcast;
}
static BroadcastIntOpsTuple custom(nd4j::scalar::IntOps scalar, nd4j::pairwise::IntOps pairwise, nd4j::broadcast::IntOps broadcast);
};
}
#endif //DEV_TESTS_BROADCASTOPSTUPLE_H

View File

@ -27,22 +27,14 @@
namespace nd4j {
namespace ops {
CONFIGURABLE_OP_IMPL(cyclic_rshift_bits, 1, 1, true, 0, -2) {
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
BROADCASTABLE_OP_IMPL(cyclic_rshift_bits, 0, 0) {
auto x = INPUT_VARIABLE(0);
auto y = INPUT_VARIABLE(1);
auto z = OUTPUT_VARIABLE(0);
REQUIRE_TRUE(block.numI() > 0 || block.width() > 1, 0, "cyclic_rshift_bits: actual shift value is missing");
BROADCAST_CHECK_EMPTY(x,y,z);
uint32_t shift = 0;
if (block.width() > 1) {
shift = INPUT_VARIABLE(1)->e<uint32_t>(0);
} else if (block.numI() > 0) {
shift = INT_ARG(0);
};
helpers::cyclic_rshift_bits(block.launchContext(), *input, *output, shift);
REQUIRE_TRUE(shift > 0 && shift < input->sizeOfT() * 8, 0, "cyclic_rshift_bits: can't shift beyond size of data type")
x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::CyclicShiftRight, pairwise::CyclicShiftRight, broadcast::CyclicShiftRight), y, z, false);
return Status::OK();
}

View File

@ -27,22 +27,14 @@
namespace nd4j {
namespace ops {
CONFIGURABLE_OP_IMPL(cyclic_shift_bits, 1, 1, true, 0, -2) {
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
BROADCASTABLE_OP_IMPL(cyclic_shift_bits, 0, 0) {
auto x = INPUT_VARIABLE(0);
auto y = INPUT_VARIABLE(1);
auto z = OUTPUT_VARIABLE(0);
REQUIRE_TRUE(block.numI() > 0 || block.width() > 1, 0, "cyclic_shift_bits: actual shift value is missing");
BROADCAST_CHECK_EMPTY(x,y,z);
uint32_t shift = 0;
if (block.width() > 1) {
shift = INPUT_VARIABLE(1)->e<uint32_t>(0);
} else if (block.numI() > 0) {
shift = INT_ARG(0);
};
helpers::cyclic_shift_bits(block.launchContext(), *input, *output, shift);
REQUIRE_TRUE(shift > 0 && shift < input->sizeOfT() * 8, 0, "cyclic_shift_bits: can't shift beyond size of data type")
x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::CyclicShiftLeft, pairwise::CyclicShiftLeft, broadcast::CyclicShiftLeft), y, z, false);
return Status::OK();
}

View File

@ -27,22 +27,14 @@
namespace nd4j {
namespace ops {
CONFIGURABLE_OP_IMPL(rshift_bits, 1, 1, true, 0, -2) {
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
BROADCASTABLE_OP_IMPL(rshift_bits, 0, 0) {
auto x = INPUT_VARIABLE(0);
auto y = INPUT_VARIABLE(1);
auto z = OUTPUT_VARIABLE(0);
REQUIRE_TRUE(block.numI() > 0 || block.width() > 1, 0, "rshift_bits: actual shift value is missing");
BROADCAST_CHECK_EMPTY(x,y,z);
uint32_t shift = 0;
if (block.width() > 1) {
shift = INPUT_VARIABLE(1)->e<uint32_t>(0);
} else if (block.numI() > 0) {
shift = INT_ARG(0);
};
REQUIRE_TRUE(shift > 0 && shift < input->sizeOfT() * 8, 0, "rshift_bits: can't shift beyond size of data type")
helpers::rshift_bits(block.launchContext(), *input, *output, shift);
x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::ShiftRight, pairwise::ShiftRight, broadcast::ShiftRight), y, z, false);
return Status::OK();
}

View File

@ -27,22 +27,14 @@
namespace nd4j {
namespace ops {
CONFIGURABLE_OP_IMPL(shift_bits, 1, 1, true, 0, -2) {
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
BROADCASTABLE_OP_IMPL(shift_bits, 0, 0) {
auto x = INPUT_VARIABLE(0);
auto y = INPUT_VARIABLE(1);
auto z = OUTPUT_VARIABLE(0);
REQUIRE_TRUE(block.numI() > 0 || block.width() > 1, 0, "shift_bits: actual shift value is missing");
BROADCAST_CHECK_EMPTY(x,y,z);
uint32_t shift = 0;
if (block.width() > 1) {
shift = INPUT_VARIABLE(1)->e<uint32_t>(0);
} else if (block.numI() > 0) {
shift = INT_ARG(0);
};
REQUIRE_TRUE(shift > 0 && shift < input->sizeOfT() * 8, 0, "shift_bits: can't shift beyond size of data type")
helpers::shift_bits(block.launchContext(), *input, *output, shift);
x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::ShiftLeft, pairwise::ShiftLeft, broadcast::ShiftLeft), y, z, false);
return Status::OK();
}

View File

@ -45,7 +45,7 @@ namespace nd4j {
* @tparam T
*/
#if NOT_EXCLUDED(OP_shift_bits)
DECLARE_CONFIGURABLE_OP(shift_bits, 1, 1, true, 0, -2);
DECLARE_BROADCASTABLE_OP(shift_bits, 0, 0);
#endif
/**
@ -56,7 +56,7 @@ namespace nd4j {
* @tparam T
*/
#if NOT_EXCLUDED(OP_rshift_bits)
DECLARE_CONFIGURABLE_OP(rshift_bits, 1, 1, true, 0, -2);
DECLARE_BROADCASTABLE_OP(rshift_bits, 0, 0);
#endif
/**
@ -67,7 +67,7 @@ namespace nd4j {
* @tparam T
*/
#if NOT_EXCLUDED(OP_cyclic_shift_bits)
DECLARE_CONFIGURABLE_OP(cyclic_shift_bits, 1, 1, true, 0, -2);
DECLARE_BROADCASTABLE_OP(cyclic_shift_bits, 0, 0);
#endif
/**
@ -78,7 +78,7 @@ namespace nd4j {
* @tparam T
*/
#if NOT_EXCLUDED(OP_cyclic_rshift_bits)
DECLARE_CONFIGURABLE_OP(cyclic_rshift_bits, 1, 1, true, 0, -2);
DECLARE_BROADCASTABLE_OP(cyclic_rshift_bits, 0, 0);
#endif
/**

View File

@ -0,0 +1,27 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <ops/BroadcastIntOpsTuple.h>
namespace nd4j {
BroadcastIntOpsTuple BroadcastIntOpsTuple::custom(nd4j::scalar::IntOps scalar, nd4j::pairwise::IntOps pairwise, nd4j::broadcast::IntOps broadcast) {
BroadcastIntOpsTuple t(scalar, pairwise, broadcast);
return t;
}
}

View File

@ -660,6 +660,98 @@ namespace simdOps {
}
};
template <typename X>
class IntOr {
public:
op_def static X op(X d1, X d2) {
return d2 | d1;
}
op_def static X op(X d1, X d2, X *params) {
return op(d1, d2);
}
};
template <typename X>
class IntAnd {
public:
op_def static X op(X d1, X d2) {
return d2 & d1;
}
op_def static X op(X d1, X d2, X *params) {
return op(d1, d2);
}
};
template <typename X>
class IntXor {
public:
op_def static X op(X d1, X d2) {
return d2 ^ d1;
}
op_def static X op(X d1, X d2, X *params) {
return op(d1, d2);
}
};
template <typename X>
class ShiftLeft {
public:
op_def static X op(X d1, X d2) {
return d1 << d2;
}
op_def static X op(X d1, X d2, X *params) {
return op(d1, d2);
}
};
template <typename X>
class ShiftRight {
public:
op_def static X op(X d1, X d2) {
return d1 >> d2;
}
op_def static X op(X d1, X d2, X *params) {
return op(d1, d2);
}
};
template <typename X>
class CyclicShiftLeft {
public:
op_def static X op(X d1, X d2) {
return d1 << d2 | d1 >> ((sizeof(X) * 8) - d2);
}
op_def static X op(X d1, X d2, X *params) {
return op(d1, d2);
}
};
template <typename X>
class CyclicShiftRight {
public:
op_def static X op(X d1, X d2) {
return d1 >> d2 | d1 << ((sizeof(X) * 8) - d2);
}
op_def static X op(X d1, X d2, X *params) {
return op(d1, d2);
}
};
template <typename X, typename Z>
class Or {
public:

View File

@ -623,12 +623,13 @@ TEST_F(DeclarableOpsTests13, adjustSaturation_5) {
TEST_F(DeclarableOpsTests13, shift_bits_1) {
auto x = NDArrayFactory::create<int>('c', {5});
auto y = NDArrayFactory::create<int>(4);
auto e = x.ulike();
x.assign(32);
e.assign(512);
nd4j::ops::shift_bits op;
auto result = op.execute({&x}, {}, {4});
auto result = op.execute({&x, &y}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
@ -640,12 +641,13 @@ TEST_F(DeclarableOpsTests13, shift_bits_1) {
TEST_F(DeclarableOpsTests13, rshift_bits_1) {
auto x = NDArrayFactory::create<int>('c', {5});
auto y = NDArrayFactory::create<int>(4);
auto e = x.ulike();
x.assign(512);
e.assign(32);
nd4j::ops::rshift_bits op;
auto result = op.execute({&x}, {}, {4});
auto result = op.execute({&x, &y}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
@ -657,12 +659,13 @@ TEST_F(DeclarableOpsTests13, rshift_bits_1) {
TEST_F(DeclarableOpsTests13, cyclic_shift_bits_1) {
auto x = NDArrayFactory::create<int>('c', {5});
auto y = NDArrayFactory::create<int>(4);
auto e = x.ulike();
x.assign(32);
e.assign(512);
nd4j::ops::cyclic_shift_bits op;
auto result = op.execute({&x}, {}, {4});
auto result = op.execute({&x, &y}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
@ -674,12 +677,107 @@ TEST_F(DeclarableOpsTests13, cyclic_shift_bits_1) {
TEST_F(DeclarableOpsTests13, cyclic_rshift_bits_1) {
auto x = NDArrayFactory::create<int>('c', {5});
auto y = NDArrayFactory::create<int>(4);
auto e = x.ulike();
x.assign(512);
e.assign(32);
nd4j::ops::cyclic_rshift_bits op;
auto result = op.execute({&x}, {}, {4});
auto result = op.execute({&x, &y}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_EQ(e, *z);
delete result;
}
TEST_F(DeclarableOpsTests13, shift_bits_2) {
auto x = NDArrayFactory::create<int>('c', {5});
auto y = NDArrayFactory::create<int>('c', {5});
auto e = x.ulike();
x.assign(32);
y.assign(4);
e.assign(512);
nd4j::ops::shift_bits op;
auto result = op.execute({&x, &y}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_EQ(e, *z);
delete result;
}
TEST_F(DeclarableOpsTests13, rshift_bits_2) {
auto x = NDArrayFactory::create<int>('c', {5});
auto y = NDArrayFactory::create<int>('c', {5});
auto e = x.ulike();
x.assign(512);
y.assign(4);
e.assign(32);
nd4j::ops::rshift_bits op;
auto result = op.execute({&x, &y}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_EQ(e, *z);
delete result;
}
TEST_F(DeclarableOpsTests13, cyclic_shift_bits_2) {
auto x = NDArrayFactory::create<int>('c', {5});
auto y = NDArrayFactory::create<int>('c', {5});
auto e = x.ulike();
x.assign(32);
y.assign(4);
e.assign(512);
nd4j::ops::cyclic_shift_bits op;
auto result = op.execute({&x, &y}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_EQ(e, *z);
delete result;
}
TEST_F(DeclarableOpsTests13, cyclic_rshift_bits_2) {
auto x = NDArrayFactory::create<int>('c', {5});
auto y = NDArrayFactory::create<int>('c', {5});
auto e = x.ulike();
x.assign(512);
y.assign(4);
e.assign(32);
nd4j::ops::cyclic_rshift_bits op;
auto result = op.execute({&x, &y}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_EQ(e, *z);
delete result;
}
TEST_F(DeclarableOpsTests13, shift_bits_3) {
auto x = NDArrayFactory::create<int>('c', {5, 5});
auto y = NDArrayFactory::create<int>('c', {1, 5});
auto e = x.ulike();
x.assign(32);
y.assign(4);
e.assign(512);
nd4j::ops::shift_bits op;
auto result = op.execute({&x, &y}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);

View File

@ -1234,19 +1234,19 @@ public class DifferentialFunctionFactory {
return new Xor(sameDiff(), ix, iy).outputVariable();
}
public SDVariable shift(SDVariable ix, int shift) {
public SDVariable shift(SDVariable ix, SDVariable shift) {
return new ShiftBits(sameDiff(), ix, shift).outputVariable();
}
public SDVariable rshift(SDVariable ix, int shift) {
public SDVariable rshift(SDVariable ix, SDVariable shift) {
return new RShiftBits(sameDiff(), ix, shift).outputVariable();
}
public SDVariable rotl(SDVariable ix, int shift) {
public SDVariable rotl(SDVariable ix, SDVariable shift) {
return new CyclicShiftBits(sameDiff(), ix, shift).outputVariable();
}
public SDVariable rotr(SDVariable ix, int shift) {
public SDVariable rotr(SDVariable ix, SDVariable shift) {
return new CyclicRShiftBits(sameDiff(), ix, shift).outputVariable();
}

View File

@ -2428,7 +2428,7 @@ public class SDMath extends SDOps {
* @param x Input 1
* @return Output SDVariable with shifted bits
*/
public SDVariable bitShift(String name, SDVariable x, int shift) {
public SDVariable bitShift(String name, SDVariable x, SDVariable shift) {
validateInteger("shift_bits", x);
SDVariable result = f().shift(x, shift);
return updateVariableNameAndReference(result, name);
@ -2441,7 +2441,7 @@ public class SDMath extends SDOps {
* @param x Input 1
* @return Output SDVariable with shifted bits
*/
public SDVariable bitShiftRight(String name, SDVariable x, int shift) {
public SDVariable bitShiftRight(String name, SDVariable x, SDVariable shift) {
validateInteger("rshift_bits", x);
SDVariable result = f().rshift(x, shift);
return updateVariableNameAndReference(result, name);
@ -2454,7 +2454,7 @@ public class SDMath extends SDOps {
* @param x Input 1
* @return Output SDVariable with shifted bits
*/
public SDVariable bitRotl(String name, SDVariable x, int shift) {
public SDVariable bitRotl(String name, SDVariable x, SDVariable shift) {
validateInteger("cyclic_shift_bits", x);
SDVariable result = f().rotl(x, shift);
return updateVariableNameAndReference(result, name);
@ -2467,7 +2467,7 @@ public class SDMath extends SDOps {
* @param x Input 1
* @return Output SDVariable with shifted bits
*/
public SDVariable bitRotr(String name, SDVariable x, int shift) {
public SDVariable bitRotr(String name, SDVariable x, SDVariable shift) {
validateInteger("cyclic_rshift_bits", x);
SDVariable result = f().rotr(x, shift);
return updateVariableNameAndReference(result, name);

View File

@ -34,18 +34,16 @@ import java.util.List;
*/
public class CyclicRShiftBits extends BaseDynamicTransformOp {
public CyclicRShiftBits(SameDiff sameDiff, SDVariable x, int shift) {
super(sameDiff, new SDVariable[] {x} ,false);
this.addIArgument(shift);
public CyclicRShiftBits(SameDiff sameDiff, SDVariable x, SDVariable shift) {
super(sameDiff, new SDVariable[] {x, shift} ,false);
}
public CyclicRShiftBits(INDArray input, int shift, INDArray output) {
super(new INDArray[]{input}, new INDArray[]{output});
this.addIArgument(shift);
public CyclicRShiftBits(INDArray input, INDArray shift, INDArray output) {
super(new INDArray[]{input, shift}, new INDArray[]{output});
}
public CyclicRShiftBits(INDArray input, int shift) {
this(input, shift,null);
public CyclicRShiftBits(INDArray input, INDArray shift) {
this(input, shift,input.ulike());
}
public CyclicRShiftBits() {}

View File

@ -34,18 +34,16 @@ import java.util.List;
*/
public class CyclicShiftBits extends BaseDynamicTransformOp {
public CyclicShiftBits(SameDiff sameDiff, SDVariable x, int shift) {
super(sameDiff, new SDVariable[] {x} ,false);
this.addIArgument(shift);
public CyclicShiftBits(SameDiff sameDiff, SDVariable x, SDVariable shift) {
super(sameDiff, new SDVariable[] {x, shift} ,false);
}
public CyclicShiftBits(INDArray input, int shift, INDArray output) {
super(new INDArray[]{input}, new INDArray[]{output});
this.addIArgument(shift);
public CyclicShiftBits(INDArray input, INDArray shift, INDArray output) {
super(new INDArray[]{input, shift}, new INDArray[]{output});
}
public CyclicShiftBits(INDArray input, int shift) {
this(input, shift,null);
public CyclicShiftBits(INDArray input, INDArray shift) {
this(input, shift,input.ulike());
}
public CyclicShiftBits() {}

View File

@ -34,18 +34,16 @@ import java.util.List;
*/
public class RShiftBits extends BaseDynamicTransformOp {
public RShiftBits(SameDiff sameDiff, SDVariable x, int shift) {
super(sameDiff, new SDVariable[] {x} ,false);
this.addIArgument(shift);
public RShiftBits(SameDiff sameDiff, SDVariable x, SDVariable y) {
super(sameDiff, new SDVariable[] {x, y} ,false);
}
public RShiftBits(INDArray input, int shift, INDArray output) {
super(new INDArray[]{input}, new INDArray[]{output});
this.addIArgument(shift);
public RShiftBits(INDArray input, INDArray shift, INDArray output) {
super(new INDArray[]{input, shift}, new INDArray[]{output});
}
public RShiftBits(INDArray input, int shift) {
this(input, shift,null);
public RShiftBits(INDArray input, INDArray shift) {
this(input, shift,input.ulike());
}
public RShiftBits() {}

View File

@ -34,18 +34,16 @@ import java.util.List;
*/
public class ShiftBits extends BaseDynamicTransformOp {
public ShiftBits(SameDiff sameDiff, SDVariable x, int shift) {
super(sameDiff, new SDVariable[] {x} ,false);
this.addIArgument(shift);
public ShiftBits(SameDiff sameDiff, SDVariable x, SDVariable y) {
super(sameDiff, new SDVariable[] {x, y} ,false);
}
public ShiftBits(INDArray input, int shift, INDArray output) {
super(new INDArray[]{input}, new INDArray[]{output});
this.addIArgument(shift);
public ShiftBits(INDArray x, INDArray y, INDArray output) {
super(new INDArray[]{x, y}, new INDArray[]{output});
}
public ShiftBits(INDArray input, int shift) {
this(input, shift,null);
public ShiftBits(INDArray x, INDArray y) {
this(x, y,x.ulike());
}
public ShiftBits() {}

View File

@ -3594,6 +3594,7 @@ public native @Cast("Nd4jPointer") Pointer lcSolverHandle(OpaqueLaunchContext lc
// #include <op_enums.h>
// #include <ops/BroadcastOpsTuple.h>
// #include <ops/BroadcastBoolOpsTuple.h>
// #include <ops/BroadcastIntOpsTuple.h>
// #include <array/ExtraArguments.h>
// #include <Status.h>
// #include <ShapeDescriptor.h>

View File

@ -3594,6 +3594,7 @@ public native @Cast("Nd4jPointer") Pointer lcSolverHandle(OpaqueLaunchContext lc
// #include <op_enums.h>
// #include <ops/BroadcastOpsTuple.h>
// #include <ops/BroadcastBoolOpsTuple.h>
// #include <ops/BroadcastIntOpsTuple.h>
// #include <array/ExtraArguments.h>
// #include <Status.h>
// #include <ShapeDescriptor.h>
@ -16496,15 +16497,18 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
/**
* creates identity 2D matrix or batch of identical 2D identity matrices
*
*
* Input array:
* provide some array - in any case operation simply neglects it
*
*
* Input float argument (if passed):
* TArgs[0] - type of elements of output array, default value is 5 (float)
*
* Input integer arguments:
* IArgs[0] - order of output identity matrix, 99 -> 'c'-order, 102 -> 'f'-order
* IArgs[1] - the number of rows in output inner-most 2D identity matrix
* IArgs[2] - optional, the number of columns in output inner-most 2D identity matrix, if this argument is not provided then it is taken to be equal to number of rows
* IArgs[3,4,...] - optional, shape of batch, output matrix will have leading batch dimensions of this shape
* IArgs[3,4,...] - optional, shape of batch, output matrix will have leading batch dimensions of this shape
*/
// #if NOT_EXCLUDED(OP_eye)
@Namespace("nd4j::ops") public static class eye extends DeclarableCustomOp {
@ -16598,10 +16602,10 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
/**
* clip a list of given tensors with given average norm when needed
*
*
* Input:
* a list of tensors (at least one)
*
*
* Input floating point argument:
* clip_norm - a value that used as threshold value and norm to be used
*
@ -16749,12 +16753,12 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
/**
* returns histogram (as 1D array) with fixed bins width
*
*
* Input arrays:
* - input array with elements to be binned into output histogram
* - input array with elements to be binned into output histogram
* - range array with first element being bottom limit and second element being top limit of histogram,
please note that input_value <= range[0] will be mapped to histogram[0], input_value >= range[1] will be mapped to histogram[-1]
*
*
* Input integer arguments:
* nbins (optional) - number of histogram bins, default value is 100
*/
@ -21822,7 +21826,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
* \tparam T
*/
// #if NOT_EXCLUDED(OP_shift_bits)
@Namespace("nd4j::ops") public static class shift_bits extends DeclarableOp {
@Namespace("nd4j::ops") public static class shift_bits extends BroadcastableOp {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public shift_bits(Pointer p) { super(p); }
@ -21835,7 +21839,6 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
public shift_bits() { super((Pointer)null); allocate(); }
private native void allocate();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
}
// #endif
@ -21847,7 +21850,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
* \tparam T
*/
// #if NOT_EXCLUDED(OP_rshift_bits)
@Namespace("nd4j::ops") public static class rshift_bits extends DeclarableOp {
@Namespace("nd4j::ops") public static class rshift_bits extends BroadcastableOp {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public rshift_bits(Pointer p) { super(p); }
@ -21860,7 +21863,6 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
public rshift_bits() { super((Pointer)null); allocate(); }
private native void allocate();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
}
// #endif
@ -21872,7 +21874,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
* \tparam T
*/
// #if NOT_EXCLUDED(OP_cyclic_shift_bits)
@Namespace("nd4j::ops") public static class cyclic_shift_bits extends DeclarableOp {
@Namespace("nd4j::ops") public static class cyclic_shift_bits extends BroadcastableOp {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public cyclic_shift_bits(Pointer p) { super(p); }
@ -21885,7 +21887,6 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
public cyclic_shift_bits() { super((Pointer)null); allocate(); }
private native void allocate();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
}
// #endif
@ -21897,7 +21898,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
* \tparam T
*/
// #if NOT_EXCLUDED(OP_cyclic_rshift_bits)
@Namespace("nd4j::ops") public static class cyclic_rshift_bits extends DeclarableOp {
@Namespace("nd4j::ops") public static class cyclic_rshift_bits extends BroadcastableOp {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public cyclic_rshift_bits(Pointer p) { super(p); }
@ -21910,6 +21911,30 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
public cyclic_rshift_bits() { super((Pointer)null); allocate(); }
private native void allocate();
}
// #endif
/**
* This operation returns hamming distance based on bits
*
* PLEASE NOTE: This operation is applicable only to integer data types
*
* \tparam T
*/
// #if NOT_EXCLUDED(OP_bits_hamming_distance)
@Namespace("nd4j::ops") public static class bits_hamming_distance extends DeclarableCustomOp {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public bits_hamming_distance(Pointer p) { super(p); }
/** Native array allocator. Access with {@link Pointer#position(long)}. */
public bits_hamming_distance(long size) { super((Pointer)null); allocateArray(size); }
private native void allocateArray(long size);
@Override public bits_hamming_distance position(long position) {
return (bits_hamming_distance)super.position(position);
}
public bits_hamming_distance() { super((Pointer)null); allocate(); }
private native void allocate();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
}
// #endif

View File

@ -90,8 +90,13 @@ class SDVariableWrapper {
def |(other: SDVariable)(implicit sameDiff: SameDiff): SDVariable = sameDiff.math.or(thisVariable, other)
def &(other: SDVariable)(implicit sameDiff: SameDiff): SDVariable = sameDiff.math.and(thisVariable, other)
def <<(x: Int)(implicit sameDiff: SameDiff): SDVariable = sameDiff.math.bitShift(null, thisVariable, x)
def >>(x: Int)(implicit sameDiff: SameDiff): SDVariable = sameDiff.math.bitShiftRight(null, thisVariable, x)
def <<(other: SDVariable)(implicit sameDiff: SameDiff): SDVariable = sameDiff.math.bitShift(null, thisVariable, other)
def >>(other: SDVariable)(implicit sameDiff: SameDiff): SDVariable =
sameDiff.math.bitShiftRight(null, thisVariable, other)
def <<(x: Int)(implicit sameDiff: SameDiff): SDVariable =
sameDiff.math.bitShift(null, thisVariable, sameDiff.constant(x))
def >>(x: Int)(implicit sameDiff: SameDiff): SDVariable =
sameDiff.math.bitShiftRight(null, thisVariable, sameDiff.constant(x))
// Overloads for numeric arguments
// Float

View File

@ -188,4 +188,16 @@ class MathTest extends FlatSpec with Matchers {
val w3 = w1 >> 2
w3.eval.toIntVector.head shouldBe 4
}
"SameDiff" should "provide shifting operations with SDVariable argument" in {
implicit val sd = SameDiff.create()
val w1 = sd.constant(16)
val two = sd.constant(2)
val w2 = w1 << two
w2.eval.toIntVector.head shouldBe 64
val w3 = w1 >> two
w3.eval.toIntVector.head shouldBe 4
}
}