diff --git a/libnd4j/include/loops/cpu/random.cpp b/libnd4j/include/loops/cpu/random.cpp index d4c808719..6fccc6376 100644 --- a/libnd4j/include/loops/cpu/random.cpp +++ b/libnd4j/include/loops/cpu/random.cpp @@ -29,6 +29,7 @@ using namespace randomOps; namespace functions { namespace random { + template template void RandomFunction::execTransform(Nd4jPointer state, @@ -56,18 +57,32 @@ namespace functions { if(shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { - uint xShapeInfoCast[MAX_RANK]; - const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - auto func = PRAGMA_THREADS_FOR { - PRAGMA_OMP_SIMD - for (auto i = start; i < stop; i += increment) { - auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - z[offset] = OpClass::op(x[offset], y[offset], i, length, rng, extraArguments); - } - }; + if(shape::elementWiseStride(zShapeInfo) == 1 && shape::elementWiseStride(xShapeInfo) == 1 && shape::elementWiseStride(yShapeInfo) == 1 && + shape::order(xShapeInfo) == shape::order(zShapeInfo) && shape::order(zShapeInfo) == shape::order(yShapeInfo) ){ + + auto func = PRAGMA_THREADS_FOR { + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + z[i] = OpClass::op(x[i], y[i], i, length, rng, extraArguments); + } + }; + samediff::Threads::parallel_for(func, 0, length, 1); + } + else{ + uint xShapeInfoCast[MAX_RANK]; + const bool canCastX = nd4j::DataTypeUtils::castShapeInfo(xShapeInfo, xShapeInfoCast); - samediff::Threads::parallel_for(func, 0, length, 1); + auto func = PRAGMA_THREADS_FOR { + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i += increment) { + auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + z[offset] = OpClass::op(x[offset], y[offset], i, length, rng, extraArguments); + } + }; + + samediff::Threads::parallel_for(func, 0, length, 1); + } } else if (shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo)) { @@ -169,15 +184,27 @@ namespace functions { if(shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo)) { - auto func = PRAGMA_THREADS_FOR { - PRAGMA_OMP_SIMD - for (uint64_t i = start; i < stop; i += increment) { - auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); - z[offset] = OpClass::op(x[offset], i, length, rng, extraArguments); - } - }; + if(shape::elementWiseStride(zShapeInfo) == 1 && shape::elementWiseStride(xShapeInfo) == 1 && shape::order(xShapeInfo) == shape::order(zShapeInfo)){ + + auto func = PRAGMA_THREADS_FOR { + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + z[i] = OpClass::op(x[i], i, length, rng, extraArguments); + } + }; + samediff::Threads::parallel_for(func, 0, length, 1); + } + else{ + auto func = PRAGMA_THREADS_FOR { + PRAGMA_OMP_SIMD + for (uint64_t i = start; i < stop; i += increment) { + auto offset = shape::indexOffset(i, xShapeInfo, xShapeInfoCast, canCastX); + z[offset] = OpClass::op(x[offset], i, length, rng, extraArguments); + } + }; - samediff::Threads::parallel_for(func, 0, length, 1); + samediff::Threads::parallel_for(func, 0, length, 1); + } } else { @@ -208,20 +235,34 @@ namespace functions { auto length = shape::length(zShapeInfo); nd4j::graph::RandomGenerator* rng = reinterpret_cast(state); - nd4j::OmpLaunchHelper info(length); - uint zShapeInfoCast[MAX_RANK]; - const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); + if(shape::elementWiseStride(zShapeInfo) == 1){ - auto func = PRAGMA_THREADS_FOR { - PRAGMA_OMP_SIMD - for (uint64_t i = start; i < stop; i += increment) { - auto offset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); - z[offset] = OpClass::op(i, length, rng, extraArguments); - } - }; + auto func = PRAGMA_THREADS_FOR { + PRAGMA_OMP_SIMD + for (auto i = start; i < stop; i++) { + z[i] = OpClass::op( i, length, rng, extraArguments); + } + }; - samediff::Threads::parallel_for(func, 0, length, 1); + samediff::Threads::parallel_for(func, 0, length, 1); + } + else{ + nd4j::OmpLaunchHelper info(length); + + uint zShapeInfoCast[MAX_RANK]; + const bool canCastZ = nd4j::DataTypeUtils::castShapeInfo(zShapeInfo, zShapeInfoCast); + + auto func = PRAGMA_THREADS_FOR { + PRAGMA_OMP_SIMD + for (uint64_t i = start; i < stop; i += increment) { + auto offset = shape::indexOffset(i, zShapeInfo, zShapeInfoCast, canCastZ); + z[offset] = OpClass::op(i, length, rng, extraArguments); + } + }; + + samediff::Threads::parallel_for(func, 0, length, 1); + } } template