diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/fasttext/FastTextTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/fasttext/FastTextTest.java index b60af71d5..731a9cd60 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/fasttext/FastTextTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/fasttext/FastTextTest.java @@ -103,7 +103,7 @@ public class FastTextTest extends BaseDL4JTest { } @Test - public void testPredict() throws IOException { + public void testPredict() { String text = "I like soccer"; FastText fastText = new FastText(supModelFile); @@ -118,7 +118,7 @@ public class FastTextTest extends BaseDL4JTest { } @Test - public void testPredictProbability() throws IOException { + public void testPredictProbability() { String text = "I like soccer"; FastText fastText = new FastText(supModelFile); diff --git a/libnd4j/blas/NDArray.h b/libnd4j/blas/NDArray.h index 3237e5033..3ef3716b3 100644 --- a/libnd4j/blas/NDArray.h +++ b/libnd4j/blas/NDArray.h @@ -34,6 +34,7 @@ #include #include #include +#include #include #include #include @@ -659,6 +660,8 @@ namespace nd4j { void applyPairwiseTransform(nd4j::pairwise::BoolOps op, const NDArray *other, NDArray *target, ExtraArguments *extraParams = nullptr) const; + void applyPairwiseTransform(nd4j::pairwise::IntOps op, const NDArray *other, NDArray *target, ExtraArguments *extraParams = nullptr) const; + /** * apply operation which requires broadcasting, broadcast a smaller array (tad) along bigger one (this) * tad - array to broadcast @@ -672,6 +675,9 @@ namespace nd4j { void applyBroadcast(nd4j::broadcast::BoolOps op, const std::vector &dimensions, const NDArray *tad, NDArray *target = nullptr, ExtraArguments *extraArgs = nullptr); + void applyBroadcast(nd4j::broadcast::IntOps op, const std::vector &dimensions, const NDArray *tad, NDArray *target = nullptr, ExtraArguments *extraArgs = nullptr); + + /** * apply operation which requires broadcasting, broadcast one tensor along another, also this method checks the possibility of broadcasting * other - input array @@ -692,6 +698,9 @@ namespace nd4j { void applyTrueBroadcast(nd4j::BroadcastBoolOpsTuple op, const NDArray* other, NDArray* target, const bool checkTargetShape = true, ExtraArguments *extraArgs = nullptr) const; + void applyTrueBroadcast(nd4j::BroadcastIntOpsTuple op, const NDArray* other, NDArray* target, const bool checkTargetShape = true, ExtraArguments *extraArgs = nullptr) const; + + /** * apply a scalar operation to an array * scalar - input scalar @@ -704,6 +713,9 @@ namespace nd4j { template void applyScalar(nd4j::scalar::BoolOps op, const T scalar, NDArray* target, ExtraArguments *extraParams = nullptr) const; + template + void applyScalar(nd4j::scalar::IntOps op, const T scalar, NDArray* target, ExtraArguments *extraParams = nullptr) const; + /** * apply a scalar operation to an array * scalar - input array which is simple scalar @@ -714,6 +726,7 @@ namespace nd4j { void applyScalarArr(nd4j::scalar::BoolOps op, const NDArray* scalar, NDArray* target, ExtraArguments *extraParams = nullptr) const; + void applyScalarArr(nd4j::scalar::IntOps op, const NDArray* scalar, NDArray* target, ExtraArguments *extraParams = nullptr) const; #if defined(__CUDABLAS__) //&& defined(BUILD_TESTS) template diff --git a/libnd4j/blas/NDArray.hpp b/libnd4j/blas/NDArray.hpp index 72b029c0b..82427f9b9 100644 --- a/libnd4j/blas/NDArray.hpp +++ b/libnd4j/blas/NDArray.hpp @@ -2663,6 +2663,88 @@ void NDArray::applyTrueBroadcast(nd4j::BroadcastBoolOpsTuple op, const NDArray* delete pTarget; } + + +////////////////////////////////////////////////////////////////////////// + void NDArray::applyTrueBroadcast(nd4j::BroadcastIntOpsTuple op, const NDArray* other, NDArray* target, const bool checkTargetShape, ExtraArguments *extraArgs) const { + if (isS()) + throw std::runtime_error("NDArray::applyTrueBroadcast bool: you can't use this method on String array!"); + if(target == nullptr || other == nullptr) + throw std::runtime_error("NDArray::applyTrueBroadcast int method: target or other = nullptr !"); + + if (isEmpty() || other->isEmpty()) + return; + + NDArray::prepareSpecialUse({target}, {this, other}); + + if (isScalar()) { + NDArray temp(target->_shapeInfo, dataType(), false, getContext()); + temp.assign(this); + temp.applyPairwiseTransform(op.p, other, target, extraArgs); + return; + } + if (other->isScalar()) { + this->applyScalarArr(op.s, other, target, extraArgs); + return; + } + + const NDArray* min(other); + const NDArray* max(this); + + if(this->rankOf() < other->rankOf()) { + max = other; + min = this; + } + + if(checkTargetShape) { + Nd4jLong* newShapeInfo = nullptr; + if(!ShapeUtils::evalBroadcastShapeInfo(*max, *min, false, newShapeInfo, getContext()->getWorkspace())) // the rank of target array must be equal to max->rankOf)() + throw std::runtime_error("NDArray::applyTrueBroadcast method: the shapes of this and other arrays are not suitable for broadcast operation !"); + if(!shape::equalsSoft(target->_shapeInfo, newShapeInfo) || target->dataType() != this->dataType()) + throw std::runtime_error("NDArray::applyTrueBroadcast int method: the shape or type of target array is wrong !"); + if(dataType() != other->dataType()) + throw std::invalid_argument("NDArray::applyTrueBroadcast int method: this and other arrays must have the same type !"); + } + + NDArray* pTarget = (max->dataType() == target->dataType()) ? target : new NDArray(target->ordering(), target->getShapeAsVector(), max->dataType(), target->getContext()); + // check whether max array has to be tiled + if(!max->isSameShape(target)) { + // evaluate repeating dimensions for tile operation + std::vector repeatMax(max->rankOf()); + for(int i = 1; i <= max->rankOf(); ++i) + repeatMax[i-1] = (target->_shapeInfo[i] / max->_shapeInfo[i]); + max->tile(repeatMax, *pTarget); + } + else + pTarget->assign(max); + + // check whether min array has to be tiled + std::vector repeatMin(min->rankOf()); + int product = 1; + for(int i = min->rankOf(); i >=1 ; --i) { + repeatMin[i-1] = (target->_shapeInfo[target->rankOf() - min->rankOf() + i] / min->_shapeInfo[i]); + product *= repeatMin[i-1]; + } + + auto pMin = const_cast(min); + if(product != 1 ) + pMin = new NDArray(min->tile(repeatMin)); + + std::vector sameDims = ShapeUtils::getDimsWithSameShape(*target, *pMin); + + if(max == this) + pTarget->applyBroadcast(op.b, sameDims, pMin, target, extraArgs); + else + pMin->applyBroadcast(op.b, sameDims, pTarget, target, extraArgs); + + if(pMin != min) + delete pMin; + if(pTarget != target) + delete pTarget; + } + + + ////////////////////////////////////////////////////////////////////////// NDArray NDArray::applyTrueBroadcast(nd4j::BroadcastOpsTuple op, const NDArray& other, ExtraArguments *extraArgs) const { if (isEmpty() || other.isEmpty()) { @@ -2801,6 +2883,67 @@ void NDArray::applyBroadcast(nd4j::broadcast::BoolOps op, const std::vector registerSpecialUse({result}, {this, other}); } + +////////////////////////////////////////////////////////////////////////// + void NDArray::applyBroadcast(nd4j::broadcast::IntOps op, const std::vector& dimensions, const NDArray* other, NDArray* target, ExtraArguments* extraArgs) { + if (!isZ()) + throw std::runtime_error("NDArray::applyBroadcast IntOps: you can't use this method on non-Integer array!"); + if(isEmpty() || other->isEmpty()) { + if(!target->isEmpty()) + throw std::runtime_error("NDArray::applyBroadcast IntOps: when some of input arrays (or both) is empty, target array must be empty as well !"); + return; + } + + if (dimensions.empty()) + return; + + auto result = target == nullptr ? this : target; + + if (other->lengthOf() == lengthOf() && this->rankOf() == other->rankOf()) { + NDArray::prepareSpecialUse({result}, {this, other}); + NativeOpExecutioner::execPairwiseIntTransform(getContext(), fromBroadcastToPairwiseInt(op), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({result}, {this, other}); + return; + } + + NDArray *min(nullptr), *max(nullptr); + if((lengthOf() > other->lengthOf()) || (lengthOf() == other->lengthOf() && rankOf() >= other->rankOf())) { + max = this; + min = const_cast(other); + } + else { + max = const_cast(other); + min = this; + } + + if(result->dataType() != dataType()) + throw std::invalid_argument("NDArray::applyBroadcast int method: type of target array must be the same as input!"); + if(!result->isSameShape(max)) + throw std::invalid_argument("NDArray::applyBroadcast int method: max and target arrays must have the same shape !"); + if(_dataType != other->_dataType) + throw std::invalid_argument("NDArray::applyBroadcast int method: this and other arrays must have the same type !"); + + std::vector copy(dimensions); + + if (dimensions.size() > 1) + std::sort(copy.begin(), copy.end()); + + Nd4jLong tadLength = shape::tadLength(max->shapeInfo(), copy.data(), (int) copy.size()); + if (tadLength != min->lengthOf()) + throw std::runtime_error("Tad length mismatch"); + + auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(max->shapeInfo(), copy); + auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(result->shapeInfo(), copy); + + // TODO: eventually we want separate tads here + NDArray::prepareSpecialUse({result}, {this, other}); + if(max == this) + NativeOpExecutioner::execBroadcastInt( getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), copy.data(), (int)copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets()); + else + NativeOpExecutioner::execInverseBroadcastInt(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), copy.data(), (int)copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets()); + registerSpecialUse({result}, {this, other}); + } + ////////////////////////////////////////////////////////////////////////// void NDArray::applyBroadcast(nd4j::broadcast::Ops op, const std::initializer_list dimensions, const NDArray* tadArray, NDArray* target, ExtraArguments* extraArgs) { std::vector vec(dimensions); @@ -3043,6 +3186,22 @@ void NDArray::applyPairwiseTransform(nd4j::pairwise::BoolOps op, const NDArray * NDArray::registerSpecialUse({target}, {this, other}); } +//////////////////////////////////////////////////////////////////////// + void NDArray::applyPairwiseTransform(nd4j::pairwise::IntOps op, const NDArray *other, NDArray *target, ExtraArguments *extraParams) const{ + if (isS()) + throw std::runtime_error("NDArray::applyPairwiseTransform IntOps: you can't use this method on String array!"); + if (other->lengthOf() != target->lengthOf()) + throw std::invalid_argument("NDArray::applyPairwiseTransform IntOps method - lengths of arrays are mismatched"); + if (!target->isZ()) + throw std::invalid_argument("NDArray::applyPairwiseTransform IntOps method - result must have bool type"); + if (dataType() != other->dataType()) + throw std::invalid_argument("NDArray::applyPairwiseTransform IntOps method - this and other arrays must have the same type !"); + + NDArray::prepareSpecialUse({target}, {this, other}); + NativeOpExecutioner::execPairwiseIntTransform(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()) : nullptr); + NDArray::registerSpecialUse({target}, {this, other}); + } + ////////////////////////////////////////////////////////////////////////// void NDArray::applyPairwiseTransform(nd4j::pairwise::Ops op, const NDArray& other, ExtraArguments *extraParams) { applyPairwiseTransform(op, &other, this, extraParams); @@ -3585,6 +3744,45 @@ template void NDArray::applyScalar(nd4j::scalar::BoolOps op, const int8_ template void NDArray::applyScalar(nd4j::scalar::BoolOps op, const uint8_t scalar, NDArray *target, ExtraArguments *extraParams) const; template void NDArray::applyScalar(nd4j::scalar::BoolOps op, const bool scalar, NDArray *target, ExtraArguments *extraParams) const; + +////////////////////////////////////////////////////////////////////////// + void NDArray::applyScalarArr(nd4j::scalar::IntOps op, const NDArray* scalar, NDArray *target, ExtraArguments *extraParams) const { + if (isS()) + throw std::runtime_error("NDArray::applyScalarArr IntOps: you can't use this method on String array!"); + + if (target == nullptr || target->dataType() != this->dataType()) + throw std::invalid_argument("NDArray::applyScalarArr int method: target is nullptr or has not bool type!"); + if (dataType() != scalar->dataType()) { + nd4j_printf("NDArray::applyScalarArr IntOps: this dtype: [%i]; scalar dtype: [%i]\n", this->dataType(), scalar->dataType()); + throw std::invalid_argument("NDArray::applyScalarArr int method: this and scalar arrays must have the same type!"); + } + + NDArray::prepareSpecialUse({target}, {this, scalar}); + NativeOpExecutioner::execScalarInt(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo(), scalar->getBuffer(), scalar->getShapeInfo(), scalar->getSpecialBuffer(), scalar->getSpecialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target->dataType()): nullptr); + NDArray::registerSpecialUse({target}, {this, scalar}); + } + +//////////////////////////////////////////////////////////////////////// + template + void NDArray::applyScalar(nd4j::scalar::IntOps op, const T scalar, NDArray *target, ExtraArguments *extraParams) const { + + NDArray scalarArr = NDArrayFactory::create(this->dataType(), scalar, getContext()); + applyScalarArr(op, &scalarArr, target, extraParams); + } + + template <> void NDArray::applyScalar(nd4j::scalar::IntOps op, const NDArray* scalar, NDArray *target, ExtraArguments *extraParams) const { throw std::runtime_error("NDArray::applyScalar method: do not use me!");} + template void NDArray::applyScalar(nd4j::scalar::IntOps op, const double scalar, NDArray *target, ExtraArguments *extraParams) const; + template void NDArray::applyScalar(nd4j::scalar::IntOps op, const float scalar, NDArray *target, ExtraArguments *extraParams) const; + template void NDArray::applyScalar(nd4j::scalar::IntOps op, const float16 scalar, NDArray *target, ExtraArguments *extraParams) const; + template void NDArray::applyScalar(nd4j::scalar::IntOps op, const bfloat16 scalar, NDArray *target, ExtraArguments *extraParams) const; + template void NDArray::applyScalar(nd4j::scalar::IntOps op, const Nd4jLong scalar, NDArray *target, ExtraArguments *extraParams) const; + template void NDArray::applyScalar(nd4j::scalar::IntOps op, const int scalar, NDArray *target, ExtraArguments *extraParams) const; + template void NDArray::applyScalar(nd4j::scalar::IntOps op, const int16_t scalar, NDArray *target, ExtraArguments *extraParams) const; + template void NDArray::applyScalar(nd4j::scalar::IntOps op, const int8_t scalar, NDArray *target, ExtraArguments *extraParams) const; + template void NDArray::applyScalar(nd4j::scalar::IntOps op, const uint8_t scalar, NDArray *target, ExtraArguments *extraParams) const; + template void NDArray::applyScalar(nd4j::scalar::IntOps op, const bool scalar, NDArray *target, ExtraArguments *extraParams) const; + + //////////////////////////////////////////////////////////////////////// void NDArray::applyIndexReduce(nd4j::indexreduce::Ops op, NDArray* target, const std::vector& dimensions, const ExtraArguments *extraParams) const { if (isS()) diff --git a/libnd4j/blas/NativeOpExecutioner.h b/libnd4j/blas/NativeOpExecutioner.h index 534453a0f..cae7a4e56 100644 --- a/libnd4j/blas/NativeOpExecutioner.h +++ b/libnd4j/blas/NativeOpExecutioner.h @@ -189,6 +189,16 @@ static void execScalarBool(nd4j::LaunchContext *lc, void *dScalar, Nd4jLong *dSscalarShapeInfo, void *extraParams, bool allowParallelism = true); +static void execScalarInt(nd4j::LaunchContext *lc, + int opNum, + void *hX, Nd4jLong *hXShapeInfo, + void *dX, Nd4jLong *dXShapeInfo, + void *hZ, Nd4jLong *hZShapeInfo, + void *dZ, Nd4jLong *dZShapeInfo, + void *hScalar, Nd4jLong *hSscalarShapeInfo, + void *dScalar, Nd4jLong *dSscalarShapeInfo, + void *extraParams, bool allowParallelism = true); + static void execScalar(nd4j::LaunchContext *lc, int opNum, void *hX, Nd4jLong *hXShapeInfo, @@ -215,6 +225,20 @@ static void execScalarBool(nd4j::LaunchContext *lc, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ); + static void execScalarInt(nd4j::LaunchContext *lc, + int opNum, + void *hX, Nd4jLong *hXShapeInfo, + void *dX, Nd4jLong *dXShapeInfo, + void *extraParams, + void *hZ, Nd4jLong *hZShapeInfo, + void *dZ, Nd4jLong *dZShapeInfo, + void *hScalars, Nd4jLong *hScalarShapeInfo, + void *dScalars, Nd4jLong *dScalarShapeInfo, + int *dimension, int dimensionLength, + Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, + Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ); + + /** * * @param opNum @@ -275,6 +299,30 @@ static void execScalarBool(nd4j::LaunchContext *lc, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ); + + static void execBroadcastInt(nd4j::LaunchContext *lc, + int opNum, + void *hX, Nd4jLong *hXShapeInfo, + void *dX, Nd4jLong *dXShapeInfo, + void *hY, Nd4jLong *hYShapeInfo, + void *dY, Nd4jLong *dYShapeInfo, + void *hZ, Nd4jLong *hZShapeInfo, + void *dZ, Nd4jLong *dZShapeInfo, + int *dimension, int dimensionLength, + Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, + Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ); + + static void execInverseBroadcastInt(nd4j::LaunchContext *lc, + int opNum, + void *x, Nd4jLong *xShapeInfo, + void *dX, Nd4jLong *dXShapeInfo, + void *y, Nd4jLong *yShapeInfo, + void *dY, Nd4jLong *dYShapeInfo, + void *result, Nd4jLong *resultShapeInfo, + void *dZ, Nd4jLong *dZShapeInfo, + int *dimension, int dimensionLength, + Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, + Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ); /** * @@ -308,6 +356,16 @@ static void execScalarBool(nd4j::LaunchContext *lc, void *dZ, Nd4jLong *dZShapeInfo, void *extraParams); + static void execPairwiseIntTransform(nd4j::LaunchContext *lc, + int opNum, + void *hX, Nd4jLong *hXShapeInfo, + void *dX, Nd4jLong *dXShapeInfo, + void *hY, Nd4jLong *hYShapeInfo, + void *dY, Nd4jLong *dYShapeInfo, + void *hZ, Nd4jLong *hZShapeInfo, + void *dZ, Nd4jLong *dZShapeInfo, + void *extraParams); + /** * * @param opNum diff --git a/libnd4j/blas/cpu/NativeOpExecutioner.cpp b/libnd4j/blas/cpu/NativeOpExecutioner.cpp index b2ce7846a..22fd9eca4 100644 --- a/libnd4j/blas/cpu/NativeOpExecutioner.cpp +++ b/libnd4j/blas/cpu/NativeOpExecutioner.cpp @@ -24,6 +24,10 @@ #include #include +#include +#include +#include + #include #include #include @@ -242,6 +246,66 @@ void NativeOpExecutioner::execInverseBroadcastBool(nd4j::LaunchContext *lc, BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::execInverse(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, BOOL_TYPES); } + + +//////////////////////////////////////////////////////////////////////// +void NativeOpExecutioner::execBroadcastInt(nd4j::LaunchContext *lc, + int opNum, + void *hX, Nd4jLong *hXShapeInfo, + void *dX, Nd4jLong *dXShapeInfo, + void *hY, Nd4jLong *hYShapeInfo, + void *dY, Nd4jLong *dYShapeInfo, + void *hZ, Nd4jLong *hZShapeInfo, + void *dZ, Nd4jLong *dZShapeInfo, + int *dimension, int dimensionLength, + Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, + Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) { +#ifdef _OPENMP + omp_set_nested(1); +#endif + + auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); + auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); + auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + + if (xType != yType || xType != zType) + throw nd4j::datatype_exception::build("NativeOpExecutioner::execBroadcastInt", zType, xType, yType); + + if (!nd4j::DataTypeUtils::isZ(zType)) + throw nd4j::datatype_exception::build("NativeOpExecutioner::execBroadcastInt requires integer data type", zType); + + BUILD_SINGLE_SELECTOR(xType, functions::broadcast::BroadcastInt, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), INTEGER_TYPES); +} + +void NativeOpExecutioner::execInverseBroadcastInt(nd4j::LaunchContext *lc, + int opNum, + void *hX, Nd4jLong *hXShapeInfo, + void *dX, Nd4jLong *dXShapeInfo, + void *hY, Nd4jLong *hYShapeInfo, + void *dY, Nd4jLong *dYShapeInfo, + void *hZ, Nd4jLong *hZShapeInfo, + void *dZ, Nd4jLong *dZShapeInfo, + int *dimension, int dimensionLength, + Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, + Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) { + +#ifdef _OPENMP + omp_set_nested(1); +#endif + + auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); + auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); + auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + + if (xType != yType || xType != zType) + throw nd4j::datatype_exception::build("NativeOpExecutioner::execPairwiseIntTransform", zType, xType, yType); + + if (!nd4j::DataTypeUtils::isZ(zType)) + throw nd4j::datatype_exception::build("NativeOpExecutioner::execBroadcastInt requires integer data type", zType); + + BUILD_SINGLE_SELECTOR(xType, functions::broadcast::BroadcastInt, ::execInverse(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), INTEGER_TYPES); +} + //////////////////////////////////////////////////////////////////////// /** * @@ -297,9 +361,42 @@ void NativeOpExecutioner::execPairwiseBoolTransform(nd4j::LaunchContext *lc, auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + if (xType != yType) + throw nd4j::datatype_exception::build("NativeOpExecutioner::execPairwiseBoolTransform", xType, yType); + + if (zType != nd4j::DataType::BOOL) + throw nd4j::datatype_exception::build("NativeOpExecutioner::execPairwiseBoolTransform", nd4j::DataType::BOOL, zType); + BUILD_DOUBLE_SELECTOR(xType, zType, functions::pairwise_transforms::PairWiseBoolTransform, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, extraParams), LIBND4J_TYPES, BOOL_TYPES); } +//////////////////////////////////////////////////////////////////////// +void NativeOpExecutioner::execPairwiseIntTransform(nd4j::LaunchContext *lc, + int opNum, + void *hX, Nd4jLong *hXShapeInfo, + void *dX, Nd4jLong *dXShapeInfo, + void *hY, Nd4jLong *hYShapeInfo, + void *dY, Nd4jLong *dYShapeInfo, + void *hZ, Nd4jLong *hZShapeInfo, + void *dZ, Nd4jLong *dZShapeInfo, + void *extraParams) { +#ifdef _OPENMP + omp_set_nested(1); +#endif + + auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); + auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); + auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + + if (xType != yType || xType != zType) + throw nd4j::datatype_exception::build("NativeOpExecutioner::execPairwiseIntTransform", zType, xType, yType); + + if (!nd4j::DataTypeUtils::isZ(zType)) + throw nd4j::datatype_exception::build("NativeOpExecutioner::execSPairwiseInt requires integer data type", zType); + + BUILD_SINGLE_SELECTOR(xType, functions::pairwise_transforms::PairWiseIntTransform, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, extraParams), INTEGER_TYPES); +} + //////////////////////////////////////////////////////////////////////// /** * @@ -739,6 +836,64 @@ void NativeOpExecutioner::execScalarBool(nd4j::LaunchContext *lc, BUILD_DOUBLE_SELECTOR(xType, zType, functions::scalar::ScalarBoolTransform, ::transform(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, hScalars, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, BOOL_TYPES); } +//////////////////////////////////////////////////////////////////////// +void NativeOpExecutioner::execScalarInt(nd4j::LaunchContext *lc, + int opNum, + void *hX, Nd4jLong *hXShapeInfo, + void *dX, Nd4jLong *dXShapeInfo, + void *hZ, Nd4jLong *hZShapeInfo, + void *dZ, Nd4jLong *dZShapeInfo, + void *hScalar, Nd4jLong *hSscalarShapeInfo, + void *dScalar, Nd4jLong *dSscalarShapeInfo, + void *extraParams, bool allowParallelism) { + +#ifdef _OPENMP + omp_set_nested(1); +#endif + + auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); + auto yType = nd4j::ArrayOptions::dataType(hSscalarShapeInfo); + auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + + if (xType != yType || xType != zType) + throw nd4j::datatype_exception::build("NativeOpExecutioner::execScalarInt", xType, yType); + + if (!nd4j::DataTypeUtils::isZ(zType)) + throw nd4j::datatype_exception::build("NativeOpExecutioner::execScalarInt", nd4j::DataType::INT32, zType); + + BUILD_SINGLE_SELECTOR(xType, functions::scalar::ScalarIntTransform, ::transform(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, hScalar, extraParams), INTEGER_TYPES); +} + +//////////////////////////////////////////////////////////////////////// +void NativeOpExecutioner::execScalarInt(nd4j::LaunchContext *lc, + int opNum, + void *hX, Nd4jLong *hXShapeInfo, + void *dX, Nd4jLong *dXShapeInfo, + void *extraParams, + void *hZ, Nd4jLong *hZShapeInfo, + void *dZ, Nd4jLong *dZShapeInfo, + void *hScalars, Nd4jLong *hScalarShapeInfo, + void *dScalars, Nd4jLong *dScalarShapeInfo, + int *dimension, int dimensionLength, + Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, + Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) { +#ifdef _OPENMP + omp_set_nested(1); +#endif + + auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); + auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo); + auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + + if (xType != yType || xType != zType) + throw nd4j::datatype_exception::build("NativeOpExecutioner::execScalarInt", xType, yType); + + if (!nd4j::DataTypeUtils::isZ(zType)) + throw nd4j::datatype_exception::build("NativeOpExecutioner::execScalarInt requires integer data type", zType); + + BUILD_SINGLE_SELECTOR(xType, functions::scalar::ScalarIntTransform, ::transform(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, hScalars, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), INTEGER_TYPES); +} + //////////////////////////////////////////////////////////////////////// /** * diff --git a/libnd4j/blas/cuda/NativeOpExecutioner.cu b/libnd4j/blas/cuda/NativeOpExecutioner.cu index 8c4f1d3fa..fcb473820 100644 --- a/libnd4j/blas/cuda/NativeOpExecutioner.cu +++ b/libnd4j/blas/cuda/NativeOpExecutioner.cu @@ -38,19 +38,22 @@ #include #include #include -#include #include #include #include +#include #include +#include +#include #include #include #include #include -#include #include #include +#include #include +#include using namespace nd4j; @@ -152,6 +155,39 @@ void NativeOpExecutioner::execPairwiseBoolTransform( nd4j::LaunchContext *lc, throw cuda_exception::build("execPairwiseBoolTransform failed", res); } +//////////////////////////////////////////////////////////////////////// +void NativeOpExecutioner::execPairwiseIntTransform( nd4j::LaunchContext *lc, + int opNum, + void *hX, Nd4jLong *hXShapeInfo, + void *dX, Nd4jLong *dXShapeInfo, + void *hY, Nd4jLong *hYShapeInfo, + void *dY, Nd4jLong *dYShapeInfo, + void *hZ, Nd4jLong *hZShapeInfo, + void *dZ, Nd4jLong *dZShapeInfo, + void *extraParams) { + + auto stream = lc->getCudaStream(); + + auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); + auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); + auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + + if (!DataTypeUtils::isZ(zType)) + throw nd4j::datatype_exception::build("NativeOpExecutioner::execPairwiseIntTransform wrong Z operand data type", nd4j::DataType::BOOL, zType); + + if (yType != xType || zType != xType) + throw nd4j::datatype_exception::build("NativeOpExecutioner::execPairwiseIntTransform both operands must have same data type", xType, yType); + + dim3 launchDims(256, 1024, 16384); + + BUILD_SINGLE_SELECTOR(xType, functions::pairwise_transforms::PairWiseIntTransform, ::executeCudaShaped(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, extraParams), INTEGER_TYPES) + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) + throw cuda_exception::build("execPairwiseIntTransform failed", res); +} + //////////////////////////////////////////////////////////////////////// void NativeOpExecutioner::execSummaryStatsScalar(nd4j::LaunchContext *lc, int opNum, @@ -252,6 +288,81 @@ void NativeOpExecutioner::execInverseBroadcastBool(nd4j::LaunchContext *lc, throw cuda_exception::build("execInverseBroadcastBool failed", res); } + +//////////////////////////////////////////////////////////////////////// +void NativeOpExecutioner::execBroadcastInt(nd4j::LaunchContext *lc, + int opNum, + void *hX, Nd4jLong *hXShapeInfo, + void *dX, Nd4jLong *dXShapeInfo, + void *hY, Nd4jLong *hYShapeInfo, + void *dY, Nd4jLong *dYShapeInfo, + void *hZ, Nd4jLong *hZShapeInfo, + void *dZ, Nd4jLong *dZShapeInfo, + int *dimension, int dimensionLength, + Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, + Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) { + + auto stream = lc->getCudaStream(); + + auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); + auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); + auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + + if (!DataTypeUtils::isZ(zType)) + throw std::runtime_error("NativeOpExecutioner::execBroadcastInt requires Z operand to have INT type"); + + if (yType != xType || zType != xType) + throw std::runtime_error("NativeOpExecutioner::execBroadcastInt requires both X & Y operands to have same type"); + + if (nd4j::Environment::getInstance()->isDebugAndVerbose()) + printf("F3B opNum:[%i]\n", opNum); + + dim3 launchDims(256, 256, 1024); + + BUILD_SINGLE_SELECTOR(xType, functions::broadcast::BroadcastInt, ::execBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), INTEGER_TYPES) + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) + throw cuda_exception::build("execBroadcastBool failed", res); +} + +void NativeOpExecutioner::execInverseBroadcastInt(nd4j::LaunchContext *lc, + int opNum, + void *hX, Nd4jLong *hXShapeInfo, + void *dX, Nd4jLong *dXShapeInfo, + void *hY, Nd4jLong *hYShapeInfo, + void *dY, Nd4jLong *dYShapeInfo, + void *hZ, Nd4jLong *hZShapeInfo, + void *dZ, Nd4jLong *dZShapeInfo, + int *dimension, int dimensionLength, + Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, + Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) { + auto stream = lc->getCudaStream(); + + auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); + auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); + auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + + if (!DataTypeUtils::isZ(zType)) + throw std::runtime_error("NativeOpExecutioner::execBroadcastInt requires Z operand to have INT type"); + + if (yType != xType || zType != xType) + throw std::runtime_error("NativeOpExecutioner::execBroadcastInt requires both X & Y operands to have same type"); + + if (nd4j::Environment::getInstance()->isDebugAndVerbose()) + printf("F3BI opNum:[%i]\n", opNum); + + dim3 launchDims(256, 256, 1024); + + BUILD_SINGLE_SELECTOR(xType, functions::broadcast::BroadcastInt, ::execInverseBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), INTEGER_TYPES) + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) + throw cuda_exception::build("execInverseBroadcastInt failed", res); +} + //////////////////////////////////////////////////////////////////////// /** * @@ -1114,6 +1225,75 @@ void NativeOpExecutioner::execScalarBool(nd4j::LaunchContext *lc, throw cuda_exception::build("execScalarBool B failed", res); } +//////////////////////////////////////////////////////////////////////// +void NativeOpExecutioner::execScalarInt(nd4j::LaunchContext *lc, + int opNum, + void *hX, Nd4jLong *hXShapeInfo, + void *dX, Nd4jLong *dXShapeInfo, + void *hZ, Nd4jLong *hZShapeInfo, + void *dZ, Nd4jLong *dZShapeInfo, + void *hScalar, Nd4jLong *hScalarShapeInfo, + void *dScalar, Nd4jLong *dScalarShapeInfo, + void *extraParams, bool allowParallelism) { + + auto stream = lc->getCudaStream(); + + dim3 launchDims = dim3(256, 512, 8192); + + auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); + auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo); + auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + + if (xType != yType || zType != xType) + throw std::runtime_error("NativeOpExecutioner::execScalarInt requires X & Y to have same type"); + + if (!DataTypeUtils::isZ(zType) ) + throw std::runtime_error("NativeOpExecutioner::execScalarInt requires Z operand to have INT type"); + + BUILD_SINGLE_SELECTOR(xType, functions::scalar::ScalarIntTransform, ::executeCudaShaped(launchDims, stream, opNum, dX, dXShapeInfo, dZ, dZShapeInfo, dScalar, extraParams), INTEGER_TYPES); + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) + throw cuda_exception::build("execScalarInt failed", res); +} + +//////////////////////////////////////////////////////////////////////// +void NativeOpExecutioner::execScalarInt(nd4j::LaunchContext *lc, + int opNum, + void *hX, Nd4jLong *hXShapeInfo, + void *dX, Nd4jLong *dXShapeInfo, + void *extraParams, + void *hZ, Nd4jLong *hZShapeInfo, + void *dZ, Nd4jLong *dZShapeInfo, + void *hScalars, Nd4jLong *hScalarShapeInfo, + void *dScalars, Nd4jLong *dScalarShapeInfo, + int *dimension, int dimensionLength, + Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, + Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) { + + auto stream = lc->getCudaStream(); + + dim3 launchDims(256, 512, 8192); + + auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); + auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo); + auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + + if (xType != yType || zType != xType) + throw std::runtime_error("NativeOpExecutioner::execScalarInt requires X & Y to have same type"); + + if (!DataTypeUtils::isZ(zType) ) + throw std::runtime_error("NativeOpExecutioner::execScalarInt requires Z operand to have INT type"); + + BUILD_SINGLE_SELECTOR(xType, functions::scalar::ScalarIntTransform, ::executeCudaAlongDimension(launchDims, stream, opNum, dX, dXShapeInfo, dZ, dZShapeInfo, dScalars, extraParams, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), INTEGER_TYPES); + + // TODO: remove after the release + auto res = cudaStreamSynchronize(*stream); + if (res != 0) + throw cuda_exception::build("execScalarInt B failed", res); +} + //////////////////////////////////////////////////////////////////////// void NativeOpExecutioner::execScalar(nd4j::LaunchContext *lc, int opNum, diff --git a/libnd4j/include/loops/BroadcastPairwiseConverter.h b/libnd4j/include/loops/BroadcastPairwiseConverter.h index 4b11cfc0d..fb5acf19b 100644 --- a/libnd4j/include/loops/BroadcastPairwiseConverter.h +++ b/libnd4j/include/loops/BroadcastPairwiseConverter.h @@ -71,12 +71,25 @@ inline pairwise::BoolOps fromBroadcastToPairwiseBool(broadcast::BoolOps op) { case broadcast::And: return pairwise::And; case broadcast::Or: return pairwise::Or; case broadcast::Xor: return pairwise::Xor; - case broadcast::Not: return pairwise::Not; + case broadcast::Not: return pairwise::Not; default: throw std::runtime_error("fromBroadcastToPairwiseBool: Not convertible operation"); } } + inline pairwise::IntOps fromBroadcastToPairwiseInt(broadcast::IntOps op) { + switch (op) { + case broadcast::IntOps::IntAnd: return pairwise::IntOps::IntAnd; + case broadcast::IntOps::IntOr: return pairwise::IntOps::IntOr; + case broadcast::IntOps::IntXor: return pairwise::IntOps::IntXor; + case broadcast::IntOps::ShiftLeft: return pairwise::IntOps::ShiftLeft; + case broadcast::IntOps::ShiftRight: return pairwise::IntOps::ShiftRight; + case broadcast::IntOps::CyclicShiftLeft: return pairwise::IntOps::CyclicShiftLeft; + case broadcast::IntOps::CyclicShiftRight: return pairwise::IntOps::CyclicShiftRight; + default: + throw std::runtime_error("fromBroadcastToPairwiseInt: Not convertible operation"); + } + } } #endif //DEV_TESTS_BROADCASTPAIRWISECONVERTER_H \ No newline at end of file diff --git a/libnd4j/include/loops/broadcasting_int.h b/libnd4j/include/loops/broadcasting_int.h new file mode 100644 index 000000000..84bc0f949 --- /dev/null +++ b/libnd4j/include/loops/broadcasting_int.h @@ -0,0 +1,164 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +/* + * broadcasting.h + * + * Created on: Dec 28, 2015 + * Author: agibsonccc + */ + +#ifndef BROADCASTING_INT_H_ +#define BROADCASTING_INT_H_ +#include +#include +#include +#include +#include +#include +#include + +#ifdef __CUDACC__ +#include +#include +#endif +#ifdef __JNI__ +#include +#endif + +#include + +#include "legacy_ops.h" + +namespace functions { + namespace broadcast { + +/** + * Broadcast operation + * for broadcasting a smaller tensor + * along long a bigger one. + */ + template + class BroadcastInt { + public: + +#ifdef __CUDACC__ + + template + static __device__ void transformCuda( + void *x, + Nd4jLong *xShapeInfo, + void *y, + Nd4jLong *yShapeInfo, + void *result, + Nd4jLong *resultShapeInfo, + int *dimension, + int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ); + + template + static __host__ void intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ); + + static __host__ void execBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ); + + template + static __device__ void transformInverseCuda( + void *x, + Nd4jLong *xShapeInfo, + void *y, + Nd4jLong *yShapeInfo, + void *result, + Nd4jLong *resultShapeInfo, + int *dimension, + int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ); + + template + static __host__ void intermediateInverseBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ); + + static __host__ void execInverseBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ); + +#endif + + static void exec(int opNum, + void *x, + Nd4jLong *xShapeInfo, + void *y, + Nd4jLong *yShapeInfo, + void *result, + Nd4jLong *resultShapeInfo, + int *dimension, + int dimensionLength, + Nd4jLong *tadShapeInfo, + Nd4jLong *tadOffset, + Nd4jLong *tadShapeInfoZ, + Nd4jLong *tadOffsetZ); + + static void execInverse(int opNum, + void *x, + Nd4jLong *xShapeInfo, + void *y, + Nd4jLong *yShapeInfo, + void *result, + Nd4jLong *resultShapeInfo, + int *dimension, + int dimensionLength, + Nd4jLong *tadShapeInfo, + Nd4jLong *tadOffset, + Nd4jLong *tadShapeInfoZ, + Nd4jLong *tadOffsetZ); + + /** + * CPU execution + * @param x the input + * @param xShapeInfo the x shape information + * @param y the y data + * @param yShapeInfo the y shape information + * @param result the result + * @param resultShapeInfo the result shape information + * @param dimension the dimension to broadcast along long + * @param dimensionLength the length of the dimension buffer + */ + template + static void exec(void *x, + Nd4jLong *xShapeInfo, + void *y, + Nd4jLong *yShapeInfo, + void *result, + Nd4jLong *resultShapeInfo, + int *dimension, + int dimensionLength, + Nd4jLong *tadShapeInfo, + Nd4jLong *tadOffset, + Nd4jLong *tadShapeInfoZ, + Nd4jLong *tadOffsetZ); + + template + static void execInverse(void *x, + Nd4jLong *xShapeInfo, + void *y, + Nd4jLong *yShapeInfo, + void *result, + Nd4jLong *resultShapeInfo, + int *dimension, + int dimensionLength, + Nd4jLong *tadShapeInfo, + Nd4jLong *tadOffset, + Nd4jLong *tadShapeInfoZ, + Nd4jLong *tadOffsetZ); + }; + } +} + +#endif /* BROADCASTING_H_ */ diff --git a/libnd4j/include/loops/cpu/broadcasting_int.cpp b/libnd4j/include/loops/cpu/broadcasting_int.cpp new file mode 100644 index 000000000..c092da50b --- /dev/null +++ b/libnd4j/include/loops/cpu/broadcasting_int.cpp @@ -0,0 +1,464 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include +#include +#include +#include +#include +#include + +using namespace simdOps; + +namespace functions { + namespace broadcast { + + template + void BroadcastInt::exec(const int opNum, + void *x, + Nd4jLong *xShapeInfo, + void *y, + Nd4jLong *yShapeInfo, + void *z, + Nd4jLong *zShapeInfo, + int *dimension, + int dimensionLength, + Nd4jLong *xTadShapeInfo, + Nd4jLong *xTadOffset, + Nd4jLong *zTadShapeInfo, + Nd4jLong *zTadOffset) { + DISPATCH_BY_OPNUM_T(exec, PARAMS(x, + xShapeInfo, + y, + yShapeInfo, + z, + zShapeInfo, + dimension, + dimensionLength, + xTadShapeInfo, + xTadOffset, + zTadShapeInfo, + zTadOffset), BROADCAST_INT_OPS); + } + + template + void BroadcastInt::execInverse(const int opNum, + void *x, + Nd4jLong *xShapeInfo, + void *y, + Nd4jLong *yShapeInfo, + void *z, + Nd4jLong *zShapeInfo, + int *dimension, + int dimensionLength, + Nd4jLong *xTadShapeInfo, + Nd4jLong *xTadOffset, + Nd4jLong *zTadShapeInfo, + Nd4jLong *zTadOffset) { + DISPATCH_BY_OPNUM_T(execInverse, PARAMS(x, + xShapeInfo, + y, + yShapeInfo, + z, + zShapeInfo, + dimension, + dimensionLength, + xTadShapeInfo, + xTadOffset, + zTadShapeInfo, + zTadOffset), BROADCAST_INT_OPS); + } + + template + template + void BroadcastInt::exec(void *vx, + Nd4jLong *xShapeInfo, + void *vy, + Nd4jLong *yShapeInfo, + void *vz, + Nd4jLong *zShapeInfo, + int *dimension, + int dimensionLength, + Nd4jLong *xTadShapeInfo, + Nd4jLong *xTadOffset, + Nd4jLong *zTadShapeInfo, + Nd4jLong *zTadOffset) { + + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + + //decompose in to several sub tads after + //moving all dimensions (in sorted order) + //to the back. + //permuted version of the x shape info for setting up the tad problem + auto xTadShapeShapeInfo = xTadShapeInfo; + auto tadOffsets = xTadOffset; + + if (xTadShapeInfo == nullptr || tadOffsets == nullptr) { + auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(xShapeInfo, dimension, dimensionLength); + + xTadShapeShapeInfo = tadPack.primaryShapeInfo(); + tadOffsets = tadPack.primaryOffsets(); + } + + //int *resultStride = shape::stride(xTadShapeShapeInfo); + unsigned int tadLength = shape::length(xTadShapeShapeInfo);//shape::length(xTadShapeShapeInfo); + unsigned int tads = shape::length(xShapeInfo) / tadLength; + + if (zTadShapeInfo == nullptr) { + zTadShapeInfo = xTadShapeShapeInfo; + zTadOffset = tadOffsets; + } + + auto lenZ = shape::length(zTadShapeInfo); + auto lenY = shape::length(yShapeInfo); + + int tadsPerThread = tads / TAD_THRESHOLD; + int threads = nd4j::math::nd4j_max(1, tadsPerThread); + threads = nd4j::math::nd4j_min(threads, omp_get_max_threads()); + + auto xEws = shape::elementWiseStride(xTadShapeShapeInfo); + auto yEws = shape::elementWiseStride(yShapeInfo); + auto zEws = shape::elementWiseStride(zTadShapeInfo); + + const nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXYZ(xTadShapeShapeInfo, yShapeInfo, zTadShapeInfo); + + if (kindOfLoop == nd4j::LoopKind::EWS1) { + PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) + for (int i = 0; i < tads; i++) { + auto oX = x + tadOffsets[i]; + auto oZ = z + zTadOffset[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) + oZ[f] = OpType::op(oX[f], y[f]); + } + } + else if(kindOfLoop == nd4j::LoopKind::EWSNONZERO) { + PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) + for (int i = 0; i < tads; i++) { + auto oX = x + tadOffsets[i]; + auto oZ = z + zTadOffset[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) + oZ[f * zEws] = OpType::op(oX[f * xEws], y[f * yEws]); + } + } + else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xTadShapeShapeInfo, zTadShapeInfo)) { + + uint tadShapeShapeInfoCast[MAX_RANK]; + bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast); + + PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) + for (int i = 0; i < tads; i++) { + + auto oZ = z + zTadOffset[i]; + auto oX = x + tadOffsets[i]; + + // TODO: cover this codebranch with tests + // all this stuff already happens within thread + PRAGMA_OMP_SIMD + for (int f = 0; f < tadLength; f++) { + auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, tadLength, canCastX); + oZ[offset] = OpType::op(oX[offset], y[offset]); + } + } + } + else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo)) { + + uint tadShapeShapeInfoCast[MAX_RANK]; + uint tadShapeInfoZCast[MAX_RANK]; + bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast); + bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast); + + PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) + for (int i = 0; i < tads; i++) { + + auto oZ = z + zTadOffset[i]; + auto oX = x + tadOffsets[i]; + + PRAGMA_OMP_SIMD + for (int f = 0; f < tadLength; f++) { + auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, tadLength, canCastX); + auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, lenZ, canCastZ); + oZ[zOffset] = OpType::op(oX[offset], y[offset]); + } + } + } + else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, zTadShapeInfo)) { + + uint tadShapeShapeInfoCast[MAX_RANK]; + uint yShapeInfoCast[MAX_RANK]; + bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast); + bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); + + PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) + for (int i = 0; i < tads; i++) { + + auto oZ = z + zTadOffset[i]; + auto oX = x + tadOffsets[i]; + + PRAGMA_OMP_SIMD + for (int f = 0; f < tadLength; f++) { + auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, tadLength, canCastX); + auto yOffset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, lenY, canCastY); + oZ[offset] = OpType::op(oX[offset], y[yOffset]); + } + } + } + else if(shape::haveSameShapeAndStrides(yShapeInfo, zTadShapeInfo)) { + + uint tadShapeShapeInfoCast[MAX_RANK]; + uint yShapeInfoCast[MAX_RANK]; + bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast); + bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); + + PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) + for (int i = 0; i < tads; i++) { + + auto oZ = z + zTadOffset[i]; + auto oX = x + tadOffsets[i]; + + PRAGMA_OMP_SIMD + for (int f = 0; f < tadLength; f++) { + auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, tadLength, canCastX); + auto offset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, lenY, canCastY); + oZ[offset] = OpType::op(oX[xOffset], y[offset]); + } + } + } + else { + + uint tadShapeShapeInfoCast[MAX_RANK]; + uint tadShapeInfoZCast[MAX_RANK]; + uint yShapeInfoCast[MAX_RANK]; + bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xTadShapeShapeInfo, tadShapeShapeInfoCast); + bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); + bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast); + + PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) + for (int i = 0; i < tads; i++) { + + auto oZ = z + zTadOffset[i]; + auto oX = x + tadOffsets[i]; + + PRAGMA_OMP_SIMD + for (int f = 0; f < tadLength; f++) { + auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, tadLength, canCastX); + auto yOffset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, lenY, canCastY); + auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, lenZ, canCastZ); + oZ[zOffset] = OpType::op(oX[xOffset], y[yOffset]); + } + } + } + } + + + template + template + void BroadcastInt::execInverse(void *vx, + Nd4jLong *xShapeInfo, + void *vy, + Nd4jLong *yShapeInfo, + void *vz, + Nd4jLong *zShapeInfo, + int *dimension, + int dimensionLength, + Nd4jLong *yTadShapeInfo, + Nd4jLong *yTadOffset, + Nd4jLong *zTadShapeInfo, + Nd4jLong *zTadOffset) { + + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + + //decompose in to several sub tads after + //moving all dimensions (in sorted order) + //to the back. + //permuted version of the x shape info for setting up the tad problem + auto yTadShapeShapeInfo = yTadShapeInfo; + auto tadOffsets = yTadOffset; + + if (yTadShapeInfo == nullptr || tadOffsets == nullptr) { + auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(yShapeInfo, dimension, dimensionLength); + + yTadShapeShapeInfo = tadPack.primaryShapeInfo(); + tadOffsets = tadPack.primaryOffsets(); + } + + //int *resultStride = shape::stride(yTadShapeShapeInfo); + unsigned int tadLength = shape::length(yTadShapeShapeInfo); + unsigned int tads = shape::length(yShapeInfo) / tadLength; + + if (zTadShapeInfo == nullptr) { + zTadShapeInfo = yTadShapeShapeInfo; + zTadOffset = tadOffsets; + } + + auto lenZ = shape::length(zTadShapeInfo); + auto lenX = shape::length(xShapeInfo); + + int tadsPerThread = tads / TAD_THRESHOLD; + int threads = nd4j::math::nd4j_max(1, tadsPerThread); + threads = nd4j::math::nd4j_min(threads, omp_get_max_threads()); + + auto yEws = shape::elementWiseStride(yTadShapeShapeInfo); + auto xEws = shape::elementWiseStride(xShapeInfo); + auto zEws = shape::elementWiseStride(zTadShapeInfo); + + const nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXYZ(yTadShapeShapeInfo, xShapeInfo, zTadShapeInfo); + + if (kindOfLoop == nd4j::LoopKind::EWS1) { + PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) + for (int i = 0; i < tads; i++) { + auto oY = y + tadOffsets[i]; + auto oZ = z + zTadOffset[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) + oZ[f] = OpType::op(x[f], oY[f]); + } + } + else if(kindOfLoop == nd4j::LoopKind::EWSNONZERO) { + PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) + for (int i = 0; i < tads; i++) { + auto oY = y + tadOffsets[i]; + auto oZ = z + zTadOffset[i]; + + PRAGMA_OMP_SIMD + for (uint f = 0; f < tadLength; f++) + oZ[f * zEws] = OpType::op(x[f * xEws], oY[f * yEws]); + } + } + else if(shape::haveSameShapeAndStrides(yTadShapeShapeInfo, xShapeInfo) && shape::haveSameShapeAndStrides(yTadShapeShapeInfo, zTadShapeInfo)) { + + uint tadShapeShapeInfoCast[MAX_RANK]; + bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast); + + PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) + for (int i = 0; i < tads; i++) { + + auto oY = y + tadOffsets[i]; + auto oZ = z + zTadOffset[i]; + + // TODO: cover this codebranch with tests + // all this stuff already happens within thread + PRAGMA_OMP_SIMD + for (int f = 0; f < tadLength; f++) { + auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, tadLength, canCastY); + oZ[offset] = OpType::op(x[offset], oY[offset]); + } + } + } + else if(shape::haveSameShapeAndStrides(yTadShapeShapeInfo, xShapeInfo)) { + + uint tadShapeShapeInfoCast[MAX_RANK]; + uint tadShapeInfoZCast[MAX_RANK]; + bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast); + bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast); + + PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) + for (int i = 0; i < tads; i++) { + + auto oZ = z + zTadOffset[i]; + auto oY = y + tadOffsets[i]; + + PRAGMA_OMP_SIMD + for (int f = 0; f < tadLength; f++) { + auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, tadLength, canCastY); + auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, lenZ, canCastZ); + oZ[zOffset] = OpType::op(x[offset], oY[offset]); + } + } + } + else if(shape::haveSameShapeAndStrides(yTadShapeShapeInfo, zTadShapeInfo)) { + + uint tadShapeShapeInfoCast[MAX_RANK]; + uint xShapeInfoCast[MAX_RANK]; + bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast); + + PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) + for (int i = 0; i < tads; i++) { + + auto oZ = z + zTadOffset[i]; + auto oY = y + tadOffsets[i]; + + PRAGMA_OMP_SIMD + for (int f = 0; f < tadLength; f++) { + auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, tadLength, canCastY); + auto xOffset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, lenX, canCastX); + oZ[offset] = OpType::op(x[xOffset], oY[offset]); + } + } + } + else if(shape::haveSameShapeAndStrides(xShapeInfo, zTadShapeInfo)) { + + uint tadShapeShapeInfoCast[MAX_RANK]; + uint xShapeInfoCast[MAX_RANK]; + bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast); + + PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) + for (int i = 0; i < tads; i++) { + + auto oZ = z + zTadOffset[i]; + auto oY = y + tadOffsets[i]; + + PRAGMA_OMP_SIMD + for (int f = 0; f < tadLength; f++) { + auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, tadLength, canCastY); + auto offset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, lenX, canCastX); + oZ[offset] = OpType::op(x[offset], oY[yOffset]); + } + } + } + else { + + uint xShapeInfoCast[MAX_RANK]; + uint tadShapeShapeInfoCast[MAX_RANK]; + uint tadShapeInfoZCast[MAX_RANK]; + bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yTadShapeShapeInfo, tadShapeShapeInfoCast); + bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zTadShapeInfo, tadShapeInfoZCast); + + PRAGMA_OMP_PARALLEL_FOR_THREADS(threads) + for (int i = 0; i < tads; i++) { + + auto oZ = z + zTadOffset[i]; + auto oY = y + tadOffsets[i]; + + PRAGMA_OMP_SIMD + for (int f = 0; f < tadLength; f++) { + auto xOffset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, lenX, canCastX); + auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, tadLength, canCastY); + auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, lenZ, canCastZ); + oZ[zOffset] = OpType::op(x[xOffset], oY[yOffset]); + } + } + } + } + + BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT BroadcastInt, , INTEGER_TYPES); + } +} \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/pairwise_int.cpp b/libnd4j/include/loops/cpu/pairwise_int.cpp new file mode 100644 index 000000000..b356adcc2 --- /dev/null +++ b/libnd4j/include/loops/cpu/pairwise_int.cpp @@ -0,0 +1,309 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include +#include +#include +#include + +using namespace simdOps; + +namespace functions { + namespace pairwise_transforms { + + template + void PairWiseIntTransform::exec( + const int opNum, + void *x, + Nd4jLong xEws, + void *y, + Nd4jLong yEws, + void *z, + Nd4jLong zEws, + void *extraParams, + Nd4jLong n) { + DISPATCH_BY_OPNUM_T(exec, PARAMS(x, + xEws, + y, + yEws, + z, + zEws, + extraParams, + n), PAIRWISE_INT_OPS); + }; + + + + template + template + void PairWiseIntTransform::exec(void *vx, + Nd4jLong xEws, + void *vy, + Nd4jLong yEws, + void *vz, + Nd4jLong zEws, + void *vextraParams, + const Nd4jLong n) { + + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + + nd4j::OmpLaunchHelper info(n); + + if (xEws == 1 && yEws == 1 && zEws == 1) { + + PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) + { + auto threadNum = omp_get_thread_num(); + Nd4jLong threadOffset = info.getThreadOffset(threadNum); + auto xi = x + threadOffset; + auto yi = y + threadOffset; + auto zi = z + threadOffset; + auto ulen = static_cast(info.getItersPerThread(threadNum)); + + PRAGMA_OMP_SIMD + for (Nd4jLong i = 0; i < ulen; i++) + zi[i] = OpType::op(xi[i], yi[i], extraParams); + } + } + else { + + PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) + { + auto threadNum = omp_get_thread_num(); + Nd4jLong threadOffset = info.getThreadOffset(threadNum); + auto xi = x + xEws*threadOffset; + auto yi = y + yEws*threadOffset; + auto zi = z + zEws*threadOffset; + auto ulen = static_cast(info.getItersPerThread(threadNum)); + + PRAGMA_OMP_SIMD + for (Nd4jLong i = 0; i < ulen; i++) + zi[i*zEws] = OpType::op(xi[i*xEws], yi[i*yEws], extraParams); + } + } + } + + template + void PairWiseIntTransform::exec( + const int opNum, + void *x, + Nd4jLong *xShapeInfo, + void *y, + Nd4jLong *yShapeInfo, + void *z, + Nd4jLong *zShapeInfo, + void *extraParams) { + DISPATCH_BY_OPNUM_T(exec, PARAMS(x, + xShapeInfo, + y, + yShapeInfo, + z, + zShapeInfo, + extraParams), + PAIRWISE_INT_OPS); + }; + + + template + template + void PairWiseIntTransform::exec(void *vx, Nd4jLong* xShapeInfo, + void *vy, Nd4jLong* yShapeInfo, + void *vz, Nd4jLong* zShapeInfo, + void *vextraParams) { + + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + + auto n = shape::length(xShapeInfo); + auto xEws = shape::elementWiseStride(xShapeInfo); + auto yEws = shape::elementWiseStride(yShapeInfo); + auto zEws = shape::elementWiseStride(zShapeInfo); + + nd4j::OmpLaunchHelper info(n); + + if (shape::isScalar(yShapeInfo)) { + + uint xShapeInfoCast[MAX_RANK]; + const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + + if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { + + PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) + { + auto threadNum = omp_get_thread_num(); + auto threadOffset = info.getThreadOffset(threadNum); + auto ulen = static_cast(info.getItersPerThread(threadNum)); + + PRAGMA_OMP_SIMD + for(Nd4jLong i = 0; i < ulen; i++) { + auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, n, canCastX); + z[offset] = OpType::op(x[offset], y[0], extraParams); + } + } + } + else { + + uint zShapeInfoCast[MAX_RANK]; + const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); + + PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) + { + auto threadNum = omp_get_thread_num(); + auto threadOffset = info.getThreadOffset(threadNum); + auto ulen = static_cast(info.getItersPerThread(threadNum)); + + PRAGMA_OMP_SIMD + for(Nd4jLong i = 0; i < ulen; i++) { + auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, n, canCastX); + auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, n, canCastZ); + z[zOffset] = OpType::op(x[xOffset], y[0], extraParams); + } + } + } + return; + } + + const nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXYZ(xShapeInfo, yShapeInfo, zShapeInfo); + const bool sameShapesXY = shape::shapeEquals(xShapeInfo, yShapeInfo); + + if ((kindOfLoop == nd4j::LoopKind::EWS1 || kindOfLoop == nd4j::LoopKind::EWSNONZERO) && sameShapesXY) { + exec(x, xEws, y, yEws, z, zEws, extraParams, n); + } + else if ((kindOfLoop == nd4j::LoopKind::EWS1 || kindOfLoop == nd4j::LoopKind::EWSNONZERO) && !sameShapesXY) { //not same shape + exec(x, xEws, y, yEws, z, zEws, extraParams, shape::length(yShapeInfo)); + } + else { + + if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { + + uint xShapeInfoCast[MAX_RANK]; + const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + + PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) + { + auto threadNum = omp_get_thread_num(); + auto threadOffset = info.getThreadOffset(threadNum); + auto ulen = static_cast(info.getItersPerThread(threadNum)); + + PRAGMA_OMP_SIMD + for (Nd4jLong i = 0; i < ulen; i++) { + auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, n, canCastX); + z[offset] = OpType::op(x[offset], y[offset], extraParams); + } + } + } + else if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) { + + uint xShapeInfoCast[MAX_RANK]; + uint zShapeInfoCast[MAX_RANK]; + const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); + + PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) + { + auto threadNum = omp_get_thread_num(); + auto threadOffset = info.getThreadOffset(threadNum); + auto ulen = static_cast(info.getItersPerThread(threadNum)); + + PRAGMA_OMP_SIMD + for (Nd4jLong i = 0; i < ulen; i++) { + auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, n, canCastX); + auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, n, canCastZ); + z[zOffset] = OpType::op(x[offset], y[offset], extraParams); + } + } + } + else if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { + + uint xShapeInfoCast[MAX_RANK]; + uint yShapeInfoCast[MAX_RANK]; + const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + const bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); + + PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) + { + auto threadNum = omp_get_thread_num(); + auto threadOffset = info.getThreadOffset(threadNum); + auto ulen = static_cast(info.getItersPerThread(threadNum)); + + PRAGMA_OMP_SIMD + for (Nd4jLong i = 0; i < ulen; i++) { + auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, n, canCastX); + auto yOffset = shape::indexOffset(i + threadOffset, yShapeInfo, yShapeInfoCast, n, canCastY); + z[offset] = OpType::op(x[offset], y[yOffset], extraParams); + } + } + } + else if(shape::haveSameShapeAndStrides(yShapeInfo, zShapeInfo)) { + + uint xShapeInfoCast[MAX_RANK]; + uint yShapeInfoCast[MAX_RANK]; + const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + const bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); + + PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) + { + auto threadNum = omp_get_thread_num(); + auto threadOffset = info.getThreadOffset(threadNum); + auto ulen = static_cast(info.getItersPerThread(threadNum)); + + PRAGMA_OMP_SIMD + for (Nd4jLong i = 0; i < ulen; i++) { + auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, n, canCastX); + auto offset = shape::indexOffset(i + threadOffset, yShapeInfo, yShapeInfoCast, n, canCastY); + z[offset] = OpType::op(x[xOffset], y[offset], extraParams); + } + } + } + else { + + uint xShapeInfoCast[MAX_RANK]; + uint yShapeInfoCast[MAX_RANK]; + uint zShapeInfoCast[MAX_RANK]; + const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + const bool canCastY = nd4j::DataTypeUtils::castShapeInfo(yShapeInfo, yShapeInfoCast); + const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); + + PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) + { + auto threadNum = omp_get_thread_num(); + auto threadOffset = info.getThreadOffset(threadNum); + auto ulen = static_cast(info.getItersPerThread(threadNum)); + + PRAGMA_OMP_SIMD + for (Nd4jLong i = 0; i < ulen; i++) { + auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, n, canCastX); + auto yOffset = shape::indexOffset(i + threadOffset, yShapeInfo, yShapeInfoCast, n, canCastY); + auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, n, canCastZ); + z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams); + } + } + } + } + } + + BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT PairWiseIntTransform, , INTEGER_TYPES); + } +} diff --git a/libnd4j/include/loops/cpu/scalar_int.cpp b/libnd4j/include/loops/cpu/scalar_int.cpp new file mode 100644 index 000000000..9920cc836 --- /dev/null +++ b/libnd4j/include/loops/cpu/scalar_int.cpp @@ -0,0 +1,255 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include "../scalar_int.h" +#include +#include +#include + +#include "../legacy_ops.h" + +using namespace simdOps; + +namespace functions { + namespace scalar { + + + template + template + void ScalarIntTransform::transform(void *vx, Nd4jLong *xShapeInfo, + void *vextraParams, + void *vz, Nd4jLong *zShapeInfo, + void *vscalars, + int *dimension, int dimensionLength, + Nd4jLong *xTadShapeInfo, Nd4jLong *xTadOffsets, + Nd4jLong *zTadShapeInfo, Nd4jLong *zTadOffsets) { + + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto scalars = reinterpret_cast(vscalars); + auto extraParams = reinterpret_cast(vextraParams); + + if (zTadShapeInfo == nullptr) { + zTadShapeInfo = xTadShapeInfo; + zTadOffsets = xTadOffsets; + } + + // tad preparation + const int xTadEws = shape::elementWiseStride(xTadShapeInfo); + const int zTadEws = shape::elementWiseStride(zTadShapeInfo); + const int tadLength = shape::tadLength(xShapeInfo, dimension, dimensionLength); + const int numTads = shape::length(xShapeInfo) / tadLength; + + nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXZ(xTadShapeInfo, zTadShapeInfo); + + if (kindOfLoop != nd4j::LoopKind::EWS1 && kindOfLoop != nd4j::LoopKind::EWSNONZERO) { + printf("ScalarIntTransform::transform: super-bad loop visited. Shouldn't ever happen\n"); + return; + } + + int num_threads = nd4j::math::nd4j_min(numTads, omp_get_max_threads()); + + if (kindOfLoop == nd4j::LoopKind::EWS1) { + PRAGMA_OMP_PARALLEL_FOR_THREADS(num_threads) + for (unsigned int r = 0; r < numTads; r++) { + auto oZ = z + zTadOffsets[r]; + auto oX = x + xTadOffsets[r]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) + oZ[f] = OpType::op(oX[f], scalars[r], extraParams); + } + } + else { // kindOfLoop != nd4j::LoopKind::EWSNONZERO + PRAGMA_OMP_PARALLEL_FOR_THREADS(num_threads) + for (unsigned int r = 0; r < numTads; r++) { + auto oZ = z + zTadOffsets[r]; + auto oX = x + xTadOffsets[r]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < tadLength; f++) + oZ[f * zTadEws] = OpType::op(oX[f * xTadEws], scalars[r], extraParams); + } + } + } + + template + void ScalarIntTransform::transform(int opNum, + void *x, + Nd4jLong *xShapeInfo, + void *extraParams, + void *z, + Nd4jLong *zShapeInfo, + void *scalars, + int *dimension, + int dimensionLength, + Nd4jLong *xTadShapeInfo, + Nd4jLong *xTadOffsets, + Nd4jLong *zTadShapeInfo, + Nd4jLong *zTadOffsets) { + DISPATCH_BY_OPNUM_T(transform, PARAMS(x, xShapeInfo, extraParams, z, zShapeInfo, scalars, dimension, dimensionLength, xTadShapeInfo, xTadOffsets, zTadShapeInfo, zTadOffsets), SCALAR_INT_OPS); + } + + + template + void ScalarIntTransform::transform(const int opNum, + void *x, + Nd4jLong xEws, + void *z, + Nd4jLong zEws, + void *scalar, + void *extraParams, + const Nd4jLong n) { + DISPATCH_BY_OPNUM_T(transform, PARAMS(x, xEws, z, zEws, scalar, extraParams, n), SCALAR_INT_OPS); + } + + template + void ScalarIntTransform::transform(const int opNum, + void *x, + Nd4jLong *xShapeInfo, + void *z, + Nd4jLong *zShapeInfo, + void *scalar, + void *extraParams) { + DISPATCH_BY_OPNUM_T(transform, PARAMS(x, xShapeInfo, z, zShapeInfo, scalar, extraParams), SCALAR_INT_OPS); + } + + template + template + void ScalarIntTransform::transform(void *vx, + Nd4jLong *xShapeInfo, + void *vz, + Nd4jLong *zShapeInfo, + void *vscalar, + void *vextraParams) { + + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto scalar = reinterpret_cast(vscalar)[0]; + auto extraParams = reinterpret_cast(vextraParams); + + auto xEws = shape::elementWiseStride(xShapeInfo); + auto zEws = shape::elementWiseStride(zShapeInfo); + auto len = shape::length(xShapeInfo); + + // nd4j_logger("Launching scalar: xOrder: %i; zOrder: %i; xEWS: %i\n", xOrder, zOrder, xEws); + + nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXZ(xShapeInfo, zShapeInfo); + + if (kindOfLoop == nd4j::LoopKind::EWS1 || kindOfLoop == nd4j::LoopKind::EWSNONZERO) { + transform(x, xEws, z, zEws, vscalar, extraParams, len); + return; + } + + uint xShapeInfoCast[MAX_RANK]; + const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); + + nd4j::OmpLaunchHelper info(len); + + if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { + + PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) + { + auto threadNum = omp_get_thread_num(); + auto threadOffset = info.getThreadOffset(threadNum); + auto ulen = static_cast(info.getItersPerThread(threadNum)); + + PRAGMA_OMP_SIMD + for (unsigned int i = 0; i < ulen; i++) { + auto offset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, len, canCastX); + z[offset] = OpType::op(x[offset], scalar, extraParams); + } + } + } + else { + + uint zShapeInfoCast[MAX_RANK]; + const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); + + PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) + { + auto threadNum = omp_get_thread_num(); + auto threadOffset = info.getThreadOffset(threadNum); + auto ulen = static_cast(info.getItersPerThread(threadNum)); + + PRAGMA_OMP_SIMD + for (unsigned int i = 0; i < ulen; i++) { + auto xOffset = shape::indexOffset(i + threadOffset, xShapeInfo, xShapeInfoCast, len, canCastX); + auto zOffset = shape::indexOffset(i + threadOffset, zShapeInfo, zShapeInfoCast, len, canCastZ); + z[zOffset] = OpType::op(x[xOffset], scalar, extraParams); + } + } + } + } + + + template + template + void ScalarIntTransform::transform(void *vx, + Nd4jLong xEws, + void *vz, + Nd4jLong zEws, + void *vscalar, + void *vextraParams, + const Nd4jLong len) { + + auto x = reinterpret_cast(vx); + auto z = reinterpret_cast(vz); + auto scalar = reinterpret_cast(vscalar)[0]; + auto extraParams = reinterpret_cast(vextraParams); + + nd4j::OmpLaunchHelper info(len); + + if (xEws == 1 && zEws == 1) { + + PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) + { + auto threadNum = omp_get_thread_num(); + auto threadOffset = info.getThreadOffset(threadNum); + auto xi = x + threadOffset; + auto zi = z + threadOffset; + auto ulen = static_cast(info.getItersPerThread(threadNum)); + + PRAGMA_OMP_SIMD + for (unsigned int i = 0; i < ulen; i++) + zi[i] = OpType::op(xi[i], scalar, extraParams); + } + } + else { + + PRAGMA_OMP_PARALLEL_THREADS(info._numThreads) + { + auto threadNum = omp_get_thread_num(); + auto threadOffset = info.getThreadOffset(threadNum); + auto xi = x + xEws * threadOffset; + auto zi = z + zEws * threadOffset; + auto ulen = static_cast(info.getItersPerThread(threadNum)); + + PRAGMA_OMP_SIMD + for (unsigned int i = 0; i < ulen; i++) + zi[i * zEws] = OpType::op(xi[i * xEws], scalar, extraParams); + } + } + } + + BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT ScalarIntTransform, , INTEGER_TYPES); + +} +} diff --git a/libnd4j/include/loops/cuda/broadcasting_int.cu b/libnd4j/include/loops/cuda/broadcasting_int.cu new file mode 100644 index 000000000..38193f35d --- /dev/null +++ b/libnd4j/include/loops/cuda/broadcasting_int.cu @@ -0,0 +1,291 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace simdOps; + +////////////////////////////////////////////////////////////////////////// +template +static __global__ void broadcastIntSimple( + void *x, + Nd4jLong *xShapeInfo, + void *y, + Nd4jLong *yShapeInfo, + void *z, + Nd4jLong *zShapeInfo, + int *dimension, + int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { + + functions::broadcast::BroadcastInt::template transformCuda(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo,dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ); +} + +////////////////////////////////////////////////////////////////////////// +template +static __global__ void broadcastBoolInverseSimple( + void *x, + Nd4jLong *xShapeInfo, + void *y, + Nd4jLong *yShapeInfo, + void *z, + Nd4jLong *zShapeInfo, + int *dimension, + int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { + + functions::broadcast::BroadcastInt::template transformInverseCuda(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo,dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ); +} + +namespace functions { + namespace broadcast { +////////////////////////////////////////////////////////////////////////// + template + template + __host__ void BroadcastInt::intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { + broadcastIntSimple<<>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ); + } + +////////////////////////////////////////////////////////////////////////// + template + __host__ void BroadcastInt::execBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { + DISPATCH_BY_OPNUM_T(intermediateBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), OPS_A(BROADCAST_INT_OPS)) + } + +////////////////////////////////////////////////////////////////////////// + template + template + __host__ void BroadcastInt::intermediateInverseBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { + broadcastBoolInverseSimple<<>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ); + } + +////////////////////////////////////////////////////////////////////////// + template + __host__ void BroadcastInt::execInverseBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { + DISPATCH_BY_OPNUM_T(intermediateInverseBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), OPS_A(BROADCAST_INT_OPS)) + } + +////////////////////////////////////////////////////////////////////////// + template + template + __device__ void BroadcastInt::transformInverseCuda( + void *vx, Nd4jLong *xShapeInfo, + void *vy, Nd4jLong *yShapeInfo, + void *vz, Nd4jLong *zShapeInfo, + int *dimension, int dimensionLength, + Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { + + if (tadOnlyShapeInfoZ == nullptr) { + tadOnlyShapeInfoZ = tadOnlyShapeInfo; + tadOffsetsZ = tadOffsets; + } + + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + + //decompose in to several sub tads after + //moving all dimensions (in sorted order) + //to the back. + //permuted version of the x shape info for setting up the tad problem + __shared__ Nd4jLong tadLength; + __shared__ Nd4jLong tadEWS; + __shared__ int numTads; + __shared__ Nd4jLong xEWS; + __shared__ Nd4jLong zEWS; + + if (threadIdx.x == 0) { + tadLength = shape::length(tadOnlyShapeInfo);//shape::tadLength(xShapeInfo, dimension, dimensionLength); + tadEWS = shape::elementWiseStride(tadOnlyShapeInfo); + numTads = shape::length(yShapeInfo) / tadLength; + xEWS = shape::elementWiseStride(xShapeInfo); + zEWS = shape::elementWiseStride(tadOnlyShapeInfoZ); + } + __syncthreads(); + + for (int r = blockIdx.x; r < numTads; r += gridDim.x) { + auto rZ = z + tadOffsetsZ[r]; + auto rY = y + tadOffsets[r]; + + if(tadEWS > 0 && zEWS > 0 && xEWS > 0 && dimensionLength == 1) { + + for (int i = threadIdx.x; i < tadLength; i+= blockDim.x) + rZ[i * zEWS] = OpType::op(x[i * xEWS], rY[i * tadEWS]); + } + else { + // it is expected that x and z tads and y array all have the same length + for (Nd4jLong i = threadIdx.x; i < tadLength; i+= blockDim.x) { + auto xOffset = shape::getIndexOffset(i, xShapeInfo, tadLength); + auto yOffset = shape::getIndexOffset(i, tadOnlyShapeInfo, tadLength); + auto zOffset = shape::getIndexOffset(i, tadOnlyShapeInfoZ, tadLength); + + rZ[zOffset] = OpType::op(x[xOffset], rY[yOffset]); + } + } + } + } + +////////////////////////////////////////////////////////////////////////// + template + template + __device__ void BroadcastInt::transformCuda( + void *vx, Nd4jLong *xShapeInfo, + void *vy, Nd4jLong *yShapeInfo, + void *vz, Nd4jLong *zShapeInfo, + int *dimension, int dimensionLength, + Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { + + if (tadOnlyShapeInfoZ == nullptr) { + tadOnlyShapeInfoZ = tadOnlyShapeInfo; + tadOffsetsZ = tadOffsets; + } + + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + + //decompose in to several sub tads after + //moving all dimensions (in sorted order) + //to the back. + //permuted version of the x shape info for setting up the tad problem + __shared__ Nd4jLong tadLength; + __shared__ Nd4jLong tadEWS; + __shared__ int numTads; + __shared__ Nd4jLong yEWS; + __shared__ Nd4jLong zEWS; + + if (threadIdx.x == 0) { + tadLength = shape::length(tadOnlyShapeInfo);//shape::tadLength(xShapeInfo, dimension, dimensionLength); + tadEWS = shape::elementWiseStride(tadOnlyShapeInfo); + numTads = shape::length(xShapeInfo) / tadLength; + yEWS = shape::elementWiseStride(yShapeInfo); + zEWS = shape::elementWiseStride(tadOnlyShapeInfoZ); + } + __syncthreads(); + + __shared__ X *rZ; + __shared__ X *rX; + + for (int r = blockIdx.x; r < numTads; r += gridDim.x) { + + if (threadIdx.x == 0) { + rZ = z + tadOffsetsZ[r]; + rX = x + tadOffsets[r]; + } + __syncthreads(); + + + if(tadEWS > 0 && zEWS > 0 && yEWS > 0 && dimensionLength == 1) { + + for (int i = threadIdx.x; i < tadLength; i+= blockDim.x) + rZ[i * zEWS] = OpType::op(rX[i * tadEWS], y[i * yEWS]); + } + else { + // it is expected that x and z tads and y array all have the same length + for (Nd4jLong i = threadIdx.x; i < tadLength; i+= blockDim.x) { + auto xOffset = shape::getIndexOffset(i, tadOnlyShapeInfo, tadLength); + auto yOffset = shape::getIndexOffset(i, yShapeInfo, tadLength); + auto zOffset = shape::getIndexOffset(i, tadOnlyShapeInfoZ, tadLength); + + rZ[zOffset] = OpType::op(rX[xOffset], y[yOffset]); + } + } + } + } + + + template + void BroadcastInt::exec(int opNum, + void *x, + Nd4jLong *xShapeInfo, + void *y, + Nd4jLong *yShapeInfo, + void *result, + Nd4jLong *resultShapeInfo, + int *dimension, + int dimensionLength, + Nd4jLong *tadShapeInfo, + Nd4jLong *tadOffset, + Nd4jLong *tadShapeInfoZ, + Nd4jLong *tadOffsetZ) { + + } + + template + void BroadcastInt::execInverse(int opNum, + void *x, + Nd4jLong *xShapeInfo, + void *y, + Nd4jLong *yShapeInfo, + void *result, + Nd4jLong *resultShapeInfo, + int *dimension, + int dimensionLength, + Nd4jLong *tadShapeInfo, + Nd4jLong *tadOffset, + Nd4jLong *tadShapeInfoZ, + Nd4jLong *tadOffsetZ) { + + } + + template + template + void BroadcastInt::exec(void *x, + Nd4jLong *xShapeInfo, + void *y, + Nd4jLong *yShapeInfo, + void *result, + Nd4jLong *resultShapeInfo, + int *dimension, + int dimensionLength, + Nd4jLong *tadShapeInfo, + Nd4jLong *tadOffset, + Nd4jLong *tadShapeInfoZ, + Nd4jLong *tadOffsetZ) { + + } + + template + template + void BroadcastInt::execInverse(void *x, + Nd4jLong *xShapeInfo, + void *y, + Nd4jLong *yShapeInfo, + void *result, + Nd4jLong *resultShapeInfo, + int *dimension, + int dimensionLength, + Nd4jLong *tadShapeInfo, + Nd4jLong *tadOffset, + Nd4jLong *tadShapeInfoZ, + Nd4jLong *tadOffsetZ) { + + } + + BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT BroadcastInt, , INTEGER_TYPES); + } +} \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/pairwise_int.cu b/libnd4j/include/loops/cuda/pairwise_int.cu new file mode 100644 index 000000000..5cc12846c --- /dev/null +++ b/libnd4j/include/loops/cuda/pairwise_int.cu @@ -0,0 +1,173 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// @author raver119@gmail.com +// @author Yurii Shyrma (iuriish@yahoo.com), created on 08.11.2018 + +#ifndef PAIRWISE_INT_CU +#define PAIRWISE_INT_CU + + +#include "../pairwise_int.h" + + +using namespace simdOps; + +//////////////////////////////////////////////////////////////////////////////// +template +__global__ static void pairwiseSimpleShaped(void* vx, Nd4jLong *xShapeInfo, + void *vy, Nd4jLong *yShapeInfo, + void *vz, Nd4jLong *zShapeInfo, + void *vextraParams) { + + auto x = reinterpret_cast(vx); + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + __shared__ int xEws; + __shared__ int yEws; + __shared__ int zEws; + __shared__ char xOrder; + __shared__ char yOrder; + __shared__ char zOrder; + __shared__ Nd4jLong len; + + if (threadIdx.x == 0) { + xEws = shape::elementWiseStride(xShapeInfo); + yEws = shape::elementWiseStride(yShapeInfo); + zEws = shape::elementWiseStride(zShapeInfo); + xOrder = shape::order(xShapeInfo); + yOrder = shape::order(yShapeInfo); + zOrder = shape::order(zShapeInfo); + len = shape::length(xShapeInfo); + } + __syncthreads(); + + + if (xEws >= 1 && yEws >= 1 && zEws >= 1 && xOrder == yOrder && xOrder == zOrder) { + for (Nd4jLong i = tid; i < len; i += gridDim.x * blockDim.x) { + z[i * zEws] = OpType::op(x[i * xEws], y[i * yEws], extraParams); + } + } + else if (vx == vz) { + for (Nd4jLong i = tid; i < len; i += gridDim.x * blockDim.x) { + auto xOffset = shape::getIndexOffset(i, xShapeInfo, len); + auto yOffset = shape::getIndexOffset(i, yShapeInfo, len); + + z[xOffset] = OpType::op(x[xOffset], y[yOffset], extraParams); + } + } + else { + for (Nd4jLong i = tid; i < len; i += gridDim.x * blockDim.x) { + auto xOffset = shape::getIndexOffset(i, xShapeInfo, len); + auto yOffset = shape::getIndexOffset(i, yShapeInfo, len); + auto zOffset = shape::getIndexOffset(i, zShapeInfo, len); + + z[zOffset] = OpType::op(x[xOffset], y[yOffset], extraParams); + } + } +} + + +namespace functions { +namespace pairwise_transforms { + +//////////////////////////////////////////////////////////////////////////////// +template +template +void _CUDA_H PairWiseIntTransform::intermediateShaped(dim3& launchDims, cudaStream_t *stream, + void *vx, Nd4jLong *xShapeInfo, + void *vy, Nd4jLong *yShapeInfo, + void *vz, Nd4jLong *zShapeInfo, + void *vextraParams){ + + pairwiseSimpleShaped<<>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vextraParams); +} + + +//////////////////////////////////////////////////////////////////////////////// +template +void PairWiseIntTransform::executeCudaShaped(dim3& launchDims, cudaStream_t *stream, int opNum, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, void *vextraParams) { + auto xType = nd4j::DataTypeUtils::fromT(); + + DISPATCH_BY_OPNUM_T(intermediateShaped, PARAMS(launchDims, stream, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vextraParams), PAIRWISE_INT_OPS); +} + + + template + void PairWiseIntTransform::exec( + const int opNum, + void *dx, + Nd4jLong *xShapeBuffer, + void *y, + Nd4jLong *yShapeBuffer, + void *result, + Nd4jLong *resultShapeBuffer, + void *extraParams) { + + } + + template + void PairWiseIntTransform::exec( + const int opNum, + void *dx, + Nd4jLong xStride, + void *y, + Nd4jLong yStride, + void *result, + Nd4jLong resultStride, + void *extraParams, + Nd4jLong n) { + + } + + + template + template + void PairWiseIntTransform::exec( + void *vx, + Nd4jLong* xShapeBuffer, + void *vy, + Nd4jLong* yShapeBuffer, + void *vresult, + Nd4jLong* resultShapeBuffer, + void *vextraParams) { + + } + + template + template + void PairWiseIntTransform::exec(void *vx, + Nd4jLong xStride, + void *vy, + Nd4jLong yStride, + void *vresult, + Nd4jLong resultStride, + void *vextraParams, + const Nd4jLong n) { + + } + + + + BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT PairWiseIntTransform, , INTEGER_TYPES); +} +} + +#endif // PAIRWISE_INT_CU \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/scalar_int.cu b/libnd4j/include/loops/cuda/scalar_int.cu new file mode 100644 index 000000000..48f141525 --- /dev/null +++ b/libnd4j/include/loops/cuda/scalar_int.cu @@ -0,0 +1,269 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com), created on 08.11.2018 +// @author raver119@gmail.com +// + +#include "../scalar_int.h" +#include +#include + +#include "../legacy_ops.h" + +using namespace simdOps; + +//////////////////////////////////////////////////////////////////////// +template +__global__ void scalarAlongDimension(void *x, Nd4jLong *xShapeInfo, + void *extraParams, + void *z, Nd4jLong *zShapeInfo, + void *scalars, + int *dimension, int dimensionLength, + Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, + Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) { + + functions::scalar::ScalarIntTransform::template transformCuda(x, xShapeInfo, extraParams, z, zShapeInfo, scalars, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ); +} + + +//////////////////////////////////////////////////////////////////////// +template +__global__ void scalarSimpleShaped(void* x, void *y, Nd4jLong *xShapeInfo, void *params, void *z, Nd4jLong *zShapeInfo, int *allocationBuffer) { + + functions::scalar::ScalarIntTransform::template transformCuda(y, x, xShapeInfo, params, z, zShapeInfo, allocationBuffer); +} + + + + + +// *********************************************************************// +// *********************************************************************// +namespace functions { +namespace scalar { + +//////////////////////////////////////////////////////////////////////// +template +template +__device__ void ScalarIntTransform::transformCuda(void* vscalar, + void *vy, Nd4jLong *yShapeInfo, + void *vparams, + void *vz, Nd4jLong *zShapeInfo, + int *allocationBuffer) { + auto scalar = reinterpret_cast(vscalar)[0]; + auto y = reinterpret_cast(vy); + auto params = reinterpret_cast(vparams); + auto z = reinterpret_cast(vz); + + auto yRank = shape::rank(yShapeInfo); + auto yEWS = shape::elementWiseStride(yShapeInfo); + auto yShape = shape::shapeOf(yShapeInfo); + auto yStride = shape::stride(yShapeInfo); + + auto zRank = shape::rank(zShapeInfo); + auto zEWS = shape::elementWiseStride(zShapeInfo); + auto zShape = shape::shapeOf(zShapeInfo); + auto zStride = shape::stride(zShapeInfo); + + int totalThreads = gridDim.x * blockDim.x; + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + __shared__ int len; + if(threadIdx.x == 0) + len = shape::length(yShapeInfo); + __syncthreads(); + + if(yEWS >= 1 && zEWS >= 1 && shape::order(yShapeInfo) == shape::order(zShapeInfo)) { + transformCuda(len, vscalar, vy, yEWS, vparams, vz, zEWS, allocationBuffer); + } + else { + for (Nd4jLong i = tid; i < len; i+= totalThreads) + z[shape::getIndexOffset(i, zShapeInfo, len)] = OpType::op(y[shape::getIndexOffset(i, yShapeInfo, len)], scalar, params); + } +} + +//////////////////////////////////////////////////////////////////////// +template +template +__device__ void ScalarIntTransform::transformCuda(Nd4jLong len, + void* vx, + void *vy, Nd4jLong yEWS, + void *vparams, + void *vz, Nd4jLong zEWS, + int *allocationBuffer) { + + auto x = reinterpret_cast(vx)[0]; + auto y = reinterpret_cast(vy); + auto z = reinterpret_cast(vz); + auto params = reinterpret_cast(vparams); + + int totalThreads = gridDim.x * blockDim.x; + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + Nd4jLong i = tid; + if(yEWS == 1 && zEWS == 1) { + for (; i < len; i += totalThreads) + z[i] = OpType::op(y[i], x, params); + } + else { + for (; i < len; i += totalThreads) + z[i * zEWS] = OpType::op(y[i * yEWS], x, params); + } +} + + +//////////////////////////////////////////////////////////////////////// +template +template +__device__ void ScalarIntTransform::transformCuda(void *vx, Nd4jLong *xShapeInfo, + void *vextraParams, + void *vz, Nd4jLong *zShapeInfo, + void *vscalars, + int *dimension, int dimensionLength, + Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, + Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) { + auto x = reinterpret_cast(vx); + auto scalars = reinterpret_cast(vscalars); + auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); + + if (tadShapeInfoZ == nullptr) { + tadShapeInfoZ = tadShapeInfo; + tadOffsetsZ = tadOffsets; + } + + // tad preparation + auto tadEws = shape::elementWiseStride(tadShapeInfo); + auto zEws = shape::elementWiseStride(tadShapeInfoZ); + auto tadLength = shape::length(tadShapeInfo);//shape::tadLength(xShapeInfo, dimension, dimensionLength); + auto numTads =shape::length(xShapeInfo) / tadLength; + + if (tadEws > 0 && zEws > 0 && shape::order(tadShapeInfo) == shape::order(zShapeInfo)) { + + // main loop, rolling over tads + for (int r = blockIdx.x; r < numTads; r += gridDim.x) { + X *oZ = z + tadOffsetsZ[r]; + X *oX = x + tadOffsets[r]; + + auto s = scalars[r]; + + for (int f = threadIdx.x; f < tadLength; f += blockDim.x) + oZ[f * zEws] = OpType::op(oX[f * tadEws], s, extraParams); + } + } else { + // main loop, rolling over tads + for (int r = blockIdx.x; r < numTads; r += gridDim.x) { + X *oZ = z + tadOffsetsZ[r]; + X *oX = x + tadOffsets[r]; + + auto s = scalars[r]; + + for (int f = threadIdx.x; f < tadLength; f += blockDim.x) + oZ[shape::getIndexOffset(f, tadShapeInfoZ, tadLength)] = OpType::op(oX[shape::getIndexOffset(f, tadShapeInfo, tadLength)], s, extraParams); + } + } +} + + +//////////////////////////////////////////////////////////////////////// +template +template +_CUDA_H void ScalarIntTransform::intermediateAlongDimension(dim3& launchDims, cudaStream_t *stream, + void *x, Nd4jLong *xShapeInfo, + void *z, Nd4jLong *zShapeInfo, + void *scalars, + void *extraParams, + int *dimension, int dimensionLength, + Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, + Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) { + + scalarAlongDimension<<>>(x, xShapeInfo, extraParams, z, zShapeInfo, scalars, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ); +} + +//////////////////////////////////////////////////////////////////////// +template +template +void _CUDA_H ScalarIntTransform::intermediateShaped(dim3& launchDims, cudaStream_t *stream, + void *vx, Nd4jLong *xShapeInfo, + void *vz, Nd4jLong *zShapeInfo, + void* vscalar, + void *vextraParams, int *allocPointer){ + + scalarSimpleShaped<<>>(vx, vscalar, xShapeInfo, vextraParams, vz, zShapeInfo, allocPointer); +} + +//////////////////////////////////////////////////////////////////////// +template +void ScalarIntTransform::executeCudaShaped(dim3& launchDims, cudaStream_t *stream, + int opNum, + void *vx, Nd4jLong *xShapeInfo, + void *vz, Nd4jLong *zShapeInfo, + void* vscalar, + void *vextraParams) { + + if (nd4j::Environment::getInstance()->isDebugAndVerbose()) + printf("H14 opNum:[%i]\n", opNum); + + DISPATCH_BY_OPNUM_T(intermediateShaped, PARAMS(launchDims, stream, vx, xShapeInfo, vz, zShapeInfo, vscalar, vextraParams, nullptr), SCALAR_INT_OPS); +} + +//////////////////////////////////////////////////////////////////////// +template +void ScalarIntTransform::executeCudaAlongDimension(dim3& launchDims, cudaStream_t *stream, int opNum, void *vx, Nd4jLong *xShapeInfo, void *vz, Nd4jLong *zShapeInfo, void *vscalars, void *vextraParams, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) { + DISPATCH_BY_OPNUM_T(intermediateAlongDimension, PARAMS(launchDims, stream, vx, xShapeInfo, vz, zShapeInfo, vscalars, vextraParams, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), SCALAR_INT_OPS); +} + + BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT ScalarIntTransform, , INTEGER_TYPES); + + + template + template + void ScalarIntTransform::transform(void *x, Nd4jLong *xShapeInfo, void *extraParams, void *z, Nd4jLong *zShapeInfo, void *scalars, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) { + + } + + template + void ScalarIntTransform::transform(int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams, void *z, Nd4jLong *zShapeInfo, void *scalars, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) { + + } + + template + void ScalarIntTransform::transform(const int opNum, void *x, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *scalar, void *extraParams) { + + } + + template + void ScalarIntTransform::transform(const int opNum, void *x, Nd4jLong xStride, void *result, Nd4jLong resultStride, void *scalar, void *extraParams, const Nd4jLong n) { + + } + + template + template + void ScalarIntTransform::transform(void *x, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *scalar, void *extraParams) { + + } + + + template + template + void ScalarIntTransform::transform(void *x, Nd4jLong xStride, void *result, Nd4jLong resultStride, void *scalar, void *extraParams, const Nd4jLong n) { + + } +} +} + diff --git a/libnd4j/include/loops/legacy_ops.h b/libnd4j/include/loops/legacy_ops.h index b3096ac0e..b0d891287 100644 --- a/libnd4j/include/loops/legacy_ops.h +++ b/libnd4j/include/loops/legacy_ops.h @@ -29,6 +29,16 @@ (4, aggregateOps::CBOW) ,\ (5, aggregateOps::GEMM) +#define BROADCAST_INT_OPS \ + (0, ShiftLeft), \ + (1, ShiftRight), \ + (2, CyclicShiftLeft), \ + (3, CyclicShiftRight), \ + (4, IntAnd), \ + (5, IntOr), \ + (6, IntXor) + + #define BROADCAST_BOOL_OPS \ (0, EqualTo),\ (1, GreaterThan),\ @@ -171,6 +181,14 @@ (0, SummaryStatsVariance), \ (1, SummaryStatsStandardDeviation) +#define SCALAR_INT_OPS \ + (0, ShiftLeft), \ + (1, ShiftRight), \ + (2, CyclicShiftLeft), \ + (3, CyclicShiftRight), \ + (4, IntAnd), \ + (5, IntOr), \ + (6, IntXor) #define SCALAR_BOOL_OPS \ (0, EqualTo),\ @@ -300,6 +318,15 @@ (13, ExponentialDistribution),\ (14, ExponentialDistributionInv) +#define PAIRWISE_INT_OPS \ + (0, ShiftLeft), \ + (1, ShiftRight), \ + (2, CyclicShiftLeft), \ + (3, CyclicShiftRight), \ + (4, IntAnd), \ + (5, IntOr), \ + (6, IntXor) + #define PAIRWISE_BOOL_OPS \ (0, EqualTo),\ (1, GreaterThan),\ diff --git a/libnd4j/include/loops/pairwise_int.h b/libnd4j/include/loops/pairwise_int.h new file mode 100644 index 000000000..14d273285 --- /dev/null +++ b/libnd4j/include/loops/pairwise_int.h @@ -0,0 +1,119 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +/* + * pairwise_transform.h + * + * Created on: Dec 28, 2015 + * Author: agibsonccc + */ + +#ifndef PAIRWISE_INT_H_ +#define PAIRWISE_INT_H_ +#ifdef _OPENMP +#include +#endif +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef __CUDACC__ +#include +#include +#endif + +#ifndef _OPENMP +#define omp_get_thread_num() 0 +#define omp_get_max_threads() 1 +#endif + + +#include "legacy_ops.h" + +using namespace simdOps; + +namespace functions { + namespace pairwise_transforms { + +/** + * Transforms involving 2 arrays + */ + template + class PairWiseIntTransform { + public: + +#ifdef __CUDACC__ + + template + static __host__ void intermediateShaped(dim3& launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, void *vextraParams); + + static __host__ void executeCudaShaped(dim3& launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, void *extraParams); + + +#endif + public: + + static void exec( + const int opNum, + void *dx, + Nd4jLong *xShapeBuffer, + void *y, + Nd4jLong *yShapeBuffer, + void *result, + Nd4jLong *resultShapeBuffer, + void *extraParams); + + static void exec( + const int opNum, + void *dx, + Nd4jLong xStride, + void *y, + Nd4jLong yStride, + void *result, + Nd4jLong resultStride, + void *extraParams, + Nd4jLong n); + + + template + static void exec( + void *vx, + Nd4jLong* xShapeBuffer, + void *vy, + Nd4jLong* yShapeBuffer, + void *vresult, + Nd4jLong* resultShapeBuffer, + void *vextraParams); + + template + static void exec(void *vx, + Nd4jLong xStride, + void *vy, + Nd4jLong yStride, + void *vresult, + Nd4jLong resultStride, + void *vextraParams, + const Nd4jLong n); + }; + } +} + +#endif /* PAIRWISE_TRANSFORM_H_ */ diff --git a/libnd4j/include/loops/scalar_int.h b/libnd4j/include/loops/scalar_int.h new file mode 100644 index 000000000..f873d5419 --- /dev/null +++ b/libnd4j/include/loops/scalar_int.h @@ -0,0 +1,142 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +/* + * scalar.h + * + * Created on: Dec 28, 2015 + * Author: agibsonccc + */ + +#ifndef SCALAR_INT_H_ +#define SCALAR_INT_H_ +#include + +#ifdef __JNI__ +#include +#endif +#include +#include +#include +#include "helpers/logger.h" +#include +#include + +#ifdef __CUDACC__ +#include +#include +#include +#endif + +#include "legacy_ops.h" + +namespace functions { + namespace scalar { +/** + * Apply a scalar + * operation to an array + */ + template + class ScalarIntTransform { + + public: + +#ifdef __CUDACC__ + + template + __device__ + static void transformCuda(void* scalar, void *vy, Nd4jLong *shapeInfo, void *vparams, void *vresult, Nd4jLong *resultShapeInfo, int *allocationBuffer); + + template + __device__ + static void transformCuda(Nd4jLong n, void* vx, void *vy, Nd4jLong yEWS, void *vparams, void *vz, Nd4jLong zEWS, int *allocationBuffer); + + template + __device__ + static void transformCuda(void *vx, Nd4jLong *xShapeInfo, void *vextraParams, void *vz, Nd4jLong *zShapeInfo, void *vscalars, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ); + + template + __host__ + static void intermediateAlongDimension(dim3& launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *z, Nd4jLong *zShapeInfo, void *scalars, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ); + + template + __host__ + static void intermediateShaped(dim3& launchDims, cudaStream_t *stream, void *vx, Nd4jLong *xShapeInfo, void *vz, Nd4jLong *zShapeInfo, void* vscalar, void *vextraParams, int *allocPointer); + + __host__ + static void executeCudaShaped(dim3& launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void* scalar, void *extraParams); + + __host__ + static void executeCudaAlongDimension(dim3& launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *z, Nd4jLong *zShapeInfo, void *scalars, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ); + + +/* +#include "cuda/scalar_temp.cu" +*/ +#endif + template + static void transform(void *x, Nd4jLong *xShapeInfo, void *extraParams, void *z, Nd4jLong *zShapeInfo, void *scalars, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ); + + static void transform(int opNum, void *x, Nd4jLong *xShapeInfo, void *extraParams, void *z, Nd4jLong *zShapeInfo, void *scalars, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ); + + static void transform(const int opNum, void *x, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *scalar, void *extraParams); + + static void transform(const int opNum, void *x, Nd4jLong xStride, void *result, Nd4jLong resultStride, void *scalar, void *extraParams, const Nd4jLong n); + + + + + /* + * ScalarOp along dimension + */ + + + /** + * CPU implementation of scalar operation + * @param x the input + * @param xStride the stride for the input + * @param result the result buffer + * @param resultStride the stride for the result + * @param scalar the scalar to apply + * @param extraParams the extra parameters where + * neccssary + * @param n the number of elements to loop over + */ + + template + static void transform(void *x, Nd4jLong *xShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *scalar, void *extraParams); + + + /** + * CPU implementation of scalar operation + * @param x the input + * @param xStride the stride for the input + * @param result the result buffer + * @param resultStride the stride for the result + * @param scalar the scalar to apply + * @param extraParams the extra parameters where + * neccssary + * @param n the number of elements to loop over + */ + + template + static void transform(void *x, Nd4jLong xStride, void *result, Nd4jLong resultStride, void *scalar, void *extraParams, const Nd4jLong n); + }; + } +} + + +#endif /* SCALAR_H_ */ diff --git a/libnd4j/include/op_enums.h b/libnd4j/include/op_enums.h index bf9d4c72f..8a100153f 100644 --- a/libnd4j/include/op_enums.h +++ b/libnd4j/include/op_enums.h @@ -63,6 +63,10 @@ namespace nd4j { enum BoolOps { BUILD_ENUMERATION(PAIRWISE_BOOL_OPS) }; + + enum IntOps { + BUILD_ENUMERATION(PAIRWISE_INT_OPS) + }; } namespace scalar { @@ -73,6 +77,10 @@ namespace nd4j { enum BoolOps { BUILD_ENUMERATION(SCALAR_BOOL_OPS) }; + + enum IntOps { + BUILD_ENUMERATION(SCALAR_INT_OPS) + }; } namespace reduce { @@ -113,6 +121,10 @@ namespace nd4j { enum BoolOps { BUILD_ENUMERATION(BROADCAST_BOOL_OPS) }; + + enum IntOps { + BUILD_ENUMERATION(BROADCAST_INT_OPS) + }; } namespace variance { diff --git a/libnd4j/include/ops/BroadcastIntOpsTuple.h b/libnd4j/include/ops/BroadcastIntOpsTuple.h new file mode 100644 index 000000000..df40907a9 --- /dev/null +++ b/libnd4j/include/ops/BroadcastIntOpsTuple.h @@ -0,0 +1,49 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#ifndef DEV_TESTS_BROADCASTINTOPSTUPLE_H +#define DEV_TESTS_BROADCASTINTOPSTUPLE_H + +#include + +namespace nd4j { + class BroadcastIntOpsTuple { + private: + + public: + nd4j::scalar::IntOps s; + nd4j::pairwise::IntOps p; + nd4j::broadcast::IntOps b; + + BroadcastIntOpsTuple() = default; + ~BroadcastIntOpsTuple() = default; + + BroadcastIntOpsTuple(nd4j::scalar::IntOps scalar, nd4j::pairwise::IntOps pairwise, nd4j::broadcast::IntOps broadcast) { + s = scalar; + p = pairwise; + b = broadcast; + } + + static BroadcastIntOpsTuple custom(nd4j::scalar::IntOps scalar, nd4j::pairwise::IntOps pairwise, nd4j::broadcast::IntOps broadcast); + }; +} + + +#endif //DEV_TESTS_BROADCASTOPSTUPLE_H diff --git a/libnd4j/include/ops/declarable/generic/bitwise/cyclic_rshift.cpp b/libnd4j/include/ops/declarable/generic/bitwise/cyclic_rshift.cpp index 2aac5c6f9..89d380d02 100644 --- a/libnd4j/include/ops/declarable/generic/bitwise/cyclic_rshift.cpp +++ b/libnd4j/include/ops/declarable/generic/bitwise/cyclic_rshift.cpp @@ -27,22 +27,14 @@ namespace nd4j { namespace ops { - CONFIGURABLE_OP_IMPL(cyclic_rshift_bits, 1, 1, true, 0, -2) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); + BROADCASTABLE_OP_IMPL(cyclic_rshift_bits, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); - REQUIRE_TRUE(block.numI() > 0 || block.width() > 1, 0, "cyclic_rshift_bits: actual shift value is missing"); + BROADCAST_CHECK_EMPTY(x,y,z); - uint32_t shift = 0; - if (block.width() > 1) { - shift = INPUT_VARIABLE(1)->e(0); - } else if (block.numI() > 0) { - shift = INT_ARG(0); - }; - - helpers::cyclic_rshift_bits(block.launchContext(), *input, *output, shift); - - REQUIRE_TRUE(shift > 0 && shift < input->sizeOfT() * 8, 0, "cyclic_rshift_bits: can't shift beyond size of data type") + x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::CyclicShiftRight, pairwise::CyclicShiftRight, broadcast::CyclicShiftRight), y, z, false); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/bitwise/cyclic_shift.cpp b/libnd4j/include/ops/declarable/generic/bitwise/cyclic_shift.cpp index 0bdb9503d..f18314910 100644 --- a/libnd4j/include/ops/declarable/generic/bitwise/cyclic_shift.cpp +++ b/libnd4j/include/ops/declarable/generic/bitwise/cyclic_shift.cpp @@ -27,22 +27,14 @@ namespace nd4j { namespace ops { - CONFIGURABLE_OP_IMPL(cyclic_shift_bits, 1, 1, true, 0, -2) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); + BROADCASTABLE_OP_IMPL(cyclic_shift_bits, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); - REQUIRE_TRUE(block.numI() > 0 || block.width() > 1, 0, "cyclic_shift_bits: actual shift value is missing"); + BROADCAST_CHECK_EMPTY(x,y,z); - uint32_t shift = 0; - if (block.width() > 1) { - shift = INPUT_VARIABLE(1)->e(0); - } else if (block.numI() > 0) { - shift = INT_ARG(0); - }; - - helpers::cyclic_shift_bits(block.launchContext(), *input, *output, shift); - - REQUIRE_TRUE(shift > 0 && shift < input->sizeOfT() * 8, 0, "cyclic_shift_bits: can't shift beyond size of data type") + x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::CyclicShiftLeft, pairwise::CyclicShiftLeft, broadcast::CyclicShiftLeft), y, z, false); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/bitwise/rshift.cpp b/libnd4j/include/ops/declarable/generic/bitwise/rshift.cpp index 4068351a2..36b0defd0 100644 --- a/libnd4j/include/ops/declarable/generic/bitwise/rshift.cpp +++ b/libnd4j/include/ops/declarable/generic/bitwise/rshift.cpp @@ -27,22 +27,14 @@ namespace nd4j { namespace ops { - CONFIGURABLE_OP_IMPL(rshift_bits, 1, 1, true, 0, -2) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); + BROADCASTABLE_OP_IMPL(rshift_bits, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); - REQUIRE_TRUE(block.numI() > 0 || block.width() > 1, 0, "rshift_bits: actual shift value is missing"); + BROADCAST_CHECK_EMPTY(x,y,z); - uint32_t shift = 0; - if (block.width() > 1) { - shift = INPUT_VARIABLE(1)->e(0); - } else if (block.numI() > 0) { - shift = INT_ARG(0); - }; - - REQUIRE_TRUE(shift > 0 && shift < input->sizeOfT() * 8, 0, "rshift_bits: can't shift beyond size of data type") - - helpers::rshift_bits(block.launchContext(), *input, *output, shift); + x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::ShiftRight, pairwise::ShiftRight, broadcast::ShiftRight), y, z, false); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/bitwise/shift.cpp b/libnd4j/include/ops/declarable/generic/bitwise/shift.cpp index f79da1024..ab4ed9880 100644 --- a/libnd4j/include/ops/declarable/generic/bitwise/shift.cpp +++ b/libnd4j/include/ops/declarable/generic/bitwise/shift.cpp @@ -27,22 +27,14 @@ namespace nd4j { namespace ops { - CONFIGURABLE_OP_IMPL(shift_bits, 1, 1, true, 0, -2) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); + BROADCASTABLE_OP_IMPL(shift_bits, 0, 0) { + auto x = INPUT_VARIABLE(0); + auto y = INPUT_VARIABLE(1); + auto z = OUTPUT_VARIABLE(0); - REQUIRE_TRUE(block.numI() > 0 || block.width() > 1, 0, "shift_bits: actual shift value is missing"); + BROADCAST_CHECK_EMPTY(x,y,z); - uint32_t shift = 0; - if (block.width() > 1) { - shift = INPUT_VARIABLE(1)->e(0); - } else if (block.numI() > 0) { - shift = INT_ARG(0); - }; - - REQUIRE_TRUE(shift > 0 && shift < input->sizeOfT() * 8, 0, "shift_bits: can't shift beyond size of data type") - - helpers::shift_bits(block.launchContext(), *input, *output, shift); + x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::ShiftLeft, pairwise::ShiftLeft, broadcast::ShiftLeft), y, z, false); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/headers/bitwise.h b/libnd4j/include/ops/declarable/headers/bitwise.h index da5ff4047..a6362a73f 100644 --- a/libnd4j/include/ops/declarable/headers/bitwise.h +++ b/libnd4j/include/ops/declarable/headers/bitwise.h @@ -45,7 +45,7 @@ namespace nd4j { * @tparam T */ #if NOT_EXCLUDED(OP_shift_bits) - DECLARE_CONFIGURABLE_OP(shift_bits, 1, 1, true, 0, -2); + DECLARE_BROADCASTABLE_OP(shift_bits, 0, 0); #endif /** @@ -56,7 +56,7 @@ namespace nd4j { * @tparam T */ #if NOT_EXCLUDED(OP_rshift_bits) - DECLARE_CONFIGURABLE_OP(rshift_bits, 1, 1, true, 0, -2); + DECLARE_BROADCASTABLE_OP(rshift_bits, 0, 0); #endif /** @@ -67,7 +67,7 @@ namespace nd4j { * @tparam T */ #if NOT_EXCLUDED(OP_cyclic_shift_bits) - DECLARE_CONFIGURABLE_OP(cyclic_shift_bits, 1, 1, true, 0, -2); + DECLARE_BROADCASTABLE_OP(cyclic_shift_bits, 0, 0); #endif /** @@ -78,7 +78,7 @@ namespace nd4j { * @tparam T */ #if NOT_EXCLUDED(OP_cyclic_rshift_bits) - DECLARE_CONFIGURABLE_OP(cyclic_rshift_bits, 1, 1, true, 0, -2); + DECLARE_BROADCASTABLE_OP(cyclic_rshift_bits, 0, 0); #endif /** diff --git a/libnd4j/include/ops/impl/BroadcastIntOpsTuple.cpp b/libnd4j/include/ops/impl/BroadcastIntOpsTuple.cpp new file mode 100644 index 000000000..607572b59 --- /dev/null +++ b/libnd4j/include/ops/impl/BroadcastIntOpsTuple.cpp @@ -0,0 +1,27 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// +#include + +namespace nd4j { + BroadcastIntOpsTuple BroadcastIntOpsTuple::custom(nd4j::scalar::IntOps scalar, nd4j::pairwise::IntOps pairwise, nd4j::broadcast::IntOps broadcast) { + BroadcastIntOpsTuple t(scalar, pairwise, broadcast); + return t; + } +} diff --git a/libnd4j/include/ops/ops.h b/libnd4j/include/ops/ops.h index fe6bfae81..e4fef2c3c 100644 --- a/libnd4j/include/ops/ops.h +++ b/libnd4j/include/ops/ops.h @@ -660,6 +660,98 @@ namespace simdOps { } }; + template + class IntOr { + public: + + op_def static X op(X d1, X d2) { + return d2 | d1; + } + + op_def static X op(X d1, X d2, X *params) { + return op(d1, d2); + } + }; + + template + class IntAnd { + public: + + op_def static X op(X d1, X d2) { + return d2 & d1; + } + + op_def static X op(X d1, X d2, X *params) { + return op(d1, d2); + } + }; + + template + class IntXor { + public: + + op_def static X op(X d1, X d2) { + return d2 ^ d1; + } + + op_def static X op(X d1, X d2, X *params) { + return op(d1, d2); + } + }; + + template + class ShiftLeft { + public: + + op_def static X op(X d1, X d2) { + return d1 << d2; + } + + op_def static X op(X d1, X d2, X *params) { + return op(d1, d2); + } + }; + + template + class ShiftRight { + public: + + op_def static X op(X d1, X d2) { + return d1 >> d2; + } + + op_def static X op(X d1, X d2, X *params) { + return op(d1, d2); + } + }; + + template + class CyclicShiftLeft { + public: + + op_def static X op(X d1, X d2) { + return d1 << d2 | d1 >> ((sizeof(X) * 8) - d2); + } + + op_def static X op(X d1, X d2, X *params) { + return op(d1, d2); + } + }; + + template + class CyclicShiftRight { + public: + + op_def static X op(X d1, X d2) { + return d1 >> d2 | d1 << ((sizeof(X) * 8) - d2); + } + + op_def static X op(X d1, X d2, X *params) { + return op(d1, d2); + } + }; + + template class Or { public: diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp index 8c484268e..87ac417be 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp @@ -623,12 +623,13 @@ TEST_F(DeclarableOpsTests13, adjustSaturation_5) { TEST_F(DeclarableOpsTests13, shift_bits_1) { auto x = NDArrayFactory::create('c', {5}); + auto y = NDArrayFactory::create(4); auto e = x.ulike(); x.assign(32); e.assign(512); nd4j::ops::shift_bits op; - auto result = op.execute({&x}, {}, {4}); + auto result = op.execute({&x, &y}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -640,12 +641,13 @@ TEST_F(DeclarableOpsTests13, shift_bits_1) { TEST_F(DeclarableOpsTests13, rshift_bits_1) { auto x = NDArrayFactory::create('c', {5}); + auto y = NDArrayFactory::create(4); auto e = x.ulike(); x.assign(512); e.assign(32); nd4j::ops::rshift_bits op; - auto result = op.execute({&x}, {}, {4}); + auto result = op.execute({&x, &y}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -657,12 +659,13 @@ TEST_F(DeclarableOpsTests13, rshift_bits_1) { TEST_F(DeclarableOpsTests13, cyclic_shift_bits_1) { auto x = NDArrayFactory::create('c', {5}); + auto y = NDArrayFactory::create(4); auto e = x.ulike(); x.assign(32); e.assign(512); nd4j::ops::cyclic_shift_bits op; - auto result = op.execute({&x}, {}, {4}); + auto result = op.execute({&x, &y}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); @@ -674,12 +677,107 @@ TEST_F(DeclarableOpsTests13, cyclic_shift_bits_1) { TEST_F(DeclarableOpsTests13, cyclic_rshift_bits_1) { auto x = NDArrayFactory::create('c', {5}); + auto y = NDArrayFactory::create(4); auto e = x.ulike(); x.assign(512); e.assign(32); nd4j::ops::cyclic_rshift_bits op; - auto result = op.execute({&x}, {}, {4}); + auto result = op.execute({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); + + ASSERT_EQ(e, *z); + + delete result; +} + +TEST_F(DeclarableOpsTests13, shift_bits_2) { + auto x = NDArrayFactory::create('c', {5}); + auto y = NDArrayFactory::create('c', {5}); + auto e = x.ulike(); + x.assign(32); + y.assign(4); + e.assign(512); + + nd4j::ops::shift_bits op; + auto result = op.execute({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); + + ASSERT_EQ(e, *z); + + delete result; +} + +TEST_F(DeclarableOpsTests13, rshift_bits_2) { + auto x = NDArrayFactory::create('c', {5}); + auto y = NDArrayFactory::create('c', {5}); + auto e = x.ulike(); + x.assign(512); + y.assign(4); + e.assign(32); + + nd4j::ops::rshift_bits op; + auto result = op.execute({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); + + ASSERT_EQ(e, *z); + + delete result; +} + +TEST_F(DeclarableOpsTests13, cyclic_shift_bits_2) { + auto x = NDArrayFactory::create('c', {5}); + auto y = NDArrayFactory::create('c', {5}); + auto e = x.ulike(); + x.assign(32); + y.assign(4); + e.assign(512); + + nd4j::ops::cyclic_shift_bits op; + auto result = op.execute({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); + + ASSERT_EQ(e, *z); + + delete result; +} + +TEST_F(DeclarableOpsTests13, cyclic_rshift_bits_2) { + auto x = NDArrayFactory::create('c', {5}); + auto y = NDArrayFactory::create('c', {5}); + auto e = x.ulike(); + x.assign(512); + y.assign(4); + e.assign(32); + + nd4j::ops::cyclic_rshift_bits op; + auto result = op.execute({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); + + ASSERT_EQ(e, *z); + + delete result; +} +TEST_F(DeclarableOpsTests13, shift_bits_3) { + auto x = NDArrayFactory::create('c', {5, 5}); + auto y = NDArrayFactory::create('c', {1, 5}); + auto e = x.ulike(); + x.assign(32); + y.assign(4); + e.assign(512); + + nd4j::ops::shift_bits op; + auto result = op.execute({&x, &y}, {}, {}); ASSERT_EQ(Status::OK(), result->status()); auto z = result->at(0); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java index 1f361ce9c..035ff0960 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java @@ -1234,19 +1234,19 @@ public class DifferentialFunctionFactory { return new Xor(sameDiff(), ix, iy).outputVariable(); } - public SDVariable shift(SDVariable ix, int shift) { + public SDVariable shift(SDVariable ix, SDVariable shift) { return new ShiftBits(sameDiff(), ix, shift).outputVariable(); } - public SDVariable rshift(SDVariable ix, int shift) { + public SDVariable rshift(SDVariable ix, SDVariable shift) { return new RShiftBits(sameDiff(), ix, shift).outputVariable(); } - public SDVariable rotl(SDVariable ix, int shift) { + public SDVariable rotl(SDVariable ix, SDVariable shift) { return new CyclicShiftBits(sameDiff(), ix, shift).outputVariable(); } - public SDVariable rotr(SDVariable ix, int shift) { + public SDVariable rotr(SDVariable ix, SDVariable shift) { return new CyclicRShiftBits(sameDiff(), ix, shift).outputVariable(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java index 70eaa5cd9..10fc0b44a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java @@ -2428,7 +2428,7 @@ public class SDMath extends SDOps { * @param x Input 1 * @return Output SDVariable with shifted bits */ - public SDVariable bitShift(String name, SDVariable x, int shift) { + public SDVariable bitShift(String name, SDVariable x, SDVariable shift) { validateInteger("shift_bits", x); SDVariable result = f().shift(x, shift); return updateVariableNameAndReference(result, name); @@ -2441,7 +2441,7 @@ public class SDMath extends SDOps { * @param x Input 1 * @return Output SDVariable with shifted bits */ - public SDVariable bitShiftRight(String name, SDVariable x, int shift) { + public SDVariable bitShiftRight(String name, SDVariable x, SDVariable shift) { validateInteger("rshift_bits", x); SDVariable result = f().rshift(x, shift); return updateVariableNameAndReference(result, name); @@ -2454,7 +2454,7 @@ public class SDMath extends SDOps { * @param x Input 1 * @return Output SDVariable with shifted bits */ - public SDVariable bitRotl(String name, SDVariable x, int shift) { + public SDVariable bitRotl(String name, SDVariable x, SDVariable shift) { validateInteger("cyclic_shift_bits", x); SDVariable result = f().rotl(x, shift); return updateVariableNameAndReference(result, name); @@ -2467,7 +2467,7 @@ public class SDMath extends SDOps { * @param x Input 1 * @return Output SDVariable with shifted bits */ - public SDVariable bitRotr(String name, SDVariable x, int shift) { + public SDVariable bitRotr(String name, SDVariable x, SDVariable shift) { validateInteger("cyclic_rshift_bits", x); SDVariable result = f().rotr(x, shift); return updateVariableNameAndReference(result, name); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CyclicRShiftBits.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CyclicRShiftBits.java index 318a7dc02..3a9173654 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CyclicRShiftBits.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CyclicRShiftBits.java @@ -34,18 +34,16 @@ import java.util.List; */ public class CyclicRShiftBits extends BaseDynamicTransformOp { - public CyclicRShiftBits(SameDiff sameDiff, SDVariable x, int shift) { - super(sameDiff, new SDVariable[] {x} ,false); - this.addIArgument(shift); + public CyclicRShiftBits(SameDiff sameDiff, SDVariable x, SDVariable shift) { + super(sameDiff, new SDVariable[] {x, shift} ,false); } - public CyclicRShiftBits(INDArray input, int shift, INDArray output) { - super(new INDArray[]{input}, new INDArray[]{output}); - this.addIArgument(shift); + public CyclicRShiftBits(INDArray input, INDArray shift, INDArray output) { + super(new INDArray[]{input, shift}, new INDArray[]{output}); } - public CyclicRShiftBits(INDArray input, int shift) { - this(input, shift,null); + public CyclicRShiftBits(INDArray input, INDArray shift) { + this(input, shift,input.ulike()); } public CyclicRShiftBits() {} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CyclicShiftBits.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CyclicShiftBits.java index b4291c5df..20b6f6955 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CyclicShiftBits.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CyclicShiftBits.java @@ -34,18 +34,16 @@ import java.util.List; */ public class CyclicShiftBits extends BaseDynamicTransformOp { - public CyclicShiftBits(SameDiff sameDiff, SDVariable x, int shift) { - super(sameDiff, new SDVariable[] {x} ,false); - this.addIArgument(shift); + public CyclicShiftBits(SameDiff sameDiff, SDVariable x, SDVariable shift) { + super(sameDiff, new SDVariable[] {x, shift} ,false); } - public CyclicShiftBits(INDArray input, int shift, INDArray output) { - super(new INDArray[]{input}, new INDArray[]{output}); - this.addIArgument(shift); + public CyclicShiftBits(INDArray input, INDArray shift, INDArray output) { + super(new INDArray[]{input, shift}, new INDArray[]{output}); } - public CyclicShiftBits(INDArray input, int shift) { - this(input, shift,null); + public CyclicShiftBits(INDArray input, INDArray shift) { + this(input, shift,input.ulike()); } public CyclicShiftBits() {} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/RShiftBits.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/RShiftBits.java index 80697efa3..4435615f5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/RShiftBits.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/RShiftBits.java @@ -34,18 +34,16 @@ import java.util.List; */ public class RShiftBits extends BaseDynamicTransformOp { - public RShiftBits(SameDiff sameDiff, SDVariable x, int shift) { - super(sameDiff, new SDVariable[] {x} ,false); - this.addIArgument(shift); + public RShiftBits(SameDiff sameDiff, SDVariable x, SDVariable y) { + super(sameDiff, new SDVariable[] {x, y} ,false); } - public RShiftBits(INDArray input, int shift, INDArray output) { - super(new INDArray[]{input}, new INDArray[]{output}); - this.addIArgument(shift); + public RShiftBits(INDArray input, INDArray shift, INDArray output) { + super(new INDArray[]{input, shift}, new INDArray[]{output}); } - public RShiftBits(INDArray input, int shift) { - this(input, shift,null); + public RShiftBits(INDArray input, INDArray shift) { + this(input, shift,input.ulike()); } public RShiftBits() {} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ShiftBits.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ShiftBits.java index 8c652f72d..5501324f2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ShiftBits.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ShiftBits.java @@ -34,18 +34,16 @@ import java.util.List; */ public class ShiftBits extends BaseDynamicTransformOp { - public ShiftBits(SameDiff sameDiff, SDVariable x, int shift) { - super(sameDiff, new SDVariable[] {x} ,false); - this.addIArgument(shift); + public ShiftBits(SameDiff sameDiff, SDVariable x, SDVariable y) { + super(sameDiff, new SDVariable[] {x, y} ,false); } - public ShiftBits(INDArray input, int shift, INDArray output) { - super(new INDArray[]{input}, new INDArray[]{output}); - this.addIArgument(shift); + public ShiftBits(INDArray x, INDArray y, INDArray output) { + super(new INDArray[]{x, y}, new INDArray[]{output}); } - public ShiftBits(INDArray input, int shift) { - this(input, shift,null); + public ShiftBits(INDArray x, INDArray y) { + this(x, y,x.ulike()); } public ShiftBits() {} diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java index 603413fd6..15f6c52ef 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java @@ -3594,6 +3594,7 @@ public native @Cast("Nd4jPointer") Pointer lcSolverHandle(OpaqueLaunchContext lc // #include // #include // #include +// #include // #include // #include // #include diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java index 38c0cb8c4..eeb4d38c3 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java @@ -3594,6 +3594,7 @@ public native @Cast("Nd4jPointer") Pointer lcSolverHandle(OpaqueLaunchContext lc // #include // #include // #include +// #include // #include // #include // #include @@ -16496,15 +16497,18 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); /** * creates identity 2D matrix or batch of identical 2D identity matrices - * + * * Input array: * provide some array - in any case operation simply neglects it - * + * + * Input float argument (if passed): + * TArgs[0] - type of elements of output array, default value is 5 (float) + * * Input integer arguments: * IArgs[0] - order of output identity matrix, 99 -> 'c'-order, 102 -> 'f'-order * IArgs[1] - the number of rows in output inner-most 2D identity matrix * IArgs[2] - optional, the number of columns in output inner-most 2D identity matrix, if this argument is not provided then it is taken to be equal to number of rows - * IArgs[3,4,...] - optional, shape of batch, output matrix will have leading batch dimensions of this shape + * IArgs[3,4,...] - optional, shape of batch, output matrix will have leading batch dimensions of this shape */ // #if NOT_EXCLUDED(OP_eye) @Namespace("nd4j::ops") public static class eye extends DeclarableCustomOp { @@ -16598,10 +16602,10 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); /** * clip a list of given tensors with given average norm when needed - * + * * Input: * a list of tensors (at least one) - * + * * Input floating point argument: * clip_norm - a value that used as threshold value and norm to be used * @@ -16749,12 +16753,12 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); /** * returns histogram (as 1D array) with fixed bins width - * + * * Input arrays: - * - input array with elements to be binned into output histogram + * - input array with elements to be binned into output histogram * - range array with first element being bottom limit and second element being top limit of histogram, please note that input_value <= range[0] will be mapped to histogram[0], input_value >= range[1] will be mapped to histogram[-1] - * + * * Input integer arguments: * nbins (optional) - number of histogram bins, default value is 100 */ @@ -21822,7 +21826,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); * \tparam T */ // #if NOT_EXCLUDED(OP_shift_bits) - @Namespace("nd4j::ops") public static class shift_bits extends DeclarableOp { + @Namespace("nd4j::ops") public static class shift_bits extends BroadcastableOp { static { Loader.load(); } /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ public shift_bits(Pointer p) { super(p); } @@ -21835,7 +21839,6 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); public shift_bits() { super((Pointer)null); allocate(); } private native void allocate(); - public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } // #endif @@ -21847,7 +21850,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); * \tparam T */ // #if NOT_EXCLUDED(OP_rshift_bits) - @Namespace("nd4j::ops") public static class rshift_bits extends DeclarableOp { + @Namespace("nd4j::ops") public static class rshift_bits extends BroadcastableOp { static { Loader.load(); } /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ public rshift_bits(Pointer p) { super(p); } @@ -21860,7 +21863,6 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); public rshift_bits() { super((Pointer)null); allocate(); } private native void allocate(); - public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } // #endif @@ -21872,7 +21874,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); * \tparam T */ // #if NOT_EXCLUDED(OP_cyclic_shift_bits) - @Namespace("nd4j::ops") public static class cyclic_shift_bits extends DeclarableOp { + @Namespace("nd4j::ops") public static class cyclic_shift_bits extends BroadcastableOp { static { Loader.load(); } /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ public cyclic_shift_bits(Pointer p) { super(p); } @@ -21885,7 +21887,6 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); public cyclic_shift_bits() { super((Pointer)null); allocate(); } private native void allocate(); - public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } // #endif @@ -21897,7 +21898,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); * \tparam T */ // #if NOT_EXCLUDED(OP_cyclic_rshift_bits) - @Namespace("nd4j::ops") public static class cyclic_rshift_bits extends DeclarableOp { + @Namespace("nd4j::ops") public static class cyclic_rshift_bits extends BroadcastableOp { static { Loader.load(); } /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ public cyclic_rshift_bits(Pointer p) { super(p); } @@ -21910,6 +21911,30 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); public cyclic_rshift_bits() { super((Pointer)null); allocate(); } private native void allocate(); + } +// #endif + + /** + * This operation returns hamming distance based on bits + * + * PLEASE NOTE: This operation is applicable only to integer data types + * + * \tparam T + */ +// #if NOT_EXCLUDED(OP_bits_hamming_distance) + @Namespace("nd4j::ops") public static class bits_hamming_distance extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public bits_hamming_distance(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public bits_hamming_distance(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public bits_hamming_distance position(long position) { + return (bits_hamming_distance)super.position(position); + } + + public bits_hamming_distance() { super((Pointer)null); allocate(); } + private native void allocate(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } // #endif diff --git a/nd4s/src/main/scala/org/nd4s/samediff/SameDiff.scala b/nd4s/src/main/scala/org/nd4s/samediff/SameDiff.scala index 8ca21b72e..e43a1e86e 100644 --- a/nd4s/src/main/scala/org/nd4s/samediff/SameDiff.scala +++ b/nd4s/src/main/scala/org/nd4s/samediff/SameDiff.scala @@ -90,8 +90,13 @@ class SDVariableWrapper { def |(other: SDVariable)(implicit sameDiff: SameDiff): SDVariable = sameDiff.math.or(thisVariable, other) def &(other: SDVariable)(implicit sameDiff: SameDiff): SDVariable = sameDiff.math.and(thisVariable, other) - def <<(x: Int)(implicit sameDiff: SameDiff): SDVariable = sameDiff.math.bitShift(null, thisVariable, x) - def >>(x: Int)(implicit sameDiff: SameDiff): SDVariable = sameDiff.math.bitShiftRight(null, thisVariable, x) + def <<(other: SDVariable)(implicit sameDiff: SameDiff): SDVariable = sameDiff.math.bitShift(null, thisVariable, other) + def >>(other: SDVariable)(implicit sameDiff: SameDiff): SDVariable = + sameDiff.math.bitShiftRight(null, thisVariable, other) + def <<(x: Int)(implicit sameDiff: SameDiff): SDVariable = + sameDiff.math.bitShift(null, thisVariable, sameDiff.constant(x)) + def >>(x: Int)(implicit sameDiff: SameDiff): SDVariable = + sameDiff.math.bitShiftRight(null, thisVariable, sameDiff.constant(x)) // Overloads for numeric arguments // Float diff --git a/nd4s/src/test/scala/org/nd4s/samediff/MathTest.scala b/nd4s/src/test/scala/org/nd4s/samediff/MathTest.scala index a2c113b50..8e9f892c6 100644 --- a/nd4s/src/test/scala/org/nd4s/samediff/MathTest.scala +++ b/nd4s/src/test/scala/org/nd4s/samediff/MathTest.scala @@ -188,4 +188,16 @@ class MathTest extends FlatSpec with Matchers { val w3 = w1 >> 2 w3.eval.toIntVector.head shouldBe 4 } + + "SameDiff" should "provide shifting operations with SDVariable argument" in { + implicit val sd = SameDiff.create() + val w1 = sd.constant(16) + val two = sd.constant(2) + + val w2 = w1 << two + w2.eval.toIntVector.head shouldBe 64 + + val w3 = w1 >> two + w3.eval.toIntVector.head shouldBe 4 + } }