Few fixes (#66)

* skip legacy transforms execution in case of empty input arrays

Signed-off-by: raver119 <raver119@gmail.com>

* - BroadcastBool ops now accept extraParams to make MatchCondition possible
- TrueBroadcastHelper now uses samediff::threads

Signed-off-by: raver119 <raver119@gmail.com>

* java side

Signed-off-by: raver119 <raver119@gmail.com>

* trigger jenkins

Signed-off-by: raver119 <raver119@gmail.com>

* update LessThanOrEqual opNum mapping

Signed-off-by: raver119 <raver119@gmail.com>

* update LessThanOrEqual opNum mapping

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-11-21 15:43:03 +03:00 committed by GitHub
parent 83cb0d9329
commit 064a56ccf1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
29 changed files with 607 additions and 203 deletions

View File

@ -2772,9 +2772,9 @@ void NDArray::applyBroadcast(nd4j::broadcast::BoolOps op, const std::vector<int>
// TODO: eventually we want separate tads here // TODO: eventually we want separate tads here
NDArray::prepareSpecialUse({result}, {this, other}); NDArray::prepareSpecialUse({result}, {this, other});
if(max == this) if(max == this)
NativeOpExecutioner::execBroadcastBool( getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), copy.data(), (int)copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets()); NativeOpExecutioner::execBroadcastBool( getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), nullptr, copy.data(), (int)copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets());
else else
NativeOpExecutioner::execInverseBroadcastBool(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), copy.data(), (int)copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets()); NativeOpExecutioner::execInverseBroadcastBool(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), nullptr, copy.data(), (int)copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets());
registerSpecialUse({result}, {this, other}); registerSpecialUse({result}, {this, other});
} }

View File

@ -284,6 +284,7 @@ static void execScalarInt(nd4j::LaunchContext *lc,
void *dY, Nd4jLong *dYShapeInfo, void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo,
void *extraParams,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ); Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ);
@ -296,6 +297,7 @@ static void execScalarInt(nd4j::LaunchContext *lc,
void *dY, Nd4jLong *dYShapeInfo, void *dY, Nd4jLong *dYShapeInfo,
void *result, Nd4jLong *resultShapeInfo, void *result, Nd4jLong *resultShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo,
void *extraParams,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ); Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);

View File

@ -179,6 +179,7 @@ ND4J_EXPORT void execBroadcastBool(
void *dY, Nd4jLong *dYShapeInfo, void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo,
void *extraParams,
void *hDimension, Nd4jLong *hDimensionShape, void *hDimension, Nd4jLong *hDimensionShape,
void *dDimension, Nd4jLong *dDimensionShape); void *dDimension, Nd4jLong *dDimensionShape);

View File

@ -156,6 +156,9 @@ void NativeOpExecutioner::execBroadcast(nd4j::LaunchContext *lc,
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo))
return;
#ifdef __ND4J_EXPERIMENTAL__ #ifdef __ND4J_EXPERIMENTAL__
BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::broadcast::Broadcast, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, LIBND4J_TYPES); BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::broadcast::Broadcast, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, LIBND4J_TYPES);
#else #else
@ -187,7 +190,8 @@ void NativeOpExecutioner::execInverseBroadcast(nd4j::LaunchContext *lc,
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo))
return;
if (!nd4j::Environment::getInstance()->isExperimentalBuild()) if (!nd4j::Environment::getInstance()->isExperimentalBuild())
if ((yType != xType && yType != nd4j::DataType::BOOL) || xType != zType) if ((yType != xType && yType != nd4j::DataType::BOOL) || xType != zType)
@ -219,6 +223,7 @@ void NativeOpExecutioner::execBroadcastBool(nd4j::LaunchContext *lc,
void *dY, Nd4jLong *dYShapeInfo, void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo,
void *extraParams,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) { Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) {
@ -228,8 +233,11 @@ void NativeOpExecutioner::execBroadcastBool(nd4j::LaunchContext *lc,
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo))
return;
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ, start, stop), LIBND4J_TYPES, BOOL_TYPES); BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ, start, stop), LIBND4J_TYPES, BOOL_TYPES);
}; };
auto xLen = shape::length(hXShapeInfo); auto xLen = shape::length(hXShapeInfo);
@ -247,22 +255,24 @@ void NativeOpExecutioner::execInverseBroadcastBool(nd4j::LaunchContext *lc,
void *dY, Nd4jLong *dYShapeInfo, void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo,
void *extraParams,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) { Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) {
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo))
return;
if (!nd4j::Environment::getInstance()->isExperimentalBuild()) if (!nd4j::Environment::getInstance()->isExperimentalBuild())
if (yType != xType || nd4j::DataType::BOOL != zType) if (yType != xType || nd4j::DataType::BOOL != zType)
throw nd4j::datatype_exception::build("NativeOps::execInverseBroadcastBool both operands must have same data type", xType, yType); throw nd4j::datatype_exception::build("NativeOps::execInverseBroadcastBool both operands must have same data type", xType, yType);
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::execInverse(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ, start, stop), LIBND4J_TYPES, BOOL_TYPES); BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::execInverse(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ, start, stop), LIBND4J_TYPES, BOOL_TYPES);
}; };
auto xLen = shape::length(hXShapeInfo); auto xLen = shape::length(hXShapeInfo);
@ -292,6 +302,9 @@ void NativeOpExecutioner::execBroadcastInt(nd4j::LaunchContext *lc,
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo))
return;
if (xType != yType || xType != zType) if (xType != yType || xType != zType)
throw nd4j::datatype_exception::build("NativeOpExecutioner::execBroadcastInt", zType, xType, yType); throw nd4j::datatype_exception::build("NativeOpExecutioner::execBroadcastInt", zType, xType, yType);
@ -321,12 +334,13 @@ void NativeOpExecutioner::execInverseBroadcastInt(nd4j::LaunchContext *lc,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) { Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) {
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo))
return;
if (xType != yType || xType != zType) if (xType != yType || xType != zType)
throw nd4j::datatype_exception::build("NativeOpExecutioner::execInverseBroadcastInt", zType, xType, yType); throw nd4j::datatype_exception::build("NativeOpExecutioner::execInverseBroadcastInt", zType, xType, yType);
@ -367,11 +381,13 @@ void NativeOpExecutioner::execPairwiseTransform(nd4j::LaunchContext *lc,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo,
void *extraParams) { void *extraParams) {
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo))
return;
#ifdef __ND4J_EXPERIMENTAL__ #ifdef __ND4J_EXPERIMENTAL__
BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::pairwise_transforms::PairWiseTransform, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, extraParams), LIBND4J_TYPES, LIBND4J_TYPES); BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::pairwise_transforms::PairWiseTransform, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, extraParams), LIBND4J_TYPES, LIBND4J_TYPES);
#else #else
@ -403,6 +419,9 @@ void NativeOpExecutioner::execPairwiseBoolTransform(nd4j::LaunchContext *lc,
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo))
return;
if (xType != yType) if (xType != yType)
throw nd4j::datatype_exception::build("NativeOpExecutioner::execPairwiseBoolTransform", xType, yType); throw nd4j::datatype_exception::build("NativeOpExecutioner::execPairwiseBoolTransform", xType, yType);
@ -429,11 +448,13 @@ void NativeOpExecutioner::execPairwiseIntTransform(nd4j::LaunchContext *lc,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo,
void *extraParams) { void *extraParams) {
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo))
return;
if (xType != yType || xType != zType) if (xType != yType || xType != zType)
throw nd4j::datatype_exception::build("NativeOpExecutioner::execPairwiseIntTransform", zType, xType, yType); throw nd4j::datatype_exception::build("NativeOpExecutioner::execPairwiseIntTransform", zType, xType, yType);
@ -837,11 +858,13 @@ void NativeOpExecutioner::execScalar(nd4j::LaunchContext *lc,
void *dScalar, Nd4jLong *dScalarShapeInfo, void *dScalar, Nd4jLong *dScalarShapeInfo,
void *extraParams, bool allowParallelism) { void *extraParams, bool allowParallelism) {
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo); auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo))
return;
#ifdef __ND4J_EXPERIMENTAL__ #ifdef __ND4J_EXPERIMENTAL__
BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::scalar::ScalarTransform, ::transform(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, hScalar, extraParams), LIBND4J_TYPES, LIBND4J_TYPES); BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::scalar::ScalarTransform, ::transform(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, hScalar, extraParams), LIBND4J_TYPES, LIBND4J_TYPES);
#else #else
@ -872,11 +895,13 @@ void NativeOpExecutioner::execScalar(nd4j::LaunchContext *lc,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) { Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) {
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo); auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo))
return;
#ifdef __ND4J_EXPERIMENTAL__ #ifdef __ND4J_EXPERIMENTAL__
BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::scalar::ScalarTransform, ::transform(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, hScalars, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, LIBND4J_TYPES); BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::scalar::ScalarTransform, ::transform(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, hScalars, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, LIBND4J_TYPES);
#else #else
@ -904,12 +929,13 @@ void NativeOpExecutioner::execScalarBool(nd4j::LaunchContext *lc,
void *dScalar, Nd4jLong *dSscalarShapeInfo, void *dScalar, Nd4jLong *dSscalarShapeInfo,
void *extraParams, bool allowParallelism) { void *extraParams, bool allowParallelism) {
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto yType = nd4j::ArrayOptions::dataType(hSscalarShapeInfo); auto yType = nd4j::ArrayOptions::dataType(hSscalarShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hSscalarShapeInfo))
return;
if (xType != yType) if (xType != yType)
throw nd4j::datatype_exception::build("NativeOpExecutioner::execScalarBool", xType, yType); throw nd4j::datatype_exception::build("NativeOpExecutioner::execScalarBool", xType, yType);
@ -939,11 +965,13 @@ void NativeOpExecutioner::execScalarBool(nd4j::LaunchContext *lc,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) { Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) {
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo); auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo))
return;
if (xType != yType) if (xType != yType)
throw nd4j::datatype_exception::build("NativeOpExecutioner::execScalarBool", xType, yType); throw nd4j::datatype_exception::build("NativeOpExecutioner::execScalarBool", xType, yType);
@ -969,12 +997,13 @@ void NativeOpExecutioner::execScalarInt(nd4j::LaunchContext *lc,
void *dScalar, Nd4jLong *dSscalarShapeInfo, void *dScalar, Nd4jLong *dSscalarShapeInfo,
void *extraParams, bool allowParallelism) { void *extraParams, bool allowParallelism) {
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto yType = nd4j::ArrayOptions::dataType(hSscalarShapeInfo); auto yType = nd4j::ArrayOptions::dataType(hSscalarShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hSscalarShapeInfo))
return;
if (xType != yType || xType != zType) if (xType != yType || xType != zType)
throw nd4j::datatype_exception::build("NativeOpExecutioner::execScalarInt", xType, yType); throw nd4j::datatype_exception::build("NativeOpExecutioner::execScalarInt", xType, yType);
@ -1004,11 +1033,13 @@ void NativeOpExecutioner::execScalarInt(nd4j::LaunchContext *lc,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) { Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) {
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo); auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo))
return;
if (xType != yType || xType != zType) if (xType != yType || xType != zType)
throw nd4j::datatype_exception::build("NativeOpExecutioner::execScalarInt", xType, yType); throw nd4j::datatype_exception::build("NativeOpExecutioner::execScalarInt", xType, yType);
@ -1126,6 +1157,9 @@ void NativeOpExecutioner::execTransformFloat(nd4j::LaunchContext *lc,
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo))
return;
auto func = PRAGMA_THREADS_DO { auto func = PRAGMA_THREADS_DO {
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformFloat, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES, FLOAT_TYPES); BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformFloat, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES, FLOAT_TYPES);
}; };
@ -1145,6 +1179,9 @@ void NativeOpExecutioner::execTransformBool(nd4j::LaunchContext *lc,
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo))
return;
auto func = PRAGMA_THREADS_DO { auto func = PRAGMA_THREADS_DO {
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformBool, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES, BOOL_TYPES); BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformBool, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES, BOOL_TYPES);
}; };
@ -1164,6 +1201,9 @@ void NativeOpExecutioner::execTransformAny(nd4j::LaunchContext *lc,
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo))
return;
auto func = PRAGMA_THREADS_DO { auto func = PRAGMA_THREADS_DO {
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformAny, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES, LIBND4J_TYPES); BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformAny, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES, LIBND4J_TYPES);
}; };
@ -1183,6 +1223,9 @@ void NativeOpExecutioner::execTransformSame(nd4j::LaunchContext *lc,
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo))
return;
auto func = PRAGMA_THREADS_DO { auto func = PRAGMA_THREADS_DO {
BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformSame, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES); BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformSame, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES);
}; };
@ -1202,6 +1245,9 @@ void NativeOpExecutioner::execTransformStrict(nd4j::LaunchContext *lc,
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo); auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo))
return;
auto func = PRAGMA_THREADS_DO { auto func = PRAGMA_THREADS_DO {
BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformStrict, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), FLOAT_TYPES); BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformStrict, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), FLOAT_TYPES);
}; };

View File

@ -231,8 +231,9 @@ void execBroadcastBool(Nd4jPointer *extraPointers,
void *dY, Nd4jLong *dYShapeInfo, void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo,
void *hDimension, Nd4jLong *hDimensionShape, void *extraParams,
void *dDimension, Nd4jLong *dDimensionShape) { void *hDimension, Nd4jLong *hDimensionShape,
void *dDimension, Nd4jLong *dDimensionShape) {
try { try {
auto dimension = reinterpret_cast<int *>(hDimension); auto dimension = reinterpret_cast<int *>(hDimension);
int dimensionLength = static_cast<int>(shape::length(hDimensionShape)); int dimensionLength = static_cast<int>(shape::length(hDimensionShape));
@ -259,6 +260,7 @@ void execBroadcastBool(Nd4jPointer *extraPointers,
dYShapeInfo, dYShapeInfo,
hZ, hZShapeInfo, hZ, hZShapeInfo,
dZ, dZShapeInfo, dZ, dZShapeInfo,
extraParams,
dimension, dimension,
dimensionLength, hTADShapeInfo, hTADOffsets, hTADShapeInfoZ, dimensionLength, hTADShapeInfo, hTADOffsets, hTADShapeInfoZ,
hTADOffsetsZ); hTADOffsetsZ);

View File

@ -101,6 +101,9 @@ void NativeOpExecutioner::execPairwiseTransform(nd4j::LaunchContext *lc,
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo))
return;
if (xType != zType && yType != zType) if (xType != zType && yType != zType)
throw std::runtime_error("NativeOpExecutioner::execPairwiseTransform requires Z operand to have either X or Y type"); throw std::runtime_error("NativeOpExecutioner::execPairwiseTransform requires Z operand to have either X or Y type");
if (lc == nullptr) if (lc == nullptr)
@ -139,6 +142,9 @@ void NativeOpExecutioner::execPairwiseBoolTransform( nd4j::LaunchContext *lc,
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo))
return;
if (!DataTypeUtils::isB(zType)) if (!DataTypeUtils::isB(zType))
throw nd4j::datatype_exception::build("NativeOpExecutioner::execPairwiseBoolTransform wrong Z operand data type", nd4j::DataType::BOOL, zType); throw nd4j::datatype_exception::build("NativeOpExecutioner::execPairwiseBoolTransform wrong Z operand data type", nd4j::DataType::BOOL, zType);
@ -172,6 +178,9 @@ void NativeOpExecutioner::execPairwiseIntTransform( nd4j::LaunchContext *lc,
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo))
return;
if (!DataTypeUtils::isZ(zType)) if (!DataTypeUtils::isZ(zType))
throw nd4j::datatype_exception::build("NativeOpExecutioner::execPairwiseIntTransform wrong Z operand data type", nd4j::DataType::BOOL, zType); throw nd4j::datatype_exception::build("NativeOpExecutioner::execPairwiseIntTransform wrong Z operand data type", nd4j::DataType::BOOL, zType);
@ -223,6 +232,7 @@ void NativeOpExecutioner::execBroadcastBool(nd4j::LaunchContext *lc,
void *dY, Nd4jLong *dYShapeInfo, void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo,
void *extraParams,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) { Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) {
@ -233,6 +243,9 @@ void NativeOpExecutioner::execBroadcastBool(nd4j::LaunchContext *lc,
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo))
return;
if (!DataTypeUtils::isB(zType)) if (!DataTypeUtils::isB(zType))
throw std::runtime_error("NativeOpExecutioner::execBroadcastBool requires Z operand to have BOOL type"); throw std::runtime_error("NativeOpExecutioner::execBroadcastBool requires Z operand to have BOOL type");
@ -244,7 +257,7 @@ void NativeOpExecutioner::execBroadcastBool(nd4j::LaunchContext *lc,
dim3 launchDims(256, 256, 1024); dim3 launchDims(256, 256, 1024);
BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::execBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, BOOL_TYPES) BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::execBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, BOOL_TYPES)
// TODO: remove after the release // TODO: remove after the release
auto res = cudaStreamSynchronize(*stream); auto res = cudaStreamSynchronize(*stream);
@ -260,6 +273,7 @@ void NativeOpExecutioner::execInverseBroadcastBool(nd4j::LaunchContext *lc,
void *dY, Nd4jLong *dYShapeInfo, void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo,
void *extraParams,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) { Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) {
@ -269,18 +283,18 @@ void NativeOpExecutioner::execInverseBroadcastBool(nd4j::LaunchContext *lc,
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo))
return;
if (!DataTypeUtils::isB(zType)) if (!DataTypeUtils::isB(zType))
throw std::runtime_error("NativeOpExecutioner::execBroadcastBool requires Z operand to have BOOL type"); throw std::runtime_error("NativeOpExecutioner::execBroadcastBool requires Z operand to have BOOL type");
if (yType != xType) if (yType != xType)
throw std::runtime_error("NativeOpExecutioner::execBroadcastBool requires both X & Y operands to have same type"); throw std::runtime_error("NativeOpExecutioner::execBroadcastBool requires both X & Y operands to have same type");
if (nd4j::Environment::getInstance()->isDebugAndVerbose())
printf("F3BI opNum:[%i]\n", opNum);
dim3 launchDims(256, 256, 1024); dim3 launchDims(256, 256, 1024);
BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::execInverseBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, BOOL_TYPES) BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::execInverseBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, BOOL_TYPES)
// TODO: remove after the release // TODO: remove after the release
auto res = cudaStreamSynchronize(*stream); auto res = cudaStreamSynchronize(*stream);
@ -308,15 +322,15 @@ void NativeOpExecutioner::execBroadcastInt(nd4j::LaunchContext *lc,
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo))
return;
if (!DataTypeUtils::isZ(zType)) if (!DataTypeUtils::isZ(zType))
throw std::runtime_error("NativeOpExecutioner::execBroadcastInt requires Z operand to have INT type"); throw std::runtime_error("NativeOpExecutioner::execBroadcastInt requires Z operand to have INT type");
if (yType != xType || zType != xType) if (yType != xType || zType != xType)
throw std::runtime_error("NativeOpExecutioner::execBroadcastInt requires both X & Y operands to have same type"); throw std::runtime_error("NativeOpExecutioner::execBroadcastInt requires both X & Y operands to have same type");
if (nd4j::Environment::getInstance()->isDebugAndVerbose())
printf("F3B opNum:[%i]\n", opNum);
dim3 launchDims(256, 256, 1024); dim3 launchDims(256, 256, 1024);
BUILD_SINGLE_SELECTOR(xType, functions::broadcast::BroadcastInt, ::execBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), INTEGER_TYPES) BUILD_SINGLE_SELECTOR(xType, functions::broadcast::BroadcastInt, ::execBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), INTEGER_TYPES)
@ -344,6 +358,9 @@ void NativeOpExecutioner::execInverseBroadcastInt(nd4j::LaunchContext *lc,
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo))
return;
if (!DataTypeUtils::isZ(zType)) if (!DataTypeUtils::isZ(zType))
throw std::runtime_error("NativeOpExecutioner::execBroadcastInt requires Z operand to have INT type"); throw std::runtime_error("NativeOpExecutioner::execBroadcastInt requires Z operand to have INT type");
@ -394,8 +411,8 @@ void NativeOpExecutioner::execBroadcast(nd4j::LaunchContext *lc,
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (nd4j::Environment::getInstance()->isDebugAndVerbose()) if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo))
printf("F3 opNum:[%i]\n", opNum); return;
dim3 launchDims(256, 256, 1024); dim3 launchDims(256, 256, 1024);
@ -429,8 +446,8 @@ void NativeOpExecutioner::execInverseBroadcast(nd4j::LaunchContext *lc,
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo); auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (nd4j::Environment::getInstance()->isDebugAndVerbose()) if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo))
printf("F3I opNum:[%i]\n", opNum); return;
dim3 launchDims(256, 256, 1024); dim3 launchDims(256, 256, 1024);
@ -832,16 +849,21 @@ void NativeOpExecutioner::execTransformSame(nd4j::LaunchContext *lc,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
auto stream = lc->getCudaStream(); auto stream = lc->getCudaStream();
dim3 launchDims(512, 512, 16384);
auto xRank = shape::rank(hXShapeInfo); auto xRank = shape::rank(hXShapeInfo);
auto zRank = shape::rank(hZShapeInfo); auto zRank = shape::rank(hZShapeInfo);
auto xType = ArrayOptions::dataType(hXShapeInfo); auto xType = ArrayOptions::dataType(hXShapeInfo);
auto zType = ArrayOptions::dataType(hZShapeInfo); auto zType = ArrayOptions::dataType(hZShapeInfo);
if (xType != zType) if (shape::isEmpty(hXShapeInfo)) {
throw std::runtime_error("NativeOpExecutioner::execTransformSame requires X & Z to have same type"); return;
}
if (xType != zType) {
throw std::runtime_error("NativeOpExecutioner::execTransformSame requires X & Z to have same type");
}
dim3 launchDims(512, 512, 16384);
BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformSame, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), LIBND4J_TYPES); BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformSame, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), LIBND4J_TYPES);
// TODO: remove after the release // TODO: remove after the release
@ -861,16 +883,21 @@ void NativeOpExecutioner::execTransformBool(nd4j::LaunchContext *lc,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
auto stream = lc->getCudaStream(); auto stream = lc->getCudaStream();
dim3 launchDims(512, 512, 16384);
auto xRank = shape::rank(hXShapeInfo); auto xRank = shape::rank(hXShapeInfo);
auto zRank = shape::rank(hZShapeInfo); auto zRank = shape::rank(hZShapeInfo);
auto xType = ArrayOptions::dataType(hXShapeInfo); auto xType = ArrayOptions::dataType(hXShapeInfo);
auto zType = ArrayOptions::dataType(hZShapeInfo); auto zType = ArrayOptions::dataType(hZShapeInfo);
if (!DataTypeUtils::isB(zType)) if (shape::isEmpty(hXShapeInfo)) {
throw std::runtime_error("NativeOpExecutioner::execTransformBool requires Z to have same boolean type"); return;
}
if (!DataTypeUtils::isB(zType)) {
throw std::runtime_error("NativeOpExecutioner::execTransformBool requires Z to have same boolean type");
}
dim3 launchDims(512, 512, 16384);
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformBool, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), LIBND4J_TYPES, BOOL_TYPES); BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformBool, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), LIBND4J_TYPES, BOOL_TYPES);
// TODO: remove after the release // TODO: remove after the release
@ -896,6 +923,9 @@ void NativeOpExecutioner::execTransformAny(nd4j::LaunchContext *lc,
auto xType = ArrayOptions::dataType(hXShapeInfo); auto xType = ArrayOptions::dataType(hXShapeInfo);
auto zType = ArrayOptions::dataType(hZShapeInfo); auto zType = ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo))
return;
dim3 launchDims(512, 512, 2048); dim3 launchDims(512, 512, 2048);
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformAny, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), LIBND4J_TYPES, LIBND4J_TYPES); BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformAny, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), LIBND4J_TYPES, LIBND4J_TYPES);
@ -917,16 +947,21 @@ void NativeOpExecutioner::execTransformStrict(nd4j::LaunchContext *lc,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
auto stream = lc->getCudaStream(); auto stream = lc->getCudaStream();
dim3 launchDims(512, 512, 16384);
auto xRank = shape::rank(hXShapeInfo); auto xRank = shape::rank(hXShapeInfo);
auto zRank = shape::rank(hZShapeInfo); auto zRank = shape::rank(hZShapeInfo);
auto xType = ArrayOptions::dataType(hXShapeInfo); auto xType = ArrayOptions::dataType(hXShapeInfo);
auto zType = ArrayOptions::dataType(hZShapeInfo); auto zType = ArrayOptions::dataType(hZShapeInfo);
if (xType != zType || !DataTypeUtils::isR(xType)) if (shape::isEmpty(hXShapeInfo)) {
throw datatype_exception::build("NativeOpExecutioner::execTransformStrict requires X & Z to have same floating point type", xType, zType); return;
}
if (xType != zType || !DataTypeUtils::isR(xType)) {
throw datatype_exception::build("NativeOpExecutioner::execTransformStrict requires X & Z to have same floating point type", xType, zType);
}
dim3 launchDims(512, 512, 16384);
BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformStrict, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), FLOAT_TYPES); BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformStrict, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), FLOAT_TYPES);
// TODO: remove after the release // TODO: remove after the release
@ -953,6 +988,9 @@ void NativeOpExecutioner::execTransformFloat(nd4j::LaunchContext *lc,
auto xType = ArrayOptions::dataType(hXShapeInfo); auto xType = ArrayOptions::dataType(hXShapeInfo);
auto zType = ArrayOptions::dataType(hZShapeInfo); auto zType = ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo))
return;
if (!DataTypeUtils::isR(zType)) if (!DataTypeUtils::isR(zType))
throw datatype_exception::build("NativeOpExecutioner::execTransformFloat requires Z to have floating point type", zType); throw datatype_exception::build("NativeOpExecutioner::execTransformFloat requires Z to have floating point type", zType);
@ -1175,6 +1213,9 @@ void NativeOpExecutioner::execScalarBool(nd4j::LaunchContext *lc,
auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo); auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo))
return;
if (xType != yType ) if (xType != yType )
throw std::runtime_error("NativeOpExecutioner::execScalarBool requires X & Y to have same type"); throw std::runtime_error("NativeOpExecutioner::execScalarBool requires X & Y to have same type");
@ -1211,6 +1252,9 @@ void NativeOpExecutioner::execScalarBool(nd4j::LaunchContext *lc,
auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo); auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo))
return;
if (xType != yType ) if (xType != yType )
throw std::runtime_error("NativeOpExecutioner::execScalarBool requires X & Y to have same type"); throw std::runtime_error("NativeOpExecutioner::execScalarBool requires X & Y to have same type");
@ -1244,6 +1288,9 @@ void NativeOpExecutioner::execScalarInt(nd4j::LaunchContext *lc,
auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo); auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo))
return;
if (xType != yType || zType != xType) if (xType != yType || zType != xType)
throw std::runtime_error("NativeOpExecutioner::execScalarInt requires X & Y to have same type"); throw std::runtime_error("NativeOpExecutioner::execScalarInt requires X & Y to have same type");
@ -1280,6 +1327,9 @@ void NativeOpExecutioner::execScalarInt(nd4j::LaunchContext *lc,
auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo); auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo))
return;
if (xType != yType || zType != xType) if (xType != yType || zType != xType)
throw std::runtime_error("NativeOpExecutioner::execScalarInt requires X & Y to have same type"); throw std::runtime_error("NativeOpExecutioner::execScalarInt requires X & Y to have same type");
@ -1313,6 +1363,9 @@ void NativeOpExecutioner::execScalar(nd4j::LaunchContext *lc,
auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo); auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo))
return;
#ifdef __ND4J_EXPERIMENTAL__ #ifdef __ND4J_EXPERIMENTAL__
BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::scalar::ScalarTransform, ::executeCudaShaped(launchDims, stream, opNum, dX, dXShapeInfo, hXShapeInfo, dZ, dZShapeInfo, hZShapeInfo, dScalar, extraParams), LIBND4J_TYPES, LIBND4J_TYPES); BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::scalar::ScalarTransform, ::executeCudaShaped(launchDims, stream, opNum, dX, dXShapeInfo, hXShapeInfo, dZ, dZShapeInfo, hZShapeInfo, dScalar, extraParams), LIBND4J_TYPES, LIBND4J_TYPES);
@ -1346,6 +1399,9 @@ void NativeOpExecutioner::execScalar(nd4j::LaunchContext *lc,
auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo); auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo))
return;
dim3 launchDims(256, 256, 16384); dim3 launchDims(256, 256, 16384);
#ifdef __ND4J_EXPERIMENTAL__ #ifdef __ND4J_EXPERIMENTAL__

View File

@ -294,6 +294,7 @@ void execBroadcastBool(Nd4jPointer *extraPointers,
void *dY, Nd4jLong *dYShapeInfo, void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo, void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo, void *dZ, Nd4jLong *dZShapeInfo,
void *extraParams,
void *hDimension, Nd4jLong *hDimensionShape, void *hDimension, Nd4jLong *hDimensionShape,
void *dDimension, Nd4jLong *dDimensionShape) { void *dDimension, Nd4jLong *dDimensionShape) {
try { try {
@ -313,7 +314,7 @@ void execBroadcastBool(Nd4jPointer *extraPointers,
LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]);
NativeOpExecutioner::execBroadcastBool(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, hY, hYShapeInfo, dY, NativeOpExecutioner::execBroadcastBool(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, hY, hYShapeInfo, dY,
dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, dimension, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, extraParams, dimension,
dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ,
tadOffsetsZ); tadOffsetsZ);
} catch (std::exception &e) { } catch (std::exception &e) {

View File

@ -20,6 +20,7 @@
#include <TrueBroadcastHelper.h> #include <TrueBroadcastHelper.h>
#include <ops/ops.h> #include <ops/ops.h>
#include <execution/Threads.h>
using namespace simdOps; using namespace simdOps;
@ -47,36 +48,39 @@ void TrueBroadcastHelper<X, Y, Z>::exec(const NDArray& xArr, const NDArray& yArr
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf()); std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(zLen > Environment::getInstance()->elementwiseThreshold()) firstprivate(xCoords, yCoords, zCoords)) auto func = PRAGMA_THREADS_FOR {
for (Nd4jLong i = 0; i < zLen; ++i) { for (auto i = start; i < stop; ++i) {
shape::index2coords(i, zShapeInfo, zCoords.data()); shape::index2coords(i, zShapeInfo, zCoords.data());
for(int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) { for (int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) {
if(ix >= 0) { if (ix >= 0) {
if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) { if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) {
xCoords[ix--] = zCoords[iz]; xCoords[ix--] = zCoords[iz];
} else { } else {
xCoords[ix--] = 0; xCoords[ix--] = 0;
}
}
if (iy >= 0) {
if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) {
yCoords[iy--] = zCoords[iz];
} else {
yCoords[iy--] = 0;
}
} }
} }
if(iy >= 0) { const auto xOffset = shape::getOffset(xShapeInfo, xCoords.data());
if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) { const auto yOffset = shape::getOffset(yShapeInfo, yCoords.data());
yCoords[iy--] = zCoords[iz]; const auto zOffset = shape::getOffset(zShapeInfo, zCoords.data());
} else {
yCoords[iy--] = 0; z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
}
}
} }
};
const auto xOffset = shape::getOffset(xShapeInfo, xCoords.data()); samediff::Threads::parallel_for(func, 0, zLen);
const auto yOffset = shape::getOffset(yShapeInfo, yCoords.data());
const auto zOffset = shape::getOffset(zShapeInfo, zCoords.data());
z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
}
} }
template <typename X, typename Y, typename Z> template <typename X, typename Y, typename Z>
@ -103,38 +107,40 @@ void TrueBroadcastBoolHelper<X, Z>::exec(const NDArray& xArr, const NDArray& yAr
const Nd4jLong zLen = zArr.lengthOf(); const Nd4jLong zLen = zArr.lengthOf();
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf()); auto func = PRAGMA_THREADS_FOR {
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
for (auto i = start; i < stop; ++i) {
PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(zLen > Environment::getInstance()->elementwiseThreshold()) firstprivate(xCoords, yCoords, zCoords)) shape::index2coords(i, zShapeInfo, zCoords.data());
for (Nd4jLong i = 0; i < zLen; ++i) {
shape::index2coords(i, zShapeInfo, zCoords.data()); for (int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) {
for(int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) { if (ix >= 0) {
if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) {
xCoords[ix--] = zCoords[iz];
} else {
xCoords[ix--] = 0;
}
}
if(ix >= 0) { if (iy >= 0) {
if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) { if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) {
xCoords[ix--] = zCoords[iz]; yCoords[iy--] = zCoords[iz];
} else { } else {
xCoords[ix--] = 0; yCoords[iy--] = 0;
}
} }
} }
if(iy >= 0) { const auto xOffset = shape::getOffset(xShapeInfo, xCoords.data());
if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) { const auto yOffset = shape::getOffset(yShapeInfo, yCoords.data());
yCoords[iy--] = zCoords[iz]; const auto zOffset = shape::getOffset(zShapeInfo, zCoords.data());
} else {
yCoords[iy--] = 0; z[zOffset] = OpType::op(x[xOffset], y[yOffset], nullptr);
}
}
} }
};
const auto xOffset = shape::getOffset(xShapeInfo, xCoords.data()); samediff::Threads::parallel_for(func, 0, zLen);
const auto yOffset = shape::getOffset(yShapeInfo, yCoords.data());
const auto zOffset = shape::getOffset(zShapeInfo, zCoords.data());
z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
}
} }
template <typename X, typename Y> template <typename X, typename Y>
@ -163,36 +169,39 @@ void TrueBroadcastIntHelper<X>::exec(const NDArray& xArr, const NDArray& yArr, N
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf()); std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(zLen > Environment::getInstance()->elementwiseThreshold()) firstprivate(xCoords, yCoords, zCoords)) auto func = PRAGMA_THREADS_FOR {
for (Nd4jLong i = 0; i < zLen; ++i) { for (auto i = start; i < stop; ++i) {
shape::index2coords(i, zShapeInfo, zCoords.data()); shape::index2coords(i, zShapeInfo, zCoords.data());
for(int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) { for (int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) {
if(ix >= 0) { if (ix >= 0) {
if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) { if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) {
xCoords[ix--] = zCoords[iz]; xCoords[ix--] = zCoords[iz];
} else { } else {
xCoords[ix--] = 0; xCoords[ix--] = 0;
}
}
if (iy >= 0) {
if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) {
yCoords[iy--] = zCoords[iz];
} else {
yCoords[iy--] = 0;
}
} }
} }
if(iy >= 0) { const auto xOffset = shape::getOffset(xShapeInfo, xCoords.data());
if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) { const auto yOffset = shape::getOffset(yShapeInfo, yCoords.data());
yCoords[iy--] = zCoords[iz]; const auto zOffset = shape::getOffset(zShapeInfo, zCoords.data());
} else {
yCoords[iy--] = 0; z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
}
}
} }
};
const auto xOffset = shape::getOffset(xShapeInfo, xCoords.data()); samediff::Threads::parallel_for(func, 0, zLen);
const auto yOffset = shape::getOffset(yShapeInfo, yCoords.data());
const auto zOffset = shape::getOffset(zShapeInfo, zCoords.data());
z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
}
} }
template <typename X> template <typename X>

View File

@ -86,7 +86,7 @@ __global__ static void trueBroadcastCuda(const void* vx, const Nd4jLong* xShapeI
const auto zOffset = shape::getOffset(zShapeInfo, zCoords); const auto zOffset = shape::getOffset(zShapeInfo, zCoords);
const auto yOffset = shape::getOffset(yShapeInfo, yCoords); const auto yOffset = shape::getOffset(yShapeInfo, yCoords);
z[zOffset] = OpType::op(x[xOffset], y[yOffset]); z[zOffset] = OpType::op(x[xOffset], y[yOffset], nullptr);
} }
} }
@ -172,7 +172,7 @@ __global__ static void trueBroadcastBoolCuda(const void* vx, const Nd4jLong* xSh
const auto zOffset = shape::getOffset(zShapeInfo, zCoords); const auto zOffset = shape::getOffset(zShapeInfo, zCoords);
const auto yOffset = shape::getOffset(yShapeInfo, yCoords); const auto yOffset = shape::getOffset(yShapeInfo, yCoords);
z[zOffset] = OpType::op(x[xOffset], y[yOffset]); z[zOffset] = OpType::op(x[xOffset], y[yOffset], nullptr);
} }
} }

View File

@ -65,13 +65,14 @@ namespace functions {
Nd4jLong *yShapeInfo, Nd4jLong *yShapeInfo,
void *result, void *result,
Nd4jLong *resultShapeInfo, Nd4jLong *resultShapeInfo,
void *extraParams,
int *dimension, int *dimension,
int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ); int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
template <typename OpClass> template <typename OpClass>
static __host__ void intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ); static __host__ void intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
static __host__ void execBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ); static __host__ void execBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
template<typename OpType> template<typename OpType>
static __device__ void transformInverseCuda( static __device__ void transformInverseCuda(
@ -81,13 +82,14 @@ namespace functions {
Nd4jLong *yShapeInfo, Nd4jLong *yShapeInfo,
void *result, void *result,
Nd4jLong *resultShapeInfo, Nd4jLong *resultShapeInfo,
void *extraParams,
int *dimension, int *dimension,
int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ); int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
template <typename OpClass> template <typename OpClass>
static __host__ void intermediateInverseBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ); static __host__ void intermediateInverseBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
static __host__ void execInverseBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ); static __host__ void execInverseBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
#else #else
@ -98,6 +100,7 @@ namespace functions {
Nd4jLong *yShapeInfo, Nd4jLong *yShapeInfo,
void *result, void *result,
Nd4jLong *resultShapeInfo, Nd4jLong *resultShapeInfo,
void *extraParams,
int *dimension, int *dimension,
int dimensionLength, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadShapeInfo,
@ -114,6 +117,7 @@ namespace functions {
Nd4jLong *yShapeInfo, Nd4jLong *yShapeInfo,
void *result, void *result,
Nd4jLong *resultShapeInfo, Nd4jLong *resultShapeInfo,
void *extraParams,
int *dimension, int *dimension,
int dimensionLength, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadShapeInfo,
@ -141,6 +145,7 @@ namespace functions {
Nd4jLong *yShapeInfo, Nd4jLong *yShapeInfo,
void *result, void *result,
Nd4jLong *resultShapeInfo, Nd4jLong *resultShapeInfo,
void *extraParams,
int *dimension, int *dimension,
int dimensionLength, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadShapeInfo,
@ -157,6 +162,7 @@ namespace functions {
Nd4jLong *yShapeInfo, Nd4jLong *yShapeInfo,
void *result, void *result,
Nd4jLong *resultShapeInfo, Nd4jLong *resultShapeInfo,
void *extraParams,
int *dimension, int *dimension,
int dimensionLength, int dimensionLength,
Nd4jLong *tadShapeInfo, Nd4jLong *tadShapeInfo,

View File

@ -39,6 +39,7 @@ namespace functions {
Nd4jLong *yShapeInfo, Nd4jLong *yShapeInfo,
void *z, void *z,
Nd4jLong *zShapeInfo, Nd4jLong *zShapeInfo,
void *extraParams,
int *dimension, int *dimension,
int dimensionLength, int dimensionLength,
Nd4jLong *xTadShapeInfo, Nd4jLong *xTadShapeInfo,
@ -53,6 +54,7 @@ namespace functions {
yShapeInfo, yShapeInfo,
z, z,
zShapeInfo, zShapeInfo,
extraParams,
dimension, dimension,
dimensionLength, dimensionLength,
xTadShapeInfo, xTadShapeInfo,
@ -69,6 +71,7 @@ namespace functions {
Nd4jLong *yShapeInfo, Nd4jLong *yShapeInfo,
void *z, void *z,
Nd4jLong *zShapeInfo, Nd4jLong *zShapeInfo,
void *extraParams,
int *dimension, int *dimension,
int dimensionLength, int dimensionLength,
Nd4jLong *xTadShapeInfo, Nd4jLong *xTadShapeInfo,
@ -83,6 +86,7 @@ namespace functions {
yShapeInfo, yShapeInfo,
z, z,
zShapeInfo, zShapeInfo,
extraParams,
dimension, dimension,
dimensionLength, dimensionLength,
xTadShapeInfo, xTadShapeInfo,
@ -99,6 +103,7 @@ namespace functions {
Nd4jLong *yShapeInfo, Nd4jLong *yShapeInfo,
void *vz, void *vz,
Nd4jLong *zShapeInfo, Nd4jLong *zShapeInfo,
void *vextraParams,
int *dimension, int *dimension,
int dimensionLength, int dimensionLength,
Nd4jLong *xTadShapeInfo, Nd4jLong *xTadShapeInfo,
@ -111,6 +116,7 @@ namespace functions {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<X *>(vx);
auto y = reinterpret_cast<X *>(vy); auto y = reinterpret_cast<X *>(vy);
auto z = reinterpret_cast<Z *>(vz); auto z = reinterpret_cast<Z *>(vz);
auto extraParams = reinterpret_cast<X*>(vextraParams);
//decompose in to several sub tads after //decompose in to several sub tads after
//moving all dimensions (in sorted order) //moving all dimensions (in sorted order)
@ -155,7 +161,7 @@ namespace functions {
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (unsigned int f = 0; f < tadLength; f++) for (unsigned int f = 0; f < tadLength; f++)
oZ[f] = OpType::op(oX[f], y[f]); oZ[f] = OpType::op(oX[f], y[f], extraParams);
} }
} }
else if(kindOfLoop == nd4j::LoopKind::EWSNONZERO) { else if(kindOfLoop == nd4j::LoopKind::EWSNONZERO) {
@ -165,7 +171,7 @@ namespace functions {
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (unsigned int f = 0; f < tadLength; f++) for (unsigned int f = 0; f < tadLength; f++)
oZ[f * zEws] = OpType::op(oX[f * xEws], y[f * yEws]); oZ[f * zEws] = OpType::op(oX[f * xEws], y[f * yEws], extraParams);
}; };
} }
else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xTadShapeShapeInfo, zTadShapeInfo)) { else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xTadShapeShapeInfo, zTadShapeInfo)) {
@ -179,7 +185,7 @@ namespace functions {
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (int f = 0; f < tadLength; f++) { for (int f = 0; f < tadLength; f++) {
auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX); auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
oZ[offset] = OpType::op(oX[offset], y[offset]); oZ[offset] = OpType::op(oX[offset], y[offset], extraParams);
} }
}; };
} }
@ -197,7 +203,7 @@ namespace functions {
for (int f = 0; f < tadLength; f++) { for (int f = 0; f < tadLength; f++) {
auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX); auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ); auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ);
oZ[zOffset] = OpType::op(oX[offset], y[offset]); oZ[zOffset] = OpType::op(oX[offset], y[offset], extraParams);
} }
}; };
} }
@ -215,7 +221,7 @@ namespace functions {
for (int f = 0; f < tadLength; f++) { for (int f = 0; f < tadLength; f++) {
auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX); auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
auto yOffset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY); auto yOffset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY);
oZ[offset] = OpType::op(oX[offset], y[yOffset]); oZ[offset] = OpType::op(oX[offset], y[yOffset], extraParams);
} }
}; };
@ -234,7 +240,7 @@ namespace functions {
for (int f = 0; f < tadLength; f++) { for (int f = 0; f < tadLength; f++) {
auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX); auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
auto offset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY); auto offset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY);
oZ[offset] = OpType::op(oX[xOffset], y[offset]); oZ[offset] = OpType::op(oX[xOffset], y[offset], extraParams);
} }
}; };
} }
@ -255,7 +261,7 @@ namespace functions {
auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX); auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
auto yOffset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY); auto yOffset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY);
auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ); auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ);
oZ[zOffset] = OpType::op(oX[xOffset], y[yOffset]); oZ[zOffset] = OpType::op(oX[xOffset], y[yOffset], extraParams);
} }
}; };
} }
@ -270,6 +276,7 @@ namespace functions {
Nd4jLong *yShapeInfo, Nd4jLong *yShapeInfo,
void *vz, void *vz,
Nd4jLong *zShapeInfo, Nd4jLong *zShapeInfo,
void *vextraParams,
int *dimension, int *dimension,
int dimensionLength, int dimensionLength,
Nd4jLong *yTadShapeInfo, Nd4jLong *yTadShapeInfo,
@ -282,6 +289,7 @@ namespace functions {
auto x = reinterpret_cast<X *>(vx); auto x = reinterpret_cast<X *>(vx);
auto y = reinterpret_cast<X *>(vy); auto y = reinterpret_cast<X *>(vy);
auto z = reinterpret_cast<Z *>(vz); auto z = reinterpret_cast<Z *>(vz);
auto extraParams = reinterpret_cast<X*>(vextraParams);
//decompose in to several sub tads after //decompose in to several sub tads after
//moving all dimensions (in sorted order) //moving all dimensions (in sorted order)
@ -326,7 +334,7 @@ namespace functions {
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (unsigned int f = 0; f < tadLength; f++) for (unsigned int f = 0; f < tadLength; f++)
oZ[f] = OpType::op(x[f], oY[f]); oZ[f] = OpType::op(x[f], oY[f], extraParams);
} }
} }
else if(kindOfLoop == nd4j::LoopKind::EWSNONZERO) { else if(kindOfLoop == nd4j::LoopKind::EWSNONZERO) {
@ -336,7 +344,7 @@ namespace functions {
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (uint f = 0; f < tadLength; f++) for (uint f = 0; f < tadLength; f++)
oZ[f * zEws] = OpType::op(x[f * xEws], oY[f * yEws]); oZ[f * zEws] = OpType::op(x[f * xEws], oY[f * yEws], extraParams);
} }
} }
else if(shape::haveSameShapeAndStrides(yTadShapeShapeInfo, xShapeInfo) && shape::haveSameShapeAndStrides(yTadShapeShapeInfo, zTadShapeInfo)) { else if(shape::haveSameShapeAndStrides(yTadShapeShapeInfo, xShapeInfo) && shape::haveSameShapeAndStrides(yTadShapeShapeInfo, zTadShapeInfo)) {
@ -351,7 +359,7 @@ namespace functions {
PRAGMA_OMP_SIMD PRAGMA_OMP_SIMD
for (int f = 0; f < tadLength; f++) { for (int f = 0; f < tadLength; f++) {
auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY); auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
oZ[offset] = OpType::op(x[offset], oY[offset]); oZ[offset] = OpType::op(x[offset], oY[offset], extraParams);
} }
} }
} }
@ -370,7 +378,7 @@ namespace functions {
for (int f = 0; f < tadLength; f++) { for (int f = 0; f < tadLength; f++) {
auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY); auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ); auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ);
oZ[zOffset] = OpType::op(x[offset], oY[offset]); oZ[zOffset] = OpType::op(x[offset], oY[offset], extraParams);
} }
} }
} }
@ -389,7 +397,7 @@ namespace functions {
for (int f = 0; f < tadLength; f++) { for (int f = 0; f < tadLength; f++) {
auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY); auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
auto xOffset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX); auto xOffset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX);
oZ[offset] = OpType::op(x[xOffset], oY[offset]); oZ[offset] = OpType::op(x[xOffset], oY[offset], extraParams);
} }
} }
} }
@ -408,7 +416,7 @@ namespace functions {
for (int f = 0; f < tadLength; f++) { for (int f = 0; f < tadLength; f++) {
auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY); auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
auto offset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX); auto offset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX);
oZ[offset] = OpType::op(x[offset], oY[yOffset]); oZ[offset] = OpType::op(x[offset], oY[yOffset], extraParams);
} }
} }
} }
@ -430,7 +438,7 @@ namespace functions {
auto xOffset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX); auto xOffset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX);
auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY); auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ); auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ);
oZ[zOffset] = OpType::op(x[xOffset], oY[yOffset]); oZ[zOffset] = OpType::op(x[xOffset], oY[yOffset], extraParams);
} }
} }
} }

View File

@ -40,10 +40,11 @@ static __global__ void broadcastBoolSimple(
Nd4jLong *yShapeInfo, Nd4jLong *yShapeInfo,
void *z, void *z,
Nd4jLong *zShapeInfo, Nd4jLong *zShapeInfo,
void *extraParams,
int *dimension, int *dimension,
int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
functions::broadcast::BroadcastBool<X, Z>::template transformCuda<OpClass>(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo,dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ); functions::broadcast::BroadcastBool<X, Z>::template transformCuda<OpClass>(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo, extraParams, dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ);
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
@ -55,10 +56,11 @@ static __global__ void broadcastBoolInverseSimple(
Nd4jLong *yShapeInfo, Nd4jLong *yShapeInfo,
void *z, void *z,
Nd4jLong *zShapeInfo, Nd4jLong *zShapeInfo,
void *extraParams,
int *dimension, int *dimension,
int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
functions::broadcast::BroadcastBool<X, Z>::template transformInverseCuda<OpClass>(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo,dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ); functions::broadcast::BroadcastBool<X, Z>::template transformInverseCuda<OpClass>(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo,extraParams,dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ);
} }
namespace functions { namespace functions {
@ -66,15 +68,15 @@ namespace functions {
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template<typename X, typename Z> template<typename X, typename Z>
template <typename OpClass> template <typename OpClass>
__host__ void BroadcastBool<X,Z>::intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { __host__ void BroadcastBool<X,Z>::intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
broadcastBoolSimple<X, Z, OpClass><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ); broadcastBoolSimple<X, Z, OpClass><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ);
nd4j::DebugHelper::checkErrorCode(stream, "intermediateBroadcastBool(...) failed"); nd4j::DebugHelper::checkErrorCode(stream, "intermediateBroadcastBool(...) failed");
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template<typename X, typename Y> template<typename X, typename Y>
__host__ void BroadcastBool<X,Y>::execBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { __host__ void BroadcastBool<X,Y>::execBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
DISPATCH_BY_OPNUM_TT(intermediateBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), OPS_A(BROADCAST_BOOL_OPS)) DISPATCH_BY_OPNUM_TT(intermediateBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), OPS_A(BROADCAST_BOOL_OPS))
DEBUG_KERNEL(stream, opNum); DEBUG_KERNEL(stream, opNum);
} }
@ -82,15 +84,15 @@ namespace functions {
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template<typename X, typename Z> template<typename X, typename Z>
template <typename OpClass> template <typename OpClass>
__host__ void BroadcastBool<X,Z>::intermediateInverseBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { __host__ void BroadcastBool<X,Z>::intermediateInverseBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
broadcastBoolInverseSimple<X, Z, OpClass><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ); broadcastBoolInverseSimple<X, Z, OpClass><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ);
nd4j::DebugHelper::checkErrorCode(stream, "intermediateBroadcastBool(...) failed"); nd4j::DebugHelper::checkErrorCode(stream, "intermediateBroadcastBool(...) failed");
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
template<typename X, typename Y> template<typename X, typename Y>
__host__ void BroadcastBool<X,Y>::execInverseBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { __host__ void BroadcastBool<X,Y>::execInverseBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
DISPATCH_BY_OPNUM_TT(intermediateInverseBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), OPS_A(BROADCAST_BOOL_OPS)) DISPATCH_BY_OPNUM_TT(intermediateInverseBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), OPS_A(BROADCAST_BOOL_OPS))
DEBUG_KERNEL(stream, opNum); DEBUG_KERNEL(stream, opNum);
} }
@ -102,6 +104,7 @@ namespace functions {
void *vx, Nd4jLong *xShapeInfo, void *vx, Nd4jLong *xShapeInfo,
void *vy, Nd4jLong *yShapeInfo, void *vy, Nd4jLong *yShapeInfo,
void *vz, Nd4jLong *zShapeInfo, void *vz, Nd4jLong *zShapeInfo,
void *vextraParams,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
@ -113,6 +116,7 @@ namespace functions {
auto x = reinterpret_cast<X*>(vx); auto x = reinterpret_cast<X*>(vx);
auto y = reinterpret_cast<X*>(vy); auto y = reinterpret_cast<X*>(vy);
auto z = reinterpret_cast<Z*>(vz); auto z = reinterpret_cast<Z*>(vz);
auto extraParams = reinterpret_cast<X*>(vextraParams);
//decompose in to several sub tads after //decompose in to several sub tads after
//moving all dimensions (in sorted order) //moving all dimensions (in sorted order)
@ -140,7 +144,7 @@ namespace functions {
if(tadEWS > 0 && zEWS > 0 && xEWS > 0 && dimensionLength == 1) { if(tadEWS > 0 && zEWS > 0 && xEWS > 0 && dimensionLength == 1) {
for (int i = threadIdx.x; i < tadLength; i+= blockDim.x) for (int i = threadIdx.x; i < tadLength; i+= blockDim.x)
rZ[i * zEWS] = OpType::op(x[i * xEWS], rY[i * tadEWS]); rZ[i * zEWS] = OpType::op(x[i * xEWS], rY[i * tadEWS], extraParams);
} }
else { else {
// it is expected that x and z tads and y array all have the same length // it is expected that x and z tads and y array all have the same length
@ -149,7 +153,7 @@ namespace functions {
auto yOffset = shape::getIndexOffset(i, tadOnlyShapeInfo); auto yOffset = shape::getIndexOffset(i, tadOnlyShapeInfo);
auto zOffset = shape::getIndexOffset(i, tadOnlyShapeInfoZ); auto zOffset = shape::getIndexOffset(i, tadOnlyShapeInfoZ);
rZ[zOffset] = OpType::op(x[xOffset], rY[yOffset]); rZ[zOffset] = OpType::op(x[xOffset], rY[yOffset], extraParams);
} }
} }
} }
@ -162,6 +166,7 @@ namespace functions {
void *vx, Nd4jLong *xShapeInfo, void *vx, Nd4jLong *xShapeInfo,
void *vy, Nd4jLong *yShapeInfo, void *vy, Nd4jLong *yShapeInfo,
void *vz, Nd4jLong *zShapeInfo, void *vz, Nd4jLong *zShapeInfo,
void *vextraParams,
int *dimension, int dimensionLength, int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) { Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
@ -173,6 +178,7 @@ namespace functions {
auto x = reinterpret_cast<X*>(vx); auto x = reinterpret_cast<X*>(vx);
auto y = reinterpret_cast<X*>(vy); auto y = reinterpret_cast<X*>(vy);
auto z = reinterpret_cast<Z*>(vz); auto z = reinterpret_cast<Z*>(vz);
auto extraParams = reinterpret_cast<X*>(vextraParams);
//decompose in to several sub tads after //decompose in to several sub tads after
//moving all dimensions (in sorted order) //moving all dimensions (in sorted order)
@ -208,7 +214,7 @@ namespace functions {
if(tadEWS > 0 && zEWS > 0 && yEWS > 0 && dimensionLength == 1) { if(tadEWS > 0 && zEWS > 0 && yEWS > 0 && dimensionLength == 1) {
for (int i = threadIdx.x; i < tadLength; i+= blockDim.x) for (int i = threadIdx.x; i < tadLength; i+= blockDim.x)
rZ[i * zEWS] = OpType::op(rX[i * tadEWS], y[i * yEWS]); rZ[i * zEWS] = OpType::op(rX[i * tadEWS], y[i * yEWS], extraParams);
} }
else { else {
// it is expected that x and z tads and y array all have the same length // it is expected that x and z tads and y array all have the same length
@ -217,7 +223,7 @@ namespace functions {
auto yOffset = shape::getIndexOffset(i, yShapeInfo); auto yOffset = shape::getIndexOffset(i, yShapeInfo);
auto zOffset = shape::getIndexOffset(i, tadOnlyShapeInfoZ); auto zOffset = shape::getIndexOffset(i, tadOnlyShapeInfoZ);
rZ[zOffset] = OpType::op(rX[xOffset], y[yOffset]); rZ[zOffset] = OpType::op(rX[xOffset], y[yOffset], extraParams);
} }
} }
} }

View File

@ -45,13 +45,13 @@
(2, LessThan),\ (2, LessThan),\
(3, Epsilon),\ (3, Epsilon),\
(4, GreaterThanOrEqual),\ (4, GreaterThanOrEqual),\
(5, LessThanOrEqual),\ (5, MatchCondition) ,\
(6, NotEqualTo),\ (6, NotEqualTo),\
(7, And),\ (7, And),\
(8, Or),\ (8, Or),\
(9, Xor) ,\ (9, Xor) ,\
(10, Not) (10, Not) ,\
(11, LessThanOrEqual)
#define BROADCAST_OPS \ #define BROADCAST_OPS \
(0, Add), \ (0, Add), \
@ -198,12 +198,13 @@
(2, LessThan),\ (2, LessThan),\
(3, Epsilon),\ (3, Epsilon),\
(4, GreaterThanOrEqual),\ (4, GreaterThanOrEqual),\
(5, LessThanOrEqual),\ (5, MatchCondition) ,\
(6, NotEqualTo),\ (6, NotEqualTo),\
(7, And),\ (7, And),\
(8, Or),\ (8, Or),\
(9, Xor) ,\ (9, Xor) ,\
(10, Not) (10, Not) ,\
(11, LessThanOrEqual)
#define SCALAR_OPS \ #define SCALAR_OPS \
(0, Add),\ (0, Add),\
@ -341,12 +342,13 @@
(2, LessThan),\ (2, LessThan),\
(3, Epsilon),\ (3, Epsilon),\
(4, GreaterThanOrEqual),\ (4, GreaterThanOrEqual),\
(5, LessThanOrEqual),\ (5, MatchCondition) ,\
(6, NotEqualTo),\ (6, NotEqualTo),\
(7, And),\ (7, And),\
(8, Or),\ (8, Or),\
(9, Xor) ,\ (9, Xor) ,\
(10, Not) (10, Not) ,\
(11, LessThanOrEqual)
#define PAIRWISE_TRANSFORM_OPS \ #define PAIRWISE_TRANSFORM_OPS \
(0, Add),\ (0, Add),\

View File

@ -2302,54 +2302,66 @@ namespace simdOps {
return old + opOutput; return old + opOutput;
} }
// this op return 1.0 if condition met, 0.0 otherwise op_def static Z op(X d1, X compare, X eps, int mode) {
op_def static Z op(X d1, X *extraParams) { switch (mode) {
X compare = extraParams[0]; case 0: // equals
X eps = extraParams[1]; return nd4j::math::nd4j_abs<X>(d1 - compare) <= eps ? 1 : 0;
case 1: // not equals
auto mode = static_cast<int>(extraParams[2]); return nd4j::math::nd4j_abs<X>(d1 - compare) > eps ? 1 : 0;
//printf("value: %f; comp: %f; eps: %f; mode: %i;\n", (float) d1, (float) compare, (float) eps, mode); case 2: // less_than
return d1 < compare ? 1 : 0;
switch (mode) { case 3: // greater_than
case 0: // equals return d1 > compare ? 1 : 0;
return nd4j::math::nd4j_abs<X>(d1 - compare) <= eps ? 1 : 0; case 4: // less_or_equals_than
case 1: // not equals return d1 <= compare ? 1 : 0;
return nd4j::math::nd4j_abs<X>(d1 - compare) > eps ? 1 : 0; case 5: // greater_or_equals_than
case 2: // less_than return d1 >= compare ? 1 : 0;
return d1 < compare ? 1 : 0; case 6: // abs_less_than
case 3: // greater_than return nd4j::math::nd4j_abs<X>(d1) < compare ? 1 : 0;
return d1 > compare ? 1 : 0; case 7: // abs_greater_than
case 4: // less_or_equals_than return nd4j::math::nd4j_abs<X>(d1) > compare ? 1 : 0;
return d1 <= compare ? 1 : 0; case 8: // is inf
case 5: // greater_or_equals_than return nd4j::math::nd4j_isinf(d1) ? 1 : 0;
return d1 >= compare ? 1 : 0; case 9: // is nan
case 6: // abs_less_than return nd4j::math::nd4j_isnan(d1) ? 1 : 0;
return nd4j::math::nd4j_abs<X>(d1) < compare ? 1 : 0; case 10:
case 7: // abs_greater_than return (d1 == compare) ? 1 : 0;
return nd4j::math::nd4j_abs<X>(d1) > compare ? 1 : 0; case 11:
case 8: // is inf return (d1 != compare) ? 1 : 0;
return nd4j::math::nd4j_isinf(d1) ? 1 : 0; case 12: // abs_greater_or_equals_than
case 9: // is nan return nd4j::math::nd4j_abs<X>(d1) >= compare ? 1 : 0;
return nd4j::math::nd4j_isnan(d1) ? 1 : 0; case 13: // abs_less_or_equals_than
case 10: return nd4j::math::nd4j_abs<X>(d1) <= compare ? 1 : 0;
return (d1 == compare) ? 1 : 0;
case 11:
return (d1 != compare) ? 1 : 0;
case 12: // abs_greater_or_equals_than
return nd4j::math::nd4j_abs<X>(d1) >= compare ? 1 : 0;
case 13: // abs_less_or_equals_than
return nd4j::math::nd4j_abs<X>(d1) <= compare ? 1 : 0;
case 14: case 14:
// isFinite // isFinite
return !(nd4j::math::nd4j_isinf(d1) || nd4j::math::nd4j_isnan(d1)) ? 1 : 0; return !(nd4j::math::nd4j_isinf(d1) || nd4j::math::nd4j_isnan(d1)) ? 1 : 0;
case 15: case 15:
// isInfinite // isInfinite
return nd4j::math::nd4j_isinf(d1) || nd4j::math::nd4j_isnan(d1) ? 1 : 0; return nd4j::math::nd4j_isinf(d1) || nd4j::math::nd4j_isnan(d1) ? 1 : 0;
default: default:
printf("Undefined match condition: [%i]\n", mode); printf("Undefined match condition: [%i]\n", mode);
} }
return d1; return d1;
}
// this op return 1.0 if condition met, 0.0 otherwise
op_def static Z op(X d1, X compare, X *extraParams) {
X eps = extraParams[1];
auto mode = static_cast<int>(extraParams[0]);
return op(d1, compare, eps, mode);
}
// this op return 1.0 if condition met, 0.0 otherwise
op_def static Z op(X d1, X *extraParams) {
X compare = extraParams[0];
X eps = extraParams[1];
auto mode = static_cast<int>(extraParams[2]);
return op(d1, compare, eps, mode);
} }
op_def static Z postProcess(Z reduction, Nd4jLong n, X *extraParams) { op_def static Z postProcess(Z reduction, Nd4jLong n, X *extraParams) {

View File

@ -1342,6 +1342,7 @@ TEST_F(CudaBasicsTests1, execBroadcastBool_1) {
nullptr, x.getShapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, x.getShapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
nullptr, y.getShapeInfo(), y.specialBuffer(), y.specialShapeInfo(), nullptr, y.getShapeInfo(), y.specialBuffer(), y.specialShapeInfo(),
nullptr, z.getShapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, z.getShapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
nullptr,
(int*)devicePtrs[0], dimensions.size(), (int*)devicePtrs[0], dimensions.size(),
(Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2],
nullptr, nullptr); nullptr, nullptr);
@ -1400,6 +1401,7 @@ TEST_F(CudaBasicsTests1, execBroadcastBool_2) {
nullptr, x.getShapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, x.getShapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
nullptr, y.getShapeInfo(), y.specialBuffer(), y.specialShapeInfo(), nullptr, y.getShapeInfo(), y.specialBuffer(), y.specialShapeInfo(),
nullptr, z.getShapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, z.getShapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
nullptr,
(int*)devicePtrs[0], dimensions.size(), (int*)devicePtrs[0], dimensions.size(),
(Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2],
nullptr, nullptr); nullptr, nullptr);

View File

@ -674,6 +674,7 @@ TEST_F(LegacyOpsTests, test_inverse_broadcast_2) {
x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(),
z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
nullptr,
nullptr, 0, nullptr, 0,
tadPackY.platformShapeInfo(), tadPackY.platformOffsets(), tadPackY.platformShapeInfo(), tadPackY.platformOffsets(),
tadPackY.platformShapeInfo(), tadPackY.platformOffsets()); tadPackY.platformShapeInfo(), tadPackY.platformOffsets());

View File

@ -202,6 +202,7 @@ printf("Unsupported for cuda now.\n");
nullptr, nullptr, nullptr, nullptr,
exp.buffer(), exp.shapeInfo(), exp.buffer(), exp.shapeInfo(),
nullptr, nullptr, nullptr, nullptr,
nullptr,
dimension.buffer(), dimension.shapeInfo(), dimension.buffer(), dimension.shapeInfo(),
nullptr, nullptr); nullptr, nullptr);
ASSERT_TRUE(exp.e<bool>(1) && !exp.e<bool>(0)); ASSERT_TRUE(exp.e<bool>(1) && !exp.e<bool>(0));

View File

@ -417,7 +417,7 @@ public class LegacyOpMapper {
case 4: case 4:
return ScalarGreaterThanOrEqual.class; return ScalarGreaterThanOrEqual.class;
case 5: case 5:
return ScalarLessThanOrEqual.class; return MatchCondition.class;
case 6: case 6:
return ScalarNotEquals.class; return ScalarNotEquals.class;
case 7: case 7:
@ -428,6 +428,8 @@ public class LegacyOpMapper {
return ScalarXor.class; return ScalarXor.class;
case 10: case 10:
return ScalarNot.class; return ScalarNot.class;
case 11:
return ScalarLessThanOrEqual.class;
default: default:
throw new UnsupportedOperationException("No known scalar bool op for op number: " + opNum); throw new UnsupportedOperationException("No known scalar bool op for op number: " + opNum);
} }

View File

@ -1864,7 +1864,10 @@ public abstract class BaseNDArray implements INDArray, Iterable {
@Override @Override
public INDArray match(INDArray comp, Condition condition) { public INDArray match(INDArray comp, Condition condition) {
return Nd4j.getExecutioner().exec(new MatchConditionTransform(this,comp,condition)); // TODO: obviously, we can make this broadcastable, eventually. But this will require new CustomOp based on MatchCondition
Preconditions.checkArgument(Arrays.equals(this.shape(), comp.shape()), "Shapes must be equal");
Preconditions.checkArgument(this.dataType() == comp.dataType(), "Data types bmust be equal");
return Nd4j.getExecutioner().exec(new MatchConditionTransform(this, comp, Nd4j.createUninitialized(DataType.BOOL, this.shape()), condition));
} }
@Override @Override

View File

@ -64,7 +64,7 @@ public class BroadcastLessThanOrEqual extends BaseBroadcastBoolOp {
@Override @Override
public int opNum() { public int opNum() {
return 5; return 11;
} }
@Override @Override

View File

@ -55,7 +55,7 @@ public class ScalarLessThanOrEqual extends BaseScalarBoolOp {
@Override @Override
public int opNum() { public int opNum() {
return 5; return 11;
} }
@Override @Override

View File

@ -53,6 +53,11 @@ public class MatchConditionTransform extends BaseTransformBoolOp {
public MatchConditionTransform() {} public MatchConditionTransform() {}
public MatchConditionTransform(@NonNull INDArray x, @NonNull INDArray y, @NonNull INDArray z, @NonNull Condition condition) {
this(x, z, Nd4j.EPS_THRESHOLD, condition);
this.y = y;
}
public MatchConditionTransform(@NonNull INDArray x, @NonNull INDArray z, @NonNull Condition condition) { public MatchConditionTransform(@NonNull INDArray x, @NonNull INDArray z, @NonNull Condition condition) {
this(x, z, Nd4j.EPS_THRESHOLD, condition); this(x, z, Nd4j.EPS_THRESHOLD, condition);
} }

View File

@ -25,71 +25,283 @@ import org.nd4j.linalg.factory.Nd4j;
*/ */
public class Conditions { public class Conditions {
/**
* This method will create Condition that checks if value is infinite
* @return
*/
public static Condition isInfinite() { public static Condition isInfinite() {
return new IsInfinite(); return new IsInfinite();
} }
/**
* This method will create Condition that checks if value is NaN
* @return
*/
public static Condition isNan() { public static Condition isNan() {
return new IsNaN(); return new IsNaN();
} }
/**
* This method will create Condition that checks if value is finite
* @return
*/
public static Condition isFinite() { public static Condition isFinite() {
return new IsFinite(); return new IsFinite();
} }
/**
* This method will create Condition that checks if value is NOT finite
* @return
*/
public static Condition notFinite() { public static Condition notFinite() {
return new NotFinite(); return new NotFinite();
} }
/**
* This method will create Condition that checks if value is two values are not equal wrt eps
*
* PLEASE NOTE: This condition should be used only with pairwise methods, i.e. INDArray.match(...)
* @return
*/
public static Condition epsNotEquals() {
// in case of pairwise MatchCondition we don't really care about number here
return epsNotEquals(0.0);
}
/**
* This method will create Condition that checks if value is two values are not equal wrt eps
*
* @return
*/
public static Condition epsNotEquals(Number value) { public static Condition epsNotEquals(Number value) {
return new EpsilonNotEquals(value); return new EpsilonNotEquals(value);
} }
/**
* This method will create Condition that checks if value is two values are equal wrt eps
*
* PLEASE NOTE: This condition should be used only with pairwise methods, i.e. INDArray.match(...)
* @return
*/
public static Condition epsEquals() {
// in case of pairwise MatchCondition we don't really care about number here
return epsEquals(0.0);
}
/**
* This method will create Condition that checks if value is two values are equal wrt eps
*
* @return
*/
public static Condition epsEquals(Number value) { public static Condition epsEquals(Number value) {
return epsEquals(value, Nd4j.EPS_THRESHOLD); return epsEquals(value, Nd4j.EPS_THRESHOLD);
} }
/**
* This method will create Condition that checks if value is two values are equal wrt eps
*
* @return
*/
public static Condition epsEquals(Number value, Number epsilon) { public static Condition epsEquals(Number value, Number epsilon) {
return new EpsilonEquals(value, epsilon.doubleValue()); return new EpsilonEquals(value, epsilon.doubleValue());
} }
/**
* This method will create Condition that checks if value is two values are equal
*
* PLEASE NOTE: This condition should be used only with pairwise methods, i.e. INDArray.match(...)
* @return
*/
public static Condition equals() {
// in case of pairwise MatchCondition we don't really care about number here
return equals(0.0);
}
/**
* This method will create Condition that checks if value is two values are equal
*
* @return
*/
public static Condition equals(Number value) { public static Condition equals(Number value) {
return new EqualsCondition(value); return new EqualsCondition(value);
} }
/**
* This method will create Condition that checks if value is two values are not equal
*
* PLEASE NOTE: This condition should be used only with pairwise methods, i.e. INDArray.match(...)
* @return
*/
public static Condition notEquals() {
// in case of pairwise MatchCondition we don't really care about number here
return notEquals(0.0);
}
/**
* This method will create Condition that checks if value is two values are not equal
*
* @return
*/
public static Condition notEquals(Number value) { public static Condition notEquals(Number value) {
return new NotEqualsCondition(value); return new NotEqualsCondition(value);
} }
/**
* This method will create Condition that checks if value is value X is greater than value Y
*
* PLEASE NOTE: This condition should be used only with pairwise methods, i.e. INDArray.match(...)
* @return
*/
public static Condition greaterThan() {
// in case of pairwise MatchCondition we don't really care about number here
return greaterThan(0.0);
}
/**
* This method will create Condition that checks if value is value X is greater than value Y
*
* @return
*/
public static Condition greaterThan(Number value) { public static Condition greaterThan(Number value) {
return new GreaterThan(value); return new GreaterThan(value);
} }
/**
* This method will create Condition that checks if value is value X is less than value Y
*
* PLEASE NOTE: This condition should be used only with pairwise methods, i.e. INDArray.match(...)
* @return
*/
public static Condition lessThan() {
// in case of pairwise MatchCondition we don't really care about number here
return lessThan(0.0);
}
/**
* This method will create Condition that checks if value is value X is less than value Y
*
* @return
*/
public static Condition lessThan(Number value) { public static Condition lessThan(Number value) {
return new LessThan(value); return new LessThan(value);
} }
/**
* This method will create Condition that checks if value is value X is less than or equal to value Y
*
* PLEASE NOTE: This condition should be used only with pairwise methods, i.e. INDArray.match(...)
* @return
*/
public static Condition lessThanOrEqual() {
// in case of pairwise MatchCondition we don't really care about number here
return lessThanOrEqual(0.0);
}
/**
* This method will create Condition that checks if value is value X is less than or equal to value Y
*
* @return
*/
public static Condition lessThanOrEqual(Number value) { public static Condition lessThanOrEqual(Number value) {
return new LessThanOrEqual(value); return new LessThanOrEqual(value);
} }
/**
* This method will create Condition that checks if value is value X is greater than or equal to value Y
*
* PLEASE NOTE: This condition should be used only with pairwise methods, i.e. INDArray.match(...)
* @return
*/
public static Condition greaterThanOrEqual() {
// in case of pairwise MatchCondition we don't really care about number here
return greaterThanOrEqual(0.0);
}
/**
* This method will create Condition that checks if value is value X is greater than or equal to value Y
*
* @return
*/
public static Condition greaterThanOrEqual(Number value) { public static Condition greaterThanOrEqual(Number value) {
return new GreaterThanOrEqual(value); return new GreaterThanOrEqual(value);
} }
/**
* This method will create Condition that checks if value is value X is greater than or equal to value Y in absolute values
*
* PLEASE NOTE: This condition should be used only with pairwise methods, i.e. INDArray.match(...)
* @return
*/
public static Condition absGreaterThanOrEqual() {
// in case of pairwise MatchCondition we don't really care about number here
return absGreaterThanOrEqual(0.0);
}
/**
* This method will create Condition that checks if value is value X is greater than or equal to value Y in absolute values
*
* @return
*/
public static Condition absGreaterThanOrEqual(Number value) { public static Condition absGreaterThanOrEqual(Number value) {
return new AbsValueGreaterOrEqualsThan(value); return new AbsValueGreaterOrEqualsThan(value);
} }
/**
* This method will create Condition that checks if value is value X is less than or equal to value Y in absolute values
*
* PLEASE NOTE: This condition should be used only with pairwise methods, i.e. INDArray.match(...)
* @return
*/
public static Condition absLessThanOrEqual() {
// in case of pairwise MatchCondition we don't really care about number here
return absLessThanOrEqual(0.0);
}
/**
* This method will create Condition that checks if value is value X is less than or equal to value Y in absolute values
*
* @return
*/
public static Condition absLessThanOrEqual(Number value) { public static Condition absLessThanOrEqual(Number value) {
return new AbsValueLessOrEqualsThan(value); return new AbsValueLessOrEqualsThan(value);
} }
/**
* This method will create Condition that checks if value is value X is greater than value Y in absolute values
*
* PLEASE NOTE: This condition should be used only with pairwise methods, i.e. INDArray.match(...)
* @return
*/
public static Condition absGreaterThan() {
// in case of pairwise MatchCondition we don't really care about number here
return absGreaterThan(0.0);
}
/**
* This method will create Condition that checks if value is value X is greater than value Y in absolute values
*
* @return
*/
public static Condition absGreaterThan(Number value) { public static Condition absGreaterThan(Number value) {
return new AbsValueGreaterThan(value); return new AbsValueGreaterThan(value);
} }
/**
* This method will create Condition that checks if value is value X is less than value Y in absolute values
*
* PLEASE NOTE: This condition should be used only with pairwise methods, i.e. INDArray.match(...)
* @return
*/
public static Condition absLessThan() {
// in case of pairwise MatchCondition we don't really care about number here
return absLessThan(0.0);
}
/**
* This method will create Condition that checks if value is value X is less than value Y in absolute values
*
* @return
*/
public static Condition absLessThan(Number value) { public static Condition absLessThan(Number value) {
return new AbsValueLessThan(value); return new AbsValueLessThan(value);
} }

View File

@ -129,6 +129,7 @@ public interface NativeOps {
@Cast("Nd4jLong *") LongPointer resultShapeInfo, @Cast("Nd4jLong *") LongPointer resultShapeInfo,
Pointer dresult, Pointer dresult,
@Cast("Nd4jLong *") LongPointer dresultShapeInfo, @Cast("Nd4jLong *") LongPointer dresultShapeInfo,
Pointer extraParams,
Pointer hDimension, @Cast("Nd4jLong *") LongPointer hDimensionShape, Pointer hDimension, @Cast("Nd4jLong *") LongPointer hDimensionShape,
Pointer dDimension, @Cast("Nd4jLong *") LongPointer dDimensionShape); Pointer dDimension, @Cast("Nd4jLong *") LongPointer dDimensionShape);

View File

@ -198,7 +198,7 @@ public class CudaExecutioner extends DefaultOpExecutioner {
null, (LongPointer) AtomicAllocator.getInstance().getHostPointer(op.x().shapeInfoDataBuffer()), x, (LongPointer) xShapeInfo, null, (LongPointer) AtomicAllocator.getInstance().getHostPointer(op.x().shapeInfoDataBuffer()), x, (LongPointer) xShapeInfo,
null, (LongPointer) AtomicAllocator.getInstance().getHostPointer(op.y().shapeInfoDataBuffer()), y, (LongPointer) AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(),context), null, (LongPointer) AtomicAllocator.getInstance().getHostPointer(op.y().shapeInfoDataBuffer()), y, (LongPointer) AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(),context),
null, (LongPointer) AtomicAllocator.getInstance().getHostPointer(op.z().shapeInfoDataBuffer()), z, (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context), null, (LongPointer) AtomicAllocator.getInstance().getHostPointer(op.z().shapeInfoDataBuffer()), z, (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context),
null, null, null,
(LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(),
AtomicAllocator.getInstance().getPointer(op.dimensions(), context), AtomicAllocator.getInstance().getPointer(op.dimensions(), context),
null); null);
@ -805,7 +805,7 @@ public class CudaExecutioner extends DefaultOpExecutioner {
null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo,
null, (LongPointer) hostYShapeInfo, y, (LongPointer) yShapeInfo, null, (LongPointer) hostYShapeInfo, y, (LongPointer) yShapeInfo,
null, (LongPointer) hostZShapeInfo, z, (LongPointer) zShapeInfo, null, (LongPointer) hostZShapeInfo, z, (LongPointer) zShapeInfo,
null, null, null,
(LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(),
AtomicAllocator.getInstance().getPointer(op.dimensions(), context), AtomicAllocator.getInstance().getPointer(op.dimensions(), context),
null); null);

View File

@ -916,6 +916,7 @@ public native void execBroadcastBool(
Pointer dY, @Cast("Nd4jLong*") LongPointer dYShapeInfo, Pointer dY, @Cast("Nd4jLong*") LongPointer dYShapeInfo,
Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo,
Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo,
Pointer extraParams,
Pointer hDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, Pointer hDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape,
Pointer dDimension, @Cast("Nd4jLong*") LongPointer dDimensionShape); Pointer dDimension, @Cast("Nd4jLong*") LongPointer dDimensionShape);
public native void execBroadcastBool( public native void execBroadcastBool(
@ -927,6 +928,7 @@ public native void execBroadcastBool(
Pointer dY, @Cast("Nd4jLong*") LongBuffer dYShapeInfo, Pointer dY, @Cast("Nd4jLong*") LongBuffer dYShapeInfo,
Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo,
Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo,
Pointer extraParams,
Pointer hDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, Pointer hDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape,
Pointer dDimension, @Cast("Nd4jLong*") LongBuffer dDimensionShape); Pointer dDimension, @Cast("Nd4jLong*") LongBuffer dDimensionShape);
public native void execBroadcastBool( public native void execBroadcastBool(
@ -938,6 +940,7 @@ public native void execBroadcastBool(
Pointer dY, @Cast("Nd4jLong*") long[] dYShapeInfo, Pointer dY, @Cast("Nd4jLong*") long[] dYShapeInfo,
Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo,
Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo,
Pointer extraParams,
Pointer hDimension, @Cast("Nd4jLong*") long[] hDimensionShape, Pointer hDimension, @Cast("Nd4jLong*") long[] hDimensionShape,
Pointer dDimension, @Cast("Nd4jLong*") long[] dDimensionShape); Pointer dDimension, @Cast("Nd4jLong*") long[] dDimensionShape);

View File

@ -967,6 +967,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
null, null, null, null,
op.z().data().addressPointer(), (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(), op.z().data().addressPointer(), (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(),
null, null, null, null,
null,
op.dimensions().data().addressPointer(), op.dimensions().data().addressPointer(),
(LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(),
null, null,

View File

@ -916,6 +916,7 @@ public native void execBroadcastBool(
Pointer dY, @Cast("Nd4jLong*") LongPointer dYShapeInfo, Pointer dY, @Cast("Nd4jLong*") LongPointer dYShapeInfo,
Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo, Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo,
Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo, Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo,
Pointer extraParams,
Pointer hDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape, Pointer hDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape,
Pointer dDimension, @Cast("Nd4jLong*") LongPointer dDimensionShape); Pointer dDimension, @Cast("Nd4jLong*") LongPointer dDimensionShape);
public native void execBroadcastBool( public native void execBroadcastBool(
@ -927,6 +928,7 @@ public native void execBroadcastBool(
Pointer dY, @Cast("Nd4jLong*") LongBuffer dYShapeInfo, Pointer dY, @Cast("Nd4jLong*") LongBuffer dYShapeInfo,
Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo, Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo,
Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo, Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo,
Pointer extraParams,
Pointer hDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape, Pointer hDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape,
Pointer dDimension, @Cast("Nd4jLong*") LongBuffer dDimensionShape); Pointer dDimension, @Cast("Nd4jLong*") LongBuffer dDimensionShape);
public native void execBroadcastBool( public native void execBroadcastBool(
@ -938,6 +940,7 @@ public native void execBroadcastBool(
Pointer dY, @Cast("Nd4jLong*") long[] dYShapeInfo, Pointer dY, @Cast("Nd4jLong*") long[] dYShapeInfo,
Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo, Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo,
Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo, Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo,
Pointer extraParams,
Pointer hDimension, @Cast("Nd4jLong*") long[] hDimensionShape, Pointer hDimension, @Cast("Nd4jLong*") long[] hDimensionShape,
Pointer dDimension, @Cast("Nd4jLong*") long[] dDimensionShape); Pointer dDimension, @Cast("Nd4jLong*") long[] dDimensionShape);

View File

@ -46,6 +46,7 @@ import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.nativeblas.NativeOpsHolder; import org.nd4j.nativeblas.NativeOpsHolder;
import java.util.ArrayList; import java.util.ArrayList;
@ -1087,6 +1088,24 @@ public class CustomOpsTests extends BaseNd4jTest {
assertArrayEquals(new long[]{1,10, 2}, lsd.get(0).getShape()); assertArrayEquals(new long[]{1,10, 2}, lsd.get(0).getShape());
} }
@Test
public void testMatch_1() {
INDArray x = Nd4j.ones(DataType.FLOAT, 3,3);
INDArray y = Nd4j.linspace(DataType.FLOAT, -5, 9, 1).reshape(3, 3);
val c = Conditions.equals(0.0);
System.out.println("Y:\n" + y);
INDArray z = x.match(y, c);
INDArray exp = Nd4j.createFromArray(new boolean[][]{
{false, false, false},
{false, false, false},
{true, false, false}
});
assertEquals(exp, z);
}
@Test @Test
public void testCreateOp_1() { public void testCreateOp_1() {