diff --git a/libnd4j/blas/NDArray.hpp b/libnd4j/blas/NDArray.hpp index a83472899..00a984d45 100644 --- a/libnd4j/blas/NDArray.hpp +++ b/libnd4j/blas/NDArray.hpp @@ -2772,9 +2772,9 @@ void NDArray::applyBroadcast(nd4j::broadcast::BoolOps op, const std::vector // TODO: eventually we want separate tads here NDArray::prepareSpecialUse({result}, {this, other}); if(max == this) - NativeOpExecutioner::execBroadcastBool( getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), copy.data(), (int)copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets()); + NativeOpExecutioner::execBroadcastBool( getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), nullptr, copy.data(), (int)copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets()); else - NativeOpExecutioner::execInverseBroadcastBool(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), copy.data(), (int)copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets()); + NativeOpExecutioner::execInverseBroadcastBool(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), nullptr, copy.data(), (int)copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets()); registerSpecialUse({result}, {this, other}); } diff --git a/libnd4j/blas/NativeOpExecutioner.h b/libnd4j/blas/NativeOpExecutioner.h index fb2ca58f0..b4a5fdea4 100644 --- a/libnd4j/blas/NativeOpExecutioner.h +++ b/libnd4j/blas/NativeOpExecutioner.h @@ -284,6 +284,7 @@ static void execScalarInt(nd4j::LaunchContext *lc, void *dY, Nd4jLong *dYShapeInfo, void *hZ, Nd4jLong *hZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo, + void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ); @@ -296,6 +297,7 @@ static void execScalarInt(nd4j::LaunchContext *lc, void *dY, Nd4jLong *dYShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *dZ, Nd4jLong *dZShapeInfo, + void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ); diff --git a/libnd4j/blas/NativeOps.h b/libnd4j/blas/NativeOps.h index ff368d7c8..ea2b0b8a5 100755 --- a/libnd4j/blas/NativeOps.h +++ b/libnd4j/blas/NativeOps.h @@ -179,6 +179,7 @@ ND4J_EXPORT void execBroadcastBool( void *dY, Nd4jLong *dYShapeInfo, void *hZ, Nd4jLong *hZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo, + void *extraParams, void *hDimension, Nd4jLong *hDimensionShape, void *dDimension, Nd4jLong *dDimensionShape); diff --git a/libnd4j/blas/cpu/NativeOpExecutioner.cpp b/libnd4j/blas/cpu/NativeOpExecutioner.cpp index dc27c1cce..75a68c984 100644 --- a/libnd4j/blas/cpu/NativeOpExecutioner.cpp +++ b/libnd4j/blas/cpu/NativeOpExecutioner.cpp @@ -156,6 +156,9 @@ void NativeOpExecutioner::execBroadcast(nd4j::LaunchContext *lc, auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) + return; + #ifdef __ND4J_EXPERIMENTAL__ BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::broadcast::Broadcast, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, LIBND4J_TYPES); #else @@ -187,7 +190,8 @@ void NativeOpExecutioner::execInverseBroadcast(nd4j::LaunchContext *lc, auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); - + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) + return; if (!nd4j::Environment::getInstance()->isExperimentalBuild()) if ((yType != xType && yType != nd4j::DataType::BOOL) || xType != zType) @@ -219,6 +223,7 @@ void NativeOpExecutioner::execBroadcastBool(nd4j::LaunchContext *lc, void *dY, Nd4jLong *dYShapeInfo, void *hZ, Nd4jLong *hZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo, + void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) { @@ -228,8 +233,11 @@ void NativeOpExecutioner::execBroadcastBool(nd4j::LaunchContext *lc, auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) + return; + auto func = PRAGMA_THREADS_FOR { - BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ, start, stop), LIBND4J_TYPES, BOOL_TYPES); + BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ, start, stop), LIBND4J_TYPES, BOOL_TYPES); }; auto xLen = shape::length(hXShapeInfo); @@ -247,22 +255,24 @@ void NativeOpExecutioner::execInverseBroadcastBool(nd4j::LaunchContext *lc, void *dY, Nd4jLong *dYShapeInfo, void *hZ, Nd4jLong *hZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo, + void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) { - - auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) + return; + if (!nd4j::Environment::getInstance()->isExperimentalBuild()) if (yType != xType || nd4j::DataType::BOOL != zType) throw nd4j::datatype_exception::build("NativeOps::execInverseBroadcastBool both operands must have same data type", xType, yType); auto func = PRAGMA_THREADS_FOR { - BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::execInverse(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ, start, stop), LIBND4J_TYPES, BOOL_TYPES); + BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::execInverse(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ, start, stop), LIBND4J_TYPES, BOOL_TYPES); }; auto xLen = shape::length(hXShapeInfo); @@ -292,6 +302,9 @@ void NativeOpExecutioner::execBroadcastInt(nd4j::LaunchContext *lc, auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) + return; + if (xType != yType || xType != zType) throw nd4j::datatype_exception::build("NativeOpExecutioner::execBroadcastInt", zType, xType, yType); @@ -321,12 +334,13 @@ void NativeOpExecutioner::execInverseBroadcastInt(nd4j::LaunchContext *lc, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) { - - auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) + return; + if (xType != yType || xType != zType) throw nd4j::datatype_exception::build("NativeOpExecutioner::execInverseBroadcastInt", zType, xType, yType); @@ -367,11 +381,13 @@ void NativeOpExecutioner::execPairwiseTransform(nd4j::LaunchContext *lc, void *dZ, Nd4jLong *dZShapeInfo, void *extraParams) { - auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) + return; + #ifdef __ND4J_EXPERIMENTAL__ BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::pairwise_transforms::PairWiseTransform, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, extraParams), LIBND4J_TYPES, LIBND4J_TYPES); #else @@ -403,6 +419,9 @@ void NativeOpExecutioner::execPairwiseBoolTransform(nd4j::LaunchContext *lc, auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) + return; + if (xType != yType) throw nd4j::datatype_exception::build("NativeOpExecutioner::execPairwiseBoolTransform", xType, yType); @@ -429,11 +448,13 @@ void NativeOpExecutioner::execPairwiseIntTransform(nd4j::LaunchContext *lc, void *dZ, Nd4jLong *dZShapeInfo, void *extraParams) { - auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) + return; + if (xType != yType || xType != zType) throw nd4j::datatype_exception::build("NativeOpExecutioner::execPairwiseIntTransform", zType, xType, yType); @@ -837,11 +858,13 @@ void NativeOpExecutioner::execScalar(nd4j::LaunchContext *lc, void *dScalar, Nd4jLong *dScalarShapeInfo, void *extraParams, bool allowParallelism) { - auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo)) + return; + #ifdef __ND4J_EXPERIMENTAL__ BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::scalar::ScalarTransform, ::transform(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, hScalar, extraParams), LIBND4J_TYPES, LIBND4J_TYPES); #else @@ -872,11 +895,13 @@ void NativeOpExecutioner::execScalar(nd4j::LaunchContext *lc, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) { - auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo)) + return; + #ifdef __ND4J_EXPERIMENTAL__ BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::scalar::ScalarTransform, ::transform(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, hScalars, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, LIBND4J_TYPES); #else @@ -904,12 +929,13 @@ void NativeOpExecutioner::execScalarBool(nd4j::LaunchContext *lc, void *dScalar, Nd4jLong *dSscalarShapeInfo, void *extraParams, bool allowParallelism) { - - auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto yType = nd4j::ArrayOptions::dataType(hSscalarShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hSscalarShapeInfo)) + return; + if (xType != yType) throw nd4j::datatype_exception::build("NativeOpExecutioner::execScalarBool", xType, yType); @@ -939,11 +965,13 @@ void NativeOpExecutioner::execScalarBool(nd4j::LaunchContext *lc, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) { - auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo)) + return; + if (xType != yType) throw nd4j::datatype_exception::build("NativeOpExecutioner::execScalarBool", xType, yType); @@ -969,12 +997,13 @@ void NativeOpExecutioner::execScalarInt(nd4j::LaunchContext *lc, void *dScalar, Nd4jLong *dSscalarShapeInfo, void *extraParams, bool allowParallelism) { - - auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto yType = nd4j::ArrayOptions::dataType(hSscalarShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hSscalarShapeInfo)) + return; + if (xType != yType || xType != zType) throw nd4j::datatype_exception::build("NativeOpExecutioner::execScalarInt", xType, yType); @@ -1004,11 +1033,13 @@ void NativeOpExecutioner::execScalarInt(nd4j::LaunchContext *lc, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) { - auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo)) + return; + if (xType != yType || xType != zType) throw nd4j::datatype_exception::build("NativeOpExecutioner::execScalarInt", xType, yType); @@ -1126,6 +1157,9 @@ void NativeOpExecutioner::execTransformFloat(nd4j::LaunchContext *lc, auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + if (shape::isEmpty(hXShapeInfo)) + return; + auto func = PRAGMA_THREADS_DO { BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformFloat, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES, FLOAT_TYPES); }; @@ -1145,6 +1179,9 @@ void NativeOpExecutioner::execTransformBool(nd4j::LaunchContext *lc, auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + if (shape::isEmpty(hXShapeInfo)) + return; + auto func = PRAGMA_THREADS_DO { BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformBool, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES, BOOL_TYPES); }; @@ -1164,6 +1201,9 @@ void NativeOpExecutioner::execTransformAny(nd4j::LaunchContext *lc, auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + if (shape::isEmpty(hXShapeInfo)) + return; + auto func = PRAGMA_THREADS_DO { BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformAny, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES, LIBND4J_TYPES); }; @@ -1183,6 +1223,9 @@ void NativeOpExecutioner::execTransformSame(nd4j::LaunchContext *lc, auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + if (shape::isEmpty(hXShapeInfo)) + return; + auto func = PRAGMA_THREADS_DO { BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformSame, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES); }; @@ -1202,6 +1245,9 @@ void NativeOpExecutioner::execTransformStrict(nd4j::LaunchContext *lc, auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + if (shape::isEmpty(hXShapeInfo)) + return; + auto func = PRAGMA_THREADS_DO { BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformStrict, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), FLOAT_TYPES); }; diff --git a/libnd4j/blas/cpu/NativeOps.cpp b/libnd4j/blas/cpu/NativeOps.cpp index df6ccc240..e790c05d0 100644 --- a/libnd4j/blas/cpu/NativeOps.cpp +++ b/libnd4j/blas/cpu/NativeOps.cpp @@ -231,8 +231,9 @@ void execBroadcastBool(Nd4jPointer *extraPointers, void *dY, Nd4jLong *dYShapeInfo, void *hZ, Nd4jLong *hZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo, - void *hDimension, Nd4jLong *hDimensionShape, - void *dDimension, Nd4jLong *dDimensionShape) { + void *extraParams, + void *hDimension, Nd4jLong *hDimensionShape, + void *dDimension, Nd4jLong *dDimensionShape) { try { auto dimension = reinterpret_cast(hDimension); int dimensionLength = static_cast(shape::length(hDimensionShape)); @@ -259,6 +260,7 @@ void execBroadcastBool(Nd4jPointer *extraPointers, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, + extraParams, dimension, dimensionLength, hTADShapeInfo, hTADOffsets, hTADShapeInfoZ, hTADOffsetsZ); diff --git a/libnd4j/blas/cuda/NativeOpExecutioner.cu b/libnd4j/blas/cuda/NativeOpExecutioner.cu index fcb473820..1f074f39b 100644 --- a/libnd4j/blas/cuda/NativeOpExecutioner.cu +++ b/libnd4j/blas/cuda/NativeOpExecutioner.cu @@ -101,6 +101,9 @@ void NativeOpExecutioner::execPairwiseTransform(nd4j::LaunchContext *lc, auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) + return; + if (xType != zType && yType != zType) throw std::runtime_error("NativeOpExecutioner::execPairwiseTransform requires Z operand to have either X or Y type"); if (lc == nullptr) @@ -139,6 +142,9 @@ void NativeOpExecutioner::execPairwiseBoolTransform( nd4j::LaunchContext *lc, auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) + return; + if (!DataTypeUtils::isB(zType)) throw nd4j::datatype_exception::build("NativeOpExecutioner::execPairwiseBoolTransform wrong Z operand data type", nd4j::DataType::BOOL, zType); @@ -172,6 +178,9 @@ void NativeOpExecutioner::execPairwiseIntTransform( nd4j::LaunchContext *lc, auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) + return; + if (!DataTypeUtils::isZ(zType)) throw nd4j::datatype_exception::build("NativeOpExecutioner::execPairwiseIntTransform wrong Z operand data type", nd4j::DataType::BOOL, zType); @@ -223,6 +232,7 @@ void NativeOpExecutioner::execBroadcastBool(nd4j::LaunchContext *lc, void *dY, Nd4jLong *dYShapeInfo, void *hZ, Nd4jLong *hZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo, + void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) { @@ -233,6 +243,9 @@ void NativeOpExecutioner::execBroadcastBool(nd4j::LaunchContext *lc, auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) + return; + if (!DataTypeUtils::isB(zType)) throw std::runtime_error("NativeOpExecutioner::execBroadcastBool requires Z operand to have BOOL type"); @@ -244,7 +257,7 @@ void NativeOpExecutioner::execBroadcastBool(nd4j::LaunchContext *lc, dim3 launchDims(256, 256, 1024); - BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::execBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, BOOL_TYPES) + BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::execBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, BOOL_TYPES) // TODO: remove after the release auto res = cudaStreamSynchronize(*stream); @@ -260,6 +273,7 @@ void NativeOpExecutioner::execInverseBroadcastBool(nd4j::LaunchContext *lc, void *dY, Nd4jLong *dYShapeInfo, void *hZ, Nd4jLong *hZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo, + void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) { @@ -269,18 +283,18 @@ void NativeOpExecutioner::execInverseBroadcastBool(nd4j::LaunchContext *lc, auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) + return; + if (!DataTypeUtils::isB(zType)) throw std::runtime_error("NativeOpExecutioner::execBroadcastBool requires Z operand to have BOOL type"); if (yType != xType) throw std::runtime_error("NativeOpExecutioner::execBroadcastBool requires both X & Y operands to have same type"); - if (nd4j::Environment::getInstance()->isDebugAndVerbose()) - printf("F3BI opNum:[%i]\n", opNum); - dim3 launchDims(256, 256, 1024); - BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::execInverseBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, BOOL_TYPES) + BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::execInverseBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, BOOL_TYPES) // TODO: remove after the release auto res = cudaStreamSynchronize(*stream); @@ -308,15 +322,15 @@ void NativeOpExecutioner::execBroadcastInt(nd4j::LaunchContext *lc, auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) + return; + if (!DataTypeUtils::isZ(zType)) throw std::runtime_error("NativeOpExecutioner::execBroadcastInt requires Z operand to have INT type"); if (yType != xType || zType != xType) throw std::runtime_error("NativeOpExecutioner::execBroadcastInt requires both X & Y operands to have same type"); - if (nd4j::Environment::getInstance()->isDebugAndVerbose()) - printf("F3B opNum:[%i]\n", opNum); - dim3 launchDims(256, 256, 1024); BUILD_SINGLE_SELECTOR(xType, functions::broadcast::BroadcastInt, ::execBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), INTEGER_TYPES) @@ -344,6 +358,9 @@ void NativeOpExecutioner::execInverseBroadcastInt(nd4j::LaunchContext *lc, auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) + return; + if (!DataTypeUtils::isZ(zType)) throw std::runtime_error("NativeOpExecutioner::execBroadcastInt requires Z operand to have INT type"); @@ -394,8 +411,8 @@ void NativeOpExecutioner::execBroadcast(nd4j::LaunchContext *lc, auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); - if (nd4j::Environment::getInstance()->isDebugAndVerbose()) - printf("F3 opNum:[%i]\n", opNum); + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) + return; dim3 launchDims(256, 256, 1024); @@ -429,8 +446,8 @@ void NativeOpExecutioner::execInverseBroadcast(nd4j::LaunchContext *lc, auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); - if (nd4j::Environment::getInstance()->isDebugAndVerbose()) - printf("F3I opNum:[%i]\n", opNum); + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo)) + return; dim3 launchDims(256, 256, 1024); @@ -832,16 +849,21 @@ void NativeOpExecutioner::execTransformSame(nd4j::LaunchContext *lc, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { auto stream = lc->getCudaStream(); - dim3 launchDims(512, 512, 16384); auto xRank = shape::rank(hXShapeInfo); - auto zRank = shape::rank(hZShapeInfo); - auto xType = ArrayOptions::dataType(hXShapeInfo); + auto zRank = shape::rank(hZShapeInfo); + auto xType = ArrayOptions::dataType(hXShapeInfo); auto zType = ArrayOptions::dataType(hZShapeInfo); - if (xType != zType) - throw std::runtime_error("NativeOpExecutioner::execTransformSame requires X & Z to have same type"); + if (shape::isEmpty(hXShapeInfo)) { + return; + } + if (xType != zType) { + throw std::runtime_error("NativeOpExecutioner::execTransformSame requires X & Z to have same type"); + } + + dim3 launchDims(512, 512, 16384); BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformSame, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), LIBND4J_TYPES); // TODO: remove after the release @@ -861,16 +883,21 @@ void NativeOpExecutioner::execTransformBool(nd4j::LaunchContext *lc, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { auto stream = lc->getCudaStream(); - dim3 launchDims(512, 512, 16384); - auto xRank = shape::rank(hXShapeInfo); - auto zRank = shape::rank(hZShapeInfo); - auto xType = ArrayOptions::dataType(hXShapeInfo); + auto xRank = shape::rank(hXShapeInfo); + auto zRank = shape::rank(hZShapeInfo); + auto xType = ArrayOptions::dataType(hXShapeInfo); auto zType = ArrayOptions::dataType(hZShapeInfo); - if (!DataTypeUtils::isB(zType)) - throw std::runtime_error("NativeOpExecutioner::execTransformBool requires Z to have same boolean type"); + if (shape::isEmpty(hXShapeInfo)) { + return; + } + if (!DataTypeUtils::isB(zType)) { + throw std::runtime_error("NativeOpExecutioner::execTransformBool requires Z to have same boolean type"); + } + + dim3 launchDims(512, 512, 16384); BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformBool, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), LIBND4J_TYPES, BOOL_TYPES); // TODO: remove after the release @@ -896,6 +923,9 @@ void NativeOpExecutioner::execTransformAny(nd4j::LaunchContext *lc, auto xType = ArrayOptions::dataType(hXShapeInfo); auto zType = ArrayOptions::dataType(hZShapeInfo); + if (shape::isEmpty(hXShapeInfo)) + return; + dim3 launchDims(512, 512, 2048); BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformAny, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), LIBND4J_TYPES, LIBND4J_TYPES); @@ -917,16 +947,21 @@ void NativeOpExecutioner::execTransformStrict(nd4j::LaunchContext *lc, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { auto stream = lc->getCudaStream(); - dim3 launchDims(512, 512, 16384); auto xRank = shape::rank(hXShapeInfo); auto zRank = shape::rank(hZShapeInfo); auto xType = ArrayOptions::dataType(hXShapeInfo); auto zType = ArrayOptions::dataType(hZShapeInfo); - if (xType != zType || !DataTypeUtils::isR(xType)) - throw datatype_exception::build("NativeOpExecutioner::execTransformStrict requires X & Z to have same floating point type", xType, zType); + if (shape::isEmpty(hXShapeInfo)) { + return; + } + if (xType != zType || !DataTypeUtils::isR(xType)) { + throw datatype_exception::build("NativeOpExecutioner::execTransformStrict requires X & Z to have same floating point type", xType, zType); + } + + dim3 launchDims(512, 512, 16384); BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformStrict, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), FLOAT_TYPES); // TODO: remove after the release @@ -953,6 +988,9 @@ void NativeOpExecutioner::execTransformFloat(nd4j::LaunchContext *lc, auto xType = ArrayOptions::dataType(hXShapeInfo); auto zType = ArrayOptions::dataType(hZShapeInfo); + if (shape::isEmpty(hXShapeInfo)) + return; + if (!DataTypeUtils::isR(zType)) throw datatype_exception::build("NativeOpExecutioner::execTransformFloat requires Z to have floating point type", zType); @@ -1175,6 +1213,9 @@ void NativeOpExecutioner::execScalarBool(nd4j::LaunchContext *lc, auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo)) + return; + if (xType != yType ) throw std::runtime_error("NativeOpExecutioner::execScalarBool requires X & Y to have same type"); @@ -1211,6 +1252,9 @@ void NativeOpExecutioner::execScalarBool(nd4j::LaunchContext *lc, auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo)) + return; + if (xType != yType ) throw std::runtime_error("NativeOpExecutioner::execScalarBool requires X & Y to have same type"); @@ -1244,6 +1288,9 @@ void NativeOpExecutioner::execScalarInt(nd4j::LaunchContext *lc, auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo)) + return; + if (xType != yType || zType != xType) throw std::runtime_error("NativeOpExecutioner::execScalarInt requires X & Y to have same type"); @@ -1280,6 +1327,9 @@ void NativeOpExecutioner::execScalarInt(nd4j::LaunchContext *lc, auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo)) + return; + if (xType != yType || zType != xType) throw std::runtime_error("NativeOpExecutioner::execScalarInt requires X & Y to have same type"); @@ -1313,6 +1363,9 @@ void NativeOpExecutioner::execScalar(nd4j::LaunchContext *lc, auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo)) + return; + #ifdef __ND4J_EXPERIMENTAL__ BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::scalar::ScalarTransform, ::executeCudaShaped(launchDims, stream, opNum, dX, dXShapeInfo, hXShapeInfo, dZ, dZShapeInfo, hZShapeInfo, dScalar, extraParams), LIBND4J_TYPES, LIBND4J_TYPES); @@ -1346,6 +1399,9 @@ void NativeOpExecutioner::execScalar(nd4j::LaunchContext *lc, auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); + if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo)) + return; + dim3 launchDims(256, 256, 16384); #ifdef __ND4J_EXPERIMENTAL__ diff --git a/libnd4j/blas/cuda/NativeOps.cu b/libnd4j/blas/cuda/NativeOps.cu index cda6acbad..024012808 100755 --- a/libnd4j/blas/cuda/NativeOps.cu +++ b/libnd4j/blas/cuda/NativeOps.cu @@ -294,6 +294,7 @@ void execBroadcastBool(Nd4jPointer *extraPointers, void *dY, Nd4jLong *dYShapeInfo, void *hZ, Nd4jLong *hZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo, + void *extraParams, void *hDimension, Nd4jLong *hDimensionShape, void *dDimension, Nd4jLong *dDimensionShape) { try { @@ -313,7 +314,7 @@ void execBroadcastBool(Nd4jPointer *extraPointers, LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); NativeOpExecutioner::execBroadcastBool(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, hY, hYShapeInfo, dY, - dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, dimension, + dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ); } catch (std::exception &e) { diff --git a/libnd4j/include/helpers/cpu/TrueBroadcastHelper.cpp b/libnd4j/include/helpers/cpu/TrueBroadcastHelper.cpp index c4c2fa995..dbf080ac9 100644 --- a/libnd4j/include/helpers/cpu/TrueBroadcastHelper.cpp +++ b/libnd4j/include/helpers/cpu/TrueBroadcastHelper.cpp @@ -20,6 +20,7 @@ #include #include +#include using namespace simdOps; @@ -47,36 +48,39 @@ void TrueBroadcastHelper::exec(const NDArray& xArr, const NDArray& yArr std::vector xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf()); - PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(zLen > Environment::getInstance()->elementwiseThreshold()) firstprivate(xCoords, yCoords, zCoords)) - for (Nd4jLong i = 0; i < zLen; ++i) { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; ++i) { - shape::index2coords(i, zShapeInfo, zCoords.data()); + shape::index2coords(i, zShapeInfo, zCoords.data()); - for(int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) { + for (int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) { - if(ix >= 0) { - if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) { - xCoords[ix--] = zCoords[iz]; - } else { - xCoords[ix--] = 0; + if (ix >= 0) { + if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) { + xCoords[ix--] = zCoords[iz]; + } else { + xCoords[ix--] = 0; + } + } + + if (iy >= 0) { + if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) { + yCoords[iy--] = zCoords[iz]; + } else { + yCoords[iy--] = 0; + } } } - if(iy >= 0) { - if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) { - yCoords[iy--] = zCoords[iz]; - } else { - yCoords[iy--] = 0; - } - } + const auto xOffset = shape::getOffset(xShapeInfo, xCoords.data()); + const auto yOffset = shape::getOffset(yShapeInfo, yCoords.data()); + const auto zOffset = shape::getOffset(zShapeInfo, zCoords.data()); + + z[zOffset] = OpType::op(x[xOffset], y[yOffset]); } + }; - const auto xOffset = shape::getOffset(xShapeInfo, xCoords.data()); - const auto yOffset = shape::getOffset(yShapeInfo, yCoords.data()); - const auto zOffset = shape::getOffset(zShapeInfo, zCoords.data()); - - z[zOffset] = OpType::op(x[xOffset], y[yOffset]); - } + samediff::Threads::parallel_for(func, 0, zLen); } template @@ -103,38 +107,40 @@ void TrueBroadcastBoolHelper::exec(const NDArray& xArr, const NDArray& yAr const Nd4jLong zLen = zArr.lengthOf(); - std::vector xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf()); + auto func = PRAGMA_THREADS_FOR { + std::vector xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf()); + for (auto i = start; i < stop; ++i) { - PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(zLen > Environment::getInstance()->elementwiseThreshold()) firstprivate(xCoords, yCoords, zCoords)) - for (Nd4jLong i = 0; i < zLen; ++i) { + shape::index2coords(i, zShapeInfo, zCoords.data()); - shape::index2coords(i, zShapeInfo, zCoords.data()); + for (int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) { - for(int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) { + if (ix >= 0) { + if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) { + xCoords[ix--] = zCoords[iz]; + } else { + xCoords[ix--] = 0; + } + } - if(ix >= 0) { - if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) { - xCoords[ix--] = zCoords[iz]; - } else { - xCoords[ix--] = 0; + if (iy >= 0) { + if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) { + yCoords[iy--] = zCoords[iz]; + } else { + yCoords[iy--] = 0; + } } } - if(iy >= 0) { - if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) { - yCoords[iy--] = zCoords[iz]; - } else { - yCoords[iy--] = 0; - } - } + const auto xOffset = shape::getOffset(xShapeInfo, xCoords.data()); + const auto yOffset = shape::getOffset(yShapeInfo, yCoords.data()); + const auto zOffset = shape::getOffset(zShapeInfo, zCoords.data()); + + z[zOffset] = OpType::op(x[xOffset], y[yOffset], nullptr); } + }; - const auto xOffset = shape::getOffset(xShapeInfo, xCoords.data()); - const auto yOffset = shape::getOffset(yShapeInfo, yCoords.data()); - const auto zOffset = shape::getOffset(zShapeInfo, zCoords.data()); - - z[zOffset] = OpType::op(x[xOffset], y[yOffset]); - } + samediff::Threads::parallel_for(func, 0, zLen); } template @@ -163,36 +169,39 @@ void TrueBroadcastIntHelper::exec(const NDArray& xArr, const NDArray& yArr, N std::vector xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf()); - PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(zLen > Environment::getInstance()->elementwiseThreshold()) firstprivate(xCoords, yCoords, zCoords)) - for (Nd4jLong i = 0; i < zLen; ++i) { + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; ++i) { - shape::index2coords(i, zShapeInfo, zCoords.data()); + shape::index2coords(i, zShapeInfo, zCoords.data()); - for(int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) { + for (int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) { - if(ix >= 0) { - if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) { - xCoords[ix--] = zCoords[iz]; - } else { - xCoords[ix--] = 0; + if (ix >= 0) { + if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) { + xCoords[ix--] = zCoords[iz]; + } else { + xCoords[ix--] = 0; + } + } + + if (iy >= 0) { + if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) { + yCoords[iy--] = zCoords[iz]; + } else { + yCoords[iy--] = 0; + } } } - if(iy >= 0) { - if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) { - yCoords[iy--] = zCoords[iz]; - } else { - yCoords[iy--] = 0; - } - } + const auto xOffset = shape::getOffset(xShapeInfo, xCoords.data()); + const auto yOffset = shape::getOffset(yShapeInfo, yCoords.data()); + const auto zOffset = shape::getOffset(zShapeInfo, zCoords.data()); + + z[zOffset] = OpType::op(x[xOffset], y[yOffset]); } + }; - const auto xOffset = shape::getOffset(xShapeInfo, xCoords.data()); - const auto yOffset = shape::getOffset(yShapeInfo, yCoords.data()); - const auto zOffset = shape::getOffset(zShapeInfo, zCoords.data()); - - z[zOffset] = OpType::op(x[xOffset], y[yOffset]); - } + samediff::Threads::parallel_for(func, 0, zLen); } template diff --git a/libnd4j/include/helpers/cuda/TrueBroadcastHelper.cu b/libnd4j/include/helpers/cuda/TrueBroadcastHelper.cu index 8f67f0004..4b7904bca 100644 --- a/libnd4j/include/helpers/cuda/TrueBroadcastHelper.cu +++ b/libnd4j/include/helpers/cuda/TrueBroadcastHelper.cu @@ -86,7 +86,7 @@ __global__ static void trueBroadcastCuda(const void* vx, const Nd4jLong* xShapeI const auto zOffset = shape::getOffset(zShapeInfo, zCoords); const auto yOffset = shape::getOffset(yShapeInfo, yCoords); - z[zOffset] = OpType::op(x[xOffset], y[yOffset]); + z[zOffset] = OpType::op(x[xOffset], y[yOffset], nullptr); } } @@ -172,7 +172,7 @@ __global__ static void trueBroadcastBoolCuda(const void* vx, const Nd4jLong* xSh const auto zOffset = shape::getOffset(zShapeInfo, zCoords); const auto yOffset = shape::getOffset(yShapeInfo, yCoords); - z[zOffset] = OpType::op(x[xOffset], y[yOffset]); + z[zOffset] = OpType::op(x[xOffset], y[yOffset], nullptr); } } diff --git a/libnd4j/include/loops/broadcasting_bool.h b/libnd4j/include/loops/broadcasting_bool.h index 3b0958be1..7ba5fa9eb 100644 --- a/libnd4j/include/loops/broadcasting_bool.h +++ b/libnd4j/include/loops/broadcasting_bool.h @@ -65,13 +65,14 @@ namespace functions { Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, + void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ); template - static __host__ void intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ); + static __host__ void intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ); - static __host__ void execBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ); + static __host__ void execBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ); template static __device__ void transformInverseCuda( @@ -81,13 +82,14 @@ namespace functions { Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, + void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ); template - static __host__ void intermediateInverseBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ); + static __host__ void intermediateInverseBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ); - static __host__ void execInverseBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ); + static __host__ void execInverseBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ); #else @@ -98,6 +100,7 @@ namespace functions { Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, + void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, @@ -114,6 +117,7 @@ namespace functions { Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, + void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, @@ -141,6 +145,7 @@ namespace functions { Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, + void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, @@ -157,6 +162,7 @@ namespace functions { Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, + void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadShapeInfo, diff --git a/libnd4j/include/loops/cpu/broadcasting_bool.cpp b/libnd4j/include/loops/cpu/broadcasting_bool.cpp index 7a3eb1e31..8d62b9506 100644 --- a/libnd4j/include/loops/cpu/broadcasting_bool.cpp +++ b/libnd4j/include/loops/cpu/broadcasting_bool.cpp @@ -39,6 +39,7 @@ namespace functions { Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, + void *extraParams, int *dimension, int dimensionLength, Nd4jLong *xTadShapeInfo, @@ -53,6 +54,7 @@ namespace functions { yShapeInfo, z, zShapeInfo, + extraParams, dimension, dimensionLength, xTadShapeInfo, @@ -69,6 +71,7 @@ namespace functions { Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, + void *extraParams, int *dimension, int dimensionLength, Nd4jLong *xTadShapeInfo, @@ -83,6 +86,7 @@ namespace functions { yShapeInfo, z, zShapeInfo, + extraParams, dimension, dimensionLength, xTadShapeInfo, @@ -99,6 +103,7 @@ namespace functions { Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, + void *vextraParams, int *dimension, int dimensionLength, Nd4jLong *xTadShapeInfo, @@ -111,6 +116,7 @@ namespace functions { auto x = reinterpret_cast(vx); auto y = reinterpret_cast(vy); auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); //decompose in to several sub tads after //moving all dimensions (in sorted order) @@ -155,7 +161,7 @@ namespace functions { PRAGMA_OMP_SIMD for (unsigned int f = 0; f < tadLength; f++) - oZ[f] = OpType::op(oX[f], y[f]); + oZ[f] = OpType::op(oX[f], y[f], extraParams); } } else if(kindOfLoop == nd4j::LoopKind::EWSNONZERO) { @@ -165,7 +171,7 @@ namespace functions { PRAGMA_OMP_SIMD for (unsigned int f = 0; f < tadLength; f++) - oZ[f * zEws] = OpType::op(oX[f * xEws], y[f * yEws]); + oZ[f * zEws] = OpType::op(oX[f * xEws], y[f * yEws], extraParams); }; } else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xTadShapeShapeInfo, zTadShapeInfo)) { @@ -179,7 +185,7 @@ namespace functions { PRAGMA_OMP_SIMD for (int f = 0; f < tadLength; f++) { auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX); - oZ[offset] = OpType::op(oX[offset], y[offset]); + oZ[offset] = OpType::op(oX[offset], y[offset], extraParams); } }; } @@ -197,7 +203,7 @@ namespace functions { for (int f = 0; f < tadLength; f++) { auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX); auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ); - oZ[zOffset] = OpType::op(oX[offset], y[offset]); + oZ[zOffset] = OpType::op(oX[offset], y[offset], extraParams); } }; } @@ -215,7 +221,7 @@ namespace functions { for (int f = 0; f < tadLength; f++) { auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX); auto yOffset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY); - oZ[offset] = OpType::op(oX[offset], y[yOffset]); + oZ[offset] = OpType::op(oX[offset], y[yOffset], extraParams); } }; @@ -234,7 +240,7 @@ namespace functions { for (int f = 0; f < tadLength; f++) { auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX); auto offset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY); - oZ[offset] = OpType::op(oX[xOffset], y[offset]); + oZ[offset] = OpType::op(oX[xOffset], y[offset], extraParams); } }; } @@ -255,7 +261,7 @@ namespace functions { auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX); auto yOffset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY); auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ); - oZ[zOffset] = OpType::op(oX[xOffset], y[yOffset]); + oZ[zOffset] = OpType::op(oX[xOffset], y[yOffset], extraParams); } }; } @@ -270,6 +276,7 @@ namespace functions { Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, + void *vextraParams, int *dimension, int dimensionLength, Nd4jLong *yTadShapeInfo, @@ -282,6 +289,7 @@ namespace functions { auto x = reinterpret_cast(vx); auto y = reinterpret_cast(vy); auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); //decompose in to several sub tads after //moving all dimensions (in sorted order) @@ -326,7 +334,7 @@ namespace functions { PRAGMA_OMP_SIMD for (unsigned int f = 0; f < tadLength; f++) - oZ[f] = OpType::op(x[f], oY[f]); + oZ[f] = OpType::op(x[f], oY[f], extraParams); } } else if(kindOfLoop == nd4j::LoopKind::EWSNONZERO) { @@ -336,7 +344,7 @@ namespace functions { PRAGMA_OMP_SIMD for (uint f = 0; f < tadLength; f++) - oZ[f * zEws] = OpType::op(x[f * xEws], oY[f * yEws]); + oZ[f * zEws] = OpType::op(x[f * xEws], oY[f * yEws], extraParams); } } else if(shape::haveSameShapeAndStrides(yTadShapeShapeInfo, xShapeInfo) && shape::haveSameShapeAndStrides(yTadShapeShapeInfo, zTadShapeInfo)) { @@ -351,7 +359,7 @@ namespace functions { PRAGMA_OMP_SIMD for (int f = 0; f < tadLength; f++) { auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY); - oZ[offset] = OpType::op(x[offset], oY[offset]); + oZ[offset] = OpType::op(x[offset], oY[offset], extraParams); } } } @@ -370,7 +378,7 @@ namespace functions { for (int f = 0; f < tadLength; f++) { auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY); auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ); - oZ[zOffset] = OpType::op(x[offset], oY[offset]); + oZ[zOffset] = OpType::op(x[offset], oY[offset], extraParams); } } } @@ -389,7 +397,7 @@ namespace functions { for (int f = 0; f < tadLength; f++) { auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY); auto xOffset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX); - oZ[offset] = OpType::op(x[xOffset], oY[offset]); + oZ[offset] = OpType::op(x[xOffset], oY[offset], extraParams); } } } @@ -408,7 +416,7 @@ namespace functions { for (int f = 0; f < tadLength; f++) { auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY); auto offset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX); - oZ[offset] = OpType::op(x[offset], oY[yOffset]); + oZ[offset] = OpType::op(x[offset], oY[yOffset], extraParams); } } } @@ -430,7 +438,7 @@ namespace functions { auto xOffset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX); auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY); auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ); - oZ[zOffset] = OpType::op(x[xOffset], oY[yOffset]); + oZ[zOffset] = OpType::op(x[xOffset], oY[yOffset], extraParams); } } } diff --git a/libnd4j/include/loops/cuda/broadcasting_bool.cu b/libnd4j/include/loops/cuda/broadcasting_bool.cu index af354a2e2..d5a45ceec 100644 --- a/libnd4j/include/loops/cuda/broadcasting_bool.cu +++ b/libnd4j/include/loops/cuda/broadcasting_bool.cu @@ -40,10 +40,11 @@ static __global__ void broadcastBoolSimple( Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, + void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { - functions::broadcast::BroadcastBool::template transformCuda(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo,dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ); + functions::broadcast::BroadcastBool::template transformCuda(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo, extraParams, dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ); } ////////////////////////////////////////////////////////////////////////// @@ -55,10 +56,11 @@ static __global__ void broadcastBoolInverseSimple( Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, + void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { - functions::broadcast::BroadcastBool::template transformInverseCuda(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo,dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ); + functions::broadcast::BroadcastBool::template transformInverseCuda(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo,extraParams,dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ); } namespace functions { @@ -66,15 +68,15 @@ namespace functions { ////////////////////////////////////////////////////////////////////////// template template - __host__ void BroadcastBool::intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { - broadcastBoolSimple<<>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ); + __host__ void BroadcastBool::intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { + broadcastBoolSimple<<>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ); nd4j::DebugHelper::checkErrorCode(stream, "intermediateBroadcastBool(...) failed"); } ////////////////////////////////////////////////////////////////////////// template - __host__ void BroadcastBool::execBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { - DISPATCH_BY_OPNUM_TT(intermediateBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), OPS_A(BROADCAST_BOOL_OPS)) + __host__ void BroadcastBool::execBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { + DISPATCH_BY_OPNUM_TT(intermediateBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), OPS_A(BROADCAST_BOOL_OPS)) DEBUG_KERNEL(stream, opNum); } @@ -82,15 +84,15 @@ namespace functions { ////////////////////////////////////////////////////////////////////////// template template - __host__ void BroadcastBool::intermediateInverseBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { - broadcastBoolInverseSimple<<>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ); + __host__ void BroadcastBool::intermediateInverseBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { + broadcastBoolInverseSimple<<>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ); nd4j::DebugHelper::checkErrorCode(stream, "intermediateBroadcastBool(...) failed"); } ////////////////////////////////////////////////////////////////////////// template - __host__ void BroadcastBool::execInverseBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { - DISPATCH_BY_OPNUM_TT(intermediateInverseBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), OPS_A(BROADCAST_BOOL_OPS)) + __host__ void BroadcastBool::execInverseBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { + DISPATCH_BY_OPNUM_TT(intermediateInverseBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), OPS_A(BROADCAST_BOOL_OPS)) DEBUG_KERNEL(stream, opNum); } @@ -102,6 +104,7 @@ namespace functions { void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, + void *vextraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { @@ -113,6 +116,7 @@ namespace functions { auto x = reinterpret_cast(vx); auto y = reinterpret_cast(vy); auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); //decompose in to several sub tads after //moving all dimensions (in sorted order) @@ -140,7 +144,7 @@ namespace functions { if(tadEWS > 0 && zEWS > 0 && xEWS > 0 && dimensionLength == 1) { for (int i = threadIdx.x; i < tadLength; i+= blockDim.x) - rZ[i * zEWS] = OpType::op(x[i * xEWS], rY[i * tadEWS]); + rZ[i * zEWS] = OpType::op(x[i * xEWS], rY[i * tadEWS], extraParams); } else { // it is expected that x and z tads and y array all have the same length @@ -149,7 +153,7 @@ namespace functions { auto yOffset = shape::getIndexOffset(i, tadOnlyShapeInfo); auto zOffset = shape::getIndexOffset(i, tadOnlyShapeInfoZ); - rZ[zOffset] = OpType::op(x[xOffset], rY[yOffset]); + rZ[zOffset] = OpType::op(x[xOffset], rY[yOffset], extraParams); } } } @@ -162,6 +166,7 @@ namespace functions { void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, void *vz, Nd4jLong *zShapeInfo, + void *vextraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { @@ -173,6 +178,7 @@ namespace functions { auto x = reinterpret_cast(vx); auto y = reinterpret_cast(vy); auto z = reinterpret_cast(vz); + auto extraParams = reinterpret_cast(vextraParams); //decompose in to several sub tads after //moving all dimensions (in sorted order) @@ -208,7 +214,7 @@ namespace functions { if(tadEWS > 0 && zEWS > 0 && yEWS > 0 && dimensionLength == 1) { for (int i = threadIdx.x; i < tadLength; i+= blockDim.x) - rZ[i * zEWS] = OpType::op(rX[i * tadEWS], y[i * yEWS]); + rZ[i * zEWS] = OpType::op(rX[i * tadEWS], y[i * yEWS], extraParams); } else { // it is expected that x and z tads and y array all have the same length @@ -217,7 +223,7 @@ namespace functions { auto yOffset = shape::getIndexOffset(i, yShapeInfo); auto zOffset = shape::getIndexOffset(i, tadOnlyShapeInfoZ); - rZ[zOffset] = OpType::op(rX[xOffset], y[yOffset]); + rZ[zOffset] = OpType::op(rX[xOffset], y[yOffset], extraParams); } } } diff --git a/libnd4j/include/loops/legacy_ops.h b/libnd4j/include/loops/legacy_ops.h index 918850e34..5108ba4a7 100644 --- a/libnd4j/include/loops/legacy_ops.h +++ b/libnd4j/include/loops/legacy_ops.h @@ -45,13 +45,13 @@ (2, LessThan),\ (3, Epsilon),\ (4, GreaterThanOrEqual),\ - (5, LessThanOrEqual),\ + (5, MatchCondition) ,\ (6, NotEqualTo),\ (7, And),\ (8, Or),\ (9, Xor) ,\ - (10, Not) - + (10, Not) ,\ + (11, LessThanOrEqual) #define BROADCAST_OPS \ (0, Add), \ @@ -198,12 +198,13 @@ (2, LessThan),\ (3, Epsilon),\ (4, GreaterThanOrEqual),\ - (5, LessThanOrEqual),\ + (5, MatchCondition) ,\ (6, NotEqualTo),\ (7, And),\ (8, Or),\ (9, Xor) ,\ - (10, Not) + (10, Not) ,\ + (11, LessThanOrEqual) #define SCALAR_OPS \ (0, Add),\ @@ -341,12 +342,13 @@ (2, LessThan),\ (3, Epsilon),\ (4, GreaterThanOrEqual),\ - (5, LessThanOrEqual),\ + (5, MatchCondition) ,\ (6, NotEqualTo),\ (7, And),\ (8, Or),\ (9, Xor) ,\ - (10, Not) + (10, Not) ,\ + (11, LessThanOrEqual) #define PAIRWISE_TRANSFORM_OPS \ (0, Add),\ diff --git a/libnd4j/include/ops/ops.h b/libnd4j/include/ops/ops.h index 1cdf08130..f3f9f9699 100644 --- a/libnd4j/include/ops/ops.h +++ b/libnd4j/include/ops/ops.h @@ -2302,54 +2302,66 @@ namespace simdOps { return old + opOutput; } - // this op return 1.0 if condition met, 0.0 otherwise - op_def static Z op(X d1, X *extraParams) { - X compare = extraParams[0]; - X eps = extraParams[1]; - - auto mode = static_cast(extraParams[2]); - //printf("value: %f; comp: %f; eps: %f; mode: %i;\n", (float) d1, (float) compare, (float) eps, mode); - - switch (mode) { - case 0: // equals - return nd4j::math::nd4j_abs(d1 - compare) <= eps ? 1 : 0; - case 1: // not equals - return nd4j::math::nd4j_abs(d1 - compare) > eps ? 1 : 0; - case 2: // less_than - return d1 < compare ? 1 : 0; - case 3: // greater_than - return d1 > compare ? 1 : 0; - case 4: // less_or_equals_than - return d1 <= compare ? 1 : 0; - case 5: // greater_or_equals_than - return d1 >= compare ? 1 : 0; - case 6: // abs_less_than - return nd4j::math::nd4j_abs(d1) < compare ? 1 : 0; - case 7: // abs_greater_than - return nd4j::math::nd4j_abs(d1) > compare ? 1 : 0; - case 8: // is inf - return nd4j::math::nd4j_isinf(d1) ? 1 : 0; - case 9: // is nan - return nd4j::math::nd4j_isnan(d1) ? 1 : 0; - case 10: - return (d1 == compare) ? 1 : 0; - case 11: - return (d1 != compare) ? 1 : 0; - case 12: // abs_greater_or_equals_than - return nd4j::math::nd4j_abs(d1) >= compare ? 1 : 0; - case 13: // abs_less_or_equals_than - return nd4j::math::nd4j_abs(d1) <= compare ? 1 : 0; + op_def static Z op(X d1, X compare, X eps, int mode) { + switch (mode) { + case 0: // equals + return nd4j::math::nd4j_abs(d1 - compare) <= eps ? 1 : 0; + case 1: // not equals + return nd4j::math::nd4j_abs(d1 - compare) > eps ? 1 : 0; + case 2: // less_than + return d1 < compare ? 1 : 0; + case 3: // greater_than + return d1 > compare ? 1 : 0; + case 4: // less_or_equals_than + return d1 <= compare ? 1 : 0; + case 5: // greater_or_equals_than + return d1 >= compare ? 1 : 0; + case 6: // abs_less_than + return nd4j::math::nd4j_abs(d1) < compare ? 1 : 0; + case 7: // abs_greater_than + return nd4j::math::nd4j_abs(d1) > compare ? 1 : 0; + case 8: // is inf + return nd4j::math::nd4j_isinf(d1) ? 1 : 0; + case 9: // is nan + return nd4j::math::nd4j_isnan(d1) ? 1 : 0; + case 10: + return (d1 == compare) ? 1 : 0; + case 11: + return (d1 != compare) ? 1 : 0; + case 12: // abs_greater_or_equals_than + return nd4j::math::nd4j_abs(d1) >= compare ? 1 : 0; + case 13: // abs_less_or_equals_than + return nd4j::math::nd4j_abs(d1) <= compare ? 1 : 0; case 14: // isFinite return !(nd4j::math::nd4j_isinf(d1) || nd4j::math::nd4j_isnan(d1)) ? 1 : 0; case 15: // isInfinite return nd4j::math::nd4j_isinf(d1) || nd4j::math::nd4j_isnan(d1) ? 1 : 0; - default: - printf("Undefined match condition: [%i]\n", mode); - } + default: + printf("Undefined match condition: [%i]\n", mode); + } return d1; + } + + // this op return 1.0 if condition met, 0.0 otherwise + op_def static Z op(X d1, X compare, X *extraParams) { + X eps = extraParams[1]; + + auto mode = static_cast(extraParams[0]); + + return op(d1, compare, eps, mode); + } + + // this op return 1.0 if condition met, 0.0 otherwise + op_def static Z op(X d1, X *extraParams) { + X compare = extraParams[0]; + X eps = extraParams[1]; + + auto mode = static_cast(extraParams[2]); + + return op(d1, compare, eps, mode); } op_def static Z postProcess(Z reduction, Nd4jLong n, X *extraParams) { diff --git a/libnd4j/tests_cpu/layers_tests/CudaBasicsTests1.cu b/libnd4j/tests_cpu/layers_tests/CudaBasicsTests1.cu index 9180393ca..c8b6fa1d9 100644 --- a/libnd4j/tests_cpu/layers_tests/CudaBasicsTests1.cu +++ b/libnd4j/tests_cpu/layers_tests/CudaBasicsTests1.cu @@ -1342,6 +1342,7 @@ TEST_F(CudaBasicsTests1, execBroadcastBool_1) { nullptr, x.getShapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, y.getShapeInfo(), y.specialBuffer(), y.specialShapeInfo(), nullptr, z.getShapeInfo(), z.specialBuffer(), z.specialShapeInfo(), + nullptr, (int*)devicePtrs[0], dimensions.size(), (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], nullptr, nullptr); @@ -1400,6 +1401,7 @@ TEST_F(CudaBasicsTests1, execBroadcastBool_2) { nullptr, x.getShapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, y.getShapeInfo(), y.specialBuffer(), y.specialShapeInfo(), nullptr, z.getShapeInfo(), z.specialBuffer(), z.specialShapeInfo(), + nullptr, (int*)devicePtrs[0], dimensions.size(), (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], nullptr, nullptr); diff --git a/libnd4j/tests_cpu/layers_tests/LegacyOpsTests.cpp b/libnd4j/tests_cpu/layers_tests/LegacyOpsTests.cpp index f48ee54f6..ffcd5759e 100644 --- a/libnd4j/tests_cpu/layers_tests/LegacyOpsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/LegacyOpsTests.cpp @@ -674,6 +674,7 @@ TEST_F(LegacyOpsTests, test_inverse_broadcast_2) { x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), + nullptr, nullptr, 0, tadPackY.platformShapeInfo(), tadPackY.platformOffsets(), tadPackY.platformShapeInfo(), tadPackY.platformOffsets()); diff --git a/libnd4j/tests_cpu/layers_tests/NativeOpsTests.cpp b/libnd4j/tests_cpu/layers_tests/NativeOpsTests.cpp index 95b3027cc..1846fc397 100644 --- a/libnd4j/tests_cpu/layers_tests/NativeOpsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/NativeOpsTests.cpp @@ -202,6 +202,7 @@ printf("Unsupported for cuda now.\n"); nullptr, nullptr, exp.buffer(), exp.shapeInfo(), nullptr, nullptr, + nullptr, dimension.buffer(), dimension.shapeInfo(), nullptr, nullptr); ASSERT_TRUE(exp.e(1) && !exp.e(0)); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/LegacyOpMapper.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/LegacyOpMapper.java index f76c42c50..33d983f23 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/LegacyOpMapper.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/LegacyOpMapper.java @@ -417,7 +417,7 @@ public class LegacyOpMapper { case 4: return ScalarGreaterThanOrEqual.class; case 5: - return ScalarLessThanOrEqual.class; + return MatchCondition.class; case 6: return ScalarNotEquals.class; case 7: @@ -428,6 +428,8 @@ public class LegacyOpMapper { return ScalarXor.class; case 10: return ScalarNot.class; + case 11: + return ScalarLessThanOrEqual.class; default: throw new UnsupportedOperationException("No known scalar bool op for op number: " + opNum); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java index 16931e434..77b946559 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java @@ -1864,7 +1864,10 @@ public abstract class BaseNDArray implements INDArray, Iterable { @Override public INDArray match(INDArray comp, Condition condition) { - return Nd4j.getExecutioner().exec(new MatchConditionTransform(this,comp,condition)); + // TODO: obviously, we can make this broadcastable, eventually. But this will require new CustomOp based on MatchCondition + Preconditions.checkArgument(Arrays.equals(this.shape(), comp.shape()), "Shapes must be equal"); + Preconditions.checkArgument(this.dataType() == comp.dataType(), "Data types bmust be equal"); + return Nd4j.getExecutioner().exec(new MatchConditionTransform(this, comp, Nd4j.createUninitialized(DataType.BOOL, this.shape()), condition)); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/bool/BroadcastLessThanOrEqual.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/bool/BroadcastLessThanOrEqual.java index e9ee1db2c..dc9017582 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/bool/BroadcastLessThanOrEqual.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/bool/BroadcastLessThanOrEqual.java @@ -64,7 +64,7 @@ public class BroadcastLessThanOrEqual extends BaseBroadcastBoolOp { @Override public int opNum() { - return 5; + return 11; } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarLessThanOrEqual.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarLessThanOrEqual.java index 4dede1775..6c9a3a893 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarLessThanOrEqual.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarLessThanOrEqual.java @@ -55,7 +55,7 @@ public class ScalarLessThanOrEqual extends BaseScalarBoolOp { @Override public int opNum() { - return 5; + return 11; } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/MatchConditionTransform.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/MatchConditionTransform.java index 0ae4b266c..dea1c9c3b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/MatchConditionTransform.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/MatchConditionTransform.java @@ -53,6 +53,11 @@ public class MatchConditionTransform extends BaseTransformBoolOp { public MatchConditionTransform() {} + public MatchConditionTransform(@NonNull INDArray x, @NonNull INDArray y, @NonNull INDArray z, @NonNull Condition condition) { + this(x, z, Nd4j.EPS_THRESHOLD, condition); + this.y = y; + } + public MatchConditionTransform(@NonNull INDArray x, @NonNull INDArray z, @NonNull Condition condition) { this(x, z, Nd4j.EPS_THRESHOLD, condition); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/conditions/Conditions.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/conditions/Conditions.java index e1525f381..f426223ef 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/conditions/Conditions.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/indexing/conditions/Conditions.java @@ -25,71 +25,283 @@ import org.nd4j.linalg.factory.Nd4j; */ public class Conditions { - + /** + * This method will create Condition that checks if value is infinite + * @return + */ public static Condition isInfinite() { return new IsInfinite(); } + /** + * This method will create Condition that checks if value is NaN + * @return + */ public static Condition isNan() { return new IsNaN(); } + /** + * This method will create Condition that checks if value is finite + * @return + */ public static Condition isFinite() { return new IsFinite(); } + /** + * This method will create Condition that checks if value is NOT finite + * @return + */ public static Condition notFinite() { return new NotFinite(); } + /** + * This method will create Condition that checks if value is two values are not equal wrt eps + * + * PLEASE NOTE: This condition should be used only with pairwise methods, i.e. INDArray.match(...) + * @return + */ + public static Condition epsNotEquals() { + // in case of pairwise MatchCondition we don't really care about number here + return epsNotEquals(0.0); + } + + /** + * This method will create Condition that checks if value is two values are not equal wrt eps + * + * @return + */ public static Condition epsNotEquals(Number value) { return new EpsilonNotEquals(value); } + /** + * This method will create Condition that checks if value is two values are equal wrt eps + * + * PLEASE NOTE: This condition should be used only with pairwise methods, i.e. INDArray.match(...) + * @return + */ + public static Condition epsEquals() { + // in case of pairwise MatchCondition we don't really care about number here + return epsEquals(0.0); + } + + /** + * This method will create Condition that checks if value is two values are equal wrt eps + * + * @return + */ public static Condition epsEquals(Number value) { return epsEquals(value, Nd4j.EPS_THRESHOLD); } + /** + * This method will create Condition that checks if value is two values are equal wrt eps + * + * @return + */ public static Condition epsEquals(Number value, Number epsilon) { return new EpsilonEquals(value, epsilon.doubleValue()); } + /** + * This method will create Condition that checks if value is two values are equal + * + * PLEASE NOTE: This condition should be used only with pairwise methods, i.e. INDArray.match(...) + * @return + */ + public static Condition equals() { + // in case of pairwise MatchCondition we don't really care about number here + return equals(0.0); + } + + /** + * This method will create Condition that checks if value is two values are equal + * + * @return + */ public static Condition equals(Number value) { return new EqualsCondition(value); } + /** + * This method will create Condition that checks if value is two values are not equal + * + * PLEASE NOTE: This condition should be used only with pairwise methods, i.e. INDArray.match(...) + * @return + */ + public static Condition notEquals() { + // in case of pairwise MatchCondition we don't really care about number here + return notEquals(0.0); + } + + /** + * This method will create Condition that checks if value is two values are not equal + * + * @return + */ public static Condition notEquals(Number value) { return new NotEqualsCondition(value); } + /** + * This method will create Condition that checks if value is value X is greater than value Y + * + * PLEASE NOTE: This condition should be used only with pairwise methods, i.e. INDArray.match(...) + * @return + */ + public static Condition greaterThan() { + // in case of pairwise MatchCondition we don't really care about number here + return greaterThan(0.0); + } + + /** + * This method will create Condition that checks if value is value X is greater than value Y + * + * @return + */ public static Condition greaterThan(Number value) { return new GreaterThan(value); } + /** + * This method will create Condition that checks if value is value X is less than value Y + * + * PLEASE NOTE: This condition should be used only with pairwise methods, i.e. INDArray.match(...) + * @return + */ + public static Condition lessThan() { + // in case of pairwise MatchCondition we don't really care about number here + return lessThan(0.0); + } + + /** + * This method will create Condition that checks if value is value X is less than value Y + * + * @return + */ public static Condition lessThan(Number value) { return new LessThan(value); } + /** + * This method will create Condition that checks if value is value X is less than or equal to value Y + * + * PLEASE NOTE: This condition should be used only with pairwise methods, i.e. INDArray.match(...) + * @return + */ + public static Condition lessThanOrEqual() { + // in case of pairwise MatchCondition we don't really care about number here + return lessThanOrEqual(0.0); + } + + /** + * This method will create Condition that checks if value is value X is less than or equal to value Y + * + * @return + */ public static Condition lessThanOrEqual(Number value) { return new LessThanOrEqual(value); } + /** + * This method will create Condition that checks if value is value X is greater than or equal to value Y + * + * PLEASE NOTE: This condition should be used only with pairwise methods, i.e. INDArray.match(...) + * @return + */ + public static Condition greaterThanOrEqual() { + // in case of pairwise MatchCondition we don't really care about number here + return greaterThanOrEqual(0.0); + } + + /** + * This method will create Condition that checks if value is value X is greater than or equal to value Y + * + * @return + */ public static Condition greaterThanOrEqual(Number value) { return new GreaterThanOrEqual(value); } + /** + * This method will create Condition that checks if value is value X is greater than or equal to value Y in absolute values + * + * PLEASE NOTE: This condition should be used only with pairwise methods, i.e. INDArray.match(...) + * @return + */ + public static Condition absGreaterThanOrEqual() { + // in case of pairwise MatchCondition we don't really care about number here + return absGreaterThanOrEqual(0.0); + } + + /** + * This method will create Condition that checks if value is value X is greater than or equal to value Y in absolute values + * + * @return + */ public static Condition absGreaterThanOrEqual(Number value) { return new AbsValueGreaterOrEqualsThan(value); } + /** + * This method will create Condition that checks if value is value X is less than or equal to value Y in absolute values + * + * PLEASE NOTE: This condition should be used only with pairwise methods, i.e. INDArray.match(...) + * @return + */ + public static Condition absLessThanOrEqual() { + // in case of pairwise MatchCondition we don't really care about number here + return absLessThanOrEqual(0.0); + } + + /** + * This method will create Condition that checks if value is value X is less than or equal to value Y in absolute values + * + * @return + */ public static Condition absLessThanOrEqual(Number value) { return new AbsValueLessOrEqualsThan(value); } + /** + * This method will create Condition that checks if value is value X is greater than value Y in absolute values + * + * PLEASE NOTE: This condition should be used only with pairwise methods, i.e. INDArray.match(...) + * @return + */ + public static Condition absGreaterThan() { + // in case of pairwise MatchCondition we don't really care about number here + return absGreaterThan(0.0); + } + + /** + * This method will create Condition that checks if value is value X is greater than value Y in absolute values + * + * @return + */ public static Condition absGreaterThan(Number value) { return new AbsValueGreaterThan(value); } + /** + * This method will create Condition that checks if value is value X is less than value Y in absolute values + * + * PLEASE NOTE: This condition should be used only with pairwise methods, i.e. INDArray.match(...) + * @return + */ + public static Condition absLessThan() { + // in case of pairwise MatchCondition we don't really care about number here + return absLessThan(0.0); + } + + /** + * This method will create Condition that checks if value is value X is less than value Y in absolute values + * + * @return + */ public static Condition absLessThan(Number value) { return new AbsValueLessThan(value); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java index d4a7b8f8b..741978a3c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java @@ -129,6 +129,7 @@ public interface NativeOps { @Cast("Nd4jLong *") LongPointer resultShapeInfo, Pointer dresult, @Cast("Nd4jLong *") LongPointer dresultShapeInfo, + Pointer extraParams, Pointer hDimension, @Cast("Nd4jLong *") LongPointer hDimensionShape, Pointer dDimension, @Cast("Nd4jLong *") LongPointer dDimensionShape); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java index 20f2b5f22..904e1305e 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java @@ -198,7 +198,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { null, (LongPointer) AtomicAllocator.getInstance().getHostPointer(op.x().shapeInfoDataBuffer()), x, (LongPointer) xShapeInfo, null, (LongPointer) AtomicAllocator.getInstance().getHostPointer(op.y().shapeInfoDataBuffer()), y, (LongPointer) AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(),context), null, (LongPointer) AtomicAllocator.getInstance().getHostPointer(op.z().shapeInfoDataBuffer()), z, (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), - null, + null, null, (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), AtomicAllocator.getInstance().getPointer(op.dimensions(), context), null); @@ -805,7 +805,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, null, (LongPointer) hostYShapeInfo, y, (LongPointer) yShapeInfo, null, (LongPointer) hostZShapeInfo, z, (LongPointer) zShapeInfo, - null, + null, null, (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), AtomicAllocator.getInstance().getPointer(op.dimensions(), context), null); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java index 1799ceb22..f567873a2 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java @@ -916,6 +916,7 @@ public native void execBroadcastBool( Pointer dY, @Cast("Nd4jLong*") LongPointer dYShapeInfo, Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + Pointer extraParams, Pointer hDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, Pointer dDimension, @Cast("Nd4jLong*") LongPointer dDimensionShape); public native void execBroadcastBool( @@ -927,6 +928,7 @@ public native void execBroadcastBool( Pointer dY, @Cast("Nd4jLong*") LongBuffer dYShapeInfo, Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + Pointer extraParams, Pointer hDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, Pointer dDimension, @Cast("Nd4jLong*") LongBuffer dDimensionShape); public native void execBroadcastBool( @@ -938,6 +940,7 @@ public native void execBroadcastBool( Pointer dY, @Cast("Nd4jLong*") long[] dYShapeInfo, Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, + Pointer extraParams, Pointer hDimension, @Cast("Nd4jLong*") long[] hDimensionShape, Pointer dDimension, @Cast("Nd4jLong*") long[] dDimensionShape); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java index a2964b7a6..751f75cea 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java @@ -967,6 +967,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { null, null, op.z().data().addressPointer(), (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), null, null, + null, op.dimensions().data().addressPointer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null, diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java index 828d9b290..0ee807594 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java @@ -916,6 +916,7 @@ public native void execBroadcastBool( Pointer dY, @Cast("Nd4jLong*") LongPointer dYShapeInfo, Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, + Pointer extraParams, Pointer hDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, Pointer dDimension, @Cast("Nd4jLong*") LongPointer dDimensionShape); public native void execBroadcastBool( @@ -927,6 +928,7 @@ public native void execBroadcastBool( Pointer dY, @Cast("Nd4jLong*") LongBuffer dYShapeInfo, Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, + Pointer extraParams, Pointer hDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, Pointer dDimension, @Cast("Nd4jLong*") LongBuffer dDimensionShape); public native void execBroadcastBool( @@ -938,6 +940,7 @@ public native void execBroadcastBool( Pointer dY, @Cast("Nd4jLong*") long[] dYShapeInfo, Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, + Pointer extraParams, Pointer hDimension, @Cast("Nd4jLong*") long[] hDimensionShape, Pointer dDimension, @Cast("Nd4jLong*") long[] dDimensionShape); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java index 84dd02cd4..ca075c872 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java @@ -46,6 +46,7 @@ import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.indexing.NDArrayIndex; +import org.nd4j.linalg.indexing.conditions.Conditions; import org.nd4j.nativeblas.NativeOpsHolder; import java.util.ArrayList; @@ -1087,6 +1088,24 @@ public class CustomOpsTests extends BaseNd4jTest { assertArrayEquals(new long[]{1,10, 2}, lsd.get(0).getShape()); } + @Test + public void testMatch_1() { + INDArray x = Nd4j.ones(DataType.FLOAT, 3,3); + INDArray y = Nd4j.linspace(DataType.FLOAT, -5, 9, 1).reshape(3, 3); + val c = Conditions.equals(0.0); + + System.out.println("Y:\n" + y); + + INDArray z = x.match(y, c); + INDArray exp = Nd4j.createFromArray(new boolean[][]{ + {false, false, false}, + {false, false, false}, + {true, false, false} + }); + + assertEquals(exp, z); + } + @Test public void testCreateOp_1() {