Few fixes (#66)
* skip legacy transforms execution in case of empty input arrays Signed-off-by: raver119 <raver119@gmail.com> * - BroadcastBool ops now accept extraParams to make MatchCondition possible - TrueBroadcastHelper now uses samediff::threads Signed-off-by: raver119 <raver119@gmail.com> * java side Signed-off-by: raver119 <raver119@gmail.com> * trigger jenkins Signed-off-by: raver119 <raver119@gmail.com> * update LessThanOrEqual opNum mapping Signed-off-by: raver119 <raver119@gmail.com> * update LessThanOrEqual opNum mapping Signed-off-by: raver119 <raver119@gmail.com>master
parent
83cb0d9329
commit
064a56ccf1
|
@ -2772,9 +2772,9 @@ void NDArray::applyBroadcast(nd4j::broadcast::BoolOps op, const std::vector<int>
|
||||||
// TODO: eventually we want separate tads here
|
// TODO: eventually we want separate tads here
|
||||||
NDArray::prepareSpecialUse({result}, {this, other});
|
NDArray::prepareSpecialUse({result}, {this, other});
|
||||||
if(max == this)
|
if(max == this)
|
||||||
NativeOpExecutioner::execBroadcastBool( getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), copy.data(), (int)copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets());
|
NativeOpExecutioner::execBroadcastBool( getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), nullptr, copy.data(), (int)copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets());
|
||||||
else
|
else
|
||||||
NativeOpExecutioner::execInverseBroadcastBool(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), copy.data(), (int)copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets());
|
NativeOpExecutioner::execInverseBroadcastBool(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), nullptr, copy.data(), (int)copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets());
|
||||||
registerSpecialUse({result}, {this, other});
|
registerSpecialUse({result}, {this, other});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -284,6 +284,7 @@ static void execScalarInt(nd4j::LaunchContext *lc,
|
||||||
void *dY, Nd4jLong *dYShapeInfo,
|
void *dY, Nd4jLong *dYShapeInfo,
|
||||||
void *hZ, Nd4jLong *hZShapeInfo,
|
void *hZ, Nd4jLong *hZShapeInfo,
|
||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
void *dZ, Nd4jLong *dZShapeInfo,
|
||||||
|
void *extraParams,
|
||||||
int *dimension, int dimensionLength,
|
int *dimension, int dimensionLength,
|
||||||
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
|
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
|
||||||
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ);
|
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ);
|
||||||
|
@ -296,6 +297,7 @@ static void execScalarInt(nd4j::LaunchContext *lc,
|
||||||
void *dY, Nd4jLong *dYShapeInfo,
|
void *dY, Nd4jLong *dYShapeInfo,
|
||||||
void *result, Nd4jLong *resultShapeInfo,
|
void *result, Nd4jLong *resultShapeInfo,
|
||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
void *dZ, Nd4jLong *dZShapeInfo,
|
||||||
|
void *extraParams,
|
||||||
int *dimension, int dimensionLength,
|
int *dimension, int dimensionLength,
|
||||||
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
|
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
|
||||||
Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
|
Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
|
||||||
|
|
|
@ -179,6 +179,7 @@ ND4J_EXPORT void execBroadcastBool(
|
||||||
void *dY, Nd4jLong *dYShapeInfo,
|
void *dY, Nd4jLong *dYShapeInfo,
|
||||||
void *hZ, Nd4jLong *hZShapeInfo,
|
void *hZ, Nd4jLong *hZShapeInfo,
|
||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
void *dZ, Nd4jLong *dZShapeInfo,
|
||||||
|
void *extraParams,
|
||||||
void *hDimension, Nd4jLong *hDimensionShape,
|
void *hDimension, Nd4jLong *hDimensionShape,
|
||||||
void *dDimension, Nd4jLong *dDimensionShape);
|
void *dDimension, Nd4jLong *dDimensionShape);
|
||||||
|
|
||||||
|
|
|
@ -156,6 +156,9 @@ void NativeOpExecutioner::execBroadcast(nd4j::LaunchContext *lc,
|
||||||
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
|
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
|
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo))
|
||||||
|
return;
|
||||||
|
|
||||||
#ifdef __ND4J_EXPERIMENTAL__
|
#ifdef __ND4J_EXPERIMENTAL__
|
||||||
BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::broadcast::Broadcast, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, LIBND4J_TYPES);
|
BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::broadcast::Broadcast, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||||
#else
|
#else
|
||||||
|
@ -187,7 +190,8 @@ void NativeOpExecutioner::execInverseBroadcast(nd4j::LaunchContext *lc,
|
||||||
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
|
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
|
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo))
|
||||||
|
return;
|
||||||
|
|
||||||
if (!nd4j::Environment::getInstance()->isExperimentalBuild())
|
if (!nd4j::Environment::getInstance()->isExperimentalBuild())
|
||||||
if ((yType != xType && yType != nd4j::DataType::BOOL) || xType != zType)
|
if ((yType != xType && yType != nd4j::DataType::BOOL) || xType != zType)
|
||||||
|
@ -219,6 +223,7 @@ void NativeOpExecutioner::execBroadcastBool(nd4j::LaunchContext *lc,
|
||||||
void *dY, Nd4jLong *dYShapeInfo,
|
void *dY, Nd4jLong *dYShapeInfo,
|
||||||
void *hZ, Nd4jLong *hZShapeInfo,
|
void *hZ, Nd4jLong *hZShapeInfo,
|
||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
void *dZ, Nd4jLong *dZShapeInfo,
|
||||||
|
void *extraParams,
|
||||||
int *dimension, int dimensionLength,
|
int *dimension, int dimensionLength,
|
||||||
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
|
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
|
||||||
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) {
|
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) {
|
||||||
|
@ -228,8 +233,11 @@ void NativeOpExecutioner::execBroadcastBool(nd4j::LaunchContext *lc,
|
||||||
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
|
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
|
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo))
|
||||||
|
return;
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ, start, stop), LIBND4J_TYPES, BOOL_TYPES);
|
BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ, start, stop), LIBND4J_TYPES, BOOL_TYPES);
|
||||||
};
|
};
|
||||||
|
|
||||||
auto xLen = shape::length(hXShapeInfo);
|
auto xLen = shape::length(hXShapeInfo);
|
||||||
|
@ -247,22 +255,24 @@ void NativeOpExecutioner::execInverseBroadcastBool(nd4j::LaunchContext *lc,
|
||||||
void *dY, Nd4jLong *dYShapeInfo,
|
void *dY, Nd4jLong *dYShapeInfo,
|
||||||
void *hZ, Nd4jLong *hZShapeInfo,
|
void *hZ, Nd4jLong *hZShapeInfo,
|
||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
void *dZ, Nd4jLong *dZShapeInfo,
|
||||||
|
void *extraParams,
|
||||||
int *dimension, int dimensionLength,
|
int *dimension, int dimensionLength,
|
||||||
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
|
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
|
||||||
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) {
|
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
|
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
|
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo))
|
||||||
|
return;
|
||||||
|
|
||||||
if (!nd4j::Environment::getInstance()->isExperimentalBuild())
|
if (!nd4j::Environment::getInstance()->isExperimentalBuild())
|
||||||
if (yType != xType || nd4j::DataType::BOOL != zType)
|
if (yType != xType || nd4j::DataType::BOOL != zType)
|
||||||
throw nd4j::datatype_exception::build("NativeOps::execInverseBroadcastBool both operands must have same data type", xType, yType);
|
throw nd4j::datatype_exception::build("NativeOps::execInverseBroadcastBool both operands must have same data type", xType, yType);
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::execInverse(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ, start, stop), LIBND4J_TYPES, BOOL_TYPES);
|
BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::execInverse(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ, start, stop), LIBND4J_TYPES, BOOL_TYPES);
|
||||||
};
|
};
|
||||||
|
|
||||||
auto xLen = shape::length(hXShapeInfo);
|
auto xLen = shape::length(hXShapeInfo);
|
||||||
|
@ -292,6 +302,9 @@ void NativeOpExecutioner::execBroadcastInt(nd4j::LaunchContext *lc,
|
||||||
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
|
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
|
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo))
|
||||||
|
return;
|
||||||
|
|
||||||
if (xType != yType || xType != zType)
|
if (xType != yType || xType != zType)
|
||||||
throw nd4j::datatype_exception::build("NativeOpExecutioner::execBroadcastInt", zType, xType, yType);
|
throw nd4j::datatype_exception::build("NativeOpExecutioner::execBroadcastInt", zType, xType, yType);
|
||||||
|
|
||||||
|
@ -321,12 +334,13 @@ void NativeOpExecutioner::execInverseBroadcastInt(nd4j::LaunchContext *lc,
|
||||||
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
|
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
|
||||||
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) {
|
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
|
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
|
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo))
|
||||||
|
return;
|
||||||
|
|
||||||
if (xType != yType || xType != zType)
|
if (xType != yType || xType != zType)
|
||||||
throw nd4j::datatype_exception::build("NativeOpExecutioner::execInverseBroadcastInt", zType, xType, yType);
|
throw nd4j::datatype_exception::build("NativeOpExecutioner::execInverseBroadcastInt", zType, xType, yType);
|
||||||
|
|
||||||
|
@ -367,11 +381,13 @@ void NativeOpExecutioner::execPairwiseTransform(nd4j::LaunchContext *lc,
|
||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
void *dZ, Nd4jLong *dZShapeInfo,
|
||||||
void *extraParams) {
|
void *extraParams) {
|
||||||
|
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
|
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
|
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo))
|
||||||
|
return;
|
||||||
|
|
||||||
#ifdef __ND4J_EXPERIMENTAL__
|
#ifdef __ND4J_EXPERIMENTAL__
|
||||||
BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::pairwise_transforms::PairWiseTransform, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, extraParams), LIBND4J_TYPES, LIBND4J_TYPES);
|
BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::pairwise_transforms::PairWiseTransform, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, extraParams), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||||
#else
|
#else
|
||||||
|
@ -403,6 +419,9 @@ void NativeOpExecutioner::execPairwiseBoolTransform(nd4j::LaunchContext *lc,
|
||||||
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
|
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
|
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo))
|
||||||
|
return;
|
||||||
|
|
||||||
if (xType != yType)
|
if (xType != yType)
|
||||||
throw nd4j::datatype_exception::build("NativeOpExecutioner::execPairwiseBoolTransform", xType, yType);
|
throw nd4j::datatype_exception::build("NativeOpExecutioner::execPairwiseBoolTransform", xType, yType);
|
||||||
|
|
||||||
|
@ -429,11 +448,13 @@ void NativeOpExecutioner::execPairwiseIntTransform(nd4j::LaunchContext *lc,
|
||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
void *dZ, Nd4jLong *dZShapeInfo,
|
||||||
void *extraParams) {
|
void *extraParams) {
|
||||||
|
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
|
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
|
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo))
|
||||||
|
return;
|
||||||
|
|
||||||
if (xType != yType || xType != zType)
|
if (xType != yType || xType != zType)
|
||||||
throw nd4j::datatype_exception::build("NativeOpExecutioner::execPairwiseIntTransform", zType, xType, yType);
|
throw nd4j::datatype_exception::build("NativeOpExecutioner::execPairwiseIntTransform", zType, xType, yType);
|
||||||
|
|
||||||
|
@ -837,11 +858,13 @@ void NativeOpExecutioner::execScalar(nd4j::LaunchContext *lc,
|
||||||
void *dScalar, Nd4jLong *dScalarShapeInfo,
|
void *dScalar, Nd4jLong *dScalarShapeInfo,
|
||||||
void *extraParams, bool allowParallelism) {
|
void *extraParams, bool allowParallelism) {
|
||||||
|
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo);
|
auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
|
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo))
|
||||||
|
return;
|
||||||
|
|
||||||
#ifdef __ND4J_EXPERIMENTAL__
|
#ifdef __ND4J_EXPERIMENTAL__
|
||||||
BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::scalar::ScalarTransform, ::transform(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, hScalar, extraParams), LIBND4J_TYPES, LIBND4J_TYPES);
|
BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::scalar::ScalarTransform, ::transform(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, hScalar, extraParams), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||||
#else
|
#else
|
||||||
|
@ -872,11 +895,13 @@ void NativeOpExecutioner::execScalar(nd4j::LaunchContext *lc,
|
||||||
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
|
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
|
||||||
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) {
|
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) {
|
||||||
|
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo);
|
auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
|
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo))
|
||||||
|
return;
|
||||||
|
|
||||||
#ifdef __ND4J_EXPERIMENTAL__
|
#ifdef __ND4J_EXPERIMENTAL__
|
||||||
BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::scalar::ScalarTransform, ::transform(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, hScalars, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, LIBND4J_TYPES);
|
BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::scalar::ScalarTransform, ::transform(opNum, hX, hXShapeInfo, extraParams, hZ, hZShapeInfo, hScalars, dimension, dimensionLength, tadShapeInfo, tadOffsets, tadShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||||
#else
|
#else
|
||||||
|
@ -904,12 +929,13 @@ void NativeOpExecutioner::execScalarBool(nd4j::LaunchContext *lc,
|
||||||
void *dScalar, Nd4jLong *dSscalarShapeInfo,
|
void *dScalar, Nd4jLong *dSscalarShapeInfo,
|
||||||
void *extraParams, bool allowParallelism) {
|
void *extraParams, bool allowParallelism) {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto yType = nd4j::ArrayOptions::dataType(hSscalarShapeInfo);
|
auto yType = nd4j::ArrayOptions::dataType(hSscalarShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
|
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hSscalarShapeInfo))
|
||||||
|
return;
|
||||||
|
|
||||||
if (xType != yType)
|
if (xType != yType)
|
||||||
throw nd4j::datatype_exception::build("NativeOpExecutioner::execScalarBool", xType, yType);
|
throw nd4j::datatype_exception::build("NativeOpExecutioner::execScalarBool", xType, yType);
|
||||||
|
|
||||||
|
@ -939,11 +965,13 @@ void NativeOpExecutioner::execScalarBool(nd4j::LaunchContext *lc,
|
||||||
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
|
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
|
||||||
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) {
|
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) {
|
||||||
|
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo);
|
auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
|
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo))
|
||||||
|
return;
|
||||||
|
|
||||||
if (xType != yType)
|
if (xType != yType)
|
||||||
throw nd4j::datatype_exception::build("NativeOpExecutioner::execScalarBool", xType, yType);
|
throw nd4j::datatype_exception::build("NativeOpExecutioner::execScalarBool", xType, yType);
|
||||||
|
|
||||||
|
@ -969,12 +997,13 @@ void NativeOpExecutioner::execScalarInt(nd4j::LaunchContext *lc,
|
||||||
void *dScalar, Nd4jLong *dSscalarShapeInfo,
|
void *dScalar, Nd4jLong *dSscalarShapeInfo,
|
||||||
void *extraParams, bool allowParallelism) {
|
void *extraParams, bool allowParallelism) {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto yType = nd4j::ArrayOptions::dataType(hSscalarShapeInfo);
|
auto yType = nd4j::ArrayOptions::dataType(hSscalarShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
|
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hSscalarShapeInfo))
|
||||||
|
return;
|
||||||
|
|
||||||
if (xType != yType || xType != zType)
|
if (xType != yType || xType != zType)
|
||||||
throw nd4j::datatype_exception::build("NativeOpExecutioner::execScalarInt", xType, yType);
|
throw nd4j::datatype_exception::build("NativeOpExecutioner::execScalarInt", xType, yType);
|
||||||
|
|
||||||
|
@ -1004,11 +1033,13 @@ void NativeOpExecutioner::execScalarInt(nd4j::LaunchContext *lc,
|
||||||
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
|
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
|
||||||
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) {
|
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ) {
|
||||||
|
|
||||||
|
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo);
|
auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
|
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo))
|
||||||
|
return;
|
||||||
|
|
||||||
if (xType != yType || xType != zType)
|
if (xType != yType || xType != zType)
|
||||||
throw nd4j::datatype_exception::build("NativeOpExecutioner::execScalarInt", xType, yType);
|
throw nd4j::datatype_exception::build("NativeOpExecutioner::execScalarInt", xType, yType);
|
||||||
|
|
||||||
|
@ -1126,6 +1157,9 @@ void NativeOpExecutioner::execTransformFloat(nd4j::LaunchContext *lc,
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
|
if (shape::isEmpty(hXShapeInfo))
|
||||||
|
return;
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_DO {
|
auto func = PRAGMA_THREADS_DO {
|
||||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformFloat, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES, FLOAT_TYPES);
|
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformFloat, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES, FLOAT_TYPES);
|
||||||
};
|
};
|
||||||
|
@ -1145,6 +1179,9 @@ void NativeOpExecutioner::execTransformBool(nd4j::LaunchContext *lc,
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
|
if (shape::isEmpty(hXShapeInfo))
|
||||||
|
return;
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_DO {
|
auto func = PRAGMA_THREADS_DO {
|
||||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformBool, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES, BOOL_TYPES);
|
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformBool, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES, BOOL_TYPES);
|
||||||
};
|
};
|
||||||
|
@ -1164,6 +1201,9 @@ void NativeOpExecutioner::execTransformAny(nd4j::LaunchContext *lc,
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
|
if (shape::isEmpty(hXShapeInfo))
|
||||||
|
return;
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_DO {
|
auto func = PRAGMA_THREADS_DO {
|
||||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformAny, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES, LIBND4J_TYPES);
|
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformAny, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||||
};
|
};
|
||||||
|
@ -1183,6 +1223,9 @@ void NativeOpExecutioner::execTransformSame(nd4j::LaunchContext *lc,
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
|
if (shape::isEmpty(hXShapeInfo))
|
||||||
|
return;
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_DO {
|
auto func = PRAGMA_THREADS_DO {
|
||||||
BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformSame, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES);
|
BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformSame, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), LIBND4J_TYPES);
|
||||||
};
|
};
|
||||||
|
@ -1202,6 +1245,9 @@ void NativeOpExecutioner::execTransformStrict(nd4j::LaunchContext *lc,
|
||||||
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
auto xType = nd4j::ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
|
if (shape::isEmpty(hXShapeInfo))
|
||||||
|
return;
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_DO {
|
auto func = PRAGMA_THREADS_DO {
|
||||||
BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformStrict, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), FLOAT_TYPES);
|
BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformStrict, ::exec(opNum, hX, hXShapeInfo, hZ, hZShapeInfo, extraParams, thread_id, numThreads), FLOAT_TYPES);
|
||||||
};
|
};
|
||||||
|
|
|
@ -231,8 +231,9 @@ void execBroadcastBool(Nd4jPointer *extraPointers,
|
||||||
void *dY, Nd4jLong *dYShapeInfo,
|
void *dY, Nd4jLong *dYShapeInfo,
|
||||||
void *hZ, Nd4jLong *hZShapeInfo,
|
void *hZ, Nd4jLong *hZShapeInfo,
|
||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
void *dZ, Nd4jLong *dZShapeInfo,
|
||||||
void *hDimension, Nd4jLong *hDimensionShape,
|
void *extraParams,
|
||||||
void *dDimension, Nd4jLong *dDimensionShape) {
|
void *hDimension, Nd4jLong *hDimensionShape,
|
||||||
|
void *dDimension, Nd4jLong *dDimensionShape) {
|
||||||
try {
|
try {
|
||||||
auto dimension = reinterpret_cast<int *>(hDimension);
|
auto dimension = reinterpret_cast<int *>(hDimension);
|
||||||
int dimensionLength = static_cast<int>(shape::length(hDimensionShape));
|
int dimensionLength = static_cast<int>(shape::length(hDimensionShape));
|
||||||
|
@ -259,6 +260,7 @@ void execBroadcastBool(Nd4jPointer *extraPointers,
|
||||||
dYShapeInfo,
|
dYShapeInfo,
|
||||||
hZ, hZShapeInfo,
|
hZ, hZShapeInfo,
|
||||||
dZ, dZShapeInfo,
|
dZ, dZShapeInfo,
|
||||||
|
extraParams,
|
||||||
dimension,
|
dimension,
|
||||||
dimensionLength, hTADShapeInfo, hTADOffsets, hTADShapeInfoZ,
|
dimensionLength, hTADShapeInfo, hTADOffsets, hTADShapeInfoZ,
|
||||||
hTADOffsetsZ);
|
hTADOffsetsZ);
|
||||||
|
|
|
@ -101,6 +101,9 @@ void NativeOpExecutioner::execPairwiseTransform(nd4j::LaunchContext *lc,
|
||||||
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
|
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
|
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo))
|
||||||
|
return;
|
||||||
|
|
||||||
if (xType != zType && yType != zType)
|
if (xType != zType && yType != zType)
|
||||||
throw std::runtime_error("NativeOpExecutioner::execPairwiseTransform requires Z operand to have either X or Y type");
|
throw std::runtime_error("NativeOpExecutioner::execPairwiseTransform requires Z operand to have either X or Y type");
|
||||||
if (lc == nullptr)
|
if (lc == nullptr)
|
||||||
|
@ -139,6 +142,9 @@ void NativeOpExecutioner::execPairwiseBoolTransform( nd4j::LaunchContext *lc,
|
||||||
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
|
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
|
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo))
|
||||||
|
return;
|
||||||
|
|
||||||
if (!DataTypeUtils::isB(zType))
|
if (!DataTypeUtils::isB(zType))
|
||||||
throw nd4j::datatype_exception::build("NativeOpExecutioner::execPairwiseBoolTransform wrong Z operand data type", nd4j::DataType::BOOL, zType);
|
throw nd4j::datatype_exception::build("NativeOpExecutioner::execPairwiseBoolTransform wrong Z operand data type", nd4j::DataType::BOOL, zType);
|
||||||
|
|
||||||
|
@ -172,6 +178,9 @@ void NativeOpExecutioner::execPairwiseIntTransform( nd4j::LaunchContext *lc,
|
||||||
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
|
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
|
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo))
|
||||||
|
return;
|
||||||
|
|
||||||
if (!DataTypeUtils::isZ(zType))
|
if (!DataTypeUtils::isZ(zType))
|
||||||
throw nd4j::datatype_exception::build("NativeOpExecutioner::execPairwiseIntTransform wrong Z operand data type", nd4j::DataType::BOOL, zType);
|
throw nd4j::datatype_exception::build("NativeOpExecutioner::execPairwiseIntTransform wrong Z operand data type", nd4j::DataType::BOOL, zType);
|
||||||
|
|
||||||
|
@ -223,6 +232,7 @@ void NativeOpExecutioner::execBroadcastBool(nd4j::LaunchContext *lc,
|
||||||
void *dY, Nd4jLong *dYShapeInfo,
|
void *dY, Nd4jLong *dYShapeInfo,
|
||||||
void *hZ, Nd4jLong *hZShapeInfo,
|
void *hZ, Nd4jLong *hZShapeInfo,
|
||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
void *dZ, Nd4jLong *dZShapeInfo,
|
||||||
|
void *extraParams,
|
||||||
int *dimension, int dimensionLength,
|
int *dimension, int dimensionLength,
|
||||||
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
|
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
|
||||||
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) {
|
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) {
|
||||||
|
@ -233,6 +243,9 @@ void NativeOpExecutioner::execBroadcastBool(nd4j::LaunchContext *lc,
|
||||||
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
|
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
|
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo))
|
||||||
|
return;
|
||||||
|
|
||||||
if (!DataTypeUtils::isB(zType))
|
if (!DataTypeUtils::isB(zType))
|
||||||
throw std::runtime_error("NativeOpExecutioner::execBroadcastBool requires Z operand to have BOOL type");
|
throw std::runtime_error("NativeOpExecutioner::execBroadcastBool requires Z operand to have BOOL type");
|
||||||
|
|
||||||
|
@ -244,7 +257,7 @@ void NativeOpExecutioner::execBroadcastBool(nd4j::LaunchContext *lc,
|
||||||
|
|
||||||
dim3 launchDims(256, 256, 1024);
|
dim3 launchDims(256, 256, 1024);
|
||||||
|
|
||||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::execBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, BOOL_TYPES)
|
BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::execBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, BOOL_TYPES)
|
||||||
|
|
||||||
// TODO: remove after the release
|
// TODO: remove after the release
|
||||||
auto res = cudaStreamSynchronize(*stream);
|
auto res = cudaStreamSynchronize(*stream);
|
||||||
|
@ -260,6 +273,7 @@ void NativeOpExecutioner::execInverseBroadcastBool(nd4j::LaunchContext *lc,
|
||||||
void *dY, Nd4jLong *dYShapeInfo,
|
void *dY, Nd4jLong *dYShapeInfo,
|
||||||
void *hZ, Nd4jLong *hZShapeInfo,
|
void *hZ, Nd4jLong *hZShapeInfo,
|
||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
void *dZ, Nd4jLong *dZShapeInfo,
|
||||||
|
void *extraParams,
|
||||||
int *dimension, int dimensionLength,
|
int *dimension, int dimensionLength,
|
||||||
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
|
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets,
|
||||||
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) {
|
Nd4jLong *tadOnlyShapeInfoZ,Nd4jLong *tadOffsetsZ) {
|
||||||
|
@ -269,18 +283,18 @@ void NativeOpExecutioner::execInverseBroadcastBool(nd4j::LaunchContext *lc,
|
||||||
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
|
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
|
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo))
|
||||||
|
return;
|
||||||
|
|
||||||
if (!DataTypeUtils::isB(zType))
|
if (!DataTypeUtils::isB(zType))
|
||||||
throw std::runtime_error("NativeOpExecutioner::execBroadcastBool requires Z operand to have BOOL type");
|
throw std::runtime_error("NativeOpExecutioner::execBroadcastBool requires Z operand to have BOOL type");
|
||||||
|
|
||||||
if (yType != xType)
|
if (yType != xType)
|
||||||
throw std::runtime_error("NativeOpExecutioner::execBroadcastBool requires both X & Y operands to have same type");
|
throw std::runtime_error("NativeOpExecutioner::execBroadcastBool requires both X & Y operands to have same type");
|
||||||
|
|
||||||
if (nd4j::Environment::getInstance()->isDebugAndVerbose())
|
|
||||||
printf("F3BI opNum:[%i]\n", opNum);
|
|
||||||
|
|
||||||
dim3 launchDims(256, 256, 1024);
|
dim3 launchDims(256, 256, 1024);
|
||||||
|
|
||||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::execInverseBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, BOOL_TYPES)
|
BUILD_DOUBLE_SELECTOR(xType, zType, functions::broadcast::BroadcastBool, ::execInverseBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, BOOL_TYPES)
|
||||||
|
|
||||||
// TODO: remove after the release
|
// TODO: remove after the release
|
||||||
auto res = cudaStreamSynchronize(*stream);
|
auto res = cudaStreamSynchronize(*stream);
|
||||||
|
@ -308,15 +322,15 @@ void NativeOpExecutioner::execBroadcastInt(nd4j::LaunchContext *lc,
|
||||||
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
|
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
|
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo))
|
||||||
|
return;
|
||||||
|
|
||||||
if (!DataTypeUtils::isZ(zType))
|
if (!DataTypeUtils::isZ(zType))
|
||||||
throw std::runtime_error("NativeOpExecutioner::execBroadcastInt requires Z operand to have INT type");
|
throw std::runtime_error("NativeOpExecutioner::execBroadcastInt requires Z operand to have INT type");
|
||||||
|
|
||||||
if (yType != xType || zType != xType)
|
if (yType != xType || zType != xType)
|
||||||
throw std::runtime_error("NativeOpExecutioner::execBroadcastInt requires both X & Y operands to have same type");
|
throw std::runtime_error("NativeOpExecutioner::execBroadcastInt requires both X & Y operands to have same type");
|
||||||
|
|
||||||
if (nd4j::Environment::getInstance()->isDebugAndVerbose())
|
|
||||||
printf("F3B opNum:[%i]\n", opNum);
|
|
||||||
|
|
||||||
dim3 launchDims(256, 256, 1024);
|
dim3 launchDims(256, 256, 1024);
|
||||||
|
|
||||||
BUILD_SINGLE_SELECTOR(xType, functions::broadcast::BroadcastInt, ::execBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), INTEGER_TYPES)
|
BUILD_SINGLE_SELECTOR(xType, functions::broadcast::BroadcastInt, ::execBroadcast(launchDims, stream, opNum, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), INTEGER_TYPES)
|
||||||
|
@ -344,6 +358,9 @@ void NativeOpExecutioner::execInverseBroadcastInt(nd4j::LaunchContext *lc,
|
||||||
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
|
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
|
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo))
|
||||||
|
return;
|
||||||
|
|
||||||
if (!DataTypeUtils::isZ(zType))
|
if (!DataTypeUtils::isZ(zType))
|
||||||
throw std::runtime_error("NativeOpExecutioner::execBroadcastInt requires Z operand to have INT type");
|
throw std::runtime_error("NativeOpExecutioner::execBroadcastInt requires Z operand to have INT type");
|
||||||
|
|
||||||
|
@ -394,8 +411,8 @@ void NativeOpExecutioner::execBroadcast(nd4j::LaunchContext *lc,
|
||||||
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
|
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
if (nd4j::Environment::getInstance()->isDebugAndVerbose())
|
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo))
|
||||||
printf("F3 opNum:[%i]\n", opNum);
|
return;
|
||||||
|
|
||||||
dim3 launchDims(256, 256, 1024);
|
dim3 launchDims(256, 256, 1024);
|
||||||
|
|
||||||
|
@ -429,8 +446,8 @@ void NativeOpExecutioner::execInverseBroadcast(nd4j::LaunchContext *lc,
|
||||||
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
|
auto yType = nd4j::ArrayOptions::dataType(hYShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
if (nd4j::Environment::getInstance()->isDebugAndVerbose())
|
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hYShapeInfo))
|
||||||
printf("F3I opNum:[%i]\n", opNum);
|
return;
|
||||||
|
|
||||||
dim3 launchDims(256, 256, 1024);
|
dim3 launchDims(256, 256, 1024);
|
||||||
|
|
||||||
|
@ -832,16 +849,21 @@ void NativeOpExecutioner::execTransformSame(nd4j::LaunchContext *lc,
|
||||||
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
||||||
|
|
||||||
auto stream = lc->getCudaStream();
|
auto stream = lc->getCudaStream();
|
||||||
dim3 launchDims(512, 512, 16384);
|
|
||||||
|
|
||||||
auto xRank = shape::rank(hXShapeInfo);
|
auto xRank = shape::rank(hXShapeInfo);
|
||||||
auto zRank = shape::rank(hZShapeInfo);
|
auto zRank = shape::rank(hZShapeInfo);
|
||||||
auto xType = ArrayOptions::dataType(hXShapeInfo);
|
auto xType = ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto zType = ArrayOptions::dataType(hZShapeInfo);
|
auto zType = ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
if (xType != zType)
|
if (shape::isEmpty(hXShapeInfo)) {
|
||||||
throw std::runtime_error("NativeOpExecutioner::execTransformSame requires X & Z to have same type");
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (xType != zType) {
|
||||||
|
throw std::runtime_error("NativeOpExecutioner::execTransformSame requires X & Z to have same type");
|
||||||
|
}
|
||||||
|
|
||||||
|
dim3 launchDims(512, 512, 16384);
|
||||||
BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformSame, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), LIBND4J_TYPES);
|
BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformSame, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), LIBND4J_TYPES);
|
||||||
|
|
||||||
// TODO: remove after the release
|
// TODO: remove after the release
|
||||||
|
@ -861,16 +883,21 @@ void NativeOpExecutioner::execTransformBool(nd4j::LaunchContext *lc,
|
||||||
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
||||||
|
|
||||||
auto stream = lc->getCudaStream();
|
auto stream = lc->getCudaStream();
|
||||||
dim3 launchDims(512, 512, 16384);
|
|
||||||
|
|
||||||
auto xRank = shape::rank(hXShapeInfo);
|
auto xRank = shape::rank(hXShapeInfo);
|
||||||
auto zRank = shape::rank(hZShapeInfo);
|
auto zRank = shape::rank(hZShapeInfo);
|
||||||
auto xType = ArrayOptions::dataType(hXShapeInfo);
|
auto xType = ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto zType = ArrayOptions::dataType(hZShapeInfo);
|
auto zType = ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
if (!DataTypeUtils::isB(zType))
|
if (shape::isEmpty(hXShapeInfo)) {
|
||||||
throw std::runtime_error("NativeOpExecutioner::execTransformBool requires Z to have same boolean type");
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!DataTypeUtils::isB(zType)) {
|
||||||
|
throw std::runtime_error("NativeOpExecutioner::execTransformBool requires Z to have same boolean type");
|
||||||
|
}
|
||||||
|
|
||||||
|
dim3 launchDims(512, 512, 16384);
|
||||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformBool, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), LIBND4J_TYPES, BOOL_TYPES);
|
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformBool, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), LIBND4J_TYPES, BOOL_TYPES);
|
||||||
|
|
||||||
// TODO: remove after the release
|
// TODO: remove after the release
|
||||||
|
@ -896,6 +923,9 @@ void NativeOpExecutioner::execTransformAny(nd4j::LaunchContext *lc,
|
||||||
auto xType = ArrayOptions::dataType(hXShapeInfo);
|
auto xType = ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto zType = ArrayOptions::dataType(hZShapeInfo);
|
auto zType = ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
|
if (shape::isEmpty(hXShapeInfo))
|
||||||
|
return;
|
||||||
|
|
||||||
dim3 launchDims(512, 512, 2048);
|
dim3 launchDims(512, 512, 2048);
|
||||||
|
|
||||||
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformAny, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), LIBND4J_TYPES, LIBND4J_TYPES);
|
BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformAny, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||||
|
@ -917,16 +947,21 @@ void NativeOpExecutioner::execTransformStrict(nd4j::LaunchContext *lc,
|
||||||
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) {
|
||||||
|
|
||||||
auto stream = lc->getCudaStream();
|
auto stream = lc->getCudaStream();
|
||||||
dim3 launchDims(512, 512, 16384);
|
|
||||||
|
|
||||||
auto xRank = shape::rank(hXShapeInfo);
|
auto xRank = shape::rank(hXShapeInfo);
|
||||||
auto zRank = shape::rank(hZShapeInfo);
|
auto zRank = shape::rank(hZShapeInfo);
|
||||||
auto xType = ArrayOptions::dataType(hXShapeInfo);
|
auto xType = ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto zType = ArrayOptions::dataType(hZShapeInfo);
|
auto zType = ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
if (xType != zType || !DataTypeUtils::isR(xType))
|
if (shape::isEmpty(hXShapeInfo)) {
|
||||||
throw datatype_exception::build("NativeOpExecutioner::execTransformStrict requires X & Z to have same floating point type", xType, zType);
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (xType != zType || !DataTypeUtils::isR(xType)) {
|
||||||
|
throw datatype_exception::build("NativeOpExecutioner::execTransformStrict requires X & Z to have same floating point type", xType, zType);
|
||||||
|
}
|
||||||
|
|
||||||
|
dim3 launchDims(512, 512, 16384);
|
||||||
BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformStrict, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), FLOAT_TYPES);
|
BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformStrict, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), FLOAT_TYPES);
|
||||||
|
|
||||||
// TODO: remove after the release
|
// TODO: remove after the release
|
||||||
|
@ -953,6 +988,9 @@ void NativeOpExecutioner::execTransformFloat(nd4j::LaunchContext *lc,
|
||||||
auto xType = ArrayOptions::dataType(hXShapeInfo);
|
auto xType = ArrayOptions::dataType(hXShapeInfo);
|
||||||
auto zType = ArrayOptions::dataType(hZShapeInfo);
|
auto zType = ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
|
if (shape::isEmpty(hXShapeInfo))
|
||||||
|
return;
|
||||||
|
|
||||||
if (!DataTypeUtils::isR(zType))
|
if (!DataTypeUtils::isR(zType))
|
||||||
throw datatype_exception::build("NativeOpExecutioner::execTransformFloat requires Z to have floating point type", zType);
|
throw datatype_exception::build("NativeOpExecutioner::execTransformFloat requires Z to have floating point type", zType);
|
||||||
|
|
||||||
|
@ -1175,6 +1213,9 @@ void NativeOpExecutioner::execScalarBool(nd4j::LaunchContext *lc,
|
||||||
auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo);
|
auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
|
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo))
|
||||||
|
return;
|
||||||
|
|
||||||
if (xType != yType )
|
if (xType != yType )
|
||||||
throw std::runtime_error("NativeOpExecutioner::execScalarBool requires X & Y to have same type");
|
throw std::runtime_error("NativeOpExecutioner::execScalarBool requires X & Y to have same type");
|
||||||
|
|
||||||
|
@ -1211,6 +1252,9 @@ void NativeOpExecutioner::execScalarBool(nd4j::LaunchContext *lc,
|
||||||
auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo);
|
auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
|
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo))
|
||||||
|
return;
|
||||||
|
|
||||||
if (xType != yType )
|
if (xType != yType )
|
||||||
throw std::runtime_error("NativeOpExecutioner::execScalarBool requires X & Y to have same type");
|
throw std::runtime_error("NativeOpExecutioner::execScalarBool requires X & Y to have same type");
|
||||||
|
|
||||||
|
@ -1244,6 +1288,9 @@ void NativeOpExecutioner::execScalarInt(nd4j::LaunchContext *lc,
|
||||||
auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo);
|
auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
|
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo))
|
||||||
|
return;
|
||||||
|
|
||||||
if (xType != yType || zType != xType)
|
if (xType != yType || zType != xType)
|
||||||
throw std::runtime_error("NativeOpExecutioner::execScalarInt requires X & Y to have same type");
|
throw std::runtime_error("NativeOpExecutioner::execScalarInt requires X & Y to have same type");
|
||||||
|
|
||||||
|
@ -1280,6 +1327,9 @@ void NativeOpExecutioner::execScalarInt(nd4j::LaunchContext *lc,
|
||||||
auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo);
|
auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
|
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo))
|
||||||
|
return;
|
||||||
|
|
||||||
if (xType != yType || zType != xType)
|
if (xType != yType || zType != xType)
|
||||||
throw std::runtime_error("NativeOpExecutioner::execScalarInt requires X & Y to have same type");
|
throw std::runtime_error("NativeOpExecutioner::execScalarInt requires X & Y to have same type");
|
||||||
|
|
||||||
|
@ -1313,6 +1363,9 @@ void NativeOpExecutioner::execScalar(nd4j::LaunchContext *lc,
|
||||||
auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo);
|
auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
|
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo))
|
||||||
|
return;
|
||||||
|
|
||||||
|
|
||||||
#ifdef __ND4J_EXPERIMENTAL__
|
#ifdef __ND4J_EXPERIMENTAL__
|
||||||
BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::scalar::ScalarTransform, ::executeCudaShaped(launchDims, stream, opNum, dX, dXShapeInfo, hXShapeInfo, dZ, dZShapeInfo, hZShapeInfo, dScalar, extraParams), LIBND4J_TYPES, LIBND4J_TYPES);
|
BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::scalar::ScalarTransform, ::executeCudaShaped(launchDims, stream, opNum, dX, dXShapeInfo, hXShapeInfo, dZ, dZShapeInfo, hZShapeInfo, dScalar, extraParams), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||||
|
@ -1346,6 +1399,9 @@ void NativeOpExecutioner::execScalar(nd4j::LaunchContext *lc,
|
||||||
auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo);
|
auto yType = nd4j::ArrayOptions::dataType(hScalarShapeInfo);
|
||||||
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo);
|
||||||
|
|
||||||
|
if (shape::isEmpty(hXShapeInfo) || shape::isEmpty(hScalarShapeInfo))
|
||||||
|
return;
|
||||||
|
|
||||||
dim3 launchDims(256, 256, 16384);
|
dim3 launchDims(256, 256, 16384);
|
||||||
|
|
||||||
#ifdef __ND4J_EXPERIMENTAL__
|
#ifdef __ND4J_EXPERIMENTAL__
|
||||||
|
|
|
@ -294,6 +294,7 @@ void execBroadcastBool(Nd4jPointer *extraPointers,
|
||||||
void *dY, Nd4jLong *dYShapeInfo,
|
void *dY, Nd4jLong *dYShapeInfo,
|
||||||
void *hZ, Nd4jLong *hZShapeInfo,
|
void *hZ, Nd4jLong *hZShapeInfo,
|
||||||
void *dZ, Nd4jLong *dZShapeInfo,
|
void *dZ, Nd4jLong *dZShapeInfo,
|
||||||
|
void *extraParams,
|
||||||
void *hDimension, Nd4jLong *hDimensionShape,
|
void *hDimension, Nd4jLong *hDimensionShape,
|
||||||
void *dDimension, Nd4jLong *dDimensionShape) {
|
void *dDimension, Nd4jLong *dDimensionShape) {
|
||||||
try {
|
try {
|
||||||
|
@ -313,7 +314,7 @@ void execBroadcastBool(Nd4jPointer *extraPointers,
|
||||||
|
|
||||||
LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]);
|
LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]);
|
||||||
NativeOpExecutioner::execBroadcastBool(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, hY, hYShapeInfo, dY,
|
NativeOpExecutioner::execBroadcastBool(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, hY, hYShapeInfo, dY,
|
||||||
dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, dimension,
|
dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, extraParams, dimension,
|
||||||
dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ,
|
dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ,
|
||||||
tadOffsetsZ);
|
tadOffsetsZ);
|
||||||
} catch (std::exception &e) {
|
} catch (std::exception &e) {
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
|
|
||||||
#include <TrueBroadcastHelper.h>
|
#include <TrueBroadcastHelper.h>
|
||||||
#include <ops/ops.h>
|
#include <ops/ops.h>
|
||||||
|
#include <execution/Threads.h>
|
||||||
|
|
||||||
using namespace simdOps;
|
using namespace simdOps;
|
||||||
|
|
||||||
|
@ -47,36 +48,39 @@ void TrueBroadcastHelper<X, Y, Z>::exec(const NDArray& xArr, const NDArray& yArr
|
||||||
|
|
||||||
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
|
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(zLen > Environment::getInstance()->elementwiseThreshold()) firstprivate(xCoords, yCoords, zCoords))
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
for (Nd4jLong i = 0; i < zLen; ++i) {
|
for (auto i = start; i < stop; ++i) {
|
||||||
|
|
||||||
shape::index2coords(i, zShapeInfo, zCoords.data());
|
shape::index2coords(i, zShapeInfo, zCoords.data());
|
||||||
|
|
||||||
for(int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) {
|
for (int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) {
|
||||||
|
|
||||||
if(ix >= 0) {
|
if (ix >= 0) {
|
||||||
if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) {
|
if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) {
|
||||||
xCoords[ix--] = zCoords[iz];
|
xCoords[ix--] = zCoords[iz];
|
||||||
} else {
|
} else {
|
||||||
xCoords[ix--] = 0;
|
xCoords[ix--] = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (iy >= 0) {
|
||||||
|
if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) {
|
||||||
|
yCoords[iy--] = zCoords[iz];
|
||||||
|
} else {
|
||||||
|
yCoords[iy--] = 0;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if(iy >= 0) {
|
const auto xOffset = shape::getOffset(xShapeInfo, xCoords.data());
|
||||||
if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) {
|
const auto yOffset = shape::getOffset(yShapeInfo, yCoords.data());
|
||||||
yCoords[iy--] = zCoords[iz];
|
const auto zOffset = shape::getOffset(zShapeInfo, zCoords.data());
|
||||||
} else {
|
|
||||||
yCoords[iy--] = 0;
|
z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
const auto xOffset = shape::getOffset(xShapeInfo, xCoords.data());
|
samediff::Threads::parallel_for(func, 0, zLen);
|
||||||
const auto yOffset = shape::getOffset(yShapeInfo, yCoords.data());
|
|
||||||
const auto zOffset = shape::getOffset(zShapeInfo, zCoords.data());
|
|
||||||
|
|
||||||
z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename X, typename Y, typename Z>
|
template <typename X, typename Y, typename Z>
|
||||||
|
@ -103,38 +107,40 @@ void TrueBroadcastBoolHelper<X, Z>::exec(const NDArray& xArr, const NDArray& yAr
|
||||||
|
|
||||||
const Nd4jLong zLen = zArr.lengthOf();
|
const Nd4jLong zLen = zArr.lengthOf();
|
||||||
|
|
||||||
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
|
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
|
||||||
|
for (auto i = start; i < stop; ++i) {
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(zLen > Environment::getInstance()->elementwiseThreshold()) firstprivate(xCoords, yCoords, zCoords))
|
shape::index2coords(i, zShapeInfo, zCoords.data());
|
||||||
for (Nd4jLong i = 0; i < zLen; ++i) {
|
|
||||||
|
|
||||||
shape::index2coords(i, zShapeInfo, zCoords.data());
|
for (int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) {
|
||||||
|
|
||||||
for(int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) {
|
if (ix >= 0) {
|
||||||
|
if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) {
|
||||||
|
xCoords[ix--] = zCoords[iz];
|
||||||
|
} else {
|
||||||
|
xCoords[ix--] = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if(ix >= 0) {
|
if (iy >= 0) {
|
||||||
if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) {
|
if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) {
|
||||||
xCoords[ix--] = zCoords[iz];
|
yCoords[iy--] = zCoords[iz];
|
||||||
} else {
|
} else {
|
||||||
xCoords[ix--] = 0;
|
yCoords[iy--] = 0;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if(iy >= 0) {
|
const auto xOffset = shape::getOffset(xShapeInfo, xCoords.data());
|
||||||
if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) {
|
const auto yOffset = shape::getOffset(yShapeInfo, yCoords.data());
|
||||||
yCoords[iy--] = zCoords[iz];
|
const auto zOffset = shape::getOffset(zShapeInfo, zCoords.data());
|
||||||
} else {
|
|
||||||
yCoords[iy--] = 0;
|
z[zOffset] = OpType::op(x[xOffset], y[yOffset], nullptr);
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
const auto xOffset = shape::getOffset(xShapeInfo, xCoords.data());
|
samediff::Threads::parallel_for(func, 0, zLen);
|
||||||
const auto yOffset = shape::getOffset(yShapeInfo, yCoords.data());
|
|
||||||
const auto zOffset = shape::getOffset(zShapeInfo, zCoords.data());
|
|
||||||
|
|
||||||
z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename X, typename Y>
|
template <typename X, typename Y>
|
||||||
|
@ -163,36 +169,39 @@ void TrueBroadcastIntHelper<X>::exec(const NDArray& xArr, const NDArray& yArr, N
|
||||||
|
|
||||||
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
|
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(zLen > Environment::getInstance()->elementwiseThreshold()) firstprivate(xCoords, yCoords, zCoords))
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
for (Nd4jLong i = 0; i < zLen; ++i) {
|
for (auto i = start; i < stop; ++i) {
|
||||||
|
|
||||||
shape::index2coords(i, zShapeInfo, zCoords.data());
|
shape::index2coords(i, zShapeInfo, zCoords.data());
|
||||||
|
|
||||||
for(int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) {
|
for (int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) {
|
||||||
|
|
||||||
if(ix >= 0) {
|
if (ix >= 0) {
|
||||||
if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) {
|
if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) {
|
||||||
xCoords[ix--] = zCoords[iz];
|
xCoords[ix--] = zCoords[iz];
|
||||||
} else {
|
} else {
|
||||||
xCoords[ix--] = 0;
|
xCoords[ix--] = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (iy >= 0) {
|
||||||
|
if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) {
|
||||||
|
yCoords[iy--] = zCoords[iz];
|
||||||
|
} else {
|
||||||
|
yCoords[iy--] = 0;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if(iy >= 0) {
|
const auto xOffset = shape::getOffset(xShapeInfo, xCoords.data());
|
||||||
if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) {
|
const auto yOffset = shape::getOffset(yShapeInfo, yCoords.data());
|
||||||
yCoords[iy--] = zCoords[iz];
|
const auto zOffset = shape::getOffset(zShapeInfo, zCoords.data());
|
||||||
} else {
|
|
||||||
yCoords[iy--] = 0;
|
z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
const auto xOffset = shape::getOffset(xShapeInfo, xCoords.data());
|
samediff::Threads::parallel_for(func, 0, zLen);
|
||||||
const auto yOffset = shape::getOffset(yShapeInfo, yCoords.data());
|
|
||||||
const auto zOffset = shape::getOffset(zShapeInfo, zCoords.data());
|
|
||||||
|
|
||||||
z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename X>
|
template <typename X>
|
||||||
|
|
|
@ -86,7 +86,7 @@ __global__ static void trueBroadcastCuda(const void* vx, const Nd4jLong* xShapeI
|
||||||
const auto zOffset = shape::getOffset(zShapeInfo, zCoords);
|
const auto zOffset = shape::getOffset(zShapeInfo, zCoords);
|
||||||
const auto yOffset = shape::getOffset(yShapeInfo, yCoords);
|
const auto yOffset = shape::getOffset(yShapeInfo, yCoords);
|
||||||
|
|
||||||
z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
|
z[zOffset] = OpType::op(x[xOffset], y[yOffset], nullptr);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -172,7 +172,7 @@ __global__ static void trueBroadcastBoolCuda(const void* vx, const Nd4jLong* xSh
|
||||||
const auto zOffset = shape::getOffset(zShapeInfo, zCoords);
|
const auto zOffset = shape::getOffset(zShapeInfo, zCoords);
|
||||||
const auto yOffset = shape::getOffset(yShapeInfo, yCoords);
|
const auto yOffset = shape::getOffset(yShapeInfo, yCoords);
|
||||||
|
|
||||||
z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
|
z[zOffset] = OpType::op(x[xOffset], y[yOffset], nullptr);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -65,13 +65,14 @@ namespace functions {
|
||||||
Nd4jLong *yShapeInfo,
|
Nd4jLong *yShapeInfo,
|
||||||
void *result,
|
void *result,
|
||||||
Nd4jLong *resultShapeInfo,
|
Nd4jLong *resultShapeInfo,
|
||||||
|
void *extraParams,
|
||||||
int *dimension,
|
int *dimension,
|
||||||
int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
|
int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
|
||||||
|
|
||||||
template <typename OpClass>
|
template <typename OpClass>
|
||||||
static __host__ void intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
|
static __host__ void intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
|
||||||
|
|
||||||
static __host__ void execBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
|
static __host__ void execBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
|
||||||
|
|
||||||
template<typename OpType>
|
template<typename OpType>
|
||||||
static __device__ void transformInverseCuda(
|
static __device__ void transformInverseCuda(
|
||||||
|
@ -81,13 +82,14 @@ namespace functions {
|
||||||
Nd4jLong *yShapeInfo,
|
Nd4jLong *yShapeInfo,
|
||||||
void *result,
|
void *result,
|
||||||
Nd4jLong *resultShapeInfo,
|
Nd4jLong *resultShapeInfo,
|
||||||
|
void *extraParams,
|
||||||
int *dimension,
|
int *dimension,
|
||||||
int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
|
int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
|
||||||
|
|
||||||
template <typename OpClass>
|
template <typename OpClass>
|
||||||
static __host__ void intermediateInverseBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
|
static __host__ void intermediateInverseBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
|
||||||
|
|
||||||
static __host__ void execInverseBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
|
static __host__ void execInverseBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *result, Nd4jLong *resultShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ);
|
||||||
|
|
||||||
#else
|
#else
|
||||||
|
|
||||||
|
@ -98,6 +100,7 @@ namespace functions {
|
||||||
Nd4jLong *yShapeInfo,
|
Nd4jLong *yShapeInfo,
|
||||||
void *result,
|
void *result,
|
||||||
Nd4jLong *resultShapeInfo,
|
Nd4jLong *resultShapeInfo,
|
||||||
|
void *extraParams,
|
||||||
int *dimension,
|
int *dimension,
|
||||||
int dimensionLength,
|
int dimensionLength,
|
||||||
Nd4jLong *tadShapeInfo,
|
Nd4jLong *tadShapeInfo,
|
||||||
|
@ -114,6 +117,7 @@ namespace functions {
|
||||||
Nd4jLong *yShapeInfo,
|
Nd4jLong *yShapeInfo,
|
||||||
void *result,
|
void *result,
|
||||||
Nd4jLong *resultShapeInfo,
|
Nd4jLong *resultShapeInfo,
|
||||||
|
void *extraParams,
|
||||||
int *dimension,
|
int *dimension,
|
||||||
int dimensionLength,
|
int dimensionLength,
|
||||||
Nd4jLong *tadShapeInfo,
|
Nd4jLong *tadShapeInfo,
|
||||||
|
@ -141,6 +145,7 @@ namespace functions {
|
||||||
Nd4jLong *yShapeInfo,
|
Nd4jLong *yShapeInfo,
|
||||||
void *result,
|
void *result,
|
||||||
Nd4jLong *resultShapeInfo,
|
Nd4jLong *resultShapeInfo,
|
||||||
|
void *extraParams,
|
||||||
int *dimension,
|
int *dimension,
|
||||||
int dimensionLength,
|
int dimensionLength,
|
||||||
Nd4jLong *tadShapeInfo,
|
Nd4jLong *tadShapeInfo,
|
||||||
|
@ -157,6 +162,7 @@ namespace functions {
|
||||||
Nd4jLong *yShapeInfo,
|
Nd4jLong *yShapeInfo,
|
||||||
void *result,
|
void *result,
|
||||||
Nd4jLong *resultShapeInfo,
|
Nd4jLong *resultShapeInfo,
|
||||||
|
void *extraParams,
|
||||||
int *dimension,
|
int *dimension,
|
||||||
int dimensionLength,
|
int dimensionLength,
|
||||||
Nd4jLong *tadShapeInfo,
|
Nd4jLong *tadShapeInfo,
|
||||||
|
|
|
@ -39,6 +39,7 @@ namespace functions {
|
||||||
Nd4jLong *yShapeInfo,
|
Nd4jLong *yShapeInfo,
|
||||||
void *z,
|
void *z,
|
||||||
Nd4jLong *zShapeInfo,
|
Nd4jLong *zShapeInfo,
|
||||||
|
void *extraParams,
|
||||||
int *dimension,
|
int *dimension,
|
||||||
int dimensionLength,
|
int dimensionLength,
|
||||||
Nd4jLong *xTadShapeInfo,
|
Nd4jLong *xTadShapeInfo,
|
||||||
|
@ -53,6 +54,7 @@ namespace functions {
|
||||||
yShapeInfo,
|
yShapeInfo,
|
||||||
z,
|
z,
|
||||||
zShapeInfo,
|
zShapeInfo,
|
||||||
|
extraParams,
|
||||||
dimension,
|
dimension,
|
||||||
dimensionLength,
|
dimensionLength,
|
||||||
xTadShapeInfo,
|
xTadShapeInfo,
|
||||||
|
@ -69,6 +71,7 @@ namespace functions {
|
||||||
Nd4jLong *yShapeInfo,
|
Nd4jLong *yShapeInfo,
|
||||||
void *z,
|
void *z,
|
||||||
Nd4jLong *zShapeInfo,
|
Nd4jLong *zShapeInfo,
|
||||||
|
void *extraParams,
|
||||||
int *dimension,
|
int *dimension,
|
||||||
int dimensionLength,
|
int dimensionLength,
|
||||||
Nd4jLong *xTadShapeInfo,
|
Nd4jLong *xTadShapeInfo,
|
||||||
|
@ -83,6 +86,7 @@ namespace functions {
|
||||||
yShapeInfo,
|
yShapeInfo,
|
||||||
z,
|
z,
|
||||||
zShapeInfo,
|
zShapeInfo,
|
||||||
|
extraParams,
|
||||||
dimension,
|
dimension,
|
||||||
dimensionLength,
|
dimensionLength,
|
||||||
xTadShapeInfo,
|
xTadShapeInfo,
|
||||||
|
@ -99,6 +103,7 @@ namespace functions {
|
||||||
Nd4jLong *yShapeInfo,
|
Nd4jLong *yShapeInfo,
|
||||||
void *vz,
|
void *vz,
|
||||||
Nd4jLong *zShapeInfo,
|
Nd4jLong *zShapeInfo,
|
||||||
|
void *vextraParams,
|
||||||
int *dimension,
|
int *dimension,
|
||||||
int dimensionLength,
|
int dimensionLength,
|
||||||
Nd4jLong *xTadShapeInfo,
|
Nd4jLong *xTadShapeInfo,
|
||||||
|
@ -111,6 +116,7 @@ namespace functions {
|
||||||
auto x = reinterpret_cast<X *>(vx);
|
auto x = reinterpret_cast<X *>(vx);
|
||||||
auto y = reinterpret_cast<X *>(vy);
|
auto y = reinterpret_cast<X *>(vy);
|
||||||
auto z = reinterpret_cast<Z *>(vz);
|
auto z = reinterpret_cast<Z *>(vz);
|
||||||
|
auto extraParams = reinterpret_cast<X*>(vextraParams);
|
||||||
|
|
||||||
//decompose in to several sub tads after
|
//decompose in to several sub tads after
|
||||||
//moving all dimensions (in sorted order)
|
//moving all dimensions (in sorted order)
|
||||||
|
@ -155,7 +161,7 @@ namespace functions {
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for (unsigned int f = 0; f < tadLength; f++)
|
for (unsigned int f = 0; f < tadLength; f++)
|
||||||
oZ[f] = OpType::op(oX[f], y[f]);
|
oZ[f] = OpType::op(oX[f], y[f], extraParams);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else if(kindOfLoop == nd4j::LoopKind::EWSNONZERO) {
|
else if(kindOfLoop == nd4j::LoopKind::EWSNONZERO) {
|
||||||
|
@ -165,7 +171,7 @@ namespace functions {
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for (unsigned int f = 0; f < tadLength; f++)
|
for (unsigned int f = 0; f < tadLength; f++)
|
||||||
oZ[f * zEws] = OpType::op(oX[f * xEws], y[f * yEws]);
|
oZ[f * zEws] = OpType::op(oX[f * xEws], y[f * yEws], extraParams);
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xTadShapeShapeInfo, zTadShapeInfo)) {
|
else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xTadShapeShapeInfo, zTadShapeInfo)) {
|
||||||
|
@ -179,7 +185,7 @@ namespace functions {
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for (int f = 0; f < tadLength; f++) {
|
for (int f = 0; f < tadLength; f++) {
|
||||||
auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
|
auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
|
||||||
oZ[offset] = OpType::op(oX[offset], y[offset]);
|
oZ[offset] = OpType::op(oX[offset], y[offset], extraParams);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
@ -197,7 +203,7 @@ namespace functions {
|
||||||
for (int f = 0; f < tadLength; f++) {
|
for (int f = 0; f < tadLength; f++) {
|
||||||
auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
|
auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
|
||||||
auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ);
|
auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ);
|
||||||
oZ[zOffset] = OpType::op(oX[offset], y[offset]);
|
oZ[zOffset] = OpType::op(oX[offset], y[offset], extraParams);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
@ -215,7 +221,7 @@ namespace functions {
|
||||||
for (int f = 0; f < tadLength; f++) {
|
for (int f = 0; f < tadLength; f++) {
|
||||||
auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
|
auto offset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
|
||||||
auto yOffset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY);
|
auto yOffset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY);
|
||||||
oZ[offset] = OpType::op(oX[offset], y[yOffset]);
|
oZ[offset] = OpType::op(oX[offset], y[yOffset], extraParams);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -234,7 +240,7 @@ namespace functions {
|
||||||
for (int f = 0; f < tadLength; f++) {
|
for (int f = 0; f < tadLength; f++) {
|
||||||
auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
|
auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
|
||||||
auto offset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY);
|
auto offset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY);
|
||||||
oZ[offset] = OpType::op(oX[xOffset], y[offset]);
|
oZ[offset] = OpType::op(oX[xOffset], y[offset], extraParams);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
@ -255,7 +261,7 @@ namespace functions {
|
||||||
auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
|
auto xOffset = shape::indexOffset(f, xTadShapeShapeInfo, tadShapeShapeInfoCast, canCastX);
|
||||||
auto yOffset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY);
|
auto yOffset = shape::indexOffset(f, yShapeInfo, yShapeInfoCast, canCastY);
|
||||||
auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ);
|
auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ);
|
||||||
oZ[zOffset] = OpType::op(oX[xOffset], y[yOffset]);
|
oZ[zOffset] = OpType::op(oX[xOffset], y[yOffset], extraParams);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
@ -270,6 +276,7 @@ namespace functions {
|
||||||
Nd4jLong *yShapeInfo,
|
Nd4jLong *yShapeInfo,
|
||||||
void *vz,
|
void *vz,
|
||||||
Nd4jLong *zShapeInfo,
|
Nd4jLong *zShapeInfo,
|
||||||
|
void *vextraParams,
|
||||||
int *dimension,
|
int *dimension,
|
||||||
int dimensionLength,
|
int dimensionLength,
|
||||||
Nd4jLong *yTadShapeInfo,
|
Nd4jLong *yTadShapeInfo,
|
||||||
|
@ -282,6 +289,7 @@ namespace functions {
|
||||||
auto x = reinterpret_cast<X *>(vx);
|
auto x = reinterpret_cast<X *>(vx);
|
||||||
auto y = reinterpret_cast<X *>(vy);
|
auto y = reinterpret_cast<X *>(vy);
|
||||||
auto z = reinterpret_cast<Z *>(vz);
|
auto z = reinterpret_cast<Z *>(vz);
|
||||||
|
auto extraParams = reinterpret_cast<X*>(vextraParams);
|
||||||
|
|
||||||
//decompose in to several sub tads after
|
//decompose in to several sub tads after
|
||||||
//moving all dimensions (in sorted order)
|
//moving all dimensions (in sorted order)
|
||||||
|
@ -326,7 +334,7 @@ namespace functions {
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for (unsigned int f = 0; f < tadLength; f++)
|
for (unsigned int f = 0; f < tadLength; f++)
|
||||||
oZ[f] = OpType::op(x[f], oY[f]);
|
oZ[f] = OpType::op(x[f], oY[f], extraParams);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else if(kindOfLoop == nd4j::LoopKind::EWSNONZERO) {
|
else if(kindOfLoop == nd4j::LoopKind::EWSNONZERO) {
|
||||||
|
@ -336,7 +344,7 @@ namespace functions {
|
||||||
|
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for (uint f = 0; f < tadLength; f++)
|
for (uint f = 0; f < tadLength; f++)
|
||||||
oZ[f * zEws] = OpType::op(x[f * xEws], oY[f * yEws]);
|
oZ[f * zEws] = OpType::op(x[f * xEws], oY[f * yEws], extraParams);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else if(shape::haveSameShapeAndStrides(yTadShapeShapeInfo, xShapeInfo) && shape::haveSameShapeAndStrides(yTadShapeShapeInfo, zTadShapeInfo)) {
|
else if(shape::haveSameShapeAndStrides(yTadShapeShapeInfo, xShapeInfo) && shape::haveSameShapeAndStrides(yTadShapeShapeInfo, zTadShapeInfo)) {
|
||||||
|
@ -351,7 +359,7 @@ namespace functions {
|
||||||
PRAGMA_OMP_SIMD
|
PRAGMA_OMP_SIMD
|
||||||
for (int f = 0; f < tadLength; f++) {
|
for (int f = 0; f < tadLength; f++) {
|
||||||
auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
|
auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
|
||||||
oZ[offset] = OpType::op(x[offset], oY[offset]);
|
oZ[offset] = OpType::op(x[offset], oY[offset], extraParams);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -370,7 +378,7 @@ namespace functions {
|
||||||
for (int f = 0; f < tadLength; f++) {
|
for (int f = 0; f < tadLength; f++) {
|
||||||
auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
|
auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
|
||||||
auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ);
|
auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ);
|
||||||
oZ[zOffset] = OpType::op(x[offset], oY[offset]);
|
oZ[zOffset] = OpType::op(x[offset], oY[offset], extraParams);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -389,7 +397,7 @@ namespace functions {
|
||||||
for (int f = 0; f < tadLength; f++) {
|
for (int f = 0; f < tadLength; f++) {
|
||||||
auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
|
auto offset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
|
||||||
auto xOffset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX);
|
auto xOffset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX);
|
||||||
oZ[offset] = OpType::op(x[xOffset], oY[offset]);
|
oZ[offset] = OpType::op(x[xOffset], oY[offset], extraParams);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -408,7 +416,7 @@ namespace functions {
|
||||||
for (int f = 0; f < tadLength; f++) {
|
for (int f = 0; f < tadLength; f++) {
|
||||||
auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
|
auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
|
||||||
auto offset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX);
|
auto offset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX);
|
||||||
oZ[offset] = OpType::op(x[offset], oY[yOffset]);
|
oZ[offset] = OpType::op(x[offset], oY[yOffset], extraParams);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -430,7 +438,7 @@ namespace functions {
|
||||||
auto xOffset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX);
|
auto xOffset = shape::indexOffset(f, xShapeInfo, xShapeInfoCast, canCastX);
|
||||||
auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
|
auto yOffset = shape::indexOffset(f, yTadShapeShapeInfo, tadShapeShapeInfoCast, canCastY);
|
||||||
auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ);
|
auto zOffset = shape::indexOffset(f, zTadShapeInfo, tadShapeInfoZCast, canCastZ);
|
||||||
oZ[zOffset] = OpType::op(x[xOffset], oY[yOffset]);
|
oZ[zOffset] = OpType::op(x[xOffset], oY[yOffset], extraParams);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -40,10 +40,11 @@ static __global__ void broadcastBoolSimple(
|
||||||
Nd4jLong *yShapeInfo,
|
Nd4jLong *yShapeInfo,
|
||||||
void *z,
|
void *z,
|
||||||
Nd4jLong *zShapeInfo,
|
Nd4jLong *zShapeInfo,
|
||||||
|
void *extraParams,
|
||||||
int *dimension,
|
int *dimension,
|
||||||
int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
|
int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
|
||||||
|
|
||||||
functions::broadcast::BroadcastBool<X, Z>::template transformCuda<OpClass>(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo,dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ);
|
functions::broadcast::BroadcastBool<X, Z>::template transformCuda<OpClass>(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo, extraParams, dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ);
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -55,10 +56,11 @@ static __global__ void broadcastBoolInverseSimple(
|
||||||
Nd4jLong *yShapeInfo,
|
Nd4jLong *yShapeInfo,
|
||||||
void *z,
|
void *z,
|
||||||
Nd4jLong *zShapeInfo,
|
Nd4jLong *zShapeInfo,
|
||||||
|
void *extraParams,
|
||||||
int *dimension,
|
int *dimension,
|
||||||
int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
|
int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
|
||||||
|
|
||||||
functions::broadcast::BroadcastBool<X, Z>::template transformInverseCuda<OpClass>(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo,dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ);
|
functions::broadcast::BroadcastBool<X, Z>::template transformInverseCuda<OpClass>(x,xShapeInfo,y,yShapeInfo,z,zShapeInfo,extraParams,dimension,dimensionLength,tadOnlyShapeInfo,tadOffsets,tadOnlyShapeInfoZ,tadOffsetsZ);
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace functions {
|
namespace functions {
|
||||||
|
@ -66,15 +68,15 @@ namespace functions {
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
template<typename X, typename Z>
|
template<typename X, typename Z>
|
||||||
template <typename OpClass>
|
template <typename OpClass>
|
||||||
__host__ void BroadcastBool<X,Z>::intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
|
__host__ void BroadcastBool<X,Z>::intermediateBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
|
||||||
broadcastBoolSimple<X, Z, OpClass><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ);
|
broadcastBoolSimple<X, Z, OpClass><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ);
|
||||||
nd4j::DebugHelper::checkErrorCode(stream, "intermediateBroadcastBool(...) failed");
|
nd4j::DebugHelper::checkErrorCode(stream, "intermediateBroadcastBool(...) failed");
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
template<typename X, typename Y>
|
template<typename X, typename Y>
|
||||||
__host__ void BroadcastBool<X,Y>::execBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
|
__host__ void BroadcastBool<X,Y>::execBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
|
||||||
DISPATCH_BY_OPNUM_TT(intermediateBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), OPS_A(BROADCAST_BOOL_OPS))
|
DISPATCH_BY_OPNUM_TT(intermediateBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), OPS_A(BROADCAST_BOOL_OPS))
|
||||||
|
|
||||||
DEBUG_KERNEL(stream, opNum);
|
DEBUG_KERNEL(stream, opNum);
|
||||||
}
|
}
|
||||||
|
@ -82,15 +84,15 @@ namespace functions {
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
template<typename X, typename Z>
|
template<typename X, typename Z>
|
||||||
template <typename OpClass>
|
template <typename OpClass>
|
||||||
__host__ void BroadcastBool<X,Z>::intermediateInverseBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
|
__host__ void BroadcastBool<X,Z>::intermediateInverseBroadcast(dim3 launchDims, cudaStream_t *stream, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
|
||||||
broadcastBoolInverseSimple<X, Z, OpClass><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ);
|
broadcastBoolInverseSimple<X, Z, OpClass><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ);
|
||||||
nd4j::DebugHelper::checkErrorCode(stream, "intermediateBroadcastBool(...) failed");
|
nd4j::DebugHelper::checkErrorCode(stream, "intermediateBroadcastBool(...) failed");
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
template<typename X, typename Y>
|
template<typename X, typename Y>
|
||||||
__host__ void BroadcastBool<X,Y>::execInverseBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
|
__host__ void BroadcastBool<X,Y>::execInverseBroadcast(dim3 launchDims, cudaStream_t *stream, int opNum, void *x, Nd4jLong *xShapeInfo, void *y, Nd4jLong *yShapeInfo, void *z, Nd4jLong *zShapeInfo, void *extraParams, int *dimension, int dimensionLength, Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
|
||||||
DISPATCH_BY_OPNUM_TT(intermediateInverseBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), OPS_A(BROADCAST_BOOL_OPS))
|
DISPATCH_BY_OPNUM_TT(intermediateInverseBroadcast, PARAMS(launchDims, stream, x, xShapeInfo, y, yShapeInfo, z, zShapeInfo, extraParams, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), OPS_A(BROADCAST_BOOL_OPS))
|
||||||
|
|
||||||
DEBUG_KERNEL(stream, opNum);
|
DEBUG_KERNEL(stream, opNum);
|
||||||
}
|
}
|
||||||
|
@ -102,6 +104,7 @@ namespace functions {
|
||||||
void *vx, Nd4jLong *xShapeInfo,
|
void *vx, Nd4jLong *xShapeInfo,
|
||||||
void *vy, Nd4jLong *yShapeInfo,
|
void *vy, Nd4jLong *yShapeInfo,
|
||||||
void *vz, Nd4jLong *zShapeInfo,
|
void *vz, Nd4jLong *zShapeInfo,
|
||||||
|
void *vextraParams,
|
||||||
int *dimension, int dimensionLength,
|
int *dimension, int dimensionLength,
|
||||||
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
|
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
|
||||||
|
|
||||||
|
@ -113,6 +116,7 @@ namespace functions {
|
||||||
auto x = reinterpret_cast<X*>(vx);
|
auto x = reinterpret_cast<X*>(vx);
|
||||||
auto y = reinterpret_cast<X*>(vy);
|
auto y = reinterpret_cast<X*>(vy);
|
||||||
auto z = reinterpret_cast<Z*>(vz);
|
auto z = reinterpret_cast<Z*>(vz);
|
||||||
|
auto extraParams = reinterpret_cast<X*>(vextraParams);
|
||||||
|
|
||||||
//decompose in to several sub tads after
|
//decompose in to several sub tads after
|
||||||
//moving all dimensions (in sorted order)
|
//moving all dimensions (in sorted order)
|
||||||
|
@ -140,7 +144,7 @@ namespace functions {
|
||||||
if(tadEWS > 0 && zEWS > 0 && xEWS > 0 && dimensionLength == 1) {
|
if(tadEWS > 0 && zEWS > 0 && xEWS > 0 && dimensionLength == 1) {
|
||||||
|
|
||||||
for (int i = threadIdx.x; i < tadLength; i+= blockDim.x)
|
for (int i = threadIdx.x; i < tadLength; i+= blockDim.x)
|
||||||
rZ[i * zEWS] = OpType::op(x[i * xEWS], rY[i * tadEWS]);
|
rZ[i * zEWS] = OpType::op(x[i * xEWS], rY[i * tadEWS], extraParams);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
// it is expected that x and z tads and y array all have the same length
|
// it is expected that x and z tads and y array all have the same length
|
||||||
|
@ -149,7 +153,7 @@ namespace functions {
|
||||||
auto yOffset = shape::getIndexOffset(i, tadOnlyShapeInfo);
|
auto yOffset = shape::getIndexOffset(i, tadOnlyShapeInfo);
|
||||||
auto zOffset = shape::getIndexOffset(i, tadOnlyShapeInfoZ);
|
auto zOffset = shape::getIndexOffset(i, tadOnlyShapeInfoZ);
|
||||||
|
|
||||||
rZ[zOffset] = OpType::op(x[xOffset], rY[yOffset]);
|
rZ[zOffset] = OpType::op(x[xOffset], rY[yOffset], extraParams);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -162,6 +166,7 @@ namespace functions {
|
||||||
void *vx, Nd4jLong *xShapeInfo,
|
void *vx, Nd4jLong *xShapeInfo,
|
||||||
void *vy, Nd4jLong *yShapeInfo,
|
void *vy, Nd4jLong *yShapeInfo,
|
||||||
void *vz, Nd4jLong *zShapeInfo,
|
void *vz, Nd4jLong *zShapeInfo,
|
||||||
|
void *vextraParams,
|
||||||
int *dimension, int dimensionLength,
|
int *dimension, int dimensionLength,
|
||||||
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
|
Nd4jLong *tadOnlyShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadOnlyShapeInfoZ, Nd4jLong *tadOffsetsZ) {
|
||||||
|
|
||||||
|
@ -173,6 +178,7 @@ namespace functions {
|
||||||
auto x = reinterpret_cast<X*>(vx);
|
auto x = reinterpret_cast<X*>(vx);
|
||||||
auto y = reinterpret_cast<X*>(vy);
|
auto y = reinterpret_cast<X*>(vy);
|
||||||
auto z = reinterpret_cast<Z*>(vz);
|
auto z = reinterpret_cast<Z*>(vz);
|
||||||
|
auto extraParams = reinterpret_cast<X*>(vextraParams);
|
||||||
|
|
||||||
//decompose in to several sub tads after
|
//decompose in to several sub tads after
|
||||||
//moving all dimensions (in sorted order)
|
//moving all dimensions (in sorted order)
|
||||||
|
@ -208,7 +214,7 @@ namespace functions {
|
||||||
if(tadEWS > 0 && zEWS > 0 && yEWS > 0 && dimensionLength == 1) {
|
if(tadEWS > 0 && zEWS > 0 && yEWS > 0 && dimensionLength == 1) {
|
||||||
|
|
||||||
for (int i = threadIdx.x; i < tadLength; i+= blockDim.x)
|
for (int i = threadIdx.x; i < tadLength; i+= blockDim.x)
|
||||||
rZ[i * zEWS] = OpType::op(rX[i * tadEWS], y[i * yEWS]);
|
rZ[i * zEWS] = OpType::op(rX[i * tadEWS], y[i * yEWS], extraParams);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
// it is expected that x and z tads and y array all have the same length
|
// it is expected that x and z tads and y array all have the same length
|
||||||
|
@ -217,7 +223,7 @@ namespace functions {
|
||||||
auto yOffset = shape::getIndexOffset(i, yShapeInfo);
|
auto yOffset = shape::getIndexOffset(i, yShapeInfo);
|
||||||
auto zOffset = shape::getIndexOffset(i, tadOnlyShapeInfoZ);
|
auto zOffset = shape::getIndexOffset(i, tadOnlyShapeInfoZ);
|
||||||
|
|
||||||
rZ[zOffset] = OpType::op(rX[xOffset], y[yOffset]);
|
rZ[zOffset] = OpType::op(rX[xOffset], y[yOffset], extraParams);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -45,13 +45,13 @@
|
||||||
(2, LessThan),\
|
(2, LessThan),\
|
||||||
(3, Epsilon),\
|
(3, Epsilon),\
|
||||||
(4, GreaterThanOrEqual),\
|
(4, GreaterThanOrEqual),\
|
||||||
(5, LessThanOrEqual),\
|
(5, MatchCondition) ,\
|
||||||
(6, NotEqualTo),\
|
(6, NotEqualTo),\
|
||||||
(7, And),\
|
(7, And),\
|
||||||
(8, Or),\
|
(8, Or),\
|
||||||
(9, Xor) ,\
|
(9, Xor) ,\
|
||||||
(10, Not)
|
(10, Not) ,\
|
||||||
|
(11, LessThanOrEqual)
|
||||||
|
|
||||||
#define BROADCAST_OPS \
|
#define BROADCAST_OPS \
|
||||||
(0, Add), \
|
(0, Add), \
|
||||||
|
@ -198,12 +198,13 @@
|
||||||
(2, LessThan),\
|
(2, LessThan),\
|
||||||
(3, Epsilon),\
|
(3, Epsilon),\
|
||||||
(4, GreaterThanOrEqual),\
|
(4, GreaterThanOrEqual),\
|
||||||
(5, LessThanOrEqual),\
|
(5, MatchCondition) ,\
|
||||||
(6, NotEqualTo),\
|
(6, NotEqualTo),\
|
||||||
(7, And),\
|
(7, And),\
|
||||||
(8, Or),\
|
(8, Or),\
|
||||||
(9, Xor) ,\
|
(9, Xor) ,\
|
||||||
(10, Not)
|
(10, Not) ,\
|
||||||
|
(11, LessThanOrEqual)
|
||||||
|
|
||||||
#define SCALAR_OPS \
|
#define SCALAR_OPS \
|
||||||
(0, Add),\
|
(0, Add),\
|
||||||
|
@ -341,12 +342,13 @@
|
||||||
(2, LessThan),\
|
(2, LessThan),\
|
||||||
(3, Epsilon),\
|
(3, Epsilon),\
|
||||||
(4, GreaterThanOrEqual),\
|
(4, GreaterThanOrEqual),\
|
||||||
(5, LessThanOrEqual),\
|
(5, MatchCondition) ,\
|
||||||
(6, NotEqualTo),\
|
(6, NotEqualTo),\
|
||||||
(7, And),\
|
(7, And),\
|
||||||
(8, Or),\
|
(8, Or),\
|
||||||
(9, Xor) ,\
|
(9, Xor) ,\
|
||||||
(10, Not)
|
(10, Not) ,\
|
||||||
|
(11, LessThanOrEqual)
|
||||||
|
|
||||||
#define PAIRWISE_TRANSFORM_OPS \
|
#define PAIRWISE_TRANSFORM_OPS \
|
||||||
(0, Add),\
|
(0, Add),\
|
||||||
|
|
|
@ -2302,54 +2302,66 @@ namespace simdOps {
|
||||||
return old + opOutput;
|
return old + opOutput;
|
||||||
}
|
}
|
||||||
|
|
||||||
// this op return 1.0 if condition met, 0.0 otherwise
|
op_def static Z op(X d1, X compare, X eps, int mode) {
|
||||||
op_def static Z op(X d1, X *extraParams) {
|
switch (mode) {
|
||||||
X compare = extraParams[0];
|
case 0: // equals
|
||||||
X eps = extraParams[1];
|
return nd4j::math::nd4j_abs<X>(d1 - compare) <= eps ? 1 : 0;
|
||||||
|
case 1: // not equals
|
||||||
auto mode = static_cast<int>(extraParams[2]);
|
return nd4j::math::nd4j_abs<X>(d1 - compare) > eps ? 1 : 0;
|
||||||
//printf("value: %f; comp: %f; eps: %f; mode: %i;\n", (float) d1, (float) compare, (float) eps, mode);
|
case 2: // less_than
|
||||||
|
return d1 < compare ? 1 : 0;
|
||||||
switch (mode) {
|
case 3: // greater_than
|
||||||
case 0: // equals
|
return d1 > compare ? 1 : 0;
|
||||||
return nd4j::math::nd4j_abs<X>(d1 - compare) <= eps ? 1 : 0;
|
case 4: // less_or_equals_than
|
||||||
case 1: // not equals
|
return d1 <= compare ? 1 : 0;
|
||||||
return nd4j::math::nd4j_abs<X>(d1 - compare) > eps ? 1 : 0;
|
case 5: // greater_or_equals_than
|
||||||
case 2: // less_than
|
return d1 >= compare ? 1 : 0;
|
||||||
return d1 < compare ? 1 : 0;
|
case 6: // abs_less_than
|
||||||
case 3: // greater_than
|
return nd4j::math::nd4j_abs<X>(d1) < compare ? 1 : 0;
|
||||||
return d1 > compare ? 1 : 0;
|
case 7: // abs_greater_than
|
||||||
case 4: // less_or_equals_than
|
return nd4j::math::nd4j_abs<X>(d1) > compare ? 1 : 0;
|
||||||
return d1 <= compare ? 1 : 0;
|
case 8: // is inf
|
||||||
case 5: // greater_or_equals_than
|
return nd4j::math::nd4j_isinf(d1) ? 1 : 0;
|
||||||
return d1 >= compare ? 1 : 0;
|
case 9: // is nan
|
||||||
case 6: // abs_less_than
|
return nd4j::math::nd4j_isnan(d1) ? 1 : 0;
|
||||||
return nd4j::math::nd4j_abs<X>(d1) < compare ? 1 : 0;
|
case 10:
|
||||||
case 7: // abs_greater_than
|
return (d1 == compare) ? 1 : 0;
|
||||||
return nd4j::math::nd4j_abs<X>(d1) > compare ? 1 : 0;
|
case 11:
|
||||||
case 8: // is inf
|
return (d1 != compare) ? 1 : 0;
|
||||||
return nd4j::math::nd4j_isinf(d1) ? 1 : 0;
|
case 12: // abs_greater_or_equals_than
|
||||||
case 9: // is nan
|
return nd4j::math::nd4j_abs<X>(d1) >= compare ? 1 : 0;
|
||||||
return nd4j::math::nd4j_isnan(d1) ? 1 : 0;
|
case 13: // abs_less_or_equals_than
|
||||||
case 10:
|
return nd4j::math::nd4j_abs<X>(d1) <= compare ? 1 : 0;
|
||||||
return (d1 == compare) ? 1 : 0;
|
|
||||||
case 11:
|
|
||||||
return (d1 != compare) ? 1 : 0;
|
|
||||||
case 12: // abs_greater_or_equals_than
|
|
||||||
return nd4j::math::nd4j_abs<X>(d1) >= compare ? 1 : 0;
|
|
||||||
case 13: // abs_less_or_equals_than
|
|
||||||
return nd4j::math::nd4j_abs<X>(d1) <= compare ? 1 : 0;
|
|
||||||
case 14:
|
case 14:
|
||||||
// isFinite
|
// isFinite
|
||||||
return !(nd4j::math::nd4j_isinf(d1) || nd4j::math::nd4j_isnan(d1)) ? 1 : 0;
|
return !(nd4j::math::nd4j_isinf(d1) || nd4j::math::nd4j_isnan(d1)) ? 1 : 0;
|
||||||
case 15:
|
case 15:
|
||||||
// isInfinite
|
// isInfinite
|
||||||
return nd4j::math::nd4j_isinf(d1) || nd4j::math::nd4j_isnan(d1) ? 1 : 0;
|
return nd4j::math::nd4j_isinf(d1) || nd4j::math::nd4j_isnan(d1) ? 1 : 0;
|
||||||
default:
|
default:
|
||||||
printf("Undefined match condition: [%i]\n", mode);
|
printf("Undefined match condition: [%i]\n", mode);
|
||||||
}
|
}
|
||||||
|
|
||||||
return d1;
|
return d1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// this op return 1.0 if condition met, 0.0 otherwise
|
||||||
|
op_def static Z op(X d1, X compare, X *extraParams) {
|
||||||
|
X eps = extraParams[1];
|
||||||
|
|
||||||
|
auto mode = static_cast<int>(extraParams[0]);
|
||||||
|
|
||||||
|
return op(d1, compare, eps, mode);
|
||||||
|
}
|
||||||
|
|
||||||
|
// this op return 1.0 if condition met, 0.0 otherwise
|
||||||
|
op_def static Z op(X d1, X *extraParams) {
|
||||||
|
X compare = extraParams[0];
|
||||||
|
X eps = extraParams[1];
|
||||||
|
|
||||||
|
auto mode = static_cast<int>(extraParams[2]);
|
||||||
|
|
||||||
|
return op(d1, compare, eps, mode);
|
||||||
}
|
}
|
||||||
|
|
||||||
op_def static Z postProcess(Z reduction, Nd4jLong n, X *extraParams) {
|
op_def static Z postProcess(Z reduction, Nd4jLong n, X *extraParams) {
|
||||||
|
|
|
@ -1342,6 +1342,7 @@ TEST_F(CudaBasicsTests1, execBroadcastBool_1) {
|
||||||
nullptr, x.getShapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
|
nullptr, x.getShapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
|
||||||
nullptr, y.getShapeInfo(), y.specialBuffer(), y.specialShapeInfo(),
|
nullptr, y.getShapeInfo(), y.specialBuffer(), y.specialShapeInfo(),
|
||||||
nullptr, z.getShapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
|
nullptr, z.getShapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
|
||||||
|
nullptr,
|
||||||
(int*)devicePtrs[0], dimensions.size(),
|
(int*)devicePtrs[0], dimensions.size(),
|
||||||
(Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2],
|
(Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2],
|
||||||
nullptr, nullptr);
|
nullptr, nullptr);
|
||||||
|
@ -1400,6 +1401,7 @@ TEST_F(CudaBasicsTests1, execBroadcastBool_2) {
|
||||||
nullptr, x.getShapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
|
nullptr, x.getShapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
|
||||||
nullptr, y.getShapeInfo(), y.specialBuffer(), y.specialShapeInfo(),
|
nullptr, y.getShapeInfo(), y.specialBuffer(), y.specialShapeInfo(),
|
||||||
nullptr, z.getShapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
|
nullptr, z.getShapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
|
||||||
|
nullptr,
|
||||||
(int*)devicePtrs[0], dimensions.size(),
|
(int*)devicePtrs[0], dimensions.size(),
|
||||||
(Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2],
|
(Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2],
|
||||||
nullptr, nullptr);
|
nullptr, nullptr);
|
||||||
|
|
|
@ -674,6 +674,7 @@ TEST_F(LegacyOpsTests, test_inverse_broadcast_2) {
|
||||||
x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
|
x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
|
||||||
y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(),
|
y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(),
|
||||||
z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
|
z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
|
||||||
|
nullptr,
|
||||||
nullptr, 0,
|
nullptr, 0,
|
||||||
tadPackY.platformShapeInfo(), tadPackY.platformOffsets(),
|
tadPackY.platformShapeInfo(), tadPackY.platformOffsets(),
|
||||||
tadPackY.platformShapeInfo(), tadPackY.platformOffsets());
|
tadPackY.platformShapeInfo(), tadPackY.platformOffsets());
|
||||||
|
|
|
@ -202,6 +202,7 @@ printf("Unsupported for cuda now.\n");
|
||||||
nullptr, nullptr,
|
nullptr, nullptr,
|
||||||
exp.buffer(), exp.shapeInfo(),
|
exp.buffer(), exp.shapeInfo(),
|
||||||
nullptr, nullptr,
|
nullptr, nullptr,
|
||||||
|
nullptr,
|
||||||
dimension.buffer(), dimension.shapeInfo(),
|
dimension.buffer(), dimension.shapeInfo(),
|
||||||
nullptr, nullptr);
|
nullptr, nullptr);
|
||||||
ASSERT_TRUE(exp.e<bool>(1) && !exp.e<bool>(0));
|
ASSERT_TRUE(exp.e<bool>(1) && !exp.e<bool>(0));
|
||||||
|
|
|
@ -417,7 +417,7 @@ public class LegacyOpMapper {
|
||||||
case 4:
|
case 4:
|
||||||
return ScalarGreaterThanOrEqual.class;
|
return ScalarGreaterThanOrEqual.class;
|
||||||
case 5:
|
case 5:
|
||||||
return ScalarLessThanOrEqual.class;
|
return MatchCondition.class;
|
||||||
case 6:
|
case 6:
|
||||||
return ScalarNotEquals.class;
|
return ScalarNotEquals.class;
|
||||||
case 7:
|
case 7:
|
||||||
|
@ -428,6 +428,8 @@ public class LegacyOpMapper {
|
||||||
return ScalarXor.class;
|
return ScalarXor.class;
|
||||||
case 10:
|
case 10:
|
||||||
return ScalarNot.class;
|
return ScalarNot.class;
|
||||||
|
case 11:
|
||||||
|
return ScalarLessThanOrEqual.class;
|
||||||
default:
|
default:
|
||||||
throw new UnsupportedOperationException("No known scalar bool op for op number: " + opNum);
|
throw new UnsupportedOperationException("No known scalar bool op for op number: " + opNum);
|
||||||
}
|
}
|
||||||
|
|
|
@ -1864,7 +1864,10 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray match(INDArray comp, Condition condition) {
|
public INDArray match(INDArray comp, Condition condition) {
|
||||||
return Nd4j.getExecutioner().exec(new MatchConditionTransform(this,comp,condition));
|
// TODO: obviously, we can make this broadcastable, eventually. But this will require new CustomOp based on MatchCondition
|
||||||
|
Preconditions.checkArgument(Arrays.equals(this.shape(), comp.shape()), "Shapes must be equal");
|
||||||
|
Preconditions.checkArgument(this.dataType() == comp.dataType(), "Data types bmust be equal");
|
||||||
|
return Nd4j.getExecutioner().exec(new MatchConditionTransform(this, comp, Nd4j.createUninitialized(DataType.BOOL, this.shape()), condition));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -64,7 +64,7 @@ public class BroadcastLessThanOrEqual extends BaseBroadcastBoolOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int opNum() {
|
public int opNum() {
|
||||||
return 5;
|
return 11;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -55,7 +55,7 @@ public class ScalarLessThanOrEqual extends BaseScalarBoolOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int opNum() {
|
public int opNum() {
|
||||||
return 5;
|
return 11;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -53,6 +53,11 @@ public class MatchConditionTransform extends BaseTransformBoolOp {
|
||||||
|
|
||||||
public MatchConditionTransform() {}
|
public MatchConditionTransform() {}
|
||||||
|
|
||||||
|
public MatchConditionTransform(@NonNull INDArray x, @NonNull INDArray y, @NonNull INDArray z, @NonNull Condition condition) {
|
||||||
|
this(x, z, Nd4j.EPS_THRESHOLD, condition);
|
||||||
|
this.y = y;
|
||||||
|
}
|
||||||
|
|
||||||
public MatchConditionTransform(@NonNull INDArray x, @NonNull INDArray z, @NonNull Condition condition) {
|
public MatchConditionTransform(@NonNull INDArray x, @NonNull INDArray z, @NonNull Condition condition) {
|
||||||
this(x, z, Nd4j.EPS_THRESHOLD, condition);
|
this(x, z, Nd4j.EPS_THRESHOLD, condition);
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,71 +25,283 @@ import org.nd4j.linalg.factory.Nd4j;
|
||||||
*/
|
*/
|
||||||
public class Conditions {
|
public class Conditions {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method will create Condition that checks if value is infinite
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
public static Condition isInfinite() {
|
public static Condition isInfinite() {
|
||||||
return new IsInfinite();
|
return new IsInfinite();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method will create Condition that checks if value is NaN
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
public static Condition isNan() {
|
public static Condition isNan() {
|
||||||
return new IsNaN();
|
return new IsNaN();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method will create Condition that checks if value is finite
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
public static Condition isFinite() {
|
public static Condition isFinite() {
|
||||||
return new IsFinite();
|
return new IsFinite();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method will create Condition that checks if value is NOT finite
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
public static Condition notFinite() {
|
public static Condition notFinite() {
|
||||||
return new NotFinite();
|
return new NotFinite();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method will create Condition that checks if value is two values are not equal wrt eps
|
||||||
|
*
|
||||||
|
* PLEASE NOTE: This condition should be used only with pairwise methods, i.e. INDArray.match(...)
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public static Condition epsNotEquals() {
|
||||||
|
// in case of pairwise MatchCondition we don't really care about number here
|
||||||
|
return epsNotEquals(0.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method will create Condition that checks if value is two values are not equal wrt eps
|
||||||
|
*
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
public static Condition epsNotEquals(Number value) {
|
public static Condition epsNotEquals(Number value) {
|
||||||
return new EpsilonNotEquals(value);
|
return new EpsilonNotEquals(value);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method will create Condition that checks if value is two values are equal wrt eps
|
||||||
|
*
|
||||||
|
* PLEASE NOTE: This condition should be used only with pairwise methods, i.e. INDArray.match(...)
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public static Condition epsEquals() {
|
||||||
|
// in case of pairwise MatchCondition we don't really care about number here
|
||||||
|
return epsEquals(0.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method will create Condition that checks if value is two values are equal wrt eps
|
||||||
|
*
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
public static Condition epsEquals(Number value) {
|
public static Condition epsEquals(Number value) {
|
||||||
return epsEquals(value, Nd4j.EPS_THRESHOLD);
|
return epsEquals(value, Nd4j.EPS_THRESHOLD);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method will create Condition that checks if value is two values are equal wrt eps
|
||||||
|
*
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
public static Condition epsEquals(Number value, Number epsilon) {
|
public static Condition epsEquals(Number value, Number epsilon) {
|
||||||
return new EpsilonEquals(value, epsilon.doubleValue());
|
return new EpsilonEquals(value, epsilon.doubleValue());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method will create Condition that checks if value is two values are equal
|
||||||
|
*
|
||||||
|
* PLEASE NOTE: This condition should be used only with pairwise methods, i.e. INDArray.match(...)
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public static Condition equals() {
|
||||||
|
// in case of pairwise MatchCondition we don't really care about number here
|
||||||
|
return equals(0.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method will create Condition that checks if value is two values are equal
|
||||||
|
*
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
public static Condition equals(Number value) {
|
public static Condition equals(Number value) {
|
||||||
return new EqualsCondition(value);
|
return new EqualsCondition(value);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method will create Condition that checks if value is two values are not equal
|
||||||
|
*
|
||||||
|
* PLEASE NOTE: This condition should be used only with pairwise methods, i.e. INDArray.match(...)
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public static Condition notEquals() {
|
||||||
|
// in case of pairwise MatchCondition we don't really care about number here
|
||||||
|
return notEquals(0.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method will create Condition that checks if value is two values are not equal
|
||||||
|
*
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
public static Condition notEquals(Number value) {
|
public static Condition notEquals(Number value) {
|
||||||
return new NotEqualsCondition(value);
|
return new NotEqualsCondition(value);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method will create Condition that checks if value is value X is greater than value Y
|
||||||
|
*
|
||||||
|
* PLEASE NOTE: This condition should be used only with pairwise methods, i.e. INDArray.match(...)
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public static Condition greaterThan() {
|
||||||
|
// in case of pairwise MatchCondition we don't really care about number here
|
||||||
|
return greaterThan(0.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method will create Condition that checks if value is value X is greater than value Y
|
||||||
|
*
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
public static Condition greaterThan(Number value) {
|
public static Condition greaterThan(Number value) {
|
||||||
return new GreaterThan(value);
|
return new GreaterThan(value);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method will create Condition that checks if value is value X is less than value Y
|
||||||
|
*
|
||||||
|
* PLEASE NOTE: This condition should be used only with pairwise methods, i.e. INDArray.match(...)
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public static Condition lessThan() {
|
||||||
|
// in case of pairwise MatchCondition we don't really care about number here
|
||||||
|
return lessThan(0.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method will create Condition that checks if value is value X is less than value Y
|
||||||
|
*
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
public static Condition lessThan(Number value) {
|
public static Condition lessThan(Number value) {
|
||||||
return new LessThan(value);
|
return new LessThan(value);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method will create Condition that checks if value is value X is less than or equal to value Y
|
||||||
|
*
|
||||||
|
* PLEASE NOTE: This condition should be used only with pairwise methods, i.e. INDArray.match(...)
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public static Condition lessThanOrEqual() {
|
||||||
|
// in case of pairwise MatchCondition we don't really care about number here
|
||||||
|
return lessThanOrEqual(0.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method will create Condition that checks if value is value X is less than or equal to value Y
|
||||||
|
*
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
public static Condition lessThanOrEqual(Number value) {
|
public static Condition lessThanOrEqual(Number value) {
|
||||||
return new LessThanOrEqual(value);
|
return new LessThanOrEqual(value);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method will create Condition that checks if value is value X is greater than or equal to value Y
|
||||||
|
*
|
||||||
|
* PLEASE NOTE: This condition should be used only with pairwise methods, i.e. INDArray.match(...)
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public static Condition greaterThanOrEqual() {
|
||||||
|
// in case of pairwise MatchCondition we don't really care about number here
|
||||||
|
return greaterThanOrEqual(0.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method will create Condition that checks if value is value X is greater than or equal to value Y
|
||||||
|
*
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
public static Condition greaterThanOrEqual(Number value) {
|
public static Condition greaterThanOrEqual(Number value) {
|
||||||
return new GreaterThanOrEqual(value);
|
return new GreaterThanOrEqual(value);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method will create Condition that checks if value is value X is greater than or equal to value Y in absolute values
|
||||||
|
*
|
||||||
|
* PLEASE NOTE: This condition should be used only with pairwise methods, i.e. INDArray.match(...)
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public static Condition absGreaterThanOrEqual() {
|
||||||
|
// in case of pairwise MatchCondition we don't really care about number here
|
||||||
|
return absGreaterThanOrEqual(0.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method will create Condition that checks if value is value X is greater than or equal to value Y in absolute values
|
||||||
|
*
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
public static Condition absGreaterThanOrEqual(Number value) {
|
public static Condition absGreaterThanOrEqual(Number value) {
|
||||||
return new AbsValueGreaterOrEqualsThan(value);
|
return new AbsValueGreaterOrEqualsThan(value);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method will create Condition that checks if value is value X is less than or equal to value Y in absolute values
|
||||||
|
*
|
||||||
|
* PLEASE NOTE: This condition should be used only with pairwise methods, i.e. INDArray.match(...)
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public static Condition absLessThanOrEqual() {
|
||||||
|
// in case of pairwise MatchCondition we don't really care about number here
|
||||||
|
return absLessThanOrEqual(0.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method will create Condition that checks if value is value X is less than or equal to value Y in absolute values
|
||||||
|
*
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
public static Condition absLessThanOrEqual(Number value) {
|
public static Condition absLessThanOrEqual(Number value) {
|
||||||
return new AbsValueLessOrEqualsThan(value);
|
return new AbsValueLessOrEqualsThan(value);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method will create Condition that checks if value is value X is greater than value Y in absolute values
|
||||||
|
*
|
||||||
|
* PLEASE NOTE: This condition should be used only with pairwise methods, i.e. INDArray.match(...)
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public static Condition absGreaterThan() {
|
||||||
|
// in case of pairwise MatchCondition we don't really care about number here
|
||||||
|
return absGreaterThan(0.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method will create Condition that checks if value is value X is greater than value Y in absolute values
|
||||||
|
*
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
public static Condition absGreaterThan(Number value) {
|
public static Condition absGreaterThan(Number value) {
|
||||||
return new AbsValueGreaterThan(value);
|
return new AbsValueGreaterThan(value);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method will create Condition that checks if value is value X is less than value Y in absolute values
|
||||||
|
*
|
||||||
|
* PLEASE NOTE: This condition should be used only with pairwise methods, i.e. INDArray.match(...)
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public static Condition absLessThan() {
|
||||||
|
// in case of pairwise MatchCondition we don't really care about number here
|
||||||
|
return absLessThan(0.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method will create Condition that checks if value is value X is less than value Y in absolute values
|
||||||
|
*
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
public static Condition absLessThan(Number value) {
|
public static Condition absLessThan(Number value) {
|
||||||
return new AbsValueLessThan(value);
|
return new AbsValueLessThan(value);
|
||||||
}
|
}
|
||||||
|
|
|
@ -129,6 +129,7 @@ public interface NativeOps {
|
||||||
@Cast("Nd4jLong *") LongPointer resultShapeInfo,
|
@Cast("Nd4jLong *") LongPointer resultShapeInfo,
|
||||||
Pointer dresult,
|
Pointer dresult,
|
||||||
@Cast("Nd4jLong *") LongPointer dresultShapeInfo,
|
@Cast("Nd4jLong *") LongPointer dresultShapeInfo,
|
||||||
|
Pointer extraParams,
|
||||||
Pointer hDimension, @Cast("Nd4jLong *") LongPointer hDimensionShape,
|
Pointer hDimension, @Cast("Nd4jLong *") LongPointer hDimensionShape,
|
||||||
Pointer dDimension, @Cast("Nd4jLong *") LongPointer dDimensionShape);
|
Pointer dDimension, @Cast("Nd4jLong *") LongPointer dDimensionShape);
|
||||||
|
|
||||||
|
|
|
@ -198,7 +198,7 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
null, (LongPointer) AtomicAllocator.getInstance().getHostPointer(op.x().shapeInfoDataBuffer()), x, (LongPointer) xShapeInfo,
|
null, (LongPointer) AtomicAllocator.getInstance().getHostPointer(op.x().shapeInfoDataBuffer()), x, (LongPointer) xShapeInfo,
|
||||||
null, (LongPointer) AtomicAllocator.getInstance().getHostPointer(op.y().shapeInfoDataBuffer()), y, (LongPointer) AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(),context),
|
null, (LongPointer) AtomicAllocator.getInstance().getHostPointer(op.y().shapeInfoDataBuffer()), y, (LongPointer) AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(),context),
|
||||||
null, (LongPointer) AtomicAllocator.getInstance().getHostPointer(op.z().shapeInfoDataBuffer()), z, (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context),
|
null, (LongPointer) AtomicAllocator.getInstance().getHostPointer(op.z().shapeInfoDataBuffer()), z, (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context),
|
||||||
null,
|
null, null,
|
||||||
(LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(),
|
(LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(),
|
||||||
AtomicAllocator.getInstance().getPointer(op.dimensions(), context),
|
AtomicAllocator.getInstance().getPointer(op.dimensions(), context),
|
||||||
null);
|
null);
|
||||||
|
@ -805,7 +805,7 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo,
|
null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo,
|
||||||
null, (LongPointer) hostYShapeInfo, y, (LongPointer) yShapeInfo,
|
null, (LongPointer) hostYShapeInfo, y, (LongPointer) yShapeInfo,
|
||||||
null, (LongPointer) hostZShapeInfo, z, (LongPointer) zShapeInfo,
|
null, (LongPointer) hostZShapeInfo, z, (LongPointer) zShapeInfo,
|
||||||
null,
|
null, null,
|
||||||
(LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(),
|
(LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(),
|
||||||
AtomicAllocator.getInstance().getPointer(op.dimensions(), context),
|
AtomicAllocator.getInstance().getPointer(op.dimensions(), context),
|
||||||
null);
|
null);
|
||||||
|
|
|
@ -916,6 +916,7 @@ public native void execBroadcastBool(
|
||||||
Pointer dY, @Cast("Nd4jLong*") LongPointer dYShapeInfo,
|
Pointer dY, @Cast("Nd4jLong*") LongPointer dYShapeInfo,
|
||||||
Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo,
|
Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo,
|
||||||
Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo,
|
Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo,
|
||||||
|
Pointer extraParams,
|
||||||
Pointer hDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape,
|
Pointer hDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape,
|
||||||
Pointer dDimension, @Cast("Nd4jLong*") LongPointer dDimensionShape);
|
Pointer dDimension, @Cast("Nd4jLong*") LongPointer dDimensionShape);
|
||||||
public native void execBroadcastBool(
|
public native void execBroadcastBool(
|
||||||
|
@ -927,6 +928,7 @@ public native void execBroadcastBool(
|
||||||
Pointer dY, @Cast("Nd4jLong*") LongBuffer dYShapeInfo,
|
Pointer dY, @Cast("Nd4jLong*") LongBuffer dYShapeInfo,
|
||||||
Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo,
|
Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo,
|
||||||
Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo,
|
Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo,
|
||||||
|
Pointer extraParams,
|
||||||
Pointer hDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape,
|
Pointer hDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape,
|
||||||
Pointer dDimension, @Cast("Nd4jLong*") LongBuffer dDimensionShape);
|
Pointer dDimension, @Cast("Nd4jLong*") LongBuffer dDimensionShape);
|
||||||
public native void execBroadcastBool(
|
public native void execBroadcastBool(
|
||||||
|
@ -938,6 +940,7 @@ public native void execBroadcastBool(
|
||||||
Pointer dY, @Cast("Nd4jLong*") long[] dYShapeInfo,
|
Pointer dY, @Cast("Nd4jLong*") long[] dYShapeInfo,
|
||||||
Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo,
|
Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo,
|
||||||
Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo,
|
Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo,
|
||||||
|
Pointer extraParams,
|
||||||
Pointer hDimension, @Cast("Nd4jLong*") long[] hDimensionShape,
|
Pointer hDimension, @Cast("Nd4jLong*") long[] hDimensionShape,
|
||||||
Pointer dDimension, @Cast("Nd4jLong*") long[] dDimensionShape);
|
Pointer dDimension, @Cast("Nd4jLong*") long[] dDimensionShape);
|
||||||
|
|
||||||
|
|
|
@ -967,6 +967,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
||||||
null, null,
|
null, null,
|
||||||
op.z().data().addressPointer(), (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(),
|
op.z().data().addressPointer(), (LongPointer) op.z().shapeInfoDataBuffer().addressPointer(),
|
||||||
null, null,
|
null, null,
|
||||||
|
null,
|
||||||
op.dimensions().data().addressPointer(),
|
op.dimensions().data().addressPointer(),
|
||||||
(LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(),
|
(LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(),
|
||||||
null,
|
null,
|
||||||
|
|
|
@ -916,6 +916,7 @@ public native void execBroadcastBool(
|
||||||
Pointer dY, @Cast("Nd4jLong*") LongPointer dYShapeInfo,
|
Pointer dY, @Cast("Nd4jLong*") LongPointer dYShapeInfo,
|
||||||
Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo,
|
Pointer hZ, @Cast("Nd4jLong*") LongPointer hZShapeInfo,
|
||||||
Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo,
|
Pointer dZ, @Cast("Nd4jLong*") LongPointer dZShapeInfo,
|
||||||
|
Pointer extraParams,
|
||||||
Pointer hDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape,
|
Pointer hDimension, @Cast("Nd4jLong*") LongPointer hDimensionShape,
|
||||||
Pointer dDimension, @Cast("Nd4jLong*") LongPointer dDimensionShape);
|
Pointer dDimension, @Cast("Nd4jLong*") LongPointer dDimensionShape);
|
||||||
public native void execBroadcastBool(
|
public native void execBroadcastBool(
|
||||||
|
@ -927,6 +928,7 @@ public native void execBroadcastBool(
|
||||||
Pointer dY, @Cast("Nd4jLong*") LongBuffer dYShapeInfo,
|
Pointer dY, @Cast("Nd4jLong*") LongBuffer dYShapeInfo,
|
||||||
Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo,
|
Pointer hZ, @Cast("Nd4jLong*") LongBuffer hZShapeInfo,
|
||||||
Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo,
|
Pointer dZ, @Cast("Nd4jLong*") LongBuffer dZShapeInfo,
|
||||||
|
Pointer extraParams,
|
||||||
Pointer hDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape,
|
Pointer hDimension, @Cast("Nd4jLong*") LongBuffer hDimensionShape,
|
||||||
Pointer dDimension, @Cast("Nd4jLong*") LongBuffer dDimensionShape);
|
Pointer dDimension, @Cast("Nd4jLong*") LongBuffer dDimensionShape);
|
||||||
public native void execBroadcastBool(
|
public native void execBroadcastBool(
|
||||||
|
@ -938,6 +940,7 @@ public native void execBroadcastBool(
|
||||||
Pointer dY, @Cast("Nd4jLong*") long[] dYShapeInfo,
|
Pointer dY, @Cast("Nd4jLong*") long[] dYShapeInfo,
|
||||||
Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo,
|
Pointer hZ, @Cast("Nd4jLong*") long[] hZShapeInfo,
|
||||||
Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo,
|
Pointer dZ, @Cast("Nd4jLong*") long[] dZShapeInfo,
|
||||||
|
Pointer extraParams,
|
||||||
Pointer hDimension, @Cast("Nd4jLong*") long[] hDimensionShape,
|
Pointer hDimension, @Cast("Nd4jLong*") long[] hDimensionShape,
|
||||||
Pointer dDimension, @Cast("Nd4jLong*") long[] dDimensionShape);
|
Pointer dDimension, @Cast("Nd4jLong*") long[] dDimensionShape);
|
||||||
|
|
||||||
|
|
|
@ -46,6 +46,7 @@ import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||||
|
import org.nd4j.linalg.indexing.conditions.Conditions;
|
||||||
import org.nd4j.nativeblas.NativeOpsHolder;
|
import org.nd4j.nativeblas.NativeOpsHolder;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
@ -1087,6 +1088,24 @@ public class CustomOpsTests extends BaseNd4jTest {
|
||||||
assertArrayEquals(new long[]{1,10, 2}, lsd.get(0).getShape());
|
assertArrayEquals(new long[]{1,10, 2}, lsd.get(0).getShape());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testMatch_1() {
|
||||||
|
INDArray x = Nd4j.ones(DataType.FLOAT, 3,3);
|
||||||
|
INDArray y = Nd4j.linspace(DataType.FLOAT, -5, 9, 1).reshape(3, 3);
|
||||||
|
val c = Conditions.equals(0.0);
|
||||||
|
|
||||||
|
System.out.println("Y:\n" + y);
|
||||||
|
|
||||||
|
INDArray z = x.match(y, c);
|
||||||
|
INDArray exp = Nd4j.createFromArray(new boolean[][]{
|
||||||
|
{false, false, false},
|
||||||
|
{false, false, false},
|
||||||
|
{true, false, false}
|
||||||
|
});
|
||||||
|
|
||||||
|
assertEquals(exp, z);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testCreateOp_1() {
|
public void testCreateOp_1() {
|
||||||
|
|
Loading…
Reference in New Issue