Few fixes (#66)

* skip legacy transforms execution in case of empty input arrays

Signed-off-by: raver119 <raver119@gmail.com>

* - BroadcastBool ops now accept extraParams to make MatchCondition possible
- TrueBroadcastHelper now uses samediff::threads

Signed-off-by: raver119 <raver119@gmail.com>

* java side

Signed-off-by: raver119 <raver119@gmail.com>

* trigger jenkins

Signed-off-by: raver119 <raver119@gmail.com>

* update LessThanOrEqual opNum mapping

Signed-off-by: raver119 <raver119@gmail.com>

* update LessThanOrEqual opNum mapping

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-11-21 15:43:03 +03:00 committed by GitHub
parent 83cb0d9329
commit 064a56ccf1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
29 changed files with 607 additions and 203 deletions

View File

@ -2772,9 +2772,9 @@ void NDArray::applyBroadcast(nd4j::broadcast::BoolOps op, const std::vector<int>
// TODO: eventually we want separate tads here
NDArray::prepareSpecialUse({result}, {this, other});
if(max == this)
NativeOpExecutioner::execBroadcastBool( getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), copy.data(), (int)copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets());
NativeOpExecutioner::execBroadcastBool( getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), nullptr, copy.data(), (int)copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets());
else
NativeOpExecutioner::execInverseBroadcastBool(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), copy.data(), (int)copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets());
NativeOpExecutioner::execInverseBroadcastBool(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), nullptr, copy.data(), (int)copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets());
registerSpecialUse({result}, {this, other});
}

View File

@ -284,6 +284,7 @@ static void execScalarInt(nd4j::LaunchContext *lc,
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *extraParams,
int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ);
@ -296,6 +297,7 @@ static void execScalarInt(nd4j::LaunchContext *lc,
void *dY, Nd4jLong *dYShapeInfo,
void *result, Nd4jLong *resultShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *extraParams,
int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);

View File

@ -179,6 +179,7 @@ ND4J_EXPORT void execBroadcastBool(
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *extraParams,
void *hDimension, Nd4jLong *hDimensionShape,
void *dDimension, Nd4jLong *dDimensionShape);

View File

@ -156,6 +156,9 @@ void NativeOpExecutioner::execBroadcast(nd4j::LaunchContext *lc,
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo))
return;
#ifdef __ND4J_EXPERIMENTAL__
BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::broadcast::Broadcast, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, LIBND4J_TYPES);
#else
@ -187,7 +190,8 @@ void NativeOpExecutioner::execInverseBroadcast(nd4j::LaunchContext *lc,
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo))
return;
if (!nd4j::Environment::getInstance()->isExperimentalBuild())
if ((yType != xType && yType != nd4j::DataType::BOOL) || xType != zType)
@ -219,6 +223,7 @@ void NativeOpExecutioner::execBroadcastBool(nd4j::LaunchContext *lc,
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *extraParams,
int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) {
@ -228,8 +233,11 @@ void NativeOpExecutioner::execBroadcastBool(nd4j::LaunchContext *lc,
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo))
return;
auto func = PRAGMA_THREADS_FOR {
BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ, start, stop), LIBND4J_TYPES, BOOL_TYPES);
BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ, start, stop), LIBND4J_TYPES, BOOL_TYPES);
};
auto xLen = shape::length(hXShapeInfo);
@ -247,22 +255,24 @@ void NativeOpExecutioner::execInverseBroadcastBool(nd4j::LaunchContext *lc,
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *extraParams,
int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) {
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo))
return;
if (!nd4j::Environment::getInstance()->isExperimentalBuild())
if (yType != xType || nd4j::DataType::BOOL != zType)
throw nd4j::datatype_exception::build("NativeOps::execInverseBroadcastBool both operands must have same data type", xType, yType);
auto func = PRAGMA_THREADS_FOR {
BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::execInverse(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ, start, stop), LIBND4J_TYPES, BOOL_TYPES);
BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::execInverse(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ, start, stop), LIBND4J_TYPES, BOOL_TYPES);
};
auto xLen = shape::length(hXShapeInfo);
@ -292,6 +302,9 @@ void NativeOpExecutioner::execBroadcastInt(nd4j::LaunchContext *lc,
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo))
return;
if (xType != yType || xType != zType)
throw nd4j::datatype_exception::build("NativeOpExecutioner::execBroadcastInt", zType, xType, yType);
@ -321,12 +334,13 @@ void NativeOpExecutioner::execInverseBroadcastInt(nd4j::LaunchContext *lc,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) {
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo))
return;
if (xType != yType || xType != zType)
throw nd4j::datatype_exception::build("NativeOpExecutioner::execInverseBroadcastInt", zType, xType, yType);
@ -367,11 +381,13 @@ void NativeOpExecutioner::execPairwiseTransform(nd4j::LaunchContext *lc,
void *dZ, Nd4jLong *dZShapeInfo,
void *extraParams) {
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo))
return;
#ifdef __ND4J_EXPERIMENTAL__
BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::pairwise_transforms::PairWiseTransform, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, extraParams), LIBND4J_TYPES, LIBND4J_TYPES);
#else
@ -403,6 +419,9 @@ void NativeOpExecutioner::execPairwiseBoolTransform(nd4j::LaunchContext *lc,
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo))
return;
if (xType != yType)
throw nd4j::datatype_exception::build("NativeOpExecutioner::execPairwiseBoolTransform", xType, yType);
@ -429,11 +448,13 @@ void NativeOpExecutioner::execPairwiseIntTransform(nd4j::LaunchContext *lc,
void *dZ, Nd4jLong *dZShapeInfo,
void *extraParams) {
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo))
return;
if (xType != yType || xType != zType)
throw nd4j::datatype_exception::build("NativeOpExecutioner::execPairwiseIntTransform", zType, xType, yType);
@ -837,11 +858,13 @@ void NativeOpExecutioner::execScalar(nd4j::LaunchContext *lc,
void *dScalar, Nd4jLong *dScalarShapeInfo,
void *extraParams, bool allowParallelism) {
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo))
return;
#ifdef __ND4J_EXPERIMENTAL__
BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::scalar::ScalarTransform, ::transform(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, hScalar, extraParams), LIBND4J_TYPES, LIBND4J_TYPES);
#else
@ -872,11 +895,13 @@ void NativeOpExecutioner::execScalar(nd4j::LaunchContext *lc,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) {
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo))
return;
#ifdef __ND4J_EXPERIMENTAL__
BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::scalar::ScalarTransform, ::transform(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, hScalars, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, LIBND4J_TYPES);
#else
@ -904,12 +929,13 @@ void NativeOpExecutioner::execScalarBool(nd4j::LaunchContext *lc,
void *dScalar, Nd4jLong *dSscalarShapeInfo,
void *extraParams, bool allowParallelism) {
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto yType = nd4j::ArrayOptions::dataType(hSscalarShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hSscalarShapeInfo))
return;
if (xType != yType)
throw nd4j::datatype_exception::build("NativeOpExecutioner::execScalarBool", xType, yType);
@ -939,11 +965,13 @@ void NativeOpExecutioner::execScalarBool(nd4j::LaunchContext *lc,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) {
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo))
return;
if (xType != yType)
throw nd4j::datatype_exception::build("NativeOpExecutioner::execScalarBool", xType, yType);
@ -969,12 +997,13 @@ void NativeOpExecutioner::execScalarInt(nd4j::LaunchContext *lc,
void *dScalar, Nd4jLong *dSscalarShapeInfo,
void *extraParams, bool allowParallelism) {
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto yType = nd4j::ArrayOptions::dataType(hSscalarShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hSscalarShapeInfo))
return;
if (xType != yType || xType != zType)
throw nd4j::datatype_exception::build("NativeOpExecutioner::execScalarInt", xType, yType);
@ -1004,11 +1033,13 @@ void NativeOpExecutioner::execScalarInt(nd4j::LaunchContext *lc,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) {
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo))
return;
if (xType != yType || xType != zType)
throw nd4j::datatype_exception::build("NativeOpExecutioner::execScalarInt", xType, yType);
@ -1126,6 +1157,9 @@ void NativeOpExecutioner::execTransformFloat(nd4j::LaunchContext *lc,
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo))
return;
auto func = PRAGMA_THREADS_DO {
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformFloat, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES, FLOAT_TYPES);
};
@ -1145,6 +1179,9 @@ void NativeOpExecutioner::execTransformBool(nd4j::LaunchContext *lc,
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo))
return;
auto func = PRAGMA_THREADS_DO {
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformBool, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES, BOOL_TYPES);
};
@ -1164,6 +1201,9 @@ void NativeOpExecutioner::execTransformAny(nd4j::LaunchContext *lc,
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo))
return;
auto func = PRAGMA_THREADS_DO {
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformAny, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES, LIBND4J_TYPES);
};
@ -1183,6 +1223,9 @@ void NativeOpExecutioner::execTransformSame(nd4j::LaunchContext *lc,
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo))
return;
auto func = PRAGMA_THREADS_DO {
BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformSame, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES);
};
@ -1202,6 +1245,9 @@ void NativeOpExecutioner::execTransformStrict(nd4j::LaunchContext *lc,
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo))
return;
auto func = PRAGMA_THREADS_DO {
BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformStrict, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), FLOAT_TYPES);
};

View File

@ -231,8 +231,9 @@ void execBroadcastBool(Nd4jPointer *extraPointers,
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *hDimension, Nd4jLong *hDimensionShape,
void *dDimension, Nd4jLong *dDimensionShape) {
void *extraParams,
void *hDimension, Nd4jLong *hDimensionShape,
void *dDimension, Nd4jLong *dDimensionShape) {
try {
auto dimension = reinterpret_cast<int *>(hDimension);
int dimensionLength = static_cast<int>(shape::length(hDimensionShape));
@ -259,6 +260,7 @@ void execBroadcastBool(Nd4jPointer *extraPointers,
dYShapeInfo,
hZ, hZShapeInfo,
dZ, dZShapeInfo,
extraParams,
dimension,
dimensionLength, hTADShapeInfo, hTADOffsets, hTADShapeInfoZ,
hTADOffsetsZ);

View File

@ -101,6 +101,9 @@ void NativeOpExecutioner::execPairwiseTransform(nd4j::LaunchContext *lc,
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo))
return;
if (xType != zType && yType != zType)
throw std::runtime_error("NativeOpExecutioner::execPairwiseTransform requires Z operand to have either X or Y type");
if (lc == nullptr)
@ -139,6 +142,9 @@ void NativeOpExecutioner::execPairwiseBoolTransform( nd4j::LaunchContext *lc,
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo))
return;
if (!DataTypeUtils::isB(zType))
throw nd4j::datatype_exception::build("NativeOpExecutioner::execPairwiseBoolTransform wrong Z operand data type", nd4j::DataType::BOOL, zType);
@ -172,6 +178,9 @@ void NativeOpExecutioner::execPairwiseIntTransform( nd4j::LaunchContext *lc,
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo))
return;
if (!DataTypeUtils::isZ(zType))
throw nd4j::datatype_exception::build("NativeOpExecutioner::execPairwiseIntTransform wrong Z operand data type", nd4j::DataType::BOOL, zType);
@ -223,6 +232,7 @@ void NativeOpExecutioner::execBroadcastBool(nd4j::LaunchContext *lc,
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *extraParams,
int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) {
@ -233,6 +243,9 @@ void NativeOpExecutioner::execBroadcastBool(nd4j::LaunchContext *lc,
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo))
return;
if (!DataTypeUtils::isB(zType))
throw std::runtime_error("NativeOpExecutioner::execBroadcastBool requires Z operand to have BOOL type");
@ -244,7 +257,7 @@ void NativeOpExecutioner::execBroadcastBool(nd4j::LaunchContext *lc,
dim3 launchDims(256, 256, 1024);
BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::execBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, BOOL_TYPES)
BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::execBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, BOOL_TYPES)
// TODO: remove after the release
auto res = cudaStreamSynchronize(*stream);
@ -260,6 +273,7 @@ void NativeOpExecutioner::execInverseBroadcastBool(nd4j::LaunchContext *lc,
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *extraParams,
int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) {
@ -269,18 +283,18 @@ void NativeOpExecutioner::execInverseBroadcastBool(nd4j::LaunchContext *lc,
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo))
return;
if (!DataTypeUtils::isB(zType))
throw std::runtime_error("NativeOpExecutioner::execBroadcastBool requires Z operand to have BOOL type");
if (yType != xType)
throw std::runtime_error("NativeOpExecutioner::execBroadcastBool requires both X & Y operands to have same type");
if (nd4j::Environment::getInstance()->isDebugAndVerbose())
printf("F3BI opNum:[%i]\n", opNum);
dim3 launchDims(256, 256, 1024);
BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::execInverseBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, BOOL_TYPES)
BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::execInverseBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, BOOL_TYPES)
// TODO: remove after the release
auto res = cudaStreamSynchronize(*stream);
@ -308,15 +322,15 @@ void NativeOpExecutioner::execBroadcastInt(nd4j::LaunchContext *lc,
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo))
return;
if (!DataTypeUtils::isZ(zType))
throw std::runtime_error("NativeOpExecutioner::execBroadcastInt requires Z operand to have INT type");
if (yType != xType || zType != xType)
throw std::runtime_error("NativeOpExecutioner::execBroadcastInt requires both X & Y operands to have same type");
if (nd4j::Environment::getInstance()->isDebugAndVerbose())
printf("F3B opNum:[%i]\n", opNum);
dim3 launchDims(256, 256, 1024);
BUILD_SINGLE_SELECTOR(xType, functions::broadcast::BroadcastInt, ::execBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), INTEGER_TYPES)
@ -344,6 +358,9 @@ void NativeOpExecutioner::execInverseBroadcastInt(nd4j::LaunchContext *lc,
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo))
return;
if (!DataTypeUtils::isZ(zType))
throw std::runtime_error("NativeOpExecutioner::execBroadcastInt requires Z operand to have INT type");
@ -394,8 +411,8 @@ void NativeOpExecutioner::execBroadcast(nd4j::LaunchContext *lc,
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (nd4j::Environment::getInstance()->isDebugAndVerbose())
printf("F3 opNum:[%i]\n", opNum);
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo))
return;
dim3 launchDims(256, 256, 1024);
@ -429,8 +446,8 @@ void NativeOpExecutioner::execInverseBroadcast(nd4j::LaunchContext *lc,
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (nd4j::Environment::getInstance()->isDebugAndVerbose())
printf("F3I opNum:[%i]\n", opNum);
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo))
return;
dim3 launchDims(256, 256, 1024);
@ -832,16 +849,21 @@ void NativeOpExecutioner::execTransformSame(nd4j::LaunchContext *lc,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
auto stream = lc->getCudaStream();
dim3 launchDims(512, 512, 16384);
auto xRank = shape::rank(hXShapeInfo);
auto zRank = shape::rank(hZShapeInfo);
auto xType = ArrayOptions::dataType(hXShapeInfo);
auto zRank = shape::rank(hZShapeInfo);
auto xType = ArrayOptions::dataType(hXShapeInfo);
auto zType = ArrayOptions::dataType(hZShapeInfo);
if (xType != zType)
throw std::runtime_error("NativeOpExecutioner::execTransformSame requires X & Z to have same type");
if (shape::isEmpty(hXShapeInfo)) {
return;
}
if (xType != zType) {
throw std::runtime_error("NativeOpExecutioner::execTransformSame requires X & Z to have same type");
}
dim3 launchDims(512, 512, 16384);
BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformSame, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), LIBND4J_TYPES);
// TODO: remove after the release
@ -861,16 +883,21 @@ void NativeOpExecutioner::execTransformBool(nd4j::LaunchContext *lc,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
auto stream = lc->getCudaStream();
dim3 launchDims(512, 512, 16384);
auto xRank = shape::rank(hXShapeInfo);
auto zRank = shape::rank(hZShapeInfo);
auto xType = ArrayOptions::dataType(hXShapeInfo);
auto xRank = shape::rank(hXShapeInfo);
auto zRank = shape::rank(hZShapeInfo);
auto xType = ArrayOptions::dataType(hXShapeInfo);
auto zType = ArrayOptions::dataType(hZShapeInfo);
if (!DataTypeUtils::isB(zType))
throw std::runtime_error("NativeOpExecutioner::execTransformBool requires Z to have same boolean type");
if (shape::isEmpty(hXShapeInfo)) {
return;
}
if (!DataTypeUtils::isB(zType)) {
throw std::runtime_error("NativeOpExecutioner::execTransformBool requires Z to have same boolean type");
}
dim3 launchDims(512, 512, 16384);
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformBool, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), LIBND4J_TYPES, BOOL_TYPES);
// TODO: remove after the release
@ -896,6 +923,9 @@ void NativeOpExecutioner::execTransformAny(nd4j::LaunchContext *lc,
auto xType = ArrayOptions::dataType(hXShapeInfo);
auto zType = ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo))
return;
dim3 launchDims(512, 512, 2048);
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformAny, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), LIBND4J_TYPES, LIBND4J_TYPES);
@ -917,16 +947,21 @@ void NativeOpExecutioner::execTransformStrict(nd4j::LaunchContext *lc,
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
auto stream = lc->getCudaStream();
dim3 launchDims(512, 512, 16384);
auto xRank = shape::rank(hXShapeInfo);
auto zRank = shape::rank(hZShapeInfo);
auto xType = ArrayOptions::dataType(hXShapeInfo);
auto zType = ArrayOptions::dataType(hZShapeInfo);
if (xType != zType || !DataTypeUtils::isR(xType))
throw datatype_exception::build("NativeOpExecutioner::execTransformStrict requires X & Z to have same floating point type", xType, zType);
if (shape::isEmpty(hXShapeInfo)) {
return;
}
if (xType != zType || !DataTypeUtils::isR(xType)) {
throw datatype_exception::build("NativeOpExecutioner::execTransformStrict requires X & Z to have same floating point type", xType, zType);
}
dim3 launchDims(512, 512, 16384);
BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformStrict, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), FLOAT_TYPES);
// TODO: remove after the release
@ -953,6 +988,9 @@ void NativeOpExecutioner::execTransformFloat(nd4j::LaunchContext *lc,
auto xType = ArrayOptions::dataType(hXShapeInfo);
auto zType = ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo))
return;
if (!DataTypeUtils::isR(zType))
throw datatype_exception::build("NativeOpExecutioner::execTransformFloat requires Z to have floating point type", zType);
@ -1175,6 +1213,9 @@ void NativeOpExecutioner::execScalarBool(nd4j::LaunchContext *lc,
auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo))
return;
if (xType != yType )
throw std::runtime_error("NativeOpExecutioner::execScalarBool requires X & Y to have same type");
@ -1211,6 +1252,9 @@ void NativeOpExecutioner::execScalarBool(nd4j::LaunchContext *lc,
auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo))
return;
if (xType != yType )
throw std::runtime_error("NativeOpExecutioner::execScalarBool requires X & Y to have same type");
@ -1244,6 +1288,9 @@ void NativeOpExecutioner::execScalarInt(nd4j::LaunchContext *lc,
auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo))
return;
if (xType != yType || zType != xType)
throw std::runtime_error("NativeOpExecutioner::execScalarInt requires X & Y to have same type");
@ -1280,6 +1327,9 @@ void NativeOpExecutioner::execScalarInt(nd4j::LaunchContext *lc,
auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo))
return;
if (xType != yType || zType != xType)
throw std::runtime_error("NativeOpExecutioner::execScalarInt requires X & Y to have same type");
@ -1313,6 +1363,9 @@ void NativeOpExecutioner::execScalar(nd4j::LaunchContext *lc,
auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo))
return;
#ifdef __ND4J_EXPERIMENTAL__
BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::scalar::ScalarTransform, ::executeCudaShaped(launchDims, stream, opNum, dX, dXShapeInfo, hXShapeInfo, dZ, dZShapeInfo, hZShapeInfo, dScalar, extraParams), LIBND4J_TYPES, LIBND4J_TYPES);
@ -1346,6 +1399,9 @@ void NativeOpExecutioner::execScalar(nd4j::LaunchContext *lc,
auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo);
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo))
return;
dim3 launchDims(256, 256, 16384);
#ifdef __ND4J_EXPERIMENTAL__

View File

@ -294,6 +294,7 @@ void execBroadcastBool(Nd4jPointer *extraPointers,
void *dY, Nd4jLong *dYShapeInfo,
void *hZ, Nd4jLong *hZShapeInfo,
void *dZ, Nd4jLong *dZShapeInfo,
void *extraParams,
void *hDimension, Nd4jLong *hDimensionShape,
void *dDimension, Nd4jLong *dDimensionShape) {
try {
@ -313,7 +314,7 @@ void execBroadcastBool(Nd4jPointer *extraPointers,
LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]);
NativeOpExecutioner::execBroadcastBool(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, hY, hYShapeInfo, dY,
dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, dimension,
dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, extraParams, dimension,
dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ,
tadOffsetsZ);
} catch (std::exception &e) {

View File

@ -20,6 +20,7 @@
#include <TrueBroadcastHelper.h>
#include <ops/ops.h>
#include <execution/Threads.h>
using namespace simdOps;
@ -47,36 +48,39 @@ void TrueBroadcastHelper<X, Y, Z>::exec(const NDArray& xArr, const NDArray& yArr
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(zLen > Environment::getInstance()->elementwiseThreshold()) firstprivate(xCoords, yCoords, zCoords))
for (Nd4jLong i = 0; i < zLen; ++i) {
auto func = PRAGMA_THREADS_FOR {
for (auto i = start; i < stop; ++i) {
shape::index2coords(i, zShapeInfo, zCoords.data());
shape::index2coords(i, zShapeInfo, zCoords.data());
for(int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) {
for (int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) {
if(ix >= 0) {
if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) {
xCoords[ix--] = zCoords[iz];
} else {
xCoords[ix--] = 0;
if (ix >= 0) {
if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) {
xCoords[ix--] = zCoords[iz];
} else {
xCoords[ix--] = 0;
}
}
if (iy >= 0) {
if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) {
yCoords[iy--] = zCoords[iz];
} else {
yCoords[iy--] = 0;
}
}
}
if(iy >= 0) {
if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) {
yCoords[iy--] = zCoords[iz];
} else {
yCoords[iy--] = 0;
}
}
const auto xOffset = shape::getOffset(xShapeInfo, xCoords.data());
const auto yOffset = shape::getOffset(yShapeInfo, yCoords.data());
const auto zOffset = shape::getOffset(zShapeInfo, zCoords.data());
z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
}
};
const auto xOffset = shape::getOffset(xShapeInfo, xCoords.data());
const auto yOffset = shape::getOffset(yShapeInfo, yCoords.data());
const auto zOffset = shape::getOffset(zShapeInfo, zCoords.data());
z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
}
samediff::Threads::parallel_for(func, 0, zLen);
}
template <typename X, typename Y, typename Z>
@ -103,38 +107,40 @@ void TrueBroadcastBoolHelper<X, Z>::exec(const NDArray& xArr, const NDArray& yAr
const Nd4jLong zLen = zArr.lengthOf();
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
auto func = PRAGMA_THREADS_FOR {
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
for (auto i = start; i < stop; ++i) {
PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(zLen > Environment::getInstance()->elementwiseThreshold()) firstprivate(xCoords, yCoords, zCoords))
for (Nd4jLong i = 0; i < zLen; ++i) {
shape::index2coords(i, zShapeInfo, zCoords.data());
shape::index2coords(i, zShapeInfo, zCoords.data());
for (int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) {
for(int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) {
if (ix >= 0) {
if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) {
xCoords[ix--] = zCoords[iz];
} else {
xCoords[ix--] = 0;
}
}
if(ix >= 0) {
if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) {
xCoords[ix--] = zCoords[iz];
} else {
xCoords[ix--] = 0;
if (iy >= 0) {
if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) {
yCoords[iy--] = zCoords[iz];
} else {
yCoords[iy--] = 0;
}
}
}
if(iy >= 0) {
if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) {
yCoords[iy--] = zCoords[iz];
} else {
yCoords[iy--] = 0;
}
}
const auto xOffset = shape::getOffset(xShapeInfo, xCoords.data());
const auto yOffset = shape::getOffset(yShapeInfo, yCoords.data());
const auto zOffset = shape::getOffset(zShapeInfo, zCoords.data());
z[zOffset] = OpType::op(x[xOffset], y[yOffset], nullptr);
}
};
const auto xOffset = shape::getOffset(xShapeInfo, xCoords.data());
const auto yOffset = shape::getOffset(yShapeInfo, yCoords.data());
const auto zOffset = shape::getOffset(zShapeInfo, zCoords.data());
z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
}
samediff::Threads::parallel_for(func, 0, zLen);
}
template <typename X, typename Y>
@ -163,36 +169,39 @@ void TrueBroadcastIntHelper<X>::exec(const NDArray& xArr, const NDArray& yArr, N
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(zLen > Environment::getInstance()->elementwiseThreshold()) firstprivate(xCoords, yCoords, zCoords))
for (Nd4jLong i = 0; i < zLen; ++i) {
auto func = PRAGMA_THREADS_FOR {
for (auto i = start; i < stop; ++i) {
shape::index2coords(i, zShapeInfo, zCoords.data());
shape::index2coords(i, zShapeInfo, zCoords.data());
for(int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) {
for (int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) {
if(ix >= 0) {
if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) {
xCoords[ix--] = zCoords[iz];
} else {
xCoords[ix--] = 0;
if (ix >= 0) {
if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) {
xCoords[ix--] = zCoords[iz];
} else {
xCoords[ix--] = 0;
}
}
if (iy >= 0) {
if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) {
yCoords[iy--] = zCoords[iz];
} else {
yCoords[iy--] = 0;
}
}
}
if(iy >= 0) {
if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) {
yCoords[iy--] = zCoords[iz];
} else {
yCoords[iy--] = 0;
}
}
const auto xOffset = shape::getOffset(xShapeInfo, xCoords.data());
const auto yOffset = shape::getOffset(yShapeInfo, yCoords.data());
const auto zOffset = shape::getOffset(zShapeInfo, zCoords.data());
z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
}
};
const auto xOffset = shape::getOffset(xShapeInfo, xCoords.data());
const auto yOffset = shape::getOffset(yShapeInfo, yCoords.data());
const auto zOffset = shape::getOffset(zShapeInfo, zCoords.data());
z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
}
samediff::Threads::parallel_for(func, 0, zLen);
}
template <typename X>

View File

@ -86,7 +86,7 @@ __global__ static void trueBroadcastCuda(const void* vx, const Nd4jLong* xShapeI
const auto zOffset = shape::getOffset(zShapeInfo, zCoords);
const auto yOffset = shape::getOffset(yShapeInfo, yCoords);
z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
z[zOffset] = OpType::op(x[xOffset], y[yOffset], nullptr);
}
}
@ -172,7 +172,7 @@ __global__ static void trueBroadcastBoolCuda(const void* vx, const Nd4jLong* xSh
const auto zOffset = shape::getOffset(zShapeInfo, zCoords);
const auto yOffset = shape::getOffset(yShapeInfo, yCoords);
z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
z[zOffset] = OpType::op(x[xOffset], y[yOffset], nullptr);
}
}

View File

@ -65,13 +65,14 @@ namespace functions {
Nd4jLong *yShapeInfo,
void *result,
Nd4jLong *resultShapeInfo,
void *extraParams,
int *dimension,
int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
template <typename OpClass>
static __host__ void intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
static __host__ void intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
static __host__ void execBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
static __host__ void execBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
template<typename OpType>
static __device__ void transformInverseCuda(
@ -81,13 +82,14 @@ namespace functions {
Nd4jLong *yShapeInfo,
void *result,
Nd4jLong *resultShapeInfo,
void *extraParams,
int *dimension,
int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
template <typename OpClass>
static __host__ void intermediateInverseBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
static __host__ void intermediateInverseBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
static __host__ void execInverseBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
static __host__ void execInverseBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
#else
@ -98,6 +100,7 @@ namespace functions {
Nd4jLong *yShapeInfo,
void *result,
Nd4jLong *resultShapeInfo,
void *extraParams,
int *dimension,
int dimensionLength,
Nd4jLong *tadShapeInfo,
@ -114,6 +117,7 @@ namespace functions {
Nd4jLong *yShapeInfo,
void *result,
Nd4jLong *resultShapeInfo,
void *extraParams,
int *dimension,
int dimensionLength,
Nd4jLong *tadShapeInfo,
@ -141,6 +145,7 @@ namespace functions {
Nd4jLong *yShapeInfo,
void *result,
Nd4jLong *resultShapeInfo,
void *extraParams,
int *dimension,
int dimensionLength,
Nd4jLong *tadShapeInfo,
@ -157,6 +162,7 @@ namespace functions {
Nd4jLong *yShapeInfo,
void *result,
Nd4jLong *resultShapeInfo,
void *extraParams,
int *dimension,
int dimensionLength,
Nd4jLong *tadShapeInfo,

View File

@ -39,6 +39,7 @@ namespace functions {
Nd4jLong *yShapeInfo,
void *z,
Nd4jLong *zShapeInfo,
void *extraParams,
int *dimension,
int dimensionLength,
Nd4jLong *xTadShapeInfo,
@ -53,6 +54,7 @@ namespace functions {
yShapeInfo,
z,
zShapeInfo,
extraParams,
dimension,
dimensionLength,
xTadShapeInfo,
@ -69,6 +71,7 @@ namespace functions {
Nd4jLong *yShapeInfo,
void *z,
Nd4jLong *zShapeInfo,
void *extraParams,
int *dimension,
int dimensionLength,
Nd4jLong *xTadShapeInfo,
@ -83,6 +86,7 @@ namespace functions {
yShapeInfo,
z,
zShapeInfo,
extraParams,
dimension,
dimensionLength,
xTadShapeInfo,
@ -99,6 +103,7 @@ namespace functions {
Nd4jLong *yShapeInfo,
void *vz,
Nd4jLong *zShapeInfo,
void *vextraParams,
int *dimension,
int dimensionLength,
Nd4jLong *xTadShapeInfo,
@ -111,6 +116,7 @@ namespace functions {
auto x = reinterpret_cast<X *>(vx);
auto y = reinterpret_cast<X *>(vy);
auto z = reinterpret_cast<Z *>(vz);
auto extraParams = reinterpret_cast<X*>(vextraParams);
//decompose in to several sub tads after
//moving all dimensions (in sorted order)
@ -155,7 +161,7 @@ namespace functions {
PRAGMA_OMP_SIMD
for (unsigned int f = 0; f < tadLength; f++)
oZ[f] = OpType::op(oX[f], y[f]);
oZ[f] = OpType::op(oX[f], y[f], extraParams);
}
}
else if(kindOfLoop == nd4j::LoopKind::EWSNONZERO) {
@ -165,7 +171,7 @@ namespace functions {
PRAGMA_OMP_SIMD
for (unsigned int f = 0; f < tadLength; f++)
oZ[f * zEws] = OpType::op(oX[f * xEws], y[f * yEws]);
oZ[f * zEws] = OpType::op(oX[f * xEws], y[f * yEws], extraParams);
};
}
else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xTadShapeShapeInfo, zTadShapeInfo)) {
@ -179,7 +185,7 @@ namespace functions {
PRAGMA_OMP_SIMD
for (int f = 0; f < tadLength; f++) {
auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
oZ[offset] = OpType::op(oX[offset], y[offset]);
oZ[offset] = OpType::op(oX[offset], y[offset], extraParams);
}
};
}
@ -197,7 +203,7 @@ namespace functions {
for (int f = 0; f < tadLength; f++) {
auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ);
oZ[zOffset] = OpType::op(oX[offset], y[offset]);
oZ[zOffset] = OpType::op(oX[offset], y[offset], extraParams);
}
};
}
@ -215,7 +221,7 @@ namespace functions {
for (int f = 0; f < tadLength; f++) {
auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
auto yOffset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY);
oZ[offset] = OpType::op(oX[offset], y[yOffset]);
oZ[offset] = OpType::op(oX[offset], y[yOffset], extraParams);
}
};
@ -234,7 +240,7 @@ namespace functions {
for (int f = 0; f < tadLength; f++) {
auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
auto offset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY);
oZ[offset] = OpType::op(oX[xOffset], y[offset]);
oZ[offset] = OpType::op(oX[xOffset], y[offset], extraParams);
}
};
}
@ -255,7 +261,7 @@ namespace functions {
auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
auto yOffset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY);
auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ);
oZ[zOffset] = OpType::op(oX[xOffset], y[yOffset]);
oZ[zOffset] = OpType::op(oX[xOffset], y[yOffset], extraParams);
}
};
}
@ -270,6 +276,7 @@ namespace functions {
Nd4jLong *yShapeInfo,
void *vz,
Nd4jLong *zShapeInfo,
void *vextraParams,
int *dimension,
int dimensionLength,
Nd4jLong *yTadShapeInfo,
@ -282,6 +289,7 @@ namespace functions {
auto x = reinterpret_cast<X *>(vx);
auto y = reinterpret_cast<X *>(vy);
auto z = reinterpret_cast<Z *>(vz);
auto extraParams = reinterpret_cast<X*>(vextraParams);
//decompose in to several sub tads after
//moving all dimensions (in sorted order)
@ -326,7 +334,7 @@ namespace functions {
PRAGMA_OMP_SIMD
for (unsigned int f = 0; f < tadLength; f++)
oZ[f] = OpType::op(x[f], oY[f]);
oZ[f] = OpType::op(x[f], oY[f], extraParams);
}
}
else if(kindOfLoop == nd4j::LoopKind::EWSNONZERO) {
@ -336,7 +344,7 @@ namespace functions {
PRAGMA_OMP_SIMD
for (uint f = 0; f < tadLength; f++)
oZ[f * zEws] = OpType::op(x[f * xEws], oY[f * yEws]);
oZ[f * zEws] = OpType::op(x[f * xEws], oY[f * yEws], extraParams);
}
}
else if(shape::haveSameShapeAndStrides(yTadShapeShapeInfo, xShapeInfo) && shape::haveSameShapeAndStrides(yTadShapeShapeInfo, zTadShapeInfo)) {
@ -351,7 +359,7 @@ namespace functions {
PRAGMA_OMP_SIMD
for (int f = 0; f < tadLength; f++) {
auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
oZ[offset] = OpType::op(x[offset], oY[offset]);
oZ[offset] = OpType::op(x[offset], oY[offset], extraParams);
}
}
}
@ -370,7 +378,7 @@ namespace functions {
for (int f = 0; f < tadLength; f++) {
auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ);
oZ[zOffset] = OpType::op(x[offset], oY[offset]);
oZ[zOffset] = OpType::op(x[offset], oY[offset], extraParams);
}
}
}
@ -389,7 +397,7 @@ namespace functions {
for (int f = 0; f < tadLength; f++) {
auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
auto xOffset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX);
oZ[offset] = OpType::op(x[xOffset], oY[offset]);
oZ[offset] = OpType::op(x[xOffset], oY[offset], extraParams);
}
}
}
@ -408,7 +416,7 @@ namespace functions {
for (int f = 0; f < tadLength; f++) {
auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
auto offset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX);
oZ[offset] = OpType::op(x[offset], oY[yOffset]);
oZ[offset] = OpType::op(x[offset], oY[yOffset], extraParams);
}
}
}
@ -430,7 +438,7 @@ namespace functions {
auto xOffset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX);
auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ);
oZ[zOffset] = OpType::op(x[xOffset], oY[yOffset]);
oZ[zOffset] = OpType::op(x[xOffset], oY[yOffset], extraParams);
}
}
}

View File

@ -40,10 +40,11 @@ static __global__ void broadcastBoolSimple(
Nd4jLong *yShapeInfo,
void *z,
Nd4jLong *zShapeInfo,
void *extraParams,
int *dimension,
int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
functions::broadcast::BroadcastBool<X, Z>::template transformCuda<OpClass>(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo,dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ);
functions::broadcast::BroadcastBool<X, Z>::template transformCuda<OpClass>(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo, extraParams, dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ);
}
//////////////////////////////////////////////////////////////////////////
@ -55,10 +56,11 @@ static __global__ void broadcastBoolInverseSimple(
Nd4jLong *yShapeInfo,
void *z,
Nd4jLong *zShapeInfo,
void *extraParams,
int *dimension,
int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
functions::broadcast::BroadcastBool<X, Z>::template transformInverseCuda<OpClass>(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo,dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ);
functions::broadcast::BroadcastBool<X, Z>::template transformInverseCuda<OpClass>(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo,extraParams,dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ);
}
namespace functions {
@ -66,15 +68,15 @@ namespace functions {
//////////////////////////////////////////////////////////////////////////
template<typename X, typename Z>
template <typename OpClass>
__host__ void BroadcastBool<X,Z>::intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
broadcastBoolSimple<X, Z, OpClass><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ);
__host__ void BroadcastBool<X,Z>::intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
broadcastBoolSimple<X, Z, OpClass><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ);
nd4j::DebugHelper::checkErrorCode(stream, "intermediateBroadcastBool(...) failed");
}
//////////////////////////////////////////////////////////////////////////
template<typename X, typename Y>
__host__ void BroadcastBool<X,Y>::execBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
DISPATCH_BY_OPNUM_TT(intermediateBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), OPS_A(BROADCAST_BOOL_OPS))
__host__ void BroadcastBool<X,Y>::execBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
DISPATCH_BY_OPNUM_TT(intermediateBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), OPS_A(BROADCAST_BOOL_OPS))
DEBUG_KERNEL(stream, opNum);
}
@ -82,15 +84,15 @@ namespace functions {
//////////////////////////////////////////////////////////////////////////
template<typename X, typename Z>
template <typename OpClass>
__host__ void BroadcastBool<X,Z>::intermediateInverseBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
broadcastBoolInverseSimple<X, Z, OpClass><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ);
__host__ void BroadcastBool<X,Z>::intermediateInverseBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
broadcastBoolInverseSimple<X, Z, OpClass><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ);
nd4j::DebugHelper::checkErrorCode(stream, "intermediateBroadcastBool(...) failed");
}
//////////////////////////////////////////////////////////////////////////
template<typename X, typename Y>
__host__ void BroadcastBool<X,Y>::execInverseBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
DISPATCH_BY_OPNUM_TT(intermediateInverseBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), OPS_A(BROADCAST_BOOL_OPS))
__host__ void BroadcastBool<X,Y>::execInverseBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
DISPATCH_BY_OPNUM_TT(intermediateInverseBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), OPS_A(BROADCAST_BOOL_OPS))
DEBUG_KERNEL(stream, opNum);
}
@ -102,6 +104,7 @@ namespace functions {
void *vx, Nd4jLong *xShapeInfo,
void *vy, Nd4jLong *yShapeInfo,
void *vz, Nd4jLong *zShapeInfo,
void *vextraParams,
int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
@ -113,6 +116,7 @@ namespace functions {
auto x = reinterpret_cast<X*>(vx);
auto y = reinterpret_cast<X*>(vy);
auto z = reinterpret_cast<Z*>(vz);
auto extraParams = reinterpret_cast<X*>(vextraParams);
//decompose in to several sub tads after
//moving all dimensions (in sorted order)
@ -140,7 +144,7 @@ namespace functions {
if(tadEWS > 0 && zEWS > 0 && xEWS > 0 && dimensionLength == 1) {
for (int i = threadIdx.x; i < tadLength; i+= blockDim.x)
rZ[i * zEWS] = OpType::op(x[i * xEWS], rY[i * tadEWS]);
rZ[i * zEWS] = OpType::op(x[i * xEWS], rY[i * tadEWS], extraParams);
}
else {
// it is expected that x and z tads and y array all have the same length
@ -149,7 +153,7 @@ namespace functions {
auto yOffset = shape::getIndexOffset(i, tadOnlyShapeInfo);
auto zOffset = shape::getIndexOffset(i, tadOnlyShapeInfoZ);
rZ[zOffset] = OpType::op(x[xOffset], rY[yOffset]);
rZ[zOffset] = OpType::op(x[xOffset], rY[yOffset], extraParams);
}
}
}
@ -162,6 +166,7 @@ namespace functions {
void *vx, Nd4jLong *xShapeInfo,
void *vy, Nd4jLong *yShapeInfo,
void *vz, Nd4jLong *zShapeInfo,
void *vextraParams,
int *dimension, int dimensionLength,
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
@ -173,6 +178,7 @@ namespace functions {
auto x = reinterpret_cast<X*>(vx);
auto y = reinterpret_cast<X*>(vy);
auto z = reinterpret_cast<Z*>(vz);
auto extraParams = reinterpret_cast<X*>(vextraParams);
//decompose in to several sub tads after
//moving all dimensions (in sorted order)
@ -208,7 +214,7 @@ namespace functions {
if(tadEWS > 0 && zEWS > 0 && yEWS > 0 && dimensionLength == 1) {
for (int i = threadIdx.x; i < tadLength; i+= blockDim.x)
rZ[i * zEWS] = OpType::op(rX[i * tadEWS], y[i * yEWS]);
rZ[i * zEWS] = OpType::op(rX[i * tadEWS], y[i * yEWS], extraParams);
}
else {
// it is expected that x and z tads and y array all have the same length
@ -217,7 +223,7 @@ namespace functions {
auto yOffset = shape::getIndexOffset(i, yShapeInfo);
auto zOffset = shape::getIndexOffset(i, tadOnlyShapeInfoZ);
rZ[zOffset] = OpType::op(rX[xOffset], y[yOffset]);
rZ[zOffset] = OpType::op(rX[xOffset], y[yOffset], extraParams);
}
}
}

View File

@ -45,13 +45,13 @@
(2, LessThan),\
(3, Epsilon),\
(4, GreaterThanOrEqual),\
(5, LessThanOrEqual),\
(5, MatchCondition) ,\
(6, NotEqualTo),\
(7, And),\
(8, Or),\
(9, Xor) ,\
(10, Not)
(10, Not) ,\
(11, LessThanOrEqual)
#define BROADCAST_OPS \
(0, Add), \
@ -198,12 +198,13 @@
(2, LessThan),\
(3, Epsilon),\
(4, GreaterThanOrEqual),\
(5, LessThanOrEqual),\
(5, MatchCondition) ,\
(6, NotEqualTo),\
(7, And),\
(8, Or),\
(9, Xor) ,\
(10, Not)
(10, Not) ,\
(11, LessThanOrEqual)
#define SCALAR_OPS \
(0, Add),\
@ -341,12 +342,13 @@
(2, LessThan),\
(3, Epsilon),\
(4, GreaterThanOrEqual),\
(5, LessThanOrEqual),\
(5, MatchCondition) ,\
(6, NotEqualTo),\
(7, And),\
(8, Or),\
(9, Xor) ,\
(10, Not)
(10, Not) ,\
(11, LessThanOrEqual)
#define PAIRWISE_TRANSFORM_OPS \
(0, Add),\

View File

@ -2302,54 +2302,66 @@ namespace simdOps {
return old + opOutput;
}
// this op return 1.0 if condition met, 0.0 otherwise
op_def static Z op(X d1, X *extraParams) {
X compare = extraParams[0];
X eps = extraParams[1];
auto mode = static_cast<int>(extraParams[2]);
//printf("value: %f; comp: %f; eps: %f; mode: %i;\n", (float) d1, (float) compare, (float) eps, mode);
switch (mode) {
case 0: // equals
return nd4j::math::nd4j_abs<X>(d1 - compare) <= eps ? 1 : 0;
case 1: // not equals
return nd4j::math::nd4j_abs<X>(d1 - compare) > eps ? 1 : 0;
case 2: // less_than
return d1 < compare ? 1 : 0;
case 3: // greater_than
return d1 > compare ? 1 : 0;
case 4: // less_or_equals_than
return d1 <= compare ? 1 : 0;
case 5: // greater_or_equals_than
return d1 >= compare ? 1 : 0;
case 6: // abs_less_than
return nd4j::math::nd4j_abs<X>(d1) < compare ? 1 : 0;
case 7: // abs_greater_than
return nd4j::math::nd4j_abs<X>(d1) > compare ? 1 : 0;
case 8: // is inf
return nd4j::math::nd4j_isinf(d1) ? 1 : 0;
case 9: // is nan
return nd4j::math::nd4j_isnan(d1) ? 1 : 0;
case 10:
return (d1 == compare) ? 1 : 0;
case 11:
return (d1 != compare) ? 1 : 0;
case 12: // abs_greater_or_equals_than
return nd4j::math::nd4j_abs<X>(d1) >= compare ? 1 : 0;
case 13: // abs_less_or_equals_than
return nd4j::math::nd4j_abs<X>(d1) <= compare ? 1 : 0;
op_def static Z op(X d1, X compare, X eps, int mode) {
switch (mode) {
case 0: // equals
return nd4j::math::nd4j_abs<X>(d1 - compare) <= eps ? 1 : 0;
case 1: // not equals
return nd4j::math::nd4j_abs<X>(d1 - compare) > eps ? 1 : 0;
case 2: // less_than
return d1 < compare ? 1 : 0;
case 3: // greater_than
return d1 > compare ? 1 : 0;
case 4: // less_or_equals_than
return d1 <= compare ? 1 : 0;
case 5: // greater_or_equals_than
return d1 >= compare ? 1 : 0;
case 6: // abs_less_than
return nd4j::math::nd4j_abs<X>(d1) < compare ? 1 : 0;
case 7: // abs_greater_than
return nd4j::math::nd4j_abs<X>(d1) > compare ? 1 : 0;
case 8: // is inf
return nd4j::math::nd4j_isinf(d1) ? 1 : 0;
case 9: // is nan
return nd4j::math::nd4j_isnan(d1) ? 1 : 0;
case 10:
return (d1 == compare) ? 1 : 0;
case 11:
return (d1 != compare) ? 1 : 0;
case 12: // abs_greater_or_equals_than
return nd4j::math::nd4j_abs<X>(d1) >= compare ? 1 : 0;
case 13: // abs_less_or_equals_than
return nd4j::math::nd4j_abs<X>(d1) <= compare ? 1 : 0;
case 14:
// isFinite
return !(nd4j::math::nd4j_isinf(d1) || nd4j::math::nd4j_isnan(d1)) ? 1 : 0;
case 15:
// isInfinite
return nd4j::math::nd4j_isinf(d1) || nd4j::math::nd4j_isnan(d1) ? 1 : 0;
default:
printf("Undefined match condition: [%i]\n", mode);
}
default:
printf("Undefined match condition: [%i]\n", mode);
}
return d1;
}
// this op return 1.0 if condition met, 0.0 otherwise
op_def static Z op(X d1, X compare, X *extraParams) {
X eps = extraParams[1];
auto mode = static_cast<int>(extraParams[0]);
return op(d1, compare, eps, mode);
}
// this op return 1.0 if condition met, 0.0 otherwise
op_def static Z op(X d1, X *extraParams) {
X compare = extraParams[0];
X eps = extraParams[1];
auto mode = static_cast<int>(extraParams[2]);
return op(d1, compare, eps, mode);
}
op_def static Z postProcess(Z reduction, Nd4jLong n, X *extraParams) {

View File

@ -1342,6 +1342,7 @@ TEST_F(CudaBasicsTests1, execBroadcastBool_1) {
nullptr, x.getShapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
nullptr, y.getShapeInfo(), y.specialBuffer(), y.specialShapeInfo(),
nullptr, z.getShapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
nullptr,
(int*)devicePtrs[0], dimensions.size(),
(Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2],
nullptr, nullptr);
@ -1400,6 +1401,7 @@ TEST_F(CudaBasicsTests1, execBroadcastBool_2) {
nullptr, x.getShapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
nullptr, y.getShapeInfo(), y.specialBuffer(), y.specialShapeInfo(),
nullptr, z.getShapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
nullptr,
(int*)devicePtrs[0], dimensions.size(),
(Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2],
nullptr, nullptr);

View File

@ -674,6 +674,7 @@ TEST_F(LegacyOpsTests, test_inverse_broadcast_2) {
x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(),
z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
nullptr,
nullptr, 0,
tadPackY.platformShapeInfo(), tadPackY.platformOffsets(),
tadPackY.platformShapeInfo(), tadPackY.platformOffsets());

View File

@ -202,6 +202,7 @@ printf("Unsupported for cuda now.\n");
nullptr, nullptr,
exp.buffer(), exp.shapeInfo(),
nullptr, nullptr,
nullptr,
dimension.buffer(), dimension.shapeInfo(),
nullptr, nullptr);
ASSERT_TRUE(exp.e<bool>(1) && !exp.e<bool>(0));

View File

@ -417,7 +417,7 @@ public class LegacyOpMapper {
case 4:
return ScalarGreaterThanOrEqual.class;
case 5:
return ScalarLessThanOrEqual.class;
return MatchCondition.class;
case 6:
return ScalarNotEquals.class;
case 7:
@ -428,6 +428,8 @@ public class LegacyOpMapper {
return ScalarXor.class;
case 10:
return ScalarNot.class;
case 11:
return ScalarLessThanOrEqual.class;
default:
throw new UnsupportedOperationException("No known scalar bool op for op number: " + opNum);
}

View File

@ -1864,7 +1864,10 @@ public abstract class BaseNDArray implements INDArray, Iterable {
@Override
public INDArray match(INDArray comp, Condition condition) {
return Nd4j.getExecutioner().exec(new MatchConditionTransform(this,comp,condition));
// TODO: obviously, we can make this broadcastable, eventually. But this will require new CustomOp based on MatchCondition
Preconditions.checkArgument(Arrays.equals(this.shape(), comp.shape()), "Shapes must be equal");
Preconditions.checkArgument(this.dataType() == comp.dataType(), "Data types bmust be equal");
return Nd4j.getExecutioner().exec(new MatchConditionTransform(this, comp, Nd4j.createUninitialized(DataType.BOOL, this.shape()), condition));
}
@Override

View File

@ -64,7 +64,7 @@ public class BroadcastLessThanOrEqual extends BaseBroadcastBoolOp {
@Override
public int opNum() {
return 5;
return 11;
}
@Override

View File

@ -55,7 +55,7 @@ public class ScalarLessThanOrEqual extends BaseScalarBoolOp {
@Override
public int opNum() {
return 5;
return 11;
}
@Override

View File

@ -53,6 +53,11 @@ public class MatchConditionTransform extends BaseTransformBoolOp {
public MatchConditionTransform() {}
public MatchConditionTransform(@NonNull INDArray x, @NonNull INDArray y, @NonNull INDArray z, @NonNull Condition condition) {
this(x, z, Nd4j.EPS_THRESHOLD, condition);
this.y = y;
}
public MatchConditionTransform(@NonNull INDArray x, @NonNull INDArray z, @NonNull Condition condition) {
this(x, z, Nd4j.EPS_THRESHOLD, condition);
}

View File

@ -25,71 +25,283 @@ import org.nd4j.linalg.factory.Nd4j;
*/
public class Conditions {
/**
* This method will create Condition that checks if value is infinite
* @return
*/
public static Condition isInfinite() {
return new IsInfinite();
}
/**
* This method will create Condition that checks if value is NaN
* @return
*/
public static Condition isNan() {
return new IsNaN();
}
/**
* This method will create Condition that checks if value is finite
* @return
*/
public static Condition isFinite() {
return new IsFinite();
}
/**
* This method will create Condition that checks if value is NOT finite
* @return
*/
public static Condition notFinite() {
return new NotFinite();
}
/**
* This method will create Condition that checks if value is two values are not equal wrt eps
*
* PLEASE NOTE: This condition should be used only with pairwise methods, i.e. INDArray.match(...)
* @return
*/
public static Condition epsNotEquals() {
// in case of pairwise MatchCondition we don't really care about number here
return epsNotEquals(0.0);
}
/**
* This method will create Condition that checks if value is two values are not equal wrt eps
*
* @return
*/
public static Condition epsNotEquals(Number value) {
return new EpsilonNotEquals(value);
}
/**
* This method will create Condition that checks if value is two values are equal wrt eps
*
* PLEASE NOTE: This condition should be used only with pairwise methods, i.e. INDArray.match(...)
* @return
*/
public static Condition epsEquals() {
// in case of pairwise MatchCondition we don't really care about number here
return epsEquals(0.0);
}
/**
* This method will create Condition that checks if value is two values are equal wrt eps
*
* @return
*/
public static Condition epsEquals(Number value) {
return epsEquals(value, Nd4j.EPS_THRESHOLD);
}
/**
* This method will create Condition that checks if value is two values are equal wrt eps
*
* @return
*/
public static Condition epsEquals(Number value, Number epsilon) {
return new EpsilonEquals(value, epsilon.doubleValue());
}
/**
* This method will create Condition that checks if value is two values are equal
*
* PLEASE NOTE: This condition should be used only with pairwise methods, i.e. INDArray.match(...)
* @return
*/
public static Condition equals() {
// in case of pairwise MatchCondition we don't really care about number here
return equals(0.0);
}
/**
* This method will create Condition that checks if value is two values are equal
*
* @return
*/
public static Condition equals(Number value) {
return new EqualsCondition(value);
}
/**
* This method will create Condition that checks if value is two values are not equal
*
* PLEASE NOTE: This condition should be used only with pairwise methods, i.e. INDArray.match(...)
* @return
*/
public static Condition notEquals() {
// in case of pairwise MatchCondition we don't really care about number here
return notEquals(0.0);
}
/**
* This method will create Condition that checks if value is two values are not equal
*
* @return
*/
public static Condition notEquals(Number value) {
return new NotEqualsCondition(value);
}
/**
* This method will create Condition that checks if value is value X is greater than value Y
*
* PLEASE NOTE: This condition should be used only with pairwise methods, i.e. INDArray.match(...)
* @return
*/
public static Condition greaterThan() {
// in case of pairwise MatchCondition we don't really care about number here
return greaterThan(0.0);
}
/**
* This method will create Condition that checks if value is value X is greater than value Y
*
* @return
*/
public static Condition greaterThan(Number value) {
return new GreaterThan(value);
}
/**
* This method will create Condition that checks if value is value X is less than value Y
*
* PLEASE NOTE: This condition should be used only with pairwise methods, i.e. INDArray.match(...)
* @return
*/
public static Condition lessThan() {
// in case of pairwise MatchCondition we don't really care about number here
return lessThan(0.0);
}
/**
* This method will create Condition that checks if value is value X is less than value Y
*
* @return
*/
public static Condition lessThan(Number value) {
return new LessThan(value);
}
/**
* This method will create Condition that checks if value is value X is less than or equal to value Y
*
* PLEASE NOTE: This condition should be used only with pairwise methods, i.e. INDArray.match(...)
* @return
*/
public static Condition lessThanOrEqual() {
// in case of pairwise MatchCondition we don't really care about number here
return lessThanOrEqual(0.0);
}
/**
* This method will create Condition that checks if value is value X is less than or equal to value Y
*
* @return
*/
public static Condition lessThanOrEqual(Number value) {
return new LessThanOrEqual(value);
}
/**
* This method will create Condition that checks if value is value X is greater than or equal to value Y
*
* PLEASE NOTE: This condition should be used only with pairwise methods, i.e. INDArray.match(...)
* @return
*/
public static Condition greaterThanOrEqual() {
// in case of pairwise MatchCondition we don't really care about number here
return greaterThanOrEqual(0.0);
}
/**
* This method will create Condition that checks if value is value X is greater than or equal to value Y
*
* @return
*/
public static Condition greaterThanOrEqual(Number value) {
return new GreaterThanOrEqual(value);
}
/**
* This method will create Condition that checks if value is value X is greater than or equal to value Y in absolute values
*
* PLEASE NOTE: This condition should be used only with pairwise methods, i.e. INDArray.match(...)
* @return
*/
public static Condition absGreaterThanOrEqual() {
// in case of pairwise MatchCondition we don't really care about number here
return absGreaterThanOrEqual(0.0);
}
/**
* This method will create Condition that checks if value is value X is greater than or equal to value Y in absolute values
*
* @return
*/
public static Condition absGreaterThanOrEqual(Number value) {
return new AbsValueGreaterOrEqualsThan(value);
}
/**
* This method will create Condition that checks if value is value X is less than or equal to value Y in absolute values
*
* PLEASE NOTE: This condition should be used only with pairwise methods, i.e. INDArray.match(...)
* @return
*/
public static Condition absLessThanOrEqual() {
// in case of pairwise MatchCondition we don't really care about number here
return absLessThanOrEqual(0.0);
}
/**
* This method will create Condition that checks if value is value X is less than or equal to value Y in absolute values
*
* @return
*/
public static Condition absLessThanOrEqual(Number value) {
return new AbsValueLessOrEqualsThan(value);
}
/**
* This method will create Condition that checks if value is value X is greater than value Y in absolute values
*
* PLEASE NOTE: This condition should be used only with pairwise methods, i.e. INDArray.match(...)
* @return
*/
public static Condition absGreaterThan() {
// in case of pairwise MatchCondition we don't really care about number here
return absGreaterThan(0.0);
}
/**
* This method will create Condition that checks if value is value X is greater than value Y in absolute values
*
* @return
*/
public static Condition absGreaterThan(Number value) {
return new AbsValueGreaterThan(value);
}
/**
* This method will create Condition that checks if value is value X is less than value Y in absolute values
*
* PLEASE NOTE: This condition should be used only with pairwise methods, i.e. INDArray.match(...)
* @return
*/
public static Condition absLessThan() {
// in case of pairwise MatchCondition we don't really care about number here
return absLessThan(0.0);
}
/**
* This method will create Condition that checks if value is value X is less than value Y in absolute values
*
* @return
*/
public static Condition absLessThan(Number value) {
return new AbsValueLessThan(value);
}

View File

@ -129,6 +129,7 @@ public interface NativeOps {
@Cast("Nd4jLong *") LongPointer resultShapeInfo,
Pointer dresult,
@Cast("Nd4jLong *") LongPointer dresultShapeInfo,
Pointer extraParams,
Pointer hDimension, @Cast("Nd4jLong *") LongPointer hDimensionShape,
Pointer dDimension, @Cast("Nd4jLong *") LongPointer dDimensionShape);

View File

@ -198,7 +198,7 @@ public class CudaExecutioner extends DefaultOpExecutioner {
null, (LongPointer) AtomicAllocator.getInstance().getHostPointer(op.x().shapeInfoDataBuffer()), x, (LongPointer) xShapeInfo,
null, (LongPointer) AtomicAllocator.getInstance().getHostPointer(op.y().shapeInfoDataBuffer()), y, (LongPointer) AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(),context),
null, (LongPointer) AtomicAllocator.getInstance().getHostPointer(op.z().shapeInfoDataBuffer()), z, (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context),
null,
null, null,
(LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(),
AtomicAllocator.getInstance().getPointer(op.dimensions(), context),
null);
@ -805,7 +805,7 @@ public class CudaExecutioner extends DefaultOpExecutioner {
null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo,
null, (LongPointer) hostYShapeInfo, y, (LongPointer) yShapeInfo,
null, (LongPointer) hostZShapeInfo, z, (LongPointer) zShapeInfo,
null,
null, null,
(LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(),
AtomicAllocator.getInstance().getPointer(op.dimensions(), context),
null);

View File

@ -916,6 +916,7 @@ public native void execBroadcastBool(
Pointer dY, @Cast("Nd4jLong*") LongPointer dYShapeInfo,
Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo,
Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo,
Pointer extraParams,
Pointer hDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape,
Pointer dDimension, @Cast("Nd4jLong*") LongPointer dDimensionShape);
public native void execBroadcastBool(
@ -927,6 +928,7 @@ public native void execBroadcastBool(
Pointer dY, @Cast("Nd4jLong*") LongBuffer dYShapeInfo,
Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo,
Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo,
Pointer extraParams,
Pointer hDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape,
Pointer dDimension, @Cast("Nd4jLong*") LongBuffer dDimensionShape);
public native void execBroadcastBool(
@ -938,6 +940,7 @@ public native void execBroadcastBool(
Pointer dY, @Cast("Nd4jLong*") long[] dYShapeInfo,
Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo,
Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo,
Pointer extraParams,
Pointer hDimension, @Cast("Nd4jLong*") long[] hDimensionShape,
Pointer dDimension, @Cast("Nd4jLong*") long[] dDimensionShape);

View File

@ -967,6 +967,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
null, null,
op.z().data().addressPointer(), (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(),
null, null,
null,
op.dimensions().data().addressPointer(),
(LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(),
null,

View File

@ -916,6 +916,7 @@ public native void execBroadcastBool(
Pointer dY, @Cast("Nd4jLong*") LongPointer dYShapeInfo,
Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo,
Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo,
Pointer extraParams,
Pointer hDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape,
Pointer dDimension, @Cast("Nd4jLong*") LongPointer dDimensionShape);
public native void execBroadcastBool(
@ -927,6 +928,7 @@ public native void execBroadcastBool(
Pointer dY, @Cast("Nd4jLong*") LongBuffer dYShapeInfo,
Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo,
Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo,
Pointer extraParams,
Pointer hDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape,
Pointer dDimension, @Cast("Nd4jLong*") LongBuffer dDimensionShape);
public native void execBroadcastBool(
@ -938,6 +940,7 @@ public native void execBroadcastBool(
Pointer dY, @Cast("Nd4jLong*") long[] dYShapeInfo,
Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo,
Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo,
Pointer extraParams,
Pointer hDimension, @Cast("Nd4jLong*") long[] hDimensionShape,
Pointer dDimension, @Cast("Nd4jLong*") long[] dDimensionShape);

View File

@ -46,6 +46,7 @@ import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.nativeblas.NativeOpsHolder;
import java.util.ArrayList;
@ -1087,6 +1088,24 @@ public class CustomOpsTests extends BaseNd4jTest {
assertArrayEquals(new long[]{1,10, 2}, lsd.get(0).getShape());
}
@Test
public void testMatch_1() {
INDArray x = Nd4j.ones(DataType.FLOAT, 3,3);
INDArray y = Nd4j.linspace(DataType.FLOAT, -5, 9, 1).reshape(3, 3);
val c = Conditions.equals(0.0);
System.out.println("Y:\n" + y);
INDArray z = x.match(y, c);
INDArray exp = Nd4j.createFromArray(new boolean[][]{
{false, false, false},
{false, false, false},
{true, false, false}
});
assertEquals(exp, z);
}
@Test
public void testCreateOp_1() {